diff --git a/tests/trader/run.py b/tests/trader/run.py index 5b385199..82b60deb 100644 --- a/tests/trader/run.py +++ b/tests/trader/run.py @@ -23,6 +23,7 @@ from vnpy.app.cta_strategy import CtaStrategyApp from vnpy.app.csv_loader import CsvLoaderApp from vnpy.app.algo_trading import AlgoTradingApp from vnpy.app.cta_backtester import CtaBacktesterApp +from vnpy.app.data_recorder import DataRecorderApp def main(): @@ -51,6 +52,7 @@ def main(): main_engine.add_app(CtaBacktesterApp) main_engine.add_app(CsvLoaderApp) main_engine.add_app(AlgoTradingApp) + main_engine.add_app(DataRecorderApp) main_window = MainWindow(main_engine, event_engine) main_window.showMaximized() diff --git a/vnpy/app/data_recorder/__init__.py b/vnpy/app/data_recorder/__init__.py new file mode 100644 index 00000000..5d91b72b --- /dev/null +++ b/vnpy/app/data_recorder/__init__.py @@ -0,0 +1,19 @@ +from pathlib import Path + +from vnpy.trader.app import BaseApp +from vnpy.trader.constant import Direction +from vnpy.trader.object import TickData, BarData, TradeData, OrderData +from vnpy.trader.utility import BarGenerator, ArrayManager + +from .engine import RecorderEngine, APP_NAME + + +class DataRecorderApp(BaseApp): + """""" + app_name = APP_NAME + app_module = __module__ + app_path = Path(__file__).parent + display_name = "行情记录" + engine_class = RecorderEngine + widget_name = "RecorderManager" + icon_name = "recorder.ico" diff --git a/vnpy/app/data_recorder/engine.py b/vnpy/app/data_recorder/engine.py new file mode 100644 index 00000000..da4c362c --- /dev/null +++ b/vnpy/app/data_recorder/engine.py @@ -0,0 +1,235 @@ +"""""" + +from threading import Thread +from queue import Queue, Empty +from copy import copy + +from vnpy.event import Event, EventEngine +from vnpy.trader.engine import BaseEngine, MainEngine +from vnpy.trader.object import ( + SubscribeRequest, + TickData, + BarData, + ContractData +) +from vnpy.trader.event import EVENT_TICK, EVENT_CONTRACT +from vnpy.trader.utility import load_json, save_json, BarGenerator +from vnpy.trader.database import database_manager + + +APP_NAME = "DataRecorder" + +EVENT_RECORDER_LOG = "eRecorderLog" +EVENT_RECORDER_UPDATE = "eRecorderUpdate" + + +class RecorderEngine(BaseEngine): + """""" + setting_filename = "data_recorder_setting.json" + + def __init__(self, main_engine: MainEngine, event_engine: EventEngine): + """""" + super().__init__(main_engine, event_engine, APP_NAME) + + self.queue = Queue() + self.thread = Thread(target=self.run) + self.active = False + + self.tick_recordings = {} + self.bar_recordings = {} + self.bar_generators = {} + + self.load_setting() + self.register_event() + self.start() + self.put_event() + + def load_setting(self): + """""" + setting = load_json(self.setting_filename) + self.tick_recordings = setting.get("tick", {}) + self.bar_recordings = setting.get("bar", {}) + + def save_setting(self): + """""" + setting = { + "tick": self.tick_recordings, + "bar": self.bar_recordings + } + save_json(self.setting_filename, setting) + + def run(self): + """""" + while self.active: + try: + task = self.queue.get(timeout=1) + task_type, data = task + + if task_type == "tick": + database_manager.save_tick_data([data]) + elif task_type == "bar": + database_manager.save_bar_data([data]) + + except Empty: + continue + + def close(self): + """""" + self.active = False + + if self.thread.isAlive(): + self.thread.join() + + def start(self): + """""" + self.active = True + self.thread.start() + + def add_bar_recording(self, vt_symbol: str): + """""" + if vt_symbol in self.bar_recordings: + self.write_log(f"已在K线记录列表中:{vt_symbol}") + return + + contract = self.main_engine.get_contract(vt_symbol) + if not contract: + self.write_log(f"找不到合约:{vt_symbol}") + return + + self.bar_recordings[vt_symbol] = { + "symbol": contract.symbol, + "exchange": contract.exchange.value, + "gateway_name": contract.gateway_name + } + + self.subscribe(contract) + self.save_setting() + self.put_event() + + self.write_log(f"添加K线记录成功:{vt_symbol}") + + def add_tick_recording(self, vt_symbol: str): + """""" + if vt_symbol in self.tick_recordings: + self.write_log(f"已在Tick记录列表中:{vt_symbol}") + return + + contract = self.main_engine.get_contract(vt_symbol) + if not contract: + self.write_log(f"找不到合约:{vt_symbol}") + return + + self.tick_recordings[vt_symbol] = { + "symbol": contract.symbol, + "exchange": contract.exchange.value, + "gateway_name": contract.gateway_name + } + + self.subscribe(contract) + self.save_setting() + self.put_event() + + self.write_log(f"添加Tick记录成功:{vt_symbol}") + + def remove_bar_recording(self, vt_symbol: str): + """""" + if vt_symbol not in self.bar_recordings: + self.write_log(f"不在K线记录列表中:{vt_symbol}") + return + + self.bar_recordings.pop(vt_symbol) + self.save_setting() + self.put_event() + + self.write_log(f"移除K线记录成功:{vt_symbol}") + + def remove_tick_recording(self, vt_symbol: str): + """""" + if vt_symbol not in self.tick_recordings: + self.write_log(f"不在Tick记录列表中:{vt_symbol}") + return + + self.tick_recordings.pop(vt_symbol) + self.save_setting() + self.put_event() + + self.write_log(f"移除Tick记录成功:{vt_symbol}") + + def register_event(self): + """""" + self.event_engine.register(EVENT_TICK, self.process_tick_event) + self.event_engine.register(EVENT_CONTRACT, self.process_contract_event) + + def process_tick_event(self, event: Event): + """""" + tick = event.data + + if tick.vt_symbol in self.tick_recordings: + self.record_tick(tick) + + if tick.vt_symbol in self.bar_recordings: + bg = self.get_bar_generator(tick.vt_symbol) + bg.update_tick(tick) + + def process_contract_event(self, event: Event): + """""" + contract = event.data + vt_symbol = contract.vt_symbol + + if (vt_symbol in self.tick_recordings or vt_symbol in self.bar_recordings): + self.subscribe(contract) + + def write_log(self, msg: str): + """""" + event = Event( + EVENT_RECORDER_LOG, + msg + ) + self.event_engine.put(event) + + def put_event(self): + """""" + tick_symbols = list(self.tick_recordings.keys()) + tick_symbols.sort() + + bar_symbols = list(self.bar_recordings.keys()) + bar_symbols.sort() + + data = { + "tick": tick_symbols, + "bar": bar_symbols + } + + event = Event( + EVENT_RECORDER_UPDATE, + data + ) + self.event_engine.put(event) + + def record_tick(self, tick: TickData): + """""" + task = ("tick", copy(tick)) + self.queue.put(task) + + def record_bar(self, bar: BarData): + """""" + task = ("bar", copy(bar)) + self.queue.put(task) + + def get_bar_generator(self, vt_symbol: str): + """""" + bg = self.bar_generators.get(vt_symbol, None) + + if not bg: + bg = BarGenerator(self.record_bar) + self.bar_generators[vt_symbol] = bg + + return bg + + def subscribe(self, contract: ContractData): + """""" + req = SubscribeRequest( + symbol=contract.symbol, + exchange=contract.exchange + ) + self.main_engine.subscribe(req, contract.gateway_name) diff --git a/vnpy/app/data_recorder/ui/__init__.py b/vnpy/app/data_recorder/ui/__init__.py new file mode 100644 index 00000000..7339f138 --- /dev/null +++ b/vnpy/app/data_recorder/ui/__init__.py @@ -0,0 +1 @@ +from .widget import RecorderManager diff --git a/vnpy/app/data_recorder/ui/recorder.ico b/vnpy/app/data_recorder/ui/recorder.ico new file mode 100644 index 00000000..1ddc8b85 Binary files /dev/null and b/vnpy/app/data_recorder/ui/recorder.ico differ diff --git a/vnpy/app/data_recorder/ui/widget.py b/vnpy/app/data_recorder/ui/widget.py new file mode 100644 index 00000000..44809baa --- /dev/null +++ b/vnpy/app/data_recorder/ui/widget.py @@ -0,0 +1,158 @@ +from datetime import datetime + + +from vnpy.event import Event, EventEngine +from vnpy.trader.engine import MainEngine +from vnpy.trader.ui import QtCore, QtGui, QtWidgets +from vnpy.trader.event import EVENT_CONTRACT + +from ..engine import ( + APP_NAME, + EVENT_RECORDER_LOG, + EVENT_RECORDER_UPDATE +) + + +class RecorderManager(QtWidgets.QWidget): + """""" + + signal_log = QtCore.pyqtSignal(Event) + signal_update = QtCore.pyqtSignal(Event) + signal_contract = QtCore.pyqtSignal(Event) + + def __init__(self, main_engine: MainEngine, event_engine: EventEngine): + super().__init__() + + self.main_engine = main_engine + self.event_engine = event_engine + self.recorder_engine = main_engine.get_engine(APP_NAME) + + self.init_ui() + self.register_event() + self.recorder_engine.put_event() + + def init_ui(self): + """""" + self.setWindowTitle("行情记录") + self.resize(1000, 600) + + # Create widgets + self.symbol_line = QtWidgets.QLineEdit() + self.symbol_line.setFixedHeight( + self.symbol_line.sizeHint().height() * 2) + + contracts = self.main_engine.get_all_contracts() + self.vt_symbols = [contract.vt_symbol for contract in contracts] + + self.symbol_completer = QtWidgets.QCompleter(self.vt_symbols) + self.symbol_completer.setFilterMode(QtCore.Qt.MatchContains) + self.symbol_completer.setCompletionMode( + self.symbol_completer.PopupCompletion) + self.symbol_line.setCompleter(self.symbol_completer) + + add_bar_button = QtWidgets.QPushButton("添加") + add_bar_button.clicked.connect(self.add_bar_recording) + + remove_bar_button = QtWidgets.QPushButton("移除") + remove_bar_button.clicked.connect(self.remove_bar_recording) + + add_tick_button = QtWidgets.QPushButton("添加") + add_tick_button.clicked.connect(self.add_tick_recording) + + remove_tick_button = QtWidgets.QPushButton("移除") + remove_tick_button.clicked.connect(self.remove_tick_recording) + + self.bar_recording_edit = QtWidgets.QTextEdit() + self.bar_recording_edit.setReadOnly(True) + + self.tick_recording_edit = QtWidgets.QTextEdit() + self.tick_recording_edit.setReadOnly(True) + + self.log_edit = QtWidgets.QTextEdit() + self.log_edit.setReadOnly(True) + + # Set layout + grid = QtWidgets.QGridLayout() + grid.addWidget(QtWidgets.QLabel("K线记录"), 0, 0) + grid.addWidget(add_bar_button, 0, 1) + grid.addWidget(remove_bar_button, 0, 2) + grid.addWidget(QtWidgets.QLabel("Tick记录"), 1, 0) + grid.addWidget(add_tick_button, 1, 1) + grid.addWidget(remove_tick_button, 1, 2) + + hbox = QtWidgets.QHBoxLayout() + hbox.addWidget(QtWidgets.QLabel("本地代码")) + hbox.addWidget(self.symbol_line) + hbox.addWidget(QtWidgets.QLabel(" ")) + hbox.addLayout(grid) + hbox.addStretch() + + grid2 = QtWidgets.QGridLayout() + grid2.addWidget(QtWidgets.QLabel("K线记录列表"), 0, 0) + grid2.addWidget(QtWidgets.QLabel("Tick记录列表"), 0, 1) + grid2.addWidget(self.bar_recording_edit, 1, 0) + grid2.addWidget(self.tick_recording_edit, 1, 1) + grid2.addWidget(self.log_edit, 2, 0, 1, 2) + + vbox = QtWidgets.QVBoxLayout() + vbox.addLayout(hbox) + vbox.addLayout(grid2) + self.setLayout(vbox) + + def register_event(self): + """""" + self.signal_log.connect(self.process_log_event) + self.signal_contract.connect(self.process_contract_event) + self.signal_update.connect(self.process_update_event) + + self.event_engine.register(EVENT_CONTRACT, self.signal_contract.emit) + self.event_engine.register( + EVENT_RECORDER_LOG, self.signal_log.emit) + self.event_engine.register( + EVENT_RECORDER_UPDATE, self.signal_update.emit) + + def process_log_event(self, event: Event): + """""" + timestamp = datetime.now().strftime("%H:%M:%S") + msg = f"{timestamp}\t{event.data}" + self.log_edit.append(msg) + + def process_update_event(self, event: Event): + """""" + data = event.data + + self.bar_recording_edit.clear() + bar_text = "\n".join(data["bar"]) + self.bar_recording_edit.setText(bar_text) + + self.tick_recording_edit.clear() + tick_text = "\n".join(data["tick"]) + self.tick_recording_edit.setText(tick_text) + + def process_contract_event(self, event: Event): + """""" + contract = event.data + self.vt_symbols.append(contract.vt_symbol) + + model = self.symbol_completer.model() + model.setStringList(self.vt_symbols) + + def add_bar_recording(self): + """""" + vt_symbol = self.symbol_line.text() + self.recorder_engine.add_bar_recording(vt_symbol) + + def add_tick_recording(self): + """""" + vt_symbol = self.symbol_line.text() + self.recorder_engine.add_tick_recording(vt_symbol) + + def remove_bar_recording(self): + """""" + vt_symbol = self.symbol_line.text() + self.recorder_engine.remove_bar_recording(vt_symbol) + + def remove_tick_recording(self): + """""" + vt_symbol = self.symbol_line.text() + self.recorder_engine.remove_tick_recording(vt_symbol) diff --git a/vnpy/trader/database/database_sql.py b/vnpy/trader/database/database_sql.py index 37c2e609..7452e467 100644 --- a/vnpy/trader/database/database_sql.py +++ b/vnpy/trader/database/database_sql.py @@ -141,7 +141,8 @@ def init_models(db: Database, driver: Driver): ).execute() else: for c in chunked(dicts, 50): - DbBarData.insert_many(c).on_conflict_replace().execute() + DbBarData.insert_many( + c).on_conflict_replace().execute() class DbTickData(ModelBase): """ @@ -309,7 +310,8 @@ def init_models(db: Database, driver: Driver): ).execute() else: for c in chunked(dicts, 50): - DbTickData.insert_many(c).on_conflict_replace().execute() + DbTickData.insert_many( + c).on_conflict_replace().execute() db.connect() db.create_tables([DbBarData, DbTickData]) @@ -332,11 +334,11 @@ class SqlManager(BaseDatabaseManager): s = ( self.class_bar.select() .where( - (self.class_bar.symbol == symbol) - & (self.class_bar.exchange == exchange.value) - & (self.class_bar.interval == interval.value) - & (self.class_bar.datetime >= start) - & (self.class_bar.datetime <= end) + (self.class_bar.symbol == symbol) & + (self.class_bar.exchange == exchange.value) & + (self.class_bar.interval == interval.value) & + (self.class_bar.datetime >= start) & + (self.class_bar.datetime <= end) ) .order_by(self.class_bar.datetime) ) @@ -349,10 +351,10 @@ class SqlManager(BaseDatabaseManager): s = ( self.class_tick.select() .where( - (self.class_tick.symbol == symbol) - & (self.class_tick.exchange == exchange.value) - & (self.class_tick.datetime >= start) - & (self.class_tick.datetime <= end) + (self.class_tick.symbol == symbol) & + (self.class_tick.exchange == exchange.value) & + (self.class_tick.datetime >= start) & + (self.class_tick.datetime <= end) ) .order_by(self.class_tick.datetime) ) @@ -374,9 +376,9 @@ class SqlManager(BaseDatabaseManager): s = ( self.class_bar.select() .where( - (self.class_bar.symbol == symbol) - & (self.class_bar.exchange == exchange.value) - & (self.class_bar.interval == interval.value) + (self.class_bar.symbol == symbol) & + (self.class_bar.exchange == exchange.value) & + (self.class_bar.interval == interval.value) ) .order_by(self.class_bar.datetime.desc()) .first() @@ -391,8 +393,8 @@ class SqlManager(BaseDatabaseManager): s = ( self.class_tick.select() .where( - (self.class_tick.symbol == symbol) - & (self.class_tick.exchange == exchange.value) + (self.class_tick.symbol == symbol) & + (self.class_tick.exchange == exchange.value) ) .order_by(self.class_tick.datetime.desc()) .first() diff --git a/vnpy/trader/utility.py b/vnpy/trader/utility.py index ef245503..ab1406de 100644 --- a/vnpy/trader/utility.py +++ b/vnpy/trader/utility.py @@ -10,7 +10,7 @@ import numpy as np import talib from .object import BarData, TickData -from .constant import Exchange +from .constant import Exchange, Interval def extract_vt_symbol(vt_symbol: str): @@ -135,6 +135,10 @@ class BarGenerator: """ new_minute = False + # Filter tick data with 0 last price + if not tick.last_price: + return + if not self.bar: new_minute = True elif self.bar.datetime.minute != tick.datetime.minute: @@ -149,6 +153,7 @@ class BarGenerator: self.bar = BarData( symbol=tick.symbol, exchange=tick.exchange, + interval=Interval.MINUTE, datetime=tick.datetime, gateway_name=tick.gateway_name, open_price=tick.last_price,