[Add] DataRecorderApp

This commit is contained in:
vn.py 2019-05-07 12:04:41 +08:00
parent 203d7cba7c
commit 1831e2bc23
8 changed files with 439 additions and 17 deletions

View File

@ -23,6 +23,7 @@ from vnpy.app.cta_strategy import CtaStrategyApp
from vnpy.app.csv_loader import CsvLoaderApp from vnpy.app.csv_loader import CsvLoaderApp
from vnpy.app.algo_trading import AlgoTradingApp from vnpy.app.algo_trading import AlgoTradingApp
from vnpy.app.cta_backtester import CtaBacktesterApp from vnpy.app.cta_backtester import CtaBacktesterApp
from vnpy.app.data_recorder import DataRecorderApp
def main(): def main():
@ -51,6 +52,7 @@ def main():
main_engine.add_app(CtaBacktesterApp) main_engine.add_app(CtaBacktesterApp)
main_engine.add_app(CsvLoaderApp) main_engine.add_app(CsvLoaderApp)
main_engine.add_app(AlgoTradingApp) main_engine.add_app(AlgoTradingApp)
main_engine.add_app(DataRecorderApp)
main_window = MainWindow(main_engine, event_engine) main_window = MainWindow(main_engine, event_engine)
main_window.showMaximized() main_window.showMaximized()

View File

@ -0,0 +1,19 @@
from pathlib import Path
from vnpy.trader.app import BaseApp
from vnpy.trader.constant import Direction
from vnpy.trader.object import TickData, BarData, TradeData, OrderData
from vnpy.trader.utility import BarGenerator, ArrayManager
from .engine import RecorderEngine, APP_NAME
class DataRecorderApp(BaseApp):
""""""
app_name = APP_NAME
app_module = __module__
app_path = Path(__file__).parent
display_name = "行情记录"
engine_class = RecorderEngine
widget_name = "RecorderManager"
icon_name = "recorder.ico"

View File

@ -0,0 +1,235 @@
""""""
from threading import Thread
from queue import Queue, Empty
from copy import copy
from vnpy.event import Event, EventEngine
from vnpy.trader.engine import BaseEngine, MainEngine
from vnpy.trader.object import (
SubscribeRequest,
TickData,
BarData,
ContractData
)
from vnpy.trader.event import EVENT_TICK, EVENT_CONTRACT
from vnpy.trader.utility import load_json, save_json, BarGenerator
from vnpy.trader.database import database_manager
APP_NAME = "DataRecorder"
EVENT_RECORDER_LOG = "eRecorderLog"
EVENT_RECORDER_UPDATE = "eRecorderUpdate"
class RecorderEngine(BaseEngine):
""""""
setting_filename = "data_recorder_setting.json"
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
""""""
super().__init__(main_engine, event_engine, APP_NAME)
self.queue = Queue()
self.thread = Thread(target=self.run)
self.active = False
self.tick_recordings = {}
self.bar_recordings = {}
self.bar_generators = {}
self.load_setting()
self.register_event()
self.start()
self.put_event()
def load_setting(self):
""""""
setting = load_json(self.setting_filename)
self.tick_recordings = setting.get("tick", {})
self.bar_recordings = setting.get("bar", {})
def save_setting(self):
""""""
setting = {
"tick": self.tick_recordings,
"bar": self.bar_recordings
}
save_json(self.setting_filename, setting)
def run(self):
""""""
while self.active:
try:
task = self.queue.get(timeout=1)
task_type, data = task
if task_type == "tick":
database_manager.save_tick_data([data])
elif task_type == "bar":
database_manager.save_bar_data([data])
except Empty:
continue
def close(self):
""""""
self.active = False
if self.thread.isAlive():
self.thread.join()
def start(self):
""""""
self.active = True
self.thread.start()
def add_bar_recording(self, vt_symbol: str):
""""""
if vt_symbol in self.bar_recordings:
self.write_log(f"已在K线记录列表中{vt_symbol}")
return
contract = self.main_engine.get_contract(vt_symbol)
if not contract:
self.write_log(f"找不到合约:{vt_symbol}")
return
self.bar_recordings[vt_symbol] = {
"symbol": contract.symbol,
"exchange": contract.exchange.value,
"gateway_name": contract.gateway_name
}
self.subscribe(contract)
self.save_setting()
self.put_event()
self.write_log(f"添加K线记录成功{vt_symbol}")
def add_tick_recording(self, vt_symbol: str):
""""""
if vt_symbol in self.tick_recordings:
self.write_log(f"已在Tick记录列表中{vt_symbol}")
return
contract = self.main_engine.get_contract(vt_symbol)
if not contract:
self.write_log(f"找不到合约:{vt_symbol}")
return
self.tick_recordings[vt_symbol] = {
"symbol": contract.symbol,
"exchange": contract.exchange.value,
"gateway_name": contract.gateway_name
}
self.subscribe(contract)
self.save_setting()
self.put_event()
self.write_log(f"添加Tick记录成功{vt_symbol}")
def remove_bar_recording(self, vt_symbol: str):
""""""
if vt_symbol not in self.bar_recordings:
self.write_log(f"不在K线记录列表中{vt_symbol}")
return
self.bar_recordings.pop(vt_symbol)
self.save_setting()
self.put_event()
self.write_log(f"移除K线记录成功{vt_symbol}")
def remove_tick_recording(self, vt_symbol: str):
""""""
if vt_symbol not in self.tick_recordings:
self.write_log(f"不在Tick记录列表中{vt_symbol}")
return
self.tick_recordings.pop(vt_symbol)
self.save_setting()
self.put_event()
self.write_log(f"移除Tick记录成功{vt_symbol}")
def register_event(self):
""""""
self.event_engine.register(EVENT_TICK, self.process_tick_event)
self.event_engine.register(EVENT_CONTRACT, self.process_contract_event)
def process_tick_event(self, event: Event):
""""""
tick = event.data
if tick.vt_symbol in self.tick_recordings:
self.record_tick(tick)
if tick.vt_symbol in self.bar_recordings:
bg = self.get_bar_generator(tick.vt_symbol)
bg.update_tick(tick)
def process_contract_event(self, event: Event):
""""""
contract = event.data
vt_symbol = contract.vt_symbol
if (vt_symbol in self.tick_recordings or vt_symbol in self.bar_recordings):
self.subscribe(contract)
def write_log(self, msg: str):
""""""
event = Event(
EVENT_RECORDER_LOG,
msg
)
self.event_engine.put(event)
def put_event(self):
""""""
tick_symbols = list(self.tick_recordings.keys())
tick_symbols.sort()
bar_symbols = list(self.bar_recordings.keys())
bar_symbols.sort()
data = {
"tick": tick_symbols,
"bar": bar_symbols
}
event = Event(
EVENT_RECORDER_UPDATE,
data
)
self.event_engine.put(event)
def record_tick(self, tick: TickData):
""""""
task = ("tick", copy(tick))
self.queue.put(task)
def record_bar(self, bar: BarData):
""""""
task = ("bar", copy(bar))
self.queue.put(task)
def get_bar_generator(self, vt_symbol: str):
""""""
bg = self.bar_generators.get(vt_symbol, None)
if not bg:
bg = BarGenerator(self.record_bar)
self.bar_generators[vt_symbol] = bg
return bg
def subscribe(self, contract: ContractData):
""""""
req = SubscribeRequest(
symbol=contract.symbol,
exchange=contract.exchange
)
self.main_engine.subscribe(req, contract.gateway_name)

View File

@ -0,0 +1 @@
from .widget import RecorderManager

Binary file not shown.

After

Width:  |  Height:  |  Size: 66 KiB

View File

@ -0,0 +1,158 @@
from datetime import datetime
from vnpy.event import Event, EventEngine
from vnpy.trader.engine import MainEngine
from vnpy.trader.ui import QtCore, QtGui, QtWidgets
from vnpy.trader.event import EVENT_CONTRACT
from ..engine import (
APP_NAME,
EVENT_RECORDER_LOG,
EVENT_RECORDER_UPDATE
)
class RecorderManager(QtWidgets.QWidget):
""""""
signal_log = QtCore.pyqtSignal(Event)
signal_update = QtCore.pyqtSignal(Event)
signal_contract = QtCore.pyqtSignal(Event)
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
super().__init__()
self.main_engine = main_engine
self.event_engine = event_engine
self.recorder_engine = main_engine.get_engine(APP_NAME)
self.init_ui()
self.register_event()
self.recorder_engine.put_event()
def init_ui(self):
""""""
self.setWindowTitle("行情记录")
self.resize(1000, 600)
# Create widgets
self.symbol_line = QtWidgets.QLineEdit()
self.symbol_line.setFixedHeight(
self.symbol_line.sizeHint().height() * 2)
contracts = self.main_engine.get_all_contracts()
self.vt_symbols = [contract.vt_symbol for contract in contracts]
self.symbol_completer = QtWidgets.QCompleter(self.vt_symbols)
self.symbol_completer.setFilterMode(QtCore.Qt.MatchContains)
self.symbol_completer.setCompletionMode(
self.symbol_completer.PopupCompletion)
self.symbol_line.setCompleter(self.symbol_completer)
add_bar_button = QtWidgets.QPushButton("添加")
add_bar_button.clicked.connect(self.add_bar_recording)
remove_bar_button = QtWidgets.QPushButton("移除")
remove_bar_button.clicked.connect(self.remove_bar_recording)
add_tick_button = QtWidgets.QPushButton("添加")
add_tick_button.clicked.connect(self.add_tick_recording)
remove_tick_button = QtWidgets.QPushButton("移除")
remove_tick_button.clicked.connect(self.remove_tick_recording)
self.bar_recording_edit = QtWidgets.QTextEdit()
self.bar_recording_edit.setReadOnly(True)
self.tick_recording_edit = QtWidgets.QTextEdit()
self.tick_recording_edit.setReadOnly(True)
self.log_edit = QtWidgets.QTextEdit()
self.log_edit.setReadOnly(True)
# Set layout
grid = QtWidgets.QGridLayout()
grid.addWidget(QtWidgets.QLabel("K线记录"), 0, 0)
grid.addWidget(add_bar_button, 0, 1)
grid.addWidget(remove_bar_button, 0, 2)
grid.addWidget(QtWidgets.QLabel("Tick记录"), 1, 0)
grid.addWidget(add_tick_button, 1, 1)
grid.addWidget(remove_tick_button, 1, 2)
hbox = QtWidgets.QHBoxLayout()
hbox.addWidget(QtWidgets.QLabel("本地代码"))
hbox.addWidget(self.symbol_line)
hbox.addWidget(QtWidgets.QLabel(" "))
hbox.addLayout(grid)
hbox.addStretch()
grid2 = QtWidgets.QGridLayout()
grid2.addWidget(QtWidgets.QLabel("K线记录列表"), 0, 0)
grid2.addWidget(QtWidgets.QLabel("Tick记录列表"), 0, 1)
grid2.addWidget(self.bar_recording_edit, 1, 0)
grid2.addWidget(self.tick_recording_edit, 1, 1)
grid2.addWidget(self.log_edit, 2, 0, 1, 2)
vbox = QtWidgets.QVBoxLayout()
vbox.addLayout(hbox)
vbox.addLayout(grid2)
self.setLayout(vbox)
def register_event(self):
""""""
self.signal_log.connect(self.process_log_event)
self.signal_contract.connect(self.process_contract_event)
self.signal_update.connect(self.process_update_event)
self.event_engine.register(EVENT_CONTRACT, self.signal_contract.emit)
self.event_engine.register(
EVENT_RECORDER_LOG, self.signal_log.emit)
self.event_engine.register(
EVENT_RECORDER_UPDATE, self.signal_update.emit)
def process_log_event(self, event: Event):
""""""
timestamp = datetime.now().strftime("%H:%M:%S")
msg = f"{timestamp}\t{event.data}"
self.log_edit.append(msg)
def process_update_event(self, event: Event):
""""""
data = event.data
self.bar_recording_edit.clear()
bar_text = "\n".join(data["bar"])
self.bar_recording_edit.setText(bar_text)
self.tick_recording_edit.clear()
tick_text = "\n".join(data["tick"])
self.tick_recording_edit.setText(tick_text)
def process_contract_event(self, event: Event):
""""""
contract = event.data
self.vt_symbols.append(contract.vt_symbol)
model = self.symbol_completer.model()
model.setStringList(self.vt_symbols)
def add_bar_recording(self):
""""""
vt_symbol = self.symbol_line.text()
self.recorder_engine.add_bar_recording(vt_symbol)
def add_tick_recording(self):
""""""
vt_symbol = self.symbol_line.text()
self.recorder_engine.add_tick_recording(vt_symbol)
def remove_bar_recording(self):
""""""
vt_symbol = self.symbol_line.text()
self.recorder_engine.remove_bar_recording(vt_symbol)
def remove_tick_recording(self):
""""""
vt_symbol = self.symbol_line.text()
self.recorder_engine.remove_tick_recording(vt_symbol)

View File

@ -141,7 +141,8 @@ def init_models(db: Database, driver: Driver):
).execute() ).execute()
else: else:
for c in chunked(dicts, 50): for c in chunked(dicts, 50):
DbBarData.insert_many(c).on_conflict_replace().execute() DbBarData.insert_many(
c).on_conflict_replace().execute()
class DbTickData(ModelBase): class DbTickData(ModelBase):
""" """
@ -309,7 +310,8 @@ def init_models(db: Database, driver: Driver):
).execute() ).execute()
else: else:
for c in chunked(dicts, 50): for c in chunked(dicts, 50):
DbTickData.insert_many(c).on_conflict_replace().execute() DbTickData.insert_many(
c).on_conflict_replace().execute()
db.connect() db.connect()
db.create_tables([DbBarData, DbTickData]) db.create_tables([DbBarData, DbTickData])
@ -332,11 +334,11 @@ class SqlManager(BaseDatabaseManager):
s = ( s = (
self.class_bar.select() self.class_bar.select()
.where( .where(
(self.class_bar.symbol == symbol) (self.class_bar.symbol == symbol) &
& (self.class_bar.exchange == exchange.value) (self.class_bar.exchange == exchange.value) &
& (self.class_bar.interval == interval.value) (self.class_bar.interval == interval.value) &
& (self.class_bar.datetime >= start) (self.class_bar.datetime >= start) &
& (self.class_bar.datetime <= end) (self.class_bar.datetime <= end)
) )
.order_by(self.class_bar.datetime) .order_by(self.class_bar.datetime)
) )
@ -349,10 +351,10 @@ class SqlManager(BaseDatabaseManager):
s = ( s = (
self.class_tick.select() self.class_tick.select()
.where( .where(
(self.class_tick.symbol == symbol) (self.class_tick.symbol == symbol) &
& (self.class_tick.exchange == exchange.value) (self.class_tick.exchange == exchange.value) &
& (self.class_tick.datetime >= start) (self.class_tick.datetime >= start) &
& (self.class_tick.datetime <= end) (self.class_tick.datetime <= end)
) )
.order_by(self.class_tick.datetime) .order_by(self.class_tick.datetime)
) )
@ -374,9 +376,9 @@ class SqlManager(BaseDatabaseManager):
s = ( s = (
self.class_bar.select() self.class_bar.select()
.where( .where(
(self.class_bar.symbol == symbol) (self.class_bar.symbol == symbol) &
& (self.class_bar.exchange == exchange.value) (self.class_bar.exchange == exchange.value) &
& (self.class_bar.interval == interval.value) (self.class_bar.interval == interval.value)
) )
.order_by(self.class_bar.datetime.desc()) .order_by(self.class_bar.datetime.desc())
.first() .first()
@ -391,8 +393,8 @@ class SqlManager(BaseDatabaseManager):
s = ( s = (
self.class_tick.select() self.class_tick.select()
.where( .where(
(self.class_tick.symbol == symbol) (self.class_tick.symbol == symbol) &
& (self.class_tick.exchange == exchange.value) (self.class_tick.exchange == exchange.value)
) )
.order_by(self.class_tick.datetime.desc()) .order_by(self.class_tick.datetime.desc())
.first() .first()

View File

@ -10,7 +10,7 @@ import numpy as np
import talib import talib
from .object import BarData, TickData from .object import BarData, TickData
from .constant import Exchange from .constant import Exchange, Interval
def extract_vt_symbol(vt_symbol: str): def extract_vt_symbol(vt_symbol: str):
@ -135,6 +135,10 @@ class BarGenerator:
""" """
new_minute = False new_minute = False
# Filter tick data with 0 last price
if not tick.last_price:
return
if not self.bar: if not self.bar:
new_minute = True new_minute = True
elif self.bar.datetime.minute != tick.datetime.minute: elif self.bar.datetime.minute != tick.datetime.minute:
@ -149,6 +153,7 @@ class BarGenerator:
self.bar = BarData( self.bar = BarData(
symbol=tick.symbol, symbol=tick.symbol,
exchange=tick.exchange, exchange=tick.exchange,
interval=Interval.MINUTE,
datetime=tick.datetime, datetime=tick.datetime,
gateway_name=tick.gateway_name, gateway_name=tick.gateway_name,
open_price=tick.last_price, open_price=tick.last_price,