[Add] added support for mongodb
[Add] added tests for database and CsvLoaderEngine [Mod] changed .travis.yaml
This commit is contained in:
parent
f797152bd5
commit
5d0bf006c6
@ -1,13 +1,19 @@
|
||||
language: python
|
||||
|
||||
cache: pip
|
||||
|
||||
dist: xenial # required for Python >= 3.7 (travis-ci/travis-ci#9069)
|
||||
|
||||
python:
|
||||
- "3.7"
|
||||
services:
|
||||
- mongodb
|
||||
- mysql
|
||||
- postgresql
|
||||
|
||||
script:
|
||||
# todo: use python unittest
|
||||
- mkdir run; cd run; python ../tests/load_all.py
|
||||
- cd tests; source travis_env.sh; python test_all.py
|
||||
|
||||
matrix:
|
||||
include:
|
||||
|
1
tests/app/__init__.py
Normal file
1
tests/app/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .test_csv_loader import *
|
20
tests/test_all.py
Normal file
20
tests/test_all.py
Normal file
@ -0,0 +1,20 @@
|
||||
# tests/runner.py
|
||||
import unittest
|
||||
|
||||
# import your test modules
|
||||
import test_import_all
|
||||
import trader
|
||||
import app
|
||||
|
||||
# initialize the test suite
|
||||
loader = unittest.TestLoader()
|
||||
suite = unittest.TestSuite()
|
||||
|
||||
# add tests to the test suite
|
||||
suite.addTests(loader.loadTestsFromModule(test_import_all))
|
||||
suite.addTests(loader.loadTestsFromModule(trader))
|
||||
suite.addTests(loader.loadTestsFromModule(app))
|
||||
|
||||
# initialize a runner, pass it your suite and run it
|
||||
runner = unittest.TextTestRunner(verbosity=3)
|
||||
result = runner.run(suite)
|
@ -2,23 +2,41 @@
|
||||
import unittest
|
||||
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
class ImportTest(unittest.TestCase):
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
def test_import_all(self):
|
||||
from vnpy.event import EventEngine
|
||||
|
||||
def test_import_main_engine(self):
|
||||
from vnpy.trader.engine import MainEngine
|
||||
|
||||
def test_import_ui(self):
|
||||
from vnpy.trader.ui import MainWindow, create_qapp
|
||||
|
||||
def test_import_bitmex_gateway(self):
|
||||
from vnpy.gateway.bitmex import BitmexGateway
|
||||
|
||||
def test_import_futu_gateway(self):
|
||||
from vnpy.gateway.futu import FutuGateway
|
||||
|
||||
def test_import_ib_gateway(self):
|
||||
from vnpy.gateway.ib import IbGateway
|
||||
|
||||
def test_import_ctp_gateway(self):
|
||||
from vnpy.gateway.ctp import CtpGateway
|
||||
|
||||
def test_import_tiger_gateway(self):
|
||||
from vnpy.gateway.tiger import TigerGateway
|
||||
|
||||
def test_import_oes_gateway(self):
|
||||
from vnpy.gateway.oes import OesGateway
|
||||
|
||||
def test_import_cta_strategy_app(self):
|
||||
from vnpy.app.cta_strategy import CtaStrategyApp
|
||||
|
||||
def test_import_csv_loader_app(self):
|
||||
from vnpy.app.csv_loader import CsvLoaderApp
|
||||
|
||||
|
2
tests/trader/__init__.py
Normal file
2
tests/trader/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from .test_database import *
|
||||
from .test_settings import *
|
125
tests/trader/test_database.py
Normal file
125
tests/trader/test_database.py
Normal file
@ -0,0 +1,125 @@
|
||||
"""
|
||||
Test if database works fine
|
||||
"""
|
||||
import os
|
||||
import unittest
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from vnpy.trader.constant import Exchange, Interval
|
||||
from vnpy.trader.database.database import Driver
|
||||
from vnpy.trader.object import BarData, TickData
|
||||
|
||||
os.environ['VNPY_TESTING'] = '1'
|
||||
|
||||
profiles = {
|
||||
Driver.SQLITE: {
|
||||
"driver": "sqlite",
|
||||
"database": "test_db.db",
|
||||
},
|
||||
Driver.MYSQL: {
|
||||
"driver": "mysql",
|
||||
"database": os.environ['VNPY_TEST_MYSQL_DATABASE'],
|
||||
"host": os.environ['VNPY_TEST_MYSQL_HOST'],
|
||||
"port": int(os.environ['VNPY_TEST_MYSQL_PORT']),
|
||||
"user": os.environ["VNPY_TEST_MYSQL_USER"],
|
||||
"password": os.environ['VNPY_TEST_MYSQL_PASSWORD'],
|
||||
},
|
||||
Driver.POSTGRESQL: {
|
||||
"driver": "postgresql",
|
||||
"database": os.environ['VNPY_TEST_POSTGRESQL_DATABASE'],
|
||||
"host": os.environ['VNPY_TEST_POSTGRESQL_HOST'],
|
||||
"port": int(os.environ['VNPY_TEST_POSTGRESQL_PORT']),
|
||||
"user": os.environ["VNPY_TEST_POSTGRESQL_USER"],
|
||||
"password": os.environ['VNPY_TEST_POSTGRESQL_PASSWORD'],
|
||||
},
|
||||
Driver.MONGODB: {
|
||||
"driver": "mongodb",
|
||||
"database": os.environ['VNPY_TEST_MONGODB_DATABASE'],
|
||||
"host": os.environ['VNPY_TEST_MONGODB_HOST'],
|
||||
"port": int(os.environ['VNPY_TEST_MONGODB_PORT']),
|
||||
"user": "",
|
||||
"password": "",
|
||||
"authentication_source": "",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def now():
|
||||
return datetime.utcnow()
|
||||
|
||||
|
||||
bar = BarData(
|
||||
gateway_name="DB",
|
||||
symbol="test_symbol",
|
||||
exchange=Exchange.BITMEX,
|
||||
datetime=now(),
|
||||
interval=Interval.MINUTE,
|
||||
)
|
||||
|
||||
tick = TickData(
|
||||
gateway_name="DB",
|
||||
symbol="test_symbol",
|
||||
exchange=Exchange.BITMEX,
|
||||
datetime=now(),
|
||||
name="DB_test_symbol",
|
||||
)
|
||||
|
||||
|
||||
class TestDatabase(unittest.TestCase):
|
||||
|
||||
def connect(self, settings: dict):
|
||||
from vnpy.trader.database.initialize import init # noqa
|
||||
self.manager = init(settings)
|
||||
|
||||
def test_upsert_bar(self):
|
||||
for driver, settings in profiles.items():
|
||||
with self.subTest(driver=driver, settings=settings):
|
||||
self.connect(settings)
|
||||
self.manager.save_bar_data([bar])
|
||||
self.manager.save_bar_data([bar])
|
||||
|
||||
def test_save_load_bar(self):
|
||||
for driver, settings in profiles.items():
|
||||
with self.subTest(driver=driver, settings=settings):
|
||||
self.connect(settings)
|
||||
# save first
|
||||
self.manager.save_bar_data([bar])
|
||||
|
||||
# and load
|
||||
results = self.manager.load_bar_data(
|
||||
symbol=bar.symbol,
|
||||
exchange=bar.exchange,
|
||||
interval=bar.interval,
|
||||
start=bar.datetime - timedelta(seconds=1), # time is not accuracy
|
||||
end=now() + timedelta(seconds=1), # time is not accuracy
|
||||
)
|
||||
count = len(results)
|
||||
self.assertNotEqual(count, 0)
|
||||
|
||||
def test_upsert_tick(self):
|
||||
for driver, settings in profiles.items():
|
||||
with self.subTest(driver=driver, settings=settings):
|
||||
self.connect(settings)
|
||||
self.manager.save_tick_data([tick])
|
||||
self.manager.save_tick_data([tick])
|
||||
|
||||
def test_save_load_tick(self):
|
||||
for driver, settings in profiles.items():
|
||||
with self.subTest(driver=driver, settings=settings):
|
||||
self.connect(settings)
|
||||
# save first
|
||||
self.manager.save_tick_data([tick])
|
||||
|
||||
# and load
|
||||
results = self.manager.load_tick_data(
|
||||
symbol=bar.symbol,
|
||||
exchange=bar.exchange,
|
||||
start=bar.datetime - timedelta(seconds=1), # time is not accuracy
|
||||
end=now() + timedelta(seconds=1), # time is not accuracy
|
||||
)
|
||||
count = len(results)
|
||||
self.assertNotEqual(count, 0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
24
tests/trader/test_settings.py
Normal file
24
tests/trader/test_settings.py
Normal file
@ -0,0 +1,24 @@
|
||||
"""
|
||||
Test if database works fine
|
||||
"""
|
||||
import unittest
|
||||
from vnpy.trader.setting import SETTINGS, get_settings
|
||||
|
||||
|
||||
class TestSettings(unittest.TestCase):
|
||||
|
||||
def test_get_settings(self):
|
||||
SETTINGS['a'] = 1
|
||||
got = get_settings()
|
||||
self.assertIn('a', got)
|
||||
self.assertEqual(got['a'], 1)
|
||||
|
||||
def test_get_settings_with_prefix(self):
|
||||
SETTINGS['a.a'] = 1
|
||||
got = get_settings()
|
||||
self.assertIn('a', got)
|
||||
self.assertEqual(got['a'], 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
16
tests/travis_env.sh
Normal file
16
tests/travis_env.sh
Normal file
@ -0,0 +1,16 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
export VNPY_TEST_MYSQL_DATABASE=vnpy
|
||||
export VNPY_TEST_MYSQL_HOST=127.0.0.1
|
||||
export VNPY_TEST_MYSQL_PORT=3306
|
||||
export VNPY_TEST_MYSQL_USER=root
|
||||
export VNPY_TEST_MYSQL_PASSWORD=
|
||||
export VNPY_TEST_POSTGRESQL_DATABASE=vnpy
|
||||
export VNPY_TEST_POSTGRESQL_HOST=127.0.0.1
|
||||
export VNPY_TEST_POSTGRESQL_PORT=5432
|
||||
export VNPY_TEST_POSTGRESQL_USER=postgres
|
||||
export VNPY_TEST_POSTGRESQL_PASSWORD=
|
||||
export VNPY_TEST_MONGODB_DATABASE=vnpy
|
||||
export VNPY_TEST_MONGODB_HOST=127.0.0.1
|
||||
export VNPY_TEST_MONGODB_PORT=27017
|
||||
|
@ -26,8 +26,9 @@ from typing import TextIO
|
||||
|
||||
from vnpy.event import EventEngine
|
||||
from vnpy.trader.constant import Exchange, Interval
|
||||
from vnpy.trader.database import DbBarData
|
||||
from vnpy.trader.database import database_manager
|
||||
from vnpy.trader.engine import BaseEngine, MainEngine
|
||||
from vnpy.trader.object import BarData
|
||||
|
||||
APP_NAME = "CsvLoader"
|
||||
|
||||
@ -70,7 +71,7 @@ class CsvLoaderEngine(BaseEngine):
|
||||
"""
|
||||
reader = csv.DictReader(f)
|
||||
|
||||
db_bars = []
|
||||
bars = []
|
||||
start = None
|
||||
count = 0
|
||||
for item in reader:
|
||||
@ -79,29 +80,29 @@ class CsvLoaderEngine(BaseEngine):
|
||||
else:
|
||||
dt = datetime.fromisoformat(item[datetime_head])
|
||||
|
||||
db_bar = DbBarData(
|
||||
bar = BarData(
|
||||
symbol=symbol,
|
||||
exchange=exchange.value,
|
||||
exchange=exchange,
|
||||
datetime=dt,
|
||||
interval=interval.value,
|
||||
interval=interval,
|
||||
volume=item[volume_head],
|
||||
open_price=item[open_head],
|
||||
high_price=item[high_head],
|
||||
low_price=item[low_head],
|
||||
close_price=item[close_head],
|
||||
gateway_name="DB",
|
||||
)
|
||||
|
||||
db_bars.append(db_bar)
|
||||
bars.append(bar)
|
||||
|
||||
# do some statistics
|
||||
count += 1
|
||||
if not start:
|
||||
start = db_bar.datetime
|
||||
end = db_bar.datetime
|
||||
start = bar.datetime
|
||||
end = bar.datetime
|
||||
|
||||
# insert into database
|
||||
DbBarData.save_all(db_bars)
|
||||
|
||||
database_manager.save_bar_data(bars)
|
||||
return start, end, count
|
||||
|
||||
def load(
|
||||
|
@ -12,9 +12,9 @@ from pandas import DataFrame
|
||||
|
||||
from vnpy.trader.constant import (Direction, Offset, Exchange,
|
||||
Interval, Status)
|
||||
from vnpy.trader.database import DbBarData, DbTickData
|
||||
from vnpy.trader.object import OrderData, TradeData
|
||||
from vnpy.trader.utility import round_to_pricetick
|
||||
from vnpy.trader.database import database_manager
|
||||
from vnpy.trader.object import OrderData, TradeData, BarData, TickData
|
||||
from vnpy.trader.utility import round_to_pricetick, extract_vt_symbol
|
||||
|
||||
from .base import (
|
||||
BacktestingMode,
|
||||
@ -103,8 +103,8 @@ class BacktestingEngine:
|
||||
|
||||
self.strategy_class = None
|
||||
self.strategy = None
|
||||
self.tick = None
|
||||
self.bar = None
|
||||
self.tick: TickData
|
||||
self.bar: BarData
|
||||
self.datetime = None
|
||||
|
||||
self.interval = None
|
||||
@ -199,14 +199,16 @@ class BacktestingEngine:
|
||||
|
||||
if self.mode == BacktestingMode.BAR:
|
||||
self.history_data = load_bar_data(
|
||||
self.vt_symbol,
|
||||
self.symbol,
|
||||
self.exchange,
|
||||
self.interval,
|
||||
self.start,
|
||||
self.end
|
||||
)
|
||||
else:
|
||||
self.history_data = load_tick_data(
|
||||
self.vt_symbol,
|
||||
self.symbol,
|
||||
self.exchange,
|
||||
self.start,
|
||||
self.end
|
||||
)
|
||||
@ -519,7 +521,7 @@ class BacktestingEngine:
|
||||
else:
|
||||
self.daily_results[d] = DailyResult(d, price)
|
||||
|
||||
def new_bar(self, bar: DbBarData):
|
||||
def new_bar(self, bar: BarData):
|
||||
""""""
|
||||
self.bar = bar
|
||||
self.datetime = bar.datetime
|
||||
@ -530,7 +532,7 @@ class BacktestingEngine:
|
||||
|
||||
self.update_daily_close(bar.close_price)
|
||||
|
||||
def new_tick(self, tick: DbTickData):
|
||||
def new_tick(self, tick: TickData):
|
||||
""""""
|
||||
self.tick = tick
|
||||
self.datetime = tick.datetime
|
||||
@ -965,41 +967,27 @@ def optimize(
|
||||
|
||||
@lru_cache(maxsize=10)
|
||||
def load_bar_data(
|
||||
vt_symbol: str,
|
||||
interval: str,
|
||||
symbol: str,
|
||||
exchange: Exchange,
|
||||
interval: Interval,
|
||||
start: datetime,
|
||||
end: datetime
|
||||
):
|
||||
""""""
|
||||
s = (
|
||||
DbBarData.select()
|
||||
.where(
|
||||
(DbBarData.vt_symbol == vt_symbol)
|
||||
& (DbBarData.interval == interval)
|
||||
& (DbBarData.datetime >= start)
|
||||
& (DbBarData.datetime <= end)
|
||||
)
|
||||
.order_by(DbBarData.datetime)
|
||||
return database_manager.load_bar_data(
|
||||
symbol, exchange, interval, start, end
|
||||
)
|
||||
data = [db_bar.to_bar() for db_bar in s]
|
||||
return data
|
||||
|
||||
|
||||
@lru_cache(maxsize=10)
|
||||
def load_tick_data(
|
||||
vt_symbol: str,
|
||||
symbol: str,
|
||||
exchange: Exchange,
|
||||
start: datetime,
|
||||
end: datetime
|
||||
):
|
||||
""""""
|
||||
s = (
|
||||
DbTickData.select()
|
||||
.where(
|
||||
(DbTickData.vt_symbol == vt_symbol)
|
||||
& (DbTickData.datetime >= start)
|
||||
& (DbTickData.datetime <= end)
|
||||
)
|
||||
.order_by(DbTickData.datetime)
|
||||
return database_manager.load_tick_data(
|
||||
symbol, exchange, start, end
|
||||
)
|
||||
data = [db_tick.db_tick() for db_tick in s]
|
||||
return data
|
||||
|
||||
|
@ -5,7 +5,7 @@ import os
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable
|
||||
from typing import Any, Callable, List
|
||||
from datetime import datetime, timedelta
|
||||
from threading import Thread
|
||||
from queue import Queue
|
||||
@ -36,7 +36,7 @@ from vnpy.trader.constant import (
|
||||
Status
|
||||
)
|
||||
from vnpy.trader.utility import load_json, save_json
|
||||
from vnpy.trader.database import DbTickData, DbBarData
|
||||
from vnpy.trader.database import database_manager
|
||||
from vnpy.trader.setting import SETTINGS
|
||||
|
||||
from .base import (
|
||||
@ -146,13 +146,12 @@ class CtaEngine(BaseEngine):
|
||||
self.write_log("RQData数据接口初始化成功")
|
||||
|
||||
def query_bar_from_rq(
|
||||
self, vt_symbol: str, interval: Interval, start: datetime, end: datetime
|
||||
self, symbol: str, exchange: Exchange, interval: Interval, start: datetime, end: datetime
|
||||
):
|
||||
"""
|
||||
Query bar data from RQData.
|
||||
"""
|
||||
symbol, exchange_str = vt_symbol.split(".")
|
||||
rq_symbol = to_rq_symbol(vt_symbol)
|
||||
rq_symbol = to_rq_symbol(symbol, exchange)
|
||||
if rq_symbol not in self.rq_symbols:
|
||||
return None
|
||||
|
||||
@ -166,11 +165,11 @@ class CtaEngine(BaseEngine):
|
||||
end_date=end
|
||||
)
|
||||
|
||||
data = []
|
||||
data: List[BarData] = []
|
||||
for ix, row in df.iterrows():
|
||||
bar = BarData(
|
||||
symbol=symbol,
|
||||
exchange=Exchange(exchange_str),
|
||||
exchange=exchange,
|
||||
interval=interval,
|
||||
datetime=row.name.to_pydatetime(),
|
||||
open_price=row["open"],
|
||||
@ -529,46 +528,41 @@ class CtaEngine(BaseEngine):
|
||||
return self.engine_type
|
||||
|
||||
def load_bar(
|
||||
self, vt_symbol: str, days: int, interval: Interval, callback: Callable
|
||||
self, symbol: str, exchange: Exchange, days: int, interval: Interval,
|
||||
callback: Callable[[BarData], None]
|
||||
):
|
||||
""""""
|
||||
end = datetime.now()
|
||||
start = end - timedelta(days)
|
||||
|
||||
# Query data from RQData by default, if not found, load from database.
|
||||
data = self.query_bar_from_rq(vt_symbol, interval, start, end)
|
||||
if not data:
|
||||
s = (
|
||||
DbBarData.select()
|
||||
.where(
|
||||
(DbBarData.vt_symbol == vt_symbol)
|
||||
& (DbBarData.interval == interval.value)
|
||||
& (DbBarData.datetime >= start)
|
||||
& (DbBarData.datetime <= end)
|
||||
)
|
||||
.order_by(DbBarData.datetime)
|
||||
# Query bars from RQData by default, if not found, load from database.
|
||||
bars = self.query_bar_from_rq(symbol, exchange, interval, start, end)
|
||||
if not bars:
|
||||
bars = database_manager.load_bar_data(
|
||||
symbol=symbol,
|
||||
exchange=exchange,
|
||||
interval=interval,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
data = [db_bar.to_bar() for db_bar in s]
|
||||
|
||||
for bar in data:
|
||||
for bar in bars:
|
||||
callback(bar)
|
||||
|
||||
def load_tick(self, vt_symbol: str, days: int, callback: Callable):
|
||||
def load_tick(self, symbol: str, exchange: Exchange, days: int,
|
||||
callback: Callable[[TickData], None]):
|
||||
""""""
|
||||
end = datetime.now()
|
||||
start = end - timedelta(days)
|
||||
|
||||
s = (
|
||||
DbTickData.select()
|
||||
.where(
|
||||
(DbBarData.vt_symbol == vt_symbol)
|
||||
& (DbBarData.datetime >= start)
|
||||
& (DbBarData.datetime <= end)
|
||||
)
|
||||
.order_by(DbBarData.datetime)
|
||||
ticks = database_manager.load_tick_data(
|
||||
symbol=symbol,
|
||||
exchange=exchange,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
for tick in s:
|
||||
for tick in ticks:
|
||||
callback(tick)
|
||||
|
||||
def call_strategy_func(
|
||||
@ -757,7 +751,7 @@ class CtaEngine(BaseEngine):
|
||||
"""
|
||||
Load strategy class from certain folder.
|
||||
"""
|
||||
for dirpath, dirnames, filenames in os.walk(path):
|
||||
for dirpath, dirnames, filenames in os.walk(str(path)):
|
||||
for filename in filenames:
|
||||
if filename.endswith(".py"):
|
||||
strategy_module_name = ".".join(
|
||||
@ -914,19 +908,19 @@ class CtaEngine(BaseEngine):
|
||||
self.main_engine.send_email(subject, msg)
|
||||
|
||||
|
||||
def to_rq_symbol(vt_symbol: str):
|
||||
def to_rq_symbol(symbol: str, exchange: Exchange):
|
||||
"""
|
||||
CZCE product of RQData has symbol like "TA1905" while
|
||||
vt symbol is "TA905.CZCE" so need to add "1" in symbol.
|
||||
"""
|
||||
symbol, exchange_str = vt_symbol.split(".")
|
||||
if exchange_str != "CZCE":
|
||||
if exchange is not Exchange.CZCE:
|
||||
return symbol.upper()
|
||||
|
||||
for count, word in enumerate(symbol):
|
||||
if word.isdigit():
|
||||
break
|
||||
|
||||
# noinspection PyUnboundLocalVariable
|
||||
product = symbol[:count]
|
||||
year = symbol[count]
|
||||
month = symbol[count + 1:]
|
||||
|
12
vnpy/trader/database/__init__.py
Normal file
12
vnpy/trader/database/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vnpy.trader.database.database import BaseDatabaseManager
|
||||
|
||||
if "VNPY_TESTING" not in os.environ:
|
||||
from vnpy.trader.setting import get_settings
|
||||
from .initialize import init
|
||||
|
||||
settings = get_settings("database.")
|
||||
database_manager: "BaseDatabaseManager" = init(settings=settings)
|
53
vnpy/trader/database/database.py
Normal file
53
vnpy/trader/database/database.py
Normal file
@ -0,0 +1,53 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Sequence, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vnpy.trader.constant import Interval, Exchange
|
||||
from vnpy.trader.object import BarData, TickData
|
||||
|
||||
|
||||
class Driver(Enum):
|
||||
SQLITE = "sqlite"
|
||||
MYSQL = "mysql"
|
||||
POSTGRESQL = "postgresql"
|
||||
MONGODB = "mongodb"
|
||||
|
||||
|
||||
class BaseDatabaseManager(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def load_bar_data(
|
||||
self,
|
||||
symbol: str,
|
||||
exchange: "Exchange",
|
||||
interval: "Interval",
|
||||
start: datetime,
|
||||
end: datetime
|
||||
) -> Sequence["BarData"]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_tick_data(
|
||||
self,
|
||||
symbol: str,
|
||||
exchange: "Exchange",
|
||||
start: datetime,
|
||||
end: datetime
|
||||
) -> Sequence["TickData"]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_bar_data(
|
||||
self,
|
||||
datas: Sequence["BarData"],
|
||||
):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_tick_data(
|
||||
self,
|
||||
datas: Sequence["TickData"],
|
||||
):
|
||||
pass
|
@ -1,99 +1,59 @@
|
||||
""""""
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
from typing import Sequence
|
||||
|
||||
from peewee import (
|
||||
AutoField,
|
||||
CharField,
|
||||
Database,
|
||||
DateTimeField,
|
||||
FloatField,
|
||||
Model,
|
||||
MySQLDatabase,
|
||||
PostgresqlDatabase,
|
||||
SqliteDatabase,
|
||||
chunked,
|
||||
)
|
||||
from mongoengine import DateTimeField, Document, FloatField, StringField, connect
|
||||
|
||||
from .constant import Exchange, Interval
|
||||
from .object import BarData, TickData
|
||||
from .setting import SETTINGS
|
||||
from .utility import get_file_path
|
||||
from vnpy.trader.constant import Exchange, Interval
|
||||
from vnpy.trader.object import BarData, TickData
|
||||
from .database import BaseDatabaseManager, Driver
|
||||
|
||||
|
||||
class Driver(Enum):
|
||||
SQLITE = "sqlite"
|
||||
MYSQL = "mysql"
|
||||
POSTGRESQL = "postgresql"
|
||||
|
||||
|
||||
_db: Database
|
||||
_driver: Driver
|
||||
|
||||
|
||||
def init():
|
||||
global _driver
|
||||
db_settings = {k[9:]: v for k, v in SETTINGS.items() if k.startswith("database.")}
|
||||
_driver = Driver(db_settings["driver"])
|
||||
|
||||
init_funcs = {
|
||||
Driver.SQLITE: init_sqlite,
|
||||
Driver.MYSQL: init_mysql,
|
||||
Driver.POSTGRESQL: init_postgresql,
|
||||
}
|
||||
|
||||
assert _driver in init_funcs
|
||||
del db_settings["driver"]
|
||||
return init_funcs[_driver](db_settings)
|
||||
|
||||
|
||||
def init_sqlite(settings: dict):
|
||||
global _db
|
||||
def init(_: Driver, settings: dict):
|
||||
database = settings["database"]
|
||||
|
||||
_db = SqliteDatabase(str(get_file_path(database)))
|
||||
host = settings["host"]
|
||||
port = settings["port"]
|
||||
username = settings["user"]
|
||||
password = settings["password"]
|
||||
authentication_source = settings["authentication_source"]
|
||||
if not username: # if username == '' or None, skip username
|
||||
username = None
|
||||
password = None
|
||||
authentication_source = None
|
||||
connect(
|
||||
db=database,
|
||||
host=host,
|
||||
port=port,
|
||||
username=username,
|
||||
password=password,
|
||||
authentication_source=authentication_source,
|
||||
)
|
||||
return MongoManager()
|
||||
|
||||
|
||||
def init_mysql(settings: dict):
|
||||
global _db
|
||||
_db = MySQLDatabase(**settings)
|
||||
|
||||
|
||||
def init_postgresql(settings: dict):
|
||||
global _db
|
||||
_db = PostgresqlDatabase(**settings)
|
||||
|
||||
|
||||
init()
|
||||
|
||||
|
||||
class ModelBase(Model):
|
||||
def to_dict(self):
|
||||
return self.__data__
|
||||
|
||||
|
||||
class DbBarData(ModelBase):
|
||||
class DbBarData(Document):
|
||||
"""
|
||||
Candlestick bar data for database storage.
|
||||
|
||||
Index is defined unique with datetime, interval, symbol
|
||||
"""
|
||||
|
||||
id = AutoField()
|
||||
symbol = CharField()
|
||||
exchange = CharField()
|
||||
datetime = DateTimeField()
|
||||
interval = CharField()
|
||||
symbol: str = StringField()
|
||||
exchange: str = StringField()
|
||||
datetime: datetime = DateTimeField()
|
||||
interval: str = StringField()
|
||||
|
||||
volume = FloatField()
|
||||
open_price = FloatField()
|
||||
high_price = FloatField()
|
||||
low_price = FloatField()
|
||||
close_price = FloatField()
|
||||
volume: float = FloatField()
|
||||
open_price: float = FloatField()
|
||||
high_price: float = FloatField()
|
||||
low_price: float = FloatField()
|
||||
close_price: float = FloatField()
|
||||
|
||||
class Meta:
|
||||
database = _db
|
||||
indexes = ((("datetime", "interval", "symbol"), True),)
|
||||
meta = {
|
||||
"indexes": [
|
||||
{"fields": ("datetime", "interval", "symbol", "exchange"), "unique": True}
|
||||
]
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def from_bar(bar: BarData):
|
||||
@ -132,80 +92,56 @@ class DbBarData(ModelBase):
|
||||
)
|
||||
return bar
|
||||
|
||||
@staticmethod
|
||||
def save_all(objs: List["DbBarData"]):
|
||||
"""
|
||||
save a list of objects, update if exists.
|
||||
"""
|
||||
with _db.atomic():
|
||||
if _driver is Driver.POSTGRESQL:
|
||||
for bar in objs:
|
||||
DbBarData.insert(bar.to_dict()).on_conflict(
|
||||
update=bar.to_dict(),
|
||||
conflict_target=(
|
||||
DbBarData.datetime,
|
||||
DbBarData.interval,
|
||||
DbBarData.symbol,
|
||||
),
|
||||
).execute()
|
||||
else:
|
||||
for c in chunked(objs, 50):
|
||||
DbBarData.insert_many(c).on_conflict_replace()
|
||||
|
||||
|
||||
class DbTickData(ModelBase):
|
||||
class DbTickData(Document):
|
||||
"""
|
||||
Tick data for database storage.
|
||||
|
||||
Index is defined unique with (datetime, symbol)
|
||||
"""
|
||||
|
||||
id = AutoField()
|
||||
symbol: str = StringField()
|
||||
exchange: str = StringField()
|
||||
datetime: datetime = DateTimeField()
|
||||
|
||||
symbol = CharField()
|
||||
exchange = CharField()
|
||||
datetime = DateTimeField()
|
||||
name: str = StringField()
|
||||
volume: float = FloatField()
|
||||
last_price: float = FloatField()
|
||||
last_volume: float = FloatField()
|
||||
limit_up: float = FloatField()
|
||||
limit_down: float = FloatField()
|
||||
|
||||
name = CharField()
|
||||
volume = FloatField()
|
||||
last_price = FloatField()
|
||||
last_volume = FloatField()
|
||||
limit_up = FloatField()
|
||||
limit_down = FloatField()
|
||||
open_price: float = FloatField()
|
||||
high_price: float = FloatField()
|
||||
low_price: float = FloatField()
|
||||
close_price: float = FloatField()
|
||||
pre_close: float = FloatField()
|
||||
|
||||
open_price = FloatField()
|
||||
high_price = FloatField()
|
||||
low_price = FloatField()
|
||||
close_price = FloatField()
|
||||
pre_close = FloatField()
|
||||
bid_price_1: float = FloatField()
|
||||
bid_price_2: float = FloatField()
|
||||
bid_price_3: float = FloatField()
|
||||
bid_price_4: float = FloatField()
|
||||
bid_price_5: float = FloatField()
|
||||
|
||||
bid_price_1 = FloatField()
|
||||
bid_price_2 = FloatField()
|
||||
bid_price_3 = FloatField()
|
||||
bid_price_4 = FloatField()
|
||||
bid_price_5 = FloatField()
|
||||
ask_price_1: float = FloatField()
|
||||
ask_price_2: float = FloatField()
|
||||
ask_price_3: float = FloatField()
|
||||
ask_price_4: float = FloatField()
|
||||
ask_price_5: float = FloatField()
|
||||
|
||||
ask_price_1 = FloatField()
|
||||
ask_price_2 = FloatField()
|
||||
ask_price_3 = FloatField()
|
||||
ask_price_4 = FloatField()
|
||||
ask_price_5 = FloatField()
|
||||
bid_volume_1: float = FloatField()
|
||||
bid_volume_2: float = FloatField()
|
||||
bid_volume_3: float = FloatField()
|
||||
bid_volume_4: float = FloatField()
|
||||
bid_volume_5: float = FloatField()
|
||||
|
||||
bid_volume_1 = FloatField()
|
||||
bid_volume_2 = FloatField()
|
||||
bid_volume_3 = FloatField()
|
||||
bid_volume_4 = FloatField()
|
||||
bid_volume_5 = FloatField()
|
||||
ask_volume_1: float = FloatField()
|
||||
ask_volume_2: float = FloatField()
|
||||
ask_volume_3: float = FloatField()
|
||||
ask_volume_4: float = FloatField()
|
||||
ask_volume_5: float = FloatField()
|
||||
|
||||
ask_volume_1 = FloatField()
|
||||
ask_volume_2 = FloatField()
|
||||
ask_volume_3 = FloatField()
|
||||
ask_volume_4 = FloatField()
|
||||
ask_volume_5 = FloatField()
|
||||
|
||||
class Meta:
|
||||
database = _db
|
||||
indexes = ((("datetime", "symbol"), True),)
|
||||
meta = {"indexes": [{"fields": ("datetime", "symbol", "exchange"), "unique": True}]}
|
||||
|
||||
@staticmethod
|
||||
def from_tick(tick: TickData):
|
||||
@ -254,7 +190,7 @@ class DbTickData(ModelBase):
|
||||
db_tick.ask_volume_4 = tick.ask_volume_4
|
||||
db_tick.ask_volume_5 = tick.ask_volume_5
|
||||
|
||||
return tick
|
||||
return db_tick
|
||||
|
||||
def to_tick(self):
|
||||
"""
|
||||
@ -304,20 +240,63 @@ class DbTickData(ModelBase):
|
||||
|
||||
return tick
|
||||
|
||||
|
||||
class MongoManager(BaseDatabaseManager):
|
||||
def load_bar_data(
|
||||
self,
|
||||
symbol: str,
|
||||
exchange: Exchange,
|
||||
interval: Interval,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
) -> Sequence[BarData]:
|
||||
s = DbBarData.objects(
|
||||
symbol=symbol,
|
||||
exchange=exchange.value,
|
||||
interval=interval.value,
|
||||
datetime__gte=start,
|
||||
datetime__lte=end,
|
||||
)
|
||||
data = [db_bar.to_bar() for db_bar in s]
|
||||
return data
|
||||
|
||||
def load_tick_data(
|
||||
self, symbol: str, exchange: Exchange, start: datetime, end: datetime
|
||||
) -> Sequence[TickData]:
|
||||
s = DbTickData.objects(
|
||||
symbol=symbol,
|
||||
exchange=exchange.value,
|
||||
datetime__gte=start,
|
||||
datetime__lte=end,
|
||||
)
|
||||
data = [db_tick.to_tick() for db_tick in s]
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def save_all(objs: List["DbTickData"]):
|
||||
with _db.atomic():
|
||||
if _driver is Driver.POSTGRESQL:
|
||||
for bar in objs:
|
||||
DbTickData.insert(bar.to_dict()).on_conflict(
|
||||
update=bar.to_dict(),
|
||||
preserve=(DbTickData.id),
|
||||
conflict_target=(DbTickData.datetime, DbTickData.symbol),
|
||||
).execute()
|
||||
else:
|
||||
for c in chunked(objs, 50):
|
||||
DbBarData.insert_many(c).on_conflict_replace()
|
||||
def to_update_param(d):
|
||||
return {
|
||||
"set__" + k: v.value if isinstance(v, Enum) else v
|
||||
for k, v in d.__dict__.items()
|
||||
}
|
||||
|
||||
def save_bar_data(self, datas: Sequence[BarData]):
|
||||
for d in datas:
|
||||
updates = self.to_update_param(d)
|
||||
updates.pop("set__gateway_name")
|
||||
updates.pop("set__vt_symbol")
|
||||
(
|
||||
DbBarData.objects(
|
||||
symbol=d.symbol, interval=d.interval.value, datetime=d.datetime
|
||||
).update_one(upsert=True, **updates)
|
||||
)
|
||||
|
||||
_db.connect()
|
||||
_db.create_tables([DbBarData, DbTickData])
|
||||
def save_tick_data(self, datas: Sequence[TickData]):
|
||||
for d in datas:
|
||||
updates = self.to_update_param(d)
|
||||
updates.pop("set__gateway_name")
|
||||
updates.pop("set__vt_symbol")
|
||||
(
|
||||
DbTickData.objects(
|
||||
symbol=d.symbol, exchange=d.exchange.value, datetime=d.datetime
|
||||
).update_one(upsert=True, **updates)
|
||||
)
|
370
vnpy/trader/database/database_sql.py
Normal file
370
vnpy/trader/database/database_sql.py
Normal file
@ -0,0 +1,370 @@
|
||||
""""""
|
||||
from datetime import datetime
|
||||
from typing import List, Sequence, Type
|
||||
|
||||
from peewee import (
|
||||
AutoField,
|
||||
CharField,
|
||||
Database,
|
||||
DateTimeField,
|
||||
FloatField,
|
||||
Model,
|
||||
MySQLDatabase,
|
||||
PostgresqlDatabase,
|
||||
SqliteDatabase,
|
||||
chunked,
|
||||
)
|
||||
|
||||
from vnpy.trader.constant import Exchange, Interval
|
||||
from vnpy.trader.object import BarData, TickData
|
||||
from vnpy.trader.utility import get_file_path
|
||||
from .database import BaseDatabaseManager, Driver
|
||||
|
||||
|
||||
def init(driver: Driver, settings: dict):
|
||||
init_funcs = {
|
||||
Driver.SQLITE: init_sqlite,
|
||||
Driver.MYSQL: init_mysql,
|
||||
Driver.POSTGRESQL: init_postgresql,
|
||||
}
|
||||
assert driver in init_funcs
|
||||
|
||||
db = init_funcs[driver](settings)
|
||||
bar, tick = init_models(db, driver)
|
||||
return SqlManager(bar, tick)
|
||||
|
||||
|
||||
def init_sqlite(settings: dict):
|
||||
database = settings["database"]
|
||||
path = str(get_file_path(database))
|
||||
db = SqliteDatabase(path)
|
||||
return db
|
||||
|
||||
|
||||
def init_mysql(settings: dict):
|
||||
keys = {"database", "user", "password", "host", "port"}
|
||||
settings = {k: v for k, v in settings.items() if k in keys}
|
||||
db = MySQLDatabase(**settings)
|
||||
return db
|
||||
|
||||
|
||||
def init_postgresql(settings: dict):
|
||||
keys = {"database", "user", "password", "host", "port"}
|
||||
settings = {k: v for k, v in settings.items() if k in keys}
|
||||
db = PostgresqlDatabase(**settings)
|
||||
return db
|
||||
|
||||
|
||||
class ModelBase(Model):
|
||||
|
||||
def to_dict(self):
|
||||
return self.__data__
|
||||
|
||||
|
||||
def init_models(db: Database, driver: Driver):
|
||||
class DbBarData(ModelBase):
|
||||
"""
|
||||
Candlestick bar data for database storage.
|
||||
|
||||
Index is defined unique with datetime, interval, symbol
|
||||
"""
|
||||
|
||||
id = AutoField()
|
||||
symbol: str = CharField()
|
||||
exchange: str = CharField()
|
||||
datetime: datetime = DateTimeField()
|
||||
interval: str = CharField()
|
||||
|
||||
volume: float = FloatField()
|
||||
open_price: float = FloatField()
|
||||
high_price: float = FloatField()
|
||||
low_price: float = FloatField()
|
||||
close_price: float = FloatField()
|
||||
|
||||
class Meta:
|
||||
database = db
|
||||
indexes = ((("datetime", "interval", "symbol", "exchange"), True),)
|
||||
|
||||
@staticmethod
|
||||
def from_bar(bar: BarData):
|
||||
"""
|
||||
Generate DbBarData object from BarData.
|
||||
"""
|
||||
db_bar = DbBarData()
|
||||
|
||||
db_bar.symbol = bar.symbol
|
||||
db_bar.exchange = bar.exchange.value
|
||||
db_bar.datetime = bar.datetime
|
||||
db_bar.interval = bar.interval.value
|
||||
db_bar.volume = bar.volume
|
||||
db_bar.open_price = bar.open_price
|
||||
db_bar.high_price = bar.high_price
|
||||
db_bar.low_price = bar.low_price
|
||||
db_bar.close_price = bar.close_price
|
||||
|
||||
return db_bar
|
||||
|
||||
def to_bar(self):
|
||||
"""
|
||||
Generate BarData object from DbBarData.
|
||||
"""
|
||||
bar = BarData(
|
||||
symbol=self.symbol,
|
||||
exchange=Exchange(self.exchange),
|
||||
datetime=self.datetime,
|
||||
interval=Interval(self.interval),
|
||||
volume=self.volume,
|
||||
open_price=self.open_price,
|
||||
high_price=self.high_price,
|
||||
low_price=self.low_price,
|
||||
close_price=self.close_price,
|
||||
gateway_name="DB",
|
||||
)
|
||||
return bar
|
||||
|
||||
@staticmethod
|
||||
def save_all(objs: List["DbBarData"]):
|
||||
"""
|
||||
save a list of objects, update if exists.
|
||||
"""
|
||||
dicts = [i.to_dict() for i in objs]
|
||||
with db.atomic():
|
||||
if driver is Driver.POSTGRESQL:
|
||||
for bar in dicts:
|
||||
DbBarData.insert(bar).on_conflict(
|
||||
update=bar,
|
||||
conflict_target=(
|
||||
DbBarData.datetime,
|
||||
DbBarData.interval,
|
||||
DbBarData.symbol,
|
||||
DbBarData.exchange,
|
||||
),
|
||||
).execute()
|
||||
else:
|
||||
for c in chunked(dicts, 50):
|
||||
DbBarData.insert_many(c).on_conflict_replace().execute()
|
||||
|
||||
class DbTickData(ModelBase):
|
||||
"""
|
||||
Tick data for database storage.
|
||||
|
||||
Index is defined unique with (datetime, symbol)
|
||||
"""
|
||||
|
||||
id = AutoField()
|
||||
|
||||
symbol: str = CharField()
|
||||
exchange: str = CharField()
|
||||
datetime: datetime = DateTimeField()
|
||||
|
||||
name: str = CharField()
|
||||
volume: float = FloatField()
|
||||
last_price: float = FloatField()
|
||||
last_volume: float = FloatField()
|
||||
limit_up: float = FloatField()
|
||||
limit_down: float = FloatField()
|
||||
|
||||
open_price: float = FloatField()
|
||||
high_price: float = FloatField()
|
||||
low_price: float = FloatField()
|
||||
pre_close: float = FloatField()
|
||||
|
||||
bid_price_1: float = FloatField()
|
||||
bid_price_2: float = FloatField(null=True)
|
||||
bid_price_3: float = FloatField(null=True)
|
||||
bid_price_4: float = FloatField(null=True)
|
||||
bid_price_5: float = FloatField(null=True)
|
||||
|
||||
ask_price_1: float = FloatField()
|
||||
ask_price_2: float = FloatField(null=True)
|
||||
ask_price_3: float = FloatField(null=True)
|
||||
ask_price_4: float = FloatField(null=True)
|
||||
ask_price_5: float = FloatField(null=True)
|
||||
|
||||
bid_volume_1: float = FloatField()
|
||||
bid_volume_2: float = FloatField(null=True)
|
||||
bid_volume_3: float = FloatField(null=True)
|
||||
bid_volume_4: float = FloatField(null=True)
|
||||
bid_volume_5: float = FloatField(null=True)
|
||||
|
||||
ask_volume_1: float = FloatField()
|
||||
ask_volume_2: float = FloatField(null=True)
|
||||
ask_volume_3: float = FloatField(null=True)
|
||||
ask_volume_4: float = FloatField(null=True)
|
||||
ask_volume_5: float = FloatField(null=True)
|
||||
|
||||
class Meta:
|
||||
database = db
|
||||
indexes = ((("datetime", "symbol", "exchange"), True),)
|
||||
|
||||
@staticmethod
|
||||
def from_tick(tick: TickData):
|
||||
"""
|
||||
Generate DbTickData object from TickData.
|
||||
"""
|
||||
db_tick = DbTickData()
|
||||
|
||||
db_tick.symbol = tick.symbol
|
||||
db_tick.exchange = tick.exchange.value
|
||||
db_tick.datetime = tick.datetime
|
||||
db_tick.name = tick.name
|
||||
db_tick.volume = tick.volume
|
||||
db_tick.last_price = tick.last_price
|
||||
db_tick.last_volume = tick.last_volume
|
||||
db_tick.limit_up = tick.limit_up
|
||||
db_tick.limit_down = tick.limit_down
|
||||
db_tick.open_price = tick.open_price
|
||||
db_tick.high_price = tick.high_price
|
||||
db_tick.low_price = tick.low_price
|
||||
db_tick.pre_close = tick.pre_close
|
||||
|
||||
db_tick.bid_price_1 = tick.bid_price_1
|
||||
db_tick.ask_price_1 = tick.ask_price_1
|
||||
db_tick.bid_volume_1 = tick.bid_volume_1
|
||||
db_tick.ask_volume_1 = tick.ask_volume_1
|
||||
|
||||
if tick.bid_price_2:
|
||||
db_tick.bid_price_2 = tick.bid_price_2
|
||||
db_tick.bid_price_3 = tick.bid_price_3
|
||||
db_tick.bid_price_4 = tick.bid_price_4
|
||||
db_tick.bid_price_5 = tick.bid_price_5
|
||||
|
||||
db_tick.ask_price_2 = tick.ask_price_2
|
||||
db_tick.ask_price_3 = tick.ask_price_3
|
||||
db_tick.ask_price_4 = tick.ask_price_4
|
||||
db_tick.ask_price_5 = tick.ask_price_5
|
||||
|
||||
db_tick.bid_volume_2 = tick.bid_volume_2
|
||||
db_tick.bid_volume_3 = tick.bid_volume_3
|
||||
db_tick.bid_volume_4 = tick.bid_volume_4
|
||||
db_tick.bid_volume_5 = tick.bid_volume_5
|
||||
|
||||
db_tick.ask_volume_2 = tick.ask_volume_2
|
||||
db_tick.ask_volume_3 = tick.ask_volume_3
|
||||
db_tick.ask_volume_4 = tick.ask_volume_4
|
||||
db_tick.ask_volume_5 = tick.ask_volume_5
|
||||
|
||||
return db_tick
|
||||
|
||||
def to_tick(self):
|
||||
"""
|
||||
Generate TickData object from DbTickData.
|
||||
"""
|
||||
tick = TickData(
|
||||
symbol=self.symbol,
|
||||
exchange=Exchange(self.exchange),
|
||||
datetime=self.datetime,
|
||||
name=self.name,
|
||||
volume=self.volume,
|
||||
last_price=self.last_price,
|
||||
last_volume=self.last_volume,
|
||||
limit_up=self.limit_up,
|
||||
limit_down=self.limit_down,
|
||||
open_price=self.open_price,
|
||||
high_price=self.high_price,
|
||||
low_price=self.low_price,
|
||||
pre_close=self.pre_close,
|
||||
bid_price_1=self.bid_price_1,
|
||||
ask_price_1=self.ask_price_1,
|
||||
bid_volume_1=self.bid_volume_1,
|
||||
ask_volume_1=self.ask_volume_1,
|
||||
gateway_name="DB",
|
||||
)
|
||||
|
||||
if self.bid_price_2:
|
||||
tick.bid_price_2 = self.bid_price_2
|
||||
tick.bid_price_3 = self.bid_price_3
|
||||
tick.bid_price_4 = self.bid_price_4
|
||||
tick.bid_price_5 = self.bid_price_5
|
||||
|
||||
tick.ask_price_2 = self.ask_price_2
|
||||
tick.ask_price_3 = self.ask_price_3
|
||||
tick.ask_price_4 = self.ask_price_4
|
||||
tick.ask_price_5 = self.ask_price_5
|
||||
|
||||
tick.bid_volume_2 = self.bid_volume_2
|
||||
tick.bid_volume_3 = self.bid_volume_3
|
||||
tick.bid_volume_4 = self.bid_volume_4
|
||||
tick.bid_volume_5 = self.bid_volume_5
|
||||
|
||||
tick.ask_volume_2 = self.ask_volume_2
|
||||
tick.ask_volume_3 = self.ask_volume_3
|
||||
tick.ask_volume_4 = self.ask_volume_4
|
||||
tick.ask_volume_5 = self.ask_volume_5
|
||||
|
||||
return tick
|
||||
|
||||
@staticmethod
|
||||
def save_all(objs: List["DbTickData"]):
|
||||
dicts = [i.to_dict() for i in objs]
|
||||
with db.atomic():
|
||||
if driver is Driver.POSTGRESQL:
|
||||
for tick in dicts:
|
||||
DbTickData.insert(tick).on_conflict(
|
||||
update=tick,
|
||||
conflict_target=(
|
||||
DbTickData.datetime,
|
||||
DbTickData.symbol,
|
||||
DbTickData.exchange,
|
||||
),
|
||||
).execute()
|
||||
else:
|
||||
for c in chunked(dicts, 50):
|
||||
DbTickData.insert_many(c).on_conflict_replace().execute()
|
||||
|
||||
db.connect()
|
||||
db.create_tables([DbBarData, DbTickData])
|
||||
return DbBarData, DbTickData
|
||||
|
||||
|
||||
class SqlManager(BaseDatabaseManager):
|
||||
|
||||
def __init__(self, class_bar: Type[Model], class_tick: Type[Model]):
|
||||
self.class_bar = class_bar
|
||||
self.class_tick = class_tick
|
||||
|
||||
def load_bar_data(
|
||||
self,
|
||||
symbol: str,
|
||||
exchange: Exchange,
|
||||
interval: Interval,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
) -> Sequence[BarData]:
|
||||
s = (
|
||||
self.class_bar.select()
|
||||
.where(
|
||||
(self.class_bar.symbol == symbol)
|
||||
& (self.class_bar.exchange == exchange.value)
|
||||
& (self.class_bar.interval == interval.value)
|
||||
& (self.class_bar.datetime >= start)
|
||||
& (self.class_bar.datetime <= end)
|
||||
)
|
||||
.order_by(self.class_bar.datetime)
|
||||
)
|
||||
data = [db_bar.to_bar() for db_bar in s]
|
||||
return data
|
||||
|
||||
def load_tick_data(
|
||||
self, symbol: str, exchange: Exchange, start: datetime, end: datetime
|
||||
) -> Sequence[TickData]:
|
||||
s = (
|
||||
self.class_tick.select()
|
||||
.where(
|
||||
(self.class_tick.symbol == symbol)
|
||||
& (self.class_tick.exchange == exchange.value)
|
||||
& (self.class_tick.datetime >= start)
|
||||
& (self.class_tick.datetime <= end)
|
||||
)
|
||||
.order_by(self.class_tick.datetime)
|
||||
)
|
||||
data = [db_tick.to_tick() for db_tick in s]
|
||||
return data
|
||||
|
||||
def save_bar_data(self, datas: Sequence[BarData]):
|
||||
ds = [self.class_bar.from_bar(i) for i in datas]
|
||||
self.class_bar.save_all(ds)
|
||||
|
||||
def save_tick_data(self, datas: Sequence[TickData]):
|
||||
ds = [self.class_tick.from_tick(i) for i in datas]
|
||||
self.class_tick.save_all(ds)
|
24
vnpy/trader/database/initialize.py
Normal file
24
vnpy/trader/database/initialize.py
Normal file
@ -0,0 +1,24 @@
|
||||
""""""
|
||||
from .database import BaseDatabaseManager, Driver
|
||||
|
||||
|
||||
def init(settings: dict) -> BaseDatabaseManager:
|
||||
driver = Driver(settings["driver"])
|
||||
if driver is Driver.MONGODB:
|
||||
return init_nosql(driver=driver, settings=settings)
|
||||
else:
|
||||
return init_sql(driver=driver, settings=settings)
|
||||
|
||||
|
||||
def init_sql(driver: Driver, settings: dict):
|
||||
from .database_sql import init
|
||||
keys = {'database', "host", "port", "user", "password"}
|
||||
settings = {k: v for k, v in settings.items() if k in keys}
|
||||
_database_manager = init(driver, settings)
|
||||
return _database_manager
|
||||
|
||||
|
||||
def init_nosql(driver: Driver, settings: dict):
|
||||
from .database_mongo import init
|
||||
_database_manager = init(driver, settings=settings)
|
||||
return _database_manager
|
@ -30,9 +30,15 @@ SETTINGS = {
|
||||
"database.host": "localhost",
|
||||
"database.port": 3306,
|
||||
"database.user": "root",
|
||||
"database.password": ""
|
||||
"database.password": "",
|
||||
"database.authentication_source": "admin", # for mongodb
|
||||
}
|
||||
|
||||
# Load global setting from json file.
|
||||
SETTING_FILENAME = "vt_setting.json"
|
||||
SETTINGS.update(load_json(SETTING_FILENAME))
|
||||
|
||||
|
||||
def get_settings(prefix: str = ""):
|
||||
prefix_length = len(prefix)
|
||||
return {k[prefix_length:]: v for k, v in SETTINGS.items() if k.startswith(prefix)}
|
||||
|
@ -3,20 +3,28 @@ General utility functions.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
from typing import Callable, TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import talib
|
||||
|
||||
from .object import BarData, TickData
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vnpy.trader.constant import Exchange
|
||||
|
||||
def resolve_path(pattern: str):
|
||||
env = dict(os.environ)
|
||||
env.update({"VNPY_TEMP": str(TEMP_DIR)})
|
||||
return pattern.format(**env)
|
||||
|
||||
def extract_vt_symbol(vt_symbol: str):
|
||||
"""
|
||||
:return: (symbol, exchange)
|
||||
"""
|
||||
symbol, exchange = vt_symbol.split('.')
|
||||
return symbol, exchange
|
||||
|
||||
|
||||
def generate_vt_symbol(symbol: str, exchange: "Exchange"):
|
||||
return f'{symbol}.{exchange.value}'
|
||||
|
||||
|
||||
def _get_trader_dir(temp_name: str):
|
||||
|
Loading…
Reference in New Issue
Block a user