[Add]optimization function into CtaBackteserApp
This commit is contained in:
parent
4c83b08315
commit
56d0812121
File diff suppressed because one or more lines are too long
@ -9,7 +9,8 @@ from vnpy.event import Event, EventEngine
|
||||
from vnpy.trader.engine import BaseEngine, MainEngine
|
||||
from vnpy.app.cta_strategy import (
|
||||
CtaTemplate,
|
||||
BacktestingEngine
|
||||
BacktestingEngine,
|
||||
OptimizationSetting
|
||||
)
|
||||
|
||||
|
||||
@ -33,9 +34,13 @@ class BacktesterEngine(BaseEngine):
|
||||
self.backtesting_engine = None
|
||||
self.thread = None
|
||||
|
||||
# Backtesting reuslt
|
||||
self.result_df = None
|
||||
self.result_statistics = None
|
||||
|
||||
# Optimization result
|
||||
self.result_values = None
|
||||
|
||||
self.load_strategy_class()
|
||||
|
||||
def init_engine(self):
|
||||
@ -162,7 +167,7 @@ class BacktesterEngine(BaseEngine):
|
||||
setting: dict
|
||||
):
|
||||
if self.thread:
|
||||
self.write_log("已有回测在运行中,请等待完成")
|
||||
self.write_log("已有回测或者优化在运行中,请等待完成")
|
||||
return False
|
||||
|
||||
self.write_log("-" * 40)
|
||||
@ -194,7 +199,102 @@ class BacktesterEngine(BaseEngine):
|
||||
""""""
|
||||
return self.result_statistics
|
||||
|
||||
def get_result_values(self):
|
||||
""""""
|
||||
return self.result_values
|
||||
|
||||
def get_default_setting(self, class_name: str):
|
||||
""""""
|
||||
strategy_class = self.classes[class_name]
|
||||
return strategy_class.get_class_parameters()
|
||||
|
||||
def run_optimization(
|
||||
self,
|
||||
class_name: str,
|
||||
vt_symbol: str,
|
||||
interval: str,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
rate: float,
|
||||
slippage: float,
|
||||
size: int,
|
||||
pricetick: float,
|
||||
capital: int,
|
||||
optimization_setting: OptimizationSetting):
|
||||
""""""
|
||||
self.write_log("开始多进程参数优化")
|
||||
|
||||
self.result_values = None
|
||||
|
||||
engine = self.backtesting_engine
|
||||
engine.clear_data()
|
||||
|
||||
engine.set_parameters(
|
||||
vt_symbol=vt_symbol,
|
||||
interval=interval,
|
||||
start=start,
|
||||
end=end,
|
||||
rate=rate,
|
||||
slippage=slippage,
|
||||
size=size,
|
||||
pricetick=pricetick,
|
||||
capital=capital
|
||||
)
|
||||
|
||||
strategy_class = self.classes[class_name]
|
||||
engine.add_strategy(
|
||||
strategy_class,
|
||||
{}
|
||||
)
|
||||
|
||||
self.result_values = engine.run_optimization(
|
||||
optimization_setting,
|
||||
output=False
|
||||
)
|
||||
|
||||
# Clear thread object handler.
|
||||
self.thread = None
|
||||
self.write_log("多进程参数优化完成")
|
||||
|
||||
# Put optimization done event
|
||||
event = Event(EVENT_BACKTESTER_OPTIMIZATION_FINISHED)
|
||||
self.event_engine.put(event)
|
||||
|
||||
def start_optimization(
|
||||
self,
|
||||
class_name: str,
|
||||
vt_symbol: str,
|
||||
interval: str,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
rate: float,
|
||||
slippage: float,
|
||||
size: int,
|
||||
pricetick: float,
|
||||
capital: int,
|
||||
optimization_setting: OptimizationSetting
|
||||
):
|
||||
if self.thread:
|
||||
self.write_log("已有回测或者优化在运行中,请等待完成")
|
||||
return False
|
||||
|
||||
self.write_log("-" * 40)
|
||||
self.thread = Thread(
|
||||
target=self.run_optimization,
|
||||
args=(
|
||||
class_name,
|
||||
vt_symbol,
|
||||
interval,
|
||||
start,
|
||||
end,
|
||||
rate,
|
||||
slippage,
|
||||
size,
|
||||
pricetick,
|
||||
capital,
|
||||
optimization_setting
|
||||
)
|
||||
)
|
||||
self.thread.start()
|
||||
|
||||
return True
|
||||
|
@ -1,19 +1,18 @@
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pyqtgraph as pg
|
||||
import numpy as np
|
||||
|
||||
from vnpy.event import Event, EventEngine
|
||||
from vnpy.trader.ui import QtCore, QtWidgets, QtGui
|
||||
from vnpy.trader.engine import MainEngine
|
||||
from vnpy.trader.constant import Interval
|
||||
import pyqtgraph as pg
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from ..engine import (
|
||||
APP_NAME,
|
||||
EVENT_BACKTESTER_LOG,
|
||||
EVENT_BACKTESTER_BACKTESTING_FINISHED,
|
||||
EVENT_BACKTESTER_OPTIMIZATION_FINISHED
|
||||
EVENT_BACKTESTER_OPTIMIZATION_FINISHED,
|
||||
OptimizationSetting
|
||||
)
|
||||
from vnpy.trader.constant import Interval
|
||||
from vnpy.trader.engine import MainEngine
|
||||
from vnpy.trader.ui import QtCore, QtWidgets, QtGui
|
||||
from vnpy.event import Event, EventEngine
|
||||
|
||||
|
||||
class BacktesterManager(QtWidgets.QWidget):
|
||||
@ -34,6 +33,8 @@ class BacktesterManager(QtWidgets.QWidget):
|
||||
self.class_names = []
|
||||
self.settings = {}
|
||||
|
||||
self.target_display = ""
|
||||
|
||||
self.init_strategy_settings()
|
||||
self.init_ui()
|
||||
self.register_event()
|
||||
@ -81,8 +82,15 @@ class BacktesterManager(QtWidgets.QWidget):
|
||||
self.pricetick_line = QtWidgets.QLineEdit("0.2")
|
||||
self.capital_line = QtWidgets.QLineEdit("1000000")
|
||||
|
||||
start_button = QtWidgets.QPushButton("开始回测")
|
||||
start_button.clicked.connect(self.start_backtesting)
|
||||
backtesting_button = QtWidgets.QPushButton("开始回测")
|
||||
backtesting_button.clicked.connect(self.start_backtesting)
|
||||
|
||||
optimization_button = QtWidgets.QPushButton("参数优化")
|
||||
optimization_button.clicked.connect(self.start_optimization)
|
||||
|
||||
self.result_button = QtWidgets.QPushButton("优化结果")
|
||||
self.result_button.clicked.connect(self.show_optimization_result)
|
||||
self.result_button.setEnabled(False)
|
||||
|
||||
form = QtWidgets.QFormLayout()
|
||||
form.addRow("交易策略", self.class_combo)
|
||||
@ -95,7 +103,13 @@ class BacktesterManager(QtWidgets.QWidget):
|
||||
form.addRow("合约乘数", self.size_line)
|
||||
form.addRow("价格跳动", self.pricetick_line)
|
||||
form.addRow("回测资金", self.capital_line)
|
||||
form.addRow(start_button)
|
||||
form.addRow(backtesting_button)
|
||||
|
||||
left_vbox = QtWidgets.QVBoxLayout()
|
||||
left_vbox.addLayout(form)
|
||||
left_vbox.addStretch()
|
||||
left_vbox.addWidget(optimization_button)
|
||||
left_vbox.addWidget(self.result_button)
|
||||
|
||||
# Result part
|
||||
self.statistics_monitor = StatisticsMonitor()
|
||||
@ -112,7 +126,7 @@ class BacktesterManager(QtWidgets.QWidget):
|
||||
vbox.addWidget(self.log_monitor)
|
||||
|
||||
hbox = QtWidgets.QHBoxLayout()
|
||||
hbox.addLayout(form)
|
||||
hbox.addLayout(left_vbox)
|
||||
hbox.addLayout(vbox)
|
||||
hbox.addWidget(self.chart)
|
||||
self.setLayout(hbox)
|
||||
@ -134,6 +148,10 @@ class BacktesterManager(QtWidgets.QWidget):
|
||||
def process_log_event(self, event: Event):
|
||||
""""""
|
||||
msg = event.data
|
||||
self.write_log(msg)
|
||||
|
||||
def write_log(self, msg):
|
||||
""""""
|
||||
timestamp = datetime.now().strftime("%H:%M:%S")
|
||||
msg = f"{timestamp}\t{msg}"
|
||||
self.log_monitor.append(msg)
|
||||
@ -148,7 +166,8 @@ class BacktesterManager(QtWidgets.QWidget):
|
||||
|
||||
def process_optimization_finished_event(self, event: Event):
|
||||
""""""
|
||||
pass
|
||||
self.write_log("请点击[优化结果]按钮查看")
|
||||
self.result_button.setEnabled(True)
|
||||
|
||||
def start_backtesting(self):
|
||||
""""""
|
||||
@ -164,7 +183,7 @@ class BacktesterManager(QtWidgets.QWidget):
|
||||
capital = float(self.capital_line.text())
|
||||
|
||||
old_setting = self.settings[class_name]
|
||||
dialog = SettingEditor(class_name, old_setting)
|
||||
dialog = BacktestingSettingEditor(class_name, old_setting)
|
||||
i = dialog.exec()
|
||||
if i != dialog.Accepted:
|
||||
return
|
||||
@ -190,6 +209,54 @@ class BacktesterManager(QtWidgets.QWidget):
|
||||
self.statistics_monitor.clear_data()
|
||||
self.chart.clear_data()
|
||||
|
||||
def start_optimization(self):
|
||||
""""""
|
||||
class_name = self.class_combo.currentText()
|
||||
vt_symbol = self.symbol_line.text()
|
||||
interval = self.interval_combo.currentText()
|
||||
start = self.start_date_edit.date().toPyDate()
|
||||
end = self.end_date_edit.date().toPyDate()
|
||||
rate = float(self.rate_line.text())
|
||||
slippage = float(self.slippage_line.text())
|
||||
size = float(self.size_line.text())
|
||||
pricetick = float(self.pricetick_line.text())
|
||||
capital = float(self.capital_line.text())
|
||||
|
||||
parameters = self.settings[class_name]
|
||||
dialog = OptimizationSettingEditor(class_name, parameters)
|
||||
i = dialog.exec()
|
||||
if i != dialog.Accepted:
|
||||
return
|
||||
|
||||
optimization_setting = dialog.get_setting()
|
||||
self.target_display = dialog.target_display
|
||||
|
||||
self.backtester_engine.start_optimization(
|
||||
class_name,
|
||||
vt_symbol,
|
||||
interval,
|
||||
start,
|
||||
end,
|
||||
rate,
|
||||
slippage,
|
||||
size,
|
||||
pricetick,
|
||||
capital,
|
||||
optimization_setting
|
||||
)
|
||||
|
||||
self.result_button.setEnabled(False)
|
||||
|
||||
def show_optimization_result(self):
|
||||
""""""
|
||||
result_values = self.backtester_engine.get_result_values()
|
||||
|
||||
dialog = OptimizationResultMonitor(
|
||||
result_values,
|
||||
self.target_display
|
||||
)
|
||||
dialog.exec_()
|
||||
|
||||
def show(self):
|
||||
""""""
|
||||
self.showMaximized()
|
||||
@ -286,7 +353,7 @@ class StatisticsMonitor(QtWidgets.QTableWidget):
|
||||
cell.setText(str(value))
|
||||
|
||||
|
||||
class SettingEditor(QtWidgets.QDialog):
|
||||
class BacktestingSettingEditor(QtWidgets.QDialog):
|
||||
"""
|
||||
For creating new strategy and editing strategy parameters.
|
||||
"""
|
||||
@ -295,7 +362,7 @@ class SettingEditor(QtWidgets.QDialog):
|
||||
self, class_name: str, parameters: dict
|
||||
):
|
||||
""""""
|
||||
super(SettingEditor, self).__init__()
|
||||
super(BacktestingSettingEditor, self).__init__()
|
||||
|
||||
self.class_name = class_name
|
||||
self.parameters = parameters
|
||||
@ -474,3 +541,164 @@ class DateAxis(pg.AxisItem):
|
||||
dt = self.dates.get(v, "")
|
||||
strings.append(str(dt))
|
||||
return strings
|
||||
|
||||
|
||||
class OptimizationSettingEditor(QtWidgets.QDialog):
|
||||
"""
|
||||
For setting up parameters for optimization.
|
||||
"""
|
||||
DISPLAY_NAME_MAP = {
|
||||
"总收益率": "total_return",
|
||||
"夏普比率": "sharpe_ratio",
|
||||
"收益回撤比": "return_drawdown_ratio",
|
||||
"日均盈亏": "daily_net_pnl"
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self, class_name: str, parameters: dict
|
||||
):
|
||||
""""""
|
||||
super().__init__()
|
||||
|
||||
self.class_name = class_name
|
||||
self.parameters = parameters
|
||||
self.edits = {}
|
||||
|
||||
self.optimization_setting = None
|
||||
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
""""""
|
||||
QLabel = QtWidgets.QLabel
|
||||
|
||||
self.target_combo = QtWidgets.QComboBox()
|
||||
self.target_combo.addItems(list(self.DISPLAY_NAME_MAP.keys()))
|
||||
|
||||
grid = QtWidgets.QGridLayout()
|
||||
grid.addWidget(QLabel("目标"), 0, 0)
|
||||
grid.addWidget(self.target_combo, 0, 1, 1, 3)
|
||||
grid.addWidget(QLabel("参数"), 1, 0)
|
||||
grid.addWidget(QLabel("开始"), 1, 1)
|
||||
grid.addWidget(QLabel("步进"), 1, 2)
|
||||
grid.addWidget(QLabel("结束"), 1, 3)
|
||||
|
||||
# Add vt_symbol and name edit if add new strategy
|
||||
self.setWindowTitle(f"优化参数配置:{self.class_name}")
|
||||
|
||||
validator = QtGui.QDoubleValidator()
|
||||
row = 2
|
||||
|
||||
for name, value in self.parameters.items():
|
||||
type_ = type(value)
|
||||
if type_ not in [int, float]:
|
||||
continue
|
||||
|
||||
start_edit = QtWidgets.QLineEdit(str(value))
|
||||
step_edit = QtWidgets.QLineEdit(str(1))
|
||||
end_edit = QtWidgets.QLineEdit(str(value))
|
||||
|
||||
for edit in [start_edit, step_edit, end_edit]:
|
||||
edit.setValidator(validator)
|
||||
|
||||
grid.addWidget(QLabel(name), row, 0)
|
||||
grid.addWidget(start_edit, row, 1)
|
||||
grid.addWidget(step_edit, row, 2)
|
||||
grid.addWidget(end_edit, row, 3)
|
||||
|
||||
self.edits[name] = {
|
||||
"type": type_,
|
||||
"start": start_edit,
|
||||
"step": step_edit,
|
||||
"end": end_edit
|
||||
}
|
||||
|
||||
row += 1
|
||||
|
||||
button = QtWidgets.QPushButton("确定")
|
||||
button.clicked.connect(self.generate_setting)
|
||||
grid.addWidget(button, row, 0, 1, 4)
|
||||
|
||||
self.setLayout(grid)
|
||||
|
||||
def generate_setting(self):
|
||||
""""""
|
||||
self.optimization_setting = OptimizationSetting()
|
||||
|
||||
self.target_display = self.target_combo.currentText()
|
||||
target_name = self.DISPLAY_NAME_MAP[self.target_display]
|
||||
self.optimization_setting.set_target(target_name)
|
||||
|
||||
for name, d in self.edits.items():
|
||||
type_ = d["type"]
|
||||
start_value = type_(d["start"].text())
|
||||
step_value = type_(d["step"].text())
|
||||
end_value = type_(d["end"].text())
|
||||
|
||||
if start_value == end_value:
|
||||
self.optimization_setting.add_parameter(name, start_value)
|
||||
else:
|
||||
self.optimization_setting.add_parameter(
|
||||
name,
|
||||
start_value,
|
||||
end_value,
|
||||
step_value
|
||||
)
|
||||
|
||||
self.accept()
|
||||
|
||||
def get_setting(self):
|
||||
""""""
|
||||
return self.optimization_setting
|
||||
|
||||
|
||||
class OptimizationResultMonitor(QtWidgets.QDialog):
|
||||
"""
|
||||
For viewing optimization result.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, result_values: list, target_display: str
|
||||
):
|
||||
""""""
|
||||
super().__init__()
|
||||
|
||||
self.result_values = result_values
|
||||
self.target_display = target_display
|
||||
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
""""""
|
||||
self.setWindowTitle("参数优化结果")
|
||||
self.resize(1100, 500)
|
||||
|
||||
table = QtWidgets.QTableWidget()
|
||||
|
||||
table.setColumnCount(2)
|
||||
table.setRowCount(len(self.result_values))
|
||||
table.setHorizontalHeaderLabels(["参数", self.target_display])
|
||||
table.verticalHeader().setVisible(False)
|
||||
|
||||
table.horizontalHeader().setSectionResizeMode(
|
||||
0, QtWidgets.QHeaderView.ResizeToContents
|
||||
)
|
||||
table.horizontalHeader().setSectionResizeMode(
|
||||
1, QtWidgets.QHeaderView.Stretch
|
||||
)
|
||||
|
||||
for n, tp in enumerate(self.result_values):
|
||||
setting, target_value, _ = tp
|
||||
setting_cell = QtWidgets.QTableWidgetItem(str(setting))
|
||||
target_cell = QtWidgets.QTableWidgetItem(str(target_value))
|
||||
|
||||
setting_cell.setTextAlignment(QtCore.Qt.AlignCenter)
|
||||
target_cell.setTextAlignment(QtCore.Qt.AlignCenter)
|
||||
|
||||
table.setItem(n, 0, setting_cell)
|
||||
table.setItem(n, 1, target_cell)
|
||||
|
||||
vbox = QtWidgets.QVBoxLayout()
|
||||
vbox.addWidget(table)
|
||||
|
||||
self.setLayout(vbox)
|
||||
|
@ -460,7 +460,7 @@ class BacktestingEngine:
|
||||
|
||||
plt.show()
|
||||
|
||||
def run_optimization(self, optimization_setting: OptimizationSetting):
|
||||
def run_optimization(self, optimization_setting: OptimizationSetting, output=True):
|
||||
""""""
|
||||
# Get optimization setting and target
|
||||
settings = optimization_setting.generate_setting()
|
||||
@ -503,9 +503,10 @@ class BacktestingEngine:
|
||||
result_values = [result.get() for result in results]
|
||||
result_values.sort(reverse=True, key=lambda result: result[1])
|
||||
|
||||
for value in result_values:
|
||||
msg = f"参数:{value[0]}, 目标:{value[1]}"
|
||||
self.output(msg)
|
||||
if output:
|
||||
for value in result_values:
|
||||
msg = f"参数:{value[0]}, 目标:{value[1]}"
|
||||
self.output(msg)
|
||||
|
||||
return result_values
|
||||
|
||||
@ -957,7 +958,7 @@ def optimize(
|
||||
engine.load_data()
|
||||
engine.run_backtesting()
|
||||
engine.calculate_result()
|
||||
statistics = engine.calculate_statistics()
|
||||
statistics = engine.calculate_statistics(output=False)
|
||||
|
||||
target_value = statistics[target_name]
|
||||
return (str(setting), target_value, statistics)
|
||||
|
Loading…
Reference in New Issue
Block a user