修复若干bug以及增加手动交易功能,回测引擎允许回测时从数据库读取数据回放,不再需要提前载入到内存

This commit is contained in:
chenxy123 2016-04-20 23:14:21 +08:00
parent 747e08afa9
commit 6b290f1e18
14 changed files with 158 additions and 41 deletions

1
.gitignore vendored
View File

@ -46,6 +46,7 @@ Release/
*.local *.local
*.temp *.temp
*.vt *.vt
*.dat
======= =======
vn.ctp/build/* vn.ctp/build/*
vn.lts/build/* vn.lts/build/*

View File

@ -1,4 +1,7 @@
{ {
"fontFamily": "微软雅黑", "fontFamily": "微软雅黑",
"fontSize": 12 "fontSize": 12,
"mongoHost": "localhost",
"mongoPort": 27017
} }

View File

@ -7,7 +7,6 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
from collections import OrderedDict from collections import OrderedDict
import json
import pymongo import pymongo
from ctaBase import * from ctaBase import *
@ -15,6 +14,7 @@ from ctaSetting import *
from vtConstant import * from vtConstant import *
from vtGateway import VtOrderData, VtTradeData from vtGateway import VtOrderData, VtTradeData
from vtFunction import loadMongoSetting
######################################################################## ########################################################################
@ -51,11 +51,12 @@ class BacktestingEngine(object):
self.dbClient = None # 数据库客户端 self.dbClient = None # 数据库客户端
self.dbCursor = None # 数据库指针 self.dbCursor = None # 数据库指针
self.historyData = [] # 历史数据的列表,回测用 #self.historyData = [] # 历史数据的列表,回测用
self.initData = [] # 初始化用的数据 self.initData = [] # 初始化用的数据
self.backtestingData = [] # 回测用的数据 #self.backtestingData = [] # 回测用的数据
self.dataStartDate = None # 回测数据开始日期datetime对象 self.dataStartDate = None # 回测数据开始日期datetime对象
self.dataEndDate = None # 回测数据结束日期datetime对象
self.strategyStartDate = None # 策略启动日期即前面的数据用于初始化datetime对象 self.strategyStartDate = None # 策略启动日期即前面的数据用于初始化datetime对象
self.limitOrderDict = OrderedDict() # 限价单字典 self.limitOrderDict = OrderedDict() # 限价单字典
@ -80,6 +81,12 @@ class BacktestingEngine(object):
initTimeDelta = timedelta(initDays) initTimeDelta = timedelta(initDays)
self.strategyStartDate = self.dataStartDate + initTimeDelta self.strategyStartDate = self.dataStartDate + initTimeDelta
#----------------------------------------------------------------------
def setEndDate(self, endDate=''):
"""设置回测的结束日期"""
if endDate:
self.dataEndDate= datetime.strptime(endDate, '%Y%m%d')
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def setBacktestingMode(self, mode): def setBacktestingMode(self, mode):
"""设置回测模式""" """设置回测模式"""
@ -88,35 +95,53 @@ class BacktestingEngine(object):
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def loadHistoryData(self, dbName, symbol): def loadHistoryData(self, dbName, symbol):
"""载入历史数据""" """载入历史数据"""
self.output(u'开始载入数据') host, port = loadMongoSetting()
self.dbClient = pymongo.MongoClient(host, port)
collection = self.dbClient[dbName][symbol]
self.output(u'开始载入数据')
# 首先根据回测模式,确认要使用的数据类 # 首先根据回测模式,确认要使用的数据类
if self.mode == self.BAR_MODE: if self.mode == self.BAR_MODE:
dataClass = CtaBarData dataClass = CtaBarData
func = self.newBar
else: else:
dataClass = CtaTickData dataClass = CtaTickData
func = self.newTick
# 从数据库进行查询
self.dbClient = pymongo.MongoClient() # 载入初始化需要用的数据
collection = self.dbClient[dbName][symbol] flt = {'datetime':{'$gte':self.dataStartDate,
'$lt':self.strategyStartDate}}
flt = {'datetime':{'$gte':self.dataStartDate}} # 数据过滤条件 initCursor = collection.find(flt)
self.dbCursor = collection.find(flt)
# 将数据从查询指针中读取出,并生成列表 # 将数据从查询指针中读取出,并生成列表
for d in self.dbCursor: for d in initCursor:
data = dataClass() data = dataClass()
data.__dict__ = d data.__dict__ = d
if data.datetime < self.strategyStartDate: self.initData.append(data)
self.initData.append(data)
else: # 载入回测数据
self.backtestingData.append(data) if not self.dataEndDate:
flt = {'datetime':{'$gte':self.strategyStartDate}} # 数据过滤条件
else:
flt = {'datetime':{'$gte':self.strategyStartDate,
'$lte':self.dataEndDate}}
self.dbCursor = collection.find(flt)
self.output(u'载入完成,数据量:%s' %(initCursor.count() + self.dbCursor.count()))
self.output(u'载入完成,数据量:%s' %len(self.backtestingData))
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def runBacktesting(self): def runBacktesting(self):
"""运行回测""" """运行回测"""
# 首先根据回测模式,确认要使用的数据类
if self.mode == self.BAR_MODE:
dataClass = CtaBarData
func = self.newBar
else:
dataClass = CtaTickData
func = self.newTick
self.output(u'开始回测') self.output(u'开始回测')
self.strategy.inited = True self.strategy.inited = True
@ -128,13 +153,13 @@ class BacktestingEngine(object):
self.output(u'策略启动完成') self.output(u'策略启动完成')
self.output(u'开始回放数据') self.output(u'开始回放数据')
if self.mode == self.BAR_MODE:
for data in self.backtestingData: for d in self.dbCursor:
self.newBar(data) data = dataClass()
#print str(data.datetime) data.__dict__ = d
else: func(data)
for data in self.backtestingData:
self.newTick(data) self.output(u'数据回放结束')
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def newBar(self, bar): def newBar(self, bar):

View File

@ -45,6 +45,7 @@ class StopOrder(object):
def __init__(self): def __init__(self):
"""Constructor""" """Constructor"""
self.vtSymbol = EMPTY_STRING self.vtSymbol = EMPTY_STRING
self.orderType = EMPTY_UNICODE
self.direction = EMPTY_UNICODE self.direction = EMPTY_UNICODE
self.offset = EMPTY_UNICODE self.offset = EMPTY_UNICODE
self.price = EMPTY_FLOAT self.price = EMPTY_FLOAT

View File

@ -118,6 +118,7 @@ class CtaEngine(object):
so = StopOrder() so = StopOrder()
so.vtSymbol = vtSymbol so.vtSymbol = vtSymbol
so.orderType = orderType
so.price = price so.price = price
so.volume = volume so.volume = volume
so.strategy = strategy so.strategy = strategy
@ -163,7 +164,7 @@ class CtaEngine(object):
for so in self.workingStopOrderDict.values(): for so in self.workingStopOrderDict.values():
if so.vtSymbol == vtSymbol: if so.vtSymbol == vtSymbol:
longTriggered = so.direction==DIRECTION_LONG and tick.lastPrice>=so.price # 多头停止单被触发 longTriggered = so.direction==DIRECTION_LONG and tick.lastPrice>=so.price # 多头停止单被触发
shortTriggered = so.direction==DIRECTION_SHORT and tick.lasatPrice<=so.price # 空头停止单被触发 shortTriggered = so.direction==DIRECTION_SHORT and tick.lastPrice<=so.price # 空头停止单被触发
if longTriggered or shortTriggered: if longTriggered or shortTriggered:
# 买入和卖出分别以涨停跌停价发单(模拟市价单) # 买入和卖出分别以涨停跌停价发单(模拟市价单)

View File

@ -13,6 +13,7 @@ from multiprocessing.pool import ThreadPool
from ctaBase import * from ctaBase import *
from vtConstant import * from vtConstant import *
from vtFunction import loadMongoSetting
from datayesClient import DatayesClient from datayesClient import DatayesClient
@ -32,7 +33,9 @@ class HistoryDataEngine(object):
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def __init__(self): def __init__(self):
"""Constructor""" """Constructor"""
self.dbClient = pymongo.MongoClient() host, port = loadMongoSetting()
self.dbClient = pymongo.MongoClient(host, port)
self.datayesClient = DatayesClient() self.datayesClient = DatayesClient()
#---------------------------------------------------------------------- #----------------------------------------------------------------------
@ -319,7 +322,9 @@ def loadMcCsv(fileName, dbName, symbol):
print u'开始读取CSV文件%s中的数据插入到%s%s' %(fileName, dbName, symbol) print u'开始读取CSV文件%s中的数据插入到%s%s' %(fileName, dbName, symbol)
# 锁定集合,并创建索引 # 锁定集合,并创建索引
client = pymongo.MongoClient() host, port = loadMongoSetting()
client = pymongo.MongoClient(host, port)
collection = client[dbName][symbol] collection = client[dbName][symbol]
collection.ensure_index([('datetime', pymongo.ASCENDING)], unique=True) collection.ensure_index([('datetime', pymongo.ASCENDING)], unique=True)

View File

@ -41,6 +41,7 @@ offsetMapReverse = {v:k for k,v in offsetMap.items()}
exchangeMap = {} exchangeMap = {}
exchangeMap[EXCHANGE_SSE] = 'SSE' exchangeMap[EXCHANGE_SSE] = 'SSE'
exchangeMap[EXCHANGE_SZSE] = 'SZE' exchangeMap[EXCHANGE_SZSE] = 'SZE'
exchangeMap[EXCHANGE_HKEX] = 'HGE'
exchangeMapReverse = {v:k for k,v in exchangeMap.items()} exchangeMapReverse = {v:k for k,v in exchangeMap.items()}
# 持仓类型映射 # 持仓类型映射

View File

@ -1,7 +1,7 @@
{ {
"brokerID": "0017", "brokerID": "0017",
"tdAddress": "tcp://140.206.81.6:17776", "tdAddress": "tcp://140.206.81.6:17776",
"password": "联系招金投资申请", "password": "联系招金投资申请",
"mdAddress": "tcp://140.206.81.6:17777", "mdAddress": "tcp://140.206.81.6:17777",
"userID": "联系招金投资申请" "userID": "联系招金投资申请"
} }

View File

@ -385,6 +385,26 @@ class SgitMdApi(MdApi):
tick.bidVolume1 = data['BidVolume1'] tick.bidVolume1 = data['BidVolume1']
tick.askPrice1 = data['AskPrice1'] tick.askPrice1 = data['AskPrice1']
tick.askVolume1 = data['AskVolume1'] tick.askVolume1 = data['AskVolume1']
tick.bidPrice2 = data['BidPrice2']
tick.bidVolume2 = data['BidVolume2']
tick.askPrice2 = data['AskPrice2']
tick.askVolume2 = data['AskVolume2']
tick.bidPrice3 = data['BidPrice3']
tick.bidVolume3 = data['BidVolume3']
tick.askPrice3 = data['AskPrice3']
tick.askVolume3 = data['AskVolume3']
tick.bidPrice4 = data['BidPrice4']
tick.bidVolume4 = data['BidVolume4']
tick.askPrice4 = data['AskPrice4']
tick.askVolume4 = data['AskVolume4']
tick.bidPrice5 = data['BidPrice5']
tick.bidVolume5 = data['BidVolume5']
tick.askPrice5 = data['AskPrice5']
tick.askVolume5 = data['AskVolume5']
self.gateway.onTick(tick) self.gateway.onTick(tick)
@ -411,6 +431,7 @@ class SgitTdApi(TdApi):
self.password = EMPTY_STRING # 密码 self.password = EMPTY_STRING # 密码
self.brokerID = EMPTY_STRING # 经纪商代码 self.brokerID = EMPTY_STRING # 经纪商代码
self.address = EMPTY_STRING # 服务器地址 self.address = EMPTY_STRING # 服务器地址
self.investorID = EMPTY_STRING # 投资者代码
self.frontID = EMPTY_INT # 前置机编号 self.frontID = EMPTY_INT # 前置机编号
self.sessionID = EMPTY_INT # 会话编号 self.sessionID = EMPTY_INT # 会话编号
@ -500,7 +521,7 @@ class SgitTdApi(TdApi):
return '' return ''
req['OrderRef'] = strID req['OrderRef'] = strID
req['InvestorID'] = self.userID req['InvestorID'] = self.investorID
req['UserID'] = self.userID req['UserID'] = self.userID
req['BrokerID'] = self.brokerID req['BrokerID'] = self.brokerID
@ -574,7 +595,7 @@ class SgitTdApi(TdApi):
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def onRspUserLogin(self, data, error, n, last): def onRspUserLogin(self, data, error, n, last):
"""登陆回报""" '''登陆回报'''
# 如果登录成功,推送日志信息 # 如果登录成功,推送日志信息
if error['ErrorID'] == 0: if error['ErrorID'] == 0:
self.loginStatus = True self.loginStatus = True
@ -588,9 +609,9 @@ class SgitTdApi(TdApi):
# 调用ready # 调用ready
self.ready() self.ready()
# 查询合约代码 # 查询投资者代码
self.reqID += 1 self.reqID += 1
self.reqQryInstrument({}, self.reqID) self.reqQryInvestor({}, self.reqID)
# 否则,推送错误信息 # 否则,推送错误信息
else: else:
@ -702,7 +723,17 @@ class SgitTdApi(TdApi):
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def onRspQryInvestor(self, data, error, n, last): def onRspQryInvestor(self, data, error, n, last):
"""""" """"""
pass self.investorID = data['InvestorID']
if last:
log = VtLogData()
log.gatewayName = self.gatewayName
log.logContent = u'投资者编码获取完成'
self.gateway.onLog(log)
# 查询合约
self.reqID += 1
self.reqQryInstrument({}, self.reqID)
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def onRspQryInstrument(self, data, error, n, last): def onRspQryInstrument(self, data, error, n, last):

View File

@ -557,7 +557,6 @@ class OrderMonitor(BasicMonitor):
######################################################################## ########################################################################
class PositionMonitor(BasicMonitor): class PositionMonitor(BasicMonitor):
"""持仓监控""" """持仓监控"""
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def __init__(self, mainEngine, eventEngine, parent=None): def __init__(self, mainEngine, eventEngine, parent=None):
"""Constructor""" """Constructor"""
@ -577,10 +576,12 @@ class PositionMonitor(BasicMonitor):
self.setDataKey('vtPositionName') self.setDataKey('vtPositionName')
self.setEventType(EVENT_POSITION) self.setEventType(EVENT_POSITION)
self.setFont(BASIC_FONT) self.setFont(BASIC_FONT)
self.setSaveData(True)
self.initTable() self.initTable()
self.registerEvent() self.registerEvent()
######################################################################## ########################################################################
class AccountMonitor(BasicMonitor): class AccountMonitor(BasicMonitor):
"""账户监控""" """账户监控"""
@ -635,6 +636,7 @@ class TradingWidget(QtGui.QFrame):
EXCHANGE_SSE, EXCHANGE_SSE,
EXCHANGE_SZSE, EXCHANGE_SZSE,
EXCHANGE_SGE, EXCHANGE_SGE,
EXCHANGE_HKEX,
EXCHANGE_SMART, EXCHANGE_SMART,
EXCHANGE_GLOBEX, EXCHANGE_GLOBEX,
EXCHANGE_IDEALPRO] EXCHANGE_IDEALPRO]
@ -1008,6 +1010,29 @@ class TradingWidget(QtGui.QFrame):
req.sessionID = order.sessionID req.sessionID = order.sessionID
req.orderID = order.orderID req.orderID = order.orderID
self.mainEngine.cancelOrder(req, order.gatewayName) self.mainEngine.cancelOrder(req, order.gatewayName)
#----------------------------------------------------------------------
def closePosition(self, cell):
"""根据持仓信息自动填写交易组件"""
# 读取持仓数据cell是一个表格中的单元格对象
pos = cell.data
symbol = pos.symbol
# 更新交易组件的显示合约
self.lineSymbol.setText(symbol)
self.updateSymbol()
# 自动填写信息
self.comboPriceType.setCurrentIndex(self.priceTypeList.index(PRICETYPE_LIMITPRICE))
self.comboOffset.setCurrentIndex(self.offsetList.index(OFFSET_CLOSE))
self.spinVolume.setValue(pos.position)
if pos.direction == DIRECTION_LONG or pos.direction == DIRECTION_NET:
self.comboDirection.setCurrentIndex(self.directionList.index(DIRECTION_SHORT))
else:
self.comboDirection.setCurrentIndex(self.directionList.index(DIRECTION_LONG))
# 价格留待更新后由用户输入,防止有误操作
######################################################################## ########################################################################

View File

@ -67,6 +67,9 @@ class MainWindow(QtGui.QMainWindow):
central.setLayout(grid) central.setLayout(grid)
self.setCentralWidget(central) self.setCentralWidget(central)
# 连接组件之间的信号
positionM.itemDoubleClicked.connect(tradingW.closePosition)
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def initMenu(self): def initMenu(self):
"""初始化菜单""" """初始化菜单"""

View File

@ -61,6 +61,7 @@ EXCHANGE_DCE = 'DCE' # 大商所
EXCHANGE_SGE = 'SGE' # 上金所 EXCHANGE_SGE = 'SGE' # 上金所
EXCHANGE_UNKNOWN = 'UNKNOWN'# 未知交易所 EXCHANGE_UNKNOWN = 'UNKNOWN'# 未知交易所
EXCHANGE_NONE = '' # 空交易所 EXCHANGE_NONE = '' # 空交易所
EXCHANGE_HKEX = 'HKEX' # 港交所
EXCHANGE_SMART = 'SMART' # IB智能路由股票、期权 EXCHANGE_SMART = 'SMART' # IB智能路由股票、期权
EXCHANGE_GLOBEX = 'GLOBEX' # CME电子交易平台 EXCHANGE_GLOBEX = 'GLOBEX' # CME电子交易平台

View File

@ -8,6 +8,7 @@ from pymongo.errors import ConnectionFailure
from eventEngine import * from eventEngine import *
from vtGateway import * from vtGateway import *
from vtFunction import loadMongoSetting
from ctaAlgo.ctaEngine import CtaEngine from ctaAlgo.ctaEngine import CtaEngine
from dataRecorder.drEngine import DrEngine from dataRecorder.drEngine import DrEngine
@ -197,8 +198,11 @@ class MainEngine(object):
def dbConnect(self): def dbConnect(self):
"""连接MongoDB数据库""" """连接MongoDB数据库"""
if not self.dbClient: if not self.dbClient:
# 读取MongoDB的设置
host, port = loadMongoSetting()
try: try:
self.dbClient = MongoClient() self.dbClient = MongoClient(host, port)
self.writeLog(u'MongoDB连接成功') self.writeLog(u'MongoDB连接成功')
except ConnectionFailure: except ConnectionFailure:
self.writeLog(u'MongoDB连接失败') self.writeLog(u'MongoDB连接失败')

View File

@ -23,4 +23,20 @@ def safeUnicode(value):
if abs(d.as_tuple().exponent) > MAX_DECIMAL: if abs(d.as_tuple().exponent) > MAX_DECIMAL:
value = round(value, ndigits=MAX_DECIMAL) value = round(value, ndigits=MAX_DECIMAL)
return unicode(value) return unicode(value)
#----------------------------------------------------------------------
def loadMongoSetting():
"""载入MongoDB数据库的配置"""
try:
f = file("VT_setting.json")
setting = json.load(f)
host = setting['mongoHost']
port = setting['mongoPort']
except:
host = 'localhost'
port = 27017
return host, port