vnpy/vn.trader/ctaAlgo/ctaBacktesting.py
msincenselee ccf3bb1b8a update
2016-10-11 00:50:47 +08:00

1241 lines
45 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# encoding: UTF-8
'''
本文件中包含的是CTA模块的回测引擎回测引擎的API和CTA引擎一致
可以使用和实盘相同的代码进行回测。
'''
from datetime import datetime, timedelta
from collections import OrderedDict
from itertools import product
import pymongo
import MySQLdb
import json
import os
import cPickle
from ctaBase import *
from ctaSetting import *
from vtConstant import *
from vtGateway import VtOrderData, VtTradeData
from vtFunction import loadMongoSetting
import logging
########################################################################
class BacktestingEngine(object):
"""
CTA回测引擎
函数接口和策略引擎保持一样,
从而实现同一套代码从回测到实盘。
# modified by IncenseLee
1.增加Mysql数据库的支持
2.修改装载数据为批量式后加载模式。
"""
TICK_MODE = 'tick'
BAR_MODE = 'bar'
#----------------------------------------------------------------------
def __init__(self):
"""Constructor"""
# 本地停止单编号计数
self.stopOrderCount = 0
# stopOrderID = STOPORDERPREFIX + str(stopOrderCount)
# 本地停止单字典
# key为stopOrderIDvalue为stopOrder对象
self.stopOrderDict = {} # 停止单撤销后不会从本字典中删除
self.workingStopOrderDict = {} # 停止单撤销后会从本字典中删除
# 回测相关
self.strategy = None # 回测策略
self.mode = self.BAR_MODE # 回测模式默认为K线
self.slippage = 0 # 回测时假设的滑点
self.rate = 0 # 回测时假设的佣金比例(适用于百分比佣金)
self.size = 1 # 合约大小默认为1
self.dbClient = None # 数据库客户端
self.dbCursor = None # 数据库指针
self.historyData = [] # 历史数据的列表,回测用
self.initData = [] # 初始化用的数据
self.backtestingData = [] # 回测用的数据
self.dbName = '' # 回测数据库名
self.symbol = '' # 回测集合名
self.dataStartDate = None # 回测数据开始日期datetime对象
self.dataEndDate = None # 回测数据结束日期datetime对象
self.strategyStartDate = None # 策略启动日期即前面的数据用于初始化datetime对象
self.limitOrderDict = OrderedDict() # 限价单字典
self.workingLimitOrderDict = OrderedDict() # 活动限价单字典,用于进行撮合用
self.limitOrderCount = 0 # 限价单编号
self.tradeCount = 0 # 成交编号
self.tradeDict = OrderedDict() # 成交字典
self.logList = [] # 日志记录
# 当前最新数据,用于模拟成交用
self.tick = None
self.bar = None
self.dt = None # 最新的时间
self.gatewayName = u'BackTest'
#----------------------------------------------------------------------
def setStartDate(self, startDate='20100416', initDays=10):
"""设置回测的启动日期"""
self.dataStartDate = datetime.strptime(startDate, '%Y%m%d')
# 初始化天数
initTimeDelta = timedelta(initDays)
self.strategyStartDate = self.dataStartDate + initTimeDelta
#----------------------------------------------------------------------
def setEndDate(self, endDate=''):
"""设置回测的结束日期"""
if endDate:
self.dataEndDate = datetime.strptime(endDate, '%Y%m%d')
else:
self.dataEndDate = datetime.now()
def setMinDiff(self, minDiff):
"""设置回测品种的最小跳价,用于修正数据"""
self.minDiff = minDiff
#----------------------------------------------------------------------
def setBacktestingMode(self, mode):
"""设置回测模式"""
self.mode = mode
#----------------------------------------------------------------------
def setDatabase(self, dbName, symbol):
"""设置历史数据所用的数据库"""
self.dbName = dbName
self.symbol = symbol
#----------------------------------------------------------------------
def loadHistoryDataFromMongo(self):
"""载入历史数据"""
host, port = loadMongoSetting()
self.dbClient = pymongo.MongoClient(host, port)
collection = self.dbClient[self.dbName][self.symbol]
self.output(u'开始载入数据')
# 首先根据回测模式,确认要使用的数据类
if self.mode == self.BAR_MODE:
dataClass = CtaBarData
func = self.newBar
else:
dataClass = CtaTickData
func = self.newTick
# 载入初始化需要用的数据
flt = {'datetime':{'$gte':self.dataStartDate,
'$lt':self.strategyStartDate}}
initCursor = collection.find(flt)
# 将数据从查询指针中读取出,并生成列表
for d in initCursor:
data = dataClass()
data.__dict__ = d
self.initData.append(data)
# 载入回测数据
if not self.dataEndDate:
flt = {'datetime':{'$gte':self.strategyStartDate}} # 数据过滤条件
else:
flt = {'datetime':{'$gte':self.strategyStartDate,
'$lte':self.dataEndDate}}
self.dbCursor = collection.find(flt)
self.output(u'载入完成,数据量:%s' %(initCursor.count() + self.dbCursor.count()))
#----------------------------------------------------------------------
def connectMysql(self):
"""连接MysqlDB"""
# 载入json文件
fileName = 'mysql_connect.json'
try:
f = file(fileName)
except IOError:
self.writeCtaLog(u'回测引擎读取Mysql_connect.json失败')
return
# 解析json文件
setting = json.load(f)
try:
mysql_host = str(setting['host'])
mysql_port = int(setting['port'])
mysql_user = str(setting['user'])
mysql_passwd = str(setting['passwd'])
mysql_db = str(setting['db'])
except IOError:
self.writeCtaLog(u'回测引擎读取Mysql_connect.json,连接配置缺少字段,请检查')
return
try:
self.__mysqlConnection = MySQLdb.connect(host=mysql_host, user=mysql_user,
passwd=mysql_passwd, db=mysql_db, port=mysql_port)
self.__mysqlConnected = True
self.writeCtaLog(u'回测引擎连接MysqlDB成功')
except ConnectionFailure:
self.writeCtaLog(u'回测引擎连接MysqlDB失败')
#----------------------------------------------------------------------
def loadDataHistoryFromMysql(self, symbol, startDate, endDate):
"""载入历史TICK数据
如果加载过多数据会导致加载失败,间隔不要超过半年
"""
if not endDate:
endDate = datetime.today()
# 看本地缓存是否存在
if self.__loadDataHistoryFromLocalCache(symbol, startDate, endDate):
self.writeCtaLog(u'历史TICK数据从Cache载入')
return
# 每次获取日期周期
intervalDays = 10
for i in range (0,(endDate - startDate).days +1, intervalDays):
d1 = startDate + timedelta(days = i )
if (endDate - d1).days > 10:
d2 = startDate + timedelta(days = i + intervalDays -1 )
else:
d2 = endDate
# 从Mysql 提取数据
self.__qryDataHistoryFromMysql(symbol, d1, d2)
self.writeCtaLog(u'历史TICK数据共载入{0}'.format(len(self.historyData)))
# 保存本地cache文件
self.__saveDataHistoryToLocalCache(symbol, startDate, endDate)
def __loadDataHistoryFromLocalCache(self, symbol, startDate, endDate):
"""看本地缓存是否存在
added by IncenseLee
"""
# 运行路径下cache子目录
cacheFolder = os.getcwd()+'/cache'
# cache文件
cacheFile = u'{0}/{1}_{2}_{3}.pickle'.\
format(cacheFolder, symbol, startDate.strftime('%Y-%m-%d'), endDate.strftime('%Y-%m-%d'))
if not os.path.isfile(cacheFile):
return False
else:
# 从cache文件加载
cache = open(cacheFile,mode='r')
self.historyData = cPickle.load(cache)
cache.close()
return True
def __saveDataHistoryToLocalCache(self, symbol, startDate, endDate):
"""保存本地缓存
added by IncenseLee
"""
# 运行路径下cache子目录
cacheFolder = os.getcwd()+'/cache'
# 创建cache子目录
if not os.path.isdir(cacheFolder):
os.mkdir(cacheFolder)
# cache 文件名
cacheFile = u'{0}/{1}_{2}_{3}.pickle'.\
format(cacheFolder, symbol, startDate.strftime('%Y-%m-%d'), endDate.strftime('%Y-%m-%d'))
# 重复存在 返回
if os.path.isfile(cacheFile):
return False
else:
# 写入cache文件
cache = open(cacheFile, mode='w')
cPickle.dump(self.historyData,cache)
cache.close()
return True
#----------------------------------------------------------------------
def __qryDataHistoryFromMysql(self, symbol, startDate, endDate):
"""从Mysql载入历史TICK数据
added by IncenseLee
"""
try:
self.connectMysql()
if self.__mysqlConnected:
# 获取指针
cur = self.__mysqlConnection.cursor(MySQLdb.cursors.DictCursor)
if endDate:
# 开始日期 ~ 结束日期
sqlstring = ' select \'{0}\' as InstrumentID, str_to_date(concat(ndate,\' \', ntime),' \
'\'%Y-%m-%d %H:%i:%s\') as UpdateTime,price as LastPrice,vol as Volume,' \
'position_vol as OpenInterest,bid1_price as BidPrice1,bid1_vol as BidVolume1, ' \
'sell1_price as AskPrice1, sell1_vol as AskVolume1 from TB_{0}MI ' \
'where ndate between cast(\'{1}\' as date) and cast(\'{2}\' as date) order by UpdateTime'.\
format(symbol, startDate, endDate)
elif startDate:
# 开始日期 - 当前
sqlstring = ' select \'{0}\' as InstrumentID,str_to_date(concat(ndate,\' \', ntime),' \
'\'%Y-%m-%d %H:%i:%s\') as UpdateTime,price as LastPrice,vol as Volume,' \
'position_vol as OpenInterest,bid1_price as BidPrice1,bid1_vol as BidVolume1, ' \
'sell1_price as AskPrice1, sell1_vol as AskVolume1 from TB__{0}MI ' \
'where ndate > cast(\'{1}\' as date) order by UpdateTime'.\
format( symbol, startDate)
else:
# 所有数据
sqlstring =' select \'{0}\' as InstrumentID,str_to_date(concat(ndate,\' \', ntime),' \
'\'%Y-%m-%d %H:%i:%s\') as UpdateTime,price as LastPrice,vol as Volume,' \
'position_vol as OpenInterest,bid1_price as BidPrice1,bid1_vol as BidVolume1, ' \
'sell1_price as AskPrice1, sell1_vol as AskVolume1 from TB__{0}MI order by UpdateTime'.\
format(symbol)
self.writeCtaLog(sqlstring)
# 执行查询
count = cur.execute(sqlstring)
self.writeCtaLog(u'历史TICK数据共{0}'.format(count))
# 分批次读取
fetch_counts = 0
fetch_size = 1000
while True:
results = cur.fetchmany(fetch_size)
if not results:
break
fetch_counts = fetch_counts + len(results)
if not self.historyData:
self.historyData =results
else:
self.historyData = self.historyData + results
self.writeCtaLog(u'{1}~{2}历史TICK数据载入共{0}'.format(fetch_counts,startDate,endDate))
else:
self.writeCtaLog(u'MysqlDB未连接请检查')
except MySQLdb.Error, e:
self.writeCtaLog(u'MysqlDB载入数据失败请检查.Error {0}'.format(e))
def __dataToTick(self, data):
"""
数据库查询返回的data结构转换为tick对象
added by IncenseLee
"""
tick = CtaTickData()
symbol = data['InstrumentID']
tick.symbol = symbol
# 创建TICK数据对象并更新数据
tick.vtSymbol = symbol
# tick.openPrice = data['OpenPrice']
# tick.highPrice = data['HighestPrice']
# tick.lowPrice = data['LowestPrice']
tick.lastPrice = float(data['LastPrice'])
tick.volume = data['Volume']
tick.openInterest = data['OpenInterest']
# tick.upperLimit = data['UpperLimitPrice']
# tick.lowerLimit = data['LowerLimitPrice']
tick.datetime = data['UpdateTime']
tick.date = tick.datetime.strftime('%Y-%m-%d')
tick.time = tick.datetime.strftime('%H:%M:%S')
tick.bidPrice1 = float(data['BidPrice1'])
# tick.bidPrice2 = data['BidPrice2']
# tick.bidPrice3 = data['BidPrice3']
# tick.bidPrice4 = data['BidPrice4']
# tick.bidPrice5 = data['BidPrice5']
tick.askPrice1 = float(data['AskPrice1'])
# tick.askPrice2 = data['AskPrice2']
# tick.askPrice3 = data['AskPrice3']
# tick.askPrice4 = data['AskPrice4']
# tick.askPrice5 = data['AskPrice5']
tick.bidVolume1 = data['BidVolume1']
# tick.bidVolume2 = data['BidVolume2']
# tick.bidVolume3 = data['BidVolume3']
# tick.bidVolume4 = data['BidVolume4']
# tick.bidVolume5 = data['BidVolume5']
tick.askVolume1 = data['AskVolume1']
# tick.askVolume2 = data['AskVolume2']
# tick.askVolume3 = data['AskVolume3']
# tick.askVolume4 = data['AskVolume4']
# tick.askVolume5 = data['AskVolume5']
return tick
#----------------------------------------------------------------------
def getMysqlDeltaDate(self,symbol, startDate, decreaseDays):
"""从mysql库中获取交易日前若干天
added by IncenseLee
"""
try:
if self.__mysqlConnected:
# 获取mysql指针
cur = self.__mysqlConnection.cursor()
sqlstring='select distinct ndate from TB_{0}MI where ndate < ' \
'cast(\'{1}\' as date) order by ndate desc limit {2},1'.format(symbol, startDate, decreaseDays-1)
# self.writeCtaLog(sqlstring)
count = cur.execute(sqlstring)
if count > 0:
# 提取第一条记录
result = cur.fetchone()
return result[0]
else:
self.writeCtaLog(u'MysqlDB没有查询结果请检查日期')
else:
self.writeCtaLog(u'MysqlDB未连接请检查')
except MySQLdb.Error, e:
self.writeCtaLog(u'MysqlDB载入数据失败请检查.Error {0}: {1}'.format(e.arg[0],e.arg[1]))
# 出错后缺省返回
return startDate-timedelta(days=3)
#----------------------------------------------------------------------
def runBackTestingWithFile(self, filename):
"""运行回测使用本地csv数据)
added by IncenseLee
"""
if not filename:
self.writeCtaLog(u'请指定回测数据文件')
return
if not self.dataStartDate:
self.writeCtaLog(u'回测开始日期未设置。')
return
if not self.dataEndDate:
self.dataEndDate = datetime.today()
import os
if not os.path.isfile(filename):
self.writeCtaLog(u'{0}文件不存在'.format(filename))
if len(self.symbol)<1:
self.writeCtaLog(u'回测对象未设置。')
return
# 首先根据回测模式,确认要使用的数据类
if not self.mode == self.BAR_MODE:
self.writeCtaLog(u'文件仅支持bar模式若扩展tick模式需要修改本方法')
return
self.output(u'开始回测')
#self.strategy.inited = True
self.strategy.onInit()
self.output(u'策略初始化完成')
self.strategy.trading = True
self.strategy.onStart()
self.output(u'策略启动完成')
self.output(u'开始回放数据')
import csv
csvfile = file(filename,'rb')
reader = csv.DictReader((line.replace('\0','') for line in csvfile),delimiter=",")
for row in reader:
bar = CtaBarData()
bar.open = float(row['Open'])
bar.high = float(row['High'])
bar.low = float(row['Low'])
bar.close = float(row['Close'])
bar.volume = float(row['TotalVolume'])
bar.date = row['Date']
bar.time = row['Time']
bar.datetime = datetime.strptime(bar.date+' '+ bar.time, '%Y/%m/%d %H:%M:%S')
if not (bar.datetime < self.dataStartDate or bar.datetime > self.dataEndDate):
self.newBar(bar)
#----------------------------------------------------------------------
def runBacktestingWithMysql(self):
"""运行回测(使用Mysql数据
added by IncenseLee
"""
if not self.dataStartDate:
self.writeCtaLog(u'回测开始日期未设置。')
return
if not self.dataEndDate:
self.dataEndDate = datetime.today()
if len(self.symbol)<1:
self.writeCtaLog(u'回测对象未设置。')
return
# 首先根据回测模式,确认要使用的数据类
if self.mode == self.BAR_MODE:
dataClass = CtaBarData
func = self.newBar
else:
dataClass = CtaTickData
func = self.newTick
self.output(u'开始回测')
#self.strategy.inited = True
self.strategy.onInit()
self.output(u'策略初始化完成')
self.strategy.trading = True
self.strategy.onStart()
self.output(u'策略启动完成')
self.output(u'开始回放数据')
# 每次获取日期周期
intervalDays = 10
for i in range (0,(self.dataEndDate - self.dataStartDate).days +1, intervalDays):
d1 = self.dataStartDate + timedelta(days = i )
if (self.dataEndDate - d1).days > intervalDays:
d2 = self.dataStartDate + timedelta(days = i + intervalDays -1 )
else:
d2 = self.dataEndDate
# 提取历史数据
self.loadDataHistoryFromMysql(self.symbol, d1, d2)
self.output(u'数据日期:{0} => {1}'.format(d1,d2))
# 将逐笔数据推送
for data in self.historyData:
# 记录最新的TICK数据
self.tick = self.__dataToTick(data)
self.dt = self.tick.datetime
# 处理限价单
self.crossLimitOrder()
self.crossStopOrder()
# 推送到策略引擎中
self.strategy.onTick(self.tick)
# 清空历史数据
self.historyData = []
self.output(u'数据回放结束')
#----------------------------------------------------------------------
def runBacktesting(self):
"""运行回测"""
# 载入历史数据
self.loadHistoryData()
# 首先根据回测模式,确认要使用的数据类
if self.mode == self.BAR_MODE:
dataClass = CtaBarData
func = self.newBar
else:
dataClass = CtaTickData
func = self.newTick
self.output(u'开始回测')
self.strategy.inited = True
self.strategy.onInit()
self.output(u'策略初始化完成')
self.strategy.trading = True
self.strategy.onStart()
self.output(u'策略启动完成')
self.output(u'开始回放数据')
for d in self.dbCursor:
data = dataClass()
data.__dict__ = d
func(data)
self.output(u'数据回放结束')
#----------------------------------------------------------------------
def newBar(self, bar):
"""新的K线"""
self.bar = bar
self.dt = bar.datetime
self.crossLimitOrder() # 先撮合限价单
self.crossStopOrder() # 再撮合停止单
self.strategy.onBar(bar) # 推送K线到策略中
#----------------------------------------------------------------------
def newTick(self, tick):
"""新的Tick"""
self.tick = tick
self.dt = tick.datetime
self.crossLimitOrder()
self.crossStopOrder()
self.strategy.onTick(tick)
#----------------------------------------------------------------------
def initStrategy(self, strategyClass, setting=None):
"""
初始化策略
setting是策略的参数设置如果使用类中写好的默认设置则可以不传该参数
"""
self.strategy = strategyClass(self, setting)
self.strategy.name = self.strategy.className
#----------------------------------------------------------------------
def sendOrder(self, vtSymbol, orderType, price, volume, strategy):
"""发单"""
self.writeCtaLog(u'{0},{1},{2}@{3}'.format(vtSymbol,orderType,price,volume))
self.limitOrderCount += 1
orderID = str(self.limitOrderCount)
order = VtOrderData()
order.vtSymbol = vtSymbol
order.price = price
order.totalVolume = volume
order.status = STATUS_NOTTRADED # 刚提交尚未成交
order.orderID = orderID
order.vtOrderID = orderID
order.orderTime = str(self.dt)
# added by IncenseLee
order.gatewayName = self.gatewayName
# CTA委托类型映射
if orderType == CTAORDER_BUY:
order.direction = DIRECTION_LONG
order.offset = OFFSET_OPEN
elif orderType == CTAORDER_SELL:
order.direction = DIRECTION_SHORT
order.offset = OFFSET_CLOSE
elif orderType == CTAORDER_SHORT:
order.direction = DIRECTION_SHORT
order.offset = OFFSET_OPEN
elif orderType == CTAORDER_COVER:
order.direction = DIRECTION_LONG
order.offset = OFFSET_CLOSE
# modified by IncenseLee
key = u'{0}.{1}'.format(order.gatewayName, orderID)
# 保存到限价单字典中
self.workingLimitOrderDict[key] = order
self.limitOrderDict[key] = order
return key
#----------------------------------------------------------------------
def cancelOrder(self, vtOrderID):
"""撤单"""
if vtOrderID in self.workingLimitOrderDict:
order = self.workingLimitOrderDict[vtOrderID]
order.status = STATUS_CANCELLED
order.cancelTime = str(self.dt)
del self.workingLimitOrderDict[vtOrderID]
#----------------------------------------------------------------------
def sendStopOrder(self, vtSymbol, orderType, price, volume, strategy):
"""发停止单(本地实现)"""
self.stopOrderCount += 1
stopOrderID = STOPORDERPREFIX + str(self.stopOrderCount)
so = StopOrder()
so.vtSymbol = vtSymbol
so.price = price
so.volume = volume
so.strategy = strategy
so.stopOrderID = stopOrderID
so.status = STOPORDER_WAITING
if orderType == CTAORDER_BUY:
so.direction = DIRECTION_LONG
so.offset = OFFSET_OPEN
elif orderType == CTAORDER_SELL:
so.direction = DIRECTION_SHORT
so.offset = OFFSET_CLOSE
elif orderType == CTAORDER_SHORT:
so.direction = DIRECTION_SHORT
so.offset = OFFSET_OPEN
elif orderType == CTAORDER_COVER:
so.direction = DIRECTION_LONG
so.offset = OFFSET_CLOSE
# 保存stopOrder对象到字典中
self.stopOrderDict[stopOrderID] = so
self.workingStopOrderDict[stopOrderID] = so
return stopOrderID
#----------------------------------------------------------------------
def cancelStopOrder(self, stopOrderID):
"""撤销停止单"""
# 检查停止单是否存在
if stopOrderID in self.workingStopOrderDict:
so = self.workingStopOrderDict[stopOrderID]
so.status = STOPORDER_CANCELLED
del self.workingStopOrderDict[stopOrderID]
#----------------------------------------------------------------------
def crossLimitOrder(self):
"""基于最新数据撮合限价单"""
# 先确定会撮合成交的价格
if self.mode == self.BAR_MODE:
buyCrossPrice = self.bar.low # 若买入方向限价单价格高于该价格,则会成交
sellCrossPrice = self.bar.high # 若卖出方向限价单价格低于该价格,则会成交
buyBestCrossPrice = self.bar.open # 在当前时间点前发出的买入委托可能的最优成交价
sellBestCrossPrice = self.bar.open # 在当前时间点前发出的卖出委托可能的最优成交价
else:
buyCrossPrice = self.tick.askPrice1
sellCrossPrice = self.tick.bidPrice1
buyBestCrossPrice = self.tick.askPrice1
sellBestCrossPrice = self.tick.bidPrice1
# 遍历限价单字典中的所有限价单
for orderID, order in self.workingLimitOrderDict.items():
# 判断是否会成交
buyCross = order.direction==DIRECTION_LONG and order.price>=buyCrossPrice
sellCross = order.direction==DIRECTION_SHORT and order.price<=sellCrossPrice
# 如果发生了成交
if buyCross or sellCross:
# 推送成交数据
self.tradeCount += 1 # 成交编号自增1
tradeID = str(self.tradeCount)
trade = VtTradeData()
trade.vtSymbol = order.vtSymbol
trade.tradeID = tradeID
trade.vtTradeID = tradeID
trade.orderID = order.orderID
trade.vtOrderID = order.orderID
trade.direction = order.direction
trade.offset = order.offset
# 以买入为例:
# 1. 假设当根K线的OHLC分别为100, 125, 90, 110
# 2. 假设在上一根K线结束(也是当前K线开始)的时刻策略发出的委托为限价105
# 3. 则在实际中的成交价会是100而不是105因为委托发出时市场的最优价格是100
if buyCross:
trade.price = min(order.price, buyBestCrossPrice)
self.strategy.pos += order.totalVolume
else:
trade.price = max(order.price, sellBestCrossPrice)
self.strategy.pos -= order.totalVolume
trade.volume = order.totalVolume
trade.tradeTime = str(self.dt)
trade.dt = self.dt
self.strategy.onTrade(trade)
self.tradeDict[tradeID] = trade
self.writeCtaLog(u'TradeId:{0}'.format(tradeID))
# 推送委托数据
order.tradedVolume = order.totalVolume
order.status = STATUS_ALLTRADED
self.strategy.onOrder(order)
# 从字典中删除该限价单
del self.workingLimitOrderDict[orderID]
#----------------------------------------------------------------------
def crossStopOrder(self):
"""基于最新数据撮合停止单"""
# 先确定会撮合成交的价格,这里和限价单规则相反
if self.mode == self.BAR_MODE:
buyCrossPrice = self.bar.high # 若买入方向停止单价格低于该价格,则会成交
sellCrossPrice = self.bar.low # 若卖出方向限价单价格高于该价格,则会成交
bestCrossPrice = self.bar.open # 最优成交价,买入停止单不能低于,卖出停止单不能高于
else:
buyCrossPrice = self.tick.lastPrice
sellCrossPrice = self.tick.lastPrice
bestCrossPrice = self.tick.lastPrice
# 遍历停止单字典中的所有停止单
for stopOrderID, so in self.workingStopOrderDict.items():
# 判断是否会成交
buyCross = so.direction==DIRECTION_LONG and so.price<=buyCrossPrice
sellCross = so.direction==DIRECTION_SHORT and so.price>=sellCrossPrice
# 如果发生了成交
if buyCross or sellCross:
# 推送成交数据
self.tradeCount += 1 # 成交编号自增1
tradeID = str(self.tradeCount)
trade = VtTradeData()
trade.vtSymbol = so.vtSymbol
trade.tradeID = tradeID
trade.vtTradeID = tradeID
if buyCross:
self.strategy.pos += so.volume
trade.price = max(bestCrossPrice, so.price)
else:
self.strategy.pos -= so.volume
trade.price = min(bestCrossPrice, so.price)
self.limitOrderCount += 1
orderID = str(self.limitOrderCount)
trade.orderID = orderID
trade.vtOrderID = orderID
trade.direction = so.direction
trade.offset = so.offset
trade.volume = so.volume
trade.tradeTime = str(self.dt)
trade.dt = self.dt
self.strategy.onTrade(trade)
self.tradeDict[tradeID] = trade
# 推送委托数据
so.status = STOPORDER_TRIGGERED
order = VtOrderData()
order.vtSymbol = so.vtSymbol
order.symbol = so.vtSymbol
order.orderID = orderID
order.vtOrderID = orderID
order.direction = so.direction
order.offset = so.offset
order.price = so.price
order.totalVolume = so.volume
order.tradedVolume = so.volume
order.status = STATUS_ALLTRADED
order.orderTime = trade.tradeTime
self.strategy.onOrder(order)
self.limitOrderDict[orderID] = order
# 从字典中删除该限价单
del self.workingStopOrderDict[stopOrderID]
#----------------------------------------------------------------------
def insertData(self, dbName, collectionName, data):
"""考虑到回测中不允许向数据库插入数据,防止实盘交易中的一些代码出错"""
pass
#----------------------------------------------------------------------
def loadBar(self, dbName, collectionName, startDate):
"""直接返回初始化数据列表中的Bar"""
return self.initData
#----------------------------------------------------------------------
def loadTick(self, dbName, collectionName, startDate):
"""直接返回初始化数据列表中的Tick"""
return self.initData
#----------------------------------------------------------------------
def writeCtaLog(self, content):
"""记录日志"""
log = str(self.dt) + ' ' + content
self.logList.append(log)
# 写入本地log日志
logging.info(content)
#----------------------------------------------------------------------
def output(self, content):
"""输出内容"""
print str(datetime.now()) + "\t" + content
#----------------------------------------------------------------------
def calculateBacktestingResult(self):
"""
计算回测结果
"""
self.output(u'计算回测结果')
# 首先基于回测后的成交记录,计算每笔交易的盈亏
resultDict = OrderedDict() # 交易结果记录
longTrade = [] # 未平仓的多头交易
shortTrade = [] # 未平仓的空头交易
i = 0
longid = EMPTY_STRING
shortid = EMPTY_STRING
for tradeid in self.tradeDict.keys():
trade = self.tradeDict[tradeid]
if tradeid == '127':
pass
# 多头交易
if trade.direction == DIRECTION_LONG:
# 如果尚无空头交易
if not shortTrade:
longTrade.append(trade)
longid = tradeid
# 当前多头交易为平空
else:
entryTrade = shortTrade.pop(0)
result = TradingResult(entryTrade.price, trade.price, -trade.volume,
self.rate, self.slippage, self.size)
resultDict[trade.dt] = result
i = i+1
self.writeCtaLog(u'{6}.{7}.开空{0},short:{1},{8}.平空{2},cover:{3},vol:{4},净盈亏:{5}'
.format(entryTrade.tradeTime, entryTrade.price,
trade.tradeTime,trade.price, trade.volume,result.pnl,
i,shortid,tradeid))
# 空头交易
else:
# 如果尚无多头交易
if not longTrade:
shortTrade.append(trade)
shortid = tradeid
# 当前空头交易为平多
else:
entryTrade = longTrade.pop(0)
result = TradingResult(entryTrade.price, trade.price, trade.volume,
self.rate, self.slippage, self.size)
resultDict[trade.dt] = result
i = i+1
self.writeCtaLog(u'{6}.{7}开多{0},buy:{1},{8}.平多{2},sell:{3},vol:{4},净盈亏:{5}'
.format(entryTrade.tradeTime, entryTrade.price,
trade.tradeTime,trade.price, trade.volume,result.pnl,
i,longid,tradeid))
# 检查是否有交易
if not resultDict:
self.output(u'无交易结果')
return {}
# 然后基于每笔交易的结果,我们可以计算具体的盈亏曲线和最大回撤等
capital = 0 # 资金
maxCapital = 0 # 资金最高净值
drawdown = 0 # 回撤
totalResult = 0 # 总成交数量
totalTurnover = 0 # 总成交金额(合约面值)
totalCommission = 0 # 总手续费
totalSlippage = 0 # 总滑点
timeList = [] # 时间序列
pnlList = [] # 每笔盈亏序列
capitalList = [] # 盈亏汇总的时间序列
drawdownList = [] # 回撤的时间序列
for time, result in resultDict.items():
capital += result.pnl
maxCapital = max(capital, maxCapital)
drawdown = capital - maxCapital
pnlList.append(result.pnl)
timeList.append(time)
capitalList.append(capital)
drawdownList.append(drawdown)
totalResult += 1
totalTurnover += result.turnover
totalCommission += result.commission
totalSlippage += result.slippage
# 返回回测结果
d = {}
d['capital'] = capital
d['maxCapital'] = maxCapital
d['drawdown'] = drawdown
d['totalResult'] = totalResult
d['totalTurnover'] = totalTurnover
d['totalCommission'] = totalCommission
d['totalSlippage'] = totalSlippage
d['timeList'] = timeList
d['pnlList'] = pnlList
d['capitalList'] = capitalList
d['drawdownList'] = drawdownList
return d
#----------------------------------------------------------------------
def showBacktestingResult(self):
"""显示回测结果"""
d = self.calculateBacktestingResult()
if len(d)== 0:
self.output(u'无交易结果')
return
# 输出
self.output('-' * 30)
self.output(u'第一笔交易:\t%s' % d['timeList'][0])
self.output(u'最后一笔交易:\t%s' % d['timeList'][-1])
self.output(u'总交易次数:\t%s' % formatNumber(d['totalResult']))
self.output(u'总盈亏:\t%s' % formatNumber(d['capital']))
self.output(u'最大回撤: \t%s' % formatNumber(min(d['drawdownList'])))
self.output(u'平均每笔盈利:\t%s' %formatNumber(d['capital']/d['totalResult']))
self.output(u'平均每笔滑点成本:\t%s' %formatNumber(d['totalSlippage']/d['totalResult']))
self.output(u'平均每笔佣金:\t%s' %formatNumber(d['totalCommission']/d['totalResult']))
# 绘图
import matplotlib.pyplot as plt
pCapital = plt.subplot(3, 1, 1)
pCapital.set_ylabel("capital")
pCapital.plot(d['capitalList'])
pDD = plt.subplot(3, 1, 2)
pDD.set_ylabel("DD")
pDD.bar(range(len(d['drawdownList'])), d['drawdownList'])
pPnl = plt.subplot(3, 1, 3)
pPnl.set_ylabel("pnl")
pPnl.hist(d['pnlList'], bins=50)
plt.show()
#----------------------------------------------------------------------
def putStrategyEvent(self, name):
"""发送策略更新事件,回测中忽略"""
pass
#----------------------------------------------------------------------
def setSlippage(self, slippage):
"""设置滑点点数"""
self.slippage = slippage
#----------------------------------------------------------------------
def setSize(self, size):
"""设置合约大小"""
self.size = size
#----------------------------------------------------------------------
def setRate(self, rate):
"""设置佣金比例"""
self.rate = float(rate)
#----------------------------------------------------------------------
def runOptimization(self, strategyClass, optimizationSetting):
"""优化参数"""
# 获取优化设置
settingList = optimizationSetting.generateSetting()
targetName = optimizationSetting.optimizeTarget
# 检查参数设置问题
if not settingList or not targetName:
self.output(u'优化设置有问题,请检查')
# 遍历优化
resultList = []
for setting in settingList:
self.clearBacktestingResult()
self.output('-' * 30)
self.output('setting: %s' %str(setting))
self.initStrategy(strategyClass, setting)
self.runBacktesting()
d = self.calculateBacktestingResult()
try:
targetValue = d[targetName]
except KeyError:
targetValue = 0
resultList.append(([str(setting)], targetValue))
# 显示结果
resultList.sort(reverse=True, key=lambda result:result[1])
self.output('-' * 30)
self.output(u'优化结果:')
for result in resultList:
self.output(u'%s: %s' %(result[0], result[1]))
#----------------------------------------------------------------------
def clearBacktestingResult(self):
"""清空之前回测的结果"""
# 清空限价单相关
self.limitOrderCount = 0
self.limitOrderDict.clear()
self.workingLimitOrderDict.clear()
# 清空停止单相关
self.stopOrderCount = 0
self.stopOrderDict.clear()
self.workingStopOrderDict.clear()
# 清空成交相关
self.tradeCount = 0
self.tradeDict.clear()
########################################################################
class TradingResult(object):
"""每笔交易的结果"""
#----------------------------------------------------------------------
def __init__(self, entry, exit, volume, rate, slippage, size):
"""Constructor"""
self.entry = entry # 开仓价格
self.exit = exit # 平仓价格
self.volume = volume # 交易数量(+/-代表方向)
self.turnover = (self.entry+self.exit)*size # 成交金额
self.commission =round(float(self.turnover*rate),4) # 手续费成本
self.slippage = slippage*2*size # 滑点成本
self.pnl = ((self.exit - self.entry) * volume * size
- self.commission - self.slippage) # 净盈亏
########################################################################
class OptimizationSetting(object):
"""优化设置"""
#----------------------------------------------------------------------
def __init__(self):
"""Constructor"""
self.paramDict = OrderedDict()
self.optimizeTarget = '' # 优化目标字段
#----------------------------------------------------------------------
def addParameter(self, name, start, end, step):
"""增加优化参数"""
if end <= start:
print u'参数起始点必须小于终止点'
return
if step <= 0:
print u'参数布进必须大于0'
return
l = []
param = start
while param <= end:
l.append(param)
param += step
self.paramDict[name] = l
#----------------------------------------------------------------------
def generateSetting(self):
"""生成优化参数组合"""
# 参数名的列表
nameList = self.paramDict.keys()
paramList = self.paramDict.values()
# 使用迭代工具生产参数对组合
productList = list(product(*paramList))
# 把参数对组合打包到一个个字典组成的列表中
settingList = []
for p in productList:
d = dict(zip(nameList, p))
settingList.append(d)
return settingList
#----------------------------------------------------------------------
def setOptimizeTarget(self, target):
"""设置优化目标字段"""
self.optimizeTarget = target
#----------------------------------------------------------------------
def formatNumber(n):
"""格式化数字到字符串"""
n = round(n, 2) # 保留两位小数
return format(n, ',') # 加上千分符
if __name__ == '__main__':
# 以下内容是一段回测脚本的演示,用户可以根据自己的需求修改
# 建议使用ipython notebook或者spyder来做回测
# 同样可以在命令模式下进行回测(一行一行输入运行)
from ctaDemo import *
# 创建回测引擎
engine = BacktestingEngine()
# 设置引擎的回测模式为K线
engine.setBacktestingMode(engine.BAR_MODE)
# 设置回测用的数据起始日期
engine.setStartDate('20110101')
# 载入历史数据到引擎中
engine.setDatabase(MINUTE_DB_NAME, 'IF0000')
# 设置产品相关参数
engine.setSlippage(0.2) # 股指1跳
engine.setRate(0.3/10000) # 万0.3
engine.setSize(300) # 股指合约大小
# 在引擎中创建策略对象
engine.initStrategy(DoubleEmaDemo, {})
# 开始跑回测
engine.runBacktesting()
# 显示回测结果
# spyder或者ipython notebook中运行时会弹出盈亏曲线图
# 直接在cmd中回测则只会打印一些回测数值
engine.showBacktestingResult()