[Fix] fix bugs for newly added functions
This commit is contained in:
parent
19e27ea031
commit
b72a1dc155
@ -3,48 +3,46 @@ Test if database works fine
|
|||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
from copy import copy
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
from vnpy.trader.constant import Exchange, Interval
|
from vnpy.trader.constant import Exchange, Interval
|
||||||
from vnpy.trader.database.database import Driver
|
from vnpy.trader.database.database import Driver
|
||||||
from vnpy.trader.object import BarData, TickData
|
from vnpy.trader.object import BarData, TickData
|
||||||
|
|
||||||
os.environ['VNPY_TESTING'] = '1'
|
os.environ["VNPY_TESTING"] = "1"
|
||||||
|
|
||||||
profiles = {
|
profiles = {Driver.SQLITE: {"driver": "sqlite", "database": "test_db.db"}}
|
||||||
Driver.SQLITE: {
|
if "VNPY_TEST_ONLY_SQLITE" not in os.environ:
|
||||||
"driver": "sqlite",
|
profiles.update(
|
||||||
"database": "test_db.db",
|
{
|
||||||
}
|
Driver.MYSQL: {
|
||||||
}
|
"driver": "mysql",
|
||||||
if 'VNPY_TEST_ONLY_SQLITE' not in os.environ:
|
"database": os.environ["VNPY_TEST_MYSQL_DATABASE"],
|
||||||
profiles.update({
|
"host": os.environ["VNPY_TEST_MYSQL_HOST"],
|
||||||
Driver.MYSQL: {
|
"port": int(os.environ["VNPY_TEST_MYSQL_PORT"]),
|
||||||
"driver": "mysql",
|
"user": os.environ["VNPY_TEST_MYSQL_USER"],
|
||||||
"database": os.environ['VNPY_TEST_MYSQL_DATABASE'],
|
"password": os.environ["VNPY_TEST_MYSQL_PASSWORD"],
|
||||||
"host": os.environ['VNPY_TEST_MYSQL_HOST'],
|
},
|
||||||
"port": int(os.environ['VNPY_TEST_MYSQL_PORT']),
|
Driver.POSTGRESQL: {
|
||||||
"user": os.environ["VNPY_TEST_MYSQL_USER"],
|
"driver": "postgresql",
|
||||||
"password": os.environ['VNPY_TEST_MYSQL_PASSWORD'],
|
"database": os.environ["VNPY_TEST_POSTGRESQL_DATABASE"],
|
||||||
},
|
"host": os.environ["VNPY_TEST_POSTGRESQL_HOST"],
|
||||||
Driver.POSTGRESQL: {
|
"port": int(os.environ["VNPY_TEST_POSTGRESQL_PORT"]),
|
||||||
"driver": "postgresql",
|
"user": os.environ["VNPY_TEST_POSTGRESQL_USER"],
|
||||||
"database": os.environ['VNPY_TEST_POSTGRESQL_DATABASE'],
|
"password": os.environ["VNPY_TEST_POSTGRESQL_PASSWORD"],
|
||||||
"host": os.environ['VNPY_TEST_POSTGRESQL_HOST'],
|
},
|
||||||
"port": int(os.environ['VNPY_TEST_POSTGRESQL_PORT']),
|
Driver.MONGODB: {
|
||||||
"user": os.environ["VNPY_TEST_POSTGRESQL_USER"],
|
"driver": "mongodb",
|
||||||
"password": os.environ['VNPY_TEST_POSTGRESQL_PASSWORD'],
|
"database": os.environ["VNPY_TEST_MONGODB_DATABASE"],
|
||||||
},
|
"host": os.environ["VNPY_TEST_MONGODB_HOST"],
|
||||||
Driver.MONGODB: {
|
"port": int(os.environ["VNPY_TEST_MONGODB_PORT"]),
|
||||||
"driver": "mongodb",
|
"user": "",
|
||||||
"database": os.environ['VNPY_TEST_MONGODB_DATABASE'],
|
"password": "",
|
||||||
"host": os.environ['VNPY_TEST_MONGODB_HOST'],
|
"authentication_source": "",
|
||||||
"port": int(os.environ['VNPY_TEST_MONGODB_PORT']),
|
},
|
||||||
"user": "",
|
}
|
||||||
"password": "",
|
)
|
||||||
"authentication_source": "",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
def now():
|
def now():
|
||||||
@ -69,17 +67,53 @@ tick = TickData(
|
|||||||
|
|
||||||
|
|
||||||
class TestDatabase(unittest.TestCase):
|
class TestDatabase(unittest.TestCase):
|
||||||
|
|
||||||
def connect(self, settings: dict):
|
def connect(self, settings: dict):
|
||||||
from vnpy.trader.database.initialize import init # noqa
|
from vnpy.trader.database.initialize import init # noqa
|
||||||
|
|
||||||
self.manager = init(settings)
|
self.manager = init(settings)
|
||||||
|
|
||||||
|
def assertBarCount(self, count, msg):
|
||||||
|
bars = self.manager.load_bar_data(
|
||||||
|
symbol=bar.symbol,
|
||||||
|
exchange=bar.exchange,
|
||||||
|
interval=bar.interval,
|
||||||
|
start=bar.datetime - timedelta(days=1),
|
||||||
|
end=now()
|
||||||
|
)
|
||||||
|
self.assertEqual(count, len(bars), msg)
|
||||||
|
|
||||||
|
def assertTickCount(self, count, msg):
|
||||||
|
ticks = self.manager.load_tick_data(
|
||||||
|
symbol=bar.symbol,
|
||||||
|
exchange=bar.exchange,
|
||||||
|
start=bar.datetime - timedelta(days=1),
|
||||||
|
end=now()
|
||||||
|
)
|
||||||
|
self.assertEqual(count, len(ticks), msg)
|
||||||
|
|
||||||
|
def assertNoData(self, msg):
|
||||||
|
self.assertBarCount(0, msg)
|
||||||
|
self.assertTickCount(0, msg)
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
# clean all data first
|
||||||
|
for driver, settings in profiles.items():
|
||||||
|
self.connect(settings)
|
||||||
|
self.manager.clean(bar.symbol)
|
||||||
|
self.assertNoData("Failed to clean data!")
|
||||||
|
self.manager = None
|
||||||
|
|
||||||
|
def tearDown(self) -> None:
|
||||||
|
self.manager.clean(bar.symbol)
|
||||||
|
self.assertNoData("Failed to clean data!")
|
||||||
|
|
||||||
def test_upsert_bar(self):
|
def test_upsert_bar(self):
|
||||||
for driver, settings in profiles.items():
|
for driver, settings in profiles.items():
|
||||||
with self.subTest(driver=driver, settings=settings):
|
with self.subTest(driver=driver, settings=settings):
|
||||||
self.connect(settings)
|
self.connect(settings)
|
||||||
self.manager.save_bar_data([bar])
|
self.manager.save_bar_data([bar])
|
||||||
self.manager.save_bar_data([bar])
|
self.manager.save_bar_data([bar])
|
||||||
|
self.assertBarCount(1, "there should be only one item after upsert")
|
||||||
|
|
||||||
def test_save_load_bar(self):
|
def test_save_load_bar(self):
|
||||||
for driver, settings in profiles.items():
|
for driver, settings in profiles.items():
|
||||||
@ -88,16 +122,7 @@ class TestDatabase(unittest.TestCase):
|
|||||||
# save first
|
# save first
|
||||||
self.manager.save_bar_data([bar])
|
self.manager.save_bar_data([bar])
|
||||||
|
|
||||||
# and load
|
self.assertBarCount(1, "there should be only one item after save")
|
||||||
results = self.manager.load_bar_data(
|
|
||||||
symbol=bar.symbol,
|
|
||||||
exchange=bar.exchange,
|
|
||||||
interval=bar.interval,
|
|
||||||
start=bar.datetime - timedelta(seconds=1), # time is not accuracy
|
|
||||||
end=now() + timedelta(seconds=1), # time is not accuracy
|
|
||||||
)
|
|
||||||
count = len(results)
|
|
||||||
self.assertNotEqual(count, 0)
|
|
||||||
|
|
||||||
def test_upsert_tick(self):
|
def test_upsert_tick(self):
|
||||||
for driver, settings in profiles.items():
|
for driver, settings in profiles.items():
|
||||||
@ -105,6 +130,7 @@ class TestDatabase(unittest.TestCase):
|
|||||||
self.connect(settings)
|
self.connect(settings)
|
||||||
self.manager.save_tick_data([tick])
|
self.manager.save_tick_data([tick])
|
||||||
self.manager.save_tick_data([tick])
|
self.manager.save_tick_data([tick])
|
||||||
|
self.assertTickCount(1, "there should be only one item after upsert")
|
||||||
|
|
||||||
def test_save_load_tick(self):
|
def test_save_load_tick(self):
|
||||||
for driver, settings in profiles.items():
|
for driver, settings in profiles.items():
|
||||||
@ -113,16 +139,48 @@ class TestDatabase(unittest.TestCase):
|
|||||||
# save first
|
# save first
|
||||||
self.manager.save_tick_data([tick])
|
self.manager.save_tick_data([tick])
|
||||||
|
|
||||||
# and load
|
self.assertTickCount(1, "there should be only one item after save")
|
||||||
results = self.manager.load_tick_data(
|
|
||||||
symbol=bar.symbol,
|
def test_newest_bar(self):
|
||||||
exchange=bar.exchange,
|
for driver, settings in profiles.items():
|
||||||
start=bar.datetime - timedelta(seconds=1), # time is not accuracy
|
with self.subTest(driver=driver, settings=settings):
|
||||||
end=now() + timedelta(seconds=1), # time is not accuracy
|
self.connect(settings)
|
||||||
|
|
||||||
|
# an older one
|
||||||
|
older_one = copy(bar)
|
||||||
|
older_one.volume = 123.0
|
||||||
|
older_one.datetime = now() - timedelta(days=1)
|
||||||
|
|
||||||
|
# and a newer one
|
||||||
|
newer_one = copy(bar)
|
||||||
|
newer_one.volume = 456.0
|
||||||
|
newer_one.datetime = now()
|
||||||
|
self.manager.save_bar_data([older_one, newer_one])
|
||||||
|
|
||||||
|
got = self.manager.get_newest_bar_data(
|
||||||
|
bar.symbol, bar.exchange, bar.interval
|
||||||
)
|
)
|
||||||
count = len(results)
|
self.assertEqual(got.volume, newer_one.volume, "the newest bar we got mismatched")
|
||||||
self.assertNotEqual(count, 0)
|
|
||||||
|
def test_newest_tick(self):
|
||||||
|
for driver, settings in profiles.items():
|
||||||
|
with self.subTest(driver=driver, settings=settings):
|
||||||
|
self.connect(settings)
|
||||||
|
|
||||||
|
# an older one
|
||||||
|
older_one = copy(tick)
|
||||||
|
older_one.volume = 123
|
||||||
|
older_one.datetime = now() - timedelta(days=1)
|
||||||
|
|
||||||
|
# and a newer one
|
||||||
|
newer_one = copy(tick)
|
||||||
|
older_one.volume = 456
|
||||||
|
newer_one.datetime = now()
|
||||||
|
self.manager.save_tick_data([older_one, newer_one])
|
||||||
|
|
||||||
|
got = self.manager.get_newest_tick_data(tick.symbol, tick.exchange)
|
||||||
|
self.assertEqual(got.volume, newer_one.volume, "the newest tick we got mismatched")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Sequence, TYPE_CHECKING, Optional
|
from typing import Optional, Sequence, TYPE_CHECKING
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vnpy.trader.constant import Interval, Exchange # noqa
|
from vnpy.trader.constant import Interval, Exchange # noqa
|
||||||
@ -76,3 +76,10 @@ class BaseDatabaseManager(ABC):
|
|||||||
otherwise, return None
|
otherwise, return None
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def clean(self, symbol: str):
|
||||||
|
"""
|
||||||
|
delete all records for a symbol
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
@ -322,8 +322,8 @@ class MongoManager(BaseDatabaseManager):
|
|||||||
.order_by("-datetime")
|
.order_by("-datetime")
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if len(s):
|
if s:
|
||||||
return list(s)[0]
|
return s.to_bar()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_newest_tick_data(
|
def get_newest_tick_data(
|
||||||
@ -334,6 +334,10 @@ class MongoManager(BaseDatabaseManager):
|
|||||||
.order_by("-datetime")
|
.order_by("-datetime")
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if len(s):
|
if s:
|
||||||
return list(s)[0]
|
return s.to_tick()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def clean(self, symbol: str):
|
||||||
|
DbTickData.objects(symbol=symbol).delete()
|
||||||
|
DbBarData.objects(symbol=symbol).delete()
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
""""""
|
""""""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List, Sequence, Type
|
from typing import List, Optional, Sequence, Type
|
||||||
|
|
||||||
from peewee import (
|
from peewee import (
|
||||||
AutoField,
|
AutoField,
|
||||||
@ -367,3 +367,40 @@ class SqlManager(BaseDatabaseManager):
|
|||||||
def save_tick_data(self, datas: Sequence[TickData]):
|
def save_tick_data(self, datas: Sequence[TickData]):
|
||||||
ds = [self.class_tick.from_tick(i) for i in datas]
|
ds = [self.class_tick.from_tick(i) for i in datas]
|
||||||
self.class_tick.save_all(ds)
|
self.class_tick.save_all(ds)
|
||||||
|
|
||||||
|
def get_newest_bar_data(
|
||||||
|
self, symbol: str, exchange: "Exchange", interval: "Interval"
|
||||||
|
) -> Optional["BarData"]:
|
||||||
|
s = (
|
||||||
|
self.class_bar.select()
|
||||||
|
.where(
|
||||||
|
(self.class_bar.symbol == symbol)
|
||||||
|
& (self.class_bar.exchange == exchange.value)
|
||||||
|
& (self.class_bar.interval == interval.value)
|
||||||
|
)
|
||||||
|
.order_by(self.class_bar.datetime.desc())
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if s:
|
||||||
|
return s.to_bar()
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_newest_tick_data(
|
||||||
|
self, symbol: str, exchange: "Exchange"
|
||||||
|
) -> Optional["TickData"]:
|
||||||
|
s = (
|
||||||
|
self.class_tick.select()
|
||||||
|
.where(
|
||||||
|
(self.class_tick.symbol == symbol)
|
||||||
|
& (self.class_tick.exchange == exchange.value)
|
||||||
|
)
|
||||||
|
.order_by(self.class_tick.datetime.desc())
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if s:
|
||||||
|
return s.to_tick()
|
||||||
|
return None
|
||||||
|
|
||||||
|
def clean(self, symbol: str):
|
||||||
|
self.class_bar.delete().where(self.class_bar.symbol == symbol).execute()
|
||||||
|
self.class_tick.delete().where(self.class_tick.symbol == symbol).execute()
|
||||||
|
@ -4,7 +4,7 @@ General utility functions.
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, TYPE_CHECKING
|
from typing import Callable
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import talib
|
import talib
|
||||||
|
Loading…
Reference in New Issue
Block a user