[Add]CTA策略模块新增保存除持仓外其他同步数据到数据库的功能

This commit is contained in:
vn.py 2017-12-02 22:33:31 +08:00
parent d8817247d3
commit 55b57ca781
3 changed files with 33 additions and 12 deletions

View File

@ -564,6 +564,12 @@ class BacktestingEngine(object):
for stopOrderID in self.workingStopOrderDict.keys(): for stopOrderID in self.workingStopOrderDict.keys():
self.cancelStopOrder(stopOrderID) self.cancelStopOrder(stopOrderID)
#----------------------------------------------------------------------
def saveSyncData(self, strategy):
"""保存同步数据(无效)"""
pass
#------------------------------------------------ #------------------------------------------------
# 结果计算相关 # 结果计算相关
#------------------------------------------------ #------------------------------------------------

View File

@ -23,6 +23,7 @@ import os
import traceback import traceback
from collections import OrderedDict from collections import OrderedDict
from datetime import datetime, timedelta from datetime import datetime, timedelta
from copy import copy
from vnpy.event import Event from vnpy.event import Event
from vnpy.trader.vtEvent import * from vnpy.trader.vtEvent import *
@ -325,7 +326,7 @@ class CtaEngine(object):
self.callStrategyFunc(strategy, strategy.onTrade, trade) self.callStrategyFunc(strategy, strategy.onTrade, trade)
# 保存策略持仓到数据库 # 保存策略持仓到数据库
self.savePosition(strategy) self.saveSyncData(strategy)
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def registerEvent(self): def registerEvent(self):
@ -520,7 +521,7 @@ class CtaEngine(object):
for setting in l: for setting in l:
self.loadStrategy(setting) self.loadStrategy(setting)
self.loadPosition() self.loadSyncData()
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def getStrategyVar(self, name): def getStrategyVar(self, name):
@ -550,7 +551,7 @@ class CtaEngine(object):
return paramDict return paramDict
else: else:
self.writeCtaLog(u'策略实例不存在:' + name) self.writeCtaLog(u'策略实例不存在:' + name)
return None return None
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def putStrategyEvent(self, name): def putStrategyEvent(self, name):
@ -577,31 +578,37 @@ class CtaEngine(object):
self.writeCtaLog(content) self.writeCtaLog(content)
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def savePosition(self, strategy): def saveSyncData(self, strategy):
"""保存策略的持仓情况到数据库""" """保存策略的持仓情况到数据库"""
flt = {'name': strategy.name, flt = {'name': strategy.name,
'vtSymbol': strategy.vtSymbol} 'vtSymbol': strategy.vtSymbol}
d = {'name': strategy.name, d = copy(flt)
'vtSymbol': strategy.vtSymbol, for key in strategy.syncList:
'pos': strategy.pos} d[key] = strategy.__getattribute__(key)
self.mainEngine.dbUpdate(POSITION_DB_NAME, strategy.className, self.mainEngine.dbUpdate(POSITION_DB_NAME, strategy.className,
d, flt, True) d, flt, True)
content = '策略%s持仓保存成功,当前持仓%s' %(strategy.name, strategy.pos) content = '策略%s同步数据保存成功,当前持仓%s' %(strategy.name, strategy.pos)
self.writeCtaLog(content) self.writeCtaLog(content)
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def loadPosition(self): def loadSyncData(self):
"""从数据库载入策略的持仓情况""" """从数据库载入策略的持仓情况"""
for strategy in self.strategyDict.values(): for strategy in self.strategyDict.values():
flt = {'name': strategy.name, flt = {'name': strategy.name,
'vtSymbol': strategy.vtSymbol} 'vtSymbol': strategy.vtSymbol}
posData = self.mainEngine.dbQuery(POSITION_DB_NAME, strategy.className, flt) syncData = self.mainEngine.dbQuery(POSITION_DB_NAME, strategy.className, flt)
for d in posData: if not syncData:
strategy.pos = d['pos'] continue
d = syncData[0]
for key in strategy.syncList:
if key in d:
strategy.__setattr__(key, d[key])
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def roundToPriceTick(self, priceTick, price): def roundToPriceTick(self, priceTick, price):

View File

@ -46,6 +46,9 @@ class CtaTemplate(object):
varList = ['inited', varList = ['inited',
'trading', 'trading',
'pos'] 'pos']
# 同步列表,保存了需要保存到数据库的变量名称
syncList = ['pos']
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def __init__(self, ctaEngine, setting): def __init__(self, ctaEngine, setting):
@ -186,6 +189,11 @@ class CtaTemplate(object):
"""查询当前运行的环境""" """查询当前运行的环境"""
return self.ctaEngine.engineType return self.ctaEngine.engineType
#----------------------------------------------------------------------
def saveSyncData(self):
"""保存同步数据到数据库"""
self.ctaEngine.saveSyncData(self)
######################################################################## ########################################################################
class TargetPosTemplate(CtaTemplate): class TargetPosTemplate(CtaTemplate):