Merge pull request #1602 from nanoric/master

[Add] database
This commit is contained in:
vn.py 2019-04-16 11:19:56 +08:00 committed by GitHub
commit 1d527bc36e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 1443 additions and 532 deletions

View File

@ -2,12 +2,27 @@ language: python
dist: xenial # required for Python >= 3.7 (travis-ci/travis-ci#9069)
cache: pip
git:
depth: 1
python:
- "3.7"
services:
- mongodb
- mysql
- postgresql
before_script:
- psql -d postgresql://postgres:${VNPY_TEST_POSTGRESQL_PASSWORD}@localhost -c "create database vnpy;"
- mysql -u root --password=${VNPY_TEST_MYSQL_PASSWORD} -e 'CREATE DATABASE vnpy;'
script:
# todo: use python unittest
- mkdir run; cd run; python ../tests/load_all.py
- pip install psycopg2 mongoengine pymysql # we should support all database in test environment
- cd tests; source travis_env.sh;
- python test_all.py
matrix:
include:
@ -19,81 +34,36 @@ matrix:
script:
- flake8
- name: "pip install under Windows"
os: "windows"
# language : cpp is necessary for windows
language: "cpp"
env:
- PATH=/c/Python37:/c/Python37/Scripts:$PATH
before_install:
- choco install python3 --version 3.7.2
install:
- python -m pip install --upgrade pip wheel setuptools
- pip install https://pip.vnpy.com/colletion/TA_Lib-0.4.17-cp37-cp37m-win_amd64.whl
- pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
- pip install -r requirements.txt
- pip install .
- name: "pip install under Ubuntu: gcc-8"
addons:
apt:
sources:
- ubuntu-toolchain-r-test
packages:
- g++-8
before_install:
# C++17
- sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test
- sudo apt-get update -y
install:
# C++17
- sudo apt-get install -y gcc-8 g++-8
- sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-8 90
- sudo update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++-8 90
- sudo update-alternatives --install /usr/bin/gcc cc /usr/bin/gcc-8 90
install:
# update pip & setuptools
- python -m pip install --upgrade pip wheel setuptools
# Linux install script
- pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
- bash ./install.sh
- name: "pip install under Ubuntu: gcc-7"
- name: "sdist install under Ubuntu: gcc-7"
addons:
apt:
sources:
- ubuntu-toolchain-r-test
packages:
- g++-7
before_install:
# C++17
- sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test
- sudo apt-get update -y
install:
# C++17
- sudo apt-get install -y gcc-7 g++-7
- sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-7 90
- sudo update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++-7 90
- sudo update-alternatives --install /usr/bin/gcc cc /usr/bin/gcc-7 90
# update pip & setuptools
- python -m pip install --upgrade pip wheel setuptools
# Linux install script
- pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
- bash ./install.sh
- name: "sdist install under Windows"
os: "windows"
# language : cpp is necessary for windows
language: "cpp"
env:
- PATH=/c/Python37:/c/Python37/Scripts:$PATH
before_install:
- choco install python3 --version 3.7.2
install:
- python -m pip install --upgrade pip wheel setuptools
- python setup.py sdist
- pip install https://pip.vnpy.com/colletion/TA_Lib-0.4.17-cp37-cp37m-win_amd64.whl
- pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
- pip install dist/`ls dist`
- name: "sdist install under Ubuntu: gcc-8"
before_install:
# C++17
- sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test
- sudo apt-get update -y
install:
# C++17
- sudo apt-get install -y gcc-8 g++-8
- sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-8 90
- sudo update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++-8 90
- sudo update-alternatives --install /usr/bin/gcc cc /usr/bin/gcc-8 90
# Linux install script
- python -m pip install --upgrade pip wheel setuptools
- pushd /tmp
@ -108,3 +78,17 @@ matrix:
- pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
- python setup.py sdist
- pip install dist/`ls dist`
- name: "pip install under osx"
os: osx
language: shell # osx supports only shell
services: []
before_install: []
install:
- pip3 install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
- bash ./install_osx.sh
before_script: []
script:
- pip3 install psycopg2 mongoengine pymysql # we should support all database in test environment
- cd tests; source travis_env.sh;
- VNPY_TEST_ONLY_SQLITE=1 python3 test_all.py

68
appveyor.yml Normal file
View File

@ -0,0 +1,68 @@
clone_depth: 1
cache:
- '%LOCALAPPDATA%\pip\Cache'
configuration:
- pip
- sdist
image:
- Visual Studio 2017
services:
- mysql
- mongodb
- postgresql
environment:
PATH: C:\Python37-x64;C:\Python37-x64\Scripts;C:\Program Files\PostgreSQL\9.6\bin\;%PATH%
VNPY_TEST_MYSQL_DATABASE: vnpy
VNPY_TEST_MYSQL_HOST: localhost
VNPY_TEST_MYSQL_PORT: 3306
VNPY_TEST_MYSQL_USER: root
VNPY_TEST_MYSQL_PASSWORD: Password12!
VNPY_TEST_POSTGRESQL_DATABASE: vnpy
VNPY_TEST_POSTGRESQL_HOST: localhost
VNPY_TEST_POSTGRESQL_PORT: 5432
VNPY_TEST_POSTGRESQL_USER: postgres
VNPY_TEST_POSTGRESQL_PASSWORD: Password12!
VNPY_TEST_MONGODB_DATABASE: vnpy
VNPY_TEST_MONGODB_HOST: localhost
VNPY_TEST_MONGODB_PORT: 27017
MYSQL_PWD: Password12!
install:
- python -m pip install --upgrade pip wheel setuptools
before_build:
- ps: psql -d "postgresql://${ENV:VNPY_TEST_POSTGRESQL_USER}:${ENV:VNPY_TEST_POSTGRESQL_PASSWORD}@localhost" -c "create database vnpy;"
- ps: . "C:\Program Files\MySQL\MySQL Server 5.7\bin\mysql" -u $ENV:VNPY_TEST_MYSQL_USER -e 'CREATE DATABASE vnpy;'
for:
- matrix:
only:
- configuration: pip
build_script:
- pip install psycopg2 mongoengine pymysql # we should support all database in test environment
- pip install https://pip.vnpy.com/colletion/TA_Lib-0.4.17-cp37-cp37m-win_amd64.whl
- pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
- pip install -r requirements.txt
- pip install .
- matrix:
only:
- configuration: sdist
build_script:
- python setup.py sdist
- pip install psycopg2 mongoengine pymysql # we should support all database in test environment
- pip install https://pip.vnpy.com/colletion/TA_Lib-0.4.17-cp37-cp37m-win_amd64.whl
- pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
- ps: $name=(ls dist).name; pip install "dist/$name"
test_script:
- cd tests
- python test_all.py

View File

@ -1,27 +1,35 @@
#!/usr/bin/env bash
python=$1
pip=$2
prefix=$3
[[ -z $python ]] && python=python
[[ -z $pip ]] && pip=pip
[[ -z $prefix ]] && prefix=/usr
# Get and build ta-lib
pushd /tmp
wget http://prdownloads.sourceforge.net/ta-lib/ta-lib-0.4.0-src.tar.gz
tar -xf ta-lib-0.4.0-src.tar.gz
cd ta-lib
./configure --prefix=/usr
make
./configure --prefix=$prefix
make -j
sudo make install
popd
# old versions of ta-lib imports numpy in setup.py
pip install numpy
$pip install numpy
# Install extra packages
pip install ta-lib
pip install https://vnpy-pip.oss-cn-shanghai.aliyuncs.com/colletion/ibapi-9.75.1-py3-none-any.whl
$pip install ta-lib
$pip install https://vnpy-pip.oss-cn-shanghai.aliyuncs.com/colletion/ibapi-9.75.1-py3-none-any.whl
# Install Python Modules
pip install -r requirements.txt
$pip install -r requirements.txt
# Install local Chinese language environment
sudo locale-gen zh_CN.GB18030
# Install vn.py
pip install .
$pip install .

3
install_osx.sh Normal file
View File

@ -0,0 +1,3 @@
#!/usr/bin/env bash
bash ./install.sh python3 pip3 /usr/local

View File

@ -29,18 +29,22 @@ with open("vnpy/__init__.py", "rb") as f:
version = str(ast.literal_eval(version_line))
if platform.uname().system == "Windows":
compiler_flags = ["/MP", "/std:c++17", # standard
compiler_flags = [
"/MP", "/std:c++17", # standard
"/O2", "/Ob2", "/Oi", "/Ot", "/Oy", "/GL", # Optimization
"/wd4819" # 936 code page
]
extra_link_args = []
else:
compiler_flags = ["-std=c++17",
compiler_flags = [
"-std=c++17", # standard
"-O3", # Optimization
"-Wno-delete-incomplete", "-Wno-sign-compare",
]
extra_link_args = ["-lstdc++"]
vnctpmd = Extension("vnpy.api.ctp.vnctpmd",
vnctpmd = Extension(
"vnpy.api.ctp.vnctpmd",
[
"vnpy/api/ctp/vnctp/vnctpmd/vnctpmd.cpp",
],
@ -55,8 +59,9 @@ vnctpmd = Extension("vnpy.api.ctp.vnctpmd",
depends=[],
runtime_library_dirs=["$ORIGIN"],
language="cpp",
)
vnctptd = Extension("vnpy.api.ctp.vnctptd",
)
vnctptd = Extension(
"vnpy.api.ctp.vnctptd",
[
"vnpy/api/ctp/vnctp/vnctptd/vnctptd.cpp",
],
@ -71,8 +76,9 @@ vnctptd = Extension("vnpy.api.ctp.vnctptd",
runtime_library_dirs=["$ORIGIN"],
depends=[],
language="cpp",
)
vnoes = Extension("vnpy.api.oes.vnoes",
)
vnoes = Extension(
"vnpy.api.oes.vnoes",
[
"vnpy/api/oes/vnoes/generated_files/classes_1.cpp",
"vnpy/api/oes/vnoes/generated_files/classes_2.cpp",
@ -86,13 +92,16 @@ vnoes = Extension("vnpy.api.oes.vnoes",
libraries=["oes_api"],
extra_compile_args=compiler_flags,
extra_link_args=extra_link_args,
runtime_library_dirs=["$ORIGIN"],
depends=[],
language="cpp",
)
)
if platform.uname().system == "Windows":
if platform.system() == "Windows":
# use pre-built pyd for windows ( support python 3.7 only )
ext_modules = []
elif platform.system() == "Darwin":
ext_modules = []
else:
ext_modules = [vnctptd, vnctpmd, vnoes]

1
tests/app/__init__.py Normal file
View File

@ -0,0 +1 @@
from .test_csv_loader import *

View File

@ -0,0 +1,90 @@
"""
Test if csv loader works fine
"""
import tempfile
import unittest
from vnpy.app.csv_loader import CsvLoaderEngine
from vnpy.trader.constant import Exchange, Interval
class TestCsvLoader(unittest.TestCase):
def setUp(self) -> None:
self.engine = CsvLoaderEngine(None, None) # no engine is necessary for CsvLoader
def test_load(self):
data = """"Datetime","Open","High","Low","Close","Volume"
2010-04-16 09:16:00,3450.0,3488.0,3450.0,3468.0,489
2010-04-16 09:17:00,3468.0,3473.8,3467.0,3467.0,302
2010-04-16 09:18:00,3467.0,3471.0,3466.0,3467.0,203
2010-04-16 09:19:00,3467.0,3468.2,3448.0,3448.0,280
2010-04-16 09:20:00,3448.0,3459.0,3448.0,3454.0,250
2010-04-16 09:21:00,3454.0,3456.8,3454.0,3456.8,109
"""
with tempfile.TemporaryFile("w+t") as f:
f.write(data)
f.seek(0)
self.engine.load_by_handle(
f,
symbol="1",
exchange=Exchange.BITMEX,
interval=Interval.MINUTE,
datetime_head="Datetime",
open_head="Open",
close_head="Close",
low_head="Low",
high_head="High",
volume_head="Volume",
datetime_format="%Y-%m-%d %H:%M:%S",
)
def test_load_duplicated(self):
data = """"Datetime","Open","High","Low","Close","Volume"
2010-04-16 09:16:00,3450.0,3488.0,3450.0,3468.0,489
2010-04-16 09:17:00,3468.0,3473.8,3467.0,3467.0,302
2010-04-16 09:18:00,3467.0,3471.0,3466.0,3467.0,203
2010-04-16 09:19:00,3467.0,3468.2,3448.0,3448.0,280
2010-04-16 09:20:00,3448.0,3459.0,3448.0,3454.0,250
2010-04-16 09:21:00,3454.0,3456.8,3454.0,3456.8,109
"""
with tempfile.TemporaryFile("w+t") as f:
f.write(data)
f.seek(0)
self.engine.load_by_handle(
f,
symbol="1",
exchange=Exchange.BITMEX,
interval=Interval.MINUTE,
datetime_head="Datetime",
open_head="Open",
close_head="Close",
low_head="Low",
high_head="High",
volume_head="Volume",
datetime_format="%Y-%m-%d %H:%M:%S",
)
with tempfile.TemporaryFile("w+t") as f:
f.write(data)
f.seek(0)
self.engine.load_by_handle(
f,
symbol="1",
exchange=Exchange.BITMEX,
interval=Interval.MINUTE,
datetime_head="Datetime",
open_head="Open",
close_head="Close",
low_head="Low",
high_head="High",
volume_head="Volume",
datetime_format="%Y-%m-%d %H:%M:%S",
)
if __name__ == "__main__":
unittest.main()

View File

@ -2,7 +2,7 @@ from time import time
import rqdatac as rq
from vnpy.trader.database import DbBarData, DB
from vnpy.trader.database import DbBarData
USERNAME = ""
PASSWORD = ""
@ -39,11 +39,11 @@ def download_minute_bar(vt_symbol):
df = rq.get_price(symbol, frequency="1m", fields=FIELDS)
with DB.atomic():
bars = []
for ix, row in df.iterrows():
print(row.name)
bar = generate_bar_from_row(row, symbol, exchange)
DbBarData.replace(bar.__data__).execute()
bars.append(bar)
DbBarData.save_all(bars)
end = time()
cost = (end - start) * 1000

31
tests/test_all.py Normal file
View File

@ -0,0 +1,31 @@
# tests/runner.py
import unittest
import app
# import your test modules
import test_import_all
import trader
# initialize the test suite
loader = unittest.TestLoader()
suite = unittest.TestSuite()
# add tests to the test suite
suite.addTests(loader.loadTestsFromModule(test_import_all))
suite.addTests(loader.loadTestsFromModule(trader))
suite.addTests(loader.loadTestsFromModule(app))
# initialize a runner, pass it your suite and run it
def main():
runner = unittest.TextTestRunner(verbosity=3)
result = runner.run(suite)
return result
if __name__ == '__main__':
result = main()
if result.failures or result.errors:
exit(1)
else:
exit(0)

View File

@ -1,24 +1,45 @@
# flake8: noqa
import unittest
import platform
# noinspection PyUnresolvedReferences
class ImportTest(unittest.TestCase):
# noinspection PyUnresolvedReferences
def test_import_all(self):
from vnpy.event import EventEngine
def test_import_main_engine(self):
from vnpy.trader.engine import MainEngine
def test_import_ui(self):
from vnpy.trader.ui import MainWindow, create_qapp
def test_import_bitmex_gateway(self):
from vnpy.gateway.bitmex import BitmexGateway
def test_import_futu_gateway(self):
from vnpy.gateway.futu import FutuGateway
def test_import_ib_gateway(self):
from vnpy.gateway.ib import IbGateway
@unittest.skipIf(platform.system() == "Darwin", "Not supported yet under osx")
def test_import_ctp_gateway(self):
from vnpy.gateway.ctp import CtpGateway
def test_import_tiger_gateway(self):
from vnpy.gateway.tiger import TigerGateway
@unittest.skipIf(platform.system() == "Darwin", "Not supported yet under osx")
def test_import_oes_gateway(self):
from vnpy.gateway.oes import OesGateway
def test_import_cta_strategy_app(self):
from vnpy.app.cta_strategy import CtaStrategyApp
def test_import_csv_loader_app(self):
from vnpy.app.csv_loader import CsvLoaderApp

2
tests/trader/__init__.py Normal file
View File

@ -0,0 +1,2 @@
from .test_database import *
from .test_settings import *

View File

@ -0,0 +1,128 @@
"""
Test if database works fine
"""
import os
import unittest
from datetime import datetime, timedelta
from vnpy.trader.constant import Exchange, Interval
from vnpy.trader.database.database import Driver
from vnpy.trader.object import BarData, TickData
os.environ['VNPY_TESTING'] = '1'
profiles = {
Driver.SQLITE: {
"driver": "sqlite",
"database": "test_db.db",
}
}
if 'VNPY_TEST_ONLY_SQLITE' not in os.environ:
profiles.update({
Driver.MYSQL: {
"driver": "mysql",
"database": os.environ['VNPY_TEST_MYSQL_DATABASE'],
"host": os.environ['VNPY_TEST_MYSQL_HOST'],
"port": int(os.environ['VNPY_TEST_MYSQL_PORT']),
"user": os.environ["VNPY_TEST_MYSQL_USER"],
"password": os.environ['VNPY_TEST_MYSQL_PASSWORD'],
},
Driver.POSTGRESQL: {
"driver": "postgresql",
"database": os.environ['VNPY_TEST_POSTGRESQL_DATABASE'],
"host": os.environ['VNPY_TEST_POSTGRESQL_HOST'],
"port": int(os.environ['VNPY_TEST_POSTGRESQL_PORT']),
"user": os.environ["VNPY_TEST_POSTGRESQL_USER"],
"password": os.environ['VNPY_TEST_POSTGRESQL_PASSWORD'],
},
Driver.MONGODB: {
"driver": "mongodb",
"database": os.environ['VNPY_TEST_MONGODB_DATABASE'],
"host": os.environ['VNPY_TEST_MONGODB_HOST'],
"port": int(os.environ['VNPY_TEST_MONGODB_PORT']),
"user": "",
"password": "",
"authentication_source": "",
},
})
def now():
return datetime.utcnow()
bar = BarData(
gateway_name="DB",
symbol="test_symbol",
exchange=Exchange.BITMEX,
datetime=now(),
interval=Interval.MINUTE,
)
tick = TickData(
gateway_name="DB",
symbol="test_symbol",
exchange=Exchange.BITMEX,
datetime=now(),
name="DB_test_symbol",
)
class TestDatabase(unittest.TestCase):
def connect(self, settings: dict):
from vnpy.trader.database.initialize import init # noqa
self.manager = init(settings)
def test_upsert_bar(self):
for driver, settings in profiles.items():
with self.subTest(driver=driver, settings=settings):
self.connect(settings)
self.manager.save_bar_data([bar])
self.manager.save_bar_data([bar])
def test_save_load_bar(self):
for driver, settings in profiles.items():
with self.subTest(driver=driver, settings=settings):
self.connect(settings)
# save first
self.manager.save_bar_data([bar])
# and load
results = self.manager.load_bar_data(
symbol=bar.symbol,
exchange=bar.exchange,
interval=bar.interval,
start=bar.datetime - timedelta(seconds=1), # time is not accuracy
end=now() + timedelta(seconds=1), # time is not accuracy
)
count = len(results)
self.assertNotEqual(count, 0)
def test_upsert_tick(self):
for driver, settings in profiles.items():
with self.subTest(driver=driver, settings=settings):
self.connect(settings)
self.manager.save_tick_data([tick])
self.manager.save_tick_data([tick])
def test_save_load_tick(self):
for driver, settings in profiles.items():
with self.subTest(driver=driver, settings=settings):
self.connect(settings)
# save first
self.manager.save_tick_data([tick])
# and load
results = self.manager.load_tick_data(
symbol=bar.symbol,
exchange=bar.exchange,
start=bar.datetime - timedelta(seconds=1), # time is not accuracy
end=now() + timedelta(seconds=1), # time is not accuracy
)
count = len(results)
self.assertNotEqual(count, 0)
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,25 @@
"""
Test if database works fine
"""
import unittest
from vnpy.trader.setting import SETTINGS, get_settings
class TestSettings(unittest.TestCase):
def test_get_settings(self):
SETTINGS['a'] = 1
got = get_settings()
self.assertIn('a', got)
self.assertEqual(got['a'], 1)
def test_get_settings_with_prefix(self):
SETTINGS['a.a'] = 1
got = get_settings()
self.assertIn('a', got)
self.assertEqual(got['a'], 1)
if __name__ == '__main__':
unittest.main()

30
tests/travis_env.sh Normal file
View File

@ -0,0 +1,30 @@
#!/usr/bin/env bash
[[ -z ${VNPY_TEST_MYSQL_DATABASE} ]] && VNPY_TEST_MYSQL_DATABASE=vnpy
[[ -z ${VNPY_TEST_MYSQL_HOST} ]] && VNPY_TEST_MYSQL_HOST=127.0.0.1
[[ -z ${VNPY_TEST_MYSQL_PORT} ]] && VNPY_TEST_MYSQL_PORT=3306
[[ -z ${VNPY_TEST_MYSQL_USER} ]] && VNPY_TEST_MYSQL_USER=root
[[ -z ${VNPY_TEST_MYSQL_PASSWORD} ]] && VNPY_TEST_MYSQL_PASSWORD=
[[ -z ${VNPY_TEST_POSTGRESQL_DATABASE} ]] && VNPY_TEST_POSTGRESQL_DATABASE=vnpy
[[ -z ${VNPY_TEST_POSTGRESQL_HOST} ]] && VNPY_TEST_POSTGRESQL_HOST=127.0.0.1
[[ -z ${VNPY_TEST_POSTGRESQL_PORT} ]] && VNPY_TEST_POSTGRESQL_PORT=5432
[[ -z ${VNPY_TEST_POSTGRESQL_USER} ]] && VNPY_TEST_POSTGRESQL_USER=postgres
[[ -z ${VNPY_TEST_POSTGRESQL_PASSWORD} ]] && VNPY_TEST_POSTGRESQL_PASSWORD=
[[ -z ${VNPY_TEST_MONGODB_DATABASE} ]] && VNPY_TEST_MONGODB_DATABASE=vnpy
[[ -z ${VNPY_TEST_MONGODB_HOST} ]] && VNPY_TEST_MONGODB_HOST=127.0.0.1
[[ -z ${VNPY_TEST_MONGODB_PORT} ]] && VNPY_TEST_MONGODB_PORT=27017
export VNPY_TEST_MYSQL_DATABASE
export VNPY_TEST_MYSQL_HOST
export VNPY_TEST_MYSQL_PORT
export VNPY_TEST_MYSQL_USER
export VNPY_TEST_MYSQL_PASSWORD
export VNPY_TEST_POSTGRESQL_DATABASE
export VNPY_TEST_POSTGRESQL_HOST
export VNPY_TEST_POSTGRESQL_PORT
export VNPY_TEST_POSTGRESQL_USER
export VNPY_TEST_POSTGRESQL_PASSWORD
export VNPY_TEST_MONGODB_DATABASE
export VNPY_TEST_MONGODB_HOST
export VNPY_TEST_MONGODB_PORT

View File

@ -22,14 +22,13 @@ Sample csv file:
import csv
from datetime import datetime
from peewee import chunked
from typing import TextIO
from vnpy.event import EventEngine
from vnpy.trader.constant import Exchange, Interval
from vnpy.trader.database import DbBarData, DB
from vnpy.trader.database import database_manager
from vnpy.trader.engine import BaseEngine, MainEngine
from vnpy.trader.object import BarData
APP_NAME = "CsvLoader"
@ -53,6 +52,59 @@ class CsvLoaderEngine(BaseEngine):
self.high_head: str = ""
self.volume_head: str = ""
def load_by_handle(
self,
f: TextIO,
symbol: str,
exchange: Exchange,
interval: Interval,
datetime_head: str,
open_head: str,
close_head: str,
low_head: str,
high_head: str,
volume_head: str,
datetime_format: str,
):
"""
load by text mode file handle
"""
reader = csv.DictReader(f)
bars = []
start = None
count = 0
for item in reader:
if datetime_format:
dt = datetime.strptime(item[datetime_head], datetime_format)
else:
dt = datetime.fromisoformat(item[datetime_head])
bar = BarData(
symbol=symbol,
exchange=exchange,
datetime=dt,
interval=interval,
volume=item[volume_head],
open_price=item[open_head],
high_price=item[high_head],
low_price=item[low_head],
close_price=item[close_head],
gateway_name="DB",
)
bars.append(bar)
# do some statistics
count += 1
if not start:
start = bar.datetime
end = bar.datetime
# insert into database
database_manager.save_bar_data(bars)
return start, end, count
def load(
self,
file_path: str,
@ -65,49 +117,22 @@ class CsvLoaderEngine(BaseEngine):
low_head: str,
high_head: str,
volume_head: str,
datetime_format: str
datetime_format: str,
):
""""""
vt_symbol = f"{symbol}.{exchange.value}"
start = None
end = None
count = 0
"""
load by filename
"""
with open(file_path, "rt") as f:
reader = csv.DictReader(f)
db_bars = []
for item in reader:
dt = datetime.strptime(item[datetime_head], datetime_format)
db_bar = {
"symbol": symbol,
"exchange": exchange.value,
"datetime": dt,
"interval": interval.value,
"volume": item[volume_head],
"open_price": item[open_head],
"high_price": item[high_head],
"low_price": item[low_head],
"close_price": item[close_head],
"vt_symbol": vt_symbol,
"gateway_name": "DB"
}
db_bars.append(db_bar)
# do some statistics
count += 1
if not start:
start = db_bar["datetime"]
end = db_bar["datetime"]
# Insert into DB
with DB.atomic():
for batch in chunked(db_bars, 50):
DbBarData.insert_many(batch).on_conflict_replace().execute()
return start, end, count
return self.load_by_handle(
f,
symbol=symbol,
exchange=exchange,
interval=interval,
datetime_head=datetime_head,
open_head=open_head,
close_head=close_head,
low_head=low_head,
high_head=high_head,
volume_head=volume_head,
datetime_format=datetime_format,
)

View File

@ -12,8 +12,8 @@ from pandas import DataFrame
from vnpy.trader.constant import (Direction, Offset, Exchange,
Interval, Status)
from vnpy.trader.database import DbBarData, DbTickData
from vnpy.trader.object import OrderData, TradeData
from vnpy.trader.database import database_manager
from vnpy.trader.object import OrderData, TradeData, BarData, TickData
from vnpy.trader.utility import round_to_pricetick
from .base import (
@ -103,8 +103,8 @@ class BacktestingEngine:
self.strategy_class = None
self.strategy = None
self.tick = None
self.bar = None
self.tick: TickData
self.bar: BarData
self.datetime = None
self.interval = None
@ -199,14 +199,16 @@ class BacktestingEngine:
if self.mode == BacktestingMode.BAR:
self.history_data = load_bar_data(
self.vt_symbol,
self.symbol,
self.exchange,
self.interval,
self.start,
self.end
)
else:
self.history_data = load_tick_data(
self.vt_symbol,
self.symbol,
self.exchange,
self.start,
self.end
)
@ -520,7 +522,7 @@ class BacktestingEngine:
else:
self.daily_results[d] = DailyResult(d, price)
def new_bar(self, bar: DbBarData):
def new_bar(self, bar: BarData):
""""""
self.bar = bar
self.datetime = bar.datetime
@ -531,7 +533,7 @@ class BacktestingEngine:
self.update_daily_close(bar.close_price)
def new_tick(self, tick: DbTickData):
def new_tick(self, tick: TickData):
""""""
self.tick = tick
self.datetime = tick.datetime
@ -966,41 +968,26 @@ def optimize(
@lru_cache(maxsize=10)
def load_bar_data(
vt_symbol: str,
interval: str,
symbol: str,
exchange: Exchange,
interval: Interval,
start: datetime,
end: datetime
):
""""""
s = (
DbBarData.select()
.where(
(DbBarData.vt_symbol == vt_symbol)
& (DbBarData.interval == interval)
& (DbBarData.datetime >= start)
& (DbBarData.datetime <= end)
return database_manager.load_bar_data(
symbol, exchange, interval, start, end
)
.order_by(DbBarData.datetime)
)
data = [db_bar.to_bar() for db_bar in s]
return data
@lru_cache(maxsize=10)
def load_tick_data(
vt_symbol: str,
symbol: str,
exchange: Exchange,
start: datetime,
end: datetime
):
""""""
s = (
DbTickData.select()
.where(
(DbTickData.vt_symbol == vt_symbol)
& (DbTickData.datetime >= start)
& (DbTickData.datetime <= end)
return database_manager.load_tick_data(
symbol, exchange, start, end
)
.order_by(DbTickData.datetime)
)
data = [db_tick.db_tick() for db_tick in s]
return data

View File

@ -5,7 +5,7 @@ import os
import traceback
from collections import defaultdict
from pathlib import Path
from typing import Any, Callable
from typing import Any, Callable, List
from datetime import datetime, timedelta
from threading import Thread
from queue import Queue
@ -36,7 +36,7 @@ from vnpy.trader.constant import (
Status
)
from vnpy.trader.utility import load_json, save_json
from vnpy.trader.database import DbTickData, DbBarData
from vnpy.trader.database import database_manager
from vnpy.trader.setting import SETTINGS
from .base import (
@ -146,13 +146,12 @@ class CtaEngine(BaseEngine):
self.write_log("RQData数据接口初始化成功")
def query_bar_from_rq(
self, vt_symbol: str, interval: Interval, start: datetime, end: datetime
self, symbol: str, exchange: Exchange, interval: Interval, start: datetime, end: datetime
):
"""
Query bar data from RQData.
"""
symbol, exchange_str = vt_symbol.split(".")
rq_symbol = to_rq_symbol(vt_symbol)
rq_symbol = to_rq_symbol(symbol, exchange)
if rq_symbol not in self.rq_symbols:
return None
@ -166,11 +165,11 @@ class CtaEngine(BaseEngine):
end_date=end
)
data = []
data: List[BarData] = []
for ix, row in df.iterrows():
bar = BarData(
symbol=symbol,
exchange=Exchange(exchange_str),
exchange=exchange,
interval=interval,
datetime=row.name.to_pydatetime(),
open_price=row["open"],
@ -529,46 +528,41 @@ class CtaEngine(BaseEngine):
return self.engine_type
def load_bar(
self, vt_symbol: str, days: int, interval: Interval, callback: Callable
self, symbol: str, exchange: Exchange, days: int, interval: Interval,
callback: Callable[[BarData], None]
):
""""""
end = datetime.now()
start = end - timedelta(days)
# Query data from RQData by default, if not found, load from database.
data = self.query_bar_from_rq(vt_symbol, interval, start, end)
if not data:
s = (
DbBarData.select()
.where(
(DbBarData.vt_symbol == vt_symbol)
& (DbBarData.interval == interval.value)
& (DbBarData.datetime >= start)
& (DbBarData.datetime <= end)
# Query bars from RQData by default, if not found, load from database.
bars = self.query_bar_from_rq(symbol, exchange, interval, start, end)
if not bars:
bars = database_manager.load_bar_data(
symbol=symbol,
exchange=exchange,
interval=interval,
start=start,
end=end,
)
.order_by(DbBarData.datetime)
)
data = [db_bar.to_bar() for db_bar in s]
for bar in data:
for bar in bars:
callback(bar)
def load_tick(self, vt_symbol: str, days: int, callback: Callable):
def load_tick(self, symbol: str, exchange: Exchange, days: int,
callback: Callable[[TickData], None]):
""""""
end = datetime.now()
start = end - timedelta(days)
s = (
DbTickData.select()
.where(
(DbBarData.vt_symbol == vt_symbol)
& (DbBarData.datetime >= start)
& (DbBarData.datetime <= end)
)
.order_by(DbBarData.datetime)
ticks = database_manager.load_tick_data(
symbol=symbol,
exchange=exchange,
start=start,
end=end,
)
for tick in s:
for tick in ticks:
callback(tick)
def call_strategy_func(
@ -757,7 +751,7 @@ class CtaEngine(BaseEngine):
"""
Load strategy class from certain folder.
"""
for dirpath, dirnames, filenames in os.walk(path):
for dirpath, dirnames, filenames in os.walk(str(path)):
for filename in filenames:
if filename.endswith(".py"):
strategy_module_name = ".".join(
@ -914,19 +908,19 @@ class CtaEngine(BaseEngine):
self.main_engine.send_email(subject, msg)
def to_rq_symbol(vt_symbol: str):
def to_rq_symbol(symbol: str, exchange: Exchange):
"""
CZCE product of RQData has symbol like "TA1905" while
vt symbol is "TA905.CZCE" so need to add "1" in symbol.
"""
symbol, exchange_str = vt_symbol.split(".")
if exchange_str != "CZCE":
if exchange is not Exchange.CZCE:
return symbol.upper()
for count, word in enumerate(symbol):
if word.isdigit():
break
# noinspection PyUnboundLocalVariable
product = symbol[:count]
year = symbol[count]
month = symbol[count + 1:]

View File

@ -1,268 +0,0 @@
""""""
from peewee import CharField, DateTimeField, FloatField, Model, MySQLDatabase, PostgresqlDatabase, \
SqliteDatabase
from .constant import Exchange, Interval
from .object import BarData, TickData
from .setting import SETTINGS
from .utility import resolve_path
def init():
db_settings = SETTINGS['database']
driver = db_settings["driver"]
init_funcs = {
"sqlite": init_sqlite,
"mysql": init_mysql,
"postgresql": init_postgresql,
}
assert driver in init_funcs
del db_settings['driver']
return init_funcs[driver](db_settings)
def init_sqlite(settings: dict):
global DB
database = settings['database']
DB = SqliteDatabase(str(resolve_path(database)))
def init_mysql(settings: dict):
global DB
DB = MySQLDatabase(**settings)
def init_postgresql(settings: dict):
global DB
DB = PostgresqlDatabase(**settings)
init()
class DbBarData(Model):
"""
Candlestick bar data for database storage.
Index is defined unique with vt_symbol, interval and datetime.
"""
symbol = CharField()
exchange = CharField()
datetime = DateTimeField()
interval = CharField()
volume = FloatField()
open_price = FloatField()
high_price = FloatField()
low_price = FloatField()
close_price = FloatField()
vt_symbol = CharField()
gateway_name = CharField()
class Meta:
database = DB
indexes = ((("vt_symbol", "interval", "datetime"), True),)
@staticmethod
def from_bar(bar: BarData):
"""
Generate DbBarData object from BarData.
"""
db_bar = DbBarData()
db_bar.symbol = bar.symbol
db_bar.exchange = bar.exchange.value
db_bar.datetime = bar.datetime
db_bar.interval = bar.interval.value
db_bar.volume = bar.volume
db_bar.open_price = bar.open_price
db_bar.high_price = bar.high_price
db_bar.low_price = bar.low_price
db_bar.close_price = bar.close_price
db_bar.vt_symbol = bar.vt_symbol
db_bar.gateway_name = "DB"
return db_bar
def to_bar(self):
"""
Generate BarData object from DbBarData.
"""
bar = BarData(
symbol=self.symbol,
exchange=Exchange(self.exchange),
datetime=self.datetime,
interval=Interval(self.interval),
volume=self.volume,
open_price=self.open_price,
high_price=self.high_price,
low_price=self.low_price,
close_price=self.close_price,
gateway_name=self.gateway_name,
)
return bar
class DbTickData(Model):
"""
Tick data for database storage.
Index is defined unique with vt_symbol, interval and datetime.
"""
symbol = CharField()
exchange = CharField()
datetime = DateTimeField()
name = CharField()
volume = FloatField()
last_price = FloatField()
last_volume = FloatField()
limit_up = FloatField()
limit_down = FloatField()
open_price = FloatField()
high_price = FloatField()
low_price = FloatField()
close_price = FloatField()
bid_price_1 = FloatField()
bid_price_2 = FloatField()
bid_price_3 = FloatField()
bid_price_4 = FloatField()
bid_price_5 = FloatField()
ask_price_1 = FloatField()
ask_price_2 = FloatField()
ask_price_3 = FloatField()
ask_price_4 = FloatField()
ask_price_5 = FloatField()
bid_volume_1 = FloatField()
bid_volume_2 = FloatField()
bid_volume_3 = FloatField()
bid_volume_4 = FloatField()
bid_volume_5 = FloatField()
ask_volume_1 = FloatField()
ask_volume_2 = FloatField()
ask_volume_3 = FloatField()
ask_volume_4 = FloatField()
ask_volume_5 = FloatField()
vt_symbol = CharField()
gateway_name = CharField()
class Meta:
database = DB
indexes = ((("vt_symbol", "datetime"), True),)
@staticmethod
def from_tick(tick: TickData):
"""
Generate DbTickData object from TickData.
"""
db_tick = DbTickData()
db_tick.symbol = tick.symbol
db_tick.exchange = tick.exchange.value
db_tick.datetime = tick.datetime
db_tick.name = tick.name
db_tick.volume = tick.volume
db_tick.last_price = tick.last_price
db_tick.last_volume = tick.last_volume
db_tick.limit_up = tick.limit_up
db_tick.limit_down = tick.limit_down
db_tick.open_price = tick.open_price
db_tick.high_price = tick.high_price
db_tick.low_price = tick.low_price
db_tick.pre_close = tick.pre_close
db_tick.bid_price_1 = tick.bid_price_1
db_tick.ask_price_1 = tick.ask_price_1
db_tick.bid_volume_1 = tick.bid_volume_1
db_tick.ask_volume_1 = tick.ask_volume_1
if tick.bid_price_2:
db_tick.bid_price_2 = tick.bid_price_2
db_tick.bid_price_3 = tick.bid_price_3
db_tick.bid_price_4 = tick.bid_price_4
db_tick.bid_price_5 = tick.bid_price_5
db_tick.ask_price_2 = tick.ask_price_2
db_tick.ask_price_3 = tick.ask_price_3
db_tick.ask_price_4 = tick.ask_price_4
db_tick.ask_price_5 = tick.ask_price_5
db_tick.bid_volume_2 = tick.bid_volume_2
db_tick.bid_volume_3 = tick.bid_volume_3
db_tick.bid_volume_4 = tick.bid_volume_4
db_tick.bid_volume_5 = tick.bid_volume_5
db_tick.ask_volume_2 = tick.ask_volume_2
db_tick.ask_volume_3 = tick.ask_volume_3
db_tick.ask_volume_4 = tick.ask_volume_4
db_tick.ask_volume_5 = tick.ask_volume_5
db_tick.vt_symbol = tick.vt_symbol
db_tick.gateway_name = "DB"
return tick
def to_tick(self):
"""
Generate TickData object from DbTickData.
"""
tick = TickData(
symbol=self.symbol,
exchange=Exchange(self.exchange),
datetime=self.datetime,
name=self.name,
volume=self.volume,
last_price=self.last_price,
last_volume=self.last_volume,
limit_up=self.limit_up,
limit_down=self.limit_down,
open_price=self.open_price,
high_price=self.high_price,
low_price=self.low_price,
pre_close=self.pre_close,
bid_price_1=self.bid_price_1,
ask_price_1=self.ask_price_1,
bid_volume_1=self.bid_volume_1,
ask_volume_1=self.ask_volume_1,
gateway_name=self.gateway_name,
)
if self.bid_price_2:
tick.bid_price_2 = self.bid_price_2
tick.bid_price_3 = self.bid_price_3
tick.bid_price_4 = self.bid_price_4
tick.bid_price_5 = self.bid_price_5
tick.ask_price_2 = self.ask_price_2
tick.ask_price_3 = self.ask_price_3
tick.ask_price_4 = self.ask_price_4
tick.ask_price_5 = self.ask_price_5
tick.bid_volume_2 = self.bid_volume_2
tick.bid_volume_3 = self.bid_volume_3
tick.bid_volume_4 = self.bid_volume_4
tick.bid_volume_5 = self.bid_volume_5
tick.ask_volume_2 = self.ask_volume_2
tick.ask_volume_3 = self.ask_volume_3
tick.ask_volume_4 = self.ask_volume_4
tick.ask_volume_5 = self.ask_volume_5
return tick
DB.connect()
DB.create_tables([DbBarData, DbTickData])

View File

@ -0,0 +1,12 @@
import os
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vnpy.trader.database.database import BaseDatabaseManager
if "VNPY_TESTING" not in os.environ:
from vnpy.trader.setting import get_settings
from .initialize import init
settings = get_settings("database.")
database_manager: "BaseDatabaseManager" = init(settings=settings)

View File

@ -0,0 +1,53 @@
from abc import ABC, abstractmethod
from datetime import datetime
from enum import Enum
from typing import Sequence, TYPE_CHECKING
if TYPE_CHECKING:
from vnpy.trader.constant import Interval, Exchange # noqa
from vnpy.trader.object import BarData, TickData # noqa
class Driver(Enum):
SQLITE = "sqlite"
MYSQL = "mysql"
POSTGRESQL = "postgresql"
MONGODB = "mongodb"
class BaseDatabaseManager(ABC):
@abstractmethod
def load_bar_data(
self,
symbol: str,
exchange: "Exchange",
interval: "Interval",
start: datetime,
end: datetime
) -> Sequence["BarData"]:
pass
@abstractmethod
def load_tick_data(
self,
symbol: str,
exchange: "Exchange",
start: datetime,
end: datetime
) -> Sequence["TickData"]:
pass
@abstractmethod
def save_bar_data(
self,
datas: Sequence["BarData"],
):
pass
@abstractmethod
def save_tick_data(
self,
datas: Sequence["TickData"],
):
pass

View File

@ -0,0 +1,302 @@
from datetime import datetime
from enum import Enum
from typing import Sequence
from mongoengine import DateTimeField, Document, FloatField, StringField, connect
from vnpy.trader.constant import Exchange, Interval
from vnpy.trader.object import BarData, TickData
from .database import BaseDatabaseManager, Driver
def init(_: Driver, settings: dict):
database = settings["database"]
host = settings["host"]
port = settings["port"]
username = settings["user"]
password = settings["password"]
authentication_source = settings["authentication_source"]
if not username: # if username == '' or None, skip username
username = None
password = None
authentication_source = None
connect(
db=database,
host=host,
port=port,
username=username,
password=password,
authentication_source=authentication_source,
)
return MongoManager()
class DbBarData(Document):
"""
Candlestick bar data for database storage.
Index is defined unique with datetime, interval, symbol
"""
symbol: str = StringField()
exchange: str = StringField()
datetime: datetime = DateTimeField()
interval: str = StringField()
volume: float = FloatField()
open_price: float = FloatField()
high_price: float = FloatField()
low_price: float = FloatField()
close_price: float = FloatField()
meta = {
"indexes": [
{"fields": ("datetime", "interval", "symbol", "exchange"), "unique": True}
]
}
@staticmethod
def from_bar(bar: BarData):
"""
Generate DbBarData object from BarData.
"""
db_bar = DbBarData()
db_bar.symbol = bar.symbol
db_bar.exchange = bar.exchange.value
db_bar.datetime = bar.datetime
db_bar.interval = bar.interval.value
db_bar.volume = bar.volume
db_bar.open_price = bar.open_price
db_bar.high_price = bar.high_price
db_bar.low_price = bar.low_price
db_bar.close_price = bar.close_price
return db_bar
def to_bar(self):
"""
Generate BarData object from DbBarData.
"""
bar = BarData(
symbol=self.symbol,
exchange=Exchange(self.exchange),
datetime=self.datetime,
interval=Interval(self.interval),
volume=self.volume,
open_price=self.open_price,
high_price=self.high_price,
low_price=self.low_price,
close_price=self.close_price,
gateway_name="DB",
)
return bar
class DbTickData(Document):
"""
Tick data for database storage.
Index is defined unique with (datetime, symbol)
"""
symbol: str = StringField()
exchange: str = StringField()
datetime: datetime = DateTimeField()
name: str = StringField()
volume: float = FloatField()
last_price: float = FloatField()
last_volume: float = FloatField()
limit_up: float = FloatField()
limit_down: float = FloatField()
open_price: float = FloatField()
high_price: float = FloatField()
low_price: float = FloatField()
close_price: float = FloatField()
pre_close: float = FloatField()
bid_price_1: float = FloatField()
bid_price_2: float = FloatField()
bid_price_3: float = FloatField()
bid_price_4: float = FloatField()
bid_price_5: float = FloatField()
ask_price_1: float = FloatField()
ask_price_2: float = FloatField()
ask_price_3: float = FloatField()
ask_price_4: float = FloatField()
ask_price_5: float = FloatField()
bid_volume_1: float = FloatField()
bid_volume_2: float = FloatField()
bid_volume_3: float = FloatField()
bid_volume_4: float = FloatField()
bid_volume_5: float = FloatField()
ask_volume_1: float = FloatField()
ask_volume_2: float = FloatField()
ask_volume_3: float = FloatField()
ask_volume_4: float = FloatField()
ask_volume_5: float = FloatField()
meta = {"indexes": [{"fields": ("datetime", "symbol", "exchange"), "unique": True}]}
@staticmethod
def from_tick(tick: TickData):
"""
Generate DbTickData object from TickData.
"""
db_tick = DbTickData()
db_tick.symbol = tick.symbol
db_tick.exchange = tick.exchange.value
db_tick.datetime = tick.datetime
db_tick.name = tick.name
db_tick.volume = tick.volume
db_tick.last_price = tick.last_price
db_tick.last_volume = tick.last_volume
db_tick.limit_up = tick.limit_up
db_tick.limit_down = tick.limit_down
db_tick.open_price = tick.open_price
db_tick.high_price = tick.high_price
db_tick.low_price = tick.low_price
db_tick.pre_close = tick.pre_close
db_tick.bid_price_1 = tick.bid_price_1
db_tick.ask_price_1 = tick.ask_price_1
db_tick.bid_volume_1 = tick.bid_volume_1
db_tick.ask_volume_1 = tick.ask_volume_1
if tick.bid_price_2:
db_tick.bid_price_2 = tick.bid_price_2
db_tick.bid_price_3 = tick.bid_price_3
db_tick.bid_price_4 = tick.bid_price_4
db_tick.bid_price_5 = tick.bid_price_5
db_tick.ask_price_2 = tick.ask_price_2
db_tick.ask_price_3 = tick.ask_price_3
db_tick.ask_price_4 = tick.ask_price_4
db_tick.ask_price_5 = tick.ask_price_5
db_tick.bid_volume_2 = tick.bid_volume_2
db_tick.bid_volume_3 = tick.bid_volume_3
db_tick.bid_volume_4 = tick.bid_volume_4
db_tick.bid_volume_5 = tick.bid_volume_5
db_tick.ask_volume_2 = tick.ask_volume_2
db_tick.ask_volume_3 = tick.ask_volume_3
db_tick.ask_volume_4 = tick.ask_volume_4
db_tick.ask_volume_5 = tick.ask_volume_5
return db_tick
def to_tick(self):
"""
Generate TickData object from DbTickData.
"""
tick = TickData(
symbol=self.symbol,
exchange=Exchange(self.exchange),
datetime=self.datetime,
name=self.name,
volume=self.volume,
last_price=self.last_price,
last_volume=self.last_volume,
limit_up=self.limit_up,
limit_down=self.limit_down,
open_price=self.open_price,
high_price=self.high_price,
low_price=self.low_price,
pre_close=self.pre_close,
bid_price_1=self.bid_price_1,
ask_price_1=self.ask_price_1,
bid_volume_1=self.bid_volume_1,
ask_volume_1=self.ask_volume_1,
gateway_name="DB",
)
if self.bid_price_2:
tick.bid_price_2 = self.bid_price_2
tick.bid_price_3 = self.bid_price_3
tick.bid_price_4 = self.bid_price_4
tick.bid_price_5 = self.bid_price_5
tick.ask_price_2 = self.ask_price_2
tick.ask_price_3 = self.ask_price_3
tick.ask_price_4 = self.ask_price_4
tick.ask_price_5 = self.ask_price_5
tick.bid_volume_2 = self.bid_volume_2
tick.bid_volume_3 = self.bid_volume_3
tick.bid_volume_4 = self.bid_volume_4
tick.bid_volume_5 = self.bid_volume_5
tick.ask_volume_2 = self.ask_volume_2
tick.ask_volume_3 = self.ask_volume_3
tick.ask_volume_4 = self.ask_volume_4
tick.ask_volume_5 = self.ask_volume_5
return tick
class MongoManager(BaseDatabaseManager):
def load_bar_data(
self,
symbol: str,
exchange: Exchange,
interval: Interval,
start: datetime,
end: datetime,
) -> Sequence[BarData]:
s = DbBarData.objects(
symbol=symbol,
exchange=exchange.value,
interval=interval.value,
datetime__gte=start,
datetime__lte=end,
)
data = [db_bar.to_bar() for db_bar in s]
return data
def load_tick_data(
self, symbol: str, exchange: Exchange, start: datetime, end: datetime
) -> Sequence[TickData]:
s = DbTickData.objects(
symbol=symbol,
exchange=exchange.value,
datetime__gte=start,
datetime__lte=end,
)
data = [db_tick.to_tick() for db_tick in s]
return data
@staticmethod
def to_update_param(d):
return {
"set__" + k: v.value if isinstance(v, Enum) else v
for k, v in d.__dict__.items()
}
def save_bar_data(self, datas: Sequence[BarData]):
for d in datas:
updates = self.to_update_param(d)
updates.pop("set__gateway_name")
updates.pop("set__vt_symbol")
(
DbBarData.objects(
symbol=d.symbol, interval=d.interval.value, datetime=d.datetime
).update_one(upsert=True, **updates)
)
def save_tick_data(self, datas: Sequence[TickData]):
for d in datas:
updates = self.to_update_param(d)
updates.pop("set__gateway_name")
updates.pop("set__vt_symbol")
(
DbTickData.objects(
symbol=d.symbol, exchange=d.exchange.value, datetime=d.datetime
).update_one(upsert=True, **updates)
)

View File

@ -0,0 +1,369 @@
""""""
from datetime import datetime
from typing import List, Sequence, Type
from peewee import (
AutoField,
CharField,
Database,
DateTimeField,
FloatField,
Model,
MySQLDatabase,
PostgresqlDatabase,
SqliteDatabase,
chunked,
)
from vnpy.trader.constant import Exchange, Interval
from vnpy.trader.object import BarData, TickData
from vnpy.trader.utility import get_file_path
from .database import BaseDatabaseManager, Driver
def init(driver: Driver, settings: dict):
init_funcs = {
Driver.SQLITE: init_sqlite,
Driver.MYSQL: init_mysql,
Driver.POSTGRESQL: init_postgresql,
}
assert driver in init_funcs
db = init_funcs[driver](settings)
bar, tick = init_models(db, driver)
return SqlManager(bar, tick)
def init_sqlite(settings: dict):
database = settings["database"]
path = str(get_file_path(database))
db = SqliteDatabase(path)
return db
def init_mysql(settings: dict):
keys = {"database", "user", "password", "host", "port"}
settings = {k: v for k, v in settings.items() if k in keys}
db = MySQLDatabase(**settings)
return db
def init_postgresql(settings: dict):
keys = {"database", "user", "password", "host", "port"}
settings = {k: v for k, v in settings.items() if k in keys}
db = PostgresqlDatabase(**settings)
return db
class ModelBase(Model):
def to_dict(self):
return self.__data__
def init_models(db: Database, driver: Driver):
class DbBarData(ModelBase):
"""
Candlestick bar data for database storage.
Index is defined unique with datetime, interval, symbol
"""
id = AutoField()
symbol: str = CharField()
exchange: str = CharField()
datetime: datetime = DateTimeField()
interval: str = CharField()
volume: float = FloatField()
open_price: float = FloatField()
high_price: float = FloatField()
low_price: float = FloatField()
close_price: float = FloatField()
class Meta:
database = db
indexes = ((("datetime", "interval", "symbol", "exchange"), True),)
@staticmethod
def from_bar(bar: BarData):
"""
Generate DbBarData object from BarData.
"""
db_bar = DbBarData()
db_bar.symbol = bar.symbol
db_bar.exchange = bar.exchange.value
db_bar.datetime = bar.datetime
db_bar.interval = bar.interval.value
db_bar.volume = bar.volume
db_bar.open_price = bar.open_price
db_bar.high_price = bar.high_price
db_bar.low_price = bar.low_price
db_bar.close_price = bar.close_price
return db_bar
def to_bar(self):
"""
Generate BarData object from DbBarData.
"""
bar = BarData(
symbol=self.symbol,
exchange=Exchange(self.exchange),
datetime=self.datetime,
interval=Interval(self.interval),
volume=self.volume,
open_price=self.open_price,
high_price=self.high_price,
low_price=self.low_price,
close_price=self.close_price,
gateway_name="DB",
)
return bar
@staticmethod
def save_all(objs: List["DbBarData"]):
"""
save a list of objects, update if exists.
"""
dicts = [i.to_dict() for i in objs]
with db.atomic():
if driver is Driver.POSTGRESQL:
for bar in dicts:
DbBarData.insert(bar).on_conflict(
update=bar,
conflict_target=(
DbBarData.datetime,
DbBarData.interval,
DbBarData.symbol,
DbBarData.exchange,
),
).execute()
else:
for c in chunked(dicts, 50):
DbBarData.insert_many(c).on_conflict_replace().execute()
class DbTickData(ModelBase):
"""
Tick data for database storage.
Index is defined unique with (datetime, symbol)
"""
id = AutoField()
symbol: str = CharField()
exchange: str = CharField()
datetime: datetime = DateTimeField()
name: str = CharField()
volume: float = FloatField()
last_price: float = FloatField()
last_volume: float = FloatField()
limit_up: float = FloatField()
limit_down: float = FloatField()
open_price: float = FloatField()
high_price: float = FloatField()
low_price: float = FloatField()
pre_close: float = FloatField()
bid_price_1: float = FloatField()
bid_price_2: float = FloatField(null=True)
bid_price_3: float = FloatField(null=True)
bid_price_4: float = FloatField(null=True)
bid_price_5: float = FloatField(null=True)
ask_price_1: float = FloatField()
ask_price_2: float = FloatField(null=True)
ask_price_3: float = FloatField(null=True)
ask_price_4: float = FloatField(null=True)
ask_price_5: float = FloatField(null=True)
bid_volume_1: float = FloatField()
bid_volume_2: float = FloatField(null=True)
bid_volume_3: float = FloatField(null=True)
bid_volume_4: float = FloatField(null=True)
bid_volume_5: float = FloatField(null=True)
ask_volume_1: float = FloatField()
ask_volume_2: float = FloatField(null=True)
ask_volume_3: float = FloatField(null=True)
ask_volume_4: float = FloatField(null=True)
ask_volume_5: float = FloatField(null=True)
class Meta:
database = db
indexes = ((("datetime", "symbol", "exchange"), True),)
@staticmethod
def from_tick(tick: TickData):
"""
Generate DbTickData object from TickData.
"""
db_tick = DbTickData()
db_tick.symbol = tick.symbol
db_tick.exchange = tick.exchange.value
db_tick.datetime = tick.datetime
db_tick.name = tick.name
db_tick.volume = tick.volume
db_tick.last_price = tick.last_price
db_tick.last_volume = tick.last_volume
db_tick.limit_up = tick.limit_up
db_tick.limit_down = tick.limit_down
db_tick.open_price = tick.open_price
db_tick.high_price = tick.high_price
db_tick.low_price = tick.low_price
db_tick.pre_close = tick.pre_close
db_tick.bid_price_1 = tick.bid_price_1
db_tick.ask_price_1 = tick.ask_price_1
db_tick.bid_volume_1 = tick.bid_volume_1
db_tick.ask_volume_1 = tick.ask_volume_1
if tick.bid_price_2:
db_tick.bid_price_2 = tick.bid_price_2
db_tick.bid_price_3 = tick.bid_price_3
db_tick.bid_price_4 = tick.bid_price_4
db_tick.bid_price_5 = tick.bid_price_5
db_tick.ask_price_2 = tick.ask_price_2
db_tick.ask_price_3 = tick.ask_price_3
db_tick.ask_price_4 = tick.ask_price_4
db_tick.ask_price_5 = tick.ask_price_5
db_tick.bid_volume_2 = tick.bid_volume_2
db_tick.bid_volume_3 = tick.bid_volume_3
db_tick.bid_volume_4 = tick.bid_volume_4
db_tick.bid_volume_5 = tick.bid_volume_5
db_tick.ask_volume_2 = tick.ask_volume_2
db_tick.ask_volume_3 = tick.ask_volume_3
db_tick.ask_volume_4 = tick.ask_volume_4
db_tick.ask_volume_5 = tick.ask_volume_5
return db_tick
def to_tick(self):
"""
Generate TickData object from DbTickData.
"""
tick = TickData(
symbol=self.symbol,
exchange=Exchange(self.exchange),
datetime=self.datetime,
name=self.name,
volume=self.volume,
last_price=self.last_price,
last_volume=self.last_volume,
limit_up=self.limit_up,
limit_down=self.limit_down,
open_price=self.open_price,
high_price=self.high_price,
low_price=self.low_price,
pre_close=self.pre_close,
bid_price_1=self.bid_price_1,
ask_price_1=self.ask_price_1,
bid_volume_1=self.bid_volume_1,
ask_volume_1=self.ask_volume_1,
gateway_name="DB",
)
if self.bid_price_2:
tick.bid_price_2 = self.bid_price_2
tick.bid_price_3 = self.bid_price_3
tick.bid_price_4 = self.bid_price_4
tick.bid_price_5 = self.bid_price_5
tick.ask_price_2 = self.ask_price_2
tick.ask_price_3 = self.ask_price_3
tick.ask_price_4 = self.ask_price_4
tick.ask_price_5 = self.ask_price_5
tick.bid_volume_2 = self.bid_volume_2
tick.bid_volume_3 = self.bid_volume_3
tick.bid_volume_4 = self.bid_volume_4
tick.bid_volume_5 = self.bid_volume_5
tick.ask_volume_2 = self.ask_volume_2
tick.ask_volume_3 = self.ask_volume_3
tick.ask_volume_4 = self.ask_volume_4
tick.ask_volume_5 = self.ask_volume_5
return tick
@staticmethod
def save_all(objs: List["DbTickData"]):
dicts = [i.to_dict() for i in objs]
with db.atomic():
if driver is Driver.POSTGRESQL:
for tick in dicts:
DbTickData.insert(tick).on_conflict(
update=tick,
conflict_target=(
DbTickData.datetime,
DbTickData.symbol,
DbTickData.exchange,
),
).execute()
else:
for c in chunked(dicts, 50):
DbTickData.insert_many(c).on_conflict_replace().execute()
db.connect()
db.create_tables([DbBarData, DbTickData])
return DbBarData, DbTickData
class SqlManager(BaseDatabaseManager):
def __init__(self, class_bar: Type[Model], class_tick: Type[Model]):
self.class_bar = class_bar
self.class_tick = class_tick
def load_bar_data(
self,
symbol: str,
exchange: Exchange,
interval: Interval,
start: datetime,
end: datetime,
) -> Sequence[BarData]:
s = (
self.class_bar.select()
.where(
(self.class_bar.symbol == symbol)
& (self.class_bar.exchange == exchange.value)
& (self.class_bar.interval == interval.value)
& (self.class_bar.datetime >= start)
& (self.class_bar.datetime <= end)
)
.order_by(self.class_bar.datetime)
)
data = [db_bar.to_bar() for db_bar in s]
return data
def load_tick_data(
self, symbol: str, exchange: Exchange, start: datetime, end: datetime
) -> Sequence[TickData]:
s = (
self.class_tick.select()
.where(
(self.class_tick.symbol == symbol)
& (self.class_tick.exchange == exchange.value)
& (self.class_tick.datetime >= start)
& (self.class_tick.datetime <= end)
)
.order_by(self.class_tick.datetime)
)
data = [db_tick.to_tick() for db_tick in s]
return data
def save_bar_data(self, datas: Sequence[BarData]):
ds = [self.class_bar.from_bar(i) for i in datas]
self.class_bar.save_all(ds)
def save_tick_data(self, datas: Sequence[TickData]):
ds = [self.class_tick.from_tick(i) for i in datas]
self.class_tick.save_all(ds)

View File

@ -0,0 +1,24 @@
""""""
from .database import BaseDatabaseManager, Driver
def init(settings: dict) -> BaseDatabaseManager:
driver = Driver(settings["driver"])
if driver is Driver.MONGODB:
return init_nosql(driver=driver, settings=settings)
else:
return init_sql(driver=driver, settings=settings)
def init_sql(driver: Driver, settings: dict):
from .database_sql import init
keys = {'database', "host", "port", "user", "password"}
settings = {k: v for k, v in settings.items() if k in keys}
_database_manager = init(driver, settings)
return _database_manager
def init_nosql(driver: Driver, settings: dict):
from .database_mongo import init
_database_manager = init(driver, settings=settings)
return _database_manager

View File

@ -24,16 +24,21 @@ SETTINGS = {
"rqdata.username": "",
"rqdata.password": "",
"database": {
"driver": "sqlite", # sqlite, mysql, postgresql
"database": "{VNPY_TEMP}/database.db", # for sqlite, use this as filepath
"host": "localhost",
"port": 3306,
"user": "root",
"password": ""
}
"database.driver": "sqlite", # see database.Driver
"database.database": "database.db", # for sqlite, use this as filepath
"database.host": "localhost",
"database.port": 3306,
"database.user": "root",
"database.password": "",
"database.authentication_source": "admin", # for mongodb
}
# Load global setting from json file.
SETTING_FILENAME = "vt_setting.json"
SETTINGS.update(load_json(SETTING_FILENAME))
def get_settings(prefix: str = ""):
prefix_length = len(prefix)
return {k[prefix_length:]: v for k, v in SETTINGS.items() if k.startswith(prefix)}

View File

@ -3,20 +3,28 @@ General utility functions.
"""
import json
import os
from pathlib import Path
from typing import Callable
from typing import Callable, TYPE_CHECKING
import numpy as np
import talib
from .object import BarData, TickData
if TYPE_CHECKING:
from vnpy.trader.constant import Exchange
def resolve_path(pattern: str):
env = dict(os.environ)
env.update({"VNPY_TEMP": str(TEMP_DIR)})
return pattern.format(**env)
def extract_vt_symbol(vt_symbol: str):
"""
:return: (symbol, exchange)
"""
symbol, exchange = vt_symbol.split('.')
return symbol, exchange
def generate_vt_symbol(symbol: str, exchange: "Exchange"):
return f'{symbol}.{exchange.value}'
def _get_trader_dir(temp_name: str):