From 1074a26b77966f1f91a6c185584ef37c8c4a7153 Mon Sep 17 00:00:00 2001 From: nanoric Date: Fri, 12 Apr 2019 04:41:56 -0400 Subject: [PATCH] [Add] added support for PostgreSQL --- tests/app/test_csv_loader.py | 90 ++++++++++++++++++++++ tests/backtesting/getdata.py | 12 +-- vnpy/app/csv_loader/engine.py | 134 +++++++++++++++++++-------------- vnpy/trader/database.py | 137 ++++++++++++++++++++++++---------- vnpy/trader/setting.py | 15 ++-- 5 files changed, 278 insertions(+), 110 deletions(-) create mode 100644 tests/app/test_csv_loader.py diff --git a/tests/app/test_csv_loader.py b/tests/app/test_csv_loader.py new file mode 100644 index 00000000..b52a2183 --- /dev/null +++ b/tests/app/test_csv_loader.py @@ -0,0 +1,90 @@ +""" +Test if csv loader works fine +""" +import tempfile +import unittest + +from vnpy.app.csv_loader import CsvLoaderEngine +from vnpy.trader.constant import Exchange, Interval + + +class TestCsvLoader(unittest.TestCase): + + def setUp(self) -> None: + self.engine = CsvLoaderEngine(None, None) # no engine is necessary for CsvLoader + + def test_load(self): + data = """"Datetime","Open","High","Low","Close","Volume" +2010-04-16 09:16:00,3450.0,3488.0,3450.0,3468.0,489 +2010-04-16 09:17:00,3468.0,3473.8,3467.0,3467.0,302 +2010-04-16 09:18:00,3467.0,3471.0,3466.0,3467.0,203 +2010-04-16 09:19:00,3467.0,3468.2,3448.0,3448.0,280 +2010-04-16 09:20:00,3448.0,3459.0,3448.0,3454.0,250 +2010-04-16 09:21:00,3454.0,3456.8,3454.0,3456.8,109 +""" + with tempfile.TemporaryFile("w+t") as f: + f.write(data) + f.seek(0) + + self.engine.load_by_handle( + f, + symbol="1", + exchange=Exchange.BITMEX, + interval=Interval.MINUTE, + datetime_head="Datetime", + open_head="Open", + close_head="Close", + low_head="Low", + high_head="High", + volume_head="Volume", + datetime_format="%Y-%m-%d %H:%M:%S", + ) + + def test_load_duplicated(self): + data = """"Datetime","Open","High","Low","Close","Volume" +2010-04-16 09:16:00,3450.0,3488.0,3450.0,3468.0,489 +2010-04-16 09:17:00,3468.0,3473.8,3467.0,3467.0,302 +2010-04-16 09:18:00,3467.0,3471.0,3466.0,3467.0,203 +2010-04-16 09:19:00,3467.0,3468.2,3448.0,3448.0,280 +2010-04-16 09:20:00,3448.0,3459.0,3448.0,3454.0,250 +2010-04-16 09:21:00,3454.0,3456.8,3454.0,3456.8,109 +""" + with tempfile.TemporaryFile("w+t") as f: + f.write(data) + f.seek(0) + + self.engine.load_by_handle( + f, + symbol="1", + exchange=Exchange.BITMEX, + interval=Interval.MINUTE, + datetime_head="Datetime", + open_head="Open", + close_head="Close", + low_head="Low", + high_head="High", + volume_head="Volume", + datetime_format="%Y-%m-%d %H:%M:%S", + ) + + with tempfile.TemporaryFile("w+t") as f: + f.write(data) + f.seek(0) + + self.engine.load_by_handle( + f, + symbol="1", + exchange=Exchange.BITMEX, + interval=Interval.MINUTE, + datetime_head="Datetime", + open_head="Open", + close_head="Close", + low_head="Low", + high_head="High", + volume_head="Volume", + datetime_format="%Y-%m-%d %H:%M:%S", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/backtesting/getdata.py b/tests/backtesting/getdata.py index 77436bcc..62d02465 100644 --- a/tests/backtesting/getdata.py +++ b/tests/backtesting/getdata.py @@ -2,7 +2,7 @@ from time import time import rqdatac as rq -from vnpy.trader.database import DbBarData, DB +from vnpy.trader.database import DbBarData USERNAME = "" PASSWORD = "" @@ -39,11 +39,11 @@ def download_minute_bar(vt_symbol): df = rq.get_price(symbol, frequency="1m", fields=FIELDS) - with DB.atomic(): - for ix, row in df.iterrows(): - print(row.name) - bar = generate_bar_from_row(row, symbol, exchange) - DbBarData.replace(bar.__data__).execute() + bars = [] + for ix, row in df.iterrows(): + bar = generate_bar_from_row(row, symbol, exchange) + bars.append(bar) + DbBarData.save_all(bars) end = time() cost = (end - start) * 1000 diff --git a/vnpy/app/csv_loader/engine.py b/vnpy/app/csv_loader/engine.py index fa88afe7..d9d29385 100644 --- a/vnpy/app/csv_loader/engine.py +++ b/vnpy/app/csv_loader/engine.py @@ -22,15 +22,13 @@ Sample csv file: import csv from datetime import datetime - -from peewee import chunked +from typing import TextIO from vnpy.event import EventEngine from vnpy.trader.constant import Exchange, Interval -from vnpy.trader.database import DbBarData, DB +from vnpy.trader.database import DbBarData from vnpy.trader.engine import BaseEngine, MainEngine - APP_NAME = "CsvLoader" @@ -41,17 +39,70 @@ class CsvLoaderEngine(BaseEngine): """""" super().__init__(main_engine, event_engine, APP_NAME) - self.file_path: str = "" + self.file_path: str = '' self.symbol: str = "" self.exchange: Exchange = Exchange.SSE self.interval: Interval = Interval.MINUTE - self.datetime_head: str = "" - self.open_head: str = "" - self.close_head: str = "" - self.low_head: str = "" - self.high_head: str = "" - self.volume_head: str = "" + self.datetime_head: str = '' + self.open_head: str = '' + self.close_head: str = '' + self.low_head: str = '' + self.high_head: str = '' + self.volume_head: str = '' + + def load_by_handle( + self, + f: TextIO, + symbol: str, + exchange: Exchange, + interval: Interval, + datetime_head: str, + open_head: str, + close_head: str, + low_head: str, + high_head: str, + volume_head: str, + datetime_format: str + ): + """ + load by text mode file handle + """ + reader = csv.DictReader(f) + + db_bars = [] + start = None + count = 0 + for item in reader: + if datetime_format: + dt = datetime.strptime(item[datetime_head], datetime_format) + else: + dt = datetime.fromisoformat(item[datetime_head]) + + db_bar = DbBarData( + symbol=symbol, + exchange=exchange.value, + datetime=dt, + interval=interval.value, + volume=item[volume_head], + open_price=item[open_head], + high_price=item[high_head], + low_price=item[low_head], + close_price=item[close_head], + ) + + db_bars.append(db_bar) + + # do some statistics + count += 1 + if not start: + start = db_bar.datetime + end = db_bar.datetime + + # insert into database + DbBarData.save_all(db_bars) + + return start, end, count def load( self, @@ -67,47 +118,20 @@ class CsvLoaderEngine(BaseEngine): volume_head: str, datetime_format: str ): - """""" - vt_symbol = f"{symbol}.{exchange.value}" - - start = None - end = None - count = 0 - - with open(file_path, "rt") as f: - reader = csv.DictReader(f) - - db_bars = [] - - for item in reader: - dt = datetime.strptime(item[datetime_head], datetime_format) - - db_bar = { - "symbol": symbol, - "exchange": exchange.value, - "datetime": dt, - "interval": interval.value, - "volume": item[volume_head], - "open_price": item[open_head], - "high_price": item[high_head], - "low_price": item[low_head], - "close_price": item[close_head], - "vt_symbol": vt_symbol, - "gateway_name": "DB" - } - - db_bars.append(db_bar) - - # do some statistics - count += 1 - if not start: - start = db_bar["datetime"] - - end = db_bar["datetime"] - - # Insert into DB - with DB.atomic(): - for batch in chunked(db_bars, 50): - DbBarData.insert_many(batch).on_conflict_replace().execute() - - return start, end, count + """ + load by filename + """ + with open(file_path, 'rt') as f: + return self.load_by_handle( + f, + symbol=symbol, + exchange=exchange, + interval=interval, + datetime_head=datetime_head, + open_head=open_head, + close_head=close_head, + low_head=low_head, + high_head=high_head, + volume_head=volume_head, + datetime_format=datetime_format, + ) diff --git a/vnpy/trader/database.py b/vnpy/trader/database.py index 1e90dcbc..ef0c9cc2 100644 --- a/vnpy/trader/database.py +++ b/vnpy/trader/database.py @@ -1,56 +1,85 @@ """""" +from enum import Enum +from typing import List -from peewee import CharField, DateTimeField, FloatField, Model, MySQLDatabase, PostgresqlDatabase, \ - SqliteDatabase +from peewee import ( + AutoField, + CharField, + Database, + DateTimeField, + FloatField, + Model, + MySQLDatabase, + PostgresqlDatabase, + SqliteDatabase, + chunked, +) from .constant import Exchange, Interval from .object import BarData, TickData from .setting import SETTINGS -from .utility import resolve_path +from .utility import get_file_path + + +class Driver(Enum): + SQLITE = "sqlite" + MYSQL = "mysql" + POSTGRESQL = "postgresql" + + +_db: Database +_driver: Driver def init(): - db_settings = SETTINGS['database'] - driver = db_settings["driver"] + global _driver + db_settings = {k[9:]: v for k, v in SETTINGS.items() if k.startswith("database.")} + _driver = Driver(db_settings["driver"]) init_funcs = { - "sqlite": init_sqlite, - "mysql": init_mysql, - "postgresql": init_postgresql, + Driver.SQLITE: init_sqlite, + Driver.MYSQL: init_mysql, + Driver.POSTGRESQL: init_postgresql, } - assert driver in init_funcs - del db_settings['driver'] - return init_funcs[driver](db_settings) + assert _driver in init_funcs + del db_settings["driver"] + return init_funcs[_driver](db_settings) def init_sqlite(settings: dict): - global DB - database = settings['database'] + global _db + database = settings["database"] - DB = SqliteDatabase(str(resolve_path(database))) + _db = SqliteDatabase(str(get_file_path(database))) def init_mysql(settings: dict): - global DB - DB = MySQLDatabase(**settings) + global _db + _db = MySQLDatabase(**settings) def init_postgresql(settings: dict): - global DB - DB = PostgresqlDatabase(**settings) + global _db + _db = PostgresqlDatabase(**settings) init() -class DbBarData(Model): +class ModelBase(Model): + def to_dict(self): + return self.__data__ + + +class DbBarData(ModelBase): """ Candlestick bar data for database storage. - Index is defined unique with vt_symbol, interval and datetime. + Index is defined unique with datetime, interval, symbol """ + id = AutoField() symbol = CharField() exchange = CharField() datetime = DateTimeField() @@ -62,12 +91,9 @@ class DbBarData(Model): low_price = FloatField() close_price = FloatField() - vt_symbol = CharField() - gateway_name = CharField() - class Meta: - database = DB - indexes = ((("vt_symbol", "interval", "datetime"), True),) + database = _db + indexes = ((("datetime", "interval", "symbol"), True),) @staticmethod def from_bar(bar: BarData): @@ -85,8 +111,6 @@ class DbBarData(Model): db_bar.high_price = bar.high_price db_bar.low_price = bar.low_price db_bar.close_price = bar.close_price - db_bar.vt_symbol = bar.vt_symbol - db_bar.gateway_name = "DB" return db_bar @@ -104,18 +128,40 @@ class DbBarData(Model): high_price=self.high_price, low_price=self.low_price, close_price=self.close_price, - gateway_name=self.gateway_name, + gateway_name="DB", ) return bar + @staticmethod + def save_all(objs: List["DbBarData"]): + """ + save a list of objects, update if exists. + """ + with _db.atomic(): + if _driver is Driver.POSTGRESQL: + for bar in objs: + DbBarData.insert(bar.to_dict()).on_conflict( + update=bar.to_dict(), + conflict_target=( + DbBarData.datetime, + DbBarData.interval, + DbBarData.symbol, + ), + ).execute() + else: + for c in chunked(objs, 50): + DbBarData.insert_many(c).on_conflict_replace() -class DbTickData(Model): + +class DbTickData(ModelBase): """ Tick data for database storage. - Index is defined unique with vt_symbol, interval and datetime. + Index is defined unique with (datetime, symbol) """ + id = AutoField() + symbol = CharField() exchange = CharField() datetime = DateTimeField() @@ -131,6 +177,7 @@ class DbTickData(Model): high_price = FloatField() low_price = FloatField() close_price = FloatField() + pre_close = FloatField() bid_price_1 = FloatField() bid_price_2 = FloatField() @@ -156,12 +203,9 @@ class DbTickData(Model): ask_volume_4 = FloatField() ask_volume_5 = FloatField() - vt_symbol = CharField() - gateway_name = CharField() - class Meta: - database = DB - indexes = ((("vt_symbol", "datetime"), True),) + database = _db + indexes = ((("datetime", "symbol"), True),) @staticmethod def from_tick(tick: TickData): @@ -210,9 +254,6 @@ class DbTickData(Model): db_tick.ask_volume_4 = tick.ask_volume_4 db_tick.ask_volume_5 = tick.ask_volume_5 - db_tick.vt_symbol = tick.vt_symbol - db_tick.gateway_name = "DB" - return tick def to_tick(self): @@ -237,7 +278,7 @@ class DbTickData(Model): ask_price_1=self.ask_price_1, bid_volume_1=self.bid_volume_1, ask_volume_1=self.ask_volume_1, - gateway_name=self.gateway_name, + gateway_name="DB", ) if self.bid_price_2: @@ -263,6 +304,20 @@ class DbTickData(Model): return tick + @staticmethod + def save_all(objs: List["DbTickData"]): + with _db.atomic(): + if _driver is Driver.POSTGRESQL: + for bar in objs: + DbTickData.insert(bar.to_dict()).on_conflict( + update=bar.to_dict(), + preserve=(DbTickData.id), + conflict_target=(DbTickData.datetime, DbTickData.symbol), + ).execute() + else: + for c in chunked(objs, 50): + DbBarData.insert_many(c).on_conflict_replace() -DB.connect() -DB.create_tables([DbBarData, DbTickData]) + +_db.connect() +_db.create_tables([DbBarData, DbTickData]) diff --git a/vnpy/trader/setting.py b/vnpy/trader/setting.py index b27dbf96..778b147c 100644 --- a/vnpy/trader/setting.py +++ b/vnpy/trader/setting.py @@ -24,14 +24,13 @@ SETTINGS = { "rqdata.username": "", "rqdata.password": "", - "database": { - "driver": "sqlite", # sqlite, mysql, postgresql - "database": "{VNPY_TEMP}/database.db", # for sqlite, use this as filepath - "host": "localhost", - "port": 3306, - "user": "root", - "password": "" - } + + "database.driver": "sqlite", # see database.Driver + "database.database": "database.db", # for sqlite, use this as filepath + "database.host": "localhost", + "database.port": 3306, + "database.user": "root", + "database.password": "" } # Load global setting from json file.