diff --git a/examples/candle_chart/run.py b/examples/candle_chart/run.py index 6a940833..61848f08 100644 --- a/examples/candle_chart/run.py +++ b/examples/candle_chart/run.py @@ -1,12 +1,12 @@ from datetime import datetime -from vnpy.trader.ui import QtWidgets -from vnpy.chart import ChartWidget +from vnpy.trader.ui import create_qapp, QtCore from vnpy.trader.database import database_manager from vnpy.trader.constant import Exchange, Interval +from vnpy.chart import ChartWidget, VolumeItem, CandleItem if __name__ == "__main__": - app = QtWidgets.QApplication([]) + app = create_qapp() bars = database_manager.load_bar_data( "IF88", @@ -17,7 +17,25 @@ if __name__ == "__main__": ) widget = ChartWidget() - widget.update_history(bars) - widget.show() + widget.add_plot("candle", hide_x_axis=True) + widget.add_plot("volume") + widget.add_item(CandleItem, "candle", "candle") + widget.add_item(VolumeItem, "volume", "volume") + widget.add_cursor() + n = 1000 + history = bars[:n] + new_data = bars[n:] + + widget.update_history(history) + + def update_bar(): + bar = new_data.pop(0) + widget.update_bar(bar) + + timer = QtCore.QTimer() + timer.timeout.connect(update_bar) + # timer.start(100) + + widget.show() app.exec_() diff --git a/vnpy/chart/__init__.py b/vnpy/chart/__init__.py index 27ede33a..4d090ff4 100644 --- a/vnpy/chart/__init__.py +++ b/vnpy/chart/__init__.py @@ -1 +1,2 @@ from .widget import ChartWidget +from .item import CandleItem, VolumeItem diff --git a/vnpy/chart/base.py b/vnpy/chart/base.py index 8d02527d..d2243454 100644 --- a/vnpy/chart/base.py +++ b/vnpy/chart/base.py @@ -2,8 +2,8 @@ WHITE_COLOR = (255, 255, 255) BLACK_COLOR = (0, 0, 0) GREY_COLOR = (100, 100, 100) -UP_COLOR = (85, 234, 204) -DOWN_COLOR = (218, 75, 61) +UP_COLOR = (255, 0, 0) +DOWN_COLOR = (0, 255, 0) PEN_WIDTH = 1 BAR_WIDTH = 0.4 diff --git a/vnpy/chart/cursor.py b/vnpy/chart/cursor.py deleted file mode 100644 index 9c84df9c..00000000 --- a/vnpy/chart/cursor.py +++ /dev/null @@ -1,9 +0,0 @@ -from vnpy.trader.ui import QtCore - - -class ChartCursor(QtCore.QObject): - """""" - - def __init__(self): - """""" - pass diff --git a/vnpy/chart/item.py b/vnpy/chart/item.py index a1937456..e1435ccc 100644 --- a/vnpy/chart/item.py +++ b/vnpy/chart/item.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import List, Dict +from typing import List, Dict, Tuple import pyqtgraph as pg @@ -40,12 +40,21 @@ class ChartItem(pg.GraphicsObject): pass @abstractmethod - def boundingRect(self): + def boundingRect(self) -> QtCore.QRectF: """ Get bounding rectangles for item. """ pass + @abstractmethod + def get_y_range(self, min_ix: int = None, max_ix: int = None) -> Tuple[float, float]: + """ + Get range of y-axis with given x-axis range. + + If min_ix and max_ix not specified, then return range with whole data set. + """ + pass + def update_history(self, history: List[BarData]) -> BarData: """ Update a list of bar data. @@ -186,6 +195,15 @@ class CandleItem(ChartItem): ) return rect + def get_y_range(self, min_ix: int = None, max_ix: int = None) -> Tuple[float, float]: + """ + Get range of y-axis with given x-axis range. + + If min_ix and max_ix not specified, then return range with whole data set. + """ + min_price, max_price = self._manager.get_price_range(min_ix, max_ix) + return min_price, max_price + class VolumeItem(ChartItem): """""" @@ -231,3 +249,12 @@ class VolumeItem(ChartItem): max_volume - min_volume ) return rect + + def get_y_range(self, min_ix: int = None, max_ix: int = None) -> Tuple[float, float]: + """ + Get range of y-axis with given x-axis range. + + If min_ix and max_ix not specified, then return range with whole data set. + """ + min_volume, max_volume = self._manager.get_volume_range(min_ix, max_ix) + return min_volume, max_volume diff --git a/vnpy/chart/manager.py b/vnpy/chart/manager.py index 9744df7c..39f76876 100644 --- a/vnpy/chart/manager.py +++ b/vnpy/chart/manager.py @@ -46,7 +46,7 @@ class BarManager: self._datetime_index_map[dt] = ix self._index_datetime_map[ix] = dt - self.datetime_bar_map[dt] = bar + self._bars[dt] = bar self._clear_cache() diff --git a/vnpy/chart/widget.py b/vnpy/chart/widget.py index 730b2118..ad9d351c 100644 --- a/vnpy/chart/widget.py +++ b/vnpy/chart/widget.py @@ -1,14 +1,14 @@ -from typing import List +from typing import List, Dict, Type import pyqtgraph as pg -from vnpy.trader.ui import QtGui, QtWidgets +from vnpy.trader.ui import QtGui, QtWidgets, QtCore from vnpy.trader.object import BarData from .manager import BarManager -from .base import GREY_COLOR +from .base import GREY_COLOR, WHITE_COLOR from .axis import DatetimeAxis, AXIS_FONT -from .item import CandleItem, VolumeItem, ChartItem +from .item import ChartItem class ChartWidget(pg.PlotWidget): @@ -21,32 +21,43 @@ class ChartWidget(pg.PlotWidget): self._manager: BarManager = BarManager() - self._plots: List[ChartItem] = [] - self._items: List[pg.GraphicsObject] = [] + self._plots: Dict[str, ChartItem] = {} + self._items: Dict[str, pg.GraphicsObject] = {} + self._item_plot_map: Dict[ChartItem, pg.GraphicsObject] = {} - self._max_ix: int = 0 - self._bar_count: int = 0 + self._first_plot: pg.PlotItem = None + self._right_ix: int = 0 # Index of most right data + self._bar_count: int = 0 # Total bar visible in chart self.init_ui() def init_ui(self) -> None: """""" + self.setWindowTitle("ChartWidget of vn.py") + self._layout = pg.GraphicsLayout() self._layout.setContentsMargins(10, 10, 10, 10) self._layout.setSpacing(0) self._layout.setBorder(color=GREY_COLOR, width=0.8) self._layout.setZValue(0) - self.setCentralItem(self._layout) self._x_axis = DatetimeAxis(self._manager, orientation='bottom') - self.init_candle() - self.init_volume() - self._volume_plot.setXLink(self._candle_plot) - - def new_plot(self) -> None: + def add_cursor(self) -> None: """""" + self._cursor = ChartCursor(self, self._manager, self._plots) + + def add_plot( + self, + plot_name: str, + minimum_height: int = 80, + hide_x_axis: bool = False + ) -> None: + """ + Add plot area. + """ + # Create plot object plot = pg.PlotItem(axisItems={'bottom': self._x_axis}) plot.setMenuEnabled(False) plot.setClipToView(True) @@ -55,47 +66,56 @@ class ChartWidget(pg.PlotWidget): plot.setDownsampling(mode='peak') plot.setRange(xRange=(0, 1), yRange=(0, 1)) plot.hideButtons() + plot.setMinimumHeight(minimum_height) + if hide_x_axis: + plot.hideAxis("bottom") + + if not self._first_plot: + self._first_plot = plot + + # Connect view change signal to update y range function view = plot.getViewBox() - view.sigXRangeChanged.connect(self._change_plot_y_range) + view.sigXRangeChanged.connect(self._update_plot_range) view.setMouseEnabled(x=True, y=False) + # Set right axis right_axis = plot.getAxis('right') right_axis.setWidth(60) right_axis.setStyle(tickFont=AXIS_FONT) - return plot + # Connect x-axis link + if self._plots: + first_plot = list(self._plots.values())[0] + plot.setXLink(first_plot) - def init_candle(self) -> None: - """ - Initialize candle plot. - """ - self._candle_item = CandleItem(self._manager) - self._items.append(self._candle_item) + # Store plot object in dict + self._plots[plot_name] = plot - self._candle_plot = self.new_plot() - self._candle_plot.addItem(self._candle_item) - self._candle_plot.setMinimumHeight(80) - self._candle_plot.hideAxis('bottom') - self._plots.append(self._candle_plot) + def add_item( + self, + item_class: Type[ChartItem], + item_name: str, + plot_name: str + ): + """ + Add chart item. + """ + item = item_class(self._manager) + self._items[item_name] = item + + plot = self._plots.get(plot_name) + plot.addItem(item) + self._item_plot_map[item] = plot self._layout.nextRow() - self._layout.addItem(self._candle_plot) + self._layout.addItem(plot) - def init_volume(self) -> None: + def get_all_plots(self): """ - Initialize bar plot. + Get all plot objects. """ - self._volume_item = VolumeItem(self._manager) - self._items.append(self._volume_item) - - self._volume_plot = self.new_plot() - self._volume_plot.addItem(self._volume_item) - self._volume_plot.setMinimumHeight(80) - self._plots.append(self._volume_plot) - - self._layout.nextRow() - self._layout.addItem(self._volume_plot) + return self._plots.values() def clear_all(self) -> None: """ @@ -103,7 +123,7 @@ class ChartWidget(pg.PlotWidget): """ self._manager.clear_all() - for item in self._items: + for item in self._items.values(): item.clear_all() def update_history(self, history: List[BarData]) -> None: @@ -112,10 +132,10 @@ class ChartWidget(pg.PlotWidget): """ self._manager.update_history(history) - for item in self._items: + for item in self._items.values(): item.update_history(history) - self._update_plot_range() + self._update_plot_limits() def update_bar(self, bar: BarData) -> None: """ @@ -123,57 +143,131 @@ class ChartWidget(pg.PlotWidget): """ self._manager.update_bar(bar) - for item in self.items: + for item in self._items.values(): item.update_bar(bar) - self._update_plot_range() + self._update_plot_limits() + + def _update_plot_limits(self) -> None: + """ + Update the limit of plots. + """ + for item, plot in self._item_plot_map.items(): + min_value, max_value = item.get_y_range() + + plot.setLimits( + xMin=-self.MIN_BAR_COUNT, + xMax=self._manager.get_count(), + yMin=min_value, + yMax=max_value + ) def _update_plot_range(self) -> None: - """ - Update the range of plots. - """ - max_ix = self._max_ix - min_ix = self._max_ix - self._bar_count - - # Update limit and range for x-axis - for plot in self._plots: - plot.setLimits( - xMin=-self.MIN_BAR_COUNT, - xMax=self._manager.get_count() - ) - plot.setRange( - xRange=(min_ix, max_ix), - padding=0 - ) - - # Update limit for y-axis - min_price, max_price = self._manager.get_price_range() - self._candle_plot.setLimits(yMin=min_price, yMax=max_price) - - min_volume, max_volume = self._manager.get_volume_range() - self._volume_plot.setLimits(yMin=min_volume, yMax=max_volume) - - def _change_plot_y_range(self) -> None: """ Reset the y-axis range of plots. """ - view = self._candle_plot.getViewBox() + view = self._first_plot.getViewBox() view_range = view.viewRange() + min_ix = max(0, int(view_range[0][0])) max_ix = min(self._manager.get_count(), int(view_range[0][1])) - price_range = self._manager.get_price_range(min_ix, max_ix) - self._candle_plot.setRange(yRange=price_range) - - volume_range = self._manager.get_volume_range(min_ix, max_ix) - self._volume_plot.setRange(yRange=volume_range) + # Update limit for y-axis + for item, plot in self._item_plot_map.items(): + y_range = item.get_y_range(min_ix, max_ix) + plot.setRange(yRange=y_range) def paintEvent(self, event: QtGui.QPaintEvent) -> None: """ Reimplement this method of parent to update current max_ix value. """ - view = self._candle_plot.getViewBox() + view = self._first_plot.getViewBox() view_range = view.viewRange() - self._max_ix = max(0, view_range[0][1]) + self._right_ix = max(0, view_range[0][1]) super().paintEvent(event) + + +class ChartCursor(QtCore.QObject): + """""" + + def __init__( + self, + widget: ChartWidget, + manager: BarManager, + plots: Dict[str, pg.GraphicsObject] + ): + """""" + super().__init__() + + self._widget: ChartWidget = widget + self._manager: BarManager = manager + self._plots: Dict[str, pg.GraphicsObject] = plots + + self._x = 0 + self._y = 0 + self._plot_name = "" + + self.init_ui() + + def init_ui(self): + """""" + # Create line objects + self._v_lines: Dict[str, pg.InfiniteLine] = {} + self._h_lines: Dict[str, pg.InfiniteLine] = {} + self._views: Dict[str, pg.ViewBox] = {} + + pen = pg.mkPen(WHITE_COLOR) + + for plot_name, plot in self._plots.items(): + v_line = pg.InfiniteLine(angle=90, movable=False, pen=pen) + h_line = pg.InfiniteLine(angle=0, movable=False, pen=pen) + view = plot.getViewBox() + + for line in [v_line, h_line]: + line.setZValue(0) + line.hide() + view.addItem(line) + + self._v_lines[plot_name] = v_line + self._h_lines[plot_name] = h_line + self._views[plot_name] = view + + # Connect signal + self.proxy = pg.SignalProxy( + self._widget.scene().sigMouseMoved, + rateLimit=360, + slot=self.mouse_moved + ) + + def mouse_moved(self, evt: tuple): + """""" + if not self._manager.get_count(): + return + + pos = evt[0] + + for plot_name, view in self._views.items(): + rect = view.sceneBoundingRect() + + if rect.contains(pos): + mouse_point = view.mapSceneToView(pos) + self._x = mouse_point.x() + self._y = mouse_point.y() + self._plot_name = plot_name + break + + self.update_line() + + def update_line(self): + """""" + for v_line in self._v_lines.values(): + v_line.setPos(self._x) + v_line.show() + + for plot_name, h_line in self._h_lines.items(): + if plot_name == self._plot_name: + h_line.setPos(self._y) + h_line.show() + else: + h_line.hide()