[Add]增加CTA回测模块的历史数据缓存服务器进程功能 #847
This commit is contained in:
parent
672353f369
commit
e1906094dc
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
3
examples/CtaBacktesting/startHds.py
Normal file
3
examples/CtaBacktesting/startHds.py
Normal file
@ -0,0 +1,3 @@
|
||||
from vnpy.trader.app.ctaStrategy.ctaBacktesting import runHistoryDataServer
|
||||
|
||||
runHistoryDataServer()
|
@ -18,6 +18,9 @@ import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from vnpy.rpc import RpcClient, RpcServer, RemoteException
|
||||
|
||||
|
||||
# 如果安装了seaborn则设置为白色风格
|
||||
try:
|
||||
import seaborn as sns
|
||||
@ -71,6 +74,7 @@ class BacktestingEngine(object):
|
||||
|
||||
self.dbClient = None # 数据库客户端
|
||||
self.dbCursor = None # 数据库指针
|
||||
self.hdsClient = None # 历史数据服务器客户端
|
||||
|
||||
self.initData = [] # 初始化用的数据
|
||||
self.dbName = '' # 回测数据库名
|
||||
@ -181,6 +185,15 @@ class BacktestingEngine(object):
|
||||
# 数据回放相关
|
||||
#------------------------------------------------
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def initHdsClient(self):
|
||||
"""初始化历史数据服务器客户端"""
|
||||
reqAddress = 'tcp://localhost:5555'
|
||||
subAddress = 'tcp://localhost:7777'
|
||||
|
||||
self.hdsClient = RpcClient(reqAddress, subAddress)
|
||||
self.hdsClient.start()
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def loadHistoryData(self):
|
||||
"""载入历史数据"""
|
||||
@ -198,9 +211,15 @@ class BacktestingEngine(object):
|
||||
func = self.newTick
|
||||
|
||||
# 载入初始化需要用的数据
|
||||
flt = {'datetime':{'$gte':self.dataStartDate,
|
||||
'$lt':self.strategyStartDate}}
|
||||
initCursor = collection.find(flt).sort('datetime')
|
||||
if self.hdsClient:
|
||||
initCursor = self.hdsClient.loadHistoryData(self.dbName,
|
||||
self.symbol,
|
||||
self.dataStartDate,
|
||||
self.strategyStartDate)
|
||||
else:
|
||||
flt = {'datetime':{'$gte':self.dataStartDate,
|
||||
'$lt':self.strategyStartDate}}
|
||||
initCursor = collection.find(flt).sort('datetime')
|
||||
|
||||
# 将数据从查询指针中读取出,并生成列表
|
||||
self.initData = [] # 清空initData列表
|
||||
@ -210,14 +229,24 @@ class BacktestingEngine(object):
|
||||
self.initData.append(data)
|
||||
|
||||
# 载入回测数据
|
||||
if not self.dataEndDate:
|
||||
flt = {'datetime':{'$gte':self.strategyStartDate}} # 数据过滤条件
|
||||
if self.hdsClient:
|
||||
self.dbCursor = self.hdsClient.loadHistoryData(self.dbName,
|
||||
self.symbol,
|
||||
self.strategyStartDate,
|
||||
self.dataEndDate)
|
||||
else:
|
||||
flt = {'datetime':{'$gte':self.strategyStartDate,
|
||||
'$lte':self.dataEndDate}}
|
||||
self.dbCursor = collection.find(flt).sort('datetime')
|
||||
if not self.dataEndDate:
|
||||
flt = {'datetime':{'$gte':self.strategyStartDate}} # 数据过滤条件
|
||||
else:
|
||||
flt = {'datetime':{'$gte':self.strategyStartDate,
|
||||
'$lte':self.dataEndDate}}
|
||||
self.dbCursor = collection.find(flt).sort('datetime')
|
||||
|
||||
self.output(u'载入完成,数据量:%s' %(initCursor.count() + self.dbCursor.count()))
|
||||
if isinstance(self.dbCursor, list):
|
||||
count = len(initCursor) + len(self.dbCursor)
|
||||
else:
|
||||
count = initCursor.count() + self.dbCursor.count()
|
||||
self.output(u'载入完成,数据量:%s' %count)
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def runBacktesting(self):
|
||||
@ -1254,6 +1283,58 @@ class OptimizationSetting(object):
|
||||
self.optimizeTarget = target
|
||||
|
||||
|
||||
########################################################################
|
||||
class HistoryDataServer(RpcServer):
|
||||
"""历史数据缓存服务器"""
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def __init__(self, repAddress, pubAddress):
|
||||
"""Constructor"""
|
||||
super(HistoryDataServer, self).__init__(repAddress, pubAddress)
|
||||
|
||||
self.dbClient = pymongo.MongoClient(globalSetting['mongoHost'],
|
||||
globalSetting['mongoPort'])
|
||||
|
||||
self.historyDict = {}
|
||||
|
||||
self.register(self.loadHistoryData)
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def loadHistoryData(self, dbName, symbol, start, end):
|
||||
""""""
|
||||
# 首先检查是否有缓存,如果有则直接返回
|
||||
history = self.historyDict.get((dbName, symbol, start, end), None)
|
||||
if history:
|
||||
print(u'找到内存缓存:%s %s %s %s' %(dbName, symbol, start, end))
|
||||
return history
|
||||
|
||||
# 否则从数据库加载
|
||||
collection = self.dbClient[dbName][symbol]
|
||||
|
||||
if end:
|
||||
flt = {'datetime':{'$gte':start, '$lt':end}}
|
||||
else:
|
||||
flt = {'datetime':{'$gte':start}}
|
||||
|
||||
cx = collection.find(flt).sort('datetime')
|
||||
history = [d for d in cx]
|
||||
|
||||
self.historyDict[(dbName, symbol, start, end)] = history
|
||||
print(u'从数据库加载:%s %s %s %s' %(dbName, symbol, start, end))
|
||||
return history
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def runHistoryDataServer():
|
||||
""""""
|
||||
repAddress = 'tcp://*:5555'
|
||||
pubAddress = 'tcp://*:7777'
|
||||
|
||||
hds = HistoryDataServer(repAddress, pubAddress)
|
||||
hds.start()
|
||||
|
||||
print(u'按任意键退出')
|
||||
raw_input()
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def formatNumber(n):
|
||||
"""格式化数字到字符串"""
|
||||
|
Loading…
Reference in New Issue
Block a user