diff --git a/vnpy/app/cta_backtester/engine.py b/vnpy/app/cta_backtester/engine.py index 40c45f87..b7267fe2 100644 --- a/vnpy/app/cta_backtester/engine.py +++ b/vnpy/app/cta_backtester/engine.py @@ -7,13 +7,16 @@ from pathlib import Path from vnpy.event import Event, EventEngine from vnpy.trader.engine import BaseEngine, MainEngine +from vnpy.trader.constant import Interval +from vnpy.trader.utility import extract_vt_symbol +from vnpy.trader.rqdata import rqdata_client +from vnpy.trader.database import database_manager from vnpy.app.cta_strategy import ( CtaTemplate, BacktestingEngine, OptimizationSetting ) - APP_NAME = "CtaBacktester" EVENT_BACKTESTER_LOG = "eBacktesterLog" @@ -53,6 +56,16 @@ class BacktesterEngine(BaseEngine): self.write_log("策略文件加载完成") + self.init_rqdata() + + def init_rqdata(self): + """ + Init RQData client. + """ + result = rqdata_client.init() + if result: + self.write_log("RQData数据接口初始化成功") + def write_log(self, msg: str): """""" event = Event(EVENT_BACKTESTER_LOG) @@ -167,7 +180,7 @@ class BacktesterEngine(BaseEngine): setting: dict ): if self.thread: - self.write_log("已有回测或者优化在运行中,请等待完成") + self.write_log("已有任务在运行中,请等待完成") return False self.write_log("-" * 40) @@ -275,7 +288,7 @@ class BacktesterEngine(BaseEngine): optimization_setting: OptimizationSetting ): if self.thread: - self.write_log("已有回测或者优化在运行中,请等待完成") + self.write_log("已有任务在运行中,请等待完成") return False self.write_log("-" * 40) @@ -298,3 +311,54 @@ class BacktesterEngine(BaseEngine): self.thread.start() return True + + def run_downloading( + self, + vt_symbol: str, + interval: str, + start: datetime, + end: datetime + ): + """ + Query bar data from RQData. + """ + self.write_log(f"{vt_symbol}-{interval}开始下载历史数据") + + symbol, exchange = extract_vt_symbol(vt_symbol) + data = rqdata_client.query_bar( + symbol, exchange, Interval(interval), start, end + ) + + if not data: + self.write_log(f"数据下载失败,无法获取{vt_symbol}的历史数据") + + database_manager.save_bar_data(data) + + # Clear thread object handler. + self.thread = None + self.write_log(f"{vt_symbol}-{interval}历史数据下载完成") + + def start_downloading( + self, + vt_symbol: str, + interval: str, + start: datetime, + end: datetime + ): + if self.thread: + self.write_log("已有任务在运行中,请等待完成") + return False + + self.write_log("-" * 40) + self.thread = Thread( + target=self.run_downloading, + args=( + vt_symbol, + interval, + start, + end + ) + ) + self.thread.start() + + return True diff --git a/vnpy/app/cta_backtester/ui/widget.py b/vnpy/app/cta_backtester/ui/widget.py index bbed6991..ac9ff738 100644 --- a/vnpy/app/cta_backtester/ui/widget.py +++ b/vnpy/app/cta_backtester/ui/widget.py @@ -92,6 +92,17 @@ class BacktesterManager(QtWidgets.QWidget): self.result_button.clicked.connect(self.show_optimization_result) self.result_button.setEnabled(False) + downloading_button = QtWidgets.QPushButton("下载数据") + downloading_button.clicked.connect(self.start_downloading) + + for button in [ + backtesting_button, + optimization_button, + downloading_button, + self.result_button + ]: + button.setFixedHeight(button.sizeHint().height() * 2) + form = QtWidgets.QFormLayout() form.addRow("交易策略", self.class_combo) form.addRow("本地代码", self.symbol_line) @@ -107,6 +118,7 @@ class BacktesterManager(QtWidgets.QWidget): left_vbox = QtWidgets.QVBoxLayout() left_vbox.addLayout(form) + left_vbox.addWidget(downloading_button) left_vbox.addStretch() left_vbox.addWidget(optimization_button) left_vbox.addWidget(self.result_button) @@ -247,6 +259,20 @@ class BacktesterManager(QtWidgets.QWidget): self.result_button.setEnabled(False) + def start_downloading(self): + """""" + vt_symbol = self.symbol_line.text() + interval = self.interval_combo.currentText() + start = self.start_date_edit.date().toPyDate() + end = self.end_date_edit.date().toPyDate() + + self.backtester_engine.start_downloading( + vt_symbol, + interval, + start, + end + ) + def show_optimization_result(self): """""" result_values = self.backtester_engine.get_result_values() diff --git a/vnpy/app/cta_strategy/engine.py b/vnpy/app/cta_strategy/engine.py index 84939da3..ebe8ad73 100644 --- a/vnpy/app/cta_strategy/engine.py +++ b/vnpy/app/cta_strategy/engine.py @@ -5,7 +5,7 @@ import os import traceback from collections import defaultdict from pathlib import Path -from typing import Any, Callable, List +from typing import Any, Callable from datetime import datetime, timedelta from threading import Thread from queue import Queue @@ -37,7 +37,7 @@ from vnpy.trader.constant import ( ) from vnpy.trader.utility import load_json, save_json, extract_vt_symbol from vnpy.trader.database import database_manager -from vnpy.trader.setting import SETTINGS +from vnpy.trader.rqdata import rqdata_client from .base import ( APP_NAME, @@ -124,26 +124,9 @@ class CtaEngine(BaseEngine): """ Init RQData client. """ - username = SETTINGS["rqdata.username"] - password = SETTINGS["rqdata.password"] - if not username or not password: - return - - import rqdatac - - self.rq_client = rqdatac - self.rq_client.init(username, password, - ('rqdatad-pro.ricequant.com', 16011)) - - try: - df = self.rq_client.all_instruments( - type='Future', date=datetime.now()) - for ix, row in df.iterrows(): - self.rq_symbols.add(row['order_book_id']) - except RuntimeError: - pass - - self.write_log("RQData数据接口初始化成功") + result = rqdata_client.init() + if result: + self.write_log("RQData数据接口初始化成功") def query_bar_from_rq( self, symbol: str, exchange: Exchange, interval: Interval, start: datetime, end: datetime @@ -151,36 +134,9 @@ class CtaEngine(BaseEngine): """ Query bar data from RQData. """ - rq_symbol = to_rq_symbol(symbol, exchange) - if rq_symbol not in self.rq_symbols: - return None - - end += timedelta(1) # For querying night trading period data - - df = self.rq_client.get_price( - rq_symbol, - frequency=interval.value, - fields=["open", "high", "low", "close", "volume"], - start_date=start, - end_date=end + data = rqdata_client.query_bar( + symbol, exchange, interval, start, end ) - - data: List[BarData] = [] - for ix, row in df.iterrows(): - bar = BarData( - symbol=symbol, - exchange=exchange, - interval=interval, - datetime=row.name.to_pydatetime(), - open_price=row["open"], - high_price=row["high"], - low_price=row["low"], - close_price=row["close"], - volume=row["volume"], - gateway_name="RQ" - ) - data.append(bar) - return data def process_tick_event(self, event: Event): @@ -915,29 +871,3 @@ class CtaEngine(BaseEngine): subject = "CTA策略引擎" self.main_engine.send_email(subject, msg) - - -def to_rq_symbol(symbol: str, exchange: Exchange): - """ - CZCE product of RQData has symbol like "TA1905" while - vt symbol is "TA905.CZCE" so need to add "1" in symbol. - """ - if exchange is not Exchange.CZCE: - return symbol.upper() - - for count, word in enumerate(symbol): - if word.isdigit(): - break - - # noinspection PyUnboundLocalVariable - product = symbol[:count] - year = symbol[count] - month = symbol[count + 1:] - - if year == "9": - year = "1" + year - else: - year = "2" + year - - rq_symbol = f"{product}{year}{month}".upper() - return rq_symbol diff --git a/vnpy/trader/rqdata.py b/vnpy/trader/rqdata.py new file mode 100644 index 00000000..4f0f261e --- /dev/null +++ b/vnpy/trader/rqdata.py @@ -0,0 +1,123 @@ +from datetime import datetime, timedelta +from typing import List + +from rqdatac import init as rqdata_init +from rqdatac.services.basic import all_instruments as rqdata_all_instruments +from rqdatac.services.get_price import get_price as rqdata_get_price + +from .setting import SETTINGS +from .constant import Exchange, Interval +from .object import BarData + + +class RqdataClient: + """ + Client for querying history data from RQData. + """ + + def __init__(self): + """""" + self.username = SETTINGS["rqdata.username"] + self.password = SETTINGS["rqdata.password"] + + self.inited = False + self.symbols = set() + + def init(self): + """""" + if self.inited: + return True + + if not self.username or not self.password: + return False + + rqdata_init(self.username, self.password, + ('rqdatad-pro.ricequant.com', 16011)) + + try: + df = rqdata_all_instruments(date=datetime.now()) + for ix, row in df.iterrows(): + self.symbols.add(row['order_book_id']) + except RuntimeError: + return False + + self.inited = True + return True + + def to_rq_symbol(self, symbol: str, exchange: Exchange): + """ + CZCE product of RQData has symbol like "TA1905" while + vt symbol is "TA905.CZCE" so need to add "1" in symbol. + """ + if exchange in [Exchange.SSE, Exchange.SZSE]: + if exchange == Exchange.SSE: + rq_symbol = f"{symbol}.XSHG" + else: + rq_symbol = f"{symbol}.XSHE" + else: + if exchange is not Exchange.CZCE: + return symbol.upper() + + for count, word in enumerate(symbol): + if word.isdigit(): + break + + # noinspection PyUnboundLocalVariable + product = symbol[:count] + year = symbol[count] + month = symbol[count + 1:] + + if year == "9": + year = "1" + year + else: + year = "2" + year + + rq_symbol = f"{product}{year}{month}".upper() + + return rq_symbol + + def query_bar( + self, + symbol: str, + exchange: Exchange, + interval: Interval, + start: datetime, + end: datetime + ): + """ + Query bar data from RQData. + """ + rq_symbol = self.to_rq_symbol(symbol, exchange) + if rq_symbol not in self.symbols: + return None + + end += timedelta(1) # For querying night trading period data + + df = rqdata_get_price( + rq_symbol, + frequency=interval.value, + fields=["open", "high", "low", "close", "volume"], + start_date=start, + end_date=end + ) + + data: List[BarData] = [] + for ix, row in df.iterrows(): + bar = BarData( + symbol=symbol, + exchange=exchange, + interval=interval, + datetime=row.name.to_pydatetime(), + open_price=row["open"], + high_price=row["high"], + low_price=row["low"], + close_price=row["close"], + volume=row["volume"], + gateway_name="RQ" + ) + data.append(bar) + + return data + + +rqdata_client = RqdataClient()