修改回测模型,支持mysql的tick数据库

修改SendOrder方法,增加返回GatewayName
修改Tick的方法,增加data-》Tick对象。
This commit is contained in:
msincenselee 2016-09-14 22:29:31 +08:00
parent e373f16ac0
commit c2f3ba62c7

View File

@ -7,15 +7,21 @@
from datetime import datetime, timedelta
from collections import OrderedDict
from itertools import product
import pymongo
import MySQLdb
import json
import os
import cPickle
from ctaBase import *
from ctaSetting import *
from vtConstant import *
from vtGateway import VtOrderData, VtTradeData
from vtFunction import loadMongoSetting
import logging
########################################################################
class BacktestingEngine(object):
@ -23,6 +29,10 @@ class BacktestingEngine(object):
CTA回测引擎
函数接口和策略引擎保持一样
从而实现同一套代码从回测到实盘
# modified by IncenseLee
1.增加Mysql数据库的支持
2.修改装载数据为批量式后加载模式
"""
TICK_MODE = 'tick'
@ -51,10 +61,13 @@ class BacktestingEngine(object):
self.dbClient = None # 数据库客户端
self.dbCursor = None # 数据库指针
#self.historyData = [] # 历史数据的列表,回测用
self.historyData = [] # 历史数据的列表,回测用
self.initData = [] # 初始化用的数据
#self.backtestingData = [] # 回测用的数据
self.backtestingData = [] # 回测用的数据
self.dbName = '' # 回测数据库名
self.symbol = '' # 回测集合名
self.dataStartDate = None # 回测数据开始日期datetime对象
self.dataEndDate = None # 回测数据结束日期datetime对象
self.strategyStartDate = None # 策略启动日期即前面的数据用于初始化datetime对象
@ -72,33 +85,49 @@ class BacktestingEngine(object):
self.tick = None
self.bar = None
self.dt = None # 最新的时间
self.gatewayName = u'BackTest'
#----------------------------------------------------------------------
def setStartDate(self, startDate='20100416', initDays=10):
"""设置回测的启动日期"""
self.dataStartDate = datetime.strptime(startDate, '%Y%m%d')
# 初始化天数
initTimeDelta = timedelta(initDays)
self.strategyStartDate = self.dataStartDate + initTimeDelta
#----------------------------------------------------------------------
def setEndDate(self, endDate=''):
"""设置回测的结束日期"""
if endDate:
self.dataEndDate= datetime.strptime(endDate, '%Y%m%d')
self.dataEndDate = datetime.strptime(endDate, '%Y%m%d')
else:
self.dataEndDate = datetime.now()
def setMinDiff(self, minDiff):
"""设置回测品种的最小跳价,用于修正数据"""
self.minDiff = minDiff
#----------------------------------------------------------------------
def setBacktestingMode(self, mode):
"""设置回测模式"""
self.mode = mode
#----------------------------------------------------------------------
def loadHistoryData(self, dbName, symbol):
def setDatabase(self, dbName, symbol):
"""设置历史数据所用的数据库"""
self.dbName = dbName
self.symbol = symbol
#----------------------------------------------------------------------
def loadHistoryDataFromMongo(self):
"""载入历史数据"""
host, port = loadMongoSetting()
self.dbClient = pymongo.MongoClient(host, port)
collection = self.dbClient[dbName][symbol]
collection = self.dbClient[self.dbName][self.symbol]
self.output(u'开始载入数据')
@ -130,10 +159,370 @@ class BacktestingEngine(object):
self.dbCursor = collection.find(flt)
self.output(u'载入完成,数据量:%s' %(initCursor.count() + self.dbCursor.count()))
#----------------------------------------------------------------------
def connectMysql(self):
"""连接MysqlDB"""
# 载入json文件
fileName = 'mysql_connect.json'
try:
f = file(fileName)
except IOError:
self.writeCtaLog(u'回测引擎读取Mysql_connect.json失败')
return
# 解析json文件
setting = json.load(f)
try:
mysql_host = str(setting['host'])
mysql_port = int(setting['port'])
mysql_user = str(setting['user'])
mysql_passwd = str(setting['passwd'])
mysql_db = str(setting['db'])
except IOError:
self.writeCtaLog(u'回测引擎读取Mysql_connect.json,连接配置缺少字段,请检查')
return
try:
self.__mysqlConnection = MySQLdb.connect(host=mysql_host, user=mysql_user,
passwd=mysql_passwd, db=mysql_db, port=mysql_port)
self.__mysqlConnected = True
self.writeCtaLog(u'回测引擎连接MysqlDB成功')
except ConnectionFailure:
self.writeCtaLog(u'回测引擎连接MysqlDB失败')
#----------------------------------------------------------------------
def loadDataHistoryFromMysql(self, symbol, startDate, endDate):
"""载入历史TICK数据
如果加载过多数据会导致加载失败,间隔不要超过半年
"""
if not endDate:
endDate = datetime.today()
# 看本地缓存是否存在
if self.__loadDataHistoryFromLocalCache(symbol, startDate, endDate):
self.writeCtaLog(u'历史TICK数据从Cache载入')
return
# 每次获取日期周期
intervalDays = 10
for i in range (0,(endDate - startDate).days +1, intervalDays):
d1 = startDate + timedelta(days = i )
if (endDate - d1).days > 10:
d2 = startDate + timedelta(days = i + intervalDays -1 )
else:
d2 = endDate
# 从Mysql 提取数据
self.__qryDataHistoryFromMysql(symbol, d1, d2)
self.writeCtaLog(u'历史TICK数据共载入{0}'.format(len(self.historyData)))
# 保存本地cache文件
self.__saveDataHistoryToLocalCache(symbol, startDate, endDate)
def __loadDataHistoryFromLocalCache(self, symbol, startDate, endDate):
"""看本地缓存是否存在
added by IncenseLee
"""
# 运行路径下cache子目录
cacheFolder = os.getcwd()+'/cache'
# cache文件
cacheFile = u'{0}/{1}_{2}_{3}.pickle'.\
format(cacheFolder, symbol, startDate.strftime('%Y-%m-%d'), endDate.strftime('%Y-%m-%d'))
if not os.path.isfile(cacheFile):
return False
else:
# 从cache文件加载
cache = open(cacheFile,mode='r')
self.historyData = cPickle.load(cache)
cache.close()
return True
def __saveDataHistoryToLocalCache(self, symbol, startDate, endDate):
"""保存本地缓存
added by IncenseLee
"""
# 运行路径下cache子目录
cacheFolder = os.getcwd()+'/cache'
# 创建cache子目录
if not os.path.isdir(cacheFolder):
os.mkdir(cacheFolder)
# cache 文件名
cacheFile = u'{0}/{1}_{2}_{3}.pickle'.\
format(cacheFolder, symbol, startDate.strftime('%Y-%m-%d'), endDate.strftime('%Y-%m-%d'))
# 重复存在 返回
if os.path.isfile(cacheFile):
return False
else:
# 写入cache文件
cache = open(cacheFile, mode='w')
cPickle.dump(self.historyData,cache)
cache.close()
return True
#----------------------------------------------------------------------
def __qryDataHistoryFromMysql(self, symbol, startDate, endDate):
"""从Mysql载入历史TICK数据
added by IncenseLee
"""
try:
self.connectMysql()
if self.__mysqlConnected:
# 获取指针
cur = self.__mysqlConnection.cursor(MySQLdb.cursors.DictCursor)
if endDate:
# 开始日期 ~ 结束日期
sqlstring = ' select \'{0}\' as InstrumentID, str_to_date(concat(ndate,\' \', ntime),' \
'\'%Y-%m-%d %H:%i:%s\') as UpdateTime,price as LastPrice,vol as Volume,' \
'position_vol as OpenInterest,bid1_price as BidPrice1,bid1_vol as BidVolume1, ' \
'sell1_price as AskPrice1, sell1_vol as AskVolume1 from TB_{0}MI ' \
'where ndate between cast(\'{1}\' as date) and cast(\'{2}\' as date) order by UpdateTime'.\
format(symbol, startDate, endDate)
elif startDate:
# 开始日期 - 当前
sqlstring = ' select \'{0}\' as InstrumentID,str_to_date(concat(ndate,\' \', ntime),' \
'\'%Y-%m-%d %H:%i:%s\') as UpdateTime,price as LastPrice,vol as Volume,' \
'position_vol as OpenInterest,bid1_price as BidPrice1,bid1_vol as BidVolume1, ' \
'sell1_price as AskPrice1, sell1_vol as AskVolume1 from TB__{0}MI ' \
'where ndate > cast(\'{1}\' as date) order by UpdateTime'.\
format( symbol, startDate)
else:
# 所有数据
sqlstring =' select \'{0}\' as InstrumentID,str_to_date(concat(ndate,\' \', ntime),' \
'\'%Y-%m-%d %H:%i:%s\') as UpdateTime,price as LastPrice,vol as Volume,' \
'position_vol as OpenInterest,bid1_price as BidPrice1,bid1_vol as BidVolume1, ' \
'sell1_price as AskPrice1, sell1_vol as AskVolume1 from TB__{0}MI order by UpdateTime'.\
format(symbol)
self.writeCtaLog(sqlstring)
# 执行查询
count = cur.execute(sqlstring)
self.writeCtaLog(u'历史TICK数据共{0}'.format(count))
# 分批次读取
fetch_counts = 0
fetch_size = 1000
while True:
results = cur.fetchmany(fetch_size)
if not results:
break
fetch_counts = fetch_counts + len(results)
if not self.historyData:
self.historyData =results
else:
self.historyData = self.historyData + results
self.writeCtaLog(u'{1}~{2}历史TICK数据载入共{0}'.format(fetch_counts,startDate,endDate))
else:
self.writeCtaLog(u'MysqlDB未连接请检查')
except MySQLdb.Error, e:
self.writeCtaLog(u'MysqlDB载入数据失败请检查.Error {0}'.format(e))
def __dataToTick(self, data):
"""
数据库查询返回的data结构转换为tick对象
added by IncenseLee
"""
tick = CtaTickData()
symbol = data['InstrumentID']
tick.symbol = symbol
# 创建TICK数据对象并更新数据
tick.vtSymbol = symbol
# tick.openPrice = data['OpenPrice']
# tick.highPrice = data['HighestPrice']
# tick.lowPrice = data['LowestPrice']
tick.lastPrice = float(data['LastPrice'])
tick.volume = data['Volume']
tick.openInterest = data['OpenInterest']
# tick.upperLimit = data['UpperLimitPrice']
# tick.lowerLimit = data['LowerLimitPrice']
tick.datetime = data['UpdateTime']
tick.date = tick.datetime.strftime('%Y-%m-%d')
tick.time = tick.datetime.strftime('%H:%M:%S')
tick.bidPrice1 = float(data['BidPrice1'])
# tick.bidPrice2 = data['BidPrice2']
# tick.bidPrice3 = data['BidPrice3']
# tick.bidPrice4 = data['BidPrice4']
# tick.bidPrice5 = data['BidPrice5']
tick.askPrice1 = float(data['AskPrice1'])
# tick.askPrice2 = data['AskPrice2']
# tick.askPrice3 = data['AskPrice3']
# tick.askPrice4 = data['AskPrice4']
# tick.askPrice5 = data['AskPrice5']
tick.bidVolume1 = data['BidVolume1']
# tick.bidVolume2 = data['BidVolume2']
# tick.bidVolume3 = data['BidVolume3']
# tick.bidVolume4 = data['BidVolume4']
# tick.bidVolume5 = data['BidVolume5']
tick.askVolume1 = data['AskVolume1']
# tick.askVolume2 = data['AskVolume2']
# tick.askVolume3 = data['AskVolume3']
# tick.askVolume4 = data['AskVolume4']
# tick.askVolume5 = data['AskVolume5']
return tick
#----------------------------------------------------------------------
def getMysqlDeltaDate(self,symbol, startDate, decreaseDays):
"""从mysql库中获取交易日前若干天
added by IncenseLee
"""
try:
if self.__mysqlConnected:
# 获取mysql指针
cur = self.__mysqlConnection.cursor()
sqlstring='select distinct ndate from TB_{0}MI where ndate < ' \
'cast(\'{1}\' as date) order by ndate desc limit {2},1'.format(symbol, startDate, decreaseDays-1)
# self.writeCtaLog(sqlstring)
count = cur.execute(sqlstring)
if count > 0:
# 提取第一条记录
result = cur.fetchone()
return result[0]
else:
self.writeCtaLog(u'MysqlDB没有查询结果请检查日期')
else:
self.writeCtaLog(u'MysqlDB未连接请检查')
except MySQLdb.Error, e:
self.writeCtaLog(u'MysqlDB载入数据失败请检查.Error {0}: {1}'.format(e.arg[0],e.arg[1]))
# 出错后缺省返回
return startDate-timedelta(days=3)
#----------------------------------------------------------------------
def runBacktestingWithMysql(self):
"""运行回测(使用Mysql数据
added by IncenseLee
"""
if not self.dataStartDate:
self.writeCtaLog(u'回测开始日期未设置。')
return
if not self.dataEndDate:
self.dataEndDate = datetime.today()
if len(self.symbol)<1:
self.writeCtaLog(u'回测对象未设置。')
return
# 首先根据回测模式,确认要使用的数据类
if self.mode == self.BAR_MODE:
dataClass = CtaBarData
func = self.newBar
else:
dataClass = CtaTickData
func = self.newTick
self.output(u'开始回测')
#self.strategy.inited = True
self.strategy.onInit()
self.output(u'策略初始化完成')
self.strategy.trading = True
self.strategy.onStart()
self.output(u'策略启动完成')
self.output(u'开始回放数据')
# 每次获取日期周期
intervalDays = 10
for i in range (0,(self.dataEndDate - self.dataStartDate).days +1, intervalDays):
d1 = self.dataStartDate + timedelta(days = i )
if (self.dataEndDate - d1).days > intervalDays:
d2 = self.dataStartDate + timedelta(days = i + intervalDays -1 )
else:
d2 = self.dataEndDate
# 提取历史数据
self.loadDataHistoryFromMysql(self.symbol, d1, d2)
self.output(u'数据日期:{0} => {1}'.format(d1,d2))
# 将逐笔数据推送
for data in self.historyData:
# 记录最新的TICK数据
self.tick = self.__dataToTick(data)
self.dt = self.tick.datetime
# 处理限价单
self.crossLimitOrder()
self.crossStopOrder()
# 推送到策略引擎中
self.strategy.onTick(self.tick)
# 清空历史数据
self.historyData = []
self.output(u'数据回放结束')
#----------------------------------------------------------------------
def runBacktesting(self):
"""运行回测"""
# 载入历史数据
self.loadHistoryData()
# 首先根据回测模式,确认要使用的数据类
if self.mode == self.BAR_MODE:
dataClass = CtaBarData
@ -160,7 +549,7 @@ class BacktestingEngine(object):
func(data)
self.output(u'数据回放结束')
#----------------------------------------------------------------------
def newBar(self, bar):
"""新的K线"""
@ -191,6 +580,8 @@ class BacktestingEngine(object):
#----------------------------------------------------------------------
def sendOrder(self, vtSymbol, orderType, price, volume, strategy):
"""发单"""
self.writeCtaLog(u'{0},{1},{2}@{3}'.format(vtSymbol,orderType,price,volume))
self.limitOrderCount += 1
orderID = str(self.limitOrderCount)
@ -202,6 +593,9 @@ class BacktestingEngine(object):
order.orderID = orderID
order.vtOrderID = orderID
order.orderTime = str(self.dt)
# added by IncenseLee
order.gatewayName = self.gatewayName
# CTA委托类型映射
if orderType == CTAORDER_BUY:
@ -220,8 +614,9 @@ class BacktestingEngine(object):
# 保存到限价单字典中
self.workingLimitOrderDict[orderID] = order
self.limitOrderDict[orderID] = order
return orderID
# modified by IncenseLee
return u'{0}.{1}'.format(order.gatewayName, orderID)
#----------------------------------------------------------------------
def cancelOrder(self, vtOrderID):
@ -279,13 +674,15 @@ class BacktestingEngine(object):
"""基于最新数据撮合限价单"""
# 先确定会撮合成交的价格
if self.mode == self.BAR_MODE:
buyCrossPrice = self.bar.low # 若买入方向限价单价格高于该价格,则会成交
sellCrossPrice = self.bar.high # 若卖出方向限价单价格低于该价格,则会成交
bestCrossPrice = self.bar.open # 在当前时间点前发出的委托可能的最优成交价
buyCrossPrice = self.bar.low # 若买入方向限价单价格高于该价格,则会成交
sellCrossPrice = self.bar.high # 若卖出方向限价单价格低于该价格,则会成交
buyBestCrossPrice = self.bar.open # 在当前时间点前发出的买入委托可能的最优成交价
sellBestCrossPrice = self.bar.open # 在当前时间点前发出的卖出委托可能的最优成交价
else:
buyCrossPrice = self.tick.lastPrice
sellCrossPrice = self.tick.lastPrice
bestCrossPrice = self.tick.lastPrice
buyCrossPrice = self.tick.askPrice1
sellCrossPrice = self.tick.bidPrice1
buyBestCrossPrice = self.tick.askPrice1
sellBestCrossPrice = self.tick.bidPrice1
# 遍历限价单字典中的所有限价单
for orderID, order in self.workingLimitOrderDict.items():
@ -312,10 +709,10 @@ class BacktestingEngine(object):
# 2. 假设在上一根K线结束(也是当前K线开始)的时刻策略发出的委托为限价105
# 3. 则在实际中的成交价会是100而不是105因为委托发出时市场的最优价格是100
if buyCross:
trade.price = min(order.price, bestCrossPrice)
trade.price = min(order.price, buyBestCrossPrice)
self.strategy.pos += order.totalVolume
else:
trade.price = max(order.price, bestCrossPrice)
trade.price = max(order.price, sellBestCrossPrice)
self.strategy.pos -= order.totalVolume
trade.volume = order.totalVolume
@ -425,27 +822,27 @@ class BacktestingEngine(object):
"""记录日志"""
log = str(self.dt) + ' ' + content
self.logList.append(log)
# 写入本地log日志
logging.info(content)
#----------------------------------------------------------------------
def output(self, content):
"""输出内容"""
print content
print str(datetime.now()) + "\t" + content
#----------------------------------------------------------------------
def showBacktestingResult(self):
def calculateBacktestingResult(self):
"""
显示回测结果
计算回测结果
"""
self.output(u'显示回测结果')
self.output(u'计算回测结果')
# 首先基于回测后的成交记录,计算每笔交易的盈亏
pnlDict = OrderedDict() # 每笔盈亏的记录
resultDict = OrderedDict() # 交易结果记录
longTrade = [] # 未平仓的多头交易
shortTrade = [] # 未平仓的空头交易
# 计算滑点,一个来回包括两次
totalSlippage = self.slippage * 2
for trade in self.tradeDict.values():
# 多头交易
if trade.direction == DIRECTION_LONG:
@ -455,12 +852,15 @@ class BacktestingEngine(object):
# 当前多头交易为平空
else:
entryTrade = shortTrade.pop(0)
# 计算比例佣金
commission = (trade.price+entryTrade.price) * self.rate
# 计算盈亏
pnl = ((trade.price - entryTrade.price)*(-1) - totalSlippage - commission) \
* trade.volume * self.size
pnlDict[trade.dt] = pnl
result = TradingResult(entryTrade.price, trade.price, -trade.volume,
self.rate, self.slippage, self.size)
resultDict[trade.dt] = result
self.output(u'{0},short:{1},{2},cover:{3},vol:{4},{5}'
.format(entryTrade.tradeTime, entryTrade.price,trade.tradeTime,trade.price, trade.volume,result.pnl))
# 空头交易
else:
# 如果尚无多头交易
@ -469,58 +869,101 @@ class BacktestingEngine(object):
# 当前空头交易为平多
else:
entryTrade = longTrade.pop(0)
# 计算比例佣金
commission = (trade.price+entryTrade.price) * self.rate
# 计算盈亏
pnl = ((trade.price - entryTrade.price) - totalSlippage - commission) \
* trade.volume * self.size
pnlDict[trade.dt] = pnl
result = TradingResult(entryTrade.price, trade.price, trade.volume,
self.rate, self.slippage, self.size)
resultDict[trade.dt] = result
self.output(u'{0},buy:{1},{2},sell:{3},vol:{4},{5}'
.format(entryTrade.tradeTime, entryTrade.price,trade.tradeTime,trade.price, trade.volume,result.pnl))
# 检查是否有交易
if not resultDict:
self.output(u'无交易结果')
return {}
# 然后基于每笔交易的结果,我们可以计算具体的盈亏曲线和最大回撤等
timeList = pnlDict.keys()
pnlList = pnlDict.values()
capital = 0 # 资金
maxCapital = 0 # 资金最高净值
drawdown = 0 # 回撤
capital = 0
maxCapital = 0
drawdown = 0
totalResult = 0 # 总成交数量
totalTurnover = 0 # 总成交金额(合约面值)
totalCommission = 0 # 总手续费
totalSlippage = 0 # 总滑点
timeList = [] # 时间序列
pnlList = [] # 每笔盈亏序列
capitalList = [] # 盈亏汇总的时间序列
maxCapitalList = [] # 最高盈利的时间序列
drawdownList = [] # 回撤的时间序列
for pnl in pnlList:
capital += pnl
for time, result in resultDict.items():
capital += result.pnl
maxCapital = max(capital, maxCapital)
drawdown = capital - maxCapital
pnlList.append(result.pnl)
timeList.append(time)
capitalList.append(capital)
maxCapitalList.append(maxCapital)
drawdownList.append(drawdown)
totalResult += 1
totalTurnover += result.turnover
totalCommission += result.commission
totalSlippage += result.slippage
# 返回回测结果
d = {}
d['capital'] = capital
d['maxCapital'] = maxCapital
d['drawdown'] = drawdown
d['totalResult'] = totalResult
d['totalTurnover'] = totalTurnover
d['totalCommission'] = totalCommission
d['totalSlippage'] = totalSlippage
d['timeList'] = timeList
d['pnlList'] = pnlList
d['capitalList'] = capitalList
d['drawdownList'] = drawdownList
return d
#----------------------------------------------------------------------
def showBacktestingResult(self):
"""显示回测结果"""
d = self.calculateBacktestingResult()
if len(d)== 0:
self.output(u'无交易结果')
return
# 输出
self.output('-' * 50)
self.output(u'第一笔交易时间:%s' % timeList[0])
self.output(u'最后一笔交易时间:%s' % timeList[-1])
self.output(u'总交易次数:%s' % len(pnlList))
self.output(u'总盈亏:%s' % capitalList[-1])
self.output(u'最大回撤: %s' % min(drawdownList))
self.output('-' * 30)
self.output(u'第一笔交易:\t%s' % d['timeList'][0])
self.output(u'最后一笔交易:\t%s' % d['timeList'][-1])
self.output(u'总交易次数:\t%s' % formatNumber(d['totalResult']))
self.output(u'总盈亏:\t%s' % formatNumber(d['capital']))
self.output(u'最大回撤: \t%s' % formatNumber(min(d['drawdownList'])))
self.output(u'平均每笔盈利:\t%s' %formatNumber(d['capital']/d['totalResult']))
self.output(u'平均每笔滑点:\t%s' %formatNumber(d['totalSlippage']/d['totalResult']))
self.output(u'平均每笔佣金:\t%s' %formatNumber(d['totalCommission']/d['totalResult']))
# 绘图
import matplotlib.pyplot as plt
#import matplotlib.pyplot as plt
pCapital = plt.subplot(3, 1, 1)
pCapital.set_ylabel("capital")
pCapital.plot(capitalList)
#pCapital = plt.subplot(3, 1, 1)
#pCapital.set_ylabel("capital")
#pCapital.plot(d['capitalList'])
pDD = plt.subplot(3, 1, 2)
pDD.set_ylabel("DD")
pDD.bar(range(len(drawdownList)), drawdownList)
#pDD = plt.subplot(3, 1, 2)
#pDD.set_ylabel("DD")
#pDD.bar(range(len(d['drawdownList'])), d['drawdownList'])
pPnl = plt.subplot(3, 1, 3)
pPnl.set_ylabel("pnl")
pPnl.hist(pnlList, bins=20)
#pPnl = plt.subplot(3, 1, 3)
#pPnl.set_ylabel("pnl")
#pPnl.hist(d['pnlList'], bins=50)
plt.show()
#plt.show()
#----------------------------------------------------------------------
def putStrategyEvent(self, name):
@ -529,7 +972,7 @@ class BacktestingEngine(object):
#----------------------------------------------------------------------
def setSlippage(self, slippage):
"""设置滑点"""
"""设置滑点点数"""
self.slippage = slippage
#----------------------------------------------------------------------
@ -542,6 +985,136 @@ class BacktestingEngine(object):
"""设置佣金比例"""
self.rate = rate
#----------------------------------------------------------------------
def runOptimization(self, strategyClass, optimizationSetting):
"""优化参数"""
# 获取优化设置
settingList = optimizationSetting.generateSetting()
targetName = optimizationSetting.optimizeTarget
# 检查参数设置问题
if not settingList or not targetName:
self.output(u'优化设置有问题,请检查')
# 遍历优化
resultList = []
for setting in settingList:
self.clearBacktestingResult()
self.output('-' * 30)
self.output('setting: %s' %str(setting))
self.initStrategy(strategyClass, setting)
self.runBacktesting()
d = self.calculateBacktestingResult()
try:
targetValue = d[targetName]
except KeyError:
targetValue = 0
resultList.append(([str(setting)], targetValue))
# 显示结果
resultList.sort(reverse=True, key=lambda result:result[1])
self.output('-' * 30)
self.output(u'优化结果:')
for result in resultList:
self.output(u'%s: %s' %(result[0], result[1]))
#----------------------------------------------------------------------
def clearBacktestingResult(self):
"""清空之前回测的结果"""
# 清空限价单相关
self.limitOrderCount = 0
self.limitOrderDict.clear()
self.workingLimitOrderDict.clear()
# 清空停止单相关
self.stopOrderCount = 0
self.stopOrderDict.clear()
self.workingStopOrderDict.clear()
# 清空成交相关
self.tradeCount = 0
self.tradeDict.clear()
########################################################################
class TradingResult(object):
"""每笔交易的结果"""
#----------------------------------------------------------------------
def __init__(self, entry, exit, volume, rate, slippage, size):
"""Constructor"""
self.entry = entry # 开仓价格
self.exit = exit # 平仓价格
self.volume = volume # 交易数量(+/-代表方向)
self.turnover = (self.entry+self.exit)*size # 成交金额
self.commission = self.turnover*rate # 手续费成本
self.slippage = slippage*2*size # 滑点成本
self.pnl = ((self.exit - self.entry) * volume * size
- self.commission - self.slippage) # 净盈亏
########################################################################
class OptimizationSetting(object):
"""优化设置"""
#----------------------------------------------------------------------
def __init__(self):
"""Constructor"""
self.paramDict = OrderedDict()
self.optimizeTarget = '' # 优化目标字段
#----------------------------------------------------------------------
def addParameter(self, name, start, end, step):
"""增加优化参数"""
if end <= start:
print u'参数起始点必须小于终止点'
return
if step <= 0:
print u'参数布进必须大于0'
return
l = []
param = start
while param <= end:
l.append(param)
param += step
self.paramDict[name] = l
#----------------------------------------------------------------------
def generateSetting(self):
"""生成优化参数组合"""
# 参数名的列表
nameList = self.paramDict.keys()
paramList = self.paramDict.values()
# 使用迭代工具生产参数对组合
productList = list(product(*paramList))
# 把参数对组合打包到一个个字典组成的列表中
settingList = []
for p in productList:
d = dict(zip(nameList, p))
settingList.append(d)
return settingList
#----------------------------------------------------------------------
def setOptimizeTarget(self, target):
"""设置优化目标字段"""
self.optimizeTarget = target
#----------------------------------------------------------------------
def formatNumber(n):
"""格式化数字到字符串"""
n = round(n, 2) # 保留两位小数
return format(n, ',') # 加上千分符
if __name__ == '__main__':
@ -560,7 +1133,7 @@ if __name__ == '__main__':
engine.setStartDate('20110101')
# 载入历史数据到引擎中
engine.loadHistoryData(MINUTE_DB_NAME, 'IF0000')
engine.setDatabase(MINUTE_DB_NAME, 'IF0000')
# 设置产品相关参数
engine.setSlippage(0.2) # 股指1跳