diff --git a/.travis.yml b/.travis.yml index 0ef835b3..696fb7da 100644 --- a/.travis.yml +++ b/.travis.yml @@ -75,6 +75,7 @@ matrix: - sudo make install - popd - pip install numpy + - pip install --pre --extra-index-url https://rquser:ricequant99@py.ricequant.com/simple/ rqdatac - pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl - python setup.py sdist - pip install dist/`ls dist` diff --git a/appveyor.yml b/appveyor.yml index 437edefc..0bd0d088 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -47,10 +47,7 @@ for: - configuration: pip build_script: - pip install psycopg2 mongoengine pymysql # we should support all database in test environment - - pip install https://pip.vnpy.com/colletion/TA_Lib-0.4.17-cp37-cp37m-win_amd64.whl - - pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl - - pip install -r requirements.txt - - pip install . + - call install.bat - matrix: only: @@ -58,6 +55,7 @@ for: build_script: - python setup.py sdist - pip install psycopg2 mongoengine pymysql # we should support all database in test environment + - pip install --pre --extra-index-url https://rquser:ricequant99@py.ricequant.com/simple/ rqdatac - pip install https://pip.vnpy.com/colletion/TA_Lib-0.4.17-cp37-cp37m-win_amd64.whl - pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl - ps: $name=(ls dist).name; pip install "dist/$name" diff --git a/install.bat b/install.bat index e45576f0..5a0afb30 100644 --- a/install.bat +++ b/install.bat @@ -1,3 +1,6 @@ +:: rqdatac +pip install --pre --extra-index-url https://rquser:ricequant99@py.ricequant.com/simple/ rqdatac + ::Install talib and ibapi pip install https://pip.vnpy.com/colletion/TA_Lib-0.4.17-cp37-cp37m-win_amd64.whl pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl diff --git a/install.sh b/install.sh index dc748ee5..0ca3b777 100644 --- a/install.sh +++ b/install.sh @@ -22,6 +22,7 @@ popd $pip install numpy # Install extra packages +$pip install --pre --extra-index-url https://rquser:ricequant99@py.ricequant.com/simple/ rqdatac $pip install ta-lib $pip install https://vnpy-pip.oss-cn-shanghai.aliyuncs.com/colletion/ibapi-9.75.1-py3-none-any.whl diff --git a/requirements.txt b/requirements.txt index 308b49c4..c6959ed5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,6 +11,7 @@ matplotlib seaborn futu-api tigeropen +rqdatac ta-lib ibapi mongoengine diff --git a/setup.py b/setup.py index 3b0d1ca7..8a00a931 100644 --- a/setup.py +++ b/setup.py @@ -119,6 +119,7 @@ install_requires = [ "seaborn", "futu-api", "tigeropen", + "rqdatac", "ta-lib", "ibapi" ] diff --git a/tests/trader/test_database.py b/tests/trader/test_database.py index df009baf..02e9bfc1 100644 --- a/tests/trader/test_database.py +++ b/tests/trader/test_database.py @@ -3,48 +3,46 @@ Test if database works fine """ import os import unittest +from copy import copy from datetime import datetime, timedelta from vnpy.trader.constant import Exchange, Interval from vnpy.trader.database.database import Driver from vnpy.trader.object import BarData, TickData -os.environ['VNPY_TESTING'] = '1' +os.environ["VNPY_TESTING"] = "1" -profiles = { - Driver.SQLITE: { - "driver": "sqlite", - "database": "test_db.db", - } -} -if 'VNPY_TEST_ONLY_SQLITE' not in os.environ: - profiles.update({ - Driver.MYSQL: { - "driver": "mysql", - "database": os.environ['VNPY_TEST_MYSQL_DATABASE'], - "host": os.environ['VNPY_TEST_MYSQL_HOST'], - "port": int(os.environ['VNPY_TEST_MYSQL_PORT']), - "user": os.environ["VNPY_TEST_MYSQL_USER"], - "password": os.environ['VNPY_TEST_MYSQL_PASSWORD'], - }, - Driver.POSTGRESQL: { - "driver": "postgresql", - "database": os.environ['VNPY_TEST_POSTGRESQL_DATABASE'], - "host": os.environ['VNPY_TEST_POSTGRESQL_HOST'], - "port": int(os.environ['VNPY_TEST_POSTGRESQL_PORT']), - "user": os.environ["VNPY_TEST_POSTGRESQL_USER"], - "password": os.environ['VNPY_TEST_POSTGRESQL_PASSWORD'], - }, - Driver.MONGODB: { - "driver": "mongodb", - "database": os.environ['VNPY_TEST_MONGODB_DATABASE'], - "host": os.environ['VNPY_TEST_MONGODB_HOST'], - "port": int(os.environ['VNPY_TEST_MONGODB_PORT']), - "user": "", - "password": "", - "authentication_source": "", - }, - }) +profiles = {Driver.SQLITE: {"driver": "sqlite", "database": "test_db.db"}} +if "VNPY_TEST_ONLY_SQLITE" not in os.environ: + profiles.update( + { + Driver.MYSQL: { + "driver": "mysql", + "database": os.environ["VNPY_TEST_MYSQL_DATABASE"], + "host": os.environ["VNPY_TEST_MYSQL_HOST"], + "port": int(os.environ["VNPY_TEST_MYSQL_PORT"]), + "user": os.environ["VNPY_TEST_MYSQL_USER"], + "password": os.environ["VNPY_TEST_MYSQL_PASSWORD"], + }, + Driver.POSTGRESQL: { + "driver": "postgresql", + "database": os.environ["VNPY_TEST_POSTGRESQL_DATABASE"], + "host": os.environ["VNPY_TEST_POSTGRESQL_HOST"], + "port": int(os.environ["VNPY_TEST_POSTGRESQL_PORT"]), + "user": os.environ["VNPY_TEST_POSTGRESQL_USER"], + "password": os.environ["VNPY_TEST_POSTGRESQL_PASSWORD"], + }, + Driver.MONGODB: { + "driver": "mongodb", + "database": os.environ["VNPY_TEST_MONGODB_DATABASE"], + "host": os.environ["VNPY_TEST_MONGODB_HOST"], + "port": int(os.environ["VNPY_TEST_MONGODB_PORT"]), + "user": "", + "password": "", + "authentication_source": "", + }, + } + ) def now(): @@ -72,14 +70,51 @@ class TestDatabase(unittest.TestCase): def connect(self, settings: dict): from vnpy.trader.database.initialize import init # noqa + self.manager = init(settings) + def assertBarCount(self, count, msg): + bars = self.manager.load_bar_data( + symbol=bar.symbol, + exchange=bar.exchange, + interval=bar.interval, + start=bar.datetime - timedelta(days=1), + end=now() + ) + self.assertEqual(count, len(bars), msg) + + def assertTickCount(self, count, msg): + ticks = self.manager.load_tick_data( + symbol=bar.symbol, + exchange=bar.exchange, + start=bar.datetime - timedelta(days=1), + end=now() + ) + self.assertEqual(count, len(ticks), msg) + + def assertNoData(self, msg): + self.assertBarCount(0, msg) + self.assertTickCount(0, msg) + + def setUp(self) -> None: + # clean all data first + for driver, settings in profiles.items(): + self.connect(settings) + self.manager.clean(bar.symbol) + self.assertNoData("Failed to clean data!") + self.manager = None + + def tearDown(self) -> None: + self.manager.clean(bar.symbol) + self.assertNoData("Failed to clean data!") + def test_upsert_bar(self): for driver, settings in profiles.items(): with self.subTest(driver=driver, settings=settings): self.connect(settings) self.manager.save_bar_data([bar]) self.manager.save_bar_data([bar]) + self.assertBarCount(1, "there should be only one item after upsert") def test_save_load_bar(self): for driver, settings in profiles.items(): @@ -88,16 +123,7 @@ class TestDatabase(unittest.TestCase): # save first self.manager.save_bar_data([bar]) - # and load - results = self.manager.load_bar_data( - symbol=bar.symbol, - exchange=bar.exchange, - interval=bar.interval, - start=bar.datetime - timedelta(seconds=1), # time is not accuracy - end=now() + timedelta(seconds=1), # time is not accuracy - ) - count = len(results) - self.assertNotEqual(count, 0) + self.assertBarCount(1, "there should be only one item after save") def test_upsert_tick(self): for driver, settings in profiles.items(): @@ -105,6 +131,7 @@ class TestDatabase(unittest.TestCase): self.connect(settings) self.manager.save_tick_data([tick]) self.manager.save_tick_data([tick]) + self.assertTickCount(1, "there should be only one item after upsert") def test_save_load_tick(self): for driver, settings in profiles.items(): @@ -113,16 +140,58 @@ class TestDatabase(unittest.TestCase): # save first self.manager.save_tick_data([tick]) - # and load - results = self.manager.load_tick_data( - symbol=bar.symbol, - exchange=bar.exchange, - start=bar.datetime - timedelta(seconds=1), # time is not accuracy - end=now() + timedelta(seconds=1), # time is not accuracy + self.assertTickCount(1, "there should be only one item after save") + + def test_newest_bar(self): + for driver, settings in profiles.items(): + with self.subTest(driver=driver, settings=settings): + self.connect(settings) + got = self.manager.get_newest_bar_data(bar.symbol, bar.exchange, bar.interval) + self.assertIsNone( + got, + "database is empty, but return value for newest_bar_data() is not a None" ) - count = len(results) - self.assertNotEqual(count, 0) + + # an older one + older_one = copy(bar) + older_one.volume = 123.0 + older_one.datetime = now() - timedelta(days=1) + + # and a newer one + newer_one = copy(bar) + newer_one.volume = 456.0 + newer_one.datetime = now() + self.manager.save_bar_data([older_one, newer_one]) + + got = self.manager.get_newest_bar_data( + bar.symbol, bar.exchange, bar.interval + ) + self.assertEqual(got.volume, newer_one.volume, "the newest bar we got mismatched") + + def test_newest_tick(self): + for driver, settings in profiles.items(): + with self.subTest(driver=driver, settings=settings): + self.connect(settings) + + got = self.manager.get_newest_tick_data(tick.symbol, tick.exchange) + self.assertIsNone( + got, + "database is empty, but return value for newest_tick_data() is not a None" + ) + # an older one + older_one = copy(tick) + older_one.volume = 123 + older_one.datetime = now() - timedelta(days=1) + + # and a newer one + newer_one = copy(tick) + older_one.volume = 456 + newer_one.datetime = now() + self.manager.save_tick_data([older_one, newer_one]) + + got = self.manager.get_newest_tick_data(tick.symbol, tick.exchange) + self.assertEqual(got.volume, newer_one.volume, "the newest tick we got mismatched") -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/vnpy/trader/database/database.py b/vnpy/trader/database/database.py index ec765e8d..c6920158 100644 --- a/vnpy/trader/database/database.py +++ b/vnpy/trader/database/database.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from datetime import datetime from enum import Enum -from typing import Sequence, TYPE_CHECKING +from typing import Optional, Sequence, TYPE_CHECKING if TYPE_CHECKING: from vnpy.trader.constant import Interval, Exchange # noqa @@ -51,3 +51,35 @@ class BaseDatabaseManager(ABC): datas: Sequence["TickData"], ): pass + + @abstractmethod + def get_newest_bar_data( + self, + symbol: str, + exchange: "Exchange", + interval: "Interval" + ) -> Optional["BarData"]: + """ + If there is data in database, return the one with greatest datetime(newest one) + otherwise, return None + """ + pass + + @abstractmethod + def get_newest_tick_data( + self, + symbol: str, + exchange: "Exchange", + ) -> Optional["TickData"]: + """ + If there is data in database, return the one with greatest datetime(newest one) + otherwise, return None + """ + pass + + @abstractmethod + def clean(self, symbol: str): + """ + delete all records for a symbol + """ + pass diff --git a/vnpy/trader/database/database_mongo.py b/vnpy/trader/database/database_mongo.py index cdf68000..f1e7c4a9 100644 --- a/vnpy/trader/database/database_mongo.py +++ b/vnpy/trader/database/database_mongo.py @@ -1,6 +1,6 @@ from datetime import datetime from enum import Enum -from typing import Sequence +from typing import Sequence, Optional from mongoengine import DateTimeField, Document, FloatField, StringField, connect @@ -54,8 +54,10 @@ class DbBarData(Document): meta = { "indexes": [ - {"fields": ("datetime", "interval", "symbol", - "exchange"), "unique": True} + { + "fields": ("datetime", "interval", "symbol", "exchange"), + "unique": True, + } ] } @@ -145,8 +147,14 @@ class DbTickData(Document): ask_volume_4: float = FloatField() ask_volume_5: float = FloatField() - meta = {"indexes": [ - {"fields": ("datetime", "symbol", "exchange"), "unique": True}]} + meta = { + "indexes": [ + { + "fields": ("datetime", "symbol", "exchange"), + "unique": True, + } + ], + } @staticmethod def from_tick(tick: TickData): @@ -305,3 +313,31 @@ class MongoManager(BaseDatabaseManager): symbol=d.symbol, exchange=d.exchange.value, datetime=d.datetime ).update_one(upsert=True, **updates) ) + + def get_newest_bar_data( + self, symbol: str, exchange: "Exchange", interval: "Interval" + ) -> Optional["BarData"]: + s = ( + DbBarData.objects(symbol=symbol, exchange=exchange.value) + .order_by("-datetime") + .first() + ) + if s: + return s.to_bar() + return None + + def get_newest_tick_data( + self, symbol: str, exchange: "Exchange" + ) -> Optional["TickData"]: + s = ( + DbTickData.objects(symbol=symbol, exchange=exchange.value) + .order_by("-datetime") + .first() + ) + if s: + return s.to_tick() + return None + + def clean(self, symbol: str): + DbTickData.objects(symbol=symbol).delete() + DbBarData.objects(symbol=symbol).delete() diff --git a/vnpy/trader/database/database_sql.py b/vnpy/trader/database/database_sql.py index 61d6e2ff..37c2e609 100644 --- a/vnpy/trader/database/database_sql.py +++ b/vnpy/trader/database/database_sql.py @@ -1,6 +1,6 @@ """""" from datetime import datetime -from typing import List, Sequence, Type +from typing import List, Optional, Sequence, Type from peewee import ( AutoField, @@ -367,3 +367,40 @@ class SqlManager(BaseDatabaseManager): def save_tick_data(self, datas: Sequence[TickData]): ds = [self.class_tick.from_tick(i) for i in datas] self.class_tick.save_all(ds) + + def get_newest_bar_data( + self, symbol: str, exchange: "Exchange", interval: "Interval" + ) -> Optional["BarData"]: + s = ( + self.class_bar.select() + .where( + (self.class_bar.symbol == symbol) + & (self.class_bar.exchange == exchange.value) + & (self.class_bar.interval == interval.value) + ) + .order_by(self.class_bar.datetime.desc()) + .first() + ) + if s: + return s.to_bar() + return None + + def get_newest_tick_data( + self, symbol: str, exchange: "Exchange" + ) -> Optional["TickData"]: + s = ( + self.class_tick.select() + .where( + (self.class_tick.symbol == symbol) + & (self.class_tick.exchange == exchange.value) + ) + .order_by(self.class_tick.datetime.desc()) + .first() + ) + if s: + return s.to_tick() + return None + + def clean(self, symbol: str): + self.class_bar.delete().where(self.class_bar.symbol == symbol).execute() + self.class_tick.delete().where(self.class_tick.symbol == symbol).execute()