diff --git a/examples/candle_chart/run.py b/examples/candle_chart/run.py index 61848f08..1059fe6c 100644 --- a/examples/candle_chart/run.py +++ b/examples/candle_chart/run.py @@ -18,7 +18,7 @@ if __name__ == "__main__": widget = ChartWidget() widget.add_plot("candle", hide_x_axis=True) - widget.add_plot("volume") + widget.add_plot("volume", maximum_height=200) widget.add_item(CandleItem, "candle", "candle") widget.add_item(VolumeItem, "volume", "volume") widget.add_cursor() diff --git a/vnpy/chart/base.py b/vnpy/chart/base.py index d2243454..629a1fec 100644 --- a/vnpy/chart/base.py +++ b/vnpy/chart/base.py @@ -3,7 +3,13 @@ BLACK_COLOR = (0, 0, 0) GREY_COLOR = (100, 100, 100) UP_COLOR = (255, 0, 0) -DOWN_COLOR = (0, 255, 0) +DOWN_COLOR = (0, 255, 50) +CURSOR_COLOR = (255, 245, 162) PEN_WIDTH = 1 BAR_WIDTH = 0.4 + + +def to_int(value: float) -> int: + """""" + return int(round(value, 0)) diff --git a/vnpy/chart/manager.py b/vnpy/chart/manager.py index 39f76876..a786eed6 100644 --- a/vnpy/chart/manager.py +++ b/vnpy/chart/manager.py @@ -1,9 +1,10 @@ from typing import Dict, List, Tuple from datetime import datetime -from functools import lru_cache from vnpy.trader.object import BarData +from .base import to_int + class BarManager: """""" @@ -14,6 +15,9 @@ class BarManager: self._datetime_index_map: Dict[datetime, int] = {} self._index_datetime_map: Dict[int, datetime] = {} + self._price_ranges: Dict[Tuple[int, int], Tuple[float, float]] = {} + self._volume_ranges: Dict[Tuple[int, int], Tuple[float, float]] = {} + def update_history(self, history: List[BarData]) -> None: """ Update a list of bar data. @@ -62,16 +66,18 @@ class BarManager: """ return self._datetime_index_map.get(dt, None) - def get_datetime(self, ix: int) -> datetime: + def get_datetime(self, ix: float) -> datetime: """ Get datetime with index. """ + ix = to_int(ix) return self._index_datetime_map.get(ix, None) - def get_bar(self, ix: int) -> BarData: + def get_bar(self, ix: float) -> BarData: """ Get bar data with index. """ + ix = to_int(ix) dt = self._index_datetime_map.get(ix, None) if not dt: return None @@ -84,8 +90,7 @@ class BarManager: """ return list(self._bars.values()) - @lru_cache(maxsize=99999) - def get_price_range(self, min_ix: int = None, max_ix: int = None) -> Tuple[float, float]: + def get_price_range(self, min_ix: float = None, max_ix: float = None) -> Tuple[float, float]: """ Get price range to show within given index range. """ @@ -96,8 +101,14 @@ class BarManager: min_ix = 0 max_ix = len(self._bars) - 1 else: + min_ix = to_int(min_ix) + max_ix = to_int(max_ix) max_ix = min(max_ix, self.get_count()) + buf = self._price_ranges.get((min_ix, max_ix), None) + if buf: + return buf + bar_list = list(self._bars.values())[min_ix:max_ix + 1] first_bar = bar_list[0] max_price = first_bar.high_price @@ -109,8 +120,7 @@ class BarManager: return min_price, max_price - @lru_cache(maxsize=99999) - def get_volume_range(self, min_ix: int = None, max_ix: int = None) -> Tuple[float, float]: + def get_volume_range(self, min_ix: float = None, max_ix: float = None) -> Tuple[float, float]: """ Get volume range to show within given index range. """ @@ -120,8 +130,15 @@ class BarManager: if not min_ix: min_ix = 0 max_ix = len(self._bars) - 1 + else: + min_ix = to_int(min_ix) + max_ix = to_int(max_ix) + max_ix = min(max_ix, self.get_count()) + + buf = self._volume_ranges.get((min_ix, max_ix), None) + if buf: + return buf - max_ix = min(max_ix, self.get_count()) bar_list = list(self._bars.values())[min_ix:max_ix + 1] first_bar = bar_list[0] @@ -135,10 +152,10 @@ class BarManager: def _clear_cache(self) -> None: """ - Clear lru_cache range data. + Clear cached range data. """ - self.get_price_range.cache_clear() - self.get_volume_range.cache_clear() + self._price_ranges.clear() + self._volume_ranges.clear() def clear_all(self) -> None: """ diff --git a/vnpy/chart/widget.py b/vnpy/chart/widget.py index ad9d351c..4b1afd19 100644 --- a/vnpy/chart/widget.py +++ b/vnpy/chart/widget.py @@ -6,7 +6,7 @@ from vnpy.trader.ui import QtGui, QtWidgets, QtCore from vnpy.trader.object import BarData from .manager import BarManager -from .base import GREY_COLOR, WHITE_COLOR +from .base import GREY_COLOR, WHITE_COLOR, CURSOR_COLOR, BLACK_COLOR, to_int from .axis import DatetimeAxis, AXIS_FONT from .item import ChartItem @@ -26,12 +26,14 @@ class ChartWidget(pg.PlotWidget): self._item_plot_map: Dict[ChartItem, pg.GraphicsObject] = {} self._first_plot: pg.PlotItem = None - self._right_ix: int = 0 # Index of most right data - self._bar_count: int = 0 # Total bar visible in chart + self._cursor: ChartCursor = None - self.init_ui() + self._right_ix: int = 0 # Index of most right data + self._bar_count: int = self.MIN_BAR_COUNT # Total bar visible in chart - def init_ui(self) -> None: + self._init_ui() + + def _init_ui(self) -> None: """""" self.setWindowTitle("ChartWidget of vn.py") @@ -46,12 +48,14 @@ class ChartWidget(pg.PlotWidget): def add_cursor(self) -> None: """""" - self._cursor = ChartCursor(self, self._manager, self._plots) + if not self._cursor: + self._cursor = ChartCursor(self, self._manager, self._plots) def add_plot( self, plot_name: str, minimum_height: int = 80, + maximum_height: int = None, hide_x_axis: bool = False ) -> None: """ @@ -68,6 +72,9 @@ class ChartWidget(pg.PlotWidget): plot.hideButtons() plot.setMinimumHeight(minimum_height) + if maximum_height: + plot.setMaximumHeight(maximum_height) + if hide_x_axis: plot.hideAxis("bottom") @@ -76,7 +83,7 @@ class ChartWidget(pg.PlotWidget): # Connect view change signal to update y range function view = plot.getViewBox() - view.sigXRangeChanged.connect(self._update_plot_range) + view.sigXRangeChanged.connect(self._update_y_range) view.setMouseEnabled(x=True, y=False) # Set right axis @@ -126,6 +133,9 @@ class ChartWidget(pg.PlotWidget): for item in self._items.values(): item.clear_all() + if self._cursor: + self._cursor.clear_all() + def update_history(self, history: List[BarData]) -> None: """ Update a list of bar data. @@ -137,6 +147,8 @@ class ChartWidget(pg.PlotWidget): self._update_plot_limits() + self.move_to_right() + def update_bar(self, bar: BarData) -> None: """ Update single bar data. @@ -148,6 +160,9 @@ class ChartWidget(pg.PlotWidget): self._update_plot_limits() + if self._right_ix >= (self._manager.get_count() - self._bar_count / 2): + self.move_to_right() + def _update_plot_limits(self) -> None: """ Update the limit of plots. @@ -156,15 +171,25 @@ class ChartWidget(pg.PlotWidget): min_value, max_value = item.get_y_range() plot.setLimits( - xMin=-self.MIN_BAR_COUNT, + xMin=-1, xMax=self._manager.get_count(), yMin=min_value, yMax=max_value ) - def _update_plot_range(self) -> None: + def _update_x_range(self) -> None: """ - Reset the y-axis range of plots. + Update the x-axis range of plots. + """ + max_ix = self._right_ix + min_ix = self._right_ix - self._bar_count + + for plot in self._plots.values(): + plot.setRange(xRange=(min_ix, max_ix), padding=0) + + def _update_y_range(self) -> None: + """ + Update the y-axis range of plots. """ view = self._first_plot.getViewBox() view_range = view.viewRange() @@ -187,6 +212,79 @@ class ChartWidget(pg.PlotWidget): super().paintEvent(event) + def keyPressEvent(self, event: QtGui.QKeyEvent) -> None: + """ + Reimplement this method of parent to move chart horizontally and zoom in/out. + """ + if event.key() == QtCore.Qt.Key_Left: + self._on_key_left() + elif event.key() == QtCore.Qt.Key_Right: + self._on_key_right() + elif event.key() == QtCore.Qt.Key_Up: + self._on_key_up() + elif event.key() == QtCore.Qt.Key_Down: + self._on_key_down() + + def wheelEvent(self, event: QtGui.QWheelEvent) -> None: + """ + Reimplement this method of parent to zoom in/out. + """ + delta = event.angleDelta() + + if delta.y() > 0: + self._on_key_up() + elif delta.y() < 0: + self._on_key_down() + + def _on_key_left(self) -> None: + """ + Move chart to left. + """ + self._right_ix -= 1 + self._right_ix = max(self._right_ix, self._bar_count) + + self._update_x_range() + self._cursor.move_left() + self._cursor.update_info() + + def _on_key_right(self) -> None: + """ + Move chart to right. + """ + self._right_ix += 1 + self._right_ix = min(self._right_ix, self._manager.get_count()) + + self._update_x_range() + self._cursor.move_right() + self._cursor.update_info() + + def _on_key_down(self) -> None: + """ + Zoom out the chart. + """ + self._bar_count *= 1.2 + self._bar_count = min(int(self._bar_count), self._manager.get_count()) + + self._update_x_range() + self._cursor.update_info() + + def _on_key_up(self) -> None: + """ + Zoom in the chart. + """ + self._bar_count /= 1.2 + self._bar_count = max(int(self._bar_count), self.MIN_BAR_COUNT) + + self._update_x_range() + self._cursor.update_info() + + def move_to_right(self) -> None: + """ + Move chart to the most right. + """ + self._right_ix = self._manager.get_count() + self._update_x_range() + class ChartCursor(QtCore.QObject): """""" @@ -204,15 +302,23 @@ class ChartCursor(QtCore.QObject): self._manager: BarManager = manager self._plots: Dict[str, pg.GraphicsObject] = plots - self._x = 0 - self._y = 0 - self._plot_name = "" + self._x: int = 0 + self._y: int = 0 + self._plot_name: str = "" - self.init_ui() + self._init_ui() + self._connect_signal() - def init_ui(self): + def _init_ui(self): """""" - # Create line objects + self._init_line() + self._init_label() + self._init_info() + + def _init_line(self) -> None: + """ + Create line objects. + """ self._v_lines: Dict[str, pg.InfiniteLine] = {} self._h_lines: Dict[str, pg.InfiniteLine] = {} self._views: Dict[str, pg.ViewBox] = {} @@ -233,33 +339,69 @@ class ChartCursor(QtCore.QObject): self._h_lines[plot_name] = h_line self._views[plot_name] = view - # Connect signal - self.proxy = pg.SignalProxy( + def _init_label(self) -> None: + """ + Create label objects on axis. + """ + self._y_labels: Dict[str, pg.TextItem] = {} + for plot_name, plot in self._plots.items(): + label = pg.TextItem(plot_name, fill=CURSOR_COLOR, color=BLACK_COLOR) + label.hide() + label.setZValue(2) + plot.addItem(label, ignoreBounds=True) + self._y_labels[plot_name] = label + + self._x_label: pg.TextItem = pg.TextItem( + "datetime", fill=CURSOR_COLOR, color=BLACK_COLOR) + self._x_label.hide() + self._x_label.setZValue(2) + plot.addItem(self._x_label, ignoreBounds=True) + + def _init_info(self) -> None: + """ + """ + self._info = pg.TextItem("info", color=CURSOR_COLOR) + self._info.hide() + self._info.setZValue(2) + + plot = list(self._plots.values())[0] + plot.addItem(self._info, ignoreBounds=True) + + def _connect_signal(self) -> None: + """ + Connect mouse move signal to update function. + """ + self._proxy = pg.SignalProxy( self._widget.scene().sigMouseMoved, rateLimit=360, - slot=self.mouse_moved + slot=self._mouse_moved ) - def mouse_moved(self, evt: tuple): - """""" + def _mouse_moved(self, evt: tuple) -> None: + """ + Callback function when mouse is moved. + """ if not self._manager.get_count(): return + # First get current mouse point pos = evt[0] - for plot_name, view in self._views.items(): rect = view.sceneBoundingRect() if rect.contains(pos): mouse_point = view.mapSceneToView(pos) - self._x = mouse_point.x() + self._x = to_int(mouse_point.x()) self._y = mouse_point.y() self._plot_name = plot_name break - self.update_line() + # Then update cursor component + self._update_line() + self._update_label() + self.update_info() - def update_line(self): + def _update_line(self) -> None: """""" for v_line in self._v_lines.values(): v_line.setPos(self._x) @@ -271,3 +413,95 @@ class ChartCursor(QtCore.QObject): h_line.show() else: h_line.hide() + + def _update_label(self) -> None: + """""" + bottom_plot = list(self._plots.values())[-1] + axis_width = bottom_plot.getAxis("right").width() + axis_height = bottom_plot.getAxis("bottom").height() + axis_offset = QtCore.QPointF(axis_width, axis_height) + + bottom_view = list(self._views.values())[-1] + bottom_right = bottom_view.mapSceneToView( + bottom_view.sceneBoundingRect().bottomRight() - axis_offset + ) + + for plot_name, label in self._y_labels.items(): + if plot_name == self._plot_name: + label.setText(str(self._y)) + label.show() + label.setPos(bottom_right.x(), self._y) + else: + label.hide() + + dt = self._manager.get_datetime(self._x) + if dt: + self._x_label.setText(dt.strftime("%Y-%m-%d %H:%M:%S")) + self._x_label.show() + self._x_label.setPos(self._x, bottom_right.y()) + self._x_label.setAnchor((0, 0)) + + def update_info(self) -> None: + """""" + bar = self._manager.get_bar(self._x) + + if bar: + op = bar.open_price + hp = bar.high_price + lp = bar.low_price + cp = bar.close_price + v = bar.volume + text = f"(open){op} (high){hp} (low){lp} (close){cp} (volume){v}" + else: + text = "" + + self._info.setText(text) + self._info.show() + + view = list(self._views.values())[0] + top_left = view.mapSceneToView(view.sceneBoundingRect().topLeft()) + self._info.setPos(top_left) + + def move_right(self) -> None: + """ + Move cursor index to right by 1. + """ + if self._x == self._manager.get_count() - 1: + return + self._x += 1 + + self._update_after_move() + + def move_left(self) -> None: + """ + Move cursor index to left by 1. + """ + if self._x == 0: + return + self._x -= 1 + + self._update_after_move() + + def _update_after_move(self) -> None: + """ + Update cursor after moved by left/right. + """ + bar = self._manager.get_bar(self._x) + self._y = bar.close_price + + self._update_line() + self._update_label() + + def clear_all(self) -> None: + """ + Clear all data. + """ + self._x = 0 + self._y = 0 + self._plot_name = "" + + for line in self._v_lines + self._h_lines: + line.hide() + + for label in self._y_labels + [self._x_label]: + label.hide()