diff --git a/tests/trader/test_database.py b/tests/trader/test_database.py index df009baf..d1aedb89 100644 --- a/tests/trader/test_database.py +++ b/tests/trader/test_database.py @@ -3,48 +3,46 @@ Test if database works fine """ import os import unittest +from copy import copy from datetime import datetime, timedelta from vnpy.trader.constant import Exchange, Interval from vnpy.trader.database.database import Driver from vnpy.trader.object import BarData, TickData -os.environ['VNPY_TESTING'] = '1' +os.environ["VNPY_TESTING"] = "1" -profiles = { - Driver.SQLITE: { - "driver": "sqlite", - "database": "test_db.db", - } -} -if 'VNPY_TEST_ONLY_SQLITE' not in os.environ: - profiles.update({ - Driver.MYSQL: { - "driver": "mysql", - "database": os.environ['VNPY_TEST_MYSQL_DATABASE'], - "host": os.environ['VNPY_TEST_MYSQL_HOST'], - "port": int(os.environ['VNPY_TEST_MYSQL_PORT']), - "user": os.environ["VNPY_TEST_MYSQL_USER"], - "password": os.environ['VNPY_TEST_MYSQL_PASSWORD'], - }, - Driver.POSTGRESQL: { - "driver": "postgresql", - "database": os.environ['VNPY_TEST_POSTGRESQL_DATABASE'], - "host": os.environ['VNPY_TEST_POSTGRESQL_HOST'], - "port": int(os.environ['VNPY_TEST_POSTGRESQL_PORT']), - "user": os.environ["VNPY_TEST_POSTGRESQL_USER"], - "password": os.environ['VNPY_TEST_POSTGRESQL_PASSWORD'], - }, - Driver.MONGODB: { - "driver": "mongodb", - "database": os.environ['VNPY_TEST_MONGODB_DATABASE'], - "host": os.environ['VNPY_TEST_MONGODB_HOST'], - "port": int(os.environ['VNPY_TEST_MONGODB_PORT']), - "user": "", - "password": "", - "authentication_source": "", - }, - }) +profiles = {Driver.SQLITE: {"driver": "sqlite", "database": "test_db.db"}} +if "VNPY_TEST_ONLY_SQLITE" not in os.environ: + profiles.update( + { + Driver.MYSQL: { + "driver": "mysql", + "database": os.environ["VNPY_TEST_MYSQL_DATABASE"], + "host": os.environ["VNPY_TEST_MYSQL_HOST"], + "port": int(os.environ["VNPY_TEST_MYSQL_PORT"]), + "user": os.environ["VNPY_TEST_MYSQL_USER"], + "password": os.environ["VNPY_TEST_MYSQL_PASSWORD"], + }, + Driver.POSTGRESQL: { + "driver": "postgresql", + "database": os.environ["VNPY_TEST_POSTGRESQL_DATABASE"], + "host": os.environ["VNPY_TEST_POSTGRESQL_HOST"], + "port": int(os.environ["VNPY_TEST_POSTGRESQL_PORT"]), + "user": os.environ["VNPY_TEST_POSTGRESQL_USER"], + "password": os.environ["VNPY_TEST_POSTGRESQL_PASSWORD"], + }, + Driver.MONGODB: { + "driver": "mongodb", + "database": os.environ["VNPY_TEST_MONGODB_DATABASE"], + "host": os.environ["VNPY_TEST_MONGODB_HOST"], + "port": int(os.environ["VNPY_TEST_MONGODB_PORT"]), + "user": "", + "password": "", + "authentication_source": "", + }, + } + ) def now(): @@ -69,17 +67,53 @@ tick = TickData( class TestDatabase(unittest.TestCase): - def connect(self, settings: dict): from vnpy.trader.database.initialize import init # noqa + self.manager = init(settings) + def assertBarCount(self, count, msg): + bars = self.manager.load_bar_data( + symbol=bar.symbol, + exchange=bar.exchange, + interval=bar.interval, + start=bar.datetime - timedelta(days=1), + end=now() + ) + self.assertEqual(count, len(bars), msg) + + def assertTickCount(self, count, msg): + ticks = self.manager.load_tick_data( + symbol=bar.symbol, + exchange=bar.exchange, + start=bar.datetime - timedelta(days=1), + end=now() + ) + self.assertEqual(count, len(ticks), msg) + + def assertNoData(self, msg): + self.assertBarCount(0, msg) + self.assertTickCount(0, msg) + + def setUp(self) -> None: + # clean all data first + for driver, settings in profiles.items(): + self.connect(settings) + self.manager.clean(bar.symbol) + self.assertNoData("Failed to clean data!") + self.manager = None + + def tearDown(self) -> None: + self.manager.clean(bar.symbol) + self.assertNoData("Failed to clean data!") + def test_upsert_bar(self): for driver, settings in profiles.items(): with self.subTest(driver=driver, settings=settings): self.connect(settings) self.manager.save_bar_data([bar]) self.manager.save_bar_data([bar]) + self.assertBarCount(1, "there should be only one item after upsert") def test_save_load_bar(self): for driver, settings in profiles.items(): @@ -88,16 +122,7 @@ class TestDatabase(unittest.TestCase): # save first self.manager.save_bar_data([bar]) - # and load - results = self.manager.load_bar_data( - symbol=bar.symbol, - exchange=bar.exchange, - interval=bar.interval, - start=bar.datetime - timedelta(seconds=1), # time is not accuracy - end=now() + timedelta(seconds=1), # time is not accuracy - ) - count = len(results) - self.assertNotEqual(count, 0) + self.assertBarCount(1, "there should be only one item after save") def test_upsert_tick(self): for driver, settings in profiles.items(): @@ -105,6 +130,7 @@ class TestDatabase(unittest.TestCase): self.connect(settings) self.manager.save_tick_data([tick]) self.manager.save_tick_data([tick]) + self.assertTickCount(1, "there should be only one item after upsert") def test_save_load_tick(self): for driver, settings in profiles.items(): @@ -113,16 +139,48 @@ class TestDatabase(unittest.TestCase): # save first self.manager.save_tick_data([tick]) - # and load - results = self.manager.load_tick_data( - symbol=bar.symbol, - exchange=bar.exchange, - start=bar.datetime - timedelta(seconds=1), # time is not accuracy - end=now() + timedelta(seconds=1), # time is not accuracy + self.assertTickCount(1, "there should be only one item after save") + + def test_newest_bar(self): + for driver, settings in profiles.items(): + with self.subTest(driver=driver, settings=settings): + self.connect(settings) + + # an older one + older_one = copy(bar) + older_one.volume = 123.0 + older_one.datetime = now() - timedelta(days=1) + + # and a newer one + newer_one = copy(bar) + newer_one.volume = 456.0 + newer_one.datetime = now() + self.manager.save_bar_data([older_one, newer_one]) + + got = self.manager.get_newest_bar_data( + bar.symbol, bar.exchange, bar.interval ) - count = len(results) - self.assertNotEqual(count, 0) + self.assertEqual(got.volume, newer_one.volume, "the newest bar we got mismatched") + + def test_newest_tick(self): + for driver, settings in profiles.items(): + with self.subTest(driver=driver, settings=settings): + self.connect(settings) + + # an older one + older_one = copy(tick) + older_one.volume = 123 + older_one.datetime = now() - timedelta(days=1) + + # and a newer one + newer_one = copy(tick) + older_one.volume = 456 + newer_one.datetime = now() + self.manager.save_tick_data([older_one, newer_one]) + + got = self.manager.get_newest_tick_data(tick.symbol, tick.exchange) + self.assertEqual(got.volume, newer_one.volume, "the newest tick we got mismatched") -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/vnpy/trader/database/database.py b/vnpy/trader/database/database.py index 1b1728bb..c6920158 100644 --- a/vnpy/trader/database/database.py +++ b/vnpy/trader/database/database.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from datetime import datetime from enum import Enum -from typing import Sequence, TYPE_CHECKING, Optional +from typing import Optional, Sequence, TYPE_CHECKING if TYPE_CHECKING: from vnpy.trader.constant import Interval, Exchange # noqa @@ -76,3 +76,10 @@ class BaseDatabaseManager(ABC): otherwise, return None """ pass + + @abstractmethod + def clean(self, symbol: str): + """ + delete all records for a symbol + """ + pass diff --git a/vnpy/trader/database/database_mongo.py b/vnpy/trader/database/database_mongo.py index 7f5bcb90..f1e7c4a9 100644 --- a/vnpy/trader/database/database_mongo.py +++ b/vnpy/trader/database/database_mongo.py @@ -322,8 +322,8 @@ class MongoManager(BaseDatabaseManager): .order_by("-datetime") .first() ) - if len(s): - return list(s)[0] + if s: + return s.to_bar() return None def get_newest_tick_data( @@ -334,6 +334,10 @@ class MongoManager(BaseDatabaseManager): .order_by("-datetime") .first() ) - if len(s): - return list(s)[0] + if s: + return s.to_tick() return None + + def clean(self, symbol: str): + DbTickData.objects(symbol=symbol).delete() + DbBarData.objects(symbol=symbol).delete() diff --git a/vnpy/trader/database/database_sql.py b/vnpy/trader/database/database_sql.py index 61d6e2ff..37c2e609 100644 --- a/vnpy/trader/database/database_sql.py +++ b/vnpy/trader/database/database_sql.py @@ -1,6 +1,6 @@ """""" from datetime import datetime -from typing import List, Sequence, Type +from typing import List, Optional, Sequence, Type from peewee import ( AutoField, @@ -367,3 +367,40 @@ class SqlManager(BaseDatabaseManager): def save_tick_data(self, datas: Sequence[TickData]): ds = [self.class_tick.from_tick(i) for i in datas] self.class_tick.save_all(ds) + + def get_newest_bar_data( + self, symbol: str, exchange: "Exchange", interval: "Interval" + ) -> Optional["BarData"]: + s = ( + self.class_bar.select() + .where( + (self.class_bar.symbol == symbol) + & (self.class_bar.exchange == exchange.value) + & (self.class_bar.interval == interval.value) + ) + .order_by(self.class_bar.datetime.desc()) + .first() + ) + if s: + return s.to_bar() + return None + + def get_newest_tick_data( + self, symbol: str, exchange: "Exchange" + ) -> Optional["TickData"]: + s = ( + self.class_tick.select() + .where( + (self.class_tick.symbol == symbol) + & (self.class_tick.exchange == exchange.value) + ) + .order_by(self.class_tick.datetime.desc()) + .first() + ) + if s: + return s.to_tick() + return None + + def clean(self, symbol: str): + self.class_bar.delete().where(self.class_bar.symbol == symbol).execute() + self.class_tick.delete().where(self.class_tick.symbol == symbol).execute() diff --git a/vnpy/trader/utility.py b/vnpy/trader/utility.py index 89af4945..60a33485 100644 --- a/vnpy/trader/utility.py +++ b/vnpy/trader/utility.py @@ -4,7 +4,7 @@ General utility functions. import json from pathlib import Path -from typing import Callable, TYPE_CHECKING +from typing import Callable import numpy as np import talib