From 70e5d85b4157596a7a5af1fa868745dfc2f5a216 Mon Sep 17 00:00:00 2001 From: msincenselee Date: Tue, 19 Feb 2019 00:15:57 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- vnpy/trader/app/ctaStrategy/ctaEngine.py | 42 +++++++++++++++++-- vnpy/trader/app/ctaStrategy/ctaHistoryData.py | 4 +- vnpy/trader/app/ctaStrategy/ctaLineBar.py | 5 ++- 3 files changed, 45 insertions(+), 6 deletions(-) diff --git a/vnpy/trader/app/ctaStrategy/ctaEngine.py b/vnpy/trader/app/ctaStrategy/ctaEngine.py index 658ea55a..0dee7597 100644 --- a/vnpy/trader/app/ctaStrategy/ctaEngine.py +++ b/vnpy/trader/app/ctaStrategy/ctaEngine.py @@ -30,6 +30,7 @@ import re import csv import copy import decimal +from copy import copy from vnpy.trader.vtEvent import * from vnpy.trader.vtConstant import * @@ -567,7 +568,10 @@ class CtaEngine(object): # 推送到策略onTrade事件 self.callStrategyFunc(strategy, strategy.onTrade, trade) - # 更新持仓缓存数据 + # 保存策略持仓到数据库 + self.saveSyncData(strategy) + + # 更新持仓缓存数据 if trade.vtSymbol in self.tickStrategyDict: posBuffer = self.posBufferDict.get(trade.vtSymbol, None) if not posBuffer: @@ -1111,6 +1115,7 @@ class CtaEngine(object): self.callStrategyFunc(strategy, strategy.onInit, force) # strategy.onInit(force=force) # strategy.inited = True + self.loadSyncData(strategy) # 初始化完成后加载同步数据 else: self.writeCtaLog(u'请勿重复初始化策略实例:%s' % name) return True @@ -1902,11 +1907,10 @@ class CtaEngine(object): except Exception as ex: self.writeCtaCritical(u'加载策略配置{}:异常{},{}'.format(setting, str(ex), traceback.format_exc())) traceback.print_exc() - self.loadPosition() + except Exception as ex: self.writeCtaCritical(u'加载策略配置异常:{},{}'.format(str(ex),traceback.format_exc())) - # ---------------------------------------------------------------------- # 策略运行监控相关 def getStrategyVar(self, name): @@ -2105,6 +2109,38 @@ class CtaEngine(object): except: self.writeCtaLog(u'loadPosition Exception from Mongodb') + # ---------------------------------------------------------------------- + def saveSyncData(self, strategy): + """保存策略的持仓情况到数据库""" + flt = {'name': strategy.name, + 'vtSymbol': strategy.vtSymbol} + + d = copy(flt) + for key in strategy.syncList: + d[key] = strategy.__getattribute__(key) + + self.mainEngine.dbUpdate(POSITION_DB_NAME, strategy.className, + d, flt, True) + + content = u'策略%s同步数据保存成功,当前持仓%s' % (strategy.name, strategy.pos) + self.writeCtaLog(content) + + # ---------------------------------------------------------------------- + def loadSyncData(self, strategy): + """从数据库载入策略的持仓情况""" + flt = {'name': strategy.name, + 'vtSymbol': strategy.vtSymbol} + syncData = self.mainEngine.dbQuery(POSITION_DB_NAME, strategy.className, flt) + + if not syncData: + return + + d = syncData[0] + + for key in strategy.syncList: + if key in d: + strategy.__setattr__(key, d[key]) + # ---------------------------------------------------------------------- # 公共方法相关 def roundToPriceTick(self, priceTick, price): diff --git a/vnpy/trader/app/ctaStrategy/ctaHistoryData.py b/vnpy/trader/app/ctaStrategy/ctaHistoryData.py index e1402a35..f19e4aaf 100644 --- a/vnpy/trader/app/ctaStrategy/ctaHistoryData.py +++ b/vnpy/trader/app/ctaStrategy/ctaHistoryData.py @@ -40,7 +40,7 @@ class HistoryDataEngine(object): #---------------------------------------------------------------------- def __init__(self): """Constructor""" - host, port = loadMongoSetting() + host, port,_ = loadMongoSetting() self.dbClient = pymongo.MongoClient(host, port) self.datayesClient = DatayesClient() @@ -329,7 +329,7 @@ def loadMcCsv(fileName, dbName, symbol): print( u'开始读取CSV文件%s中的数据插入到%s的%s中' %(fileName, dbName, symbol)) # 锁定集合,并创建索引 - host, port = loadMongoSetting() + host, port,_ = loadMongoSetting() client = pymongo.MongoClient(host, port) collection = client[dbName][symbol] diff --git a/vnpy/trader/app/ctaStrategy/ctaLineBar.py b/vnpy/trader/app/ctaStrategy/ctaLineBar.py index 00c0985a..45b50d00 100644 --- a/vnpy/trader/app/ctaStrategy/ctaLineBar.py +++ b/vnpy/trader/app/ctaStrategy/ctaLineBar.py @@ -3539,8 +3539,11 @@ class CtaLineBar(object): :param:direction,多:检查是否有顶背离,空,检查是否有底背离 :return: """ - if len(self.lineSkTop) < 2 or len(self.lineSkButtom) < 2 or self._rt_SK is None or self._rt_SD is None: + if len(self.lineSkTop) < 2 or len(self.lineSkButtom) < 2 : return False + if runtime: + if self._rt_SK is None or self._rt_SD is None: + return False t1 = self.lineSkTop[-1] t2 = self.get_2nd_item(self.lineSkTop[:-1])