[Add]CtaBackteserApp for GUI based backtesting

This commit is contained in:
vn.py 2019-04-11 14:01:29 +08:00
parent 80bf809ccb
commit bee61d79b0
8 changed files with 699 additions and 2 deletions

View File

@ -16,6 +16,7 @@ from vnpy.gateway.huobi import HuobiGateway
from vnpy.app.cta_strategy import CtaStrategyApp
from vnpy.app.csv_loader import CsvLoaderApp
from vnpy.app.algo_trading import AlgoTradingApp
from vnpy.app.cta_backtester import CtaBacktesterApp
def main():
@ -35,6 +36,7 @@ def main():
main_engine.add_gateway(HuobiGateway)
main_engine.add_app(CtaStrategyApp)
main_engine.add_app(CtaBacktesterApp)
main_engine.add_app(CsvLoaderApp)
main_engine.add_app(AlgoTradingApp)

View File

@ -0,0 +1,17 @@
from pathlib import Path
from vnpy.trader.app import BaseApp
from .engine import BacktesterEngine, APP_NAME
class CtaBacktesterApp(BaseApp):
""""""
app_name = APP_NAME
app_module = __module__
app_path = Path(__file__).parent
display_name = "CTA回测"
engine_class = BacktesterEngine
widget_name = "BacktesterManager"
icon_name = "backtester.ico"

View File

@ -0,0 +1,199 @@
import os
import importlib
from datetime import datetime
from threading import Thread
from pathlib import Path
from vnpy.event import Event, EventEngine
from vnpy.trader.engine import BaseEngine, MainEngine
from vnpy.app.cta_strategy import (
CtaTemplate,
BacktestingEngine,
OptimizationSetting
)
APP_NAME = "CtaBacktester"
EVENT_BACKTESTER_LOG = "eBacktesterLog"
EVENT_BACKTESTER_BACKTESTING_FINISHED = "eBacktesterBacktestingFinished"
EVENT_BACKTESTER_OPTIMIZATION_FINISHED = "eBacktesterOptimizationFinished"
class BacktesterEngine(BaseEngine):
"""
For running CTA strategy backtesting.
"""
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
""""""
super().__init__(main_engine, event_engine, APP_NAME)
self.classes = {}
self.backtesting_engine = None
self.thread = None
self.result_df = None
self.result_statistics = None
self.load_strategy_class()
def init_engine(self):
""""""
self.write_log("初始化CTA回测引擎")
self.backtesting_engine = BacktestingEngine()
# Redirect log from backtesting engine outside.
self.backtesting_engine.output = self.write_log
self.write_log("策略文件加载完成")
def write_log(self, msg: str):
""""""
event = Event(EVENT_BACKTESTER_LOG)
event.data = msg
self.event_engine.put(event)
def load_strategy_class(self):
"""
Load strategy class from source code.
"""
app_path = Path(__file__).parent.parent
path1 = app_path.joinpath("cta_strategy", "strategies")
self.load_strategy_class_from_folder(
path1, "vnpy.app.cta_strategy.strategies")
path2 = Path.cwd().joinpath("strategies")
self.load_strategy_class_from_folder(path2, "strategies")
def load_strategy_class_from_folder(self, path: Path, module_name: str = ""):
"""
Load strategy class from certain folder.
"""
for dirpath, dirnames, filenames in os.walk(path):
for filename in filenames:
if filename.endswith(".py"):
strategy_module_name = ".".join(
[module_name, filename.replace(".py", "")])
self.load_strategy_class_from_module(strategy_module_name)
def load_strategy_class_from_module(self, module_name: str):
"""
Load strategy class from module file.
"""
try:
module = importlib.import_module(module_name)
for name in dir(module):
value = getattr(module, name)
if (isinstance(value, type) and issubclass(value, CtaTemplate) and value is not CtaTemplate):
self.classes[value.__name__] = value
except: # noqa
msg = f"策略文件{module_name}加载失败,触发异常:\n{traceback.format_exc()}"
self.write_log(msg)
def get_strategy_class_names(self):
""""""
return list(self.classes.keys())
def run_backtesting(
self,
class_name: str,
vt_symbol: str,
interval: str,
start: datetime,
end: datetime,
rate: float,
slippage: float,
size: int,
pricetick: float,
capital: int,
setting: dict
):
""""""
self.result_df = None
self.result_statistics = None
engine = self.backtesting_engine
engine.clear_data()
engine.set_parameters(
vt_symbol=vt_symbol,
interval=interval,
start=start,
end=end,
rate=rate,
slippage=slippage,
size=size,
pricetick=pricetick,
capital=capital
)
strategy_class = self.classes[class_name]
engine.add_strategy(
strategy_class,
setting
)
engine.load_data()
engine.run_backtesting()
self.result_df = engine.calculate_result()
self.result_statistics = engine.calculate_statistics(output=False)
# Clear thread object handler.
self.thread = None
# Put backtesting done event
event = Event(EVENT_BACKTESTER_BACKTESTING_FINISHED)
self.event_engine.put(event)
def start_backtesting(
self,
class_name: str,
vt_symbol: str,
interval: str,
start: datetime,
end: datetime,
rate: float,
slippage: float,
size: int,
pricetick: float,
capital: int,
setting: dict
):
if self.thread:
self.write_log("已有回测在运行中,请等待完成")
return False
self.thread = Thread(
target=self.run_backtesting,
args=(
class_name,
vt_symbol,
interval,
start,
end,
rate,
slippage,
size,
pricetick,
capital,
setting
)
)
self.thread.start()
return True
def get_result_df(self):
""""""
return self.result_df
def get_result_statistics(self):
""""""
return self.result_statistics
def get_default_setting(self, class_name: str):
""""""
strategy_class = self.classes[class_name]
return strategy_class.get_class_parameters()

View File

@ -0,0 +1 @@
from .widget import BacktesterManager

Binary file not shown.

After

Width:  |  Height:  |  Size: 63 KiB

View File

@ -0,0 +1,476 @@
from datetime import datetime, timedelta
import pyqtgraph as pg
import numpy as np
from vnpy.event import Event, EventEngine
from vnpy.trader.ui import QtCore, QtWidgets, QtGui
from vnpy.trader.engine import MainEngine
from vnpy.trader.constant import Interval
from ..engine import (
APP_NAME,
EVENT_BACKTESTER_LOG,
EVENT_BACKTESTER_BACKTESTING_FINISHED,
EVENT_BACKTESTER_OPTIMIZATION_FINISHED
)
class BacktesterManager(QtWidgets.QWidget):
""""""
signal_log = QtCore.pyqtSignal(Event)
signal_backtesting_finished = QtCore.pyqtSignal(Event)
signal_optimization_finished = QtCore.pyqtSignal(Event)
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
""""""
super().__init__()
self.main_engine = main_engine
self.event_engine = event_engine
self.backtester_engine = main_engine.get_engine(APP_NAME)
self.class_names = []
self.settings = {}
self.init_strategy_settings()
self.init_ui()
self.register_event()
self.backtester_engine.init_engine()
def init_strategy_settings(self):
""""""
self.class_names = self.backtester_engine.get_strategy_class_names()
for class_name in self.class_names:
setting = self.backtester_engine.get_default_setting(class_name)
self.settings[class_name] = setting
def init_ui(self):
""""""
self.setWindowTitle("CTA回测")
# Setting Part
self.class_combo = QtWidgets.QComboBox()
self.class_combo.addItems(self.class_names)
self.symbol_line = QtWidgets.QLineEdit("IF88.CFFEX")
self.interval_combo = QtWidgets.QComboBox()
for inteval in Interval:
self.interval_combo.addItem(inteval.value)
end_dt = datetime.now()
start_dt = end_dt - timedelta(days=3 * 365)
self.start_date_edit = QtWidgets.QDateEdit(
QtCore.QDate(
start_dt.year,
start_dt.month,
start_dt.day
)
)
self.end_date_edit = QtWidgets.QDateEdit(
QtCore.QDate.currentDate()
)
self.rate_line = QtWidgets.QLineEdit("0.000025")
self.slippage_line = QtWidgets.QLineEdit("0.2")
self.size_line = QtWidgets.QLineEdit("300")
self.pricetick_line = QtWidgets.QLineEdit("0.2")
self.capital_line = QtWidgets.QLineEdit("1000000")
start_button = QtWidgets.QPushButton("开始回测")
start_button.clicked.connect(self.start_backtesting)
form = QtWidgets.QFormLayout()
form.addRow("交易策略", self.class_combo)
form.addRow("本地代码", self.symbol_line)
form.addRow("K线周期", self.interval_combo)
form.addRow("开始日期", self.start_date_edit)
form.addRow("结束日期", self.end_date_edit)
form.addRow("手续费率", self.rate_line)
form.addRow("交易滑点", self.slippage_line)
form.addRow("合约乘数", self.size_line)
form.addRow("价格跳动", self.pricetick_line)
form.addRow("回测资金", self.capital_line)
form.addRow(start_button)
# Result part
self.statistics_monitor = StatisticsMonitor()
self.log_monitor = QtWidgets.QTextEdit()
self.log_monitor.setMaximumHeight(400)
self.chart = BacktesterChart()
self.chart.setMinimumWidth(1000)
# Layout
vbox = QtWidgets.QVBoxLayout()
vbox.addWidget(self.statistics_monitor)
vbox.addWidget(self.log_monitor)
hbox = QtWidgets.QHBoxLayout()
hbox.addLayout(form)
hbox.addLayout(vbox)
hbox.addWidget(self.chart)
self.setLayout(hbox)
def register_event(self):
""""""
self.signal_log.connect(self.process_log_event)
self.signal_backtesting_finished.connect(
self.process_backtesting_finished_event)
self.signal_optimization_finished.connect(
self.process_optimization_finished_event)
self.event_engine.register(EVENT_BACKTESTER_LOG, self.signal_log.emit)
self.event_engine.register(
EVENT_BACKTESTER_BACKTESTING_FINISHED, self.signal_backtesting_finished.emit)
self.event_engine.register(
EVENT_BACKTESTER_OPTIMIZATION_FINISHED, self.signal_optimization_finished.emit)
def process_log_event(self, event: Event):
""""""
msg = event.data
timestamp = datetime.now().strftime("%H:%M:%S")
msg = f"{timestamp}\t{msg}"
self.log_monitor.append(msg)
def process_backtesting_finished_event(self, event: Event):
""""""
statistics = self.backtester_engine.get_result_statistics()
self.statistics_monitor.set_data(statistics)
df = self.backtester_engine.get_result_df()
self.chart.set_data(df)
def process_optimization_finished_event(self, event: Event):
""""""
pass
def start_backtesting(self):
""""""
class_name = self.class_combo.currentText()
vt_symbol = self.symbol_line.text()
interval = self.interval_combo.currentText()
start = self.start_date_edit.date().toPyDate()
end = self.end_date_edit.date().toPyDate()
rate = float(self.rate_line.text())
slippage = float(self.slippage_line.text())
size = float(self.size_line.text())
pricetick = float(self.pricetick_line.text())
capital = float(self.capital_line.text())
old_setting = self.settings[class_name]
dialog = SettingEditor(class_name, old_setting)
i = dialog.exec()
if i != dialog.Accepted:
return
new_setting = dialog.get_setting()
self.settings[class_name] = new_setting
result = self.backtester_engine.start_backtesting(
class_name,
vt_symbol,
interval,
start,
end,
rate,
slippage,
size,
pricetick,
capital,
new_setting
)
if result:
self.statistics_monitor.clear_data()
self.chart.clear_data()
def show(self):
""""""
self.showMaximized()
class StatisticsMonitor(QtWidgets.QTableWidget):
""""""
KEY_NAME_MAP = {
"start_date": "首个交易日",
"end_date": "最后交易日",
"total_days": "总交易日",
"profit_days": "盈利交易日",
"loss_days": "亏损交易日",
"capital": "起始资金",
"end_balance": "结束资金",
"total_return": "总收益率",
"annual_return": "年化收益",
"max_drawdown": "最大回撤",
"max_ddpercent": "百分比最大回撤",
"total_net_pnl": "总盈亏",
"total_commission": "总手续费",
"total_slippage": "总滑点",
"total_turnover": "总成交额",
"total_trade_count": "总成交笔数",
"daily_net_pnl": "日均盈亏",
"daily_commission": "日均手续费",
"daily_slippage": "日均滑点",
"daily_turnover": "日均成交额",
"daily_trade_count": "日均成交笔数",
"daily_return": "日均收益率",
"return_std": "收益标准差",
"sharpe_ratio": "夏普比率",
"return_drawdown_ratio": "收益回撤比"
}
def __init__(self):
""""""
super().__init__()
self.cells = {}
self.init_ui()
def init_ui(self):
""""""
self.setRowCount(len(self.KEY_NAME_MAP))
self.setVerticalHeaderLabels(list(self.KEY_NAME_MAP.values()))
self.setColumnCount(1)
self.horizontalHeader().setVisible(False)
self.horizontalHeader().setSectionResizeMode(
QtWidgets.QHeaderView.Stretch
)
for row, key in enumerate(self.KEY_NAME_MAP.keys()):
cell = QtWidgets.QTableWidgetItem()
self.setItem(row, 0, cell)
self.cells[key] = cell
def clear_data(self):
""""""
for cell in self.cells.values():
cell.setText("")
def set_data(self, data: dict):
""""""
data["capital"] = f"{data['capital']:,.2f}"
data["end_balance"] = f"{data['end_balance']:,.2f}"
data["total_return"] = f"{data['total_return']:,.2f}%"
data["annual_return"] = f"{data['annual_return']:,.2f}%"
data["max_drawdown"] = f"{data['max_drawdown']:,.2f}"
data["max_ddpercent"] = f"{data['max_ddpercent']:,.2f}%"
data["total_net_pnl"] = f"{data['total_net_pnl']:,.2f}"
data["total_commission"] = f"{data['total_commission']:,.2f}"
data["total_slippage"] = f"{data['total_slippage']:,.2f}"
data["total_turnover"] = f"{data['total_turnover']:,.2f}"
data["daily_net_pnl"] = f"{data['daily_net_pnl']:,.2f}"
data["daily_commission"] = f"{data['daily_commission']:,.2f}"
data["daily_slippage"] = f"{data['daily_slippage']:,.2f}"
data["daily_turnover"] = f"{data['daily_turnover']:,.2f}"
data["daily_return"] = f"{data['daily_return']:,.2f}%"
data["return_std"] = f"{data['return_std']:,.2f}%"
data["sharpe_ratio"] = f"{data['sharpe_ratio']:,.2f}"
data["return_drawdown_ratio"] = f"{data['return_drawdown_ratio']:,.2f}"
for key, cell in self.cells.items():
value = data.get(key, "")
cell.setText(str(value))
class SettingEditor(QtWidgets.QDialog):
"""
For creating new strategy and editing strategy parameters.
"""
def __init__(
self, class_name: str, parameters: dict
):
""""""
super(SettingEditor, self).__init__()
self.class_name = class_name
self.parameters = parameters
self.edits = {}
self.init_ui()
def init_ui(self):
""""""
form = QtWidgets.QFormLayout()
# Add vt_symbol and name edit if add new strategy
self.setWindowTitle(f"策略参数配置:{self.class_name}")
button_text = "确定"
parameters = self.parameters
for name, value in parameters.items():
type_ = type(value)
edit = QtWidgets.QLineEdit(str(value))
if type_ is int:
validator = QtGui.QIntValidator()
edit.setValidator(validator)
elif type_ is float:
validator = QtGui.QDoubleValidator()
edit.setValidator(validator)
form.addRow(f"{name} {type_}", edit)
self.edits[name] = (edit, type_)
button = QtWidgets.QPushButton(button_text)
button.clicked.connect(self.accept)
form.addRow(button)
self.setLayout(form)
def get_setting(self):
""""""
setting = {}
for name, tp in self.edits.items():
edit, type_ = tp
value_text = edit.text()
if type_ == bool:
if value_text == "True":
value = True
else:
value = False
else:
value = type_(value_text)
setting[name] = value
return setting
class BacktesterChart(pg.GraphicsWindow):
""""""
def __init__(self):
""""""
super().__init__(title="Backtester Chart")
self.dates = {}
self.init_ui()
def init_ui(self):
""""""
pg.setConfigOptions(antialias=True)
# Create plot widgets
self.balance_plot = self.addPlot(
title="账户净值",
axisItems={"bottom": DateAxis(self.dates, orientation="bottom")}
)
self.nextRow()
self.drawdown_plot = self.addPlot(
title="净值回撤",
axisItems={"bottom": DateAxis(self.dates, orientation="bottom")}
)
self.nextRow()
self.pnl_plot = self.addPlot(
title="每日盈亏",
axisItems={"bottom": DateAxis(self.dates, orientation="bottom")}
)
self.nextRow()
self.distribution_plot = self.addPlot(title="盈亏分布")
# Add curves and bars on plot widgets
self.balance_curve = self.balance_plot.plot(
pen=pg.mkPen("#ffc107", width=3)
)
dd_color = "#303f9f"
self.drawdown_curve = self.drawdown_plot.plot(
fillLevel=-0.3, brush=dd_color, pen=dd_color
)
profit_color = 'r'
loss_color = 'g'
self.profit_pnl_bar = pg.BarGraphItem(
x=[], height=[], width=0.3, brush=profit_color, pen=profit_color
)
self.loss_pnl_bar = pg.BarGraphItem(
x=[], height=[], width=0.3, brush=loss_color, pen=loss_color
)
self.pnl_plot.addItem(self.profit_pnl_bar)
self.pnl_plot.addItem(self.loss_pnl_bar)
distribution_color = "#6d4c41"
self.distribution_curve = self.distribution_plot.plot(
fillLevel=-0.3, brush=distribution_color, pen=distribution_color
)
def clear_data(self):
""""""
self.balance_curve.setData([], [])
self.drawdown_curve.setData([], [])
self.profit_pnl_bar.setOpts(x=[], height=[])
self.loss_pnl_bar.setOpts(x=[], height=[])
self.distribution_curve.setData([], [])
def set_data(self, df):
""""""
count = len(df)
self.dates.clear()
for n, date in enumerate(df.index):
self.dates[n] = date
# Set data for curve of balance and drawdown
self.balance_curve.setData(df["balance"])
self.drawdown_curve.setData(df["drawdown"])
# Set data for daily pnl bar
profit_pnl_x = []
profit_pnl_height = []
loss_pnl_x = []
loss_pnl_height = []
for count, pnl in enumerate(df["net_pnl"]):
if pnl >= 0:
profit_pnl_height.append(pnl)
profit_pnl_x.append(count)
else:
loss_pnl_height.append(pnl)
loss_pnl_x.append(count)
self.profit_pnl_bar.setOpts(x=profit_pnl_x, height=profit_pnl_height)
self.loss_pnl_bar.setOpts(x=loss_pnl_x, height=loss_pnl_height)
# Set data for pnl distribution
hist, x = np.histogram(df["net_pnl"], bins="auto")
x = x[:-1]
self.distribution_curve.setData(x, hist)
class DateAxis(pg.AxisItem):
"""Axis for showing date data"""
def __init__(self, dates: dict, *args, **kwargs):
""""""
super().__init__(*args, **kwargs)
self.dates = dates
def tickStrings(self, values, scale, spacing):
""""""
strings = []
for v in values:
dt = self.dates.get(v, "")
strings.append(str(dt))
return strings

View File

@ -7,6 +7,7 @@ from vnpy.trader.utility import BarGenerator, ArrayManager
from .base import APP_NAME, StopOrder
from .engine import CtaEngine
from .backtesting import BacktestingEngine, OptimizationSetting
from .template import CtaTemplate, CtaSignal, TargetPosTemplate

View File

@ -293,7 +293,7 @@ class BacktestingEngine:
self.output("逐日盯市盈亏计算完成")
return self.daily_df
def calculate_statistics(self, df: DataFrame = None, Output=True):
def calculate_statistics(self, df: DataFrame = None, output=True):
""""""
self.output("开始计算策略统计指标")
@ -377,7 +377,7 @@ class BacktestingEngine:
return_drawdown_ratio = -total_return / max_ddpercent
# Output
if Output:
if output:
self.output("-" * 30)
self.output(f"首个交易日:\t{start_date}")
self.output(f"最后交易日:\t{end_date}")
@ -417,6 +417,7 @@ class BacktestingEngine:
"total_days": total_days,
"profit_days": profit_days,
"loss_days": loss_days,
"capital": self.capital,
"end_balance": end_balance,
"max_drawdown": max_drawdown,
"max_ddpercent": max_ddpercent,