Source code for czsc.utils.bar_generator

# -*- coding: utf-8 -*-
"""
author: zengbin93
email: zeng_bin8888@163.com
create_dt: 2021/11/14 12:39
describe: 从任意周期K线开始合成更高周期K线的工具类
"""
from datetime import datetime, timedelta
from typing import List
from czsc.objects import RawBar, Freq


[docs]def freq_end_time(dt: datetime, freq: Freq) -> datetime: """获取 dt 对应的K线周期结束时间 :param dt: datetime :param freq: Freq :return: datetime """ dt = dt.replace(second=0, microsecond=0) if freq in [Freq.F1, Freq.F5, Freq.F15, Freq.F30, Freq.F60]: m = int(freq.value.strip("分钟")) if m < 60: if (dt.hour == 15 and dt.minute == 0) or (dt.hour == 11 and dt.minute == 30): return dt delta_m = dt.minute % m if delta_m != 0: dt += timedelta(minutes=m - delta_m) return dt else: dt_span = { 60: ["01:00", "2:00", "3:00", '10:30', "11:30", "14:00", "15:00", "22:00", "23:00", "23:59"], } for v in dt_span[m]: hour, minute = v.split(":") edt = dt.replace(hour=int(hour), minute=int(minute)) if dt <= edt: return edt # 处理 日、周、月、季、年 的结束时间 dt = dt.replace(hour=0, minute=0) if freq == Freq.D: return dt if freq == Freq.W: sdt = dt + timedelta(days=5 - dt.isoweekday()) return sdt if freq == Freq.M: if dt.month == 12: sdt = datetime(year=dt.year + 1, month=1, day=1) - timedelta(days=1) else: sdt = datetime(year=dt.year, month=dt.month + 1, day=1) - timedelta(days=1) return sdt if freq == Freq.S: dt_m = dt.month if dt_m in [1, 2, 3]: sdt = datetime(year=dt.year, month=4, day=1) - timedelta(days=1) elif dt_m in [4, 5, 6]: sdt = datetime(year=dt.year, month=7, day=1) - timedelta(days=1) elif dt_m in [7, 8, 9]: sdt = datetime(year=dt.year, month=10, day=1) - timedelta(days=1) else: sdt = datetime(year=dt.year + 1, month=1, day=1) - timedelta(days=1) return sdt if freq == Freq.Y: return datetime(year=dt.year, month=12, day=31) print(f'freq_end_time error: {dt} - {freq}') return dt
[docs]class BarGenerator: """使用日线合成周线、月线、季线""" def __init__(self, base_freq: str, freqs: List[str], max_count: int = 5000): self.symbol = None self.end_dt = None self.base_freq = base_freq self.max_count = max_count self.freqs = freqs self.bars = {v: [] for v in self.freqs} self.bars.update({base_freq: []}) self.freq_map = {f.value: f for _, f in Freq.__members__.items()} self.__validate_freq_params() def __validate_freq_params(self): # K线周期约束,由于要求 base_freq 的输入K必须是完成的,只能生成 base_freq 以上周期的K线 self.base_freq_constraint = { "1分钟": ['5分钟', '15分钟', '30分钟', '60分钟', '日线', '周线', '月线', '季线', '年线'], "5分钟": ['15分钟', '30分钟', '60分钟', '日线', '周线', '月线', '季线', '年线'], "15分钟": ['30分钟', '60分钟', '日线', '周线', '月线', '季线', '年线'], "30分钟": ['60分钟', '日线', '周线', '月线', '季线', '年线'], "60分钟": ['日线', '周线', '月线', '季线', '年线'], "日线": ['周线', '月线', '季线', '年线'], } assert self.base_freq in self.freq_map.keys() assert self.base_freq in self.base_freq_constraint.keys() bfc = self.base_freq_constraint[self.base_freq] for freq in self.freqs: assert freq in self.freq_map.keys() assert freq in bfc, f"{freq} 不在允许生成的周期列表中" assert self.base_freq not in self.freqs, 'base_freq 不能在 freqs 列表中'
[docs] def init_freq_bars(self, freq: str, bars: List[RawBar]): """初始化某个周期的K线序列 :param freq: 周期名称 :param bars: K线序列 :return: """ assert freq in self.bars.keys() assert not self.bars[freq], f"self.bars['{freq}'] 不为空,不允许执行初始化" self.bars[freq] = bars self.symbol = bars[-1].symbol
def __repr__(self): return f"<BarGenerator for {self.symbol} @ {self.end_dt}>" def _update_freq(self, bar: RawBar, freq: Freq) -> None: """更新指定周期K线 :param bar: 基础周期已完成K线 :param freq: 目标周期 :return: """ freq_edt = freq_end_time(bar.dt, freq) if not self.bars[freq.value]: bar_ = RawBar(symbol=bar.symbol, freq=freq, dt=freq_edt, id=0, open=bar.open, close=bar.close, high=bar.high, low=bar.low, vol=bar.vol, amount=bar.amount) self.bars[freq.value].append(bar_) return last: RawBar = self.bars[freq.value][-1] if freq_edt != self.bars[freq.value][-1].dt: bar_ = RawBar(symbol=bar.symbol, freq=freq, dt=freq_edt, id=last.id + 1, open=bar.open, close=bar.close, high=bar.high, low=bar.low, vol=bar.vol, amount=bar.amount) self.bars[freq.value].append(bar_) else: if last.amount: amount = last.amount + bar.amount else: amount = None bar_ = RawBar(symbol=bar.symbol, freq=freq, dt=freq_edt, id=last.id, open=last.open, close=bar.close, high=max(last.high, bar.high), low=min(last.low, bar.low), vol=last.vol + bar.vol, amount=amount) self.bars[freq.value][-1] = bar_
[docs] def update(self, bar: RawBar) -> None: """更新各周期K线 :param bar: 必须是已经结束的Bar :return: """ base_freq = self.base_freq assert bar.freq.value == base_freq self.symbol = bar.symbol self.end_dt = bar.dt if self.bars[base_freq] and self.bars[base_freq][-1].dt == bar.dt: print(f"BarGenerator.update: 输入重复K线,基准周期为{base_freq}") return for freq in self.bars.keys(): self._update_freq(bar, self.freq_map[freq]) # 限制存在内存中的K限制数量 for f, b in self.bars.items(): self.bars[f] = b[-self.max_count:]