diff --git a/vnpy/trader/app/ctaStrategy/ctaBacktesting.py b/vnpy/trader/app/ctaStrategy/ctaBacktesting.py index 88db33ab..57a64649 100644 --- a/vnpy/trader/app/ctaStrategy/ctaBacktesting.py +++ b/vnpy/trader/app/ctaStrategy/ctaBacktesting.py @@ -564,6 +564,12 @@ class BacktestingEngine(object): for stopOrderID in self.workingStopOrderDict.keys(): self.cancelStopOrder(stopOrderID) + #---------------------------------------------------------------------- + def saveSyncData(self, strategy): + """保存同步数据(无效)""" + pass + + #------------------------------------------------ # 结果计算相关 #------------------------------------------------ diff --git a/vnpy/trader/app/ctaStrategy/ctaEngine.py b/vnpy/trader/app/ctaStrategy/ctaEngine.py index 06f27034..55fb2f8a 100644 --- a/vnpy/trader/app/ctaStrategy/ctaEngine.py +++ b/vnpy/trader/app/ctaStrategy/ctaEngine.py @@ -23,6 +23,7 @@ import os import traceback from collections import OrderedDict from datetime import datetime, timedelta +from copy import copy from vnpy.event import Event from vnpy.trader.vtEvent import * @@ -325,7 +326,7 @@ class CtaEngine(object): self.callStrategyFunc(strategy, strategy.onTrade, trade) # 保存策略持仓到数据库 - self.savePosition(strategy) + self.saveSyncData(strategy) #---------------------------------------------------------------------- def registerEvent(self): @@ -520,7 +521,7 @@ class CtaEngine(object): for setting in l: self.loadStrategy(setting) - self.loadPosition() + self.loadSyncData() #---------------------------------------------------------------------- def getStrategyVar(self, name): @@ -550,7 +551,7 @@ class CtaEngine(object): return paramDict else: self.writeCtaLog(u'策略实例不存在:' + name) - return None + return None #---------------------------------------------------------------------- def putStrategyEvent(self, name): @@ -577,31 +578,37 @@ class CtaEngine(object): self.writeCtaLog(content) #---------------------------------------------------------------------- - def savePosition(self, strategy): + def saveSyncData(self, strategy): """保存策略的持仓情况到数据库""" flt = {'name': strategy.name, 'vtSymbol': strategy.vtSymbol} - d = {'name': strategy.name, - 'vtSymbol': strategy.vtSymbol, - 'pos': strategy.pos} + d = copy(flt) + for key in strategy.syncList: + d[key] = strategy.__getattribute__(key) self.mainEngine.dbUpdate(POSITION_DB_NAME, strategy.className, d, flt, True) - content = '策略%s持仓保存成功,当前持仓%s' %(strategy.name, strategy.pos) + content = '策略%s同步数据保存成功,当前持仓%s' %(strategy.name, strategy.pos) self.writeCtaLog(content) #---------------------------------------------------------------------- - def loadPosition(self): + def loadSyncData(self): """从数据库载入策略的持仓情况""" for strategy in self.strategyDict.values(): flt = {'name': strategy.name, 'vtSymbol': strategy.vtSymbol} - posData = self.mainEngine.dbQuery(POSITION_DB_NAME, strategy.className, flt) + syncData = self.mainEngine.dbQuery(POSITION_DB_NAME, strategy.className, flt) - for d in posData: - strategy.pos = d['pos'] + if not syncData: + continue + + d = syncData[0] + + for key in strategy.syncList: + if key in d: + strategy.__setattr__(key, d[key]) #---------------------------------------------------------------------- def roundToPriceTick(self, priceTick, price): diff --git a/vnpy/trader/app/ctaStrategy/ctaTemplate.py b/vnpy/trader/app/ctaStrategy/ctaTemplate.py index e55eb3c4..9dcf52f5 100644 --- a/vnpy/trader/app/ctaStrategy/ctaTemplate.py +++ b/vnpy/trader/app/ctaStrategy/ctaTemplate.py @@ -46,6 +46,9 @@ class CtaTemplate(object): varList = ['inited', 'trading', 'pos'] + + # 同步列表,保存了需要保存到数据库的变量名称 + syncList = ['pos'] #---------------------------------------------------------------------- def __init__(self, ctaEngine, setting): @@ -186,6 +189,11 @@ class CtaTemplate(object): """查询当前运行的环境""" return self.ctaEngine.engineType + #---------------------------------------------------------------------- + def saveSyncData(self): + """保存同步数据到数据库""" + self.ctaEngine.saveSyncData(self) + ######################################################################## class TargetPosTemplate(CtaTemplate):