from typing import List, Dict

from tqz_strategy.template import StrategyTemplate
from public_module.utility import BarGenerator
from public_module.object import TickData, BarData


class PairTradingStrategy(StrategyTemplate):
    """"""

    author = "tqz"

    main_vt_symbol = ""
    vt_symbols = []

    parameters = [
        "main_vt_symbol",
        "vt_symbols",
    ]
    variables = [
        "pos",
    ]

    def __init__(
        self,
        strategy_engine,
        strategy_name: str,
        vt_symbols: List[str],
        setting: dict
    ):
        """"""
        super().__init__(strategy_engine, strategy_name, vt_symbols, setting)

        self.bar_generators: Dict[str, BarGenerator] = {}
        self.tick_map: Dict[str, TickData] = {}
        self.bar_map: Dict[str, BarData] = {}

        self.i_symbol = "i_symbol"
        self.lots = 1

        for vt_symbol in self.vt_symbols:
            self.pos[vt_symbol] = 0
            self.bar_generators[vt_symbol] = BarGenerator(self.on_bar)
        self.bar_generators["i_bg"] = BarGenerator(self.on_bar)


    def on_tick(self, tick: TickData):
        """
        Callback of new tick data update.
        """
        self.bar_generators[tick.vt_symbol].update_tick(tick)
        self.tick_map[tick.vt_symbol] = tick

        if tick.vt_symbol != self.main_vt_symbol:
            return

        self.bar_generators["i_bg"].update_tick(
            self.__get_new_index_tick(tick=tick, tick_map=self.tick_map)
        )


    def on_bar(self, bar: BarData):
        """"""
        self.bar_map[bar.vt_symbol] = bar

        """ need test in real atm.
        bar_map_update = True
        for bar_data in self.bar_map.values():
            if bar_data.datetime >= bar.datetime:
                continue
            bar_map_update = False
        
        if bar_map_update:
            self.on_bars(bars=self.bar_map)
        """

        if bar.symbol == self.i_symbol:
            self.on_bars(bars=self.bar_map)


    def on_bars(self, bars: Dict[str, BarData]):
        """"""
        bars_sort = sorted(bars.values(), key=lambda bar: bar.close_price, reverse=False)
        low_bar, high_bar = bars_sort[0], bars_sort[-1]
        high_low = []

        if high_bar.symbol != self.i_symbol:
            high_low.append(high_bar.vt_symbol)
            if high_bar.vt_symbol not in self.pos.values():
                self.pos[high_bar.vt_symbol] = 0

            if self.pos[high_bar.vt_symbol] is 0:
                self.short(vt_symbol=high_bar.vt_symbol, price=high_bar.close_price, volume=self.lots)
            elif self.pos[high_bar.vt_symbol] > 0:
                self.sell(vt_symbol=high_bar.vt_symbol, price=high_bar.close_price, volume=self.lots)
                self.short(vt_symbol=high_bar.vt_symbol, price=high_bar.close_price, volume=self.lots)
            elif self.pos[high_bar.vt_symbol] < 0:
                pass

        if low_bar.symbol != self.i_symbol:
            high_low.append(low_bar.vt_symbol)
            if low_bar.vt_symbol not in self.pos.values():
                self.pos[low_bar.vt_symbol] = 0

            if self.pos[low_bar.vt_symbol] is 0:
                self.buy(vt_symbol=low_bar.vt_symbol, price=low_bar.close_price, volume=self.lots)
            elif self.pos[low_bar.vt_symbol] < 0:
                self.cover(vt_symbol=low_bar.vt_symbol, price=low_bar.close_price, volume=self.lots)
                self.buy(vt_symbol=low_bar.vt_symbol, price=low_bar.close_price, volume=self.lots)
            elif self.pos[low_bar.vt_symbol] > 0:
                pass

        for vt_symbol, pos in self.pos.items():
            if vt_symbol in high_low:
                continue

            if self.pos[vt_symbol] > 0:
                self.sell(vt_symbol=vt_symbol, price=bars[vt_symbol].close_price, volume=pos)
            elif self.pos[vt_symbol] < 0:
                self.cover(vt_symbol=vt_symbol, price=bars[vt_symbol].close_price, volume=pos)
            elif self.pos[vt_symbol] is 0:
                pass


    # --- real atm
    def __get_new_index_tick(self, tick: TickData, tick_map: {str, TickData}):
        tt_last_volume = 0
        tt_volume = 0
        tt_open_interest = 0
        tt_oi_lp = 0
        for vt_symbol, tick_data in tick_map.items():
            tt_last_volume += tick_data.last_volume
            tt_volume += tick_data.volume
            tt_open_interest += tick_data.open_interest
            tt_oi_lp += tick_data.last_price * tick_data.open_interest

        return TickData(
            symbol=self.i_symbol,
            exchange=tick.exchange,
            gateway_name=tick.gateway_name,
            datetime=tick.datetime,
            last_price=round(tt_oi_lp / tt_open_interest, 4),
            last_volume=tt_last_volume,
            volume=tt_volume,
            open_interest=tt_open_interest
        )

Logo

加入社区!打开量化的大门,首批课程上线啦!

更多推荐