diff --git a/vnpy/trader/app/ctaStrategy/ctaGridTrade.py b/vnpy/trader/app/ctaStrategy/ctaGridTrade.py index dfaf2264..f07fbe39 100644 --- a/vnpy/trader/app/ctaStrategy/ctaGridTrade.py +++ b/vnpy/trader/app/ctaStrategy/ctaGridTrade.py @@ -19,7 +19,7 @@ ChangeLog: 170504,增加锁单网格 170707,增加重用选项 170719, 增加网格类型 - 171208,增加openPrices/snapshot + 171208,增加openPrices/snapshot """ # 网格类型 @@ -435,6 +435,7 @@ class CtaGridTrade(object): def getLastOpenedGrid(self, direction,type = EMPTY_STRING,orderby_asc=True): """获取最后一个开仓的网格""" + # highest_short_price_grid = getLastOpenedGrid(DIRECTION_SHORT if direction == DIRECTION_SHORT: opened_short_grids = self.getGrids(direction=direction, opened=True,type=type) @@ -780,14 +781,19 @@ class CtaGridTrade(object): pass def save(self, direction): - """保存网格至本地Json文件""" + """ + 保存网格至本地Json文件" + 2017/11/23 update: 保存时,空的列表也保存 + :param direction: + :return: + """"" # 更新开仓均价 self.recount_avg_open_price() path = os.path.abspath(os.path.dirname(__file__)) # 保存上网格列表 - if len(self.upGrids) > 0 and direction == DIRECTION_SHORT: + if direction == DIRECTION_SHORT: jsonFileName = os.path.join(path, u'data', u'{0}_upGrids.json'.format(self.jsonName)) l = [] @@ -801,7 +807,7 @@ class CtaGridTrade(object): #self.writeCtaLog(u'上网格保存文件{0}完成'.format(jsonFileName)) # 保存上网格列表 - if len(self.dnGrids) > 0 and direction == DIRECTION_LONG: + if direction == DIRECTION_LONG: jsonFileName = os.path.join(path, u'data', u'{0}_dnGrids.json'.format(self.jsonName)) l = [] @@ -814,8 +820,13 @@ class CtaGridTrade(object): #self.writeCtaLog(u'下网格保存文件{0}完成'.format(jsonFileName)) - def load(self, direction): - """加载本地Json至网格""" + def load(self, direction, openStatusFilter=[]): + """ + 加载本地Json至网格 + :param direction: DIRECTION_SHORT,做空网格;DIRECTION_LONG,做多网格 + :param openStatusFilter: 缺省,不做过滤;True,只提取已开仓的数据,False,只提取未开仓的数据 + :return: + """ path = os.path.abspath(os.path.dirname(__file__)) @@ -872,6 +883,13 @@ class CtaGridTrade(object): except KeyError: grid.lockGrids = [] + try: + grid.type = i['type'] + if grid.type == False: + grid.type = EMPTY_STRING + except KeyError: + grid.type = EMPTY_STRING + try: grid.reuse = i['reuse'] except KeyError: @@ -889,6 +907,11 @@ class CtaGridTrade(object): self.writeCtaLog(grid.toStr()) + # 增加对开仓状态的过滤,满足某些策略只提取已开仓的网格数据 + if len(openStatusFilter) > 0: + if grid.openStatus not in openStatusFilter: + continue + grids.append(grid) else: