Merge pull request #1609 from nanoric/newest_data
[Add] DatabaseManager.get_newest_xxx_data
This commit is contained in:
commit
fe0e17763f
@ -75,6 +75,7 @@ matrix:
|
||||
- sudo make install
|
||||
- popd
|
||||
- pip install numpy
|
||||
- pip install --pre --extra-index-url https://rquser:ricequant99@py.ricequant.com/simple/ rqdatac
|
||||
- pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
|
||||
- python setup.py sdist
|
||||
- pip install dist/`ls dist`
|
||||
|
@ -47,10 +47,7 @@ for:
|
||||
- configuration: pip
|
||||
build_script:
|
||||
- pip install psycopg2 mongoengine pymysql # we should support all database in test environment
|
||||
- pip install https://pip.vnpy.com/colletion/TA_Lib-0.4.17-cp37-cp37m-win_amd64.whl
|
||||
- pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
|
||||
- pip install -r requirements.txt
|
||||
- pip install .
|
||||
- call install.bat
|
||||
|
||||
- matrix:
|
||||
only:
|
||||
@ -58,6 +55,7 @@ for:
|
||||
build_script:
|
||||
- python setup.py sdist
|
||||
- pip install psycopg2 mongoengine pymysql # we should support all database in test environment
|
||||
- pip install --pre --extra-index-url https://rquser:ricequant99@py.ricequant.com/simple/ rqdatac
|
||||
- pip install https://pip.vnpy.com/colletion/TA_Lib-0.4.17-cp37-cp37m-win_amd64.whl
|
||||
- pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
|
||||
- ps: $name=(ls dist).name; pip install "dist/$name"
|
||||
|
@ -1,3 +1,6 @@
|
||||
:: rqdatac
|
||||
pip install --pre --extra-index-url https://rquser:ricequant99@py.ricequant.com/simple/ rqdatac
|
||||
|
||||
::Install talib and ibapi
|
||||
pip install https://pip.vnpy.com/colletion/TA_Lib-0.4.17-cp37-cp37m-win_amd64.whl
|
||||
pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
|
||||
|
@ -22,6 +22,7 @@ popd
|
||||
$pip install numpy
|
||||
|
||||
# Install extra packages
|
||||
$pip install --pre --extra-index-url https://rquser:ricequant99@py.ricequant.com/simple/ rqdatac
|
||||
$pip install ta-lib
|
||||
$pip install https://vnpy-pip.oss-cn-shanghai.aliyuncs.com/colletion/ibapi-9.75.1-py3-none-any.whl
|
||||
|
||||
|
@ -11,6 +11,7 @@ matplotlib
|
||||
seaborn
|
||||
futu-api
|
||||
tigeropen
|
||||
rqdatac
|
||||
ta-lib
|
||||
ibapi
|
||||
mongoengine
|
||||
|
1
setup.py
1
setup.py
@ -119,6 +119,7 @@ install_requires = [
|
||||
"seaborn",
|
||||
"futu-api",
|
||||
"tigeropen",
|
||||
"rqdatac",
|
||||
"ta-lib",
|
||||
"ibapi"
|
||||
]
|
||||
|
@ -3,48 +3,46 @@ Test if database works fine
|
||||
"""
|
||||
import os
|
||||
import unittest
|
||||
from copy import copy
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from vnpy.trader.constant import Exchange, Interval
|
||||
from vnpy.trader.database.database import Driver
|
||||
from vnpy.trader.object import BarData, TickData
|
||||
|
||||
os.environ['VNPY_TESTING'] = '1'
|
||||
os.environ["VNPY_TESTING"] = "1"
|
||||
|
||||
profiles = {
|
||||
Driver.SQLITE: {
|
||||
"driver": "sqlite",
|
||||
"database": "test_db.db",
|
||||
}
|
||||
}
|
||||
if 'VNPY_TEST_ONLY_SQLITE' not in os.environ:
|
||||
profiles.update({
|
||||
Driver.MYSQL: {
|
||||
"driver": "mysql",
|
||||
"database": os.environ['VNPY_TEST_MYSQL_DATABASE'],
|
||||
"host": os.environ['VNPY_TEST_MYSQL_HOST'],
|
||||
"port": int(os.environ['VNPY_TEST_MYSQL_PORT']),
|
||||
"user": os.environ["VNPY_TEST_MYSQL_USER"],
|
||||
"password": os.environ['VNPY_TEST_MYSQL_PASSWORD'],
|
||||
},
|
||||
Driver.POSTGRESQL: {
|
||||
"driver": "postgresql",
|
||||
"database": os.environ['VNPY_TEST_POSTGRESQL_DATABASE'],
|
||||
"host": os.environ['VNPY_TEST_POSTGRESQL_HOST'],
|
||||
"port": int(os.environ['VNPY_TEST_POSTGRESQL_PORT']),
|
||||
"user": os.environ["VNPY_TEST_POSTGRESQL_USER"],
|
||||
"password": os.environ['VNPY_TEST_POSTGRESQL_PASSWORD'],
|
||||
},
|
||||
Driver.MONGODB: {
|
||||
"driver": "mongodb",
|
||||
"database": os.environ['VNPY_TEST_MONGODB_DATABASE'],
|
||||
"host": os.environ['VNPY_TEST_MONGODB_HOST'],
|
||||
"port": int(os.environ['VNPY_TEST_MONGODB_PORT']),
|
||||
"user": "",
|
||||
"password": "",
|
||||
"authentication_source": "",
|
||||
},
|
||||
})
|
||||
profiles = {Driver.SQLITE: {"driver": "sqlite", "database": "test_db.db"}}
|
||||
if "VNPY_TEST_ONLY_SQLITE" not in os.environ:
|
||||
profiles.update(
|
||||
{
|
||||
Driver.MYSQL: {
|
||||
"driver": "mysql",
|
||||
"database": os.environ["VNPY_TEST_MYSQL_DATABASE"],
|
||||
"host": os.environ["VNPY_TEST_MYSQL_HOST"],
|
||||
"port": int(os.environ["VNPY_TEST_MYSQL_PORT"]),
|
||||
"user": os.environ["VNPY_TEST_MYSQL_USER"],
|
||||
"password": os.environ["VNPY_TEST_MYSQL_PASSWORD"],
|
||||
},
|
||||
Driver.POSTGRESQL: {
|
||||
"driver": "postgresql",
|
||||
"database": os.environ["VNPY_TEST_POSTGRESQL_DATABASE"],
|
||||
"host": os.environ["VNPY_TEST_POSTGRESQL_HOST"],
|
||||
"port": int(os.environ["VNPY_TEST_POSTGRESQL_PORT"]),
|
||||
"user": os.environ["VNPY_TEST_POSTGRESQL_USER"],
|
||||
"password": os.environ["VNPY_TEST_POSTGRESQL_PASSWORD"],
|
||||
},
|
||||
Driver.MONGODB: {
|
||||
"driver": "mongodb",
|
||||
"database": os.environ["VNPY_TEST_MONGODB_DATABASE"],
|
||||
"host": os.environ["VNPY_TEST_MONGODB_HOST"],
|
||||
"port": int(os.environ["VNPY_TEST_MONGODB_PORT"]),
|
||||
"user": "",
|
||||
"password": "",
|
||||
"authentication_source": "",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def now():
|
||||
@ -72,14 +70,51 @@ class TestDatabase(unittest.TestCase):
|
||||
|
||||
def connect(self, settings: dict):
|
||||
from vnpy.trader.database.initialize import init # noqa
|
||||
|
||||
self.manager = init(settings)
|
||||
|
||||
def assertBarCount(self, count, msg):
|
||||
bars = self.manager.load_bar_data(
|
||||
symbol=bar.symbol,
|
||||
exchange=bar.exchange,
|
||||
interval=bar.interval,
|
||||
start=bar.datetime - timedelta(days=1),
|
||||
end=now()
|
||||
)
|
||||
self.assertEqual(count, len(bars), msg)
|
||||
|
||||
def assertTickCount(self, count, msg):
|
||||
ticks = self.manager.load_tick_data(
|
||||
symbol=bar.symbol,
|
||||
exchange=bar.exchange,
|
||||
start=bar.datetime - timedelta(days=1),
|
||||
end=now()
|
||||
)
|
||||
self.assertEqual(count, len(ticks), msg)
|
||||
|
||||
def assertNoData(self, msg):
|
||||
self.assertBarCount(0, msg)
|
||||
self.assertTickCount(0, msg)
|
||||
|
||||
def setUp(self) -> None:
|
||||
# clean all data first
|
||||
for driver, settings in profiles.items():
|
||||
self.connect(settings)
|
||||
self.manager.clean(bar.symbol)
|
||||
self.assertNoData("Failed to clean data!")
|
||||
self.manager = None
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.manager.clean(bar.symbol)
|
||||
self.assertNoData("Failed to clean data!")
|
||||
|
||||
def test_upsert_bar(self):
|
||||
for driver, settings in profiles.items():
|
||||
with self.subTest(driver=driver, settings=settings):
|
||||
self.connect(settings)
|
||||
self.manager.save_bar_data([bar])
|
||||
self.manager.save_bar_data([bar])
|
||||
self.assertBarCount(1, "there should be only one item after upsert")
|
||||
|
||||
def test_save_load_bar(self):
|
||||
for driver, settings in profiles.items():
|
||||
@ -88,16 +123,7 @@ class TestDatabase(unittest.TestCase):
|
||||
# save first
|
||||
self.manager.save_bar_data([bar])
|
||||
|
||||
# and load
|
||||
results = self.manager.load_bar_data(
|
||||
symbol=bar.symbol,
|
||||
exchange=bar.exchange,
|
||||
interval=bar.interval,
|
||||
start=bar.datetime - timedelta(seconds=1), # time is not accuracy
|
||||
end=now() + timedelta(seconds=1), # time is not accuracy
|
||||
)
|
||||
count = len(results)
|
||||
self.assertNotEqual(count, 0)
|
||||
self.assertBarCount(1, "there should be only one item after save")
|
||||
|
||||
def test_upsert_tick(self):
|
||||
for driver, settings in profiles.items():
|
||||
@ -105,6 +131,7 @@ class TestDatabase(unittest.TestCase):
|
||||
self.connect(settings)
|
||||
self.manager.save_tick_data([tick])
|
||||
self.manager.save_tick_data([tick])
|
||||
self.assertTickCount(1, "there should be only one item after upsert")
|
||||
|
||||
def test_save_load_tick(self):
|
||||
for driver, settings in profiles.items():
|
||||
@ -113,16 +140,58 @@ class TestDatabase(unittest.TestCase):
|
||||
# save first
|
||||
self.manager.save_tick_data([tick])
|
||||
|
||||
# and load
|
||||
results = self.manager.load_tick_data(
|
||||
symbol=bar.symbol,
|
||||
exchange=bar.exchange,
|
||||
start=bar.datetime - timedelta(seconds=1), # time is not accuracy
|
||||
end=now() + timedelta(seconds=1), # time is not accuracy
|
||||
self.assertTickCount(1, "there should be only one item after save")
|
||||
|
||||
def test_newest_bar(self):
|
||||
for driver, settings in profiles.items():
|
||||
with self.subTest(driver=driver, settings=settings):
|
||||
self.connect(settings)
|
||||
got = self.manager.get_newest_bar_data(bar.symbol, bar.exchange, bar.interval)
|
||||
self.assertIsNone(
|
||||
got,
|
||||
"database is empty, but return value for newest_bar_data() is not a None"
|
||||
)
|
||||
count = len(results)
|
||||
self.assertNotEqual(count, 0)
|
||||
|
||||
# an older one
|
||||
older_one = copy(bar)
|
||||
older_one.volume = 123.0
|
||||
older_one.datetime = now() - timedelta(days=1)
|
||||
|
||||
# and a newer one
|
||||
newer_one = copy(bar)
|
||||
newer_one.volume = 456.0
|
||||
newer_one.datetime = now()
|
||||
self.manager.save_bar_data([older_one, newer_one])
|
||||
|
||||
got = self.manager.get_newest_bar_data(
|
||||
bar.symbol, bar.exchange, bar.interval
|
||||
)
|
||||
self.assertEqual(got.volume, newer_one.volume, "the newest bar we got mismatched")
|
||||
|
||||
def test_newest_tick(self):
|
||||
for driver, settings in profiles.items():
|
||||
with self.subTest(driver=driver, settings=settings):
|
||||
self.connect(settings)
|
||||
|
||||
got = self.manager.get_newest_tick_data(tick.symbol, tick.exchange)
|
||||
self.assertIsNone(
|
||||
got,
|
||||
"database is empty, but return value for newest_tick_data() is not a None"
|
||||
)
|
||||
# an older one
|
||||
older_one = copy(tick)
|
||||
older_one.volume = 123
|
||||
older_one.datetime = now() - timedelta(days=1)
|
||||
|
||||
# and a newer one
|
||||
newer_one = copy(tick)
|
||||
older_one.volume = 456
|
||||
newer_one.datetime = now()
|
||||
self.manager.save_tick_data([older_one, newer_one])
|
||||
|
||||
got = self.manager.get_newest_tick_data(tick.symbol, tick.exchange)
|
||||
self.assertEqual(got.volume, newer_one.volume, "the newest tick we got mismatched")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -1,7 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Sequence, TYPE_CHECKING
|
||||
from typing import Optional, Sequence, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vnpy.trader.constant import Interval, Exchange # noqa
|
||||
@ -51,3 +51,35 @@ class BaseDatabaseManager(ABC):
|
||||
datas: Sequence["TickData"],
|
||||
):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_newest_bar_data(
|
||||
self,
|
||||
symbol: str,
|
||||
exchange: "Exchange",
|
||||
interval: "Interval"
|
||||
) -> Optional["BarData"]:
|
||||
"""
|
||||
If there is data in database, return the one with greatest datetime(newest one)
|
||||
otherwise, return None
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_newest_tick_data(
|
||||
self,
|
||||
symbol: str,
|
||||
exchange: "Exchange",
|
||||
) -> Optional["TickData"]:
|
||||
"""
|
||||
If there is data in database, return the one with greatest datetime(newest one)
|
||||
otherwise, return None
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def clean(self, symbol: str):
|
||||
"""
|
||||
delete all records for a symbol
|
||||
"""
|
||||
pass
|
||||
|
@ -1,6 +1,6 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Sequence
|
||||
from typing import Sequence, Optional
|
||||
|
||||
from mongoengine import DateTimeField, Document, FloatField, StringField, connect
|
||||
|
||||
@ -54,8 +54,10 @@ class DbBarData(Document):
|
||||
|
||||
meta = {
|
||||
"indexes": [
|
||||
{"fields": ("datetime", "interval", "symbol",
|
||||
"exchange"), "unique": True}
|
||||
{
|
||||
"fields": ("datetime", "interval", "symbol", "exchange"),
|
||||
"unique": True,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@ -145,8 +147,14 @@ class DbTickData(Document):
|
||||
ask_volume_4: float = FloatField()
|
||||
ask_volume_5: float = FloatField()
|
||||
|
||||
meta = {"indexes": [
|
||||
{"fields": ("datetime", "symbol", "exchange"), "unique": True}]}
|
||||
meta = {
|
||||
"indexes": [
|
||||
{
|
||||
"fields": ("datetime", "symbol", "exchange"),
|
||||
"unique": True,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def from_tick(tick: TickData):
|
||||
@ -305,3 +313,31 @@ class MongoManager(BaseDatabaseManager):
|
||||
symbol=d.symbol, exchange=d.exchange.value, datetime=d.datetime
|
||||
).update_one(upsert=True, **updates)
|
||||
)
|
||||
|
||||
def get_newest_bar_data(
|
||||
self, symbol: str, exchange: "Exchange", interval: "Interval"
|
||||
) -> Optional["BarData"]:
|
||||
s = (
|
||||
DbBarData.objects(symbol=symbol, exchange=exchange.value)
|
||||
.order_by("-datetime")
|
||||
.first()
|
||||
)
|
||||
if s:
|
||||
return s.to_bar()
|
||||
return None
|
||||
|
||||
def get_newest_tick_data(
|
||||
self, symbol: str, exchange: "Exchange"
|
||||
) -> Optional["TickData"]:
|
||||
s = (
|
||||
DbTickData.objects(symbol=symbol, exchange=exchange.value)
|
||||
.order_by("-datetime")
|
||||
.first()
|
||||
)
|
||||
if s:
|
||||
return s.to_tick()
|
||||
return None
|
||||
|
||||
def clean(self, symbol: str):
|
||||
DbTickData.objects(symbol=symbol).delete()
|
||||
DbBarData.objects(symbol=symbol).delete()
|
||||
|
@ -1,6 +1,6 @@
|
||||
""""""
|
||||
from datetime import datetime
|
||||
from typing import List, Sequence, Type
|
||||
from typing import List, Optional, Sequence, Type
|
||||
|
||||
from peewee import (
|
||||
AutoField,
|
||||
@ -367,3 +367,40 @@ class SqlManager(BaseDatabaseManager):
|
||||
def save_tick_data(self, datas: Sequence[TickData]):
|
||||
ds = [self.class_tick.from_tick(i) for i in datas]
|
||||
self.class_tick.save_all(ds)
|
||||
|
||||
def get_newest_bar_data(
|
||||
self, symbol: str, exchange: "Exchange", interval: "Interval"
|
||||
) -> Optional["BarData"]:
|
||||
s = (
|
||||
self.class_bar.select()
|
||||
.where(
|
||||
(self.class_bar.symbol == symbol)
|
||||
& (self.class_bar.exchange == exchange.value)
|
||||
& (self.class_bar.interval == interval.value)
|
||||
)
|
||||
.order_by(self.class_bar.datetime.desc())
|
||||
.first()
|
||||
)
|
||||
if s:
|
||||
return s.to_bar()
|
||||
return None
|
||||
|
||||
def get_newest_tick_data(
|
||||
self, symbol: str, exchange: "Exchange"
|
||||
) -> Optional["TickData"]:
|
||||
s = (
|
||||
self.class_tick.select()
|
||||
.where(
|
||||
(self.class_tick.symbol == symbol)
|
||||
& (self.class_tick.exchange == exchange.value)
|
||||
)
|
||||
.order_by(self.class_tick.datetime.desc())
|
||||
.first()
|
||||
)
|
||||
if s:
|
||||
return s.to_tick()
|
||||
return None
|
||||
|
||||
def clean(self, symbol: str):
|
||||
self.class_bar.delete().where(self.class_bar.symbol == symbol).execute()
|
||||
self.class_tick.delete().where(self.class_tick.symbol == symbol).execute()
|
||||
|
Loading…
Reference in New Issue
Block a user