diff --git a/tests/trader/run.py b/tests/trader/run.py index 77dc42e5..c616b85a 100644 --- a/tests/trader/run.py +++ b/tests/trader/run.py @@ -6,7 +6,7 @@ from vnpy.trader.ui import MainWindow, create_qapp from vnpy.gateway.bitmex import BitmexGateway from vnpy.gateway.futu import FutuGateway -#from vnpy.gateway.ib import IbGateway +from vnpy.gateway.ib import IbGateway from vnpy.gateway.ctp import CtpGateway from vnpy.app.cta_strategy import CtaStrategyApp @@ -20,7 +20,7 @@ def main(): main_engine = MainEngine(event_engine) main_engine.add_gateway(CtpGateway) - #main_engine.add_gateway(IbGateway) + main_engine.add_gateway(IbGateway) main_engine.add_gateway(FutuGateway) main_engine.add_gateway(BitmexGateway) diff --git a/vnpy/app/cta_strategy/engine.py b/vnpy/app/cta_strategy/engine.py index c50dde97..e75256cb 100644 --- a/vnpy/app/cta_strategy/engine.py +++ b/vnpy/app/cta_strategy/engine.py @@ -19,7 +19,7 @@ from vnpy.trader.object import ( ) from vnpy.trader.event import EVENT_TICK, EVENT_ORDER, EVENT_TRADE from vnpy.trader.constant import Direction, PriceType, Interval -from vnpy.trader.utility import get_temp_path +from vnpy.trader.utility import get_temp_path, load_json, save_json from vnpy.trader.database import DbTickData, DbBarData from .base import ( @@ -41,7 +41,8 @@ class CtaEngine(BaseEngine): engine_type = EngineType.LIVE # live trading engine - filename = "CtaStrategy.vt" + setting_filename = "cta_strategy_setting.vt" + data_filename = "cta_strategy_data.json" def __init__(self, main_engine: MainEngine, event_engine: EventEngine): """""" @@ -49,6 +50,7 @@ class CtaEngine(BaseEngine): main_engine, event_engine, "CtaStrategy") self.setting_file = None # setting file object + self.strategy_data = {} # strategy_name: dict self.classes = {} # class_name: stategy_class self.strategies = {} # strategy_name: strategy @@ -67,6 +69,7 @@ class CtaEngine(BaseEngine): """ self.load_strategy_class() self.load_strategy_setting() + self.load_strategy_data() self.register_event() self.write_log("CTA策略引擎初始化成功") @@ -404,8 +407,17 @@ class CtaEngine(BaseEngine): Init a strategy. """ strategy = self.strategies[strategy_name] + + # Call on_init function of strategy self.call_strategy_func(strategy, strategy.on_init) - strategy.inited = True + + # Restore strategy data(variables) + data = self.strategy_data.get(strategy_name, None) + if data: + for name in strategy.variables: + value = data.get(name, None) + if value: + setattr(strategy, name, value) # Subscribe market data contract = self.main_engine.get_contract(strategy.vt_symbol) @@ -416,6 +428,7 @@ class CtaEngine(BaseEngine): else: self.write_log(f"行情订阅失败,找不到合约{strategy.vt_symbol}", strategy) + strategy.inited = True self.put_strategy_event(strategy) def start_strategy(self, strategy_name: str): @@ -488,9 +501,10 @@ class CtaEngine(BaseEngine): """ for dirpath, dirnames, filenames in os.walk(path): for filename in filenames: - module_name = ".".join( - [module_name, filename.replace(".py", "")]) - self.load_strategy_class_from_module(module_name) + if filename.endswith(".py"): + strategy_module_name = ".".join( + [module_name, filename.replace(".py", "")]) + self.load_strategy_class_from_module(strategy_module_name) def load_strategy_class_from_module(self, module_name: str): """ @@ -501,12 +515,29 @@ class CtaEngine(BaseEngine): for name in dir(module): value = getattr(module, name) - if issubclass(value, CtaTemplate) and value is not CtaTemplate: + if (isinstance(value, type) and issubclass(value, CtaTemplate) and value is not CtaTemplate): self.classes[value.__name__] = value except: # noqa msg = f"策略文件{module_name}加载失败,触发异常:\n{traceback.format_exc()}" self.write_log(msg) + def load_strategy_data(self): + """ + Load strategy data from json file. + """ + self.strategy_data = load_json(self.data_filename) + + def sync_strategy_data(self, strategy: CtaTemplate): + """ + Sync strategy data into json file. + """ + data = strategy.get_variables() + data.pop("inited") # Strategy status (inited, trading) should not be synced. + data.pop("trading") + + self.strategy_data[strategy.strategy_name] = data + save_json(self.data_filename, self.strategy_data) + def get_all_strategy_class_names(self): """ Return names of strategy classes loaded. @@ -554,7 +585,7 @@ class CtaEngine(BaseEngine): """ Load setting file. """ - filepath = str(get_temp_path(self.filename)) + filepath = str(get_temp_path(self.setting_filename)) self.setting_file = shelve.open(filepath) for tp in list(self.setting_file.values()): diff --git a/vnpy/app/cta_strategy/template.py b/vnpy/app/cta_strategy/template.py index 211b258f..2b5ab498 100644 --- a/vnpy/app/cta_strategy/template.py +++ b/vnpy/app/cta_strategy/template.py @@ -233,7 +233,9 @@ class CtaTemplate(ABC): """ self.cta_engine.send_email(msg, self) - def save_variables(self): + def sync_data(self): """ + Sync strategy variables value into disk storage. """ - self.cta_engine.save_strategy_variables(self) + if self.trading: + self.cta_engine.sync_strategy_data(self) diff --git a/vnpy/trader/utility.py b/vnpy/trader/utility.py index 3be11ed5..fdceff66 100644 --- a/vnpy/trader/utility.py +++ b/vnpy/trader/utility.py @@ -11,7 +11,6 @@ import numpy as np import talib from .object import BarData, TickData -from .constant import interval class Singleton(type):