[Fix] fix bugs for newly added functions
This commit is contained in:
parent
19e27ea031
commit
b72a1dc155
@ -3,48 +3,46 @@ Test if database works fine
|
||||
"""
|
||||
import os
|
||||
import unittest
|
||||
from copy import copy
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from vnpy.trader.constant import Exchange, Interval
|
||||
from vnpy.trader.database.database import Driver
|
||||
from vnpy.trader.object import BarData, TickData
|
||||
|
||||
os.environ['VNPY_TESTING'] = '1'
|
||||
os.environ["VNPY_TESTING"] = "1"
|
||||
|
||||
profiles = {
|
||||
Driver.SQLITE: {
|
||||
"driver": "sqlite",
|
||||
"database": "test_db.db",
|
||||
}
|
||||
}
|
||||
if 'VNPY_TEST_ONLY_SQLITE' not in os.environ:
|
||||
profiles.update({
|
||||
Driver.MYSQL: {
|
||||
"driver": "mysql",
|
||||
"database": os.environ['VNPY_TEST_MYSQL_DATABASE'],
|
||||
"host": os.environ['VNPY_TEST_MYSQL_HOST'],
|
||||
"port": int(os.environ['VNPY_TEST_MYSQL_PORT']),
|
||||
"user": os.environ["VNPY_TEST_MYSQL_USER"],
|
||||
"password": os.environ['VNPY_TEST_MYSQL_PASSWORD'],
|
||||
},
|
||||
Driver.POSTGRESQL: {
|
||||
"driver": "postgresql",
|
||||
"database": os.environ['VNPY_TEST_POSTGRESQL_DATABASE'],
|
||||
"host": os.environ['VNPY_TEST_POSTGRESQL_HOST'],
|
||||
"port": int(os.environ['VNPY_TEST_POSTGRESQL_PORT']),
|
||||
"user": os.environ["VNPY_TEST_POSTGRESQL_USER"],
|
||||
"password": os.environ['VNPY_TEST_POSTGRESQL_PASSWORD'],
|
||||
},
|
||||
Driver.MONGODB: {
|
||||
"driver": "mongodb",
|
||||
"database": os.environ['VNPY_TEST_MONGODB_DATABASE'],
|
||||
"host": os.environ['VNPY_TEST_MONGODB_HOST'],
|
||||
"port": int(os.environ['VNPY_TEST_MONGODB_PORT']),
|
||||
"user": "",
|
||||
"password": "",
|
||||
"authentication_source": "",
|
||||
},
|
||||
})
|
||||
profiles = {Driver.SQLITE: {"driver": "sqlite", "database": "test_db.db"}}
|
||||
if "VNPY_TEST_ONLY_SQLITE" not in os.environ:
|
||||
profiles.update(
|
||||
{
|
||||
Driver.MYSQL: {
|
||||
"driver": "mysql",
|
||||
"database": os.environ["VNPY_TEST_MYSQL_DATABASE"],
|
||||
"host": os.environ["VNPY_TEST_MYSQL_HOST"],
|
||||
"port": int(os.environ["VNPY_TEST_MYSQL_PORT"]),
|
||||
"user": os.environ["VNPY_TEST_MYSQL_USER"],
|
||||
"password": os.environ["VNPY_TEST_MYSQL_PASSWORD"],
|
||||
},
|
||||
Driver.POSTGRESQL: {
|
||||
"driver": "postgresql",
|
||||
"database": os.environ["VNPY_TEST_POSTGRESQL_DATABASE"],
|
||||
"host": os.environ["VNPY_TEST_POSTGRESQL_HOST"],
|
||||
"port": int(os.environ["VNPY_TEST_POSTGRESQL_PORT"]),
|
||||
"user": os.environ["VNPY_TEST_POSTGRESQL_USER"],
|
||||
"password": os.environ["VNPY_TEST_POSTGRESQL_PASSWORD"],
|
||||
},
|
||||
Driver.MONGODB: {
|
||||
"driver": "mongodb",
|
||||
"database": os.environ["VNPY_TEST_MONGODB_DATABASE"],
|
||||
"host": os.environ["VNPY_TEST_MONGODB_HOST"],
|
||||
"port": int(os.environ["VNPY_TEST_MONGODB_PORT"]),
|
||||
"user": "",
|
||||
"password": "",
|
||||
"authentication_source": "",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def now():
|
||||
@ -69,17 +67,53 @@ tick = TickData(
|
||||
|
||||
|
||||
class TestDatabase(unittest.TestCase):
|
||||
|
||||
def connect(self, settings: dict):
|
||||
from vnpy.trader.database.initialize import init # noqa
|
||||
|
||||
self.manager = init(settings)
|
||||
|
||||
def assertBarCount(self, count, msg):
|
||||
bars = self.manager.load_bar_data(
|
||||
symbol=bar.symbol,
|
||||
exchange=bar.exchange,
|
||||
interval=bar.interval,
|
||||
start=bar.datetime - timedelta(days=1),
|
||||
end=now()
|
||||
)
|
||||
self.assertEqual(count, len(bars), msg)
|
||||
|
||||
def assertTickCount(self, count, msg):
|
||||
ticks = self.manager.load_tick_data(
|
||||
symbol=bar.symbol,
|
||||
exchange=bar.exchange,
|
||||
start=bar.datetime - timedelta(days=1),
|
||||
end=now()
|
||||
)
|
||||
self.assertEqual(count, len(ticks), msg)
|
||||
|
||||
def assertNoData(self, msg):
|
||||
self.assertBarCount(0, msg)
|
||||
self.assertTickCount(0, msg)
|
||||
|
||||
def setUp(self) -> None:
|
||||
# clean all data first
|
||||
for driver, settings in profiles.items():
|
||||
self.connect(settings)
|
||||
self.manager.clean(bar.symbol)
|
||||
self.assertNoData("Failed to clean data!")
|
||||
self.manager = None
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.manager.clean(bar.symbol)
|
||||
self.assertNoData("Failed to clean data!")
|
||||
|
||||
def test_upsert_bar(self):
|
||||
for driver, settings in profiles.items():
|
||||
with self.subTest(driver=driver, settings=settings):
|
||||
self.connect(settings)
|
||||
self.manager.save_bar_data([bar])
|
||||
self.manager.save_bar_data([bar])
|
||||
self.assertBarCount(1, "there should be only one item after upsert")
|
||||
|
||||
def test_save_load_bar(self):
|
||||
for driver, settings in profiles.items():
|
||||
@ -88,16 +122,7 @@ class TestDatabase(unittest.TestCase):
|
||||
# save first
|
||||
self.manager.save_bar_data([bar])
|
||||
|
||||
# and load
|
||||
results = self.manager.load_bar_data(
|
||||
symbol=bar.symbol,
|
||||
exchange=bar.exchange,
|
||||
interval=bar.interval,
|
||||
start=bar.datetime - timedelta(seconds=1), # time is not accuracy
|
||||
end=now() + timedelta(seconds=1), # time is not accuracy
|
||||
)
|
||||
count = len(results)
|
||||
self.assertNotEqual(count, 0)
|
||||
self.assertBarCount(1, "there should be only one item after save")
|
||||
|
||||
def test_upsert_tick(self):
|
||||
for driver, settings in profiles.items():
|
||||
@ -105,6 +130,7 @@ class TestDatabase(unittest.TestCase):
|
||||
self.connect(settings)
|
||||
self.manager.save_tick_data([tick])
|
||||
self.manager.save_tick_data([tick])
|
||||
self.assertTickCount(1, "there should be only one item after upsert")
|
||||
|
||||
def test_save_load_tick(self):
|
||||
for driver, settings in profiles.items():
|
||||
@ -113,16 +139,48 @@ class TestDatabase(unittest.TestCase):
|
||||
# save first
|
||||
self.manager.save_tick_data([tick])
|
||||
|
||||
# and load
|
||||
results = self.manager.load_tick_data(
|
||||
symbol=bar.symbol,
|
||||
exchange=bar.exchange,
|
||||
start=bar.datetime - timedelta(seconds=1), # time is not accuracy
|
||||
end=now() + timedelta(seconds=1), # time is not accuracy
|
||||
self.assertTickCount(1, "there should be only one item after save")
|
||||
|
||||
def test_newest_bar(self):
|
||||
for driver, settings in profiles.items():
|
||||
with self.subTest(driver=driver, settings=settings):
|
||||
self.connect(settings)
|
||||
|
||||
# an older one
|
||||
older_one = copy(bar)
|
||||
older_one.volume = 123.0
|
||||
older_one.datetime = now() - timedelta(days=1)
|
||||
|
||||
# and a newer one
|
||||
newer_one = copy(bar)
|
||||
newer_one.volume = 456.0
|
||||
newer_one.datetime = now()
|
||||
self.manager.save_bar_data([older_one, newer_one])
|
||||
|
||||
got = self.manager.get_newest_bar_data(
|
||||
bar.symbol, bar.exchange, bar.interval
|
||||
)
|
||||
count = len(results)
|
||||
self.assertNotEqual(count, 0)
|
||||
self.assertEqual(got.volume, newer_one.volume, "the newest bar we got mismatched")
|
||||
|
||||
def test_newest_tick(self):
|
||||
for driver, settings in profiles.items():
|
||||
with self.subTest(driver=driver, settings=settings):
|
||||
self.connect(settings)
|
||||
|
||||
# an older one
|
||||
older_one = copy(tick)
|
||||
older_one.volume = 123
|
||||
older_one.datetime = now() - timedelta(days=1)
|
||||
|
||||
# and a newer one
|
||||
newer_one = copy(tick)
|
||||
older_one.volume = 456
|
||||
newer_one.datetime = now()
|
||||
self.manager.save_tick_data([older_one, newer_one])
|
||||
|
||||
got = self.manager.get_newest_tick_data(tick.symbol, tick.exchange)
|
||||
self.assertEqual(got.volume, newer_one.volume, "the newest tick we got mismatched")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -1,7 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Sequence, TYPE_CHECKING, Optional
|
||||
from typing import Optional, Sequence, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vnpy.trader.constant import Interval, Exchange # noqa
|
||||
@ -76,3 +76,10 @@ class BaseDatabaseManager(ABC):
|
||||
otherwise, return None
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def clean(self, symbol: str):
|
||||
"""
|
||||
delete all records for a symbol
|
||||
"""
|
||||
pass
|
||||
|
@ -322,8 +322,8 @@ class MongoManager(BaseDatabaseManager):
|
||||
.order_by("-datetime")
|
||||
.first()
|
||||
)
|
||||
if len(s):
|
||||
return list(s)[0]
|
||||
if s:
|
||||
return s.to_bar()
|
||||
return None
|
||||
|
||||
def get_newest_tick_data(
|
||||
@ -334,6 +334,10 @@ class MongoManager(BaseDatabaseManager):
|
||||
.order_by("-datetime")
|
||||
.first()
|
||||
)
|
||||
if len(s):
|
||||
return list(s)[0]
|
||||
if s:
|
||||
return s.to_tick()
|
||||
return None
|
||||
|
||||
def clean(self, symbol: str):
|
||||
DbTickData.objects(symbol=symbol).delete()
|
||||
DbBarData.objects(symbol=symbol).delete()
|
||||
|
@ -1,6 +1,6 @@
|
||||
""""""
|
||||
from datetime import datetime
|
||||
from typing import List, Sequence, Type
|
||||
from typing import List, Optional, Sequence, Type
|
||||
|
||||
from peewee import (
|
||||
AutoField,
|
||||
@ -367,3 +367,40 @@ class SqlManager(BaseDatabaseManager):
|
||||
def save_tick_data(self, datas: Sequence[TickData]):
|
||||
ds = [self.class_tick.from_tick(i) for i in datas]
|
||||
self.class_tick.save_all(ds)
|
||||
|
||||
def get_newest_bar_data(
|
||||
self, symbol: str, exchange: "Exchange", interval: "Interval"
|
||||
) -> Optional["BarData"]:
|
||||
s = (
|
||||
self.class_bar.select()
|
||||
.where(
|
||||
(self.class_bar.symbol == symbol)
|
||||
& (self.class_bar.exchange == exchange.value)
|
||||
& (self.class_bar.interval == interval.value)
|
||||
)
|
||||
.order_by(self.class_bar.datetime.desc())
|
||||
.first()
|
||||
)
|
||||
if s:
|
||||
return s.to_bar()
|
||||
return None
|
||||
|
||||
def get_newest_tick_data(
|
||||
self, symbol: str, exchange: "Exchange"
|
||||
) -> Optional["TickData"]:
|
||||
s = (
|
||||
self.class_tick.select()
|
||||
.where(
|
||||
(self.class_tick.symbol == symbol)
|
||||
& (self.class_tick.exchange == exchange.value)
|
||||
)
|
||||
.order_by(self.class_tick.datetime.desc())
|
||||
.first()
|
||||
)
|
||||
if s:
|
||||
return s.to_tick()
|
||||
return None
|
||||
|
||||
def clean(self, symbol: str):
|
||||
self.class_bar.delete().where(self.class_bar.symbol == symbol).execute()
|
||||
self.class_tick.delete().where(self.class_tick.symbol == symbol).execute()
|
||||
|
@ -4,7 +4,7 @@ General utility functions.
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Callable, TYPE_CHECKING
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
import talib
|
||||
|
Loading…
Reference in New Issue
Block a user