[Add]optimization function into CtaBackteserApp

This commit is contained in:
vn.py 2019-04-15 23:32:04 +08:00
parent 4c83b08315
commit 56d0812121
4 changed files with 475 additions and 93 deletions

File diff suppressed because one or more lines are too long

View File

@ -9,7 +9,8 @@ from vnpy.event import Event, EventEngine
from vnpy.trader.engine import BaseEngine, MainEngine
from vnpy.app.cta_strategy import (
CtaTemplate,
BacktestingEngine
BacktestingEngine,
OptimizationSetting
)
@ -33,9 +34,13 @@ class BacktesterEngine(BaseEngine):
self.backtesting_engine = None
self.thread = None
# Backtesting reuslt
self.result_df = None
self.result_statistics = None
# Optimization result
self.result_values = None
self.load_strategy_class()
def init_engine(self):
@ -162,7 +167,7 @@ class BacktesterEngine(BaseEngine):
setting: dict
):
if self.thread:
self.write_log("已有回测在运行中,请等待完成")
self.write_log("已有回测或者优化在运行中,请等待完成")
return False
self.write_log("-" * 40)
@ -194,7 +199,102 @@ class BacktesterEngine(BaseEngine):
""""""
return self.result_statistics
def get_result_values(self):
""""""
return self.result_values
def get_default_setting(self, class_name: str):
""""""
strategy_class = self.classes[class_name]
return strategy_class.get_class_parameters()
def run_optimization(
self,
class_name: str,
vt_symbol: str,
interval: str,
start: datetime,
end: datetime,
rate: float,
slippage: float,
size: int,
pricetick: float,
capital: int,
optimization_setting: OptimizationSetting):
""""""
self.write_log("开始多进程参数优化")
self.result_values = None
engine = self.backtesting_engine
engine.clear_data()
engine.set_parameters(
vt_symbol=vt_symbol,
interval=interval,
start=start,
end=end,
rate=rate,
slippage=slippage,
size=size,
pricetick=pricetick,
capital=capital
)
strategy_class = self.classes[class_name]
engine.add_strategy(
strategy_class,
{}
)
self.result_values = engine.run_optimization(
optimization_setting,
output=False
)
# Clear thread object handler.
self.thread = None
self.write_log("多进程参数优化完成")
# Put optimization done event
event = Event(EVENT_BACKTESTER_OPTIMIZATION_FINISHED)
self.event_engine.put(event)
def start_optimization(
self,
class_name: str,
vt_symbol: str,
interval: str,
start: datetime,
end: datetime,
rate: float,
slippage: float,
size: int,
pricetick: float,
capital: int,
optimization_setting: OptimizationSetting
):
if self.thread:
self.write_log("已有回测或者优化在运行中,请等待完成")
return False
self.write_log("-" * 40)
self.thread = Thread(
target=self.run_optimization,
args=(
class_name,
vt_symbol,
interval,
start,
end,
rate,
slippage,
size,
pricetick,
capital,
optimization_setting
)
)
self.thread.start()
return True

View File

@ -1,19 +1,18 @@
from datetime import datetime, timedelta
import pyqtgraph as pg
import numpy as np
from vnpy.event import Event, EventEngine
from vnpy.trader.ui import QtCore, QtWidgets, QtGui
from vnpy.trader.engine import MainEngine
from vnpy.trader.constant import Interval
import pyqtgraph as pg
from datetime import datetime, timedelta
from ..engine import (
APP_NAME,
EVENT_BACKTESTER_LOG,
EVENT_BACKTESTER_BACKTESTING_FINISHED,
EVENT_BACKTESTER_OPTIMIZATION_FINISHED
EVENT_BACKTESTER_OPTIMIZATION_FINISHED,
OptimizationSetting
)
from vnpy.trader.constant import Interval
from vnpy.trader.engine import MainEngine
from vnpy.trader.ui import QtCore, QtWidgets, QtGui
from vnpy.event import Event, EventEngine
class BacktesterManager(QtWidgets.QWidget):
@ -34,6 +33,8 @@ class BacktesterManager(QtWidgets.QWidget):
self.class_names = []
self.settings = {}
self.target_display = ""
self.init_strategy_settings()
self.init_ui()
self.register_event()
@ -81,8 +82,15 @@ class BacktesterManager(QtWidgets.QWidget):
self.pricetick_line = QtWidgets.QLineEdit("0.2")
self.capital_line = QtWidgets.QLineEdit("1000000")
start_button = QtWidgets.QPushButton("开始回测")
start_button.clicked.connect(self.start_backtesting)
backtesting_button = QtWidgets.QPushButton("开始回测")
backtesting_button.clicked.connect(self.start_backtesting)
optimization_button = QtWidgets.QPushButton("参数优化")
optimization_button.clicked.connect(self.start_optimization)
self.result_button = QtWidgets.QPushButton("优化结果")
self.result_button.clicked.connect(self.show_optimization_result)
self.result_button.setEnabled(False)
form = QtWidgets.QFormLayout()
form.addRow("交易策略", self.class_combo)
@ -95,7 +103,13 @@ class BacktesterManager(QtWidgets.QWidget):
form.addRow("合约乘数", self.size_line)
form.addRow("价格跳动", self.pricetick_line)
form.addRow("回测资金", self.capital_line)
form.addRow(start_button)
form.addRow(backtesting_button)
left_vbox = QtWidgets.QVBoxLayout()
left_vbox.addLayout(form)
left_vbox.addStretch()
left_vbox.addWidget(optimization_button)
left_vbox.addWidget(self.result_button)
# Result part
self.statistics_monitor = StatisticsMonitor()
@ -112,7 +126,7 @@ class BacktesterManager(QtWidgets.QWidget):
vbox.addWidget(self.log_monitor)
hbox = QtWidgets.QHBoxLayout()
hbox.addLayout(form)
hbox.addLayout(left_vbox)
hbox.addLayout(vbox)
hbox.addWidget(self.chart)
self.setLayout(hbox)
@ -134,6 +148,10 @@ class BacktesterManager(QtWidgets.QWidget):
def process_log_event(self, event: Event):
""""""
msg = event.data
self.write_log(msg)
def write_log(self, msg):
""""""
timestamp = datetime.now().strftime("%H:%M:%S")
msg = f"{timestamp}\t{msg}"
self.log_monitor.append(msg)
@ -148,7 +166,8 @@ class BacktesterManager(QtWidgets.QWidget):
def process_optimization_finished_event(self, event: Event):
""""""
pass
self.write_log("请点击[优化结果]按钮查看")
self.result_button.setEnabled(True)
def start_backtesting(self):
""""""
@ -164,7 +183,7 @@ class BacktesterManager(QtWidgets.QWidget):
capital = float(self.capital_line.text())
old_setting = self.settings[class_name]
dialog = SettingEditor(class_name, old_setting)
dialog = BacktestingSettingEditor(class_name, old_setting)
i = dialog.exec()
if i != dialog.Accepted:
return
@ -190,6 +209,54 @@ class BacktesterManager(QtWidgets.QWidget):
self.statistics_monitor.clear_data()
self.chart.clear_data()
def start_optimization(self):
""""""
class_name = self.class_combo.currentText()
vt_symbol = self.symbol_line.text()
interval = self.interval_combo.currentText()
start = self.start_date_edit.date().toPyDate()
end = self.end_date_edit.date().toPyDate()
rate = float(self.rate_line.text())
slippage = float(self.slippage_line.text())
size = float(self.size_line.text())
pricetick = float(self.pricetick_line.text())
capital = float(self.capital_line.text())
parameters = self.settings[class_name]
dialog = OptimizationSettingEditor(class_name, parameters)
i = dialog.exec()
if i != dialog.Accepted:
return
optimization_setting = dialog.get_setting()
self.target_display = dialog.target_display
self.backtester_engine.start_optimization(
class_name,
vt_symbol,
interval,
start,
end,
rate,
slippage,
size,
pricetick,
capital,
optimization_setting
)
self.result_button.setEnabled(False)
def show_optimization_result(self):
""""""
result_values = self.backtester_engine.get_result_values()
dialog = OptimizationResultMonitor(
result_values,
self.target_display
)
dialog.exec_()
def show(self):
""""""
self.showMaximized()
@ -286,7 +353,7 @@ class StatisticsMonitor(QtWidgets.QTableWidget):
cell.setText(str(value))
class SettingEditor(QtWidgets.QDialog):
class BacktestingSettingEditor(QtWidgets.QDialog):
"""
For creating new strategy and editing strategy parameters.
"""
@ -295,7 +362,7 @@ class SettingEditor(QtWidgets.QDialog):
self, class_name: str, parameters: dict
):
""""""
super(SettingEditor, self).__init__()
super(BacktestingSettingEditor, self).__init__()
self.class_name = class_name
self.parameters = parameters
@ -474,3 +541,164 @@ class DateAxis(pg.AxisItem):
dt = self.dates.get(v, "")
strings.append(str(dt))
return strings
class OptimizationSettingEditor(QtWidgets.QDialog):
"""
For setting up parameters for optimization.
"""
DISPLAY_NAME_MAP = {
"总收益率": "total_return",
"夏普比率": "sharpe_ratio",
"收益回撤比": "return_drawdown_ratio",
"日均盈亏": "daily_net_pnl"
}
def __init__(
self, class_name: str, parameters: dict
):
""""""
super().__init__()
self.class_name = class_name
self.parameters = parameters
self.edits = {}
self.optimization_setting = None
self.init_ui()
def init_ui(self):
""""""
QLabel = QtWidgets.QLabel
self.target_combo = QtWidgets.QComboBox()
self.target_combo.addItems(list(self.DISPLAY_NAME_MAP.keys()))
grid = QtWidgets.QGridLayout()
grid.addWidget(QLabel("目标"), 0, 0)
grid.addWidget(self.target_combo, 0, 1, 1, 3)
grid.addWidget(QLabel("参数"), 1, 0)
grid.addWidget(QLabel("开始"), 1, 1)
grid.addWidget(QLabel("步进"), 1, 2)
grid.addWidget(QLabel("结束"), 1, 3)
# Add vt_symbol and name edit if add new strategy
self.setWindowTitle(f"优化参数配置:{self.class_name}")
validator = QtGui.QDoubleValidator()
row = 2
for name, value in self.parameters.items():
type_ = type(value)
if type_ not in [int, float]:
continue
start_edit = QtWidgets.QLineEdit(str(value))
step_edit = QtWidgets.QLineEdit(str(1))
end_edit = QtWidgets.QLineEdit(str(value))
for edit in [start_edit, step_edit, end_edit]:
edit.setValidator(validator)
grid.addWidget(QLabel(name), row, 0)
grid.addWidget(start_edit, row, 1)
grid.addWidget(step_edit, row, 2)
grid.addWidget(end_edit, row, 3)
self.edits[name] = {
"type": type_,
"start": start_edit,
"step": step_edit,
"end": end_edit
}
row += 1
button = QtWidgets.QPushButton("确定")
button.clicked.connect(self.generate_setting)
grid.addWidget(button, row, 0, 1, 4)
self.setLayout(grid)
def generate_setting(self):
""""""
self.optimization_setting = OptimizationSetting()
self.target_display = self.target_combo.currentText()
target_name = self.DISPLAY_NAME_MAP[self.target_display]
self.optimization_setting.set_target(target_name)
for name, d in self.edits.items():
type_ = d["type"]
start_value = type_(d["start"].text())
step_value = type_(d["step"].text())
end_value = type_(d["end"].text())
if start_value == end_value:
self.optimization_setting.add_parameter(name, start_value)
else:
self.optimization_setting.add_parameter(
name,
start_value,
end_value,
step_value
)
self.accept()
def get_setting(self):
""""""
return self.optimization_setting
class OptimizationResultMonitor(QtWidgets.QDialog):
"""
For viewing optimization result.
"""
def __init__(
self, result_values: list, target_display: str
):
""""""
super().__init__()
self.result_values = result_values
self.target_display = target_display
self.init_ui()
def init_ui(self):
""""""
self.setWindowTitle("参数优化结果")
self.resize(1100, 500)
table = QtWidgets.QTableWidget()
table.setColumnCount(2)
table.setRowCount(len(self.result_values))
table.setHorizontalHeaderLabels(["参数", self.target_display])
table.verticalHeader().setVisible(False)
table.horizontalHeader().setSectionResizeMode(
0, QtWidgets.QHeaderView.ResizeToContents
)
table.horizontalHeader().setSectionResizeMode(
1, QtWidgets.QHeaderView.Stretch
)
for n, tp in enumerate(self.result_values):
setting, target_value, _ = tp
setting_cell = QtWidgets.QTableWidgetItem(str(setting))
target_cell = QtWidgets.QTableWidgetItem(str(target_value))
setting_cell.setTextAlignment(QtCore.Qt.AlignCenter)
target_cell.setTextAlignment(QtCore.Qt.AlignCenter)
table.setItem(n, 0, setting_cell)
table.setItem(n, 1, target_cell)
vbox = QtWidgets.QVBoxLayout()
vbox.addWidget(table)
self.setLayout(vbox)

View File

@ -460,7 +460,7 @@ class BacktestingEngine:
plt.show()
def run_optimization(self, optimization_setting: OptimizationSetting):
def run_optimization(self, optimization_setting: OptimizationSetting, output=True):
""""""
# Get optimization setting and target
settings = optimization_setting.generate_setting()
@ -503,6 +503,7 @@ class BacktestingEngine:
result_values = [result.get() for result in results]
result_values.sort(reverse=True, key=lambda result: result[1])
if output:
for value in result_values:
msg = f"参数:{value[0]}, 目标:{value[1]}"
self.output(msg)
@ -957,7 +958,7 @@ def optimize(
engine.load_data()
engine.run_backtesting()
engine.calculate_result()
statistics = engine.calculate_statistics()
statistics = engine.calculate_statistics(output=False)
target_value = statistics[target_name]
return (str(setting), target_value, statistics)