Source code for czsc.utils.signal_analyzer
# -*- coding: utf-8 -*-
"""
author: zengbin93
email: zeng_bin8888@163.com
create_dt: 2023/3/30 21:13
describe:
"""
import os
import hashlib
import pandas as pd
from copy import deepcopy
from tqdm import tqdm
from loguru import logger
from deprecated import deprecated
from typing import List, AnyStr
from concurrent.futures import ProcessPoolExecutor
[docs]@deprecated(version="1.0.0", reason="分析方法不太合理,不再使用")
class SignalPerformance:
"""信号表现分析"""
def __init__(self, dfs: pd.DataFrame, keys: List[AnyStr]):
"""
:param dfs: 信号表
:param keys: 信号列,支持一个或多个信号列组合分析
"""
base_cols = [x for x in dfs.columns if len(x.split("_")) != 3]
dfs = dfs[base_cols + keys].copy()
if "year" not in dfs.columns:
y = dfs["dt"].apply(lambda x: x.year)
dfs["year"] = y.values
self.dfs = dfs
self.keys = keys
self.b_cols = [x for x in dfs.columns if x[0] == "b" and x[-1] == "b"]
self.n_cols = [x for x in dfs.columns if x[0] == "n" and x[-1] == "b"]
def __return_performance(self, dfs: pd.DataFrame, mode: str = "1b") -> pd.DataFrame:
"""分析信号组合的分类能力,也就是信号出现前后的收益情况
:param dfs: 信号数据表,
:param mode: 分析模式,
0b 截面向前看
0n 截面向后看
1b 时序向前看
1n 时序向后看
:return:
"""
mode = mode.lower()
assert mode in ["0b", "0n", "1b", "1n"]
keys = self.keys
len_dfs = len(dfs)
cols = self.b_cols if mode.endswith("b") else self.n_cols
sdt = dfs["dt"].min().strftime("%Y%m%d")
edt = dfs["dt"].max().strftime("%Y%m%d")
def __static(_df, _name):
_res = {
"name": _name,
"date_span": f"{sdt} ~ {edt}",
"count": len(_df),
"cover": round(len(_df) / len_dfs, 4),
}
if mode.startswith("0"):
_r = _df.groupby("dt")[cols].mean().mean().to_dict()
else:
_r = _df[cols].mean().to_dict()
_res.update(_r)
return _res
results = [__static(dfs, "基准")]
for values, dfg in dfs.groupby(by=keys if len(keys) > 1 else keys[0]):
if isinstance(values, str):
values = [values]
assert isinstance(keys, (list, tuple)) and isinstance(values, (list, tuple))
assert len(keys) == len(values)
name = "#".join([f"{key1}_{name1}" for key1, name1 in zip(keys, values)])
results.append(__static(dfg, name))
dfr = pd.DataFrame(results)
dfr[cols] = dfr[cols].round(2)
return dfr
[docs] def analyze(self, mode="0b") -> pd.DataFrame:
"""分析信号出现前后的收益情况
:param mode: 分析模式,
0b 截面向前看
0n 截面向后看
1b 时序向前看
1n 时序向后看
:return:
"""
dfr = self.__return_performance(self.dfs, mode)
results = [dfr]
for year, df_ in self.dfs.groupby("year"):
dfr_ = self.__return_performance(df_, mode)
results.append(dfr_)
dfr = pd.concat(results, ignore_index=True)
return dfr
[docs] def report(self, file_xlsx=None):
res = {
"向后看截面": self.analyze("0n"),
"向后看时序": self.analyze("1n"),
}
if file_xlsx:
writer = pd.ExcelWriter(file_xlsx)
for sn, df_ in res.items():
df_.to_excel(writer, sheet_name=sn, index=False)
writer.close()
return res
[docs]@deprecated(version="1.0.0", reason="分析方法不太合理,不再使用")
class SignalAnalyzer:
def __init__(self, symbols, read_bars, signals_config, results_path, **kwargs):
"""信号分析
:param symbols: 品种列表
:param read_bars: 读取K线的函数
:param signals_config: 信号配置
:param results_path: 结果保存路径
:param kwargs: 其他参数
- sdt: 信号生成的开始时间
- edt: 信号生成的结束时间
- bar_sdt: 读取K线的开始时间
"""
self.version = "V230520"
self.symbols = symbols
self.read_bars = read_bars
self.signals_config = signals_config
self.results_path = results_path
os.makedirs(self.results_path, exist_ok=True)
self.signals_path = os.path.join(self.results_path, "signals")
os.makedirs(self.signals_path, exist_ok=True)
self.kwargs = kwargs
self.task_hash = hashlib.sha256((str(signals_config) + str(symbols)).encode("utf-8")).hexdigest()[:8].upper()
[docs] def generate_symbol_signals(self, symbol):
from czsc.traders.sig_parse import get_signals_freqs
from czsc.traders.base import generate_czsc_signals
from czsc.utils.trade import update_nbars
try:
file_cache = os.path.join(self.signals_path, f"{symbol}.parquet")
if os.path.exists(file_cache):
sigs = pd.read_parquet(file_cache)
else:
freqs = get_signals_freqs(deepcopy(self.signals_config))
sdt = self.kwargs.get("sdt", "20170101")
edt = self.kwargs.get("edt", "20220101")
bar_sdt = self.kwargs.get("bar_sdt", "20150101")
bars = self.read_bars(symbol, freqs[0], bar_sdt, edt, fq="后复权")
if len(bars) < 100:
logger.error(f"{symbol} 信号生成失败:数据量不足")
return pd.DataFrame()
sigs: pd.DataFrame = generate_czsc_signals(bars, deepcopy(self.signals_config), sdt=sdt, df=True) # type: ignore
if sigs.empty:
logger.error(f"{symbol} 信号生成失败:数据量不足")
return pd.DataFrame()
sigs.drop(["freq", "cache"], axis=1, inplace=True)
update_nbars(
sigs, price_col="open", move=1, numbers=(1, 2, 3, 5, 8, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100)
)
sigs.to_parquet(file_cache)
return sigs
except Exception as e:
logger.exception(e)
logger.error(f"{symbol} 信号生成失败: {e}")
return pd.DataFrame()
[docs] @staticmethod
def find_valuable_signals(dfp):
"""根据信号表现,找出表现好的信号
:param dfp: 信号表现分析结果
:return: 表现好的信号
"""
n_cols = [x for x in dfp.columns if x.startswith("n") and x.endswith("b")]
# 价值描述:1)与基准的差值越大越好;2)与基准的差值越大,且胜率越高越好
rows = []
for _, dfg in dfp.groupby("date_span"):
base = dfg[dfg["name"] == "基准"].iloc[0].to_dict()
olds = dfg[dfg["name"] != "基准"].to_dict(orient="records")
sum_base = sum([base[x] for x in n_cols])
for row in olds:
if "其他" in row["name"]:
continue
delta = [row[x] - base[x] for x in n_cols]
win_rate = sum([1 if x > 0 else 0 for x in delta]) / len(delta)
row["delta_win_rate"] = win_rate
sum_delta = sum(delta)
if abs(sum_delta) / sum_base < 0.1:
continue
if (win_rate > 0.7 and sum_delta > 0) or (win_rate < 0.3 and sum_delta < 0):
rows.append(row)
return pd.DataFrame(rows)
[docs] def execute(self, max_workers=10):
"""执行信号分析"""
symbols_sig = []
if max_workers <= 1:
for symbol in tqdm(self.symbols, desc="生成信号"):
sigs = self.generate_symbol_signals(symbol)
if not sigs.empty:
symbols_sig.append(sigs)
else:
with ProcessPoolExecutor(max_workers=max_workers) as executor:
results = executor.map(self.generate_symbol_signals, self.symbols)
for result in results:
if not result.empty:
symbols_sig.append(result)
results_path = self.results_path
dfs = pd.concat(symbols_sig, ignore_index=True)
sig_keys = [x for x in dfs.columns if len(x.split("_")) == 3]
sps = {"向后看截面": [], "向后看时序": []}
raw_results_path = os.path.join(results_path, "raw_results")
os.makedirs(raw_results_path, exist_ok=True)
for key in tqdm(sig_keys, desc="分析信号表现"):
sp = SignalPerformance(dfs, keys=[key])
res = sp.report(os.path.join(raw_results_path, f"{key}.xlsx"))
for k, v in res.items():
sps[k].append(v)
for k, v in sps.items():
dfp = pd.concat(v, ignore_index=True)
dfp.drop_duplicates(subset=["name", "date_span"], inplace=True, ignore_index=True)
dfp.to_excel(os.path.join(results_path, f"{self.task_hash}_{k}_汇总.xlsx"), index=False)
dfp_valuable = self.find_valuable_signals(dfp)
dfp_valuable.to_excel(os.path.join(results_path, f"{self.task_hash}_{k}_有价值信号.xlsx"), index=False)