diff --git a/vnpy/app/cta_backtester/engine.py b/vnpy/app/cta_backtester/engine.py index 4214baa3..2c47f4dc 100644 --- a/vnpy/app/cta_backtester/engine.py +++ b/vnpy/app/cta_backtester/engine.py @@ -4,6 +4,7 @@ import traceback from datetime import datetime from threading import Thread from pathlib import Path +from inspect import getfile from vnpy.event import Event, EventEngine from vnpy.trader.engine import BaseEngine, MainEngine @@ -116,6 +117,12 @@ class BacktesterEngine(BaseEngine): msg = f"策略文件{module_name}加载失败,触发异常:\n{traceback.format_exc()}" self.write_log(msg) + def reload_strategy_class(self): + """""" + self.classes.clear() + self.load_strategy_class() + self.write_log("策略文件重载刷新完成") + def get_strategy_class_names(self): """""" return list(self.classes.keys()) @@ -425,3 +432,9 @@ class BacktesterEngine(BaseEngine): def get_history_data(self): """""" return self.backtesting_engine.history_data + + def get_strategy_class_file(self, class_name: str): + """""" + strategy_class = self.classes[class_name] + file_path = getfile(strategy_class) + return file_path diff --git a/vnpy/app/cta_backtester/ui/widget.py b/vnpy/app/cta_backtester/ui/widget.py index 87d1ad0a..1aba2e23 100644 --- a/vnpy/app/cta_backtester/ui/widget.py +++ b/vnpy/app/cta_backtester/ui/widget.py @@ -13,6 +13,7 @@ from vnpy.trader.constant import Interval, Direction from vnpy.trader.engine import MainEngine from vnpy.trader.ui import QtCore, QtWidgets, QtGui from vnpy.trader.ui.widget import BaseMonitor, BaseCell, DirectionCell, EnumCell +from vnpy.trader.ui.editor import CodeEditor from vnpy.event import Event, EventEngine from vnpy.chart import ChartWidget, CandleItem, VolumeItem @@ -117,6 +118,12 @@ class BacktesterManager(QtWidgets.QWidget): self.candle_button.clicked.connect(self.show_candle_chart) self.candle_button.setEnabled(False) + edit_button = QtWidgets.QPushButton("代码编辑") + edit_button.clicked.connect(self.edit_strategy_code) + + reload_button = QtWidgets.QPushButton("策略重载") + reload_button.clicked.connect(self.reload_strategy_class) + for button in [ backtesting_button, optimization_button, @@ -125,7 +132,9 @@ class BacktesterManager(QtWidgets.QWidget): self.order_button, self.trade_button, self.daily_button, - self.candle_button + self.candle_button, + edit_button, + reload_button ]: button.setFixedHeight(button.sizeHint().height() * 2) @@ -142,18 +151,24 @@ class BacktesterManager(QtWidgets.QWidget): form.addRow("回测资金", self.capital_line) form.addRow("合约模式", self.inverse_combo) + result_grid = QtWidgets.QGridLayout() + result_grid.addWidget(self.trade_button, 0, 0) + result_grid.addWidget(self.order_button, 0, 1) + result_grid.addWidget(self.daily_button, 1, 0) + result_grid.addWidget(self.candle_button, 1, 1) + left_vbox = QtWidgets.QVBoxLayout() left_vbox.addLayout(form) left_vbox.addWidget(backtesting_button) left_vbox.addWidget(downloading_button) left_vbox.addStretch() - left_vbox.addWidget(self.trade_button) - left_vbox.addWidget(self.order_button) - left_vbox.addWidget(self.daily_button) - left_vbox.addWidget(self.candle_button) + left_vbox.addLayout(result_grid) left_vbox.addStretch() left_vbox.addWidget(optimization_button) left_vbox.addWidget(self.result_button) + left_vbox.addStretch() + left_vbox.addWidget(edit_button) + left_vbox.addWidget(reload_button) # Result part self.statistics_monitor = StatisticsMonitor() @@ -197,6 +212,9 @@ class BacktesterManager(QtWidgets.QWidget): hbox.addWidget(self.chart) self.setLayout(hbox) + # Code Editor + self.editor = CodeEditor(self.main_engine, self.event_engine) + def register_event(self): """""" self.signal_log.connect(self.process_log_event) @@ -403,6 +421,21 @@ class BacktesterManager(QtWidgets.QWidget): self.candle_dialog.exec_() + def edit_strategy_code(self): + """""" + class_name = self.class_combo.currentText() + file_path = self.backtester_engine.get_strategy_class_file(class_name) + + self.editor.open_editor(file_path) + self.editor.show() + + def reload_strategy_class(self): + """""" + self.backtester_engine.reload_strategy_class() + + self.class_combo.clear() + self.init_strategy_settings() + def show(self): """""" self.showMaximized()