diff --git a/tests/trader/run.py b/tests/trader/run.py index 1ab073c5..2cda6ae2 100644 --- a/tests/trader/run.py +++ b/tests/trader/run.py @@ -16,6 +16,7 @@ from vnpy.gateway.huobi import HuobiGateway from vnpy.app.cta_strategy import CtaStrategyApp from vnpy.app.csv_loader import CsvLoaderApp from vnpy.app.algo_trading import AlgoTradingApp +from vnpy.app.cta_backtester import CtaBacktesterApp def main(): @@ -35,6 +36,7 @@ def main(): main_engine.add_gateway(HuobiGateway) main_engine.add_app(CtaStrategyApp) + main_engine.add_app(CtaBacktesterApp) main_engine.add_app(CsvLoaderApp) main_engine.add_app(AlgoTradingApp) diff --git a/vnpy/app/cta_backtester/__init__.py b/vnpy/app/cta_backtester/__init__.py new file mode 100644 index 00000000..c589e02f --- /dev/null +++ b/vnpy/app/cta_backtester/__init__.py @@ -0,0 +1,17 @@ +from pathlib import Path + +from vnpy.trader.app import BaseApp + +from .engine import BacktesterEngine, APP_NAME + + +class CtaBacktesterApp(BaseApp): + """""" + + app_name = APP_NAME + app_module = __module__ + app_path = Path(__file__).parent + display_name = "CTA回测" + engine_class = BacktesterEngine + widget_name = "BacktesterManager" + icon_name = "backtester.ico" diff --git a/vnpy/app/cta_backtester/engine.py b/vnpy/app/cta_backtester/engine.py new file mode 100644 index 00000000..621330d2 --- /dev/null +++ b/vnpy/app/cta_backtester/engine.py @@ -0,0 +1,199 @@ +import os +import importlib +from datetime import datetime +from threading import Thread +from pathlib import Path + +from vnpy.event import Event, EventEngine +from vnpy.trader.engine import BaseEngine, MainEngine +from vnpy.app.cta_strategy import ( + CtaTemplate, + BacktestingEngine, + OptimizationSetting +) + + +APP_NAME = "CtaBacktester" + +EVENT_BACKTESTER_LOG = "eBacktesterLog" +EVENT_BACKTESTER_BACKTESTING_FINISHED = "eBacktesterBacktestingFinished" +EVENT_BACKTESTER_OPTIMIZATION_FINISHED = "eBacktesterOptimizationFinished" + + +class BacktesterEngine(BaseEngine): + """ + For running CTA strategy backtesting. + """ + + def __init__(self, main_engine: MainEngine, event_engine: EventEngine): + """""" + super().__init__(main_engine, event_engine, APP_NAME) + + self.classes = {} + self.backtesting_engine = None + self.thread = None + + self.result_df = None + self.result_statistics = None + + self.load_strategy_class() + + def init_engine(self): + """""" + self.write_log("初始化CTA回测引擎") + + self.backtesting_engine = BacktestingEngine() + # Redirect log from backtesting engine outside. + self.backtesting_engine.output = self.write_log + + self.write_log("策略文件加载完成") + + def write_log(self, msg: str): + """""" + event = Event(EVENT_BACKTESTER_LOG) + event.data = msg + self.event_engine.put(event) + + def load_strategy_class(self): + """ + Load strategy class from source code. + """ + app_path = Path(__file__).parent.parent + path1 = app_path.joinpath("cta_strategy", "strategies") + self.load_strategy_class_from_folder( + path1, "vnpy.app.cta_strategy.strategies") + + path2 = Path.cwd().joinpath("strategies") + self.load_strategy_class_from_folder(path2, "strategies") + + def load_strategy_class_from_folder(self, path: Path, module_name: str = ""): + """ + Load strategy class from certain folder. + """ + for dirpath, dirnames, filenames in os.walk(path): + for filename in filenames: + if filename.endswith(".py"): + strategy_module_name = ".".join( + [module_name, filename.replace(".py", "")]) + self.load_strategy_class_from_module(strategy_module_name) + + def load_strategy_class_from_module(self, module_name: str): + """ + Load strategy class from module file. + """ + try: + module = importlib.import_module(module_name) + + for name in dir(module): + value = getattr(module, name) + if (isinstance(value, type) and issubclass(value, CtaTemplate) and value is not CtaTemplate): + self.classes[value.__name__] = value + except: # noqa + msg = f"策略文件{module_name}加载失败,触发异常:\n{traceback.format_exc()}" + self.write_log(msg) + + def get_strategy_class_names(self): + """""" + return list(self.classes.keys()) + + def run_backtesting( + self, + class_name: str, + vt_symbol: str, + interval: str, + start: datetime, + end: datetime, + rate: float, + slippage: float, + size: int, + pricetick: float, + capital: int, + setting: dict + ): + """""" + self.result_df = None + self.result_statistics = None + + engine = self.backtesting_engine + engine.clear_data() + + engine.set_parameters( + vt_symbol=vt_symbol, + interval=interval, + start=start, + end=end, + rate=rate, + slippage=slippage, + size=size, + pricetick=pricetick, + capital=capital + ) + + strategy_class = self.classes[class_name] + engine.add_strategy( + strategy_class, + setting + ) + + engine.load_data() + engine.run_backtesting() + self.result_df = engine.calculate_result() + self.result_statistics = engine.calculate_statistics(output=False) + + # Clear thread object handler. + self.thread = None + + # Put backtesting done event + event = Event(EVENT_BACKTESTER_BACKTESTING_FINISHED) + self.event_engine.put(event) + + def start_backtesting( + self, + class_name: str, + vt_symbol: str, + interval: str, + start: datetime, + end: datetime, + rate: float, + slippage: float, + size: int, + pricetick: float, + capital: int, + setting: dict + ): + if self.thread: + self.write_log("已有回测在运行中,请等待完成") + return False + + self.thread = Thread( + target=self.run_backtesting, + args=( + class_name, + vt_symbol, + interval, + start, + end, + rate, + slippage, + size, + pricetick, + capital, + setting + ) + ) + self.thread.start() + + return True + + def get_result_df(self): + """""" + return self.result_df + + def get_result_statistics(self): + """""" + return self.result_statistics + + def get_default_setting(self, class_name: str): + """""" + strategy_class = self.classes[class_name] + return strategy_class.get_class_parameters() diff --git a/vnpy/app/cta_backtester/ui/__init__.py b/vnpy/app/cta_backtester/ui/__init__.py new file mode 100644 index 00000000..a02dafb3 --- /dev/null +++ b/vnpy/app/cta_backtester/ui/__init__.py @@ -0,0 +1 @@ +from .widget import BacktesterManager diff --git a/vnpy/app/cta_backtester/ui/backtester.ico b/vnpy/app/cta_backtester/ui/backtester.ico new file mode 100644 index 00000000..647b8ca9 Binary files /dev/null and b/vnpy/app/cta_backtester/ui/backtester.ico differ diff --git a/vnpy/app/cta_backtester/ui/widget.py b/vnpy/app/cta_backtester/ui/widget.py new file mode 100644 index 00000000..d8cde991 --- /dev/null +++ b/vnpy/app/cta_backtester/ui/widget.py @@ -0,0 +1,476 @@ +from datetime import datetime, timedelta + +import pyqtgraph as pg +import numpy as np + +from vnpy.event import Event, EventEngine +from vnpy.trader.ui import QtCore, QtWidgets, QtGui +from vnpy.trader.engine import MainEngine +from vnpy.trader.constant import Interval + +from ..engine import ( + APP_NAME, + EVENT_BACKTESTER_LOG, + EVENT_BACKTESTER_BACKTESTING_FINISHED, + EVENT_BACKTESTER_OPTIMIZATION_FINISHED +) + + +class BacktesterManager(QtWidgets.QWidget): + """""" + + signal_log = QtCore.pyqtSignal(Event) + signal_backtesting_finished = QtCore.pyqtSignal(Event) + signal_optimization_finished = QtCore.pyqtSignal(Event) + + def __init__(self, main_engine: MainEngine, event_engine: EventEngine): + """""" + super().__init__() + + self.main_engine = main_engine + self.event_engine = event_engine + + self.backtester_engine = main_engine.get_engine(APP_NAME) + self.class_names = [] + self.settings = {} + + self.init_strategy_settings() + self.init_ui() + self.register_event() + self.backtester_engine.init_engine() + + def init_strategy_settings(self): + """""" + self.class_names = self.backtester_engine.get_strategy_class_names() + + for class_name in self.class_names: + setting = self.backtester_engine.get_default_setting(class_name) + self.settings[class_name] = setting + + def init_ui(self): + """""" + self.setWindowTitle("CTA回测") + + # Setting Part + self.class_combo = QtWidgets.QComboBox() + self.class_combo.addItems(self.class_names) + + self.symbol_line = QtWidgets.QLineEdit("IF88.CFFEX") + + self.interval_combo = QtWidgets.QComboBox() + for inteval in Interval: + self.interval_combo.addItem(inteval.value) + + end_dt = datetime.now() + start_dt = end_dt - timedelta(days=3 * 365) + + self.start_date_edit = QtWidgets.QDateEdit( + QtCore.QDate( + start_dt.year, + start_dt.month, + start_dt.day + ) + ) + self.end_date_edit = QtWidgets.QDateEdit( + QtCore.QDate.currentDate() + ) + + self.rate_line = QtWidgets.QLineEdit("0.000025") + self.slippage_line = QtWidgets.QLineEdit("0.2") + self.size_line = QtWidgets.QLineEdit("300") + self.pricetick_line = QtWidgets.QLineEdit("0.2") + self.capital_line = QtWidgets.QLineEdit("1000000") + + start_button = QtWidgets.QPushButton("开始回测") + start_button.clicked.connect(self.start_backtesting) + + form = QtWidgets.QFormLayout() + form.addRow("交易策略", self.class_combo) + form.addRow("本地代码", self.symbol_line) + form.addRow("K线周期", self.interval_combo) + form.addRow("开始日期", self.start_date_edit) + form.addRow("结束日期", self.end_date_edit) + form.addRow("手续费率", self.rate_line) + form.addRow("交易滑点", self.slippage_line) + form.addRow("合约乘数", self.size_line) + form.addRow("价格跳动", self.pricetick_line) + form.addRow("回测资金", self.capital_line) + form.addRow(start_button) + + # Result part + self.statistics_monitor = StatisticsMonitor() + + self.log_monitor = QtWidgets.QTextEdit() + self.log_monitor.setMaximumHeight(400) + + self.chart = BacktesterChart() + self.chart.setMinimumWidth(1000) + + # Layout + vbox = QtWidgets.QVBoxLayout() + vbox.addWidget(self.statistics_monitor) + vbox.addWidget(self.log_monitor) + + hbox = QtWidgets.QHBoxLayout() + hbox.addLayout(form) + hbox.addLayout(vbox) + hbox.addWidget(self.chart) + self.setLayout(hbox) + + def register_event(self): + """""" + self.signal_log.connect(self.process_log_event) + self.signal_backtesting_finished.connect( + self.process_backtesting_finished_event) + self.signal_optimization_finished.connect( + self.process_optimization_finished_event) + + self.event_engine.register(EVENT_BACKTESTER_LOG, self.signal_log.emit) + self.event_engine.register( + EVENT_BACKTESTER_BACKTESTING_FINISHED, self.signal_backtesting_finished.emit) + self.event_engine.register( + EVENT_BACKTESTER_OPTIMIZATION_FINISHED, self.signal_optimization_finished.emit) + + def process_log_event(self, event: Event): + """""" + msg = event.data + timestamp = datetime.now().strftime("%H:%M:%S") + msg = f"{timestamp}\t{msg}" + self.log_monitor.append(msg) + + def process_backtesting_finished_event(self, event: Event): + """""" + statistics = self.backtester_engine.get_result_statistics() + self.statistics_monitor.set_data(statistics) + + df = self.backtester_engine.get_result_df() + self.chart.set_data(df) + + def process_optimization_finished_event(self, event: Event): + """""" + pass + + def start_backtesting(self): + """""" + class_name = self.class_combo.currentText() + vt_symbol = self.symbol_line.text() + interval = self.interval_combo.currentText() + start = self.start_date_edit.date().toPyDate() + end = self.end_date_edit.date().toPyDate() + rate = float(self.rate_line.text()) + slippage = float(self.slippage_line.text()) + size = float(self.size_line.text()) + pricetick = float(self.pricetick_line.text()) + capital = float(self.capital_line.text()) + + old_setting = self.settings[class_name] + dialog = SettingEditor(class_name, old_setting) + i = dialog.exec() + if i != dialog.Accepted: + return + + new_setting = dialog.get_setting() + self.settings[class_name] = new_setting + + result = self.backtester_engine.start_backtesting( + class_name, + vt_symbol, + interval, + start, + end, + rate, + slippage, + size, + pricetick, + capital, + new_setting + ) + + if result: + self.statistics_monitor.clear_data() + self.chart.clear_data() + + def show(self): + """""" + self.showMaximized() + + +class StatisticsMonitor(QtWidgets.QTableWidget): + """""" + KEY_NAME_MAP = { + "start_date": "首个交易日", + "end_date": "最后交易日", + + "total_days": "总交易日", + "profit_days": "盈利交易日", + "loss_days": "亏损交易日", + + "capital": "起始资金", + "end_balance": "结束资金", + + "total_return": "总收益率", + "annual_return": "年化收益", + "max_drawdown": "最大回撤", + "max_ddpercent": "百分比最大回撤", + + "total_net_pnl": "总盈亏", + "total_commission": "总手续费", + "total_slippage": "总滑点", + "total_turnover": "总成交额", + "total_trade_count": "总成交笔数", + + "daily_net_pnl": "日均盈亏", + "daily_commission": "日均手续费", + "daily_slippage": "日均滑点", + "daily_turnover": "日均成交额", + "daily_trade_count": "日均成交笔数", + + "daily_return": "日均收益率", + "return_std": "收益标准差", + "sharpe_ratio": "夏普比率", + "return_drawdown_ratio": "收益回撤比" + } + + def __init__(self): + """""" + super().__init__() + + self.cells = {} + + self.init_ui() + + def init_ui(self): + """""" + self.setRowCount(len(self.KEY_NAME_MAP)) + self.setVerticalHeaderLabels(list(self.KEY_NAME_MAP.values())) + + self.setColumnCount(1) + self.horizontalHeader().setVisible(False) + self.horizontalHeader().setSectionResizeMode( + QtWidgets.QHeaderView.Stretch + ) + + for row, key in enumerate(self.KEY_NAME_MAP.keys()): + cell = QtWidgets.QTableWidgetItem() + self.setItem(row, 0, cell) + self.cells[key] = cell + + def clear_data(self): + """""" + for cell in self.cells.values(): + cell.setText("") + + def set_data(self, data: dict): + """""" + data["capital"] = f"{data['capital']:,.2f}" + data["end_balance"] = f"{data['end_balance']:,.2f}" + data["total_return"] = f"{data['total_return']:,.2f}%" + data["annual_return"] = f"{data['annual_return']:,.2f}%" + data["max_drawdown"] = f"{data['max_drawdown']:,.2f}" + data["max_ddpercent"] = f"{data['max_ddpercent']:,.2f}%" + data["total_net_pnl"] = f"{data['total_net_pnl']:,.2f}" + data["total_commission"] = f"{data['total_commission']:,.2f}" + data["total_slippage"] = f"{data['total_slippage']:,.2f}" + data["total_turnover"] = f"{data['total_turnover']:,.2f}" + data["daily_net_pnl"] = f"{data['daily_net_pnl']:,.2f}" + data["daily_commission"] = f"{data['daily_commission']:,.2f}" + data["daily_slippage"] = f"{data['daily_slippage']:,.2f}" + data["daily_turnover"] = f"{data['daily_turnover']:,.2f}" + data["daily_return"] = f"{data['daily_return']:,.2f}%" + data["return_std"] = f"{data['return_std']:,.2f}%" + data["sharpe_ratio"] = f"{data['sharpe_ratio']:,.2f}" + data["return_drawdown_ratio"] = f"{data['return_drawdown_ratio']:,.2f}" + + for key, cell in self.cells.items(): + value = data.get(key, "") + cell.setText(str(value)) + + +class SettingEditor(QtWidgets.QDialog): + """ + For creating new strategy and editing strategy parameters. + """ + + def __init__( + self, class_name: str, parameters: dict + ): + """""" + super(SettingEditor, self).__init__() + + self.class_name = class_name + self.parameters = parameters + self.edits = {} + + self.init_ui() + + def init_ui(self): + """""" + form = QtWidgets.QFormLayout() + + # Add vt_symbol and name edit if add new strategy + self.setWindowTitle(f"策略参数配置:{self.class_name}") + button_text = "确定" + parameters = self.parameters + + for name, value in parameters.items(): + type_ = type(value) + + edit = QtWidgets.QLineEdit(str(value)) + if type_ is int: + validator = QtGui.QIntValidator() + edit.setValidator(validator) + elif type_ is float: + validator = QtGui.QDoubleValidator() + edit.setValidator(validator) + + form.addRow(f"{name} {type_}", edit) + + self.edits[name] = (edit, type_) + + button = QtWidgets.QPushButton(button_text) + button.clicked.connect(self.accept) + form.addRow(button) + + self.setLayout(form) + + def get_setting(self): + """""" + setting = {} + + for name, tp in self.edits.items(): + edit, type_ = tp + value_text = edit.text() + + if type_ == bool: + if value_text == "True": + value = True + else: + value = False + else: + value = type_(value_text) + + setting[name] = value + + return setting + + +class BacktesterChart(pg.GraphicsWindow): + """""" + + def __init__(self): + """""" + super().__init__(title="Backtester Chart") + + self.dates = {} + + self.init_ui() + + def init_ui(self): + """""" + pg.setConfigOptions(antialias=True) + + # Create plot widgets + self.balance_plot = self.addPlot( + title="账户净值", + axisItems={"bottom": DateAxis(self.dates, orientation="bottom")} + ) + self.nextRow() + + self.drawdown_plot = self.addPlot( + title="净值回撤", + axisItems={"bottom": DateAxis(self.dates, orientation="bottom")} + ) + self.nextRow() + + self.pnl_plot = self.addPlot( + title="每日盈亏", + axisItems={"bottom": DateAxis(self.dates, orientation="bottom")} + ) + self.nextRow() + + self.distribution_plot = self.addPlot(title="盈亏分布") + + # Add curves and bars on plot widgets + self.balance_curve = self.balance_plot.plot( + pen=pg.mkPen("#ffc107", width=3) + ) + + dd_color = "#303f9f" + self.drawdown_curve = self.drawdown_plot.plot( + fillLevel=-0.3, brush=dd_color, pen=dd_color + ) + + profit_color = 'r' + loss_color = 'g' + self.profit_pnl_bar = pg.BarGraphItem( + x=[], height=[], width=0.3, brush=profit_color, pen=profit_color + ) + self.loss_pnl_bar = pg.BarGraphItem( + x=[], height=[], width=0.3, brush=loss_color, pen=loss_color + ) + self.pnl_plot.addItem(self.profit_pnl_bar) + self.pnl_plot.addItem(self.loss_pnl_bar) + + distribution_color = "#6d4c41" + self.distribution_curve = self.distribution_plot.plot( + fillLevel=-0.3, brush=distribution_color, pen=distribution_color + ) + + def clear_data(self): + """""" + self.balance_curve.setData([], []) + self.drawdown_curve.setData([], []) + self.profit_pnl_bar.setOpts(x=[], height=[]) + self.loss_pnl_bar.setOpts(x=[], height=[]) + self.distribution_curve.setData([], []) + + def set_data(self, df): + """""" + count = len(df) + + self.dates.clear() + for n, date in enumerate(df.index): + self.dates[n] = date + + # Set data for curve of balance and drawdown + self.balance_curve.setData(df["balance"]) + self.drawdown_curve.setData(df["drawdown"]) + + # Set data for daily pnl bar + profit_pnl_x = [] + profit_pnl_height = [] + loss_pnl_x = [] + loss_pnl_height = [] + + for count, pnl in enumerate(df["net_pnl"]): + if pnl >= 0: + profit_pnl_height.append(pnl) + profit_pnl_x.append(count) + else: + loss_pnl_height.append(pnl) + loss_pnl_x.append(count) + + self.profit_pnl_bar.setOpts(x=profit_pnl_x, height=profit_pnl_height) + self.loss_pnl_bar.setOpts(x=loss_pnl_x, height=loss_pnl_height) + + # Set data for pnl distribution + hist, x = np.histogram(df["net_pnl"], bins="auto") + x = x[:-1] + self.distribution_curve.setData(x, hist) + + +class DateAxis(pg.AxisItem): + """Axis for showing date data""" + + def __init__(self, dates: dict, *args, **kwargs): + """""" + super().__init__(*args, **kwargs) + self.dates = dates + + def tickStrings(self, values, scale, spacing): + """""" + strings = [] + for v in values: + dt = self.dates.get(v, "") + strings.append(str(dt)) + return strings diff --git a/vnpy/app/cta_strategy/__init__.py b/vnpy/app/cta_strategy/__init__.py index 9d079f9b..e753746b 100644 --- a/vnpy/app/cta_strategy/__init__.py +++ b/vnpy/app/cta_strategy/__init__.py @@ -7,6 +7,7 @@ from vnpy.trader.utility import BarGenerator, ArrayManager from .base import APP_NAME, StopOrder from .engine import CtaEngine +from .backtesting import BacktestingEngine, OptimizationSetting from .template import CtaTemplate, CtaSignal, TargetPosTemplate diff --git a/vnpy/app/cta_strategy/backtesting.py b/vnpy/app/cta_strategy/backtesting.py index 39768844..5712f798 100644 --- a/vnpy/app/cta_strategy/backtesting.py +++ b/vnpy/app/cta_strategy/backtesting.py @@ -293,7 +293,7 @@ class BacktestingEngine: self.output("逐日盯市盈亏计算完成") return self.daily_df - def calculate_statistics(self, df: DataFrame = None, Output=True): + def calculate_statistics(self, df: DataFrame = None, output=True): """""" self.output("开始计算策略统计指标") @@ -377,7 +377,7 @@ class BacktestingEngine: return_drawdown_ratio = -total_return / max_ddpercent # Output - if Output: + if output: self.output("-" * 30) self.output(f"首个交易日:\t{start_date}") self.output(f"最后交易日:\t{end_date}") @@ -417,6 +417,7 @@ class BacktestingEngine: "total_days": total_days, "profit_days": profit_days, "loss_days": loss_days, + "capital": self.capital, "end_balance": end_balance, "max_drawdown": max_drawdown, "max_ddpercent": max_ddpercent,