[Add] load and sync cta strategy data function
This commit is contained in:
parent
7ade45e37f
commit
9df1cf9a74
@ -6,7 +6,7 @@ from vnpy.trader.ui import MainWindow, create_qapp
|
||||
|
||||
from vnpy.gateway.bitmex import BitmexGateway
|
||||
from vnpy.gateway.futu import FutuGateway
|
||||
#from vnpy.gateway.ib import IbGateway
|
||||
from vnpy.gateway.ib import IbGateway
|
||||
from vnpy.gateway.ctp import CtpGateway
|
||||
|
||||
from vnpy.app.cta_strategy import CtaStrategyApp
|
||||
@ -20,7 +20,7 @@ def main():
|
||||
|
||||
main_engine = MainEngine(event_engine)
|
||||
main_engine.add_gateway(CtpGateway)
|
||||
#main_engine.add_gateway(IbGateway)
|
||||
main_engine.add_gateway(IbGateway)
|
||||
main_engine.add_gateway(FutuGateway)
|
||||
main_engine.add_gateway(BitmexGateway)
|
||||
|
||||
|
@ -19,7 +19,7 @@ from vnpy.trader.object import (
|
||||
)
|
||||
from vnpy.trader.event import EVENT_TICK, EVENT_ORDER, EVENT_TRADE
|
||||
from vnpy.trader.constant import Direction, PriceType, Interval
|
||||
from vnpy.trader.utility import get_temp_path
|
||||
from vnpy.trader.utility import get_temp_path, load_json, save_json
|
||||
from vnpy.trader.database import DbTickData, DbBarData
|
||||
|
||||
from .base import (
|
||||
@ -41,7 +41,8 @@ class CtaEngine(BaseEngine):
|
||||
|
||||
engine_type = EngineType.LIVE # live trading engine
|
||||
|
||||
filename = "CtaStrategy.vt"
|
||||
setting_filename = "cta_strategy_setting.vt"
|
||||
data_filename = "cta_strategy_data.json"
|
||||
|
||||
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
|
||||
""""""
|
||||
@ -49,6 +50,7 @@ class CtaEngine(BaseEngine):
|
||||
main_engine, event_engine, "CtaStrategy")
|
||||
|
||||
self.setting_file = None # setting file object
|
||||
self.strategy_data = {} # strategy_name: dict
|
||||
|
||||
self.classes = {} # class_name: stategy_class
|
||||
self.strategies = {} # strategy_name: strategy
|
||||
@ -67,6 +69,7 @@ class CtaEngine(BaseEngine):
|
||||
"""
|
||||
self.load_strategy_class()
|
||||
self.load_strategy_setting()
|
||||
self.load_strategy_data()
|
||||
self.register_event()
|
||||
self.write_log("CTA策略引擎初始化成功")
|
||||
|
||||
@ -404,8 +407,17 @@ class CtaEngine(BaseEngine):
|
||||
Init a strategy.
|
||||
"""
|
||||
strategy = self.strategies[strategy_name]
|
||||
|
||||
# Call on_init function of strategy
|
||||
self.call_strategy_func(strategy, strategy.on_init)
|
||||
strategy.inited = True
|
||||
|
||||
# Restore strategy data(variables)
|
||||
data = self.strategy_data.get(strategy_name, None)
|
||||
if data:
|
||||
for name in strategy.variables:
|
||||
value = data.get(name, None)
|
||||
if value:
|
||||
setattr(strategy, name, value)
|
||||
|
||||
# Subscribe market data
|
||||
contract = self.main_engine.get_contract(strategy.vt_symbol)
|
||||
@ -416,6 +428,7 @@ class CtaEngine(BaseEngine):
|
||||
else:
|
||||
self.write_log(f"行情订阅失败,找不到合约{strategy.vt_symbol}", strategy)
|
||||
|
||||
strategy.inited = True
|
||||
self.put_strategy_event(strategy)
|
||||
|
||||
def start_strategy(self, strategy_name: str):
|
||||
@ -488,9 +501,10 @@ class CtaEngine(BaseEngine):
|
||||
"""
|
||||
for dirpath, dirnames, filenames in os.walk(path):
|
||||
for filename in filenames:
|
||||
module_name = ".".join(
|
||||
[module_name, filename.replace(".py", "")])
|
||||
self.load_strategy_class_from_module(module_name)
|
||||
if filename.endswith(".py"):
|
||||
strategy_module_name = ".".join(
|
||||
[module_name, filename.replace(".py", "")])
|
||||
self.load_strategy_class_from_module(strategy_module_name)
|
||||
|
||||
def load_strategy_class_from_module(self, module_name: str):
|
||||
"""
|
||||
@ -501,12 +515,29 @@ class CtaEngine(BaseEngine):
|
||||
|
||||
for name in dir(module):
|
||||
value = getattr(module, name)
|
||||
if issubclass(value, CtaTemplate) and value is not CtaTemplate:
|
||||
if (isinstance(value, type) and issubclass(value, CtaTemplate) and value is not CtaTemplate):
|
||||
self.classes[value.__name__] = value
|
||||
except: # noqa
|
||||
msg = f"策略文件{module_name}加载失败,触发异常:\n{traceback.format_exc()}"
|
||||
self.write_log(msg)
|
||||
|
||||
def load_strategy_data(self):
|
||||
"""
|
||||
Load strategy data from json file.
|
||||
"""
|
||||
self.strategy_data = load_json(self.data_filename)
|
||||
|
||||
def sync_strategy_data(self, strategy: CtaTemplate):
|
||||
"""
|
||||
Sync strategy data into json file.
|
||||
"""
|
||||
data = strategy.get_variables()
|
||||
data.pop("inited") # Strategy status (inited, trading) should not be synced.
|
||||
data.pop("trading")
|
||||
|
||||
self.strategy_data[strategy.strategy_name] = data
|
||||
save_json(self.data_filename, self.strategy_data)
|
||||
|
||||
def get_all_strategy_class_names(self):
|
||||
"""
|
||||
Return names of strategy classes loaded.
|
||||
@ -554,7 +585,7 @@ class CtaEngine(BaseEngine):
|
||||
"""
|
||||
Load setting file.
|
||||
"""
|
||||
filepath = str(get_temp_path(self.filename))
|
||||
filepath = str(get_temp_path(self.setting_filename))
|
||||
self.setting_file = shelve.open(filepath)
|
||||
|
||||
for tp in list(self.setting_file.values()):
|
||||
|
@ -233,7 +233,9 @@ class CtaTemplate(ABC):
|
||||
"""
|
||||
self.cta_engine.send_email(msg, self)
|
||||
|
||||
def save_variables(self):
|
||||
def sync_data(self):
|
||||
"""
|
||||
Sync strategy variables value into disk storage.
|
||||
"""
|
||||
self.cta_engine.save_strategy_variables(self)
|
||||
if self.trading:
|
||||
self.cta_engine.sync_strategy_data(self)
|
||||
|
@ -11,7 +11,6 @@ import numpy as np
|
||||
import talib
|
||||
|
||||
from .object import BarData, TickData
|
||||
from .constant import interval
|
||||
|
||||
|
||||
class Singleton(type):
|
||||
|
Loading…
Reference in New Issue
Block a user