Source code for czsc.traders.utils

# -*- coding: utf-8 -*-
"""
author: zengbin93
email: zeng_bin8888@163.com
create_dt: 2021/12/12 21:49
"""
import os
import dill
from tqdm import tqdm
from loguru import logger
from typing import List, Callable
from czsc.analyze import CZSC
from czsc.utils import x_round, BarGenerator, kline_pro
from czsc.objects import RawBar
from czsc.traders.advanced import CzscAdvancedTrader


[docs]def trade_replay(bg: BarGenerator, raw_bars: List[RawBar], strategy: Callable, res_path): """交易策略交易过程回放""" os.makedirs(res_path, exist_ok=True) trader = CzscAdvancedTrader(bg, strategy) for bar in raw_bars: trader.update(bar) if trader.long_pos and trader.long_pos.pos_changed: op = trader.long_pos.operates[-1] _dt = op['dt'].strftime('%Y%m%d#%H%M') file_name = f"{op['op'].value}_{_dt}_{op['bid']}_{x_round(op['price'], 2)}_{op['op_desc']}.html" file_html = os.path.join(res_path, file_name) trader.take_snapshot(file_html) logger.info(f'snapshot saved into {file_html}') if trader.short_pos and trader.short_pos.pos_changed: op = trader.short_pos.operates[-1] _dt = op['dt'].strftime('%Y%m%d#%H%M') file_name = f"{op['op'].value}_{_dt}_{op['bid']}_{x_round(op['price'], 2)}_{op['op_desc']}.html" file_html = os.path.join(res_path, file_name) trader.take_snapshot(file_html) logger.info(f'snapshot saved into {file_html}') c = CZSC(raw_bars, max_bi_num=10000) kline = [x.__dict__ for x in c.bars_raw] bi = [{'dt': x.fx_a.dt, "bi": x.fx_a.fx} for x in c.bi_list] + \ [{'dt': c.bi_list[-1].fx_b.dt, "bi": c.bi_list[-1].fx_b.fx}] fx = [] for bi_ in c.bi_list: fx.extend([{'dt': x.dt, "fx": x.fx} for x in bi_.fxs[1:]]) # 构建 BS 序列 bs = [] if trader.long_pos: bs.extend(trader.long_pos.operates) if trader.short_pos: bs.extend(trader.short_pos.operates) chart = kline_pro(kline, bi=bi, fx=fx, bs=bs, width="1400px", height='580px', title=f"{strategy.__name__} {bg.symbol} 交易回放") chart.render(os.path.join(res_path, f"replay_{strategy.__name__}@{bg.symbol}.html")) dill.dump(trader, open(os.path.join(res_path, "trader.pkl"), 'wb')) logger.info(f"{trader.strategy.__name__} {trader.results['long_performance']}")
[docs]def trader_fast_backtest(bars: List[RawBar], init_n: int, strategy: Callable, html_path: str = None, ): """纯 CTA 择时系统快速回测,多空交易通通支持 :param bars: 原始K线序列 :param init_n: 用于初始化 BarGenerator 的K线数量 :param strategy: 策略定义函数 :param html_path: 交易快照保存路径,默认为 None 的情况下,不保存快照 注意,保存HTML交易快照非常耗时,建议只用于核对部分标的的交易买卖点时进行保存 :return: 操作列表,交易对,性能评估 """ ts_code = bars[0].symbol tactic = strategy(ts_code) base_freq = tactic['base_freq'] freqs = tactic['freqs'] bg = BarGenerator(base_freq, freqs, max_count=5000) for bar in bars[:init_n]: bg.update(bar) ct = CzscAdvancedTrader(bg, strategy) signals = [] for bar in tqdm(bars[init_n:], desc=f"{ts_code} bt"): ct.update(bar) signals.append(ct.s) if ct.long_pos: if ct.long_pos.pos_changed and html_path: op = ct.long_pos.operates[-1] file_name = f"{op['op'].value}_{op['bid']}_{x_round(op['price'], 2)}_{op['op_desc']}.html" file_html = os.path.join(html_path, file_name) ct.take_snapshot(file_html) logger.info(f'snapshot saved into {file_html}') if ct.short_pos: if ct.short_pos.pos_changed and html_path: op = ct.short_pos.operates[-1] file_name = f"{op['op'].value}_{op['bid']}_{x_round(op['price'], 2)}_{op['op_desc']}.html" file_html = os.path.join(html_path, file_name) ct.take_snapshot(file_html) logger.info(f'snapshot saved into {file_html}') res = {"signals": signals} res.update(ct.results) return res