# encoding: UTF-8 ''' 本文件中包含的是CTA模块的回测引擎,回测引擎的API和CTA引擎一致, 可以使用和实盘相同的代码进行回测。 ''' from __future__ import division from datetime import datetime, timedelta from collections import OrderedDict from itertools import product import multiprocessing import pymongo from ctaBase import * from vtConstant import * from vtGateway import VtOrderData, VtTradeData from vtFunction import loadMongoSetting from eventEngine import * import MySQLdb import json import os import cPickle import csv import logging import copy import pandas as pd import re ######################################################################## class BacktestingEngine(object): """ CTA回测引擎 函数接口和策略引擎保持一样, 从而实现同一套代码从回测到实盘。 # modified by IncenseLee: 1.增加Mysql数据库的支持; 2.修改装载数据为批量式后加载模式。 3.增加csv 读取bar的回测模式 4.增加csv 读取tick合并价差的回测模式 5.增加EventEngine,并对newBar增加发送OnBar事件,供外部的回测主体显示Bar线。 """ TICK_MODE = 'tick' # 数据模式,逐Tick回测 BAR_MODE = 'bar' # 数据模式,逐Bar回测 REALTIME_MODE ='RealTime' # 逐笔交易计算资金,供策略获取资金容量,计算开仓数量 FINAL_MODE = 'Final' # 最后才统计交易,不适合按照百分比等开仓数量计算 #---------------------------------------------------------------------- def __init__(self, eventEngine = None): """Constructor""" self.eventEngine = eventEngine # 本地停止单编号计数 self.stopOrderCount = 0 # stopOrderID = STOPORDERPREFIX + str(stopOrderCount) # 本地停止单字典 # key为stopOrderID,value为stopOrder对象 self.stopOrderDict = {} # 停止单撤销后不会从本字典中删除 self.workingStopOrderDict = {} # 停止单撤销后会从本字典中删除 # 引擎类型为回测 self.engineType = ENGINETYPE_BACKTESTING # 回测相关 self.strategy = None # 回测策略 self.mode = self.BAR_MODE # 回测模式,默认为K线 self.startDate = '' self.initDays = 0 self.endDate = '' self.slippage = 0 # 回测时假设的滑点 self.rate = 0 # 回测时假设的佣金比例(适用于百分比佣金) self.size = 1 # 合约大小,默认为1 self.priceTick = 0 # 价格最小变动 self.dbClient = None # 数据库客户端 self.dbCursor = None # 数据库指针 self.historyData = [] # 历史数据的列表,回测用 self.initData = [] # 初始化用的数据 self.backtestingData = [] # 回测用的数据 self.dbName = '' # 回测数据库名 self.symbol = '' # 回测集合名 self.margin_rate = 0.11 # 回测合约的保证金比率 self.dataStartDate = None # 回测数据开始日期,datetime对象 self.dataEndDate = None # 回测数据结束日期,datetime对象 self.strategyStartDate = None # 策略启动日期(即前面的数据用于初始化),datetime对象 self.limitOrderDict = OrderedDict() # 限价单字典 self.workingLimitOrderDict = OrderedDict() # 活动限价单字典,用于进行撮合用 self.limitOrderCount = 0 # 限价单编号 # 持仓缓存字典 # key为vtSymbol,value为PositionBuffer对象 self.posBufferDict = {} self.tradeCount = 0 # 成交编号 self.tradeDict = OrderedDict() # 成交字典 self.logList = [] # 日志记录 # 当前最新数据,用于模拟成交用 self.tick = None self.bar = None self.dt = None # 最新的时间 self.gatewayName = u'BackTest' # csvFile相关 self.barTimeInterval = 60 # csv文件,属于K线类型,K线的周期(秒数),缺省是1分钟 # 费用情况 self.avaliable = EMPTY_FLOAT self.percent = EMPTY_FLOAT self.percentLimit = 30 # 投资仓位比例上限 # 回测计算相关 self.calculateMode = self.FINAL_MODE self.usageCompounding = False # 是否使用简单复利 (只针对FINAL_MODE有效) self.initCapital = 10000 # 期初资金 self.capital = self.initCapital # 资金 (相当于Balance) self.maxCapital = self.initCapital # 资金最高净值 self.maxPnl = 0 # 最高盈利 self.minPnl = 0 # 最大亏损 self.maxVolume = 1 # 最大仓位数 self.winningResult = 0 # 盈利次数 self.losingResult = 0 # 亏损次数 self.totalResult = 0 # 总成交数量 self.totalWinning = 0 # 总盈利 self.totalLosing = 0 # 总亏损 self.totalTurnover = 0 # 总成交金额(合约面值) self.totalCommission = 0 # 总手续费 self.totalSlippage = 0 # 总滑点 self.timeList = [] # 时间序列 self.pnlList = [] # 每笔盈亏序列 self.capitalList = [] # 盈亏汇总的时间序列 self.drawdownList = [] # 回撤的时间序列 self.drawdownRateList = [] # 最大回撤比例的时间序列 self.dailyList = [] self.exportTradeList = [] # 导出交易记录列表 self.fixCommission = EMPTY_FLOAT # 固定交易费用 def getAccountInfo(self): """返回账号的实时权益,可用资金,仓位比例,投资仓位比例上限""" if self.capital == EMPTY_FLOAT: self.percent = EMPTY_FLOAT return self.capital, self.avaliable, self.percent, self.percentLimit #---------------------------------------------------------------------- def setStartDate(self, startDate='20100416', initDays=10): """设置回测的启动日期""" self.startDate = startDate self.initDays = initDays self.dataStartDate = datetime.strptime(startDate, '%Y%m%d') # 初始化天数 initTimeDelta = timedelta(initDays) self.strategyStartDate = self.dataStartDate + initTimeDelta #---------------------------------------------------------------------- def setEndDate(self, endDate=''): """设置回测的结束日期""" self.endDate = endDate if endDate: self.dataEndDate = datetime.strptime(endDate, '%Y%m%d') # 若不修改时间则会导致不包含dataEndDate当天数据 self.dataEndDate.replace(hour=23, minute=59) else: self.dataEndDate = datetime.now() def setMinDiff(self, minDiff): """设置回测品种的最小跳价,用于修正数据""" self.minDiff = minDiff self.priceTick = minDiff #---------------------------------------------------------------------- def setBacktestingMode(self, mode): """设置回测模式""" self.mode = mode #---------------------------------------------------------------------- def setDatabase(self, dbName, symbol): """设置历史数据所用的数据库""" self.dbName = dbName self.symbol = symbol def setMarginRate(self, margin_rate): if margin_rate!= EMPTY_FLOAT: self.margin_rate = margin_rate # ---------------------------------------------------------------------- def setSlippage(self, slippage): """设置滑点点数""" self.slippage = slippage # ---------------------------------------------------------------------- def setSize(self, size): """设置合约大小""" self.size = size # ---------------------------------------------------------------------- def setRate(self, rate): """设置佣金比例""" self.rate = rate # ---------------------------------------------------------------------- def setPriceTick(self, priceTick): """设置价格最小变动""" self.priceTick = priceTick self.minDiff = priceTick #---------------------------------------------------------------------- def loadHistoryDataFromMongo(self): """载入历史数据""" host, port, log = loadMongoSetting() self.dbClient = pymongo.MongoClient(host, port) collection = self.dbClient[self.dbName][self.symbol] self.output(u'开始载入数据') # 首先根据回测模式,确认要使用的数据类 if self.mode == self.BAR_MODE: dataClass = CtaBarData func = self.newBar else: dataClass = CtaTickData func = self.newTick # 载入初始化需要用的数据 flt = {'datetime':{'$gte':self.dataStartDate, '$lt':self.strategyStartDate}} initCursor = collection.find(flt) # 将数据从查询指针中读取出,并生成列表 self.initData = [] # 清空initData列表 for d in initCursor: data = dataClass() data.__dict__ = d self.initData.append(data) # 载入回测数据 if not self.dataEndDate: flt = {'datetime':{'$gte':self.strategyStartDate}} # 数据过滤条件 else: flt = {'datetime':{'$gte':self.strategyStartDate, '$lte':self.dataEndDate}} self.dbCursor = collection.find(flt) self.output(u'载入完成,数据量:%s' %(initCursor.count() + self.dbCursor.count())) #---------------------------------------------------------------------- def connectMysql(self): """连接MysqlDB""" # 载入json文件 fileName = 'mysql_connect.json' try: f = file(fileName) except IOError: self.writeCtaLog(u'回测引擎读取Mysql_connect.json失败') return # 解析json文件 setting = json.load(f) try: mysql_host = str(setting['host']) mysql_port = int(setting['port']) mysql_user = str(setting['user']) mysql_passwd = str(setting['passwd']) mysql_db = str(setting['db']) except IOError: self.writeCtaLog(u'回测引擎读取Mysql_connect.json,连接配置缺少字段,请检查') return try: self.__mysqlConnection = MySQLdb.connect(host=mysql_host, user=mysql_user, passwd=mysql_passwd, db=mysql_db, port=mysql_port) self.__mysqlConnected = True self.writeCtaLog(u'回测引擎连接MysqlDB成功') except Exception: self.writeCtaLog(u'回测引擎连接MysqlDB失败') #---------------------------------------------------------------------- def loadDataHistoryFromMysql(self, symbol, startDate, endDate): """载入历史TICK数据 如果加载过多数据会导致加载失败,间隔不要超过半年 """ if not endDate: endDate = datetime.today() # 看本地缓存是否存在 if self.__loadDataHistoryFromLocalCache(symbol, startDate, endDate): self.writeCtaLog(u'历史TICK数据从Cache载入') return # 每次获取日期周期 intervalDays = 10 for i in range (0,(endDate - startDate).days +1, intervalDays): d1 = startDate + timedelta(days = i ) if (endDate - d1).days > 10: d2 = startDate + timedelta(days = i + intervalDays -1 ) else: d2 = endDate # 从Mysql 提取数据 self.__qryDataHistoryFromMysql(symbol, d1, d2) self.writeCtaLog(u'历史TICK数据共载入{0}条'.format(len(self.historyData))) # 保存本地cache文件 self.__saveDataHistoryToLocalCache(symbol, startDate, endDate) def __loadDataHistoryFromLocalCache(self, symbol, startDate, endDate): """看本地缓存是否存在 added by IncenseLee """ # 运行路径下cache子目录 cacheFolder = os.getcwd()+'/cache' # cache文件 cacheFile = u'{0}/{1}_{2}_{3}.pickle'.\ format(cacheFolder, symbol, startDate.strftime('%Y-%m-%d'), endDate.strftime('%Y-%m-%d')) if not os.path.isfile(cacheFile): return False else: try: # 从cache文件加载 cache = open(cacheFile,mode='r') self.historyData = cPickle.load(cache) cache.close() return True except Exception as e: self.writeCtaLog(u'读取文件{0}失败'.format(cacheFile)) return False def __saveDataHistoryToLocalCache(self, symbol, startDate, endDate): """保存本地缓存 added by IncenseLee """ # 运行路径下cache子目录 cacheFolder = os.getcwd()+'/cache' # 创建cache子目录 if not os.path.isdir(cacheFolder): os.mkdir(cacheFolder) # cache 文件名 cacheFile = u'{0}/{1}_{2}_{3}.pickle'.\ format(cacheFolder, symbol, startDate.strftime('%Y-%m-%d'), endDate.strftime('%Y-%m-%d')) # 重复存在 返回 if os.path.isfile(cacheFile): return False else: # 写入cache文件 cache = open(cacheFile, mode='w') cPickle.dump(self.historyData,cache) cache.close() return True #---------------------------------------------------------------------- def __qryDataHistoryFromMysql(self, symbol, startDate, endDate): """从Mysql载入历史TICK数据 added by IncenseLee """ try: self.connectMysql() if self.__mysqlConnected: # 获取指针 cur = self.__mysqlConnection.cursor(MySQLdb.cursors.DictCursor) if endDate: # 开始日期 ~ 结束日期 sqlstring = ' select \'{0}\' as InstrumentID, str_to_date(concat(ndate,\' \', ntime),' \ '\'%Y-%m-%d %H:%i:%s\') as UpdateTime,price as LastPrice,vol as Volume, day_vol as DayVolume,' \ 'position_vol as OpenInterest,bid1_price as BidPrice1,bid1_vol as BidVolume1, ' \ 'sell1_price as AskPrice1, sell1_vol as AskVolume1 from TB_{0}MI ' \ 'where ndate between cast(\'{1}\' as date) and cast(\'{2}\' as date) order by UpdateTime'.\ format(symbol, startDate, endDate) elif startDate: # 开始日期 - 当前 sqlstring = ' select \'{0}\' as InstrumentID,str_to_date(concat(ndate,\' \', ntime),' \ '\'%Y-%m-%d %H:%i:%s\') as UpdateTime,price as LastPrice,vol as Volume, day_vol as DayVolume,' \ 'position_vol as OpenInterest,bid1_price as BidPrice1,bid1_vol as BidVolume1, ' \ 'sell1_price as AskPrice1, sell1_vol as AskVolume1 from TB__{0}MI ' \ 'where ndate > cast(\'{1}\' as date) order by UpdateTime'.\ format( symbol, startDate) else: # 所有数据 sqlstring =' select \'{0}\' as InstrumentID,str_to_date(concat(ndate,\' \', ntime),' \ '\'%Y-%m-%d %H:%i:%s\') as UpdateTime,price as LastPrice,vol as Volume, day_vol as DayVolume,' \ 'position_vol as OpenInterest,bid1_price as BidPrice1,bid1_vol as BidVolume1, ' \ 'sell1_price as AskPrice1, sell1_vol as AskVolume1 from TB__{0}MI order by UpdateTime'.\ format(symbol) self.writeCtaLog(sqlstring) # 执行查询 count = cur.execute(sqlstring) self.writeCtaLog(u'历史TICK数据共{0}条'.format(count)) # 分批次读取 fetch_counts = 0 fetch_size = 1000 while True: results = cur.fetchmany(fetch_size) if not results: break fetch_counts = fetch_counts + len(results) if not self.historyData: self.historyData =results else: self.historyData = self.historyData + results self.writeCtaLog(u'{1}~{2}历史TICK数据载入共{0}条'.format(fetch_counts,startDate,endDate)) else: self.writeCtaLog(u'MysqlDB未连接,请检查') except MySQLdb.Error as e: self.writeCtaLog(u'MysqlDB载入数据失败,请检查.Error {0}'.format(e)) def __dataToTick(self, data): """ 数据库查询返回的data结构,转换为tick对象 added by IncenseLee """ tick = CtaTickData() symbol = data['InstrumentID'] tick.symbol = symbol # 创建TICK数据对象并更新数据 tick.vtSymbol = symbol # tick.openPrice = data['OpenPrice'] # tick.highPrice = data['HighestPrice'] # tick.lowPrice = data['LowestPrice'] tick.lastPrice = float(data['LastPrice']) # bug fix: # ctp日常传送的volume数据,是交易日日内累加值。数据库的volume,是数据商自行计算整理的 # 因此,改为使用DayVolume,与CTP实盘一致 #tick.volume = data['Volume'] tick.volume = data['DayVolume'] tick.openInterest = data['OpenInterest'] # tick.upperLimit = data['UpperLimitPrice'] # tick.lowerLimit = data['LowerLimitPrice'] tick.datetime = data['UpdateTime'] tick.date = tick.datetime.strftime('%Y-%m-%d') tick.time = tick.datetime.strftime('%H:%M:%S') # 数据库中并没有tradingDay的数据,回测时,暂时按照date授予。 tick.tradingDay = tick.date tick.bidPrice1 = float(data['BidPrice1']) # tick.bidPrice2 = data['BidPrice2'] # tick.bidPrice3 = data['BidPrice3'] # tick.bidPrice4 = data['BidPrice4'] # tick.bidPrice5 = data['BidPrice5'] tick.askPrice1 = float(data['AskPrice1']) # tick.askPrice2 = data['AskPrice2'] # tick.askPrice3 = data['AskPrice3'] # tick.askPrice4 = data['AskPrice4'] # tick.askPrice5 = data['AskPrice5'] tick.bidVolume1 = data['BidVolume1'] # tick.bidVolume2 = data['BidVolume2'] # tick.bidVolume3 = data['BidVolume3'] # tick.bidVolume4 = data['BidVolume4'] # tick.bidVolume5 = data['BidVolume5'] tick.askVolume1 = data['AskVolume1'] # tick.askVolume2 = data['AskVolume2'] # tick.askVolume3 = data['AskVolume3'] # tick.askVolume4 = data['AskVolume4'] # tick.askVolume5 = data['AskVolume5'] return tick #---------------------------------------------------------------------- def getMysqlDeltaDate(self,symbol, startDate, decreaseDays): """从mysql库中获取交易日前若干天 added by IncenseLee """ try: if self.__mysqlConnected: # 获取mysql指针 cur = self.__mysqlConnection.cursor() sqlstring='select distinct ndate from TB_{0}MI where ndate < ' \ 'cast(\'{1}\' as date) order by ndate desc limit {2},1'.format(symbol, startDate, decreaseDays-1) # self.writeCtaLog(sqlstring) count = cur.execute(sqlstring) if count > 0: # 提取第一条记录 result = cur.fetchone() return result[0] else: self.writeCtaLog(u'MysqlDB没有查询结果,请检查日期') else: self.writeCtaLog(u'MysqlDB未连接,请检查') except MySQLdb.Error as e: self.writeCtaLog(u'MysqlDB载入数据失败,请检查.Error {0}: {1}'.format(e.arg[0],e.arg[1])) # 出错后缺省返回 return startDate-timedelta(days=3) # ---------------------------------------------------------------------- def runBackTestingWithArbTickFile(self,mainPath, arbSymbol): """运行套利回测(使用本地tickcsv数据) 参数:套利代码 SP rb1610&rb1701 added by IncenseLee 原始的tick,分别存放在白天目录1和夜盘目录2中,每天都有各个合约的数据 Z:\ticks\SHFE\201606\RB\0601\ RB1610.txt RB1701.txt .... Z:\ticks\SHFE_night\201606\RB\0601 RB1610.txt RB1701.txt .... 夜盘目录为自然日,不是交易日。 按照回测的开始日期,到结束日期,循环每一天。 每天优先读取日盘数据,再读取夜盘数据。 读取eg1(如RB1610),读取Leg2(如RB701),合并成价差tick,灌输到策略的onTick中。 """ self.capital = self.initCapital # 更新设置期初资金 if len(arbSymbol) < 1: self.writeCtaLog(u'套利合约为空') return if not (arbSymbol.upper().index("SP") == 0 and arbSymbol.index(" ") > 0 and arbSymbol.index("&") > 0): self.writeCtaLog(u'套利合约格式不符合') return # 获得Leg1,leg2 legs = arbSymbol[arbSymbol.index(" "):] leg1 = legs[1:legs.index("&")] leg2 = legs[legs.index("&") + 1:] self.writeCtaLog(u'Leg1:{0},Leg2:{1}'.format(leg1, leg2)) if not self.dataStartDate: self.writeCtaLog(u'回测开始日期未设置。') return # RB if len(self.symbol)<1: self.writeCtaLog(u'回测对象未设置。') return if not self.dataEndDate: self.dataEndDate = datetime.today() #首先根据回测模式,确认要使用的数据类 if self.mode == self.BAR_MODE: self.writeCtaLog(u'本回测仅支持tick模式') return testdays = (self.dataEndDate - self.dataStartDate).days if testdays < 1: self.writeCtaLog(u'回测时间不足') return for i in range(0, testdays): testday = self.dataStartDate + timedelta(days = i) self.output(u'回测日期:{0}'.format(testday)) # 白天数据 self.__loadArbTicks(mainPath,testday,leg1,leg2) # 夜盘数据 self.__loadArbTicks(mainPath+'_night', testday, leg1, leg2) def __loadArbTicks(self,mainPath,testday,leg1,leg2): self.writeCtaLog(u'加载回测日期:{0}\{1}的价差tick'.format(mainPath, testday)) cachefilename = u'{0}_{1}_{2}_{3}_{4}'.format(self.symbol,leg1,leg2, mainPath, testday.strftime('%Y%m%d')) arbTicks = self.__loadArbTicksFromLocalCache(cachefilename) dt = None if len(arbTicks) < 1: leg1File = u'z:\\ticks\\{0}\\{1}\\{2}\\{3}\\{4}.txt' \ .format(mainPath, testday.strftime('%Y%m'), self.symbol, testday.strftime('%m%d'), leg1) if not os.path.isfile(leg1File): self.writeCtaLog(u'{0}文件不存在'.format(leg1File)) return leg2File = u'z:\\ticks\\{0}\\{1}\\{2}\\{3}\\{4}.txt' \ .format(mainPath, testday.strftime('%Y%m'), self.symbol, testday.strftime('%m%d'), leg2) if not os.path.isfile(leg2File): self.writeCtaLog(u'{0}文件不存在'.format(leg2File)) return # 先读取leg2的数据到目录,以日期时间为key leg2Ticks = {} leg2CsvReadFile = file(leg2File, 'rb') #reader = csv.DictReader((line.replace('\0',' ') for line in leg2CsvReadFile), delimiter=",") reader = csv.DictReader(leg2CsvReadFile, delimiter=",") self.writeCtaLog(u'加载{0}'.format(leg2File)) for row in reader: tick = CtaTickData() tick.vtSymbol = self.symbol tick.symbol = self.symbol tick.date = testday.strftime('%Y%m%d') tick.tradingDay = tick.date tick.time = row['Time'] try: tick.datetime = datetime.strptime(tick.date + ' ' + tick.time, '%Y%m%d %H:%M:%S.%f') except Exception as ex: self.writeCtaError(u'日期转换错误:{0},{1}:{2}'.format(tick.date + ' ' + tick.time, Exception, ex)) continue # 修正毫秒 if tick.datetime.replace(microsecond = 0) == dt: # 与上一个tick的时间(去除毫秒后)相同,修改为500毫秒 tick.datetime=tick.datetime.replace(microsecond = 500) tick.time = tick.datetime.strftime('%H:%M:%S.%f') else: tick.datetime = tick.datetime.replace(microsecond=0) tick.time = tick.datetime.strftime('%H:%M:%S.%f') dt = tick.datetime tick.lastPrice = float(row['LastPrice']) tick.volume = int(float(row['LVolume'])) tick.bidPrice1 = float(row['BidPrice']) # 叫买价(价格低) tick.bidVolume1 = int(float(row['BidVolume'])) tick.askPrice1 = float(row['AskPrice']) # 叫卖价(价格高) tick.askVolume1 = int(float(row['AskVolume'])) # 排除涨停/跌停的数据 if (tick.bidPrice1 == float('1.79769E308') and tick.bidVolume1 == 0) \ or (tick.askPrice1 == float('1.79769E308') and tick.askVolume1 == 0): continue dtStr = tick.date + ' ' + tick.time if dtStr in leg2Ticks: self.writeCtaError(u'日内数据重复,异常,数据时间为:{0}'.format(dtStr)) else: leg2Ticks[dtStr] = tick leg1CsvReadFile = file(leg1File, 'rb') #reader = csv.DictReader((line.replace('\0',' ') for line in leg1CsvReadFile), delimiter=",") reader = csv.DictReader(leg1CsvReadFile, delimiter=",") self.writeCtaLog(u'加载{0}'.format(leg1File)) dt = None for row in reader: arbTick = CtaTickData() arbTick.date = testday.strftime('%Y%m%d') arbTick.time = row['Time'] try: arbTick.datetime = datetime.strptime(arbTick.date + ' ' + arbTick.time, '%Y%m%d %H:%M:%S.%f') except Exception as ex: self.writeCtaError(u'日期转换错误:{0},{1}:{2}'.format(arbTick.date + ' ' + arbTick.time, Exception, ex)) continue # 修正毫秒 if arbTick.datetime.replace(microsecond=0) == dt: # 与上一个tick的时间(去除毫秒后)相同,修改为500毫秒 arbTick.datetime = arbTick.datetime.replace(microsecond=500) arbTick.time = arbTick.datetime.strftime('%H:%M:%S.%f') else: arbTick.datetime = arbTick.datetime.replace(microsecond=0) arbTick.time = arbTick.datetime.strftime('%H:%M:%S.%f') dt = arbTick.datetime dtStr = ' '.join([arbTick.date, arbTick.time]) if dtStr in leg2Ticks: leg2Tick = leg2Ticks[dtStr] arbTick.vtSymbol = self.symbol arbTick.symbol = self.symbol arbTick.lastPrice = EMPTY_FLOAT arbTick.volume = EMPTY_INT leg1AskPrice1 = float(row['AskPrice']) leg1AskVolume1 = int(float(row['AskVolume'])) leg1BidPrice1 = float(row['BidPrice']) leg1BidVolume1 = int(float(row['BidVolume'])) # 排除涨停/跌停的数据 if ((leg1AskPrice1 == float('1.79769E308') or leg1AskPrice1 == 0) and leg1AskVolume1 == 0) \ or ((leg1BidPrice1 == float('1.79769E308') or leg1BidPrice1 == 0) and leg1BidVolume1 == 0): continue # 叫卖价差=leg1.askPrice1 - leg2.bidPrice1,volume为两者最小 arbTick.askPrice1 = leg1AskPrice1 - leg2Tick.bidPrice1 arbTick.askVolume1 = min(leg1AskVolume1, leg2Tick.bidVolume1) # 叫买价差=leg1.bidPrice1 - leg2.askPrice1,volume为两者最小 arbTick.bidPrice1 = leg1BidPrice1 - leg2Tick.askPrice1 arbTick.bidVolume1 = min(leg1BidVolume1, leg2Tick.askVolume1) arbTicks.append(arbTick) del leg2Ticks[dtStr] # 保存到历史目录 if len(arbTicks) > 0: self.__saveArbTicksToLocalCache(cachefilename, arbTicks) for t in arbTicks: # 推送到策略中 self.newTick(t) def __loadArbTicksFromLocalCache(self, filename): """从本地缓存中,加载数据""" # 运行路径下cache子目录 cacheFolder = os.getcwd() + '/cache' # cache文件 cacheFile = u'{0}/{1}.pickle'. \ format(cacheFolder, filename) if not os.path.isfile(cacheFile): return [] else: # 从cache文件加载 cache = open(cacheFile, mode='r') l = cPickle.load(cache) cache.close() return l def __saveArbTicksToLocalCache(self, filename, arbticks): """保存价差tick到本地缓存目录""" # 运行路径下cache子目录 cacheFolder = os.getcwd() + '/cache' # 创建cache子目录 if not os.path.isdir(cacheFolder): os.mkdir(cacheFolder) # cache 文件名 cacheFile = u'{0}/{1}.pickle'. \ format(cacheFolder, filename) # 重复存在 返回 if os.path.isfile(cacheFile): return False else: # 写入cache文件 cache = open(cacheFile, mode='w') cPickle.dump(arbticks, cache) cache.close() return True def runBackTestingWithNonStrArbTickFile(self, leg1MainPath, leg2MainPath, leg1Symbol,leg2Symbol): """运行套利回测(使用本地tickcsv数据) 参数: leg1MainPath: leg1合约所在的市场路径 leg2MainPath: leg2合约所在的市场路径 leg1Symbol: leg1合约 Leg2Symbol:leg2合约 added by IncenseLee 原始的tick,分别存放在白天目录1和夜盘目录2中,每天都有各个合约的数据 Z:\ticks\SHFE\201606\RB\0601\ RB1610.txt RB1701.txt .... Z:\ticks\SHFE_night\201606\RB\0601 RB1610.txt RB1701.txt .... 夜盘目录为自然日,不是交易日。 按照回测的开始日期,到结束日期,循环每一天。 每天优先读取日盘数据,再读取夜盘数据。 读取eg1(如RB1610),读取Leg2(如RB701),根据两者tick的时间优先顺序,逐一tick灌输到策略的onTick中。 """ self.capital = self.initCapital # 更新设置期初资金 if not self.dataStartDate: self.writeCtaLog(u'回测开始日期未设置。') return # RB if len(self.symbol)<1: self.writeCtaLog(u'回测对象未设置。') return if not self.dataEndDate: self.dataEndDate = datetime.today() #首先根据回测模式,确认要使用的数据类 if self.mode == self.BAR_MODE: self.writeCtaLog(u'本回测仅支持tick模式') return testdays = (self.dataEndDate - self.dataStartDate).days if testdays < 1: self.writeCtaLog(u'回测时间不足') return for i in range(0, testdays): testday = self.dataStartDate + timedelta(days = i) self.output(u'回测日期:{0}'.format(testday)) # 加载运行白天数据 self.__loadNotStdArbTicks(leg1MainPath, leg2MainPath, testday, leg1Symbol,leg2Symbol) self.savingDailyData(testday, self.capital, self.maxCapital) # 加载运行夜盘数据 self.__loadNotStdArbTicks(leg1MainPath+'_night', leg2MainPath+'_night', testday, leg1Symbol, leg2Symbol) self.savingDailyData(self.dataEndDate, self.capital, self.maxCapital) def __loadTicksFromFile(self, filepath, tickDate, vtSymbol): """从文件中读取tick""" # 先读取数据到Dict,以日期时间为key ticks = OrderedDict() if not os.path.isfile(filepath): self.writeCtaLog(u'{0}文件不存在'.format(filepath)) return ticks dt = None csvReadFile = file(filepath, 'rb') reader = csv.DictReader(csvReadFile, delimiter=",") self.writeCtaLog(u'加载{0}'.format(filepath)) for row in reader: tick = CtaTickData() tick.vtSymbol = vtSymbol tick.symbol = vtSymbol tick.date = tickDate.strftime('%Y%m%d') tick.tradingDay = tick.date tick.time = row['Time'] try: tick.datetime = datetime.strptime(tick.date + ' ' + tick.time, '%Y%m%d %H:%M:%S.%f') except Exception as ex: self.writeCtaError(u'日期转换错误:{0},{1}:{2}'.format(tick.date + ' ' + tick.time, Exception, ex)) continue # 修正毫秒 if tick.datetime.replace(microsecond=0) == dt: # 与上一个tick的时间(去除毫秒后)相同,修改为500毫秒 tick.datetime = tick.datetime.replace(microsecond=500) tick.time = tick.datetime.strftime('%H:%M:%S.%f') else: tick.datetime = tick.datetime.replace(microsecond=0) tick.time = tick.datetime.strftime('%H:%M:%S.%f') dt = tick.datetime tick.lastPrice = float(row['LastPrice']) tick.volume = int(float(row['LVolume'])) tick.bidPrice1 = float(row['BidPrice']) # 叫买价(价格低) tick.bidVolume1 = int(float(row['BidVolume'])) tick.askPrice1 = float(row['AskPrice']) # 叫卖价(价格高) tick.askVolume1 = int(float(row['AskVolume'])) # 排除涨停/跌停的数据 if (tick.bidPrice1 == float('1.79769E308') and tick.bidVolume1 == 0) \ or (tick.askPrice1 == float('1.79769E308') and tick.askVolume1 == 0): continue dtStr = tick.date + ' ' + tick.time if dtStr in ticks: self.writeCtaError(u'日内数据重复,异常,数据时间为:{0}'.format(dtStr)) else: ticks[dtStr] = tick return ticks def __loadNotStdArbTicks(self, leg1MainPath,leg2MainPath, testday, leg1Symbol, leg2Symbol): self.writeCtaLog(u'加载回测日期:{0}的价差tick'.format( testday)) p = re.compile(r"([A-Z]+)[0-9]+", re.I) leg1_shortSymbol = p.match(leg1Symbol) leg2_shortSymbol = p.match(leg2Symbol) if leg1_shortSymbol is None or leg2_shortSymbol is None: self.writeCtaLog(u'{0},{1}不能正则分解'.format(leg1Symbol, leg2Symbol)) return leg1_shortSymbol = leg1_shortSymbol.group(1) leg2_shortSymbol = leg2_shortSymbol.group(1) # E:\Ticks\ZJ\2015\201505\TF leg1File = u'{0}\\{1}\\{2}\\{3}\\{4}\\{5}.txt' \ .format(leg1MainPath, testday.strftime('%Y'),testday.strftime('%Y%m'), leg1_shortSymbol, testday.strftime('%m%d'), leg1Symbol) if not os.path.isfile(leg1File): self.writeCtaLog(u'{0}文件不存在'.format(leg1File)) return leg2File = u'{0}\\{1}\\{2}\\{3}\\{4}\\{5}.txt' \ .format(leg2MainPath, testday.strftime('%Y'), testday.strftime('%Y%m'), leg2_shortSymbol, testday.strftime('%m%d'), leg2Symbol) if not os.path.isfile(leg2File): self.writeCtaLog(u'{0}文件不存在'.format(leg2File)) return leg1Ticks = self.__loadTicksFromFile(filepath=leg1File,tickDate= testday, vtSymbol=leg1Symbol) if len(leg1Ticks) == 0: self.writeCtaLog(u'{0}读取tick数为空'.format(leg1File)) return leg2Ticks = self.__loadTicksFromFile(filepath=leg2File, tickDate=testday, vtSymbol=leg2Symbol) if len(leg2Ticks) == 0: self.writeCtaLog(u'{0}读取tick数为空'.format(leg1File)) return leg1_tick = None leg2_tick = None while not (len(leg1Ticks) == 0 or len(leg2Ticks) == 0): if leg1_tick is None and len(leg1Ticks) > 0: leg1_tick = leg1Ticks.popitem(last=False) if leg2_tick is None and len(leg2Ticks) > 0: leg2_tick = leg2Ticks.popitem(last=False) if leg1_tick is None and leg2_tick is not None: self.newTick(leg2_tick[1]) leg2_tick = None elif leg1_tick is not None and leg2_tick is None: self.newTick(leg1_tick[1]) leg1_tick = None elif leg1_tick is not None and leg2_tick is not None: leg1 = leg1_tick[1] leg2 = leg2_tick[1] if leg1.datetime <= leg2.datetime: self.newTick(leg1) leg1_tick = None else: self.newTick(leg2) leg2_tick = None def runBackTestingWithNonStrArbTickFile2(self, leg1MainPath, leg2MainPath, leg1Symbol, leg2Symbol): """运行套利回测(使用本地tickcsv数据,数据从taobao标普购买) 参数: leg1MainPath: leg1合约所在的市场路径 leg2MainPath: leg2合约所在的市场路径 leg1Symbol: leg1合约 Leg2Symbol:leg2合约 added by IncenseLee 原始的tick,存放在相应市场下每天的目录中,目录包含市场各个合约的数据 E:\ticks\SQ\201606\20160601\ RB10.csv RB01.csv .... 目录为交易日。 按照回测的开始日期,到结束日期,循环每一天。 读取eg1(如RB1610),读取Leg2(如RB701),根据两者tick的时间优先顺序,逐一tick灌输到策略的onTick中。 """ self.capital = self.initCapital # 更新设置期初资金 if not self.dataStartDate: self.writeCtaLog(u'回测开始日期未设置。') return # RB if len(self.symbol) < 1: self.writeCtaLog(u'回测对象未设置。') return if not self.dataEndDate: self.dataEndDate = datetime.today() # 首先根据回测模式,确认要使用的数据类 if self.mode == self.BAR_MODE: self.writeCtaLog(u'本回测仅支持tick模式') return testdays = (self.dataEndDate - self.dataStartDate).days if testdays < 1: self.writeCtaLog(u'回测时间不足') return for i in range(0, testdays): testday = self.dataStartDate + timedelta(days=i) self.output(u'回测日期:{0}'.format(testday)) # 加载运行每天数据 self.__loadNotStdArbTicks2(leg1MainPath, leg2MainPath, testday, leg1Symbol, leg2Symbol) self.savingDailyData(testday, self.capital, self.maxCapital) def __loadTicksFromFile2(self, filepath, tickDate, vtSymbol): """从csv文件中UnicodeDictReader读取tick""" # 先读取数据到Dict,以日期时间为key ticks = OrderedDict() if not os.path.isfile(filepath): self.writeCtaLog(u'{0}文件不存在'.format(filepath)) return ticks dt = None csvReadFile = file(filepath, 'rb') df = pd.read_csv(filepath, encoding='gbk') df.columns = ['date', 'time', 'lastPrice', 'lastVolume', 'totalInterest', 'position', 'bidPrice1', 'bidVolume1', 'bidPrice2', 'bidVolume2', 'bidPrice3', 'bidVolume3', 'askPrice1', 'askVolume1', 'askPrice2', 'askVolume2', 'askPrice3', 'askVolume3','BS'] self.writeCtaLog(u'加载{0}'.format(filepath)) for i in range(0,len(df)): #日期, 时间, 成交价, 成交量, 总量, 属性(持仓增减), B1价, B1量, B2价, B2量, B3价, B3量, S1价, S1量, S2价, S2量, S3价, S3量, BS # 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 row = df.iloc[i].to_dict() tick = CtaTickData() tick.vtSymbol = vtSymbol tick.symbol = vtSymbol tick.date = row['date'] tick.tradingDay = tickDate.strftime('%Y%m%d') tick.time = row['time'] try: tick.datetime = datetime.strptime(tick.date + ' ' + tick.time, '%Y-%m-%d %H:%M:%S') except Exception as ex: self.writeCtaError(u'日期转换错误:{0},{1}:{2}'.format(tick.date + ' ' + tick.time, Exception, ex)) continue tick.date = tick.datetime.strftime('%Y%m%d') # 修正毫秒 if tick.datetime.replace(microsecond=0) == dt: # 与上一个tick的时间(去除毫秒后)相同,修改为500毫秒 tick.datetime = tick.datetime.replace(microsecond=500) tick.time = tick.datetime.strftime('%H:%M:%S.%f') else: tick.datetime = tick.datetime.replace(microsecond=0) tick.time = tick.datetime.strftime('%H:%M:%S.%f') dt = tick.datetime tick.lastPrice = float(row['lastPrice']) tick.volume = int(float(row['lastVolume'])) tick.bidPrice1 = float(row['bidPrice1']) # 叫买价(价格低) tick.bidVolume1 = int(float(row['bidVolume1'])) tick.askPrice1 = float(row['askPrice1']) # 叫卖价(价格高) tick.askVolume1 = int(float(row['askVolume1'])) # 排除涨停/跌停的数据 if (tick.bidPrice1 == float('1.79769E308') and tick.bidVolume1 == 0) \ or (tick.askPrice1 == float('1.79769E308') and tick.askVolume1 == 0): continue dtStr = tick.date + ' ' + tick.time if dtStr in ticks: self.writeCtaError(u'日内数据重复,异常,数据时间为:{0}'.format(dtStr)) else: ticks[dtStr] = tick return ticks def __loadNotStdArbTicks2(self, leg1MainPath, leg2MainPath, testday, leg1Symbol, leg2Symbol): self.writeCtaLog(u'加载回测日期:{0}的价差tick'.format(testday)) p = re.compile(r"([A-Z]+)[0-9]+",re.I) leg1_shortSymbol = p.match(leg1Symbol) leg2_shortSymbol = p.match(leg2Symbol) if leg1_shortSymbol is None or leg2_shortSymbol is None: self.writeCtaLog(u'{0},{1}不能正则分解'.format(leg1Symbol,leg2Symbol)) return leg1_shortSymbol = leg1_shortSymbol.group(1) leg2_shortSymbol = leg2_shortSymbol.group(1) # E:\Ticks\SQ\2014\201401\20140102\ag01_20140102.csv leg1File = u'e:\\ticks\\{0}\\{1}\\{2}\\{3}\\{4}{5}_{3}.csv' \ .format(leg1MainPath, testday.strftime('%Y'), testday.strftime('%Y%m'), testday.strftime('%Y%m%d'), leg1_shortSymbol, leg1Symbol[-2:]) if not os.path.isfile(leg1File): self.writeCtaLog(u'{0}文件不存在'.format(leg1File)) return leg2File = u'e:\\ticks\\{0}\\{1}\\{2}\\{3}\\{4}{5}_{3}.csv' \ .format(leg2MainPath,testday.strftime('%Y'), testday.strftime('%Y%m'), testday.strftime('%Y%m%d'), leg2_shortSymbol, leg2Symbol[-2:]) if not os.path.isfile(leg2File): self.writeCtaLog(u'{0}文件不存在'.format(leg2File)) return leg1Ticks = self.__loadTicksFromFile2(filepath=leg1File, tickDate=testday, vtSymbol=leg1Symbol) if len(leg1Ticks) == 0: self.writeCtaLog(u'{0}读取tick数为空'.format(leg1File)) return leg2Ticks = self.__loadTicksFromFile2(filepath=leg2File, tickDate=testday, vtSymbol=leg2Symbol) if len(leg2Ticks) == 0: self.writeCtaLog(u'{0}读取tick数为空'.format(leg1File)) return leg1_tick = None leg2_tick = None while not (len(leg1Ticks) == 0 or len(leg2Ticks) == 0): if leg1_tick is None and len(leg1Ticks) > 0: leg1_tick = leg1Ticks.popitem(last=False) if leg2_tick is None and len(leg2Ticks) > 0: leg2_tick = leg2Ticks.popitem(last=False) if leg1_tick is None and leg2_tick is not None: self.newTick(leg2_tick[1]) leg2_tick = None elif leg1_tick is not None and leg2_tick is None: self.newTick(leg1_tick[1]) leg1_tick = None elif leg1_tick is not None and leg2_tick is not None: leg1 = leg1_tick[1] leg2 = leg2_tick[1] if leg1.datetime <= leg2.datetime: self.newTick(leg1) leg1_tick = None else: self.newTick(leg2) leg2_tick = None #---------------------------------------------------------------------- def runBackTestingWithBarFile(self, filename): """运行回测(使用本地csv数据) added by IncenseLee """ self.capital = self.initCapital # 更新设置期初资金 if not filename: self.writeCtaLog(u'请指定回测数据文件') return if not self.dataStartDate: self.writeCtaLog(u'回测开始日期未设置。') return if not self.dataEndDate: self.dataEndDate = datetime.today() import os if not os.path.isfile(filename): self.writeCtaLog(u'{0}文件不存在'.format(filename)) if len(self.symbol) < 1: self.writeCtaLog(u'回测对象未设置。') return # 首先根据回测模式,确认要使用的数据类 if not self.mode == self.BAR_MODE: self.writeCtaLog(u'文件仅支持bar模式,若扩展tick模式,需要修改本方法') return self.output(u'开始回测') #self.strategy.inited = True self.strategy.onInit() self.output(u'策略初始化完成') self.strategy.trading = True self.strategy.onStart() self.output(u'策略启动完成') self.output(u'开始回放数据') import csv csvfile = file(filename,'rb') reader = csv.DictReader((line.replace('\0', '') for line in csvfile), delimiter=",") for row in reader: try: bar = CtaBarData() bar.symbol = self.symbol bar.vtSymbol = self.symbol # 从tb导出的csv文件 #bar.open = float(row['Open']) #bar.high = float(row['High']) #bar.low = float(row['Low']) #bar.close = float(row['Close']) #bar.volume = float(row['TotalVolume'])# #barEndTime = datetime.strptime(row['Date']+' ' + row['Time'], '%Y/%m/%d %H:%M:%S') # 从ricequant导出的csv文件 bar.open = float(row['open']) bar.high = float(row['high']) bar.low = float(row['low']) bar.close = float(row['close']) bar.volume = float(row['volume']) barEndTime = datetime.strptime(row['index'], '%Y-%m-%d %H:%M:%S') bar.tradingDay = row['trading_date'] # 使用Bar的开始时间作为datetime bar.datetime = barEndTime - timedelta(seconds=self.barTimeInterval) bar.date = bar.datetime.strftime('%Y-%m-%d') bar.time = bar.datetime.strftime('%H:%M:%S') if not (bar.datetime < self.dataStartDate or bar.datetime >= self.dataEndDate): self.newBar(bar) except Exception as ex: self.writeCtaLog(u'{0}:{1}'.format(Exception, ex)) continue #---------------------------------------------------------------------- def runBacktestingWithMysql(self): """运行回测(使用Mysql数据) added by IncenseLee """ self.capital = self.initCapital # 更新设置期初资金 if not self.dataStartDate: self.writeCtaLog(u'回测开始日期未设置。') return if not self.dataEndDate: self.dataEndDate = datetime.today() if len(self.symbol)<1: self.writeCtaLog(u'回测对象未设置。') return # 首先根据回测模式,确认要使用的数据类 if self.mode == self.BAR_MODE: dataClass = CtaBarData func = self.newBar else: dataClass = CtaTickData func = self.newTick self.output(u'开始回测') #self.strategy.inited = True self.strategy.onInit() self.output(u'策略初始化完成') self.strategy.trading = True self.strategy.onStart() self.output(u'策略启动完成') self.output(u'开始回放数据') # 每次获取日期周期 intervalDays = 10 for i in range (0,(self.dataEndDate - self.dataStartDate).days +1, intervalDays): d1 = self.dataStartDate + timedelta(days = i ) if (self.dataEndDate - d1).days > intervalDays: d2 = self.dataStartDate + timedelta(days = i + intervalDays -1 ) else: d2 = self.dataEndDate # 提取历史数据 self.loadDataHistoryFromMysql(self.symbol, d1, d2) self.output(u'数据日期:{0} => {1}'.format(d1,d2)) # 将逐笔数据推送 for data in self.historyData: # 记录最新的TICK数据 self.tick = self.__dataToTick(data) self.dt = self.tick.datetime # 处理限价单 self.crossLimitOrder() self.crossStopOrder() # 推送到策略引擎中 self.strategy.onTick(self.tick) # 清空历史数据 self.historyData = [] self.output(u'数据回放结束') #---------------------------------------------------------------------- def runBacktesting(self): """运行回测""" self.capital = self.initCapital # 更新设置期初资金 # 载入历史数据 #self.loadHistoryData() self.loadHistoryDataFromMongo() # 首先根据回测模式,确认要使用的数据类 if self.mode == self.BAR_MODE: dataClass = CtaBarData func = self.newBar else: dataClass = CtaTickData func = self.newTick self.output(u'开始回测') self.strategy.inited = True self.strategy.onInit() self.output(u'策略初始化完成') self.strategy.trading = True self.strategy.onStart() self.output(u'策略启动完成') self.output(u'开始回放数据') for d in self.dbCursor: data = dataClass() data.__dict__ = d func(data) self.output(u'数据回放结束') def __sendOnBarEvent(self, bar): """发送Bar的事件""" if self.eventEngine is not None: eventType = EVENT_ON_BAR + '_' + self.symbol event = Event(type_= eventType) event.dict_['data'] = bar self.eventEngine.put(event) # ---------------------------------------------------------------------- def newBar(self, bar): """新的K线""" self.bar = bar self.dt = bar.datetime self.crossLimitOrder() # 先撮合限价单 self.crossStopOrder() # 再撮合停止单 self.strategy.onBar(bar) # 推送K线到策略中 self.__sendOnBarEvent(bar) # 推送K线到事件 #---------------------------------------------------------------------- def newTick(self, tick): """新的Tick""" self.tick = tick self.dt = tick.datetime self.crossLimitOrder() self.crossStopOrder() self.strategy.onTick(tick) #---------------------------------------------------------------------- def initStrategy(self, strategyClass, setting=None): """ 初始化策略 setting是策略的参数设置,如果使用类中写好的默认设置则可以不传该参数 """ self.strategy = strategyClass(self, setting) if not self.strategy.name: self.strategy.name = self.strategy.className self.strategy.onInit() self.strategy.onStart() #---------------------------------------------------------------------- def sendOrder(self, vtSymbol, orderType, price, volume, strategy): """发单""" self.writeCtaLog(u'{0},{1},{2}@{3}'.format(vtSymbol, orderType, price, volume)) self.limitOrderCount += 1 orderID = str(self.limitOrderCount) order = VtOrderData() order.vtSymbol = vtSymbol order.price = self.roundToPriceTick(price) order.totalVolume = volume order.status = STATUS_NOTTRADED # 刚提交尚未成交 order.orderID = orderID order.vtOrderID = orderID order.orderTime = str(self.dt) # added by IncenseLee order.gatewayName = self.gatewayName # CTA委托类型映射 if orderType == CTAORDER_BUY: order.direction = DIRECTION_LONG order.offset = OFFSET_OPEN elif orderType == CTAORDER_SELL: order.direction = DIRECTION_SHORT order.offset = OFFSET_CLOSE elif orderType == CTAORDER_SHORT: order.direction = DIRECTION_SHORT order.offset = OFFSET_OPEN elif orderType == CTAORDER_COVER: order.direction = DIRECTION_LONG order.offset = OFFSET_CLOSE # modified by IncenseLee key = u'{0}.{1}'.format(order.gatewayName, orderID) # 保存到限价单字典中 self.workingLimitOrderDict[key] = order self.limitOrderDict[key] = order return key #---------------------------------------------------------------------- def cancelOrder(self, vtOrderID): """撤单""" if vtOrderID in self.workingLimitOrderDict: order = self.workingLimitOrderDict[vtOrderID] order.status = STATUS_CANCELLED order.cancelTime = str(self.dt) del self.workingLimitOrderDict[vtOrderID] def cancelOrders(self, symbol, offset=EMPTY_STRING): """撤销所有单""" # Symbol参数:指定合约的撤单; # OFFSET参数:指定Offset的撤单,缺省不填写时,为所有 self.writeCtaLog(u'从所有订单中撤销{0}\{1}'.format(offset, symbol)) for vtOrderID in self.workingLimitOrderDict.keys(): order = self.workingLimitOrderDict[vtOrderID] if offset == EMPTY_STRING: offsetCond = True else: offsetCond = order.offset == offset if order.symbol == symbol and offsetCond: self.writeCtaLog(u'撤销订单:{0},{1} {2}@{3}'.format(vtOrderID, order.direction, order.price, order.totalVolume)) order.status = STATUS_CANCELLED order.cancelTime = str(self.dt) del self.workingLimitOrderDict[vtOrderID] #---------------------------------------------------------------------- def sendStopOrder(self, vtSymbol, orderType, price, volume, strategy): """发停止单(本地实现)""" self.stopOrderCount += 1 stopOrderID = STOPORDERPREFIX + str(self.stopOrderCount) so = StopOrder() so.vtSymbol = vtSymbol so.price = self.roundToPriceTick(price) so.volume = volume so.strategy = strategy so.stopOrderID = stopOrderID so.status = STOPORDER_WAITING if orderType == CTAORDER_BUY: so.direction = DIRECTION_LONG so.offset = OFFSET_OPEN elif orderType == CTAORDER_SELL: so.direction = DIRECTION_SHORT so.offset = OFFSET_CLOSE elif orderType == CTAORDER_SHORT: so.direction = DIRECTION_SHORT so.offset = OFFSET_OPEN elif orderType == CTAORDER_COVER: so.direction = DIRECTION_LONG so.offset = OFFSET_CLOSE # 保存stopOrder对象到字典中 self.stopOrderDict[stopOrderID] = so self.workingStopOrderDict[stopOrderID] = so return stopOrderID #---------------------------------------------------------------------- def cancelStopOrder(self, stopOrderID): """撤销停止单""" # 检查停止单是否存在 if stopOrderID in self.workingStopOrderDict: so = self.workingStopOrderDict[stopOrderID] so.status = STOPORDER_CANCELLED del self.workingStopOrderDict[stopOrderID] #---------------------------------------------------------------------- def crossLimitOrder(self): """基于最新数据撮合限价单""" # 先确定会撮合成交的价格 if self.mode == self.BAR_MODE: buyCrossPrice = self.bar.low # 若买入方向限价单价格高于该价格,则会成交 sellCrossPrice = self.bar.high # 若卖出方向限价单价格低于该价格,则会成交 buyBestCrossPrice = self.bar.open # 在当前时间点前发出的买入委托可能的最优成交价 sellBestCrossPrice = self.bar.open # 在当前时间点前发出的卖出委托可能的最优成交价 vtSymbol = self.bar.vtSymbol else: buyCrossPrice = self.tick.askPrice1 sellCrossPrice = self.tick.bidPrice1 buyBestCrossPrice = self.tick.askPrice1 sellBestCrossPrice = self.tick.bidPrice1 vtSymbol = self.tick.vtSymbol # 遍历限价单字典中的所有限价单 for orderID, order in self.workingLimitOrderDict.items(): # 判断是否会成交 buyCross = order.direction == DIRECTION_LONG and order.price >= buyCrossPrice and vtSymbol == order.vtSymbol sellCross = order.direction == DIRECTION_SHORT and order.price <= sellCrossPrice and vtSymbol == order.vtSymbol # 如果发生了成交 if buyCross or sellCross: # 推送成交数据 self.tradeCount += 1 # 成交编号自增1 tradeID = str(self.tradeCount) trade = VtTradeData() trade.vtSymbol = order.vtSymbol trade.tradeID = tradeID trade.vtTradeID = tradeID trade.orderID = order.orderID trade.vtOrderID = order.orderID trade.direction = order.direction trade.offset = order.offset # 以买入为例: # 1. 假设当根K线的OHLC分别为:100, 125, 90, 110 # 2. 假设在上一根K线结束(也是当前K线开始)的时刻,策略发出的委托为限价105 # 3. 则在实际中的成交价会是100而不是105,因为委托发出时市场的最优价格是100 if buyCross: trade.price = min(order.price, buyBestCrossPrice) self.strategy.pos += order.totalVolume else: trade.price = max(order.price, sellBestCrossPrice) self.strategy.pos -= order.totalVolume trade.volume = order.totalVolume trade.tradeTime = str(self.dt) trade.dt = self.dt self.strategy.onTrade(trade) self.tradeDict[tradeID] = trade self.writeCtaLog(u'TradeId:{0}'.format(tradeID)) # 推送委托数据 order.tradedVolume = order.totalVolume order.status = STATUS_ALLTRADED self.strategy.onOrder(order) # 从字典中删除该限价单 try: del self.workingLimitOrderDict[orderID] except Exception as ex: self.writeCtaError(u'{0}:{1}'.format(Exception, ex)) # 实时计算模式 if self.calculateMode == self.REALTIME_MODE: self.realtimeCalculate() #---------------------------------------------------------------------- def crossStopOrder(self): """基于最新数据撮合停止单""" # 先确定会撮合成交的价格,这里和限价单规则相反 if self.mode == self.BAR_MODE: buyCrossPrice = self.bar.high # 若买入方向停止单价格低于该价格,则会成交 sellCrossPrice = self.bar.low # 若卖出方向限价单价格高于该价格,则会成交 bestCrossPrice = self.bar.open # 最优成交价,买入停止单不能低于,卖出停止单不能高于 vtSymbol = self.bar.vtSymbol else: buyCrossPrice = self.tick.lastPrice sellCrossPrice = self.tick.lastPrice bestCrossPrice = self.tick.lastPrice vtSymbol = self.tick.vtSymbol # 遍历停止单字典中的所有停止单 for stopOrderID, so in self.workingStopOrderDict.items(): # 判断是否会成交 buyCross = so.direction == DIRECTION_LONG and so.price <= buyCrossPrice and vtSymbol == so.vtSymbol sellCross = so.direction == DIRECTION_SHORT and so.price >= sellCrossPrice and vtSymbol == so.vtSymbol # 如果发生了成交 if buyCross or sellCross: # 推送成交数据 self.tradeCount += 1 # 成交编号自增1 tradeID = str(self.tradeCount) trade = VtTradeData() trade.vtSymbol = so.vtSymbol trade.tradeID = tradeID trade.vtTradeID = tradeID if buyCross: self.strategy.pos += so.volume trade.price = max(bestCrossPrice, so.price) else: self.strategy.pos -= so.volume trade.price = min(bestCrossPrice, so.price) self.limitOrderCount += 1 orderID = str(self.limitOrderCount) trade.orderID = orderID trade.vtOrderID = orderID trade.direction = so.direction trade.offset = so.offset trade.volume = so.volume trade.tradeTime = str(self.dt) trade.dt = self.dt self.strategy.onTrade(trade) self.tradeDict[tradeID] = trade # 推送委托数据 so.status = STOPORDER_TRIGGERED order = VtOrderData() order.vtSymbol = so.vtSymbol order.symbol = so.vtSymbol order.orderID = orderID order.vtOrderID = orderID order.direction = so.direction order.offset = so.offset order.price = so.price order.totalVolume = so.volume order.tradedVolume = so.volume order.status = STATUS_ALLTRADED order.orderTime = trade.tradeTime self.strategy.onOrder(order) self.limitOrderDict[orderID] = order # 从字典中删除该限价单 if stopOrderID in self.workingStopOrderDict: del self.workingStopOrderDict[stopOrderID] # 若采用实时计算净值 if self.calculateMode == self.REALTIME_MODE: self.realtimeCalculate() #---------------------------------------------------------------------- def insertData(self, dbName, collectionName, data): """考虑到回测中不允许向数据库插入数据,防止实盘交易中的一些代码出错""" pass #---------------------------------------------------------------------- def loadBar(self, dbName, collectionName, startDate): """直接返回初始化数据列表中的Bar""" return self.initData #---------------------------------------------------------------------- def loadTick(self, dbName, collectionName, startDate): """直接返回初始化数据列表中的Tick""" return self.initData #---------------------------------------------------------------------- def writeCtaLog(self, content): """记录日志""" #log = str(self.dt) + ' ' + content #self.logList.append(log) # 写入本地log日志 logging.info(content) def writeCtaError(self, content): """记录异常""" self.output(content) self.writeCtaLog(content) #---------------------------------------------------------------------- def output(self, content): """输出内容""" print str(datetime.now()) + "\t" + content def realtimeCalculate(self): """实时计算交易结果""" resultDict = OrderedDict() # 交易结果记录 longTrade = [] # 未平仓的多头交易 shortTrade = [] # 未平仓的空头交易 longid = EMPTY_STRING shortid = EMPTY_STRING no_match_shortTrade = False no_match_longTrade = False # 对交易记录逐一处理 for tradeid in self.tradeDict.keys(): trade = self.tradeDict[tradeid] # 多头交易 if trade.direction == DIRECTION_LONG: # 存在空单 if len(shortTrade)>0: # 检查是否存在与Symbol一致的空单 pop_indexs = [i for i, val in enumerate(shortTrade) if val.vtSymbol == trade.vtSymbol] if len(pop_indexs) < 1: #self.output(u'空头交易清单,没有{0}的空单'.format(trade.vtSymbol)) no_match_shortTrade = True # 如果尚无空单 if len(shortTrade) == 0 or no_match_shortTrade: #self.output(u'{0}多开:{1},{2}'.format(trade.vtSymbol, trade.volume, trade.price)) #self.writeCtaLog(u'{0}多开:{1},{2}'.format(trade.vtSymbol, trade.volume, trade.price)) longTrade.append(trade) longid = tradeid no_match_shortTrade = False # 当前多头交易为平空 else: gId = tradeid # 交易组(多个平仓数为一组) gr = None # 组合的交易结果 coverVolume = trade.volume while coverVolume > 0: if len(shortTrade)==0: self.writeCtaError(u'异常,没有开空仓的数据') break pop_indexs = [i for i, val in enumerate(shortTrade) if val.vtSymbol == trade.vtSymbol] if len(pop_indexs) < 1: self.writeCtaError(u'没有对应的symbol:{0}开空仓数据'.format(trade.vtSymbol)) break pop_index = pop_indexs[0] # 从未平仓的空头交易 entryTrade = shortTrade.pop(pop_index) # 开空volume,不大于平仓volume if coverVolume >= entryTrade.volume: self.writeCtaLog(u'coverVolume:{0} >= entryTrade.volume:{1}'.format(coverVolume, entryTrade.volume)) coverVolume = coverVolume - entryTrade.volume #self.output(u'{0}空平:{1},{2}'.format(entryTrade.vtSymbol, entryTrade.volume, trade.price)) #self.writeCtaLog(u'{0}空平:{1},{2}'.format(entryTrade.vtSymbol, entryTrade.volume, trade.price)) result = TradingResult(entryPrice=entryTrade.price, entryDt=entryTrade.dt, exitPrice=trade.price, exitDt=trade.dt, volume=-entryTrade.volume, rate=self.rate, slippage=self.slippage, size=self.size, groupId=gId, fixcommission=self.fixCommission) t = {} t['vtSymbol'] = entryTrade.vtSymbol t['OpenTime'] = entryTrade.tradeTime t['OpenPrice'] = entryTrade.price t['Direction'] = u'Short' t['CloseTime'] = trade.tradeTime t['ClosePrice'] = trade.price t['Volume'] = entryTrade.volume t['Profit'] = result.pnl self.exportTradeList.append(t) msg = u'Gid:{0} {1}[{2}:开空tid={3}:{4}]-[{5}.平空tid={6},{7},vol:{8}],净盈亏:{9}'\ .format(gId, entryTrade.vtSymbol, entryTrade.tradeTime, shortid, entryTrade.price, trade.tradeTime, tradeid, trade.price, entryTrade.volume, result.pnl) self.output(msg) self.writeCtaLog(msg) if type(gr) == type(None): if coverVolume > 0: # 属于组合 gr = copy.deepcopy(result) # 删除开空交易单 del self.tradeDict[entryTrade.tradeID] else: # 不属于组合 resultDict[entryTrade.dt] = result # 删除平空交易单, del self.tradeDict[trade.tradeID] # 删除开空交易单 del self.tradeDict[entryTrade.tradeID] else: # 更新组合的数据 gr.turnover = gr.turnover + result.turnover gr.commission = gr.commission + result.commission gr.slippage = gr.slippage + result.slippage gr.pnl = gr.pnl + result.pnl # 删除开空交易单 del self.tradeDict[entryTrade.tradeID] # 所有仓位平完 if coverVolume == 0: gr.volume = trade.volume resultDict[entryTrade.dt] = gr # 删除平空交易单, del self.tradeDict[trade.tradeID] # 开空volume,大于平仓volume,需要更新减少tradeDict的数量。 else: self.writeCtaLog(u'Short volume:{0} > Cover volume:{1},需要更新减少tradeDict的数量。'.format(entryTrade.volume,coverVolume)) shortVolume = entryTrade.volume - coverVolume result = TradingResult(entryPrice=entryTrade.price, entryDt=entryTrade.dt, exitPrice=trade.price, exitDt=trade.dt, volume=-coverVolume, rate=self.rate, slippage=self.slippage, size=self.size, groupId=gId, fixcommission=self.fixCommission) t = {} t['vtSymbol'] = entryTrade.vtSymbol t['OpenTime'] = entryTrade.tradeTime t['OpenPrice'] = entryTrade.price t['Direction'] = u'Short' t['CloseTime'] = trade.tradeTime t['ClosePrice'] = trade.price t['Volume'] = coverVolume t['Profit'] = result.pnl self.exportTradeList.append(t) msg = u'Gid:{0} {1}[{2}:开空tid={3}:{4}]-[{5}.平空tid={6},{7},vol:{8}],净盈亏:{9}'\ .format(gId, entryTrade.vtSymbol, entryTrade.tradeTime, shortid, entryTrade.price, trade.tradeTime, tradeid, trade.price, coverVolume, result.pnl) self.output(msg) self.writeCtaLog(msg) # 更新(减少)开仓单的volume,重新推进开仓单列表中 entryTrade.volume = shortVolume shortTrade.append(entryTrade) coverVolume = 0 if type(gr) == type(None): resultDict[entryTrade.dt] = result else: # 更新组合的数据 gr.turnover = gr.turnover + result.turnover gr.commission = gr.commission + result.commission gr.slippage = gr.slippage + result.slippage gr.pnl = gr.pnl + result.pnl gr.volume = trade.volume resultDict[entryTrade.dt] = gr # 删除平空交易单, del self.tradeDict[trade.tradeID] if type(gr) != type(None): self.writeCtaLog(u'组合净盈亏:{0}'.format(gr.pnl)) self.writeCtaLog(u'-------------') # 空头交易 else: if len(longTrade) > 0: pop_indexs = [i for i, val in enumerate(longTrade) if val.vtSymbol == trade.vtSymbol] if len(pop_indexs) < 1: #self.output(u'多头交易清单,没有{0}的多单'.format(trade.vtSymbol)) no_match_longTrade = True # 如果尚无多单 if len(longTrade) == 0 or no_match_longTrade: #self.output(u'{0}空开:{1},{2}'.format(trade.vtSymbol, trade.volume, trade.price)) #self.writeCtaLog(u'{0}空开:{1},{2}'.format(trade.vtSymbol, trade.volume, trade.price)) shortTrade.append(trade) shortid = tradeid no_match_longTrade = False # 当前空头交易为平多 else: gId = tradeid # 交易组(多个平仓数为一组) s gr = None # 组合的交易结果 sellVolume = trade.volume while sellVolume > 0: if len(longTrade) == 0: self.writeCtaError(u'异常,没有开多单') break pop_indexs = [i for i, val in enumerate(longTrade) if val.vtSymbol == trade.vtSymbol] if len(pop_indexs) < 1: self.writeCtaError(u'没有对应的symbol{0}开多仓数据,'.format(trade.vtSymbol)) break pop_index = pop_indexs[0] entryTrade = longTrade.pop(pop_index) # 开多volume,不大于平仓volume if sellVolume >= entryTrade.volume: self.writeCtaLog(u'{0}Sell Volume:{1} >= Entry Volume:{2}'.format(entryTrade.vtSymbol, sellVolume, entryTrade.volume)) sellVolume = sellVolume - entryTrade.volume #self.output(u'{0}多平:{1},{2}'.format(entryTrade.vtSymbol, entryTrade.volume, trade.price)) #self.writeCtaLog(u'{0}多平:{1},{2}'.format(entryTrade.vtSymbol, entryTrade.volume, trade.price)) result = TradingResult(entryPrice=entryTrade.price, entryDt=entryTrade.dt, exitPrice=trade.price, exitDt=trade.dt, volume=entryTrade.volume, rate=self.rate, slippage=self.slippage, size=self.size, groupId=gId, fixcommission=self.fixCommission) t = {} t['vtSymbol'] = entryTrade.vtSymbol t['OpenTime'] = entryTrade.tradeTime t['OpenPrice'] = entryTrade.price t['Direction'] = u'Long' t['CloseTime'] = trade.tradeTime t['ClosePrice'] = trade.price t['Volume'] = entryTrade.volume t['Profit'] = result.pnl self.exportTradeList.append(t) msg = u'Gid:{0} {1}[{2}:开多tid={3}:{4}]-[{5}.平多tid={6},{7},vol:{8}],净盈亏:{9}'\ .format(gId, entryTrade.vtSymbol, entryTrade.tradeTime, longid, entryTrade.price, trade.tradeTime, tradeid, trade.price, entryTrade.volume, result.pnl) self.output(msg) self.writeCtaLog(msg) if type(gr) == type(None): if sellVolume > 0: # 属于组合 gr = copy.deepcopy(result) # 删除开多交易单 del self.tradeDict[entryTrade.tradeID] else: # 不属于组合 resultDict[entryTrade.dt] = result # 删除平多交易单, del self.tradeDict[trade.tradeID] # 删除开多交易单 del self.tradeDict[entryTrade.tradeID] else: # 更新组合的数据 gr.turnover = gr.turnover + result.turnover gr.commission = gr.commission + result.commission gr.slippage = gr.slippage + result.slippage gr.pnl = gr.pnl + result.pnl # 删除开多交易单 del self.tradeDict[entryTrade.tradeID] if sellVolume == 0: gr.volume = trade.volume resultDict[entryTrade.dt] = gr # 删除平多交易单, del self.tradeDict[trade.tradeID] # 开多volume,大于平仓volume,需要更新减少tradeDict的数量。 else: longVolume = entryTrade.volume -sellVolume self.writeCtaLog(u'Long Volume:{0} > sell Volume:{1}'.format(entryTrade.volume,sellVolume)) result = TradingResult(entryPrice=entryTrade.price, entryDt=entryTrade.dt, exitPrice=trade.price, exitDt=trade.dt, volume=sellVolume, rate=self.rate, slippage=self.slippage, size=self.size, groupId=gId, fixcommission=self.fixCommission) t = {} t['vtSymbol'] = entryTrade.vtSymbol t['OpenTime'] = entryTrade.tradeTime t['OpenPrice'] = entryTrade.price t['Direction'] = u'Long' t['CloseTime'] = trade.tradeTime t['ClosePrice'] = trade.price t['Volume'] = sellVolume t['Profit'] = result.pnl self.exportTradeList.append(t) self.writeCtaLog(u'Gid:{0} {1}[{2}:开多tid={3}:{4}]-[{5}.平多tid={6},{7},vol:{8}],净盈亏:{9}' .format(gId, entryTrade.vtSymbol,entryTrade.tradeTime, longid, entryTrade.price, trade.tradeTime, tradeid, trade.price, sellVolume, result.pnl)) # 减少开多volume,重新推进开多单列表中 entryTrade.volume = longVolume longTrade.append(entryTrade) sellVolume = 0 if type(gr) == type(None): resultDict[entryTrade.dt] = result else: # 更新组合的数据 gr.turnover = gr.turnover + result.turnover gr.commission = gr.commission + result.commission gr.slippage = gr.slippage + result.slippage gr.pnl = gr.pnl + result.pnl gr.volume = trade.volume resultDict[entryTrade.dt] = gr # 删除平多交易单, del self.tradeDict[trade.tradeID] if type(gr) != type(None): self.writeCtaLog(u'组合净盈亏:{0}'.format(gr.pnl)) self.writeCtaLog(u'-------------') # 计算仓位比例 occupyMoney = EMPTY_FLOAT occupyLongVolume = EMPTY_INT occupyShortVolume = EMPTY_INT if len(longTrade) > 0: for t in longTrade: occupyMoney += t.price * abs(t.volume) * self.size * self.margin_rate occupyLongVolume += abs(t.volume) if len(shortTrade) > 0: for t in shortTrade: occupyMoney += t.price * abs(t.volume) * self.size * self.margin_rate occupyShortVolume += (t.volume) self.output(u'occupyLongVolume:{0},occupyShortVolume:{1}'.format(occupyLongVolume,occupyShortVolume)) self.writeCtaLog(u'occupyLongVolume:{0},occupyShortVolume:{1}'.format(occupyLongVolume, occupyShortVolume)) # 最大持仓 self.maxVolume = max(self.maxVolume, occupyLongVolume + occupyShortVolume) self.avaliable = self.capital - occupyMoney self.percent = round(float(occupyMoney * 100 / self.capital), 2) # 检查是否有平交易 if not resultDict: msg = u'' if len(longTrade) > 0: msg += u'持多仓{0},'.format(occupyLongVolume) if len(shortTrade) > 0: msg += u'持空仓{0},'.format(occupyShortVolume) msg += u'资金占用:{0},仓位:{1}'.format(occupyMoney, self.percent) self.output(msg) self.writeCtaLog(msg) return # 对交易结果汇总统计 for time, result in resultDict.items(): if result.pnl > 0: self.winningResult += 1 self.totalWinning += result.pnl else: self.losingResult += 1 self.totalLosing += result.pnl self.capital += result.pnl self.maxCapital = max(self.capital, self.maxCapital) #self.maxVolume = max(self.maxVolume, result.volume) drawdown = self.capital - self.maxCapital drawdownRate = round(float(drawdown*100/self.maxCapital),4) self.pnlList.append(result.pnl) self.timeList.append(time) self.capitalList.append(self.capital) self.drawdownList.append(drawdown) self.drawdownRateList.append(drawdownRate) self.totalResult += 1 self.totalTurnover += result.turnover self.totalCommission += result.commission self.totalSlippage += result.slippage self.output(u'[{5}],{6} Vol:{0},盈亏:{1},回撤:{2}/{3},权益:{4}'. format(abs(result.volume), result.pnl, drawdown, drawdownRate, self.capital, result.groupId, time)) # 重新计算一次avaliable self.avaliable = self.capital - occupyMoney self.percent = round(float(occupyMoney * 100 / self.capital), 2) def savingDailyData(self, d, c, m): """保存每日数据""" dict = {} dict['date'] = d.strftime('%Y/%m/%d') dict['capital'] = c dict['maxCapital'] = m dict['rate'] = c / self.initCapital self.dailyList.append(dict) # ---------------------------------------------------------------------- def calculateBacktestingResult(self): """ 计算回测结果 Modified by Incense Lee 增加了支持逐步加仓的计算: 例如,前面共有6次开仓(1手开仓+5次加仓,每次1手),平仓只有1次(六手)。那么,交易次数是6次(开仓+平仓)。 暂不支持每次加仓数目不一致的核对(因为比较复杂) 增加组合的支持。(组合中,仍然按照1手逐步加仓和多手平仓的方法,即使启用了复利模式,也仍然按照这个规则,只是在计算收益时才乘以系数) 增加期初权益,每次交易后的权益,可用资金,仓位比例。 """ self.output(u'计算回测结果') # 首先基于回测后的成交记录,计算每笔交易的盈亏 resultDict = OrderedDict() # 交易结果记录 longTrade = [] # 未平仓的多头交易 shortTrade = [] # 未平仓的空头交易 i = 1 tradeUnit = 1 longid = EMPTY_STRING shortid = EMPTY_STRING for tradeid in self.tradeDict.keys(): trade = self.tradeDict[tradeid] # 多头交易 if trade.direction == DIRECTION_LONG: # 如果尚无空头交易 if not shortTrade: longTrade.append(trade) longid = tradeid # 当前多头交易为平空 else: gId = i # 交易组(多个平仓数为一组) gt = 1 # 组合的交易次数 gr = None # 组合的交易结果 if trade.volume >tradeUnit: self.writeCtaLog(u'平仓数{0},组合编号:{1}'.format(trade.volume,gId)) gt = int(trade.volume/tradeUnit) for tv in range(gt): entryTrade = shortTrade.pop(0) result = TradingResult(entryPrice=entryTrade.price, entryDt=entryTrade.dt, exitPrice=trade.price, exitDt=trade.dt, volume=-tradeUnit, rate=self.rate, slippage=self.slippage, size=self.size, groupId=gId, fixcommission=self.fixCommission) if tv == 0: if gt == 1: resultDict[entryTrade.dt] = result else: gr = copy.deepcopy(result) else: gr.turnover = gr.turnover + result.turnover gr.commission = gr.commission + result.commission gr.slippage = gr.slippage + result.slippage gr.pnl = gr.pnl + result.pnl if tv == gt -1: gr.volume = trade.volume resultDict[entryTrade.dt] = gr t = {} t['OpenTime'] = entryTrade.tradeTime.strftime('%Y/%m/%d %H:%M:%S') t['OpenPrice'] = entryTrade.price t['Direction'] = u'Short' t['CloseTime'] = trade.tradeTime.strftime('%Y/%m/%d %H:%M:%S') t['ClosePrice'] = trade.price t['Volume'] = tradeUnit t['Profit'] = result.pnl self.exportTradeList.append(t) self.writeCtaLog(u'{9}@{6} [{7}:开空{0},short:{1}]-[{8}:平空{2},cover:{3},vol:{4}],净盈亏:{5}' .format(entryTrade.tradeTime, entryTrade.price, trade.tradeTime, trade.price, tradeUnit, result.pnl, i, shortid, tradeid, gId)) i = i+1 if type(gr) != type(None): self.writeCtaLog(u'组合净盈亏:{0}'.format(gr.pnl)) self.writeCtaLog(u'-------------') # 空头交易 else: # 如果尚无多头交易 if not longTrade: shortTrade.append(trade) shortid = tradeid # 当前空头交易为平多 else: gId = i # 交易组(多个平仓数为一组) gt = 1 # 组合的交易次数 gr = None # 组合的交易结果 if trade.volume >tradeUnit: self.writeCtaLog(u'平仓数{0},组合编号:{1}'.format(trade.volume,gId)) gt = int(trade.volume/tradeUnit) for tv in range(gt): entryTrade = longTrade.pop(0) result = TradingResult(entryPrice=entryTrade.price, entryDt=entryTrade.dt, exitPrice=trade.price, exitDt=trade.dt, volume=tradeUnit, rate=self.rate, slippage=self.slippage, size=self.size, groupId=gId, fixcommission=self.fixCommission) if tv == 0: if gt==1: resultDict[entryTrade.dt] = result else: gr = copy.deepcopy(result) else: gr.turnover = gr.turnover + result.turnover gr.commission = gr.commission + result.commission gr.slippage = gr.slippage + result.slippage gr.pnl = gr.pnl + result.pnl if tv == gt -1: gr.volume = trade.volume resultDict[entryTrade.dt] = gr t = {} t['OpenTime'] = entryTrade.tradeTime.strftime('%Y/%m/%d %H:%M:%S') t['OpenPrice'] = entryTrade.price t['Direction'] = u'Long' t['CloseTime'] = trade.tradeTime.strftime('%Y/%m/%d %H:%M:%S') t['ClosePrice'] = trade.price t['Volume'] = tradeUnit t['Profit'] = result.pnl self.exportTradeList.append(t) self.writeCtaLog(u'{9}@{6} [{7}:开多{0},buy:{1}]-[{8}.平多{2},sell:{3},vol:{4}],净盈亏:{5}' .format(entryTrade.tradeTime, entryTrade.price, trade.tradeTime,trade.price, tradeUnit, result.pnl, i, longid, tradeid, gId)) i = i+1 if type(gr) != type(None): self.writeCtaLog(u'组合净盈亏:{0}'.format(gr.pnl)) self.writeCtaLog(u'-------------') # 检查是否有交易 if not resultDict: self.output(u'无交易结果') return {} # 然后基于每笔交易的结果,我们可以计算具体的盈亏曲线和最大回撤等 """ initCapital = 40000 # 期初资金 capital = initCapital # 资金 maxCapital = initCapital # 资金最高净值 maxPnl = 0 # 最高盈利 minPnl = 0 # 最大亏损 maxVolume = 1 # 最大仓位数 wins = 0 totalResult = 0 # 总成交数量 totalTurnover = 0 # 总成交金额(合约面值) totalCommission = 0 # 总手续费 totalSlippage = 0 # 总滑点 timeList = [] # 时间序列 pnlList = [] # 每笔盈亏序列 capitalList = [] # 盈亏汇总的时间序列 drawdownList = [] # 回撤的时间序列 drawdownRateList = [] # 最大回撤比例的时间序列 """ drawdown = 0 # 回撤 compounding = 1 # 简单的复利基数(如果资金是期初资金的x倍,就扩大开仓比例,例如3w开1手,6w开2手,12w开4手) for time, result in resultDict.items(): # 是否使用简单复利 if self.usageCompounding: compounding = int(self.capital/self.initCapital) if result.pnl > 0: self.winningResult += 1 self.totalWinning += result.pnl else: self.losingResult += 1 self.totalLosing += result.pnl self.capital += result.pnl*compounding self.maxCapital = max(self.capital, self.maxCapital) self.maxVolume = max(self.maxVolume, result.volume*compounding) drawdown = self.capital - self.maxCapital drawdownRate = round(float(drawdown*100/self.maxCapital),4) self.pnlList.append(result.pnl*compounding) self.timeList.append(time) self.capitalList.append(self.capningital) self.drawdownList.append(drawdown) self.drawdownRateList.append(drawdownRate) self.totalResult += 1 self.totalTurnover += result.turnover*compounding self.totalCommission += result.commission*compounding self.totalSlippage += result.slippage*compounding # --------------------------------------------------------------------- def exportTradeResult(self): """到处回测结果表""" if not self.exportTradeList: return csvOutputFile = os.path.abspath(os.path.join(os.path.dirname(__file__), 'TestLogs', 'TradeList_{0}.csv'.format(datetime.now().strftime('%Y%m%d_%H%M')))) import csv csvWriteFile = file(csvOutputFile, 'wb') fieldnames = ['vtSymbol','OpenTime', 'OpenPrice', 'Direction', 'CloseTime', 'ClosePrice', 'Volume', 'Profit'] writer = csv.DictWriter(f=csvWriteFile, fieldnames=fieldnames, dialect='excel') writer.writeheader() for row in self.exportTradeList: writer.writerow(row) if not self.dailyList: return csvOutputFile2 = os.path.abspath(os.path.join(os.path.dirname(__file__), 'TestLogs', 'DailyList_{0}.csv'.format(datetime.now().strftime('%Y%m%d_%H%M')))) csvWriteFile2 = file(csvOutputFile2, 'wb') fieldnames = ['date','capital', 'maxCapital','rate'] writer2 = csv.DictWriter(f=csvWriteFile2, fieldnames=fieldnames, dialect='excel') writer2.writeheader() for row in self.dailyList: writer2.writerow(row) def getResult(self): # 返回回测结果 d = {} d['initCapital'] = self.initCapital d['capital'] = self.capital - self.initCapital d['maxCapital'] = self.maxCapital if len(self.pnlList) == 0: return {} d['maxPnl'] = max(self.pnlList) d['minPnl'] = min(self.pnlList) d['maxVolume'] = self.maxVolume d['totalResult'] = self.totalResult d['totalTurnover'] = self.totalTurnover d['totalCommission'] = self.totalCommission d['totalSlippage'] = self.totalSlippage d['timeList'] = self.timeList d['pnlList'] = self.pnlList d['capitalList'] = self.capitalList d['drawdownList'] = self.drawdownList d['drawdownRateList'] = self.drawdownRateList d['winningRate'] = round(100 * self.winningResult / len(self.pnlList), 4) averageWinning = 0 # 这里把数据都初始化为0 averageLosing = 0 profitLossRatio = 0 if self.winningResult: averageWinning = self.totalWinning / self.winningResult # 平均每笔盈利 if self.losingResult: averageLosing = self.totalLosing / self.losingResult # 平均每笔亏损 if averageLosing: profitLossRatio = -averageWinning / averageLosing # 盈亏比 d['averageWinning'] = averageWinning d['averageLosing'] = averageLosing d['profitLossRatio'] = profitLossRatio return d #---------------------------------------------------------------------- def showBacktestingResult(self): """显示回测结果""" if self.calculateMode != self.REALTIME_MODE: self.calculateBacktestingResult() d = self.getResult() if len(d) == 0: self.output(u'无交易结果') return # 导出交易清单 self.exportTradeResult() # 输出 self.output('-' * 30) self.output(u'第一笔交易:\t%s' % d['timeList'][0]) self.output(u'最后一笔交易:\t%s' % d['timeList'][-1]) self.output(u'总交易次数:\t%s' % formatNumber(d['totalResult'])) self.output(u'期初资金:\t%s' % formatNumber(d['initCapital'])) self.output(u'总盈亏:\t%s' % formatNumber(d['capital'])) self.output(u'资金最高净值:\t%s' % formatNumber(d['maxCapital'])) self.output(u'每笔最大盈利:\t%s' % formatNumber(d['maxPnl'])) self.output(u'每笔最大亏损:\t%s' % formatNumber(d['minPnl'])) self.output(u'净值最大回撤: \t%s' % formatNumber(min(d['drawdownList']))) self.output(u'净值最大回撤率: \t%s' % formatNumber(min(d['drawdownRateList']))) self.output(u'胜率:\t%s' % formatNumber(d['winningRate'])) self.output(u'盈利交易平均值\t%s' % formatNumber(d['averageWinning'])) self.output(u'亏损交易平均值\t%s' % formatNumber(d['averageLosing'])) self.output(u'盈亏比:\t%s' % formatNumber(d['profitLossRatio'])) self.output(u'最大持仓:\t%s' % formatNumber(d['maxVolume'])) self.output(u'平均每笔盈利:\t%s' %formatNumber(d['capital']/d['totalResult'])) self.output(u'平均每笔滑点成本:\t%s' %formatNumber(d['totalSlippage']/d['totalResult'])) self.output(u'平均每笔佣金:\t%s' %formatNumber(d['totalCommission']/d['totalResult'])) # 绘图 import matplotlib.pyplot as plt import numpy as np try: import seaborn as sns # 如果安装了seaborn则设置为白色风格 sns.set_style('whitegrid') except ImportError: pass pCapital = plt.subplot(4, 1, 1) pCapital.set_ylabel("capital") pCapital.plot(d['capitalList'], color='r', lw=0.8) pDD = plt.subplot(4, 1, 2) pDD.set_ylabel("DD") pDD.bar(range(len(d['drawdownList'])), d['drawdownList'], color='g') pPnl = plt.subplot(4, 1, 3) pPnl.set_ylabel("pnl") pPnl.hist(d['pnlList'], bins=50, color='c') """ pPos = plt.subplot(4, 1, 4) pPos.set_ylabel("Position") if d['posList'][-1] == 0: del d['posList'][-1] tradeTimeIndex = [item.strftime("%m/%d %H:%M:%S") for item in d['tradeTimeList']] xindex = np.arange(0, len(tradeTimeIndex), np.int(len(tradeTimeIndex)/10)) tradeTimeIndex = map(lambda i: tradeTimeIndex[i], xindex) pPos.plot(d['posList'], color='k', drawstyle='steps-pre') pPos.set_ylim(-1.2, 1.2) plt.sca(pPos) """ plt.tight_layout() #plt.xticks(xindex, tradeTimeIndex, rotation=30) # 旋转15 plt.show() #---------------------------------------------------------------------- def putStrategyEvent(self, name): """发送策略更新事件,回测中忽略""" pass #---------------------------------------------------------------------- def runOptimization(self, strategyClass, optimizationSetting): """优化参数""" # 获取优化设置 settingList = optimizationSetting.generateSetting() targetName = optimizationSetting.optimizeTarget # 检查参数设置问题 if not settingList or not targetName: self.output(u'优化设置有问题,请检查') # 遍历优化 resultList = [] for setting in settingList: self.clearBacktestingResult() self.output('-' * 30) self.output('setting: %s' %str(setting)) self.initStrategy(strategyClass, setting) self.runBacktesting() d = self.calculateBacktestingResult() try: targetValue = d[targetName] except KeyError: targetValue = 0 resultList.append(([str(setting)], targetValue)) # 显示结果 resultList.sort(reverse=True, key=lambda result:result[1]) self.output('-' * 30) self.output(u'优化结果:') for result in resultList: self.output(u'%s: %s' %(result[0], result[1])) return result #---------------------------------------------------------------------- def clearBacktestingResult(self): """清空之前回测的结果""" # 清空限价单相关 self.limitOrderCount = 0 self.limitOrderDict.clear() self.workingLimitOrderDict.clear() # 清空停止单相关 self.stopOrderCount = 0 self.stopOrderDict.clear() self.workingStopOrderDict.clear() # 清空成交相关 self.tradeCount = 0 self.tradeDict.clear() #---------------------------------------------------------------------- def runParallelOptimization(self, strategyClass, optimizationSetting): """并行优化参数""" # 获取优化设置 settingList = optimizationSetting.generateSetting() targetName = optimizationSetting.optimizeTarget # 检查参数设置问题 if not settingList or not targetName: self.output(u'优化设置有问题,请检查') # 多进程优化,启动一个对应CPU核心数量的进程池 pool = multiprocessing.Pool(multiprocessing.cpu_count()) l = [] for setting in settingList: l.append(pool.apply_async(optimize, (strategyClass, setting, targetName, self.mode, self.startDate, self.initDays, self.endDate, self.slippage, self.rate, self.size, self.dbName, self.symbol))) pool.close() pool.join() # 显示结果 resultList = [res.get() for res in l] resultList.sort(reverse=True, key=lambda result:result[1]) self.output('-' * 30) self.output(u'优化结果:') for result in resultList: self.output(u'%s: %s' %(result[0], result[1])) #---------------------------------------------------------------------- def roundToPriceTick(self, price): """取整价格到合约最小价格变动""" if not self.priceTick: return price newPrice = round(price/self.priceTick, 0) * self.priceTick return newPrice ######################################################################## class TradingResult(object): """每笔交易的结果""" #---------------------------------------------------------------------- def __init__(self, entryPrice,entryDt, exitPrice,exitDt,volume, rate, slippage, size, groupId, fixcommission=EMPTY_FLOAT): """Constructor""" self.entryPrice = entryPrice # 开仓价格 self.exitPrice = exitPrice # 平仓价格 self.entryDt = entryDt # 开仓时间datetime self.exitDt = exitDt # 平仓时间 self.volume = volume # 交易数量(+/-代表方向) self.groupId = groupId # 主交易ID(针对多手平仓) self.turnover = (self.entryPrice + self.exitPrice) * size * abs(volume) # 成交金额 if fixcommission: self.commission = fixcommission * self.volume else: self.commission = self.turnover * rate # 手续费成本 self.slippage = slippage * 2 * size * abs(volume) # 滑点成本 self.pnl = ((self.exitPrice - self.entryPrice) * volume * size - self.commission - self.slippage) # 净盈亏 ######################################################################## class OptimizationSetting(object): """优化设置""" #---------------------------------------------------------------------- def __init__(self): """Constructor""" self.paramDict = OrderedDict() self.optimizeTarget = '' # 优化目标字段 #---------------------------------------------------------------------- def addParameter(self, name, start, end=None, step=None): """增加优化参数""" if end is None and step is None: self.paramDict[name] = [start] return if end < start: print u'参数起始点必须不大于终止点' return if step <= 0: print u'参数布进必须大于0' return l = [] param = start while param <= end: l.append(param) param += step self.paramDict[name] = l #---------------------------------------------------------------------- def generateSetting(self): """生成优化参数组合""" # 参数名的列表 nameList = self.paramDict.keys() paramList = self.paramDict.values() # 使用迭代工具生产参数对组合 productList = list(product(*paramList)) # 把参数对组合打包到一个个字典组成的列表中 settingList = [] for p in productList: d = dict(zip(nameList, p)) settingList.append(d) return settingList #---------------------------------------------------------------------- def setOptimizeTarget(self, target): """设置优化目标字段""" self.optimizeTarget = target #---------------------------------------------------------------------- def formatNumber(n): """格式化数字到字符串""" rn = round(n, 2) # 保留两位小数 return format(rn, ',') # 加上千分符 #---------------------------------------------------------------------- def optimize(strategyClass, setting, targetName, mode, startDate, initDays, endDate, slippage, rate, size, dbName, symbol): """多进程优化时跑在每个进程中运行的函数""" engine = BacktestingEngine() engine.setBacktestingMode(mode) engine.setStartDate(startDate, initDays) engine.setEndDate(endDate) engine.setSlippage(slippage) engine.setRate(rate) engine.setSize(size) engine.setDatabase(dbName, symbol) engine.initStrategy(strategyClass, setting) engine.runBacktesting() d = engine.calculateBacktestingResult() try: targetValue = d[targetName] except KeyError: targetValue = 0 return (str(setting), targetValue) if __name__ == '__main__': # 以下内容是一段回测脚本的演示,用户可以根据自己的需求修改 # 建议使用ipython notebook或者spyder来做回测 # 同样可以在命令模式下进行回测(一行一行输入运行) from strategy.strategyEmaDemo import * # 创建回测引擎 engine = BacktestingEngine() # 设置引擎的回测模式为K线 engine.setBacktestingMode(engine.BAR_MODE) # 设置回测用的数据起始日期 engine.setStartDate('20110101') # 载入历史数据到引擎中 engine.setDatabase(MINUTE_DB_NAME, 'IF0000') # 设置产品相关参数 engine.setSlippage(0.2) # 股指1跳 engine.setRate(0.3/10000) # 万0.3 engine.setSize(300) # 股指合约大小 # 在引擎中创建策略对象 engine.initStrategy(EmaDemoStrategy, {}) # 开始跑回测 engine.runBacktesting() # 显示回测结果 # spyder或者ipython notebook中运行时,会弹出盈亏曲线图 # 直接在cmd中回测则只会打印一些回测数值 engine.showBacktestingResult()