[Add]增加CTA回测模块的历史数据缓存服务器进程功能 #847
This commit is contained in:
parent
672353f369
commit
e1906094dc
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
3
examples/CtaBacktesting/startHds.py
Normal file
3
examples/CtaBacktesting/startHds.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from vnpy.trader.app.ctaStrategy.ctaBacktesting import runHistoryDataServer
|
||||||
|
|
||||||
|
runHistoryDataServer()
|
@ -18,6 +18,9 @@ import pandas as pd
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
from vnpy.rpc import RpcClient, RpcServer, RemoteException
|
||||||
|
|
||||||
|
|
||||||
# 如果安装了seaborn则设置为白色风格
|
# 如果安装了seaborn则设置为白色风格
|
||||||
try:
|
try:
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
@ -71,6 +74,7 @@ class BacktestingEngine(object):
|
|||||||
|
|
||||||
self.dbClient = None # 数据库客户端
|
self.dbClient = None # 数据库客户端
|
||||||
self.dbCursor = None # 数据库指针
|
self.dbCursor = None # 数据库指针
|
||||||
|
self.hdsClient = None # 历史数据服务器客户端
|
||||||
|
|
||||||
self.initData = [] # 初始化用的数据
|
self.initData = [] # 初始化用的数据
|
||||||
self.dbName = '' # 回测数据库名
|
self.dbName = '' # 回测数据库名
|
||||||
@ -181,6 +185,15 @@ class BacktestingEngine(object):
|
|||||||
# 数据回放相关
|
# 数据回放相关
|
||||||
#------------------------------------------------
|
#------------------------------------------------
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
def initHdsClient(self):
|
||||||
|
"""初始化历史数据服务器客户端"""
|
||||||
|
reqAddress = 'tcp://localhost:5555'
|
||||||
|
subAddress = 'tcp://localhost:7777'
|
||||||
|
|
||||||
|
self.hdsClient = RpcClient(reqAddress, subAddress)
|
||||||
|
self.hdsClient.start()
|
||||||
|
|
||||||
#----------------------------------------------------------------------
|
#----------------------------------------------------------------------
|
||||||
def loadHistoryData(self):
|
def loadHistoryData(self):
|
||||||
"""载入历史数据"""
|
"""载入历史数据"""
|
||||||
@ -188,7 +201,7 @@ class BacktestingEngine(object):
|
|||||||
collection = self.dbClient[self.dbName][self.symbol]
|
collection = self.dbClient[self.dbName][self.symbol]
|
||||||
|
|
||||||
self.output(u'开始载入数据')
|
self.output(u'开始载入数据')
|
||||||
|
|
||||||
# 首先根据回测模式,确认要使用的数据类
|
# 首先根据回测模式,确认要使用的数据类
|
||||||
if self.mode == self.BAR_MODE:
|
if self.mode == self.BAR_MODE:
|
||||||
dataClass = VtBarData
|
dataClass = VtBarData
|
||||||
@ -197,10 +210,16 @@ class BacktestingEngine(object):
|
|||||||
dataClass = VtTickData
|
dataClass = VtTickData
|
||||||
func = self.newTick
|
func = self.newTick
|
||||||
|
|
||||||
# 载入初始化需要用的数据
|
# 载入初始化需要用的数据
|
||||||
flt = {'datetime':{'$gte':self.dataStartDate,
|
if self.hdsClient:
|
||||||
'$lt':self.strategyStartDate}}
|
initCursor = self.hdsClient.loadHistoryData(self.dbName,
|
||||||
initCursor = collection.find(flt).sort('datetime')
|
self.symbol,
|
||||||
|
self.dataStartDate,
|
||||||
|
self.strategyStartDate)
|
||||||
|
else:
|
||||||
|
flt = {'datetime':{'$gte':self.dataStartDate,
|
||||||
|
'$lt':self.strategyStartDate}}
|
||||||
|
initCursor = collection.find(flt).sort('datetime')
|
||||||
|
|
||||||
# 将数据从查询指针中读取出,并生成列表
|
# 将数据从查询指针中读取出,并生成列表
|
||||||
self.initData = [] # 清空initData列表
|
self.initData = [] # 清空initData列表
|
||||||
@ -210,14 +229,24 @@ class BacktestingEngine(object):
|
|||||||
self.initData.append(data)
|
self.initData.append(data)
|
||||||
|
|
||||||
# 载入回测数据
|
# 载入回测数据
|
||||||
if not self.dataEndDate:
|
if self.hdsClient:
|
||||||
flt = {'datetime':{'$gte':self.strategyStartDate}} # 数据过滤条件
|
self.dbCursor = self.hdsClient.loadHistoryData(self.dbName,
|
||||||
|
self.symbol,
|
||||||
|
self.strategyStartDate,
|
||||||
|
self.dataEndDate)
|
||||||
else:
|
else:
|
||||||
flt = {'datetime':{'$gte':self.strategyStartDate,
|
if not self.dataEndDate:
|
||||||
'$lte':self.dataEndDate}}
|
flt = {'datetime':{'$gte':self.strategyStartDate}} # 数据过滤条件
|
||||||
self.dbCursor = collection.find(flt).sort('datetime')
|
else:
|
||||||
|
flt = {'datetime':{'$gte':self.strategyStartDate,
|
||||||
|
'$lte':self.dataEndDate}}
|
||||||
|
self.dbCursor = collection.find(flt).sort('datetime')
|
||||||
|
|
||||||
self.output(u'载入完成,数据量:%s' %(initCursor.count() + self.dbCursor.count()))
|
if isinstance(self.dbCursor, list):
|
||||||
|
count = len(initCursor) + len(self.dbCursor)
|
||||||
|
else:
|
||||||
|
count = initCursor.count() + self.dbCursor.count()
|
||||||
|
self.output(u'载入完成,数据量:%s' %count)
|
||||||
|
|
||||||
#----------------------------------------------------------------------
|
#----------------------------------------------------------------------
|
||||||
def runBacktesting(self):
|
def runBacktesting(self):
|
||||||
@ -1254,6 +1283,58 @@ class OptimizationSetting(object):
|
|||||||
self.optimizeTarget = target
|
self.optimizeTarget = target
|
||||||
|
|
||||||
|
|
||||||
|
########################################################################
|
||||||
|
class HistoryDataServer(RpcServer):
|
||||||
|
"""历史数据缓存服务器"""
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
def __init__(self, repAddress, pubAddress):
|
||||||
|
"""Constructor"""
|
||||||
|
super(HistoryDataServer, self).__init__(repAddress, pubAddress)
|
||||||
|
|
||||||
|
self.dbClient = pymongo.MongoClient(globalSetting['mongoHost'],
|
||||||
|
globalSetting['mongoPort'])
|
||||||
|
|
||||||
|
self.historyDict = {}
|
||||||
|
|
||||||
|
self.register(self.loadHistoryData)
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
def loadHistoryData(self, dbName, symbol, start, end):
|
||||||
|
""""""
|
||||||
|
# 首先检查是否有缓存,如果有则直接返回
|
||||||
|
history = self.historyDict.get((dbName, symbol, start, end), None)
|
||||||
|
if history:
|
||||||
|
print(u'找到内存缓存:%s %s %s %s' %(dbName, symbol, start, end))
|
||||||
|
return history
|
||||||
|
|
||||||
|
# 否则从数据库加载
|
||||||
|
collection = self.dbClient[dbName][symbol]
|
||||||
|
|
||||||
|
if end:
|
||||||
|
flt = {'datetime':{'$gte':start, '$lt':end}}
|
||||||
|
else:
|
||||||
|
flt = {'datetime':{'$gte':start}}
|
||||||
|
|
||||||
|
cx = collection.find(flt).sort('datetime')
|
||||||
|
history = [d for d in cx]
|
||||||
|
|
||||||
|
self.historyDict[(dbName, symbol, start, end)] = history
|
||||||
|
print(u'从数据库加载:%s %s %s %s' %(dbName, symbol, start, end))
|
||||||
|
return history
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
def runHistoryDataServer():
|
||||||
|
""""""
|
||||||
|
repAddress = 'tcp://*:5555'
|
||||||
|
pubAddress = 'tcp://*:7777'
|
||||||
|
|
||||||
|
hds = HistoryDataServer(repAddress, pubAddress)
|
||||||
|
hds.start()
|
||||||
|
|
||||||
|
print(u'按任意键退出')
|
||||||
|
raw_input()
|
||||||
|
|
||||||
#----------------------------------------------------------------------
|
#----------------------------------------------------------------------
|
||||||
def formatNumber(n):
|
def formatNumber(n):
|
||||||
"""格式化数字到字符串"""
|
"""格式化数字到字符串"""
|
||||||
|
Loading…
Reference in New Issue
Block a user