[Add] added support for mongodb

[Add] added tests for database and CsvLoaderEngine
[Mod] changed .travis.yaml
This commit is contained in:
nanoric 2019-04-13 23:02:30 -04:00
parent f797152bd5
commit 5d0bf006c6
18 changed files with 892 additions and 245 deletions

View File

@ -1,13 +1,19 @@
language: python
cache: pip
dist: xenial # required for Python >= 3.7 (travis-ci/travis-ci#9069)
python:
- "3.7"
services:
- mongodb
- mysql
- postgresql
script:
# todo: use python unittest
- mkdir run; cd run; python ../tests/load_all.py
- cd tests; source travis_env.sh; python test_all.py
matrix:
include:

1
tests/app/__init__.py Normal file
View File

@ -0,0 +1 @@
from .test_csv_loader import *

20
tests/test_all.py Normal file
View File

@ -0,0 +1,20 @@
# tests/runner.py
import unittest
# import your test modules
import test_import_all
import trader
import app
# initialize the test suite
loader = unittest.TestLoader()
suite = unittest.TestSuite()
# add tests to the test suite
suite.addTests(loader.loadTestsFromModule(test_import_all))
suite.addTests(loader.loadTestsFromModule(trader))
suite.addTests(loader.loadTestsFromModule(app))
# initialize a runner, pass it your suite and run it
runner = unittest.TextTestRunner(verbosity=3)
result = runner.run(suite)

View File

@ -2,23 +2,41 @@
import unittest
# noinspection PyUnresolvedReferences
class ImportTest(unittest.TestCase):
# noinspection PyUnresolvedReferences
def test_import_all(self):
from vnpy.event import EventEngine
def test_import_main_engine(self):
from vnpy.trader.engine import MainEngine
def test_import_ui(self):
from vnpy.trader.ui import MainWindow, create_qapp
def test_import_bitmex_gateway(self):
from vnpy.gateway.bitmex import BitmexGateway
def test_import_futu_gateway(self):
from vnpy.gateway.futu import FutuGateway
def test_import_ib_gateway(self):
from vnpy.gateway.ib import IbGateway
def test_import_ctp_gateway(self):
from vnpy.gateway.ctp import CtpGateway
def test_import_tiger_gateway(self):
from vnpy.gateway.tiger import TigerGateway
def test_import_oes_gateway(self):
from vnpy.gateway.oes import OesGateway
def test_import_cta_strategy_app(self):
from vnpy.app.cta_strategy import CtaStrategyApp
def test_import_csv_loader_app(self):
from vnpy.app.csv_loader import CsvLoaderApp

2
tests/trader/__init__.py Normal file
View File

@ -0,0 +1,2 @@
from .test_database import *
from .test_settings import *

View File

@ -0,0 +1,125 @@
"""
Test if database works fine
"""
import os
import unittest
from datetime import datetime, timedelta
from vnpy.trader.constant import Exchange, Interval
from vnpy.trader.database.database import Driver
from vnpy.trader.object import BarData, TickData
os.environ['VNPY_TESTING'] = '1'
profiles = {
Driver.SQLITE: {
"driver": "sqlite",
"database": "test_db.db",
},
Driver.MYSQL: {
"driver": "mysql",
"database": os.environ['VNPY_TEST_MYSQL_DATABASE'],
"host": os.environ['VNPY_TEST_MYSQL_HOST'],
"port": int(os.environ['VNPY_TEST_MYSQL_PORT']),
"user": os.environ["VNPY_TEST_MYSQL_USER"],
"password": os.environ['VNPY_TEST_MYSQL_PASSWORD'],
},
Driver.POSTGRESQL: {
"driver": "postgresql",
"database": os.environ['VNPY_TEST_POSTGRESQL_DATABASE'],
"host": os.environ['VNPY_TEST_POSTGRESQL_HOST'],
"port": int(os.environ['VNPY_TEST_POSTGRESQL_PORT']),
"user": os.environ["VNPY_TEST_POSTGRESQL_USER"],
"password": os.environ['VNPY_TEST_POSTGRESQL_PASSWORD'],
},
Driver.MONGODB: {
"driver": "mongodb",
"database": os.environ['VNPY_TEST_MONGODB_DATABASE'],
"host": os.environ['VNPY_TEST_MONGODB_HOST'],
"port": int(os.environ['VNPY_TEST_MONGODB_PORT']),
"user": "",
"password": "",
"authentication_source": "",
},
}
def now():
return datetime.utcnow()
bar = BarData(
gateway_name="DB",
symbol="test_symbol",
exchange=Exchange.BITMEX,
datetime=now(),
interval=Interval.MINUTE,
)
tick = TickData(
gateway_name="DB",
symbol="test_symbol",
exchange=Exchange.BITMEX,
datetime=now(),
name="DB_test_symbol",
)
class TestDatabase(unittest.TestCase):
def connect(self, settings: dict):
from vnpy.trader.database.initialize import init # noqa
self.manager = init(settings)
def test_upsert_bar(self):
for driver, settings in profiles.items():
with self.subTest(driver=driver, settings=settings):
self.connect(settings)
self.manager.save_bar_data([bar])
self.manager.save_bar_data([bar])
def test_save_load_bar(self):
for driver, settings in profiles.items():
with self.subTest(driver=driver, settings=settings):
self.connect(settings)
# save first
self.manager.save_bar_data([bar])
# and load
results = self.manager.load_bar_data(
symbol=bar.symbol,
exchange=bar.exchange,
interval=bar.interval,
start=bar.datetime - timedelta(seconds=1), # time is not accuracy
end=now() + timedelta(seconds=1), # time is not accuracy
)
count = len(results)
self.assertNotEqual(count, 0)
def test_upsert_tick(self):
for driver, settings in profiles.items():
with self.subTest(driver=driver, settings=settings):
self.connect(settings)
self.manager.save_tick_data([tick])
self.manager.save_tick_data([tick])
def test_save_load_tick(self):
for driver, settings in profiles.items():
with self.subTest(driver=driver, settings=settings):
self.connect(settings)
# save first
self.manager.save_tick_data([tick])
# and load
results = self.manager.load_tick_data(
symbol=bar.symbol,
exchange=bar.exchange,
start=bar.datetime - timedelta(seconds=1), # time is not accuracy
end=now() + timedelta(seconds=1), # time is not accuracy
)
count = len(results)
self.assertNotEqual(count, 0)
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,24 @@
"""
Test if database works fine
"""
import unittest
from vnpy.trader.setting import SETTINGS, get_settings
class TestSettings(unittest.TestCase):
def test_get_settings(self):
SETTINGS['a'] = 1
got = get_settings()
self.assertIn('a', got)
self.assertEqual(got['a'], 1)
def test_get_settings_with_prefix(self):
SETTINGS['a.a'] = 1
got = get_settings()
self.assertIn('a', got)
self.assertEqual(got['a'], 1)
if __name__ == '__main__':
unittest.main()

16
tests/travis_env.sh Normal file
View File

@ -0,0 +1,16 @@
#!/usr/bin/env bash
export VNPY_TEST_MYSQL_DATABASE=vnpy
export VNPY_TEST_MYSQL_HOST=127.0.0.1
export VNPY_TEST_MYSQL_PORT=3306
export VNPY_TEST_MYSQL_USER=root
export VNPY_TEST_MYSQL_PASSWORD=
export VNPY_TEST_POSTGRESQL_DATABASE=vnpy
export VNPY_TEST_POSTGRESQL_HOST=127.0.0.1
export VNPY_TEST_POSTGRESQL_PORT=5432
export VNPY_TEST_POSTGRESQL_USER=postgres
export VNPY_TEST_POSTGRESQL_PASSWORD=
export VNPY_TEST_MONGODB_DATABASE=vnpy
export VNPY_TEST_MONGODB_HOST=127.0.0.1
export VNPY_TEST_MONGODB_PORT=27017

View File

@ -26,8 +26,9 @@ from typing import TextIO
from vnpy.event import EventEngine
from vnpy.trader.constant import Exchange, Interval
from vnpy.trader.database import DbBarData
from vnpy.trader.database import database_manager
from vnpy.trader.engine import BaseEngine, MainEngine
from vnpy.trader.object import BarData
APP_NAME = "CsvLoader"
@ -70,7 +71,7 @@ class CsvLoaderEngine(BaseEngine):
"""
reader = csv.DictReader(f)
db_bars = []
bars = []
start = None
count = 0
for item in reader:
@ -79,29 +80,29 @@ class CsvLoaderEngine(BaseEngine):
else:
dt = datetime.fromisoformat(item[datetime_head])
db_bar = DbBarData(
bar = BarData(
symbol=symbol,
exchange=exchange.value,
exchange=exchange,
datetime=dt,
interval=interval.value,
interval=interval,
volume=item[volume_head],
open_price=item[open_head],
high_price=item[high_head],
low_price=item[low_head],
close_price=item[close_head],
gateway_name="DB",
)
db_bars.append(db_bar)
bars.append(bar)
# do some statistics
count += 1
if not start:
start = db_bar.datetime
end = db_bar.datetime
start = bar.datetime
end = bar.datetime
# insert into database
DbBarData.save_all(db_bars)
database_manager.save_bar_data(bars)
return start, end, count
def load(

View File

@ -12,9 +12,9 @@ from pandas import DataFrame
from vnpy.trader.constant import (Direction, Offset, Exchange,
Interval, Status)
from vnpy.trader.database import DbBarData, DbTickData
from vnpy.trader.object import OrderData, TradeData
from vnpy.trader.utility import round_to_pricetick
from vnpy.trader.database import database_manager
from vnpy.trader.object import OrderData, TradeData, BarData, TickData
from vnpy.trader.utility import round_to_pricetick, extract_vt_symbol
from .base import (
BacktestingMode,
@ -103,8 +103,8 @@ class BacktestingEngine:
self.strategy_class = None
self.strategy = None
self.tick = None
self.bar = None
self.tick: TickData
self.bar: BarData
self.datetime = None
self.interval = None
@ -199,14 +199,16 @@ class BacktestingEngine:
if self.mode == BacktestingMode.BAR:
self.history_data = load_bar_data(
self.vt_symbol,
self.symbol,
self.exchange,
self.interval,
self.start,
self.end
)
else:
self.history_data = load_tick_data(
self.vt_symbol,
self.symbol,
self.exchange,
self.start,
self.end
)
@ -519,7 +521,7 @@ class BacktestingEngine:
else:
self.daily_results[d] = DailyResult(d, price)
def new_bar(self, bar: DbBarData):
def new_bar(self, bar: BarData):
""""""
self.bar = bar
self.datetime = bar.datetime
@ -530,7 +532,7 @@ class BacktestingEngine:
self.update_daily_close(bar.close_price)
def new_tick(self, tick: DbTickData):
def new_tick(self, tick: TickData):
""""""
self.tick = tick
self.datetime = tick.datetime
@ -965,41 +967,27 @@ def optimize(
@lru_cache(maxsize=10)
def load_bar_data(
vt_symbol: str,
interval: str,
symbol: str,
exchange: Exchange,
interval: Interval,
start: datetime,
end: datetime
):
""""""
s = (
DbBarData.select()
.where(
(DbBarData.vt_symbol == vt_symbol)
& (DbBarData.interval == interval)
& (DbBarData.datetime >= start)
& (DbBarData.datetime <= end)
return database_manager.load_bar_data(
symbol, exchange, interval, start, end
)
.order_by(DbBarData.datetime)
)
data = [db_bar.to_bar() for db_bar in s]
return data
@lru_cache(maxsize=10)
def load_tick_data(
vt_symbol: str,
symbol: str,
exchange: Exchange,
start: datetime,
end: datetime
):
""""""
s = (
DbTickData.select()
.where(
(DbTickData.vt_symbol == vt_symbol)
& (DbTickData.datetime >= start)
& (DbTickData.datetime <= end)
return database_manager.load_tick_data(
symbol, exchange, start, end
)
.order_by(DbTickData.datetime)
)
data = [db_tick.db_tick() for db_tick in s]
return data

View File

@ -5,7 +5,7 @@ import os
import traceback
from collections import defaultdict
from pathlib import Path
from typing import Any, Callable
from typing import Any, Callable, List
from datetime import datetime, timedelta
from threading import Thread
from queue import Queue
@ -36,7 +36,7 @@ from vnpy.trader.constant import (
Status
)
from vnpy.trader.utility import load_json, save_json
from vnpy.trader.database import DbTickData, DbBarData
from vnpy.trader.database import database_manager
from vnpy.trader.setting import SETTINGS
from .base import (
@ -146,13 +146,12 @@ class CtaEngine(BaseEngine):
self.write_log("RQData数据接口初始化成功")
def query_bar_from_rq(
self, vt_symbol: str, interval: Interval, start: datetime, end: datetime
self, symbol: str, exchange: Exchange, interval: Interval, start: datetime, end: datetime
):
"""
Query bar data from RQData.
"""
symbol, exchange_str = vt_symbol.split(".")
rq_symbol = to_rq_symbol(vt_symbol)
rq_symbol = to_rq_symbol(symbol, exchange)
if rq_symbol not in self.rq_symbols:
return None
@ -166,11 +165,11 @@ class CtaEngine(BaseEngine):
end_date=end
)
data = []
data: List[BarData] = []
for ix, row in df.iterrows():
bar = BarData(
symbol=symbol,
exchange=Exchange(exchange_str),
exchange=exchange,
interval=interval,
datetime=row.name.to_pydatetime(),
open_price=row["open"],
@ -529,46 +528,41 @@ class CtaEngine(BaseEngine):
return self.engine_type
def load_bar(
self, vt_symbol: str, days: int, interval: Interval, callback: Callable
self, symbol: str, exchange: Exchange, days: int, interval: Interval,
callback: Callable[[BarData], None]
):
""""""
end = datetime.now()
start = end - timedelta(days)
# Query data from RQData by default, if not found, load from database.
data = self.query_bar_from_rq(vt_symbol, interval, start, end)
if not data:
s = (
DbBarData.select()
.where(
(DbBarData.vt_symbol == vt_symbol)
& (DbBarData.interval == interval.value)
& (DbBarData.datetime >= start)
& (DbBarData.datetime <= end)
# Query bars from RQData by default, if not found, load from database.
bars = self.query_bar_from_rq(symbol, exchange, interval, start, end)
if not bars:
bars = database_manager.load_bar_data(
symbol=symbol,
exchange=exchange,
interval=interval,
start=start,
end=end,
)
.order_by(DbBarData.datetime)
)
data = [db_bar.to_bar() for db_bar in s]
for bar in data:
for bar in bars:
callback(bar)
def load_tick(self, vt_symbol: str, days: int, callback: Callable):
def load_tick(self, symbol: str, exchange: Exchange, days: int,
callback: Callable[[TickData], None]):
""""""
end = datetime.now()
start = end - timedelta(days)
s = (
DbTickData.select()
.where(
(DbBarData.vt_symbol == vt_symbol)
& (DbBarData.datetime >= start)
& (DbBarData.datetime <= end)
)
.order_by(DbBarData.datetime)
ticks = database_manager.load_tick_data(
symbol=symbol,
exchange=exchange,
start=start,
end=end,
)
for tick in s:
for tick in ticks:
callback(tick)
def call_strategy_func(
@ -757,7 +751,7 @@ class CtaEngine(BaseEngine):
"""
Load strategy class from certain folder.
"""
for dirpath, dirnames, filenames in os.walk(path):
for dirpath, dirnames, filenames in os.walk(str(path)):
for filename in filenames:
if filename.endswith(".py"):
strategy_module_name = ".".join(
@ -914,19 +908,19 @@ class CtaEngine(BaseEngine):
self.main_engine.send_email(subject, msg)
def to_rq_symbol(vt_symbol: str):
def to_rq_symbol(symbol: str, exchange: Exchange):
"""
CZCE product of RQData has symbol like "TA1905" while
vt symbol is "TA905.CZCE" so need to add "1" in symbol.
"""
symbol, exchange_str = vt_symbol.split(".")
if exchange_str != "CZCE":
if exchange is not Exchange.CZCE:
return symbol.upper()
for count, word in enumerate(symbol):
if word.isdigit():
break
# noinspection PyUnboundLocalVariable
product = symbol[:count]
year = symbol[count]
month = symbol[count + 1:]

View File

@ -0,0 +1,12 @@
import os
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vnpy.trader.database.database import BaseDatabaseManager
if "VNPY_TESTING" not in os.environ:
from vnpy.trader.setting import get_settings
from .initialize import init
settings = get_settings("database.")
database_manager: "BaseDatabaseManager" = init(settings=settings)

View File

@ -0,0 +1,53 @@
from abc import ABC, abstractmethod
from datetime import datetime
from enum import Enum
from typing import Sequence, TYPE_CHECKING
if TYPE_CHECKING:
from vnpy.trader.constant import Interval, Exchange
from vnpy.trader.object import BarData, TickData
class Driver(Enum):
SQLITE = "sqlite"
MYSQL = "mysql"
POSTGRESQL = "postgresql"
MONGODB = "mongodb"
class BaseDatabaseManager(ABC):
@abstractmethod
def load_bar_data(
self,
symbol: str,
exchange: "Exchange",
interval: "Interval",
start: datetime,
end: datetime
) -> Sequence["BarData"]:
pass
@abstractmethod
def load_tick_data(
self,
symbol: str,
exchange: "Exchange",
start: datetime,
end: datetime
) -> Sequence["TickData"]:
pass
@abstractmethod
def save_bar_data(
self,
datas: Sequence["BarData"],
):
pass
@abstractmethod
def save_tick_data(
self,
datas: Sequence["TickData"],
):
pass

View File

@ -1,99 +1,59 @@
""""""
from datetime import datetime
from enum import Enum
from typing import List
from typing import Sequence
from peewee import (
AutoField,
CharField,
Database,
DateTimeField,
FloatField,
Model,
MySQLDatabase,
PostgresqlDatabase,
SqliteDatabase,
chunked,
)
from mongoengine import DateTimeField, Document, FloatField, StringField, connect
from .constant import Exchange, Interval
from .object import BarData, TickData
from .setting import SETTINGS
from .utility import get_file_path
from vnpy.trader.constant import Exchange, Interval
from vnpy.trader.object import BarData, TickData
from .database import BaseDatabaseManager, Driver
class Driver(Enum):
SQLITE = "sqlite"
MYSQL = "mysql"
POSTGRESQL = "postgresql"
_db: Database
_driver: Driver
def init():
global _driver
db_settings = {k[9:]: v for k, v in SETTINGS.items() if k.startswith("database.")}
_driver = Driver(db_settings["driver"])
init_funcs = {
Driver.SQLITE: init_sqlite,
Driver.MYSQL: init_mysql,
Driver.POSTGRESQL: init_postgresql,
}
assert _driver in init_funcs
del db_settings["driver"]
return init_funcs[_driver](db_settings)
def init_sqlite(settings: dict):
global _db
def init(_: Driver, settings: dict):
database = settings["database"]
_db = SqliteDatabase(str(get_file_path(database)))
host = settings["host"]
port = settings["port"]
username = settings["user"]
password = settings["password"]
authentication_source = settings["authentication_source"]
if not username: # if username == '' or None, skip username
username = None
password = None
authentication_source = None
connect(
db=database,
host=host,
port=port,
username=username,
password=password,
authentication_source=authentication_source,
)
return MongoManager()
def init_mysql(settings: dict):
global _db
_db = MySQLDatabase(**settings)
def init_postgresql(settings: dict):
global _db
_db = PostgresqlDatabase(**settings)
init()
class ModelBase(Model):
def to_dict(self):
return self.__data__
class DbBarData(ModelBase):
class DbBarData(Document):
"""
Candlestick bar data for database storage.
Index is defined unique with datetime, interval, symbol
"""
id = AutoField()
symbol = CharField()
exchange = CharField()
datetime = DateTimeField()
interval = CharField()
symbol: str = StringField()
exchange: str = StringField()
datetime: datetime = DateTimeField()
interval: str = StringField()
volume = FloatField()
open_price = FloatField()
high_price = FloatField()
low_price = FloatField()
close_price = FloatField()
volume: float = FloatField()
open_price: float = FloatField()
high_price: float = FloatField()
low_price: float = FloatField()
close_price: float = FloatField()
class Meta:
database = _db
indexes = ((("datetime", "interval", "symbol"), True),)
meta = {
"indexes": [
{"fields": ("datetime", "interval", "symbol", "exchange"), "unique": True}
]
}
@staticmethod
def from_bar(bar: BarData):
@ -132,80 +92,56 @@ class DbBarData(ModelBase):
)
return bar
@staticmethod
def save_all(objs: List["DbBarData"]):
"""
save a list of objects, update if exists.
"""
with _db.atomic():
if _driver is Driver.POSTGRESQL:
for bar in objs:
DbBarData.insert(bar.to_dict()).on_conflict(
update=bar.to_dict(),
conflict_target=(
DbBarData.datetime,
DbBarData.interval,
DbBarData.symbol,
),
).execute()
else:
for c in chunked(objs, 50):
DbBarData.insert_many(c).on_conflict_replace()
class DbTickData(ModelBase):
class DbTickData(Document):
"""
Tick data for database storage.
Index is defined unique with (datetime, symbol)
"""
id = AutoField()
symbol: str = StringField()
exchange: str = StringField()
datetime: datetime = DateTimeField()
symbol = CharField()
exchange = CharField()
datetime = DateTimeField()
name: str = StringField()
volume: float = FloatField()
last_price: float = FloatField()
last_volume: float = FloatField()
limit_up: float = FloatField()
limit_down: float = FloatField()
name = CharField()
volume = FloatField()
last_price = FloatField()
last_volume = FloatField()
limit_up = FloatField()
limit_down = FloatField()
open_price: float = FloatField()
high_price: float = FloatField()
low_price: float = FloatField()
close_price: float = FloatField()
pre_close: float = FloatField()
open_price = FloatField()
high_price = FloatField()
low_price = FloatField()
close_price = FloatField()
pre_close = FloatField()
bid_price_1: float = FloatField()
bid_price_2: float = FloatField()
bid_price_3: float = FloatField()
bid_price_4: float = FloatField()
bid_price_5: float = FloatField()
bid_price_1 = FloatField()
bid_price_2 = FloatField()
bid_price_3 = FloatField()
bid_price_4 = FloatField()
bid_price_5 = FloatField()
ask_price_1: float = FloatField()
ask_price_2: float = FloatField()
ask_price_3: float = FloatField()
ask_price_4: float = FloatField()
ask_price_5: float = FloatField()
ask_price_1 = FloatField()
ask_price_2 = FloatField()
ask_price_3 = FloatField()
ask_price_4 = FloatField()
ask_price_5 = FloatField()
bid_volume_1: float = FloatField()
bid_volume_2: float = FloatField()
bid_volume_3: float = FloatField()
bid_volume_4: float = FloatField()
bid_volume_5: float = FloatField()
bid_volume_1 = FloatField()
bid_volume_2 = FloatField()
bid_volume_3 = FloatField()
bid_volume_4 = FloatField()
bid_volume_5 = FloatField()
ask_volume_1: float = FloatField()
ask_volume_2: float = FloatField()
ask_volume_3: float = FloatField()
ask_volume_4: float = FloatField()
ask_volume_5: float = FloatField()
ask_volume_1 = FloatField()
ask_volume_2 = FloatField()
ask_volume_3 = FloatField()
ask_volume_4 = FloatField()
ask_volume_5 = FloatField()
class Meta:
database = _db
indexes = ((("datetime", "symbol"), True),)
meta = {"indexes": [{"fields": ("datetime", "symbol", "exchange"), "unique": True}]}
@staticmethod
def from_tick(tick: TickData):
@ -254,7 +190,7 @@ class DbTickData(ModelBase):
db_tick.ask_volume_4 = tick.ask_volume_4
db_tick.ask_volume_5 = tick.ask_volume_5
return tick
return db_tick
def to_tick(self):
"""
@ -304,20 +240,63 @@ class DbTickData(ModelBase):
return tick
class MongoManager(BaseDatabaseManager):
def load_bar_data(
self,
symbol: str,
exchange: Exchange,
interval: Interval,
start: datetime,
end: datetime,
) -> Sequence[BarData]:
s = DbBarData.objects(
symbol=symbol,
exchange=exchange.value,
interval=interval.value,
datetime__gte=start,
datetime__lte=end,
)
data = [db_bar.to_bar() for db_bar in s]
return data
def load_tick_data(
self, symbol: str, exchange: Exchange, start: datetime, end: datetime
) -> Sequence[TickData]:
s = DbTickData.objects(
symbol=symbol,
exchange=exchange.value,
datetime__gte=start,
datetime__lte=end,
)
data = [db_tick.to_tick() for db_tick in s]
return data
@staticmethod
def save_all(objs: List["DbTickData"]):
with _db.atomic():
if _driver is Driver.POSTGRESQL:
for bar in objs:
DbTickData.insert(bar.to_dict()).on_conflict(
update=bar.to_dict(),
preserve=(DbTickData.id),
conflict_target=(DbTickData.datetime, DbTickData.symbol),
).execute()
else:
for c in chunked(objs, 50):
DbBarData.insert_many(c).on_conflict_replace()
def to_update_param(d):
return {
"set__" + k: v.value if isinstance(v, Enum) else v
for k, v in d.__dict__.items()
}
def save_bar_data(self, datas: Sequence[BarData]):
for d in datas:
updates = self.to_update_param(d)
updates.pop("set__gateway_name")
updates.pop("set__vt_symbol")
(
DbBarData.objects(
symbol=d.symbol, interval=d.interval.value, datetime=d.datetime
).update_one(upsert=True, **updates)
)
_db.connect()
_db.create_tables([DbBarData, DbTickData])
def save_tick_data(self, datas: Sequence[TickData]):
for d in datas:
updates = self.to_update_param(d)
updates.pop("set__gateway_name")
updates.pop("set__vt_symbol")
(
DbTickData.objects(
symbol=d.symbol, exchange=d.exchange.value, datetime=d.datetime
).update_one(upsert=True, **updates)
)

View File

@ -0,0 +1,370 @@
""""""
from datetime import datetime
from typing import List, Sequence, Type
from peewee import (
AutoField,
CharField,
Database,
DateTimeField,
FloatField,
Model,
MySQLDatabase,
PostgresqlDatabase,
SqliteDatabase,
chunked,
)
from vnpy.trader.constant import Exchange, Interval
from vnpy.trader.object import BarData, TickData
from vnpy.trader.utility import get_file_path
from .database import BaseDatabaseManager, Driver
def init(driver: Driver, settings: dict):
init_funcs = {
Driver.SQLITE: init_sqlite,
Driver.MYSQL: init_mysql,
Driver.POSTGRESQL: init_postgresql,
}
assert driver in init_funcs
db = init_funcs[driver](settings)
bar, tick = init_models(db, driver)
return SqlManager(bar, tick)
def init_sqlite(settings: dict):
database = settings["database"]
path = str(get_file_path(database))
db = SqliteDatabase(path)
return db
def init_mysql(settings: dict):
keys = {"database", "user", "password", "host", "port"}
settings = {k: v for k, v in settings.items() if k in keys}
db = MySQLDatabase(**settings)
return db
def init_postgresql(settings: dict):
keys = {"database", "user", "password", "host", "port"}
settings = {k: v for k, v in settings.items() if k in keys}
db = PostgresqlDatabase(**settings)
return db
class ModelBase(Model):
def to_dict(self):
return self.__data__
def init_models(db: Database, driver: Driver):
class DbBarData(ModelBase):
"""
Candlestick bar data for database storage.
Index is defined unique with datetime, interval, symbol
"""
id = AutoField()
symbol: str = CharField()
exchange: str = CharField()
datetime: datetime = DateTimeField()
interval: str = CharField()
volume: float = FloatField()
open_price: float = FloatField()
high_price: float = FloatField()
low_price: float = FloatField()
close_price: float = FloatField()
class Meta:
database = db
indexes = ((("datetime", "interval", "symbol", "exchange"), True),)
@staticmethod
def from_bar(bar: BarData):
"""
Generate DbBarData object from BarData.
"""
db_bar = DbBarData()
db_bar.symbol = bar.symbol
db_bar.exchange = bar.exchange.value
db_bar.datetime = bar.datetime
db_bar.interval = bar.interval.value
db_bar.volume = bar.volume
db_bar.open_price = bar.open_price
db_bar.high_price = bar.high_price
db_bar.low_price = bar.low_price
db_bar.close_price = bar.close_price
return db_bar
def to_bar(self):
"""
Generate BarData object from DbBarData.
"""
bar = BarData(
symbol=self.symbol,
exchange=Exchange(self.exchange),
datetime=self.datetime,
interval=Interval(self.interval),
volume=self.volume,
open_price=self.open_price,
high_price=self.high_price,
low_price=self.low_price,
close_price=self.close_price,
gateway_name="DB",
)
return bar
@staticmethod
def save_all(objs: List["DbBarData"]):
"""
save a list of objects, update if exists.
"""
dicts = [i.to_dict() for i in objs]
with db.atomic():
if driver is Driver.POSTGRESQL:
for bar in dicts:
DbBarData.insert(bar).on_conflict(
update=bar,
conflict_target=(
DbBarData.datetime,
DbBarData.interval,
DbBarData.symbol,
DbBarData.exchange,
),
).execute()
else:
for c in chunked(dicts, 50):
DbBarData.insert_many(c).on_conflict_replace().execute()
class DbTickData(ModelBase):
"""
Tick data for database storage.
Index is defined unique with (datetime, symbol)
"""
id = AutoField()
symbol: str = CharField()
exchange: str = CharField()
datetime: datetime = DateTimeField()
name: str = CharField()
volume: float = FloatField()
last_price: float = FloatField()
last_volume: float = FloatField()
limit_up: float = FloatField()
limit_down: float = FloatField()
open_price: float = FloatField()
high_price: float = FloatField()
low_price: float = FloatField()
pre_close: float = FloatField()
bid_price_1: float = FloatField()
bid_price_2: float = FloatField(null=True)
bid_price_3: float = FloatField(null=True)
bid_price_4: float = FloatField(null=True)
bid_price_5: float = FloatField(null=True)
ask_price_1: float = FloatField()
ask_price_2: float = FloatField(null=True)
ask_price_3: float = FloatField(null=True)
ask_price_4: float = FloatField(null=True)
ask_price_5: float = FloatField(null=True)
bid_volume_1: float = FloatField()
bid_volume_2: float = FloatField(null=True)
bid_volume_3: float = FloatField(null=True)
bid_volume_4: float = FloatField(null=True)
bid_volume_5: float = FloatField(null=True)
ask_volume_1: float = FloatField()
ask_volume_2: float = FloatField(null=True)
ask_volume_3: float = FloatField(null=True)
ask_volume_4: float = FloatField(null=True)
ask_volume_5: float = FloatField(null=True)
class Meta:
database = db
indexes = ((("datetime", "symbol", "exchange"), True),)
@staticmethod
def from_tick(tick: TickData):
"""
Generate DbTickData object from TickData.
"""
db_tick = DbTickData()
db_tick.symbol = tick.symbol
db_tick.exchange = tick.exchange.value
db_tick.datetime = tick.datetime
db_tick.name = tick.name
db_tick.volume = tick.volume
db_tick.last_price = tick.last_price
db_tick.last_volume = tick.last_volume
db_tick.limit_up = tick.limit_up
db_tick.limit_down = tick.limit_down
db_tick.open_price = tick.open_price
db_tick.high_price = tick.high_price
db_tick.low_price = tick.low_price
db_tick.pre_close = tick.pre_close
db_tick.bid_price_1 = tick.bid_price_1
db_tick.ask_price_1 = tick.ask_price_1
db_tick.bid_volume_1 = tick.bid_volume_1
db_tick.ask_volume_1 = tick.ask_volume_1
if tick.bid_price_2:
db_tick.bid_price_2 = tick.bid_price_2
db_tick.bid_price_3 = tick.bid_price_3
db_tick.bid_price_4 = tick.bid_price_4
db_tick.bid_price_5 = tick.bid_price_5
db_tick.ask_price_2 = tick.ask_price_2
db_tick.ask_price_3 = tick.ask_price_3
db_tick.ask_price_4 = tick.ask_price_4
db_tick.ask_price_5 = tick.ask_price_5
db_tick.bid_volume_2 = tick.bid_volume_2
db_tick.bid_volume_3 = tick.bid_volume_3
db_tick.bid_volume_4 = tick.bid_volume_4
db_tick.bid_volume_5 = tick.bid_volume_5
db_tick.ask_volume_2 = tick.ask_volume_2
db_tick.ask_volume_3 = tick.ask_volume_3
db_tick.ask_volume_4 = tick.ask_volume_4
db_tick.ask_volume_5 = tick.ask_volume_5
return db_tick
def to_tick(self):
"""
Generate TickData object from DbTickData.
"""
tick = TickData(
symbol=self.symbol,
exchange=Exchange(self.exchange),
datetime=self.datetime,
name=self.name,
volume=self.volume,
last_price=self.last_price,
last_volume=self.last_volume,
limit_up=self.limit_up,
limit_down=self.limit_down,
open_price=self.open_price,
high_price=self.high_price,
low_price=self.low_price,
pre_close=self.pre_close,
bid_price_1=self.bid_price_1,
ask_price_1=self.ask_price_1,
bid_volume_1=self.bid_volume_1,
ask_volume_1=self.ask_volume_1,
gateway_name="DB",
)
if self.bid_price_2:
tick.bid_price_2 = self.bid_price_2
tick.bid_price_3 = self.bid_price_3
tick.bid_price_4 = self.bid_price_4
tick.bid_price_5 = self.bid_price_5
tick.ask_price_2 = self.ask_price_2
tick.ask_price_3 = self.ask_price_3
tick.ask_price_4 = self.ask_price_4
tick.ask_price_5 = self.ask_price_5
tick.bid_volume_2 = self.bid_volume_2
tick.bid_volume_3 = self.bid_volume_3
tick.bid_volume_4 = self.bid_volume_4
tick.bid_volume_5 = self.bid_volume_5
tick.ask_volume_2 = self.ask_volume_2
tick.ask_volume_3 = self.ask_volume_3
tick.ask_volume_4 = self.ask_volume_4
tick.ask_volume_5 = self.ask_volume_5
return tick
@staticmethod
def save_all(objs: List["DbTickData"]):
dicts = [i.to_dict() for i in objs]
with db.atomic():
if driver is Driver.POSTGRESQL:
for tick in dicts:
DbTickData.insert(tick).on_conflict(
update=tick,
conflict_target=(
DbTickData.datetime,
DbTickData.symbol,
DbTickData.exchange,
),
).execute()
else:
for c in chunked(dicts, 50):
DbTickData.insert_many(c).on_conflict_replace().execute()
db.connect()
db.create_tables([DbBarData, DbTickData])
return DbBarData, DbTickData
class SqlManager(BaseDatabaseManager):
def __init__(self, class_bar: Type[Model], class_tick: Type[Model]):
self.class_bar = class_bar
self.class_tick = class_tick
def load_bar_data(
self,
symbol: str,
exchange: Exchange,
interval: Interval,
start: datetime,
end: datetime,
) -> Sequence[BarData]:
s = (
self.class_bar.select()
.where(
(self.class_bar.symbol == symbol)
& (self.class_bar.exchange == exchange.value)
& (self.class_bar.interval == interval.value)
& (self.class_bar.datetime >= start)
& (self.class_bar.datetime <= end)
)
.order_by(self.class_bar.datetime)
)
data = [db_bar.to_bar() for db_bar in s]
return data
def load_tick_data(
self, symbol: str, exchange: Exchange, start: datetime, end: datetime
) -> Sequence[TickData]:
s = (
self.class_tick.select()
.where(
(self.class_tick.symbol == symbol)
& (self.class_tick.exchange == exchange.value)
& (self.class_tick.datetime >= start)
& (self.class_tick.datetime <= end)
)
.order_by(self.class_tick.datetime)
)
data = [db_tick.to_tick() for db_tick in s]
return data
def save_bar_data(self, datas: Sequence[BarData]):
ds = [self.class_bar.from_bar(i) for i in datas]
self.class_bar.save_all(ds)
def save_tick_data(self, datas: Sequence[TickData]):
ds = [self.class_tick.from_tick(i) for i in datas]
self.class_tick.save_all(ds)

View File

@ -0,0 +1,24 @@
""""""
from .database import BaseDatabaseManager, Driver
def init(settings: dict) -> BaseDatabaseManager:
driver = Driver(settings["driver"])
if driver is Driver.MONGODB:
return init_nosql(driver=driver, settings=settings)
else:
return init_sql(driver=driver, settings=settings)
def init_sql(driver: Driver, settings: dict):
from .database_sql import init
keys = {'database', "host", "port", "user", "password"}
settings = {k: v for k, v in settings.items() if k in keys}
_database_manager = init(driver, settings)
return _database_manager
def init_nosql(driver: Driver, settings: dict):
from .database_mongo import init
_database_manager = init(driver, settings=settings)
return _database_manager

View File

@ -30,9 +30,15 @@ SETTINGS = {
"database.host": "localhost",
"database.port": 3306,
"database.user": "root",
"database.password": ""
"database.password": "",
"database.authentication_source": "admin", # for mongodb
}
# Load global setting from json file.
SETTING_FILENAME = "vt_setting.json"
SETTINGS.update(load_json(SETTING_FILENAME))
def get_settings(prefix: str = ""):
prefix_length = len(prefix)
return {k[prefix_length:]: v for k, v in SETTINGS.items() if k.startswith(prefix)}

View File

@ -3,20 +3,28 @@ General utility functions.
"""
import json
import os
from pathlib import Path
from typing import Callable
from typing import Callable, TYPE_CHECKING
import numpy as np
import talib
from .object import BarData, TickData
if TYPE_CHECKING:
from vnpy.trader.constant import Exchange
def resolve_path(pattern: str):
env = dict(os.environ)
env.update({"VNPY_TEMP": str(TEMP_DIR)})
return pattern.format(**env)
def extract_vt_symbol(vt_symbol: str):
"""
:return: (symbol, exchange)
"""
symbol, exchange = vt_symbol.split('.')
return symbol, exchange
def generate_vt_symbol(symbol: str, exchange: "Exchange"):
return f'{symbol}.{exchange.value}'
def _get_trader_dir(temp_name: str):