Merge remote-tracking branch 'remotes/origin/DEV' into newest_data

This commit is contained in:
nanoric 2019-04-18 00:25:10 -04:00
commit eefb423953
4 changed files with 223 additions and 80 deletions

View File

@ -7,13 +7,16 @@ from pathlib import Path
from vnpy.event import Event, EventEngine
from vnpy.trader.engine import BaseEngine, MainEngine
from vnpy.trader.constant import Interval
from vnpy.trader.utility import extract_vt_symbol
from vnpy.trader.rqdata import rqdata_client
from vnpy.trader.database import database_manager
from vnpy.app.cta_strategy import (
CtaTemplate,
BacktestingEngine,
OptimizationSetting
)
APP_NAME = "CtaBacktester"
EVENT_BACKTESTER_LOG = "eBacktesterLog"
@ -53,6 +56,16 @@ class BacktesterEngine(BaseEngine):
self.write_log("策略文件加载完成")
self.init_rqdata()
def init_rqdata(self):
"""
Init RQData client.
"""
result = rqdata_client.init()
if result:
self.write_log("RQData数据接口初始化成功")
def write_log(self, msg: str):
""""""
event = Event(EVENT_BACKTESTER_LOG)
@ -167,7 +180,7 @@ class BacktesterEngine(BaseEngine):
setting: dict
):
if self.thread:
self.write_log("已有回测或者优化在运行中,请等待完成")
self.write_log("已有任务在运行中,请等待完成")
return False
self.write_log("-" * 40)
@ -275,7 +288,7 @@ class BacktesterEngine(BaseEngine):
optimization_setting: OptimizationSetting
):
if self.thread:
self.write_log("已有回测或者优化在运行中,请等待完成")
self.write_log("已有任务在运行中,请等待完成")
return False
self.write_log("-" * 40)
@ -298,3 +311,54 @@ class BacktesterEngine(BaseEngine):
self.thread.start()
return True
def run_downloading(
self,
vt_symbol: str,
interval: str,
start: datetime,
end: datetime
):
"""
Query bar data from RQData.
"""
self.write_log(f"{vt_symbol}-{interval}开始下载历史数据")
symbol, exchange = extract_vt_symbol(vt_symbol)
data = rqdata_client.query_bar(
symbol, exchange, Interval(interval), start, end
)
if not data:
self.write_log(f"数据下载失败,无法获取{vt_symbol}的历史数据")
database_manager.save_bar_data(data)
# Clear thread object handler.
self.thread = None
self.write_log(f"{vt_symbol}-{interval}历史数据下载完成")
def start_downloading(
self,
vt_symbol: str,
interval: str,
start: datetime,
end: datetime
):
if self.thread:
self.write_log("已有任务在运行中,请等待完成")
return False
self.write_log("-" * 40)
self.thread = Thread(
target=self.run_downloading,
args=(
vt_symbol,
interval,
start,
end
)
)
self.thread.start()
return True

View File

@ -92,6 +92,17 @@ class BacktesterManager(QtWidgets.QWidget):
self.result_button.clicked.connect(self.show_optimization_result)
self.result_button.setEnabled(False)
downloading_button = QtWidgets.QPushButton("下载数据")
downloading_button.clicked.connect(self.start_downloading)
for button in [
backtesting_button,
optimization_button,
downloading_button,
self.result_button
]:
button.setFixedHeight(button.sizeHint().height() * 2)
form = QtWidgets.QFormLayout()
form.addRow("交易策略", self.class_combo)
form.addRow("本地代码", self.symbol_line)
@ -107,6 +118,7 @@ class BacktesterManager(QtWidgets.QWidget):
left_vbox = QtWidgets.QVBoxLayout()
left_vbox.addLayout(form)
left_vbox.addWidget(downloading_button)
left_vbox.addStretch()
left_vbox.addWidget(optimization_button)
left_vbox.addWidget(self.result_button)
@ -247,6 +259,20 @@ class BacktesterManager(QtWidgets.QWidget):
self.result_button.setEnabled(False)
def start_downloading(self):
""""""
vt_symbol = self.symbol_line.text()
interval = self.interval_combo.currentText()
start = self.start_date_edit.date().toPyDate()
end = self.end_date_edit.date().toPyDate()
self.backtester_engine.start_downloading(
vt_symbol,
interval,
start,
end
)
def show_optimization_result(self):
""""""
result_values = self.backtester_engine.get_result_values()

View File

@ -5,7 +5,7 @@ import os
import traceback
from collections import defaultdict
from pathlib import Path
from typing import Any, Callable, List
from typing import Any, Callable
from datetime import datetime, timedelta
from threading import Thread
from queue import Queue
@ -37,7 +37,7 @@ from vnpy.trader.constant import (
)
from vnpy.trader.utility import load_json, save_json, extract_vt_symbol
from vnpy.trader.database import database_manager
from vnpy.trader.setting import SETTINGS
from vnpy.trader.rqdata import rqdata_client
from .base import (
APP_NAME,
@ -124,26 +124,9 @@ class CtaEngine(BaseEngine):
"""
Init RQData client.
"""
username = SETTINGS["rqdata.username"]
password = SETTINGS["rqdata.password"]
if not username or not password:
return
import rqdatac
self.rq_client = rqdatac
self.rq_client.init(username, password,
('rqdatad-pro.ricequant.com', 16011))
try:
df = self.rq_client.all_instruments(
type='Future', date=datetime.now())
for ix, row in df.iterrows():
self.rq_symbols.add(row['order_book_id'])
except RuntimeError:
pass
self.write_log("RQData数据接口初始化成功")
result = rqdata_client.init()
if result:
self.write_log("RQData数据接口初始化成功")
def query_bar_from_rq(
self, symbol: str, exchange: Exchange, interval: Interval, start: datetime, end: datetime
@ -151,36 +134,9 @@ class CtaEngine(BaseEngine):
"""
Query bar data from RQData.
"""
rq_symbol = to_rq_symbol(symbol, exchange)
if rq_symbol not in self.rq_symbols:
return None
end += timedelta(1) # For querying night trading period data
df = self.rq_client.get_price(
rq_symbol,
frequency=interval.value,
fields=["open", "high", "low", "close", "volume"],
start_date=start,
end_date=end
data = rqdata_client.query_bar(
symbol, exchange, interval, start, end
)
data: List[BarData] = []
for ix, row in df.iterrows():
bar = BarData(
symbol=symbol,
exchange=exchange,
interval=interval,
datetime=row.name.to_pydatetime(),
open_price=row["open"],
high_price=row["high"],
low_price=row["low"],
close_price=row["close"],
volume=row["volume"],
gateway_name="RQ"
)
data.append(bar)
return data
def process_tick_event(self, event: Event):
@ -915,29 +871,3 @@ class CtaEngine(BaseEngine):
subject = "CTA策略引擎"
self.main_engine.send_email(subject, msg)
def to_rq_symbol(symbol: str, exchange: Exchange):
"""
CZCE product of RQData has symbol like "TA1905" while
vt symbol is "TA905.CZCE" so need to add "1" in symbol.
"""
if exchange is not Exchange.CZCE:
return symbol.upper()
for count, word in enumerate(symbol):
if word.isdigit():
break
# noinspection PyUnboundLocalVariable
product = symbol[:count]
year = symbol[count]
month = symbol[count + 1:]
if year == "9":
year = "1" + year
else:
year = "2" + year
rq_symbol = f"{product}{year}{month}".upper()
return rq_symbol

123
vnpy/trader/rqdata.py Normal file
View File

@ -0,0 +1,123 @@
from datetime import datetime, timedelta
from typing import List
from rqdatac import init as rqdata_init
from rqdatac.services.basic import all_instruments as rqdata_all_instruments
from rqdatac.services.get_price import get_price as rqdata_get_price
from .setting import SETTINGS
from .constant import Exchange, Interval
from .object import BarData
class RqdataClient:
"""
Client for querying history data from RQData.
"""
def __init__(self):
""""""
self.username = SETTINGS["rqdata.username"]
self.password = SETTINGS["rqdata.password"]
self.inited = False
self.symbols = set()
def init(self):
""""""
if self.inited:
return True
if not self.username or not self.password:
return False
rqdata_init(self.username, self.password,
('rqdatad-pro.ricequant.com', 16011))
try:
df = rqdata_all_instruments(date=datetime.now())
for ix, row in df.iterrows():
self.symbols.add(row['order_book_id'])
except RuntimeError:
return False
self.inited = True
return True
def to_rq_symbol(self, symbol: str, exchange: Exchange):
"""
CZCE product of RQData has symbol like "TA1905" while
vt symbol is "TA905.CZCE" so need to add "1" in symbol.
"""
if exchange in [Exchange.SSE, Exchange.SZSE]:
if exchange == Exchange.SSE:
rq_symbol = f"{symbol}.XSHG"
else:
rq_symbol = f"{symbol}.XSHE"
else:
if exchange is not Exchange.CZCE:
return symbol.upper()
for count, word in enumerate(symbol):
if word.isdigit():
break
# noinspection PyUnboundLocalVariable
product = symbol[:count]
year = symbol[count]
month = symbol[count + 1:]
if year == "9":
year = "1" + year
else:
year = "2" + year
rq_symbol = f"{product}{year}{month}".upper()
return rq_symbol
def query_bar(
self,
symbol: str,
exchange: Exchange,
interval: Interval,
start: datetime,
end: datetime
):
"""
Query bar data from RQData.
"""
rq_symbol = self.to_rq_symbol(symbol, exchange)
if rq_symbol not in self.symbols:
return None
end += timedelta(1) # For querying night trading period data
df = rqdata_get_price(
rq_symbol,
frequency=interval.value,
fields=["open", "high", "low", "close", "volume"],
start_date=start,
end_date=end
)
data: List[BarData] = []
for ix, row in df.iterrows():
bar = BarData(
symbol=symbol,
exchange=exchange,
interval=interval,
datetime=row.name.to_pydatetime(),
open_price=row["open"],
high_price=row["high"],
low_price=row["low"],
close_price=row["close"],
volume=row["volume"],
gateway_name="RQ"
)
data.append(bar)
return data
rqdata_client = RqdataClient()