From 5d0bf006c65200a933fd25a43a27cd4e48909592 Mon Sep 17 00:00:00 2001 From: nanoric Date: Sat, 13 Apr 2019 23:02:30 -0400 Subject: [PATCH] [Add] added support for mongodb [Add] added tests for database and CsvLoaderEngine [Mod] changed .travis.yaml --- .travis.yml | 8 +- tests/app/__init__.py | 1 + tests/test_all.py | 20 + tests/{load_all.py => test_import_all.py} | 18 + tests/trader/__init__.py | 2 + tests/trader/test_database.py | 125 ++++++ tests/trader/test_settings.py | 24 ++ tests/travis_env.sh | 16 + vnpy/app/csv_loader/engine.py | 21 +- vnpy/app/cta_strategy/backtesting.py | 58 ++- vnpy/app/cta_strategy/engine.py | 68 ++-- vnpy/trader/database/__init__.py | 12 + vnpy/trader/database/database.py | 53 +++ .../database_mongo.py} | 289 +++++++------- vnpy/trader/database/database_sql.py | 370 ++++++++++++++++++ vnpy/trader/database/initialize.py | 24 ++ vnpy/trader/setting.py | 8 +- vnpy/trader/utility.py | 20 +- 18 files changed, 892 insertions(+), 245 deletions(-) create mode 100644 tests/app/__init__.py create mode 100644 tests/test_all.py rename tests/{load_all.py => test_import_all.py} (63%) create mode 100644 tests/trader/__init__.py create mode 100644 tests/trader/test_database.py create mode 100644 tests/trader/test_settings.py create mode 100644 tests/travis_env.sh create mode 100644 vnpy/trader/database/__init__.py create mode 100644 vnpy/trader/database/database.py rename vnpy/trader/{database.py => database/database_mongo.py} (51%) create mode 100644 vnpy/trader/database/database_sql.py create mode 100644 vnpy/trader/database/initialize.py diff --git a/.travis.yml b/.travis.yml index 8622616d..d38b8fed 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,13 +1,19 @@ language: python +cache: pip + dist: xenial # required for Python >= 3.7 (travis-ci/travis-ci#9069) python: - "3.7" +services: + - mongodb + - mysql + - postgresql script: # todo: use python unittest - - mkdir run; cd run; python ../tests/load_all.py + - cd tests; source travis_env.sh; python test_all.py matrix: include: diff --git a/tests/app/__init__.py b/tests/app/__init__.py new file mode 100644 index 00000000..9aafa67d --- /dev/null +++ b/tests/app/__init__.py @@ -0,0 +1 @@ +from .test_csv_loader import * diff --git a/tests/test_all.py b/tests/test_all.py new file mode 100644 index 00000000..e824ff21 --- /dev/null +++ b/tests/test_all.py @@ -0,0 +1,20 @@ +# tests/runner.py +import unittest + +# import your test modules +import test_import_all +import trader +import app + +# initialize the test suite +loader = unittest.TestLoader() +suite = unittest.TestSuite() + +# add tests to the test suite +suite.addTests(loader.loadTestsFromModule(test_import_all)) +suite.addTests(loader.loadTestsFromModule(trader)) +suite.addTests(loader.loadTestsFromModule(app)) + +# initialize a runner, pass it your suite and run it +runner = unittest.TextTestRunner(verbosity=3) +result = runner.run(suite) diff --git a/tests/load_all.py b/tests/test_import_all.py similarity index 63% rename from tests/load_all.py rename to tests/test_import_all.py index feae9c12..df930c89 100644 --- a/tests/load_all.py +++ b/tests/test_import_all.py @@ -2,23 +2,41 @@ import unittest +# noinspection PyUnresolvedReferences class ImportTest(unittest.TestCase): # noinspection PyUnresolvedReferences def test_import_all(self): from vnpy.event import EventEngine + def test_import_main_engine(self): from vnpy.trader.engine import MainEngine + + def test_import_ui(self): from vnpy.trader.ui import MainWindow, create_qapp + def test_import_bitmex_gateway(self): from vnpy.gateway.bitmex import BitmexGateway + + def test_import_futu_gateway(self): from vnpy.gateway.futu import FutuGateway + + def test_import_ib_gateway(self): from vnpy.gateway.ib import IbGateway + + def test_import_ctp_gateway(self): from vnpy.gateway.ctp import CtpGateway + + def test_import_tiger_gateway(self): from vnpy.gateway.tiger import TigerGateway + + def test_import_oes_gateway(self): from vnpy.gateway.oes import OesGateway + def test_import_cta_strategy_app(self): from vnpy.app.cta_strategy import CtaStrategyApp + + def test_import_csv_loader_app(self): from vnpy.app.csv_loader import CsvLoaderApp diff --git a/tests/trader/__init__.py b/tests/trader/__init__.py new file mode 100644 index 00000000..c5776fc1 --- /dev/null +++ b/tests/trader/__init__.py @@ -0,0 +1,2 @@ +from .test_database import * +from .test_settings import * diff --git a/tests/trader/test_database.py b/tests/trader/test_database.py new file mode 100644 index 00000000..59f1dcfc --- /dev/null +++ b/tests/trader/test_database.py @@ -0,0 +1,125 @@ +""" +Test if database works fine +""" +import os +import unittest +from datetime import datetime, timedelta + +from vnpy.trader.constant import Exchange, Interval +from vnpy.trader.database.database import Driver +from vnpy.trader.object import BarData, TickData + +os.environ['VNPY_TESTING'] = '1' + +profiles = { + Driver.SQLITE: { + "driver": "sqlite", + "database": "test_db.db", + }, + Driver.MYSQL: { + "driver": "mysql", + "database": os.environ['VNPY_TEST_MYSQL_DATABASE'], + "host": os.environ['VNPY_TEST_MYSQL_HOST'], + "port": int(os.environ['VNPY_TEST_MYSQL_PORT']), + "user": os.environ["VNPY_TEST_MYSQL_USER"], + "password": os.environ['VNPY_TEST_MYSQL_PASSWORD'], + }, + Driver.POSTGRESQL: { + "driver": "postgresql", + "database": os.environ['VNPY_TEST_POSTGRESQL_DATABASE'], + "host": os.environ['VNPY_TEST_POSTGRESQL_HOST'], + "port": int(os.environ['VNPY_TEST_POSTGRESQL_PORT']), + "user": os.environ["VNPY_TEST_POSTGRESQL_USER"], + "password": os.environ['VNPY_TEST_POSTGRESQL_PASSWORD'], + }, + Driver.MONGODB: { + "driver": "mongodb", + "database": os.environ['VNPY_TEST_MONGODB_DATABASE'], + "host": os.environ['VNPY_TEST_MONGODB_HOST'], + "port": int(os.environ['VNPY_TEST_MONGODB_PORT']), + "user": "", + "password": "", + "authentication_source": "", + }, +} + + +def now(): + return datetime.utcnow() + + +bar = BarData( + gateway_name="DB", + symbol="test_symbol", + exchange=Exchange.BITMEX, + datetime=now(), + interval=Interval.MINUTE, +) + +tick = TickData( + gateway_name="DB", + symbol="test_symbol", + exchange=Exchange.BITMEX, + datetime=now(), + name="DB_test_symbol", +) + + +class TestDatabase(unittest.TestCase): + + def connect(self, settings: dict): + from vnpy.trader.database.initialize import init # noqa + self.manager = init(settings) + + def test_upsert_bar(self): + for driver, settings in profiles.items(): + with self.subTest(driver=driver, settings=settings): + self.connect(settings) + self.manager.save_bar_data([bar]) + self.manager.save_bar_data([bar]) + + def test_save_load_bar(self): + for driver, settings in profiles.items(): + with self.subTest(driver=driver, settings=settings): + self.connect(settings) + # save first + self.manager.save_bar_data([bar]) + + # and load + results = self.manager.load_bar_data( + symbol=bar.symbol, + exchange=bar.exchange, + interval=bar.interval, + start=bar.datetime - timedelta(seconds=1), # time is not accuracy + end=now() + timedelta(seconds=1), # time is not accuracy + ) + count = len(results) + self.assertNotEqual(count, 0) + + def test_upsert_tick(self): + for driver, settings in profiles.items(): + with self.subTest(driver=driver, settings=settings): + self.connect(settings) + self.manager.save_tick_data([tick]) + self.manager.save_tick_data([tick]) + + def test_save_load_tick(self): + for driver, settings in profiles.items(): + with self.subTest(driver=driver, settings=settings): + self.connect(settings) + # save first + self.manager.save_tick_data([tick]) + + # and load + results = self.manager.load_tick_data( + symbol=bar.symbol, + exchange=bar.exchange, + start=bar.datetime - timedelta(seconds=1), # time is not accuracy + end=now() + timedelta(seconds=1), # time is not accuracy + ) + count = len(results) + self.assertNotEqual(count, 0) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trader/test_settings.py b/tests/trader/test_settings.py new file mode 100644 index 00000000..c8e86132 --- /dev/null +++ b/tests/trader/test_settings.py @@ -0,0 +1,24 @@ +""" +Test if database works fine +""" +import unittest +from vnpy.trader.setting import SETTINGS, get_settings + + +class TestSettings(unittest.TestCase): + + def test_get_settings(self): + SETTINGS['a'] = 1 + got = get_settings() + self.assertIn('a', got) + self.assertEqual(got['a'], 1) + + def test_get_settings_with_prefix(self): + SETTINGS['a.a'] = 1 + got = get_settings() + self.assertIn('a', got) + self.assertEqual(got['a'], 1) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/travis_env.sh b/tests/travis_env.sh new file mode 100644 index 00000000..030d79c5 --- /dev/null +++ b/tests/travis_env.sh @@ -0,0 +1,16 @@ +#!/usr/bin/env bash + +export VNPY_TEST_MYSQL_DATABASE=vnpy +export VNPY_TEST_MYSQL_HOST=127.0.0.1 +export VNPY_TEST_MYSQL_PORT=3306 +export VNPY_TEST_MYSQL_USER=root +export VNPY_TEST_MYSQL_PASSWORD= +export VNPY_TEST_POSTGRESQL_DATABASE=vnpy +export VNPY_TEST_POSTGRESQL_HOST=127.0.0.1 +export VNPY_TEST_POSTGRESQL_PORT=5432 +export VNPY_TEST_POSTGRESQL_USER=postgres +export VNPY_TEST_POSTGRESQL_PASSWORD= +export VNPY_TEST_MONGODB_DATABASE=vnpy +export VNPY_TEST_MONGODB_HOST=127.0.0.1 +export VNPY_TEST_MONGODB_PORT=27017 + diff --git a/vnpy/app/csv_loader/engine.py b/vnpy/app/csv_loader/engine.py index 7125e3ad..a9ff1936 100644 --- a/vnpy/app/csv_loader/engine.py +++ b/vnpy/app/csv_loader/engine.py @@ -26,8 +26,9 @@ from typing import TextIO from vnpy.event import EventEngine from vnpy.trader.constant import Exchange, Interval -from vnpy.trader.database import DbBarData +from vnpy.trader.database import database_manager from vnpy.trader.engine import BaseEngine, MainEngine +from vnpy.trader.object import BarData APP_NAME = "CsvLoader" @@ -70,7 +71,7 @@ class CsvLoaderEngine(BaseEngine): """ reader = csv.DictReader(f) - db_bars = [] + bars = [] start = None count = 0 for item in reader: @@ -79,29 +80,29 @@ class CsvLoaderEngine(BaseEngine): else: dt = datetime.fromisoformat(item[datetime_head]) - db_bar = DbBarData( + bar = BarData( symbol=symbol, - exchange=exchange.value, + exchange=exchange, datetime=dt, - interval=interval.value, + interval=interval, volume=item[volume_head], open_price=item[open_head], high_price=item[high_head], low_price=item[low_head], close_price=item[close_head], + gateway_name="DB", ) - db_bars.append(db_bar) + bars.append(bar) # do some statistics count += 1 if not start: - start = db_bar.datetime - end = db_bar.datetime + start = bar.datetime + end = bar.datetime # insert into database - DbBarData.save_all(db_bars) - + database_manager.save_bar_data(bars) return start, end, count def load( diff --git a/vnpy/app/cta_strategy/backtesting.py b/vnpy/app/cta_strategy/backtesting.py index f1458fe7..e4661107 100644 --- a/vnpy/app/cta_strategy/backtesting.py +++ b/vnpy/app/cta_strategy/backtesting.py @@ -12,9 +12,9 @@ from pandas import DataFrame from vnpy.trader.constant import (Direction, Offset, Exchange, Interval, Status) -from vnpy.trader.database import DbBarData, DbTickData -from vnpy.trader.object import OrderData, TradeData -from vnpy.trader.utility import round_to_pricetick +from vnpy.trader.database import database_manager +from vnpy.trader.object import OrderData, TradeData, BarData, TickData +from vnpy.trader.utility import round_to_pricetick, extract_vt_symbol from .base import ( BacktestingMode, @@ -103,8 +103,8 @@ class BacktestingEngine: self.strategy_class = None self.strategy = None - self.tick = None - self.bar = None + self.tick: TickData + self.bar: BarData self.datetime = None self.interval = None @@ -199,14 +199,16 @@ class BacktestingEngine: if self.mode == BacktestingMode.BAR: self.history_data = load_bar_data( - self.vt_symbol, + self.symbol, + self.exchange, self.interval, self.start, self.end ) else: self.history_data = load_tick_data( - self.vt_symbol, + self.symbol, + self.exchange, self.start, self.end ) @@ -519,7 +521,7 @@ class BacktestingEngine: else: self.daily_results[d] = DailyResult(d, price) - def new_bar(self, bar: DbBarData): + def new_bar(self, bar: BarData): """""" self.bar = bar self.datetime = bar.datetime @@ -530,7 +532,7 @@ class BacktestingEngine: self.update_daily_close(bar.close_price) - def new_tick(self, tick: DbTickData): + def new_tick(self, tick: TickData): """""" self.tick = tick self.datetime = tick.datetime @@ -965,41 +967,27 @@ def optimize( @lru_cache(maxsize=10) def load_bar_data( - vt_symbol: str, - interval: str, - start: datetime, + symbol: str, + exchange: Exchange, + interval: Interval, + start: datetime, end: datetime ): """""" - s = ( - DbBarData.select() - .where( - (DbBarData.vt_symbol == vt_symbol) - & (DbBarData.interval == interval) - & (DbBarData.datetime >= start) - & (DbBarData.datetime <= end) - ) - .order_by(DbBarData.datetime) + return database_manager.load_bar_data( + symbol, exchange, interval, start, end ) - data = [db_bar.to_bar() for db_bar in s] - return data @lru_cache(maxsize=10) def load_tick_data( - vt_symbol: str, - start: datetime, + symbol: str, + exchange: Exchange, + start: datetime, end: datetime ): """""" - s = ( - DbTickData.select() - .where( - (DbTickData.vt_symbol == vt_symbol) - & (DbTickData.datetime >= start) - & (DbTickData.datetime <= end) - ) - .order_by(DbTickData.datetime) + return database_manager.load_tick_data( + symbol, exchange, start, end ) - data = [db_tick.db_tick() for db_tick in s] - return data + diff --git a/vnpy/app/cta_strategy/engine.py b/vnpy/app/cta_strategy/engine.py index 37e5973a..7824ff2d 100644 --- a/vnpy/app/cta_strategy/engine.py +++ b/vnpy/app/cta_strategy/engine.py @@ -5,7 +5,7 @@ import os import traceback from collections import defaultdict from pathlib import Path -from typing import Any, Callable +from typing import Any, Callable, List from datetime import datetime, timedelta from threading import Thread from queue import Queue @@ -36,7 +36,7 @@ from vnpy.trader.constant import ( Status ) from vnpy.trader.utility import load_json, save_json -from vnpy.trader.database import DbTickData, DbBarData +from vnpy.trader.database import database_manager from vnpy.trader.setting import SETTINGS from .base import ( @@ -146,13 +146,12 @@ class CtaEngine(BaseEngine): self.write_log("RQData数据接口初始化成功") def query_bar_from_rq( - self, vt_symbol: str, interval: Interval, start: datetime, end: datetime + self, symbol: str, exchange: Exchange, interval: Interval, start: datetime, end: datetime ): """ Query bar data from RQData. """ - symbol, exchange_str = vt_symbol.split(".") - rq_symbol = to_rq_symbol(vt_symbol) + rq_symbol = to_rq_symbol(symbol, exchange) if rq_symbol not in self.rq_symbols: return None @@ -166,11 +165,11 @@ class CtaEngine(BaseEngine): end_date=end ) - data = [] + data: List[BarData] = [] for ix, row in df.iterrows(): bar = BarData( symbol=symbol, - exchange=Exchange(exchange_str), + exchange=exchange, interval=interval, datetime=row.name.to_pydatetime(), open_price=row["open"], @@ -529,46 +528,41 @@ class CtaEngine(BaseEngine): return self.engine_type def load_bar( - self, vt_symbol: str, days: int, interval: Interval, callback: Callable + self, symbol: str, exchange: Exchange, days: int, interval: Interval, + callback: Callable[[BarData], None] ): """""" end = datetime.now() start = end - timedelta(days) - # Query data from RQData by default, if not found, load from database. - data = self.query_bar_from_rq(vt_symbol, interval, start, end) - if not data: - s = ( - DbBarData.select() - .where( - (DbBarData.vt_symbol == vt_symbol) - & (DbBarData.interval == interval.value) - & (DbBarData.datetime >= start) - & (DbBarData.datetime <= end) - ) - .order_by(DbBarData.datetime) + # Query bars from RQData by default, if not found, load from database. + bars = self.query_bar_from_rq(symbol, exchange, interval, start, end) + if not bars: + bars = database_manager.load_bar_data( + symbol=symbol, + exchange=exchange, + interval=interval, + start=start, + end=end, ) - data = [db_bar.to_bar() for db_bar in s] - for bar in data: + for bar in bars: callback(bar) - def load_tick(self, vt_symbol: str, days: int, callback: Callable): + def load_tick(self, symbol: str, exchange: Exchange, days: int, + callback: Callable[[TickData], None]): """""" end = datetime.now() start = end - timedelta(days) - s = ( - DbTickData.select() - .where( - (DbBarData.vt_symbol == vt_symbol) - & (DbBarData.datetime >= start) - & (DbBarData.datetime <= end) - ) - .order_by(DbBarData.datetime) + ticks = database_manager.load_tick_data( + symbol=symbol, + exchange=exchange, + start=start, + end=end, ) - for tick in s: + for tick in ticks: callback(tick) def call_strategy_func( @@ -757,7 +751,7 @@ class CtaEngine(BaseEngine): """ Load strategy class from certain folder. """ - for dirpath, dirnames, filenames in os.walk(path): + for dirpath, dirnames, filenames in os.walk(str(path)): for filename in filenames: if filename.endswith(".py"): strategy_module_name = ".".join( @@ -914,19 +908,19 @@ class CtaEngine(BaseEngine): self.main_engine.send_email(subject, msg) -def to_rq_symbol(vt_symbol: str): +def to_rq_symbol(symbol: str, exchange: Exchange): """ CZCE product of RQData has symbol like "TA1905" while vt symbol is "TA905.CZCE" so need to add "1" in symbol. """ - symbol, exchange_str = vt_symbol.split(".") - if exchange_str != "CZCE": + if exchange is not Exchange.CZCE: return symbol.upper() for count, word in enumerate(symbol): if word.isdigit(): break - + + # noinspection PyUnboundLocalVariable product = symbol[:count] year = symbol[count] month = symbol[count + 1:] diff --git a/vnpy/trader/database/__init__.py b/vnpy/trader/database/__init__.py new file mode 100644 index 00000000..16cb96b8 --- /dev/null +++ b/vnpy/trader/database/__init__.py @@ -0,0 +1,12 @@ +import os +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from vnpy.trader.database.database import BaseDatabaseManager + +if "VNPY_TESTING" not in os.environ: + from vnpy.trader.setting import get_settings + from .initialize import init + + settings = get_settings("database.") + database_manager: "BaseDatabaseManager" = init(settings=settings) diff --git a/vnpy/trader/database/database.py b/vnpy/trader/database/database.py new file mode 100644 index 00000000..06077383 --- /dev/null +++ b/vnpy/trader/database/database.py @@ -0,0 +1,53 @@ +from abc import ABC, abstractmethod +from datetime import datetime +from enum import Enum +from typing import Sequence, TYPE_CHECKING + +if TYPE_CHECKING: + from vnpy.trader.constant import Interval, Exchange + from vnpy.trader.object import BarData, TickData + + +class Driver(Enum): + SQLITE = "sqlite" + MYSQL = "mysql" + POSTGRESQL = "postgresql" + MONGODB = "mongodb" + + +class BaseDatabaseManager(ABC): + + @abstractmethod + def load_bar_data( + self, + symbol: str, + exchange: "Exchange", + interval: "Interval", + start: datetime, + end: datetime + ) -> Sequence["BarData"]: + pass + + @abstractmethod + def load_tick_data( + self, + symbol: str, + exchange: "Exchange", + start: datetime, + end: datetime + ) -> Sequence["TickData"]: + pass + + @abstractmethod + def save_bar_data( + self, + datas: Sequence["BarData"], + ): + pass + + @abstractmethod + def save_tick_data( + self, + datas: Sequence["TickData"], + ): + pass diff --git a/vnpy/trader/database.py b/vnpy/trader/database/database_mongo.py similarity index 51% rename from vnpy/trader/database.py rename to vnpy/trader/database/database_mongo.py index ef0c9cc2..ef5038e1 100644 --- a/vnpy/trader/database.py +++ b/vnpy/trader/database/database_mongo.py @@ -1,99 +1,59 @@ -"""""" +from datetime import datetime from enum import Enum -from typing import List +from typing import Sequence -from peewee import ( - AutoField, - CharField, - Database, - DateTimeField, - FloatField, - Model, - MySQLDatabase, - PostgresqlDatabase, - SqliteDatabase, - chunked, -) +from mongoengine import DateTimeField, Document, FloatField, StringField, connect -from .constant import Exchange, Interval -from .object import BarData, TickData -from .setting import SETTINGS -from .utility import get_file_path +from vnpy.trader.constant import Exchange, Interval +from vnpy.trader.object import BarData, TickData +from .database import BaseDatabaseManager, Driver -class Driver(Enum): - SQLITE = "sqlite" - MYSQL = "mysql" - POSTGRESQL = "postgresql" - - -_db: Database -_driver: Driver - - -def init(): - global _driver - db_settings = {k[9:]: v for k, v in SETTINGS.items() if k.startswith("database.")} - _driver = Driver(db_settings["driver"]) - - init_funcs = { - Driver.SQLITE: init_sqlite, - Driver.MYSQL: init_mysql, - Driver.POSTGRESQL: init_postgresql, - } - - assert _driver in init_funcs - del db_settings["driver"] - return init_funcs[_driver](db_settings) - - -def init_sqlite(settings: dict): - global _db +def init(_: Driver, settings: dict): database = settings["database"] - - _db = SqliteDatabase(str(get_file_path(database))) + host = settings["host"] + port = settings["port"] + username = settings["user"] + password = settings["password"] + authentication_source = settings["authentication_source"] + if not username: # if username == '' or None, skip username + username = None + password = None + authentication_source = None + connect( + db=database, + host=host, + port=port, + username=username, + password=password, + authentication_source=authentication_source, + ) + return MongoManager() -def init_mysql(settings: dict): - global _db - _db = MySQLDatabase(**settings) - - -def init_postgresql(settings: dict): - global _db - _db = PostgresqlDatabase(**settings) - - -init() - - -class ModelBase(Model): - def to_dict(self): - return self.__data__ - - -class DbBarData(ModelBase): +class DbBarData(Document): """ Candlestick bar data for database storage. Index is defined unique with datetime, interval, symbol """ - id = AutoField() - symbol = CharField() - exchange = CharField() - datetime = DateTimeField() - interval = CharField() + symbol: str = StringField() + exchange: str = StringField() + datetime: datetime = DateTimeField() + interval: str = StringField() - volume = FloatField() - open_price = FloatField() - high_price = FloatField() - low_price = FloatField() - close_price = FloatField() + volume: float = FloatField() + open_price: float = FloatField() + high_price: float = FloatField() + low_price: float = FloatField() + close_price: float = FloatField() - class Meta: - database = _db - indexes = ((("datetime", "interval", "symbol"), True),) + meta = { + "indexes": [ + {"fields": ("datetime", "interval", "symbol", "exchange"), "unique": True} + ] + } @staticmethod def from_bar(bar: BarData): @@ -132,80 +92,56 @@ class DbBarData(ModelBase): ) return bar - @staticmethod - def save_all(objs: List["DbBarData"]): - """ - save a list of objects, update if exists. - """ - with _db.atomic(): - if _driver is Driver.POSTGRESQL: - for bar in objs: - DbBarData.insert(bar.to_dict()).on_conflict( - update=bar.to_dict(), - conflict_target=( - DbBarData.datetime, - DbBarData.interval, - DbBarData.symbol, - ), - ).execute() - else: - for c in chunked(objs, 50): - DbBarData.insert_many(c).on_conflict_replace() - -class DbTickData(ModelBase): +class DbTickData(Document): """ Tick data for database storage. Index is defined unique with (datetime, symbol) """ - id = AutoField() + symbol: str = StringField() + exchange: str = StringField() + datetime: datetime = DateTimeField() - symbol = CharField() - exchange = CharField() - datetime = DateTimeField() + name: str = StringField() + volume: float = FloatField() + last_price: float = FloatField() + last_volume: float = FloatField() + limit_up: float = FloatField() + limit_down: float = FloatField() - name = CharField() - volume = FloatField() - last_price = FloatField() - last_volume = FloatField() - limit_up = FloatField() - limit_down = FloatField() + open_price: float = FloatField() + high_price: float = FloatField() + low_price: float = FloatField() + close_price: float = FloatField() + pre_close: float = FloatField() - open_price = FloatField() - high_price = FloatField() - low_price = FloatField() - close_price = FloatField() - pre_close = FloatField() + bid_price_1: float = FloatField() + bid_price_2: float = FloatField() + bid_price_3: float = FloatField() + bid_price_4: float = FloatField() + bid_price_5: float = FloatField() - bid_price_1 = FloatField() - bid_price_2 = FloatField() - bid_price_3 = FloatField() - bid_price_4 = FloatField() - bid_price_5 = FloatField() + ask_price_1: float = FloatField() + ask_price_2: float = FloatField() + ask_price_3: float = FloatField() + ask_price_4: float = FloatField() + ask_price_5: float = FloatField() - ask_price_1 = FloatField() - ask_price_2 = FloatField() - ask_price_3 = FloatField() - ask_price_4 = FloatField() - ask_price_5 = FloatField() + bid_volume_1: float = FloatField() + bid_volume_2: float = FloatField() + bid_volume_3: float = FloatField() + bid_volume_4: float = FloatField() + bid_volume_5: float = FloatField() - bid_volume_1 = FloatField() - bid_volume_2 = FloatField() - bid_volume_3 = FloatField() - bid_volume_4 = FloatField() - bid_volume_5 = FloatField() + ask_volume_1: float = FloatField() + ask_volume_2: float = FloatField() + ask_volume_3: float = FloatField() + ask_volume_4: float = FloatField() + ask_volume_5: float = FloatField() - ask_volume_1 = FloatField() - ask_volume_2 = FloatField() - ask_volume_3 = FloatField() - ask_volume_4 = FloatField() - ask_volume_5 = FloatField() - - class Meta: - database = _db - indexes = ((("datetime", "symbol"), True),) + meta = {"indexes": [{"fields": ("datetime", "symbol", "exchange"), "unique": True}]} @staticmethod def from_tick(tick: TickData): @@ -254,7 +190,7 @@ class DbTickData(ModelBase): db_tick.ask_volume_4 = tick.ask_volume_4 db_tick.ask_volume_5 = tick.ask_volume_5 - return tick + return db_tick def to_tick(self): """ @@ -304,20 +240,63 @@ class DbTickData(ModelBase): return tick + +class MongoManager(BaseDatabaseManager): + def load_bar_data( + self, + symbol: str, + exchange: Exchange, + interval: Interval, + start: datetime, + end: datetime, + ) -> Sequence[BarData]: + s = DbBarData.objects( + symbol=symbol, + exchange=exchange.value, + interval=interval.value, + datetime__gte=start, + datetime__lte=end, + ) + data = [db_bar.to_bar() for db_bar in s] + return data + + def load_tick_data( + self, symbol: str, exchange: Exchange, start: datetime, end: datetime + ) -> Sequence[TickData]: + s = DbTickData.objects( + symbol=symbol, + exchange=exchange.value, + datetime__gte=start, + datetime__lte=end, + ) + data = [db_tick.to_tick() for db_tick in s] + return data + @staticmethod - def save_all(objs: List["DbTickData"]): - with _db.atomic(): - if _driver is Driver.POSTGRESQL: - for bar in objs: - DbTickData.insert(bar.to_dict()).on_conflict( - update=bar.to_dict(), - preserve=(DbTickData.id), - conflict_target=(DbTickData.datetime, DbTickData.symbol), - ).execute() - else: - for c in chunked(objs, 50): - DbBarData.insert_many(c).on_conflict_replace() + def to_update_param(d): + return { + "set__" + k: v.value if isinstance(v, Enum) else v + for k, v in d.__dict__.items() + } + def save_bar_data(self, datas: Sequence[BarData]): + for d in datas: + updates = self.to_update_param(d) + updates.pop("set__gateway_name") + updates.pop("set__vt_symbol") + ( + DbBarData.objects( + symbol=d.symbol, interval=d.interval.value, datetime=d.datetime + ).update_one(upsert=True, **updates) + ) -_db.connect() -_db.create_tables([DbBarData, DbTickData]) + def save_tick_data(self, datas: Sequence[TickData]): + for d in datas: + updates = self.to_update_param(d) + updates.pop("set__gateway_name") + updates.pop("set__vt_symbol") + ( + DbTickData.objects( + symbol=d.symbol, exchange=d.exchange.value, datetime=d.datetime + ).update_one(upsert=True, **updates) + ) diff --git a/vnpy/trader/database/database_sql.py b/vnpy/trader/database/database_sql.py new file mode 100644 index 00000000..a82802ad --- /dev/null +++ b/vnpy/trader/database/database_sql.py @@ -0,0 +1,370 @@ +"""""" +from datetime import datetime +from typing import List, Sequence, Type + +from peewee import ( + AutoField, + CharField, + Database, + DateTimeField, + FloatField, + Model, + MySQLDatabase, + PostgresqlDatabase, + SqliteDatabase, + chunked, +) + +from vnpy.trader.constant import Exchange, Interval +from vnpy.trader.object import BarData, TickData +from vnpy.trader.utility import get_file_path +from .database import BaseDatabaseManager, Driver + + +def init(driver: Driver, settings: dict): + init_funcs = { + Driver.SQLITE: init_sqlite, + Driver.MYSQL: init_mysql, + Driver.POSTGRESQL: init_postgresql, + } + assert driver in init_funcs + + db = init_funcs[driver](settings) + bar, tick = init_models(db, driver) + return SqlManager(bar, tick) + + +def init_sqlite(settings: dict): + database = settings["database"] + path = str(get_file_path(database)) + db = SqliteDatabase(path) + return db + + +def init_mysql(settings: dict): + keys = {"database", "user", "password", "host", "port"} + settings = {k: v for k, v in settings.items() if k in keys} + db = MySQLDatabase(**settings) + return db + + +def init_postgresql(settings: dict): + keys = {"database", "user", "password", "host", "port"} + settings = {k: v for k, v in settings.items() if k in keys} + db = PostgresqlDatabase(**settings) + return db + + +class ModelBase(Model): + + def to_dict(self): + return self.__data__ + + +def init_models(db: Database, driver: Driver): + class DbBarData(ModelBase): + """ + Candlestick bar data for database storage. + + Index is defined unique with datetime, interval, symbol + """ + + id = AutoField() + symbol: str = CharField() + exchange: str = CharField() + datetime: datetime = DateTimeField() + interval: str = CharField() + + volume: float = FloatField() + open_price: float = FloatField() + high_price: float = FloatField() + low_price: float = FloatField() + close_price: float = FloatField() + + class Meta: + database = db + indexes = ((("datetime", "interval", "symbol", "exchange"), True),) + + @staticmethod + def from_bar(bar: BarData): + """ + Generate DbBarData object from BarData. + """ + db_bar = DbBarData() + + db_bar.symbol = bar.symbol + db_bar.exchange = bar.exchange.value + db_bar.datetime = bar.datetime + db_bar.interval = bar.interval.value + db_bar.volume = bar.volume + db_bar.open_price = bar.open_price + db_bar.high_price = bar.high_price + db_bar.low_price = bar.low_price + db_bar.close_price = bar.close_price + + return db_bar + + def to_bar(self): + """ + Generate BarData object from DbBarData. + """ + bar = BarData( + symbol=self.symbol, + exchange=Exchange(self.exchange), + datetime=self.datetime, + interval=Interval(self.interval), + volume=self.volume, + open_price=self.open_price, + high_price=self.high_price, + low_price=self.low_price, + close_price=self.close_price, + gateway_name="DB", + ) + return bar + + @staticmethod + def save_all(objs: List["DbBarData"]): + """ + save a list of objects, update if exists. + """ + dicts = [i.to_dict() for i in objs] + with db.atomic(): + if driver is Driver.POSTGRESQL: + for bar in dicts: + DbBarData.insert(bar).on_conflict( + update=bar, + conflict_target=( + DbBarData.datetime, + DbBarData.interval, + DbBarData.symbol, + DbBarData.exchange, + ), + ).execute() + else: + for c in chunked(dicts, 50): + DbBarData.insert_many(c).on_conflict_replace().execute() + + class DbTickData(ModelBase): + """ + Tick data for database storage. + + Index is defined unique with (datetime, symbol) + """ + + id = AutoField() + + symbol: str = CharField() + exchange: str = CharField() + datetime: datetime = DateTimeField() + + name: str = CharField() + volume: float = FloatField() + last_price: float = FloatField() + last_volume: float = FloatField() + limit_up: float = FloatField() + limit_down: float = FloatField() + + open_price: float = FloatField() + high_price: float = FloatField() + low_price: float = FloatField() + pre_close: float = FloatField() + + bid_price_1: float = FloatField() + bid_price_2: float = FloatField(null=True) + bid_price_3: float = FloatField(null=True) + bid_price_4: float = FloatField(null=True) + bid_price_5: float = FloatField(null=True) + + ask_price_1: float = FloatField() + ask_price_2: float = FloatField(null=True) + ask_price_3: float = FloatField(null=True) + ask_price_4: float = FloatField(null=True) + ask_price_5: float = FloatField(null=True) + + bid_volume_1: float = FloatField() + bid_volume_2: float = FloatField(null=True) + bid_volume_3: float = FloatField(null=True) + bid_volume_4: float = FloatField(null=True) + bid_volume_5: float = FloatField(null=True) + + ask_volume_1: float = FloatField() + ask_volume_2: float = FloatField(null=True) + ask_volume_3: float = FloatField(null=True) + ask_volume_4: float = FloatField(null=True) + ask_volume_5: float = FloatField(null=True) + + class Meta: + database = db + indexes = ((("datetime", "symbol", "exchange"), True),) + + @staticmethod + def from_tick(tick: TickData): + """ + Generate DbTickData object from TickData. + """ + db_tick = DbTickData() + + db_tick.symbol = tick.symbol + db_tick.exchange = tick.exchange.value + db_tick.datetime = tick.datetime + db_tick.name = tick.name + db_tick.volume = tick.volume + db_tick.last_price = tick.last_price + db_tick.last_volume = tick.last_volume + db_tick.limit_up = tick.limit_up + db_tick.limit_down = tick.limit_down + db_tick.open_price = tick.open_price + db_tick.high_price = tick.high_price + db_tick.low_price = tick.low_price + db_tick.pre_close = tick.pre_close + + db_tick.bid_price_1 = tick.bid_price_1 + db_tick.ask_price_1 = tick.ask_price_1 + db_tick.bid_volume_1 = tick.bid_volume_1 + db_tick.ask_volume_1 = tick.ask_volume_1 + + if tick.bid_price_2: + db_tick.bid_price_2 = tick.bid_price_2 + db_tick.bid_price_3 = tick.bid_price_3 + db_tick.bid_price_4 = tick.bid_price_4 + db_tick.bid_price_5 = tick.bid_price_5 + + db_tick.ask_price_2 = tick.ask_price_2 + db_tick.ask_price_3 = tick.ask_price_3 + db_tick.ask_price_4 = tick.ask_price_4 + db_tick.ask_price_5 = tick.ask_price_5 + + db_tick.bid_volume_2 = tick.bid_volume_2 + db_tick.bid_volume_3 = tick.bid_volume_3 + db_tick.bid_volume_4 = tick.bid_volume_4 + db_tick.bid_volume_5 = tick.bid_volume_5 + + db_tick.ask_volume_2 = tick.ask_volume_2 + db_tick.ask_volume_3 = tick.ask_volume_3 + db_tick.ask_volume_4 = tick.ask_volume_4 + db_tick.ask_volume_5 = tick.ask_volume_5 + + return db_tick + + def to_tick(self): + """ + Generate TickData object from DbTickData. + """ + tick = TickData( + symbol=self.symbol, + exchange=Exchange(self.exchange), + datetime=self.datetime, + name=self.name, + volume=self.volume, + last_price=self.last_price, + last_volume=self.last_volume, + limit_up=self.limit_up, + limit_down=self.limit_down, + open_price=self.open_price, + high_price=self.high_price, + low_price=self.low_price, + pre_close=self.pre_close, + bid_price_1=self.bid_price_1, + ask_price_1=self.ask_price_1, + bid_volume_1=self.bid_volume_1, + ask_volume_1=self.ask_volume_1, + gateway_name="DB", + ) + + if self.bid_price_2: + tick.bid_price_2 = self.bid_price_2 + tick.bid_price_3 = self.bid_price_3 + tick.bid_price_4 = self.bid_price_4 + tick.bid_price_5 = self.bid_price_5 + + tick.ask_price_2 = self.ask_price_2 + tick.ask_price_3 = self.ask_price_3 + tick.ask_price_4 = self.ask_price_4 + tick.ask_price_5 = self.ask_price_5 + + tick.bid_volume_2 = self.bid_volume_2 + tick.bid_volume_3 = self.bid_volume_3 + tick.bid_volume_4 = self.bid_volume_4 + tick.bid_volume_5 = self.bid_volume_5 + + tick.ask_volume_2 = self.ask_volume_2 + tick.ask_volume_3 = self.ask_volume_3 + tick.ask_volume_4 = self.ask_volume_4 + tick.ask_volume_5 = self.ask_volume_5 + + return tick + + @staticmethod + def save_all(objs: List["DbTickData"]): + dicts = [i.to_dict() for i in objs] + with db.atomic(): + if driver is Driver.POSTGRESQL: + for tick in dicts: + DbTickData.insert(tick).on_conflict( + update=tick, + conflict_target=( + DbTickData.datetime, + DbTickData.symbol, + DbTickData.exchange, + ), + ).execute() + else: + for c in chunked(dicts, 50): + DbTickData.insert_many(c).on_conflict_replace().execute() + + db.connect() + db.create_tables([DbBarData, DbTickData]) + return DbBarData, DbTickData + + +class SqlManager(BaseDatabaseManager): + + def __init__(self, class_bar: Type[Model], class_tick: Type[Model]): + self.class_bar = class_bar + self.class_tick = class_tick + + def load_bar_data( + self, + symbol: str, + exchange: Exchange, + interval: Interval, + start: datetime, + end: datetime, + ) -> Sequence[BarData]: + s = ( + self.class_bar.select() + .where( + (self.class_bar.symbol == symbol) + & (self.class_bar.exchange == exchange.value) + & (self.class_bar.interval == interval.value) + & (self.class_bar.datetime >= start) + & (self.class_bar.datetime <= end) + ) + .order_by(self.class_bar.datetime) + ) + data = [db_bar.to_bar() for db_bar in s] + return data + + def load_tick_data( + self, symbol: str, exchange: Exchange, start: datetime, end: datetime + ) -> Sequence[TickData]: + s = ( + self.class_tick.select() + .where( + (self.class_tick.symbol == symbol) + & (self.class_tick.exchange == exchange.value) + & (self.class_tick.datetime >= start) + & (self.class_tick.datetime <= end) + ) + .order_by(self.class_tick.datetime) + ) + data = [db_tick.to_tick() for db_tick in s] + return data + + def save_bar_data(self, datas: Sequence[BarData]): + ds = [self.class_bar.from_bar(i) for i in datas] + self.class_bar.save_all(ds) + + def save_tick_data(self, datas: Sequence[TickData]): + ds = [self.class_tick.from_tick(i) for i in datas] + self.class_tick.save_all(ds) diff --git a/vnpy/trader/database/initialize.py b/vnpy/trader/database/initialize.py new file mode 100644 index 00000000..a15b484c --- /dev/null +++ b/vnpy/trader/database/initialize.py @@ -0,0 +1,24 @@ +"""""" +from .database import BaseDatabaseManager, Driver + + +def init(settings: dict) -> BaseDatabaseManager: + driver = Driver(settings["driver"]) + if driver is Driver.MONGODB: + return init_nosql(driver=driver, settings=settings) + else: + return init_sql(driver=driver, settings=settings) + + +def init_sql(driver: Driver, settings: dict): + from .database_sql import init + keys = {'database', "host", "port", "user", "password"} + settings = {k: v for k, v in settings.items() if k in keys} + _database_manager = init(driver, settings) + return _database_manager + + +def init_nosql(driver: Driver, settings: dict): + from .database_mongo import init + _database_manager = init(driver, settings=settings) + return _database_manager diff --git a/vnpy/trader/setting.py b/vnpy/trader/setting.py index 778b147c..1dd8b0ff 100644 --- a/vnpy/trader/setting.py +++ b/vnpy/trader/setting.py @@ -30,9 +30,15 @@ SETTINGS = { "database.host": "localhost", "database.port": 3306, "database.user": "root", - "database.password": "" + "database.password": "", + "database.authentication_source": "admin", # for mongodb } # Load global setting from json file. SETTING_FILENAME = "vt_setting.json" SETTINGS.update(load_json(SETTING_FILENAME)) + + +def get_settings(prefix: str = ""): + prefix_length = len(prefix) + return {k[prefix_length:]: v for k, v in SETTINGS.items() if k.startswith(prefix)} diff --git a/vnpy/trader/utility.py b/vnpy/trader/utility.py index 216dcbae..69fa6991 100644 --- a/vnpy/trader/utility.py +++ b/vnpy/trader/utility.py @@ -3,20 +3,28 @@ General utility functions. """ import json -import os from pathlib import Path -from typing import Callable +from typing import Callable, TYPE_CHECKING import numpy as np import talib from .object import BarData, TickData +if TYPE_CHECKING: + from vnpy.trader.constant import Exchange -def resolve_path(pattern: str): - env = dict(os.environ) - env.update({"VNPY_TEMP": str(TEMP_DIR)}) - return pattern.format(**env) + +def extract_vt_symbol(vt_symbol: str): + """ + :return: (symbol, exchange) + """ + symbol, exchange = vt_symbol.split('.') + return symbol, exchange + + +def generate_vt_symbol(symbol: str, exchange: "Exchange"): + return f'{symbol}.{exchange.value}' def _get_trader_dir(temp_name: str):