Merge remote-tracking branch 'remotes/origin/DEV' into newest_data
This commit is contained in:
commit
eefb423953
@ -7,13 +7,16 @@ from pathlib import Path
|
||||
|
||||
from vnpy.event import Event, EventEngine
|
||||
from vnpy.trader.engine import BaseEngine, MainEngine
|
||||
from vnpy.trader.constant import Interval
|
||||
from vnpy.trader.utility import extract_vt_symbol
|
||||
from vnpy.trader.rqdata import rqdata_client
|
||||
from vnpy.trader.database import database_manager
|
||||
from vnpy.app.cta_strategy import (
|
||||
CtaTemplate,
|
||||
BacktestingEngine,
|
||||
OptimizationSetting
|
||||
)
|
||||
|
||||
|
||||
APP_NAME = "CtaBacktester"
|
||||
|
||||
EVENT_BACKTESTER_LOG = "eBacktesterLog"
|
||||
@ -53,6 +56,16 @@ class BacktesterEngine(BaseEngine):
|
||||
|
||||
self.write_log("策略文件加载完成")
|
||||
|
||||
self.init_rqdata()
|
||||
|
||||
def init_rqdata(self):
|
||||
"""
|
||||
Init RQData client.
|
||||
"""
|
||||
result = rqdata_client.init()
|
||||
if result:
|
||||
self.write_log("RQData数据接口初始化成功")
|
||||
|
||||
def write_log(self, msg: str):
|
||||
""""""
|
||||
event = Event(EVENT_BACKTESTER_LOG)
|
||||
@ -167,7 +180,7 @@ class BacktesterEngine(BaseEngine):
|
||||
setting: dict
|
||||
):
|
||||
if self.thread:
|
||||
self.write_log("已有回测或者优化在运行中,请等待完成")
|
||||
self.write_log("已有任务在运行中,请等待完成")
|
||||
return False
|
||||
|
||||
self.write_log("-" * 40)
|
||||
@ -275,7 +288,7 @@ class BacktesterEngine(BaseEngine):
|
||||
optimization_setting: OptimizationSetting
|
||||
):
|
||||
if self.thread:
|
||||
self.write_log("已有回测或者优化在运行中,请等待完成")
|
||||
self.write_log("已有任务在运行中,请等待完成")
|
||||
return False
|
||||
|
||||
self.write_log("-" * 40)
|
||||
@ -298,3 +311,54 @@ class BacktesterEngine(BaseEngine):
|
||||
self.thread.start()
|
||||
|
||||
return True
|
||||
|
||||
def run_downloading(
|
||||
self,
|
||||
vt_symbol: str,
|
||||
interval: str,
|
||||
start: datetime,
|
||||
end: datetime
|
||||
):
|
||||
"""
|
||||
Query bar data from RQData.
|
||||
"""
|
||||
self.write_log(f"{vt_symbol}-{interval}开始下载历史数据")
|
||||
|
||||
symbol, exchange = extract_vt_symbol(vt_symbol)
|
||||
data = rqdata_client.query_bar(
|
||||
symbol, exchange, Interval(interval), start, end
|
||||
)
|
||||
|
||||
if not data:
|
||||
self.write_log(f"数据下载失败,无法获取{vt_symbol}的历史数据")
|
||||
|
||||
database_manager.save_bar_data(data)
|
||||
|
||||
# Clear thread object handler.
|
||||
self.thread = None
|
||||
self.write_log(f"{vt_symbol}-{interval}历史数据下载完成")
|
||||
|
||||
def start_downloading(
|
||||
self,
|
||||
vt_symbol: str,
|
||||
interval: str,
|
||||
start: datetime,
|
||||
end: datetime
|
||||
):
|
||||
if self.thread:
|
||||
self.write_log("已有任务在运行中,请等待完成")
|
||||
return False
|
||||
|
||||
self.write_log("-" * 40)
|
||||
self.thread = Thread(
|
||||
target=self.run_downloading,
|
||||
args=(
|
||||
vt_symbol,
|
||||
interval,
|
||||
start,
|
||||
end
|
||||
)
|
||||
)
|
||||
self.thread.start()
|
||||
|
||||
return True
|
||||
|
@ -92,6 +92,17 @@ class BacktesterManager(QtWidgets.QWidget):
|
||||
self.result_button.clicked.connect(self.show_optimization_result)
|
||||
self.result_button.setEnabled(False)
|
||||
|
||||
downloading_button = QtWidgets.QPushButton("下载数据")
|
||||
downloading_button.clicked.connect(self.start_downloading)
|
||||
|
||||
for button in [
|
||||
backtesting_button,
|
||||
optimization_button,
|
||||
downloading_button,
|
||||
self.result_button
|
||||
]:
|
||||
button.setFixedHeight(button.sizeHint().height() * 2)
|
||||
|
||||
form = QtWidgets.QFormLayout()
|
||||
form.addRow("交易策略", self.class_combo)
|
||||
form.addRow("本地代码", self.symbol_line)
|
||||
@ -107,6 +118,7 @@ class BacktesterManager(QtWidgets.QWidget):
|
||||
|
||||
left_vbox = QtWidgets.QVBoxLayout()
|
||||
left_vbox.addLayout(form)
|
||||
left_vbox.addWidget(downloading_button)
|
||||
left_vbox.addStretch()
|
||||
left_vbox.addWidget(optimization_button)
|
||||
left_vbox.addWidget(self.result_button)
|
||||
@ -247,6 +259,20 @@ class BacktesterManager(QtWidgets.QWidget):
|
||||
|
||||
self.result_button.setEnabled(False)
|
||||
|
||||
def start_downloading(self):
|
||||
""""""
|
||||
vt_symbol = self.symbol_line.text()
|
||||
interval = self.interval_combo.currentText()
|
||||
start = self.start_date_edit.date().toPyDate()
|
||||
end = self.end_date_edit.date().toPyDate()
|
||||
|
||||
self.backtester_engine.start_downloading(
|
||||
vt_symbol,
|
||||
interval,
|
||||
start,
|
||||
end
|
||||
)
|
||||
|
||||
def show_optimization_result(self):
|
||||
""""""
|
||||
result_values = self.backtester_engine.get_result_values()
|
||||
|
@ -5,7 +5,7 @@ import os
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, List
|
||||
from typing import Any, Callable
|
||||
from datetime import datetime, timedelta
|
||||
from threading import Thread
|
||||
from queue import Queue
|
||||
@ -37,7 +37,7 @@ from vnpy.trader.constant import (
|
||||
)
|
||||
from vnpy.trader.utility import load_json, save_json, extract_vt_symbol
|
||||
from vnpy.trader.database import database_manager
|
||||
from vnpy.trader.setting import SETTINGS
|
||||
from vnpy.trader.rqdata import rqdata_client
|
||||
|
||||
from .base import (
|
||||
APP_NAME,
|
||||
@ -124,26 +124,9 @@ class CtaEngine(BaseEngine):
|
||||
"""
|
||||
Init RQData client.
|
||||
"""
|
||||
username = SETTINGS["rqdata.username"]
|
||||
password = SETTINGS["rqdata.password"]
|
||||
if not username or not password:
|
||||
return
|
||||
|
||||
import rqdatac
|
||||
|
||||
self.rq_client = rqdatac
|
||||
self.rq_client.init(username, password,
|
||||
('rqdatad-pro.ricequant.com', 16011))
|
||||
|
||||
try:
|
||||
df = self.rq_client.all_instruments(
|
||||
type='Future', date=datetime.now())
|
||||
for ix, row in df.iterrows():
|
||||
self.rq_symbols.add(row['order_book_id'])
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
self.write_log("RQData数据接口初始化成功")
|
||||
result = rqdata_client.init()
|
||||
if result:
|
||||
self.write_log("RQData数据接口初始化成功")
|
||||
|
||||
def query_bar_from_rq(
|
||||
self, symbol: str, exchange: Exchange, interval: Interval, start: datetime, end: datetime
|
||||
@ -151,36 +134,9 @@ class CtaEngine(BaseEngine):
|
||||
"""
|
||||
Query bar data from RQData.
|
||||
"""
|
||||
rq_symbol = to_rq_symbol(symbol, exchange)
|
||||
if rq_symbol not in self.rq_symbols:
|
||||
return None
|
||||
|
||||
end += timedelta(1) # For querying night trading period data
|
||||
|
||||
df = self.rq_client.get_price(
|
||||
rq_symbol,
|
||||
frequency=interval.value,
|
||||
fields=["open", "high", "low", "close", "volume"],
|
||||
start_date=start,
|
||||
end_date=end
|
||||
data = rqdata_client.query_bar(
|
||||
symbol, exchange, interval, start, end
|
||||
)
|
||||
|
||||
data: List[BarData] = []
|
||||
for ix, row in df.iterrows():
|
||||
bar = BarData(
|
||||
symbol=symbol,
|
||||
exchange=exchange,
|
||||
interval=interval,
|
||||
datetime=row.name.to_pydatetime(),
|
||||
open_price=row["open"],
|
||||
high_price=row["high"],
|
||||
low_price=row["low"],
|
||||
close_price=row["close"],
|
||||
volume=row["volume"],
|
||||
gateway_name="RQ"
|
||||
)
|
||||
data.append(bar)
|
||||
|
||||
return data
|
||||
|
||||
def process_tick_event(self, event: Event):
|
||||
@ -915,29 +871,3 @@ class CtaEngine(BaseEngine):
|
||||
subject = "CTA策略引擎"
|
||||
|
||||
self.main_engine.send_email(subject, msg)
|
||||
|
||||
|
||||
def to_rq_symbol(symbol: str, exchange: Exchange):
|
||||
"""
|
||||
CZCE product of RQData has symbol like "TA1905" while
|
||||
vt symbol is "TA905.CZCE" so need to add "1" in symbol.
|
||||
"""
|
||||
if exchange is not Exchange.CZCE:
|
||||
return symbol.upper()
|
||||
|
||||
for count, word in enumerate(symbol):
|
||||
if word.isdigit():
|
||||
break
|
||||
|
||||
# noinspection PyUnboundLocalVariable
|
||||
product = symbol[:count]
|
||||
year = symbol[count]
|
||||
month = symbol[count + 1:]
|
||||
|
||||
if year == "9":
|
||||
year = "1" + year
|
||||
else:
|
||||
year = "2" + year
|
||||
|
||||
rq_symbol = f"{product}{year}{month}".upper()
|
||||
return rq_symbol
|
||||
|
123
vnpy/trader/rqdata.py
Normal file
123
vnpy/trader/rqdata.py
Normal file
@ -0,0 +1,123 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List
|
||||
|
||||
from rqdatac import init as rqdata_init
|
||||
from rqdatac.services.basic import all_instruments as rqdata_all_instruments
|
||||
from rqdatac.services.get_price import get_price as rqdata_get_price
|
||||
|
||||
from .setting import SETTINGS
|
||||
from .constant import Exchange, Interval
|
||||
from .object import BarData
|
||||
|
||||
|
||||
class RqdataClient:
|
||||
"""
|
||||
Client for querying history data from RQData.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
""""""
|
||||
self.username = SETTINGS["rqdata.username"]
|
||||
self.password = SETTINGS["rqdata.password"]
|
||||
|
||||
self.inited = False
|
||||
self.symbols = set()
|
||||
|
||||
def init(self):
|
||||
""""""
|
||||
if self.inited:
|
||||
return True
|
||||
|
||||
if not self.username or not self.password:
|
||||
return False
|
||||
|
||||
rqdata_init(self.username, self.password,
|
||||
('rqdatad-pro.ricequant.com', 16011))
|
||||
|
||||
try:
|
||||
df = rqdata_all_instruments(date=datetime.now())
|
||||
for ix, row in df.iterrows():
|
||||
self.symbols.add(row['order_book_id'])
|
||||
except RuntimeError:
|
||||
return False
|
||||
|
||||
self.inited = True
|
||||
return True
|
||||
|
||||
def to_rq_symbol(self, symbol: str, exchange: Exchange):
|
||||
"""
|
||||
CZCE product of RQData has symbol like "TA1905" while
|
||||
vt symbol is "TA905.CZCE" so need to add "1" in symbol.
|
||||
"""
|
||||
if exchange in [Exchange.SSE, Exchange.SZSE]:
|
||||
if exchange == Exchange.SSE:
|
||||
rq_symbol = f"{symbol}.XSHG"
|
||||
else:
|
||||
rq_symbol = f"{symbol}.XSHE"
|
||||
else:
|
||||
if exchange is not Exchange.CZCE:
|
||||
return symbol.upper()
|
||||
|
||||
for count, word in enumerate(symbol):
|
||||
if word.isdigit():
|
||||
break
|
||||
|
||||
# noinspection PyUnboundLocalVariable
|
||||
product = symbol[:count]
|
||||
year = symbol[count]
|
||||
month = symbol[count + 1:]
|
||||
|
||||
if year == "9":
|
||||
year = "1" + year
|
||||
else:
|
||||
year = "2" + year
|
||||
|
||||
rq_symbol = f"{product}{year}{month}".upper()
|
||||
|
||||
return rq_symbol
|
||||
|
||||
def query_bar(
|
||||
self,
|
||||
symbol: str,
|
||||
exchange: Exchange,
|
||||
interval: Interval,
|
||||
start: datetime,
|
||||
end: datetime
|
||||
):
|
||||
"""
|
||||
Query bar data from RQData.
|
||||
"""
|
||||
rq_symbol = self.to_rq_symbol(symbol, exchange)
|
||||
if rq_symbol not in self.symbols:
|
||||
return None
|
||||
|
||||
end += timedelta(1) # For querying night trading period data
|
||||
|
||||
df = rqdata_get_price(
|
||||
rq_symbol,
|
||||
frequency=interval.value,
|
||||
fields=["open", "high", "low", "close", "volume"],
|
||||
start_date=start,
|
||||
end_date=end
|
||||
)
|
||||
|
||||
data: List[BarData] = []
|
||||
for ix, row in df.iterrows():
|
||||
bar = BarData(
|
||||
symbol=symbol,
|
||||
exchange=exchange,
|
||||
interval=interval,
|
||||
datetime=row.name.to_pydatetime(),
|
||||
open_price=row["open"],
|
||||
high_price=row["high"],
|
||||
low_price=row["low"],
|
||||
close_price=row["close"],
|
||||
volume=row["volume"],
|
||||
gateway_name="RQ"
|
||||
)
|
||||
data.append(bar)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
rqdata_client = RqdataClient()
|
Loading…
Reference in New Issue
Block a user