[Mod] add trade scatter to show all backtesting trades

This commit is contained in:
vn.py 2019-07-18 22:54:47 +08:00
parent aedcf81d84
commit 66b46704e6
2 changed files with 47 additions and 3 deletions

View File

@ -9,7 +9,7 @@ from ..engine import (
EVENT_BACKTESTER_OPTIMIZATION_FINISHED,
OptimizationSetting
)
from vnpy.trader.constant import Interval
from vnpy.trader.constant import Interval, Direction
from vnpy.trader.engine import MainEngine
from vnpy.trader.ui import QtCore, QtWidgets, QtGui
from vnpy.trader.ui.widget import BaseMonitor, BaseCell, DirectionCell, EnumCell
@ -975,6 +975,7 @@ class CandleChartDialog(QtWidgets.QDialog):
""""""
super().__init__()
self.dt_ix_map = {}
self.updated = False
self.init_ui()
@ -983,6 +984,7 @@ class CandleChartDialog(QtWidgets.QDialog):
self.setWindowTitle("回测K线图表")
self.resize(1400, 800)
# Create chart widget
self.chart = ChartWidget()
self.chart.add_plot("candle", hide_x_axis=True)
self.chart.add_plot("volume", maximum_height=200)
@ -990,6 +992,12 @@ class CandleChartDialog(QtWidgets.QDialog):
self.chart.add_item(VolumeItem, "volume", "volume")
self.chart.add_cursor()
# Add scatter item for showing tradings
self.trade_scatter = pg.ScatterPlotItem()
candle_plot = self.chart.get_plot("candle")
candle_plot.addItem(self.trade_scatter)
# Set layout
vbox = QtWidgets.QVBoxLayout()
vbox.addWidget(self.chart)
self.setLayout(vbox)
@ -999,14 +1007,41 @@ class CandleChartDialog(QtWidgets.QDialog):
self.updated = True
self.chart.update_history(history)
for ix, bar in enumerate(history):
self.dt_ix_map[bar.datetime] = ix
def update_trades(self, trades: list):
""""""
pass
trade_data = []
for trade in trades:
ix = self.dt_ix_map[trade.datetime]
scatter = {
"pos": (ix, trade.price),
"data": 1,
"size": 14,
"pen": pg.mkPen((255, 255, 255))
}
if trade.direction == Direction.LONG:
scatter["symbol"] = "t1"
scatter["brush"] = pg.mkBrush((255, 255, 0))
else:
scatter["symbol"] = "t"
scatter["brush"] = pg.mkBrush((0, 0, 255))
trade_data.append(scatter)
self.trade_scatter.setData(trade_data)
def clear_data(self):
""""""
self.updated = False
self.chart.clear_all()
self.dt_ix_map.clear()
self.trade_scatter.clear()
def is_updated(self):
""""""

View File

@ -14,6 +14,9 @@ from .axis import DatetimeAxis
from .item import ChartItem
pg.setConfigOptions(antialias=True)
class ChartWidget(pg.PlotWidget):
""""""
MIN_BAR_COUNT = 100
@ -122,7 +125,13 @@ class ChartWidget(pg.PlotWidget):
self._layout.nextRow()
self._layout.addItem(plot)
def get_all_plots(self):
def get_plot(self, plot_name: str) -> pg.PlotItem:
"""
Get specific plot with its name.
"""
return self._plots.get(plot_name, None)
def get_all_plots(self) -> List[pg.PlotItem]:
"""
Get all plot objects.
"""