[Mod]完成海龟回测引擎的统计和绘图功能

This commit is contained in:
vn.py 2018-11-13 23:00:38 +08:00
parent 270847132a
commit 272857792e
5 changed files with 303 additions and 756 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -2,10 +2,12 @@
from csv import DictReader
from datetime import datetime
from pymongo import MongoClient
from collections import OrderedDict, defaultdict
import numpy as np
import matplotlib.pyplot as plt
from pymongo import MongoClient
from vnpy.trader.vtObject import VtBarData
from vnpy.trader.vtConstant import DIRECTION_LONG, DIRECTION_SHORT
@ -25,7 +27,7 @@ SLIPPAGE_DICT = {}
########################################################################
class BacktestingEngine(object):
""""""
"""组合类CTA策略回测引擎"""
#----------------------------------------------------------------------
def __init__(self):
@ -40,6 +42,7 @@ class BacktestingEngine(object):
self.fixedCommissionDict = {} # 固定手续费字典
self.slippageDict = {} # 滑点成本字典
self.portfolioValue = 0
self.startDt = None
self.endDt = None
self.currentDt = None
@ -52,13 +55,15 @@ class BacktestingEngine(object):
#----------------------------------------------------------------------
def setPeriod(self, startDt, endDt):
""""""
"""设置回测周期"""
self.startDt = startDt
self.endDt = endDt
#----------------------------------------------------------------------
def initPortfolio(self, filename, portfolioValue=10000000):
""""""
"""初始化投资组合"""
self.portfolioValue = portfolioValue
with open(filename) as f:
r = DictReader(f)
for d in r:
@ -73,12 +78,12 @@ class BacktestingEngine(object):
self.portfolio = TurtlePortfolio(self)
self.portfolio.init(portfolioValue, self.vtSymbolList, SIZE_DICT)
self.writeLog(u'投资组合的合约代码%s' %(self.vtSymbolList))
self.writeLog(u'投资组合的初始价值%s' %(portfolioValue))
self.output(u'投资组合的合约代码%s' %(self.vtSymbolList))
self.output(u'投资组合的初始价值%s' %(portfolioValue))
#----------------------------------------------------------------------
def loadData(self):
""""""
"""加载数据"""
mc = MongoClient()
db = mc[DAILY_DB_NAME]
@ -96,14 +101,14 @@ class BacktestingEngine(object):
barDict = self.dataDict.setdefault(bar.datetime, OrderedDict())
barDict[bar.vtSymbol] = bar
self.writeLog(u'%s数据加载完成,总数据量:%s' %(vtSymbol, cursor.count()))
self.output(u'%s数据加载完成,总数据量:%s' %(vtSymbol, cursor.count()))
self.writeLog(u'全部数据加载完成')
self.output(u'全部数据加载完成')
#----------------------------------------------------------------------
def runBacktesting(self):
""""""
self.writeLog(u'开始回放K线数据')
"""运行回测"""
self.output(u'开始回放K线数据')
for dt, barDict in self.dataDict.items():
self.currentDt = dt
@ -121,21 +126,172 @@ class BacktestingEngine(object):
self.portfolio.onBar(bar)
self.result.updateBar(bar)
self.writeLog(u'K线数据回放结束')
self.output(u'K线数据回放结束')
#----------------------------------------------------------------------
def calculateResult(self):
""""""
self.writeLog(u'开始统计回测结果')
def calculateResult(self, annualDays=240):
"""计算结果"""
self.output(u'开始统计回测结果')
for result in self.resultList:
result.calculatePnl()
self.writeLog(u'回测结果统计结束')
resultList = self.resultList
dateList = [result.date for result in resultList]
startDate = dateList[0]
endDate = dateList[-1]
totalDays = len(dateList)
profitDays = 0
lossDays = 0
endBalance = self.portfolioValue
highlevel = self.portfolioValue
totalNetPnl = 0
totalCommission = 0
totalSlippage = 0
totalTradeCount = 0
netPnlList = []
balanceList = []
highlevelList = []
drawdownList = []
ddPercentList = []
returnList = []
for result in resultList:
if result.netPnl > 0:
profitDays += 1
elif result.netPnl < 0:
lossDays += 1
netPnlList.append(result.netPnl)
prevBalance = endBalance
endBalance += result.netPnl
balanceList.append(endBalance)
returnList.append(endBalance/prevBalance - 1)
highlevel = max(highlevel, endBalance)
highlevelList.append(highlevel)
drawdown = endBalance - highlevel
drawdownList.append(drawdown)
ddPercentList.append(drawdown/highlevel*100)
totalCommission += result.commission
totalSlippage += result.slippage
totalTradeCount += result.tradeCount
totalNetPnl += result.netPnl
maxDrawdown = min(drawdownList)
maxDdPercent = min(ddPercentList)
totalReturn = (endBalance / self.portfolioValue - 1) * 100
dailyReturn = np.mean(returnList) * 100
annualizedReturn = dailyReturn * annualDays
returnStd = np.std(returnList) * 100
if returnStd:
sharpeRatio = dailyReturn / returnStd * np.sqrt(annualDays)
else:
sharpeRatio = 0
# 返回结果
result = {
'startDate': startDate,
'endDate': endDate,
'totalDays': totalDays,
'profitDays': profitDays,
'lossDays': lossDays,
'endBalance': endBalance,
'maxDrawdown': maxDrawdown,
'maxDdPercent': maxDdPercent,
'totalNetPnl': totalNetPnl,
'dailyNetPnl': totalNetPnl/totalDays,
'totalCommission': totalCommission,
'dailyCommission': totalCommission/totalDays,
'totalSlippage': totalSlippage,
'dailySlippage': totalSlippage/totalDays,
'totalTradeCount': totalTradeCount,
'dailyTradeCount': totalTradeCount/totalDays,
'totalReturn': totalReturn,
'annualizedReturn': annualizedReturn,
'dailyReturn': dailyReturn,
'returnStd': returnStd,
'sharpeRatio': sharpeRatio
}
timeseries = {
'balance': balanceList,
'return': returnList,
'highLevel': highlevel,
'drawdown': drawdownList,
'ddPercent': ddPercentList,
'date': dateList,
'netPnl': netPnlList
}
return timeseries, result
#----------------------------------------------------------------------
def showResult(self):
"""显示回测结果"""
timeseries, result = self.calculateResult()
# 输出统计结果
self.output('-' * 30)
self.output(u'首个交易日:\t%s' % result['startDate'])
self.output(u'最后交易日:\t%s' % result['endDate'])
self.output(u'总交易日:\t%s' % result['totalDays'])
self.output(u'盈利交易日\t%s' % result['profitDays'])
self.output(u'亏损交易日:\t%s' % result['lossDays'])
self.output(u'起始资金:\t%s' % self.portfolioValue)
self.output(u'结束资金:\t%s' % formatNumber(result['endBalance']))
self.output(u'总收益率:\t%s%%' % formatNumber(result['totalReturn']))
self.output(u'年化收益:\t%s%%' % formatNumber(result['annualizedReturn']))
self.output(u'总盈亏:\t%s' % formatNumber(result['totalNetPnl']))
self.output(u'最大回撤: \t%s' % formatNumber(result['maxDrawdown']))
self.output(u'百分比最大回撤: %s%%' % formatNumber(result['maxDdPercent']))
self.output(u'总手续费:\t%s' % formatNumber(result['totalCommission']))
self.output(u'总滑点:\t%s' % formatNumber(result['totalSlippage']))
self.output(u'总成交笔数:\t%s' % formatNumber(result['totalTradeCount']))
self.output(u'日均盈亏:\t%s' % formatNumber(result['dailyNetPnl']))
self.output(u'日均手续费:\t%s' % formatNumber(result['dailyCommission']))
self.output(u'日均滑点:\t%s' % formatNumber(result['dailySlippage']))
self.output(u'日均成交笔数:\t%s' % formatNumber(result['dailyTradeCount']))
self.output(u'日均收益率:\t%s%%' % formatNumber(result['dailyReturn']))
self.output(u'收益标准差:\t%s%%' % formatNumber(result['returnStd']))
self.output(u'Sharpe Ratio\t%s' % formatNumber(result['sharpeRatio']))
# 绘图
fig = plt.figure(figsize=(10, 16))
pBalance = plt.subplot(4, 1, 1)
pBalance.set_title('Balance')
plt.plot(timeseries['date'], timeseries['balance'])
pDrawdown = plt.subplot(4, 1, 2)
pDrawdown.set_title('Drawdown')
pDrawdown.fill_between(range(len(timeseries['drawdown'])), timeseries['drawdown'])
pPnl = plt.subplot(4, 1, 3)
pPnl.set_title('Daily Pnl')
plt.bar(range(len(timeseries['drawdown'])), timeseries['netPnl'])
pKDE = plt.subplot(4, 1, 4)
pKDE.set_title('Daily Pnl Distribution')
plt.hist(timeseries['netPnl'], bins=50)
plt.show()
#----------------------------------------------------------------------
def sendOrder(self, vtSymbol, direction, offset, price, volume):
""""""
"""记录交易数据由portfolio调用"""
# 对价格四舍五入
priceTick = PRICETICK_DICT[vtSymbol]
price = int(round(price/priceTick, 0)) * priceTick
@ -148,9 +304,23 @@ class BacktestingEngine(object):
self.result.updateTrade(trade)
#----------------------------------------------------------------------
def writeLog(self, content):
""""""
print '%s:\t%s' %(datetime.now().strftime('%H:%M:%S.%f'), content)
def output(self, content):
"""输出信息"""
print content
#----------------------------------------------------------------------
def getTradeData(self, vtSymbol=''):
"""获取交易数据"""
tradeList = []
for l in self.tradeDict.values():
for trade in l:
if not vtSymbol:
tradeList.append(trade)
elif trade.vtSymbol == vtSymbol:
tradeList.append(trade)
return tradeList
########################################################################
@ -182,15 +352,20 @@ class DailyResult(object):
self.tradeDict = defaultdict(list) # 成交字典
self.posDict = {} # 持仓字典(开盘时)
self.tradingPnl = 0
self.holdingPnl = 0
self.totalPnl = 0
self.tradingPnl = 0 # 交易盈亏
self.holdingPnl = 0 # 持仓盈亏
self.totalPnl = 0 # 总盈亏
self.commission = 0 # 佣金
self.slippage = 0 # 滑点
self.netPnl = 0 # 净盈亏
self.tradeCount = 0 # 成交笔数
#----------------------------------------------------------------------
def updateTrade(self, trade):
"""更新交易"""
l = self.tradeDict[trade.vtSymbol]
l.append(trade)
self.tradeCount += 1
#----------------------------------------------------------------------
def updatePos(self, d):
@ -227,9 +402,10 @@ class DailyResult(object):
commissionCost = (trade.volume * fixedCommission +
trade.volume * trade.price * variableCommission)
slippageCost = trade.volume * slippage
pnl = (close - trade.price) * trade.volume * side * size
pnl -= (commissionCost + slippageCost)
self.commission += commissionCost
self.slippage += slippageCost
self.tradingPnl += pnl
#----------------------------------------------------------------------
@ -249,4 +425,11 @@ class DailyResult(object):
self.calculateHoldingPnl()
self.calculateTradingPnl()
self.totalPnl = self.holdingPnl + self.tradingPnl
self.netPnl = self.totalPnl - self.commission - self.slippage
#----------------------------------------------------------------------
def formatNumber(n):
"""格式化数字到字符串"""
rn = round(n, 2) # 保留两位小数
return format(rn, ',') # 加上千分符

View File

@ -394,6 +394,5 @@ class TurtlePortfolio(object):
self.totalShort += unit
# 向回测引擎中发单记录
#self.engine.sendOrder(vtSymbol, direction, offset, price, volume)
self.engine.sendOrder(vtSymbol, direction, offset, price, volume*multiplier)

View File

@ -1348,6 +1348,7 @@ class HistoryDataServer(RpcServer):
print(u'从数据库加载:%s %s %s %s' %(dbName, symbol, start, end))
return history
#----------------------------------------------------------------------
def runHistoryDataServer():
""""""
@ -1361,6 +1362,7 @@ def runHistoryDataServer():
hds.stop()
raw_input()
#----------------------------------------------------------------------
def formatNumber(n):
"""格式化数字到字符串"""