[Fix] fix bugs for newly added functions

This commit is contained in:
nanoric 2019-04-18 00:21:30 -04:00
parent 19e27ea031
commit b72a1dc155
5 changed files with 168 additions and 62 deletions

View File

@ -3,48 +3,46 @@ Test if database works fine
"""
import os
import unittest
from copy import copy
from datetime import datetime, timedelta
from vnpy.trader.constant import Exchange, Interval
from vnpy.trader.database.database import Driver
from vnpy.trader.object import BarData, TickData
os.environ['VNPY_TESTING'] = '1'
os.environ["VNPY_TESTING"] = "1"
profiles = {
Driver.SQLITE: {
"driver": "sqlite",
"database": "test_db.db",
}
}
if 'VNPY_TEST_ONLY_SQLITE' not in os.environ:
profiles.update({
Driver.MYSQL: {
"driver": "mysql",
"database": os.environ['VNPY_TEST_MYSQL_DATABASE'],
"host": os.environ['VNPY_TEST_MYSQL_HOST'],
"port": int(os.environ['VNPY_TEST_MYSQL_PORT']),
"user": os.environ["VNPY_TEST_MYSQL_USER"],
"password": os.environ['VNPY_TEST_MYSQL_PASSWORD'],
},
Driver.POSTGRESQL: {
"driver": "postgresql",
"database": os.environ['VNPY_TEST_POSTGRESQL_DATABASE'],
"host": os.environ['VNPY_TEST_POSTGRESQL_HOST'],
"port": int(os.environ['VNPY_TEST_POSTGRESQL_PORT']),
"user": os.environ["VNPY_TEST_POSTGRESQL_USER"],
"password": os.environ['VNPY_TEST_POSTGRESQL_PASSWORD'],
},
Driver.MONGODB: {
"driver": "mongodb",
"database": os.environ['VNPY_TEST_MONGODB_DATABASE'],
"host": os.environ['VNPY_TEST_MONGODB_HOST'],
"port": int(os.environ['VNPY_TEST_MONGODB_PORT']),
"user": "",
"password": "",
"authentication_source": "",
},
})
profiles = {Driver.SQLITE: {"driver": "sqlite", "database": "test_db.db"}}
if "VNPY_TEST_ONLY_SQLITE" not in os.environ:
profiles.update(
{
Driver.MYSQL: {
"driver": "mysql",
"database": os.environ["VNPY_TEST_MYSQL_DATABASE"],
"host": os.environ["VNPY_TEST_MYSQL_HOST"],
"port": int(os.environ["VNPY_TEST_MYSQL_PORT"]),
"user": os.environ["VNPY_TEST_MYSQL_USER"],
"password": os.environ["VNPY_TEST_MYSQL_PASSWORD"],
},
Driver.POSTGRESQL: {
"driver": "postgresql",
"database": os.environ["VNPY_TEST_POSTGRESQL_DATABASE"],
"host": os.environ["VNPY_TEST_POSTGRESQL_HOST"],
"port": int(os.environ["VNPY_TEST_POSTGRESQL_PORT"]),
"user": os.environ["VNPY_TEST_POSTGRESQL_USER"],
"password": os.environ["VNPY_TEST_POSTGRESQL_PASSWORD"],
},
Driver.MONGODB: {
"driver": "mongodb",
"database": os.environ["VNPY_TEST_MONGODB_DATABASE"],
"host": os.environ["VNPY_TEST_MONGODB_HOST"],
"port": int(os.environ["VNPY_TEST_MONGODB_PORT"]),
"user": "",
"password": "",
"authentication_source": "",
},
}
)
def now():
@ -69,17 +67,53 @@ tick = TickData(
class TestDatabase(unittest.TestCase):
def connect(self, settings: dict):
from vnpy.trader.database.initialize import init # noqa
self.manager = init(settings)
def assertBarCount(self, count, msg):
bars = self.manager.load_bar_data(
symbol=bar.symbol,
exchange=bar.exchange,
interval=bar.interval,
start=bar.datetime - timedelta(days=1),
end=now()
)
self.assertEqual(count, len(bars), msg)
def assertTickCount(self, count, msg):
ticks = self.manager.load_tick_data(
symbol=bar.symbol,
exchange=bar.exchange,
start=bar.datetime - timedelta(days=1),
end=now()
)
self.assertEqual(count, len(ticks), msg)
def assertNoData(self, msg):
self.assertBarCount(0, msg)
self.assertTickCount(0, msg)
def setUp(self) -> None:
# clean all data first
for driver, settings in profiles.items():
self.connect(settings)
self.manager.clean(bar.symbol)
self.assertNoData("Failed to clean data!")
self.manager = None
def tearDown(self) -> None:
self.manager.clean(bar.symbol)
self.assertNoData("Failed to clean data!")
def test_upsert_bar(self):
for driver, settings in profiles.items():
with self.subTest(driver=driver, settings=settings):
self.connect(settings)
self.manager.save_bar_data([bar])
self.manager.save_bar_data([bar])
self.assertBarCount(1, "there should be only one item after upsert")
def test_save_load_bar(self):
for driver, settings in profiles.items():
@ -88,16 +122,7 @@ class TestDatabase(unittest.TestCase):
# save first
self.manager.save_bar_data([bar])
# and load
results = self.manager.load_bar_data(
symbol=bar.symbol,
exchange=bar.exchange,
interval=bar.interval,
start=bar.datetime - timedelta(seconds=1), # time is not accuracy
end=now() + timedelta(seconds=1), # time is not accuracy
)
count = len(results)
self.assertNotEqual(count, 0)
self.assertBarCount(1, "there should be only one item after save")
def test_upsert_tick(self):
for driver, settings in profiles.items():
@ -105,6 +130,7 @@ class TestDatabase(unittest.TestCase):
self.connect(settings)
self.manager.save_tick_data([tick])
self.manager.save_tick_data([tick])
self.assertTickCount(1, "there should be only one item after upsert")
def test_save_load_tick(self):
for driver, settings in profiles.items():
@ -113,16 +139,48 @@ class TestDatabase(unittest.TestCase):
# save first
self.manager.save_tick_data([tick])
# and load
results = self.manager.load_tick_data(
symbol=bar.symbol,
exchange=bar.exchange,
start=bar.datetime - timedelta(seconds=1), # time is not accuracy
end=now() + timedelta(seconds=1), # time is not accuracy
self.assertTickCount(1, "there should be only one item after save")
def test_newest_bar(self):
for driver, settings in profiles.items():
with self.subTest(driver=driver, settings=settings):
self.connect(settings)
# an older one
older_one = copy(bar)
older_one.volume = 123.0
older_one.datetime = now() - timedelta(days=1)
# and a newer one
newer_one = copy(bar)
newer_one.volume = 456.0
newer_one.datetime = now()
self.manager.save_bar_data([older_one, newer_one])
got = self.manager.get_newest_bar_data(
bar.symbol, bar.exchange, bar.interval
)
count = len(results)
self.assertNotEqual(count, 0)
self.assertEqual(got.volume, newer_one.volume, "the newest bar we got mismatched")
def test_newest_tick(self):
for driver, settings in profiles.items():
with self.subTest(driver=driver, settings=settings):
self.connect(settings)
# an older one
older_one = copy(tick)
older_one.volume = 123
older_one.datetime = now() - timedelta(days=1)
# and a newer one
newer_one = copy(tick)
older_one.volume = 456
newer_one.datetime = now()
self.manager.save_tick_data([older_one, newer_one])
got = self.manager.get_newest_tick_data(tick.symbol, tick.exchange)
self.assertEqual(got.volume, newer_one.volume, "the newest tick we got mismatched")
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

View File

@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from datetime import datetime
from enum import Enum
from typing import Sequence, TYPE_CHECKING, Optional
from typing import Optional, Sequence, TYPE_CHECKING
if TYPE_CHECKING:
from vnpy.trader.constant import Interval, Exchange # noqa
@ -76,3 +76,10 @@ class BaseDatabaseManager(ABC):
otherwise, return None
"""
pass
@abstractmethod
def clean(self, symbol: str):
"""
delete all records for a symbol
"""
pass

View File

@ -322,8 +322,8 @@ class MongoManager(BaseDatabaseManager):
.order_by("-datetime")
.first()
)
if len(s):
return list(s)[0]
if s:
return s.to_bar()
return None
def get_newest_tick_data(
@ -334,6 +334,10 @@ class MongoManager(BaseDatabaseManager):
.order_by("-datetime")
.first()
)
if len(s):
return list(s)[0]
if s:
return s.to_tick()
return None
def clean(self, symbol: str):
DbTickData.objects(symbol=symbol).delete()
DbBarData.objects(symbol=symbol).delete()

View File

@ -1,6 +1,6 @@
""""""
from datetime import datetime
from typing import List, Sequence, Type
from typing import List, Optional, Sequence, Type
from peewee import (
AutoField,
@ -367,3 +367,40 @@ class SqlManager(BaseDatabaseManager):
def save_tick_data(self, datas: Sequence[TickData]):
ds = [self.class_tick.from_tick(i) for i in datas]
self.class_tick.save_all(ds)
def get_newest_bar_data(
self, symbol: str, exchange: "Exchange", interval: "Interval"
) -> Optional["BarData"]:
s = (
self.class_bar.select()
.where(
(self.class_bar.symbol == symbol)
& (self.class_bar.exchange == exchange.value)
& (self.class_bar.interval == interval.value)
)
.order_by(self.class_bar.datetime.desc())
.first()
)
if s:
return s.to_bar()
return None
def get_newest_tick_data(
self, symbol: str, exchange: "Exchange"
) -> Optional["TickData"]:
s = (
self.class_tick.select()
.where(
(self.class_tick.symbol == symbol)
& (self.class_tick.exchange == exchange.value)
)
.order_by(self.class_tick.datetime.desc())
.first()
)
if s:
return s.to_tick()
return None
def clean(self, symbol: str):
self.class_bar.delete().where(self.class_bar.symbol == symbol).execute()
self.class_tick.delete().where(self.class_tick.symbol == symbol).execute()

View File

@ -4,7 +4,7 @@ General utility functions.
import json
from pathlib import Path
from typing import Callable, TYPE_CHECKING
from typing import Callable
import numpy as np
import talib