vnpy/vn.trader/ctaAlgo/ctaBacktesting.py
2016-10-31 17:33:09 +08:00

2058 lines
81 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# encoding: UTF-8
'''
本文件中包含的是CTA模块的回测引擎回测引擎的API和CTA引擎一致
可以使用和实盘相同的代码进行回测。
'''
from datetime import datetime, timedelta
from collections import OrderedDict
from itertools import product
import pymongo
import MySQLdb
import json
import os
import cPickle
import csv
from ctaBase import *
from ctaSetting import *
from vtConstant import *
from vtGateway import VtOrderData, VtTradeData
from vtFunction import loadMongoSetting
import logging
import copy
########################################################################
class BacktestingEngine(object):
"""
CTA回测引擎
函数接口和策略引擎保持一样,
从而实现同一套代码从回测到实盘。
# modified by IncenseLee
1.增加Mysql数据库的支持
2.修改装载数据为批量式后加载模式。
3.增加csv 读取bar的回测模式
4.增加csv 读取tick合并价差的回测模式
"""
TICK_MODE = 'tick' # 数据模式逐Tick回测
BAR_MODE = 'bar' # 数据模式逐Bar回测
REALTIME_MODE ='RealTime' # 逐笔交易计算资金,供策略获取资金容量,计算开仓数量
FINAL_MODE = 'Final' # 最后才统计交易,不适合按照百分比等开仓数量计算
#----------------------------------------------------------------------
def __init__(self):
"""Constructor"""
# 本地停止单编号计数
self.stopOrderCount = 0
# stopOrderID = STOPORDERPREFIX + str(stopOrderCount)
# 本地停止单字典
# key为stopOrderIDvalue为stopOrder对象
self.stopOrderDict = {} # 停止单撤销后不会从本字典中删除
self.workingStopOrderDict = {} # 停止单撤销后会从本字典中删除
# 回测相关
self.strategy = None # 回测策略
self.mode = self.BAR_MODE # 回测模式默认为K线
self.slippage = 0 # 回测时假设的滑点
self.rate = 0 # 回测时假设的佣金比例(适用于百分比佣金)
self.size = 1 # 合约大小默认为1
self.dbClient = None # 数据库客户端
self.dbCursor = None # 数据库指针
self.historyData = [] # 历史数据的列表,回测用
self.initData = [] # 初始化用的数据
self.backtestingData = [] # 回测用的数据
self.dbName = '' # 回测数据库名
self.symbol = '' # 回测集合名
self.dataStartDate = None # 回测数据开始日期datetime对象
self.dataEndDate = None # 回测数据结束日期datetime对象
self.strategyStartDate = None # 策略启动日期即前面的数据用于初始化datetime对象
self.limitOrderDict = OrderedDict() # 限价单字典
self.workingLimitOrderDict = OrderedDict() # 活动限价单字典,用于进行撮合用
self.limitOrderCount = 0 # 限价单编号
self.tradeCount = 0 # 成交编号
self.tradeDict = OrderedDict() # 成交字典
self.logList = [] # 日志记录
# 当前最新数据,用于模拟成交用
self.tick = None
self.bar = None
self.dt = None # 最新的时间
self.gatewayName = u'BackTest'
# csvFile相关
self.barTimeInterval = 60 # csv文件属于K线类型K线的周期秒数,缺省是1分钟
# 费用情况
self.avaliable = EMPTY_FLOAT
self.percent = EMPTY_FLOAT
self.percentLimit = 30 # 投资仓位比例上限
# 回测计算相关
self.calculateMode = self.FINAL_MODE
self.usageCompounding = False # 是否使用简单复利 只针对FINAL_MODE有效
self.initCapital = 10000 # 期初资金
self.capital = self.initCapital # 资金 相当于Balance
self.maxCapital = self.initCapital # 资金最高净值
self.maxPnl = 0 # 最高盈利
self.minPnl = 0 # 最大亏损
self.maxVolume = 1 # 最大仓位数
self.wins = 0 # 盈利次数
self.totalResult = 0 # 总成交数量
self.totalTurnover = 0 # 总成交金额(合约面值)
self.totalCommission = 0 # 总手续费
self.totalSlippage = 0 # 总滑点
self.timeList = [] # 时间序列
self.pnlList = [] # 每笔盈亏序列
self.capitalList = [] # 盈亏汇总的时间序列
self.drawdownList = [] # 回撤的时间序列
self.drawdownRateList = [] # 最大回撤比例的时间序列
self.exportTradeList = [] # 导出交易记录列表
self.fixCommission = EMPTY_FLOAT # 固定交易费用
def getAccountInfo(self):
"""返回账号的实时权益,可用资金,仓位比例,投资仓位比例上限"""
if self.capital == EMPTY_FLOAT:
self.percent = EMPTY_FLOAT
return self.capital, self.avaliable, self.percent, self.percentLimit
#----------------------------------------------------------------------
def setStartDate(self, startDate='20100416', initDays=10):
"""设置回测的启动日期"""
self.dataStartDate = datetime.strptime(startDate, '%Y%m%d')
# 初始化天数
initTimeDelta = timedelta(initDays)
self.strategyStartDate = self.dataStartDate + initTimeDelta
#----------------------------------------------------------------------
def setEndDate(self, endDate=''):
"""设置回测的结束日期"""
if endDate:
self.dataEndDate = datetime.strptime(endDate, '%Y%m%d')
else:
self.dataEndDate = datetime.now()
def setMinDiff(self, minDiff):
"""设置回测品种的最小跳价,用于修正数据"""
self.minDiff = minDiff
#----------------------------------------------------------------------
def setBacktestingMode(self, mode):
"""设置回测模式"""
self.mode = mode
#----------------------------------------------------------------------
def setDatabase(self, dbName, symbol):
"""设置历史数据所用的数据库"""
self.dbName = dbName
self.symbol = symbol
#----------------------------------------------------------------------
def loadHistoryDataFromMongo(self):
"""载入历史数据"""
host, port = loadMongoSetting()
self.dbClient = pymongo.MongoClient(host, port)
collection = self.dbClient[self.dbName][self.symbol]
self.output(u'开始载入数据')
# 首先根据回测模式,确认要使用的数据类
if self.mode == self.BAR_MODE:
dataClass = CtaBarData
func = self.newBar
else:
dataClass = CtaTickData
func = self.newTick
# 载入初始化需要用的数据
flt = {'datetime':{'$gte':self.dataStartDate,
'$lt':self.strategyStartDate}}
initCursor = collection.find(flt)
# 将数据从查询指针中读取出,并生成列表
for d in initCursor:
data = dataClass()
data.__dict__ = d
self.initData.append(data)
# 载入回测数据
if not self.dataEndDate:
flt = {'datetime':{'$gte':self.strategyStartDate}} # 数据过滤条件
else:
flt = {'datetime':{'$gte':self.strategyStartDate,
'$lte':self.dataEndDate}}
self.dbCursor = collection.find(flt)
self.output(u'载入完成,数据量:%s' %(initCursor.count() + self.dbCursor.count()))
#----------------------------------------------------------------------
def connectMysql(self):
"""连接MysqlDB"""
# 载入json文件
fileName = 'mysql_connect.json'
try:
f = file(fileName)
except IOError:
self.writeCtaLog(u'回测引擎读取Mysql_connect.json失败')
return
# 解析json文件
setting = json.load(f)
try:
mysql_host = str(setting['host'])
mysql_port = int(setting['port'])
mysql_user = str(setting['user'])
mysql_passwd = str(setting['passwd'])
mysql_db = str(setting['db'])
except IOError:
self.writeCtaLog(u'回测引擎读取Mysql_connect.json,连接配置缺少字段,请检查')
return
try:
self.__mysqlConnection = MySQLdb.connect(host=mysql_host, user=mysql_user,
passwd=mysql_passwd, db=mysql_db, port=mysql_port)
self.__mysqlConnected = True
self.writeCtaLog(u'回测引擎连接MysqlDB成功')
except ConnectionFailure:
self.writeCtaLog(u'回测引擎连接MysqlDB失败')
#----------------------------------------------------------------------
def loadDataHistoryFromMysql(self, symbol, startDate, endDate):
"""载入历史TICK数据
如果加载过多数据会导致加载失败,间隔不要超过半年
"""
if not endDate:
endDate = datetime.today()
# 看本地缓存是否存在
if self.__loadDataHistoryFromLocalCache(symbol, startDate, endDate):
self.writeCtaLog(u'历史TICK数据从Cache载入')
return
# 每次获取日期周期
intervalDays = 10
for i in range (0,(endDate - startDate).days +1, intervalDays):
d1 = startDate + timedelta(days = i )
if (endDate - d1).days > 10:
d2 = startDate + timedelta(days = i + intervalDays -1 )
else:
d2 = endDate
# 从Mysql 提取数据
self.__qryDataHistoryFromMysql(symbol, d1, d2)
self.writeCtaLog(u'历史TICK数据共载入{0}'.format(len(self.historyData)))
# 保存本地cache文件
self.__saveDataHistoryToLocalCache(symbol, startDate, endDate)
def __loadDataHistoryFromLocalCache(self, symbol, startDate, endDate):
"""看本地缓存是否存在
added by IncenseLee
"""
# 运行路径下cache子目录
cacheFolder = os.getcwd()+'/cache'
# cache文件
cacheFile = u'{0}/{1}_{2}_{3}.pickle'.\
format(cacheFolder, symbol, startDate.strftime('%Y-%m-%d'), endDate.strftime('%Y-%m-%d'))
if not os.path.isfile(cacheFile):
return False
else:
# 从cache文件加载
cache = open(cacheFile,mode='r')
self.historyData = cPickle.load(cache)
cache.close()
return True
def __saveDataHistoryToLocalCache(self, symbol, startDate, endDate):
"""保存本地缓存
added by IncenseLee
"""
# 运行路径下cache子目录
cacheFolder = os.getcwd()+'/cache'
# 创建cache子目录
if not os.path.isdir(cacheFolder):
os.mkdir(cacheFolder)
# cache 文件名
cacheFile = u'{0}/{1}_{2}_{3}.pickle'.\
format(cacheFolder, symbol, startDate.strftime('%Y-%m-%d'), endDate.strftime('%Y-%m-%d'))
# 重复存在 返回
if os.path.isfile(cacheFile):
return False
else:
# 写入cache文件
cache = open(cacheFile, mode='w')
cPickle.dump(self.historyData,cache)
cache.close()
return True
#----------------------------------------------------------------------
def __qryDataHistoryFromMysql(self, symbol, startDate, endDate):
"""从Mysql载入历史TICK数据
added by IncenseLee
"""
try:
self.connectMysql()
if self.__mysqlConnected:
# 获取指针
cur = self.__mysqlConnection.cursor(MySQLdb.cursors.DictCursor)
if endDate:
# 开始日期 ~ 结束日期
sqlstring = ' select \'{0}\' as InstrumentID, str_to_date(concat(ndate,\' \', ntime),' \
'\'%Y-%m-%d %H:%i:%s\') as UpdateTime,price as LastPrice,vol as Volume, day_vol as DayVolume,' \
'position_vol as OpenInterest,bid1_price as BidPrice1,bid1_vol as BidVolume1, ' \
'sell1_price as AskPrice1, sell1_vol as AskVolume1 from TB_{0}MI ' \
'where ndate between cast(\'{1}\' as date) and cast(\'{2}\' as date) order by UpdateTime'.\
format(symbol, startDate, endDate)
elif startDate:
# 开始日期 - 当前
sqlstring = ' select \'{0}\' as InstrumentID,str_to_date(concat(ndate,\' \', ntime),' \
'\'%Y-%m-%d %H:%i:%s\') as UpdateTime,price as LastPrice,vol as Volume, day_vol as DayVolume,' \
'position_vol as OpenInterest,bid1_price as BidPrice1,bid1_vol as BidVolume1, ' \
'sell1_price as AskPrice1, sell1_vol as AskVolume1 from TB__{0}MI ' \
'where ndate > cast(\'{1}\' as date) order by UpdateTime'.\
format( symbol, startDate)
else:
# 所有数据
sqlstring =' select \'{0}\' as InstrumentID,str_to_date(concat(ndate,\' \', ntime),' \
'\'%Y-%m-%d %H:%i:%s\') as UpdateTime,price as LastPrice,vol as Volume, day_vol as DayVolume,' \
'position_vol as OpenInterest,bid1_price as BidPrice1,bid1_vol as BidVolume1, ' \
'sell1_price as AskPrice1, sell1_vol as AskVolume1 from TB__{0}MI order by UpdateTime'.\
format(symbol)
self.writeCtaLog(sqlstring)
# 执行查询
count = cur.execute(sqlstring)
self.writeCtaLog(u'历史TICK数据共{0}'.format(count))
# 分批次读取
fetch_counts = 0
fetch_size = 1000
while True:
results = cur.fetchmany(fetch_size)
if not results:
break
fetch_counts = fetch_counts + len(results)
if not self.historyData:
self.historyData =results
else:
self.historyData = self.historyData + results
self.writeCtaLog(u'{1}~{2}历史TICK数据载入共{0}'.format(fetch_counts,startDate,endDate))
else:
self.writeCtaLog(u'MysqlDB未连接请检查')
except MySQLdb.Error, e:
self.writeCtaLog(u'MysqlDB载入数据失败请检查.Error {0}'.format(e))
def __dataToTick(self, data):
"""
数据库查询返回的data结构转换为tick对象
added by IncenseLee
"""
tick = CtaTickData()
symbol = data['InstrumentID']
tick.symbol = symbol
# 创建TICK数据对象并更新数据
tick.vtSymbol = symbol
# tick.openPrice = data['OpenPrice']
# tick.highPrice = data['HighestPrice']
# tick.lowPrice = data['LowestPrice']
tick.lastPrice = float(data['LastPrice'])
# bug fix:
# ctp日常传送的volume数据是交易日日内累加值。数据库的volume是数据商自行计算整理的
# 因此改为使用DayVolume与CTP实盘一致
#tick.volume = data['Volume']
tick.volume = data['DayVolume']
tick.openInterest = data['OpenInterest']
# tick.upperLimit = data['UpperLimitPrice']
# tick.lowerLimit = data['LowerLimitPrice']
tick.datetime = data['UpdateTime']
tick.date = tick.datetime.strftime('%Y-%m-%d')
tick.time = tick.datetime.strftime('%H:%M:%S')
# 数据库中并没有tradingDay的数据回测时暂时按照date授予。
tick.tradingDay = tick.date
tick.bidPrice1 = float(data['BidPrice1'])
# tick.bidPrice2 = data['BidPrice2']
# tick.bidPrice3 = data['BidPrice3']
# tick.bidPrice4 = data['BidPrice4']
# tick.bidPrice5 = data['BidPrice5']
tick.askPrice1 = float(data['AskPrice1'])
# tick.askPrice2 = data['AskPrice2']
# tick.askPrice3 = data['AskPrice3']
# tick.askPrice4 = data['AskPrice4']
# tick.askPrice5 = data['AskPrice5']
tick.bidVolume1 = data['BidVolume1']
# tick.bidVolume2 = data['BidVolume2']
# tick.bidVolume3 = data['BidVolume3']
# tick.bidVolume4 = data['BidVolume4']
# tick.bidVolume5 = data['BidVolume5']
tick.askVolume1 = data['AskVolume1']
# tick.askVolume2 = data['AskVolume2']
# tick.askVolume3 = data['AskVolume3']
# tick.askVolume4 = data['AskVolume4']
# tick.askVolume5 = data['AskVolume5']
return tick
#----------------------------------------------------------------------
def getMysqlDeltaDate(self,symbol, startDate, decreaseDays):
"""从mysql库中获取交易日前若干天
added by IncenseLee
"""
try:
if self.__mysqlConnected:
# 获取mysql指针
cur = self.__mysqlConnection.cursor()
sqlstring='select distinct ndate from TB_{0}MI where ndate < ' \
'cast(\'{1}\' as date) order by ndate desc limit {2},1'.format(symbol, startDate, decreaseDays-1)
# self.writeCtaLog(sqlstring)
count = cur.execute(sqlstring)
if count > 0:
# 提取第一条记录
result = cur.fetchone()
return result[0]
else:
self.writeCtaLog(u'MysqlDB没有查询结果请检查日期')
else:
self.writeCtaLog(u'MysqlDB未连接请检查')
except MySQLdb.Error, e:
self.writeCtaLog(u'MysqlDB载入数据失败请检查.Error {0}: {1}'.format(e.arg[0],e.arg[1]))
# 出错后缺省返回
return startDate-timedelta(days=3)
# ----------------------------------------------------------------------
def runBackTestingWithArbTickFile(self, arbSymbol):
"""运行套利回测使用本地tickcsv数据)
参数:套利代码 SP rb1610&rb1701
added by IncenseLee
原始的tick分别存放在白天目录1和夜盘目录2中每天都有各个合约的数据
Z:\ticks\SHFE\201606\RB\0601\
RB1610.txt
RB1701.txt
....
Z:\ticks\SHFE_night\201606\RB\0601
RB1610.txt
RB1701.txt
....
夜盘目录为自然日,不是交易日。
按照回测的开始日期,到结束日期,循环每一天。
每天优先读取日盘数据,再读取夜盘数据。
读取eg1如RB1610读取Leg2如RB701合并成价差tick灌输到策略的onTick中。
"""
self.capital = self.initCapital # 更新设置期初资金
if len(arbSymbol) < 1:
self.writeCtaLog(u'套利合约为空')
return
if not (arbSymbol.upper().index("SP") == 0 and arbSymbol.index(" ") > 0 and arbSymbol.index("&") > 0):
self.writeCtaLog(u'套利合约格式不符合')
return
# 获得Leg1leg2
legs = arbSymbol[arbSymbol.index(" "):]
leg1 = legs[1:legs.index("&")]
leg2 = legs[legs.index("&") + 1:]
self.writeCtaLog(u'Leg1:{0},Leg2:{1}'.format(leg1, leg2))
if not self.dataStartDate:
self.writeCtaLog(u'回测开始日期未设置。')
return
# RB
if len(self.symbol)<1:
self.writeCtaLog(u'回测对象未设置。')
return
if not self.dataEndDate:
self.dataEndDate = datetime.today()
#首先根据回测模式,确认要使用的数据类
if self.mode == self.BAR_MODE:
self.writeCtaLog(u'本回测仅支持tick模式')
return
testdays = (self.dataEndDate - self.dataStartDate).days
if testdays < 1:
self.writeCtaLog(u'回测时间不足')
return
for i in range(0, testdays):
testday = self.dataStartDate + timedelta(days = i)
self.output(u'回测日期:{0}'.format(testday))
# 白天数据
self.__loadArbTicks(u'SHFE',testday,leg1,leg2)
# 夜盘数据
self.__loadArbTicks(u'SHFE_night', testday, leg1, leg2)
def __loadArbTicks(self,mainPath,testday,leg1,leg2):
leg1File = u'z:\\ticks\\{0}\\{1}\\{2}\\{3}\\{4}.txt' \
.format(mainPath,testday.strftime('%Y%m'), self.symbol, testday.strftime('%m%d'), leg1)
if not os.path.isfile(leg1File):
self.writeCtaLog(u'{0}文件不存在'.format(leg1File))
return
leg2File = u'z:\\ticks\\{0}\\{1}\\{2}\\{3}\\{4}.txt' \
.format(mainPath,testday.strftime('%Y%m'), self.symbol, testday.strftime('%m%d'), leg2)
if not os.path.isfile(leg2File):
self.writeCtaLog(u'{0}文件不存在'.format(leg2File))
return
self.writeCtaLog(u'加载回测日期:{0}\{1}的价差tick'.format(mainPath, testday))
cachefilename = u'{0}_{1}_{2}_{3}_{4}'.format(self.symbol,leg1,leg2, mainPath, testday.strftime('%Y%m%d'))
arbTicks = self.__loadArbTicksFromLocalCache(cachefilename)
if len(arbTicks) < 1:
# 先读取leg2的数据到目录以日期时间为key
leg2Ticks = {}
leg2CsvReadFile = file(leg2File, 'rb')
reader = csv.DictReader((line.replace('\0', '') for line in leg2CsvReadFile), delimiter=",")
self.writeCtaLog(u'加载{0}'.format(leg2File))
for row in reader:
tick = CtaTickData()
tick.vtSymbol = self.symbol
tick.symbol = self.symbol
tick.date = testday.strftime('%Y%m%d')
tick.time = row['Time']
tick.datetime = datetime.strptime(tick.date + ' ' + tick.time, '%Y%m%d %H:%M:%S.%f')
tick.lastPrice = float(row['LastPrice'])
tick.volume = int(float(row['LVolume']))
tick.bidPrice1 = float(row['BidPrice']) # 叫买价(价格低)
tick.bidVolume1 = int(float(row['BidVolume']))
tick.askPrice1 = float(row['AskPrice']) # 叫卖价(价格高)
tick.askVolume1 = int(float(row['AskVolume']))
# 排除涨停/跌停的数据
if (tick.bidPrice1 == float('1.79769E308') and tick.bidVolume1 == 0) \
or (tick.askPrice1 == float('1.79769E308') and tick.askVolume1 == 0):
continue
leg2Ticks[tick.date + ' ' + tick.time] = tick
leg1CsvReadFile = file(leg1File, 'rb')
reader = csv.DictReader((line.replace('\0', '') for line in leg1CsvReadFile), delimiter=",")
self.writeCtaLog(u'加载{0}'.format(leg1File))
for row in reader:
dtStr = ' '.join([testday.strftime('%Y%m%d'), row['Time']])
if dtStr in leg2Ticks:
leg2Tick = leg2Ticks[dtStr]
arbTick = CtaTickData()
arbTick.vtSymbol = self.symbol
arbTick.symbol = self.symbol
arbTick.date = testday.strftime('%Y%m%d')
arbTick.time = row['Time']
arbTick.datetime = datetime.strptime(arbTick.date + ' ' + arbTick.time, '%Y%m%d %H:%M:%S.%f')
arbTick.lastPrice = EMPTY_FLOAT
arbTick.volume = EMPTY_INT
leg1AskPrice1 = float(row['AskPrice'])
leg1AskVolume1 = int(float(row['AskVolume']))
leg1BidPrice1 = float(row['BidPrice'])
leg1BidVolume1 = int(float(row['BidVolume']))
# 排除涨停/跌停的数据
if (leg1AskPrice1== float('1.79769E308') and leg1AskVolume1 == 0) \
or (leg1BidPrice1 == float('1.79769E308') and leg1BidVolume1 == 0):
continue
# 叫卖价差=leg1.askPrice1 - leg2.bidPrice1volume为两者最小
arbTick.askPrice1 = leg1AskPrice1 - leg2Tick.bidPrice1
arbTick.askVolume1 = min(leg1AskVolume1, leg2Tick.bidVolume1)
# 叫买价差=leg1.bidPrice1 - leg2.askPrice1volume为两者最小
arbTick.bidPrice1 = leg1BidPrice1 - leg2Tick.askPrice1
arbTick.bidVolume1 = min(leg1BidVolume1, leg2Tick.askVolume1)
arbTicks.append(arbTick)
# 保存到历史目录
if len(arbTicks) > 0:
self.__saveArbTicksToLocalCache(cachefilename, arbTicks)
for t in arbTicks:
# 推送到策略中
self.newTick(t)
def __loadArbTicksFromLocalCache(self, filename):
"""从本地缓存中,加载数据"""
# 运行路径下cache子目录
cacheFolder = os.getcwd() + '/cache'
# cache文件
cacheFile = u'{0}/{1}.pickle'. \
format(cacheFolder, filename)
if not os.path.isfile(cacheFile):
return []
else:
# 从cache文件加载
cache = open(cacheFile, mode='r')
l = cPickle.load(cache)
cache.close()
return l
def __saveArbTicksToLocalCache(self, filename, arbticks):
"""保存价差tick到本地缓存目录"""
# 运行路径下cache子目录
cacheFolder = os.getcwd() + '/cache'
# 创建cache子目录
if not os.path.isdir(cacheFolder):
os.mkdir(cacheFolder)
# cache 文件名
cacheFile = u'{0}/{1}.pickle'. \
format(cacheFolder, filename)
# 重复存在 返回
if os.path.isfile(cacheFile):
return False
else:
# 写入cache文件
cache = open(cacheFile, mode='w')
cPickle.dump(arbticks, cache)
cache.close()
return True
#----------------------------------------------------------------------
def runBackTestingWithBarFile(self, filename):
"""运行回测使用本地csv数据)
added by IncenseLee
"""
self.capital = self.initCapital # 更新设置期初资金
if not filename:
self.writeCtaLog(u'请指定回测数据文件')
return
if not self.dataStartDate:
self.writeCtaLog(u'回测开始日期未设置。')
return
if not self.dataEndDate:
self.dataEndDate = datetime.today()
import os
if not os.path.isfile(filename):
self.writeCtaLog(u'{0}文件不存在'.format(filename))
if len(self.symbol)<1:
self.writeCtaLog(u'回测对象未设置。')
return
# 首先根据回测模式,确认要使用的数据类
if not self.mode == self.BAR_MODE:
self.writeCtaLog(u'文件仅支持bar模式若扩展tick模式需要修改本方法')
return
self.output(u'开始回测')
#self.strategy.inited = True
self.strategy.onInit()
self.output(u'策略初始化完成')
self.strategy.trading = True
self.strategy.onStart()
self.output(u'策略启动完成')
self.output(u'开始回放数据')
import csv
csvfile = file(filename,'rb')
reader = csv.DictReader((line.replace('\0', '') for line in csvfile), delimiter=",")
for row in reader:
try:
bar = CtaBarData()
bar.open = float(row['Open'])
bar.high = float(row['High'])
bar.low = float(row['Low'])
bar.close = float(row['Close'])
bar.volume = float(row['TotalVolume'])
barEndTime = datetime.strptime(row['Date']+' ' + row['Time'], '%Y/%m/%d %H:%M:%S')
# 使用Bar的开始时间作为datetime
bar.datetime = barEndTime - timedelta(seconds=self.barTimeInterval)
bar.date = bar.datetime.strftime('%Y-%m-%d')
bar.time = bar.datetime.strftime('%H:%M:%S')
if not (bar.datetime < self.dataStartDate or bar.datetime >= self.dataEndDate):
self.newBar(bar)
except Exception, ex:
self.writeCtaLog(u'{0}:{1}'.format(Exception,ex))
continue
#----------------------------------------------------------------------
def runBacktestingWithMysql(self):
"""运行回测(使用Mysql数据
added by IncenseLee
"""
self.capital = self.initCapital # 更新设置期初资金
if not self.dataStartDate:
self.writeCtaLog(u'回测开始日期未设置。')
return
if not self.dataEndDate:
self.dataEndDate = datetime.today()
if len(self.symbol)<1:
self.writeCtaLog(u'回测对象未设置。')
return
# 首先根据回测模式,确认要使用的数据类
if self.mode == self.BAR_MODE:
dataClass = CtaBarData
func = self.newBar
else:
dataClass = CtaTickData
func = self.newTick
self.output(u'开始回测')
#self.strategy.inited = True
self.strategy.onInit()
self.output(u'策略初始化完成')
self.strategy.trading = True
self.strategy.onStart()
self.output(u'策略启动完成')
self.output(u'开始回放数据')
# 每次获取日期周期
intervalDays = 10
for i in range (0,(self.dataEndDate - self.dataStartDate).days +1, intervalDays):
d1 = self.dataStartDate + timedelta(days = i )
if (self.dataEndDate - d1).days > intervalDays:
d2 = self.dataStartDate + timedelta(days = i + intervalDays -1 )
else:
d2 = self.dataEndDate
# 提取历史数据
self.loadDataHistoryFromMysql(self.symbol, d1, d2)
self.output(u'数据日期:{0} => {1}'.format(d1,d2))
# 将逐笔数据推送
for data in self.historyData:
# 记录最新的TICK数据
self.tick = self.__dataToTick(data)
self.dt = self.tick.datetime
# 处理限价单
self.crossLimitOrder()
self.crossStopOrder()
# 推送到策略引擎中
self.strategy.onTick(self.tick)
# 清空历史数据
self.historyData = []
self.output(u'数据回放结束')
#----------------------------------------------------------------------
def runBacktesting(self):
"""运行回测"""
self.capital = self.initCapital # 更新设置期初资金
# 载入历史数据
self.loadHistoryData()
# 首先根据回测模式,确认要使用的数据类
if self.mode == self.BAR_MODE:
dataClass = CtaBarData
func = self.newBar
else:
dataClass = CtaTickData
func = self.newTick
self.output(u'开始回测')
self.strategy.inited = True
self.strategy.onInit()
self.output(u'策略初始化完成')
self.strategy.trading = True
self.strategy.onStart()
self.output(u'策略启动完成')
self.output(u'开始回放数据')
for d in self.dbCursor:
data = dataClass()
data.__dict__ = d
func(data)
self.output(u'数据回放结束')
#----------------------------------------------------------------------
def newBar(self, bar):
"""新的K线"""
self.bar = bar
self.dt = bar.datetime
self.crossLimitOrder() # 先撮合限价单
self.crossStopOrder() # 再撮合停止单
self.strategy.onBar(bar) # 推送K线到策略中
#----------------------------------------------------------------------
def newTick(self, tick):
"""新的Tick"""
self.tick = tick
self.dt = tick.datetime
self.crossLimitOrder()
self.crossStopOrder()
self.strategy.onTick(tick)
#----------------------------------------------------------------------
def initStrategy(self, strategyClass, setting=None):
"""
初始化策略
setting是策略的参数设置如果使用类中写好的默认设置则可以不传该参数
"""
self.strategy = strategyClass(self, setting)
self.strategy.name = self.strategy.className
#----------------------------------------------------------------------
def sendOrder(self, vtSymbol, orderType, price, volume, strategy):
"""发单"""
self.writeCtaLog(u'{0},{1},{2}@{3}'.format(vtSymbol,orderType,price,volume))
self.limitOrderCount += 1
orderID = str(self.limitOrderCount)
order = VtOrderData()
order.vtSymbol = vtSymbol
order.price = price
order.totalVolume = volume
order.status = STATUS_NOTTRADED # 刚提交尚未成交
order.orderID = orderID
order.vtOrderID = orderID
order.orderTime = str(self.dt)
# added by IncenseLee
order.gatewayName = self.gatewayName
# CTA委托类型映射
if orderType == CTAORDER_BUY:
order.direction = DIRECTION_LONG
order.offset = OFFSET_OPEN
elif orderType == CTAORDER_SELL:
order.direction = DIRECTION_SHORT
order.offset = OFFSET_CLOSE
elif orderType == CTAORDER_SHORT:
order.direction = DIRECTION_SHORT
order.offset = OFFSET_OPEN
elif orderType == CTAORDER_COVER:
order.direction = DIRECTION_LONG
order.offset = OFFSET_CLOSE
# modified by IncenseLee
key = u'{0}.{1}'.format(order.gatewayName, orderID)
# 保存到限价单字典中
self.workingLimitOrderDict[key] = order
self.limitOrderDict[key] = order
return key
#----------------------------------------------------------------------
def cancelOrder(self, vtOrderID):
"""撤单"""
if vtOrderID in self.workingLimitOrderDict:
order = self.workingLimitOrderDict[vtOrderID]
order.status = STATUS_CANCELLED
order.cancelTime = str(self.dt)
del self.workingLimitOrderDict[vtOrderID]
def cancelOrders(self, symbol, offset=EMPTY_STRING):
"""撤销所有单"""
# Symbol参数:指定合约的撤单;
# OFFSET参数:指定Offset的撤单,缺省不填写时,为所有
self.writeCtaLog(u'从所有订单中撤销{0}\{1}'.format(offset, symbol))
for vtOrderID in self.workingLimitOrderDict.keys():
order = self.workingLimitOrderDict[vtOrderID]
if offset == EMPTY_STRING:
offsetCond = True
else:
offsetCond = order.offset == offset
if order.symbol == symbol and offsetCond:
self.writeCtaLog(u'撤销订单:{0},{1} {2}@{3}'.format(vtOrderID, order.direction, order.price, order.totalVolume))
order.status = STATUS_CANCELLED
order.cancelTime = str(self.dt)
del self.workingLimitOrderDict[vtOrderID]
#----------------------------------------------------------------------
def sendStopOrder(self, vtSymbol, orderType, price, volume, strategy):
"""发停止单(本地实现)"""
self.stopOrderCount += 1
stopOrderID = STOPORDERPREFIX + str(self.stopOrderCount)
so = StopOrder()
so.vtSymbol = vtSymbol
so.price = price
so.volume = volume
so.strategy = strategy
so.stopOrderID = stopOrderID
so.status = STOPORDER_WAITING
if orderType == CTAORDER_BUY:
so.direction = DIRECTION_LONG
so.offset = OFFSET_OPEN
elif orderType == CTAORDER_SELL:
so.direction = DIRECTION_SHORT
so.offset = OFFSET_CLOSE
elif orderType == CTAORDER_SHORT:
so.direction = DIRECTION_SHORT
so.offset = OFFSET_OPEN
elif orderType == CTAORDER_COVER:
so.direction = DIRECTION_LONG
so.offset = OFFSET_CLOSE
# 保存stopOrder对象到字典中
self.stopOrderDict[stopOrderID] = so
self.workingStopOrderDict[stopOrderID] = so
return stopOrderID
#----------------------------------------------------------------------
def cancelStopOrder(self, stopOrderID):
"""撤销停止单"""
# 检查停止单是否存在
if stopOrderID in self.workingStopOrderDict:
so = self.workingStopOrderDict[stopOrderID]
so.status = STOPORDER_CANCELLED
del self.workingStopOrderDict[stopOrderID]
#----------------------------------------------------------------------
def crossLimitOrder(self):
"""基于最新数据撮合限价单"""
# 先确定会撮合成交的价格
if self.mode == self.BAR_MODE:
buyCrossPrice = self.bar.low # 若买入方向限价单价格高于该价格,则会成交
sellCrossPrice = self.bar.high # 若卖出方向限价单价格低于该价格,则会成交
buyBestCrossPrice = self.bar.open # 在当前时间点前发出的买入委托可能的最优成交价
sellBestCrossPrice = self.bar.open # 在当前时间点前发出的卖出委托可能的最优成交价
else:
buyCrossPrice = self.tick.askPrice1
sellCrossPrice = self.tick.bidPrice1
buyBestCrossPrice = self.tick.askPrice1
sellBestCrossPrice = self.tick.bidPrice1
# 遍历限价单字典中的所有限价单
for orderID, order in self.workingLimitOrderDict.items():
# 判断是否会成交
buyCross = order.direction==DIRECTION_LONG and order.price>=buyCrossPrice
sellCross = order.direction==DIRECTION_SHORT and order.price<=sellCrossPrice
# 如果发生了成交
if buyCross or sellCross:
# 推送成交数据
self.tradeCount += 1 # 成交编号自增1
tradeID = str(self.tradeCount)
trade = VtTradeData()
trade.vtSymbol = order.vtSymbol
trade.tradeID = tradeID
trade.vtTradeID = tradeID
trade.orderID = order.orderID
trade.vtOrderID = order.orderID
trade.direction = order.direction
trade.offset = order.offset
# 以买入为例:
# 1. 假设当根K线的OHLC分别为100, 125, 90, 110
# 2. 假设在上一根K线结束(也是当前K线开始)的时刻策略发出的委托为限价105
# 3. 则在实际中的成交价会是100而不是105因为委托发出时市场的最优价格是100
if buyCross:
trade.price = min(order.price, buyBestCrossPrice)
self.strategy.pos += order.totalVolume
else:
trade.price = max(order.price, sellBestCrossPrice)
self.strategy.pos -= order.totalVolume
trade.volume = order.totalVolume
trade.tradeTime = str(self.dt)
trade.dt = self.dt
self.strategy.onTrade(trade)
self.tradeDict[tradeID] = trade
self.writeCtaLog(u'TradeId:{0}'.format(tradeID))
# 推送委托数据
order.tradedVolume = order.totalVolume
order.status = STATUS_ALLTRADED
self.strategy.onOrder(order)
# 从字典中删除该限价单
try:
del self.workingLimitOrderDict[orderID]
except Exception,ex:
self.writeCtaError(u'{0}:{1}'.format(Exception,ex))
if self.calculateMode == self.REALTIME_MODE:
self.realtimeCalculate()
#----------------------------------------------------------------------
def crossStopOrder(self):
"""基于最新数据撮合停止单"""
# 先确定会撮合成交的价格,这里和限价单规则相反
if self.mode == self.BAR_MODE:
buyCrossPrice = self.bar.high # 若买入方向停止单价格低于该价格,则会成交
sellCrossPrice = self.bar.low # 若卖出方向限价单价格高于该价格,则会成交
bestCrossPrice = self.bar.open # 最优成交价,买入停止单不能低于,卖出停止单不能高于
else:
buyCrossPrice = self.tick.lastPrice
sellCrossPrice = self.tick.lastPrice
bestCrossPrice = self.tick.lastPrice
# 遍历停止单字典中的所有停止单
for stopOrderID, so in self.workingStopOrderDict.items():
# 判断是否会成交
buyCross = so.direction==DIRECTION_LONG and so.price<=buyCrossPrice
sellCross = so.direction==DIRECTION_SHORT and so.price>=sellCrossPrice
# 如果发生了成交
if buyCross or sellCross:
# 推送成交数据
self.tradeCount += 1 # 成交编号自增1
tradeID = str(self.tradeCount)
trade = VtTradeData()
trade.vtSymbol = so.vtSymbol
trade.tradeID = tradeID
trade.vtTradeID = tradeID
if buyCross:
self.strategy.pos += so.volume
trade.price = max(bestCrossPrice, so.price)
else:
self.strategy.pos -= so.volume
trade.price = min(bestCrossPrice, so.price)
self.limitOrderCount += 1
orderID = str(self.limitOrderCount)
trade.orderID = orderID
trade.vtOrderID = orderID
trade.direction = so.direction
trade.offset = so.offset
trade.volume = so.volume
trade.tradeTime = str(self.dt)
trade.dt = self.dt
self.strategy.onTrade(trade)
self.tradeDict[tradeID] = trade
# 推送委托数据
so.status = STOPORDER_TRIGGERED
order = VtOrderData()
order.vtSymbol = so.vtSymbol
order.symbol = so.vtSymbol
order.orderID = orderID
order.vtOrderID = orderID
order.direction = so.direction
order.offset = so.offset
order.price = so.price
order.totalVolume = so.volume
order.tradedVolume = so.volume
order.status = STATUS_ALLTRADED
order.orderTime = trade.tradeTime
self.strategy.onOrder(order)
self.limitOrderDict[orderID] = order
# 从字典中删除该限价单
del self.workingStopOrderDict[stopOrderID]
if self.calculateMode == self.REALTIME_MODE:
self.realtimeCalculate()
#----------------------------------------------------------------------
def insertData(self, dbName, collectionName, data):
"""考虑到回测中不允许向数据库插入数据,防止实盘交易中的一些代码出错"""
pass
#----------------------------------------------------------------------
def loadBar(self, dbName, collectionName, startDate):
"""直接返回初始化数据列表中的Bar"""
return self.initData
#----------------------------------------------------------------------
def loadTick(self, dbName, collectionName, startDate):
"""直接返回初始化数据列表中的Tick"""
return self.initData
#----------------------------------------------------------------------
def writeCtaLog(self, content):
"""记录日志"""
#log = str(self.dt) + ' ' + content
#self.logList.append(log)
# 写入本地log日志
logging.info(content)
def writeCtaError(self, content):
"""记录异常"""
self.output(content)
self.writeCtaLog(content)
#----------------------------------------------------------------------
def output(self, content):
"""输出内容"""
print str(datetime.now()) + "\t" + content
def realtimeCalculate(self):
"""实时计算交易结果"""
resultDict = OrderedDict() # 交易结果记录
longTrade = [] # 未平仓的多头交易
shortTrade = [] # 未平仓的空头交易
longid = EMPTY_STRING
shortid = EMPTY_STRING
for tradeid in self.tradeDict.keys():
trade = self.tradeDict[tradeid]
# 多头交易
if trade.direction == DIRECTION_LONG:
# 如果尚无空头交易
if not shortTrade:
longTrade.append(trade)
longid = tradeid
# 当前多头交易为平空
else:
gId = tradeid # 交易组(多个平仓数为一组)
gr = None # 组合的交易结果
coverVolume = trade.volume
while coverVolume > 0:
if len(shortTrade)==0:
self.writeCtaError(u'异常,没有开空仓的数据')
break
# 从未平仓的空头交易
entryTrade = shortTrade.pop(0)
# 开空volume不大于平仓volume
if coverVolume >= entryTrade.volume:
self.writeCtaLog(u'coverVolume:{0} >= entryTrade.volume:{1}'.format(coverVolume, entryTrade.volume))
coverVolume = coverVolume - entryTrade.volume
result = TradingResult(entryTrade.price, trade.price, -entryTrade.volume,
self.rate, self.slippage, self.size,
groupId=gId, fixcommission=self.fixCommission)
t = {}
t['OpenTime'] = entryTrade.tradeTime
t['OpenPrice'] = entryTrade.price
t['Direction'] = u'Short'
t['CloseTime'] = trade.tradeTime
t['ClosePrice'] = trade.price
t['Volume'] = entryTrade.volume
t['Profit'] = result.pnl
self.exportTradeList.append(t)
self.writeCtaLog(u'{6} [{7}:开空{0},short:{1}]-[{8}:平空{2},cover:{3},vol:{4}],净盈亏:{5}'
.format(entryTrade.tradeTime, entryTrade.price,
trade.tradeTime, trade.price, entryTrade.volume, result.pnl,
gId, shortid, tradeid))
if type(gr) == type(None):
if coverVolume > 0:
# 属于组合
gr = copy.deepcopy(result)
# 删除开空交易单
del self.tradeDict[entryTrade.tradeID]
else:
# 不属于组合
resultDict[entryTrade.dt] = result
# 删除平空交易单,
del self.tradeDict[trade.tradeID]
# 删除开空交易单
del self.tradeDict[entryTrade.tradeID]
else:
# 更新组合的数据
gr.turnover = gr.turnover + result.turnover
gr.commission = gr.commission + result.commission
gr.slippage = gr.slippage + result.slippage
gr.pnl = gr.pnl + result.pnl
# 删除开空交易单
del self.tradeDict[entryTrade.tradeID]
# 所有仓位平完
if coverVolume == 0:
gr.volume = trade.volume
resultDict[entryTrade.dt] = gr
# 删除平空交易单,
del self.tradeDict[trade.tradeID]
# 开空volume,大于平仓volume需要更新减少tradeDict的数量。
else:
self.writeCtaLog(u'Short volume:{0} > Cover volume:{1}需要更新减少tradeDict的数量。'.format(entryTrade.volume,coverVolume))
shortVolume = entryTrade.volume - coverVolume
result = TradingResult(entryTrade.price, trade.price, -coverVolume,
self.rate, self.slippage, self.size,
groupId=gId, fixcommission=self.fixCommission)
t = {}
t['OpenTime'] = entryTrade.tradeTime
t['OpenPrice'] = entryTrade.price
t['Direction'] = u'Short'
t['CloseTime'] = trade.tradeTime
t['ClosePrice'] = trade.price
t['Volume'] = coverVolume
t['Profit'] = result.pnl
self.exportTradeList.append(t)
self.writeCtaLog(u'{6} [{7}:开空{0},short:{1}]-[{8}:平空{2},cover:{3},vol:{4}],净盈亏:{5}'
.format(entryTrade.tradeTime, entryTrade.price,
trade.tradeTime, trade.price, coverVolume, result.pnl,
gId, shortid, tradeid))
# 更新减少开仓单的volume,重新推进开仓单列表中
entryTrade.volume = shortVolume
shortTrade.append(entryTrade)
coverVolume = 0
if type(gr) == type(None):
resultDict[entryTrade.dt] = result
else:
# 更新组合的数据
gr.turnover = gr.turnover + result.turnover
gr.commission = gr.commission + result.commission
gr.slippage = gr.slippage + result.slippage
gr.pnl = gr.pnl + result.pnl
gr.volume = trade.volume
resultDict[entryTrade.dt] = gr
# 删除平空交易单,
del self.tradeDict[trade.tradeID]
if type(gr) != type(None):
self.writeCtaLog(u'组合净盈亏:{0}'.format(gr.pnl))
self.writeCtaLog(u'-------------')
# 空头交易
else:
# 如果尚无多头交易
if not longTrade:
shortTrade.append(trade)
shortid = tradeid
# 当前空头交易为平多
else:
gId = tradeid # 交易组(多个平仓数为一组) s
gr = None # 组合的交易结果
sellVolume = trade.volume
self.output(u'多平:{0}'.format(sellVolume))
self.writeCtaLog(u'多平:{0}'.format(sellVolume))
while sellVolume > 0:
if len(longTrade)==0:
self.writeCtaError(u'异常,没有开多单')
break
entryTrade = longTrade.pop(0)
# 开多volume不大于平仓volume
if sellVolume >= entryTrade.volume:
self.writeCtaLog(u'Sell Volume:{0} >= Entry Volume:{1}'.format(sellVolume, entryTrade.volume))
sellVolume = sellVolume - entryTrade.volume
result = TradingResult(entryTrade.price, trade.price, entryTrade.volume,
self.rate, self.slippage, self.size,
groupId=gId, fixcommission=self.fixCommission)
t = {}
t['OpenTime'] = entryTrade.tradeTime
t['OpenPrice'] = entryTrade.price
t['Direction'] = u'Long'
t['CloseTime'] = trade.tradeTime
t['ClosePrice'] = trade.price
t['Volume'] = entryTrade.volume
t['Profit'] = result.pnl
self.exportTradeList.append(t)
self.writeCtaLog(u'{6} [{7}:开多{0},buy:{1}]-[{8}.平多{2},sell:{3},vol:{4}],净盈亏:{5}'
.format(entryTrade.tradeTime, entryTrade.price,
trade.tradeTime,trade.price, entryTrade.volume, result.pnl,
gId, longid, tradeid))
if type(gr) == type(None):
if sellVolume > 0:
# 属于组合
gr = copy.deepcopy(result)
# 删除开多交易单
del self.tradeDict[entryTrade.tradeID]
else:
# 不属于组合
resultDict[entryTrade.dt] = result
# 删除平多交易单,
del self.tradeDict[trade.tradeID]
# 删除开多交易单
del self.tradeDict[entryTrade.tradeID]
else:
# 更新组合的数据
gr.turnover = gr.turnover + result.turnover
gr.commission = gr.commission + result.commission
gr.slippage = gr.slippage + result.slippage
gr.pnl = gr.pnl + result.pnl
# 删除开多交易单
del self.tradeDict[entryTrade.tradeID]
if sellVolume == 0:
gr.volume = trade.volume
resultDict[entryTrade.dt] = gr
# 删除平多交易单,
del self.tradeDict[trade.tradeID]
# 开多volume,大于平仓volume需要更新减少tradeDict的数量。
else:
longVolume = entryTrade.volume -sellVolume
self.writeCtaLog(u'Long Volume:{0} > sell Volume:{1}'.format(entryTrade.volume,sellVolume))
result = TradingResult(entryTrade.price, trade.price, sellVolume,
self.rate, self.slippage, self.size,
groupId=gId, fixcommission=self.fixCommission)
t = {}
t['OpenTime'] = entryTrade.tradeTime
t['OpenPrice'] = entryTrade.price
t['Direction'] = u'Long'
t['CloseTime'] = trade.tradeTime
t['ClosePrice'] = trade.price
t['Volume'] = sellVolume
t['Profit'] = result.pnl
self.exportTradeList.append(t)
self.writeCtaLog(u'{6} [{7}:开多{0},buy:{1}]-[{8}.平多{2},sell:{3},vol:{4}],净盈亏:{5}'
.format(entryTrade.tradeTime, entryTrade.price,
trade.tradeTime, trade.price, sellVolume, result.pnl,
gId, longid, tradeid))
# 减少开多volume,重新推进开多单列表中
entryTrade.volume = longVolume
longTrade.append(entryTrade)
sellVolume = 0
if type(gr) == type(None):
resultDict[entryTrade.dt] = result
else:
# 更新组合的数据
gr.turnover = gr.turnover + result.turnover
gr.commission = gr.commission + result.commission
gr.slippage = gr.slippage + result.slippage
gr.pnl = gr.pnl + result.pnl
gr.volume = trade.volume
resultDict[entryTrade.dt] = gr
# 删除平多交易单,
del self.tradeDict[trade.tradeID]
if type(gr) != type(None):
self.writeCtaLog(u'组合净盈亏:{0}'.format(gr.pnl))
self.writeCtaLog(u'-------------')
# 计算仓位比例
occupyMoney = EMPTY_FLOAT
occupyLongVolume = EMPTY_INT
occupyShortVolume = EMPTY_INT
if len(longTrade) > 0:
for t in longTrade:
occupyMoney += t.price * abs(t.volume) * self.size * 0.11
occupyLongVolume += abs(t.volume)
if len(shortTrade) > 0:
for t in shortTrade:
occupyMoney += t.price * abs(t.volume) * self.size * 0.11
occupyShortVolume += (t.volume)
self.output(u'occupyLongVolume:{0},occupyShortVolume:{1}'.format(occupyLongVolume,occupyShortVolume))
self.writeCtaLog(u'occupyLongVolume:{0},occupyShortVolume:{1}'.format(occupyLongVolume, occupyShortVolume))
# 最大持仓
self.maxVolume = max(self.maxVolume, max(occupyLongVolume, occupyShortVolume))
self.avaliable = self.capital - occupyMoney
self.percent = round(float(occupyMoney * 100 / self.capital), 2)
# 检查是否有平交易
if not resultDict:
if len(longTrade) > 0:
msg = u'持多仓{0},资金占用:{1},仓位:{2}'.format(occupyLongVolume, occupyMoney, self.percent)
self.output(msg)
self.writeCtaLog(msg)
elif len(shortTrade) > 0:
msg = u'持空仓{0},资金占用:{1},仓位:{2}'.format(occupyShortVolume, occupyMoney, self.percent)
self.output(msg)
self.writeCtaLog(msg)
return
for time, result in resultDict.items():
if result.pnl > 0:
self.wins += 1
self.capital += result.pnl
self.maxCapital = max(self.capital, self.maxCapital)
#self.maxVolume = max(self.maxVolume, result.volume)
drawdown = self.capital - self.maxCapital
drawdownRate = round(float(drawdown*100/self.maxCapital),4)
self.pnlList.append(result.pnl)
self.timeList.append(time)
self.capitalList.append(self.capital)
self.drawdownList.append(drawdown)
self.drawdownRateList.append(drawdownRate)
self.totalResult += 1
self.totalTurnover += result.turnover
self.totalCommission += result.commission
self.totalSlippage += result.slippage
self.output(u'[{5}],{6} Vol:{0},盈亏:{1},回撤:{2}/{3},权益:{4}'.
format(abs(result.volume), result.pnl, drawdown,
drawdownRate, self.capital, result.groupId, time))
# 重新计算一次avaliable
self.avaliable = self.capital - occupyMoney
self.percent = round(float(occupyMoney * 100 / self.capital), 2)
#----------------------------------------------------------------------
def calculateBacktestingResult(self):
"""
计算回测结果
Modified by Incense Lee
增加了支持逐步加仓的计算:
例如前面共有6次开仓1手开仓+5次加仓每次1手平仓只有1次六手。那么交易次数是6次开仓+平仓)。
暂不支持每次加仓数目不一致的核对(因为比较复杂)
增加组合的支持。组合中仍然按照1手逐步加仓和多手平仓的方法即使启用了复利模式也仍然按照这个规则只是在计算收益时才乘以系数
增加期初权益,每次交易后的权益,可用资金,仓位比例。
"""
self.output(u'计算回测结果')
# 首先基于回测后的成交记录,计算每笔交易的盈亏
resultDict = OrderedDict() # 交易结果记录
longTrade = [] # 未平仓的多头交易
shortTrade = [] # 未平仓的空头交易
i = 1
tradeUnit = 1
longid = EMPTY_STRING
shortid = EMPTY_STRING
for tradeid in self.tradeDict.keys():
trade = self.tradeDict[tradeid]
# 多头交易
if trade.direction == DIRECTION_LONG:
# 如果尚无空头交易
if not shortTrade:
longTrade.append(trade)
longid = tradeid
# 当前多头交易为平空
else:
gId = i # 交易组(多个平仓数为一组)
gt = 1 # 组合的交易次数
gr = None # 组合的交易结果
if trade.volume >tradeUnit:
self.writeCtaLog(u'平仓数{0},组合编号:{1}'.format(trade.volume,gId))
gt = int(trade.volume/tradeUnit)
for tv in range(gt):
entryTrade = shortTrade.pop(0)
result = TradingResult(entryTrade.price, trade.price, -tradeUnit,
self.rate, self.slippage, self.size,
groupId=gId, fixcommission=self.fixCommission)
if tv == 0:
if gt==1:
resultDict[entryTrade.dt] = result
else:
gr = copy.deepcopy(result)
else:
gr.turnover = gr.turnover + result.turnover
gr.commission = gr.commission + result.commission
gr.slippage = gr.slippage + result.slippage
gr.pnl = gr.pnl + result.pnl
if tv == gt -1:
gr.volume = trade.volume
resultDict[entryTrade.dt] = gr
t = {}
t['OpenTime'] = entryTrade.tradeTime.strftime('%Y/%m/%d %H:%M:%S')
t['OpenPrice'] = entryTrade.price
t['Direction'] = u'Short'
t['CloseTime'] = trade.tradeTime.strftime('%Y/%m/%d %H:%M:%S')
t['ClosePrice'] = trade.price
t['Volume'] = tradeUnit
t['Profit'] = result.pnl
self.exportTradeList.append(t)
self.writeCtaLog(u'{9}@{6} [{7}:开空{0},short:{1}]-[{8}:平空{2},cover:{3},vol:{4}],净盈亏:{5}'
.format(entryTrade.tradeTime, entryTrade.price,
trade.tradeTime, trade.price, tradeUnit, result.pnl,
i, shortid, tradeid,gId))
i = i+1
if type(gr) != type(None):
self.writeCtaLog(u'组合净盈亏:{0}'.format(gr.pnl))
self.writeCtaLog(u'-------------')
# 空头交易
else:
# 如果尚无多头交易
if not longTrade:
shortTrade.append(trade)
shortid = tradeid
# 当前空头交易为平多
else:
gId = i # 交易组(多个平仓数为一组)
gt = 1 # 组合的交易次数
gr = None # 组合的交易结果
if trade.volume >tradeUnit:
self.writeCtaLog(u'平仓数{0},组合编号:{1}'.format(trade.volume,gId))
gt = int(trade.volume/tradeUnit)
for tv in range(gt):
entryTrade = longTrade.pop(0)
result = TradingResult(entryTrade.price, trade.price, tradeUnit,
self.rate, self.slippage, self.size,
groupId= gId, fixcommission=self.fixCommission)
if tv == 0:
if gt==1:
resultDict[entryTrade.dt] = result
else:
gr = copy.deepcopy(result)
else:
gr.turnover = gr.turnover + result.turnover
gr.commission = gr.commission + result.commission
gr.slippage = gr.slippage + result.slippage
gr.pnl = gr.pnl + result.pnl
if tv == gt -1:
gr.volume = trade.volume
resultDict[entryTrade.dt] = gr
t = {}
t['OpenTime'] = entryTrade.tradeTime.strftime('%Y/%m/%d %H:%M:%S')
t['OpenPrice'] = entryTrade.price
t['Direction'] = u'Long'
t['CloseTime'] = trade.tradeTime.strftime('%Y/%m/%d %H:%M:%S')
t['ClosePrice'] = trade.price
t['Volume'] = tradeUnit
t['Profit'] = result.pnl
self.exportTradeList.append(t)
self.writeCtaLog(u'{9}@{6} [{7}:开多{0},buy:{1}]-[{8}.平多{2},sell:{3},vol:{4}],净盈亏:{5}'
.format(entryTrade.tradeTime, entryTrade.price,
trade.tradeTime,trade.price, tradeUnit, result.pnl,
i, longid, tradeid, gId))
i = i+1
if type(gr) != type(None):
self.writeCtaLog(u'组合净盈亏:{0}'.format(gr.pnl))
self.writeCtaLog(u'-------------')
# 检查是否有交易
if not resultDict:
self.output(u'无交易结果')
return {}
# 然后基于每笔交易的结果,我们可以计算具体的盈亏曲线和最大回撤等
"""
initCapital = 40000 # 期初资金
capital = initCapital # 资金
maxCapital = initCapital # 资金最高净值
maxPnl = 0 # 最高盈利
minPnl = 0 # 最大亏损
maxVolume = 1 # 最大仓位数
wins = 0
totalResult = 0 # 总成交数量
totalTurnover = 0 # 总成交金额(合约面值)
totalCommission = 0 # 总手续费
totalSlippage = 0 # 总滑点
timeList = [] # 时间序列
pnlList = [] # 每笔盈亏序列
capitalList = [] # 盈亏汇总的时间序列
drawdownList = [] # 回撤的时间序列
drawdownRateList = [] # 最大回撤比例的时间序列
"""
drawdown = 0 # 回撤
compounding = 1 # 简单的复利基数如果资金是期初资金的x倍就扩大开仓比例,例如3w开1手6w开2手12w开4手)
for time, result in resultDict.items():
# 是否使用简单复利
if self.usageCompounding:
compounding = int(self.capital/self.initCapital)
if result.pnl > 0:
self.wins += 1
self.capital += result.pnl*compounding
self.maxCapital = max(self.capital, self.maxCapital)
self.maxVolume = max(self.maxVolume, result.volume*compounding)
drawdown = self.capital - self.maxCapital
drawdownRate = round(float(drawdown*100/self.maxCapital),4)
self.pnlList.append(result.pnl*compounding)
self.timeList.append(time)
self.capitalList.append(self.capital)
self.drawdownList.append(drawdown)
self.drawdownRateList.append(drawdownRate)
self.totalResult += 1
self.totalTurnover += result.turnover*compounding
self.totalCommission += result.commission*compounding
self.totalSlippage += result.slippage*compounding
# ---------------------------------------------------------------------
def exportTradeResult(self):
"""到处回测结果表"""
if not self.exportTradeList:
return
csvOutputFile = os.getcwd() + '/TestLogs/Output_{0}.csv'.format(datetime.now().strftime('%Y%m%d_%H%M'))
import csv
csvWriteFile = file(csvOutputFile, 'wb')
fieldnames = ['OpenTime', 'OpenPrice', 'Direction', 'CloseTime', 'ClosePrice', 'Volume', 'Profit']
writer = csv.DictWriter(f=csvWriteFile, fieldnames=fieldnames, dialect='excel')
writer.writeheader()
for row in self.exportTradeList:
writer.writerow(row)
def getResult(self):
# 返回回测结果
d = {}
d['initCapital'] = self.initCapital
d['capital'] = self.capital - self.initCapital
d['maxCapital'] = self.maxCapital
if len(self.pnlList) == 0:
return {}
d['maxPnl'] = max(self.pnlList)
d['minPnl'] = min(self.pnlList)
d['maxVolume'] = self.maxVolume
d['totalResult'] = self.totalResult
d['totalTurnover'] = self.totalTurnover
d['totalCommission'] = self.totalCommission
d['totalSlippage'] = self.totalSlippage
d['timeList'] = self.timeList
d['pnlList'] = self.pnlList
d['capitalList'] = self.capitalList
d['drawdownList'] = self.drawdownList
d['drawdownRateList'] = self.drawdownRateList
d['winRate'] = round(100*self.wins/len(self.pnlList),4)
return d
#----------------------------------------------------------------------
def showBacktestingResult(self):
"""显示回测结果"""
if self.calculateMode != self.REALTIME_MODE:
self.calculateBacktestingResult()
d = self.getResult()
if len(d)== 0:
self.output(u'无交易结果')
return
# 导出交易清单
self.exportTradeResult()
# 输出
self.output('-' * 30)
self.output(u'第一笔交易:\t%s' % d['timeList'][0])
self.output(u'最后一笔交易:\t%s' % d['timeList'][-1])
self.output(u'总交易次数:\t%s' % formatNumber(d['totalResult']))
self.output(u'期初资金:\t%s' % formatNumber(d['initCapital']))
self.output(u'总盈亏:\t%s' % formatNumber(d['capital']))
self.output(u'资金最高净值:\t%s' % formatNumber(d['maxCapital']))
self.output(u'每笔最大盈利:\t%s' % formatNumber(d['maxPnl']))
self.output(u'每笔最大亏损:\t%s' % formatNumber(d['minPnl']))
self.output(u'净值最大回撤: \t%s' % formatNumber(min(d['drawdownList'])))
self.output(u'净值最大回撤率: \t%s' % formatNumber(min(d['drawdownRateList'])))
self.output(u'胜率:\t%s' % formatNumber(d['winRate']))
self.output(u'最大持仓:\t%s' % formatNumber(d['maxVolume']))
self.output(u'平均每笔盈利:\t%s' %formatNumber(d['capital']/d['totalResult']))
self.output(u'平均每笔滑点成本:\t%s' %formatNumber(d['totalSlippage']/d['totalResult']))
self.output(u'平均每笔佣金:\t%s' %formatNumber(d['totalCommission']/d['totalResult']))
# 绘图
import matplotlib.pyplot as plt
pCapital = plt.subplot(3, 1, 1)
pCapital.set_ylabel("capital")
pCapital.plot(d['capitalList'])
pDD = plt.subplot(3, 1, 2)
pDD.set_ylabel("DD")
pDD.bar(range(len(d['drawdownList'])), d['drawdownList'])
pPnl = plt.subplot(3, 1, 3)
pPnl.set_ylabel("pnl")
pPnl.hist(d['pnlList'], bins=50)
plt.show()
#----------------------------------------------------------------------
def putStrategyEvent(self, name):
"""发送策略更新事件,回测中忽略"""
pass
#----------------------------------------------------------------------
def setSlippage(self, slippage):
"""设置滑点点数"""
self.slippage = slippage
#----------------------------------------------------------------------
def setSize(self, size):
"""设置合约大小"""
self.size = size
#----------------------------------------------------------------------
def setRate(self, rate):
"""设置佣金比例"""
self.rate = float(rate)
#----------------------------------------------------------------------
def runOptimization(self, strategyClass, optimizationSetting):
"""优化参数"""
# 获取优化设置
settingList = optimizationSetting.generateSetting()
targetName = optimizationSetting.optimizeTarget
# 检查参数设置问题
if not settingList or not targetName:
self.output(u'优化设置有问题,请检查')
# 遍历优化
resultList = []
for setting in settingList:
self.clearBacktestingResult()
self.output('-' * 30)
self.output('setting: %s' %str(setting))
self.initStrategy(strategyClass, setting)
self.runBacktesting()
d = self.calculateBacktestingResult()
try:
targetValue = d[targetName]
except KeyError:
targetValue = 0
resultList.append(([str(setting)], targetValue))
# 显示结果
resultList.sort(reverse=True, key=lambda result:result[1])
self.output('-' * 30)
self.output(u'优化结果:')
for result in resultList:
self.output(u'%s: %s' %(result[0], result[1]))
#----------------------------------------------------------------------
def clearBacktestingResult(self):
"""清空之前回测的结果"""
# 清空限价单相关
self.limitOrderCount = 0
self.limitOrderDict.clear()
self.workingLimitOrderDict.clear()
# 清空停止单相关
self.stopOrderCount = 0
self.stopOrderDict.clear()
self.workingStopOrderDict.clear()
# 清空成交相关
self.tradeCount = 0
self.tradeDict.clear()
########################################################################
class TradingResult(object):
"""每笔交易的结果"""
#----------------------------------------------------------------------
def __init__(self, entry, exit, volume, rate, slippage, size, groupId, fixcommission=EMPTY_FLOAT):
"""Constructor"""
self.entry = entry # 开仓价格
self.exit = exit # 平仓价格
self.volume = volume # 交易数量(+/-代表方向)
self.groupId = groupId # 主交易ID针对多手平仓
self.turnover = (self.entry+self.exit)*size # 成交金额
if fixcommission:
self.commission = fixcommission * self.volume
else:
self.commission =round(float(self.turnover*rate),4) # 手续费成本
self.slippage = slippage*2*size # 滑点成本
self.pnl = ((self.exit - self.entry) * volume * size
- self.commission - self.slippage) # 净盈亏
########################################################################
class OptimizationSetting(object):
"""优化设置"""
#----------------------------------------------------------------------
def __init__(self):
"""Constructor"""
self.paramDict = OrderedDict()
self.optimizeTarget = '' # 优化目标字段
#----------------------------------------------------------------------
def addParameter(self, name, start, end, step):
"""增加优化参数"""
if end <= start:
print u'参数起始点必须小于终止点'
return
if step <= 0:
print u'参数布进必须大于0'
return
l = []
param = start
while param <= end:
l.append(param)
param += step
self.paramDict[name] = l
#----------------------------------------------------------------------
def generateSetting(self):
"""生成优化参数组合"""
# 参数名的列表
nameList = self.paramDict.keys()
paramList = self.paramDict.values()
# 使用迭代工具生产参数对组合
productList = list(product(*paramList))
# 把参数对组合打包到一个个字典组成的列表中
settingList = []
for p in productList:
d = dict(zip(nameList, p))
settingList.append(d)
return settingList
#----------------------------------------------------------------------
def setOptimizeTarget(self, target):
"""设置优化目标字段"""
self.optimizeTarget = target
#----------------------------------------------------------------------
def formatNumber(n):
"""格式化数字到字符串"""
n = round(n, 2) # 保留两位小数
return format(n, ',') # 加上千分符
if __name__ == '__main__':
# 以下内容是一段回测脚本的演示,用户可以根据自己的需求修改
# 建议使用ipython notebook或者spyder来做回测
# 同样可以在命令模式下进行回测(一行一行输入运行)
from ctaDemo import *
# 创建回测引擎
engine = BacktestingEngine()
# 设置引擎的回测模式为K线
engine.setBacktestingMode(engine.BAR_MODE)
# 设置回测用的数据起始日期
engine.setStartDate('20110101')
# 载入历史数据到引擎中
engine.setDatabase(MINUTE_DB_NAME, 'IF0000')
# 设置产品相关参数
engine.setSlippage(0.2) # 股指1跳
engine.setRate(0.3/10000) # 万0.3
engine.setSize(300) # 股指合约大小
# 在引擎中创建策略对象
engine.initStrategy(DoubleEmaDemo, {})
# 开始跑回测
engine.runBacktesting()
# 显示回测结果
# spyder或者ipython notebook中运行时会弹出盈亏曲线图
# 直接在cmd中回测则只会打印一些回测数值
engine.showBacktestingResult()