修改回测模型,支持mysql的tick数据库

修改SendOrder方法,增加返回GatewayName
修改Tick的方法,增加data-》Tick对象。
This commit is contained in:
msincenselee 2016-09-14 22:29:31 +08:00
parent e373f16ac0
commit c2f3ba62c7

View File

@ -7,15 +7,21 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
from collections import OrderedDict from collections import OrderedDict
from itertools import product
import pymongo import pymongo
import MySQLdb
import json
import os
import cPickle
from ctaBase import * from ctaBase import *
from ctaSetting import * from ctaSetting import *
from vtConstant import * from vtConstant import *
from vtGateway import VtOrderData, VtTradeData from vtGateway import VtOrderData, VtTradeData
from vtFunction import loadMongoSetting from vtFunction import loadMongoSetting
import logging
######################################################################## ########################################################################
class BacktestingEngine(object): class BacktestingEngine(object):
@ -23,6 +29,10 @@ class BacktestingEngine(object):
CTA回测引擎 CTA回测引擎
函数接口和策略引擎保持一样 函数接口和策略引擎保持一样
从而实现同一套代码从回测到实盘 从而实现同一套代码从回测到实盘
# modified by IncenseLee
1.增加Mysql数据库的支持
2.修改装载数据为批量式后加载模式
""" """
TICK_MODE = 'tick' TICK_MODE = 'tick'
@ -51,10 +61,13 @@ class BacktestingEngine(object):
self.dbClient = None # 数据库客户端 self.dbClient = None # 数据库客户端
self.dbCursor = None # 数据库指针 self.dbCursor = None # 数据库指针
#self.historyData = [] # 历史数据的列表,回测用 self.historyData = [] # 历史数据的列表,回测用
self.initData = [] # 初始化用的数据 self.initData = [] # 初始化用的数据
#self.backtestingData = [] # 回测用的数据 self.backtestingData = [] # 回测用的数据
self.dbName = '' # 回测数据库名
self.symbol = '' # 回测集合名
self.dataStartDate = None # 回测数据开始日期datetime对象 self.dataStartDate = None # 回测数据开始日期datetime对象
self.dataEndDate = None # 回测数据结束日期datetime对象 self.dataEndDate = None # 回测数据结束日期datetime对象
self.strategyStartDate = None # 策略启动日期即前面的数据用于初始化datetime对象 self.strategyStartDate = None # 策略启动日期即前面的数据用于初始化datetime对象
@ -72,33 +85,49 @@ class BacktestingEngine(object):
self.tick = None self.tick = None
self.bar = None self.bar = None
self.dt = None # 最新的时间 self.dt = None # 最新的时间
self.gatewayName = u'BackTest'
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def setStartDate(self, startDate='20100416', initDays=10): def setStartDate(self, startDate='20100416', initDays=10):
"""设置回测的启动日期""" """设置回测的启动日期"""
self.dataStartDate = datetime.strptime(startDate, '%Y%m%d') self.dataStartDate = datetime.strptime(startDate, '%Y%m%d')
# 初始化天数
initTimeDelta = timedelta(initDays) initTimeDelta = timedelta(initDays)
self.strategyStartDate = self.dataStartDate + initTimeDelta self.strategyStartDate = self.dataStartDate + initTimeDelta
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def setEndDate(self, endDate=''): def setEndDate(self, endDate=''):
"""设置回测的结束日期""" """设置回测的结束日期"""
if endDate: if endDate:
self.dataEndDate= datetime.strptime(endDate, '%Y%m%d') self.dataEndDate = datetime.strptime(endDate, '%Y%m%d')
else:
self.dataEndDate = datetime.now()
def setMinDiff(self, minDiff):
"""设置回测品种的最小跳价,用于修正数据"""
self.minDiff = minDiff
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def setBacktestingMode(self, mode): def setBacktestingMode(self, mode):
"""设置回测模式""" """设置回测模式"""
self.mode = mode self.mode = mode
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def loadHistoryData(self, dbName, symbol): def setDatabase(self, dbName, symbol):
"""设置历史数据所用的数据库"""
self.dbName = dbName
self.symbol = symbol
#----------------------------------------------------------------------
def loadHistoryDataFromMongo(self):
"""载入历史数据""" """载入历史数据"""
host, port = loadMongoSetting() host, port = loadMongoSetting()
self.dbClient = pymongo.MongoClient(host, port) self.dbClient = pymongo.MongoClient(host, port)
collection = self.dbClient[dbName][symbol] collection = self.dbClient[self.dbName][self.symbol]
self.output(u'开始载入数据') self.output(u'开始载入数据')
@ -130,10 +159,370 @@ class BacktestingEngine(object):
self.dbCursor = collection.find(flt) self.dbCursor = collection.find(flt)
self.output(u'载入完成,数据量:%s' %(initCursor.count() + self.dbCursor.count())) self.output(u'载入完成,数据量:%s' %(initCursor.count() + self.dbCursor.count()))
#----------------------------------------------------------------------
def connectMysql(self):
"""连接MysqlDB"""
# 载入json文件
fileName = 'mysql_connect.json'
try:
f = file(fileName)
except IOError:
self.writeCtaLog(u'回测引擎读取Mysql_connect.json失败')
return
# 解析json文件
setting = json.load(f)
try:
mysql_host = str(setting['host'])
mysql_port = int(setting['port'])
mysql_user = str(setting['user'])
mysql_passwd = str(setting['passwd'])
mysql_db = str(setting['db'])
except IOError:
self.writeCtaLog(u'回测引擎读取Mysql_connect.json,连接配置缺少字段,请检查')
return
try:
self.__mysqlConnection = MySQLdb.connect(host=mysql_host, user=mysql_user,
passwd=mysql_passwd, db=mysql_db, port=mysql_port)
self.__mysqlConnected = True
self.writeCtaLog(u'回测引擎连接MysqlDB成功')
except ConnectionFailure:
self.writeCtaLog(u'回测引擎连接MysqlDB失败')
#----------------------------------------------------------------------
def loadDataHistoryFromMysql(self, symbol, startDate, endDate):
"""载入历史TICK数据
如果加载过多数据会导致加载失败,间隔不要超过半年
"""
if not endDate:
endDate = datetime.today()
# 看本地缓存是否存在
if self.__loadDataHistoryFromLocalCache(symbol, startDate, endDate):
self.writeCtaLog(u'历史TICK数据从Cache载入')
return
# 每次获取日期周期
intervalDays = 10
for i in range (0,(endDate - startDate).days +1, intervalDays):
d1 = startDate + timedelta(days = i )
if (endDate - d1).days > 10:
d2 = startDate + timedelta(days = i + intervalDays -1 )
else:
d2 = endDate
# 从Mysql 提取数据
self.__qryDataHistoryFromMysql(symbol, d1, d2)
self.writeCtaLog(u'历史TICK数据共载入{0}'.format(len(self.historyData)))
# 保存本地cache文件
self.__saveDataHistoryToLocalCache(symbol, startDate, endDate)
def __loadDataHistoryFromLocalCache(self, symbol, startDate, endDate):
"""看本地缓存是否存在
added by IncenseLee
"""
# 运行路径下cache子目录
cacheFolder = os.getcwd()+'/cache'
# cache文件
cacheFile = u'{0}/{1}_{2}_{3}.pickle'.\
format(cacheFolder, symbol, startDate.strftime('%Y-%m-%d'), endDate.strftime('%Y-%m-%d'))
if not os.path.isfile(cacheFile):
return False
else:
# 从cache文件加载
cache = open(cacheFile,mode='r')
self.historyData = cPickle.load(cache)
cache.close()
return True
def __saveDataHistoryToLocalCache(self, symbol, startDate, endDate):
"""保存本地缓存
added by IncenseLee
"""
# 运行路径下cache子目录
cacheFolder = os.getcwd()+'/cache'
# 创建cache子目录
if not os.path.isdir(cacheFolder):
os.mkdir(cacheFolder)
# cache 文件名
cacheFile = u'{0}/{1}_{2}_{3}.pickle'.\
format(cacheFolder, symbol, startDate.strftime('%Y-%m-%d'), endDate.strftime('%Y-%m-%d'))
# 重复存在 返回
if os.path.isfile(cacheFile):
return False
else:
# 写入cache文件
cache = open(cacheFile, mode='w')
cPickle.dump(self.historyData,cache)
cache.close()
return True
#----------------------------------------------------------------------
def __qryDataHistoryFromMysql(self, symbol, startDate, endDate):
"""从Mysql载入历史TICK数据
added by IncenseLee
"""
try:
self.connectMysql()
if self.__mysqlConnected:
# 获取指针
cur = self.__mysqlConnection.cursor(MySQLdb.cursors.DictCursor)
if endDate:
# 开始日期 ~ 结束日期
sqlstring = ' select \'{0}\' as InstrumentID, str_to_date(concat(ndate,\' \', ntime),' \
'\'%Y-%m-%d %H:%i:%s\') as UpdateTime,price as LastPrice,vol as Volume,' \
'position_vol as OpenInterest,bid1_price as BidPrice1,bid1_vol as BidVolume1, ' \
'sell1_price as AskPrice1, sell1_vol as AskVolume1 from TB_{0}MI ' \
'where ndate between cast(\'{1}\' as date) and cast(\'{2}\' as date) order by UpdateTime'.\
format(symbol, startDate, endDate)
elif startDate:
# 开始日期 - 当前
sqlstring = ' select \'{0}\' as InstrumentID,str_to_date(concat(ndate,\' \', ntime),' \
'\'%Y-%m-%d %H:%i:%s\') as UpdateTime,price as LastPrice,vol as Volume,' \
'position_vol as OpenInterest,bid1_price as BidPrice1,bid1_vol as BidVolume1, ' \
'sell1_price as AskPrice1, sell1_vol as AskVolume1 from TB__{0}MI ' \
'where ndate > cast(\'{1}\' as date) order by UpdateTime'.\
format( symbol, startDate)
else:
# 所有数据
sqlstring =' select \'{0}\' as InstrumentID,str_to_date(concat(ndate,\' \', ntime),' \
'\'%Y-%m-%d %H:%i:%s\') as UpdateTime,price as LastPrice,vol as Volume,' \
'position_vol as OpenInterest,bid1_price as BidPrice1,bid1_vol as BidVolume1, ' \
'sell1_price as AskPrice1, sell1_vol as AskVolume1 from TB__{0}MI order by UpdateTime'.\
format(symbol)
self.writeCtaLog(sqlstring)
# 执行查询
count = cur.execute(sqlstring)
self.writeCtaLog(u'历史TICK数据共{0}'.format(count))
# 分批次读取
fetch_counts = 0
fetch_size = 1000
while True:
results = cur.fetchmany(fetch_size)
if not results:
break
fetch_counts = fetch_counts + len(results)
if not self.historyData:
self.historyData =results
else:
self.historyData = self.historyData + results
self.writeCtaLog(u'{1}~{2}历史TICK数据载入共{0}'.format(fetch_counts,startDate,endDate))
else:
self.writeCtaLog(u'MysqlDB未连接请检查')
except MySQLdb.Error, e:
self.writeCtaLog(u'MysqlDB载入数据失败请检查.Error {0}'.format(e))
def __dataToTick(self, data):
"""
数据库查询返回的data结构转换为tick对象
added by IncenseLee
"""
tick = CtaTickData()
symbol = data['InstrumentID']
tick.symbol = symbol
# 创建TICK数据对象并更新数据
tick.vtSymbol = symbol
# tick.openPrice = data['OpenPrice']
# tick.highPrice = data['HighestPrice']
# tick.lowPrice = data['LowestPrice']
tick.lastPrice = float(data['LastPrice'])
tick.volume = data['Volume']
tick.openInterest = data['OpenInterest']
# tick.upperLimit = data['UpperLimitPrice']
# tick.lowerLimit = data['LowerLimitPrice']
tick.datetime = data['UpdateTime']
tick.date = tick.datetime.strftime('%Y-%m-%d')
tick.time = tick.datetime.strftime('%H:%M:%S')
tick.bidPrice1 = float(data['BidPrice1'])
# tick.bidPrice2 = data['BidPrice2']
# tick.bidPrice3 = data['BidPrice3']
# tick.bidPrice4 = data['BidPrice4']
# tick.bidPrice5 = data['BidPrice5']
tick.askPrice1 = float(data['AskPrice1'])
# tick.askPrice2 = data['AskPrice2']
# tick.askPrice3 = data['AskPrice3']
# tick.askPrice4 = data['AskPrice4']
# tick.askPrice5 = data['AskPrice5']
tick.bidVolume1 = data['BidVolume1']
# tick.bidVolume2 = data['BidVolume2']
# tick.bidVolume3 = data['BidVolume3']
# tick.bidVolume4 = data['BidVolume4']
# tick.bidVolume5 = data['BidVolume5']
tick.askVolume1 = data['AskVolume1']
# tick.askVolume2 = data['AskVolume2']
# tick.askVolume3 = data['AskVolume3']
# tick.askVolume4 = data['AskVolume4']
# tick.askVolume5 = data['AskVolume5']
return tick
#----------------------------------------------------------------------
def getMysqlDeltaDate(self,symbol, startDate, decreaseDays):
"""从mysql库中获取交易日前若干天
added by IncenseLee
"""
try:
if self.__mysqlConnected:
# 获取mysql指针
cur = self.__mysqlConnection.cursor()
sqlstring='select distinct ndate from TB_{0}MI where ndate < ' \
'cast(\'{1}\' as date) order by ndate desc limit {2},1'.format(symbol, startDate, decreaseDays-1)
# self.writeCtaLog(sqlstring)
count = cur.execute(sqlstring)
if count > 0:
# 提取第一条记录
result = cur.fetchone()
return result[0]
else:
self.writeCtaLog(u'MysqlDB没有查询结果请检查日期')
else:
self.writeCtaLog(u'MysqlDB未连接请检查')
except MySQLdb.Error, e:
self.writeCtaLog(u'MysqlDB载入数据失败请检查.Error {0}: {1}'.format(e.arg[0],e.arg[1]))
# 出错后缺省返回
return startDate-timedelta(days=3)
#----------------------------------------------------------------------
def runBacktestingWithMysql(self):
"""运行回测(使用Mysql数据
added by IncenseLee
"""
if not self.dataStartDate:
self.writeCtaLog(u'回测开始日期未设置。')
return
if not self.dataEndDate:
self.dataEndDate = datetime.today()
if len(self.symbol)<1:
self.writeCtaLog(u'回测对象未设置。')
return
# 首先根据回测模式,确认要使用的数据类
if self.mode == self.BAR_MODE:
dataClass = CtaBarData
func = self.newBar
else:
dataClass = CtaTickData
func = self.newTick
self.output(u'开始回测')
#self.strategy.inited = True
self.strategy.onInit()
self.output(u'策略初始化完成')
self.strategy.trading = True
self.strategy.onStart()
self.output(u'策略启动完成')
self.output(u'开始回放数据')
# 每次获取日期周期
intervalDays = 10
for i in range (0,(self.dataEndDate - self.dataStartDate).days +1, intervalDays):
d1 = self.dataStartDate + timedelta(days = i )
if (self.dataEndDate - d1).days > intervalDays:
d2 = self.dataStartDate + timedelta(days = i + intervalDays -1 )
else:
d2 = self.dataEndDate
# 提取历史数据
self.loadDataHistoryFromMysql(self.symbol, d1, d2)
self.output(u'数据日期:{0} => {1}'.format(d1,d2))
# 将逐笔数据推送
for data in self.historyData:
# 记录最新的TICK数据
self.tick = self.__dataToTick(data)
self.dt = self.tick.datetime
# 处理限价单
self.crossLimitOrder()
self.crossStopOrder()
# 推送到策略引擎中
self.strategy.onTick(self.tick)
# 清空历史数据
self.historyData = []
self.output(u'数据回放结束')
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def runBacktesting(self): def runBacktesting(self):
"""运行回测""" """运行回测"""
# 载入历史数据
self.loadHistoryData()
# 首先根据回测模式,确认要使用的数据类 # 首先根据回测模式,确认要使用的数据类
if self.mode == self.BAR_MODE: if self.mode == self.BAR_MODE:
dataClass = CtaBarData dataClass = CtaBarData
@ -160,7 +549,7 @@ class BacktestingEngine(object):
func(data) func(data)
self.output(u'数据回放结束') self.output(u'数据回放结束')
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def newBar(self, bar): def newBar(self, bar):
"""新的K线""" """新的K线"""
@ -191,6 +580,8 @@ class BacktestingEngine(object):
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def sendOrder(self, vtSymbol, orderType, price, volume, strategy): def sendOrder(self, vtSymbol, orderType, price, volume, strategy):
"""发单""" """发单"""
self.writeCtaLog(u'{0},{1},{2}@{3}'.format(vtSymbol,orderType,price,volume))
self.limitOrderCount += 1 self.limitOrderCount += 1
orderID = str(self.limitOrderCount) orderID = str(self.limitOrderCount)
@ -202,6 +593,9 @@ class BacktestingEngine(object):
order.orderID = orderID order.orderID = orderID
order.vtOrderID = orderID order.vtOrderID = orderID
order.orderTime = str(self.dt) order.orderTime = str(self.dt)
# added by IncenseLee
order.gatewayName = self.gatewayName
# CTA委托类型映射 # CTA委托类型映射
if orderType == CTAORDER_BUY: if orderType == CTAORDER_BUY:
@ -220,8 +614,9 @@ class BacktestingEngine(object):
# 保存到限价单字典中 # 保存到限价单字典中
self.workingLimitOrderDict[orderID] = order self.workingLimitOrderDict[orderID] = order
self.limitOrderDict[orderID] = order self.limitOrderDict[orderID] = order
return orderID # modified by IncenseLee
return u'{0}.{1}'.format(order.gatewayName, orderID)
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def cancelOrder(self, vtOrderID): def cancelOrder(self, vtOrderID):
@ -279,13 +674,15 @@ class BacktestingEngine(object):
"""基于最新数据撮合限价单""" """基于最新数据撮合限价单"""
# 先确定会撮合成交的价格 # 先确定会撮合成交的价格
if self.mode == self.BAR_MODE: if self.mode == self.BAR_MODE:
buyCrossPrice = self.bar.low # 若买入方向限价单价格高于该价格,则会成交 buyCrossPrice = self.bar.low # 若买入方向限价单价格高于该价格,则会成交
sellCrossPrice = self.bar.high # 若卖出方向限价单价格低于该价格,则会成交 sellCrossPrice = self.bar.high # 若卖出方向限价单价格低于该价格,则会成交
bestCrossPrice = self.bar.open # 在当前时间点前发出的委托可能的最优成交价 buyBestCrossPrice = self.bar.open # 在当前时间点前发出的买入委托可能的最优成交价
sellBestCrossPrice = self.bar.open # 在当前时间点前发出的卖出委托可能的最优成交价
else: else:
buyCrossPrice = self.tick.lastPrice buyCrossPrice = self.tick.askPrice1
sellCrossPrice = self.tick.lastPrice sellCrossPrice = self.tick.bidPrice1
bestCrossPrice = self.tick.lastPrice buyBestCrossPrice = self.tick.askPrice1
sellBestCrossPrice = self.tick.bidPrice1
# 遍历限价单字典中的所有限价单 # 遍历限价单字典中的所有限价单
for orderID, order in self.workingLimitOrderDict.items(): for orderID, order in self.workingLimitOrderDict.items():
@ -312,10 +709,10 @@ class BacktestingEngine(object):
# 2. 假设在上一根K线结束(也是当前K线开始)的时刻策略发出的委托为限价105 # 2. 假设在上一根K线结束(也是当前K线开始)的时刻策略发出的委托为限价105
# 3. 则在实际中的成交价会是100而不是105因为委托发出时市场的最优价格是100 # 3. 则在实际中的成交价会是100而不是105因为委托发出时市场的最优价格是100
if buyCross: if buyCross:
trade.price = min(order.price, bestCrossPrice) trade.price = min(order.price, buyBestCrossPrice)
self.strategy.pos += order.totalVolume self.strategy.pos += order.totalVolume
else: else:
trade.price = max(order.price, bestCrossPrice) trade.price = max(order.price, sellBestCrossPrice)
self.strategy.pos -= order.totalVolume self.strategy.pos -= order.totalVolume
trade.volume = order.totalVolume trade.volume = order.totalVolume
@ -425,27 +822,27 @@ class BacktestingEngine(object):
"""记录日志""" """记录日志"""
log = str(self.dt) + ' ' + content log = str(self.dt) + ' ' + content
self.logList.append(log) self.logList.append(log)
# 写入本地log日志
logging.info(content)
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def output(self, content): def output(self, content):
"""输出内容""" """输出内容"""
print content print str(datetime.now()) + "\t" + content
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def showBacktestingResult(self): def calculateBacktestingResult(self):
""" """
显示回测结果 计算回测结果
""" """
self.output(u'显示回测结果') self.output(u'计算回测结果')
# 首先基于回测后的成交记录,计算每笔交易的盈亏 # 首先基于回测后的成交记录,计算每笔交易的盈亏
pnlDict = OrderedDict() # 每笔盈亏的记录 resultDict = OrderedDict() # 交易结果记录
longTrade = [] # 未平仓的多头交易 longTrade = [] # 未平仓的多头交易
shortTrade = [] # 未平仓的空头交易 shortTrade = [] # 未平仓的空头交易
# 计算滑点,一个来回包括两次
totalSlippage = self.slippage * 2
for trade in self.tradeDict.values(): for trade in self.tradeDict.values():
# 多头交易 # 多头交易
if trade.direction == DIRECTION_LONG: if trade.direction == DIRECTION_LONG:
@ -455,12 +852,15 @@ class BacktestingEngine(object):
# 当前多头交易为平空 # 当前多头交易为平空
else: else:
entryTrade = shortTrade.pop(0) entryTrade = shortTrade.pop(0)
# 计算比例佣金
commission = (trade.price+entryTrade.price) * self.rate result = TradingResult(entryTrade.price, trade.price, -trade.volume,
# 计算盈亏 self.rate, self.slippage, self.size)
pnl = ((trade.price - entryTrade.price)*(-1) - totalSlippage - commission) \
* trade.volume * self.size resultDict[trade.dt] = result
pnlDict[trade.dt] = pnl
self.output(u'{0},short:{1},{2},cover:{3},vol:{4},{5}'
.format(entryTrade.tradeTime, entryTrade.price,trade.tradeTime,trade.price, trade.volume,result.pnl))
# 空头交易 # 空头交易
else: else:
# 如果尚无多头交易 # 如果尚无多头交易
@ -469,58 +869,101 @@ class BacktestingEngine(object):
# 当前空头交易为平多 # 当前空头交易为平多
else: else:
entryTrade = longTrade.pop(0) entryTrade = longTrade.pop(0)
# 计算比例佣金
commission = (trade.price+entryTrade.price) * self.rate result = TradingResult(entryTrade.price, trade.price, trade.volume,
# 计算盈亏 self.rate, self.slippage, self.size)
pnl = ((trade.price - entryTrade.price) - totalSlippage - commission) \ resultDict[trade.dt] = result
* trade.volume * self.size
pnlDict[trade.dt] = pnl self.output(u'{0},buy:{1},{2},sell:{3},vol:{4},{5}'
.format(entryTrade.tradeTime, entryTrade.price,trade.tradeTime,trade.price, trade.volume,result.pnl))
# 检查是否有交易
if not resultDict:
self.output(u'无交易结果')
return {}
# 然后基于每笔交易的结果,我们可以计算具体的盈亏曲线和最大回撤等 # 然后基于每笔交易的结果,我们可以计算具体的盈亏曲线和最大回撤等
timeList = pnlDict.keys() capital = 0 # 资金
pnlList = pnlDict.values() maxCapital = 0 # 资金最高净值
drawdown = 0 # 回撤
capital = 0 totalResult = 0 # 总成交数量
maxCapital = 0 totalTurnover = 0 # 总成交金额(合约面值)
drawdown = 0 totalCommission = 0 # 总手续费
totalSlippage = 0 # 总滑点
timeList = [] # 时间序列
pnlList = [] # 每笔盈亏序列
capitalList = [] # 盈亏汇总的时间序列 capitalList = [] # 盈亏汇总的时间序列
maxCapitalList = [] # 最高盈利的时间序列
drawdownList = [] # 回撤的时间序列 drawdownList = [] # 回撤的时间序列
for pnl in pnlList: for time, result in resultDict.items():
capital += pnl capital += result.pnl
maxCapital = max(capital, maxCapital) maxCapital = max(capital, maxCapital)
drawdown = capital - maxCapital drawdown = capital - maxCapital
pnlList.append(result.pnl)
timeList.append(time)
capitalList.append(capital) capitalList.append(capital)
maxCapitalList.append(maxCapital)
drawdownList.append(drawdown) drawdownList.append(drawdown)
totalResult += 1
totalTurnover += result.turnover
totalCommission += result.commission
totalSlippage += result.slippage
# 返回回测结果
d = {}
d['capital'] = capital
d['maxCapital'] = maxCapital
d['drawdown'] = drawdown
d['totalResult'] = totalResult
d['totalTurnover'] = totalTurnover
d['totalCommission'] = totalCommission
d['totalSlippage'] = totalSlippage
d['timeList'] = timeList
d['pnlList'] = pnlList
d['capitalList'] = capitalList
d['drawdownList'] = drawdownList
return d
#----------------------------------------------------------------------
def showBacktestingResult(self):
"""显示回测结果"""
d = self.calculateBacktestingResult()
if len(d)== 0:
self.output(u'无交易结果')
return
# 输出 # 输出
self.output('-' * 50) self.output('-' * 30)
self.output(u'第一笔交易时间:%s' % timeList[0]) self.output(u'第一笔交易:\t%s' % d['timeList'][0])
self.output(u'最后一笔交易时间:%s' % timeList[-1]) self.output(u'最后一笔交易:\t%s' % d['timeList'][-1])
self.output(u'总交易次数:%s' % len(pnlList))
self.output(u'总盈亏:%s' % capitalList[-1]) self.output(u'总交易次数:\t%s' % formatNumber(d['totalResult']))
self.output(u'最大回撤: %s' % min(drawdownList)) self.output(u'总盈亏:\t%s' % formatNumber(d['capital']))
self.output(u'最大回撤: \t%s' % formatNumber(min(d['drawdownList'])))
self.output(u'平均每笔盈利:\t%s' %formatNumber(d['capital']/d['totalResult']))
self.output(u'平均每笔滑点:\t%s' %formatNumber(d['totalSlippage']/d['totalResult']))
self.output(u'平均每笔佣金:\t%s' %formatNumber(d['totalCommission']/d['totalResult']))
# 绘图 # 绘图
import matplotlib.pyplot as plt #import matplotlib.pyplot as plt
pCapital = plt.subplot(3, 1, 1) #pCapital = plt.subplot(3, 1, 1)
pCapital.set_ylabel("capital") #pCapital.set_ylabel("capital")
pCapital.plot(capitalList) #pCapital.plot(d['capitalList'])
pDD = plt.subplot(3, 1, 2) #pDD = plt.subplot(3, 1, 2)
pDD.set_ylabel("DD") #pDD.set_ylabel("DD")
pDD.bar(range(len(drawdownList)), drawdownList) #pDD.bar(range(len(d['drawdownList'])), d['drawdownList'])
pPnl = plt.subplot(3, 1, 3) #pPnl = plt.subplot(3, 1, 3)
pPnl.set_ylabel("pnl") #pPnl.set_ylabel("pnl")
pPnl.hist(pnlList, bins=20) #pPnl.hist(d['pnlList'], bins=50)
plt.show() #plt.show()
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def putStrategyEvent(self, name): def putStrategyEvent(self, name):
@ -529,7 +972,7 @@ class BacktestingEngine(object):
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def setSlippage(self, slippage): def setSlippage(self, slippage):
"""设置滑点""" """设置滑点点数"""
self.slippage = slippage self.slippage = slippage
#---------------------------------------------------------------------- #----------------------------------------------------------------------
@ -542,6 +985,136 @@ class BacktestingEngine(object):
"""设置佣金比例""" """设置佣金比例"""
self.rate = rate self.rate = rate
#----------------------------------------------------------------------
def runOptimization(self, strategyClass, optimizationSetting):
"""优化参数"""
# 获取优化设置
settingList = optimizationSetting.generateSetting()
targetName = optimizationSetting.optimizeTarget
# 检查参数设置问题
if not settingList or not targetName:
self.output(u'优化设置有问题,请检查')
# 遍历优化
resultList = []
for setting in settingList:
self.clearBacktestingResult()
self.output('-' * 30)
self.output('setting: %s' %str(setting))
self.initStrategy(strategyClass, setting)
self.runBacktesting()
d = self.calculateBacktestingResult()
try:
targetValue = d[targetName]
except KeyError:
targetValue = 0
resultList.append(([str(setting)], targetValue))
# 显示结果
resultList.sort(reverse=True, key=lambda result:result[1])
self.output('-' * 30)
self.output(u'优化结果:')
for result in resultList:
self.output(u'%s: %s' %(result[0], result[1]))
#----------------------------------------------------------------------
def clearBacktestingResult(self):
"""清空之前回测的结果"""
# 清空限价单相关
self.limitOrderCount = 0
self.limitOrderDict.clear()
self.workingLimitOrderDict.clear()
# 清空停止单相关
self.stopOrderCount = 0
self.stopOrderDict.clear()
self.workingStopOrderDict.clear()
# 清空成交相关
self.tradeCount = 0
self.tradeDict.clear()
########################################################################
class TradingResult(object):
"""每笔交易的结果"""
#----------------------------------------------------------------------
def __init__(self, entry, exit, volume, rate, slippage, size):
"""Constructor"""
self.entry = entry # 开仓价格
self.exit = exit # 平仓价格
self.volume = volume # 交易数量(+/-代表方向)
self.turnover = (self.entry+self.exit)*size # 成交金额
self.commission = self.turnover*rate # 手续费成本
self.slippage = slippage*2*size # 滑点成本
self.pnl = ((self.exit - self.entry) * volume * size
- self.commission - self.slippage) # 净盈亏
########################################################################
class OptimizationSetting(object):
"""优化设置"""
#----------------------------------------------------------------------
def __init__(self):
"""Constructor"""
self.paramDict = OrderedDict()
self.optimizeTarget = '' # 优化目标字段
#----------------------------------------------------------------------
def addParameter(self, name, start, end, step):
"""增加优化参数"""
if end <= start:
print u'参数起始点必须小于终止点'
return
if step <= 0:
print u'参数布进必须大于0'
return
l = []
param = start
while param <= end:
l.append(param)
param += step
self.paramDict[name] = l
#----------------------------------------------------------------------
def generateSetting(self):
"""生成优化参数组合"""
# 参数名的列表
nameList = self.paramDict.keys()
paramList = self.paramDict.values()
# 使用迭代工具生产参数对组合
productList = list(product(*paramList))
# 把参数对组合打包到一个个字典组成的列表中
settingList = []
for p in productList:
d = dict(zip(nameList, p))
settingList.append(d)
return settingList
#----------------------------------------------------------------------
def setOptimizeTarget(self, target):
"""设置优化目标字段"""
self.optimizeTarget = target
#----------------------------------------------------------------------
def formatNumber(n):
"""格式化数字到字符串"""
n = round(n, 2) # 保留两位小数
return format(n, ',') # 加上千分符
if __name__ == '__main__': if __name__ == '__main__':
@ -560,7 +1133,7 @@ if __name__ == '__main__':
engine.setStartDate('20110101') engine.setStartDate('20110101')
# 载入历史数据到引擎中 # 载入历史数据到引擎中
engine.loadHistoryData(MINUTE_DB_NAME, 'IF0000') engine.setDatabase(MINUTE_DB_NAME, 'IF0000')
# 设置产品相关参数 # 设置产品相关参数
engine.setSlippage(0.2) # 股指1跳 engine.setSlippage(0.2) # 股指1跳