From ecbc5b7e3f4bd462d92ff02012cc09c5d55520cb Mon Sep 17 00:00:00 2001 From: chenxy123 Date: Mon, 28 Mar 2016 20:52:15 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E9=83=A8=E5=88=86bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- vn.oanda/vnoanda.py | 5 ++-- vn.trader/ctaAlgo/ctaBacktesting.py | 13 ++++++++- vn.trader/ctpGateway/ctpGateway.py | 45 ++++++++++++++++++++--------- vn.trader/oandaGateway/vnoanda.py | 5 ++-- 4 files changed, 49 insertions(+), 19 deletions(-) diff --git a/vn.oanda/vnoanda.py b/vn.oanda/vnoanda.py index 6272b651..9a77b35c 100644 --- a/vn.oanda/vnoanda.py +++ b/vn.oanda/vnoanda.py @@ -177,8 +177,9 @@ class OandaApi(object): #---------------------------------------------------------------------- def exit(self): """退出接口""" - self.active = False - self.reqThread.join() + if self.active: + self.active = False + self.reqThread.join() #---------------------------------------------------------------------- def initFunctionSetting(self, code, setting): diff --git a/vn.trader/ctaAlgo/ctaBacktesting.py b/vn.trader/ctaAlgo/ctaBacktesting.py index 69060680..3ca8d909 100644 --- a/vn.trader/ctaAlgo/ctaBacktesting.py +++ b/vn.trader/ctaAlgo/ctaBacktesting.py @@ -283,11 +283,14 @@ class BacktestingEngine(object): # 1. 假设当根K线的OHLC分别为:100, 125, 90, 110 # 2. 假设在上一根K线结束(也是当前K线开始)的时刻,策略发出的委托为限价105 # 3. 则在实际中的成交价会是100而不是105,因为委托发出时市场的最优价格是100 + # 同时更新策略对象的持仓情况 if buyCross: trade.price = min(order.price, bestCrossPrice) + self.strategy.pos += order.totalVolume else: trade.price = max(order.price, bestCrossPrice) - + self.strategy.pos -= order.totalVolume + trade.volume = order.totalVolume trade.tradeTime = str(self.dt) trade.dt = self.dt @@ -322,6 +325,12 @@ class BacktestingEngine(object): # 如果发生了成交 if buyCross or sellCross: + # 更新策略对象的持仓情况 + if buyCross: + self.strategy.pos += order.totalVolume + else: + self.strategy.pos -= order.totalVolume + # 推送成交数据 self.tradeCount += 1 # 成交编号自增1 tradeID = str(self.tradeCount) @@ -463,6 +472,8 @@ class BacktestingEngine(object): pPnl = plt.subplot(3, 1, 3) pPnl.set_ylabel("pnl") pPnl.hist(pnlList, bins=20) + + plt.show() # 输出 self.output('-' * 50) diff --git a/vn.trader/ctpGateway/ctpGateway.py b/vn.trader/ctpGateway/ctpGateway.py index 0a2ff7b1..768ea832 100644 --- a/vn.trader/ctpGateway/ctpGateway.py +++ b/vn.trader/ctpGateway/ctpGateway.py @@ -10,6 +10,7 @@ vtSymbol直接使用symbol import os import json +from copy import copy from vnctpmd import MdApi from vnctptd import TdApi @@ -436,6 +437,8 @@ class CtpTdApi(TdApi): self.frontID = EMPTY_INT # 前置机编号 self.sessionID = EMPTY_INT # 会话编号 + self.posDict = {} # 缓存持仓数据的字典 + #---------------------------------------------------------------------- def onFrontConnected(self): """服务器连接""" @@ -624,33 +627,47 @@ class CtpTdApi(TdApi): #---------------------------------------------------------------------- def onRspQryInvestorPosition(self, data, error, n, last): """持仓查询回报""" - pos = VtPositionData() - pos.gatewayName = self.gatewayName + # 获取缓存字典中的持仓对象,若无则创建并初始化 + positionName = '.'.join([data['InstrumentID'], data['PosiDirection']]) - # 保存代码 - pos.symbol = data['InstrumentID'] - pos.vtSymbol = pos.symbol # 这里因为data中没有ExchangeID这个字段 + if positionName in self.posDict: + pos = self.posDict[positionName] + else: + pos = VtPositionData() + self.posDict[positionName] = pos + + pos.gatewayName = self.gatewayName - # 方向和持仓冻结数量 - pos.direction = posiDirectionMapReverse.get(data['PosiDirection'], '') + # 保存代码 + pos.symbol = data['InstrumentID'] + pos.vtSymbol = pos.symbol # 这里因为data中没有ExchangeID这个字段 + + # 方向 + pos.direction = posiDirectionMapReverse.get(data['PosiDirection'], '') + + # VT系统持仓名 + pos.vtPositionName = '.'.join([pos.vtSymbol, pos.direction]) + + # 持仓冻结数量 if pos.direction == DIRECTION_NET or pos.direction == DIRECTION_LONG: pos.frozen = data['LongFrozen'] elif pos.direction == DIRECTION_SHORT: pos.frozen = data['ShortFrozen'] # 持仓量 - pos.position = data['Position'] - pos.ydPosition = data['YdPosition'] + if data['Position']: + pos.position = data['Position'] + + if data['YdPosition']: + pos.ydPosition = data['YdPosition'] # 持仓均价 if pos.position: pos.price = data['PositionCost'] / pos.position - - # VT系统持仓名 - pos.vtPositionName = '.'.join([pos.vtSymbol, pos.direction]) - + # 推送 - self.gateway.onPosition(pos) + newpos = copy(pos) + self.gateway.onPosition(newpos) #---------------------------------------------------------------------- def onRspQryTradingAccount(self, data, error, n, last): diff --git a/vn.trader/oandaGateway/vnoanda.py b/vn.trader/oandaGateway/vnoanda.py index 62f29609..57d1d16e 100644 --- a/vn.trader/oandaGateway/vnoanda.py +++ b/vn.trader/oandaGateway/vnoanda.py @@ -177,8 +177,9 @@ class OandaApi(object): #---------------------------------------------------------------------- def exit(self): """退出接口""" - self.active = False - self.reqThread.join() + if self.active: + self.active = False + self.reqThread.join() #---------------------------------------------------------------------- def initFunctionSetting(self, code, setting):