vnpy/tests/trader/test_database.py

198 lines
6.8 KiB
Python
Raw Normal View History

"""
Test if database works fine
"""
import os
import unittest
from copy import copy
from datetime import datetime, timedelta
from vnpy.trader.constant import Exchange, Interval
from vnpy.trader.database.database import Driver
from vnpy.trader.object import BarData, TickData
os.environ["VNPY_TESTING"] = "1"
profiles = {Driver.SQLITE: {"driver": "sqlite", "database": "test_db.db"}}
if "VNPY_TEST_ONLY_SQLITE" not in os.environ:
profiles.update(
{
Driver.MYSQL: {
"driver": "mysql",
"database": os.environ["VNPY_TEST_MYSQL_DATABASE"],
"host": os.environ["VNPY_TEST_MYSQL_HOST"],
"port": int(os.environ["VNPY_TEST_MYSQL_PORT"]),
"user": os.environ["VNPY_TEST_MYSQL_USER"],
"password": os.environ["VNPY_TEST_MYSQL_PASSWORD"],
},
Driver.POSTGRESQL: {
"driver": "postgresql",
"database": os.environ["VNPY_TEST_POSTGRESQL_DATABASE"],
"host": os.environ["VNPY_TEST_POSTGRESQL_HOST"],
"port": int(os.environ["VNPY_TEST_POSTGRESQL_PORT"]),
"user": os.environ["VNPY_TEST_POSTGRESQL_USER"],
"password": os.environ["VNPY_TEST_POSTGRESQL_PASSWORD"],
},
Driver.MONGODB: {
"driver": "mongodb",
"database": os.environ["VNPY_TEST_MONGODB_DATABASE"],
"host": os.environ["VNPY_TEST_MONGODB_HOST"],
"port": int(os.environ["VNPY_TEST_MONGODB_PORT"]),
"user": "",
"password": "",
"authentication_source": "",
},
}
)
def now():
return datetime.utcnow()
bar = BarData(
gateway_name="DB",
symbol="test_symbol",
exchange=Exchange.BITMEX,
datetime=now(),
interval=Interval.MINUTE,
)
tick = TickData(
gateway_name="DB",
symbol="test_symbol",
exchange=Exchange.BITMEX,
datetime=now(),
name="DB_test_symbol",
)
class TestDatabase(unittest.TestCase):
def connect(self, settings: dict):
from vnpy.trader.database.initialize import init # noqa
self.manager = init(settings)
def assertBarCount(self, count, msg):
bars = self.manager.load_bar_data(
symbol=bar.symbol,
exchange=bar.exchange,
interval=bar.interval,
start=bar.datetime - timedelta(days=1),
end=now()
)
self.assertEqual(count, len(bars), msg)
def assertTickCount(self, count, msg):
ticks = self.manager.load_tick_data(
symbol=bar.symbol,
exchange=bar.exchange,
start=bar.datetime - timedelta(days=1),
end=now()
)
self.assertEqual(count, len(ticks), msg)
def assertNoData(self, msg):
self.assertBarCount(0, msg)
self.assertTickCount(0, msg)
def setUp(self) -> None:
# clean all data first
for driver, settings in profiles.items():
self.connect(settings)
self.manager.clean(bar.symbol)
self.assertNoData("Failed to clean data!")
self.manager = None
def tearDown(self) -> None:
self.manager.clean(bar.symbol)
self.assertNoData("Failed to clean data!")
def test_upsert_bar(self):
for driver, settings in profiles.items():
with self.subTest(driver=driver, settings=settings):
self.connect(settings)
self.manager.save_bar_data([bar])
self.manager.save_bar_data([bar])
self.assertBarCount(1, "there should be only one item after upsert")
def test_save_load_bar(self):
for driver, settings in profiles.items():
with self.subTest(driver=driver, settings=settings):
self.connect(settings)
# save first
self.manager.save_bar_data([bar])
self.assertBarCount(1, "there should be only one item after save")
def test_upsert_tick(self):
for driver, settings in profiles.items():
with self.subTest(driver=driver, settings=settings):
self.connect(settings)
self.manager.save_tick_data([tick])
self.manager.save_tick_data([tick])
self.assertTickCount(1, "there should be only one item after upsert")
def test_save_load_tick(self):
for driver, settings in profiles.items():
with self.subTest(driver=driver, settings=settings):
self.connect(settings)
# save first
self.manager.save_tick_data([tick])
self.assertTickCount(1, "there should be only one item after save")
def test_newest_bar(self):
for driver, settings in profiles.items():
with self.subTest(driver=driver, settings=settings):
self.connect(settings)
got = self.manager.get_newest_bar_data(bar.symbol, bar.exchange, bar.interval)
self.assertIsNone(
got,
"database is empty, but return value for newest_bar_data() is not a None"
)
# an older one
older_one = copy(bar)
older_one.volume = 123.0
older_one.datetime = now() - timedelta(days=1)
# and a newer one
newer_one = copy(bar)
newer_one.volume = 456.0
newer_one.datetime = now()
self.manager.save_bar_data([older_one, newer_one])
got = self.manager.get_newest_bar_data(
bar.symbol, bar.exchange, bar.interval
)
self.assertEqual(got.volume, newer_one.volume, "the newest bar we got mismatched")
def test_newest_tick(self):
for driver, settings in profiles.items():
with self.subTest(driver=driver, settings=settings):
self.connect(settings)
got = self.manager.get_newest_tick_data(tick.symbol, tick.exchange)
self.assertIsNone(
got,
"database is empty, but return value for newest_tick_data() is not a None"
)
# an older one
older_one = copy(tick)
older_one.volume = 123
older_one.datetime = now() - timedelta(days=1)
# and a newer one
newer_one = copy(tick)
older_one.volume = 456
newer_one.datetime = now()
self.manager.save_tick_data([older_one, newer_one])
got = self.manager.get_newest_tick_data(tick.symbol, tick.exchange)
self.assertEqual(got.volume, newer_one.volume, "the newest tick we got mismatched")
if __name__ == "__main__":
unittest.main()