diff --git a/vn.okcoin/README.md b/vn.okcoin/README.md new file mode 100644 index 00000000..cd8a1ee5 --- /dev/null +++ b/vn.okcoin/README.md @@ -0,0 +1,36 @@ +# vn.okcoin + +贡献者:量衍投资 + +### 简介 +OkCoin的比特币交易接口,基于Websocket API开发,实现了以下功能: + +1. 发送、撤销委托 + +2. 查询委托、持仓、资金、成交历史 + +3. 实时行情、成交、资金更新的推送 + +### 特点 +相比较于[OkCoin官方](http://github.com/OKCoin/websocket/tree/master/python)给出的Python API实现,vn.okcoin的一些特点: + +1. 同时支持OkCoin的中国站和国际站交易,根据用户连接的站点会在内部自动切换结算货币(CNY、USD) + +2. 采用面向对象的接口设计模式,接近国内CTP接口的风格,并对主动函数的调用参数做了大幅简化 + +3. 数据解包和签名生成两个热点函数使用了更加高效的实现方式 + +### 参数命名 +函数的参数命名针对金融领域用户的习惯做了一些修改,具体对应如下: + +* expiry:原生命名的contract_type +* order: 原生命名的match_price +* leverage:原生命名的lever_rate +* page:原生命名的current_page +* length:原生命名的page_length + +### API版本 +日期:2016-06-29 + +链接:[http://www.okcoin.com/about/ws_getStarted.do](http://www.okcoin.com/about/ws_getStarted.do) + diff --git a/vn.okcoin/test.py b/vn.okcoin/test.py new file mode 100644 index 00000000..6f1bc39a --- /dev/null +++ b/vn.okcoin/test.py @@ -0,0 +1,47 @@ +# encoding: UTF-8 + +from vnokcoin import * + +# 在OkCoin网站申请这两个Key,分别对应用户名和密码 +apiKey = '' +secretKey = '' + +# 创建API对象 +api = OkCoinApi() + +# 连接服务器,并等待1秒 +api.connect(OKCOIN_USD, apiKey, secretKey, True) + +sleep(1) + +# 测试现货行情API +api.subscribeSpotTicker(SYMBOL_BTC) +#api.subscribeSpotTradeData(SYMBOL_BTC) +#api.subscribeSpotDepth(SYMBOL_BTC, DEPTH_20) +#api.subscribeSpotKline(SYMBOL_BTC, INTERVAL_1M) + +# 测试现货交易API +#api.subscribeSpotTrades() +#api.subscribeSpotUserInfo() +#api.spotUserInfo() +#api.spotTrade(symbol, type_, price, amount) +#api.spotCancelOrder(symbol, orderid) +#api.spotOrderInfo(symbol, orderid) + +# 测试期货行情API +#api.subscribeFutureTicker(SYMBOL_BTC, FUTURE_EXPIRY_THIS_WEEK) +#api.subscribeFutureTradeData(SYMBOL_BTC, FUTURE_EXPIRY_THIS_WEEK) +#api.subscribeFutureDepth(SYMBOL_BTC, FUTURE_EXPIRY_THIS_WEEK, DEPTH_20) +#api.subscribeFutureKline(SYMBOL_BTC, FUTURE_EXPIRY_THIS_WEEK, INTERVAL_1M) +#api.subscribeFutureIndex(SYMBOL_BTC) + +# 测试期货交易API +#api.subscribeFutureTrades() +#api.subscribeFutureUserInfo() +#api.subscribeFuturePositions() +#api.futureUserInfo() +#api.futureTrade(symbol, expiry, type_, price, amount, order, leverage) +#api.futureCancelOrder(symbol, expiry, orderid) +#api.futureOrderInfo(symbol, expiry, orderid, status, page, length) + +raw_input() \ No newline at end of file diff --git a/vn.okcoin/vnokcoin.py b/vn.okcoin/vnokcoin.py index 2e0623ab..2f30a766 100644 --- a/vn.okcoin/vnokcoin.py +++ b/vn.okcoin/vnokcoin.py @@ -49,6 +49,32 @@ TYPE_SELL = 'sell' TYPE_BUY_MARKET = 'buy_market' TYPE_SELL_MARKET = 'sell_market' +# 期货合约到期类型 +FUTURE_EXPIRY_THIS_WEEK = 'this_week' +FUTURE_EXPIRY_NEXT_WEEK = 'next_week' +FUTURE_EXPIRY_QUARTER = 'quarter' + +# 期货委托类型 +FUTURE_TYPE_LONG = 1 +FUTURE_TYPE_SHORT = 2 +FUTURE_TYPE_SELL = 3 +FUTURE_TYPE_COVER = 4 + +# 期货是否用现价 +FUTURE_ORDER_MARKET = 1 +FUTURE_ORDER_LIMIT = 0 + +# 期货杠杆 +FUTURE_LEVERAGE_10 = 10 +FUTURE_LEVERAGE_20 = 20 + +# 委托状态 +ORDER_STATUS_NOTTRADED = 0 +ORDER_STATUS_PARTTRADED = 1 +ORDER_STATUS_ALLTRADED = 2 +ORDER_STATUS_CANCELLED = -1 +ORDER_STATUS_CANCELLING = 4 + ######################################################################## class OkCoinApi(object): @@ -64,7 +90,11 @@ class OkCoinApi(object): self.ws = None # websocket应用对象 self.thread = None # 工作线程 - + + ####################### + ## 通用函数 + ####################### + #---------------------------------------------------------------------- def readData(self, evt): """解压缩推送收到的数据""" @@ -134,7 +164,11 @@ class OkCoinApi(object): self.thread = Thread(target=self.ws.run_forever) self.thread.start() - + + ####################### + ## 现货相关 + ####################### + #---------------------------------------------------------------------- def sendMarketDataRequest(self, channel): """发送行情请求""" @@ -241,33 +275,107 @@ class OkCoinApi(object): channel = 'ok_sub_spot%s_userinfo' %(self.currency) self.sendTradingRequest(channel, {}) + + ####################### + ## 期货相关 + ####################### + + #---------------------------------------------------------------------- + def subscribeFutureTicker(self, symbol, expiry): + """订阅期货普通报价""" + self.sendMarketDataRequest('ok_sub_future%s_%s_ticker_%s' %(self.currency, symbol, expiry)) + + #---------------------------------------------------------------------- + def subscribeFutureDepth(self, symbol, expiry, depth): + """订阅期货深度报价""" + self.sendMarketDataRequest('ok_sub_future%s_%s_depth_%s_%s' %(self.currency, symbol, + expiry, depth)) + #---------------------------------------------------------------------- + def subscribeFutureTradeData(self, symbol, expiry): + """订阅期货成交记录""" + self.sendMarketDataRequest('ok_sub_future%s_%s_trade_%s' %(self.currency, symbol, expiry)) + + #---------------------------------------------------------------------- + def subscribeFutureKline(self, symbol, expiry, interval): + """订阅期货K线""" + self.sendMarketDataRequest('ok_sub_future%s_%s_kline_%s_%s' %(self.currency, symbol, + expiry, interval)) + + #---------------------------------------------------------------------- + def subscribeFutureIndex(self, symbol): + """订阅期货指数""" + self.sendMarketDataRequest('ok_sub_future%s_%s_index' %(self.currency, symbol)) + + #---------------------------------------------------------------------- + def futureTrade(self, symbol, expiry, type_, price, amount, order, leverage): + """期货委托""" + params = {} + params['symbol'] = str(symbol+self.currency) + params['type'] = str(type_) + params['price'] = str(price) + params['amount'] = str(amount) + params['contract_type'] = str(expiry) + params['match_price'] = str(order) + params['lever_rate'] = str(leverage) + + channel = 'ok_future%s_trade' %(self.currency) + + self.sendTradingRequest(channel, params) + + #---------------------------------------------------------------------- + def futureCancelOrder(self, symbol, expiry, orderid): + """期货撤单""" + params = {} + params['symbol'] = str(symbol+self.currency) + params['order_id'] = str(orderid) + params['contract_type'] = str(expiry) + + channel = 'ok_future%s_cancel_order' %(self.currency) -if __name__ == "__main__": - # 在OkCoin网站申请这两个Key,分别对应用户名和密码 - apiKey = '' - secretKey = '' + self.sendTradingRequest(channel, params) + + #---------------------------------------------------------------------- + def futureUserInfo(self): + """查询期货账户""" + channel = 'ok_future%s_userinfo' %(self.currency) + + self.sendTradingRequest(channel, {}) + + #---------------------------------------------------------------------- + def futureOrderInfo(self, symbol, expiry, orderid, status, page, length): + """查询期货委托信息""" + params = {} + params['symbol'] = str(symbol+self.currency) + params['order_id'] = str(orderid) + params['contract_type'] = expiry + params['status'] = status + params['current_page'] = page + params['page_length'] = length + + channel = 'ok_future%s_orderinfo' + + self.sendTradingRequest(channel, params) + + #---------------------------------------------------------------------- + def subscribeFutureTrades(self): + """订阅期货成交信息""" + channel = 'ok_sub_future%s_trades' %(self.currency) + + self.sendTradingRequest(channel, {}) + + #---------------------------------------------------------------------- + def subscribeFutureUserInfo(self): + """订阅期货账户信息""" + channel = 'ok_sub_future%s_userinfo' %(self.currency) + + self.sendTradingRequest(channel, {}) + + #---------------------------------------------------------------------- + def subscribeFuturePositions(self): + """订阅期货持仓信息""" + channel = 'ok_sub_future%s_positions' %(self.currency) + + self.sendTradingRequest(channel, {}) - # 创建API对象 - api = OkCoinApi() - - # 连接服务器,并等待1秒 - api.connect(OKCOIN_CNY, apiKey, secretKey, True) - - sleep(1) - - # 测试现货行情API - #api.subscribeSpotTicker(SYMBOL_BTC) - #api.subscribeSpotTradeData(SYMBOL_BTC) - #api.subscribeSpotDepth(SYMBOL_BTC, DEPTH_20) - #api.subscribeSpotKline(SYMBOL_BTC, INTERVAL_1M) - - # 测试现货交易API - #api.subscribeSpotTrades() - #api.subscribeSpotUserInfo() - api.spotUserInfo() - #api.spotTrade(symbol, type_, price, amount) - #api.spotCancelOrder(symbol, orderid) - #api.spotOrderInfo(symbol, orderid) - - raw_input() \ No newline at end of file + diff --git a/vn.trader/ctaAlgo/ctaBacktesting.py b/vn.trader/ctaAlgo/ctaBacktesting.py index 19bf0f6c..684024db 100644 --- a/vn.trader/ctaAlgo/ctaBacktesting.py +++ b/vn.trader/ctaAlgo/ctaBacktesting.py @@ -7,6 +7,7 @@ from datetime import datetime, timedelta from collections import OrderedDict +from itertools import product import pymongo from ctaBase import * @@ -55,6 +56,9 @@ class BacktestingEngine(object): self.initData = [] # 初始化用的数据 #self.backtestingData = [] # 回测用的数据 + self.dbName = '' # 回测数据库名 + self.symbol = '' # 回测集合名 + self.dataStartDate = None # 回测数据开始日期,datetime对象 self.dataEndDate = None # 回测数据结束日期,datetime对象 self.strategyStartDate = None # 策略启动日期(即前面的数据用于初始化),datetime对象 @@ -91,14 +95,20 @@ class BacktestingEngine(object): def setBacktestingMode(self, mode): """设置回测模式""" self.mode = mode - + #---------------------------------------------------------------------- - def loadHistoryData(self, dbName, symbol): + def setDatabase(self, dbName, symbol): + """设置历史数据所用的数据库""" + self.dbName = dbName + self.symbol = symbol + + #---------------------------------------------------------------------- + def loadHistoryData(self): """载入历史数据""" host, port = loadMongoSetting() self.dbClient = pymongo.MongoClient(host, port) - collection = self.dbClient[dbName][symbol] + collection = self.dbClient[self.dbName][self.symbol] self.output(u'开始载入数据') @@ -134,6 +144,9 @@ class BacktestingEngine(object): #---------------------------------------------------------------------- def runBacktesting(self): """运行回测""" + # 载入历史数据 + self.loadHistoryData() + # 首先根据回测模式,确认要使用的数据类 if self.mode == self.BAR_MODE: dataClass = CtaBarData @@ -431,23 +444,20 @@ class BacktestingEngine(object): #---------------------------------------------------------------------- def output(self, content): """输出内容""" - print content - + print str(datetime.now()) + "\t" + content + #---------------------------------------------------------------------- - def showBacktestingResult(self): + def calculateBacktestingResult(self): """ - 显示回测结果 + 计算回测结果 """ - self.output(u'显示回测结果') + self.output(u'计算回测结果') # 首先基于回测后的成交记录,计算每笔交易的盈亏 - pnlDict = OrderedDict() # 每笔盈亏的记录 + resultDict = OrderedDict() # 交易结果记录 longTrade = [] # 未平仓的多头交易 shortTrade = [] # 未平仓的空头交易 - # 计算滑点,一个来回包括两次 - totalSlippage = self.slippage * 2 - for trade in self.tradeDict.values(): # 多头交易 if trade.direction == DIRECTION_LONG: @@ -457,12 +467,10 @@ class BacktestingEngine(object): # 当前多头交易为平空 else: entryTrade = shortTrade.pop(0) - # 计算比例佣金 - commission = (trade.price+entryTrade.price) * self.rate - # 计算盈亏 - pnl = ((trade.price - entryTrade.price)*(-1) - totalSlippage - commission) \ - * trade.volume * self.size - pnlDict[trade.dt] = pnl + + result = TradingResult(entryTrade.price, trade.price, -trade.volume, + self.rate, self.slippage, self.size) + resultDict[trade.dt] = result # 空头交易 else: # 如果尚无多头交易 @@ -471,56 +479,93 @@ class BacktestingEngine(object): # 当前空头交易为平多 else: entryTrade = longTrade.pop(0) - # 计算比例佣金 - commission = (trade.price+entryTrade.price) * self.rate - # 计算盈亏 - pnl = ((trade.price - entryTrade.price) - totalSlippage - commission) \ - * trade.volume * self.size - pnlDict[trade.dt] = pnl + + result = TradingResult(entryTrade.price, trade.price, trade.volume, + self.rate, self.slippage, self.size) + resultDict[trade.dt] = result + + # 检查是否有交易 + if not resultDict: + self.output(u'无交易结果') + return {} - # 然后基于每笔交易的结果,我们可以计算具体的盈亏曲线和最大回撤等 - timeList = pnlDict.keys() - pnlList = pnlDict.values() + # 然后基于每笔交易的结果,我们可以计算具体的盈亏曲线和最大回撤等 + capital = 0 # 资金 + maxCapital = 0 # 资金最高净值 + drawdown = 0 # 回撤 - capital = 0 - maxCapital = 0 - drawdown = 0 + totalResult = 0 # 总成交数量 + totalTurnover = 0 # 总成交金额(合约面值) + totalCommission = 0 # 总手续费 + totalSlippage = 0 # 总滑点 + timeList = [] # 时间序列 + pnlList = [] # 每笔盈亏序列 capitalList = [] # 盈亏汇总的时间序列 - maxCapitalList = [] # 最高盈利的时间序列 drawdownList = [] # 回撤的时间序列 - for pnl in pnlList: - capital += pnl + for time, result in resultDict.items(): + capital += result.pnl maxCapital = max(capital, maxCapital) drawdown = capital - maxCapital + pnlList.append(result.pnl) + timeList.append(time) capitalList.append(capital) - maxCapitalList.append(maxCapital) drawdownList.append(drawdown) + totalResult += 1 + totalTurnover += result.turnover + totalCommission += result.commission + totalSlippage += result.slippage + + # 返回回测结果 + d = {} + d['capital'] = capital + d['maxCapital'] = maxCapital + d['drawdown'] = drawdown + d['totalResult'] = totalResult + d['totalTurnover'] = totalTurnover + d['totalCommission'] = totalCommission + d['totalSlippage'] = totalSlippage + d['timeList'] = timeList + d['pnlList'] = pnlList + d['capitalList'] = capitalList + d['drawdownList'] = drawdownList + return d + + #---------------------------------------------------------------------- + def showBacktestingResult(self): + """显示回测结果""" + d = self.calculateBacktestingResult() + # 输出 - self.output('-' * 50) - self.output(u'第一笔交易时间:%s' % timeList[0]) - self.output(u'最后一笔交易时间:%s' % timeList[-1]) - self.output(u'总交易次数:%s' % len(pnlList)) - self.output(u'总盈亏:%s' % capitalList[-1]) - self.output(u'最大回撤: %s' % min(drawdownList)) + self.output('-' * 30) + self.output(u'第一笔交易:\t%s' % d['timeList'][0]) + self.output(u'最后一笔交易:\t%s' % d['timeList'][-1]) + + self.output(u'总交易次数:\t%s' % formatNumber(d['totalResult'])) + self.output(u'总盈亏:\t%s' % formatNumber(d['capital'])) + self.output(u'最大回撤: \t%s' % formatNumber(min(d['drawdownList']))) + + self.output(u'平均每笔盈利:\t%s' %formatNumber(d['capital']/d['totalResult'])) + self.output(u'平均每笔滑点:\t%s' %formatNumber(d['totalSlippage']/d['totalResult'])) + self.output(u'平均每笔佣金:\t%s' %formatNumber(d['totalCommission']/d['totalResult'])) # 绘图 import matplotlib.pyplot as plt pCapital = plt.subplot(3, 1, 1) pCapital.set_ylabel("capital") - pCapital.plot(capitalList) + pCapital.plot(d['capitalList']) pDD = plt.subplot(3, 1, 2) pDD.set_ylabel("DD") - pDD.bar(range(len(drawdownList)), drawdownList) + pDD.bar(range(len(d['drawdownList'])), d['drawdownList']) pPnl = plt.subplot(3, 1, 3) pPnl.set_ylabel("pnl") - pPnl.hist(pnlList, bins=20) + pPnl.hist(d['pnlList'], bins=50) plt.show() @@ -531,7 +576,7 @@ class BacktestingEngine(object): #---------------------------------------------------------------------- def setSlippage(self, slippage): - """设置滑点""" + """设置滑点点数""" self.slippage = slippage #---------------------------------------------------------------------- @@ -544,6 +589,137 @@ class BacktestingEngine(object): """设置佣金比例""" self.rate = rate + #---------------------------------------------------------------------- + def runOptimization(self, strategyClass, optimizationSetting): + """优化参数""" + # 获取优化设置 + settingList = optimizationSetting.generateSetting() + targetName = optimizationSetting.optimizeTarget + + # 检查参数设置问题 + if not settingList or not targetName: + self.output(u'优化设置有问题,请检查') + + # 遍历优化 + resultList = [] + for setting in settingList: + self.clearBacktestingResult() + self.output('-' * 30) + self.output('setting: %s' %str(setting)) + self.initStrategy(strategyClass, setting) + self.runBacktesting() + d = self.calculateBacktestingResult() + try: + targetValue = d[targetName] + except KeyError: + targetValue = 0 + resultList.append(([str(setting)], targetValue)) + + # 显示结果 + resultList.sort(reverse=True, key=lambda result:result[1]) + self.output('-' * 30) + self.output(u'优化结果:') + for result in resultList: + self.output(u'%s: %s' %(result[0], result[1])) + + #---------------------------------------------------------------------- + def clearBacktestingResult(self): + """清空之前回测的结果""" + # 清空限价单相关 + self.limitOrderCount = 0 + self.limitOrderDict.clear() + self.workingLimitOrderDict.clear() + + # 清空停止单相关 + self.stopOrderCount = 0 + self.stopOrderDict.clear() + self.workingStopOrderDict.clear() + + # 清空成交相关 + self.tradeCount = 0 + self.tradeDict.clear() + + +######################################################################## +class TradingResult(object): + """每笔交易的结果""" + + #---------------------------------------------------------------------- + def __init__(self, entry, exit, volume, rate, slippage, size): + """Constructor""" + self.entry = entry # 开仓价格 + self.exit = exit # 平仓价格 + self.volume = volume # 交易数量(+/-代表方向) + + self.turnover = (self.entry+self.exit)*size # 成交金额 + self.commission = self.turnover*rate # 手续费成本 + self.slippage = slippage*2*size # 滑点成本 + self.pnl = ((self.exit - self.entry) * volume * size + - self.commission - self.slippage) # 净盈亏 + + +######################################################################## +class OptimizationSetting(object): + """优化设置""" + + #---------------------------------------------------------------------- + def __init__(self): + """Constructor""" + self.paramDict = OrderedDict() + + self.optimizeTarget = '' # 优化目标字段 + + #---------------------------------------------------------------------- + def addParameter(self, name, start, end, step): + """增加优化参数""" + if end <= start: + print u'参数起始点必须小于终止点' + return + + if step <= 0: + print u'参数布进必须大于0' + return + + l = [] + param = start + + while param <= end: + l.append(param) + param += step + + self.paramDict[name] = l + + #---------------------------------------------------------------------- + def generateSetting(self): + """生成优化参数组合""" + # 参数名的列表 + nameList = self.paramDict.keys() + paramList = self.paramDict.values() + + # 使用迭代工具生产参数对组合 + productList = list(product(*paramList)) + + # 把参数对组合打包到一个个字典组成的列表中 + settingList = [] + for p in productList: + d = dict(zip(nameList, p)) + settingList.append(d) + + return settingList + + #---------------------------------------------------------------------- + def setOptimizeTarget(self, target): + """设置优化目标字段""" + self.optimizeTarget = target + + +#---------------------------------------------------------------------- +def formatNumber(n): + """格式化数字到字符串""" + n = round(n, 2) # 保留两位小数 + return format(n, ',') # 加上千分符 + + if __name__ == '__main__': @@ -562,7 +738,7 @@ if __name__ == '__main__': engine.setStartDate('20110101') # 载入历史数据到引擎中 - engine.loadHistoryData(MINUTE_DB_NAME, 'IF0000') + engine.setDatabase(MINUTE_DB_NAME, 'IF0000') # 设置产品相关参数 engine.setSlippage(0.2) # 股指1跳 diff --git a/vn.trader/ctaAlgo/strategyAtrRsi.py b/vn.trader/ctaAlgo/strategyAtrRsi.py index c838e6a2..b8374c2e 100644 --- a/vn.trader/ctaAlgo/strategyAtrRsi.py +++ b/vn.trader/ctaAlgo/strategyAtrRsi.py @@ -251,21 +251,29 @@ if __name__ == '__main__': # 设置回测用的数据起始日期 engine.setStartDate('20120101') - # 载入历史数据到引擎中 - engine.loadHistoryData(MINUTE_DB_NAME, 'IF0000') - # 设置产品相关参数 engine.setSlippage(0.2) # 股指1跳 engine.setRate(0.3/10000) # 万0.3 - engine.setSize(300) # 股指合约大小 + engine.setSize(300) # 股指合约大小 - # 在引擎中创建策略对象 - engine.initStrategy(AtrRsiStrategy, {}) + # 设置使用的历史数据库 + engine.setDatabase(MINUTE_DB_NAME, 'IF0000') - # 开始跑回测 - engine.runBacktesting() + ## 在引擎中创建策略对象 + #d = {'atrLength': 11} + #engine.initStrategy(AtrRsiStrategy, d) - # 显示回测结果 - engine.showBacktestingResult() + ## 开始跑回测 + #engine.runBacktesting() + + ## 显示回测结果 + #engine.showBacktestingResult() + + # 跑优化 + setting = OptimizationSetting() # 新建一个优化任务设置对象 + setting.setOptimizeTarget('capital') # 设置优化排序的目标是策略净盈利 + setting.addParameter('atrLength', 11, 12, 1) # 增加第一个优化参数atrLength,起始11,结束12,步进1 + setting.addParameter('atrMa', 20, 30, 5) # 增加第二个优化参数atrMa,起始20,结束30,步进1 + engine.runOptimization(AtrRsiStrategy, setting) # 运行优化函数,自动输出结果 \ No newline at end of file diff --git a/vn.trader/ctpGateway/ctpGateway.py b/vn.trader/ctpGateway/ctpGateway.py index f090a288..d663cf31 100644 --- a/vn.trader/ctpGateway/ctpGateway.py +++ b/vn.trader/ctpGateway/ctpGateway.py @@ -437,6 +437,7 @@ class CtpTdApi(TdApi): self.posBufferDict = {} # 缓存持仓数据的字典 self.symbolExchangeDict = {} # 保存合约代码和交易所的印射关系 + self.symbolSizeDict = {} # 保存合约代码和合约大小的印射关系 #---------------------------------------------------------------------- def onFrontConnected(self): @@ -637,10 +638,11 @@ class CtpTdApi(TdApi): # 更新持仓缓存,并获取VT系统中持仓对象的返回值 exchange = self.symbolExchangeDict.get(data['InstrumentID'], EXCHANGE_UNKNOWN) + size = self.symbolSizeDict.get(data['InstrumentID'], 1) if exchange == EXCHANGE_SHFE: - pos = posBuffer.updateShfeBuffer(data) + pos = posBuffer.updateShfeBuffer(data, size) else: - pos = posBuffer.updateBuffer(data) + pos = posBuffer.updateBuffer(data, size) self.gateway.onPosition(pos) #---------------------------------------------------------------------- @@ -735,6 +737,7 @@ class CtpTdApi(TdApi): # 缓存代码和交易所的印射关系 self.symbolExchangeDict[contract.symbol] = contract.exchange + self.symbolSizeDict[contract.symbol] = contract.size # 推送 self.gateway.onContract(contract) @@ -1318,7 +1321,7 @@ class PositionBuffer(object): self.pos = pos #---------------------------------------------------------------------- - def updateShfeBuffer(self, data): + def updateShfeBuffer(self, data, size): """更新上期所缓存,返回更新后的持仓数据""" # 昨仓和今仓的数据更新是分在两条记录里的,因此需要判断检查该条记录对应仓位 # 因为今仓字段TodayPosition可能变为0(被全部平仓),因此分辨今昨仓需要用YdPosition字段 @@ -1336,7 +1339,7 @@ class PositionBuffer(object): # 如果手头还有持仓,则通过加权平均方式计算持仓均价 if self.todayPosition or self.ydPosition: self.pos.price = ((self.todayPositionCost + self.ydPositionCost)/ - (self.todayPosition + self.ydPosition)) + ((self.todayPosition + self.ydPosition) * size)) # 否则价格为0 else: self.pos.price = 0 @@ -1344,14 +1347,14 @@ class PositionBuffer(object): return copy(self.pos) #---------------------------------------------------------------------- - def updateBuffer(self, data): + def updateBuffer(self, data, size): """更新其他交易所的缓存,返回更新后的持仓数据""" # 其他交易所并不区分今昨,因此只关心总仓位,昨仓设为0 self.pos.position = data['Position'] self.pos.ydPosition = 0 if data['Position']: - self.pos.price = data['PositionCost'] / data['Position'] + self.pos.price = data['PositionCost'] / (data['Position'] * size) else: self.pos.price = 0