Merge pull request #123 from vnpy/dev

Dev
This commit is contained in:
vn.py 2016-08-04 23:32:17 +08:00 committed by GitHub
commit 5ea18e54f6
27 changed files with 2733 additions and 669 deletions

View File

@ -1,5 +1,12 @@
# vn.py - 基于python的开源交易平台开发框架
### 论坛
新的论坛[维恩的派](http://www.vnpie.com)已经上线感谢量衍投资对vn.py项目的支持
如果你在使用vn.py的过程中有任何疑问想求助或者经验想分享欢迎到维恩的派上面发帖项目作者和其他主要贡献者也会每天阅帖保证回复的效率。
---
### Quick Start

View File

@ -330,7 +330,7 @@ def main():
api.subscribePublicTopic(1)
# 注册前置机地址,测试通过
api.registerFront("tcp://211.144.195.163:34505")
api.registerFront("tcp://211.144.195.163:54505")
# 初始化api连接前置机测试通过
api.init()
@ -367,30 +367,30 @@ def main():
#sleep(0.5)
###############################################
# 发单测试, 测试通过
reqid = reqid + 1
req = {}
req['InvestorID'] = api.userID
req['UserID'] = api.userID
req['BrokerID'] = api.brokerID
req['InstrumentID'] = '510050'
req['ExchangeID'] = 'SSE'
req['OrderPriceType'] = defineDict['SECURITY_FTDC_OPT_LimitPrice']
req['LimitPrice'] = '0.1850'
req['VolumeTotalOriginal'] = 1
req['Direction'] = defineDict['SECURITY_FTDC_D_Buy']
req['CombOffsetFlag'] = defineDict['SECURITY_FTDC_OF_Open']
req['OrderRef'] = '10'
req['CombHedgeFlag'] = defineDict['SECURITY_FTDC_HF_Speculation']
req['ContingentCondition'] = defineDict['SECURITY_FTDC_CC_Immediately']
req['ForceCloseReason'] = defineDict['SECURITY_FTDC_FCC_NotForceClose']
req['IsAutoSuspend'] = 0
req['UserForceClose'] = 0
req['TimeCondition'] = defineDict['SECURITY_FTDC_TC_GFD']
req['VolumeCondition'] = defineDict['SECURITY_FTDC_VC_AV']
req['MinVolume'] = 1
i = api.reqOrderInsert(req, reqid)
sleep(1.0)
# # 发单测试, 测试通过
# reqid = reqid + 1
# req = {}
# req['InvestorID'] = api.userID
# req['UserID'] = api.userID
# req['BrokerID'] = api.brokerID
# req['InstrumentID'] = '510050'
# req['ExchangeID'] = 'SSE'
# req['OrderPriceType'] = defineDict['SECURITY_FTDC_OPT_LimitPrice']
# req['LimitPrice'] = '0.1850'
# req['VolumeTotalOriginal'] = 1
# req['Direction'] = defineDict['SECURITY_FTDC_D_Buy']
# req['CombOffsetFlag'] = defineDict['SECURITY_FTDC_OF_Open']
# req['OrderRef'] = '10'
# req['CombHedgeFlag'] = defineDict['SECURITY_FTDC_HF_Speculation']
# req['ContingentCondition'] = defineDict['SECURITY_FTDC_CC_Immediately']
# req['ForceCloseReason'] = defineDict['SECURITY_FTDC_FCC_NotForceClose']
# req['IsAutoSuspend'] = 0
# req['UserForceClose'] = 0
# req['TimeCondition'] = defineDict['SECURITY_FTDC_TC_GFD']
# req['VolumeCondition'] = defineDict['SECURITY_FTDC_VC_AV']
# req['MinVolume'] = 1
# i = api.reqOrderInsert(req, reqid)
# sleep(1.0)
# 撤单测试,测试通过
#reqid = reqid + 1

36
vn.okcoin/README.md Normal file
View File

@ -0,0 +1,36 @@
# vn.okcoin
贡献者:量衍投资
### 简介
OkCoin的比特币交易接口基于Websocket API开发实现了以下功能
1. 发送、撤销委托
2. 查询委托、持仓、资金、成交历史
3. 实时行情、成交、资金更新的推送
### 特点
相比较于[OkCoin官方](http://github.com/OKCoin/websocket/tree/master/python)给出的Python API实现vn.okcoin的一些特点
1. 同时支持OkCoin的中国站和国际站交易根据用户连接的站点会在内部自动切换结算货币CNY、USD
2. 采用面向对象的接口设计模式接近国内CTP接口的风格并对主动函数的调用参数做了大幅简化
3. 数据解包和签名生成两个热点函数使用了更加高效的实现方式
### 参数命名
函数的参数命名针对金融领域用户的习惯做了一些修改,具体对应如下:
* expiry原生命名的contract_type
* order: 原生命名的match_price
* leverage原生命名的lever_rate
* page原生命名的current_page
* length原生命名的page_length
### API版本
日期2016-06-29
链接:[http://www.okcoin.com/about/ws_getStarted.do](http://www.okcoin.com/about/ws_getStarted.do)

47
vn.okcoin/test.py Normal file
View File

@ -0,0 +1,47 @@
# encoding: UTF-8
from vnokcoin import *
# 在OkCoin网站申请这两个Key分别对应用户名和密码
apiKey = ''
secretKey = ''
# 创建API对象
api = OkCoinApi()
# 连接服务器并等待1秒
api.connect(OKCOIN_USD, apiKey, secretKey, True)
sleep(1)
# 测试现货行情API
#api.subscribeSpotTicker(SYMBOL_BTC)
#api.subscribeSpotTradeData(SYMBOL_BTC)
api.subscribeSpotDepth(SYMBOL_BTC, DEPTH_20)
#api.subscribeSpotKline(SYMBOL_BTC, INTERVAL_1M)
# 测试现货交易API
#api.subscribeSpotTrades()
#api.subscribeSpotUserInfo()
#api.spotUserInfo()
#api.spotTrade(symbol, type_, price, amount)
#api.spotCancelOrder(symbol, orderid)
#api.spotOrderInfo(symbol, orderid)
# 测试期货行情API
#api.subscribeFutureTicker(SYMBOL_BTC, FUTURE_EXPIRY_THIS_WEEK)
#api.subscribeFutureTradeData(SYMBOL_BTC, FUTURE_EXPIRY_THIS_WEEK)
#api.subscribeFutureDepth(SYMBOL_BTC, FUTURE_EXPIRY_THIS_WEEK, DEPTH_20)
#api.subscribeFutureKline(SYMBOL_BTC, FUTURE_EXPIRY_THIS_WEEK, INTERVAL_1M)
#api.subscribeFutureIndex(SYMBOL_BTC)
# 测试期货交易API
#api.subscribeFutureTrades()
#api.subscribeFutureUserInfo()
#api.subscribeFuturePositions()
#api.futureUserInfo()
#api.futureTrade(symbol, expiry, type_, price, amount, order, leverage)
#api.futureCancelOrder(symbol, expiry, orderid)
#api.futureOrderInfo(symbol, expiry, orderid, status, page, length)
raw_input()

382
vn.okcoin/vnokcoin.py Normal file
View File

@ -0,0 +1,382 @@
# encoding: UTF-8
import hashlib
import zlib
import json
from time import sleep
from threading import Thread
import websocket
# OKCOIN网站
OKCOIN_CNY = 'wss://real.okcoin.cn:10440/websocket/okcoinapi'
OKCOIN_USD = 'wss://real.okcoin.com:10440/websocket/okcoinapi'
# 账户货币代码
CURRENCY_CNY = 'cny'
CURRENCY_USD = 'usd'
# 电子货币代码
SYMBOL_BTC = 'btc'
SYMBOL_LTC = 'ltc'
# 行情深度
DEPTH_20 = 20
DEPTH_60 = 60
# K线时间区间
INTERVAL_1M = '1min'
INTERVAL_3M = '3min'
INTERVAL_5M = '5min'
INTERVAL_15M = '15min'
INTERVAL_30M = '30min'
INTERVAL_1H = '1hour'
INTERVAL_2H = '2hour'
INTERVAL_4H = '4hour'
INTERVAL_6H = '6hour'
INTERVAL_1D = 'day'
INTERVAL_3D = '3day'
INTERVAL_1W = 'week'
# 交易代码,需要后缀货币名才能完整
TRADING_SYMBOL_BTC = 'btc_'
TRADING_SYMBOL_LTC = 'ltc_'
# 委托类型
TYPE_BUY = 'buy'
TYPE_SELL = 'sell'
TYPE_BUY_MARKET = 'buy_market'
TYPE_SELL_MARKET = 'sell_market'
# 期货合约到期类型
FUTURE_EXPIRY_THIS_WEEK = 'this_week'
FUTURE_EXPIRY_NEXT_WEEK = 'next_week'
FUTURE_EXPIRY_QUARTER = 'quarter'
# 期货委托类型
FUTURE_TYPE_LONG = 1
FUTURE_TYPE_SHORT = 2
FUTURE_TYPE_SELL = 3
FUTURE_TYPE_COVER = 4
# 期货是否用现价
FUTURE_ORDER_MARKET = 1
FUTURE_ORDER_LIMIT = 0
# 期货杠杆
FUTURE_LEVERAGE_10 = 10
FUTURE_LEVERAGE_20 = 20
# 委托状态
ORDER_STATUS_NOTTRADED = 0
ORDER_STATUS_PARTTRADED = 1
ORDER_STATUS_ALLTRADED = 2
ORDER_STATUS_CANCELLED = -1
ORDER_STATUS_CANCELLING = 4
########################################################################
class OkCoinApi(object):
"""基于Websocket的API对象"""
#----------------------------------------------------------------------
def __init__(self):
"""Constructor"""
self.apiKey = '' # 用户名
self.secretKey = '' # 密码
self.host = '' # 服务器地址
self.currency = '' # 货币类型usd或者cny
self.ws = None # websocket应用对象
self.thread = None # 工作线程
#######################
## 通用函数
#######################
#----------------------------------------------------------------------
def readData(self, evt):
"""解压缩推送收到的数据"""
# 创建解压器
decompress = zlib.decompressobj(-zlib.MAX_WBITS)
# 将原始数据解压成字符串
inflated = decompress.decompress(evt) + decompress.flush()
# 通过json解析字符串
data = json.loads(inflated)
return data
#----------------------------------------------------------------------
def generateSign(self, params):
"""生成签名"""
l = []
for key in sorted(params.keys()):
l.append('%s=%s' %(key, params[key]))
l.append('secret_key=%s' %self.secretKey)
sign = '&'.join(l)
return hashlib.md5(sign.encode('utf-8')).hexdigest().upper()
#----------------------------------------------------------------------
def onMessage(self, ws, evt):
"""信息推送"""
print 'onMessage'
data = self.readData(evt)
print data
#----------------------------------------------------------------------
def onError(self, ws, evt):
"""错误推送"""
print 'onError'
print evt
#----------------------------------------------------------------------
def onClose(self, ws):
"""接口断开"""
print 'onClose'
#----------------------------------------------------------------------
def onOpen(self, ws):
"""接口打开"""
print 'onOpen'
#----------------------------------------------------------------------
def connect(self, host, apiKey, secretKey, trace=False):
"""连接服务器"""
self.host = host
self.apiKey = apiKey
self.secretKey = secretKey
if self.host == OKCOIN_CNY:
self.currency = CURRENCY_CNY
else:
self.currency = CURRENCY_USD
websocket.enableTrace(trace)
self.ws = websocket.WebSocketApp(host,
on_message=self.onMessage,
on_error=self.onError,
on_close=self.onClose,
on_open=self.onOpen)
self.thread = Thread(target=self.ws.run_forever)
self.thread.start()
#----------------------------------------------------------------------
def sendMarketDataRequest(self, channel):
"""发送行情请求"""
# 生成请求
d = {}
d['event'] = 'addChannel'
d['binary'] = True
d['channel'] = channel
# 使用json打包并发送
j = json.dumps(d)
self.ws.send(j)
#----------------------------------------------------------------------
def sendTradingRequest(self, channel, params):
"""发送交易请求"""
# 在参数字典中加上api_key和签名字段
params['api_key'] = self.apiKey
params['sign'] = self.generateSign(params)
# 生成请求
d = {}
d['event'] = 'addChannel'
d['binary'] = True
d['channel'] = channel
d['parameters'] = params
# 使用json打包并发送
j = json.dumps(d)
self.ws.send(j)
#######################
## 现货相关
#######################
#----------------------------------------------------------------------
def subscribeSpotTicker(self, symbol):
"""订阅现货普通报价"""
self.sendMarketDataRequest('ok_sub_spot%s_%s_ticker' %(self.currency, symbol))
#----------------------------------------------------------------------
def subscribeSpotDepth(self, symbol, depth):
"""订阅现货深度报价"""
self.sendMarketDataRequest('ok_sub_spot%s_%s_depth_%s' %(self.currency, symbol, depth))
#----------------------------------------------------------------------
def subscribeSpotTradeData(self, symbol):
"""订阅现货成交记录"""
self.sendMarketDataRequest('ok_sub_spot%s_%s_trades' %(self.currency, symbol))
#----------------------------------------------------------------------
def subscribeSpotKline(self, symbol, interval):
"""订阅现货K线"""
self.sendMarketDataRequest('ok_sub_spot%s_%s_kline_%s' %(self.currency, symbol, interval))
#----------------------------------------------------------------------
def spotTrade(self, symbol, type_, price, amount):
"""现货委托"""
params = {}
params['symbol'] = str(symbol+self.currency)
params['type'] = str(type_)
params['price'] = str(price)
params['amount'] = str(amount)
channel = 'ok_spot%s_trade' %(self.currency)
self.sendTradingRequest(channel, params)
#----------------------------------------------------------------------
def spotCancelOrder(self, symbol, orderid):
"""现货撤单"""
params = {}
params['symbol'] = str(symbol+self.currency)
params['order_id'] = str(orderid)
channel = 'ok_spot%s_cancel_order' %(self.currency)
self.sendTradingRequest(channel, params)
#----------------------------------------------------------------------
def spotUserInfo(self):
"""查询现货账户"""
channel = 'ok_spot%s_userinfo' %(self.currency)
self.sendTradingRequest(channel, {})
#----------------------------------------------------------------------
def spotOrderInfo(self, symbol, orderid):
"""查询现货委托信息"""
params = {}
params['symbol'] = str(symbol+self.currency)
params['order_id'] = str(orderid)
channel = 'ok_spot%s_orderinfo' %(self.currency)
self.sendTradingRequest(channel, params)
#----------------------------------------------------------------------
def subscribeSpotTrades(self):
"""订阅现货成交信息"""
channel = 'ok_sub_spot%s_trades' %(self.currency)
self.sendTradingRequest(channel, {})
#----------------------------------------------------------------------
def subscribeSpotUserInfo(self):
"""订阅现货账户信息"""
channel = 'ok_sub_spot%s_userinfo' %(self.currency)
self.sendTradingRequest(channel, {})
#######################
## 期货相关
#######################
#----------------------------------------------------------------------
def subscribeFutureTicker(self, symbol, expiry):
"""订阅期货普通报价"""
self.sendMarketDataRequest('ok_sub_future%s_%s_ticker_%s' %(self.currency, symbol, expiry))
#----------------------------------------------------------------------
def subscribeFutureDepth(self, symbol, expiry, depth):
"""订阅期货深度报价"""
self.sendMarketDataRequest('ok_sub_future%s_%s_depth_%s_%s' %(self.currency, symbol,
expiry, depth))
#----------------------------------------------------------------------
def subscribeFutureTradeData(self, symbol, expiry):
"""订阅期货成交记录"""
self.sendMarketDataRequest('ok_sub_future%s_%s_trade_%s' %(self.currency, symbol, expiry))
#----------------------------------------------------------------------
def subscribeFutureKline(self, symbol, expiry, interval):
"""订阅期货K线"""
self.sendMarketDataRequest('ok_sub_future%s_%s_kline_%s_%s' %(self.currency, symbol,
expiry, interval))
#----------------------------------------------------------------------
def subscribeFutureIndex(self, symbol):
"""订阅期货指数"""
self.sendMarketDataRequest('ok_sub_future%s_%s_index' %(self.currency, symbol))
#----------------------------------------------------------------------
def futureTrade(self, symbol, expiry, type_, price, amount, order, leverage):
"""期货委托"""
params = {}
params['symbol'] = str(symbol+self.currency)
params['type'] = str(type_)
params['price'] = str(price)
params['amount'] = str(amount)
params['contract_type'] = str(expiry)
params['match_price'] = str(order)
params['lever_rate'] = str(leverage)
channel = 'ok_future%s_trade' %(self.currency)
self.sendTradingRequest(channel, params)
#----------------------------------------------------------------------
def futureCancelOrder(self, symbol, expiry, orderid):
"""期货撤单"""
params = {}
params['symbol'] = str(symbol+self.currency)
params['order_id'] = str(orderid)
params['contract_type'] = str(expiry)
channel = 'ok_future%s_cancel_order' %(self.currency)
self.sendTradingRequest(channel, params)
#----------------------------------------------------------------------
def futureUserInfo(self):
"""查询期货账户"""
channel = 'ok_future%s_userinfo' %(self.currency)
self.sendTradingRequest(channel, {})
#----------------------------------------------------------------------
def futureOrderInfo(self, symbol, expiry, orderid, status, page, length):
"""查询期货委托信息"""
params = {}
params['symbol'] = str(symbol+self.currency)
params['order_id'] = str(orderid)
params['contract_type'] = expiry
params['status'] = status
params['current_page'] = page
params['page_length'] = length
channel = 'ok_future%s_orderinfo'
self.sendTradingRequest(channel, params)
#----------------------------------------------------------------------
def subscribeFutureTrades(self):
"""订阅期货成交信息"""
channel = 'ok_sub_future%s_trades' %(self.currency)
self.sendTradingRequest(channel, {})
#----------------------------------------------------------------------
def subscribeFutureUserInfo(self):
"""订阅期货账户信息"""
channel = 'ok_sub_future%s_userinfo' %(self.currency)
self.sendTradingRequest(channel, {})
#----------------------------------------------------------------------
def subscribeFuturePositions(self):
"""订阅期货持仓信息"""
channel = 'ok_sub_future%s_positions' %(self.currency)
self.sendTradingRequest(channel, {})

View File

@ -4,9 +4,11 @@
本文件中包含的是CTA模块的回测引擎回测引擎的API和CTA引擎一致
可以使用和实盘相同的代码进行回测
'''
from __future__ import division
from datetime import datetime, timedelta
from collections import OrderedDict
from itertools import product
import pymongo
from ctaBase import *
@ -40,6 +42,9 @@ class BacktestingEngine(object):
self.stopOrderDict = {} # 停止单撤销后不会从本字典中删除
self.workingStopOrderDict = {} # 停止单撤销后会从本字典中删除
# 引擎类型为回测
self.engineType = ENGINETYPE_BACKTESTING
# 回测相关
self.strategy = None # 回测策略
self.mode = self.BAR_MODE # 回测模式默认为K线
@ -55,6 +60,9 @@ class BacktestingEngine(object):
self.initData = [] # 初始化用的数据
#self.backtestingData = [] # 回测用的数据
self.dbName = '' # 回测数据库名
self.symbol = '' # 回测集合名
self.dataStartDate = None # 回测数据开始日期datetime对象
self.dataEndDate = None # 回测数据结束日期datetime对象
self.strategyStartDate = None # 策略启动日期即前面的数据用于初始化datetime对象
@ -93,12 +101,18 @@ class BacktestingEngine(object):
self.mode = mode
#----------------------------------------------------------------------
def loadHistoryData(self, dbName, symbol):
def setDatabase(self, dbName, symbol):
"""设置历史数据所用的数据库"""
self.dbName = dbName
self.symbol = symbol
#----------------------------------------------------------------------
def loadHistoryData(self):
"""载入历史数据"""
host, port = loadMongoSetting()
self.dbClient = pymongo.MongoClient(host, port)
collection = self.dbClient[dbName][symbol]
collection = self.dbClient[self.dbName][self.symbol]
self.output(u'开始载入数据')
@ -134,6 +148,9 @@ class BacktestingEngine(object):
#----------------------------------------------------------------------
def runBacktesting(self):
"""运行回测"""
# 载入历史数据
self.loadHistoryData()
# 首先根据回测模式,确认要使用的数据类
if self.mode == self.BAR_MODE:
dataClass = CtaBarData
@ -281,11 +298,13 @@ class BacktestingEngine(object):
if self.mode == self.BAR_MODE:
buyCrossPrice = self.bar.low # 若买入方向限价单价格高于该价格,则会成交
sellCrossPrice = self.bar.high # 若卖出方向限价单价格低于该价格,则会成交
bestCrossPrice = self.bar.open # 在当前时间点前发出的委托可能的最优成交价
buyBestCrossPrice = self.bar.open # 在当前时间点前发出的买入委托可能的最优成交价
sellBestCrossPrice = self.bar.open # 在当前时间点前发出的卖出委托可能的最优成交价
else:
buyCrossPrice = self.tick.lastPrice
sellCrossPrice = self.tick.lastPrice
bestCrossPrice = self.tick.lastPrice
buyCrossPrice = self.tick.askPrice1
sellCrossPrice = self.tick.bidPrice1
buyBestCrossPrice = self.tick.askPrice1
sellBestCrossPrice = self.tick.bidPrice1
# 遍历限价单字典中的所有限价单
for orderID, order in self.workingLimitOrderDict.items():
@ -312,10 +331,10 @@ class BacktestingEngine(object):
# 2. 假设在上一根K线结束(也是当前K线开始)的时刻策略发出的委托为限价105
# 3. 则在实际中的成交价会是100而不是105因为委托发出时市场的最优价格是100
if buyCross:
trade.price = min(order.price, bestCrossPrice)
trade.price = min(order.price, buyBestCrossPrice)
self.strategy.pos += order.totalVolume
else:
trade.price = max(order.price, bestCrossPrice)
trade.price = max(order.price, sellBestCrossPrice)
self.strategy.pos -= order.totalVolume
trade.volume = order.totalVolume
@ -429,23 +448,21 @@ class BacktestingEngine(object):
#----------------------------------------------------------------------
def output(self, content):
"""输出内容"""
print content
print str(datetime.now()) + "\t" + content
#----------------------------------------------------------------------
def showBacktestingResult(self):
def calculateBacktestingResult(self):
"""
显示回测结果
计算回测结果
"""
self.output(u'显示回测结果')
self.output(u'计算回测结果')
# 首先基于回测后的成交记录,计算每笔交易的盈亏
pnlDict = OrderedDict() # 每笔盈亏的记录
resultList = [] # 交易结果列表
longTrade = [] # 未平仓的多头交易
shortTrade = [] # 未平仓的空头交易
# 计算滑点,一个来回包括两次
totalSlippage = self.slippage * 2
for trade in self.tradeDict.values():
# 多头交易
if trade.direction == DIRECTION_LONG:
@ -454,13 +471,40 @@ class BacktestingEngine(object):
longTrade.append(trade)
# 当前多头交易为平空
else:
entryTrade = shortTrade.pop(0)
# 计算比例佣金
commission = (trade.price+entryTrade.price) * self.rate
# 计算盈亏
pnl = ((trade.price - entryTrade.price)*(-1) - totalSlippage - commission) \
* trade.volume * self.size
pnlDict[trade.dt] = pnl
while True:
entryTrade = shortTrade[0]
exitTrade = trade
# 清算开平仓交易
closedVolume = min(exitTrade.volume, entryTrade.volume)
result = TradingResult(entryTrade.price, entryTrade.dt,
exitTrade.price, exitTrade.dt,
-closedVolume, self.rate, self.slippage, self.size)
resultList.append(result)
# 计算未清算部分
entryTrade.volume -= closedVolume
exitTrade.volume -= closedVolume
# 如果开仓交易已经全部清算,则从列表中移除
if not entryTrade.volume:
shortTrade.pop(0)
# 如果平仓交易已经全部清算,则退出循环
if not exitTrade.volume:
break
# 如果平仓交易未全部清算,
if exitTrade.volume:
# 且开仓交易已经全部清算完,则平仓交易剩余的部分
# 等于新的反向开仓交易,添加到队列中
if not shortTrade:
longTrade.append(exitTrade)
break
# 如果开仓交易还有剩余,则进入下一轮循环
else:
pass
# 空头交易
else:
# 如果尚无多头交易
@ -468,57 +512,150 @@ class BacktestingEngine(object):
shortTrade.append(trade)
# 当前空头交易为平多
else:
entryTrade = longTrade.pop(0)
# 计算比例佣金
commission = (trade.price+entryTrade.price) * self.rate
# 计算盈亏
pnl = ((trade.price - entryTrade.price) - totalSlippage - commission) \
* trade.volume * self.size
pnlDict[trade.dt] = pnl
while True:
entryTrade = longTrade[0]
exitTrade = trade
# 清算开平仓交易
closedVolume = min(exitTrade.volume, entryTrade.volume)
result = TradingResult(entryTrade.price, entryTrade.dt,
exitTrade.price, exitTrade.dt,
closedVolume, self.rate, self.slippage, self.size)
resultList.append(result)
# 计算未清算部分
entryTrade.volume -= closedVolume
exitTrade.volume -= closedVolume
# 如果开仓交易已经全部清算,则从列表中移除
if not entryTrade.volume:
longTrade.pop(0)
# 如果平仓交易已经全部清算,则退出循环
if not exitTrade.volume:
break
# 如果平仓交易未全部清算,
if exitTrade.volume:
# 且开仓交易已经全部清算完,则平仓交易剩余的部分
# 等于新的反向开仓交易,添加到队列中
if not longTrade:
shortTrade.append(exitTrade)
break
# 如果开仓交易还有剩余,则进入下一轮循环
else:
pass
# 检查是否有交易
if not resultList:
self.output(u'无交易结果')
return {}
# 然后基于每笔交易的结果,我们可以计算具体的盈亏曲线和最大回撤等
timeList = pnlDict.keys()
pnlList = pnlDict.values()
capital = 0 # 资金
maxCapital = 0 # 资金最高净值
drawdown = 0 # 回撤
capital = 0
maxCapital = 0
drawdown = 0
totalResult = 0 # 总成交数量
totalTurnover = 0 # 总成交金额(合约面值)
totalCommission = 0 # 总手续费
totalSlippage = 0 # 总滑点
timeList = [] # 时间序列
pnlList = [] # 每笔盈亏序列
capitalList = [] # 盈亏汇总的时间序列
maxCapitalList = [] # 最高盈利的时间序列
drawdownList = [] # 回撤的时间序列
for pnl in pnlList:
capital += pnl
winningResult = 0 # 盈利次数
losingResult = 0 # 亏损次数
totalWinning = 0 # 总盈利金额
totalLosing = 0 # 总亏损金额
for result in resultList:
capital += result.pnl
maxCapital = max(capital, maxCapital)
drawdown = capital - maxCapital
pnlList.append(result.pnl)
timeList.append(result.exitDt) # 交易的时间戳使用平仓时间
capitalList.append(capital)
maxCapitalList.append(maxCapital)
drawdownList.append(drawdown)
totalResult += 1
totalTurnover += result.turnover
totalCommission += result.commission
totalSlippage += result.slippage
if result.pnl >= 0:
winningResult += 1
totalWinning += result.pnl
else:
losingResult += 1
totalLosing += result.pnl
# 计算盈亏相关数据
winningRate = winningResult/totalResult*100 # 胜率
averageWinning = totalWinning/winningResult # 平均每笔盈利
averageLosing = totalLosing/losingResult # 平均每笔亏损
profitLossRatio = -averageWinning/averageLosing # 盈亏比
# 返回回测结果
d = {}
d['capital'] = capital
d['maxCapital'] = maxCapital
d['drawdown'] = drawdown
d['totalResult'] = totalResult
d['totalTurnover'] = totalTurnover
d['totalCommission'] = totalCommission
d['totalSlippage'] = totalSlippage
d['timeList'] = timeList
d['pnlList'] = pnlList
d['capitalList'] = capitalList
d['drawdownList'] = drawdownList
d['winningRate'] = winningRate
d['averageWinning'] = averageWinning
d['averageLosing'] = averageLosing
d['profitLossRatio'] = profitLossRatio
return d
#----------------------------------------------------------------------
def showBacktestingResult(self):
"""显示回测结果"""
d = self.calculateBacktestingResult()
# 输出
self.output('-' * 50)
self.output(u'第一笔交易时间:%s' % timeList[0])
self.output(u'最后一笔交易时间:%s' % timeList[-1])
self.output(u'总交易次数:%s' % len(pnlList))
self.output(u'总盈亏:%s' % capitalList[-1])
self.output(u'最大回撤: %s' % min(drawdownList))
self.output('-' * 30)
self.output(u'第一笔交易:\t%s' % d['timeList'][0])
self.output(u'最后一笔交易:\t%s' % d['timeList'][-1])
self.output(u'总交易次数:\t%s' % formatNumber(d['totalResult']))
self.output(u'总盈亏:\t%s' % formatNumber(d['capital']))
self.output(u'最大回撤: \t%s' % formatNumber(min(d['drawdownList'])))
self.output(u'平均每笔盈利:\t%s' %formatNumber(d['capital']/d['totalResult']))
self.output(u'平均每笔滑点:\t%s' %formatNumber(d['totalSlippage']/d['totalResult']))
self.output(u'平均每笔佣金:\t%s' %formatNumber(d['totalCommission']/d['totalResult']))
self.output(u'胜率\t\t%s%%' %formatNumber(d['winningRate']))
self.output(u'平均每笔盈利\t%s' %formatNumber(d['averageWinning']))
self.output(u'平均每笔亏损\t%s' %formatNumber(d['averageLosing']))
self.output(u'盈亏比:\t%s' %formatNumber(d['profitLossRatio']))
# 绘图
import matplotlib.pyplot as plt
pCapital = plt.subplot(3, 1, 1)
pCapital.set_ylabel("capital")
pCapital.plot(capitalList)
pCapital.plot(d['capitalList'])
pDD = plt.subplot(3, 1, 2)
pDD.set_ylabel("DD")
pDD.bar(range(len(drawdownList)), drawdownList)
pDD.bar(range(len(d['drawdownList'])), d['drawdownList'])
pPnl = plt.subplot(3, 1, 3)
pPnl.set_ylabel("pnl")
pPnl.hist(pnlList, bins=20)
pPnl.hist(d['pnlList'], bins=50)
plt.show()
@ -529,7 +666,7 @@ class BacktestingEngine(object):
#----------------------------------------------------------------------
def setSlippage(self, slippage):
"""设置滑点"""
"""设置滑点点数"""
self.slippage = slippage
#----------------------------------------------------------------------
@ -542,6 +679,142 @@ class BacktestingEngine(object):
"""设置佣金比例"""
self.rate = rate
#----------------------------------------------------------------------
def runOptimization(self, strategyClass, optimizationSetting):
"""优化参数"""
# 获取优化设置
settingList = optimizationSetting.generateSetting()
targetName = optimizationSetting.optimizeTarget
# 检查参数设置问题
if not settingList or not targetName:
self.output(u'优化设置有问题,请检查')
# 遍历优化
resultList = []
for setting in settingList:
self.clearBacktestingResult()
self.output('-' * 30)
self.output('setting: %s' %str(setting))
self.initStrategy(strategyClass, setting)
self.runBacktesting()
d = self.calculateBacktestingResult()
try:
targetValue = d[targetName]
except KeyError:
targetValue = 0
resultList.append(([str(setting)], targetValue))
# 显示结果
resultList.sort(reverse=True, key=lambda result:result[1])
self.output('-' * 30)
self.output(u'优化结果:')
for result in resultList:
self.output(u'%s: %s' %(result[0], result[1]))
#----------------------------------------------------------------------
def clearBacktestingResult(self):
"""清空之前回测的结果"""
# 清空限价单相关
self.limitOrderCount = 0
self.limitOrderDict.clear()
self.workingLimitOrderDict.clear()
# 清空停止单相关
self.stopOrderCount = 0
self.stopOrderDict.clear()
self.workingStopOrderDict.clear()
# 清空成交相关
self.tradeCount = 0
self.tradeDict.clear()
########################################################################
class TradingResult(object):
"""每笔交易的结果"""
#----------------------------------------------------------------------
def __init__(self, entryPrice, entryDt, exitPrice,
exitDt, volume, rate, slippage, size):
"""Constructor"""
self.entryPrice = entryPrice # 开仓价格
self.exitPrice = exitPrice # 平仓价格
self.entryDt = entryDt # 开仓时间datetime
self.exitDt = exitDt # 平仓时间
self.volume = volume # 交易数量(+/-代表方向)
self.turnover = (self.entryPrice+self.exitPrice)*size*abs(volume) # 成交金额
self.commission = self.turnover*rate # 手续费成本
self.slippage = slippage*2*size*abs(volume) # 滑点成本
self.pnl = ((self.exitPrice - self.entryPrice) * volume * size
- self.commission - self.slippage) # 净盈亏
########################################################################
class OptimizationSetting(object):
"""优化设置"""
#----------------------------------------------------------------------
def __init__(self):
"""Constructor"""
self.paramDict = OrderedDict()
self.optimizeTarget = '' # 优化目标字段
#----------------------------------------------------------------------
def addParameter(self, name, start, end, step):
"""增加优化参数"""
if end <= start:
print u'参数起始点必须小于终止点'
return
if step <= 0:
print u'参数布进必须大于0'
return
l = []
param = start
while param <= end:
l.append(param)
param += step
self.paramDict[name] = l
#----------------------------------------------------------------------
def generateSetting(self):
"""生成优化参数组合"""
# 参数名的列表
nameList = self.paramDict.keys()
paramList = self.paramDict.values()
# 使用迭代工具生产参数对组合
productList = list(product(*paramList))
# 把参数对组合打包到一个个字典组成的列表中
settingList = []
for p in productList:
d = dict(zip(nameList, p))
settingList.append(d)
return settingList
#----------------------------------------------------------------------
def setOptimizeTarget(self, target):
"""设置优化目标字段"""
self.optimizeTarget = target
#----------------------------------------------------------------------
def formatNumber(n):
"""格式化数字到字符串"""
n = round(n, 2) # 保留两位小数
return format(n, ',') # 加上千分符
if __name__ == '__main__':
@ -560,7 +833,7 @@ if __name__ == '__main__':
engine.setStartDate('20110101')
# 载入历史数据到引擎中
engine.loadHistoryData(MINUTE_DB_NAME, 'IF0000')
engine.setDatabase(MINUTE_DB_NAME, 'IF0000')
# 设置产品相关参数
engine.setSlippage(0.2) # 股指1跳

View File

@ -33,6 +33,9 @@ TICK_DB_NAME = 'VnTrader_Tick_Db'
DAILY_DB_NAME = 'VnTrader_Daily_Db'
MINUTE_DB_NAME = 'VnTrader_1Min_Db'
# 引擎类型,用于区分当前策略的运行环境
ENGINETYPE_BACKTESTING = 'backtesting' # 回测
ENGINETYPE_TRADING = 'trading' # 实盘
# CTA引擎中涉及的数据类定义
from vtConstant import EMPTY_UNICODE, EMPTY_STRING, EMPTY_FLOAT, EMPTY_INT

View File

@ -61,6 +61,13 @@ class DoubleEmaDemo(CtaTemplate):
"""Constructor"""
super(DoubleEmaDemo, self).__init__(ctaEngine, setting)
# 注意策略类中的可变对象属性通常是list和dict等在策略初始化时需要重新创建
# 否则会出现多个策略实例之间数据共享的情况,有可能导致潜在的策略逻辑错误风险,
# 策略类中的这些可变对象属性可以选择不写全都放在__init__下面写主要是为了阅读
# 策略时方便(更多是个编程习惯的选择)
self.fastMa = []
self.slowMa = []
#----------------------------------------------------------------------
def onInit(self):
"""初始化策略(必须由用户继承实现)"""

View File

@ -70,6 +70,9 @@ class CtaEngine(object):
# key为vtSymbolvalue为PositionBuffer对象
self.posBufferDict = {}
# 引擎类型为实盘
self.engineType = ENGINETYPE_TRADING
# 注册事件监听
self.registerEvent()

View File

@ -278,7 +278,7 @@ class HistoryDataEngine(object):
params = {}
params['ticker'] = symbol
if last:
params['startDate'] = last['date']
params['beginDate'] = last['date']
data = self.datayesClient.downloadData(path, params)

View File

@ -8,9 +8,7 @@
在CTA_setting.json中写入具体每个策略对象的类和合约设置
'''
from ctaTemplate import DataRecorder
from ctaDemo import DoubleEmaDemo
STRATEGY_CLASS = {}
STRATEGY_CLASS['DataRecorder'] = DataRecorder
STRATEGY_CLASS['DoubleEmaDemo'] = DoubleEmaDemo

View File

@ -120,11 +120,16 @@ class CtaTemplate(object):
vtOrderID = self.ctaEngine.sendOrder(self.vtSymbol, orderType, price, volume, self)
return vtOrderID
else:
return None
# 交易停止时发单返回空字符串
return ''
#----------------------------------------------------------------------
def cancelOrder(self, vtOrderID):
"""撤单"""
# 如果发单号为空字符串,则不进行后续操作
if not vtOrderID:
return
if STOPORDERPREFIX in vtOrderID:
self.ctaEngine.cancelStopOrder(vtOrderID)
else:
@ -161,111 +166,8 @@ class CtaTemplate(object):
"""发出策略状态变化事件"""
self.ctaEngine.putStrategyEvent(self.name)
########################################################################
class DataRecorder(CtaTemplate):
"""
纯粹用来记录历史数据的工具基于CTA策略
建议运行在实际交易程序外的一个vn.trader实例中
本工具会记录Tick和1分钟K线数据
"""
className = 'DataRecorder'
author = u'用Python的交易员'
# 策略的基本参数
name = EMPTY_UNICODE # 策略实例名称
vtSymbol = EMPTY_STRING # 交易的合约vt系统代码
# 策略的变量
bar = None # K线数据对象
barMinute = EMPTY_STRING # 当前的分钟,初始化设为-1
# 变量列表,保存了变量的名称
varList = ['inited',
'trading',
'pos',
'barMinute']
#----------------------------------------------------------------------
def __init__(self, ctaEngine, setting):
"""Constructor"""
super(DataRecorder, self).__init__(ctaEngine, setting)
#----------------------------------------------------------------------
def onInit(self):
"""初始化"""
self.writeCtaLog(u'数据记录工具初始化')
#----------------------------------------------------------------------
def onStart(self):
"""启动策略(必须由用户继承实现)"""
self.writeCtaLog(u'数据记录工具启动')
self.putEvent()
#----------------------------------------------------------------------
def onStop(self):
"""停止策略(必须由用户继承实现)"""
self.writeCtaLog(u'数据记录工具停止')
self.putEvent()
#----------------------------------------------------------------------
def onTick(self, tick):
"""收到行情TICK推送"""
# 收到Tick后首先插入到数据库里
self.insertTick(tick)
# 计算K线
tickMinute = tick.datetime.minute
if tickMinute != self.barMinute: # 如果分钟变了则把旧的K线插入数据库并生成新的K线
if self.bar:
self.onBar(self.bar)
bar = CtaBarData() # 创建新的K线目的在于防止之前K线对象在插入Mongo中被再次修改导致出错
bar.vtSymbol = tick.vtSymbol
bar.symbol = tick.symbol
bar.exchange = tick.exchange
bar.open = tick.lastPrice
bar.high = tick.lastPrice
bar.low = tick.lastPrice
bar.close = tick.lastPrice
bar.date = tick.date
bar.time = tick.time
bar.datetime = tick.datetime # K线的时间设为第一个Tick的时间
bar.volume = tick.volume
bar.openInterest = tick.openInterest
self.bar = bar # 这种写法为了减少一层访问,加快速度
self.barMinute = tickMinute # 更新当前的分钟
else: # 否则继续累加新的K线
bar = self.bar # 写法同样为了加快速度
bar.high = max(bar.high, tick.lastPrice)
bar.low = min(bar.low, tick.lastPrice)
bar.close = tick.lastPrice
bar.volume = bar.volume + tick.volume # 成交量是累加的
bar.openInterest = tick.openInterest # 持仓量直接更新
#----------------------------------------------------------------------
def onOrder(self, order):
"""收到委托变化推送"""
pass
#----------------------------------------------------------------------
def onTrade(self, trade):
"""收到成交推送"""
pass
#----------------------------------------------------------------------
def onBar(self, bar):
"""收到Bar推送"""
self.insertBar(bar)
def getEngineType(self):
"""查询当前运行的环境"""
return self.ctaEngine.engineType

View File

@ -81,6 +81,11 @@ class AtrRsiStrategy(CtaTemplate):
"""Constructor"""
super(AtrRsiStrategy, self).__init__(ctaEngine, setting)
# 注意策略类中的可变对象属性通常是list和dict等在策略初始化时需要重新创建
# 否则会出现多个策略实例之间数据共享的情况,有可能导致潜在的策略逻辑错误风险,
# 策略类中的这些可变对象属性可以选择不写全都放在__init__下面写主要是为了阅读
# 策略时方便(更多是个编程习惯的选择)
#----------------------------------------------------------------------
def onInit(self):
"""初始化策略(必须由用户继承实现)"""
@ -194,14 +199,12 @@ class AtrRsiStrategy(CtaTemplate):
if self.rsiValue > self.rsiBuy:
# 这里为了保证成交选择超价5个整指数点下单
self.buy(bar.close+5, 1)
return
if self.rsiValue < self.rsiSell:
elif self.rsiValue < self.rsiSell:
self.short(bar.close-5, 1)
return
# 持有多头仓位
if self.pos == 1:
elif self.pos > 0:
# 计算多头持有期内的最高价,以及重置最低价
self.intraTradeHigh = max(self.intraTradeHigh, bar.high)
self.intraTradeLow = bar.low
@ -210,17 +213,15 @@ class AtrRsiStrategy(CtaTemplate):
# 发出本地止损委托,并且把委托号记录下来,用于后续撤单
orderID = self.sell(longStop, 1, stop=True)
self.orderList.append(orderID)
return
# 持有空头仓位
if self.pos == -1:
elif self.pos < 0:
self.intraTradeLow = min(self.intraTradeLow, bar.low)
self.intraTradeHigh = bar.high
shortStop = self.intraTradeLow * (1+self.trailingPercent/100)
orderID = self.cover(shortStop, 1, stop=True)
self.orderList.append(orderID)
return
# 发出状态更新事件
self.putEvent()
@ -250,16 +251,17 @@ if __name__ == '__main__':
# 设置回测用的数据起始日期
engine.setStartDate('20120101')
# 载入历史数据到引擎中
engine.loadHistoryData(MINUTE_DB_NAME, 'IF0000')
# 设置产品相关参数
engine.setSlippage(0.2) # 股指1跳
engine.setRate(0.3/10000) # 万0.3
engine.setSize(300) # 股指合约大小
# 设置使用的历史数据库
engine.setDatabase(MINUTE_DB_NAME, 'IF0000')
# 在引擎中创建策略对象
engine.initStrategy(AtrRsiStrategy, {})
d = {'atrLength': 11}
engine.initStrategy(AtrRsiStrategy, d)
# 开始跑回测
engine.runBacktesting()
@ -267,4 +269,11 @@ if __name__ == '__main__':
# 显示回测结果
engine.showBacktestingResult()
## 跑优化
#setting = OptimizationSetting() # 新建一个优化任务设置对象
#setting.setOptimizeTarget('capital') # 设置优化排序的目标是策略净盈利
#setting.addParameter('atrLength', 11, 12, 1) # 增加第一个优化参数atrLength起始11结束12步进1
#setting.addParameter('atrMa', 20, 30, 5) # 增加第二个优化参数atrMa起始20结束30步进1
#engine.runOptimization(AtrRsiStrategy, setting) # 运行优化函数,自动输出结果

File diff suppressed because it is too large Load Diff

View File

@ -41,14 +41,11 @@ offsetMapReverse = {v:k for k,v in offsetMap.items()}
# 交易所类型映射
exchangeMap = {}
#exchangeMap[EXCHANGE_CFFEX] = defineDict['THOST_FTDC_EIDT_CFFEX']
#exchangeMap[EXCHANGE_SHFE] = defineDict['THOST_FTDC_EIDT_SHFE']
#exchangeMap[EXCHANGE_CZCE] = defineDict['THOST_FTDC_EIDT_CZCE']
#exchangeMap[EXCHANGE_DCE] = defineDict['THOST_FTDC_EIDT_DCE']
exchangeMap[EXCHANGE_CFFEX] = 'CFFEX'
exchangeMap[EXCHANGE_SHFE] = 'SHFE'
exchangeMap[EXCHANGE_CZCE] = 'CZCE'
exchangeMap[EXCHANGE_DCE] = 'DCE'
exchangeMap[EXCHANGE_SSE] = 'SSE'
exchangeMap[EXCHANGE_UNKNOWN] = ''
exchangeMapReverse = {v:k for k,v in exchangeMap.items()}
@ -59,6 +56,14 @@ posiDirectionMap[DIRECTION_LONG] = defineDict["THOST_FTDC_PD_Long"]
posiDirectionMap[DIRECTION_SHORT] = defineDict["THOST_FTDC_PD_Short"]
posiDirectionMapReverse = {v:k for k,v in posiDirectionMap.items()}
# 产品类型映射
productClassMap = {}
productClassMap[PRODUCT_FUTURES] = defineDict["THOST_FTDC_PC_Futures"]
productClassMap[PRODUCT_OPTION] = defineDict["THOST_FTDC_PC_Options"]
productClassMap[PRODUCT_COMBINATION] = defineDict["THOST_FTDC_PC_Combination"]
productClassMapReverse = {v:k for k,v in productClassMap.items()}
########################################################################
class CtpGateway(VtGateway):
@ -284,7 +289,7 @@ class CtpMdApi(MdApi):
# 如果登出成功,推送日志信息
if error['ErrorID'] == 0:
self.loginStatus = False
self.gateway.tdConnected = False
self.gateway.mdConnected = False
log = VtLogData()
log.gatewayName = self.gatewayName
@ -439,6 +444,8 @@ class CtpTdApi(TdApi):
self.sessionID = EMPTY_INT # 会话编号
self.posBufferDict = {} # 缓存持仓数据的字典
self.symbolExchangeDict = {} # 保存合约代码和交易所的印射关系
self.symbolSizeDict = {} # 保存合约代码和合约大小的印射关系
#----------------------------------------------------------------------
def onFrontConnected(self):
@ -482,7 +489,7 @@ class CtpTdApi(TdApi):
self.frontID = str(data['FrontID'])
self.sessionID = str(data['SessionID'])
self.loginStatus = True
self.gateway.mdConnected = True
self.gateway.tdConnected = True
log = VtLogData()
log.gatewayName = self.gatewayName
@ -615,6 +622,16 @@ class CtpTdApi(TdApi):
""""""
pass
#----------------------------------------------------------------------
def onRspLockInsert(self, data, error, n, last):
""""""
pass
#----------------------------------------------------------------------
def onRspCombActionInsert(self, data, error, n, last):
""""""
pass
#----------------------------------------------------------------------
def onRspQryOrder(self, data, error, n, last):
""""""
@ -638,7 +655,12 @@ class CtpTdApi(TdApi):
self.posBufferDict[positionName] = posBuffer
# 更新持仓缓存并获取VT系统中持仓对象的返回值
pos = posBuffer.updateBuffer(data)
exchange = self.symbolExchangeDict.get(data['InstrumentID'], EXCHANGE_UNKNOWN)
size = self.symbolSizeDict.get(data['InstrumentID'], 1)
if exchange == EXCHANGE_SHFE:
pos = posBuffer.updateShfeBuffer(data, size)
else:
pos = posBuffer.updateBuffer(data, size)
self.gateway.onPosition(pos)
#----------------------------------------------------------------------
@ -670,7 +692,7 @@ class CtpTdApi(TdApi):
#----------------------------------------------------------------------
def onRspQryInvestor(self, data, error, n, last):
"""投资者查询回报"""
""""""
pass
#----------------------------------------------------------------------
@ -715,15 +737,7 @@ class CtpTdApi(TdApi):
contract.strikePrice = data['StrikePrice']
contract.underlyingSymbol = data['UnderlyingInstrID']
# 合约类型
if data['ProductClass'] == '1':
contract.productClass = PRODUCT_FUTURES
elif data['ProductClass'] == '2':
contract.productClass = PRODUCT_OPTION
elif data['ProductClass'] == '3':
contract.productClass = PRODUCT_COMBINATION
else:
contract.productClass = PRODUCT_UNKNOWN
contract.productClass = productClassMapReverse.get(data['ProductClass'], PRODUCT_UNKNOWN)
# 期权类型
if data['OptionsType'] == '1':
@ -731,6 +745,10 @@ class CtpTdApi(TdApi):
elif data['OptionsType'] == '2':
contract.optionType = OPTION_PUT
# 缓存代码和交易所的印射关系
self.symbolExchangeDict[contract.symbol] = contract.exchange
self.symbolSizeDict[contract.symbol] = contract.size
# 推送
self.gateway.onContract(contract)
@ -747,7 +765,7 @@ class CtpTdApi(TdApi):
#----------------------------------------------------------------------
def onRspQrySettlementInfo(self, data, error, n, last):
"""查询结算信息回报"""
""""""
pass
#----------------------------------------------------------------------
@ -810,6 +828,16 @@ class CtpTdApi(TdApi):
""""""
pass
#----------------------------------------------------------------------
def onRspQryProductExchRate(self, data, error, n, last):
""""""
pass
#----------------------------------------------------------------------
def onRspQryProductGroup(self, data, error, n, last):
""""""
pass
#----------------------------------------------------------------------
def onRspQryOptionInstrTradeCost(self, data, error, n, last):
""""""
@ -835,6 +863,36 @@ class CtpTdApi(TdApi):
""""""
pass
#----------------------------------------------------------------------
def onRspQryLock(self, data, error, n, last):
""""""
pass
#----------------------------------------------------------------------
def onRspQryLockPosition(self, data, error, n, last):
""""""
pass
#----------------------------------------------------------------------
def onRspQryInvestorLevel(self, data, error, n, last):
""""""
pass
#----------------------------------------------------------------------
def onRspQryExecFreeze(self, data, error, n, last):
""""""
pass
#----------------------------------------------------------------------
def onRspQryCombInstrumentGuard(self, data, error, n, last):
""""""
pass
#----------------------------------------------------------------------
def onRspQryCombAction(self, data, error, n, last):
""""""
pass
#----------------------------------------------------------------------
def onRspQryTransferSerial(self, data, error, n, last):
""""""
@ -1023,6 +1081,31 @@ class CtpTdApi(TdApi):
""""""
pass
#----------------------------------------------------------------------
def onRtnCFMMCTradingAccountToken(self, data):
""""""
pass
#----------------------------------------------------------------------
def onRtnLock(self, data):
""""""
pass
#----------------------------------------------------------------------
def onErrRtnLockInsert(self, data, error):
""""""
pass
#----------------------------------------------------------------------
def onRtnCombAction(self, data):
""""""
pass
#----------------------------------------------------------------------
def onErrRtnCombActionInsert(self, data, error):
""""""
pass
#----------------------------------------------------------------------
def onRspQryContractBank(self, data, error, n, last):
""""""
@ -1053,6 +1136,11 @@ class CtpTdApi(TdApi):
""""""
pass
#----------------------------------------------------------------------
def onRspQueryCFMMCTradingAccountToken(self, data, error, n, last):
""""""
pass
#----------------------------------------------------------------------
def onRtnFromBankToFutureByBank(self, data):
""""""
@ -1163,6 +1251,7 @@ class CtpTdApi(TdApi):
""""""
pass
#----------------------------------------------------------------------
def connect(self, userID, password, brokerID, address):
"""初始化连接"""
@ -1313,15 +1402,16 @@ class PositionBuffer(object):
self.pos = pos
#----------------------------------------------------------------------
def updateBuffer(self, data):
"""更新缓存,返回更新后的持仓数据"""
def updateShfeBuffer(self, data, size):
"""更新上期所缓存,返回更新后的持仓数据"""
# 昨仓和今仓的数据更新是分在两条记录里的,因此需要判断检查该条记录对应仓位
if data['TodayPosition']:
self.todayPosition = data['Position']
self.todayPositionCost = data['PositionCost']
elif data['YdPosition']:
# 因为今仓字段TodayPosition可能变为0被全部平仓因此分辨今昨仓需要用YdPosition字段
if data['YdPosition']:
self.ydPosition = data['Position']
self.ydPositionCost = data['PositionCost']
else:
self.todayPosition = data['Position']
self.todayPositionCost = data['PositionCost']
# 持仓的昨仓和今仓相加后为总持仓
self.pos.position = self.todayPosition + self.ydPosition
@ -1330,13 +1420,28 @@ class PositionBuffer(object):
# 如果手头还有持仓,则通过加权平均方式计算持仓均价
if self.todayPosition or self.ydPosition:
self.pos.price = ((self.todayPositionCost + self.ydPositionCost)/
(self.todayPosition + self.ydPosition))
((self.todayPosition + self.ydPosition) * size))
# 否则价格为0
else:
self.pos.price = 0
return copy(self.pos)
#----------------------------------------------------------------------
def updateBuffer(self, data, size):
"""更新其他交易所的缓存,返回更新后的持仓数据"""
# 其他交易所并不区分今昨因此只关心总仓位昨仓设为0
self.pos.position = data['Position']
self.pos.ydPosition = 0
if data['Position']:
self.pos.price = data['PositionCost'] / (data['Position'] * size)
else:
self.pos.price = 0
return copy(self.pos)
#----------------------------------------------------------------------
def test():
"""测试"""

View File

@ -3,7 +3,7 @@
"tick":
[
["IF1605", "SGIT"],
["m1609", "XSPEED"],
["IF1606", "SGIT"],
["IH1606", "SGIT"],
["IH1606", "SGIT"],

View File

@ -11,6 +11,8 @@ import os
import copy
from collections import OrderedDict
from datetime import datetime, timedelta
from Queue import Queue
from threading import Thread
from eventEngine import *
from vtGateway import VtSubscribeReq, VtLogData
@ -43,6 +45,11 @@ class DrEngine(object):
# K线对象字典
self.barDict = {}
# 负责执行数据库插入的单独线程相关
self.active = False # 工作状态
self.queue = Queue() # 队列
self.thread = Thread(target=self.run) # 线程
# 载入设置,订阅行情
self.loadSetting()
@ -112,6 +119,9 @@ class DrEngine(object):
for activeSymbol, vtSymbol in d.items():
self.activeSymbolDict[vtSymbol] = activeSymbol
# 启动数据插入线程
self.start()
# 注册事件监听
self.registerEvent()
@ -187,7 +197,29 @@ class DrEngine(object):
#----------------------------------------------------------------------
def insertData(self, dbName, collectionName, data):
"""插入数据到数据库这里的data可以是CtaTickData或者CtaBarData"""
self.mainEngine.dbInsert(dbName, collectionName, data.__dict__)
self.queue.put((dbName, collectionName, data.__dict__))
#----------------------------------------------------------------------
def run(self):
"""运行插入线程"""
while self.active:
try:
dbName, collectionName, d = self.queue.get(block=True, timeout=1)
self.mainEngine.dbInsert(dbName, collectionName, d)
except Empty:
pass
#----------------------------------------------------------------------
def start(self):
"""启动"""
self.active = True
self.thread.start()
#----------------------------------------------------------------------
def stop(self):
"""退出"""
if self.active:
self.active = False
self.thread.join()
#----------------------------------------------------------------------
def writeDrLog(self, content):

View File

@ -733,6 +733,10 @@ class LtsTdApi(TdApi):
os.makedirs(path)
self.createFtdcTraderApi(path)
# 设置数据同步模式为推送从今日开始所有数据
self.subscribePrivateTopic(0)
self.subscribePublicTopic(0)
# 注册服务器地址
self.registerFront(self.address)
@ -1153,7 +1157,7 @@ class LtsQryApi(QryApi):
# 持仓均价
if pos.position:
pos.price = data['PositionCost'] / pos.position
pos.price = data['OpenCost'] / pos.position
# VT系统持仓名
pos.vtPositionName = '.'.join([pos.vtSymbol, pos.direction])

View File

@ -0,0 +1,7 @@
{
"host": "CNY",
"apiKey": "OKCOIN网站申请",
"secretKey": "OKCOIN网站申请",
"trace": false,
"leverage": 20
}

View File

View File

@ -0,0 +1,657 @@
# encoding: UTF-8
'''
vn.okcoin的gateway接入
注意
1. 该接口尚处于测试阶段用于实盘请谨慎
2. 目前仅支持USD和CNY的现货交易USD的期货合约交易暂不支持
'''
import os
import json
from datetime import datetime
from copy import copy
from threading import Condition
import vnokcoin
from vtGateway import *
# 价格类型映射
priceTypeMap = {}
priceTypeMap['buy'] = (DIRECTION_LONG, PRICETYPE_LIMITPRICE)
priceTypeMap['buy_market'] = (DIRECTION_LONG, PRICETYPE_MARKETPRICE)
priceTypeMap['sell'] = (DIRECTION_SHORT, PRICETYPE_LIMITPRICE)
priceTypeMap['sell_market'] = (DIRECTION_SHORT, PRICETYPE_MARKETPRICE)
priceTypeMapReverse = {v: k for k, v in priceTypeMap.items()}
# 方向类型映射
directionMap = {}
directionMapReverse = {v: k for k, v in directionMap.items()}
# 委托状态印射
statusMap = {}
statusMap[-1] = STATUS_CANCELLED
statusMap[0] = STATUS_NOTTRADED
statusMap[1] = STATUS_PARTTRADED
statusMap[2] = STATUS_ALLTRADED
statusMap[4] = STATUS_UNKNOWN
############################################
## 交易合约代码
############################################
# USD
BTC_USD_SPOT = 'BTC_USD_SPOT'
BTC_USD_THISWEEK = 'BTC_USD_THISWEEK'
BTC_USD_NEXTWEEK = 'BTC_USD_NEXTWEEK'
BTC_USD_QUARTER = 'BTC_USD_QUARTER'
LTC_USD_SPOT = 'LTC_USD_SPOT'
LTC_USD_THISWEEK = 'LTC_USD_THISWEEK'
LTC_USD_NEXTWEEK = 'LTC_USD_NEXTWEEK'
LTC_USD_QUARTER = 'LTC_USD_QUARTER'
# CNY
BTC_CNY_SPOT = 'BTC_CNY_SPOT'
LTC_CNY_SPOT = 'LTC_CNY_SPOT'
# 印射字典
spotSymbolMap = {}
spotSymbolMap['ltc_usd'] = LTC_USD_SPOT
spotSymbolMap['btc_usd'] = BTC_USD_SPOT
spotSymbolMap['ltc_cny'] = LTC_CNY_SPOT
spotSymbolMap['btc_cny'] = BTC_CNY_SPOT
spotSymbolMapReverse = {v: k for k, v in spotSymbolMap.items()}
############################################
## Channel和Symbol的印射
############################################
channelSymbolMap = {}
# USD
channelSymbolMap['ok_sub_spotusd_btc_ticker'] = BTC_USD_SPOT
channelSymbolMap['ok_sub_spotusd_ltc_ticker'] = LTC_USD_SPOT
channelSymbolMap['ok_sub_spotusd_btc_depth_20'] = BTC_USD_SPOT
channelSymbolMap['ok_sub_spotusd_ltc_depth_20'] = LTC_USD_SPOT
# CNY
channelSymbolMap['ok_sub_spotcny_btc_ticker'] = BTC_CNY_SPOT
channelSymbolMap['ok_sub_spotcny_ltc_ticker'] = LTC_CNY_SPOT
channelSymbolMap['ok_sub_spotcny_btc_depth_20'] = BTC_CNY_SPOT
channelSymbolMap['ok_sub_spotcny_ltc_depth_20'] = LTC_CNY_SPOT
########################################################################
class OkcoinGateway(VtGateway):
"""OkCoin接口"""
#----------------------------------------------------------------------
def __init__(self, eventEngine, gatewayName='OKCOIN'):
"""Constructor"""
super(OkcoinGateway, self).__init__(eventEngine, gatewayName)
self.api = Api(self)
self.leverage = 0
self.connected = False
#----------------------------------------------------------------------
def connect(self):
"""连接"""
# 载入json文件
fileName = self.gatewayName + '_connect.json'
fileName = os.getcwd() + '/okcoinGateway/' + fileName
try:
f = file(fileName)
except IOError:
log = VtLogData()
log.gatewayName = self.gatewayName
log.logContent = u'读取连接配置出错,请检查'
self.onLog(log)
return
# 解析json文件
setting = json.load(f)
try:
host = str(setting['host'])
apiKey = str(setting['apiKey'])
secretKey = str(setting['secretKey'])
trace = setting['trace']
leverage = setting['leverage']
except KeyError:
log = VtLogData()
log.gatewayName = self.gatewayName
log.logContent = u'连接配置缺少字段,请检查'
self.onLog(log)
return
# 初始化接口
self.leverage = leverage
if host == 'CNY':
host = vnokcoin.OKCOIN_CNY
else:
host = vnokcoin.OKCOIN_USD
self.api.connect(host, apiKey, secretKey, trace)
log = VtLogData()
log.gatewayName = self.gatewayName
log.logContent = u'接口初始化成功'
self.onLog(log)
# 启动查询
self.initQuery()
self.startQuery()
#----------------------------------------------------------------------
def subscribe(self, subscribeReq):
"""订阅行情"""
pass
#----------------------------------------------------------------------
def sendOrder(self, orderReq):
"""发单"""
return self.api.spotSendOrder(orderReq)
#----------------------------------------------------------------------
def cancelOrder(self, cancelOrderReq):
"""撤单"""
self.api.spotCancel(cancelOrderReq)
#----------------------------------------------------------------------
def qryAccount(self):
"""查询账户资金"""
self.api.spotUserInfo()
#----------------------------------------------------------------------
def qryPosition(self):
"""查询持仓"""
pass
#----------------------------------------------------------------------
def close(self):
"""关闭"""
pass
#----------------------------------------------------------------------
def initQuery(self):
"""初始化连续查询"""
if self.qryEnabled:
# 需要循环的查询函数列表
self.qryFunctionList = [self.qryAccount]
self.qryCount = 0 # 查询触发倒计时
self.qryTrigger = 2 # 查询触发点
self.qryNextFunction = 0 # 上次运行的查询函数索引
self.startQuery()
#----------------------------------------------------------------------
def query(self, event):
"""注册到事件处理引擎上的查询函数"""
self.qryCount += 1
if self.qryCount > self.qryTrigger:
# 清空倒计时
self.qryCount = 0
# 执行查询函数
function = self.qryFunctionList[self.qryNextFunction]
function()
# 计算下次查询函数的索引如果超过了列表长度则重新设为0
self.qryNextFunction += 1
if self.qryNextFunction == len(self.qryFunctionList):
self.qryNextFunction = 0
#----------------------------------------------------------------------
def startQuery(self):
"""启动连续查询"""
self.eventEngine.register(EVENT_TIMER, self.query)
#----------------------------------------------------------------------
def setQryEnabled(self, qryEnabled):
"""设置是否要启动循环查询"""
self.qryEnabled = qryEnabled
########################################################################
class Api(vnokcoin.OkCoinApi):
"""OkCoin的API实现"""
#----------------------------------------------------------------------
def __init__(self, gateway):
"""Constructor"""
super(Api, self).__init__()
self.gateway = gateway # gateway对象
self.gatewayName = gateway.gatewayName # gateway对象名称
self.cbDict = {}
self.tickDict = {}
self.orderDict = {}
self.lastOrderID = ''
self.orderCondition = Condition()
self.initCallback()
#----------------------------------------------------------------------
def onMessage(self, ws, evt):
"""信息推送"""
data = self.readData(evt)[0]
channel = data['channel']
callback = self.cbDict[channel]
callback(data)
#----------------------------------------------------------------------
def onError(self, ws, evt):
"""错误推送"""
error = VtErrorData()
error.gatewayName = self.gatewayName
error.errorMsg = str(evt)
self.gateway.onError(error)
#----------------------------------------------------------------------
def onClose(self, ws):
"""接口断开"""
self.gateway.connected = True
self.writeLog(u'服务器连接断开')
#----------------------------------------------------------------------
def onOpen(self, ws):
self.gateway.connected = True
self.writeLog(u'服务器连接成功')
# 连接后查询账户和委托数据
self.spotUserInfo()
self.spotOrderInfo(vnokcoin.TRADING_SYMBOL_LTC, '-1')
self.spotOrderInfo(vnokcoin.TRADING_SYMBOL_BTC, '-1')
# 连接后订阅现货的成交和账户数据
self.subscribeSpotTrades()
self.subscribeSpotUserInfo()
self.subscribeSpotTicker(vnokcoin.SYMBOL_BTC)
self.subscribeSpotTicker(vnokcoin.SYMBOL_LTC)
self.subscribeSpotDepth(vnokcoin.SYMBOL_BTC, vnokcoin.DEPTH_20)
self.subscribeSpotDepth(vnokcoin.SYMBOL_LTC, vnokcoin.DEPTH_20)
# 如果连接的是USD网站则订阅期货相关回报数据
if self.currency == vnokcoin.CURRENCY_USD:
self.subscribeFutureTrades()
self.subscribeFutureUserInfo()
self.subscribeFuturePositions()
# 返回合约信息
if self.currency == vnokcoin.CURRENCY_CNY:
l = self.generateCnyContract()
else:
l = self.generateUsdContract()
for contract in l:
contract.gatewayName = self.gatewayName
self.gateway.onContract(contract)
#----------------------------------------------------------------------
def writeLog(self, content):
"""快速记录日志"""
log = VtLogData()
log.gatewayName = self.gatewayName
log.logContent = content
self.gateway.onLog(log)
#----------------------------------------------------------------------
def initCallback(self):
"""初始化回调函数"""
# USD_SPOT
self.cbDict['ok_sub_spotusd_btc_ticker'] = self.onTicker
self.cbDict['ok_sub_spotusd_ltc_ticker'] = self.onTicker
self.cbDict['ok_sub_spotusd_btc_depth_20'] = self.onDepth
self.cbDict['ok_sub_spotusd_ltc_depth_20'] = self.onDepth
self.cbDict['ok_spotusd_userinfo'] = self.onSpotUserInfo
self.cbDict['ok_spotusd_orderinfo'] = self.onSpotOrderInfo
self.cbDict['ok_sub_spotusd_userinfo'] = self.onSpotSubUserInfo
self.cbDict['ok_sub_spotusd_trades'] = self.onSpotSubTrades
self.cbDict['ok_spotusd_trade'] = self.onSpotTrade
self.cbDict['ok_spotusd_cancel_order'] = self.onSpotCancelOrder
# CNY_SPOT
self.cbDict['ok_sub_spotcny_btc_ticker'] = self.onTicker
self.cbDict['ok_sub_spotcny_ltc_ticker'] = self.onTicker
self.cbDict['ok_sub_spotcny_btc_depth_20'] = self.onDepth
self.cbDict['ok_sub_spotcny_ltc_depth_20'] = self.onDepth
self.cbDict['ok_spotcny_userinfo'] = self.onSpotUserInfo
self.cbDict['ok_spotcny_orderinfo'] = self.onSpotOrderInfo
self.cbDict['ok_sub_spotcny_userinfo'] = self.onSpotSubUserInfo
self.cbDict['ok_sub_spotcny_trades'] = self.onSpotSubTrades
self.cbDict['ok_spotcny_trade'] = self.onSpotTrade
self.cbDict['ok_spotcny_cancel_order'] = self.onSpotCancelOrder
# USD_FUTURES
#----------------------------------------------------------------------
def onTicker(self, data):
""""""
if 'data' not in data:
return
channel = data['channel']
symbol = channelSymbolMap[channel]
if symbol not in self.tickDict:
tick = VtTickData()
tick.symbol = symbol
tick.vtSymbol = symbol
tick.gatewayName = self.gatewayName
self.tickDict[symbol] = tick
else:
tick = self.tickDict[symbol]
rawData = data['data']
tick.highPrice = float(rawData['high'])
tick.lowPrice = float(rawData['low'])
tick.lastPrice = float(rawData['last'])
tick.volume = float(rawData['vol'].replace(',', ''))
tick.date, tick.time = generateDateTime(rawData['timestamp'])
newtick = copy(tick)
self.gateway.onTick(newtick)
#----------------------------------------------------------------------
def onDepth(self, data):
""""""
if 'data' not in data:
return
channel = data['channel']
symbol = channelSymbolMap[channel]
if symbol not in self.tickDict:
tick = VtTickData()
tick.symbol = symbol
tick.vtSymbol = symbol
tick.gatewayName = self.gatewayName
self.tickDict[symbol] = tick
else:
tick = self.tickDict[symbol]
if 'data' not in data:
return
rawData = data['data']
tick.bidPrice1, tick.bidVolume1 = rawData['bids'][0]
tick.bidPrice2, tick.bidVolume2 = rawData['bids'][1]
tick.bidPrice3, tick.bidVolume3 = rawData['bids'][2]
tick.bidPrice4, tick.bidVolume4 = rawData['bids'][3]
tick.bidPrice5, tick.bidVolume5 = rawData['bids'][4]
tick.askPrice1, tick.askVolume1 = rawData['asks'][0]
tick.askPrice2, tick.askVolume2 = rawData['asks'][1]
tick.askPrice3, tick.askVolume3 = rawData['asks'][2]
tick.askPrice4, tick.askVolume4 = rawData['asks'][3]
tick.askPrice5, tick.askVolume5 = rawData['asks'][4]
newtick = copy(tick)
self.gateway.onTick(newtick)
#----------------------------------------------------------------------
def onSpotUserInfo(self, data):
"""现货账户资金推送"""
rawData = data['data']
info = rawData['info']
funds = rawData['info']['funds']
# 持仓信息
for symbol in ['btc', 'ltc', self.currency]:
if symbol in funds['free']:
pos = VtPositionData()
pos.gatewayName = self.gatewayName
pos.symbol = symbol
pos.vtSymbol = symbol
pos.vtPositionName = symbol
pos.direction = DIRECTION_NET
pos.frozen = float(funds['freezed'][symbol])
pos.position = pos.frozen + float(funds['free'][symbol])
self.gateway.onPosition(pos)
# 账户资金
account = VtAccountData()
account.gatewayName = self.gatewayName
account.accountID = self.gatewayName
account.vtAccountID = account.accountID
account.balance = float(funds['asset']['net'])
self.gateway.onAccount(account)
#----------------------------------------------------------------------
def onSpotSubUserInfo(self, data):
"""现货账户资金推送"""
if 'data' not in data:
return
rawData = data['data']
info = rawData['info']
# 持仓信息
for symbol in ['btc', 'ltc', self.currency]:
if symbol in info['free']:
pos = VtPositionData()
pos.gatewayName = self.gatewayName
pos.symbol = symbol
pos.vtSymbol = symbol
pos.vtPositionName = symbol
pos.direction = DIRECTION_NET
pos.frozen = float(info['freezed'][symbol])
pos.position = pos.frozen + float(info['free'][symbol])
self.gateway.onPosition(pos)
#----------------------------------------------------------------------
def onSpotSubTrades(self, data):
"""成交和委托推送"""
if 'data' not in data:
return
rawData = data['data']
# 委托信息
orderID = str(rawData['orderId'])
if orderID not in self.orderDict:
order = VtOrderData()
order.gatewayName = self.gatewayName
order.symbol = spotSymbolMap[rawData['symbol']]
order.vtSymbol = order.symbol
order.orderID = str(rawData['orderId'])
order.vtOrderID = '.'.join([self.gatewayName, order.orderID])
order.price = float(rawData['tradeUnitPrice'])
order.totalVolume = float(rawData['tradeAmount'])
order.direction, priceType = priceTypeMap[rawData['tradeType']]
self.orderDict[orderID] = order
else:
order = self.orderDict[orderID]
order.tradedVolume = float(rawData['completedTradeAmount'])
order.status = statusMap[rawData['status']]
self.gateway.onOrder(copy(order))
# 成交信息
if 'sigTradeAmount' in rawData and float(rawData['sigTradeAmount'])>0:
trade = VtTradeData()
trade.gatewayName = self.gatewayName
trade.symbol = spotSymbolMap[rawData['symbol']]
trade.vtSymbol = order.symbol
trade.tradeID = str(rawData['id'])
trade.vtTradeID = '.'.join([self.gatewayName, trade.tradeID])
trade.orderID = str(rawData['orderId'])
trade.vtOrderID = '.'.join([self.gatewayName, trade.orderID])
trade.price = float(rawData['sigTradePrice'])
trade.volume = float(rawData['sigTradeAmount'])
trade.direction, priceType = priceTypeMap[rawData['tradeType']]
trade.tradeTime = datetime.now().strftime('%H:%M:%S')
self.gateway.onTrade(trade)
#----------------------------------------------------------------------
def onSpotOrderInfo(self, data):
"""委托信息查询回调"""
rawData = data['data']
for d in rawData['orders']:
orderID = str(d['order_id'])
if orderID not in self.orderDict:
order = VtOrderData()
order.gatewayName = self.gatewayName
order.symbol = spotSymbolMap[d['symbol']]
order.vtSymbol = order.symbol
order.orderID = str(d['order_id'])
order.vtOrderID = '.'.join([self.gatewayName, order.orderID])
order.price = d['price']
order.totalVolume = d['amount']
order.direction, priceType = priceTypeMap[d['type']]
self.orderDict[orderID] = order
else:
order = self.orderDict[orderID]
order.tradedVolume = d['deal_amount']
order.status = statusMap[d['status']]
self.gateway.onOrder(copy(order))
#----------------------------------------------------------------------
def generateSpecificContract(self, contract, symbol):
"""生成合约"""
new = copy(contract)
new.symbol = symbol
new.vtSymbol = symbol
new.name = symbol
return new
#----------------------------------------------------------------------
def generateCnyContract(self):
"""生成CNY合约信息"""
contractList = []
contract = VtContractData()
contract.exchange = EXCHANGE_OKCOIN
contract.productClass = PRODUCT_SPOT
contract.size = 1
contract.priceTick = 0.01
contractList.append(self.generateSpecificContract(contract, BTC_CNY_SPOT))
contractList.append(self.generateSpecificContract(contract, LTC_CNY_SPOT))
return contractList
#----------------------------------------------------------------------
def generateUsdContract(self):
"""生成USD合约信息"""
contractList = []
# 现货
contract = VtContractData()
contract.exchange = EXCHANGE_OKCOIN
contract.productClass = PRODUCT_SPOT
contract.size = 1
contract.priceTick = 0.01
contractList.append(self.generateSpecificContract(contract, BTC_USD_SPOT))
contractList.append(self.generateSpecificContract(contract, LTC_USD_SPOT))
# 期货
contract.productClass = PRODUCT_FUTURES
contractList.append(self.generateSpecificContract(contract, BTC_USD_THISWEEK))
contractList.append(self.generateSpecificContract(contract, BTC_USD_NEXTWEEK))
contractList.append(self.generateSpecificContract(contract, BTC_USD_QUARTER))
contractList.append(self.generateSpecificContract(contract, LTC_USD_THISWEEK))
contractList.append(self.generateSpecificContract(contract, LTC_USD_NEXTWEEK))
contractList.append(self.generateSpecificContract(contract, LTC_USD_QUARTER))
return contractList
#----------------------------------------------------------------------
def onSpotTrade(self, data):
"""委托回报"""
rawData = data['data']
self.lastOrderID = rawData['order_id']
# 收到委托号后,通知发送委托的线程返回委托号
self.orderCondition.acquire()
self.orderCondition.notify()
self.orderCondition.release()
#----------------------------------------------------------------------
def onSpotCancelOrder(self, data):
"""撤单回报"""
pass
#----------------------------------------------------------------------
def spotSendOrder(self, req):
"""发单"""
symbol = spotSymbolMapReverse[req.symbol][:4]
type_ = priceTypeMapReverse[(req.direction, req.priceType)]
self.spotTrade(symbol, type_, str(req.price), str(req.volume))
# 等待发单回调推送委托号信息
self.orderCondition.acquire()
self.orderCondition.wait()
self.orderCondition.release()
vtOrderID = '.'.join([self.gatewayName, self.lastOrderID])
self.lastOrderID = ''
return vtOrderID
#----------------------------------------------------------------------
def spotCancel(self, req):
"""撤单"""
symbol = spotSymbolMapReverse[req.symbol][:4]
self.spotCancelOrder(symbol, req.orderID)
#----------------------------------------------------------------------
def generateDateTime(s):
"""生成时间"""
dt = datetime.fromtimestamp(float(s)/1e3)
time = dt.strftime("%H:%M:%S.%f")
date = dt.strftime("%Y%m%d")
return date, time

View File

@ -0,0 +1,382 @@
# encoding: UTF-8
import hashlib
import zlib
import json
from time import sleep
from threading import Thread
import websocket
# OKCOIN网站
OKCOIN_CNY = 'wss://real.okcoin.cn:10440/websocket/okcoinapi'
OKCOIN_USD = 'wss://real.okcoin.com:10440/websocket/okcoinapi'
# 账户货币代码
CURRENCY_CNY = 'cny'
CURRENCY_USD = 'usd'
# 电子货币代码
SYMBOL_BTC = 'btc'
SYMBOL_LTC = 'ltc'
# 行情深度
DEPTH_20 = 20
DEPTH_60 = 60
# K线时间区间
INTERVAL_1M = '1min'
INTERVAL_3M = '3min'
INTERVAL_5M = '5min'
INTERVAL_15M = '15min'
INTERVAL_30M = '30min'
INTERVAL_1H = '1hour'
INTERVAL_2H = '2hour'
INTERVAL_4H = '4hour'
INTERVAL_6H = '6hour'
INTERVAL_1D = 'day'
INTERVAL_3D = '3day'
INTERVAL_1W = 'week'
# 交易代码,需要后缀货币名才能完整
TRADING_SYMBOL_BTC = 'btc_'
TRADING_SYMBOL_LTC = 'ltc_'
# 委托类型
TYPE_BUY = 'buy'
TYPE_SELL = 'sell'
TYPE_BUY_MARKET = 'buy_market'
TYPE_SELL_MARKET = 'sell_market'
# 期货合约到期类型
FUTURE_EXPIRY_THIS_WEEK = 'this_week'
FUTURE_EXPIRY_NEXT_WEEK = 'next_week'
FUTURE_EXPIRY_QUARTER = 'quarter'
# 期货委托类型
FUTURE_TYPE_LONG = 1
FUTURE_TYPE_SHORT = 2
FUTURE_TYPE_SELL = 3
FUTURE_TYPE_COVER = 4
# 期货是否用现价
FUTURE_ORDER_MARKET = 1
FUTURE_ORDER_LIMIT = 0
# 期货杠杆
FUTURE_LEVERAGE_10 = 10
FUTURE_LEVERAGE_20 = 20
# 委托状态
ORDER_STATUS_NOTTRADED = 0
ORDER_STATUS_PARTTRADED = 1
ORDER_STATUS_ALLTRADED = 2
ORDER_STATUS_CANCELLED = -1
ORDER_STATUS_CANCELLING = 4
########################################################################
class OkCoinApi(object):
"""基于Websocket的API对象"""
#----------------------------------------------------------------------
def __init__(self):
"""Constructor"""
self.apiKey = '' # 用户名
self.secretKey = '' # 密码
self.host = '' # 服务器地址
self.currency = '' # 货币类型usd或者cny
self.ws = None # websocket应用对象
self.thread = None # 工作线程
#######################
## 通用函数
#######################
#----------------------------------------------------------------------
def readData(self, evt):
"""解压缩推送收到的数据"""
# 创建解压器
decompress = zlib.decompressobj(-zlib.MAX_WBITS)
# 将原始数据解压成字符串
inflated = decompress.decompress(evt) + decompress.flush()
# 通过json解析字符串
data = json.loads(inflated)
return data
#----------------------------------------------------------------------
def generateSign(self, params):
"""生成签名"""
l = []
for key in sorted(params.keys()):
l.append('%s=%s' %(key, params[key]))
l.append('secret_key=%s' %self.secretKey)
sign = '&'.join(l)
return hashlib.md5(sign.encode('utf-8')).hexdigest().upper()
#----------------------------------------------------------------------
def onMessage(self, ws, evt):
"""信息推送"""
print 'onMessage'
data = self.readData(evt)
print data
#----------------------------------------------------------------------
def onError(self, ws, evt):
"""错误推送"""
print 'onError'
print evt
#----------------------------------------------------------------------
def onClose(self, ws):
"""接口断开"""
print 'onClose'
#----------------------------------------------------------------------
def onOpen(self, ws):
"""接口打开"""
print 'onOpen'
#----------------------------------------------------------------------
def connect(self, host, apiKey, secretKey, trace=False):
"""连接服务器"""
self.host = host
self.apiKey = apiKey
self.secretKey = secretKey
if self.host == OKCOIN_CNY:
self.currency = CURRENCY_CNY
else:
self.currency = CURRENCY_USD
websocket.enableTrace(trace)
self.ws = websocket.WebSocketApp(host,
on_message=self.onMessage,
on_error=self.onError,
on_close=self.onClose,
on_open=self.onOpen)
self.thread = Thread(target=self.ws.run_forever)
self.thread.start()
#----------------------------------------------------------------------
def sendMarketDataRequest(self, channel):
"""发送行情请求"""
# 生成请求
d = {}
d['event'] = 'addChannel'
d['binary'] = True
d['channel'] = channel
# 使用json打包并发送
j = json.dumps(d)
self.ws.send(j)
#----------------------------------------------------------------------
def sendTradingRequest(self, channel, params):
"""发送交易请求"""
# 在参数字典中加上api_key和签名字段
params['api_key'] = self.apiKey
params['sign'] = self.generateSign(params)
# 生成请求
d = {}
d['event'] = 'addChannel'
d['binary'] = True
d['channel'] = channel
d['parameters'] = params
# 使用json打包并发送
j = json.dumps(d)
self.ws.send(j)
#######################
## 现货相关
#######################
#----------------------------------------------------------------------
def subscribeSpotTicker(self, symbol):
"""订阅现货普通报价"""
self.sendMarketDataRequest('ok_sub_spot%s_%s_ticker' %(self.currency, symbol))
#----------------------------------------------------------------------
def subscribeSpotDepth(self, symbol, depth):
"""订阅现货深度报价"""
self.sendMarketDataRequest('ok_sub_spot%s_%s_depth_%s' %(self.currency, symbol, depth))
#----------------------------------------------------------------------
def subscribeSpotTradeData(self, symbol):
"""订阅现货成交记录"""
self.sendMarketDataRequest('ok_sub_spot%s_%s_trades' %(self.currency, symbol))
#----------------------------------------------------------------------
def subscribeSpotKline(self, symbol, interval):
"""订阅现货K线"""
self.sendMarketDataRequest('ok_sub_spot%s_%s_kline_%s' %(self.currency, symbol, interval))
#----------------------------------------------------------------------
def spotTrade(self, symbol, type_, price, amount):
"""现货委托"""
params = {}
params['symbol'] = str(symbol+self.currency)
params['type'] = str(type_)
params['price'] = str(price)
params['amount'] = str(amount)
channel = 'ok_spot%s_trade' %(self.currency)
self.sendTradingRequest(channel, params)
#----------------------------------------------------------------------
def spotCancelOrder(self, symbol, orderid):
"""现货撤单"""
params = {}
params['symbol'] = str(symbol+self.currency)
params['order_id'] = str(orderid)
channel = 'ok_spot%s_cancel_order' %(self.currency)
self.sendTradingRequest(channel, params)
#----------------------------------------------------------------------
def spotUserInfo(self):
"""查询现货账户"""
channel = 'ok_spot%s_userinfo' %(self.currency)
self.sendTradingRequest(channel, {})
#----------------------------------------------------------------------
def spotOrderInfo(self, symbol, orderid):
"""查询现货委托信息"""
params = {}
params['symbol'] = str(symbol+self.currency)
params['order_id'] = str(orderid)
channel = 'ok_spot%s_orderinfo' %(self.currency)
self.sendTradingRequest(channel, params)
#----------------------------------------------------------------------
def subscribeSpotTrades(self):
"""订阅现货成交信息"""
channel = 'ok_sub_spot%s_trades' %(self.currency)
self.sendTradingRequest(channel, {})
#----------------------------------------------------------------------
def subscribeSpotUserInfo(self):
"""订阅现货账户信息"""
channel = 'ok_sub_spot%s_userinfo' %(self.currency)
self.sendTradingRequest(channel, {})
#######################
## 期货相关
#######################
#----------------------------------------------------------------------
def subscribeFutureTicker(self, symbol, expiry):
"""订阅期货普通报价"""
self.sendMarketDataRequest('ok_sub_future%s_%s_ticker_%s' %(self.currency, symbol, expiry))
#----------------------------------------------------------------------
def subscribeFutureDepth(self, symbol, expiry, depth):
"""订阅期货深度报价"""
self.sendMarketDataRequest('ok_sub_future%s_%s_depth_%s_%s' %(self.currency, symbol,
expiry, depth))
#----------------------------------------------------------------------
def subscribeFutureTradeData(self, symbol, expiry):
"""订阅期货成交记录"""
self.sendMarketDataRequest('ok_sub_future%s_%s_trade_%s' %(self.currency, symbol, expiry))
#----------------------------------------------------------------------
def subscribeFutureKline(self, symbol, expiry, interval):
"""订阅期货K线"""
self.sendMarketDataRequest('ok_sub_future%s_%s_kline_%s_%s' %(self.currency, symbol,
expiry, interval))
#----------------------------------------------------------------------
def subscribeFutureIndex(self, symbol):
"""订阅期货指数"""
self.sendMarketDataRequest('ok_sub_future%s_%s_index' %(self.currency, symbol))
#----------------------------------------------------------------------
def futureTrade(self, symbol, expiry, type_, price, amount, order, leverage):
"""期货委托"""
params = {}
params['symbol'] = str(symbol+self.currency)
params['type'] = str(type_)
params['price'] = str(price)
params['amount'] = str(amount)
params['contract_type'] = str(expiry)
params['match_price'] = str(order)
params['lever_rate'] = str(leverage)
channel = 'ok_future%s_trade' %(self.currency)
self.sendTradingRequest(channel, params)
#----------------------------------------------------------------------
def futureCancelOrder(self, symbol, expiry, orderid):
"""期货撤单"""
params = {}
params['symbol'] = str(symbol+self.currency)
params['order_id'] = str(orderid)
params['contract_type'] = str(expiry)
channel = 'ok_future%s_cancel_order' %(self.currency)
self.sendTradingRequest(channel, params)
#----------------------------------------------------------------------
def futureUserInfo(self):
"""查询期货账户"""
channel = 'ok_future%s_userinfo' %(self.currency)
self.sendTradingRequest(channel, {})
#----------------------------------------------------------------------
def futureOrderInfo(self, symbol, expiry, orderid, status, page, length):
"""查询期货委托信息"""
params = {}
params['symbol'] = str(symbol+self.currency)
params['order_id'] = str(orderid)
params['contract_type'] = expiry
params['status'] = status
params['current_page'] = page
params['page_length'] = length
channel = 'ok_future%s_orderinfo'
self.sendTradingRequest(channel, params)
#----------------------------------------------------------------------
def subscribeFutureTrades(self):
"""订阅期货成交信息"""
channel = 'ok_sub_future%s_trades' %(self.currency)
self.sendTradingRequest(channel, {})
#----------------------------------------------------------------------
def subscribeFutureUserInfo(self):
"""订阅期货账户信息"""
channel = 'ok_sub_future%s_userinfo' %(self.currency)
self.sendTradingRequest(channel, {})
#----------------------------------------------------------------------
def subscribeFuturePositions(self):
"""订阅期货持仓信息"""
channel = 'ok_sub_future%s_positions' %(self.currency)
self.sendTradingRequest(channel, {})

View File

@ -70,24 +70,6 @@ class DirectionCell(QtGui.QTableWidgetItem):
self.setText(text)
########################################################################
class NameCell(QtGui.QTableWidgetItem):
"""用来显示合约中文名的单元格"""
#----------------------------------------------------------------------
def __init__(self, text=None, mainEngine=None):
"""Constructor"""
super(NameCell, self).__init__()
self.data = None
if text:
self.setContent(text)
#----------------------------------------------------------------------
def setContent(self, text):
"""设置内容"""
self.setText(text)
########################################################################
class NameCell(QtGui.QTableWidgetItem):
"""用来显示合约中文的单元格"""
@ -459,6 +441,7 @@ class ErrorMonitor(BasicMonitor):
super(ErrorMonitor, self).__init__(mainEngine, eventEngine, parent)
d = OrderedDict()
d['errorTime'] = {'chinese':u'错误时间', 'cellType':BasicCell}
d['errorID'] = {'chinese':u'错误代码', 'cellType':BasicCell}
d['errorMsg'] = {'chinese':u'错误信息', 'cellType':BasicCell}
d['additionalInfo'] = {'chinese':u'补充信息', 'cellType':BasicCell}

View File

@ -10,6 +10,7 @@ from riskManager.uiRmWidget import RmEngineManager
########################################################################
class MainWindow(QtGui.QMainWindow):
"""主窗口"""
signalStatusBar = QtCore.pyqtSignal(type(Event()))
#----------------------------------------------------------------------
def __init__(self, mainEngine, eventEngine):
@ -88,6 +89,9 @@ class MainWindow(QtGui.QMainWindow):
connectOandaAction = QtGui.QAction(u'连接OANDA', self)
connectOandaAction.triggered.connect(self.connectOanda)
connectOkcoinAction = QtGui.QAction(u'连接OKCOIN', self)
connectOkcoinAction.triggered.connect(self.connectOkcoin)
connectDbAction = QtGui.QAction(u'连接数据库', self)
connectDbAction.triggered.connect(self.mainEngine.dbConnect)
@ -136,6 +140,8 @@ class MainWindow(QtGui.QMainWindow):
sysMenu.addAction(connectIbAction)
if 'OANDA' in self.mainEngine.gatewayDict:
sysMenu.addAction(connectOandaAction)
if 'OKCOIN' in self.mainEngine.gatewayDict:
sysMenu.addAction(connectOkcoinAction)
sysMenu.addSeparator()
if 'Wind' in self.mainEngine.gatewayDict:
sysMenu.addAction(connectWindAction)
@ -169,7 +175,8 @@ class MainWindow(QtGui.QMainWindow):
self.sbCount = 0
self.sbTrigger = 10 # 10秒刷新一次
self.eventEngine.register(EVENT_TIMER, self.updateStatusBar)
self.signalStatusBar.connect(self.updateStatusBar)
self.eventEngine.register(EVENT_TIMER, self.signalStatusBar.emit)
#----------------------------------------------------------------------
def updateStatusBar(self, event):
@ -237,6 +244,11 @@ class MainWindow(QtGui.QMainWindow):
"""连接OANDA"""
self.mainEngine.connect('OANDA')
#----------------------------------------------------------------------
def connectOkcoin(self):
"""连接OKCOIN"""
self.mainEngine.connect('OKCOIN')
#----------------------------------------------------------------------
def test(self):
"""测试按钮用的函数"""
@ -327,8 +339,17 @@ class MainWindow(QtGui.QMainWindow):
def loadWindowSettings(self):
"""载入窗口设置"""
settings = QtCore.QSettings('vn.py', 'vn.trader')
# 这里由于PyQt4的版本不同settings.value('state')调用返回的结果可能是:
# 1. None初次调用注册表里无相应记录因此为空
# 2. QByteArray比较新的PyQt4
# 3. QVariant以下代码正确执行所需的返回结果
# 所以为了兼容考虑这里加了一个try...except如果是1、2的情况就pass
# 可能导致主界面的设置无法载入(每次退出时的保存其实是成功了)
try:
self.restoreState(settings.value('state').toByteArray())
self.restoreGeometry(settings.value('geometry').toByteArray())
except AttributeError:
pass
########################################################################

View File

@ -69,6 +69,7 @@ EXCHANGE_GLOBEX = 'GLOBEX' # CME电子交易平台
EXCHANGE_IDEALPRO = 'IDEALPRO' # IB外汇ECN
EXCHANGE_OANDA = 'OANDA' # OANDA外汇做市商
EXCHANGE_OKCOIN = 'OKCOIN' # OKCOIN比特币交易所
# 货币类型
CURRENCY_USD = 'USD' # 美元

View File

@ -115,6 +115,13 @@ class MainEngine(object):
except Exception, e:
print e
try:
from okcoinGateway.okcoinGateway import OkcoinGateway
self.addGateway(OkcoinGateway, 'OKCOIN')
self.gatewayDict['OKCOIN'].setQryEnabled(True)
except Exception, e:
print e
#----------------------------------------------------------------------
def addGateway(self, gateway, gatewayName=None):
"""创建接口"""
@ -165,7 +172,7 @@ class MainEngine(object):
"""查询特定接口的账户"""
if gatewayName in self.gatewayDict:
gateway = self.gatewayDict[gatewayName]
gateway.getAccount()
gateway.qryAccount()
else:
self.writeLog(u'接口不存在:%s' %gatewayName)
@ -174,7 +181,7 @@ class MainEngine(object):
"""查询特定接口的持仓"""
if gatewayName in self.gatewayDict:
gateway = self.gatewayDict[gatewayName]
gateway.getPosition()
gateway.qryPosition()
else:
self.writeLog(u'接口不存在:%s' %gatewayName)
@ -188,6 +195,9 @@ class MainEngine(object):
# 停止事件引擎
self.eventEngine.stop()
# 停止数据记录引擎
self.drEngine.stop()
# 保存数据引擎里的合约数据到硬盘
self.dataEngine.saveContracts()

View File

@ -332,6 +332,8 @@ class VtErrorData(VtBaseData):
self.errorMsg = EMPTY_UNICODE # 错误信息
self.additionalInfo = EMPTY_UNICODE # 补充信息
self.errorTime = time.strftime('%X', time.localtime()) # 错误生成时间
########################################################################
class VtLogData(VtBaseData):
@ -353,7 +355,7 @@ class VtContractData(VtBaseData):
#----------------------------------------------------------------------
def __init__(self):
"""Constructor"""
super(VtBaseData, self).__init__()
super(VtContractData, self).__init__()
self.symbol = EMPTY_STRING # 代码
self.exchange = EMPTY_STRING # 交易所代码