[Mod]完成海龟回测引擎的统计和绘图功能
This commit is contained in:
parent
270847132a
commit
272857792e
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -2,10 +2,12 @@
|
||||
|
||||
from csv import DictReader
|
||||
from datetime import datetime
|
||||
|
||||
from pymongo import MongoClient
|
||||
from collections import OrderedDict, defaultdict
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from pymongo import MongoClient
|
||||
|
||||
from vnpy.trader.vtObject import VtBarData
|
||||
from vnpy.trader.vtConstant import DIRECTION_LONG, DIRECTION_SHORT
|
||||
|
||||
@ -25,7 +27,7 @@ SLIPPAGE_DICT = {}
|
||||
|
||||
########################################################################
|
||||
class BacktestingEngine(object):
|
||||
""""""
|
||||
"""组合类CTA策略回测引擎"""
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def __init__(self):
|
||||
@ -40,6 +42,7 @@ class BacktestingEngine(object):
|
||||
self.fixedCommissionDict = {} # 固定手续费字典
|
||||
self.slippageDict = {} # 滑点成本字典
|
||||
|
||||
self.portfolioValue = 0
|
||||
self.startDt = None
|
||||
self.endDt = None
|
||||
self.currentDt = None
|
||||
@ -52,13 +55,15 @@ class BacktestingEngine(object):
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def setPeriod(self, startDt, endDt):
|
||||
""""""
|
||||
"""设置回测周期"""
|
||||
self.startDt = startDt
|
||||
self.endDt = endDt
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def initPortfolio(self, filename, portfolioValue=10000000):
|
||||
""""""
|
||||
"""初始化投资组合"""
|
||||
self.portfolioValue = portfolioValue
|
||||
|
||||
with open(filename) as f:
|
||||
r = DictReader(f)
|
||||
for d in r:
|
||||
@ -73,12 +78,12 @@ class BacktestingEngine(object):
|
||||
self.portfolio = TurtlePortfolio(self)
|
||||
self.portfolio.init(portfolioValue, self.vtSymbolList, SIZE_DICT)
|
||||
|
||||
self.writeLog(u'投资组合的合约代码%s' %(self.vtSymbolList))
|
||||
self.writeLog(u'投资组合的初始价值%s' %(portfolioValue))
|
||||
self.output(u'投资组合的合约代码%s' %(self.vtSymbolList))
|
||||
self.output(u'投资组合的初始价值%s' %(portfolioValue))
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def loadData(self):
|
||||
""""""
|
||||
"""加载数据"""
|
||||
mc = MongoClient()
|
||||
db = mc[DAILY_DB_NAME]
|
||||
|
||||
@ -96,14 +101,14 @@ class BacktestingEngine(object):
|
||||
barDict = self.dataDict.setdefault(bar.datetime, OrderedDict())
|
||||
barDict[bar.vtSymbol] = bar
|
||||
|
||||
self.writeLog(u'%s数据加载完成,总数据量:%s' %(vtSymbol, cursor.count()))
|
||||
self.output(u'%s数据加载完成,总数据量:%s' %(vtSymbol, cursor.count()))
|
||||
|
||||
self.writeLog(u'全部数据加载完成')
|
||||
self.output(u'全部数据加载完成')
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def runBacktesting(self):
|
||||
""""""
|
||||
self.writeLog(u'开始回放K线数据')
|
||||
"""运行回测"""
|
||||
self.output(u'开始回放K线数据')
|
||||
|
||||
for dt, barDict in self.dataDict.items():
|
||||
self.currentDt = dt
|
||||
@ -121,21 +126,172 @@ class BacktestingEngine(object):
|
||||
self.portfolio.onBar(bar)
|
||||
self.result.updateBar(bar)
|
||||
|
||||
self.writeLog(u'K线数据回放结束')
|
||||
self.output(u'K线数据回放结束')
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def calculateResult(self):
|
||||
""""""
|
||||
self.writeLog(u'开始统计回测结果')
|
||||
def calculateResult(self, annualDays=240):
|
||||
"""计算结果"""
|
||||
self.output(u'开始统计回测结果')
|
||||
|
||||
for result in self.resultList:
|
||||
result.calculatePnl()
|
||||
|
||||
self.writeLog(u'回测结果统计结束')
|
||||
resultList = self.resultList
|
||||
dateList = [result.date for result in resultList]
|
||||
|
||||
startDate = dateList[0]
|
||||
endDate = dateList[-1]
|
||||
totalDays = len(dateList)
|
||||
|
||||
profitDays = 0
|
||||
lossDays = 0
|
||||
endBalance = self.portfolioValue
|
||||
highlevel = self.portfolioValue
|
||||
totalNetPnl = 0
|
||||
totalCommission = 0
|
||||
totalSlippage = 0
|
||||
totalTradeCount = 0
|
||||
|
||||
netPnlList = []
|
||||
balanceList = []
|
||||
highlevelList = []
|
||||
drawdownList = []
|
||||
ddPercentList = []
|
||||
returnList = []
|
||||
|
||||
for result in resultList:
|
||||
if result.netPnl > 0:
|
||||
profitDays += 1
|
||||
elif result.netPnl < 0:
|
||||
lossDays += 1
|
||||
netPnlList.append(result.netPnl)
|
||||
|
||||
prevBalance = endBalance
|
||||
endBalance += result.netPnl
|
||||
balanceList.append(endBalance)
|
||||
returnList.append(endBalance/prevBalance - 1)
|
||||
|
||||
highlevel = max(highlevel, endBalance)
|
||||
highlevelList.append(highlevel)
|
||||
|
||||
drawdown = endBalance - highlevel
|
||||
drawdownList.append(drawdown)
|
||||
ddPercentList.append(drawdown/highlevel*100)
|
||||
|
||||
totalCommission += result.commission
|
||||
totalSlippage += result.slippage
|
||||
totalTradeCount += result.tradeCount
|
||||
totalNetPnl += result.netPnl
|
||||
|
||||
maxDrawdown = min(drawdownList)
|
||||
maxDdPercent = min(ddPercentList)
|
||||
totalReturn = (endBalance / self.portfolioValue - 1) * 100
|
||||
dailyReturn = np.mean(returnList) * 100
|
||||
annualizedReturn = dailyReturn * annualDays
|
||||
returnStd = np.std(returnList) * 100
|
||||
|
||||
if returnStd:
|
||||
sharpeRatio = dailyReturn / returnStd * np.sqrt(annualDays)
|
||||
else:
|
||||
sharpeRatio = 0
|
||||
|
||||
# 返回结果
|
||||
result = {
|
||||
'startDate': startDate,
|
||||
'endDate': endDate,
|
||||
'totalDays': totalDays,
|
||||
'profitDays': profitDays,
|
||||
'lossDays': lossDays,
|
||||
'endBalance': endBalance,
|
||||
'maxDrawdown': maxDrawdown,
|
||||
'maxDdPercent': maxDdPercent,
|
||||
'totalNetPnl': totalNetPnl,
|
||||
'dailyNetPnl': totalNetPnl/totalDays,
|
||||
'totalCommission': totalCommission,
|
||||
'dailyCommission': totalCommission/totalDays,
|
||||
'totalSlippage': totalSlippage,
|
||||
'dailySlippage': totalSlippage/totalDays,
|
||||
'totalTradeCount': totalTradeCount,
|
||||
'dailyTradeCount': totalTradeCount/totalDays,
|
||||
'totalReturn': totalReturn,
|
||||
'annualizedReturn': annualizedReturn,
|
||||
'dailyReturn': dailyReturn,
|
||||
'returnStd': returnStd,
|
||||
'sharpeRatio': sharpeRatio
|
||||
}
|
||||
|
||||
timeseries = {
|
||||
'balance': balanceList,
|
||||
'return': returnList,
|
||||
'highLevel': highlevel,
|
||||
'drawdown': drawdownList,
|
||||
'ddPercent': ddPercentList,
|
||||
'date': dateList,
|
||||
'netPnl': netPnlList
|
||||
}
|
||||
|
||||
return timeseries, result
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def showResult(self):
|
||||
"""显示回测结果"""
|
||||
timeseries, result = self.calculateResult()
|
||||
|
||||
# 输出统计结果
|
||||
self.output('-' * 30)
|
||||
self.output(u'首个交易日:\t%s' % result['startDate'])
|
||||
self.output(u'最后交易日:\t%s' % result['endDate'])
|
||||
|
||||
self.output(u'总交易日:\t%s' % result['totalDays'])
|
||||
self.output(u'盈利交易日\t%s' % result['profitDays'])
|
||||
self.output(u'亏损交易日:\t%s' % result['lossDays'])
|
||||
|
||||
self.output(u'起始资金:\t%s' % self.portfolioValue)
|
||||
self.output(u'结束资金:\t%s' % formatNumber(result['endBalance']))
|
||||
|
||||
self.output(u'总收益率:\t%s%%' % formatNumber(result['totalReturn']))
|
||||
self.output(u'年化收益:\t%s%%' % formatNumber(result['annualizedReturn']))
|
||||
self.output(u'总盈亏:\t%s' % formatNumber(result['totalNetPnl']))
|
||||
self.output(u'最大回撤: \t%s' % formatNumber(result['maxDrawdown']))
|
||||
self.output(u'百分比最大回撤: %s%%' % formatNumber(result['maxDdPercent']))
|
||||
|
||||
self.output(u'总手续费:\t%s' % formatNumber(result['totalCommission']))
|
||||
self.output(u'总滑点:\t%s' % formatNumber(result['totalSlippage']))
|
||||
self.output(u'总成交笔数:\t%s' % formatNumber(result['totalTradeCount']))
|
||||
|
||||
self.output(u'日均盈亏:\t%s' % formatNumber(result['dailyNetPnl']))
|
||||
self.output(u'日均手续费:\t%s' % formatNumber(result['dailyCommission']))
|
||||
self.output(u'日均滑点:\t%s' % formatNumber(result['dailySlippage']))
|
||||
self.output(u'日均成交笔数:\t%s' % formatNumber(result['dailyTradeCount']))
|
||||
|
||||
self.output(u'日均收益率:\t%s%%' % formatNumber(result['dailyReturn']))
|
||||
self.output(u'收益标准差:\t%s%%' % formatNumber(result['returnStd']))
|
||||
self.output(u'Sharpe Ratio:\t%s' % formatNumber(result['sharpeRatio']))
|
||||
|
||||
# 绘图
|
||||
fig = plt.figure(figsize=(10, 16))
|
||||
|
||||
pBalance = plt.subplot(4, 1, 1)
|
||||
pBalance.set_title('Balance')
|
||||
plt.plot(timeseries['date'], timeseries['balance'])
|
||||
|
||||
pDrawdown = plt.subplot(4, 1, 2)
|
||||
pDrawdown.set_title('Drawdown')
|
||||
pDrawdown.fill_between(range(len(timeseries['drawdown'])), timeseries['drawdown'])
|
||||
|
||||
pPnl = plt.subplot(4, 1, 3)
|
||||
pPnl.set_title('Daily Pnl')
|
||||
plt.bar(range(len(timeseries['drawdown'])), timeseries['netPnl'])
|
||||
|
||||
pKDE = plt.subplot(4, 1, 4)
|
||||
pKDE.set_title('Daily Pnl Distribution')
|
||||
plt.hist(timeseries['netPnl'], bins=50)
|
||||
|
||||
plt.show()
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def sendOrder(self, vtSymbol, direction, offset, price, volume):
|
||||
""""""
|
||||
"""记录交易数据(由portfolio调用)"""
|
||||
# 对价格四舍五入
|
||||
priceTick = PRICETICK_DICT[vtSymbol]
|
||||
price = int(round(price/priceTick, 0)) * priceTick
|
||||
@ -148,9 +304,23 @@ class BacktestingEngine(object):
|
||||
self.result.updateTrade(trade)
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def writeLog(self, content):
|
||||
""""""
|
||||
print '%s:\t%s' %(datetime.now().strftime('%H:%M:%S.%f'), content)
|
||||
def output(self, content):
|
||||
"""输出信息"""
|
||||
print content
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def getTradeData(self, vtSymbol=''):
|
||||
"""获取交易数据"""
|
||||
tradeList = []
|
||||
|
||||
for l in self.tradeDict.values():
|
||||
for trade in l:
|
||||
if not vtSymbol:
|
||||
tradeList.append(trade)
|
||||
elif trade.vtSymbol == vtSymbol:
|
||||
tradeList.append(trade)
|
||||
|
||||
return tradeList
|
||||
|
||||
|
||||
########################################################################
|
||||
@ -182,15 +352,20 @@ class DailyResult(object):
|
||||
self.tradeDict = defaultdict(list) # 成交字典
|
||||
self.posDict = {} # 持仓字典(开盘时)
|
||||
|
||||
self.tradingPnl = 0
|
||||
self.holdingPnl = 0
|
||||
self.totalPnl = 0
|
||||
self.tradingPnl = 0 # 交易盈亏
|
||||
self.holdingPnl = 0 # 持仓盈亏
|
||||
self.totalPnl = 0 # 总盈亏
|
||||
self.commission = 0 # 佣金
|
||||
self.slippage = 0 # 滑点
|
||||
self.netPnl = 0 # 净盈亏
|
||||
self.tradeCount = 0 # 成交笔数
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def updateTrade(self, trade):
|
||||
"""更新交易"""
|
||||
l = self.tradeDict[trade.vtSymbol]
|
||||
l.append(trade)
|
||||
self.tradeCount += 1
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def updatePos(self, d):
|
||||
@ -227,9 +402,10 @@ class DailyResult(object):
|
||||
commissionCost = (trade.volume * fixedCommission +
|
||||
trade.volume * trade.price * variableCommission)
|
||||
slippageCost = trade.volume * slippage
|
||||
|
||||
pnl = (close - trade.price) * trade.volume * side * size
|
||||
pnl -= (commissionCost + slippageCost)
|
||||
|
||||
self.commission += commissionCost
|
||||
self.slippage += slippageCost
|
||||
self.tradingPnl += pnl
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
@ -249,4 +425,11 @@ class DailyResult(object):
|
||||
self.calculateHoldingPnl()
|
||||
self.calculateTradingPnl()
|
||||
self.totalPnl = self.holdingPnl + self.tradingPnl
|
||||
self.netPnl = self.totalPnl - self.commission - self.slippage
|
||||
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def formatNumber(n):
|
||||
"""格式化数字到字符串"""
|
||||
rn = round(n, 2) # 保留两位小数
|
||||
return format(rn, ',') # 加上千分符
|
@ -394,6 +394,5 @@ class TurtlePortfolio(object):
|
||||
self.totalShort += unit
|
||||
|
||||
# 向回测引擎中发单记录
|
||||
#self.engine.sendOrder(vtSymbol, direction, offset, price, volume)
|
||||
self.engine.sendOrder(vtSymbol, direction, offset, price, volume*multiplier)
|
||||
|
@ -1348,6 +1348,7 @@ class HistoryDataServer(RpcServer):
|
||||
print(u'从数据库加载:%s %s %s %s' %(dbName, symbol, start, end))
|
||||
return history
|
||||
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def runHistoryDataServer():
|
||||
""""""
|
||||
@ -1361,6 +1362,7 @@ def runHistoryDataServer():
|
||||
hds.stop()
|
||||
raw_input()
|
||||
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def formatNumber(n):
|
||||
"""格式化数字到字符串"""
|
||||
|
Loading…
Reference in New Issue
Block a user