diff --git a/vn.trader/ctaAlgo/ctaBacktesting.py b/vn.trader/ctaAlgo/ctaBacktesting.py index e7900cf2..f00d2cf9 100644 --- a/vn.trader/ctaAlgo/ctaBacktesting.py +++ b/vn.trader/ctaAlgo/ctaBacktesting.py @@ -9,6 +9,7 @@ from __future__ import division from datetime import datetime, timedelta from collections import OrderedDict from itertools import product +import multiprocessing import pymongo from ctaBase import * @@ -49,6 +50,10 @@ class BacktestingEngine(object): self.strategy = None # 回测策略 self.mode = self.BAR_MODE # 回测模式,默认为K线 + self.startDate = '' + self.initDays = 0 + self.endDate = '' + self.slippage = 0 # 回测时假设的滑点 self.rate = 0 # 回测时假设的佣金比例(适用于百分比佣金) self.size = 1 # 合约大小,默认为1 @@ -84,6 +89,9 @@ class BacktestingEngine(object): #---------------------------------------------------------------------- def setStartDate(self, startDate='20100416', initDays=10): """设置回测的启动日期""" + self.startDate = startDate + self.initDays = initDays + self.dataStartDate = datetime.strptime(startDate, '%Y%m%d') initTimeDelta = timedelta(initDays) @@ -92,8 +100,11 @@ class BacktestingEngine(object): #---------------------------------------------------------------------- def setEndDate(self, endDate=''): """设置回测的结束日期""" + self.endDate = endDate if endDate: self.dataEndDate= datetime.strptime(endDate, '%Y%m%d') + # 若不修改时间则会导致不包含dataEndDate当天数据 + self.dataEndDate.replace(hour=23, minute=59) #---------------------------------------------------------------------- def setBacktestingMode(self, mode): @@ -595,10 +606,18 @@ class BacktestingEngine(object): totalLosing += result.pnl # 计算盈亏相关数据 - winningRate = winningResult/totalResult*100 # 胜率 - averageWinning = totalWinning/winningResult # 平均每笔盈利 - averageLosing = totalLosing/losingResult # 平均每笔亏损 - profitLossRatio = -averageWinning/averageLosing # 盈亏比 + winningRate = winningResult/totalResult*100 # 胜率 + + averageWinning = 0 # 这里把数据都初始化为0 + averageLosing = 0 + profitLossRatio = 0 + + if winningResult: + averageWinning = totalWinning/winningResult # 平均每笔盈利 + if losingResult: + averageLosing = totalLosing/losingResult # 平均每笔亏损 + if averageLosing: + profitLossRatio = -averageWinning/averageLosing # 盈亏比 # 返回回测结果 d = {} @@ -712,6 +731,7 @@ class BacktestingEngine(object): self.output(u'优化结果:') for result in resultList: self.output(u'%s: %s' %(result[0], result[1])) + return result #---------------------------------------------------------------------- def clearBacktestingResult(self): @@ -730,6 +750,37 @@ class BacktestingEngine(object): self.tradeCount = 0 self.tradeDict.clear() + #---------------------------------------------------------------------- + def runParallelOptimization(self, strategyClass, optimizationSetting): + """并行优化参数""" + # 获取优化设置 + settingList = optimizationSetting.generateSetting() + targetName = optimizationSetting.optimizeTarget + + # 检查参数设置问题 + if not settingList or not targetName: + self.output(u'优化设置有问题,请检查') + + # 多进程优化,启动一个对应CPU核心数量的进程池 + pool = multiprocessing.Pool(multiprocessing.cpu_count()) + l = [] + for setting in settingList: + l.append(pool.apply_async(optimize, (strategyClass, setting, + targetName, self.mode, + self.startDate, self.initDays, self.endDate, + self.slippage, self.rate, self.size, + self.dbName, self.symbol))) + pool.close() + pool.join() + + # 显示结果 + resultList = [res.get() for res in l] + resultList.sort(reverse=True, key=lambda result:result[1]) + self.output('-' * 30) + self.output(u'优化结果:') + for result in resultList: + self.output(u'%s: %s' %(result[0], result[1])) + ######################################################################## class TradingResult(object): @@ -812,10 +863,32 @@ class OptimizationSetting(object): #---------------------------------------------------------------------- def formatNumber(n): """格式化数字到字符串""" - n = round(n, 2) # 保留两位小数 - return format(n, ',') # 加上千分符 + rn = round(n, 2) # 保留两位小数 + return format(rn, ',') # 加上千分符 + +#---------------------------------------------------------------------- +def optimize(strategyClass, setting, targetName, + mode, startDate, initDays, endDate, + slippage, rate, size, + dbName, symbol): + """多进程优化时跑在每个进程中运行的函数""" + engine = BacktestingEngine() + engine.setBacktestingMode(mode) + engine.setStartDate(startDate, initDays) + engine.setSlippage(slippage) + engine.setRate(rate) + engine.setSize(size) + engine.setDatabase(dbName, symbol) + engine.initStrategy(strategyClass, setting) + engine.runBacktesting() + d = engine.calculateBacktestingResult() + try: + targetValue = d[targetName] + except KeyError: + targetValue = 0 + return (str(setting), targetValue) if __name__ == '__main__': diff --git a/vn.trader/ctaAlgo/strategyAtrRsi.py b/vn.trader/ctaAlgo/strategyAtrRsi.py index 70029941..d19660c7 100644 --- a/vn.trader/ctaAlgo/strategyAtrRsi.py +++ b/vn.trader/ctaAlgo/strategyAtrRsi.py @@ -259,21 +259,31 @@ if __name__ == '__main__': # 设置使用的历史数据库 engine.setDatabase(MINUTE_DB_NAME, 'IF0000') - # 在引擎中创建策略对象 - d = {'atrLength': 11} - engine.initStrategy(AtrRsiStrategy, d) + ## 在引擎中创建策略对象 + #d = {'atrLength': 11} + #engine.initStrategy(AtrRsiStrategy, d) - # 开始跑回测 - engine.runBacktesting() + ## 开始跑回测 + ##engine.runBacktesting() - # 显示回测结果 - engine.showBacktestingResult() + ## 显示回测结果 + ##engine.showBacktestingResult() - ## 跑优化 - #setting = OptimizationSetting() # 新建一个优化任务设置对象 - #setting.setOptimizeTarget('capital') # 设置优化排序的目标是策略净盈利 - #setting.addParameter('atrLength', 11, 12, 1) # 增加第一个优化参数atrLength,起始11,结束12,步进1 - #setting.addParameter('atrMa', 20, 30, 5) # 增加第二个优化参数atrMa,起始20,结束30,步进1 - #engine.runOptimization(AtrRsiStrategy, setting) # 运行优化函数,自动输出结果 + # 跑优化 + setting = OptimizationSetting() # 新建一个优化任务设置对象 + setting.setOptimizeTarget('capital') # 设置优化排序的目标是策略净盈利 + setting.addParameter('atrLength', 11, 20, 1) # 增加第一个优化参数atrLength,起始11,结束12,步进1 + setting.addParameter('atrMa', 20, 30, 5) # 增加第二个优化参数atrMa,起始20,结束30,步进1 - \ No newline at end of file + # 性能测试环境:I7-3770,主频3.4G, 8核心,内存16G,Windows 7 专业版 + # 测试时还跑着一堆其他的程序,性能仅供参考 + import time + start = time.time() + + # 运行单进程优化函数,自动输出结果,耗时:359秒 + #engine.runOptimization(AtrRsiStrategy, setting) + + # 多进程优化,耗时:89秒 + engine.runParallelOptimization(AtrRsiStrategy, setting) + + print u'耗时:%s' %(time.time()-start) \ No newline at end of file