[Add]增加CTA回测模块的历史数据缓存服务器进程功能 #847

This commit is contained in:
vn.py 2018-08-06 15:04:24 +08:00
parent 672353f369
commit e1906094dc
4 changed files with 246 additions and 174 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,3 @@
from vnpy.trader.app.ctaStrategy.ctaBacktesting import runHistoryDataServer
runHistoryDataServer()

View File

@ -18,6 +18,9 @@ import pandas as pd
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from vnpy.rpc import RpcClient, RpcServer, RemoteException
# 如果安装了seaborn则设置为白色风格 # 如果安装了seaborn则设置为白色风格
try: try:
import seaborn as sns import seaborn as sns
@ -71,6 +74,7 @@ class BacktestingEngine(object):
self.dbClient = None # 数据库客户端 self.dbClient = None # 数据库客户端
self.dbCursor = None # 数据库指针 self.dbCursor = None # 数据库指针
self.hdsClient = None # 历史数据服务器客户端
self.initData = [] # 初始化用的数据 self.initData = [] # 初始化用的数据
self.dbName = '' # 回测数据库名 self.dbName = '' # 回测数据库名
@ -181,6 +185,15 @@ class BacktestingEngine(object):
# 数据回放相关 # 数据回放相关
#------------------------------------------------ #------------------------------------------------
#----------------------------------------------------------------------
def initHdsClient(self):
"""初始化历史数据服务器客户端"""
reqAddress = 'tcp://localhost:5555'
subAddress = 'tcp://localhost:7777'
self.hdsClient = RpcClient(reqAddress, subAddress)
self.hdsClient.start()
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def loadHistoryData(self): def loadHistoryData(self):
"""载入历史数据""" """载入历史数据"""
@ -188,7 +201,7 @@ class BacktestingEngine(object):
collection = self.dbClient[self.dbName][self.symbol] collection = self.dbClient[self.dbName][self.symbol]
self.output(u'开始载入数据') self.output(u'开始载入数据')
# 首先根据回测模式,确认要使用的数据类 # 首先根据回测模式,确认要使用的数据类
if self.mode == self.BAR_MODE: if self.mode == self.BAR_MODE:
dataClass = VtBarData dataClass = VtBarData
@ -197,10 +210,16 @@ class BacktestingEngine(object):
dataClass = VtTickData dataClass = VtTickData
func = self.newTick func = self.newTick
# 载入初始化需要用的数据 # 载入初始化需要用的数据
flt = {'datetime':{'$gte':self.dataStartDate, if self.hdsClient:
'$lt':self.strategyStartDate}} initCursor = self.hdsClient.loadHistoryData(self.dbName,
initCursor = collection.find(flt).sort('datetime') self.symbol,
self.dataStartDate,
self.strategyStartDate)
else:
flt = {'datetime':{'$gte':self.dataStartDate,
'$lt':self.strategyStartDate}}
initCursor = collection.find(flt).sort('datetime')
# 将数据从查询指针中读取出,并生成列表 # 将数据从查询指针中读取出,并生成列表
self.initData = [] # 清空initData列表 self.initData = [] # 清空initData列表
@ -210,14 +229,24 @@ class BacktestingEngine(object):
self.initData.append(data) self.initData.append(data)
# 载入回测数据 # 载入回测数据
if not self.dataEndDate: if self.hdsClient:
flt = {'datetime':{'$gte':self.strategyStartDate}} # 数据过滤条件 self.dbCursor = self.hdsClient.loadHistoryData(self.dbName,
self.symbol,
self.strategyStartDate,
self.dataEndDate)
else: else:
flt = {'datetime':{'$gte':self.strategyStartDate, if not self.dataEndDate:
'$lte':self.dataEndDate}} flt = {'datetime':{'$gte':self.strategyStartDate}} # 数据过滤条件
self.dbCursor = collection.find(flt).sort('datetime') else:
flt = {'datetime':{'$gte':self.strategyStartDate,
'$lte':self.dataEndDate}}
self.dbCursor = collection.find(flt).sort('datetime')
self.output(u'载入完成,数据量:%s' %(initCursor.count() + self.dbCursor.count())) if isinstance(self.dbCursor, list):
count = len(initCursor) + len(self.dbCursor)
else:
count = initCursor.count() + self.dbCursor.count()
self.output(u'载入完成,数据量:%s' %count)
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def runBacktesting(self): def runBacktesting(self):
@ -1254,6 +1283,58 @@ class OptimizationSetting(object):
self.optimizeTarget = target self.optimizeTarget = target
########################################################################
class HistoryDataServer(RpcServer):
"""历史数据缓存服务器"""
#----------------------------------------------------------------------
def __init__(self, repAddress, pubAddress):
"""Constructor"""
super(HistoryDataServer, self).__init__(repAddress, pubAddress)
self.dbClient = pymongo.MongoClient(globalSetting['mongoHost'],
globalSetting['mongoPort'])
self.historyDict = {}
self.register(self.loadHistoryData)
#----------------------------------------------------------------------
def loadHistoryData(self, dbName, symbol, start, end):
""""""
# 首先检查是否有缓存,如果有则直接返回
history = self.historyDict.get((dbName, symbol, start, end), None)
if history:
print(u'找到内存缓存:%s %s %s %s' %(dbName, symbol, start, end))
return history
# 否则从数据库加载
collection = self.dbClient[dbName][symbol]
if end:
flt = {'datetime':{'$gte':start, '$lt':end}}
else:
flt = {'datetime':{'$gte':start}}
cx = collection.find(flt).sort('datetime')
history = [d for d in cx]
self.historyDict[(dbName, symbol, start, end)] = history
print(u'从数据库加载:%s %s %s %s' %(dbName, symbol, start, end))
return history
#----------------------------------------------------------------------
def runHistoryDataServer():
""""""
repAddress = 'tcp://*:5555'
pubAddress = 'tcp://*:7777'
hds = HistoryDataServer(repAddress, pubAddress)
hds.start()
print(u'按任意键退出')
raw_input()
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def formatNumber(n): def formatNumber(n):
"""格式化数字到字符串""" """格式化数字到字符串"""