diff --git a/examples/vn_trader/run.py b/examples/vn_trader/run.py index 32174775..bc0a304d 100644 --- a/examples/vn_trader/run.py +++ b/examples/vn_trader/run.py @@ -37,7 +37,7 @@ from vnpy.app.cta_strategy import CtaStrategyApp # from vnpy.app.csv_loader import CsvLoaderApp # from vnpy.app.algo_trading import AlgoTradingApp from vnpy.app.cta_backtester import CtaBacktesterApp -# from vnpy.app.data_recorder import DataRecorderApp +from vnpy.app.data_recorder import DataRecorderApp # from vnpy.app.risk_manager import RiskManagerApp from vnpy.app.script_trader import ScriptTraderApp from vnpy.app.rpc_service import RpcServiceApp @@ -61,7 +61,7 @@ def main(): # main_engine.add_gateway(FemasGateway) # main_engine.add_gateway(IbGateway) # main_engine.add_gateway(FutuGateway) - # main_engine.add_gateway(BitmexGateway) + main_engine.add_gateway(BitmexGateway) # main_engine.add_gateway(TigerGateway) # main_engine.add_gateway(OesGateway) # main_engine.add_gateway(OkexGateway) @@ -85,7 +85,7 @@ def main(): main_engine.add_app(CtaBacktesterApp) # main_engine.add_app(CsvLoaderApp) # main_engine.add_app(AlgoTradingApp) - # main_engine.add_app(DataRecorderApp) + main_engine.add_app(DataRecorderApp) # main_engine.add_app(RiskManagerApp) # main_engine.add_app(ScriptTraderApp) # main_engine.add_app(RpcServiceApp) diff --git a/vnpy/app/data_recorder/engine.py b/vnpy/app/data_recorder/engine.py index da4c362c..815bc088 100644 --- a/vnpy/app/data_recorder/engine.py +++ b/vnpy/app/data_recorder/engine.py @@ -6,6 +6,7 @@ from copy import copy from vnpy.event import Event, EventEngine from vnpy.trader.engine import BaseEngine, MainEngine +from vnpy.trader.constant import Exchange from vnpy.trader.object import ( SubscribeRequest, TickData, @@ -15,6 +16,7 @@ from vnpy.trader.object import ( from vnpy.trader.event import EVENT_TICK, EVENT_CONTRACT from vnpy.trader.utility import load_json, save_json, BarGenerator from vnpy.trader.database import database_manager +from vnpy.app.spread_trading.base import EVENT_SPREAD_DATA, SpreadData APP_NAME = "DataRecorder" @@ -91,18 +93,22 @@ class RecorderEngine(BaseEngine): self.write_log(f"已在K线记录列表中:{vt_symbol}") return - contract = self.main_engine.get_contract(vt_symbol) - if not contract: - self.write_log(f"找不到合约:{vt_symbol}") - return + if Exchange.LOCAL.value not in vt_symbol: + contract = self.main_engine.get_contract(vt_symbol) + if not contract: + self.write_log(f"找不到合约:{vt_symbol}") + return - self.bar_recordings[vt_symbol] = { - "symbol": contract.symbol, - "exchange": contract.exchange.value, - "gateway_name": contract.gateway_name - } + self.bar_recordings[vt_symbol] = { + "symbol": contract.symbol, + "exchange": contract.exchange.value, + "gateway_name": contract.gateway_name + } + + self.subscribe(contract) + else: + self.tick_recordings[vt_symbol] = {} - self.subscribe(contract) self.save_setting() self.put_event() @@ -114,18 +120,24 @@ class RecorderEngine(BaseEngine): self.write_log(f"已在Tick记录列表中:{vt_symbol}") return - contract = self.main_engine.get_contract(vt_symbol) - if not contract: - self.write_log(f"找不到合约:{vt_symbol}") - return + # For normal contract + if Exchange.LOCAL.value not in vt_symbol: + contract = self.main_engine.get_contract(vt_symbol) + if not contract: + self.write_log(f"找不到合约:{vt_symbol}") + return - self.tick_recordings[vt_symbol] = { - "symbol": contract.symbol, - "exchange": contract.exchange.value, - "gateway_name": contract.gateway_name - } + self.tick_recordings[vt_symbol] = { + "symbol": contract.symbol, + "exchange": contract.exchange.value, + "gateway_name": contract.gateway_name + } + + self.subscribe(contract) + # No need to subscribe for spread data + else: + self.tick_recordings[vt_symbol] = {} - self.subscribe(contract) self.save_setting() self.put_event() @@ -159,11 +171,11 @@ class RecorderEngine(BaseEngine): """""" self.event_engine.register(EVENT_TICK, self.process_tick_event) self.event_engine.register(EVENT_CONTRACT, self.process_contract_event) + self.event_engine.register( + EVENT_SPREAD_DATA, self.process_spread_event) - def process_tick_event(self, event: Event): + def update_tick(self, tick: TickData): """""" - tick = event.data - if tick.vt_symbol in self.tick_recordings: self.record_tick(tick) @@ -171,6 +183,11 @@ class RecorderEngine(BaseEngine): bg = self.get_bar_generator(tick.vt_symbol) bg.update_tick(tick) + def process_tick_event(self, event: Event): + """""" + tick = event.data + self.update_tick(tick) + def process_contract_event(self, event: Event): """""" contract = event.data @@ -179,6 +196,15 @@ class RecorderEngine(BaseEngine): if (vt_symbol in self.tick_recordings or vt_symbol in self.bar_recordings): self.subscribe(contract) + def process_spread_event(self, event: Event): + """""" + spread: SpreadData = event.data + tick = spread.to_tick() + + # Filter not inited spread data + if tick.datetime: + self.update_tick(tick) + def write_log(self, msg: str): """""" event = Event(