198 lines
6.8 KiB
Python
198 lines
6.8 KiB
Python
"""
|
|
Test if database works fine
|
|
"""
|
|
import os
|
|
import unittest
|
|
from copy import copy
|
|
from datetime import datetime, timedelta
|
|
|
|
from vnpy.trader.constant import Exchange, Interval
|
|
from vnpy.trader.database.database import Driver
|
|
from vnpy.trader.object import BarData, TickData
|
|
|
|
os.environ["VNPY_TESTING"] = "1"
|
|
|
|
profiles = {Driver.SQLITE: {"driver": "sqlite", "database": "test_db.db"}}
|
|
if "VNPY_TEST_ONLY_SQLITE" not in os.environ:
|
|
profiles.update(
|
|
{
|
|
Driver.MYSQL: {
|
|
"driver": "mysql",
|
|
"database": os.environ["VNPY_TEST_MYSQL_DATABASE"],
|
|
"host": os.environ["VNPY_TEST_MYSQL_HOST"],
|
|
"port": int(os.environ["VNPY_TEST_MYSQL_PORT"]),
|
|
"user": os.environ["VNPY_TEST_MYSQL_USER"],
|
|
"password": os.environ["VNPY_TEST_MYSQL_PASSWORD"],
|
|
},
|
|
Driver.POSTGRESQL: {
|
|
"driver": "postgresql",
|
|
"database": os.environ["VNPY_TEST_POSTGRESQL_DATABASE"],
|
|
"host": os.environ["VNPY_TEST_POSTGRESQL_HOST"],
|
|
"port": int(os.environ["VNPY_TEST_POSTGRESQL_PORT"]),
|
|
"user": os.environ["VNPY_TEST_POSTGRESQL_USER"],
|
|
"password": os.environ["VNPY_TEST_POSTGRESQL_PASSWORD"],
|
|
},
|
|
Driver.MONGODB: {
|
|
"driver": "mongodb",
|
|
"database": os.environ["VNPY_TEST_MONGODB_DATABASE"],
|
|
"host": os.environ["VNPY_TEST_MONGODB_HOST"],
|
|
"port": int(os.environ["VNPY_TEST_MONGODB_PORT"]),
|
|
"user": "",
|
|
"password": "",
|
|
"authentication_source": "",
|
|
},
|
|
}
|
|
)
|
|
|
|
|
|
def now():
|
|
return datetime.utcnow()
|
|
|
|
|
|
bar = BarData(
|
|
gateway_name="DB",
|
|
symbol="test_symbol",
|
|
exchange=Exchange.BITMEX,
|
|
datetime=now(),
|
|
interval=Interval.MINUTE,
|
|
)
|
|
|
|
tick = TickData(
|
|
gateway_name="DB",
|
|
symbol="test_symbol",
|
|
exchange=Exchange.BITMEX,
|
|
datetime=now(),
|
|
name="DB_test_symbol",
|
|
)
|
|
|
|
|
|
class TestDatabase(unittest.TestCase):
|
|
|
|
def connect(self, settings: dict):
|
|
from vnpy.trader.database.initialize import init # noqa
|
|
|
|
self.manager = init(settings)
|
|
|
|
def assertBarCount(self, count, msg):
|
|
bars = self.manager.load_bar_data(
|
|
symbol=bar.symbol,
|
|
exchange=bar.exchange,
|
|
interval=bar.interval,
|
|
start=bar.datetime - timedelta(days=1),
|
|
end=now()
|
|
)
|
|
self.assertEqual(count, len(bars), msg)
|
|
|
|
def assertTickCount(self, count, msg):
|
|
ticks = self.manager.load_tick_data(
|
|
symbol=bar.symbol,
|
|
exchange=bar.exchange,
|
|
start=bar.datetime - timedelta(days=1),
|
|
end=now()
|
|
)
|
|
self.assertEqual(count, len(ticks), msg)
|
|
|
|
def assertNoData(self, msg):
|
|
self.assertBarCount(0, msg)
|
|
self.assertTickCount(0, msg)
|
|
|
|
def setUp(self) -> None:
|
|
# clean all data first
|
|
for driver, settings in profiles.items():
|
|
self.connect(settings)
|
|
self.manager.clean(bar.symbol)
|
|
self.assertNoData("Failed to clean data!")
|
|
self.manager = None
|
|
|
|
def tearDown(self) -> None:
|
|
self.manager.clean(bar.symbol)
|
|
self.assertNoData("Failed to clean data!")
|
|
|
|
def test_upsert_bar(self):
|
|
for driver, settings in profiles.items():
|
|
with self.subTest(driver=driver, settings=settings):
|
|
self.connect(settings)
|
|
self.manager.save_bar_data([bar])
|
|
self.manager.save_bar_data([bar])
|
|
self.assertBarCount(1, "there should be only one item after upsert")
|
|
|
|
def test_save_load_bar(self):
|
|
for driver, settings in profiles.items():
|
|
with self.subTest(driver=driver, settings=settings):
|
|
self.connect(settings)
|
|
# save first
|
|
self.manager.save_bar_data([bar])
|
|
|
|
self.assertBarCount(1, "there should be only one item after save")
|
|
|
|
def test_upsert_tick(self):
|
|
for driver, settings in profiles.items():
|
|
with self.subTest(driver=driver, settings=settings):
|
|
self.connect(settings)
|
|
self.manager.save_tick_data([tick])
|
|
self.manager.save_tick_data([tick])
|
|
self.assertTickCount(1, "there should be only one item after upsert")
|
|
|
|
def test_save_load_tick(self):
|
|
for driver, settings in profiles.items():
|
|
with self.subTest(driver=driver, settings=settings):
|
|
self.connect(settings)
|
|
# save first
|
|
self.manager.save_tick_data([tick])
|
|
|
|
self.assertTickCount(1, "there should be only one item after save")
|
|
|
|
def test_newest_bar(self):
|
|
for driver, settings in profiles.items():
|
|
with self.subTest(driver=driver, settings=settings):
|
|
self.connect(settings)
|
|
got = self.manager.get_newest_bar_data(bar.symbol, bar.exchange, bar.interval)
|
|
self.assertIsNone(
|
|
got,
|
|
"database is empty, but return value for newest_bar_data() is not a None"
|
|
)
|
|
|
|
# an older one
|
|
older_one = copy(bar)
|
|
older_one.volume = 123.0
|
|
older_one.datetime = now() - timedelta(days=1)
|
|
|
|
# and a newer one
|
|
newer_one = copy(bar)
|
|
newer_one.volume = 456.0
|
|
newer_one.datetime = now()
|
|
self.manager.save_bar_data([older_one, newer_one])
|
|
|
|
got = self.manager.get_newest_bar_data(
|
|
bar.symbol, bar.exchange, bar.interval
|
|
)
|
|
self.assertEqual(got.volume, newer_one.volume, "the newest bar we got mismatched")
|
|
|
|
def test_newest_tick(self):
|
|
for driver, settings in profiles.items():
|
|
with self.subTest(driver=driver, settings=settings):
|
|
self.connect(settings)
|
|
|
|
got = self.manager.get_newest_tick_data(tick.symbol, tick.exchange)
|
|
self.assertIsNone(
|
|
got,
|
|
"database is empty, but return value for newest_tick_data() is not a None"
|
|
)
|
|
# an older one
|
|
older_one = copy(tick)
|
|
older_one.volume = 123
|
|
older_one.datetime = now() - timedelta(days=1)
|
|
|
|
# and a newer one
|
|
newer_one = copy(tick)
|
|
older_one.volume = 456
|
|
newer_one.datetime = now()
|
|
self.manager.save_tick_data([older_one, newer_one])
|
|
|
|
got = self.manager.get_newest_tick_data(tick.symbol, tick.exchange)
|
|
self.assertEqual(got.volume, newer_one.volume, "the newest tick we got mismatched")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|