修复若干bug以及增加手动交易功能,回测引擎允许回测时从数据库读取数据回放,不再需要提前载入到内存
This commit is contained in:
parent
747e08afa9
commit
6b290f1e18
1
.gitignore
vendored
1
.gitignore
vendored
@ -46,6 +46,7 @@ Release/
|
||||
*.local
|
||||
*.temp
|
||||
*.vt
|
||||
*.dat
|
||||
=======
|
||||
vn.ctp/build/*
|
||||
vn.lts/build/*
|
||||
|
@ -1,4 +1,7 @@
|
||||
{
|
||||
"fontFamily": "微软雅黑",
|
||||
"fontSize": 12
|
||||
"fontSize": 12,
|
||||
|
||||
"mongoHost": "localhost",
|
||||
"mongoPort": 27017
|
||||
}
|
@ -7,7 +7,6 @@
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from collections import OrderedDict
|
||||
import json
|
||||
import pymongo
|
||||
|
||||
from ctaBase import *
|
||||
@ -15,6 +14,7 @@ from ctaSetting import *
|
||||
|
||||
from vtConstant import *
|
||||
from vtGateway import VtOrderData, VtTradeData
|
||||
from vtFunction import loadMongoSetting
|
||||
|
||||
|
||||
########################################################################
|
||||
@ -51,11 +51,12 @@ class BacktestingEngine(object):
|
||||
self.dbClient = None # 数据库客户端
|
||||
self.dbCursor = None # 数据库指针
|
||||
|
||||
self.historyData = [] # 历史数据的列表,回测用
|
||||
#self.historyData = [] # 历史数据的列表,回测用
|
||||
self.initData = [] # 初始化用的数据
|
||||
self.backtestingData = [] # 回测用的数据
|
||||
#self.backtestingData = [] # 回测用的数据
|
||||
|
||||
self.dataStartDate = None # 回测数据开始日期,datetime对象
|
||||
self.dataEndDate = None # 回测数据结束日期,datetime对象
|
||||
self.strategyStartDate = None # 策略启动日期(即前面的数据用于初始化),datetime对象
|
||||
|
||||
self.limitOrderDict = OrderedDict() # 限价单字典
|
||||
@ -80,6 +81,12 @@ class BacktestingEngine(object):
|
||||
initTimeDelta = timedelta(initDays)
|
||||
self.strategyStartDate = self.dataStartDate + initTimeDelta
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def setEndDate(self, endDate=''):
|
||||
"""设置回测的结束日期"""
|
||||
if endDate:
|
||||
self.dataEndDate= datetime.strptime(endDate, '%Y%m%d')
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def setBacktestingMode(self, mode):
|
||||
"""设置回测模式"""
|
||||
@ -88,35 +95,53 @@ class BacktestingEngine(object):
|
||||
#----------------------------------------------------------------------
|
||||
def loadHistoryData(self, dbName, symbol):
|
||||
"""载入历史数据"""
|
||||
host, port = loadMongoSetting()
|
||||
|
||||
self.dbClient = pymongo.MongoClient(host, port)
|
||||
collection = self.dbClient[dbName][symbol]
|
||||
|
||||
self.output(u'开始载入数据')
|
||||
|
||||
# 首先根据回测模式,确认要使用的数据类
|
||||
if self.mode == self.BAR_MODE:
|
||||
dataClass = CtaBarData
|
||||
func = self.newBar
|
||||
else:
|
||||
dataClass = CtaTickData
|
||||
func = self.newTick
|
||||
|
||||
# 从数据库进行查询
|
||||
self.dbClient = pymongo.MongoClient()
|
||||
collection = self.dbClient[dbName][symbol]
|
||||
|
||||
flt = {'datetime':{'$gte':self.dataStartDate}} # 数据过滤条件
|
||||
self.dbCursor = collection.find(flt)
|
||||
# 载入初始化需要用的数据
|
||||
flt = {'datetime':{'$gte':self.dataStartDate,
|
||||
'$lt':self.strategyStartDate}}
|
||||
initCursor = collection.find(flt)
|
||||
|
||||
# 将数据从查询指针中读取出,并生成列表
|
||||
for d in self.dbCursor:
|
||||
for d in initCursor:
|
||||
data = dataClass()
|
||||
data.__dict__ = d
|
||||
if data.datetime < self.strategyStartDate:
|
||||
self.initData.append(data)
|
||||
else:
|
||||
self.backtestingData.append(data)
|
||||
|
||||
self.output(u'载入完成,数据量:%s' %len(self.backtestingData))
|
||||
# 载入回测数据
|
||||
if not self.dataEndDate:
|
||||
flt = {'datetime':{'$gte':self.strategyStartDate}} # 数据过滤条件
|
||||
else:
|
||||
flt = {'datetime':{'$gte':self.strategyStartDate,
|
||||
'$lte':self.dataEndDate}}
|
||||
self.dbCursor = collection.find(flt)
|
||||
|
||||
self.output(u'载入完成,数据量:%s' %(initCursor.count() + self.dbCursor.count()))
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def runBacktesting(self):
|
||||
"""运行回测"""
|
||||
# 首先根据回测模式,确认要使用的数据类
|
||||
if self.mode == self.BAR_MODE:
|
||||
dataClass = CtaBarData
|
||||
func = self.newBar
|
||||
else:
|
||||
dataClass = CtaTickData
|
||||
func = self.newTick
|
||||
|
||||
self.output(u'开始回测')
|
||||
|
||||
self.strategy.inited = True
|
||||
@ -128,13 +153,13 @@ class BacktestingEngine(object):
|
||||
self.output(u'策略启动完成')
|
||||
|
||||
self.output(u'开始回放数据')
|
||||
if self.mode == self.BAR_MODE:
|
||||
for data in self.backtestingData:
|
||||
self.newBar(data)
|
||||
#print str(data.datetime)
|
||||
else:
|
||||
for data in self.backtestingData:
|
||||
self.newTick(data)
|
||||
|
||||
for d in self.dbCursor:
|
||||
data = dataClass()
|
||||
data.__dict__ = d
|
||||
func(data)
|
||||
|
||||
self.output(u'数据回放结束')
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def newBar(self, bar):
|
||||
|
@ -45,6 +45,7 @@ class StopOrder(object):
|
||||
def __init__(self):
|
||||
"""Constructor"""
|
||||
self.vtSymbol = EMPTY_STRING
|
||||
self.orderType = EMPTY_UNICODE
|
||||
self.direction = EMPTY_UNICODE
|
||||
self.offset = EMPTY_UNICODE
|
||||
self.price = EMPTY_FLOAT
|
||||
|
@ -118,6 +118,7 @@ class CtaEngine(object):
|
||||
|
||||
so = StopOrder()
|
||||
so.vtSymbol = vtSymbol
|
||||
so.orderType = orderType
|
||||
so.price = price
|
||||
so.volume = volume
|
||||
so.strategy = strategy
|
||||
@ -163,7 +164,7 @@ class CtaEngine(object):
|
||||
for so in self.workingStopOrderDict.values():
|
||||
if so.vtSymbol == vtSymbol:
|
||||
longTriggered = so.direction==DIRECTION_LONG and tick.lastPrice>=so.price # 多头停止单被触发
|
||||
shortTriggered = so.direction==DIRECTION_SHORT and tick.lasatPrice<=so.price # 空头停止单被触发
|
||||
shortTriggered = so.direction==DIRECTION_SHORT and tick.lastPrice<=so.price # 空头停止单被触发
|
||||
|
||||
if longTriggered or shortTriggered:
|
||||
# 买入和卖出分别以涨停跌停价发单(模拟市价单)
|
||||
|
@ -13,6 +13,7 @@ from multiprocessing.pool import ThreadPool
|
||||
|
||||
from ctaBase import *
|
||||
from vtConstant import *
|
||||
from vtFunction import loadMongoSetting
|
||||
from datayesClient import DatayesClient
|
||||
|
||||
|
||||
@ -32,7 +33,9 @@ class HistoryDataEngine(object):
|
||||
#----------------------------------------------------------------------
|
||||
def __init__(self):
|
||||
"""Constructor"""
|
||||
self.dbClient = pymongo.MongoClient()
|
||||
host, port = loadMongoSetting()
|
||||
|
||||
self.dbClient = pymongo.MongoClient(host, port)
|
||||
self.datayesClient = DatayesClient()
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
@ -319,7 +322,9 @@ def loadMcCsv(fileName, dbName, symbol):
|
||||
print u'开始读取CSV文件%s中的数据插入到%s的%s中' %(fileName, dbName, symbol)
|
||||
|
||||
# 锁定集合,并创建索引
|
||||
client = pymongo.MongoClient()
|
||||
host, port = loadMongoSetting()
|
||||
|
||||
client = pymongo.MongoClient(host, port)
|
||||
collection = client[dbName][symbol]
|
||||
collection.ensure_index([('datetime', pymongo.ASCENDING)], unique=True)
|
||||
|
||||
|
@ -41,6 +41,7 @@ offsetMapReverse = {v:k for k,v in offsetMap.items()}
|
||||
exchangeMap = {}
|
||||
exchangeMap[EXCHANGE_SSE] = 'SSE'
|
||||
exchangeMap[EXCHANGE_SZSE] = 'SZE'
|
||||
exchangeMap[EXCHANGE_HKEX] = 'HGE'
|
||||
exchangeMapReverse = {v:k for k,v in exchangeMap.items()}
|
||||
|
||||
# 持仓类型映射
|
||||
|
@ -1,7 +1,7 @@
|
||||
{
|
||||
"brokerID": "0017",
|
||||
"tdAddress": "tcp://140.206.81.6:17776",
|
||||
"password": "联系招金投资申请",
|
||||
"password": "请联系招金投资申请",
|
||||
"mdAddress": "tcp://140.206.81.6:17777",
|
||||
"userID": "联系招金投资申请"
|
||||
"userID": "请联系招金投资申请"
|
||||
}
|
@ -386,6 +386,26 @@ class SgitMdApi(MdApi):
|
||||
tick.askPrice1 = data['AskPrice1']
|
||||
tick.askVolume1 = data['AskVolume1']
|
||||
|
||||
tick.bidPrice2 = data['BidPrice2']
|
||||
tick.bidVolume2 = data['BidVolume2']
|
||||
tick.askPrice2 = data['AskPrice2']
|
||||
tick.askVolume2 = data['AskVolume2']
|
||||
|
||||
tick.bidPrice3 = data['BidPrice3']
|
||||
tick.bidVolume3 = data['BidVolume3']
|
||||
tick.askPrice3 = data['AskPrice3']
|
||||
tick.askVolume3 = data['AskVolume3']
|
||||
|
||||
tick.bidPrice4 = data['BidPrice4']
|
||||
tick.bidVolume4 = data['BidVolume4']
|
||||
tick.askPrice4 = data['AskPrice4']
|
||||
tick.askVolume4 = data['AskVolume4']
|
||||
|
||||
tick.bidPrice5 = data['BidPrice5']
|
||||
tick.bidVolume5 = data['BidVolume5']
|
||||
tick.askPrice5 = data['AskPrice5']
|
||||
tick.askVolume5 = data['AskVolume5']
|
||||
|
||||
self.gateway.onTick(tick)
|
||||
|
||||
|
||||
@ -411,6 +431,7 @@ class SgitTdApi(TdApi):
|
||||
self.password = EMPTY_STRING # 密码
|
||||
self.brokerID = EMPTY_STRING # 经纪商代码
|
||||
self.address = EMPTY_STRING # 服务器地址
|
||||
self.investorID = EMPTY_STRING # 投资者代码
|
||||
|
||||
self.frontID = EMPTY_INT # 前置机编号
|
||||
self.sessionID = EMPTY_INT # 会话编号
|
||||
@ -500,7 +521,7 @@ class SgitTdApi(TdApi):
|
||||
return ''
|
||||
|
||||
req['OrderRef'] = strID
|
||||
req['InvestorID'] = self.userID
|
||||
req['InvestorID'] = self.investorID
|
||||
req['UserID'] = self.userID
|
||||
req['BrokerID'] = self.brokerID
|
||||
|
||||
@ -574,7 +595,7 @@ class SgitTdApi(TdApi):
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def onRspUserLogin(self, data, error, n, last):
|
||||
"""登陆回报"""
|
||||
'''登陆回报'''
|
||||
# 如果登录成功,推送日志信息
|
||||
if error['ErrorID'] == 0:
|
||||
self.loginStatus = True
|
||||
@ -588,9 +609,9 @@ class SgitTdApi(TdApi):
|
||||
# 调用ready
|
||||
self.ready()
|
||||
|
||||
# 查询合约代码
|
||||
# 查询投资者代码
|
||||
self.reqID += 1
|
||||
self.reqQryInstrument({}, self.reqID)
|
||||
self.reqQryInvestor({}, self.reqID)
|
||||
|
||||
# 否则,推送错误信息
|
||||
else:
|
||||
@ -702,7 +723,17 @@ class SgitTdApi(TdApi):
|
||||
#----------------------------------------------------------------------
|
||||
def onRspQryInvestor(self, data, error, n, last):
|
||||
""""""
|
||||
pass
|
||||
self.investorID = data['InvestorID']
|
||||
|
||||
if last:
|
||||
log = VtLogData()
|
||||
log.gatewayName = self.gatewayName
|
||||
log.logContent = u'投资者编码获取完成'
|
||||
self.gateway.onLog(log)
|
||||
|
||||
# 查询合约
|
||||
self.reqID += 1
|
||||
self.reqQryInstrument({}, self.reqID)
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def onRspQryInstrument(self, data, error, n, last):
|
||||
|
@ -557,7 +557,6 @@ class OrderMonitor(BasicMonitor):
|
||||
########################################################################
|
||||
class PositionMonitor(BasicMonitor):
|
||||
"""持仓监控"""
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def __init__(self, mainEngine, eventEngine, parent=None):
|
||||
"""Constructor"""
|
||||
@ -577,6 +576,8 @@ class PositionMonitor(BasicMonitor):
|
||||
self.setDataKey('vtPositionName')
|
||||
self.setEventType(EVENT_POSITION)
|
||||
self.setFont(BASIC_FONT)
|
||||
self.setSaveData(True)
|
||||
|
||||
self.initTable()
|
||||
self.registerEvent()
|
||||
|
||||
@ -635,6 +636,7 @@ class TradingWidget(QtGui.QFrame):
|
||||
EXCHANGE_SSE,
|
||||
EXCHANGE_SZSE,
|
||||
EXCHANGE_SGE,
|
||||
EXCHANGE_HKEX,
|
||||
EXCHANGE_SMART,
|
||||
EXCHANGE_GLOBEX,
|
||||
EXCHANGE_IDEALPRO]
|
||||
@ -1009,6 +1011,29 @@ class TradingWidget(QtGui.QFrame):
|
||||
req.orderID = order.orderID
|
||||
self.mainEngine.cancelOrder(req, order.gatewayName)
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def closePosition(self, cell):
|
||||
"""根据持仓信息自动填写交易组件"""
|
||||
# 读取持仓数据,cell是一个表格中的单元格对象
|
||||
pos = cell.data
|
||||
symbol = pos.symbol
|
||||
|
||||
# 更新交易组件的显示合约
|
||||
self.lineSymbol.setText(symbol)
|
||||
self.updateSymbol()
|
||||
|
||||
# 自动填写信息
|
||||
self.comboPriceType.setCurrentIndex(self.priceTypeList.index(PRICETYPE_LIMITPRICE))
|
||||
self.comboOffset.setCurrentIndex(self.offsetList.index(OFFSET_CLOSE))
|
||||
self.spinVolume.setValue(pos.position)
|
||||
|
||||
if pos.direction == DIRECTION_LONG or pos.direction == DIRECTION_NET:
|
||||
self.comboDirection.setCurrentIndex(self.directionList.index(DIRECTION_SHORT))
|
||||
else:
|
||||
self.comboDirection.setCurrentIndex(self.directionList.index(DIRECTION_LONG))
|
||||
|
||||
# 价格留待更新后由用户输入,防止有误操作
|
||||
|
||||
|
||||
########################################################################
|
||||
class ContractMonitor(BasicMonitor):
|
||||
|
@ -67,6 +67,9 @@ class MainWindow(QtGui.QMainWindow):
|
||||
central.setLayout(grid)
|
||||
self.setCentralWidget(central)
|
||||
|
||||
# 连接组件之间的信号
|
||||
positionM.itemDoubleClicked.connect(tradingW.closePosition)
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def initMenu(self):
|
||||
"""初始化菜单"""
|
||||
|
@ -61,6 +61,7 @@ EXCHANGE_DCE = 'DCE' # 大商所
|
||||
EXCHANGE_SGE = 'SGE' # 上金所
|
||||
EXCHANGE_UNKNOWN = 'UNKNOWN'# 未知交易所
|
||||
EXCHANGE_NONE = '' # 空交易所
|
||||
EXCHANGE_HKEX = 'HKEX' # 港交所
|
||||
|
||||
EXCHANGE_SMART = 'SMART' # IB智能路由(股票、期权)
|
||||
EXCHANGE_GLOBEX = 'GLOBEX' # CME电子交易平台
|
||||
|
@ -8,6 +8,7 @@ from pymongo.errors import ConnectionFailure
|
||||
|
||||
from eventEngine import *
|
||||
from vtGateway import *
|
||||
from vtFunction import loadMongoSetting
|
||||
|
||||
from ctaAlgo.ctaEngine import CtaEngine
|
||||
from dataRecorder.drEngine import DrEngine
|
||||
@ -197,8 +198,11 @@ class MainEngine(object):
|
||||
def dbConnect(self):
|
||||
"""连接MongoDB数据库"""
|
||||
if not self.dbClient:
|
||||
# 读取MongoDB的设置
|
||||
host, port = loadMongoSetting()
|
||||
|
||||
try:
|
||||
self.dbClient = MongoClient()
|
||||
self.dbClient = MongoClient(host, port)
|
||||
self.writeLog(u'MongoDB连接成功')
|
||||
except ConnectionFailure:
|
||||
self.writeLog(u'MongoDB连接失败')
|
||||
|
@ -24,3 +24,19 @@ def safeUnicode(value):
|
||||
value = round(value, ndigits=MAX_DECIMAL)
|
||||
|
||||
return unicode(value)
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def loadMongoSetting():
|
||||
"""载入MongoDB数据库的配置"""
|
||||
try:
|
||||
f = file("VT_setting.json")
|
||||
setting = json.load(f)
|
||||
host = setting['mongoHost']
|
||||
port = setting['mongoPort']
|
||||
except:
|
||||
host = 'localhost'
|
||||
port = 27017
|
||||
|
||||
return host, port
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user