diff --git a/vnpy/app/cta_backtester/ui/widget.py b/vnpy/app/cta_backtester/ui/widget.py index ffa43db3..5c5787af 100644 --- a/vnpy/app/cta_backtester/ui/widget.py +++ b/vnpy/app/cta_backtester/ui/widget.py @@ -9,7 +9,7 @@ from ..engine import ( EVENT_BACKTESTER_OPTIMIZATION_FINISHED, OptimizationSetting ) -from vnpy.trader.constant import Interval +from vnpy.trader.constant import Interval, Direction from vnpy.trader.engine import MainEngine from vnpy.trader.ui import QtCore, QtWidgets, QtGui from vnpy.trader.ui.widget import BaseMonitor, BaseCell, DirectionCell, EnumCell @@ -975,6 +975,7 @@ class CandleChartDialog(QtWidgets.QDialog): """""" super().__init__() + self.dt_ix_map = {} self.updated = False self.init_ui() @@ -983,6 +984,7 @@ class CandleChartDialog(QtWidgets.QDialog): self.setWindowTitle("回测K线图表") self.resize(1400, 800) + # Create chart widget self.chart = ChartWidget() self.chart.add_plot("candle", hide_x_axis=True) self.chart.add_plot("volume", maximum_height=200) @@ -990,6 +992,12 @@ class CandleChartDialog(QtWidgets.QDialog): self.chart.add_item(VolumeItem, "volume", "volume") self.chart.add_cursor() + # Add scatter item for showing tradings + self.trade_scatter = pg.ScatterPlotItem() + candle_plot = self.chart.get_plot("candle") + candle_plot.addItem(self.trade_scatter) + + # Set layout vbox = QtWidgets.QVBoxLayout() vbox.addWidget(self.chart) self.setLayout(vbox) @@ -999,14 +1007,41 @@ class CandleChartDialog(QtWidgets.QDialog): self.updated = True self.chart.update_history(history) + for ix, bar in enumerate(history): + self.dt_ix_map[bar.datetime] = ix + def update_trades(self, trades: list): """""" - pass + trade_data = [] + + for trade in trades: + ix = self.dt_ix_map[trade.datetime] + scatter = { + "pos": (ix, trade.price), + "data": 1, + "size": 14, + "pen": pg.mkPen((255, 255, 255)) + } + + if trade.direction == Direction.LONG: + scatter["symbol"] = "t1" + scatter["brush"] = pg.mkBrush((255, 255, 0)) + else: + scatter["symbol"] = "t" + scatter["brush"] = pg.mkBrush((0, 0, 255)) + + trade_data.append(scatter) + + self.trade_scatter.setData(trade_data) + def clear_data(self): """""" self.updated = False self.chart.clear_all() + + self.dt_ix_map.clear() + self.trade_scatter.clear() def is_updated(self): """""" diff --git a/vnpy/chart/widget.py b/vnpy/chart/widget.py index 4c13a900..e00aab1b 100644 --- a/vnpy/chart/widget.py +++ b/vnpy/chart/widget.py @@ -14,6 +14,9 @@ from .axis import DatetimeAxis from .item import ChartItem +pg.setConfigOptions(antialias=True) + + class ChartWidget(pg.PlotWidget): """""" MIN_BAR_COUNT = 100 @@ -122,7 +125,13 @@ class ChartWidget(pg.PlotWidget): self._layout.nextRow() self._layout.addItem(plot) - def get_all_plots(self): + def get_plot(self, plot_name: str) -> pg.PlotItem: + """ + Get specific plot with its name. + """ + return self._plots.get(plot_name, None) + + def get_all_plots(self) -> List[pg.PlotItem]: """ Get all plot objects. """