2016-07-02 03:12:44 +00:00
|
|
|
|
# encoding: UTF-8
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
本文件中包含的是CTA模块的回测引擎,回测引擎的API和CTA引擎一致,
|
|
|
|
|
可以使用和实盘相同的代码进行回测。
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
from datetime import datetime, timedelta
|
|
|
|
|
from collections import OrderedDict
|
2016-09-14 14:29:31 +00:00
|
|
|
|
from itertools import product
|
2016-07-02 03:12:44 +00:00
|
|
|
|
import pymongo
|
|
|
|
|
|
2016-09-14 14:29:31 +00:00
|
|
|
|
import MySQLdb
|
|
|
|
|
import json
|
|
|
|
|
import os
|
|
|
|
|
import cPickle
|
|
|
|
|
|
2016-07-02 03:12:44 +00:00
|
|
|
|
from ctaBase import *
|
|
|
|
|
from ctaSetting import *
|
|
|
|
|
|
|
|
|
|
from vtConstant import *
|
|
|
|
|
from vtGateway import VtOrderData, VtTradeData
|
|
|
|
|
from vtFunction import loadMongoSetting
|
2016-09-14 14:29:31 +00:00
|
|
|
|
import logging
|
2016-07-02 03:12:44 +00:00
|
|
|
|
|
|
|
|
|
########################################################################
|
|
|
|
|
class BacktestingEngine(object):
|
|
|
|
|
"""
|
|
|
|
|
CTA回测引擎
|
|
|
|
|
函数接口和策略引擎保持一样,
|
|
|
|
|
从而实现同一套代码从回测到实盘。
|
2016-09-14 14:29:31 +00:00
|
|
|
|
# modified by IncenseLee:
|
|
|
|
|
1.增加Mysql数据库的支持;
|
|
|
|
|
2.修改装载数据为批量式后加载模式。
|
|
|
|
|
|
2016-07-02 03:12:44 +00:00
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
TICK_MODE = 'tick'
|
|
|
|
|
BAR_MODE = 'bar'
|
|
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------
|
|
|
|
|
def __init__(self):
|
|
|
|
|
"""Constructor"""
|
|
|
|
|
# 本地停止单编号计数
|
|
|
|
|
self.stopOrderCount = 0
|
|
|
|
|
# stopOrderID = STOPORDERPREFIX + str(stopOrderCount)
|
|
|
|
|
|
|
|
|
|
# 本地停止单字典
|
|
|
|
|
# key为stopOrderID,value为stopOrder对象
|
|
|
|
|
self.stopOrderDict = {} # 停止单撤销后不会从本字典中删除
|
|
|
|
|
self.workingStopOrderDict = {} # 停止单撤销后会从本字典中删除
|
|
|
|
|
|
|
|
|
|
# 回测相关
|
|
|
|
|
self.strategy = None # 回测策略
|
|
|
|
|
self.mode = self.BAR_MODE # 回测模式,默认为K线
|
|
|
|
|
|
|
|
|
|
self.slippage = 0 # 回测时假设的滑点
|
|
|
|
|
self.rate = 0 # 回测时假设的佣金比例(适用于百分比佣金)
|
|
|
|
|
self.size = 1 # 合约大小,默认为1
|
|
|
|
|
|
|
|
|
|
self.dbClient = None # 数据库客户端
|
|
|
|
|
self.dbCursor = None # 数据库指针
|
|
|
|
|
|
2016-09-14 14:29:31 +00:00
|
|
|
|
self.historyData = [] # 历史数据的列表,回测用
|
2016-07-02 03:12:44 +00:00
|
|
|
|
self.initData = [] # 初始化用的数据
|
2016-09-14 14:29:31 +00:00
|
|
|
|
self.backtestingData = [] # 回测用的数据
|
2016-07-02 03:12:44 +00:00
|
|
|
|
|
2016-09-14 14:29:31 +00:00
|
|
|
|
self.dbName = '' # 回测数据库名
|
|
|
|
|
self.symbol = '' # 回测集合名
|
|
|
|
|
|
2016-07-02 03:12:44 +00:00
|
|
|
|
self.dataStartDate = None # 回测数据开始日期,datetime对象
|
|
|
|
|
self.dataEndDate = None # 回测数据结束日期,datetime对象
|
|
|
|
|
self.strategyStartDate = None # 策略启动日期(即前面的数据用于初始化),datetime对象
|
|
|
|
|
|
|
|
|
|
self.limitOrderDict = OrderedDict() # 限价单字典
|
|
|
|
|
self.workingLimitOrderDict = OrderedDict() # 活动限价单字典,用于进行撮合用
|
|
|
|
|
self.limitOrderCount = 0 # 限价单编号
|
|
|
|
|
|
|
|
|
|
self.tradeCount = 0 # 成交编号
|
|
|
|
|
self.tradeDict = OrderedDict() # 成交字典
|
|
|
|
|
|
|
|
|
|
self.logList = [] # 日志记录
|
|
|
|
|
|
|
|
|
|
# 当前最新数据,用于模拟成交用
|
|
|
|
|
self.tick = None
|
|
|
|
|
self.bar = None
|
|
|
|
|
self.dt = None # 最新的时间
|
2016-09-14 14:29:31 +00:00
|
|
|
|
self.gatewayName = u'BackTest'
|
2016-07-02 03:12:44 +00:00
|
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------
|
|
|
|
|
def setStartDate(self, startDate='20100416', initDays=10):
|
|
|
|
|
"""设置回测的启动日期"""
|
|
|
|
|
self.dataStartDate = datetime.strptime(startDate, '%Y%m%d')
|
2016-09-14 14:29:31 +00:00
|
|
|
|
|
|
|
|
|
# 初始化天数
|
2016-07-02 03:12:44 +00:00
|
|
|
|
initTimeDelta = timedelta(initDays)
|
2016-09-14 14:29:31 +00:00
|
|
|
|
|
2016-07-02 03:12:44 +00:00
|
|
|
|
self.strategyStartDate = self.dataStartDate + initTimeDelta
|
|
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------
|
|
|
|
|
def setEndDate(self, endDate=''):
|
|
|
|
|
"""设置回测的结束日期"""
|
|
|
|
|
if endDate:
|
2016-09-14 14:29:31 +00:00
|
|
|
|
self.dataEndDate = datetime.strptime(endDate, '%Y%m%d')
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
self.dataEndDate = datetime.now()
|
|
|
|
|
|
|
|
|
|
def setMinDiff(self, minDiff):
|
|
|
|
|
"""设置回测品种的最小跳价,用于修正数据"""
|
|
|
|
|
self.minDiff = minDiff
|
|
|
|
|
|
2016-07-02 03:12:44 +00:00
|
|
|
|
#----------------------------------------------------------------------
|
|
|
|
|
def setBacktestingMode(self, mode):
|
|
|
|
|
"""设置回测模式"""
|
|
|
|
|
self.mode = mode
|
2016-09-14 14:29:31 +00:00
|
|
|
|
|
2016-07-02 03:12:44 +00:00
|
|
|
|
#----------------------------------------------------------------------
|
2016-09-14 14:29:31 +00:00
|
|
|
|
def setDatabase(self, dbName, symbol):
|
|
|
|
|
"""设置历史数据所用的数据库"""
|
|
|
|
|
self.dbName = dbName
|
|
|
|
|
self.symbol = symbol
|
|
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------
|
|
|
|
|
def loadHistoryDataFromMongo(self):
|
2016-07-02 03:12:44 +00:00
|
|
|
|
"""载入历史数据"""
|
|
|
|
|
host, port = loadMongoSetting()
|
|
|
|
|
|
|
|
|
|
self.dbClient = pymongo.MongoClient(host, port)
|
2016-09-14 14:29:31 +00:00
|
|
|
|
collection = self.dbClient[self.dbName][self.symbol]
|
2016-07-02 03:12:44 +00:00
|
|
|
|
|
|
|
|
|
self.output(u'开始载入数据')
|
|
|
|
|
|
|
|
|
|
# 首先根据回测模式,确认要使用的数据类
|
|
|
|
|
if self.mode == self.BAR_MODE:
|
|
|
|
|
dataClass = CtaBarData
|
|
|
|
|
func = self.newBar
|
|
|
|
|
else:
|
|
|
|
|
dataClass = CtaTickData
|
|
|
|
|
func = self.newTick
|
|
|
|
|
|
|
|
|
|
# 载入初始化需要用的数据
|
|
|
|
|
flt = {'datetime':{'$gte':self.dataStartDate,
|
|
|
|
|
'$lt':self.strategyStartDate}}
|
|
|
|
|
initCursor = collection.find(flt)
|
|
|
|
|
|
|
|
|
|
# 将数据从查询指针中读取出,并生成列表
|
|
|
|
|
for d in initCursor:
|
|
|
|
|
data = dataClass()
|
|
|
|
|
data.__dict__ = d
|
|
|
|
|
self.initData.append(data)
|
|
|
|
|
|
|
|
|
|
# 载入回测数据
|
|
|
|
|
if not self.dataEndDate:
|
|
|
|
|
flt = {'datetime':{'$gte':self.strategyStartDate}} # 数据过滤条件
|
|
|
|
|
else:
|
|
|
|
|
flt = {'datetime':{'$gte':self.strategyStartDate,
|
|
|
|
|
'$lte':self.dataEndDate}}
|
|
|
|
|
self.dbCursor = collection.find(flt)
|
|
|
|
|
|
|
|
|
|
self.output(u'载入完成,数据量:%s' %(initCursor.count() + self.dbCursor.count()))
|
2016-09-14 14:29:31 +00:00
|
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------
|
|
|
|
|
def connectMysql(self):
|
|
|
|
|
"""连接MysqlDB"""
|
|
|
|
|
|
|
|
|
|
# 载入json文件
|
|
|
|
|
fileName = 'mysql_connect.json'
|
|
|
|
|
try:
|
|
|
|
|
f = file(fileName)
|
|
|
|
|
except IOError:
|
|
|
|
|
self.writeCtaLog(u'回测引擎读取Mysql_connect.json失败')
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# 解析json文件
|
|
|
|
|
setting = json.load(f)
|
|
|
|
|
try:
|
|
|
|
|
mysql_host = str(setting['host'])
|
|
|
|
|
mysql_port = int(setting['port'])
|
|
|
|
|
mysql_user = str(setting['user'])
|
|
|
|
|
mysql_passwd = str(setting['passwd'])
|
|
|
|
|
mysql_db = str(setting['db'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
except IOError:
|
|
|
|
|
self.writeCtaLog(u'回测引擎读取Mysql_connect.json,连接配置缺少字段,请检查')
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
self.__mysqlConnection = MySQLdb.connect(host=mysql_host, user=mysql_user,
|
|
|
|
|
passwd=mysql_passwd, db=mysql_db, port=mysql_port)
|
|
|
|
|
self.__mysqlConnected = True
|
|
|
|
|
self.writeCtaLog(u'回测引擎连接MysqlDB成功')
|
|
|
|
|
except ConnectionFailure:
|
|
|
|
|
self.writeCtaLog(u'回测引擎连接MysqlDB失败')
|
|
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------
|
|
|
|
|
def loadDataHistoryFromMysql(self, symbol, startDate, endDate):
|
|
|
|
|
"""载入历史TICK数据
|
|
|
|
|
如果加载过多数据会导致加载失败,间隔不要超过半年
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
if not endDate:
|
|
|
|
|
endDate = datetime.today()
|
|
|
|
|
|
|
|
|
|
# 看本地缓存是否存在
|
|
|
|
|
if self.__loadDataHistoryFromLocalCache(symbol, startDate, endDate):
|
|
|
|
|
self.writeCtaLog(u'历史TICK数据从Cache载入')
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# 每次获取日期周期
|
|
|
|
|
intervalDays = 10
|
|
|
|
|
|
|
|
|
|
for i in range (0,(endDate - startDate).days +1, intervalDays):
|
|
|
|
|
d1 = startDate + timedelta(days = i )
|
|
|
|
|
|
|
|
|
|
if (endDate - d1).days > 10:
|
|
|
|
|
d2 = startDate + timedelta(days = i + intervalDays -1 )
|
|
|
|
|
else:
|
|
|
|
|
d2 = endDate
|
|
|
|
|
|
|
|
|
|
# 从Mysql 提取数据
|
|
|
|
|
self.__qryDataHistoryFromMysql(symbol, d1, d2)
|
|
|
|
|
|
|
|
|
|
self.writeCtaLog(u'历史TICK数据共载入{0}条'.format(len(self.historyData)))
|
|
|
|
|
|
|
|
|
|
# 保存本地cache文件
|
|
|
|
|
self.__saveDataHistoryToLocalCache(symbol, startDate, endDate)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __loadDataHistoryFromLocalCache(self, symbol, startDate, endDate):
|
|
|
|
|
"""看本地缓存是否存在
|
|
|
|
|
added by IncenseLee
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# 运行路径下cache子目录
|
|
|
|
|
cacheFolder = os.getcwd()+'/cache'
|
|
|
|
|
|
|
|
|
|
# cache文件
|
|
|
|
|
cacheFile = u'{0}/{1}_{2}_{3}.pickle'.\
|
|
|
|
|
format(cacheFolder, symbol, startDate.strftime('%Y-%m-%d'), endDate.strftime('%Y-%m-%d'))
|
|
|
|
|
|
|
|
|
|
if not os.path.isfile(cacheFile):
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
# 从cache文件加载
|
|
|
|
|
cache = open(cacheFile,mode='r')
|
|
|
|
|
self.historyData = cPickle.load(cache)
|
|
|
|
|
cache.close()
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
def __saveDataHistoryToLocalCache(self, symbol, startDate, endDate):
|
|
|
|
|
"""保存本地缓存
|
|
|
|
|
added by IncenseLee
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# 运行路径下cache子目录
|
|
|
|
|
cacheFolder = os.getcwd()+'/cache'
|
|
|
|
|
|
|
|
|
|
# 创建cache子目录
|
|
|
|
|
if not os.path.isdir(cacheFolder):
|
|
|
|
|
os.mkdir(cacheFolder)
|
|
|
|
|
|
|
|
|
|
# cache 文件名
|
|
|
|
|
cacheFile = u'{0}/{1}_{2}_{3}.pickle'.\
|
|
|
|
|
format(cacheFolder, symbol, startDate.strftime('%Y-%m-%d'), endDate.strftime('%Y-%m-%d'))
|
|
|
|
|
|
|
|
|
|
# 重复存在 返回
|
|
|
|
|
if os.path.isfile(cacheFile):
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
# 写入cache文件
|
|
|
|
|
cache = open(cacheFile, mode='w')
|
|
|
|
|
cPickle.dump(self.historyData,cache)
|
|
|
|
|
cache.close()
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------
|
|
|
|
|
def __qryDataHistoryFromMysql(self, symbol, startDate, endDate):
|
|
|
|
|
"""从Mysql载入历史TICK数据
|
|
|
|
|
added by IncenseLee
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
self.connectMysql()
|
|
|
|
|
if self.__mysqlConnected:
|
|
|
|
|
|
|
|
|
|
# 获取指针
|
|
|
|
|
cur = self.__mysqlConnection.cursor(MySQLdb.cursors.DictCursor)
|
|
|
|
|
|
|
|
|
|
if endDate:
|
|
|
|
|
|
|
|
|
|
# 开始日期 ~ 结束日期
|
|
|
|
|
sqlstring = ' select \'{0}\' as InstrumentID, str_to_date(concat(ndate,\' \', ntime),' \
|
|
|
|
|
'\'%Y-%m-%d %H:%i:%s\') as UpdateTime,price as LastPrice,vol as Volume,' \
|
|
|
|
|
'position_vol as OpenInterest,bid1_price as BidPrice1,bid1_vol as BidVolume1, ' \
|
|
|
|
|
'sell1_price as AskPrice1, sell1_vol as AskVolume1 from TB_{0}MI ' \
|
|
|
|
|
'where ndate between cast(\'{1}\' as date) and cast(\'{2}\' as date) order by UpdateTime'.\
|
|
|
|
|
format(symbol, startDate, endDate)
|
|
|
|
|
|
|
|
|
|
elif startDate:
|
|
|
|
|
|
|
|
|
|
# 开始日期 - 当前
|
|
|
|
|
sqlstring = ' select \'{0}\' as InstrumentID,str_to_date(concat(ndate,\' \', ntime),' \
|
|
|
|
|
'\'%Y-%m-%d %H:%i:%s\') as UpdateTime,price as LastPrice,vol as Volume,' \
|
|
|
|
|
'position_vol as OpenInterest,bid1_price as BidPrice1,bid1_vol as BidVolume1, ' \
|
|
|
|
|
'sell1_price as AskPrice1, sell1_vol as AskVolume1 from TB__{0}MI ' \
|
|
|
|
|
'where ndate > cast(\'{1}\' as date) order by UpdateTime'.\
|
|
|
|
|
format( symbol, startDate)
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
|
|
# 所有数据
|
|
|
|
|
sqlstring =' select \'{0}\' as InstrumentID,str_to_date(concat(ndate,\' \', ntime),' \
|
|
|
|
|
'\'%Y-%m-%d %H:%i:%s\') as UpdateTime,price as LastPrice,vol as Volume,' \
|
|
|
|
|
'position_vol as OpenInterest,bid1_price as BidPrice1,bid1_vol as BidVolume1, ' \
|
|
|
|
|
'sell1_price as AskPrice1, sell1_vol as AskVolume1 from TB__{0}MI order by UpdateTime'.\
|
|
|
|
|
format(symbol)
|
|
|
|
|
|
|
|
|
|
self.writeCtaLog(sqlstring)
|
|
|
|
|
|
|
|
|
|
# 执行查询
|
|
|
|
|
count = cur.execute(sqlstring)
|
|
|
|
|
self.writeCtaLog(u'历史TICK数据共{0}条'.format(count))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 分批次读取
|
|
|
|
|
fetch_counts = 0
|
|
|
|
|
fetch_size = 1000
|
|
|
|
|
|
|
|
|
|
while True:
|
|
|
|
|
results = cur.fetchmany(fetch_size)
|
|
|
|
|
|
|
|
|
|
if not results:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
fetch_counts = fetch_counts + len(results)
|
|
|
|
|
|
|
|
|
|
if not self.historyData:
|
|
|
|
|
self.historyData =results
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
self.historyData = self.historyData + results
|
|
|
|
|
|
|
|
|
|
self.writeCtaLog(u'{1}~{2}历史TICK数据载入共{0}条'.format(fetch_counts,startDate,endDate))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
self.writeCtaLog(u'MysqlDB未连接,请检查')
|
|
|
|
|
|
|
|
|
|
except MySQLdb.Error, e:
|
|
|
|
|
|
|
|
|
|
self.writeCtaLog(u'MysqlDB载入数据失败,请检查.Error {0}'.format(e))
|
|
|
|
|
|
|
|
|
|
def __dataToTick(self, data):
|
|
|
|
|
"""
|
|
|
|
|
数据库查询返回的data结构,转换为tick对象
|
|
|
|
|
added by IncenseLee
|
|
|
|
|
"""
|
|
|
|
|
tick = CtaTickData()
|
|
|
|
|
symbol = data['InstrumentID']
|
|
|
|
|
tick.symbol = symbol
|
|
|
|
|
|
|
|
|
|
# 创建TICK数据对象并更新数据
|
|
|
|
|
tick.vtSymbol = symbol
|
|
|
|
|
# tick.openPrice = data['OpenPrice']
|
|
|
|
|
# tick.highPrice = data['HighestPrice']
|
|
|
|
|
# tick.lowPrice = data['LowestPrice']
|
|
|
|
|
tick.lastPrice = float(data['LastPrice'])
|
|
|
|
|
|
|
|
|
|
tick.volume = data['Volume']
|
|
|
|
|
tick.openInterest = data['OpenInterest']
|
|
|
|
|
|
|
|
|
|
# tick.upperLimit = data['UpperLimitPrice']
|
|
|
|
|
# tick.lowerLimit = data['LowerLimitPrice']
|
|
|
|
|
|
|
|
|
|
tick.datetime = data['UpdateTime']
|
|
|
|
|
tick.date = tick.datetime.strftime('%Y-%m-%d')
|
|
|
|
|
tick.time = tick.datetime.strftime('%H:%M:%S')
|
|
|
|
|
|
|
|
|
|
tick.bidPrice1 = float(data['BidPrice1'])
|
|
|
|
|
# tick.bidPrice2 = data['BidPrice2']
|
|
|
|
|
# tick.bidPrice3 = data['BidPrice3']
|
|
|
|
|
# tick.bidPrice4 = data['BidPrice4']
|
|
|
|
|
# tick.bidPrice5 = data['BidPrice5']
|
|
|
|
|
|
|
|
|
|
tick.askPrice1 = float(data['AskPrice1'])
|
|
|
|
|
# tick.askPrice2 = data['AskPrice2']
|
|
|
|
|
# tick.askPrice3 = data['AskPrice3']
|
|
|
|
|
# tick.askPrice4 = data['AskPrice4']
|
|
|
|
|
# tick.askPrice5 = data['AskPrice5']
|
|
|
|
|
|
|
|
|
|
tick.bidVolume1 = data['BidVolume1']
|
|
|
|
|
# tick.bidVolume2 = data['BidVolume2']
|
|
|
|
|
# tick.bidVolume3 = data['BidVolume3']
|
|
|
|
|
# tick.bidVolume4 = data['BidVolume4']
|
|
|
|
|
# tick.bidVolume5 = data['BidVolume5']
|
|
|
|
|
|
|
|
|
|
tick.askVolume1 = data['AskVolume1']
|
|
|
|
|
# tick.askVolume2 = data['AskVolume2']
|
|
|
|
|
# tick.askVolume3 = data['AskVolume3']
|
|
|
|
|
# tick.askVolume4 = data['AskVolume4']
|
|
|
|
|
# tick.askVolume5 = data['AskVolume5']
|
|
|
|
|
|
|
|
|
|
return tick
|
|
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------
|
|
|
|
|
def getMysqlDeltaDate(self,symbol, startDate, decreaseDays):
|
|
|
|
|
"""从mysql库中获取交易日前若干天
|
|
|
|
|
added by IncenseLee
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
if self.__mysqlConnected:
|
|
|
|
|
|
|
|
|
|
# 获取mysql指针
|
|
|
|
|
cur = self.__mysqlConnection.cursor()
|
|
|
|
|
|
|
|
|
|
sqlstring='select distinct ndate from TB_{0}MI where ndate < ' \
|
|
|
|
|
'cast(\'{1}\' as date) order by ndate desc limit {2},1'.format(symbol, startDate, decreaseDays-1)
|
|
|
|
|
|
|
|
|
|
# self.writeCtaLog(sqlstring)
|
|
|
|
|
|
|
|
|
|
count = cur.execute(sqlstring)
|
|
|
|
|
|
|
|
|
|
if count > 0:
|
|
|
|
|
|
|
|
|
|
# 提取第一条记录
|
|
|
|
|
result = cur.fetchone()
|
|
|
|
|
|
|
|
|
|
return result[0]
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
self.writeCtaLog(u'MysqlDB没有查询结果,请检查日期')
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
self.writeCtaLog(u'MysqlDB未连接,请检查')
|
|
|
|
|
|
|
|
|
|
except MySQLdb.Error, e:
|
|
|
|
|
|
|
|
|
|
self.writeCtaLog(u'MysqlDB载入数据失败,请检查.Error {0}: {1}'.format(e.arg[0],e.arg[1]))
|
|
|
|
|
|
|
|
|
|
# 出错后缺省返回
|
|
|
|
|
return startDate-timedelta(days=3)
|
|
|
|
|
|
2016-09-30 12:39:18 +00:00
|
|
|
|
#----------------------------------------------------------------------
|
|
|
|
|
def runBackTestingWithFile(self, filename):
|
|
|
|
|
"""运行回测(使用本地csv数据)
|
|
|
|
|
added by IncenseLee
|
|
|
|
|
"""
|
|
|
|
|
if not filename:
|
|
|
|
|
self.writeCtaLog(u'请指定回测数据文件')
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if not self.dataStartDate:
|
|
|
|
|
self.writeCtaLog(u'回测开始日期未设置。')
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if not self.dataEndDate:
|
|
|
|
|
self.dataEndDate = datetime.today()
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
if not os.path.isfile(filename):
|
|
|
|
|
self.writeCtaLog(u'{0}文件不存在'.format(filename))
|
|
|
|
|
|
|
|
|
|
if len(self.symbol)<1:
|
|
|
|
|
self.writeCtaLog(u'回测对象未设置。')
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# 首先根据回测模式,确认要使用的数据类
|
|
|
|
|
if not self.mode == self.BAR_MODE:
|
|
|
|
|
self.writeCtaLog(u'文件仅支持bar模式,若扩展tick模式,需要修改本方法')
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
self.output(u'开始回测')
|
|
|
|
|
|
|
|
|
|
#self.strategy.inited = True
|
|
|
|
|
self.strategy.onInit()
|
|
|
|
|
self.output(u'策略初始化完成')
|
|
|
|
|
|
|
|
|
|
self.strategy.trading = True
|
|
|
|
|
self.strategy.onStart()
|
|
|
|
|
self.output(u'策略启动完成')
|
|
|
|
|
|
|
|
|
|
self.output(u'开始回放数据')
|
|
|
|
|
|
|
|
|
|
import csv
|
|
|
|
|
|
|
|
|
|
csvfile = file(filename,'rb')
|
|
|
|
|
|
|
|
|
|
reader = csv.DictReader((line.replace('\0','') for line in csvfile),delimiter=",")
|
|
|
|
|
|
|
|
|
|
for row in reader:
|
|
|
|
|
|
|
|
|
|
bar = CtaBarData()
|
|
|
|
|
bar.open = float(row['Open'])
|
|
|
|
|
bar.high = float(row['High'])
|
|
|
|
|
bar.low = float(row['Low'])
|
|
|
|
|
bar.close = float(row['Close'])
|
|
|
|
|
bar.volume = float(row['TotalVolume'])
|
|
|
|
|
bar.date = row['Date']
|
|
|
|
|
bar.time = row['Time']
|
|
|
|
|
bar.datetime = datetime.strptime(bar.date+' '+ bar.time, '%Y/%m/%d %H:%M:%S')
|
|
|
|
|
|
|
|
|
|
if not (bar.datetime < self.dataStartDate or bar.datetime > self.dataEndDate):
|
|
|
|
|
|
|
|
|
|
self.newBar(bar)
|
|
|
|
|
|
|
|
|
|
|
2016-09-14 14:29:31 +00:00
|
|
|
|
#----------------------------------------------------------------------
|
|
|
|
|
def runBacktestingWithMysql(self):
|
|
|
|
|
"""运行回测(使用Mysql数据)
|
|
|
|
|
added by IncenseLee
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
if not self.dataStartDate:
|
|
|
|
|
self.writeCtaLog(u'回测开始日期未设置。')
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if not self.dataEndDate:
|
|
|
|
|
self.dataEndDate = datetime.today()
|
|
|
|
|
|
|
|
|
|
if len(self.symbol)<1:
|
|
|
|
|
self.writeCtaLog(u'回测对象未设置。')
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# 首先根据回测模式,确认要使用的数据类
|
|
|
|
|
if self.mode == self.BAR_MODE:
|
|
|
|
|
dataClass = CtaBarData
|
|
|
|
|
func = self.newBar
|
|
|
|
|
else:
|
|
|
|
|
dataClass = CtaTickData
|
|
|
|
|
func = self.newTick
|
|
|
|
|
|
|
|
|
|
self.output(u'开始回测')
|
|
|
|
|
|
|
|
|
|
#self.strategy.inited = True
|
|
|
|
|
self.strategy.onInit()
|
|
|
|
|
self.output(u'策略初始化完成')
|
|
|
|
|
|
|
|
|
|
self.strategy.trading = True
|
|
|
|
|
self.strategy.onStart()
|
|
|
|
|
self.output(u'策略启动完成')
|
|
|
|
|
|
|
|
|
|
self.output(u'开始回放数据')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 每次获取日期周期
|
|
|
|
|
intervalDays = 10
|
|
|
|
|
|
|
|
|
|
for i in range (0,(self.dataEndDate - self.dataStartDate).days +1, intervalDays):
|
|
|
|
|
d1 = self.dataStartDate + timedelta(days = i )
|
|
|
|
|
|
|
|
|
|
if (self.dataEndDate - d1).days > intervalDays:
|
|
|
|
|
d2 = self.dataStartDate + timedelta(days = i + intervalDays -1 )
|
|
|
|
|
else:
|
|
|
|
|
d2 = self.dataEndDate
|
|
|
|
|
|
|
|
|
|
# 提取历史数据
|
|
|
|
|
self.loadDataHistoryFromMysql(self.symbol, d1, d2)
|
|
|
|
|
|
|
|
|
|
self.output(u'数据日期:{0} => {1}'.format(d1,d2))
|
|
|
|
|
# 将逐笔数据推送
|
|
|
|
|
for data in self.historyData:
|
|
|
|
|
|
|
|
|
|
# 记录最新的TICK数据
|
|
|
|
|
self.tick = self.__dataToTick(data)
|
|
|
|
|
self.dt = self.tick.datetime
|
|
|
|
|
|
|
|
|
|
# 处理限价单
|
|
|
|
|
self.crossLimitOrder()
|
|
|
|
|
self.crossStopOrder()
|
|
|
|
|
|
|
|
|
|
# 推送到策略引擎中
|
|
|
|
|
self.strategy.onTick(self.tick)
|
|
|
|
|
|
|
|
|
|
# 清空历史数据
|
|
|
|
|
self.historyData = []
|
|
|
|
|
|
|
|
|
|
self.output(u'数据回放结束')
|
|
|
|
|
|
2016-07-02 03:12:44 +00:00
|
|
|
|
#----------------------------------------------------------------------
|
|
|
|
|
def runBacktesting(self):
|
|
|
|
|
"""运行回测"""
|
2016-09-14 14:29:31 +00:00
|
|
|
|
# 载入历史数据
|
|
|
|
|
self.loadHistoryData()
|
|
|
|
|
|
2016-07-02 03:12:44 +00:00
|
|
|
|
# 首先根据回测模式,确认要使用的数据类
|
|
|
|
|
if self.mode == self.BAR_MODE:
|
|
|
|
|
dataClass = CtaBarData
|
|
|
|
|
func = self.newBar
|
|
|
|
|
else:
|
|
|
|
|
dataClass = CtaTickData
|
|
|
|
|
func = self.newTick
|
|
|
|
|
|
|
|
|
|
self.output(u'开始回测')
|
|
|
|
|
|
|
|
|
|
self.strategy.inited = True
|
|
|
|
|
self.strategy.onInit()
|
|
|
|
|
self.output(u'策略初始化完成')
|
|
|
|
|
|
|
|
|
|
self.strategy.trading = True
|
|
|
|
|
self.strategy.onStart()
|
|
|
|
|
self.output(u'策略启动完成')
|
|
|
|
|
|
|
|
|
|
self.output(u'开始回放数据')
|
|
|
|
|
|
|
|
|
|
for d in self.dbCursor:
|
|
|
|
|
data = dataClass()
|
|
|
|
|
data.__dict__ = d
|
|
|
|
|
func(data)
|
|
|
|
|
|
|
|
|
|
self.output(u'数据回放结束')
|
2016-09-14 14:29:31 +00:00
|
|
|
|
|
2016-07-02 03:12:44 +00:00
|
|
|
|
#----------------------------------------------------------------------
|
|
|
|
|
def newBar(self, bar):
|
|
|
|
|
"""新的K线"""
|
|
|
|
|
self.bar = bar
|
|
|
|
|
self.dt = bar.datetime
|
|
|
|
|
self.crossLimitOrder() # 先撮合限价单
|
|
|
|
|
self.crossStopOrder() # 再撮合停止单
|
|
|
|
|
self.strategy.onBar(bar) # 推送K线到策略中
|
|
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------
|
|
|
|
|
def newTick(self, tick):
|
|
|
|
|
"""新的Tick"""
|
|
|
|
|
self.tick = tick
|
|
|
|
|
self.dt = tick.datetime
|
|
|
|
|
self.crossLimitOrder()
|
|
|
|
|
self.crossStopOrder()
|
|
|
|
|
self.strategy.onTick(tick)
|
|
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------
|
|
|
|
|
def initStrategy(self, strategyClass, setting=None):
|
|
|
|
|
"""
|
|
|
|
|
初始化策略
|
|
|
|
|
setting是策略的参数设置,如果使用类中写好的默认设置则可以不传该参数
|
|
|
|
|
"""
|
|
|
|
|
self.strategy = strategyClass(self, setting)
|
|
|
|
|
self.strategy.name = self.strategy.className
|
|
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------
|
|
|
|
|
def sendOrder(self, vtSymbol, orderType, price, volume, strategy):
|
|
|
|
|
"""发单"""
|
2016-09-14 14:29:31 +00:00
|
|
|
|
|
|
|
|
|
self.writeCtaLog(u'{0},{1},{2}@{3}'.format(vtSymbol,orderType,price,volume))
|
2016-07-02 03:12:44 +00:00
|
|
|
|
self.limitOrderCount += 1
|
|
|
|
|
orderID = str(self.limitOrderCount)
|
|
|
|
|
|
|
|
|
|
order = VtOrderData()
|
|
|
|
|
order.vtSymbol = vtSymbol
|
|
|
|
|
order.price = price
|
|
|
|
|
order.totalVolume = volume
|
|
|
|
|
order.status = STATUS_NOTTRADED # 刚提交尚未成交
|
|
|
|
|
order.orderID = orderID
|
|
|
|
|
order.vtOrderID = orderID
|
|
|
|
|
order.orderTime = str(self.dt)
|
2016-09-14 14:29:31 +00:00
|
|
|
|
|
|
|
|
|
# added by IncenseLee
|
|
|
|
|
order.gatewayName = self.gatewayName
|
2016-07-02 03:12:44 +00:00
|
|
|
|
|
|
|
|
|
# CTA委托类型映射
|
|
|
|
|
if orderType == CTAORDER_BUY:
|
|
|
|
|
order.direction = DIRECTION_LONG
|
|
|
|
|
order.offset = OFFSET_OPEN
|
|
|
|
|
elif orderType == CTAORDER_SELL:
|
|
|
|
|
order.direction = DIRECTION_SHORT
|
|
|
|
|
order.offset = OFFSET_CLOSE
|
|
|
|
|
elif orderType == CTAORDER_SHORT:
|
|
|
|
|
order.direction = DIRECTION_SHORT
|
|
|
|
|
order.offset = OFFSET_OPEN
|
|
|
|
|
elif orderType == CTAORDER_COVER:
|
|
|
|
|
order.direction = DIRECTION_LONG
|
|
|
|
|
order.offset = OFFSET_CLOSE
|
2016-09-14 14:29:31 +00:00
|
|
|
|
|
|
|
|
|
# modified by IncenseLee
|
2016-09-15 02:08:14 +00:00
|
|
|
|
key = u'{0}.{1}'.format(order.gatewayName, orderID)
|
|
|
|
|
# 保存到限价单字典中
|
|
|
|
|
self.workingLimitOrderDict[key] = order
|
|
|
|
|
self.limitOrderDict[key] = order
|
|
|
|
|
return key
|
2016-07-02 03:12:44 +00:00
|
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------
|
|
|
|
|
def cancelOrder(self, vtOrderID):
|
|
|
|
|
"""撤单"""
|
|
|
|
|
if vtOrderID in self.workingLimitOrderDict:
|
|
|
|
|
order = self.workingLimitOrderDict[vtOrderID]
|
|
|
|
|
order.status = STATUS_CANCELLED
|
|
|
|
|
order.cancelTime = str(self.dt)
|
|
|
|
|
del self.workingLimitOrderDict[vtOrderID]
|
|
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------
|
|
|
|
|
def sendStopOrder(self, vtSymbol, orderType, price, volume, strategy):
|
|
|
|
|
"""发停止单(本地实现)"""
|
2016-09-15 02:08:14 +00:00
|
|
|
|
|
2016-07-02 03:12:44 +00:00
|
|
|
|
self.stopOrderCount += 1
|
|
|
|
|
stopOrderID = STOPORDERPREFIX + str(self.stopOrderCount)
|
|
|
|
|
|
|
|
|
|
so = StopOrder()
|
|
|
|
|
so.vtSymbol = vtSymbol
|
|
|
|
|
so.price = price
|
|
|
|
|
so.volume = volume
|
|
|
|
|
so.strategy = strategy
|
|
|
|
|
so.stopOrderID = stopOrderID
|
|
|
|
|
so.status = STOPORDER_WAITING
|
|
|
|
|
|
|
|
|
|
if orderType == CTAORDER_BUY:
|
|
|
|
|
so.direction = DIRECTION_LONG
|
|
|
|
|
so.offset = OFFSET_OPEN
|
|
|
|
|
elif orderType == CTAORDER_SELL:
|
|
|
|
|
so.direction = DIRECTION_SHORT
|
|
|
|
|
so.offset = OFFSET_CLOSE
|
|
|
|
|
elif orderType == CTAORDER_SHORT:
|
|
|
|
|
so.direction = DIRECTION_SHORT
|
|
|
|
|
so.offset = OFFSET_OPEN
|
|
|
|
|
elif orderType == CTAORDER_COVER:
|
|
|
|
|
so.direction = DIRECTION_LONG
|
|
|
|
|
so.offset = OFFSET_CLOSE
|
|
|
|
|
|
|
|
|
|
# 保存stopOrder对象到字典中
|
|
|
|
|
self.stopOrderDict[stopOrderID] = so
|
|
|
|
|
self.workingStopOrderDict[stopOrderID] = so
|
|
|
|
|
|
|
|
|
|
return stopOrderID
|
|
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------
|
|
|
|
|
def cancelStopOrder(self, stopOrderID):
|
|
|
|
|
"""撤销停止单"""
|
|
|
|
|
# 检查停止单是否存在
|
|
|
|
|
if stopOrderID in self.workingStopOrderDict:
|
|
|
|
|
so = self.workingStopOrderDict[stopOrderID]
|
|
|
|
|
so.status = STOPORDER_CANCELLED
|
|
|
|
|
del self.workingStopOrderDict[stopOrderID]
|
|
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------
|
|
|
|
|
def crossLimitOrder(self):
|
|
|
|
|
"""基于最新数据撮合限价单"""
|
|
|
|
|
# 先确定会撮合成交的价格
|
|
|
|
|
if self.mode == self.BAR_MODE:
|
2016-09-14 14:29:31 +00:00
|
|
|
|
buyCrossPrice = self.bar.low # 若买入方向限价单价格高于该价格,则会成交
|
|
|
|
|
sellCrossPrice = self.bar.high # 若卖出方向限价单价格低于该价格,则会成交
|
|
|
|
|
buyBestCrossPrice = self.bar.open # 在当前时间点前发出的买入委托可能的最优成交价
|
|
|
|
|
sellBestCrossPrice = self.bar.open # 在当前时间点前发出的卖出委托可能的最优成交价
|
2016-07-02 03:12:44 +00:00
|
|
|
|
else:
|
2016-09-14 14:29:31 +00:00
|
|
|
|
buyCrossPrice = self.tick.askPrice1
|
|
|
|
|
sellCrossPrice = self.tick.bidPrice1
|
|
|
|
|
buyBestCrossPrice = self.tick.askPrice1
|
|
|
|
|
sellBestCrossPrice = self.tick.bidPrice1
|
2016-07-02 03:12:44 +00:00
|
|
|
|
|
|
|
|
|
# 遍历限价单字典中的所有限价单
|
|
|
|
|
for orderID, order in self.workingLimitOrderDict.items():
|
|
|
|
|
# 判断是否会成交
|
|
|
|
|
buyCross = order.direction==DIRECTION_LONG and order.price>=buyCrossPrice
|
|
|
|
|
sellCross = order.direction==DIRECTION_SHORT and order.price<=sellCrossPrice
|
|
|
|
|
|
|
|
|
|
# 如果发生了成交
|
|
|
|
|
if buyCross or sellCross:
|
|
|
|
|
# 推送成交数据
|
|
|
|
|
self.tradeCount += 1 # 成交编号自增1
|
2016-09-30 12:39:18 +00:00
|
|
|
|
|
2016-07-02 03:12:44 +00:00
|
|
|
|
tradeID = str(self.tradeCount)
|
|
|
|
|
trade = VtTradeData()
|
|
|
|
|
trade.vtSymbol = order.vtSymbol
|
|
|
|
|
trade.tradeID = tradeID
|
|
|
|
|
trade.vtTradeID = tradeID
|
|
|
|
|
trade.orderID = order.orderID
|
|
|
|
|
trade.vtOrderID = order.orderID
|
|
|
|
|
trade.direction = order.direction
|
|
|
|
|
trade.offset = order.offset
|
|
|
|
|
|
|
|
|
|
# 以买入为例:
|
|
|
|
|
# 1. 假设当根K线的OHLC分别为:100, 125, 90, 110
|
|
|
|
|
# 2. 假设在上一根K线结束(也是当前K线开始)的时刻,策略发出的委托为限价105
|
|
|
|
|
# 3. 则在实际中的成交价会是100而不是105,因为委托发出时市场的最优价格是100
|
|
|
|
|
if buyCross:
|
2016-09-14 14:29:31 +00:00
|
|
|
|
trade.price = min(order.price, buyBestCrossPrice)
|
2016-07-02 03:12:44 +00:00
|
|
|
|
self.strategy.pos += order.totalVolume
|
|
|
|
|
else:
|
2016-09-14 14:29:31 +00:00
|
|
|
|
trade.price = max(order.price, sellBestCrossPrice)
|
2016-07-02 03:12:44 +00:00
|
|
|
|
self.strategy.pos -= order.totalVolume
|
|
|
|
|
|
|
|
|
|
trade.volume = order.totalVolume
|
|
|
|
|
trade.tradeTime = str(self.dt)
|
|
|
|
|
trade.dt = self.dt
|
|
|
|
|
self.strategy.onTrade(trade)
|
|
|
|
|
|
|
|
|
|
self.tradeDict[tradeID] = trade
|
2016-09-30 12:39:18 +00:00
|
|
|
|
self.writeCtaLog(u'TradeId:{0}'.format(tradeID))
|
2016-07-02 03:12:44 +00:00
|
|
|
|
|
|
|
|
|
# 推送委托数据
|
|
|
|
|
order.tradedVolume = order.totalVolume
|
|
|
|
|
order.status = STATUS_ALLTRADED
|
2016-09-30 12:39:18 +00:00
|
|
|
|
|
2016-07-02 03:12:44 +00:00
|
|
|
|
self.strategy.onOrder(order)
|
|
|
|
|
|
|
|
|
|
# 从字典中删除该限价单
|
|
|
|
|
del self.workingLimitOrderDict[orderID]
|
|
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------
|
|
|
|
|
def crossStopOrder(self):
|
|
|
|
|
"""基于最新数据撮合停止单"""
|
|
|
|
|
# 先确定会撮合成交的价格,这里和限价单规则相反
|
|
|
|
|
if self.mode == self.BAR_MODE:
|
|
|
|
|
buyCrossPrice = self.bar.high # 若买入方向停止单价格低于该价格,则会成交
|
|
|
|
|
sellCrossPrice = self.bar.low # 若卖出方向限价单价格高于该价格,则会成交
|
|
|
|
|
bestCrossPrice = self.bar.open # 最优成交价,买入停止单不能低于,卖出停止单不能高于
|
|
|
|
|
else:
|
|
|
|
|
buyCrossPrice = self.tick.lastPrice
|
|
|
|
|
sellCrossPrice = self.tick.lastPrice
|
|
|
|
|
bestCrossPrice = self.tick.lastPrice
|
|
|
|
|
|
|
|
|
|
# 遍历停止单字典中的所有停止单
|
|
|
|
|
for stopOrderID, so in self.workingStopOrderDict.items():
|
|
|
|
|
# 判断是否会成交
|
|
|
|
|
buyCross = so.direction==DIRECTION_LONG and so.price<=buyCrossPrice
|
|
|
|
|
sellCross = so.direction==DIRECTION_SHORT and so.price>=sellCrossPrice
|
|
|
|
|
|
|
|
|
|
# 如果发生了成交
|
|
|
|
|
if buyCross or sellCross:
|
|
|
|
|
# 推送成交数据
|
|
|
|
|
self.tradeCount += 1 # 成交编号自增1
|
|
|
|
|
tradeID = str(self.tradeCount)
|
|
|
|
|
trade = VtTradeData()
|
|
|
|
|
trade.vtSymbol = so.vtSymbol
|
|
|
|
|
trade.tradeID = tradeID
|
|
|
|
|
trade.vtTradeID = tradeID
|
|
|
|
|
|
|
|
|
|
if buyCross:
|
|
|
|
|
self.strategy.pos += so.volume
|
|
|
|
|
trade.price = max(bestCrossPrice, so.price)
|
|
|
|
|
else:
|
|
|
|
|
self.strategy.pos -= so.volume
|
|
|
|
|
trade.price = min(bestCrossPrice, so.price)
|
|
|
|
|
|
|
|
|
|
self.limitOrderCount += 1
|
|
|
|
|
orderID = str(self.limitOrderCount)
|
|
|
|
|
trade.orderID = orderID
|
|
|
|
|
trade.vtOrderID = orderID
|
|
|
|
|
|
|
|
|
|
trade.direction = so.direction
|
|
|
|
|
trade.offset = so.offset
|
|
|
|
|
trade.volume = so.volume
|
|
|
|
|
trade.tradeTime = str(self.dt)
|
|
|
|
|
trade.dt = self.dt
|
|
|
|
|
self.strategy.onTrade(trade)
|
|
|
|
|
|
|
|
|
|
self.tradeDict[tradeID] = trade
|
|
|
|
|
|
|
|
|
|
# 推送委托数据
|
|
|
|
|
so.status = STOPORDER_TRIGGERED
|
|
|
|
|
|
|
|
|
|
order = VtOrderData()
|
|
|
|
|
order.vtSymbol = so.vtSymbol
|
|
|
|
|
order.symbol = so.vtSymbol
|
|
|
|
|
order.orderID = orderID
|
|
|
|
|
order.vtOrderID = orderID
|
|
|
|
|
order.direction = so.direction
|
|
|
|
|
order.offset = so.offset
|
|
|
|
|
order.price = so.price
|
|
|
|
|
order.totalVolume = so.volume
|
|
|
|
|
order.tradedVolume = so.volume
|
|
|
|
|
order.status = STATUS_ALLTRADED
|
|
|
|
|
order.orderTime = trade.tradeTime
|
|
|
|
|
self.strategy.onOrder(order)
|
|
|
|
|
|
|
|
|
|
self.limitOrderDict[orderID] = order
|
|
|
|
|
|
|
|
|
|
# 从字典中删除该限价单
|
|
|
|
|
del self.workingStopOrderDict[stopOrderID]
|
|
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------
|
|
|
|
|
def insertData(self, dbName, collectionName, data):
|
|
|
|
|
"""考虑到回测中不允许向数据库插入数据,防止实盘交易中的一些代码出错"""
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------
|
|
|
|
|
def loadBar(self, dbName, collectionName, startDate):
|
|
|
|
|
"""直接返回初始化数据列表中的Bar"""
|
|
|
|
|
return self.initData
|
|
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------
|
|
|
|
|
def loadTick(self, dbName, collectionName, startDate):
|
|
|
|
|
"""直接返回初始化数据列表中的Tick"""
|
|
|
|
|
return self.initData
|
|
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------
|
|
|
|
|
def writeCtaLog(self, content):
|
|
|
|
|
"""记录日志"""
|
|
|
|
|
log = str(self.dt) + ' ' + content
|
|
|
|
|
self.logList.append(log)
|
2016-09-14 14:29:31 +00:00
|
|
|
|
|
|
|
|
|
# 写入本地log日志
|
|
|
|
|
logging.info(content)
|
2016-07-02 03:12:44 +00:00
|
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------
|
|
|
|
|
def output(self, content):
|
|
|
|
|
"""输出内容"""
|
2016-09-14 14:29:31 +00:00
|
|
|
|
print str(datetime.now()) + "\t" + content
|
|
|
|
|
|
2016-07-02 03:12:44 +00:00
|
|
|
|
#----------------------------------------------------------------------
|
2016-09-14 14:29:31 +00:00
|
|
|
|
def calculateBacktestingResult(self):
|
2016-07-02 03:12:44 +00:00
|
|
|
|
"""
|
2016-09-14 14:29:31 +00:00
|
|
|
|
计算回测结果
|
2016-07-02 03:12:44 +00:00
|
|
|
|
"""
|
2016-09-14 14:29:31 +00:00
|
|
|
|
self.output(u'计算回测结果')
|
2016-07-02 03:12:44 +00:00
|
|
|
|
|
|
|
|
|
# 首先基于回测后的成交记录,计算每笔交易的盈亏
|
2016-09-14 14:29:31 +00:00
|
|
|
|
resultDict = OrderedDict() # 交易结果记录
|
2016-07-02 03:12:44 +00:00
|
|
|
|
longTrade = [] # 未平仓的多头交易
|
|
|
|
|
shortTrade = [] # 未平仓的空头交易
|
2016-09-14 14:29:31 +00:00
|
|
|
|
|
2016-09-30 12:39:18 +00:00
|
|
|
|
i = 0
|
|
|
|
|
|
|
|
|
|
longid = EMPTY_STRING
|
|
|
|
|
shortid = EMPTY_STRING
|
|
|
|
|
|
|
|
|
|
for tradeid in self.tradeDict.keys():
|
|
|
|
|
|
|
|
|
|
trade = self.tradeDict[tradeid]
|
|
|
|
|
|
|
|
|
|
if tradeid == '127':
|
|
|
|
|
pass
|
|
|
|
|
|
2016-07-02 03:12:44 +00:00
|
|
|
|
# 多头交易
|
|
|
|
|
if trade.direction == DIRECTION_LONG:
|
|
|
|
|
# 如果尚无空头交易
|
|
|
|
|
if not shortTrade:
|
|
|
|
|
longTrade.append(trade)
|
2016-09-30 12:39:18 +00:00
|
|
|
|
longid = tradeid
|
2016-07-02 03:12:44 +00:00
|
|
|
|
# 当前多头交易为平空
|
|
|
|
|
else:
|
|
|
|
|
entryTrade = shortTrade.pop(0)
|
2016-09-14 14:29:31 +00:00
|
|
|
|
|
|
|
|
|
result = TradingResult(entryTrade.price, trade.price, -trade.volume,
|
|
|
|
|
self.rate, self.slippage, self.size)
|
|
|
|
|
|
|
|
|
|
resultDict[trade.dt] = result
|
|
|
|
|
|
2016-09-30 12:39:18 +00:00
|
|
|
|
i = i+1
|
|
|
|
|
self.writeCtaLog(u'{6}.{7}.开空{0},short:{1},{8}.平空{2},cover:{3},vol:{4},净盈亏:{5}'
|
|
|
|
|
.format(entryTrade.tradeTime, entryTrade.price,
|
|
|
|
|
trade.tradeTime,trade.price, trade.volume,result.pnl,
|
|
|
|
|
i,shortid,tradeid))
|
2016-09-14 14:29:31 +00:00
|
|
|
|
|
2016-07-02 03:12:44 +00:00
|
|
|
|
# 空头交易
|
|
|
|
|
else:
|
|
|
|
|
# 如果尚无多头交易
|
|
|
|
|
if not longTrade:
|
|
|
|
|
shortTrade.append(trade)
|
2016-09-30 12:39:18 +00:00
|
|
|
|
shortid = tradeid
|
2016-07-02 03:12:44 +00:00
|
|
|
|
# 当前空头交易为平多
|
|
|
|
|
else:
|
|
|
|
|
entryTrade = longTrade.pop(0)
|
2016-09-14 14:29:31 +00:00
|
|
|
|
|
|
|
|
|
result = TradingResult(entryTrade.price, trade.price, trade.volume,
|
|
|
|
|
self.rate, self.slippage, self.size)
|
|
|
|
|
resultDict[trade.dt] = result
|
|
|
|
|
|
2016-09-30 12:39:18 +00:00
|
|
|
|
i = i+1
|
|
|
|
|
self.writeCtaLog(u'{6}.{7}开多{0},buy:{1},{8}.平多{2},sell:{3},vol:{4},净盈亏:{5}'
|
|
|
|
|
.format(entryTrade.tradeTime, entryTrade.price,
|
|
|
|
|
trade.tradeTime,trade.price, trade.volume,result.pnl,
|
|
|
|
|
i,longid,tradeid))
|
2016-09-14 14:29:31 +00:00
|
|
|
|
|
|
|
|
|
# 检查是否有交易
|
|
|
|
|
if not resultDict:
|
|
|
|
|
self.output(u'无交易结果')
|
|
|
|
|
return {}
|
2016-07-02 03:12:44 +00:00
|
|
|
|
|
|
|
|
|
# 然后基于每笔交易的结果,我们可以计算具体的盈亏曲线和最大回撤等
|
2016-09-14 14:29:31 +00:00
|
|
|
|
capital = 0 # 资金
|
|
|
|
|
maxCapital = 0 # 资金最高净值
|
|
|
|
|
drawdown = 0 # 回撤
|
2016-07-02 03:12:44 +00:00
|
|
|
|
|
2016-09-14 14:29:31 +00:00
|
|
|
|
totalResult = 0 # 总成交数量
|
|
|
|
|
totalTurnover = 0 # 总成交金额(合约面值)
|
|
|
|
|
totalCommission = 0 # 总手续费
|
|
|
|
|
totalSlippage = 0 # 总滑点
|
2016-07-02 03:12:44 +00:00
|
|
|
|
|
2016-09-14 14:29:31 +00:00
|
|
|
|
timeList = [] # 时间序列
|
|
|
|
|
pnlList = [] # 每笔盈亏序列
|
2016-07-02 03:12:44 +00:00
|
|
|
|
capitalList = [] # 盈亏汇总的时间序列
|
|
|
|
|
drawdownList = [] # 回撤的时间序列
|
|
|
|
|
|
2016-09-14 14:29:31 +00:00
|
|
|
|
for time, result in resultDict.items():
|
|
|
|
|
capital += result.pnl
|
2016-07-02 03:12:44 +00:00
|
|
|
|
maxCapital = max(capital, maxCapital)
|
|
|
|
|
drawdown = capital - maxCapital
|
|
|
|
|
|
2016-09-14 14:29:31 +00:00
|
|
|
|
pnlList.append(result.pnl)
|
|
|
|
|
timeList.append(time)
|
2016-07-02 03:12:44 +00:00
|
|
|
|
capitalList.append(capital)
|
|
|
|
|
drawdownList.append(drawdown)
|
|
|
|
|
|
2016-09-14 14:29:31 +00:00
|
|
|
|
totalResult += 1
|
|
|
|
|
totalTurnover += result.turnover
|
|
|
|
|
totalCommission += result.commission
|
|
|
|
|
totalSlippage += result.slippage
|
|
|
|
|
|
|
|
|
|
# 返回回测结果
|
|
|
|
|
d = {}
|
|
|
|
|
d['capital'] = capital
|
|
|
|
|
d['maxCapital'] = maxCapital
|
|
|
|
|
d['drawdown'] = drawdown
|
|
|
|
|
d['totalResult'] = totalResult
|
|
|
|
|
d['totalTurnover'] = totalTurnover
|
|
|
|
|
d['totalCommission'] = totalCommission
|
|
|
|
|
d['totalSlippage'] = totalSlippage
|
|
|
|
|
d['timeList'] = timeList
|
|
|
|
|
d['pnlList'] = pnlList
|
|
|
|
|
d['capitalList'] = capitalList
|
|
|
|
|
d['drawdownList'] = drawdownList
|
|
|
|
|
return d
|
|
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------
|
|
|
|
|
def showBacktestingResult(self):
|
|
|
|
|
"""显示回测结果"""
|
|
|
|
|
d = self.calculateBacktestingResult()
|
|
|
|
|
|
|
|
|
|
if len(d)== 0:
|
|
|
|
|
self.output(u'无交易结果')
|
|
|
|
|
return
|
2016-07-02 03:12:44 +00:00
|
|
|
|
# 输出
|
2016-09-14 14:29:31 +00:00
|
|
|
|
self.output('-' * 30)
|
|
|
|
|
self.output(u'第一笔交易:\t%s' % d['timeList'][0])
|
|
|
|
|
self.output(u'最后一笔交易:\t%s' % d['timeList'][-1])
|
|
|
|
|
|
|
|
|
|
self.output(u'总交易次数:\t%s' % formatNumber(d['totalResult']))
|
|
|
|
|
self.output(u'总盈亏:\t%s' % formatNumber(d['capital']))
|
|
|
|
|
self.output(u'最大回撤: \t%s' % formatNumber(min(d['drawdownList'])))
|
|
|
|
|
|
|
|
|
|
self.output(u'平均每笔盈利:\t%s' %formatNumber(d['capital']/d['totalResult']))
|
2016-10-10 16:50:47 +00:00
|
|
|
|
self.output(u'平均每笔滑点成本:\t%s' %formatNumber(d['totalSlippage']/d['totalResult']))
|
2016-09-14 14:29:31 +00:00
|
|
|
|
self.output(u'平均每笔佣金:\t%s' %formatNumber(d['totalCommission']/d['totalResult']))
|
2016-07-02 03:12:44 +00:00
|
|
|
|
|
|
|
|
|
# 绘图
|
2016-09-30 12:39:18 +00:00
|
|
|
|
import matplotlib.pyplot as plt
|
2016-07-02 03:12:44 +00:00
|
|
|
|
|
2016-09-30 12:39:18 +00:00
|
|
|
|
pCapital = plt.subplot(3, 1, 1)
|
|
|
|
|
pCapital.set_ylabel("capital")
|
|
|
|
|
pCapital.plot(d['capitalList'])
|
2016-07-02 03:12:44 +00:00
|
|
|
|
|
2016-09-30 12:39:18 +00:00
|
|
|
|
pDD = plt.subplot(3, 1, 2)
|
|
|
|
|
pDD.set_ylabel("DD")
|
|
|
|
|
pDD.bar(range(len(d['drawdownList'])), d['drawdownList'])
|
2016-07-02 03:12:44 +00:00
|
|
|
|
|
2016-09-30 12:39:18 +00:00
|
|
|
|
pPnl = plt.subplot(3, 1, 3)
|
|
|
|
|
pPnl.set_ylabel("pnl")
|
|
|
|
|
pPnl.hist(d['pnlList'], bins=50)
|
2016-07-02 03:12:44 +00:00
|
|
|
|
|
2016-09-30 12:39:18 +00:00
|
|
|
|
plt.show()
|
2016-07-02 03:12:44 +00:00
|
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------
|
|
|
|
|
def putStrategyEvent(self, name):
|
|
|
|
|
"""发送策略更新事件,回测中忽略"""
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------
|
|
|
|
|
def setSlippage(self, slippage):
|
2016-09-14 14:29:31 +00:00
|
|
|
|
"""设置滑点点数"""
|
2016-07-02 03:12:44 +00:00
|
|
|
|
self.slippage = slippage
|
|
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------
|
|
|
|
|
def setSize(self, size):
|
|
|
|
|
"""设置合约大小"""
|
|
|
|
|
self.size = size
|
|
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------
|
|
|
|
|
def setRate(self, rate):
|
|
|
|
|
"""设置佣金比例"""
|
2016-10-10 16:50:47 +00:00
|
|
|
|
self.rate = float(rate)
|
2016-07-02 03:12:44 +00:00
|
|
|
|
|
2016-09-14 14:29:31 +00:00
|
|
|
|
#----------------------------------------------------------------------
|
|
|
|
|
def runOptimization(self, strategyClass, optimizationSetting):
|
|
|
|
|
"""优化参数"""
|
|
|
|
|
# 获取优化设置
|
|
|
|
|
settingList = optimizationSetting.generateSetting()
|
|
|
|
|
targetName = optimizationSetting.optimizeTarget
|
|
|
|
|
|
|
|
|
|
# 检查参数设置问题
|
|
|
|
|
if not settingList or not targetName:
|
|
|
|
|
self.output(u'优化设置有问题,请检查')
|
|
|
|
|
|
|
|
|
|
# 遍历优化
|
|
|
|
|
resultList = []
|
|
|
|
|
for setting in settingList:
|
|
|
|
|
self.clearBacktestingResult()
|
|
|
|
|
self.output('-' * 30)
|
|
|
|
|
self.output('setting: %s' %str(setting))
|
|
|
|
|
self.initStrategy(strategyClass, setting)
|
|
|
|
|
self.runBacktesting()
|
|
|
|
|
d = self.calculateBacktestingResult()
|
|
|
|
|
try:
|
|
|
|
|
targetValue = d[targetName]
|
|
|
|
|
except KeyError:
|
|
|
|
|
targetValue = 0
|
|
|
|
|
resultList.append(([str(setting)], targetValue))
|
|
|
|
|
|
|
|
|
|
# 显示结果
|
|
|
|
|
resultList.sort(reverse=True, key=lambda result:result[1])
|
|
|
|
|
self.output('-' * 30)
|
|
|
|
|
self.output(u'优化结果:')
|
|
|
|
|
for result in resultList:
|
|
|
|
|
self.output(u'%s: %s' %(result[0], result[1]))
|
|
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------
|
|
|
|
|
def clearBacktestingResult(self):
|
|
|
|
|
"""清空之前回测的结果"""
|
|
|
|
|
# 清空限价单相关
|
|
|
|
|
self.limitOrderCount = 0
|
|
|
|
|
self.limitOrderDict.clear()
|
|
|
|
|
self.workingLimitOrderDict.clear()
|
|
|
|
|
|
|
|
|
|
# 清空停止单相关
|
|
|
|
|
self.stopOrderCount = 0
|
|
|
|
|
self.stopOrderDict.clear()
|
|
|
|
|
self.workingStopOrderDict.clear()
|
|
|
|
|
|
|
|
|
|
# 清空成交相关
|
|
|
|
|
self.tradeCount = 0
|
|
|
|
|
self.tradeDict.clear()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
########################################################################
|
|
|
|
|
class TradingResult(object):
|
|
|
|
|
"""每笔交易的结果"""
|
|
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------
|
|
|
|
|
def __init__(self, entry, exit, volume, rate, slippage, size):
|
|
|
|
|
"""Constructor"""
|
|
|
|
|
self.entry = entry # 开仓价格
|
|
|
|
|
self.exit = exit # 平仓价格
|
|
|
|
|
self.volume = volume # 交易数量(+/-代表方向)
|
|
|
|
|
|
|
|
|
|
self.turnover = (self.entry+self.exit)*size # 成交金额
|
2016-10-10 16:50:47 +00:00
|
|
|
|
self.commission =round(float(self.turnover*rate),4) # 手续费成本
|
2016-09-14 14:29:31 +00:00
|
|
|
|
self.slippage = slippage*2*size # 滑点成本
|
|
|
|
|
self.pnl = ((self.exit - self.entry) * volume * size
|
|
|
|
|
- self.commission - self.slippage) # 净盈亏
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
########################################################################
|
|
|
|
|
class OptimizationSetting(object):
|
|
|
|
|
"""优化设置"""
|
|
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------
|
|
|
|
|
def __init__(self):
|
|
|
|
|
"""Constructor"""
|
|
|
|
|
self.paramDict = OrderedDict()
|
|
|
|
|
|
|
|
|
|
self.optimizeTarget = '' # 优化目标字段
|
|
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------
|
|
|
|
|
def addParameter(self, name, start, end, step):
|
|
|
|
|
"""增加优化参数"""
|
|
|
|
|
if end <= start:
|
|
|
|
|
print u'参数起始点必须小于终止点'
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if step <= 0:
|
|
|
|
|
print u'参数布进必须大于0'
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
l = []
|
|
|
|
|
param = start
|
|
|
|
|
|
|
|
|
|
while param <= end:
|
|
|
|
|
l.append(param)
|
|
|
|
|
param += step
|
|
|
|
|
|
|
|
|
|
self.paramDict[name] = l
|
|
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------
|
|
|
|
|
def generateSetting(self):
|
|
|
|
|
"""生成优化参数组合"""
|
|
|
|
|
# 参数名的列表
|
|
|
|
|
nameList = self.paramDict.keys()
|
|
|
|
|
paramList = self.paramDict.values()
|
|
|
|
|
|
|
|
|
|
# 使用迭代工具生产参数对组合
|
|
|
|
|
productList = list(product(*paramList))
|
|
|
|
|
|
|
|
|
|
# 把参数对组合打包到一个个字典组成的列表中
|
|
|
|
|
settingList = []
|
|
|
|
|
for p in productList:
|
|
|
|
|
d = dict(zip(nameList, p))
|
|
|
|
|
settingList.append(d)
|
|
|
|
|
|
|
|
|
|
return settingList
|
|
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------
|
|
|
|
|
def setOptimizeTarget(self, target):
|
|
|
|
|
"""设置优化目标字段"""
|
|
|
|
|
self.optimizeTarget = target
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#----------------------------------------------------------------------
|
|
|
|
|
def formatNumber(n):
|
|
|
|
|
"""格式化数字到字符串"""
|
|
|
|
|
n = round(n, 2) # 保留两位小数
|
|
|
|
|
return format(n, ',') # 加上千分符
|
|
|
|
|
|
2016-07-02 03:12:44 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
# 以下内容是一段回测脚本的演示,用户可以根据自己的需求修改
|
|
|
|
|
# 建议使用ipython notebook或者spyder来做回测
|
|
|
|
|
# 同样可以在命令模式下进行回测(一行一行输入运行)
|
|
|
|
|
from ctaDemo import *
|
|
|
|
|
|
|
|
|
|
# 创建回测引擎
|
|
|
|
|
engine = BacktestingEngine()
|
|
|
|
|
|
|
|
|
|
# 设置引擎的回测模式为K线
|
|
|
|
|
engine.setBacktestingMode(engine.BAR_MODE)
|
|
|
|
|
|
|
|
|
|
# 设置回测用的数据起始日期
|
|
|
|
|
engine.setStartDate('20110101')
|
|
|
|
|
|
|
|
|
|
# 载入历史数据到引擎中
|
2016-09-14 14:29:31 +00:00
|
|
|
|
engine.setDatabase(MINUTE_DB_NAME, 'IF0000')
|
2016-07-02 03:12:44 +00:00
|
|
|
|
|
|
|
|
|
# 设置产品相关参数
|
|
|
|
|
engine.setSlippage(0.2) # 股指1跳
|
|
|
|
|
engine.setRate(0.3/10000) # 万0.3
|
|
|
|
|
engine.setSize(300) # 股指合约大小
|
|
|
|
|
|
|
|
|
|
# 在引擎中创建策略对象
|
|
|
|
|
engine.initStrategy(DoubleEmaDemo, {})
|
|
|
|
|
|
|
|
|
|
# 开始跑回测
|
|
|
|
|
engine.runBacktesting()
|
|
|
|
|
|
|
|
|
|
# 显示回测结果
|
|
|
|
|
# spyder或者ipython notebook中运行时,会弹出盈亏曲线图
|
|
|
|
|
# 直接在cmd中回测则只会打印一些回测数值
|
|
|
|
|
engine.showBacktestingResult()
|
|
|
|
|
|
|
|
|
|
|