Merge pull request #1609 from nanoric/newest_data

[Add] DatabaseManager.get_newest_xxx_data
This commit is contained in:
vn.py 2019-04-18 13:27:41 +08:00 committed by GitHub
commit fe0e17763f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 244 additions and 65 deletions

View File

@ -75,6 +75,7 @@ matrix:
- sudo make install
- popd
- pip install numpy
- pip install --pre --extra-index-url https://rquser:ricequant99@py.ricequant.com/simple/ rqdatac
- pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
- python setup.py sdist
- pip install dist/`ls dist`

View File

@ -47,10 +47,7 @@ for:
- configuration: pip
build_script:
- pip install psycopg2 mongoengine pymysql # we should support all database in test environment
- pip install https://pip.vnpy.com/colletion/TA_Lib-0.4.17-cp37-cp37m-win_amd64.whl
- pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
- pip install -r requirements.txt
- pip install .
- call install.bat
- matrix:
only:
@ -58,6 +55,7 @@ for:
build_script:
- python setup.py sdist
- pip install psycopg2 mongoengine pymysql # we should support all database in test environment
- pip install --pre --extra-index-url https://rquser:ricequant99@py.ricequant.com/simple/ rqdatac
- pip install https://pip.vnpy.com/colletion/TA_Lib-0.4.17-cp37-cp37m-win_amd64.whl
- pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
- ps: $name=(ls dist).name; pip install "dist/$name"

View File

@ -1,3 +1,6 @@
:: rqdatac
pip install --pre --extra-index-url https://rquser:ricequant99@py.ricequant.com/simple/ rqdatac
::Install talib and ibapi
pip install https://pip.vnpy.com/colletion/TA_Lib-0.4.17-cp37-cp37m-win_amd64.whl
pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl

View File

@ -22,6 +22,7 @@ popd
$pip install numpy
# Install extra packages
$pip install --pre --extra-index-url https://rquser:ricequant99@py.ricequant.com/simple/ rqdatac
$pip install ta-lib
$pip install https://vnpy-pip.oss-cn-shanghai.aliyuncs.com/colletion/ibapi-9.75.1-py3-none-any.whl

View File

@ -11,6 +11,7 @@ matplotlib
seaborn
futu-api
tigeropen
rqdatac
ta-lib
ibapi
mongoengine

View File

@ -119,6 +119,7 @@ install_requires = [
"seaborn",
"futu-api",
"tigeropen",
"rqdatac",
"ta-lib",
"ibapi"
]

View File

@ -3,48 +3,46 @@ Test if database works fine
"""
import os
import unittest
from copy import copy
from datetime import datetime, timedelta
from vnpy.trader.constant import Exchange, Interval
from vnpy.trader.database.database import Driver
from vnpy.trader.object import BarData, TickData
os.environ['VNPY_TESTING'] = '1'
os.environ["VNPY_TESTING"] = "1"
profiles = {
Driver.SQLITE: {
"driver": "sqlite",
"database": "test_db.db",
}
}
if 'VNPY_TEST_ONLY_SQLITE' not in os.environ:
profiles.update({
Driver.MYSQL: {
"driver": "mysql",
"database": os.environ['VNPY_TEST_MYSQL_DATABASE'],
"host": os.environ['VNPY_TEST_MYSQL_HOST'],
"port": int(os.environ['VNPY_TEST_MYSQL_PORT']),
"user": os.environ["VNPY_TEST_MYSQL_USER"],
"password": os.environ['VNPY_TEST_MYSQL_PASSWORD'],
},
Driver.POSTGRESQL: {
"driver": "postgresql",
"database": os.environ['VNPY_TEST_POSTGRESQL_DATABASE'],
"host": os.environ['VNPY_TEST_POSTGRESQL_HOST'],
"port": int(os.environ['VNPY_TEST_POSTGRESQL_PORT']),
"user": os.environ["VNPY_TEST_POSTGRESQL_USER"],
"password": os.environ['VNPY_TEST_POSTGRESQL_PASSWORD'],
},
Driver.MONGODB: {
"driver": "mongodb",
"database": os.environ['VNPY_TEST_MONGODB_DATABASE'],
"host": os.environ['VNPY_TEST_MONGODB_HOST'],
"port": int(os.environ['VNPY_TEST_MONGODB_PORT']),
"user": "",
"password": "",
"authentication_source": "",
},
})
profiles = {Driver.SQLITE: {"driver": "sqlite", "database": "test_db.db"}}
if "VNPY_TEST_ONLY_SQLITE" not in os.environ:
profiles.update(
{
Driver.MYSQL: {
"driver": "mysql",
"database": os.environ["VNPY_TEST_MYSQL_DATABASE"],
"host": os.environ["VNPY_TEST_MYSQL_HOST"],
"port": int(os.environ["VNPY_TEST_MYSQL_PORT"]),
"user": os.environ["VNPY_TEST_MYSQL_USER"],
"password": os.environ["VNPY_TEST_MYSQL_PASSWORD"],
},
Driver.POSTGRESQL: {
"driver": "postgresql",
"database": os.environ["VNPY_TEST_POSTGRESQL_DATABASE"],
"host": os.environ["VNPY_TEST_POSTGRESQL_HOST"],
"port": int(os.environ["VNPY_TEST_POSTGRESQL_PORT"]),
"user": os.environ["VNPY_TEST_POSTGRESQL_USER"],
"password": os.environ["VNPY_TEST_POSTGRESQL_PASSWORD"],
},
Driver.MONGODB: {
"driver": "mongodb",
"database": os.environ["VNPY_TEST_MONGODB_DATABASE"],
"host": os.environ["VNPY_TEST_MONGODB_HOST"],
"port": int(os.environ["VNPY_TEST_MONGODB_PORT"]),
"user": "",
"password": "",
"authentication_source": "",
},
}
)
def now():
@ -72,14 +70,51 @@ class TestDatabase(unittest.TestCase):
def connect(self, settings: dict):
from vnpy.trader.database.initialize import init # noqa
self.manager = init(settings)
def assertBarCount(self, count, msg):
bars = self.manager.load_bar_data(
symbol=bar.symbol,
exchange=bar.exchange,
interval=bar.interval,
start=bar.datetime - timedelta(days=1),
end=now()
)
self.assertEqual(count, len(bars), msg)
def assertTickCount(self, count, msg):
ticks = self.manager.load_tick_data(
symbol=bar.symbol,
exchange=bar.exchange,
start=bar.datetime - timedelta(days=1),
end=now()
)
self.assertEqual(count, len(ticks), msg)
def assertNoData(self, msg):
self.assertBarCount(0, msg)
self.assertTickCount(0, msg)
def setUp(self) -> None:
# clean all data first
for driver, settings in profiles.items():
self.connect(settings)
self.manager.clean(bar.symbol)
self.assertNoData("Failed to clean data!")
self.manager = None
def tearDown(self) -> None:
self.manager.clean(bar.symbol)
self.assertNoData("Failed to clean data!")
def test_upsert_bar(self):
for driver, settings in profiles.items():
with self.subTest(driver=driver, settings=settings):
self.connect(settings)
self.manager.save_bar_data([bar])
self.manager.save_bar_data([bar])
self.assertBarCount(1, "there should be only one item after upsert")
def test_save_load_bar(self):
for driver, settings in profiles.items():
@ -88,16 +123,7 @@ class TestDatabase(unittest.TestCase):
# save first
self.manager.save_bar_data([bar])
# and load
results = self.manager.load_bar_data(
symbol=bar.symbol,
exchange=bar.exchange,
interval=bar.interval,
start=bar.datetime - timedelta(seconds=1), # time is not accuracy
end=now() + timedelta(seconds=1), # time is not accuracy
)
count = len(results)
self.assertNotEqual(count, 0)
self.assertBarCount(1, "there should be only one item after save")
def test_upsert_tick(self):
for driver, settings in profiles.items():
@ -105,6 +131,7 @@ class TestDatabase(unittest.TestCase):
self.connect(settings)
self.manager.save_tick_data([tick])
self.manager.save_tick_data([tick])
self.assertTickCount(1, "there should be only one item after upsert")
def test_save_load_tick(self):
for driver, settings in profiles.items():
@ -113,16 +140,58 @@ class TestDatabase(unittest.TestCase):
# save first
self.manager.save_tick_data([tick])
# and load
results = self.manager.load_tick_data(
symbol=bar.symbol,
exchange=bar.exchange,
start=bar.datetime - timedelta(seconds=1), # time is not accuracy
end=now() + timedelta(seconds=1), # time is not accuracy
self.assertTickCount(1, "there should be only one item after save")
def test_newest_bar(self):
for driver, settings in profiles.items():
with self.subTest(driver=driver, settings=settings):
self.connect(settings)
got = self.manager.get_newest_bar_data(bar.symbol, bar.exchange, bar.interval)
self.assertIsNone(
got,
"database is empty, but return value for newest_bar_data() is not a None"
)
count = len(results)
self.assertNotEqual(count, 0)
# an older one
older_one = copy(bar)
older_one.volume = 123.0
older_one.datetime = now() - timedelta(days=1)
# and a newer one
newer_one = copy(bar)
newer_one.volume = 456.0
newer_one.datetime = now()
self.manager.save_bar_data([older_one, newer_one])
got = self.manager.get_newest_bar_data(
bar.symbol, bar.exchange, bar.interval
)
self.assertEqual(got.volume, newer_one.volume, "the newest bar we got mismatched")
def test_newest_tick(self):
for driver, settings in profiles.items():
with self.subTest(driver=driver, settings=settings):
self.connect(settings)
got = self.manager.get_newest_tick_data(tick.symbol, tick.exchange)
self.assertIsNone(
got,
"database is empty, but return value for newest_tick_data() is not a None"
)
# an older one
older_one = copy(tick)
older_one.volume = 123
older_one.datetime = now() - timedelta(days=1)
# and a newer one
newer_one = copy(tick)
older_one.volume = 456
newer_one.datetime = now()
self.manager.save_tick_data([older_one, newer_one])
got = self.manager.get_newest_tick_data(tick.symbol, tick.exchange)
self.assertEqual(got.volume, newer_one.volume, "the newest tick we got mismatched")
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

View File

@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from datetime import datetime
from enum import Enum
from typing import Sequence, TYPE_CHECKING
from typing import Optional, Sequence, TYPE_CHECKING
if TYPE_CHECKING:
from vnpy.trader.constant import Interval, Exchange # noqa
@ -51,3 +51,35 @@ class BaseDatabaseManager(ABC):
datas: Sequence["TickData"],
):
pass
@abstractmethod
def get_newest_bar_data(
self,
symbol: str,
exchange: "Exchange",
interval: "Interval"
) -> Optional["BarData"]:
"""
If there is data in database, return the one with greatest datetime(newest one)
otherwise, return None
"""
pass
@abstractmethod
def get_newest_tick_data(
self,
symbol: str,
exchange: "Exchange",
) -> Optional["TickData"]:
"""
If there is data in database, return the one with greatest datetime(newest one)
otherwise, return None
"""
pass
@abstractmethod
def clean(self, symbol: str):
"""
delete all records for a symbol
"""
pass

View File

@ -1,6 +1,6 @@
from datetime import datetime
from enum import Enum
from typing import Sequence
from typing import Sequence, Optional
from mongoengine import DateTimeField, Document, FloatField, StringField, connect
@ -54,8 +54,10 @@ class DbBarData(Document):
meta = {
"indexes": [
{"fields": ("datetime", "interval", "symbol",
"exchange"), "unique": True}
{
"fields": ("datetime", "interval", "symbol", "exchange"),
"unique": True,
}
]
}
@ -145,8 +147,14 @@ class DbTickData(Document):
ask_volume_4: float = FloatField()
ask_volume_5: float = FloatField()
meta = {"indexes": [
{"fields": ("datetime", "symbol", "exchange"), "unique": True}]}
meta = {
"indexes": [
{
"fields": ("datetime", "symbol", "exchange"),
"unique": True,
}
],
}
@staticmethod
def from_tick(tick: TickData):
@ -305,3 +313,31 @@ class MongoManager(BaseDatabaseManager):
symbol=d.symbol, exchange=d.exchange.value, datetime=d.datetime
).update_one(upsert=True, **updates)
)
def get_newest_bar_data(
self, symbol: str, exchange: "Exchange", interval: "Interval"
) -> Optional["BarData"]:
s = (
DbBarData.objects(symbol=symbol, exchange=exchange.value)
.order_by("-datetime")
.first()
)
if s:
return s.to_bar()
return None
def get_newest_tick_data(
self, symbol: str, exchange: "Exchange"
) -> Optional["TickData"]:
s = (
DbTickData.objects(symbol=symbol, exchange=exchange.value)
.order_by("-datetime")
.first()
)
if s:
return s.to_tick()
return None
def clean(self, symbol: str):
DbTickData.objects(symbol=symbol).delete()
DbBarData.objects(symbol=symbol).delete()

View File

@ -1,6 +1,6 @@
""""""
from datetime import datetime
from typing import List, Sequence, Type
from typing import List, Optional, Sequence, Type
from peewee import (
AutoField,
@ -367,3 +367,40 @@ class SqlManager(BaseDatabaseManager):
def save_tick_data(self, datas: Sequence[TickData]):
ds = [self.class_tick.from_tick(i) for i in datas]
self.class_tick.save_all(ds)
def get_newest_bar_data(
self, symbol: str, exchange: "Exchange", interval: "Interval"
) -> Optional["BarData"]:
s = (
self.class_bar.select()
.where(
(self.class_bar.symbol == symbol)
& (self.class_bar.exchange == exchange.value)
& (self.class_bar.interval == interval.value)
)
.order_by(self.class_bar.datetime.desc())
.first()
)
if s:
return s.to_bar()
return None
def get_newest_tick_data(
self, symbol: str, exchange: "Exchange"
) -> Optional["TickData"]:
s = (
self.class_tick.select()
.where(
(self.class_tick.symbol == symbol)
& (self.class_tick.exchange == exchange.value)
)
.order_by(self.class_tick.datetime.desc())
.first()
)
if s:
return s.to_tick()
return None
def clean(self, symbol: str):
self.class_bar.delete().where(self.class_bar.symbol == symbol).execute()
self.class_tick.delete().where(self.class_tick.symbol == symbol).execute()