[Add] load and sync cta strategy data function

This commit is contained in:
vn.py 2019-02-16 11:18:58 +08:00
parent 7ade45e37f
commit 9df1cf9a74
4 changed files with 45 additions and 13 deletions

View File

@ -6,7 +6,7 @@ from vnpy.trader.ui import MainWindow, create_qapp
from vnpy.gateway.bitmex import BitmexGateway
from vnpy.gateway.futu import FutuGateway
#from vnpy.gateway.ib import IbGateway
from vnpy.gateway.ib import IbGateway
from vnpy.gateway.ctp import CtpGateway
from vnpy.app.cta_strategy import CtaStrategyApp
@ -20,7 +20,7 @@ def main():
main_engine = MainEngine(event_engine)
main_engine.add_gateway(CtpGateway)
#main_engine.add_gateway(IbGateway)
main_engine.add_gateway(IbGateway)
main_engine.add_gateway(FutuGateway)
main_engine.add_gateway(BitmexGateway)

View File

@ -19,7 +19,7 @@ from vnpy.trader.object import (
)
from vnpy.trader.event import EVENT_TICK, EVENT_ORDER, EVENT_TRADE
from vnpy.trader.constant import Direction, PriceType, Interval
from vnpy.trader.utility import get_temp_path
from vnpy.trader.utility import get_temp_path, load_json, save_json
from vnpy.trader.database import DbTickData, DbBarData
from .base import (
@ -41,7 +41,8 @@ class CtaEngine(BaseEngine):
engine_type = EngineType.LIVE # live trading engine
filename = "CtaStrategy.vt"
setting_filename = "cta_strategy_setting.vt"
data_filename = "cta_strategy_data.json"
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
""""""
@ -49,6 +50,7 @@ class CtaEngine(BaseEngine):
main_engine, event_engine, "CtaStrategy")
self.setting_file = None # setting file object
self.strategy_data = {} # strategy_name: dict
self.classes = {} # class_name: stategy_class
self.strategies = {} # strategy_name: strategy
@ -67,6 +69,7 @@ class CtaEngine(BaseEngine):
"""
self.load_strategy_class()
self.load_strategy_setting()
self.load_strategy_data()
self.register_event()
self.write_log("CTA策略引擎初始化成功")
@ -404,8 +407,17 @@ class CtaEngine(BaseEngine):
Init a strategy.
"""
strategy = self.strategies[strategy_name]
# Call on_init function of strategy
self.call_strategy_func(strategy, strategy.on_init)
strategy.inited = True
# Restore strategy data(variables)
data = self.strategy_data.get(strategy_name, None)
if data:
for name in strategy.variables:
value = data.get(name, None)
if value:
setattr(strategy, name, value)
# Subscribe market data
contract = self.main_engine.get_contract(strategy.vt_symbol)
@ -416,6 +428,7 @@ class CtaEngine(BaseEngine):
else:
self.write_log(f"行情订阅失败,找不到合约{strategy.vt_symbol}", strategy)
strategy.inited = True
self.put_strategy_event(strategy)
def start_strategy(self, strategy_name: str):
@ -488,9 +501,10 @@ class CtaEngine(BaseEngine):
"""
for dirpath, dirnames, filenames in os.walk(path):
for filename in filenames:
module_name = ".".join(
[module_name, filename.replace(".py", "")])
self.load_strategy_class_from_module(module_name)
if filename.endswith(".py"):
strategy_module_name = ".".join(
[module_name, filename.replace(".py", "")])
self.load_strategy_class_from_module(strategy_module_name)
def load_strategy_class_from_module(self, module_name: str):
"""
@ -501,12 +515,29 @@ class CtaEngine(BaseEngine):
for name in dir(module):
value = getattr(module, name)
if issubclass(value, CtaTemplate) and value is not CtaTemplate:
if (isinstance(value, type) and issubclass(value, CtaTemplate) and value is not CtaTemplate):
self.classes[value.__name__] = value
except: # noqa
msg = f"策略文件{module_name}加载失败,触发异常:\n{traceback.format_exc()}"
self.write_log(msg)
def load_strategy_data(self):
"""
Load strategy data from json file.
"""
self.strategy_data = load_json(self.data_filename)
def sync_strategy_data(self, strategy: CtaTemplate):
"""
Sync strategy data into json file.
"""
data = strategy.get_variables()
data.pop("inited") # Strategy status (inited, trading) should not be synced.
data.pop("trading")
self.strategy_data[strategy.strategy_name] = data
save_json(self.data_filename, self.strategy_data)
def get_all_strategy_class_names(self):
"""
Return names of strategy classes loaded.
@ -554,7 +585,7 @@ class CtaEngine(BaseEngine):
"""
Load setting file.
"""
filepath = str(get_temp_path(self.filename))
filepath = str(get_temp_path(self.setting_filename))
self.setting_file = shelve.open(filepath)
for tp in list(self.setting_file.values()):

View File

@ -233,7 +233,9 @@ class CtaTemplate(ABC):
"""
self.cta_engine.send_email(msg, self)
def save_variables(self):
def sync_data(self):
"""
Sync strategy variables value into disk storage.
"""
self.cta_engine.save_strategy_variables(self)
if self.trading:
self.cta_engine.sync_strategy_data(self)

View File

@ -11,7 +11,6 @@ import numpy as np
import talib
from .object import BarData, TickData
from .constant import interval
class Singleton(type):