Source code for czsc.traders.dummy

# -*- coding: utf-8 -*-
"""
author: zengbin93
email: zeng_bin8888@163.com
create_dt: 2023/3/23 19:12
describe:
"""
import os
import time
import pandas as pd
from tqdm import tqdm
from loguru import logger
from concurrent.futures import ProcessPoolExecutor
from czsc import fsa
from czsc.traders.base import generate_czsc_signals


[docs]class DummyBacktest: def __init__(self, strategy, signals_path, results_path, read_bars, **kwargs): """策略回测(支持多进程执行) :param strategy: CZSC择时策略 :param signals_path: 信号文件存放路径 :param results_path: 回测结果存放路径 :param read_bars: 读入K线数据的函数 函数签名为:read_bars(symbol, freq, sdt, edt, fq) -> List[RawBar] :param kwargs: 其他参数 - signals_module_name: 信号函数模块名,用于动态加载信号文件,默认为 czsc.signals """ from czsc.strategies import CzscStrategyBase assert issubclass(strategy, CzscStrategyBase), "strategy 必须是 CzscStrategyBase 的子类" self.strategy = strategy self.results_path = results_path os.makedirs(self.results_path, exist_ok=True) self.signals_path = signals_path os.makedirs(self.signals_path, exist_ok=True) # 缓存 poss 数据 self.poss_path = os.path.join(results_path, 'poss') os.makedirs(self.poss_path, exist_ok=True) logger.add(os.path.join(self.results_path, 'dummy.log'), encoding='utf-8', enqueue=True) self.read_bars = read_bars self.kwargs = kwargs # 回测起止时间 self.sdt = kwargs.get('sdt', '20100101') self.edt = kwargs.get('edt', '20230301') self.bars_sdt = pd.to_datetime(self.sdt) - pd.Timedelta(days=365*3)
[docs] def replay(self, symbol): """回放单个品种的交易""" tactic = self.strategy(symbol=symbol, **self.kwargs) bars = self.read_bars(symbol, tactic.base_freq, self.sdt, self.edt, fq='后复权') tactic.replay(bars, os.path.join(self.results_path, f"{symbol}_replay"), sdt='20200101')
[docs] def one_symbol_dummy(self, symbol): """回测单个品种""" start_time = time.time() tactic = self.strategy(symbol=symbol, **self.kwargs) symbol_path = os.path.join(self.poss_path, symbol) if os.path.exists(symbol_path): logger.info(f"{symbol} 已经回测过,跳过") return None os.makedirs(symbol_path, exist_ok=True) try: file_sigs = os.path.join(self.signals_path, f"{symbol}.sigs") if not os.path.exists(file_sigs): bars = self.read_bars(symbol, tactic.base_freq, self.bars_sdt, self.edt, fq='后复权') sigs = generate_czsc_signals(bars, signals_config=tactic.signals_config, sdt=self.sdt, df=True) sigs.drop(columns=['freq', 'cache'], inplace=True) sigs.to_parquet(file_sigs) else: sigs = pd.read_parquet(file_sigs) sigs = sigs[sigs['dt'] >= self.sdt] sigs = sigs.to_dict('records') trader = tactic.dummy(sigs) except Exception as e: logger.exception(e) return None for pos in trader.positions: try: file_pairs = os.path.join(symbol_path, f"{pos.name}.pairs") file_holds = os.path.join(symbol_path, f"{pos.name}.holds") pairs = pd.DataFrame(pos.pairs) pairs.to_parquet(file_pairs) dfh = pd.DataFrame(pos.holds) dfh['n1b'] = (dfh['price'].shift(-1) / dfh['price'] - 1) * 10000 dfh.fillna(0, inplace=True) dfh['symbol'] = pos.symbol dfh.to_parquet(file_holds) except Exception as e: logger.debug(f"{symbol} {pos.name} 保存失败,原因:{e}") logger.info(f"{symbol} 回测完成,共 {len(trader.positions)} 个持仓策略,耗时 {time.time() - start_time:.2f} 秒")
[docs] def one_pos_stats(self, pos_name): """分析单个持仓策略的表现""" from czsc.traders.performance import PairsPerformance symbols = os.listdir(self.poss_path) pos_pairs = [] pos_holds = [] for symbol in tqdm(symbols, desc=f"读取 {pos_name}"): try: dfp = pd.read_parquet(os.path.join(self.poss_path, f"{symbol}/{pos_name}.pairs")) pos_pairs.append(dfp) dfh = pd.read_parquet(os.path.join(self.poss_path, f"{symbol}/{pos_name}.holds")) pos_holds.append(dfh[dfh['pos'] != 0]) except Exception as e: logger.debug(f"{symbol} 读取失败,原因:{e}") pairs = pd.concat(pos_pairs, ignore_index=True) if not pairs.empty: pp = PairsPerformance(pairs) pairs.to_feather(os.path.join(self.results_path, f"{pos_name}_pairs.feather")) pp.agg_to_excel(os.path.join(self.results_path, f"{pos_name}_回测结果.xlsx")) stats = dict(pp.basic_info) # 加入截面等权评价 holds = pd.concat(pos_holds, ignore_index=True) cross = holds.groupby('dt').apply(lambda x: (x['n1b'] * x['pos']).sum() / sum(x['pos'] != 0)) stats['截面等权收益'] = cross.sum() cross.to_excel(os.path.join(self.results_path, f"{pos_name}_截面等权收益.xlsx"), index=True) stats['pos_name'] = pos_name return stats else: return None
[docs] def execute(self, symbols, n_jobs=2, **kwargs): """回测多个品种 :param symbols: 品种列表 :param n_jobs: 进程数量,默认为 2 需要注意的是: 1. 如果进程数过多,可能会导致内存不足 2. 多进程在 pycharm 的 ipython 中无法使用,需要在命令行中运行 :param kwargs: :return: """ results_path = self.results_path tactic = self.strategy(symbol="symbol", **self.kwargs) dumps_map = {pos.name: pos.dump() for pos in tactic.positions} logger.info(f"策略回测,持仓策略数量:{len(tactic.positions)},共 {len(symbols)} 只标的,使用 {n_jobs} 个进程;" f"结果保存在 {results_path}。请耐心等待...") with ProcessPoolExecutor(n_jobs) as pool: pool.map(self.one_symbol_dummy, sorted(symbols)) all_stats = [] with ProcessPoolExecutor(max_workers=min(n_jobs, 6)) as pool: _stats = pool.map(self.one_pos_stats, list(dumps_map.keys())) for _s in _stats: if not _s: continue _s['pos_dump'] = dumps_map[_s['pos_name']] all_stats.append(_s) file_report = os.path.join(results_path, f'{self.strategy.__name__}_回测结果汇总.xlsx') report_df = pd.DataFrame(all_stats).sort_values(['截面等权收益'], ascending=False, ignore_index=True) report_df.to_excel(file_report, index=False) logger.info(f"策略回测完成,结果保存在 {results_path}。") if kwargs.get('feishu_app_id') and kwargs.get('feishu_app_secret'): if os.path.exists(file_report): fsa.push_message(file_report, msg_type='file', **kwargs) else: fsa.push_message(f"{self.strategy.__name__} 回测结果为空,请检查原因!", **kwargs)