From 7d91bcb158bda7cad2508a7ce57a1b5c78933215 Mon Sep 17 00:00:00 2001 From: msincenselee Date: Thu, 28 Dec 2017 08:42:37 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dload=EF=BC=88=EF=BC=89?= =?UTF-8?q?=E4=B8=ADtype=3DFalse=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- vnpy/trader/app/ctaStrategy/ctaGridTrade.py | 35 +++++++++++++++++---- 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/vnpy/trader/app/ctaStrategy/ctaGridTrade.py b/vnpy/trader/app/ctaStrategy/ctaGridTrade.py index dfaf2264..f07fbe39 100644 --- a/vnpy/trader/app/ctaStrategy/ctaGridTrade.py +++ b/vnpy/trader/app/ctaStrategy/ctaGridTrade.py @@ -19,7 +19,7 @@ ChangeLog: 170504,增加锁单网格 170707,增加重用选项 170719, 增加网格类型 - 171208,增加openPrices/snapshot + 171208,增加openPrices/snapshot """ # 网格类型 @@ -435,6 +435,7 @@ class CtaGridTrade(object): def getLastOpenedGrid(self, direction,type = EMPTY_STRING,orderby_asc=True): """获取最后一个开仓的网格""" + # highest_short_price_grid = getLastOpenedGrid(DIRECTION_SHORT if direction == DIRECTION_SHORT: opened_short_grids = self.getGrids(direction=direction, opened=True,type=type) @@ -780,14 +781,19 @@ class CtaGridTrade(object): pass def save(self, direction): - """保存网格至本地Json文件""" + """ + 保存网格至本地Json文件" + 2017/11/23 update: 保存时,空的列表也保存 + :param direction: + :return: + """"" # 更新开仓均价 self.recount_avg_open_price() path = os.path.abspath(os.path.dirname(__file__)) # 保存上网格列表 - if len(self.upGrids) > 0 and direction == DIRECTION_SHORT: + if direction == DIRECTION_SHORT: jsonFileName = os.path.join(path, u'data', u'{0}_upGrids.json'.format(self.jsonName)) l = [] @@ -801,7 +807,7 @@ class CtaGridTrade(object): #self.writeCtaLog(u'上网格保存文件{0}完成'.format(jsonFileName)) # 保存上网格列表 - if len(self.dnGrids) > 0 and direction == DIRECTION_LONG: + if direction == DIRECTION_LONG: jsonFileName = os.path.join(path, u'data', u'{0}_dnGrids.json'.format(self.jsonName)) l = [] @@ -814,8 +820,13 @@ class CtaGridTrade(object): #self.writeCtaLog(u'下网格保存文件{0}完成'.format(jsonFileName)) - def load(self, direction): - """加载本地Json至网格""" + def load(self, direction, openStatusFilter=[]): + """ + 加载本地Json至网格 + :param direction: DIRECTION_SHORT,做空网格;DIRECTION_LONG,做多网格 + :param openStatusFilter: 缺省,不做过滤;True,只提取已开仓的数据,False,只提取未开仓的数据 + :return: + """ path = os.path.abspath(os.path.dirname(__file__)) @@ -872,6 +883,13 @@ class CtaGridTrade(object): except KeyError: grid.lockGrids = [] + try: + grid.type = i['type'] + if grid.type == False: + grid.type = EMPTY_STRING + except KeyError: + grid.type = EMPTY_STRING + try: grid.reuse = i['reuse'] except KeyError: @@ -889,6 +907,11 @@ class CtaGridTrade(object): self.writeCtaLog(grid.toStr()) + # 增加对开仓状态的过滤,满足某些策略只提取已开仓的网格数据 + if len(openStatusFilter) > 0: + if grid.openStatus not in openStatusFilter: + continue + grids.append(grid) else: