[Add] added support for mongodb
[Add] added tests for database and CsvLoaderEngine [Mod] changed .travis.yaml
This commit is contained in:
parent
f797152bd5
commit
5d0bf006c6
@ -1,13 +1,19 @@
|
|||||||
language: python
|
language: python
|
||||||
|
|
||||||
|
cache: pip
|
||||||
|
|
||||||
dist: xenial # required for Python >= 3.7 (travis-ci/travis-ci#9069)
|
dist: xenial # required for Python >= 3.7 (travis-ci/travis-ci#9069)
|
||||||
|
|
||||||
python:
|
python:
|
||||||
- "3.7"
|
- "3.7"
|
||||||
|
services:
|
||||||
|
- mongodb
|
||||||
|
- mysql
|
||||||
|
- postgresql
|
||||||
|
|
||||||
script:
|
script:
|
||||||
# todo: use python unittest
|
# todo: use python unittest
|
||||||
- mkdir run; cd run; python ../tests/load_all.py
|
- cd tests; source travis_env.sh; python test_all.py
|
||||||
|
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
|
1
tests/app/__init__.py
Normal file
1
tests/app/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from .test_csv_loader import *
|
20
tests/test_all.py
Normal file
20
tests/test_all.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
# tests/runner.py
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
# import your test modules
|
||||||
|
import test_import_all
|
||||||
|
import trader
|
||||||
|
import app
|
||||||
|
|
||||||
|
# initialize the test suite
|
||||||
|
loader = unittest.TestLoader()
|
||||||
|
suite = unittest.TestSuite()
|
||||||
|
|
||||||
|
# add tests to the test suite
|
||||||
|
suite.addTests(loader.loadTestsFromModule(test_import_all))
|
||||||
|
suite.addTests(loader.loadTestsFromModule(trader))
|
||||||
|
suite.addTests(loader.loadTestsFromModule(app))
|
||||||
|
|
||||||
|
# initialize a runner, pass it your suite and run it
|
||||||
|
runner = unittest.TextTestRunner(verbosity=3)
|
||||||
|
result = runner.run(suite)
|
@ -2,23 +2,41 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
class ImportTest(unittest.TestCase):
|
class ImportTest(unittest.TestCase):
|
||||||
|
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
def test_import_all(self):
|
def test_import_all(self):
|
||||||
from vnpy.event import EventEngine
|
from vnpy.event import EventEngine
|
||||||
|
|
||||||
|
def test_import_main_engine(self):
|
||||||
from vnpy.trader.engine import MainEngine
|
from vnpy.trader.engine import MainEngine
|
||||||
|
|
||||||
|
def test_import_ui(self):
|
||||||
from vnpy.trader.ui import MainWindow, create_qapp
|
from vnpy.trader.ui import MainWindow, create_qapp
|
||||||
|
|
||||||
|
def test_import_bitmex_gateway(self):
|
||||||
from vnpy.gateway.bitmex import BitmexGateway
|
from vnpy.gateway.bitmex import BitmexGateway
|
||||||
|
|
||||||
|
def test_import_futu_gateway(self):
|
||||||
from vnpy.gateway.futu import FutuGateway
|
from vnpy.gateway.futu import FutuGateway
|
||||||
|
|
||||||
|
def test_import_ib_gateway(self):
|
||||||
from vnpy.gateway.ib import IbGateway
|
from vnpy.gateway.ib import IbGateway
|
||||||
|
|
||||||
|
def test_import_ctp_gateway(self):
|
||||||
from vnpy.gateway.ctp import CtpGateway
|
from vnpy.gateway.ctp import CtpGateway
|
||||||
|
|
||||||
|
def test_import_tiger_gateway(self):
|
||||||
from vnpy.gateway.tiger import TigerGateway
|
from vnpy.gateway.tiger import TigerGateway
|
||||||
|
|
||||||
|
def test_import_oes_gateway(self):
|
||||||
from vnpy.gateway.oes import OesGateway
|
from vnpy.gateway.oes import OesGateway
|
||||||
|
|
||||||
|
def test_import_cta_strategy_app(self):
|
||||||
from vnpy.app.cta_strategy import CtaStrategyApp
|
from vnpy.app.cta_strategy import CtaStrategyApp
|
||||||
|
|
||||||
|
def test_import_csv_loader_app(self):
|
||||||
from vnpy.app.csv_loader import CsvLoaderApp
|
from vnpy.app.csv_loader import CsvLoaderApp
|
||||||
|
|
||||||
|
|
2
tests/trader/__init__.py
Normal file
2
tests/trader/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
from .test_database import *
|
||||||
|
from .test_settings import *
|
125
tests/trader/test_database.py
Normal file
125
tests/trader/test_database.py
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
"""
|
||||||
|
Test if database works fine
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
from vnpy.trader.constant import Exchange, Interval
|
||||||
|
from vnpy.trader.database.database import Driver
|
||||||
|
from vnpy.trader.object import BarData, TickData
|
||||||
|
|
||||||
|
os.environ['VNPY_TESTING'] = '1'
|
||||||
|
|
||||||
|
profiles = {
|
||||||
|
Driver.SQLITE: {
|
||||||
|
"driver": "sqlite",
|
||||||
|
"database": "test_db.db",
|
||||||
|
},
|
||||||
|
Driver.MYSQL: {
|
||||||
|
"driver": "mysql",
|
||||||
|
"database": os.environ['VNPY_TEST_MYSQL_DATABASE'],
|
||||||
|
"host": os.environ['VNPY_TEST_MYSQL_HOST'],
|
||||||
|
"port": int(os.environ['VNPY_TEST_MYSQL_PORT']),
|
||||||
|
"user": os.environ["VNPY_TEST_MYSQL_USER"],
|
||||||
|
"password": os.environ['VNPY_TEST_MYSQL_PASSWORD'],
|
||||||
|
},
|
||||||
|
Driver.POSTGRESQL: {
|
||||||
|
"driver": "postgresql",
|
||||||
|
"database": os.environ['VNPY_TEST_POSTGRESQL_DATABASE'],
|
||||||
|
"host": os.environ['VNPY_TEST_POSTGRESQL_HOST'],
|
||||||
|
"port": int(os.environ['VNPY_TEST_POSTGRESQL_PORT']),
|
||||||
|
"user": os.environ["VNPY_TEST_POSTGRESQL_USER"],
|
||||||
|
"password": os.environ['VNPY_TEST_POSTGRESQL_PASSWORD'],
|
||||||
|
},
|
||||||
|
Driver.MONGODB: {
|
||||||
|
"driver": "mongodb",
|
||||||
|
"database": os.environ['VNPY_TEST_MONGODB_DATABASE'],
|
||||||
|
"host": os.environ['VNPY_TEST_MONGODB_HOST'],
|
||||||
|
"port": int(os.environ['VNPY_TEST_MONGODB_PORT']),
|
||||||
|
"user": "",
|
||||||
|
"password": "",
|
||||||
|
"authentication_source": "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def now():
|
||||||
|
return datetime.utcnow()
|
||||||
|
|
||||||
|
|
||||||
|
bar = BarData(
|
||||||
|
gateway_name="DB",
|
||||||
|
symbol="test_symbol",
|
||||||
|
exchange=Exchange.BITMEX,
|
||||||
|
datetime=now(),
|
||||||
|
interval=Interval.MINUTE,
|
||||||
|
)
|
||||||
|
|
||||||
|
tick = TickData(
|
||||||
|
gateway_name="DB",
|
||||||
|
symbol="test_symbol",
|
||||||
|
exchange=Exchange.BITMEX,
|
||||||
|
datetime=now(),
|
||||||
|
name="DB_test_symbol",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDatabase(unittest.TestCase):
|
||||||
|
|
||||||
|
def connect(self, settings: dict):
|
||||||
|
from vnpy.trader.database.initialize import init # noqa
|
||||||
|
self.manager = init(settings)
|
||||||
|
|
||||||
|
def test_upsert_bar(self):
|
||||||
|
for driver, settings in profiles.items():
|
||||||
|
with self.subTest(driver=driver, settings=settings):
|
||||||
|
self.connect(settings)
|
||||||
|
self.manager.save_bar_data([bar])
|
||||||
|
self.manager.save_bar_data([bar])
|
||||||
|
|
||||||
|
def test_save_load_bar(self):
|
||||||
|
for driver, settings in profiles.items():
|
||||||
|
with self.subTest(driver=driver, settings=settings):
|
||||||
|
self.connect(settings)
|
||||||
|
# save first
|
||||||
|
self.manager.save_bar_data([bar])
|
||||||
|
|
||||||
|
# and load
|
||||||
|
results = self.manager.load_bar_data(
|
||||||
|
symbol=bar.symbol,
|
||||||
|
exchange=bar.exchange,
|
||||||
|
interval=bar.interval,
|
||||||
|
start=bar.datetime - timedelta(seconds=1), # time is not accuracy
|
||||||
|
end=now() + timedelta(seconds=1), # time is not accuracy
|
||||||
|
)
|
||||||
|
count = len(results)
|
||||||
|
self.assertNotEqual(count, 0)
|
||||||
|
|
||||||
|
def test_upsert_tick(self):
|
||||||
|
for driver, settings in profiles.items():
|
||||||
|
with self.subTest(driver=driver, settings=settings):
|
||||||
|
self.connect(settings)
|
||||||
|
self.manager.save_tick_data([tick])
|
||||||
|
self.manager.save_tick_data([tick])
|
||||||
|
|
||||||
|
def test_save_load_tick(self):
|
||||||
|
for driver, settings in profiles.items():
|
||||||
|
with self.subTest(driver=driver, settings=settings):
|
||||||
|
self.connect(settings)
|
||||||
|
# save first
|
||||||
|
self.manager.save_tick_data([tick])
|
||||||
|
|
||||||
|
# and load
|
||||||
|
results = self.manager.load_tick_data(
|
||||||
|
symbol=bar.symbol,
|
||||||
|
exchange=bar.exchange,
|
||||||
|
start=bar.datetime - timedelta(seconds=1), # time is not accuracy
|
||||||
|
end=now() + timedelta(seconds=1), # time is not accuracy
|
||||||
|
)
|
||||||
|
count = len(results)
|
||||||
|
self.assertNotEqual(count, 0)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
24
tests/trader/test_settings.py
Normal file
24
tests/trader/test_settings.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
"""
|
||||||
|
Test if database works fine
|
||||||
|
"""
|
||||||
|
import unittest
|
||||||
|
from vnpy.trader.setting import SETTINGS, get_settings
|
||||||
|
|
||||||
|
|
||||||
|
class TestSettings(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_get_settings(self):
|
||||||
|
SETTINGS['a'] = 1
|
||||||
|
got = get_settings()
|
||||||
|
self.assertIn('a', got)
|
||||||
|
self.assertEqual(got['a'], 1)
|
||||||
|
|
||||||
|
def test_get_settings_with_prefix(self):
|
||||||
|
SETTINGS['a.a'] = 1
|
||||||
|
got = get_settings()
|
||||||
|
self.assertIn('a', got)
|
||||||
|
self.assertEqual(got['a'], 1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
16
tests/travis_env.sh
Normal file
16
tests/travis_env.sh
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
export VNPY_TEST_MYSQL_DATABASE=vnpy
|
||||||
|
export VNPY_TEST_MYSQL_HOST=127.0.0.1
|
||||||
|
export VNPY_TEST_MYSQL_PORT=3306
|
||||||
|
export VNPY_TEST_MYSQL_USER=root
|
||||||
|
export VNPY_TEST_MYSQL_PASSWORD=
|
||||||
|
export VNPY_TEST_POSTGRESQL_DATABASE=vnpy
|
||||||
|
export VNPY_TEST_POSTGRESQL_HOST=127.0.0.1
|
||||||
|
export VNPY_TEST_POSTGRESQL_PORT=5432
|
||||||
|
export VNPY_TEST_POSTGRESQL_USER=postgres
|
||||||
|
export VNPY_TEST_POSTGRESQL_PASSWORD=
|
||||||
|
export VNPY_TEST_MONGODB_DATABASE=vnpy
|
||||||
|
export VNPY_TEST_MONGODB_HOST=127.0.0.1
|
||||||
|
export VNPY_TEST_MONGODB_PORT=27017
|
||||||
|
|
@ -26,8 +26,9 @@ from typing import TextIO
|
|||||||
|
|
||||||
from vnpy.event import EventEngine
|
from vnpy.event import EventEngine
|
||||||
from vnpy.trader.constant import Exchange, Interval
|
from vnpy.trader.constant import Exchange, Interval
|
||||||
from vnpy.trader.database import DbBarData
|
from vnpy.trader.database import database_manager
|
||||||
from vnpy.trader.engine import BaseEngine, MainEngine
|
from vnpy.trader.engine import BaseEngine, MainEngine
|
||||||
|
from vnpy.trader.object import BarData
|
||||||
|
|
||||||
APP_NAME = "CsvLoader"
|
APP_NAME = "CsvLoader"
|
||||||
|
|
||||||
@ -70,7 +71,7 @@ class CsvLoaderEngine(BaseEngine):
|
|||||||
"""
|
"""
|
||||||
reader = csv.DictReader(f)
|
reader = csv.DictReader(f)
|
||||||
|
|
||||||
db_bars = []
|
bars = []
|
||||||
start = None
|
start = None
|
||||||
count = 0
|
count = 0
|
||||||
for item in reader:
|
for item in reader:
|
||||||
@ -79,29 +80,29 @@ class CsvLoaderEngine(BaseEngine):
|
|||||||
else:
|
else:
|
||||||
dt = datetime.fromisoformat(item[datetime_head])
|
dt = datetime.fromisoformat(item[datetime_head])
|
||||||
|
|
||||||
db_bar = DbBarData(
|
bar = BarData(
|
||||||
symbol=symbol,
|
symbol=symbol,
|
||||||
exchange=exchange.value,
|
exchange=exchange,
|
||||||
datetime=dt,
|
datetime=dt,
|
||||||
interval=interval.value,
|
interval=interval,
|
||||||
volume=item[volume_head],
|
volume=item[volume_head],
|
||||||
open_price=item[open_head],
|
open_price=item[open_head],
|
||||||
high_price=item[high_head],
|
high_price=item[high_head],
|
||||||
low_price=item[low_head],
|
low_price=item[low_head],
|
||||||
close_price=item[close_head],
|
close_price=item[close_head],
|
||||||
|
gateway_name="DB",
|
||||||
)
|
)
|
||||||
|
|
||||||
db_bars.append(db_bar)
|
bars.append(bar)
|
||||||
|
|
||||||
# do some statistics
|
# do some statistics
|
||||||
count += 1
|
count += 1
|
||||||
if not start:
|
if not start:
|
||||||
start = db_bar.datetime
|
start = bar.datetime
|
||||||
end = db_bar.datetime
|
end = bar.datetime
|
||||||
|
|
||||||
# insert into database
|
# insert into database
|
||||||
DbBarData.save_all(db_bars)
|
database_manager.save_bar_data(bars)
|
||||||
|
|
||||||
return start, end, count
|
return start, end, count
|
||||||
|
|
||||||
def load(
|
def load(
|
||||||
|
@ -12,9 +12,9 @@ from pandas import DataFrame
|
|||||||
|
|
||||||
from vnpy.trader.constant import (Direction, Offset, Exchange,
|
from vnpy.trader.constant import (Direction, Offset, Exchange,
|
||||||
Interval, Status)
|
Interval, Status)
|
||||||
from vnpy.trader.database import DbBarData, DbTickData
|
from vnpy.trader.database import database_manager
|
||||||
from vnpy.trader.object import OrderData, TradeData
|
from vnpy.trader.object import OrderData, TradeData, BarData, TickData
|
||||||
from vnpy.trader.utility import round_to_pricetick
|
from vnpy.trader.utility import round_to_pricetick, extract_vt_symbol
|
||||||
|
|
||||||
from .base import (
|
from .base import (
|
||||||
BacktestingMode,
|
BacktestingMode,
|
||||||
@ -103,8 +103,8 @@ class BacktestingEngine:
|
|||||||
|
|
||||||
self.strategy_class = None
|
self.strategy_class = None
|
||||||
self.strategy = None
|
self.strategy = None
|
||||||
self.tick = None
|
self.tick: TickData
|
||||||
self.bar = None
|
self.bar: BarData
|
||||||
self.datetime = None
|
self.datetime = None
|
||||||
|
|
||||||
self.interval = None
|
self.interval = None
|
||||||
@ -199,14 +199,16 @@ class BacktestingEngine:
|
|||||||
|
|
||||||
if self.mode == BacktestingMode.BAR:
|
if self.mode == BacktestingMode.BAR:
|
||||||
self.history_data = load_bar_data(
|
self.history_data = load_bar_data(
|
||||||
self.vt_symbol,
|
self.symbol,
|
||||||
|
self.exchange,
|
||||||
self.interval,
|
self.interval,
|
||||||
self.start,
|
self.start,
|
||||||
self.end
|
self.end
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.history_data = load_tick_data(
|
self.history_data = load_tick_data(
|
||||||
self.vt_symbol,
|
self.symbol,
|
||||||
|
self.exchange,
|
||||||
self.start,
|
self.start,
|
||||||
self.end
|
self.end
|
||||||
)
|
)
|
||||||
@ -519,7 +521,7 @@ class BacktestingEngine:
|
|||||||
else:
|
else:
|
||||||
self.daily_results[d] = DailyResult(d, price)
|
self.daily_results[d] = DailyResult(d, price)
|
||||||
|
|
||||||
def new_bar(self, bar: DbBarData):
|
def new_bar(self, bar: BarData):
|
||||||
""""""
|
""""""
|
||||||
self.bar = bar
|
self.bar = bar
|
||||||
self.datetime = bar.datetime
|
self.datetime = bar.datetime
|
||||||
@ -530,7 +532,7 @@ class BacktestingEngine:
|
|||||||
|
|
||||||
self.update_daily_close(bar.close_price)
|
self.update_daily_close(bar.close_price)
|
||||||
|
|
||||||
def new_tick(self, tick: DbTickData):
|
def new_tick(self, tick: TickData):
|
||||||
""""""
|
""""""
|
||||||
self.tick = tick
|
self.tick = tick
|
||||||
self.datetime = tick.datetime
|
self.datetime = tick.datetime
|
||||||
@ -965,41 +967,27 @@ def optimize(
|
|||||||
|
|
||||||
@lru_cache(maxsize=10)
|
@lru_cache(maxsize=10)
|
||||||
def load_bar_data(
|
def load_bar_data(
|
||||||
vt_symbol: str,
|
symbol: str,
|
||||||
interval: str,
|
exchange: Exchange,
|
||||||
start: datetime,
|
interval: Interval,
|
||||||
|
start: datetime,
|
||||||
end: datetime
|
end: datetime
|
||||||
):
|
):
|
||||||
""""""
|
""""""
|
||||||
s = (
|
return database_manager.load_bar_data(
|
||||||
DbBarData.select()
|
symbol, exchange, interval, start, end
|
||||||
.where(
|
|
||||||
(DbBarData.vt_symbol == vt_symbol)
|
|
||||||
& (DbBarData.interval == interval)
|
|
||||||
& (DbBarData.datetime >= start)
|
|
||||||
& (DbBarData.datetime <= end)
|
|
||||||
)
|
|
||||||
.order_by(DbBarData.datetime)
|
|
||||||
)
|
)
|
||||||
data = [db_bar.to_bar() for db_bar in s]
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=10)
|
@lru_cache(maxsize=10)
|
||||||
def load_tick_data(
|
def load_tick_data(
|
||||||
vt_symbol: str,
|
symbol: str,
|
||||||
start: datetime,
|
exchange: Exchange,
|
||||||
|
start: datetime,
|
||||||
end: datetime
|
end: datetime
|
||||||
):
|
):
|
||||||
""""""
|
""""""
|
||||||
s = (
|
return database_manager.load_tick_data(
|
||||||
DbTickData.select()
|
symbol, exchange, start, end
|
||||||
.where(
|
|
||||||
(DbTickData.vt_symbol == vt_symbol)
|
|
||||||
& (DbTickData.datetime >= start)
|
|
||||||
& (DbTickData.datetime <= end)
|
|
||||||
)
|
|
||||||
.order_by(DbTickData.datetime)
|
|
||||||
)
|
)
|
||||||
data = [db_tick.db_tick() for db_tick in s]
|
|
||||||
return data
|
|
||||||
|
@ -5,7 +5,7 @@ import os
|
|||||||
import traceback
|
import traceback
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable
|
from typing import Any, Callable, List
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
@ -36,7 +36,7 @@ from vnpy.trader.constant import (
|
|||||||
Status
|
Status
|
||||||
)
|
)
|
||||||
from vnpy.trader.utility import load_json, save_json
|
from vnpy.trader.utility import load_json, save_json
|
||||||
from vnpy.trader.database import DbTickData, DbBarData
|
from vnpy.trader.database import database_manager
|
||||||
from vnpy.trader.setting import SETTINGS
|
from vnpy.trader.setting import SETTINGS
|
||||||
|
|
||||||
from .base import (
|
from .base import (
|
||||||
@ -146,13 +146,12 @@ class CtaEngine(BaseEngine):
|
|||||||
self.write_log("RQData数据接口初始化成功")
|
self.write_log("RQData数据接口初始化成功")
|
||||||
|
|
||||||
def query_bar_from_rq(
|
def query_bar_from_rq(
|
||||||
self, vt_symbol: str, interval: Interval, start: datetime, end: datetime
|
self, symbol: str, exchange: Exchange, interval: Interval, start: datetime, end: datetime
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Query bar data from RQData.
|
Query bar data from RQData.
|
||||||
"""
|
"""
|
||||||
symbol, exchange_str = vt_symbol.split(".")
|
rq_symbol = to_rq_symbol(symbol, exchange)
|
||||||
rq_symbol = to_rq_symbol(vt_symbol)
|
|
||||||
if rq_symbol not in self.rq_symbols:
|
if rq_symbol not in self.rq_symbols:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -166,11 +165,11 @@ class CtaEngine(BaseEngine):
|
|||||||
end_date=end
|
end_date=end
|
||||||
)
|
)
|
||||||
|
|
||||||
data = []
|
data: List[BarData] = []
|
||||||
for ix, row in df.iterrows():
|
for ix, row in df.iterrows():
|
||||||
bar = BarData(
|
bar = BarData(
|
||||||
symbol=symbol,
|
symbol=symbol,
|
||||||
exchange=Exchange(exchange_str),
|
exchange=exchange,
|
||||||
interval=interval,
|
interval=interval,
|
||||||
datetime=row.name.to_pydatetime(),
|
datetime=row.name.to_pydatetime(),
|
||||||
open_price=row["open"],
|
open_price=row["open"],
|
||||||
@ -529,46 +528,41 @@ class CtaEngine(BaseEngine):
|
|||||||
return self.engine_type
|
return self.engine_type
|
||||||
|
|
||||||
def load_bar(
|
def load_bar(
|
||||||
self, vt_symbol: str, days: int, interval: Interval, callback: Callable
|
self, symbol: str, exchange: Exchange, days: int, interval: Interval,
|
||||||
|
callback: Callable[[BarData], None]
|
||||||
):
|
):
|
||||||
""""""
|
""""""
|
||||||
end = datetime.now()
|
end = datetime.now()
|
||||||
start = end - timedelta(days)
|
start = end - timedelta(days)
|
||||||
|
|
||||||
# Query data from RQData by default, if not found, load from database.
|
# Query bars from RQData by default, if not found, load from database.
|
||||||
data = self.query_bar_from_rq(vt_symbol, interval, start, end)
|
bars = self.query_bar_from_rq(symbol, exchange, interval, start, end)
|
||||||
if not data:
|
if not bars:
|
||||||
s = (
|
bars = database_manager.load_bar_data(
|
||||||
DbBarData.select()
|
symbol=symbol,
|
||||||
.where(
|
exchange=exchange,
|
||||||
(DbBarData.vt_symbol == vt_symbol)
|
interval=interval,
|
||||||
& (DbBarData.interval == interval.value)
|
start=start,
|
||||||
& (DbBarData.datetime >= start)
|
end=end,
|
||||||
& (DbBarData.datetime <= end)
|
|
||||||
)
|
|
||||||
.order_by(DbBarData.datetime)
|
|
||||||
)
|
)
|
||||||
data = [db_bar.to_bar() for db_bar in s]
|
|
||||||
|
|
||||||
for bar in data:
|
for bar in bars:
|
||||||
callback(bar)
|
callback(bar)
|
||||||
|
|
||||||
def load_tick(self, vt_symbol: str, days: int, callback: Callable):
|
def load_tick(self, symbol: str, exchange: Exchange, days: int,
|
||||||
|
callback: Callable[[TickData], None]):
|
||||||
""""""
|
""""""
|
||||||
end = datetime.now()
|
end = datetime.now()
|
||||||
start = end - timedelta(days)
|
start = end - timedelta(days)
|
||||||
|
|
||||||
s = (
|
ticks = database_manager.load_tick_data(
|
||||||
DbTickData.select()
|
symbol=symbol,
|
||||||
.where(
|
exchange=exchange,
|
||||||
(DbBarData.vt_symbol == vt_symbol)
|
start=start,
|
||||||
& (DbBarData.datetime >= start)
|
end=end,
|
||||||
& (DbBarData.datetime <= end)
|
|
||||||
)
|
|
||||||
.order_by(DbBarData.datetime)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for tick in s:
|
for tick in ticks:
|
||||||
callback(tick)
|
callback(tick)
|
||||||
|
|
||||||
def call_strategy_func(
|
def call_strategy_func(
|
||||||
@ -757,7 +751,7 @@ class CtaEngine(BaseEngine):
|
|||||||
"""
|
"""
|
||||||
Load strategy class from certain folder.
|
Load strategy class from certain folder.
|
||||||
"""
|
"""
|
||||||
for dirpath, dirnames, filenames in os.walk(path):
|
for dirpath, dirnames, filenames in os.walk(str(path)):
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
if filename.endswith(".py"):
|
if filename.endswith(".py"):
|
||||||
strategy_module_name = ".".join(
|
strategy_module_name = ".".join(
|
||||||
@ -914,19 +908,19 @@ class CtaEngine(BaseEngine):
|
|||||||
self.main_engine.send_email(subject, msg)
|
self.main_engine.send_email(subject, msg)
|
||||||
|
|
||||||
|
|
||||||
def to_rq_symbol(vt_symbol: str):
|
def to_rq_symbol(symbol: str, exchange: Exchange):
|
||||||
"""
|
"""
|
||||||
CZCE product of RQData has symbol like "TA1905" while
|
CZCE product of RQData has symbol like "TA1905" while
|
||||||
vt symbol is "TA905.CZCE" so need to add "1" in symbol.
|
vt symbol is "TA905.CZCE" so need to add "1" in symbol.
|
||||||
"""
|
"""
|
||||||
symbol, exchange_str = vt_symbol.split(".")
|
if exchange is not Exchange.CZCE:
|
||||||
if exchange_str != "CZCE":
|
|
||||||
return symbol.upper()
|
return symbol.upper()
|
||||||
|
|
||||||
for count, word in enumerate(symbol):
|
for count, word in enumerate(symbol):
|
||||||
if word.isdigit():
|
if word.isdigit():
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# noinspection PyUnboundLocalVariable
|
||||||
product = symbol[:count]
|
product = symbol[:count]
|
||||||
year = symbol[count]
|
year = symbol[count]
|
||||||
month = symbol[count + 1:]
|
month = symbol[count + 1:]
|
||||||
|
12
vnpy/trader/database/__init__.py
Normal file
12
vnpy/trader/database/__init__.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
import os
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vnpy.trader.database.database import BaseDatabaseManager
|
||||||
|
|
||||||
|
if "VNPY_TESTING" not in os.environ:
|
||||||
|
from vnpy.trader.setting import get_settings
|
||||||
|
from .initialize import init
|
||||||
|
|
||||||
|
settings = get_settings("database.")
|
||||||
|
database_manager: "BaseDatabaseManager" = init(settings=settings)
|
53
vnpy/trader/database/database.py
Normal file
53
vnpy/trader/database/database.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Sequence, TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vnpy.trader.constant import Interval, Exchange
|
||||||
|
from vnpy.trader.object import BarData, TickData
|
||||||
|
|
||||||
|
|
||||||
|
class Driver(Enum):
|
||||||
|
SQLITE = "sqlite"
|
||||||
|
MYSQL = "mysql"
|
||||||
|
POSTGRESQL = "postgresql"
|
||||||
|
MONGODB = "mongodb"
|
||||||
|
|
||||||
|
|
||||||
|
class BaseDatabaseManager(ABC):
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load_bar_data(
|
||||||
|
self,
|
||||||
|
symbol: str,
|
||||||
|
exchange: "Exchange",
|
||||||
|
interval: "Interval",
|
||||||
|
start: datetime,
|
||||||
|
end: datetime
|
||||||
|
) -> Sequence["BarData"]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load_tick_data(
|
||||||
|
self,
|
||||||
|
symbol: str,
|
||||||
|
exchange: "Exchange",
|
||||||
|
start: datetime,
|
||||||
|
end: datetime
|
||||||
|
) -> Sequence["TickData"]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save_bar_data(
|
||||||
|
self,
|
||||||
|
datas: Sequence["BarData"],
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save_tick_data(
|
||||||
|
self,
|
||||||
|
datas: Sequence["TickData"],
|
||||||
|
):
|
||||||
|
pass
|
@ -1,99 +1,59 @@
|
|||||||
""""""
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List
|
from typing import Sequence
|
||||||
|
|
||||||
from peewee import (
|
from mongoengine import DateTimeField, Document, FloatField, StringField, connect
|
||||||
AutoField,
|
|
||||||
CharField,
|
|
||||||
Database,
|
|
||||||
DateTimeField,
|
|
||||||
FloatField,
|
|
||||||
Model,
|
|
||||||
MySQLDatabase,
|
|
||||||
PostgresqlDatabase,
|
|
||||||
SqliteDatabase,
|
|
||||||
chunked,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .constant import Exchange, Interval
|
from vnpy.trader.constant import Exchange, Interval
|
||||||
from .object import BarData, TickData
|
from vnpy.trader.object import BarData, TickData
|
||||||
from .setting import SETTINGS
|
from .database import BaseDatabaseManager, Driver
|
||||||
from .utility import get_file_path
|
|
||||||
|
|
||||||
|
|
||||||
class Driver(Enum):
|
def init(_: Driver, settings: dict):
|
||||||
SQLITE = "sqlite"
|
|
||||||
MYSQL = "mysql"
|
|
||||||
POSTGRESQL = "postgresql"
|
|
||||||
|
|
||||||
|
|
||||||
_db: Database
|
|
||||||
_driver: Driver
|
|
||||||
|
|
||||||
|
|
||||||
def init():
|
|
||||||
global _driver
|
|
||||||
db_settings = {k[9:]: v for k, v in SETTINGS.items() if k.startswith("database.")}
|
|
||||||
_driver = Driver(db_settings["driver"])
|
|
||||||
|
|
||||||
init_funcs = {
|
|
||||||
Driver.SQLITE: init_sqlite,
|
|
||||||
Driver.MYSQL: init_mysql,
|
|
||||||
Driver.POSTGRESQL: init_postgresql,
|
|
||||||
}
|
|
||||||
|
|
||||||
assert _driver in init_funcs
|
|
||||||
del db_settings["driver"]
|
|
||||||
return init_funcs[_driver](db_settings)
|
|
||||||
|
|
||||||
|
|
||||||
def init_sqlite(settings: dict):
|
|
||||||
global _db
|
|
||||||
database = settings["database"]
|
database = settings["database"]
|
||||||
|
host = settings["host"]
|
||||||
_db = SqliteDatabase(str(get_file_path(database)))
|
port = settings["port"]
|
||||||
|
username = settings["user"]
|
||||||
|
password = settings["password"]
|
||||||
|
authentication_source = settings["authentication_source"]
|
||||||
|
if not username: # if username == '' or None, skip username
|
||||||
|
username = None
|
||||||
|
password = None
|
||||||
|
authentication_source = None
|
||||||
|
connect(
|
||||||
|
db=database,
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
username=username,
|
||||||
|
password=password,
|
||||||
|
authentication_source=authentication_source,
|
||||||
|
)
|
||||||
|
return MongoManager()
|
||||||
|
|
||||||
|
|
||||||
def init_mysql(settings: dict):
|
class DbBarData(Document):
|
||||||
global _db
|
|
||||||
_db = MySQLDatabase(**settings)
|
|
||||||
|
|
||||||
|
|
||||||
def init_postgresql(settings: dict):
|
|
||||||
global _db
|
|
||||||
_db = PostgresqlDatabase(**settings)
|
|
||||||
|
|
||||||
|
|
||||||
init()
|
|
||||||
|
|
||||||
|
|
||||||
class ModelBase(Model):
|
|
||||||
def to_dict(self):
|
|
||||||
return self.__data__
|
|
||||||
|
|
||||||
|
|
||||||
class DbBarData(ModelBase):
|
|
||||||
"""
|
"""
|
||||||
Candlestick bar data for database storage.
|
Candlestick bar data for database storage.
|
||||||
|
|
||||||
Index is defined unique with datetime, interval, symbol
|
Index is defined unique with datetime, interval, symbol
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id = AutoField()
|
symbol: str = StringField()
|
||||||
symbol = CharField()
|
exchange: str = StringField()
|
||||||
exchange = CharField()
|
datetime: datetime = DateTimeField()
|
||||||
datetime = DateTimeField()
|
interval: str = StringField()
|
||||||
interval = CharField()
|
|
||||||
|
|
||||||
volume = FloatField()
|
volume: float = FloatField()
|
||||||
open_price = FloatField()
|
open_price: float = FloatField()
|
||||||
high_price = FloatField()
|
high_price: float = FloatField()
|
||||||
low_price = FloatField()
|
low_price: float = FloatField()
|
||||||
close_price = FloatField()
|
close_price: float = FloatField()
|
||||||
|
|
||||||
class Meta:
|
meta = {
|
||||||
database = _db
|
"indexes": [
|
||||||
indexes = ((("datetime", "interval", "symbol"), True),)
|
{"fields": ("datetime", "interval", "symbol", "exchange"), "unique": True}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_bar(bar: BarData):
|
def from_bar(bar: BarData):
|
||||||
@ -132,80 +92,56 @@ class DbBarData(ModelBase):
|
|||||||
)
|
)
|
||||||
return bar
|
return bar
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def save_all(objs: List["DbBarData"]):
|
|
||||||
"""
|
|
||||||
save a list of objects, update if exists.
|
|
||||||
"""
|
|
||||||
with _db.atomic():
|
|
||||||
if _driver is Driver.POSTGRESQL:
|
|
||||||
for bar in objs:
|
|
||||||
DbBarData.insert(bar.to_dict()).on_conflict(
|
|
||||||
update=bar.to_dict(),
|
|
||||||
conflict_target=(
|
|
||||||
DbBarData.datetime,
|
|
||||||
DbBarData.interval,
|
|
||||||
DbBarData.symbol,
|
|
||||||
),
|
|
||||||
).execute()
|
|
||||||
else:
|
|
||||||
for c in chunked(objs, 50):
|
|
||||||
DbBarData.insert_many(c).on_conflict_replace()
|
|
||||||
|
|
||||||
|
class DbTickData(Document):
|
||||||
class DbTickData(ModelBase):
|
|
||||||
"""
|
"""
|
||||||
Tick data for database storage.
|
Tick data for database storage.
|
||||||
|
|
||||||
Index is defined unique with (datetime, symbol)
|
Index is defined unique with (datetime, symbol)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id = AutoField()
|
symbol: str = StringField()
|
||||||
|
exchange: str = StringField()
|
||||||
|
datetime: datetime = DateTimeField()
|
||||||
|
|
||||||
symbol = CharField()
|
name: str = StringField()
|
||||||
exchange = CharField()
|
volume: float = FloatField()
|
||||||
datetime = DateTimeField()
|
last_price: float = FloatField()
|
||||||
|
last_volume: float = FloatField()
|
||||||
|
limit_up: float = FloatField()
|
||||||
|
limit_down: float = FloatField()
|
||||||
|
|
||||||
name = CharField()
|
open_price: float = FloatField()
|
||||||
volume = FloatField()
|
high_price: float = FloatField()
|
||||||
last_price = FloatField()
|
low_price: float = FloatField()
|
||||||
last_volume = FloatField()
|
close_price: float = FloatField()
|
||||||
limit_up = FloatField()
|
pre_close: float = FloatField()
|
||||||
limit_down = FloatField()
|
|
||||||
|
|
||||||
open_price = FloatField()
|
bid_price_1: float = FloatField()
|
||||||
high_price = FloatField()
|
bid_price_2: float = FloatField()
|
||||||
low_price = FloatField()
|
bid_price_3: float = FloatField()
|
||||||
close_price = FloatField()
|
bid_price_4: float = FloatField()
|
||||||
pre_close = FloatField()
|
bid_price_5: float = FloatField()
|
||||||
|
|
||||||
bid_price_1 = FloatField()
|
ask_price_1: float = FloatField()
|
||||||
bid_price_2 = FloatField()
|
ask_price_2: float = FloatField()
|
||||||
bid_price_3 = FloatField()
|
ask_price_3: float = FloatField()
|
||||||
bid_price_4 = FloatField()
|
ask_price_4: float = FloatField()
|
||||||
bid_price_5 = FloatField()
|
ask_price_5: float = FloatField()
|
||||||
|
|
||||||
ask_price_1 = FloatField()
|
bid_volume_1: float = FloatField()
|
||||||
ask_price_2 = FloatField()
|
bid_volume_2: float = FloatField()
|
||||||
ask_price_3 = FloatField()
|
bid_volume_3: float = FloatField()
|
||||||
ask_price_4 = FloatField()
|
bid_volume_4: float = FloatField()
|
||||||
ask_price_5 = FloatField()
|
bid_volume_5: float = FloatField()
|
||||||
|
|
||||||
bid_volume_1 = FloatField()
|
ask_volume_1: float = FloatField()
|
||||||
bid_volume_2 = FloatField()
|
ask_volume_2: float = FloatField()
|
||||||
bid_volume_3 = FloatField()
|
ask_volume_3: float = FloatField()
|
||||||
bid_volume_4 = FloatField()
|
ask_volume_4: float = FloatField()
|
||||||
bid_volume_5 = FloatField()
|
ask_volume_5: float = FloatField()
|
||||||
|
|
||||||
ask_volume_1 = FloatField()
|
meta = {"indexes": [{"fields": ("datetime", "symbol", "exchange"), "unique": True}]}
|
||||||
ask_volume_2 = FloatField()
|
|
||||||
ask_volume_3 = FloatField()
|
|
||||||
ask_volume_4 = FloatField()
|
|
||||||
ask_volume_5 = FloatField()
|
|
||||||
|
|
||||||
class Meta:
|
|
||||||
database = _db
|
|
||||||
indexes = ((("datetime", "symbol"), True),)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_tick(tick: TickData):
|
def from_tick(tick: TickData):
|
||||||
@ -254,7 +190,7 @@ class DbTickData(ModelBase):
|
|||||||
db_tick.ask_volume_4 = tick.ask_volume_4
|
db_tick.ask_volume_4 = tick.ask_volume_4
|
||||||
db_tick.ask_volume_5 = tick.ask_volume_5
|
db_tick.ask_volume_5 = tick.ask_volume_5
|
||||||
|
|
||||||
return tick
|
return db_tick
|
||||||
|
|
||||||
def to_tick(self):
|
def to_tick(self):
|
||||||
"""
|
"""
|
||||||
@ -304,20 +240,63 @@ class DbTickData(ModelBase):
|
|||||||
|
|
||||||
return tick
|
return tick
|
||||||
|
|
||||||
|
|
||||||
|
class MongoManager(BaseDatabaseManager):
|
||||||
|
def load_bar_data(
|
||||||
|
self,
|
||||||
|
symbol: str,
|
||||||
|
exchange: Exchange,
|
||||||
|
interval: Interval,
|
||||||
|
start: datetime,
|
||||||
|
end: datetime,
|
||||||
|
) -> Sequence[BarData]:
|
||||||
|
s = DbBarData.objects(
|
||||||
|
symbol=symbol,
|
||||||
|
exchange=exchange.value,
|
||||||
|
interval=interval.value,
|
||||||
|
datetime__gte=start,
|
||||||
|
datetime__lte=end,
|
||||||
|
)
|
||||||
|
data = [db_bar.to_bar() for db_bar in s]
|
||||||
|
return data
|
||||||
|
|
||||||
|
def load_tick_data(
|
||||||
|
self, symbol: str, exchange: Exchange, start: datetime, end: datetime
|
||||||
|
) -> Sequence[TickData]:
|
||||||
|
s = DbTickData.objects(
|
||||||
|
symbol=symbol,
|
||||||
|
exchange=exchange.value,
|
||||||
|
datetime__gte=start,
|
||||||
|
datetime__lte=end,
|
||||||
|
)
|
||||||
|
data = [db_tick.to_tick() for db_tick in s]
|
||||||
|
return data
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def save_all(objs: List["DbTickData"]):
|
def to_update_param(d):
|
||||||
with _db.atomic():
|
return {
|
||||||
if _driver is Driver.POSTGRESQL:
|
"set__" + k: v.value if isinstance(v, Enum) else v
|
||||||
for bar in objs:
|
for k, v in d.__dict__.items()
|
||||||
DbTickData.insert(bar.to_dict()).on_conflict(
|
}
|
||||||
update=bar.to_dict(),
|
|
||||||
preserve=(DbTickData.id),
|
|
||||||
conflict_target=(DbTickData.datetime, DbTickData.symbol),
|
|
||||||
).execute()
|
|
||||||
else:
|
|
||||||
for c in chunked(objs, 50):
|
|
||||||
DbBarData.insert_many(c).on_conflict_replace()
|
|
||||||
|
|
||||||
|
def save_bar_data(self, datas: Sequence[BarData]):
|
||||||
|
for d in datas:
|
||||||
|
updates = self.to_update_param(d)
|
||||||
|
updates.pop("set__gateway_name")
|
||||||
|
updates.pop("set__vt_symbol")
|
||||||
|
(
|
||||||
|
DbBarData.objects(
|
||||||
|
symbol=d.symbol, interval=d.interval.value, datetime=d.datetime
|
||||||
|
).update_one(upsert=True, **updates)
|
||||||
|
)
|
||||||
|
|
||||||
_db.connect()
|
def save_tick_data(self, datas: Sequence[TickData]):
|
||||||
_db.create_tables([DbBarData, DbTickData])
|
for d in datas:
|
||||||
|
updates = self.to_update_param(d)
|
||||||
|
updates.pop("set__gateway_name")
|
||||||
|
updates.pop("set__vt_symbol")
|
||||||
|
(
|
||||||
|
DbTickData.objects(
|
||||||
|
symbol=d.symbol, exchange=d.exchange.value, datetime=d.datetime
|
||||||
|
).update_one(upsert=True, **updates)
|
||||||
|
)
|
370
vnpy/trader/database/database_sql.py
Normal file
370
vnpy/trader/database/database_sql.py
Normal file
@ -0,0 +1,370 @@
|
|||||||
|
""""""
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Sequence, Type
|
||||||
|
|
||||||
|
from peewee import (
|
||||||
|
AutoField,
|
||||||
|
CharField,
|
||||||
|
Database,
|
||||||
|
DateTimeField,
|
||||||
|
FloatField,
|
||||||
|
Model,
|
||||||
|
MySQLDatabase,
|
||||||
|
PostgresqlDatabase,
|
||||||
|
SqliteDatabase,
|
||||||
|
chunked,
|
||||||
|
)
|
||||||
|
|
||||||
|
from vnpy.trader.constant import Exchange, Interval
|
||||||
|
from vnpy.trader.object import BarData, TickData
|
||||||
|
from vnpy.trader.utility import get_file_path
|
||||||
|
from .database import BaseDatabaseManager, Driver
|
||||||
|
|
||||||
|
|
||||||
|
def init(driver: Driver, settings: dict):
|
||||||
|
init_funcs = {
|
||||||
|
Driver.SQLITE: init_sqlite,
|
||||||
|
Driver.MYSQL: init_mysql,
|
||||||
|
Driver.POSTGRESQL: init_postgresql,
|
||||||
|
}
|
||||||
|
assert driver in init_funcs
|
||||||
|
|
||||||
|
db = init_funcs[driver](settings)
|
||||||
|
bar, tick = init_models(db, driver)
|
||||||
|
return SqlManager(bar, tick)
|
||||||
|
|
||||||
|
|
||||||
|
def init_sqlite(settings: dict):
|
||||||
|
database = settings["database"]
|
||||||
|
path = str(get_file_path(database))
|
||||||
|
db = SqliteDatabase(path)
|
||||||
|
return db
|
||||||
|
|
||||||
|
|
||||||
|
def init_mysql(settings: dict):
|
||||||
|
keys = {"database", "user", "password", "host", "port"}
|
||||||
|
settings = {k: v for k, v in settings.items() if k in keys}
|
||||||
|
db = MySQLDatabase(**settings)
|
||||||
|
return db
|
||||||
|
|
||||||
|
|
||||||
|
def init_postgresql(settings: dict):
|
||||||
|
keys = {"database", "user", "password", "host", "port"}
|
||||||
|
settings = {k: v for k, v in settings.items() if k in keys}
|
||||||
|
db = PostgresqlDatabase(**settings)
|
||||||
|
return db
|
||||||
|
|
||||||
|
|
||||||
|
class ModelBase(Model):
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return self.__data__
|
||||||
|
|
||||||
|
|
||||||
|
def init_models(db: Database, driver: Driver):
|
||||||
|
class DbBarData(ModelBase):
|
||||||
|
"""
|
||||||
|
Candlestick bar data for database storage.
|
||||||
|
|
||||||
|
Index is defined unique with datetime, interval, symbol
|
||||||
|
"""
|
||||||
|
|
||||||
|
id = AutoField()
|
||||||
|
symbol: str = CharField()
|
||||||
|
exchange: str = CharField()
|
||||||
|
datetime: datetime = DateTimeField()
|
||||||
|
interval: str = CharField()
|
||||||
|
|
||||||
|
volume: float = FloatField()
|
||||||
|
open_price: float = FloatField()
|
||||||
|
high_price: float = FloatField()
|
||||||
|
low_price: float = FloatField()
|
||||||
|
close_price: float = FloatField()
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
database = db
|
||||||
|
indexes = ((("datetime", "interval", "symbol", "exchange"), True),)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_bar(bar: BarData):
|
||||||
|
"""
|
||||||
|
Generate DbBarData object from BarData.
|
||||||
|
"""
|
||||||
|
db_bar = DbBarData()
|
||||||
|
|
||||||
|
db_bar.symbol = bar.symbol
|
||||||
|
db_bar.exchange = bar.exchange.value
|
||||||
|
db_bar.datetime = bar.datetime
|
||||||
|
db_bar.interval = bar.interval.value
|
||||||
|
db_bar.volume = bar.volume
|
||||||
|
db_bar.open_price = bar.open_price
|
||||||
|
db_bar.high_price = bar.high_price
|
||||||
|
db_bar.low_price = bar.low_price
|
||||||
|
db_bar.close_price = bar.close_price
|
||||||
|
|
||||||
|
return db_bar
|
||||||
|
|
||||||
|
def to_bar(self):
|
||||||
|
"""
|
||||||
|
Generate BarData object from DbBarData.
|
||||||
|
"""
|
||||||
|
bar = BarData(
|
||||||
|
symbol=self.symbol,
|
||||||
|
exchange=Exchange(self.exchange),
|
||||||
|
datetime=self.datetime,
|
||||||
|
interval=Interval(self.interval),
|
||||||
|
volume=self.volume,
|
||||||
|
open_price=self.open_price,
|
||||||
|
high_price=self.high_price,
|
||||||
|
low_price=self.low_price,
|
||||||
|
close_price=self.close_price,
|
||||||
|
gateway_name="DB",
|
||||||
|
)
|
||||||
|
return bar
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def save_all(objs: List["DbBarData"]):
|
||||||
|
"""
|
||||||
|
save a list of objects, update if exists.
|
||||||
|
"""
|
||||||
|
dicts = [i.to_dict() for i in objs]
|
||||||
|
with db.atomic():
|
||||||
|
if driver is Driver.POSTGRESQL:
|
||||||
|
for bar in dicts:
|
||||||
|
DbBarData.insert(bar).on_conflict(
|
||||||
|
update=bar,
|
||||||
|
conflict_target=(
|
||||||
|
DbBarData.datetime,
|
||||||
|
DbBarData.interval,
|
||||||
|
DbBarData.symbol,
|
||||||
|
DbBarData.exchange,
|
||||||
|
),
|
||||||
|
).execute()
|
||||||
|
else:
|
||||||
|
for c in chunked(dicts, 50):
|
||||||
|
DbBarData.insert_many(c).on_conflict_replace().execute()
|
||||||
|
|
||||||
|
class DbTickData(ModelBase):
|
||||||
|
"""
|
||||||
|
Tick data for database storage.
|
||||||
|
|
||||||
|
Index is defined unique with (datetime, symbol)
|
||||||
|
"""
|
||||||
|
|
||||||
|
id = AutoField()
|
||||||
|
|
||||||
|
symbol: str = CharField()
|
||||||
|
exchange: str = CharField()
|
||||||
|
datetime: datetime = DateTimeField()
|
||||||
|
|
||||||
|
name: str = CharField()
|
||||||
|
volume: float = FloatField()
|
||||||
|
last_price: float = FloatField()
|
||||||
|
last_volume: float = FloatField()
|
||||||
|
limit_up: float = FloatField()
|
||||||
|
limit_down: float = FloatField()
|
||||||
|
|
||||||
|
open_price: float = FloatField()
|
||||||
|
high_price: float = FloatField()
|
||||||
|
low_price: float = FloatField()
|
||||||
|
pre_close: float = FloatField()
|
||||||
|
|
||||||
|
bid_price_1: float = FloatField()
|
||||||
|
bid_price_2: float = FloatField(null=True)
|
||||||
|
bid_price_3: float = FloatField(null=True)
|
||||||
|
bid_price_4: float = FloatField(null=True)
|
||||||
|
bid_price_5: float = FloatField(null=True)
|
||||||
|
|
||||||
|
ask_price_1: float = FloatField()
|
||||||
|
ask_price_2: float = FloatField(null=True)
|
||||||
|
ask_price_3: float = FloatField(null=True)
|
||||||
|
ask_price_4: float = FloatField(null=True)
|
||||||
|
ask_price_5: float = FloatField(null=True)
|
||||||
|
|
||||||
|
bid_volume_1: float = FloatField()
|
||||||
|
bid_volume_2: float = FloatField(null=True)
|
||||||
|
bid_volume_3: float = FloatField(null=True)
|
||||||
|
bid_volume_4: float = FloatField(null=True)
|
||||||
|
bid_volume_5: float = FloatField(null=True)
|
||||||
|
|
||||||
|
ask_volume_1: float = FloatField()
|
||||||
|
ask_volume_2: float = FloatField(null=True)
|
||||||
|
ask_volume_3: float = FloatField(null=True)
|
||||||
|
ask_volume_4: float = FloatField(null=True)
|
||||||
|
ask_volume_5: float = FloatField(null=True)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
database = db
|
||||||
|
indexes = ((("datetime", "symbol", "exchange"), True),)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_tick(tick: TickData):
|
||||||
|
"""
|
||||||
|
Generate DbTickData object from TickData.
|
||||||
|
"""
|
||||||
|
db_tick = DbTickData()
|
||||||
|
|
||||||
|
db_tick.symbol = tick.symbol
|
||||||
|
db_tick.exchange = tick.exchange.value
|
||||||
|
db_tick.datetime = tick.datetime
|
||||||
|
db_tick.name = tick.name
|
||||||
|
db_tick.volume = tick.volume
|
||||||
|
db_tick.last_price = tick.last_price
|
||||||
|
db_tick.last_volume = tick.last_volume
|
||||||
|
db_tick.limit_up = tick.limit_up
|
||||||
|
db_tick.limit_down = tick.limit_down
|
||||||
|
db_tick.open_price = tick.open_price
|
||||||
|
db_tick.high_price = tick.high_price
|
||||||
|
db_tick.low_price = tick.low_price
|
||||||
|
db_tick.pre_close = tick.pre_close
|
||||||
|
|
||||||
|
db_tick.bid_price_1 = tick.bid_price_1
|
||||||
|
db_tick.ask_price_1 = tick.ask_price_1
|
||||||
|
db_tick.bid_volume_1 = tick.bid_volume_1
|
||||||
|
db_tick.ask_volume_1 = tick.ask_volume_1
|
||||||
|
|
||||||
|
if tick.bid_price_2:
|
||||||
|
db_tick.bid_price_2 = tick.bid_price_2
|
||||||
|
db_tick.bid_price_3 = tick.bid_price_3
|
||||||
|
db_tick.bid_price_4 = tick.bid_price_4
|
||||||
|
db_tick.bid_price_5 = tick.bid_price_5
|
||||||
|
|
||||||
|
db_tick.ask_price_2 = tick.ask_price_2
|
||||||
|
db_tick.ask_price_3 = tick.ask_price_3
|
||||||
|
db_tick.ask_price_4 = tick.ask_price_4
|
||||||
|
db_tick.ask_price_5 = tick.ask_price_5
|
||||||
|
|
||||||
|
db_tick.bid_volume_2 = tick.bid_volume_2
|
||||||
|
db_tick.bid_volume_3 = tick.bid_volume_3
|
||||||
|
db_tick.bid_volume_4 = tick.bid_volume_4
|
||||||
|
db_tick.bid_volume_5 = tick.bid_volume_5
|
||||||
|
|
||||||
|
db_tick.ask_volume_2 = tick.ask_volume_2
|
||||||
|
db_tick.ask_volume_3 = tick.ask_volume_3
|
||||||
|
db_tick.ask_volume_4 = tick.ask_volume_4
|
||||||
|
db_tick.ask_volume_5 = tick.ask_volume_5
|
||||||
|
|
||||||
|
return db_tick
|
||||||
|
|
||||||
|
def to_tick(self):
|
||||||
|
"""
|
||||||
|
Generate TickData object from DbTickData.
|
||||||
|
"""
|
||||||
|
tick = TickData(
|
||||||
|
symbol=self.symbol,
|
||||||
|
exchange=Exchange(self.exchange),
|
||||||
|
datetime=self.datetime,
|
||||||
|
name=self.name,
|
||||||
|
volume=self.volume,
|
||||||
|
last_price=self.last_price,
|
||||||
|
last_volume=self.last_volume,
|
||||||
|
limit_up=self.limit_up,
|
||||||
|
limit_down=self.limit_down,
|
||||||
|
open_price=self.open_price,
|
||||||
|
high_price=self.high_price,
|
||||||
|
low_price=self.low_price,
|
||||||
|
pre_close=self.pre_close,
|
||||||
|
bid_price_1=self.bid_price_1,
|
||||||
|
ask_price_1=self.ask_price_1,
|
||||||
|
bid_volume_1=self.bid_volume_1,
|
||||||
|
ask_volume_1=self.ask_volume_1,
|
||||||
|
gateway_name="DB",
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.bid_price_2:
|
||||||
|
tick.bid_price_2 = self.bid_price_2
|
||||||
|
tick.bid_price_3 = self.bid_price_3
|
||||||
|
tick.bid_price_4 = self.bid_price_4
|
||||||
|
tick.bid_price_5 = self.bid_price_5
|
||||||
|
|
||||||
|
tick.ask_price_2 = self.ask_price_2
|
||||||
|
tick.ask_price_3 = self.ask_price_3
|
||||||
|
tick.ask_price_4 = self.ask_price_4
|
||||||
|
tick.ask_price_5 = self.ask_price_5
|
||||||
|
|
||||||
|
tick.bid_volume_2 = self.bid_volume_2
|
||||||
|
tick.bid_volume_3 = self.bid_volume_3
|
||||||
|
tick.bid_volume_4 = self.bid_volume_4
|
||||||
|
tick.bid_volume_5 = self.bid_volume_5
|
||||||
|
|
||||||
|
tick.ask_volume_2 = self.ask_volume_2
|
||||||
|
tick.ask_volume_3 = self.ask_volume_3
|
||||||
|
tick.ask_volume_4 = self.ask_volume_4
|
||||||
|
tick.ask_volume_5 = self.ask_volume_5
|
||||||
|
|
||||||
|
return tick
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def save_all(objs: List["DbTickData"]):
|
||||||
|
dicts = [i.to_dict() for i in objs]
|
||||||
|
with db.atomic():
|
||||||
|
if driver is Driver.POSTGRESQL:
|
||||||
|
for tick in dicts:
|
||||||
|
DbTickData.insert(tick).on_conflict(
|
||||||
|
update=tick,
|
||||||
|
conflict_target=(
|
||||||
|
DbTickData.datetime,
|
||||||
|
DbTickData.symbol,
|
||||||
|
DbTickData.exchange,
|
||||||
|
),
|
||||||
|
).execute()
|
||||||
|
else:
|
||||||
|
for c in chunked(dicts, 50):
|
||||||
|
DbTickData.insert_many(c).on_conflict_replace().execute()
|
||||||
|
|
||||||
|
db.connect()
|
||||||
|
db.create_tables([DbBarData, DbTickData])
|
||||||
|
return DbBarData, DbTickData
|
||||||
|
|
||||||
|
|
||||||
|
class SqlManager(BaseDatabaseManager):
|
||||||
|
|
||||||
|
def __init__(self, class_bar: Type[Model], class_tick: Type[Model]):
|
||||||
|
self.class_bar = class_bar
|
||||||
|
self.class_tick = class_tick
|
||||||
|
|
||||||
|
def load_bar_data(
|
||||||
|
self,
|
||||||
|
symbol: str,
|
||||||
|
exchange: Exchange,
|
||||||
|
interval: Interval,
|
||||||
|
start: datetime,
|
||||||
|
end: datetime,
|
||||||
|
) -> Sequence[BarData]:
|
||||||
|
s = (
|
||||||
|
self.class_bar.select()
|
||||||
|
.where(
|
||||||
|
(self.class_bar.symbol == symbol)
|
||||||
|
& (self.class_bar.exchange == exchange.value)
|
||||||
|
& (self.class_bar.interval == interval.value)
|
||||||
|
& (self.class_bar.datetime >= start)
|
||||||
|
& (self.class_bar.datetime <= end)
|
||||||
|
)
|
||||||
|
.order_by(self.class_bar.datetime)
|
||||||
|
)
|
||||||
|
data = [db_bar.to_bar() for db_bar in s]
|
||||||
|
return data
|
||||||
|
|
||||||
|
def load_tick_data(
|
||||||
|
self, symbol: str, exchange: Exchange, start: datetime, end: datetime
|
||||||
|
) -> Sequence[TickData]:
|
||||||
|
s = (
|
||||||
|
self.class_tick.select()
|
||||||
|
.where(
|
||||||
|
(self.class_tick.symbol == symbol)
|
||||||
|
& (self.class_tick.exchange == exchange.value)
|
||||||
|
& (self.class_tick.datetime >= start)
|
||||||
|
& (self.class_tick.datetime <= end)
|
||||||
|
)
|
||||||
|
.order_by(self.class_tick.datetime)
|
||||||
|
)
|
||||||
|
data = [db_tick.to_tick() for db_tick in s]
|
||||||
|
return data
|
||||||
|
|
||||||
|
def save_bar_data(self, datas: Sequence[BarData]):
|
||||||
|
ds = [self.class_bar.from_bar(i) for i in datas]
|
||||||
|
self.class_bar.save_all(ds)
|
||||||
|
|
||||||
|
def save_tick_data(self, datas: Sequence[TickData]):
|
||||||
|
ds = [self.class_tick.from_tick(i) for i in datas]
|
||||||
|
self.class_tick.save_all(ds)
|
24
vnpy/trader/database/initialize.py
Normal file
24
vnpy/trader/database/initialize.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
""""""
|
||||||
|
from .database import BaseDatabaseManager, Driver
|
||||||
|
|
||||||
|
|
||||||
|
def init(settings: dict) -> BaseDatabaseManager:
|
||||||
|
driver = Driver(settings["driver"])
|
||||||
|
if driver is Driver.MONGODB:
|
||||||
|
return init_nosql(driver=driver, settings=settings)
|
||||||
|
else:
|
||||||
|
return init_sql(driver=driver, settings=settings)
|
||||||
|
|
||||||
|
|
||||||
|
def init_sql(driver: Driver, settings: dict):
|
||||||
|
from .database_sql import init
|
||||||
|
keys = {'database', "host", "port", "user", "password"}
|
||||||
|
settings = {k: v for k, v in settings.items() if k in keys}
|
||||||
|
_database_manager = init(driver, settings)
|
||||||
|
return _database_manager
|
||||||
|
|
||||||
|
|
||||||
|
def init_nosql(driver: Driver, settings: dict):
|
||||||
|
from .database_mongo import init
|
||||||
|
_database_manager = init(driver, settings=settings)
|
||||||
|
return _database_manager
|
@ -30,9 +30,15 @@ SETTINGS = {
|
|||||||
"database.host": "localhost",
|
"database.host": "localhost",
|
||||||
"database.port": 3306,
|
"database.port": 3306,
|
||||||
"database.user": "root",
|
"database.user": "root",
|
||||||
"database.password": ""
|
"database.password": "",
|
||||||
|
"database.authentication_source": "admin", # for mongodb
|
||||||
}
|
}
|
||||||
|
|
||||||
# Load global setting from json file.
|
# Load global setting from json file.
|
||||||
SETTING_FILENAME = "vt_setting.json"
|
SETTING_FILENAME = "vt_setting.json"
|
||||||
SETTINGS.update(load_json(SETTING_FILENAME))
|
SETTINGS.update(load_json(SETTING_FILENAME))
|
||||||
|
|
||||||
|
|
||||||
|
def get_settings(prefix: str = ""):
|
||||||
|
prefix_length = len(prefix)
|
||||||
|
return {k[prefix_length:]: v for k, v in SETTINGS.items() if k.startswith(prefix)}
|
||||||
|
@ -3,20 +3,28 @@ General utility functions.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable
|
from typing import Callable, TYPE_CHECKING
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import talib
|
import talib
|
||||||
|
|
||||||
from .object import BarData, TickData
|
from .object import BarData, TickData
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vnpy.trader.constant import Exchange
|
||||||
|
|
||||||
def resolve_path(pattern: str):
|
|
||||||
env = dict(os.environ)
|
def extract_vt_symbol(vt_symbol: str):
|
||||||
env.update({"VNPY_TEMP": str(TEMP_DIR)})
|
"""
|
||||||
return pattern.format(**env)
|
:return: (symbol, exchange)
|
||||||
|
"""
|
||||||
|
symbol, exchange = vt_symbol.split('.')
|
||||||
|
return symbol, exchange
|
||||||
|
|
||||||
|
|
||||||
|
def generate_vt_symbol(symbol: str, exchange: "Exchange"):
|
||||||
|
return f'{symbol}.{exchange.value}'
|
||||||
|
|
||||||
|
|
||||||
def _get_trader_dir(temp_name: str):
|
def _get_trader_dir(temp_name: str):
|
||||||
|
Loading…
Reference in New Issue
Block a user