[Add]CtaBackteserApp for GUI based backtesting
This commit is contained in:
parent
80bf809ccb
commit
bee61d79b0
@ -16,6 +16,7 @@ from vnpy.gateway.huobi import HuobiGateway
|
||||
from vnpy.app.cta_strategy import CtaStrategyApp
|
||||
from vnpy.app.csv_loader import CsvLoaderApp
|
||||
from vnpy.app.algo_trading import AlgoTradingApp
|
||||
from vnpy.app.cta_backtester import CtaBacktesterApp
|
||||
|
||||
|
||||
def main():
|
||||
@ -35,6 +36,7 @@ def main():
|
||||
main_engine.add_gateway(HuobiGateway)
|
||||
|
||||
main_engine.add_app(CtaStrategyApp)
|
||||
main_engine.add_app(CtaBacktesterApp)
|
||||
main_engine.add_app(CsvLoaderApp)
|
||||
main_engine.add_app(AlgoTradingApp)
|
||||
|
||||
|
17
vnpy/app/cta_backtester/__init__.py
Normal file
17
vnpy/app/cta_backtester/__init__.py
Normal file
@ -0,0 +1,17 @@
|
||||
from pathlib import Path
|
||||
|
||||
from vnpy.trader.app import BaseApp
|
||||
|
||||
from .engine import BacktesterEngine, APP_NAME
|
||||
|
||||
|
||||
class CtaBacktesterApp(BaseApp):
|
||||
""""""
|
||||
|
||||
app_name = APP_NAME
|
||||
app_module = __module__
|
||||
app_path = Path(__file__).parent
|
||||
display_name = "CTA回测"
|
||||
engine_class = BacktesterEngine
|
||||
widget_name = "BacktesterManager"
|
||||
icon_name = "backtester.ico"
|
199
vnpy/app/cta_backtester/engine.py
Normal file
199
vnpy/app/cta_backtester/engine.py
Normal file
@ -0,0 +1,199 @@
|
||||
import os
|
||||
import importlib
|
||||
from datetime import datetime
|
||||
from threading import Thread
|
||||
from pathlib import Path
|
||||
|
||||
from vnpy.event import Event, EventEngine
|
||||
from vnpy.trader.engine import BaseEngine, MainEngine
|
||||
from vnpy.app.cta_strategy import (
|
||||
CtaTemplate,
|
||||
BacktestingEngine,
|
||||
OptimizationSetting
|
||||
)
|
||||
|
||||
|
||||
APP_NAME = "CtaBacktester"
|
||||
|
||||
EVENT_BACKTESTER_LOG = "eBacktesterLog"
|
||||
EVENT_BACKTESTER_BACKTESTING_FINISHED = "eBacktesterBacktestingFinished"
|
||||
EVENT_BACKTESTER_OPTIMIZATION_FINISHED = "eBacktesterOptimizationFinished"
|
||||
|
||||
|
||||
class BacktesterEngine(BaseEngine):
|
||||
"""
|
||||
For running CTA strategy backtesting.
|
||||
"""
|
||||
|
||||
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
|
||||
""""""
|
||||
super().__init__(main_engine, event_engine, APP_NAME)
|
||||
|
||||
self.classes = {}
|
||||
self.backtesting_engine = None
|
||||
self.thread = None
|
||||
|
||||
self.result_df = None
|
||||
self.result_statistics = None
|
||||
|
||||
self.load_strategy_class()
|
||||
|
||||
def init_engine(self):
|
||||
""""""
|
||||
self.write_log("初始化CTA回测引擎")
|
||||
|
||||
self.backtesting_engine = BacktestingEngine()
|
||||
# Redirect log from backtesting engine outside.
|
||||
self.backtesting_engine.output = self.write_log
|
||||
|
||||
self.write_log("策略文件加载完成")
|
||||
|
||||
def write_log(self, msg: str):
|
||||
""""""
|
||||
event = Event(EVENT_BACKTESTER_LOG)
|
||||
event.data = msg
|
||||
self.event_engine.put(event)
|
||||
|
||||
def load_strategy_class(self):
|
||||
"""
|
||||
Load strategy class from source code.
|
||||
"""
|
||||
app_path = Path(__file__).parent.parent
|
||||
path1 = app_path.joinpath("cta_strategy", "strategies")
|
||||
self.load_strategy_class_from_folder(
|
||||
path1, "vnpy.app.cta_strategy.strategies")
|
||||
|
||||
path2 = Path.cwd().joinpath("strategies")
|
||||
self.load_strategy_class_from_folder(path2, "strategies")
|
||||
|
||||
def load_strategy_class_from_folder(self, path: Path, module_name: str = ""):
|
||||
"""
|
||||
Load strategy class from certain folder.
|
||||
"""
|
||||
for dirpath, dirnames, filenames in os.walk(path):
|
||||
for filename in filenames:
|
||||
if filename.endswith(".py"):
|
||||
strategy_module_name = ".".join(
|
||||
[module_name, filename.replace(".py", "")])
|
||||
self.load_strategy_class_from_module(strategy_module_name)
|
||||
|
||||
def load_strategy_class_from_module(self, module_name: str):
|
||||
"""
|
||||
Load strategy class from module file.
|
||||
"""
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
for name in dir(module):
|
||||
value = getattr(module, name)
|
||||
if (isinstance(value, type) and issubclass(value, CtaTemplate) and value is not CtaTemplate):
|
||||
self.classes[value.__name__] = value
|
||||
except: # noqa
|
||||
msg = f"策略文件{module_name}加载失败,触发异常:\n{traceback.format_exc()}"
|
||||
self.write_log(msg)
|
||||
|
||||
def get_strategy_class_names(self):
|
||||
""""""
|
||||
return list(self.classes.keys())
|
||||
|
||||
def run_backtesting(
|
||||
self,
|
||||
class_name: str,
|
||||
vt_symbol: str,
|
||||
interval: str,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
rate: float,
|
||||
slippage: float,
|
||||
size: int,
|
||||
pricetick: float,
|
||||
capital: int,
|
||||
setting: dict
|
||||
):
|
||||
""""""
|
||||
self.result_df = None
|
||||
self.result_statistics = None
|
||||
|
||||
engine = self.backtesting_engine
|
||||
engine.clear_data()
|
||||
|
||||
engine.set_parameters(
|
||||
vt_symbol=vt_symbol,
|
||||
interval=interval,
|
||||
start=start,
|
||||
end=end,
|
||||
rate=rate,
|
||||
slippage=slippage,
|
||||
size=size,
|
||||
pricetick=pricetick,
|
||||
capital=capital
|
||||
)
|
||||
|
||||
strategy_class = self.classes[class_name]
|
||||
engine.add_strategy(
|
||||
strategy_class,
|
||||
setting
|
||||
)
|
||||
|
||||
engine.load_data()
|
||||
engine.run_backtesting()
|
||||
self.result_df = engine.calculate_result()
|
||||
self.result_statistics = engine.calculate_statistics(output=False)
|
||||
|
||||
# Clear thread object handler.
|
||||
self.thread = None
|
||||
|
||||
# Put backtesting done event
|
||||
event = Event(EVENT_BACKTESTER_BACKTESTING_FINISHED)
|
||||
self.event_engine.put(event)
|
||||
|
||||
def start_backtesting(
|
||||
self,
|
||||
class_name: str,
|
||||
vt_symbol: str,
|
||||
interval: str,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
rate: float,
|
||||
slippage: float,
|
||||
size: int,
|
||||
pricetick: float,
|
||||
capital: int,
|
||||
setting: dict
|
||||
):
|
||||
if self.thread:
|
||||
self.write_log("已有回测在运行中,请等待完成")
|
||||
return False
|
||||
|
||||
self.thread = Thread(
|
||||
target=self.run_backtesting,
|
||||
args=(
|
||||
class_name,
|
||||
vt_symbol,
|
||||
interval,
|
||||
start,
|
||||
end,
|
||||
rate,
|
||||
slippage,
|
||||
size,
|
||||
pricetick,
|
||||
capital,
|
||||
setting
|
||||
)
|
||||
)
|
||||
self.thread.start()
|
||||
|
||||
return True
|
||||
|
||||
def get_result_df(self):
|
||||
""""""
|
||||
return self.result_df
|
||||
|
||||
def get_result_statistics(self):
|
||||
""""""
|
||||
return self.result_statistics
|
||||
|
||||
def get_default_setting(self, class_name: str):
|
||||
""""""
|
||||
strategy_class = self.classes[class_name]
|
||||
return strategy_class.get_class_parameters()
|
1
vnpy/app/cta_backtester/ui/__init__.py
Normal file
1
vnpy/app/cta_backtester/ui/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .widget import BacktesterManager
|
BIN
vnpy/app/cta_backtester/ui/backtester.ico
Normal file
BIN
vnpy/app/cta_backtester/ui/backtester.ico
Normal file
Binary file not shown.
After Width: | Height: | Size: 63 KiB |
476
vnpy/app/cta_backtester/ui/widget.py
Normal file
476
vnpy/app/cta_backtester/ui/widget.py
Normal file
@ -0,0 +1,476 @@
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pyqtgraph as pg
|
||||
import numpy as np
|
||||
|
||||
from vnpy.event import Event, EventEngine
|
||||
from vnpy.trader.ui import QtCore, QtWidgets, QtGui
|
||||
from vnpy.trader.engine import MainEngine
|
||||
from vnpy.trader.constant import Interval
|
||||
|
||||
from ..engine import (
|
||||
APP_NAME,
|
||||
EVENT_BACKTESTER_LOG,
|
||||
EVENT_BACKTESTER_BACKTESTING_FINISHED,
|
||||
EVENT_BACKTESTER_OPTIMIZATION_FINISHED
|
||||
)
|
||||
|
||||
|
||||
class BacktesterManager(QtWidgets.QWidget):
|
||||
""""""
|
||||
|
||||
signal_log = QtCore.pyqtSignal(Event)
|
||||
signal_backtesting_finished = QtCore.pyqtSignal(Event)
|
||||
signal_optimization_finished = QtCore.pyqtSignal(Event)
|
||||
|
||||
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
|
||||
""""""
|
||||
super().__init__()
|
||||
|
||||
self.main_engine = main_engine
|
||||
self.event_engine = event_engine
|
||||
|
||||
self.backtester_engine = main_engine.get_engine(APP_NAME)
|
||||
self.class_names = []
|
||||
self.settings = {}
|
||||
|
||||
self.init_strategy_settings()
|
||||
self.init_ui()
|
||||
self.register_event()
|
||||
self.backtester_engine.init_engine()
|
||||
|
||||
def init_strategy_settings(self):
|
||||
""""""
|
||||
self.class_names = self.backtester_engine.get_strategy_class_names()
|
||||
|
||||
for class_name in self.class_names:
|
||||
setting = self.backtester_engine.get_default_setting(class_name)
|
||||
self.settings[class_name] = setting
|
||||
|
||||
def init_ui(self):
|
||||
""""""
|
||||
self.setWindowTitle("CTA回测")
|
||||
|
||||
# Setting Part
|
||||
self.class_combo = QtWidgets.QComboBox()
|
||||
self.class_combo.addItems(self.class_names)
|
||||
|
||||
self.symbol_line = QtWidgets.QLineEdit("IF88.CFFEX")
|
||||
|
||||
self.interval_combo = QtWidgets.QComboBox()
|
||||
for inteval in Interval:
|
||||
self.interval_combo.addItem(inteval.value)
|
||||
|
||||
end_dt = datetime.now()
|
||||
start_dt = end_dt - timedelta(days=3 * 365)
|
||||
|
||||
self.start_date_edit = QtWidgets.QDateEdit(
|
||||
QtCore.QDate(
|
||||
start_dt.year,
|
||||
start_dt.month,
|
||||
start_dt.day
|
||||
)
|
||||
)
|
||||
self.end_date_edit = QtWidgets.QDateEdit(
|
||||
QtCore.QDate.currentDate()
|
||||
)
|
||||
|
||||
self.rate_line = QtWidgets.QLineEdit("0.000025")
|
||||
self.slippage_line = QtWidgets.QLineEdit("0.2")
|
||||
self.size_line = QtWidgets.QLineEdit("300")
|
||||
self.pricetick_line = QtWidgets.QLineEdit("0.2")
|
||||
self.capital_line = QtWidgets.QLineEdit("1000000")
|
||||
|
||||
start_button = QtWidgets.QPushButton("开始回测")
|
||||
start_button.clicked.connect(self.start_backtesting)
|
||||
|
||||
form = QtWidgets.QFormLayout()
|
||||
form.addRow("交易策略", self.class_combo)
|
||||
form.addRow("本地代码", self.symbol_line)
|
||||
form.addRow("K线周期", self.interval_combo)
|
||||
form.addRow("开始日期", self.start_date_edit)
|
||||
form.addRow("结束日期", self.end_date_edit)
|
||||
form.addRow("手续费率", self.rate_line)
|
||||
form.addRow("交易滑点", self.slippage_line)
|
||||
form.addRow("合约乘数", self.size_line)
|
||||
form.addRow("价格跳动", self.pricetick_line)
|
||||
form.addRow("回测资金", self.capital_line)
|
||||
form.addRow(start_button)
|
||||
|
||||
# Result part
|
||||
self.statistics_monitor = StatisticsMonitor()
|
||||
|
||||
self.log_monitor = QtWidgets.QTextEdit()
|
||||
self.log_monitor.setMaximumHeight(400)
|
||||
|
||||
self.chart = BacktesterChart()
|
||||
self.chart.setMinimumWidth(1000)
|
||||
|
||||
# Layout
|
||||
vbox = QtWidgets.QVBoxLayout()
|
||||
vbox.addWidget(self.statistics_monitor)
|
||||
vbox.addWidget(self.log_monitor)
|
||||
|
||||
hbox = QtWidgets.QHBoxLayout()
|
||||
hbox.addLayout(form)
|
||||
hbox.addLayout(vbox)
|
||||
hbox.addWidget(self.chart)
|
||||
self.setLayout(hbox)
|
||||
|
||||
def register_event(self):
|
||||
""""""
|
||||
self.signal_log.connect(self.process_log_event)
|
||||
self.signal_backtesting_finished.connect(
|
||||
self.process_backtesting_finished_event)
|
||||
self.signal_optimization_finished.connect(
|
||||
self.process_optimization_finished_event)
|
||||
|
||||
self.event_engine.register(EVENT_BACKTESTER_LOG, self.signal_log.emit)
|
||||
self.event_engine.register(
|
||||
EVENT_BACKTESTER_BACKTESTING_FINISHED, self.signal_backtesting_finished.emit)
|
||||
self.event_engine.register(
|
||||
EVENT_BACKTESTER_OPTIMIZATION_FINISHED, self.signal_optimization_finished.emit)
|
||||
|
||||
def process_log_event(self, event: Event):
|
||||
""""""
|
||||
msg = event.data
|
||||
timestamp = datetime.now().strftime("%H:%M:%S")
|
||||
msg = f"{timestamp}\t{msg}"
|
||||
self.log_monitor.append(msg)
|
||||
|
||||
def process_backtesting_finished_event(self, event: Event):
|
||||
""""""
|
||||
statistics = self.backtester_engine.get_result_statistics()
|
||||
self.statistics_monitor.set_data(statistics)
|
||||
|
||||
df = self.backtester_engine.get_result_df()
|
||||
self.chart.set_data(df)
|
||||
|
||||
def process_optimization_finished_event(self, event: Event):
|
||||
""""""
|
||||
pass
|
||||
|
||||
def start_backtesting(self):
|
||||
""""""
|
||||
class_name = self.class_combo.currentText()
|
||||
vt_symbol = self.symbol_line.text()
|
||||
interval = self.interval_combo.currentText()
|
||||
start = self.start_date_edit.date().toPyDate()
|
||||
end = self.end_date_edit.date().toPyDate()
|
||||
rate = float(self.rate_line.text())
|
||||
slippage = float(self.slippage_line.text())
|
||||
size = float(self.size_line.text())
|
||||
pricetick = float(self.pricetick_line.text())
|
||||
capital = float(self.capital_line.text())
|
||||
|
||||
old_setting = self.settings[class_name]
|
||||
dialog = SettingEditor(class_name, old_setting)
|
||||
i = dialog.exec()
|
||||
if i != dialog.Accepted:
|
||||
return
|
||||
|
||||
new_setting = dialog.get_setting()
|
||||
self.settings[class_name] = new_setting
|
||||
|
||||
result = self.backtester_engine.start_backtesting(
|
||||
class_name,
|
||||
vt_symbol,
|
||||
interval,
|
||||
start,
|
||||
end,
|
||||
rate,
|
||||
slippage,
|
||||
size,
|
||||
pricetick,
|
||||
capital,
|
||||
new_setting
|
||||
)
|
||||
|
||||
if result:
|
||||
self.statistics_monitor.clear_data()
|
||||
self.chart.clear_data()
|
||||
|
||||
def show(self):
|
||||
""""""
|
||||
self.showMaximized()
|
||||
|
||||
|
||||
class StatisticsMonitor(QtWidgets.QTableWidget):
|
||||
""""""
|
||||
KEY_NAME_MAP = {
|
||||
"start_date": "首个交易日",
|
||||
"end_date": "最后交易日",
|
||||
|
||||
"total_days": "总交易日",
|
||||
"profit_days": "盈利交易日",
|
||||
"loss_days": "亏损交易日",
|
||||
|
||||
"capital": "起始资金",
|
||||
"end_balance": "结束资金",
|
||||
|
||||
"total_return": "总收益率",
|
||||
"annual_return": "年化收益",
|
||||
"max_drawdown": "最大回撤",
|
||||
"max_ddpercent": "百分比最大回撤",
|
||||
|
||||
"total_net_pnl": "总盈亏",
|
||||
"total_commission": "总手续费",
|
||||
"total_slippage": "总滑点",
|
||||
"total_turnover": "总成交额",
|
||||
"total_trade_count": "总成交笔数",
|
||||
|
||||
"daily_net_pnl": "日均盈亏",
|
||||
"daily_commission": "日均手续费",
|
||||
"daily_slippage": "日均滑点",
|
||||
"daily_turnover": "日均成交额",
|
||||
"daily_trade_count": "日均成交笔数",
|
||||
|
||||
"daily_return": "日均收益率",
|
||||
"return_std": "收益标准差",
|
||||
"sharpe_ratio": "夏普比率",
|
||||
"return_drawdown_ratio": "收益回撤比"
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
""""""
|
||||
super().__init__()
|
||||
|
||||
self.cells = {}
|
||||
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
""""""
|
||||
self.setRowCount(len(self.KEY_NAME_MAP))
|
||||
self.setVerticalHeaderLabels(list(self.KEY_NAME_MAP.values()))
|
||||
|
||||
self.setColumnCount(1)
|
||||
self.horizontalHeader().setVisible(False)
|
||||
self.horizontalHeader().setSectionResizeMode(
|
||||
QtWidgets.QHeaderView.Stretch
|
||||
)
|
||||
|
||||
for row, key in enumerate(self.KEY_NAME_MAP.keys()):
|
||||
cell = QtWidgets.QTableWidgetItem()
|
||||
self.setItem(row, 0, cell)
|
||||
self.cells[key] = cell
|
||||
|
||||
def clear_data(self):
|
||||
""""""
|
||||
for cell in self.cells.values():
|
||||
cell.setText("")
|
||||
|
||||
def set_data(self, data: dict):
|
||||
""""""
|
||||
data["capital"] = f"{data['capital']:,.2f}"
|
||||
data["end_balance"] = f"{data['end_balance']:,.2f}"
|
||||
data["total_return"] = f"{data['total_return']:,.2f}%"
|
||||
data["annual_return"] = f"{data['annual_return']:,.2f}%"
|
||||
data["max_drawdown"] = f"{data['max_drawdown']:,.2f}"
|
||||
data["max_ddpercent"] = f"{data['max_ddpercent']:,.2f}%"
|
||||
data["total_net_pnl"] = f"{data['total_net_pnl']:,.2f}"
|
||||
data["total_commission"] = f"{data['total_commission']:,.2f}"
|
||||
data["total_slippage"] = f"{data['total_slippage']:,.2f}"
|
||||
data["total_turnover"] = f"{data['total_turnover']:,.2f}"
|
||||
data["daily_net_pnl"] = f"{data['daily_net_pnl']:,.2f}"
|
||||
data["daily_commission"] = f"{data['daily_commission']:,.2f}"
|
||||
data["daily_slippage"] = f"{data['daily_slippage']:,.2f}"
|
||||
data["daily_turnover"] = f"{data['daily_turnover']:,.2f}"
|
||||
data["daily_return"] = f"{data['daily_return']:,.2f}%"
|
||||
data["return_std"] = f"{data['return_std']:,.2f}%"
|
||||
data["sharpe_ratio"] = f"{data['sharpe_ratio']:,.2f}"
|
||||
data["return_drawdown_ratio"] = f"{data['return_drawdown_ratio']:,.2f}"
|
||||
|
||||
for key, cell in self.cells.items():
|
||||
value = data.get(key, "")
|
||||
cell.setText(str(value))
|
||||
|
||||
|
||||
class SettingEditor(QtWidgets.QDialog):
|
||||
"""
|
||||
For creating new strategy and editing strategy parameters.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, class_name: str, parameters: dict
|
||||
):
|
||||
""""""
|
||||
super(SettingEditor, self).__init__()
|
||||
|
||||
self.class_name = class_name
|
||||
self.parameters = parameters
|
||||
self.edits = {}
|
||||
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
""""""
|
||||
form = QtWidgets.QFormLayout()
|
||||
|
||||
# Add vt_symbol and name edit if add new strategy
|
||||
self.setWindowTitle(f"策略参数配置:{self.class_name}")
|
||||
button_text = "确定"
|
||||
parameters = self.parameters
|
||||
|
||||
for name, value in parameters.items():
|
||||
type_ = type(value)
|
||||
|
||||
edit = QtWidgets.QLineEdit(str(value))
|
||||
if type_ is int:
|
||||
validator = QtGui.QIntValidator()
|
||||
edit.setValidator(validator)
|
||||
elif type_ is float:
|
||||
validator = QtGui.QDoubleValidator()
|
||||
edit.setValidator(validator)
|
||||
|
||||
form.addRow(f"{name} {type_}", edit)
|
||||
|
||||
self.edits[name] = (edit, type_)
|
||||
|
||||
button = QtWidgets.QPushButton(button_text)
|
||||
button.clicked.connect(self.accept)
|
||||
form.addRow(button)
|
||||
|
||||
self.setLayout(form)
|
||||
|
||||
def get_setting(self):
|
||||
""""""
|
||||
setting = {}
|
||||
|
||||
for name, tp in self.edits.items():
|
||||
edit, type_ = tp
|
||||
value_text = edit.text()
|
||||
|
||||
if type_ == bool:
|
||||
if value_text == "True":
|
||||
value = True
|
||||
else:
|
||||
value = False
|
||||
else:
|
||||
value = type_(value_text)
|
||||
|
||||
setting[name] = value
|
||||
|
||||
return setting
|
||||
|
||||
|
||||
class BacktesterChart(pg.GraphicsWindow):
|
||||
""""""
|
||||
|
||||
def __init__(self):
|
||||
""""""
|
||||
super().__init__(title="Backtester Chart")
|
||||
|
||||
self.dates = {}
|
||||
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
""""""
|
||||
pg.setConfigOptions(antialias=True)
|
||||
|
||||
# Create plot widgets
|
||||
self.balance_plot = self.addPlot(
|
||||
title="账户净值",
|
||||
axisItems={"bottom": DateAxis(self.dates, orientation="bottom")}
|
||||
)
|
||||
self.nextRow()
|
||||
|
||||
self.drawdown_plot = self.addPlot(
|
||||
title="净值回撤",
|
||||
axisItems={"bottom": DateAxis(self.dates, orientation="bottom")}
|
||||
)
|
||||
self.nextRow()
|
||||
|
||||
self.pnl_plot = self.addPlot(
|
||||
title="每日盈亏",
|
||||
axisItems={"bottom": DateAxis(self.dates, orientation="bottom")}
|
||||
)
|
||||
self.nextRow()
|
||||
|
||||
self.distribution_plot = self.addPlot(title="盈亏分布")
|
||||
|
||||
# Add curves and bars on plot widgets
|
||||
self.balance_curve = self.balance_plot.plot(
|
||||
pen=pg.mkPen("#ffc107", width=3)
|
||||
)
|
||||
|
||||
dd_color = "#303f9f"
|
||||
self.drawdown_curve = self.drawdown_plot.plot(
|
||||
fillLevel=-0.3, brush=dd_color, pen=dd_color
|
||||
)
|
||||
|
||||
profit_color = 'r'
|
||||
loss_color = 'g'
|
||||
self.profit_pnl_bar = pg.BarGraphItem(
|
||||
x=[], height=[], width=0.3, brush=profit_color, pen=profit_color
|
||||
)
|
||||
self.loss_pnl_bar = pg.BarGraphItem(
|
||||
x=[], height=[], width=0.3, brush=loss_color, pen=loss_color
|
||||
)
|
||||
self.pnl_plot.addItem(self.profit_pnl_bar)
|
||||
self.pnl_plot.addItem(self.loss_pnl_bar)
|
||||
|
||||
distribution_color = "#6d4c41"
|
||||
self.distribution_curve = self.distribution_plot.plot(
|
||||
fillLevel=-0.3, brush=distribution_color, pen=distribution_color
|
||||
)
|
||||
|
||||
def clear_data(self):
|
||||
""""""
|
||||
self.balance_curve.setData([], [])
|
||||
self.drawdown_curve.setData([], [])
|
||||
self.profit_pnl_bar.setOpts(x=[], height=[])
|
||||
self.loss_pnl_bar.setOpts(x=[], height=[])
|
||||
self.distribution_curve.setData([], [])
|
||||
|
||||
def set_data(self, df):
|
||||
""""""
|
||||
count = len(df)
|
||||
|
||||
self.dates.clear()
|
||||
for n, date in enumerate(df.index):
|
||||
self.dates[n] = date
|
||||
|
||||
# Set data for curve of balance and drawdown
|
||||
self.balance_curve.setData(df["balance"])
|
||||
self.drawdown_curve.setData(df["drawdown"])
|
||||
|
||||
# Set data for daily pnl bar
|
||||
profit_pnl_x = []
|
||||
profit_pnl_height = []
|
||||
loss_pnl_x = []
|
||||
loss_pnl_height = []
|
||||
|
||||
for count, pnl in enumerate(df["net_pnl"]):
|
||||
if pnl >= 0:
|
||||
profit_pnl_height.append(pnl)
|
||||
profit_pnl_x.append(count)
|
||||
else:
|
||||
loss_pnl_height.append(pnl)
|
||||
loss_pnl_x.append(count)
|
||||
|
||||
self.profit_pnl_bar.setOpts(x=profit_pnl_x, height=profit_pnl_height)
|
||||
self.loss_pnl_bar.setOpts(x=loss_pnl_x, height=loss_pnl_height)
|
||||
|
||||
# Set data for pnl distribution
|
||||
hist, x = np.histogram(df["net_pnl"], bins="auto")
|
||||
x = x[:-1]
|
||||
self.distribution_curve.setData(x, hist)
|
||||
|
||||
|
||||
class DateAxis(pg.AxisItem):
|
||||
"""Axis for showing date data"""
|
||||
|
||||
def __init__(self, dates: dict, *args, **kwargs):
|
||||
""""""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dates = dates
|
||||
|
||||
def tickStrings(self, values, scale, spacing):
|
||||
""""""
|
||||
strings = []
|
||||
for v in values:
|
||||
dt = self.dates.get(v, "")
|
||||
strings.append(str(dt))
|
||||
return strings
|
@ -7,6 +7,7 @@ from vnpy.trader.utility import BarGenerator, ArrayManager
|
||||
|
||||
from .base import APP_NAME, StopOrder
|
||||
from .engine import CtaEngine
|
||||
from .backtesting import BacktestingEngine, OptimizationSetting
|
||||
from .template import CtaTemplate, CtaSignal, TargetPosTemplate
|
||||
|
||||
|
||||
|
@ -293,7 +293,7 @@ class BacktestingEngine:
|
||||
self.output("逐日盯市盈亏计算完成")
|
||||
return self.daily_df
|
||||
|
||||
def calculate_statistics(self, df: DataFrame = None, Output=True):
|
||||
def calculate_statistics(self, df: DataFrame = None, output=True):
|
||||
""""""
|
||||
self.output("开始计算策略统计指标")
|
||||
|
||||
@ -377,7 +377,7 @@ class BacktestingEngine:
|
||||
return_drawdown_ratio = -total_return / max_ddpercent
|
||||
|
||||
# Output
|
||||
if Output:
|
||||
if output:
|
||||
self.output("-" * 30)
|
||||
self.output(f"首个交易日:\t{start_date}")
|
||||
self.output(f"最后交易日:\t{end_date}")
|
||||
@ -417,6 +417,7 @@ class BacktestingEngine:
|
||||
"total_days": total_days,
|
||||
"profit_days": profit_days,
|
||||
"loss_days": loss_days,
|
||||
"capital": self.capital,
|
||||
"end_balance": end_balance,
|
||||
"max_drawdown": max_drawdown,
|
||||
"max_ddpercent": max_ddpercent,
|
||||
|
Loading…
Reference in New Issue
Block a user