Merge pull request #1602 from nanoric/master

[Add] database
This commit is contained in:
vn.py 2019-04-16 11:19:56 +08:00 committed by GitHub
commit 1d527bc36e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 1443 additions and 532 deletions

View File

@ -2,12 +2,27 @@ language: python
dist: xenial # required for Python >= 3.7 (travis-ci/travis-ci#9069) dist: xenial # required for Python >= 3.7 (travis-ci/travis-ci#9069)
cache: pip
git:
depth: 1
python: python:
- "3.7" - "3.7"
services:
- mongodb
- mysql
- postgresql
before_script:
- psql -d postgresql://postgres:${VNPY_TEST_POSTGRESQL_PASSWORD}@localhost -c "create database vnpy;"
- mysql -u root --password=${VNPY_TEST_MYSQL_PASSWORD} -e 'CREATE DATABASE vnpy;'
script: script:
# todo: use python unittest - pip install psycopg2 mongoengine pymysql # we should support all database in test environment
- mkdir run; cd run; python ../tests/load_all.py - cd tests; source travis_env.sh;
- python test_all.py
matrix: matrix:
include: include:
@ -19,81 +34,36 @@ matrix:
script: script:
- flake8 - flake8
- name: "pip install under Windows"
os: "windows"
# language : cpp is necessary for windows
language: "cpp"
env:
- PATH=/c/Python37:/c/Python37/Scripts:$PATH
before_install:
- choco install python3 --version 3.7.2
install:
- python -m pip install --upgrade pip wheel setuptools
- pip install https://pip.vnpy.com/colletion/TA_Lib-0.4.17-cp37-cp37m-win_amd64.whl
- pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
- pip install -r requirements.txt
- pip install .
- name: "pip install under Ubuntu: gcc-8" - name: "pip install under Ubuntu: gcc-8"
addons:
apt:
sources:
- ubuntu-toolchain-r-test
packages:
- g++-8
before_install: before_install:
# C++17
- sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test
- sudo apt-get update -y
install:
# C++17
- sudo apt-get install -y gcc-8 g++-8
- sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-8 90 - sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-8 90
- sudo update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++-8 90 - sudo update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++-8 90
- sudo update-alternatives --install /usr/bin/gcc cc /usr/bin/gcc-8 90 - sudo update-alternatives --install /usr/bin/gcc cc /usr/bin/gcc-8 90
install:
# update pip & setuptools # update pip & setuptools
- python -m pip install --upgrade pip wheel setuptools - python -m pip install --upgrade pip wheel setuptools
# Linux install script # Linux install script
- pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl - pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
- bash ./install.sh - bash ./install.sh
- name: "pip install under Ubuntu: gcc-7" - name: "sdist install under Ubuntu: gcc-7"
addons:
apt:
sources:
- ubuntu-toolchain-r-test
packages:
- g++-7
before_install: before_install:
# C++17
- sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test
- sudo apt-get update -y
install:
# C++17
- sudo apt-get install -y gcc-7 g++-7
- sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-7 90 - sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-7 90
- sudo update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++-7 90 - sudo update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++-7 90
- sudo update-alternatives --install /usr/bin/gcc cc /usr/bin/gcc-7 90 - sudo update-alternatives --install /usr/bin/gcc cc /usr/bin/gcc-7 90
# update pip & setuptools
- python -m pip install --upgrade pip wheel setuptools
# Linux install script
- pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
- bash ./install.sh
- name: "sdist install under Windows"
os: "windows"
# language : cpp is necessary for windows
language: "cpp"
env:
- PATH=/c/Python37:/c/Python37/Scripts:$PATH
before_install:
- choco install python3 --version 3.7.2
install: install:
- python -m pip install --upgrade pip wheel setuptools
- python setup.py sdist
- pip install https://pip.vnpy.com/colletion/TA_Lib-0.4.17-cp37-cp37m-win_amd64.whl
- pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
- pip install dist/`ls dist`
- name: "sdist install under Ubuntu: gcc-8"
before_install:
# C++17
- sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test
- sudo apt-get update -y
install:
# C++17
- sudo apt-get install -y gcc-8 g++-8
- sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-8 90
- sudo update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++-8 90
- sudo update-alternatives --install /usr/bin/gcc cc /usr/bin/gcc-8 90
# Linux install script # Linux install script
- python -m pip install --upgrade pip wheel setuptools - python -m pip install --upgrade pip wheel setuptools
- pushd /tmp - pushd /tmp
@ -108,3 +78,17 @@ matrix:
- pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl - pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
- python setup.py sdist - python setup.py sdist
- pip install dist/`ls dist` - pip install dist/`ls dist`
- name: "pip install under osx"
os: osx
language: shell # osx supports only shell
services: []
before_install: []
install:
- pip3 install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
- bash ./install_osx.sh
before_script: []
script:
- pip3 install psycopg2 mongoengine pymysql # we should support all database in test environment
- cd tests; source travis_env.sh;
- VNPY_TEST_ONLY_SQLITE=1 python3 test_all.py

68
appveyor.yml Normal file
View File

@ -0,0 +1,68 @@
clone_depth: 1
cache:
- '%LOCALAPPDATA%\pip\Cache'
configuration:
- pip
- sdist
image:
- Visual Studio 2017
services:
- mysql
- mongodb
- postgresql
environment:
PATH: C:\Python37-x64;C:\Python37-x64\Scripts;C:\Program Files\PostgreSQL\9.6\bin\;%PATH%
VNPY_TEST_MYSQL_DATABASE: vnpy
VNPY_TEST_MYSQL_HOST: localhost
VNPY_TEST_MYSQL_PORT: 3306
VNPY_TEST_MYSQL_USER: root
VNPY_TEST_MYSQL_PASSWORD: Password12!
VNPY_TEST_POSTGRESQL_DATABASE: vnpy
VNPY_TEST_POSTGRESQL_HOST: localhost
VNPY_TEST_POSTGRESQL_PORT: 5432
VNPY_TEST_POSTGRESQL_USER: postgres
VNPY_TEST_POSTGRESQL_PASSWORD: Password12!
VNPY_TEST_MONGODB_DATABASE: vnpy
VNPY_TEST_MONGODB_HOST: localhost
VNPY_TEST_MONGODB_PORT: 27017
MYSQL_PWD: Password12!
install:
- python -m pip install --upgrade pip wheel setuptools
before_build:
- ps: psql -d "postgresql://${ENV:VNPY_TEST_POSTGRESQL_USER}:${ENV:VNPY_TEST_POSTGRESQL_PASSWORD}@localhost" -c "create database vnpy;"
- ps: . "C:\Program Files\MySQL\MySQL Server 5.7\bin\mysql" -u $ENV:VNPY_TEST_MYSQL_USER -e 'CREATE DATABASE vnpy;'
for:
- matrix:
only:
- configuration: pip
build_script:
- pip install psycopg2 mongoengine pymysql # we should support all database in test environment
- pip install https://pip.vnpy.com/colletion/TA_Lib-0.4.17-cp37-cp37m-win_amd64.whl
- pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
- pip install -r requirements.txt
- pip install .
- matrix:
only:
- configuration: sdist
build_script:
- python setup.py sdist
- pip install psycopg2 mongoengine pymysql # we should support all database in test environment
- pip install https://pip.vnpy.com/colletion/TA_Lib-0.4.17-cp37-cp37m-win_amd64.whl
- pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
- ps: $name=(ls dist).name; pip install "dist/$name"
test_script:
- cd tests
- python test_all.py

View File

@ -1,27 +1,35 @@
#!/usr/bin/env bash #!/usr/bin/env bash
python=$1
pip=$2
prefix=$3
[[ -z $python ]] && python=python
[[ -z $pip ]] && pip=pip
[[ -z $prefix ]] && prefix=/usr
# Get and build ta-lib # Get and build ta-lib
pushd /tmp pushd /tmp
wget http://prdownloads.sourceforge.net/ta-lib/ta-lib-0.4.0-src.tar.gz wget http://prdownloads.sourceforge.net/ta-lib/ta-lib-0.4.0-src.tar.gz
tar -xf ta-lib-0.4.0-src.tar.gz tar -xf ta-lib-0.4.0-src.tar.gz
cd ta-lib cd ta-lib
./configure --prefix=/usr ./configure --prefix=$prefix
make make -j
sudo make install sudo make install
popd popd
# old versions of ta-lib imports numpy in setup.py # old versions of ta-lib imports numpy in setup.py
pip install numpy $pip install numpy
# Install extra packages # Install extra packages
pip install ta-lib $pip install ta-lib
pip install https://vnpy-pip.oss-cn-shanghai.aliyuncs.com/colletion/ibapi-9.75.1-py3-none-any.whl $pip install https://vnpy-pip.oss-cn-shanghai.aliyuncs.com/colletion/ibapi-9.75.1-py3-none-any.whl
# Install Python Modules # Install Python Modules
pip install -r requirements.txt $pip install -r requirements.txt
# Install local Chinese language environment # Install local Chinese language environment
sudo locale-gen zh_CN.GB18030 sudo locale-gen zh_CN.GB18030
# Install vn.py # Install vn.py
pip install . $pip install .

3
install_osx.sh Normal file
View File

@ -0,0 +1,3 @@
#!/usr/bin/env bash
bash ./install.sh python3 pip3 /usr/local

123
setup.py
View File

@ -29,70 +29,79 @@ with open("vnpy/__init__.py", "rb") as f:
version = str(ast.literal_eval(version_line)) version = str(ast.literal_eval(version_line))
if platform.uname().system == "Windows": if platform.uname().system == "Windows":
compiler_flags = ["/MP", "/std:c++17", # standard compiler_flags = [
"/O2", "/Ob2", "/Oi", "/Ot", "/Oy", "/GL", # Optimization "/MP", "/std:c++17", # standard
"/wd4819" # 936 code page "/O2", "/Ob2", "/Oi", "/Ot", "/Oy", "/GL", # Optimization
] "/wd4819" # 936 code page
]
extra_link_args = [] extra_link_args = []
else: else:
compiler_flags = ["-std=c++17", compiler_flags = [
"-Wno-delete-incomplete", "-Wno-sign-compare", "-std=c++17", # standard
] "-O3", # Optimization
"-Wno-delete-incomplete", "-Wno-sign-compare",
]
extra_link_args = ["-lstdc++"] extra_link_args = ["-lstdc++"]
vnctpmd = Extension("vnpy.api.ctp.vnctpmd", vnctpmd = Extension(
[ "vnpy.api.ctp.vnctpmd",
"vnpy/api/ctp/vnctp/vnctpmd/vnctpmd.cpp", [
], "vnpy/api/ctp/vnctp/vnctpmd/vnctpmd.cpp",
include_dirs=["vnpy/api/ctp/include", ],
"vnpy/api/ctp/vnctp", ], include_dirs=["vnpy/api/ctp/include",
define_macros=[], "vnpy/api/ctp/vnctp", ],
undef_macros=[], define_macros=[],
library_dirs=["vnpy/api/ctp/libs", "vnpy/api/ctp"], undef_macros=[],
libraries=["thostmduserapi", "thosttraderapi", ], library_dirs=["vnpy/api/ctp/libs", "vnpy/api/ctp"],
extra_compile_args=compiler_flags, libraries=["thostmduserapi", "thosttraderapi", ],
extra_link_args=extra_link_args, extra_compile_args=compiler_flags,
depends=[], extra_link_args=extra_link_args,
runtime_library_dirs=["$ORIGIN"], depends=[],
language="cpp", runtime_library_dirs=["$ORIGIN"],
) language="cpp",
vnctptd = Extension("vnpy.api.ctp.vnctptd", )
[ vnctptd = Extension(
"vnpy/api/ctp/vnctp/vnctptd/vnctptd.cpp", "vnpy.api.ctp.vnctptd",
], [
include_dirs=["vnpy/api/ctp/include", "vnpy/api/ctp/vnctp/vnctptd/vnctptd.cpp",
"vnpy/api/ctp/vnctp", ], ],
define_macros=[], include_dirs=["vnpy/api/ctp/include",
undef_macros=[], "vnpy/api/ctp/vnctp", ],
library_dirs=["vnpy/api/ctp/libs", "vnpy/api/ctp"], define_macros=[],
libraries=["thostmduserapi", "thosttraderapi", ], undef_macros=[],
extra_compile_args=compiler_flags, library_dirs=["vnpy/api/ctp/libs", "vnpy/api/ctp"],
extra_link_args=extra_link_args, libraries=["thostmduserapi", "thosttraderapi", ],
runtime_library_dirs=["$ORIGIN"], extra_compile_args=compiler_flags,
depends=[], extra_link_args=extra_link_args,
language="cpp", runtime_library_dirs=["$ORIGIN"],
) depends=[],
vnoes = Extension("vnpy.api.oes.vnoes", language="cpp",
[ )
"vnpy/api/oes/vnoes/generated_files/classes_1.cpp", vnoes = Extension(
"vnpy/api/oes/vnoes/generated_files/classes_2.cpp", "vnpy.api.oes.vnoes",
"vnpy/api/oes/vnoes/generated_files/module.cpp", [
], "vnpy/api/oes/vnoes/generated_files/classes_1.cpp",
include_dirs=["vnpy/api/oes/include", "vnpy/api/oes/vnoes/generated_files/classes_2.cpp",
"vnpy/api/oes/vnoes", ], "vnpy/api/oes/vnoes/generated_files/module.cpp",
define_macros=[("BRIGAND_NO_BOOST_SUPPORT", "1")], ],
undef_macros=[], include_dirs=["vnpy/api/oes/include",
library_dirs=["vnpy/api/oes/libs"], "vnpy/api/oes/vnoes", ],
libraries=["oes_api"], define_macros=[("BRIGAND_NO_BOOST_SUPPORT", "1")],
extra_compile_args=compiler_flags, undef_macros=[],
extra_link_args=extra_link_args, library_dirs=["vnpy/api/oes/libs"],
depends=[], libraries=["oes_api"],
language="cpp", extra_compile_args=compiler_flags,
) extra_link_args=extra_link_args,
runtime_library_dirs=["$ORIGIN"],
depends=[],
language="cpp",
)
if platform.uname().system == "Windows": if platform.system() == "Windows":
# use pre-built pyd for windows ( support python 3.7 only ) # use pre-built pyd for windows ( support python 3.7 only )
ext_modules = [] ext_modules = []
elif platform.system() == "Darwin":
ext_modules = []
else: else:
ext_modules = [vnctptd, vnctpmd, vnoes] ext_modules = [vnctptd, vnctpmd, vnoes]

1
tests/app/__init__.py Normal file
View File

@ -0,0 +1 @@
from .test_csv_loader import *

View File

@ -0,0 +1,90 @@
"""
Test if csv loader works fine
"""
import tempfile
import unittest
from vnpy.app.csv_loader import CsvLoaderEngine
from vnpy.trader.constant import Exchange, Interval
class TestCsvLoader(unittest.TestCase):
def setUp(self) -> None:
self.engine = CsvLoaderEngine(None, None) # no engine is necessary for CsvLoader
def test_load(self):
data = """"Datetime","Open","High","Low","Close","Volume"
2010-04-16 09:16:00,3450.0,3488.0,3450.0,3468.0,489
2010-04-16 09:17:00,3468.0,3473.8,3467.0,3467.0,302
2010-04-16 09:18:00,3467.0,3471.0,3466.0,3467.0,203
2010-04-16 09:19:00,3467.0,3468.2,3448.0,3448.0,280
2010-04-16 09:20:00,3448.0,3459.0,3448.0,3454.0,250
2010-04-16 09:21:00,3454.0,3456.8,3454.0,3456.8,109
"""
with tempfile.TemporaryFile("w+t") as f:
f.write(data)
f.seek(0)
self.engine.load_by_handle(
f,
symbol="1",
exchange=Exchange.BITMEX,
interval=Interval.MINUTE,
datetime_head="Datetime",
open_head="Open",
close_head="Close",
low_head="Low",
high_head="High",
volume_head="Volume",
datetime_format="%Y-%m-%d %H:%M:%S",
)
def test_load_duplicated(self):
data = """"Datetime","Open","High","Low","Close","Volume"
2010-04-16 09:16:00,3450.0,3488.0,3450.0,3468.0,489
2010-04-16 09:17:00,3468.0,3473.8,3467.0,3467.0,302
2010-04-16 09:18:00,3467.0,3471.0,3466.0,3467.0,203
2010-04-16 09:19:00,3467.0,3468.2,3448.0,3448.0,280
2010-04-16 09:20:00,3448.0,3459.0,3448.0,3454.0,250
2010-04-16 09:21:00,3454.0,3456.8,3454.0,3456.8,109
"""
with tempfile.TemporaryFile("w+t") as f:
f.write(data)
f.seek(0)
self.engine.load_by_handle(
f,
symbol="1",
exchange=Exchange.BITMEX,
interval=Interval.MINUTE,
datetime_head="Datetime",
open_head="Open",
close_head="Close",
low_head="Low",
high_head="High",
volume_head="Volume",
datetime_format="%Y-%m-%d %H:%M:%S",
)
with tempfile.TemporaryFile("w+t") as f:
f.write(data)
f.seek(0)
self.engine.load_by_handle(
f,
symbol="1",
exchange=Exchange.BITMEX,
interval=Interval.MINUTE,
datetime_head="Datetime",
open_head="Open",
close_head="Close",
low_head="Low",
high_head="High",
volume_head="Volume",
datetime_format="%Y-%m-%d %H:%M:%S",
)
if __name__ == "__main__":
unittest.main()

View File

@ -2,7 +2,7 @@ from time import time
import rqdatac as rq import rqdatac as rq
from vnpy.trader.database import DbBarData, DB from vnpy.trader.database import DbBarData
USERNAME = "" USERNAME = ""
PASSWORD = "" PASSWORD = ""
@ -39,11 +39,11 @@ def download_minute_bar(vt_symbol):
df = rq.get_price(symbol, frequency="1m", fields=FIELDS) df = rq.get_price(symbol, frequency="1m", fields=FIELDS)
with DB.atomic(): bars = []
for ix, row in df.iterrows(): for ix, row in df.iterrows():
print(row.name) bar = generate_bar_from_row(row, symbol, exchange)
bar = generate_bar_from_row(row, symbol, exchange) bars.append(bar)
DbBarData.replace(bar.__data__).execute() DbBarData.save_all(bars)
end = time() end = time()
cost = (end - start) * 1000 cost = (end - start) * 1000

31
tests/test_all.py Normal file
View File

@ -0,0 +1,31 @@
# tests/runner.py
import unittest
import app
# import your test modules
import test_import_all
import trader
# initialize the test suite
loader = unittest.TestLoader()
suite = unittest.TestSuite()
# add tests to the test suite
suite.addTests(loader.loadTestsFromModule(test_import_all))
suite.addTests(loader.loadTestsFromModule(trader))
suite.addTests(loader.loadTestsFromModule(app))
# initialize a runner, pass it your suite and run it
def main():
runner = unittest.TextTestRunner(verbosity=3)
result = runner.run(suite)
return result
if __name__ == '__main__':
result = main()
if result.failures or result.errors:
exit(1)
else:
exit(0)

View File

@ -1,24 +1,45 @@
# flake8: noqa # flake8: noqa
import unittest import unittest
import platform
# noinspection PyUnresolvedReferences
class ImportTest(unittest.TestCase): class ImportTest(unittest.TestCase):
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
def test_import_all(self): def test_import_all(self):
from vnpy.event import EventEngine from vnpy.event import EventEngine
def test_import_main_engine(self):
from vnpy.trader.engine import MainEngine from vnpy.trader.engine import MainEngine
def test_import_ui(self):
from vnpy.trader.ui import MainWindow, create_qapp from vnpy.trader.ui import MainWindow, create_qapp
def test_import_bitmex_gateway(self):
from vnpy.gateway.bitmex import BitmexGateway from vnpy.gateway.bitmex import BitmexGateway
def test_import_futu_gateway(self):
from vnpy.gateway.futu import FutuGateway from vnpy.gateway.futu import FutuGateway
def test_import_ib_gateway(self):
from vnpy.gateway.ib import IbGateway from vnpy.gateway.ib import IbGateway
@unittest.skipIf(platform.system() == "Darwin", "Not supported yet under osx")
def test_import_ctp_gateway(self):
from vnpy.gateway.ctp import CtpGateway from vnpy.gateway.ctp import CtpGateway
def test_import_tiger_gateway(self):
from vnpy.gateway.tiger import TigerGateway from vnpy.gateway.tiger import TigerGateway
@unittest.skipIf(platform.system() == "Darwin", "Not supported yet under osx")
def test_import_oes_gateway(self):
from vnpy.gateway.oes import OesGateway from vnpy.gateway.oes import OesGateway
def test_import_cta_strategy_app(self):
from vnpy.app.cta_strategy import CtaStrategyApp from vnpy.app.cta_strategy import CtaStrategyApp
def test_import_csv_loader_app(self):
from vnpy.app.csv_loader import CsvLoaderApp from vnpy.app.csv_loader import CsvLoaderApp

2
tests/trader/__init__.py Normal file
View File

@ -0,0 +1,2 @@
from .test_database import *
from .test_settings import *

View File

@ -0,0 +1,128 @@
"""
Test if database works fine
"""
import os
import unittest
from datetime import datetime, timedelta
from vnpy.trader.constant import Exchange, Interval
from vnpy.trader.database.database import Driver
from vnpy.trader.object import BarData, TickData
os.environ['VNPY_TESTING'] = '1'
profiles = {
Driver.SQLITE: {
"driver": "sqlite",
"database": "test_db.db",
}
}
if 'VNPY_TEST_ONLY_SQLITE' not in os.environ:
profiles.update({
Driver.MYSQL: {
"driver": "mysql",
"database": os.environ['VNPY_TEST_MYSQL_DATABASE'],
"host": os.environ['VNPY_TEST_MYSQL_HOST'],
"port": int(os.environ['VNPY_TEST_MYSQL_PORT']),
"user": os.environ["VNPY_TEST_MYSQL_USER"],
"password": os.environ['VNPY_TEST_MYSQL_PASSWORD'],
},
Driver.POSTGRESQL: {
"driver": "postgresql",
"database": os.environ['VNPY_TEST_POSTGRESQL_DATABASE'],
"host": os.environ['VNPY_TEST_POSTGRESQL_HOST'],
"port": int(os.environ['VNPY_TEST_POSTGRESQL_PORT']),
"user": os.environ["VNPY_TEST_POSTGRESQL_USER"],
"password": os.environ['VNPY_TEST_POSTGRESQL_PASSWORD'],
},
Driver.MONGODB: {
"driver": "mongodb",
"database": os.environ['VNPY_TEST_MONGODB_DATABASE'],
"host": os.environ['VNPY_TEST_MONGODB_HOST'],
"port": int(os.environ['VNPY_TEST_MONGODB_PORT']),
"user": "",
"password": "",
"authentication_source": "",
},
})
def now():
return datetime.utcnow()
bar = BarData(
gateway_name="DB",
symbol="test_symbol",
exchange=Exchange.BITMEX,
datetime=now(),
interval=Interval.MINUTE,
)
tick = TickData(
gateway_name="DB",
symbol="test_symbol",
exchange=Exchange.BITMEX,
datetime=now(),
name="DB_test_symbol",
)
class TestDatabase(unittest.TestCase):
def connect(self, settings: dict):
from vnpy.trader.database.initialize import init # noqa
self.manager = init(settings)
def test_upsert_bar(self):
for driver, settings in profiles.items():
with self.subTest(driver=driver, settings=settings):
self.connect(settings)
self.manager.save_bar_data([bar])
self.manager.save_bar_data([bar])
def test_save_load_bar(self):
for driver, settings in profiles.items():
with self.subTest(driver=driver, settings=settings):
self.connect(settings)
# save first
self.manager.save_bar_data([bar])
# and load
results = self.manager.load_bar_data(
symbol=bar.symbol,
exchange=bar.exchange,
interval=bar.interval,
start=bar.datetime - timedelta(seconds=1), # time is not accuracy
end=now() + timedelta(seconds=1), # time is not accuracy
)
count = len(results)
self.assertNotEqual(count, 0)
def test_upsert_tick(self):
for driver, settings in profiles.items():
with self.subTest(driver=driver, settings=settings):
self.connect(settings)
self.manager.save_tick_data([tick])
self.manager.save_tick_data([tick])
def test_save_load_tick(self):
for driver, settings in profiles.items():
with self.subTest(driver=driver, settings=settings):
self.connect(settings)
# save first
self.manager.save_tick_data([tick])
# and load
results = self.manager.load_tick_data(
symbol=bar.symbol,
exchange=bar.exchange,
start=bar.datetime - timedelta(seconds=1), # time is not accuracy
end=now() + timedelta(seconds=1), # time is not accuracy
)
count = len(results)
self.assertNotEqual(count, 0)
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,25 @@
"""
Test if database works fine
"""
import unittest
from vnpy.trader.setting import SETTINGS, get_settings
class TestSettings(unittest.TestCase):
def test_get_settings(self):
SETTINGS['a'] = 1
got = get_settings()
self.assertIn('a', got)
self.assertEqual(got['a'], 1)
def test_get_settings_with_prefix(self):
SETTINGS['a.a'] = 1
got = get_settings()
self.assertIn('a', got)
self.assertEqual(got['a'], 1)
if __name__ == '__main__':
unittest.main()

30
tests/travis_env.sh Normal file
View File

@ -0,0 +1,30 @@
#!/usr/bin/env bash
[[ -z ${VNPY_TEST_MYSQL_DATABASE} ]] && VNPY_TEST_MYSQL_DATABASE=vnpy
[[ -z ${VNPY_TEST_MYSQL_HOST} ]] && VNPY_TEST_MYSQL_HOST=127.0.0.1
[[ -z ${VNPY_TEST_MYSQL_PORT} ]] && VNPY_TEST_MYSQL_PORT=3306
[[ -z ${VNPY_TEST_MYSQL_USER} ]] && VNPY_TEST_MYSQL_USER=root
[[ -z ${VNPY_TEST_MYSQL_PASSWORD} ]] && VNPY_TEST_MYSQL_PASSWORD=
[[ -z ${VNPY_TEST_POSTGRESQL_DATABASE} ]] && VNPY_TEST_POSTGRESQL_DATABASE=vnpy
[[ -z ${VNPY_TEST_POSTGRESQL_HOST} ]] && VNPY_TEST_POSTGRESQL_HOST=127.0.0.1
[[ -z ${VNPY_TEST_POSTGRESQL_PORT} ]] && VNPY_TEST_POSTGRESQL_PORT=5432
[[ -z ${VNPY_TEST_POSTGRESQL_USER} ]] && VNPY_TEST_POSTGRESQL_USER=postgres
[[ -z ${VNPY_TEST_POSTGRESQL_PASSWORD} ]] && VNPY_TEST_POSTGRESQL_PASSWORD=
[[ -z ${VNPY_TEST_MONGODB_DATABASE} ]] && VNPY_TEST_MONGODB_DATABASE=vnpy
[[ -z ${VNPY_TEST_MONGODB_HOST} ]] && VNPY_TEST_MONGODB_HOST=127.0.0.1
[[ -z ${VNPY_TEST_MONGODB_PORT} ]] && VNPY_TEST_MONGODB_PORT=27017
export VNPY_TEST_MYSQL_DATABASE
export VNPY_TEST_MYSQL_HOST
export VNPY_TEST_MYSQL_PORT
export VNPY_TEST_MYSQL_USER
export VNPY_TEST_MYSQL_PASSWORD
export VNPY_TEST_POSTGRESQL_DATABASE
export VNPY_TEST_POSTGRESQL_HOST
export VNPY_TEST_POSTGRESQL_PORT
export VNPY_TEST_POSTGRESQL_USER
export VNPY_TEST_POSTGRESQL_PASSWORD
export VNPY_TEST_MONGODB_DATABASE
export VNPY_TEST_MONGODB_HOST
export VNPY_TEST_MONGODB_PORT

View File

@ -22,14 +22,13 @@ Sample csv file:
import csv import csv
from datetime import datetime from datetime import datetime
from typing import TextIO
from peewee import chunked
from vnpy.event import EventEngine from vnpy.event import EventEngine
from vnpy.trader.constant import Exchange, Interval from vnpy.trader.constant import Exchange, Interval
from vnpy.trader.database import DbBarData, DB from vnpy.trader.database import database_manager
from vnpy.trader.engine import BaseEngine, MainEngine from vnpy.trader.engine import BaseEngine, MainEngine
from vnpy.trader.object import BarData
APP_NAME = "CsvLoader" APP_NAME = "CsvLoader"
@ -53,6 +52,59 @@ class CsvLoaderEngine(BaseEngine):
self.high_head: str = "" self.high_head: str = ""
self.volume_head: str = "" self.volume_head: str = ""
def load_by_handle(
self,
f: TextIO,
symbol: str,
exchange: Exchange,
interval: Interval,
datetime_head: str,
open_head: str,
close_head: str,
low_head: str,
high_head: str,
volume_head: str,
datetime_format: str,
):
"""
load by text mode file handle
"""
reader = csv.DictReader(f)
bars = []
start = None
count = 0
for item in reader:
if datetime_format:
dt = datetime.strptime(item[datetime_head], datetime_format)
else:
dt = datetime.fromisoformat(item[datetime_head])
bar = BarData(
symbol=symbol,
exchange=exchange,
datetime=dt,
interval=interval,
volume=item[volume_head],
open_price=item[open_head],
high_price=item[high_head],
low_price=item[low_head],
close_price=item[close_head],
gateway_name="DB",
)
bars.append(bar)
# do some statistics
count += 1
if not start:
start = bar.datetime
end = bar.datetime
# insert into database
database_manager.save_bar_data(bars)
return start, end, count
def load( def load(
self, self,
file_path: str, file_path: str,
@ -65,49 +117,22 @@ class CsvLoaderEngine(BaseEngine):
low_head: str, low_head: str,
high_head: str, high_head: str,
volume_head: str, volume_head: str,
datetime_format: str datetime_format: str,
): ):
"""""" """
vt_symbol = f"{symbol}.{exchange.value}" load by filename
"""
start = None
end = None
count = 0
with open(file_path, "rt") as f: with open(file_path, "rt") as f:
reader = csv.DictReader(f) return self.load_by_handle(
f,
db_bars = [] symbol=symbol,
exchange=exchange,
for item in reader: interval=interval,
dt = datetime.strptime(item[datetime_head], datetime_format) datetime_head=datetime_head,
open_head=open_head,
db_bar = { close_head=close_head,
"symbol": symbol, low_head=low_head,
"exchange": exchange.value, high_head=high_head,
"datetime": dt, volume_head=volume_head,
"interval": interval.value, datetime_format=datetime_format,
"volume": item[volume_head], )
"open_price": item[open_head],
"high_price": item[high_head],
"low_price": item[low_head],
"close_price": item[close_head],
"vt_symbol": vt_symbol,
"gateway_name": "DB"
}
db_bars.append(db_bar)
# do some statistics
count += 1
if not start:
start = db_bar["datetime"]
end = db_bar["datetime"]
# Insert into DB
with DB.atomic():
for batch in chunked(db_bars, 50):
DbBarData.insert_many(batch).on_conflict_replace().execute()
return start, end, count

View File

@ -12,8 +12,8 @@ from pandas import DataFrame
from vnpy.trader.constant import (Direction, Offset, Exchange, from vnpy.trader.constant import (Direction, Offset, Exchange,
Interval, Status) Interval, Status)
from vnpy.trader.database import DbBarData, DbTickData from vnpy.trader.database import database_manager
from vnpy.trader.object import OrderData, TradeData from vnpy.trader.object import OrderData, TradeData, BarData, TickData
from vnpy.trader.utility import round_to_pricetick from vnpy.trader.utility import round_to_pricetick
from .base import ( from .base import (
@ -103,8 +103,8 @@ class BacktestingEngine:
self.strategy_class = None self.strategy_class = None
self.strategy = None self.strategy = None
self.tick = None self.tick: TickData
self.bar = None self.bar: BarData
self.datetime = None self.datetime = None
self.interval = None self.interval = None
@ -199,14 +199,16 @@ class BacktestingEngine:
if self.mode == BacktestingMode.BAR: if self.mode == BacktestingMode.BAR:
self.history_data = load_bar_data( self.history_data = load_bar_data(
self.vt_symbol, self.symbol,
self.exchange,
self.interval, self.interval,
self.start, self.start,
self.end self.end
) )
else: else:
self.history_data = load_tick_data( self.history_data = load_tick_data(
self.vt_symbol, self.symbol,
self.exchange,
self.start, self.start,
self.end self.end
) )
@ -520,7 +522,7 @@ class BacktestingEngine:
else: else:
self.daily_results[d] = DailyResult(d, price) self.daily_results[d] = DailyResult(d, price)
def new_bar(self, bar: DbBarData): def new_bar(self, bar: BarData):
"""""" """"""
self.bar = bar self.bar = bar
self.datetime = bar.datetime self.datetime = bar.datetime
@ -531,7 +533,7 @@ class BacktestingEngine:
self.update_daily_close(bar.close_price) self.update_daily_close(bar.close_price)
def new_tick(self, tick: DbTickData): def new_tick(self, tick: TickData):
"""""" """"""
self.tick = tick self.tick = tick
self.datetime = tick.datetime self.datetime = tick.datetime
@ -966,41 +968,26 @@ def optimize(
@lru_cache(maxsize=10) @lru_cache(maxsize=10)
def load_bar_data( def load_bar_data(
vt_symbol: str, symbol: str,
interval: str, exchange: Exchange,
start: datetime, interval: Interval,
start: datetime,
end: datetime end: datetime
): ):
"""""" """"""
s = ( return database_manager.load_bar_data(
DbBarData.select() symbol, exchange, interval, start, end
.where(
(DbBarData.vt_symbol == vt_symbol)
& (DbBarData.interval == interval)
& (DbBarData.datetime >= start)
& (DbBarData.datetime <= end)
)
.order_by(DbBarData.datetime)
) )
data = [db_bar.to_bar() for db_bar in s]
return data
@lru_cache(maxsize=10) @lru_cache(maxsize=10)
def load_tick_data( def load_tick_data(
vt_symbol: str, symbol: str,
start: datetime, exchange: Exchange,
start: datetime,
end: datetime end: datetime
): ):
"""""" """"""
s = ( return database_manager.load_tick_data(
DbTickData.select() symbol, exchange, start, end
.where(
(DbTickData.vt_symbol == vt_symbol)
& (DbTickData.datetime >= start)
& (DbTickData.datetime <= end)
)
.order_by(DbTickData.datetime)
) )
data = [db_tick.db_tick() for db_tick in s]
return data

View File

@ -5,7 +5,7 @@ import os
import traceback import traceback
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Any, Callable from typing import Any, Callable, List
from datetime import datetime, timedelta from datetime import datetime, timedelta
from threading import Thread from threading import Thread
from queue import Queue from queue import Queue
@ -36,7 +36,7 @@ from vnpy.trader.constant import (
Status Status
) )
from vnpy.trader.utility import load_json, save_json from vnpy.trader.utility import load_json, save_json
from vnpy.trader.database import DbTickData, DbBarData from vnpy.trader.database import database_manager
from vnpy.trader.setting import SETTINGS from vnpy.trader.setting import SETTINGS
from .base import ( from .base import (
@ -146,13 +146,12 @@ class CtaEngine(BaseEngine):
self.write_log("RQData数据接口初始化成功") self.write_log("RQData数据接口初始化成功")
def query_bar_from_rq( def query_bar_from_rq(
self, vt_symbol: str, interval: Interval, start: datetime, end: datetime self, symbol: str, exchange: Exchange, interval: Interval, start: datetime, end: datetime
): ):
""" """
Query bar data from RQData. Query bar data from RQData.
""" """
symbol, exchange_str = vt_symbol.split(".") rq_symbol = to_rq_symbol(symbol, exchange)
rq_symbol = to_rq_symbol(vt_symbol)
if rq_symbol not in self.rq_symbols: if rq_symbol not in self.rq_symbols:
return None return None
@ -166,11 +165,11 @@ class CtaEngine(BaseEngine):
end_date=end end_date=end
) )
data = [] data: List[BarData] = []
for ix, row in df.iterrows(): for ix, row in df.iterrows():
bar = BarData( bar = BarData(
symbol=symbol, symbol=symbol,
exchange=Exchange(exchange_str), exchange=exchange,
interval=interval, interval=interval,
datetime=row.name.to_pydatetime(), datetime=row.name.to_pydatetime(),
open_price=row["open"], open_price=row["open"],
@ -529,46 +528,41 @@ class CtaEngine(BaseEngine):
return self.engine_type return self.engine_type
def load_bar( def load_bar(
self, vt_symbol: str, days: int, interval: Interval, callback: Callable self, symbol: str, exchange: Exchange, days: int, interval: Interval,
callback: Callable[[BarData], None]
): ):
"""""" """"""
end = datetime.now() end = datetime.now()
start = end - timedelta(days) start = end - timedelta(days)
# Query data from RQData by default, if not found, load from database. # Query bars from RQData by default, if not found, load from database.
data = self.query_bar_from_rq(vt_symbol, interval, start, end) bars = self.query_bar_from_rq(symbol, exchange, interval, start, end)
if not data: if not bars:
s = ( bars = database_manager.load_bar_data(
DbBarData.select() symbol=symbol,
.where( exchange=exchange,
(DbBarData.vt_symbol == vt_symbol) interval=interval,
& (DbBarData.interval == interval.value) start=start,
& (DbBarData.datetime >= start) end=end,
& (DbBarData.datetime <= end)
)
.order_by(DbBarData.datetime)
) )
data = [db_bar.to_bar() for db_bar in s]
for bar in data: for bar in bars:
callback(bar) callback(bar)
def load_tick(self, vt_symbol: str, days: int, callback: Callable): def load_tick(self, symbol: str, exchange: Exchange, days: int,
callback: Callable[[TickData], None]):
"""""" """"""
end = datetime.now() end = datetime.now()
start = end - timedelta(days) start = end - timedelta(days)
s = ( ticks = database_manager.load_tick_data(
DbTickData.select() symbol=symbol,
.where( exchange=exchange,
(DbBarData.vt_symbol == vt_symbol) start=start,
& (DbBarData.datetime >= start) end=end,
& (DbBarData.datetime <= end)
)
.order_by(DbBarData.datetime)
) )
for tick in s: for tick in ticks:
callback(tick) callback(tick)
def call_strategy_func( def call_strategy_func(
@ -757,7 +751,7 @@ class CtaEngine(BaseEngine):
""" """
Load strategy class from certain folder. Load strategy class from certain folder.
""" """
for dirpath, dirnames, filenames in os.walk(path): for dirpath, dirnames, filenames in os.walk(str(path)):
for filename in filenames: for filename in filenames:
if filename.endswith(".py"): if filename.endswith(".py"):
strategy_module_name = ".".join( strategy_module_name = ".".join(
@ -914,19 +908,19 @@ class CtaEngine(BaseEngine):
self.main_engine.send_email(subject, msg) self.main_engine.send_email(subject, msg)
def to_rq_symbol(vt_symbol: str): def to_rq_symbol(symbol: str, exchange: Exchange):
""" """
CZCE product of RQData has symbol like "TA1905" while CZCE product of RQData has symbol like "TA1905" while
vt symbol is "TA905.CZCE" so need to add "1" in symbol. vt symbol is "TA905.CZCE" so need to add "1" in symbol.
""" """
symbol, exchange_str = vt_symbol.split(".") if exchange is not Exchange.CZCE:
if exchange_str != "CZCE":
return symbol.upper() return symbol.upper()
for count, word in enumerate(symbol): for count, word in enumerate(symbol):
if word.isdigit(): if word.isdigit():
break break
# noinspection PyUnboundLocalVariable
product = symbol[:count] product = symbol[:count]
year = symbol[count] year = symbol[count]
month = symbol[count + 1:] month = symbol[count + 1:]

View File

@ -1,268 +0,0 @@
""""""
from peewee import CharField, DateTimeField, FloatField, Model, MySQLDatabase, PostgresqlDatabase, \
SqliteDatabase
from .constant import Exchange, Interval
from .object import BarData, TickData
from .setting import SETTINGS
from .utility import resolve_path
def init():
db_settings = SETTINGS['database']
driver = db_settings["driver"]
init_funcs = {
"sqlite": init_sqlite,
"mysql": init_mysql,
"postgresql": init_postgresql,
}
assert driver in init_funcs
del db_settings['driver']
return init_funcs[driver](db_settings)
def init_sqlite(settings: dict):
global DB
database = settings['database']
DB = SqliteDatabase(str(resolve_path(database)))
def init_mysql(settings: dict):
global DB
DB = MySQLDatabase(**settings)
def init_postgresql(settings: dict):
global DB
DB = PostgresqlDatabase(**settings)
init()
class DbBarData(Model):
"""
Candlestick bar data for database storage.
Index is defined unique with vt_symbol, interval and datetime.
"""
symbol = CharField()
exchange = CharField()
datetime = DateTimeField()
interval = CharField()
volume = FloatField()
open_price = FloatField()
high_price = FloatField()
low_price = FloatField()
close_price = FloatField()
vt_symbol = CharField()
gateway_name = CharField()
class Meta:
database = DB
indexes = ((("vt_symbol", "interval", "datetime"), True),)
@staticmethod
def from_bar(bar: BarData):
"""
Generate DbBarData object from BarData.
"""
db_bar = DbBarData()
db_bar.symbol = bar.symbol
db_bar.exchange = bar.exchange.value
db_bar.datetime = bar.datetime
db_bar.interval = bar.interval.value
db_bar.volume = bar.volume
db_bar.open_price = bar.open_price
db_bar.high_price = bar.high_price
db_bar.low_price = bar.low_price
db_bar.close_price = bar.close_price
db_bar.vt_symbol = bar.vt_symbol
db_bar.gateway_name = "DB"
return db_bar
def to_bar(self):
"""
Generate BarData object from DbBarData.
"""
bar = BarData(
symbol=self.symbol,
exchange=Exchange(self.exchange),
datetime=self.datetime,
interval=Interval(self.interval),
volume=self.volume,
open_price=self.open_price,
high_price=self.high_price,
low_price=self.low_price,
close_price=self.close_price,
gateway_name=self.gateway_name,
)
return bar
class DbTickData(Model):
"""
Tick data for database storage.
Index is defined unique with vt_symbol, interval and datetime.
"""
symbol = CharField()
exchange = CharField()
datetime = DateTimeField()
name = CharField()
volume = FloatField()
last_price = FloatField()
last_volume = FloatField()
limit_up = FloatField()
limit_down = FloatField()
open_price = FloatField()
high_price = FloatField()
low_price = FloatField()
close_price = FloatField()
bid_price_1 = FloatField()
bid_price_2 = FloatField()
bid_price_3 = FloatField()
bid_price_4 = FloatField()
bid_price_5 = FloatField()
ask_price_1 = FloatField()
ask_price_2 = FloatField()
ask_price_3 = FloatField()
ask_price_4 = FloatField()
ask_price_5 = FloatField()
bid_volume_1 = FloatField()
bid_volume_2 = FloatField()
bid_volume_3 = FloatField()
bid_volume_4 = FloatField()
bid_volume_5 = FloatField()
ask_volume_1 = FloatField()
ask_volume_2 = FloatField()
ask_volume_3 = FloatField()
ask_volume_4 = FloatField()
ask_volume_5 = FloatField()
vt_symbol = CharField()
gateway_name = CharField()
class Meta:
database = DB
indexes = ((("vt_symbol", "datetime"), True),)
@staticmethod
def from_tick(tick: TickData):
"""
Generate DbTickData object from TickData.
"""
db_tick = DbTickData()
db_tick.symbol = tick.symbol
db_tick.exchange = tick.exchange.value
db_tick.datetime = tick.datetime
db_tick.name = tick.name
db_tick.volume = tick.volume
db_tick.last_price = tick.last_price
db_tick.last_volume = tick.last_volume
db_tick.limit_up = tick.limit_up
db_tick.limit_down = tick.limit_down
db_tick.open_price = tick.open_price
db_tick.high_price = tick.high_price
db_tick.low_price = tick.low_price
db_tick.pre_close = tick.pre_close
db_tick.bid_price_1 = tick.bid_price_1
db_tick.ask_price_1 = tick.ask_price_1
db_tick.bid_volume_1 = tick.bid_volume_1
db_tick.ask_volume_1 = tick.ask_volume_1
if tick.bid_price_2:
db_tick.bid_price_2 = tick.bid_price_2
db_tick.bid_price_3 = tick.bid_price_3
db_tick.bid_price_4 = tick.bid_price_4
db_tick.bid_price_5 = tick.bid_price_5
db_tick.ask_price_2 = tick.ask_price_2
db_tick.ask_price_3 = tick.ask_price_3
db_tick.ask_price_4 = tick.ask_price_4
db_tick.ask_price_5 = tick.ask_price_5
db_tick.bid_volume_2 = tick.bid_volume_2
db_tick.bid_volume_3 = tick.bid_volume_3
db_tick.bid_volume_4 = tick.bid_volume_4
db_tick.bid_volume_5 = tick.bid_volume_5
db_tick.ask_volume_2 = tick.ask_volume_2
db_tick.ask_volume_3 = tick.ask_volume_3
db_tick.ask_volume_4 = tick.ask_volume_4
db_tick.ask_volume_5 = tick.ask_volume_5
db_tick.vt_symbol = tick.vt_symbol
db_tick.gateway_name = "DB"
return tick
def to_tick(self):
"""
Generate TickData object from DbTickData.
"""
tick = TickData(
symbol=self.symbol,
exchange=Exchange(self.exchange),
datetime=self.datetime,
name=self.name,
volume=self.volume,
last_price=self.last_price,
last_volume=self.last_volume,
limit_up=self.limit_up,
limit_down=self.limit_down,
open_price=self.open_price,
high_price=self.high_price,
low_price=self.low_price,
pre_close=self.pre_close,
bid_price_1=self.bid_price_1,
ask_price_1=self.ask_price_1,
bid_volume_1=self.bid_volume_1,
ask_volume_1=self.ask_volume_1,
gateway_name=self.gateway_name,
)
if self.bid_price_2:
tick.bid_price_2 = self.bid_price_2
tick.bid_price_3 = self.bid_price_3
tick.bid_price_4 = self.bid_price_4
tick.bid_price_5 = self.bid_price_5
tick.ask_price_2 = self.ask_price_2
tick.ask_price_3 = self.ask_price_3
tick.ask_price_4 = self.ask_price_4
tick.ask_price_5 = self.ask_price_5
tick.bid_volume_2 = self.bid_volume_2
tick.bid_volume_3 = self.bid_volume_3
tick.bid_volume_4 = self.bid_volume_4
tick.bid_volume_5 = self.bid_volume_5
tick.ask_volume_2 = self.ask_volume_2
tick.ask_volume_3 = self.ask_volume_3
tick.ask_volume_4 = self.ask_volume_4
tick.ask_volume_5 = self.ask_volume_5
return tick
DB.connect()
DB.create_tables([DbBarData, DbTickData])

View File

@ -0,0 +1,12 @@
import os
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vnpy.trader.database.database import BaseDatabaseManager
if "VNPY_TESTING" not in os.environ:
from vnpy.trader.setting import get_settings
from .initialize import init
settings = get_settings("database.")
database_manager: "BaseDatabaseManager" = init(settings=settings)

View File

@ -0,0 +1,53 @@
from abc import ABC, abstractmethod
from datetime import datetime
from enum import Enum
from typing import Sequence, TYPE_CHECKING
if TYPE_CHECKING:
from vnpy.trader.constant import Interval, Exchange # noqa
from vnpy.trader.object import BarData, TickData # noqa
class Driver(Enum):
SQLITE = "sqlite"
MYSQL = "mysql"
POSTGRESQL = "postgresql"
MONGODB = "mongodb"
class BaseDatabaseManager(ABC):
@abstractmethod
def load_bar_data(
self,
symbol: str,
exchange: "Exchange",
interval: "Interval",
start: datetime,
end: datetime
) -> Sequence["BarData"]:
pass
@abstractmethod
def load_tick_data(
self,
symbol: str,
exchange: "Exchange",
start: datetime,
end: datetime
) -> Sequence["TickData"]:
pass
@abstractmethod
def save_bar_data(
self,
datas: Sequence["BarData"],
):
pass
@abstractmethod
def save_tick_data(
self,
datas: Sequence["TickData"],
):
pass

View File

@ -0,0 +1,302 @@
from datetime import datetime
from enum import Enum
from typing import Sequence
from mongoengine import DateTimeField, Document, FloatField, StringField, connect
from vnpy.trader.constant import Exchange, Interval
from vnpy.trader.object import BarData, TickData
from .database import BaseDatabaseManager, Driver
def init(_: Driver, settings: dict):
database = settings["database"]
host = settings["host"]
port = settings["port"]
username = settings["user"]
password = settings["password"]
authentication_source = settings["authentication_source"]
if not username: # if username == '' or None, skip username
username = None
password = None
authentication_source = None
connect(
db=database,
host=host,
port=port,
username=username,
password=password,
authentication_source=authentication_source,
)
return MongoManager()
class DbBarData(Document):
"""
Candlestick bar data for database storage.
Index is defined unique with datetime, interval, symbol
"""
symbol: str = StringField()
exchange: str = StringField()
datetime: datetime = DateTimeField()
interval: str = StringField()
volume: float = FloatField()
open_price: float = FloatField()
high_price: float = FloatField()
low_price: float = FloatField()
close_price: float = FloatField()
meta = {
"indexes": [
{"fields": ("datetime", "interval", "symbol", "exchange"), "unique": True}
]
}
@staticmethod
def from_bar(bar: BarData):
"""
Generate DbBarData object from BarData.
"""
db_bar = DbBarData()
db_bar.symbol = bar.symbol
db_bar.exchange = bar.exchange.value
db_bar.datetime = bar.datetime
db_bar.interval = bar.interval.value
db_bar.volume = bar.volume
db_bar.open_price = bar.open_price
db_bar.high_price = bar.high_price
db_bar.low_price = bar.low_price
db_bar.close_price = bar.close_price
return db_bar
def to_bar(self):
"""
Generate BarData object from DbBarData.
"""
bar = BarData(
symbol=self.symbol,
exchange=Exchange(self.exchange),
datetime=self.datetime,
interval=Interval(self.interval),
volume=self.volume,
open_price=self.open_price,
high_price=self.high_price,
low_price=self.low_price,
close_price=self.close_price,
gateway_name="DB",
)
return bar
class DbTickData(Document):
"""
Tick data for database storage.
Index is defined unique with (datetime, symbol)
"""
symbol: str = StringField()
exchange: str = StringField()
datetime: datetime = DateTimeField()
name: str = StringField()
volume: float = FloatField()
last_price: float = FloatField()
last_volume: float = FloatField()
limit_up: float = FloatField()
limit_down: float = FloatField()
open_price: float = FloatField()
high_price: float = FloatField()
low_price: float = FloatField()
close_price: float = FloatField()
pre_close: float = FloatField()
bid_price_1: float = FloatField()
bid_price_2: float = FloatField()
bid_price_3: float = FloatField()
bid_price_4: float = FloatField()
bid_price_5: float = FloatField()
ask_price_1: float = FloatField()
ask_price_2: float = FloatField()
ask_price_3: float = FloatField()
ask_price_4: float = FloatField()
ask_price_5: float = FloatField()
bid_volume_1: float = FloatField()
bid_volume_2: float = FloatField()
bid_volume_3: float = FloatField()
bid_volume_4: float = FloatField()
bid_volume_5: float = FloatField()
ask_volume_1: float = FloatField()
ask_volume_2: float = FloatField()
ask_volume_3: float = FloatField()
ask_volume_4: float = FloatField()
ask_volume_5: float = FloatField()
meta = {"indexes": [{"fields": ("datetime", "symbol", "exchange"), "unique": True}]}
@staticmethod
def from_tick(tick: TickData):
"""
Generate DbTickData object from TickData.
"""
db_tick = DbTickData()
db_tick.symbol = tick.symbol
db_tick.exchange = tick.exchange.value
db_tick.datetime = tick.datetime
db_tick.name = tick.name
db_tick.volume = tick.volume
db_tick.last_price = tick.last_price
db_tick.last_volume = tick.last_volume
db_tick.limit_up = tick.limit_up
db_tick.limit_down = tick.limit_down
db_tick.open_price = tick.open_price
db_tick.high_price = tick.high_price
db_tick.low_price = tick.low_price
db_tick.pre_close = tick.pre_close
db_tick.bid_price_1 = tick.bid_price_1
db_tick.ask_price_1 = tick.ask_price_1
db_tick.bid_volume_1 = tick.bid_volume_1
db_tick.ask_volume_1 = tick.ask_volume_1
if tick.bid_price_2:
db_tick.bid_price_2 = tick.bid_price_2
db_tick.bid_price_3 = tick.bid_price_3
db_tick.bid_price_4 = tick.bid_price_4
db_tick.bid_price_5 = tick.bid_price_5
db_tick.ask_price_2 = tick.ask_price_2
db_tick.ask_price_3 = tick.ask_price_3
db_tick.ask_price_4 = tick.ask_price_4
db_tick.ask_price_5 = tick.ask_price_5
db_tick.bid_volume_2 = tick.bid_volume_2
db_tick.bid_volume_3 = tick.bid_volume_3
db_tick.bid_volume_4 = tick.bid_volume_4
db_tick.bid_volume_5 = tick.bid_volume_5
db_tick.ask_volume_2 = tick.ask_volume_2
db_tick.ask_volume_3 = tick.ask_volume_3
db_tick.ask_volume_4 = tick.ask_volume_4
db_tick.ask_volume_5 = tick.ask_volume_5
return db_tick
def to_tick(self):
"""
Generate TickData object from DbTickData.
"""
tick = TickData(
symbol=self.symbol,
exchange=Exchange(self.exchange),
datetime=self.datetime,
name=self.name,
volume=self.volume,
last_price=self.last_price,
last_volume=self.last_volume,
limit_up=self.limit_up,
limit_down=self.limit_down,
open_price=self.open_price,
high_price=self.high_price,
low_price=self.low_price,
pre_close=self.pre_close,
bid_price_1=self.bid_price_1,
ask_price_1=self.ask_price_1,
bid_volume_1=self.bid_volume_1,
ask_volume_1=self.ask_volume_1,
gateway_name="DB",
)
if self.bid_price_2:
tick.bid_price_2 = self.bid_price_2
tick.bid_price_3 = self.bid_price_3
tick.bid_price_4 = self.bid_price_4
tick.bid_price_5 = self.bid_price_5
tick.ask_price_2 = self.ask_price_2
tick.ask_price_3 = self.ask_price_3
tick.ask_price_4 = self.ask_price_4
tick.ask_price_5 = self.ask_price_5
tick.bid_volume_2 = self.bid_volume_2
tick.bid_volume_3 = self.bid_volume_3
tick.bid_volume_4 = self.bid_volume_4
tick.bid_volume_5 = self.bid_volume_5
tick.ask_volume_2 = self.ask_volume_2
tick.ask_volume_3 = self.ask_volume_3
tick.ask_volume_4 = self.ask_volume_4
tick.ask_volume_5 = self.ask_volume_5
return tick
class MongoManager(BaseDatabaseManager):
def load_bar_data(
self,
symbol: str,
exchange: Exchange,
interval: Interval,
start: datetime,
end: datetime,
) -> Sequence[BarData]:
s = DbBarData.objects(
symbol=symbol,
exchange=exchange.value,
interval=interval.value,
datetime__gte=start,
datetime__lte=end,
)
data = [db_bar.to_bar() for db_bar in s]
return data
def load_tick_data(
self, symbol: str, exchange: Exchange, start: datetime, end: datetime
) -> Sequence[TickData]:
s = DbTickData.objects(
symbol=symbol,
exchange=exchange.value,
datetime__gte=start,
datetime__lte=end,
)
data = [db_tick.to_tick() for db_tick in s]
return data
@staticmethod
def to_update_param(d):
return {
"set__" + k: v.value if isinstance(v, Enum) else v
for k, v in d.__dict__.items()
}
def save_bar_data(self, datas: Sequence[BarData]):
for d in datas:
updates = self.to_update_param(d)
updates.pop("set__gateway_name")
updates.pop("set__vt_symbol")
(
DbBarData.objects(
symbol=d.symbol, interval=d.interval.value, datetime=d.datetime
).update_one(upsert=True, **updates)
)
def save_tick_data(self, datas: Sequence[TickData]):
for d in datas:
updates = self.to_update_param(d)
updates.pop("set__gateway_name")
updates.pop("set__vt_symbol")
(
DbTickData.objects(
symbol=d.symbol, exchange=d.exchange.value, datetime=d.datetime
).update_one(upsert=True, **updates)
)

View File

@ -0,0 +1,369 @@
""""""
from datetime import datetime
from typing import List, Sequence, Type
from peewee import (
AutoField,
CharField,
Database,
DateTimeField,
FloatField,
Model,
MySQLDatabase,
PostgresqlDatabase,
SqliteDatabase,
chunked,
)
from vnpy.trader.constant import Exchange, Interval
from vnpy.trader.object import BarData, TickData
from vnpy.trader.utility import get_file_path
from .database import BaseDatabaseManager, Driver
def init(driver: Driver, settings: dict):
init_funcs = {
Driver.SQLITE: init_sqlite,
Driver.MYSQL: init_mysql,
Driver.POSTGRESQL: init_postgresql,
}
assert driver in init_funcs
db = init_funcs[driver](settings)
bar, tick = init_models(db, driver)
return SqlManager(bar, tick)
def init_sqlite(settings: dict):
database = settings["database"]
path = str(get_file_path(database))
db = SqliteDatabase(path)
return db
def init_mysql(settings: dict):
keys = {"database", "user", "password", "host", "port"}
settings = {k: v for k, v in settings.items() if k in keys}
db = MySQLDatabase(**settings)
return db
def init_postgresql(settings: dict):
keys = {"database", "user", "password", "host", "port"}
settings = {k: v for k, v in settings.items() if k in keys}
db = PostgresqlDatabase(**settings)
return db
class ModelBase(Model):
def to_dict(self):
return self.__data__
def init_models(db: Database, driver: Driver):
class DbBarData(ModelBase):
"""
Candlestick bar data for database storage.
Index is defined unique with datetime, interval, symbol
"""
id = AutoField()
symbol: str = CharField()
exchange: str = CharField()
datetime: datetime = DateTimeField()
interval: str = CharField()
volume: float = FloatField()
open_price: float = FloatField()
high_price: float = FloatField()
low_price: float = FloatField()
close_price: float = FloatField()
class Meta:
database = db
indexes = ((("datetime", "interval", "symbol", "exchange"), True),)
@staticmethod
def from_bar(bar: BarData):
"""
Generate DbBarData object from BarData.
"""
db_bar = DbBarData()
db_bar.symbol = bar.symbol
db_bar.exchange = bar.exchange.value
db_bar.datetime = bar.datetime
db_bar.interval = bar.interval.value
db_bar.volume = bar.volume
db_bar.open_price = bar.open_price
db_bar.high_price = bar.high_price
db_bar.low_price = bar.low_price
db_bar.close_price = bar.close_price
return db_bar
def to_bar(self):
"""
Generate BarData object from DbBarData.
"""
bar = BarData(
symbol=self.symbol,
exchange=Exchange(self.exchange),
datetime=self.datetime,
interval=Interval(self.interval),
volume=self.volume,
open_price=self.open_price,
high_price=self.high_price,
low_price=self.low_price,
close_price=self.close_price,
gateway_name="DB",
)
return bar
@staticmethod
def save_all(objs: List["DbBarData"]):
"""
save a list of objects, update if exists.
"""
dicts = [i.to_dict() for i in objs]
with db.atomic():
if driver is Driver.POSTGRESQL:
for bar in dicts:
DbBarData.insert(bar).on_conflict(
update=bar,
conflict_target=(
DbBarData.datetime,
DbBarData.interval,
DbBarData.symbol,
DbBarData.exchange,
),
).execute()
else:
for c in chunked(dicts, 50):
DbBarData.insert_many(c).on_conflict_replace().execute()
class DbTickData(ModelBase):
"""
Tick data for database storage.
Index is defined unique with (datetime, symbol)
"""
id = AutoField()
symbol: str = CharField()
exchange: str = CharField()
datetime: datetime = DateTimeField()
name: str = CharField()
volume: float = FloatField()
last_price: float = FloatField()
last_volume: float = FloatField()
limit_up: float = FloatField()
limit_down: float = FloatField()
open_price: float = FloatField()
high_price: float = FloatField()
low_price: float = FloatField()
pre_close: float = FloatField()
bid_price_1: float = FloatField()
bid_price_2: float = FloatField(null=True)
bid_price_3: float = FloatField(null=True)
bid_price_4: float = FloatField(null=True)
bid_price_5: float = FloatField(null=True)
ask_price_1: float = FloatField()
ask_price_2: float = FloatField(null=True)
ask_price_3: float = FloatField(null=True)
ask_price_4: float = FloatField(null=True)
ask_price_5: float = FloatField(null=True)
bid_volume_1: float = FloatField()
bid_volume_2: float = FloatField(null=True)
bid_volume_3: float = FloatField(null=True)
bid_volume_4: float = FloatField(null=True)
bid_volume_5: float = FloatField(null=True)
ask_volume_1: float = FloatField()
ask_volume_2: float = FloatField(null=True)
ask_volume_3: float = FloatField(null=True)
ask_volume_4: float = FloatField(null=True)
ask_volume_5: float = FloatField(null=True)
class Meta:
database = db
indexes = ((("datetime", "symbol", "exchange"), True),)
@staticmethod
def from_tick(tick: TickData):
"""
Generate DbTickData object from TickData.
"""
db_tick = DbTickData()
db_tick.symbol = tick.symbol
db_tick.exchange = tick.exchange.value
db_tick.datetime = tick.datetime
db_tick.name = tick.name
db_tick.volume = tick.volume
db_tick.last_price = tick.last_price
db_tick.last_volume = tick.last_volume
db_tick.limit_up = tick.limit_up
db_tick.limit_down = tick.limit_down
db_tick.open_price = tick.open_price
db_tick.high_price = tick.high_price
db_tick.low_price = tick.low_price
db_tick.pre_close = tick.pre_close
db_tick.bid_price_1 = tick.bid_price_1
db_tick.ask_price_1 = tick.ask_price_1
db_tick.bid_volume_1 = tick.bid_volume_1
db_tick.ask_volume_1 = tick.ask_volume_1
if tick.bid_price_2:
db_tick.bid_price_2 = tick.bid_price_2
db_tick.bid_price_3 = tick.bid_price_3
db_tick.bid_price_4 = tick.bid_price_4
db_tick.bid_price_5 = tick.bid_price_5
db_tick.ask_price_2 = tick.ask_price_2
db_tick.ask_price_3 = tick.ask_price_3
db_tick.ask_price_4 = tick.ask_price_4
db_tick.ask_price_5 = tick.ask_price_5
db_tick.bid_volume_2 = tick.bid_volume_2
db_tick.bid_volume_3 = tick.bid_volume_3
db_tick.bid_volume_4 = tick.bid_volume_4
db_tick.bid_volume_5 = tick.bid_volume_5
db_tick.ask_volume_2 = tick.ask_volume_2
db_tick.ask_volume_3 = tick.ask_volume_3
db_tick.ask_volume_4 = tick.ask_volume_4
db_tick.ask_volume_5 = tick.ask_volume_5
return db_tick
def to_tick(self):
"""
Generate TickData object from DbTickData.
"""
tick = TickData(
symbol=self.symbol,
exchange=Exchange(self.exchange),
datetime=self.datetime,
name=self.name,
volume=self.volume,
last_price=self.last_price,
last_volume=self.last_volume,
limit_up=self.limit_up,
limit_down=self.limit_down,
open_price=self.open_price,
high_price=self.high_price,
low_price=self.low_price,
pre_close=self.pre_close,
bid_price_1=self.bid_price_1,
ask_price_1=self.ask_price_1,
bid_volume_1=self.bid_volume_1,
ask_volume_1=self.ask_volume_1,
gateway_name="DB",
)
if self.bid_price_2:
tick.bid_price_2 = self.bid_price_2
tick.bid_price_3 = self.bid_price_3
tick.bid_price_4 = self.bid_price_4
tick.bid_price_5 = self.bid_price_5
tick.ask_price_2 = self.ask_price_2
tick.ask_price_3 = self.ask_price_3
tick.ask_price_4 = self.ask_price_4
tick.ask_price_5 = self.ask_price_5
tick.bid_volume_2 = self.bid_volume_2
tick.bid_volume_3 = self.bid_volume_3
tick.bid_volume_4 = self.bid_volume_4
tick.bid_volume_5 = self.bid_volume_5
tick.ask_volume_2 = self.ask_volume_2
tick.ask_volume_3 = self.ask_volume_3
tick.ask_volume_4 = self.ask_volume_4
tick.ask_volume_5 = self.ask_volume_5
return tick
@staticmethod
def save_all(objs: List["DbTickData"]):
dicts = [i.to_dict() for i in objs]
with db.atomic():
if driver is Driver.POSTGRESQL:
for tick in dicts:
DbTickData.insert(tick).on_conflict(
update=tick,
conflict_target=(
DbTickData.datetime,
DbTickData.symbol,
DbTickData.exchange,
),
).execute()
else:
for c in chunked(dicts, 50):
DbTickData.insert_many(c).on_conflict_replace().execute()
db.connect()
db.create_tables([DbBarData, DbTickData])
return DbBarData, DbTickData
class SqlManager(BaseDatabaseManager):
def __init__(self, class_bar: Type[Model], class_tick: Type[Model]):
self.class_bar = class_bar
self.class_tick = class_tick
def load_bar_data(
self,
symbol: str,
exchange: Exchange,
interval: Interval,
start: datetime,
end: datetime,
) -> Sequence[BarData]:
s = (
self.class_bar.select()
.where(
(self.class_bar.symbol == symbol)
& (self.class_bar.exchange == exchange.value)
& (self.class_bar.interval == interval.value)
& (self.class_bar.datetime >= start)
& (self.class_bar.datetime <= end)
)
.order_by(self.class_bar.datetime)
)
data = [db_bar.to_bar() for db_bar in s]
return data
def load_tick_data(
self, symbol: str, exchange: Exchange, start: datetime, end: datetime
) -> Sequence[TickData]:
s = (
self.class_tick.select()
.where(
(self.class_tick.symbol == symbol)
& (self.class_tick.exchange == exchange.value)
& (self.class_tick.datetime >= start)
& (self.class_tick.datetime <= end)
)
.order_by(self.class_tick.datetime)
)
data = [db_tick.to_tick() for db_tick in s]
return data
def save_bar_data(self, datas: Sequence[BarData]):
ds = [self.class_bar.from_bar(i) for i in datas]
self.class_bar.save_all(ds)
def save_tick_data(self, datas: Sequence[TickData]):
ds = [self.class_tick.from_tick(i) for i in datas]
self.class_tick.save_all(ds)

View File

@ -0,0 +1,24 @@
""""""
from .database import BaseDatabaseManager, Driver
def init(settings: dict) -> BaseDatabaseManager:
driver = Driver(settings["driver"])
if driver is Driver.MONGODB:
return init_nosql(driver=driver, settings=settings)
else:
return init_sql(driver=driver, settings=settings)
def init_sql(driver: Driver, settings: dict):
from .database_sql import init
keys = {'database', "host", "port", "user", "password"}
settings = {k: v for k, v in settings.items() if k in keys}
_database_manager = init(driver, settings)
return _database_manager
def init_nosql(driver: Driver, settings: dict):
from .database_mongo import init
_database_manager = init(driver, settings=settings)
return _database_manager

View File

@ -24,16 +24,21 @@ SETTINGS = {
"rqdata.username": "", "rqdata.username": "",
"rqdata.password": "", "rqdata.password": "",
"database": {
"driver": "sqlite", # sqlite, mysql, postgresql "database.driver": "sqlite", # see database.Driver
"database": "{VNPY_TEMP}/database.db", # for sqlite, use this as filepath "database.database": "database.db", # for sqlite, use this as filepath
"host": "localhost", "database.host": "localhost",
"port": 3306, "database.port": 3306,
"user": "root", "database.user": "root",
"password": "" "database.password": "",
} "database.authentication_source": "admin", # for mongodb
} }
# Load global setting from json file. # Load global setting from json file.
SETTING_FILENAME = "vt_setting.json" SETTING_FILENAME = "vt_setting.json"
SETTINGS.update(load_json(SETTING_FILENAME)) SETTINGS.update(load_json(SETTING_FILENAME))
def get_settings(prefix: str = ""):
prefix_length = len(prefix)
return {k[prefix_length:]: v for k, v in SETTINGS.items() if k.startswith(prefix)}

View File

@ -3,20 +3,28 @@ General utility functions.
""" """
import json import json
import os
from pathlib import Path from pathlib import Path
from typing import Callable from typing import Callable, TYPE_CHECKING
import numpy as np import numpy as np
import talib import talib
from .object import BarData, TickData from .object import BarData, TickData
if TYPE_CHECKING:
from vnpy.trader.constant import Exchange
def resolve_path(pattern: str):
env = dict(os.environ) def extract_vt_symbol(vt_symbol: str):
env.update({"VNPY_TEMP": str(TEMP_DIR)}) """
return pattern.format(**env) :return: (symbol, exchange)
"""
symbol, exchange = vt_symbol.split('.')
return symbol, exchange
def generate_vt_symbol(symbol: str, exchange: "Exchange"):
return f'{symbol}.{exchange.value}'
def _get_trader_dir(temp_name: str): def _get_trader_dir(temp_name: str):