[Add]optimization function into CtaBackteserApp

This commit is contained in:
vn.py 2019-04-15 23:32:04 +08:00
parent 4c83b08315
commit 56d0812121
4 changed files with 475 additions and 93 deletions

File diff suppressed because one or more lines are too long

View File

@ -9,7 +9,8 @@ from vnpy.event import Event, EventEngine
from vnpy.trader.engine import BaseEngine, MainEngine from vnpy.trader.engine import BaseEngine, MainEngine
from vnpy.app.cta_strategy import ( from vnpy.app.cta_strategy import (
CtaTemplate, CtaTemplate,
BacktestingEngine BacktestingEngine,
OptimizationSetting
) )
@ -33,9 +34,13 @@ class BacktesterEngine(BaseEngine):
self.backtesting_engine = None self.backtesting_engine = None
self.thread = None self.thread = None
# Backtesting reuslt
self.result_df = None self.result_df = None
self.result_statistics = None self.result_statistics = None
# Optimization result
self.result_values = None
self.load_strategy_class() self.load_strategy_class()
def init_engine(self): def init_engine(self):
@ -162,7 +167,7 @@ class BacktesterEngine(BaseEngine):
setting: dict setting: dict
): ):
if self.thread: if self.thread:
self.write_log("已有回测在运行中,请等待完成") self.write_log("已有回测或者优化在运行中,请等待完成")
return False return False
self.write_log("-" * 40) self.write_log("-" * 40)
@ -194,7 +199,102 @@ class BacktesterEngine(BaseEngine):
"""""" """"""
return self.result_statistics return self.result_statistics
def get_result_values(self):
""""""
return self.result_values
def get_default_setting(self, class_name: str): def get_default_setting(self, class_name: str):
"""""" """"""
strategy_class = self.classes[class_name] strategy_class = self.classes[class_name]
return strategy_class.get_class_parameters() return strategy_class.get_class_parameters()
def run_optimization(
self,
class_name: str,
vt_symbol: str,
interval: str,
start: datetime,
end: datetime,
rate: float,
slippage: float,
size: int,
pricetick: float,
capital: int,
optimization_setting: OptimizationSetting):
""""""
self.write_log("开始多进程参数优化")
self.result_values = None
engine = self.backtesting_engine
engine.clear_data()
engine.set_parameters(
vt_symbol=vt_symbol,
interval=interval,
start=start,
end=end,
rate=rate,
slippage=slippage,
size=size,
pricetick=pricetick,
capital=capital
)
strategy_class = self.classes[class_name]
engine.add_strategy(
strategy_class,
{}
)
self.result_values = engine.run_optimization(
optimization_setting,
output=False
)
# Clear thread object handler.
self.thread = None
self.write_log("多进程参数优化完成")
# Put optimization done event
event = Event(EVENT_BACKTESTER_OPTIMIZATION_FINISHED)
self.event_engine.put(event)
def start_optimization(
self,
class_name: str,
vt_symbol: str,
interval: str,
start: datetime,
end: datetime,
rate: float,
slippage: float,
size: int,
pricetick: float,
capital: int,
optimization_setting: OptimizationSetting
):
if self.thread:
self.write_log("已有回测或者优化在运行中,请等待完成")
return False
self.write_log("-" * 40)
self.thread = Thread(
target=self.run_optimization,
args=(
class_name,
vt_symbol,
interval,
start,
end,
rate,
slippage,
size,
pricetick,
capital,
optimization_setting
)
)
self.thread.start()
return True

View File

@ -1,19 +1,18 @@
from datetime import datetime, timedelta
import pyqtgraph as pg
import numpy as np import numpy as np
import pyqtgraph as pg
from vnpy.event import Event, EventEngine from datetime import datetime, timedelta
from vnpy.trader.ui import QtCore, QtWidgets, QtGui
from vnpy.trader.engine import MainEngine
from vnpy.trader.constant import Interval
from ..engine import ( from ..engine import (
APP_NAME, APP_NAME,
EVENT_BACKTESTER_LOG, EVENT_BACKTESTER_LOG,
EVENT_BACKTESTER_BACKTESTING_FINISHED, EVENT_BACKTESTER_BACKTESTING_FINISHED,
EVENT_BACKTESTER_OPTIMIZATION_FINISHED EVENT_BACKTESTER_OPTIMIZATION_FINISHED,
OptimizationSetting
) )
from vnpy.trader.constant import Interval
from vnpy.trader.engine import MainEngine
from vnpy.trader.ui import QtCore, QtWidgets, QtGui
from vnpy.event import Event, EventEngine
class BacktesterManager(QtWidgets.QWidget): class BacktesterManager(QtWidgets.QWidget):
@ -34,6 +33,8 @@ class BacktesterManager(QtWidgets.QWidget):
self.class_names = [] self.class_names = []
self.settings = {} self.settings = {}
self.target_display = ""
self.init_strategy_settings() self.init_strategy_settings()
self.init_ui() self.init_ui()
self.register_event() self.register_event()
@ -81,8 +82,15 @@ class BacktesterManager(QtWidgets.QWidget):
self.pricetick_line = QtWidgets.QLineEdit("0.2") self.pricetick_line = QtWidgets.QLineEdit("0.2")
self.capital_line = QtWidgets.QLineEdit("1000000") self.capital_line = QtWidgets.QLineEdit("1000000")
start_button = QtWidgets.QPushButton("开始回测") backtesting_button = QtWidgets.QPushButton("开始回测")
start_button.clicked.connect(self.start_backtesting) backtesting_button.clicked.connect(self.start_backtesting)
optimization_button = QtWidgets.QPushButton("参数优化")
optimization_button.clicked.connect(self.start_optimization)
self.result_button = QtWidgets.QPushButton("优化结果")
self.result_button.clicked.connect(self.show_optimization_result)
self.result_button.setEnabled(False)
form = QtWidgets.QFormLayout() form = QtWidgets.QFormLayout()
form.addRow("交易策略", self.class_combo) form.addRow("交易策略", self.class_combo)
@ -95,7 +103,13 @@ class BacktesterManager(QtWidgets.QWidget):
form.addRow("合约乘数", self.size_line) form.addRow("合约乘数", self.size_line)
form.addRow("价格跳动", self.pricetick_line) form.addRow("价格跳动", self.pricetick_line)
form.addRow("回测资金", self.capital_line) form.addRow("回测资金", self.capital_line)
form.addRow(start_button) form.addRow(backtesting_button)
left_vbox = QtWidgets.QVBoxLayout()
left_vbox.addLayout(form)
left_vbox.addStretch()
left_vbox.addWidget(optimization_button)
left_vbox.addWidget(self.result_button)
# Result part # Result part
self.statistics_monitor = StatisticsMonitor() self.statistics_monitor = StatisticsMonitor()
@ -112,7 +126,7 @@ class BacktesterManager(QtWidgets.QWidget):
vbox.addWidget(self.log_monitor) vbox.addWidget(self.log_monitor)
hbox = QtWidgets.QHBoxLayout() hbox = QtWidgets.QHBoxLayout()
hbox.addLayout(form) hbox.addLayout(left_vbox)
hbox.addLayout(vbox) hbox.addLayout(vbox)
hbox.addWidget(self.chart) hbox.addWidget(self.chart)
self.setLayout(hbox) self.setLayout(hbox)
@ -134,6 +148,10 @@ class BacktesterManager(QtWidgets.QWidget):
def process_log_event(self, event: Event): def process_log_event(self, event: Event):
"""""" """"""
msg = event.data msg = event.data
self.write_log(msg)
def write_log(self, msg):
""""""
timestamp = datetime.now().strftime("%H:%M:%S") timestamp = datetime.now().strftime("%H:%M:%S")
msg = f"{timestamp}\t{msg}" msg = f"{timestamp}\t{msg}"
self.log_monitor.append(msg) self.log_monitor.append(msg)
@ -148,7 +166,8 @@ class BacktesterManager(QtWidgets.QWidget):
def process_optimization_finished_event(self, event: Event): def process_optimization_finished_event(self, event: Event):
"""""" """"""
pass self.write_log("请点击[优化结果]按钮查看")
self.result_button.setEnabled(True)
def start_backtesting(self): def start_backtesting(self):
"""""" """"""
@ -164,7 +183,7 @@ class BacktesterManager(QtWidgets.QWidget):
capital = float(self.capital_line.text()) capital = float(self.capital_line.text())
old_setting = self.settings[class_name] old_setting = self.settings[class_name]
dialog = SettingEditor(class_name, old_setting) dialog = BacktestingSettingEditor(class_name, old_setting)
i = dialog.exec() i = dialog.exec()
if i != dialog.Accepted: if i != dialog.Accepted:
return return
@ -190,6 +209,54 @@ class BacktesterManager(QtWidgets.QWidget):
self.statistics_monitor.clear_data() self.statistics_monitor.clear_data()
self.chart.clear_data() self.chart.clear_data()
def start_optimization(self):
""""""
class_name = self.class_combo.currentText()
vt_symbol = self.symbol_line.text()
interval = self.interval_combo.currentText()
start = self.start_date_edit.date().toPyDate()
end = self.end_date_edit.date().toPyDate()
rate = float(self.rate_line.text())
slippage = float(self.slippage_line.text())
size = float(self.size_line.text())
pricetick = float(self.pricetick_line.text())
capital = float(self.capital_line.text())
parameters = self.settings[class_name]
dialog = OptimizationSettingEditor(class_name, parameters)
i = dialog.exec()
if i != dialog.Accepted:
return
optimization_setting = dialog.get_setting()
self.target_display = dialog.target_display
self.backtester_engine.start_optimization(
class_name,
vt_symbol,
interval,
start,
end,
rate,
slippage,
size,
pricetick,
capital,
optimization_setting
)
self.result_button.setEnabled(False)
def show_optimization_result(self):
""""""
result_values = self.backtester_engine.get_result_values()
dialog = OptimizationResultMonitor(
result_values,
self.target_display
)
dialog.exec_()
def show(self): def show(self):
"""""" """"""
self.showMaximized() self.showMaximized()
@ -286,7 +353,7 @@ class StatisticsMonitor(QtWidgets.QTableWidget):
cell.setText(str(value)) cell.setText(str(value))
class SettingEditor(QtWidgets.QDialog): class BacktestingSettingEditor(QtWidgets.QDialog):
""" """
For creating new strategy and editing strategy parameters. For creating new strategy and editing strategy parameters.
""" """
@ -295,7 +362,7 @@ class SettingEditor(QtWidgets.QDialog):
self, class_name: str, parameters: dict self, class_name: str, parameters: dict
): ):
"""""" """"""
super(SettingEditor, self).__init__() super(BacktestingSettingEditor, self).__init__()
self.class_name = class_name self.class_name = class_name
self.parameters = parameters self.parameters = parameters
@ -474,3 +541,164 @@ class DateAxis(pg.AxisItem):
dt = self.dates.get(v, "") dt = self.dates.get(v, "")
strings.append(str(dt)) strings.append(str(dt))
return strings return strings
class OptimizationSettingEditor(QtWidgets.QDialog):
"""
For setting up parameters for optimization.
"""
DISPLAY_NAME_MAP = {
"总收益率": "total_return",
"夏普比率": "sharpe_ratio",
"收益回撤比": "return_drawdown_ratio",
"日均盈亏": "daily_net_pnl"
}
def __init__(
self, class_name: str, parameters: dict
):
""""""
super().__init__()
self.class_name = class_name
self.parameters = parameters
self.edits = {}
self.optimization_setting = None
self.init_ui()
def init_ui(self):
""""""
QLabel = QtWidgets.QLabel
self.target_combo = QtWidgets.QComboBox()
self.target_combo.addItems(list(self.DISPLAY_NAME_MAP.keys()))
grid = QtWidgets.QGridLayout()
grid.addWidget(QLabel("目标"), 0, 0)
grid.addWidget(self.target_combo, 0, 1, 1, 3)
grid.addWidget(QLabel("参数"), 1, 0)
grid.addWidget(QLabel("开始"), 1, 1)
grid.addWidget(QLabel("步进"), 1, 2)
grid.addWidget(QLabel("结束"), 1, 3)
# Add vt_symbol and name edit if add new strategy
self.setWindowTitle(f"优化参数配置:{self.class_name}")
validator = QtGui.QDoubleValidator()
row = 2
for name, value in self.parameters.items():
type_ = type(value)
if type_ not in [int, float]:
continue
start_edit = QtWidgets.QLineEdit(str(value))
step_edit = QtWidgets.QLineEdit(str(1))
end_edit = QtWidgets.QLineEdit(str(value))
for edit in [start_edit, step_edit, end_edit]:
edit.setValidator(validator)
grid.addWidget(QLabel(name), row, 0)
grid.addWidget(start_edit, row, 1)
grid.addWidget(step_edit, row, 2)
grid.addWidget(end_edit, row, 3)
self.edits[name] = {
"type": type_,
"start": start_edit,
"step": step_edit,
"end": end_edit
}
row += 1
button = QtWidgets.QPushButton("确定")
button.clicked.connect(self.generate_setting)
grid.addWidget(button, row, 0, 1, 4)
self.setLayout(grid)
def generate_setting(self):
""""""
self.optimization_setting = OptimizationSetting()
self.target_display = self.target_combo.currentText()
target_name = self.DISPLAY_NAME_MAP[self.target_display]
self.optimization_setting.set_target(target_name)
for name, d in self.edits.items():
type_ = d["type"]
start_value = type_(d["start"].text())
step_value = type_(d["step"].text())
end_value = type_(d["end"].text())
if start_value == end_value:
self.optimization_setting.add_parameter(name, start_value)
else:
self.optimization_setting.add_parameter(
name,
start_value,
end_value,
step_value
)
self.accept()
def get_setting(self):
""""""
return self.optimization_setting
class OptimizationResultMonitor(QtWidgets.QDialog):
"""
For viewing optimization result.
"""
def __init__(
self, result_values: list, target_display: str
):
""""""
super().__init__()
self.result_values = result_values
self.target_display = target_display
self.init_ui()
def init_ui(self):
""""""
self.setWindowTitle("参数优化结果")
self.resize(1100, 500)
table = QtWidgets.QTableWidget()
table.setColumnCount(2)
table.setRowCount(len(self.result_values))
table.setHorizontalHeaderLabels(["参数", self.target_display])
table.verticalHeader().setVisible(False)
table.horizontalHeader().setSectionResizeMode(
0, QtWidgets.QHeaderView.ResizeToContents
)
table.horizontalHeader().setSectionResizeMode(
1, QtWidgets.QHeaderView.Stretch
)
for n, tp in enumerate(self.result_values):
setting, target_value, _ = tp
setting_cell = QtWidgets.QTableWidgetItem(str(setting))
target_cell = QtWidgets.QTableWidgetItem(str(target_value))
setting_cell.setTextAlignment(QtCore.Qt.AlignCenter)
target_cell.setTextAlignment(QtCore.Qt.AlignCenter)
table.setItem(n, 0, setting_cell)
table.setItem(n, 1, target_cell)
vbox = QtWidgets.QVBoxLayout()
vbox.addWidget(table)
self.setLayout(vbox)

View File

@ -460,7 +460,7 @@ class BacktestingEngine:
plt.show() plt.show()
def run_optimization(self, optimization_setting: OptimizationSetting): def run_optimization(self, optimization_setting: OptimizationSetting, output=True):
"""""" """"""
# Get optimization setting and target # Get optimization setting and target
settings = optimization_setting.generate_setting() settings = optimization_setting.generate_setting()
@ -503,9 +503,10 @@ class BacktestingEngine:
result_values = [result.get() for result in results] result_values = [result.get() for result in results]
result_values.sort(reverse=True, key=lambda result: result[1]) result_values.sort(reverse=True, key=lambda result: result[1])
for value in result_values: if output:
msg = f"参数:{value[0]}, 目标:{value[1]}" for value in result_values:
self.output(msg) msg = f"参数:{value[0]}, 目标:{value[1]}"
self.output(msg)
return result_values return result_values
@ -957,7 +958,7 @@ def optimize(
engine.load_data() engine.load_data()
engine.run_backtesting() engine.run_backtesting()
engine.calculate_result() engine.calculate_result()
statistics = engine.calculate_statistics() statistics = engine.calculate_statistics(output=False)
target_value = statistics[target_name] target_value = statistics[target_name]
return (str(setting), target_value, statistics) return (str(setting), target_value, statistics)