修复若干bug以及增加手动交易功能,回测引擎允许回测时从数据库读取数据回放,不再需要提前载入到内存

This commit is contained in:
chenxy123 2016-04-20 23:14:21 +08:00
parent 747e08afa9
commit 6b290f1e18
14 changed files with 158 additions and 41 deletions

1
.gitignore vendored
View File

@ -46,6 +46,7 @@ Release/
*.local
*.temp
*.vt
*.dat
=======
vn.ctp/build/*
vn.lts/build/*

View File

@ -1,4 +1,7 @@
{
"fontFamily": "微软雅黑",
"fontSize": 12
"fontSize": 12,
"mongoHost": "localhost",
"mongoPort": 27017
}

View File

@ -7,7 +7,6 @@
from datetime import datetime, timedelta
from collections import OrderedDict
import json
import pymongo
from ctaBase import *
@ -15,6 +14,7 @@ from ctaSetting import *
from vtConstant import *
from vtGateway import VtOrderData, VtTradeData
from vtFunction import loadMongoSetting
########################################################################
@ -51,11 +51,12 @@ class BacktestingEngine(object):
self.dbClient = None # 数据库客户端
self.dbCursor = None # 数据库指针
self.historyData = [] # 历史数据的列表,回测用
#self.historyData = [] # 历史数据的列表,回测用
self.initData = [] # 初始化用的数据
self.backtestingData = [] # 回测用的数据
#self.backtestingData = [] # 回测用的数据
self.dataStartDate = None # 回测数据开始日期datetime对象
self.dataEndDate = None # 回测数据结束日期datetime对象
self.strategyStartDate = None # 策略启动日期即前面的数据用于初始化datetime对象
self.limitOrderDict = OrderedDict() # 限价单字典
@ -80,6 +81,12 @@ class BacktestingEngine(object):
initTimeDelta = timedelta(initDays)
self.strategyStartDate = self.dataStartDate + initTimeDelta
#----------------------------------------------------------------------
def setEndDate(self, endDate=''):
"""设置回测的结束日期"""
if endDate:
self.dataEndDate= datetime.strptime(endDate, '%Y%m%d')
#----------------------------------------------------------------------
def setBacktestingMode(self, mode):
"""设置回测模式"""
@ -88,35 +95,53 @@ class BacktestingEngine(object):
#----------------------------------------------------------------------
def loadHistoryData(self, dbName, symbol):
"""载入历史数据"""
host, port = loadMongoSetting()
self.dbClient = pymongo.MongoClient(host, port)
collection = self.dbClient[dbName][symbol]
self.output(u'开始载入数据')
# 首先根据回测模式,确认要使用的数据类
if self.mode == self.BAR_MODE:
dataClass = CtaBarData
func = self.newBar
else:
dataClass = CtaTickData
func = self.newTick
# 从数据库进行查询
self.dbClient = pymongo.MongoClient()
collection = self.dbClient[dbName][symbol]
flt = {'datetime':{'$gte':self.dataStartDate}} # 数据过滤条件
self.dbCursor = collection.find(flt)
# 载入初始化需要用的数据
flt = {'datetime':{'$gte':self.dataStartDate,
'$lt':self.strategyStartDate}}
initCursor = collection.find(flt)
# 将数据从查询指针中读取出,并生成列表
for d in self.dbCursor:
for d in initCursor:
data = dataClass()
data.__dict__ = d
if data.datetime < self.strategyStartDate:
self.initData.append(data)
else:
self.backtestingData.append(data)
self.output(u'载入完成,数据量:%s' %len(self.backtestingData))
# 载入回测数据
if not self.dataEndDate:
flt = {'datetime':{'$gte':self.strategyStartDate}} # 数据过滤条件
else:
flt = {'datetime':{'$gte':self.strategyStartDate,
'$lte':self.dataEndDate}}
self.dbCursor = collection.find(flt)
self.output(u'载入完成,数据量:%s' %(initCursor.count() + self.dbCursor.count()))
#----------------------------------------------------------------------
def runBacktesting(self):
"""运行回测"""
# 首先根据回测模式,确认要使用的数据类
if self.mode == self.BAR_MODE:
dataClass = CtaBarData
func = self.newBar
else:
dataClass = CtaTickData
func = self.newTick
self.output(u'开始回测')
self.strategy.inited = True
@ -128,13 +153,13 @@ class BacktestingEngine(object):
self.output(u'策略启动完成')
self.output(u'开始回放数据')
if self.mode == self.BAR_MODE:
for data in self.backtestingData:
self.newBar(data)
#print str(data.datetime)
else:
for data in self.backtestingData:
self.newTick(data)
for d in self.dbCursor:
data = dataClass()
data.__dict__ = d
func(data)
self.output(u'数据回放结束')
#----------------------------------------------------------------------
def newBar(self, bar):

View File

@ -45,6 +45,7 @@ class StopOrder(object):
def __init__(self):
"""Constructor"""
self.vtSymbol = EMPTY_STRING
self.orderType = EMPTY_UNICODE
self.direction = EMPTY_UNICODE
self.offset = EMPTY_UNICODE
self.price = EMPTY_FLOAT

View File

@ -118,6 +118,7 @@ class CtaEngine(object):
so = StopOrder()
so.vtSymbol = vtSymbol
so.orderType = orderType
so.price = price
so.volume = volume
so.strategy = strategy
@ -163,7 +164,7 @@ class CtaEngine(object):
for so in self.workingStopOrderDict.values():
if so.vtSymbol == vtSymbol:
longTriggered = so.direction==DIRECTION_LONG and tick.lastPrice>=so.price # 多头停止单被触发
shortTriggered = so.direction==DIRECTION_SHORT and tick.lasatPrice<=so.price # 空头停止单被触发
shortTriggered = so.direction==DIRECTION_SHORT and tick.lastPrice<=so.price # 空头停止单被触发
if longTriggered or shortTriggered:
# 买入和卖出分别以涨停跌停价发单(模拟市价单)

View File

@ -13,6 +13,7 @@ from multiprocessing.pool import ThreadPool
from ctaBase import *
from vtConstant import *
from vtFunction import loadMongoSetting
from datayesClient import DatayesClient
@ -32,7 +33,9 @@ class HistoryDataEngine(object):
#----------------------------------------------------------------------
def __init__(self):
"""Constructor"""
self.dbClient = pymongo.MongoClient()
host, port = loadMongoSetting()
self.dbClient = pymongo.MongoClient(host, port)
self.datayesClient = DatayesClient()
#----------------------------------------------------------------------
@ -319,7 +322,9 @@ def loadMcCsv(fileName, dbName, symbol):
print u'开始读取CSV文件%s中的数据插入到%s%s' %(fileName, dbName, symbol)
# 锁定集合,并创建索引
client = pymongo.MongoClient()
host, port = loadMongoSetting()
client = pymongo.MongoClient(host, port)
collection = client[dbName][symbol]
collection.ensure_index([('datetime', pymongo.ASCENDING)], unique=True)

View File

@ -41,6 +41,7 @@ offsetMapReverse = {v:k for k,v in offsetMap.items()}
exchangeMap = {}
exchangeMap[EXCHANGE_SSE] = 'SSE'
exchangeMap[EXCHANGE_SZSE] = 'SZE'
exchangeMap[EXCHANGE_HKEX] = 'HGE'
exchangeMapReverse = {v:k for k,v in exchangeMap.items()}
# 持仓类型映射

View File

@ -1,7 +1,7 @@
{
"brokerID": "0017",
"tdAddress": "tcp://140.206.81.6:17776",
"password": "联系招金投资申请",
"password": "联系招金投资申请",
"mdAddress": "tcp://140.206.81.6:17777",
"userID": "联系招金投资申请"
"userID": "联系招金投资申请"
}

View File

@ -386,6 +386,26 @@ class SgitMdApi(MdApi):
tick.askPrice1 = data['AskPrice1']
tick.askVolume1 = data['AskVolume1']
tick.bidPrice2 = data['BidPrice2']
tick.bidVolume2 = data['BidVolume2']
tick.askPrice2 = data['AskPrice2']
tick.askVolume2 = data['AskVolume2']
tick.bidPrice3 = data['BidPrice3']
tick.bidVolume3 = data['BidVolume3']
tick.askPrice3 = data['AskPrice3']
tick.askVolume3 = data['AskVolume3']
tick.bidPrice4 = data['BidPrice4']
tick.bidVolume4 = data['BidVolume4']
tick.askPrice4 = data['AskPrice4']
tick.askVolume4 = data['AskVolume4']
tick.bidPrice5 = data['BidPrice5']
tick.bidVolume5 = data['BidVolume5']
tick.askPrice5 = data['AskPrice5']
tick.askVolume5 = data['AskVolume5']
self.gateway.onTick(tick)
@ -411,6 +431,7 @@ class SgitTdApi(TdApi):
self.password = EMPTY_STRING # 密码
self.brokerID = EMPTY_STRING # 经纪商代码
self.address = EMPTY_STRING # 服务器地址
self.investorID = EMPTY_STRING # 投资者代码
self.frontID = EMPTY_INT # 前置机编号
self.sessionID = EMPTY_INT # 会话编号
@ -500,7 +521,7 @@ class SgitTdApi(TdApi):
return ''
req['OrderRef'] = strID
req['InvestorID'] = self.userID
req['InvestorID'] = self.investorID
req['UserID'] = self.userID
req['BrokerID'] = self.brokerID
@ -574,7 +595,7 @@ class SgitTdApi(TdApi):
#----------------------------------------------------------------------
def onRspUserLogin(self, data, error, n, last):
"""登陆回报"""
'''登陆回报'''
# 如果登录成功,推送日志信息
if error['ErrorID'] == 0:
self.loginStatus = True
@ -588,9 +609,9 @@ class SgitTdApi(TdApi):
# 调用ready
self.ready()
# 查询合约代码
# 查询投资者代码
self.reqID += 1
self.reqQryInstrument({}, self.reqID)
self.reqQryInvestor({}, self.reqID)
# 否则,推送错误信息
else:
@ -702,7 +723,17 @@ class SgitTdApi(TdApi):
#----------------------------------------------------------------------
def onRspQryInvestor(self, data, error, n, last):
""""""
pass
self.investorID = data['InvestorID']
if last:
log = VtLogData()
log.gatewayName = self.gatewayName
log.logContent = u'投资者编码获取完成'
self.gateway.onLog(log)
# 查询合约
self.reqID += 1
self.reqQryInstrument({}, self.reqID)
#----------------------------------------------------------------------
def onRspQryInstrument(self, data, error, n, last):

View File

@ -557,7 +557,6 @@ class OrderMonitor(BasicMonitor):
########################################################################
class PositionMonitor(BasicMonitor):
"""持仓监控"""
#----------------------------------------------------------------------
def __init__(self, mainEngine, eventEngine, parent=None):
"""Constructor"""
@ -577,6 +576,8 @@ class PositionMonitor(BasicMonitor):
self.setDataKey('vtPositionName')
self.setEventType(EVENT_POSITION)
self.setFont(BASIC_FONT)
self.setSaveData(True)
self.initTable()
self.registerEvent()
@ -635,6 +636,7 @@ class TradingWidget(QtGui.QFrame):
EXCHANGE_SSE,
EXCHANGE_SZSE,
EXCHANGE_SGE,
EXCHANGE_HKEX,
EXCHANGE_SMART,
EXCHANGE_GLOBEX,
EXCHANGE_IDEALPRO]
@ -1009,6 +1011,29 @@ class TradingWidget(QtGui.QFrame):
req.orderID = order.orderID
self.mainEngine.cancelOrder(req, order.gatewayName)
#----------------------------------------------------------------------
def closePosition(self, cell):
"""根据持仓信息自动填写交易组件"""
# 读取持仓数据cell是一个表格中的单元格对象
pos = cell.data
symbol = pos.symbol
# 更新交易组件的显示合约
self.lineSymbol.setText(symbol)
self.updateSymbol()
# 自动填写信息
self.comboPriceType.setCurrentIndex(self.priceTypeList.index(PRICETYPE_LIMITPRICE))
self.comboOffset.setCurrentIndex(self.offsetList.index(OFFSET_CLOSE))
self.spinVolume.setValue(pos.position)
if pos.direction == DIRECTION_LONG or pos.direction == DIRECTION_NET:
self.comboDirection.setCurrentIndex(self.directionList.index(DIRECTION_SHORT))
else:
self.comboDirection.setCurrentIndex(self.directionList.index(DIRECTION_LONG))
# 价格留待更新后由用户输入,防止有误操作
########################################################################
class ContractMonitor(BasicMonitor):

View File

@ -67,6 +67,9 @@ class MainWindow(QtGui.QMainWindow):
central.setLayout(grid)
self.setCentralWidget(central)
# 连接组件之间的信号
positionM.itemDoubleClicked.connect(tradingW.closePosition)
#----------------------------------------------------------------------
def initMenu(self):
"""初始化菜单"""

View File

@ -61,6 +61,7 @@ EXCHANGE_DCE = 'DCE' # 大商所
EXCHANGE_SGE = 'SGE' # 上金所
EXCHANGE_UNKNOWN = 'UNKNOWN'# 未知交易所
EXCHANGE_NONE = '' # 空交易所
EXCHANGE_HKEX = 'HKEX' # 港交所
EXCHANGE_SMART = 'SMART' # IB智能路由股票、期权
EXCHANGE_GLOBEX = 'GLOBEX' # CME电子交易平台

View File

@ -8,6 +8,7 @@ from pymongo.errors import ConnectionFailure
from eventEngine import *
from vtGateway import *
from vtFunction import loadMongoSetting
from ctaAlgo.ctaEngine import CtaEngine
from dataRecorder.drEngine import DrEngine
@ -197,8 +198,11 @@ class MainEngine(object):
def dbConnect(self):
"""连接MongoDB数据库"""
if not self.dbClient:
# 读取MongoDB的设置
host, port = loadMongoSetting()
try:
self.dbClient = MongoClient()
self.dbClient = MongoClient(host, port)
self.writeLog(u'MongoDB连接成功')
except ConnectionFailure:
self.writeLog(u'MongoDB连接失败')

View File

@ -24,3 +24,19 @@ def safeUnicode(value):
value = round(value, ndigits=MAX_DECIMAL)
return unicode(value)
#----------------------------------------------------------------------
def loadMongoSetting():
"""载入MongoDB数据库的配置"""
try:
f = file("VT_setting.json")
setting = json.load(f)
host = setting['mongoHost']
port = setting['mongoPort']
except:
host = 'localhost'
port = 27017
return host, port