[Add]download history data function into CtaBacktester

This commit is contained in:
vn.py 2019-04-18 11:27:50 +08:00
parent 1d6d191ac6
commit bd36646036
5 changed files with 118 additions and 23 deletions

View File

@ -7,13 +7,16 @@ from pathlib import Path
from vnpy.event import Event, EventEngine
from vnpy.trader.engine import BaseEngine, MainEngine
from vnpy.trader.constant import Interval
from vnpy.trader.utility import extract_vt_symbol
from vnpy.trader.rqdata import rqdata_client
from vnpy.trader.database import database_manager
from vnpy.app.cta_strategy import (
CtaTemplate,
BacktestingEngine,
OptimizationSetting
)
APP_NAME = "CtaBacktester"
EVENT_BACKTESTER_LOG = "eBacktesterLog"
@ -53,6 +56,16 @@ class BacktesterEngine(BaseEngine):
self.write_log("策略文件加载完成")
self.init_rqdata()
def init_rqdata(self):
"""
Init RQData client.
"""
result = rqdata_client.init()
if result:
self.write_log("RQData数据接口初始化成功")
def write_log(self, msg: str):
""""""
event = Event(EVENT_BACKTESTER_LOG)
@ -167,7 +180,7 @@ class BacktesterEngine(BaseEngine):
setting: dict
):
if self.thread:
self.write_log("已有回测或者优化在运行中,请等待完成")
self.write_log("已有任务在运行中,请等待完成")
return False
self.write_log("-" * 40)
@ -275,7 +288,7 @@ class BacktesterEngine(BaseEngine):
optimization_setting: OptimizationSetting
):
if self.thread:
self.write_log("已有回测或者优化在运行中,请等待完成")
self.write_log("已有任务在运行中,请等待完成")
return False
self.write_log("-" * 40)
@ -298,3 +311,54 @@ class BacktesterEngine(BaseEngine):
self.thread.start()
return True
def run_downloading(
self,
vt_symbol: str,
interval: str,
start: datetime,
end: datetime
):
"""
Query bar data from RQData.
"""
self.write_log(f"{vt_symbol}-{interval}开始下载历史数据")
symbol, exchange = extract_vt_symbol(vt_symbol)
data = rqdata_client.query_bar(
symbol, exchange, Interval(interval), start, end
)
if not data:
self.write_log(f"数据下载失败,无法获取{vt_symbol}的历史数据")
database_manager.save_bar_data(data)
# Clear thread object handler.
self.thread = None
self.write_log(f"{vt_symbol}-{interval}历史数据下载完成")
def start_downloading(
self,
vt_symbol: str,
interval: str,
start: datetime,
end: datetime
):
if self.thread:
self.write_log("已有任务在运行中,请等待完成")
return False
self.write_log("-" * 40)
self.thread = Thread(
target=self.run_downloading,
args=(
vt_symbol,
interval,
start,
end
)
)
self.thread.start()
return True

View File

@ -92,6 +92,17 @@ class BacktesterManager(QtWidgets.QWidget):
self.result_button.clicked.connect(self.show_optimization_result)
self.result_button.setEnabled(False)
downloading_button = QtWidgets.QPushButton("下载数据")
downloading_button.clicked.connect(self.start_downloading)
for button in [
backtesting_button,
optimization_button,
downloading_button,
self.result_button
]:
button.setFixedHeight(button.sizeHint().height() * 2)
form = QtWidgets.QFormLayout()
form.addRow("交易策略", self.class_combo)
form.addRow("本地代码", self.symbol_line)
@ -107,6 +118,7 @@ class BacktesterManager(QtWidgets.QWidget):
left_vbox = QtWidgets.QVBoxLayout()
left_vbox.addLayout(form)
left_vbox.addWidget(downloading_button)
left_vbox.addStretch()
left_vbox.addWidget(optimization_button)
left_vbox.addWidget(self.result_button)
@ -247,6 +259,20 @@ class BacktesterManager(QtWidgets.QWidget):
self.result_button.setEnabled(False)
def start_downloading(self):
""""""
vt_symbol = self.symbol_line.text()
interval = self.interval_combo.currentText()
start = self.start_date_edit.date().toPyDate()
end = self.end_date_edit.date().toPyDate()
self.backtester_engine.start_downloading(
vt_symbol,
interval,
start,
end
)
def show_optimization_result(self):
""""""
result_values = self.backtester_engine.get_result_values()

View File

@ -871,4 +871,3 @@ class CtaEngine(BaseEngine):
subject = "CTA策略引擎"
self.main_engine.send_email(subject, msg)

View File

@ -35,8 +35,7 @@ class RqdataClient:
('rqdatad-pro.ricequant.com', 16011))
try:
df = rqdata_all_instruments(
type='Future', date=datetime.now())
df = rqdata_all_instruments(date=datetime.now())
for ix, row in df.iterrows():
self.symbols.add(row['order_book_id'])
except RuntimeError:
@ -50,24 +49,31 @@ class RqdataClient:
CZCE product of RQData has symbol like "TA1905" while
vt symbol is "TA905.CZCE" so need to add "1" in symbol.
"""
if exchange is not Exchange.CZCE:
return symbol.upper()
for count, word in enumerate(symbol):
if word.isdigit():
break
# noinspection PyUnboundLocalVariable
product = symbol[:count]
year = symbol[count]
month = symbol[count + 1:]
if year == "9":
year = "1" + year
if exchange in [Exchange.SSE, Exchange.SZSE]:
if exchange == Exchange.SSE:
rq_symbol = f"{symbol}.XSHG"
else:
rq_symbol = f"{symbol}.XSHE"
else:
year = "2" + year
if exchange is not Exchange.CZCE:
return symbol.upper()
for count, word in enumerate(symbol):
if word.isdigit():
break
# noinspection PyUnboundLocalVariable
product = symbol[:count]
year = symbol[count]
month = symbol[count + 1:]
if year == "9":
year = "1" + year
else:
year = "2" + year
rq_symbol = f"{product}{year}{month}".upper()
rq_symbol = f"{product}{year}{month}".upper()
return rq_symbol
def query_bar(

View File

@ -4,7 +4,7 @@ General utility functions.
import json
from pathlib import Path
from typing import Callable, TYPE_CHECKING
from typing import Callable
import numpy as np
import talib