Source code for czsc.utils.signal_analyzer
# -*- coding: utf-8 -*-
"""
author: zengbin93
email: zeng_bin8888@163.com
create_dt: 2023/3/30 21:13
describe:
"""
import os
import hashlib
import pandas as pd
from copy import deepcopy
from tqdm import tqdm
from loguru import logger
from typing import List, AnyStr
from concurrent.futures import ProcessPoolExecutor
[docs]class SignalPerformance:
"""信号表现分析"""
def __init__(self, dfs: pd.DataFrame, keys: List[AnyStr]):
"""
:param dfs: 信号表
:param keys: 信号列,支持一个或多个信号列组合分析
"""
base_cols = [x for x in dfs.columns if len(x.split("_")) != 3]
dfs = dfs[base_cols + keys].copy()
if 'year' not in dfs.columns:
y = dfs['dt'].apply(lambda x: x.year)
dfs['year'] = y.values
self.dfs = dfs
self.keys = keys
self.b_cols = [x for x in dfs.columns if x[0] == 'b' and x[-1] == 'b']
self.n_cols = [x for x in dfs.columns if x[0] == 'n' and x[-1] == 'b']
def __return_performance(self, dfs: pd.DataFrame, mode: str = '1b') -> pd.DataFrame:
"""分析信号组合的分类能力,也就是信号出现前后的收益情况
:param dfs: 信号数据表,
:param mode: 分析模式,
0b 截面向前看
0n 截面向后看
1b 时序向前看
1n 时序向后看
:return:
"""
mode = mode.lower()
assert mode in ['0b', '0n', '1b', '1n']
keys = self.keys
len_dfs = len(dfs)
cols = self.b_cols if mode.endswith('b') else self.n_cols
sdt = dfs['dt'].min().strftime("%Y%m%d")
edt = dfs['dt'].max().strftime("%Y%m%d")
def __static(_df, _name):
_res = {"name": _name, "date_span": f"{sdt} ~ {edt}",
"count": len(_df), "cover": round(len(_df) / len_dfs, 4)}
if mode.startswith('0'):
_r = _df.groupby('dt')[cols].mean().mean().to_dict()
else:
_r = _df[cols].mean().to_dict()
_res.update(_r)
return _res
results = [__static(dfs, "基准")]
for values, dfg in dfs.groupby(by=keys if len(keys) > 1 else keys[0]):
if isinstance(values, str):
values = [values]
assert isinstance(keys, (list, tuple)) and isinstance(values, (list, tuple))
assert len(keys) == len(values)
name = "#".join([f"{key1}_{name1}" for key1, name1 in zip(keys, values)])
results.append(__static(dfg, name))
dfr = pd.DataFrame(results)
dfr[cols] = dfr[cols].round(2)
return dfr
[docs] def analyze(self, mode='0b') -> pd.DataFrame:
"""分析信号出现前后的收益情况
:param mode: 分析模式,
0b 截面向前看
0n 截面向后看
1b 时序向前看
1n 时序向后看
:return:
"""
dfr = self.__return_performance(self.dfs, mode)
results = [dfr]
for year, df_ in self.dfs.groupby('year'):
dfr_ = self.__return_performance(df_, mode)
results.append(dfr_)
dfr = pd.concat(results, ignore_index=True)
return dfr
[docs] def report(self, file_xlsx=None):
res = {
'向后看截面': self.analyze('0n'),
'向后看时序': self.analyze('1n'),
}
if file_xlsx:
writer = pd.ExcelWriter(file_xlsx)
for sn, df_ in res.items():
df_.to_excel(writer, sheet_name=sn, index=False)
writer.close()
return res
[docs]class SignalAnalyzer:
def __init__(self, symbols, read_bars, signals_config, results_path, **kwargs):
"""信号分析
:param symbols: 品种列表
:param read_bars: 读取K线的函数
:param signals_config: 信号配置
:param results_path: 结果保存路径
:param kwargs: 其他参数
- sdt: 信号生成的开始时间
- edt: 信号生成的结束时间
- bar_sdt: 读取K线的开始时间
"""
self.version = 'V230520'
self.symbols = symbols
self.read_bars = read_bars
self.signals_config = signals_config
self.results_path = results_path
os.makedirs(self.results_path, exist_ok=True)
self.signals_path = os.path.join(self.results_path, 'signals')
os.makedirs(self.signals_path, exist_ok=True)
self.kwargs = kwargs
self.task_hash = hashlib.sha256((str(signals_config) + str(symbols)).encode('utf-8')).hexdigest()[:8].upper()
[docs] def generate_symbol_signals(self, symbol):
from czsc.traders.sig_parse import get_signals_freqs
from czsc.traders.base import generate_czsc_signals
from czsc.utils.trade import update_nbars
try:
file_cache = os.path.join(self.signals_path, f"{symbol}.parquet")
if os.path.exists(file_cache):
sigs = pd.read_parquet(file_cache)
else:
freqs = get_signals_freqs(deepcopy(self.signals_config))
sdt = self.kwargs.get('sdt', '20170101')
edt = self.kwargs.get('edt', '20220101')
bar_sdt = self.kwargs.get('bar_sdt', '20150101')
bars = self.read_bars(symbol, freqs[0], bar_sdt, edt, fq='后复权')
if len(bars) < 100:
logger.error(f"{symbol} 信号生成失败:数据量不足")
return pd.DataFrame()
sigs: pd.DataFrame = generate_czsc_signals(bars, deepcopy(self.signals_config), sdt=sdt, df=True) # type: ignore
if sigs.empty:
logger.error(f"{symbol} 信号生成失败:数据量不足")
return pd.DataFrame()
sigs.drop(['freq', 'cache'], axis=1, inplace=True)
update_nbars(sigs, price_col='open', move=1,
numbers=(1, 2, 3, 5, 8, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100))
sigs.to_parquet(file_cache)
return sigs
except Exception as e:
logger.exception(e)
logger.error(f"{symbol} 信号生成失败: {e}")
return pd.DataFrame()
[docs] @staticmethod
def find_valuable_signals(dfp):
"""根据信号表现,找出表现好的信号
:param dfp: 信号表现分析结果
:return: 表现好的信号
"""
n_cols = [x for x in dfp.columns if x.startswith('n') and x.endswith('b')]
# 价值描述:1)与基准的差值越大越好;2)与基准的差值越大,且胜率越高越好
rows = []
for _, dfg in dfp.groupby('date_span'):
base = dfg[dfg['name'] == '基准'].iloc[0].to_dict()
olds = dfg[dfg['name'] != '基准'].to_dict(orient='records')
sum_base = sum([base[x] for x in n_cols])
for row in olds:
if '其他' in row['name']:
continue
delta = [row[x] - base[x] for x in n_cols]
win_rate = sum([1 if x > 0 else 0 for x in delta]) / len(delta)
row['delta_win_rate'] = win_rate
sum_delta = sum(delta)
if abs(sum_delta) / sum_base < 0.1:
continue
if (win_rate > 0.7 and sum_delta > 0) or (win_rate < 0.3 and sum_delta < 0):
rows.append(row)
return pd.DataFrame(rows)
[docs] def execute(self, max_workers=10):
"""执行信号分析"""
symbols_sig = []
if max_workers <= 1:
for symbol in tqdm(self.symbols, desc="生成信号"):
sigs = self.generate_symbol_signals(symbol)
if not sigs.empty:
symbols_sig.append(sigs)
else:
with ProcessPoolExecutor(max_workers=max_workers) as executor:
results = executor.map(self.generate_symbol_signals, self.symbols)
for result in results:
if not result.empty:
symbols_sig.append(result)
results_path = self.results_path
dfs = pd.concat(symbols_sig, ignore_index=True)
sig_keys = [x for x in dfs.columns if len(x.split("_")) == 3]
sps = {'向后看截面': [], '向后看时序': []}
raw_results_path = os.path.join(results_path, 'raw_results')
os.makedirs(raw_results_path, exist_ok=True)
for key in tqdm(sig_keys, desc="分析信号表现"):
sp = SignalPerformance(dfs, keys=[key])
res = sp.report(os.path.join(raw_results_path, f'{key}.xlsx'))
for k, v in res.items():
sps[k].append(v)
for k, v in sps.items():
dfp = pd.concat(v, ignore_index=True)
dfp.drop_duplicates(subset=['name', 'date_span'], inplace=True, ignore_index=True)
dfp.to_excel(os.path.join(results_path, f'{self.task_hash}_{k}_汇总.xlsx'), index=False)
dfp_valuable = self.find_valuable_signals(dfp)
dfp_valuable.to_excel(os.path.join(results_path, f'{self.task_hash}_{k}_有价值信号.xlsx'), index=False)