[Mod] add trade scatter to show all backtesting trades
This commit is contained in:
parent
aedcf81d84
commit
66b46704e6
@ -9,7 +9,7 @@ from ..engine import (
|
||||
EVENT_BACKTESTER_OPTIMIZATION_FINISHED,
|
||||
OptimizationSetting
|
||||
)
|
||||
from vnpy.trader.constant import Interval
|
||||
from vnpy.trader.constant import Interval, Direction
|
||||
from vnpy.trader.engine import MainEngine
|
||||
from vnpy.trader.ui import QtCore, QtWidgets, QtGui
|
||||
from vnpy.trader.ui.widget import BaseMonitor, BaseCell, DirectionCell, EnumCell
|
||||
@ -975,6 +975,7 @@ class CandleChartDialog(QtWidgets.QDialog):
|
||||
""""""
|
||||
super().__init__()
|
||||
|
||||
self.dt_ix_map = {}
|
||||
self.updated = False
|
||||
self.init_ui()
|
||||
|
||||
@ -983,6 +984,7 @@ class CandleChartDialog(QtWidgets.QDialog):
|
||||
self.setWindowTitle("回测K线图表")
|
||||
self.resize(1400, 800)
|
||||
|
||||
# Create chart widget
|
||||
self.chart = ChartWidget()
|
||||
self.chart.add_plot("candle", hide_x_axis=True)
|
||||
self.chart.add_plot("volume", maximum_height=200)
|
||||
@ -990,6 +992,12 @@ class CandleChartDialog(QtWidgets.QDialog):
|
||||
self.chart.add_item(VolumeItem, "volume", "volume")
|
||||
self.chart.add_cursor()
|
||||
|
||||
# Add scatter item for showing tradings
|
||||
self.trade_scatter = pg.ScatterPlotItem()
|
||||
candle_plot = self.chart.get_plot("candle")
|
||||
candle_plot.addItem(self.trade_scatter)
|
||||
|
||||
# Set layout
|
||||
vbox = QtWidgets.QVBoxLayout()
|
||||
vbox.addWidget(self.chart)
|
||||
self.setLayout(vbox)
|
||||
@ -999,14 +1007,41 @@ class CandleChartDialog(QtWidgets.QDialog):
|
||||
self.updated = True
|
||||
self.chart.update_history(history)
|
||||
|
||||
for ix, bar in enumerate(history):
|
||||
self.dt_ix_map[bar.datetime] = ix
|
||||
|
||||
def update_trades(self, trades: list):
|
||||
""""""
|
||||
pass
|
||||
trade_data = []
|
||||
|
||||
for trade in trades:
|
||||
ix = self.dt_ix_map[trade.datetime]
|
||||
|
||||
scatter = {
|
||||
"pos": (ix, trade.price),
|
||||
"data": 1,
|
||||
"size": 14,
|
||||
"pen": pg.mkPen((255, 255, 255))
|
||||
}
|
||||
|
||||
if trade.direction == Direction.LONG:
|
||||
scatter["symbol"] = "t1"
|
||||
scatter["brush"] = pg.mkBrush((255, 255, 0))
|
||||
else:
|
||||
scatter["symbol"] = "t"
|
||||
scatter["brush"] = pg.mkBrush((0, 0, 255))
|
||||
|
||||
trade_data.append(scatter)
|
||||
|
||||
self.trade_scatter.setData(trade_data)
|
||||
|
||||
def clear_data(self):
|
||||
""""""
|
||||
self.updated = False
|
||||
self.chart.clear_all()
|
||||
|
||||
self.dt_ix_map.clear()
|
||||
self.trade_scatter.clear()
|
||||
|
||||
def is_updated(self):
|
||||
""""""
|
||||
|
@ -14,6 +14,9 @@ from .axis import DatetimeAxis
|
||||
from .item import ChartItem
|
||||
|
||||
|
||||
pg.setConfigOptions(antialias=True)
|
||||
|
||||
|
||||
class ChartWidget(pg.PlotWidget):
|
||||
""""""
|
||||
MIN_BAR_COUNT = 100
|
||||
@ -122,7 +125,13 @@ class ChartWidget(pg.PlotWidget):
|
||||
self._layout.nextRow()
|
||||
self._layout.addItem(plot)
|
||||
|
||||
def get_all_plots(self):
|
||||
def get_plot(self, plot_name: str) -> pg.PlotItem:
|
||||
"""
|
||||
Get specific plot with its name.
|
||||
"""
|
||||
return self._plots.get(plot_name, None)
|
||||
|
||||
def get_all_plots(self) -> List[pg.PlotItem]:
|
||||
"""
|
||||
Get all plot objects.
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user