[Fix] fix bugs for newly added functions

This commit is contained in:
nanoric 2019-04-18 00:21:30 -04:00
parent 19e27ea031
commit b72a1dc155
5 changed files with 168 additions and 62 deletions

View File

@ -3,48 +3,46 @@ Test if database works fine
""" """
import os import os
import unittest import unittest
from copy import copy
from datetime import datetime, timedelta from datetime import datetime, timedelta
from vnpy.trader.constant import Exchange, Interval from vnpy.trader.constant import Exchange, Interval
from vnpy.trader.database.database import Driver from vnpy.trader.database.database import Driver
from vnpy.trader.object import BarData, TickData from vnpy.trader.object import BarData, TickData
os.environ['VNPY_TESTING'] = '1' os.environ["VNPY_TESTING"] = "1"
profiles = { profiles = {Driver.SQLITE: {"driver": "sqlite", "database": "test_db.db"}}
Driver.SQLITE: { if "VNPY_TEST_ONLY_SQLITE" not in os.environ:
"driver": "sqlite", profiles.update(
"database": "test_db.db", {
}
}
if 'VNPY_TEST_ONLY_SQLITE' not in os.environ:
profiles.update({
Driver.MYSQL: { Driver.MYSQL: {
"driver": "mysql", "driver": "mysql",
"database": os.environ['VNPY_TEST_MYSQL_DATABASE'], "database": os.environ["VNPY_TEST_MYSQL_DATABASE"],
"host": os.environ['VNPY_TEST_MYSQL_HOST'], "host": os.environ["VNPY_TEST_MYSQL_HOST"],
"port": int(os.environ['VNPY_TEST_MYSQL_PORT']), "port": int(os.environ["VNPY_TEST_MYSQL_PORT"]),
"user": os.environ["VNPY_TEST_MYSQL_USER"], "user": os.environ["VNPY_TEST_MYSQL_USER"],
"password": os.environ['VNPY_TEST_MYSQL_PASSWORD'], "password": os.environ["VNPY_TEST_MYSQL_PASSWORD"],
}, },
Driver.POSTGRESQL: { Driver.POSTGRESQL: {
"driver": "postgresql", "driver": "postgresql",
"database": os.environ['VNPY_TEST_POSTGRESQL_DATABASE'], "database": os.environ["VNPY_TEST_POSTGRESQL_DATABASE"],
"host": os.environ['VNPY_TEST_POSTGRESQL_HOST'], "host": os.environ["VNPY_TEST_POSTGRESQL_HOST"],
"port": int(os.environ['VNPY_TEST_POSTGRESQL_PORT']), "port": int(os.environ["VNPY_TEST_POSTGRESQL_PORT"]),
"user": os.environ["VNPY_TEST_POSTGRESQL_USER"], "user": os.environ["VNPY_TEST_POSTGRESQL_USER"],
"password": os.environ['VNPY_TEST_POSTGRESQL_PASSWORD'], "password": os.environ["VNPY_TEST_POSTGRESQL_PASSWORD"],
}, },
Driver.MONGODB: { Driver.MONGODB: {
"driver": "mongodb", "driver": "mongodb",
"database": os.environ['VNPY_TEST_MONGODB_DATABASE'], "database": os.environ["VNPY_TEST_MONGODB_DATABASE"],
"host": os.environ['VNPY_TEST_MONGODB_HOST'], "host": os.environ["VNPY_TEST_MONGODB_HOST"],
"port": int(os.environ['VNPY_TEST_MONGODB_PORT']), "port": int(os.environ["VNPY_TEST_MONGODB_PORT"]),
"user": "", "user": "",
"password": "", "password": "",
"authentication_source": "", "authentication_source": "",
}, },
}) }
)
def now(): def now():
@ -69,17 +67,53 @@ tick = TickData(
class TestDatabase(unittest.TestCase): class TestDatabase(unittest.TestCase):
def connect(self, settings: dict): def connect(self, settings: dict):
from vnpy.trader.database.initialize import init # noqa from vnpy.trader.database.initialize import init # noqa
self.manager = init(settings) self.manager = init(settings)
def assertBarCount(self, count, msg):
bars = self.manager.load_bar_data(
symbol=bar.symbol,
exchange=bar.exchange,
interval=bar.interval,
start=bar.datetime - timedelta(days=1),
end=now()
)
self.assertEqual(count, len(bars), msg)
def assertTickCount(self, count, msg):
ticks = self.manager.load_tick_data(
symbol=bar.symbol,
exchange=bar.exchange,
start=bar.datetime - timedelta(days=1),
end=now()
)
self.assertEqual(count, len(ticks), msg)
def assertNoData(self, msg):
self.assertBarCount(0, msg)
self.assertTickCount(0, msg)
def setUp(self) -> None:
# clean all data first
for driver, settings in profiles.items():
self.connect(settings)
self.manager.clean(bar.symbol)
self.assertNoData("Failed to clean data!")
self.manager = None
def tearDown(self) -> None:
self.manager.clean(bar.symbol)
self.assertNoData("Failed to clean data!")
def test_upsert_bar(self): def test_upsert_bar(self):
for driver, settings in profiles.items(): for driver, settings in profiles.items():
with self.subTest(driver=driver, settings=settings): with self.subTest(driver=driver, settings=settings):
self.connect(settings) self.connect(settings)
self.manager.save_bar_data([bar]) self.manager.save_bar_data([bar])
self.manager.save_bar_data([bar]) self.manager.save_bar_data([bar])
self.assertBarCount(1, "there should be only one item after upsert")
def test_save_load_bar(self): def test_save_load_bar(self):
for driver, settings in profiles.items(): for driver, settings in profiles.items():
@ -88,16 +122,7 @@ class TestDatabase(unittest.TestCase):
# save first # save first
self.manager.save_bar_data([bar]) self.manager.save_bar_data([bar])
# and load self.assertBarCount(1, "there should be only one item after save")
results = self.manager.load_bar_data(
symbol=bar.symbol,
exchange=bar.exchange,
interval=bar.interval,
start=bar.datetime - timedelta(seconds=1), # time is not accuracy
end=now() + timedelta(seconds=1), # time is not accuracy
)
count = len(results)
self.assertNotEqual(count, 0)
def test_upsert_tick(self): def test_upsert_tick(self):
for driver, settings in profiles.items(): for driver, settings in profiles.items():
@ -105,6 +130,7 @@ class TestDatabase(unittest.TestCase):
self.connect(settings) self.connect(settings)
self.manager.save_tick_data([tick]) self.manager.save_tick_data([tick])
self.manager.save_tick_data([tick]) self.manager.save_tick_data([tick])
self.assertTickCount(1, "there should be only one item after upsert")
def test_save_load_tick(self): def test_save_load_tick(self):
for driver, settings in profiles.items(): for driver, settings in profiles.items():
@ -113,16 +139,48 @@ class TestDatabase(unittest.TestCase):
# save first # save first
self.manager.save_tick_data([tick]) self.manager.save_tick_data([tick])
# and load self.assertTickCount(1, "there should be only one item after save")
results = self.manager.load_tick_data(
symbol=bar.symbol, def test_newest_bar(self):
exchange=bar.exchange, for driver, settings in profiles.items():
start=bar.datetime - timedelta(seconds=1), # time is not accuracy with self.subTest(driver=driver, settings=settings):
end=now() + timedelta(seconds=1), # time is not accuracy self.connect(settings)
# an older one
older_one = copy(bar)
older_one.volume = 123.0
older_one.datetime = now() - timedelta(days=1)
# and a newer one
newer_one = copy(bar)
newer_one.volume = 456.0
newer_one.datetime = now()
self.manager.save_bar_data([older_one, newer_one])
got = self.manager.get_newest_bar_data(
bar.symbol, bar.exchange, bar.interval
) )
count = len(results) self.assertEqual(got.volume, newer_one.volume, "the newest bar we got mismatched")
self.assertNotEqual(count, 0)
def test_newest_tick(self):
for driver, settings in profiles.items():
with self.subTest(driver=driver, settings=settings):
self.connect(settings)
# an older one
older_one = copy(tick)
older_one.volume = 123
older_one.datetime = now() - timedelta(days=1)
# and a newer one
newer_one = copy(tick)
older_one.volume = 456
newer_one.datetime = now()
self.manager.save_tick_data([older_one, newer_one])
got = self.manager.get_newest_tick_data(tick.symbol, tick.exchange)
self.assertEqual(got.volume, newer_one.volume, "the newest tick we got mismatched")
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -1,7 +1,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import Sequence, TYPE_CHECKING, Optional from typing import Optional, Sequence, TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from vnpy.trader.constant import Interval, Exchange # noqa from vnpy.trader.constant import Interval, Exchange # noqa
@ -76,3 +76,10 @@ class BaseDatabaseManager(ABC):
otherwise, return None otherwise, return None
""" """
pass pass
@abstractmethod
def clean(self, symbol: str):
"""
delete all records for a symbol
"""
pass

View File

@ -322,8 +322,8 @@ class MongoManager(BaseDatabaseManager):
.order_by("-datetime") .order_by("-datetime")
.first() .first()
) )
if len(s): if s:
return list(s)[0] return s.to_bar()
return None return None
def get_newest_tick_data( def get_newest_tick_data(
@ -334,6 +334,10 @@ class MongoManager(BaseDatabaseManager):
.order_by("-datetime") .order_by("-datetime")
.first() .first()
) )
if len(s): if s:
return list(s)[0] return s.to_tick()
return None return None
def clean(self, symbol: str):
DbTickData.objects(symbol=symbol).delete()
DbBarData.objects(symbol=symbol).delete()

View File

@ -1,6 +1,6 @@
"""""" """"""
from datetime import datetime from datetime import datetime
from typing import List, Sequence, Type from typing import List, Optional, Sequence, Type
from peewee import ( from peewee import (
AutoField, AutoField,
@ -367,3 +367,40 @@ class SqlManager(BaseDatabaseManager):
def save_tick_data(self, datas: Sequence[TickData]): def save_tick_data(self, datas: Sequence[TickData]):
ds = [self.class_tick.from_tick(i) for i in datas] ds = [self.class_tick.from_tick(i) for i in datas]
self.class_tick.save_all(ds) self.class_tick.save_all(ds)
def get_newest_bar_data(
self, symbol: str, exchange: "Exchange", interval: "Interval"
) -> Optional["BarData"]:
s = (
self.class_bar.select()
.where(
(self.class_bar.symbol == symbol)
& (self.class_bar.exchange == exchange.value)
& (self.class_bar.interval == interval.value)
)
.order_by(self.class_bar.datetime.desc())
.first()
)
if s:
return s.to_bar()
return None
def get_newest_tick_data(
self, symbol: str, exchange: "Exchange"
) -> Optional["TickData"]:
s = (
self.class_tick.select()
.where(
(self.class_tick.symbol == symbol)
& (self.class_tick.exchange == exchange.value)
)
.order_by(self.class_tick.datetime.desc())
.first()
)
if s:
return s.to_tick()
return None
def clean(self, symbol: str):
self.class_bar.delete().where(self.class_bar.symbol == symbol).execute()
self.class_tick.delete().where(self.class_tick.symbol == symbol).execute()

View File

@ -4,7 +4,7 @@ General utility functions.
import json import json
from pathlib import Path from pathlib import Path
from typing import Callable, TYPE_CHECKING from typing import Callable
import numpy as np import numpy as np
import talib import talib