commit
1d527bc36e
106
.travis.yml
106
.travis.yml
@ -2,12 +2,27 @@ language: python
|
||||
|
||||
dist: xenial # required for Python >= 3.7 (travis-ci/travis-ci#9069)
|
||||
|
||||
cache: pip
|
||||
|
||||
git:
|
||||
depth: 1
|
||||
|
||||
python:
|
||||
- "3.7"
|
||||
|
||||
services:
|
||||
- mongodb
|
||||
- mysql
|
||||
- postgresql
|
||||
|
||||
before_script:
|
||||
- psql -d postgresql://postgres:${VNPY_TEST_POSTGRESQL_PASSWORD}@localhost -c "create database vnpy;"
|
||||
- mysql -u root --password=${VNPY_TEST_MYSQL_PASSWORD} -e 'CREATE DATABASE vnpy;'
|
||||
|
||||
script:
|
||||
# todo: use python unittest
|
||||
- mkdir run; cd run; python ../tests/load_all.py
|
||||
- pip install psycopg2 mongoengine pymysql # we should support all database in test environment
|
||||
- cd tests; source travis_env.sh;
|
||||
- python test_all.py
|
||||
|
||||
matrix:
|
||||
include:
|
||||
@ -19,81 +34,36 @@ matrix:
|
||||
script:
|
||||
- flake8
|
||||
|
||||
- name: "pip install under Windows"
|
||||
os: "windows"
|
||||
# language : cpp is necessary for windows
|
||||
language: "cpp"
|
||||
env:
|
||||
- PATH=/c/Python37:/c/Python37/Scripts:$PATH
|
||||
before_install:
|
||||
- choco install python3 --version 3.7.2
|
||||
install:
|
||||
- python -m pip install --upgrade pip wheel setuptools
|
||||
- pip install https://pip.vnpy.com/colletion/TA_Lib-0.4.17-cp37-cp37m-win_amd64.whl
|
||||
- pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
|
||||
- pip install -r requirements.txt
|
||||
- pip install .
|
||||
|
||||
- name: "pip install under Ubuntu: gcc-8"
|
||||
addons:
|
||||
apt:
|
||||
sources:
|
||||
- ubuntu-toolchain-r-test
|
||||
packages:
|
||||
- g++-8
|
||||
before_install:
|
||||
# C++17
|
||||
- sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test
|
||||
- sudo apt-get update -y
|
||||
install:
|
||||
# C++17
|
||||
- sudo apt-get install -y gcc-8 g++-8
|
||||
- sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-8 90
|
||||
- sudo update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++-8 90
|
||||
- sudo update-alternatives --install /usr/bin/gcc cc /usr/bin/gcc-8 90
|
||||
install:
|
||||
# update pip & setuptools
|
||||
- python -m pip install --upgrade pip wheel setuptools
|
||||
# Linux install script
|
||||
- pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
|
||||
- bash ./install.sh
|
||||
|
||||
- name: "pip install under Ubuntu: gcc-7"
|
||||
- name: "sdist install under Ubuntu: gcc-7"
|
||||
addons:
|
||||
apt:
|
||||
sources:
|
||||
- ubuntu-toolchain-r-test
|
||||
packages:
|
||||
- g++-7
|
||||
before_install:
|
||||
# C++17
|
||||
- sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test
|
||||
- sudo apt-get update -y
|
||||
install:
|
||||
# C++17
|
||||
- sudo apt-get install -y gcc-7 g++-7
|
||||
- sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-7 90
|
||||
- sudo update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++-7 90
|
||||
- sudo update-alternatives --install /usr/bin/gcc cc /usr/bin/gcc-7 90
|
||||
# update pip & setuptools
|
||||
- python -m pip install --upgrade pip wheel setuptools
|
||||
# Linux install script
|
||||
- pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
|
||||
- bash ./install.sh
|
||||
|
||||
- name: "sdist install under Windows"
|
||||
os: "windows"
|
||||
# language : cpp is necessary for windows
|
||||
language: "cpp"
|
||||
env:
|
||||
- PATH=/c/Python37:/c/Python37/Scripts:$PATH
|
||||
before_install:
|
||||
- choco install python3 --version 3.7.2
|
||||
install:
|
||||
- python -m pip install --upgrade pip wheel setuptools
|
||||
- python setup.py sdist
|
||||
- pip install https://pip.vnpy.com/colletion/TA_Lib-0.4.17-cp37-cp37m-win_amd64.whl
|
||||
- pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
|
||||
- pip install dist/`ls dist`
|
||||
|
||||
- name: "sdist install under Ubuntu: gcc-8"
|
||||
before_install:
|
||||
# C++17
|
||||
- sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test
|
||||
- sudo apt-get update -y
|
||||
install:
|
||||
# C++17
|
||||
- sudo apt-get install -y gcc-8 g++-8
|
||||
- sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-8 90
|
||||
- sudo update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++-8 90
|
||||
- sudo update-alternatives --install /usr/bin/gcc cc /usr/bin/gcc-8 90
|
||||
# Linux install script
|
||||
- python -m pip install --upgrade pip wheel setuptools
|
||||
- pushd /tmp
|
||||
@ -108,3 +78,17 @@ matrix:
|
||||
- pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
|
||||
- python setup.py sdist
|
||||
- pip install dist/`ls dist`
|
||||
|
||||
- name: "pip install under osx"
|
||||
os: osx
|
||||
language: shell # osx supports only shell
|
||||
services: []
|
||||
before_install: []
|
||||
install:
|
||||
- pip3 install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
|
||||
- bash ./install_osx.sh
|
||||
before_script: []
|
||||
script:
|
||||
- pip3 install psycopg2 mongoengine pymysql # we should support all database in test environment
|
||||
- cd tests; source travis_env.sh;
|
||||
- VNPY_TEST_ONLY_SQLITE=1 python3 test_all.py
|
||||
|
68
appveyor.yml
Normal file
68
appveyor.yml
Normal file
@ -0,0 +1,68 @@
|
||||
clone_depth: 1
|
||||
|
||||
cache:
|
||||
- '%LOCALAPPDATA%\pip\Cache'
|
||||
|
||||
configuration:
|
||||
- pip
|
||||
- sdist
|
||||
|
||||
image:
|
||||
- Visual Studio 2017
|
||||
|
||||
services:
|
||||
- mysql
|
||||
- mongodb
|
||||
- postgresql
|
||||
|
||||
environment:
|
||||
PATH: C:\Python37-x64;C:\Python37-x64\Scripts;C:\Program Files\PostgreSQL\9.6\bin\;%PATH%
|
||||
VNPY_TEST_MYSQL_DATABASE: vnpy
|
||||
VNPY_TEST_MYSQL_HOST: localhost
|
||||
VNPY_TEST_MYSQL_PORT: 3306
|
||||
VNPY_TEST_MYSQL_USER: root
|
||||
VNPY_TEST_MYSQL_PASSWORD: Password12!
|
||||
|
||||
VNPY_TEST_POSTGRESQL_DATABASE: vnpy
|
||||
VNPY_TEST_POSTGRESQL_HOST: localhost
|
||||
VNPY_TEST_POSTGRESQL_PORT: 5432
|
||||
VNPY_TEST_POSTGRESQL_USER: postgres
|
||||
VNPY_TEST_POSTGRESQL_PASSWORD: Password12!
|
||||
|
||||
VNPY_TEST_MONGODB_DATABASE: vnpy
|
||||
VNPY_TEST_MONGODB_HOST: localhost
|
||||
VNPY_TEST_MONGODB_PORT: 27017
|
||||
|
||||
MYSQL_PWD: Password12!
|
||||
|
||||
install:
|
||||
- python -m pip install --upgrade pip wheel setuptools
|
||||
before_build:
|
||||
- ps: psql -d "postgresql://${ENV:VNPY_TEST_POSTGRESQL_USER}:${ENV:VNPY_TEST_POSTGRESQL_PASSWORD}@localhost" -c "create database vnpy;"
|
||||
- ps: . "C:\Program Files\MySQL\MySQL Server 5.7\bin\mysql" -u $ENV:VNPY_TEST_MYSQL_USER -e 'CREATE DATABASE vnpy;'
|
||||
|
||||
for:
|
||||
- matrix:
|
||||
only:
|
||||
- configuration: pip
|
||||
build_script:
|
||||
- pip install psycopg2 mongoengine pymysql # we should support all database in test environment
|
||||
- pip install https://pip.vnpy.com/colletion/TA_Lib-0.4.17-cp37-cp37m-win_amd64.whl
|
||||
- pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
|
||||
- pip install -r requirements.txt
|
||||
- pip install .
|
||||
|
||||
- matrix:
|
||||
only:
|
||||
- configuration: sdist
|
||||
build_script:
|
||||
- python setup.py sdist
|
||||
- pip install psycopg2 mongoengine pymysql # we should support all database in test environment
|
||||
- pip install https://pip.vnpy.com/colletion/TA_Lib-0.4.17-cp37-cp37m-win_amd64.whl
|
||||
- pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
|
||||
- ps: $name=(ls dist).name; pip install "dist/$name"
|
||||
|
||||
test_script:
|
||||
- cd tests
|
||||
- python test_all.py
|
||||
|
22
install.sh
22
install.sh
@ -1,27 +1,35 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
python=$1
|
||||
pip=$2
|
||||
prefix=$3
|
||||
|
||||
[[ -z $python ]] && python=python
|
||||
[[ -z $pip ]] && pip=pip
|
||||
[[ -z $prefix ]] && prefix=/usr
|
||||
|
||||
# Get and build ta-lib
|
||||
pushd /tmp
|
||||
wget http://prdownloads.sourceforge.net/ta-lib/ta-lib-0.4.0-src.tar.gz
|
||||
tar -xf ta-lib-0.4.0-src.tar.gz
|
||||
cd ta-lib
|
||||
./configure --prefix=/usr
|
||||
make
|
||||
./configure --prefix=$prefix
|
||||
make -j
|
||||
sudo make install
|
||||
popd
|
||||
|
||||
# old versions of ta-lib imports numpy in setup.py
|
||||
pip install numpy
|
||||
$pip install numpy
|
||||
|
||||
# Install extra packages
|
||||
pip install ta-lib
|
||||
pip install https://vnpy-pip.oss-cn-shanghai.aliyuncs.com/colletion/ibapi-9.75.1-py3-none-any.whl
|
||||
$pip install ta-lib
|
||||
$pip install https://vnpy-pip.oss-cn-shanghai.aliyuncs.com/colletion/ibapi-9.75.1-py3-none-any.whl
|
||||
|
||||
# Install Python Modules
|
||||
pip install -r requirements.txt
|
||||
$pip install -r requirements.txt
|
||||
|
||||
# Install local Chinese language environment
|
||||
sudo locale-gen zh_CN.GB18030
|
||||
|
||||
# Install vn.py
|
||||
pip install .
|
||||
$pip install .
|
3
install_osx.sh
Normal file
3
install_osx.sh
Normal file
@ -0,0 +1,3 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
bash ./install.sh python3 pip3 /usr/local
|
123
setup.py
123
setup.py
@ -29,70 +29,79 @@ with open("vnpy/__init__.py", "rb") as f:
|
||||
version = str(ast.literal_eval(version_line))
|
||||
|
||||
if platform.uname().system == "Windows":
|
||||
compiler_flags = ["/MP", "/std:c++17", # standard
|
||||
"/O2", "/Ob2", "/Oi", "/Ot", "/Oy", "/GL", # Optimization
|
||||
"/wd4819" # 936 code page
|
||||
]
|
||||
compiler_flags = [
|
||||
"/MP", "/std:c++17", # standard
|
||||
"/O2", "/Ob2", "/Oi", "/Ot", "/Oy", "/GL", # Optimization
|
||||
"/wd4819" # 936 code page
|
||||
]
|
||||
extra_link_args = []
|
||||
else:
|
||||
compiler_flags = ["-std=c++17",
|
||||
"-Wno-delete-incomplete", "-Wno-sign-compare",
|
||||
]
|
||||
compiler_flags = [
|
||||
"-std=c++17", # standard
|
||||
"-O3", # Optimization
|
||||
"-Wno-delete-incomplete", "-Wno-sign-compare",
|
||||
]
|
||||
extra_link_args = ["-lstdc++"]
|
||||
|
||||
vnctpmd = Extension("vnpy.api.ctp.vnctpmd",
|
||||
[
|
||||
"vnpy/api/ctp/vnctp/vnctpmd/vnctpmd.cpp",
|
||||
],
|
||||
include_dirs=["vnpy/api/ctp/include",
|
||||
"vnpy/api/ctp/vnctp", ],
|
||||
define_macros=[],
|
||||
undef_macros=[],
|
||||
library_dirs=["vnpy/api/ctp/libs", "vnpy/api/ctp"],
|
||||
libraries=["thostmduserapi", "thosttraderapi", ],
|
||||
extra_compile_args=compiler_flags,
|
||||
extra_link_args=extra_link_args,
|
||||
depends=[],
|
||||
runtime_library_dirs=["$ORIGIN"],
|
||||
language="cpp",
|
||||
)
|
||||
vnctptd = Extension("vnpy.api.ctp.vnctptd",
|
||||
[
|
||||
"vnpy/api/ctp/vnctp/vnctptd/vnctptd.cpp",
|
||||
],
|
||||
include_dirs=["vnpy/api/ctp/include",
|
||||
"vnpy/api/ctp/vnctp", ],
|
||||
define_macros=[],
|
||||
undef_macros=[],
|
||||
library_dirs=["vnpy/api/ctp/libs", "vnpy/api/ctp"],
|
||||
libraries=["thostmduserapi", "thosttraderapi", ],
|
||||
extra_compile_args=compiler_flags,
|
||||
extra_link_args=extra_link_args,
|
||||
runtime_library_dirs=["$ORIGIN"],
|
||||
depends=[],
|
||||
language="cpp",
|
||||
)
|
||||
vnoes = Extension("vnpy.api.oes.vnoes",
|
||||
[
|
||||
"vnpy/api/oes/vnoes/generated_files/classes_1.cpp",
|
||||
"vnpy/api/oes/vnoes/generated_files/classes_2.cpp",
|
||||
"vnpy/api/oes/vnoes/generated_files/module.cpp",
|
||||
],
|
||||
include_dirs=["vnpy/api/oes/include",
|
||||
"vnpy/api/oes/vnoes", ],
|
||||
define_macros=[("BRIGAND_NO_BOOST_SUPPORT", "1")],
|
||||
undef_macros=[],
|
||||
library_dirs=["vnpy/api/oes/libs"],
|
||||
libraries=["oes_api"],
|
||||
extra_compile_args=compiler_flags,
|
||||
extra_link_args=extra_link_args,
|
||||
depends=[],
|
||||
language="cpp",
|
||||
)
|
||||
vnctpmd = Extension(
|
||||
"vnpy.api.ctp.vnctpmd",
|
||||
[
|
||||
"vnpy/api/ctp/vnctp/vnctpmd/vnctpmd.cpp",
|
||||
],
|
||||
include_dirs=["vnpy/api/ctp/include",
|
||||
"vnpy/api/ctp/vnctp", ],
|
||||
define_macros=[],
|
||||
undef_macros=[],
|
||||
library_dirs=["vnpy/api/ctp/libs", "vnpy/api/ctp"],
|
||||
libraries=["thostmduserapi", "thosttraderapi", ],
|
||||
extra_compile_args=compiler_flags,
|
||||
extra_link_args=extra_link_args,
|
||||
depends=[],
|
||||
runtime_library_dirs=["$ORIGIN"],
|
||||
language="cpp",
|
||||
)
|
||||
vnctptd = Extension(
|
||||
"vnpy.api.ctp.vnctptd",
|
||||
[
|
||||
"vnpy/api/ctp/vnctp/vnctptd/vnctptd.cpp",
|
||||
],
|
||||
include_dirs=["vnpy/api/ctp/include",
|
||||
"vnpy/api/ctp/vnctp", ],
|
||||
define_macros=[],
|
||||
undef_macros=[],
|
||||
library_dirs=["vnpy/api/ctp/libs", "vnpy/api/ctp"],
|
||||
libraries=["thostmduserapi", "thosttraderapi", ],
|
||||
extra_compile_args=compiler_flags,
|
||||
extra_link_args=extra_link_args,
|
||||
runtime_library_dirs=["$ORIGIN"],
|
||||
depends=[],
|
||||
language="cpp",
|
||||
)
|
||||
vnoes = Extension(
|
||||
"vnpy.api.oes.vnoes",
|
||||
[
|
||||
"vnpy/api/oes/vnoes/generated_files/classes_1.cpp",
|
||||
"vnpy/api/oes/vnoes/generated_files/classes_2.cpp",
|
||||
"vnpy/api/oes/vnoes/generated_files/module.cpp",
|
||||
],
|
||||
include_dirs=["vnpy/api/oes/include",
|
||||
"vnpy/api/oes/vnoes", ],
|
||||
define_macros=[("BRIGAND_NO_BOOST_SUPPORT", "1")],
|
||||
undef_macros=[],
|
||||
library_dirs=["vnpy/api/oes/libs"],
|
||||
libraries=["oes_api"],
|
||||
extra_compile_args=compiler_flags,
|
||||
extra_link_args=extra_link_args,
|
||||
runtime_library_dirs=["$ORIGIN"],
|
||||
depends=[],
|
||||
language="cpp",
|
||||
)
|
||||
|
||||
if platform.uname().system == "Windows":
|
||||
if platform.system() == "Windows":
|
||||
# use pre-built pyd for windows ( support python 3.7 only )
|
||||
ext_modules = []
|
||||
elif platform.system() == "Darwin":
|
||||
ext_modules = []
|
||||
else:
|
||||
ext_modules = [vnctptd, vnctpmd, vnoes]
|
||||
|
||||
|
1
tests/app/__init__.py
Normal file
1
tests/app/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .test_csv_loader import *
|
90
tests/app/test_csv_loader.py
Normal file
90
tests/app/test_csv_loader.py
Normal file
@ -0,0 +1,90 @@
|
||||
"""
|
||||
Test if csv loader works fine
|
||||
"""
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from vnpy.app.csv_loader import CsvLoaderEngine
|
||||
from vnpy.trader.constant import Exchange, Interval
|
||||
|
||||
|
||||
class TestCsvLoader(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.engine = CsvLoaderEngine(None, None) # no engine is necessary for CsvLoader
|
||||
|
||||
def test_load(self):
|
||||
data = """"Datetime","Open","High","Low","Close","Volume"
|
||||
2010-04-16 09:16:00,3450.0,3488.0,3450.0,3468.0,489
|
||||
2010-04-16 09:17:00,3468.0,3473.8,3467.0,3467.0,302
|
||||
2010-04-16 09:18:00,3467.0,3471.0,3466.0,3467.0,203
|
||||
2010-04-16 09:19:00,3467.0,3468.2,3448.0,3448.0,280
|
||||
2010-04-16 09:20:00,3448.0,3459.0,3448.0,3454.0,250
|
||||
2010-04-16 09:21:00,3454.0,3456.8,3454.0,3456.8,109
|
||||
"""
|
||||
with tempfile.TemporaryFile("w+t") as f:
|
||||
f.write(data)
|
||||
f.seek(0)
|
||||
|
||||
self.engine.load_by_handle(
|
||||
f,
|
||||
symbol="1",
|
||||
exchange=Exchange.BITMEX,
|
||||
interval=Interval.MINUTE,
|
||||
datetime_head="Datetime",
|
||||
open_head="Open",
|
||||
close_head="Close",
|
||||
low_head="Low",
|
||||
high_head="High",
|
||||
volume_head="Volume",
|
||||
datetime_format="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
def test_load_duplicated(self):
|
||||
data = """"Datetime","Open","High","Low","Close","Volume"
|
||||
2010-04-16 09:16:00,3450.0,3488.0,3450.0,3468.0,489
|
||||
2010-04-16 09:17:00,3468.0,3473.8,3467.0,3467.0,302
|
||||
2010-04-16 09:18:00,3467.0,3471.0,3466.0,3467.0,203
|
||||
2010-04-16 09:19:00,3467.0,3468.2,3448.0,3448.0,280
|
||||
2010-04-16 09:20:00,3448.0,3459.0,3448.0,3454.0,250
|
||||
2010-04-16 09:21:00,3454.0,3456.8,3454.0,3456.8,109
|
||||
"""
|
||||
with tempfile.TemporaryFile("w+t") as f:
|
||||
f.write(data)
|
||||
f.seek(0)
|
||||
|
||||
self.engine.load_by_handle(
|
||||
f,
|
||||
symbol="1",
|
||||
exchange=Exchange.BITMEX,
|
||||
interval=Interval.MINUTE,
|
||||
datetime_head="Datetime",
|
||||
open_head="Open",
|
||||
close_head="Close",
|
||||
low_head="Low",
|
||||
high_head="High",
|
||||
volume_head="Volume",
|
||||
datetime_format="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
with tempfile.TemporaryFile("w+t") as f:
|
||||
f.write(data)
|
||||
f.seek(0)
|
||||
|
||||
self.engine.load_by_handle(
|
||||
f,
|
||||
symbol="1",
|
||||
exchange=Exchange.BITMEX,
|
||||
interval=Interval.MINUTE,
|
||||
datetime_head="Datetime",
|
||||
open_head="Open",
|
||||
close_head="Close",
|
||||
low_head="Low",
|
||||
high_head="High",
|
||||
volume_head="Volume",
|
||||
datetime_format="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -2,7 +2,7 @@ from time import time
|
||||
|
||||
import rqdatac as rq
|
||||
|
||||
from vnpy.trader.database import DbBarData, DB
|
||||
from vnpy.trader.database import DbBarData
|
||||
|
||||
USERNAME = ""
|
||||
PASSWORD = ""
|
||||
@ -39,11 +39,11 @@ def download_minute_bar(vt_symbol):
|
||||
|
||||
df = rq.get_price(symbol, frequency="1m", fields=FIELDS)
|
||||
|
||||
with DB.atomic():
|
||||
for ix, row in df.iterrows():
|
||||
print(row.name)
|
||||
bar = generate_bar_from_row(row, symbol, exchange)
|
||||
DbBarData.replace(bar.__data__).execute()
|
||||
bars = []
|
||||
for ix, row in df.iterrows():
|
||||
bar = generate_bar_from_row(row, symbol, exchange)
|
||||
bars.append(bar)
|
||||
DbBarData.save_all(bars)
|
||||
|
||||
end = time()
|
||||
cost = (end - start) * 1000
|
||||
|
31
tests/test_all.py
Normal file
31
tests/test_all.py
Normal file
@ -0,0 +1,31 @@
|
||||
# tests/runner.py
|
||||
import unittest
|
||||
|
||||
import app
|
||||
# import your test modules
|
||||
import test_import_all
|
||||
import trader
|
||||
|
||||
# initialize the test suite
|
||||
loader = unittest.TestLoader()
|
||||
suite = unittest.TestSuite()
|
||||
|
||||
# add tests to the test suite
|
||||
suite.addTests(loader.loadTestsFromModule(test_import_all))
|
||||
suite.addTests(loader.loadTestsFromModule(trader))
|
||||
suite.addTests(loader.loadTestsFromModule(app))
|
||||
|
||||
|
||||
# initialize a runner, pass it your suite and run it
|
||||
def main():
|
||||
runner = unittest.TextTestRunner(verbosity=3)
|
||||
result = runner.run(suite)
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
result = main()
|
||||
if result.failures or result.errors:
|
||||
exit(1)
|
||||
else:
|
||||
exit(0)
|
@ -1,24 +1,45 @@
|
||||
# flake8: noqa
|
||||
import unittest
|
||||
import platform
|
||||
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
class ImportTest(unittest.TestCase):
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
def test_import_all(self):
|
||||
from vnpy.event import EventEngine
|
||||
|
||||
def test_import_main_engine(self):
|
||||
from vnpy.trader.engine import MainEngine
|
||||
|
||||
def test_import_ui(self):
|
||||
from vnpy.trader.ui import MainWindow, create_qapp
|
||||
|
||||
def test_import_bitmex_gateway(self):
|
||||
from vnpy.gateway.bitmex import BitmexGateway
|
||||
|
||||
def test_import_futu_gateway(self):
|
||||
from vnpy.gateway.futu import FutuGateway
|
||||
|
||||
def test_import_ib_gateway(self):
|
||||
from vnpy.gateway.ib import IbGateway
|
||||
|
||||
@unittest.skipIf(platform.system() == "Darwin", "Not supported yet under osx")
|
||||
def test_import_ctp_gateway(self):
|
||||
from vnpy.gateway.ctp import CtpGateway
|
||||
|
||||
def test_import_tiger_gateway(self):
|
||||
from vnpy.gateway.tiger import TigerGateway
|
||||
|
||||
@unittest.skipIf(platform.system() == "Darwin", "Not supported yet under osx")
|
||||
def test_import_oes_gateway(self):
|
||||
from vnpy.gateway.oes import OesGateway
|
||||
|
||||
def test_import_cta_strategy_app(self):
|
||||
from vnpy.app.cta_strategy import CtaStrategyApp
|
||||
|
||||
def test_import_csv_loader_app(self):
|
||||
from vnpy.app.csv_loader import CsvLoaderApp
|
||||
|
||||
|
2
tests/trader/__init__.py
Normal file
2
tests/trader/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from .test_database import *
|
||||
from .test_settings import *
|
128
tests/trader/test_database.py
Normal file
128
tests/trader/test_database.py
Normal file
@ -0,0 +1,128 @@
|
||||
"""
|
||||
Test if database works fine
|
||||
"""
|
||||
import os
|
||||
import unittest
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from vnpy.trader.constant import Exchange, Interval
|
||||
from vnpy.trader.database.database import Driver
|
||||
from vnpy.trader.object import BarData, TickData
|
||||
|
||||
os.environ['VNPY_TESTING'] = '1'
|
||||
|
||||
profiles = {
|
||||
Driver.SQLITE: {
|
||||
"driver": "sqlite",
|
||||
"database": "test_db.db",
|
||||
}
|
||||
}
|
||||
if 'VNPY_TEST_ONLY_SQLITE' not in os.environ:
|
||||
profiles.update({
|
||||
Driver.MYSQL: {
|
||||
"driver": "mysql",
|
||||
"database": os.environ['VNPY_TEST_MYSQL_DATABASE'],
|
||||
"host": os.environ['VNPY_TEST_MYSQL_HOST'],
|
||||
"port": int(os.environ['VNPY_TEST_MYSQL_PORT']),
|
||||
"user": os.environ["VNPY_TEST_MYSQL_USER"],
|
||||
"password": os.environ['VNPY_TEST_MYSQL_PASSWORD'],
|
||||
},
|
||||
Driver.POSTGRESQL: {
|
||||
"driver": "postgresql",
|
||||
"database": os.environ['VNPY_TEST_POSTGRESQL_DATABASE'],
|
||||
"host": os.environ['VNPY_TEST_POSTGRESQL_HOST'],
|
||||
"port": int(os.environ['VNPY_TEST_POSTGRESQL_PORT']),
|
||||
"user": os.environ["VNPY_TEST_POSTGRESQL_USER"],
|
||||
"password": os.environ['VNPY_TEST_POSTGRESQL_PASSWORD'],
|
||||
},
|
||||
Driver.MONGODB: {
|
||||
"driver": "mongodb",
|
||||
"database": os.environ['VNPY_TEST_MONGODB_DATABASE'],
|
||||
"host": os.environ['VNPY_TEST_MONGODB_HOST'],
|
||||
"port": int(os.environ['VNPY_TEST_MONGODB_PORT']),
|
||||
"user": "",
|
||||
"password": "",
|
||||
"authentication_source": "",
|
||||
},
|
||||
})
|
||||
|
||||
|
||||
def now():
|
||||
return datetime.utcnow()
|
||||
|
||||
|
||||
bar = BarData(
|
||||
gateway_name="DB",
|
||||
symbol="test_symbol",
|
||||
exchange=Exchange.BITMEX,
|
||||
datetime=now(),
|
||||
interval=Interval.MINUTE,
|
||||
)
|
||||
|
||||
tick = TickData(
|
||||
gateway_name="DB",
|
||||
symbol="test_symbol",
|
||||
exchange=Exchange.BITMEX,
|
||||
datetime=now(),
|
||||
name="DB_test_symbol",
|
||||
)
|
||||
|
||||
|
||||
class TestDatabase(unittest.TestCase):
|
||||
|
||||
def connect(self, settings: dict):
|
||||
from vnpy.trader.database.initialize import init # noqa
|
||||
self.manager = init(settings)
|
||||
|
||||
def test_upsert_bar(self):
|
||||
for driver, settings in profiles.items():
|
||||
with self.subTest(driver=driver, settings=settings):
|
||||
self.connect(settings)
|
||||
self.manager.save_bar_data([bar])
|
||||
self.manager.save_bar_data([bar])
|
||||
|
||||
def test_save_load_bar(self):
|
||||
for driver, settings in profiles.items():
|
||||
with self.subTest(driver=driver, settings=settings):
|
||||
self.connect(settings)
|
||||
# save first
|
||||
self.manager.save_bar_data([bar])
|
||||
|
||||
# and load
|
||||
results = self.manager.load_bar_data(
|
||||
symbol=bar.symbol,
|
||||
exchange=bar.exchange,
|
||||
interval=bar.interval,
|
||||
start=bar.datetime - timedelta(seconds=1), # time is not accuracy
|
||||
end=now() + timedelta(seconds=1), # time is not accuracy
|
||||
)
|
||||
count = len(results)
|
||||
self.assertNotEqual(count, 0)
|
||||
|
||||
def test_upsert_tick(self):
|
||||
for driver, settings in profiles.items():
|
||||
with self.subTest(driver=driver, settings=settings):
|
||||
self.connect(settings)
|
||||
self.manager.save_tick_data([tick])
|
||||
self.manager.save_tick_data([tick])
|
||||
|
||||
def test_save_load_tick(self):
|
||||
for driver, settings in profiles.items():
|
||||
with self.subTest(driver=driver, settings=settings):
|
||||
self.connect(settings)
|
||||
# save first
|
||||
self.manager.save_tick_data([tick])
|
||||
|
||||
# and load
|
||||
results = self.manager.load_tick_data(
|
||||
symbol=bar.symbol,
|
||||
exchange=bar.exchange,
|
||||
start=bar.datetime - timedelta(seconds=1), # time is not accuracy
|
||||
end=now() + timedelta(seconds=1), # time is not accuracy
|
||||
)
|
||||
count = len(results)
|
||||
self.assertNotEqual(count, 0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
25
tests/trader/test_settings.py
Normal file
25
tests/trader/test_settings.py
Normal file
@ -0,0 +1,25 @@
|
||||
"""
|
||||
Test if database works fine
|
||||
"""
|
||||
import unittest
|
||||
|
||||
from vnpy.trader.setting import SETTINGS, get_settings
|
||||
|
||||
|
||||
class TestSettings(unittest.TestCase):
|
||||
|
||||
def test_get_settings(self):
|
||||
SETTINGS['a'] = 1
|
||||
got = get_settings()
|
||||
self.assertIn('a', got)
|
||||
self.assertEqual(got['a'], 1)
|
||||
|
||||
def test_get_settings_with_prefix(self):
|
||||
SETTINGS['a.a'] = 1
|
||||
got = get_settings()
|
||||
self.assertIn('a', got)
|
||||
self.assertEqual(got['a'], 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
30
tests/travis_env.sh
Normal file
30
tests/travis_env.sh
Normal file
@ -0,0 +1,30 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
[[ -z ${VNPY_TEST_MYSQL_DATABASE} ]] && VNPY_TEST_MYSQL_DATABASE=vnpy
|
||||
[[ -z ${VNPY_TEST_MYSQL_HOST} ]] && VNPY_TEST_MYSQL_HOST=127.0.0.1
|
||||
[[ -z ${VNPY_TEST_MYSQL_PORT} ]] && VNPY_TEST_MYSQL_PORT=3306
|
||||
[[ -z ${VNPY_TEST_MYSQL_USER} ]] && VNPY_TEST_MYSQL_USER=root
|
||||
[[ -z ${VNPY_TEST_MYSQL_PASSWORD} ]] && VNPY_TEST_MYSQL_PASSWORD=
|
||||
[[ -z ${VNPY_TEST_POSTGRESQL_DATABASE} ]] && VNPY_TEST_POSTGRESQL_DATABASE=vnpy
|
||||
[[ -z ${VNPY_TEST_POSTGRESQL_HOST} ]] && VNPY_TEST_POSTGRESQL_HOST=127.0.0.1
|
||||
[[ -z ${VNPY_TEST_POSTGRESQL_PORT} ]] && VNPY_TEST_POSTGRESQL_PORT=5432
|
||||
[[ -z ${VNPY_TEST_POSTGRESQL_USER} ]] && VNPY_TEST_POSTGRESQL_USER=postgres
|
||||
[[ -z ${VNPY_TEST_POSTGRESQL_PASSWORD} ]] && VNPY_TEST_POSTGRESQL_PASSWORD=
|
||||
[[ -z ${VNPY_TEST_MONGODB_DATABASE} ]] && VNPY_TEST_MONGODB_DATABASE=vnpy
|
||||
[[ -z ${VNPY_TEST_MONGODB_HOST} ]] && VNPY_TEST_MONGODB_HOST=127.0.0.1
|
||||
[[ -z ${VNPY_TEST_MONGODB_PORT} ]] && VNPY_TEST_MONGODB_PORT=27017
|
||||
|
||||
|
||||
export VNPY_TEST_MYSQL_DATABASE
|
||||
export VNPY_TEST_MYSQL_HOST
|
||||
export VNPY_TEST_MYSQL_PORT
|
||||
export VNPY_TEST_MYSQL_USER
|
||||
export VNPY_TEST_MYSQL_PASSWORD
|
||||
export VNPY_TEST_POSTGRESQL_DATABASE
|
||||
export VNPY_TEST_POSTGRESQL_HOST
|
||||
export VNPY_TEST_POSTGRESQL_PORT
|
||||
export VNPY_TEST_POSTGRESQL_USER
|
||||
export VNPY_TEST_POSTGRESQL_PASSWORD
|
||||
export VNPY_TEST_MONGODB_DATABASE
|
||||
export VNPY_TEST_MONGODB_HOST
|
||||
export VNPY_TEST_MONGODB_PORT
|
@ -22,14 +22,13 @@ Sample csv file:
|
||||
|
||||
import csv
|
||||
from datetime import datetime
|
||||
|
||||
from peewee import chunked
|
||||
from typing import TextIO
|
||||
|
||||
from vnpy.event import EventEngine
|
||||
from vnpy.trader.constant import Exchange, Interval
|
||||
from vnpy.trader.database import DbBarData, DB
|
||||
from vnpy.trader.database import database_manager
|
||||
from vnpy.trader.engine import BaseEngine, MainEngine
|
||||
|
||||
from vnpy.trader.object import BarData
|
||||
|
||||
APP_NAME = "CsvLoader"
|
||||
|
||||
@ -53,6 +52,59 @@ class CsvLoaderEngine(BaseEngine):
|
||||
self.high_head: str = ""
|
||||
self.volume_head: str = ""
|
||||
|
||||
def load_by_handle(
|
||||
self,
|
||||
f: TextIO,
|
||||
symbol: str,
|
||||
exchange: Exchange,
|
||||
interval: Interval,
|
||||
datetime_head: str,
|
||||
open_head: str,
|
||||
close_head: str,
|
||||
low_head: str,
|
||||
high_head: str,
|
||||
volume_head: str,
|
||||
datetime_format: str,
|
||||
):
|
||||
"""
|
||||
load by text mode file handle
|
||||
"""
|
||||
reader = csv.DictReader(f)
|
||||
|
||||
bars = []
|
||||
start = None
|
||||
count = 0
|
||||
for item in reader:
|
||||
if datetime_format:
|
||||
dt = datetime.strptime(item[datetime_head], datetime_format)
|
||||
else:
|
||||
dt = datetime.fromisoformat(item[datetime_head])
|
||||
|
||||
bar = BarData(
|
||||
symbol=symbol,
|
||||
exchange=exchange,
|
||||
datetime=dt,
|
||||
interval=interval,
|
||||
volume=item[volume_head],
|
||||
open_price=item[open_head],
|
||||
high_price=item[high_head],
|
||||
low_price=item[low_head],
|
||||
close_price=item[close_head],
|
||||
gateway_name="DB",
|
||||
)
|
||||
|
||||
bars.append(bar)
|
||||
|
||||
# do some statistics
|
||||
count += 1
|
||||
if not start:
|
||||
start = bar.datetime
|
||||
end = bar.datetime
|
||||
|
||||
# insert into database
|
||||
database_manager.save_bar_data(bars)
|
||||
return start, end, count
|
||||
|
||||
def load(
|
||||
self,
|
||||
file_path: str,
|
||||
@ -65,49 +117,22 @@ class CsvLoaderEngine(BaseEngine):
|
||||
low_head: str,
|
||||
high_head: str,
|
||||
volume_head: str,
|
||||
datetime_format: str
|
||||
datetime_format: str,
|
||||
):
|
||||
""""""
|
||||
vt_symbol = f"{symbol}.{exchange.value}"
|
||||
|
||||
start = None
|
||||
end = None
|
||||
count = 0
|
||||
|
||||
"""
|
||||
load by filename
|
||||
"""
|
||||
with open(file_path, "rt") as f:
|
||||
reader = csv.DictReader(f)
|
||||
|
||||
db_bars = []
|
||||
|
||||
for item in reader:
|
||||
dt = datetime.strptime(item[datetime_head], datetime_format)
|
||||
|
||||
db_bar = {
|
||||
"symbol": symbol,
|
||||
"exchange": exchange.value,
|
||||
"datetime": dt,
|
||||
"interval": interval.value,
|
||||
"volume": item[volume_head],
|
||||
"open_price": item[open_head],
|
||||
"high_price": item[high_head],
|
||||
"low_price": item[low_head],
|
||||
"close_price": item[close_head],
|
||||
"vt_symbol": vt_symbol,
|
||||
"gateway_name": "DB"
|
||||
}
|
||||
|
||||
db_bars.append(db_bar)
|
||||
|
||||
# do some statistics
|
||||
count += 1
|
||||
if not start:
|
||||
start = db_bar["datetime"]
|
||||
|
||||
end = db_bar["datetime"]
|
||||
|
||||
# Insert into DB
|
||||
with DB.atomic():
|
||||
for batch in chunked(db_bars, 50):
|
||||
DbBarData.insert_many(batch).on_conflict_replace().execute()
|
||||
|
||||
return start, end, count
|
||||
return self.load_by_handle(
|
||||
f,
|
||||
symbol=symbol,
|
||||
exchange=exchange,
|
||||
interval=interval,
|
||||
datetime_head=datetime_head,
|
||||
open_head=open_head,
|
||||
close_head=close_head,
|
||||
low_head=low_head,
|
||||
high_head=high_head,
|
||||
volume_head=volume_head,
|
||||
datetime_format=datetime_format,
|
||||
)
|
||||
|
@ -12,8 +12,8 @@ from pandas import DataFrame
|
||||
|
||||
from vnpy.trader.constant import (Direction, Offset, Exchange,
|
||||
Interval, Status)
|
||||
from vnpy.trader.database import DbBarData, DbTickData
|
||||
from vnpy.trader.object import OrderData, TradeData
|
||||
from vnpy.trader.database import database_manager
|
||||
from vnpy.trader.object import OrderData, TradeData, BarData, TickData
|
||||
from vnpy.trader.utility import round_to_pricetick
|
||||
|
||||
from .base import (
|
||||
@ -103,8 +103,8 @@ class BacktestingEngine:
|
||||
|
||||
self.strategy_class = None
|
||||
self.strategy = None
|
||||
self.tick = None
|
||||
self.bar = None
|
||||
self.tick: TickData
|
||||
self.bar: BarData
|
||||
self.datetime = None
|
||||
|
||||
self.interval = None
|
||||
@ -199,14 +199,16 @@ class BacktestingEngine:
|
||||
|
||||
if self.mode == BacktestingMode.BAR:
|
||||
self.history_data = load_bar_data(
|
||||
self.vt_symbol,
|
||||
self.symbol,
|
||||
self.exchange,
|
||||
self.interval,
|
||||
self.start,
|
||||
self.end
|
||||
)
|
||||
else:
|
||||
self.history_data = load_tick_data(
|
||||
self.vt_symbol,
|
||||
self.symbol,
|
||||
self.exchange,
|
||||
self.start,
|
||||
self.end
|
||||
)
|
||||
@ -520,7 +522,7 @@ class BacktestingEngine:
|
||||
else:
|
||||
self.daily_results[d] = DailyResult(d, price)
|
||||
|
||||
def new_bar(self, bar: DbBarData):
|
||||
def new_bar(self, bar: BarData):
|
||||
""""""
|
||||
self.bar = bar
|
||||
self.datetime = bar.datetime
|
||||
@ -531,7 +533,7 @@ class BacktestingEngine:
|
||||
|
||||
self.update_daily_close(bar.close_price)
|
||||
|
||||
def new_tick(self, tick: DbTickData):
|
||||
def new_tick(self, tick: TickData):
|
||||
""""""
|
||||
self.tick = tick
|
||||
self.datetime = tick.datetime
|
||||
@ -966,41 +968,26 @@ def optimize(
|
||||
|
||||
@lru_cache(maxsize=10)
|
||||
def load_bar_data(
|
||||
vt_symbol: str,
|
||||
interval: str,
|
||||
symbol: str,
|
||||
exchange: Exchange,
|
||||
interval: Interval,
|
||||
start: datetime,
|
||||
end: datetime
|
||||
):
|
||||
""""""
|
||||
s = (
|
||||
DbBarData.select()
|
||||
.where(
|
||||
(DbBarData.vt_symbol == vt_symbol)
|
||||
& (DbBarData.interval == interval)
|
||||
& (DbBarData.datetime >= start)
|
||||
& (DbBarData.datetime <= end)
|
||||
)
|
||||
.order_by(DbBarData.datetime)
|
||||
return database_manager.load_bar_data(
|
||||
symbol, exchange, interval, start, end
|
||||
)
|
||||
data = [db_bar.to_bar() for db_bar in s]
|
||||
return data
|
||||
|
||||
|
||||
@lru_cache(maxsize=10)
|
||||
def load_tick_data(
|
||||
vt_symbol: str,
|
||||
symbol: str,
|
||||
exchange: Exchange,
|
||||
start: datetime,
|
||||
end: datetime
|
||||
):
|
||||
""""""
|
||||
s = (
|
||||
DbTickData.select()
|
||||
.where(
|
||||
(DbTickData.vt_symbol == vt_symbol)
|
||||
& (DbTickData.datetime >= start)
|
||||
& (DbTickData.datetime <= end)
|
||||
)
|
||||
.order_by(DbTickData.datetime)
|
||||
return database_manager.load_tick_data(
|
||||
symbol, exchange, start, end
|
||||
)
|
||||
data = [db_tick.db_tick() for db_tick in s]
|
||||
return data
|
@ -5,7 +5,7 @@ import os
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable
|
||||
from typing import Any, Callable, List
|
||||
from datetime import datetime, timedelta
|
||||
from threading import Thread
|
||||
from queue import Queue
|
||||
@ -36,7 +36,7 @@ from vnpy.trader.constant import (
|
||||
Status
|
||||
)
|
||||
from vnpy.trader.utility import load_json, save_json
|
||||
from vnpy.trader.database import DbTickData, DbBarData
|
||||
from vnpy.trader.database import database_manager
|
||||
from vnpy.trader.setting import SETTINGS
|
||||
|
||||
from .base import (
|
||||
@ -146,13 +146,12 @@ class CtaEngine(BaseEngine):
|
||||
self.write_log("RQData数据接口初始化成功")
|
||||
|
||||
def query_bar_from_rq(
|
||||
self, vt_symbol: str, interval: Interval, start: datetime, end: datetime
|
||||
self, symbol: str, exchange: Exchange, interval: Interval, start: datetime, end: datetime
|
||||
):
|
||||
"""
|
||||
Query bar data from RQData.
|
||||
"""
|
||||
symbol, exchange_str = vt_symbol.split(".")
|
||||
rq_symbol = to_rq_symbol(vt_symbol)
|
||||
rq_symbol = to_rq_symbol(symbol, exchange)
|
||||
if rq_symbol not in self.rq_symbols:
|
||||
return None
|
||||
|
||||
@ -166,11 +165,11 @@ class CtaEngine(BaseEngine):
|
||||
end_date=end
|
||||
)
|
||||
|
||||
data = []
|
||||
data: List[BarData] = []
|
||||
for ix, row in df.iterrows():
|
||||
bar = BarData(
|
||||
symbol=symbol,
|
||||
exchange=Exchange(exchange_str),
|
||||
exchange=exchange,
|
||||
interval=interval,
|
||||
datetime=row.name.to_pydatetime(),
|
||||
open_price=row["open"],
|
||||
@ -529,46 +528,41 @@ class CtaEngine(BaseEngine):
|
||||
return self.engine_type
|
||||
|
||||
def load_bar(
|
||||
self, vt_symbol: str, days: int, interval: Interval, callback: Callable
|
||||
self, symbol: str, exchange: Exchange, days: int, interval: Interval,
|
||||
callback: Callable[[BarData], None]
|
||||
):
|
||||
""""""
|
||||
end = datetime.now()
|
||||
start = end - timedelta(days)
|
||||
|
||||
# Query data from RQData by default, if not found, load from database.
|
||||
data = self.query_bar_from_rq(vt_symbol, interval, start, end)
|
||||
if not data:
|
||||
s = (
|
||||
DbBarData.select()
|
||||
.where(
|
||||
(DbBarData.vt_symbol == vt_symbol)
|
||||
& (DbBarData.interval == interval.value)
|
||||
& (DbBarData.datetime >= start)
|
||||
& (DbBarData.datetime <= end)
|
||||
)
|
||||
.order_by(DbBarData.datetime)
|
||||
# Query bars from RQData by default, if not found, load from database.
|
||||
bars = self.query_bar_from_rq(symbol, exchange, interval, start, end)
|
||||
if not bars:
|
||||
bars = database_manager.load_bar_data(
|
||||
symbol=symbol,
|
||||
exchange=exchange,
|
||||
interval=interval,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
data = [db_bar.to_bar() for db_bar in s]
|
||||
|
||||
for bar in data:
|
||||
for bar in bars:
|
||||
callback(bar)
|
||||
|
||||
def load_tick(self, vt_symbol: str, days: int, callback: Callable):
|
||||
def load_tick(self, symbol: str, exchange: Exchange, days: int,
|
||||
callback: Callable[[TickData], None]):
|
||||
""""""
|
||||
end = datetime.now()
|
||||
start = end - timedelta(days)
|
||||
|
||||
s = (
|
||||
DbTickData.select()
|
||||
.where(
|
||||
(DbBarData.vt_symbol == vt_symbol)
|
||||
& (DbBarData.datetime >= start)
|
||||
& (DbBarData.datetime <= end)
|
||||
)
|
||||
.order_by(DbBarData.datetime)
|
||||
ticks = database_manager.load_tick_data(
|
||||
symbol=symbol,
|
||||
exchange=exchange,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
for tick in s:
|
||||
for tick in ticks:
|
||||
callback(tick)
|
||||
|
||||
def call_strategy_func(
|
||||
@ -757,7 +751,7 @@ class CtaEngine(BaseEngine):
|
||||
"""
|
||||
Load strategy class from certain folder.
|
||||
"""
|
||||
for dirpath, dirnames, filenames in os.walk(path):
|
||||
for dirpath, dirnames, filenames in os.walk(str(path)):
|
||||
for filename in filenames:
|
||||
if filename.endswith(".py"):
|
||||
strategy_module_name = ".".join(
|
||||
@ -914,19 +908,19 @@ class CtaEngine(BaseEngine):
|
||||
self.main_engine.send_email(subject, msg)
|
||||
|
||||
|
||||
def to_rq_symbol(vt_symbol: str):
|
||||
def to_rq_symbol(symbol: str, exchange: Exchange):
|
||||
"""
|
||||
CZCE product of RQData has symbol like "TA1905" while
|
||||
vt symbol is "TA905.CZCE" so need to add "1" in symbol.
|
||||
"""
|
||||
symbol, exchange_str = vt_symbol.split(".")
|
||||
if exchange_str != "CZCE":
|
||||
if exchange is not Exchange.CZCE:
|
||||
return symbol.upper()
|
||||
|
||||
for count, word in enumerate(symbol):
|
||||
if word.isdigit():
|
||||
break
|
||||
|
||||
# noinspection PyUnboundLocalVariable
|
||||
product = symbol[:count]
|
||||
year = symbol[count]
|
||||
month = symbol[count + 1:]
|
||||
|
@ -1,268 +0,0 @@
|
||||
""""""
|
||||
|
||||
from peewee import CharField, DateTimeField, FloatField, Model, MySQLDatabase, PostgresqlDatabase, \
|
||||
SqliteDatabase
|
||||
|
||||
from .constant import Exchange, Interval
|
||||
from .object import BarData, TickData
|
||||
from .setting import SETTINGS
|
||||
from .utility import resolve_path
|
||||
|
||||
|
||||
def init():
|
||||
db_settings = SETTINGS['database']
|
||||
driver = db_settings["driver"]
|
||||
|
||||
init_funcs = {
|
||||
"sqlite": init_sqlite,
|
||||
"mysql": init_mysql,
|
||||
"postgresql": init_postgresql,
|
||||
}
|
||||
|
||||
assert driver in init_funcs
|
||||
del db_settings['driver']
|
||||
return init_funcs[driver](db_settings)
|
||||
|
||||
|
||||
def init_sqlite(settings: dict):
|
||||
global DB
|
||||
database = settings['database']
|
||||
|
||||
DB = SqliteDatabase(str(resolve_path(database)))
|
||||
|
||||
|
||||
def init_mysql(settings: dict):
|
||||
global DB
|
||||
DB = MySQLDatabase(**settings)
|
||||
|
||||
|
||||
def init_postgresql(settings: dict):
|
||||
global DB
|
||||
DB = PostgresqlDatabase(**settings)
|
||||
|
||||
|
||||
init()
|
||||
|
||||
|
||||
class DbBarData(Model):
|
||||
"""
|
||||
Candlestick bar data for database storage.
|
||||
|
||||
Index is defined unique with vt_symbol, interval and datetime.
|
||||
"""
|
||||
|
||||
symbol = CharField()
|
||||
exchange = CharField()
|
||||
datetime = DateTimeField()
|
||||
interval = CharField()
|
||||
|
||||
volume = FloatField()
|
||||
open_price = FloatField()
|
||||
high_price = FloatField()
|
||||
low_price = FloatField()
|
||||
close_price = FloatField()
|
||||
|
||||
vt_symbol = CharField()
|
||||
gateway_name = CharField()
|
||||
|
||||
class Meta:
|
||||
database = DB
|
||||
indexes = ((("vt_symbol", "interval", "datetime"), True),)
|
||||
|
||||
@staticmethod
|
||||
def from_bar(bar: BarData):
|
||||
"""
|
||||
Generate DbBarData object from BarData.
|
||||
"""
|
||||
db_bar = DbBarData()
|
||||
|
||||
db_bar.symbol = bar.symbol
|
||||
db_bar.exchange = bar.exchange.value
|
||||
db_bar.datetime = bar.datetime
|
||||
db_bar.interval = bar.interval.value
|
||||
db_bar.volume = bar.volume
|
||||
db_bar.open_price = bar.open_price
|
||||
db_bar.high_price = bar.high_price
|
||||
db_bar.low_price = bar.low_price
|
||||
db_bar.close_price = bar.close_price
|
||||
db_bar.vt_symbol = bar.vt_symbol
|
||||
db_bar.gateway_name = "DB"
|
||||
|
||||
return db_bar
|
||||
|
||||
def to_bar(self):
|
||||
"""
|
||||
Generate BarData object from DbBarData.
|
||||
"""
|
||||
bar = BarData(
|
||||
symbol=self.symbol,
|
||||
exchange=Exchange(self.exchange),
|
||||
datetime=self.datetime,
|
||||
interval=Interval(self.interval),
|
||||
volume=self.volume,
|
||||
open_price=self.open_price,
|
||||
high_price=self.high_price,
|
||||
low_price=self.low_price,
|
||||
close_price=self.close_price,
|
||||
gateway_name=self.gateway_name,
|
||||
)
|
||||
return bar
|
||||
|
||||
|
||||
class DbTickData(Model):
|
||||
"""
|
||||
Tick data for database storage.
|
||||
|
||||
Index is defined unique with vt_symbol, interval and datetime.
|
||||
"""
|
||||
|
||||
symbol = CharField()
|
||||
exchange = CharField()
|
||||
datetime = DateTimeField()
|
||||
|
||||
name = CharField()
|
||||
volume = FloatField()
|
||||
last_price = FloatField()
|
||||
last_volume = FloatField()
|
||||
limit_up = FloatField()
|
||||
limit_down = FloatField()
|
||||
|
||||
open_price = FloatField()
|
||||
high_price = FloatField()
|
||||
low_price = FloatField()
|
||||
close_price = FloatField()
|
||||
|
||||
bid_price_1 = FloatField()
|
||||
bid_price_2 = FloatField()
|
||||
bid_price_3 = FloatField()
|
||||
bid_price_4 = FloatField()
|
||||
bid_price_5 = FloatField()
|
||||
|
||||
ask_price_1 = FloatField()
|
||||
ask_price_2 = FloatField()
|
||||
ask_price_3 = FloatField()
|
||||
ask_price_4 = FloatField()
|
||||
ask_price_5 = FloatField()
|
||||
|
||||
bid_volume_1 = FloatField()
|
||||
bid_volume_2 = FloatField()
|
||||
bid_volume_3 = FloatField()
|
||||
bid_volume_4 = FloatField()
|
||||
bid_volume_5 = FloatField()
|
||||
|
||||
ask_volume_1 = FloatField()
|
||||
ask_volume_2 = FloatField()
|
||||
ask_volume_3 = FloatField()
|
||||
ask_volume_4 = FloatField()
|
||||
ask_volume_5 = FloatField()
|
||||
|
||||
vt_symbol = CharField()
|
||||
gateway_name = CharField()
|
||||
|
||||
class Meta:
|
||||
database = DB
|
||||
indexes = ((("vt_symbol", "datetime"), True),)
|
||||
|
||||
@staticmethod
|
||||
def from_tick(tick: TickData):
|
||||
"""
|
||||
Generate DbTickData object from TickData.
|
||||
"""
|
||||
db_tick = DbTickData()
|
||||
|
||||
db_tick.symbol = tick.symbol
|
||||
db_tick.exchange = tick.exchange.value
|
||||
db_tick.datetime = tick.datetime
|
||||
db_tick.name = tick.name
|
||||
db_tick.volume = tick.volume
|
||||
db_tick.last_price = tick.last_price
|
||||
db_tick.last_volume = tick.last_volume
|
||||
db_tick.limit_up = tick.limit_up
|
||||
db_tick.limit_down = tick.limit_down
|
||||
db_tick.open_price = tick.open_price
|
||||
db_tick.high_price = tick.high_price
|
||||
db_tick.low_price = tick.low_price
|
||||
db_tick.pre_close = tick.pre_close
|
||||
|
||||
db_tick.bid_price_1 = tick.bid_price_1
|
||||
db_tick.ask_price_1 = tick.ask_price_1
|
||||
db_tick.bid_volume_1 = tick.bid_volume_1
|
||||
db_tick.ask_volume_1 = tick.ask_volume_1
|
||||
|
||||
if tick.bid_price_2:
|
||||
db_tick.bid_price_2 = tick.bid_price_2
|
||||
db_tick.bid_price_3 = tick.bid_price_3
|
||||
db_tick.bid_price_4 = tick.bid_price_4
|
||||
db_tick.bid_price_5 = tick.bid_price_5
|
||||
|
||||
db_tick.ask_price_2 = tick.ask_price_2
|
||||
db_tick.ask_price_3 = tick.ask_price_3
|
||||
db_tick.ask_price_4 = tick.ask_price_4
|
||||
db_tick.ask_price_5 = tick.ask_price_5
|
||||
|
||||
db_tick.bid_volume_2 = tick.bid_volume_2
|
||||
db_tick.bid_volume_3 = tick.bid_volume_3
|
||||
db_tick.bid_volume_4 = tick.bid_volume_4
|
||||
db_tick.bid_volume_5 = tick.bid_volume_5
|
||||
|
||||
db_tick.ask_volume_2 = tick.ask_volume_2
|
||||
db_tick.ask_volume_3 = tick.ask_volume_3
|
||||
db_tick.ask_volume_4 = tick.ask_volume_4
|
||||
db_tick.ask_volume_5 = tick.ask_volume_5
|
||||
|
||||
db_tick.vt_symbol = tick.vt_symbol
|
||||
db_tick.gateway_name = "DB"
|
||||
|
||||
return tick
|
||||
|
||||
def to_tick(self):
|
||||
"""
|
||||
Generate TickData object from DbTickData.
|
||||
"""
|
||||
tick = TickData(
|
||||
symbol=self.symbol,
|
||||
exchange=Exchange(self.exchange),
|
||||
datetime=self.datetime,
|
||||
name=self.name,
|
||||
volume=self.volume,
|
||||
last_price=self.last_price,
|
||||
last_volume=self.last_volume,
|
||||
limit_up=self.limit_up,
|
||||
limit_down=self.limit_down,
|
||||
open_price=self.open_price,
|
||||
high_price=self.high_price,
|
||||
low_price=self.low_price,
|
||||
pre_close=self.pre_close,
|
||||
bid_price_1=self.bid_price_1,
|
||||
ask_price_1=self.ask_price_1,
|
||||
bid_volume_1=self.bid_volume_1,
|
||||
ask_volume_1=self.ask_volume_1,
|
||||
gateway_name=self.gateway_name,
|
||||
)
|
||||
|
||||
if self.bid_price_2:
|
||||
tick.bid_price_2 = self.bid_price_2
|
||||
tick.bid_price_3 = self.bid_price_3
|
||||
tick.bid_price_4 = self.bid_price_4
|
||||
tick.bid_price_5 = self.bid_price_5
|
||||
|
||||
tick.ask_price_2 = self.ask_price_2
|
||||
tick.ask_price_3 = self.ask_price_3
|
||||
tick.ask_price_4 = self.ask_price_4
|
||||
tick.ask_price_5 = self.ask_price_5
|
||||
|
||||
tick.bid_volume_2 = self.bid_volume_2
|
||||
tick.bid_volume_3 = self.bid_volume_3
|
||||
tick.bid_volume_4 = self.bid_volume_4
|
||||
tick.bid_volume_5 = self.bid_volume_5
|
||||
|
||||
tick.ask_volume_2 = self.ask_volume_2
|
||||
tick.ask_volume_3 = self.ask_volume_3
|
||||
tick.ask_volume_4 = self.ask_volume_4
|
||||
tick.ask_volume_5 = self.ask_volume_5
|
||||
|
||||
return tick
|
||||
|
||||
|
||||
DB.connect()
|
||||
DB.create_tables([DbBarData, DbTickData])
|
12
vnpy/trader/database/__init__.py
Normal file
12
vnpy/trader/database/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vnpy.trader.database.database import BaseDatabaseManager
|
||||
|
||||
if "VNPY_TESTING" not in os.environ:
|
||||
from vnpy.trader.setting import get_settings
|
||||
from .initialize import init
|
||||
|
||||
settings = get_settings("database.")
|
||||
database_manager: "BaseDatabaseManager" = init(settings=settings)
|
53
vnpy/trader/database/database.py
Normal file
53
vnpy/trader/database/database.py
Normal file
@ -0,0 +1,53 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Sequence, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vnpy.trader.constant import Interval, Exchange # noqa
|
||||
from vnpy.trader.object import BarData, TickData # noqa
|
||||
|
||||
|
||||
class Driver(Enum):
|
||||
SQLITE = "sqlite"
|
||||
MYSQL = "mysql"
|
||||
POSTGRESQL = "postgresql"
|
||||
MONGODB = "mongodb"
|
||||
|
||||
|
||||
class BaseDatabaseManager(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def load_bar_data(
|
||||
self,
|
||||
symbol: str,
|
||||
exchange: "Exchange",
|
||||
interval: "Interval",
|
||||
start: datetime,
|
||||
end: datetime
|
||||
) -> Sequence["BarData"]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_tick_data(
|
||||
self,
|
||||
symbol: str,
|
||||
exchange: "Exchange",
|
||||
start: datetime,
|
||||
end: datetime
|
||||
) -> Sequence["TickData"]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_bar_data(
|
||||
self,
|
||||
datas: Sequence["BarData"],
|
||||
):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_tick_data(
|
||||
self,
|
||||
datas: Sequence["TickData"],
|
||||
):
|
||||
pass
|
302
vnpy/trader/database/database_mongo.py
Normal file
302
vnpy/trader/database/database_mongo.py
Normal file
@ -0,0 +1,302 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Sequence
|
||||
|
||||
from mongoengine import DateTimeField, Document, FloatField, StringField, connect
|
||||
|
||||
from vnpy.trader.constant import Exchange, Interval
|
||||
from vnpy.trader.object import BarData, TickData
|
||||
from .database import BaseDatabaseManager, Driver
|
||||
|
||||
|
||||
def init(_: Driver, settings: dict):
|
||||
database = settings["database"]
|
||||
host = settings["host"]
|
||||
port = settings["port"]
|
||||
username = settings["user"]
|
||||
password = settings["password"]
|
||||
authentication_source = settings["authentication_source"]
|
||||
if not username: # if username == '' or None, skip username
|
||||
username = None
|
||||
password = None
|
||||
authentication_source = None
|
||||
connect(
|
||||
db=database,
|
||||
host=host,
|
||||
port=port,
|
||||
username=username,
|
||||
password=password,
|
||||
authentication_source=authentication_source,
|
||||
)
|
||||
return MongoManager()
|
||||
|
||||
|
||||
class DbBarData(Document):
|
||||
"""
|
||||
Candlestick bar data for database storage.
|
||||
|
||||
Index is defined unique with datetime, interval, symbol
|
||||
"""
|
||||
|
||||
symbol: str = StringField()
|
||||
exchange: str = StringField()
|
||||
datetime: datetime = DateTimeField()
|
||||
interval: str = StringField()
|
||||
|
||||
volume: float = FloatField()
|
||||
open_price: float = FloatField()
|
||||
high_price: float = FloatField()
|
||||
low_price: float = FloatField()
|
||||
close_price: float = FloatField()
|
||||
|
||||
meta = {
|
||||
"indexes": [
|
||||
{"fields": ("datetime", "interval", "symbol", "exchange"), "unique": True}
|
||||
]
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def from_bar(bar: BarData):
|
||||
"""
|
||||
Generate DbBarData object from BarData.
|
||||
"""
|
||||
db_bar = DbBarData()
|
||||
|
||||
db_bar.symbol = bar.symbol
|
||||
db_bar.exchange = bar.exchange.value
|
||||
db_bar.datetime = bar.datetime
|
||||
db_bar.interval = bar.interval.value
|
||||
db_bar.volume = bar.volume
|
||||
db_bar.open_price = bar.open_price
|
||||
db_bar.high_price = bar.high_price
|
||||
db_bar.low_price = bar.low_price
|
||||
db_bar.close_price = bar.close_price
|
||||
|
||||
return db_bar
|
||||
|
||||
def to_bar(self):
|
||||
"""
|
||||
Generate BarData object from DbBarData.
|
||||
"""
|
||||
bar = BarData(
|
||||
symbol=self.symbol,
|
||||
exchange=Exchange(self.exchange),
|
||||
datetime=self.datetime,
|
||||
interval=Interval(self.interval),
|
||||
volume=self.volume,
|
||||
open_price=self.open_price,
|
||||
high_price=self.high_price,
|
||||
low_price=self.low_price,
|
||||
close_price=self.close_price,
|
||||
gateway_name="DB",
|
||||
)
|
||||
return bar
|
||||
|
||||
|
||||
class DbTickData(Document):
|
||||
"""
|
||||
Tick data for database storage.
|
||||
|
||||
Index is defined unique with (datetime, symbol)
|
||||
"""
|
||||
|
||||
symbol: str = StringField()
|
||||
exchange: str = StringField()
|
||||
datetime: datetime = DateTimeField()
|
||||
|
||||
name: str = StringField()
|
||||
volume: float = FloatField()
|
||||
last_price: float = FloatField()
|
||||
last_volume: float = FloatField()
|
||||
limit_up: float = FloatField()
|
||||
limit_down: float = FloatField()
|
||||
|
||||
open_price: float = FloatField()
|
||||
high_price: float = FloatField()
|
||||
low_price: float = FloatField()
|
||||
close_price: float = FloatField()
|
||||
pre_close: float = FloatField()
|
||||
|
||||
bid_price_1: float = FloatField()
|
||||
bid_price_2: float = FloatField()
|
||||
bid_price_3: float = FloatField()
|
||||
bid_price_4: float = FloatField()
|
||||
bid_price_5: float = FloatField()
|
||||
|
||||
ask_price_1: float = FloatField()
|
||||
ask_price_2: float = FloatField()
|
||||
ask_price_3: float = FloatField()
|
||||
ask_price_4: float = FloatField()
|
||||
ask_price_5: float = FloatField()
|
||||
|
||||
bid_volume_1: float = FloatField()
|
||||
bid_volume_2: float = FloatField()
|
||||
bid_volume_3: float = FloatField()
|
||||
bid_volume_4: float = FloatField()
|
||||
bid_volume_5: float = FloatField()
|
||||
|
||||
ask_volume_1: float = FloatField()
|
||||
ask_volume_2: float = FloatField()
|
||||
ask_volume_3: float = FloatField()
|
||||
ask_volume_4: float = FloatField()
|
||||
ask_volume_5: float = FloatField()
|
||||
|
||||
meta = {"indexes": [{"fields": ("datetime", "symbol", "exchange"), "unique": True}]}
|
||||
|
||||
@staticmethod
|
||||
def from_tick(tick: TickData):
|
||||
"""
|
||||
Generate DbTickData object from TickData.
|
||||
"""
|
||||
db_tick = DbTickData()
|
||||
|
||||
db_tick.symbol = tick.symbol
|
||||
db_tick.exchange = tick.exchange.value
|
||||
db_tick.datetime = tick.datetime
|
||||
db_tick.name = tick.name
|
||||
db_tick.volume = tick.volume
|
||||
db_tick.last_price = tick.last_price
|
||||
db_tick.last_volume = tick.last_volume
|
||||
db_tick.limit_up = tick.limit_up
|
||||
db_tick.limit_down = tick.limit_down
|
||||
db_tick.open_price = tick.open_price
|
||||
db_tick.high_price = tick.high_price
|
||||
db_tick.low_price = tick.low_price
|
||||
db_tick.pre_close = tick.pre_close
|
||||
|
||||
db_tick.bid_price_1 = tick.bid_price_1
|
||||
db_tick.ask_price_1 = tick.ask_price_1
|
||||
db_tick.bid_volume_1 = tick.bid_volume_1
|
||||
db_tick.ask_volume_1 = tick.ask_volume_1
|
||||
|
||||
if tick.bid_price_2:
|
||||
db_tick.bid_price_2 = tick.bid_price_2
|
||||
db_tick.bid_price_3 = tick.bid_price_3
|
||||
db_tick.bid_price_4 = tick.bid_price_4
|
||||
db_tick.bid_price_5 = tick.bid_price_5
|
||||
|
||||
db_tick.ask_price_2 = tick.ask_price_2
|
||||
db_tick.ask_price_3 = tick.ask_price_3
|
||||
db_tick.ask_price_4 = tick.ask_price_4
|
||||
db_tick.ask_price_5 = tick.ask_price_5
|
||||
|
||||
db_tick.bid_volume_2 = tick.bid_volume_2
|
||||
db_tick.bid_volume_3 = tick.bid_volume_3
|
||||
db_tick.bid_volume_4 = tick.bid_volume_4
|
||||
db_tick.bid_volume_5 = tick.bid_volume_5
|
||||
|
||||
db_tick.ask_volume_2 = tick.ask_volume_2
|
||||
db_tick.ask_volume_3 = tick.ask_volume_3
|
||||
db_tick.ask_volume_4 = tick.ask_volume_4
|
||||
db_tick.ask_volume_5 = tick.ask_volume_5
|
||||
|
||||
return db_tick
|
||||
|
||||
def to_tick(self):
|
||||
"""
|
||||
Generate TickData object from DbTickData.
|
||||
"""
|
||||
tick = TickData(
|
||||
symbol=self.symbol,
|
||||
exchange=Exchange(self.exchange),
|
||||
datetime=self.datetime,
|
||||
name=self.name,
|
||||
volume=self.volume,
|
||||
last_price=self.last_price,
|
||||
last_volume=self.last_volume,
|
||||
limit_up=self.limit_up,
|
||||
limit_down=self.limit_down,
|
||||
open_price=self.open_price,
|
||||
high_price=self.high_price,
|
||||
low_price=self.low_price,
|
||||
pre_close=self.pre_close,
|
||||
bid_price_1=self.bid_price_1,
|
||||
ask_price_1=self.ask_price_1,
|
||||
bid_volume_1=self.bid_volume_1,
|
||||
ask_volume_1=self.ask_volume_1,
|
||||
gateway_name="DB",
|
||||
)
|
||||
|
||||
if self.bid_price_2:
|
||||
tick.bid_price_2 = self.bid_price_2
|
||||
tick.bid_price_3 = self.bid_price_3
|
||||
tick.bid_price_4 = self.bid_price_4
|
||||
tick.bid_price_5 = self.bid_price_5
|
||||
|
||||
tick.ask_price_2 = self.ask_price_2
|
||||
tick.ask_price_3 = self.ask_price_3
|
||||
tick.ask_price_4 = self.ask_price_4
|
||||
tick.ask_price_5 = self.ask_price_5
|
||||
|
||||
tick.bid_volume_2 = self.bid_volume_2
|
||||
tick.bid_volume_3 = self.bid_volume_3
|
||||
tick.bid_volume_4 = self.bid_volume_4
|
||||
tick.bid_volume_5 = self.bid_volume_5
|
||||
|
||||
tick.ask_volume_2 = self.ask_volume_2
|
||||
tick.ask_volume_3 = self.ask_volume_3
|
||||
tick.ask_volume_4 = self.ask_volume_4
|
||||
tick.ask_volume_5 = self.ask_volume_5
|
||||
|
||||
return tick
|
||||
|
||||
|
||||
class MongoManager(BaseDatabaseManager):
|
||||
def load_bar_data(
|
||||
self,
|
||||
symbol: str,
|
||||
exchange: Exchange,
|
||||
interval: Interval,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
) -> Sequence[BarData]:
|
||||
s = DbBarData.objects(
|
||||
symbol=symbol,
|
||||
exchange=exchange.value,
|
||||
interval=interval.value,
|
||||
datetime__gte=start,
|
||||
datetime__lte=end,
|
||||
)
|
||||
data = [db_bar.to_bar() for db_bar in s]
|
||||
return data
|
||||
|
||||
def load_tick_data(
|
||||
self, symbol: str, exchange: Exchange, start: datetime, end: datetime
|
||||
) -> Sequence[TickData]:
|
||||
s = DbTickData.objects(
|
||||
symbol=symbol,
|
||||
exchange=exchange.value,
|
||||
datetime__gte=start,
|
||||
datetime__lte=end,
|
||||
)
|
||||
data = [db_tick.to_tick() for db_tick in s]
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def to_update_param(d):
|
||||
return {
|
||||
"set__" + k: v.value if isinstance(v, Enum) else v
|
||||
for k, v in d.__dict__.items()
|
||||
}
|
||||
|
||||
def save_bar_data(self, datas: Sequence[BarData]):
|
||||
for d in datas:
|
||||
updates = self.to_update_param(d)
|
||||
updates.pop("set__gateway_name")
|
||||
updates.pop("set__vt_symbol")
|
||||
(
|
||||
DbBarData.objects(
|
||||
symbol=d.symbol, interval=d.interval.value, datetime=d.datetime
|
||||
).update_one(upsert=True, **updates)
|
||||
)
|
||||
|
||||
def save_tick_data(self, datas: Sequence[TickData]):
|
||||
for d in datas:
|
||||
updates = self.to_update_param(d)
|
||||
updates.pop("set__gateway_name")
|
||||
updates.pop("set__vt_symbol")
|
||||
(
|
||||
DbTickData.objects(
|
||||
symbol=d.symbol, exchange=d.exchange.value, datetime=d.datetime
|
||||
).update_one(upsert=True, **updates)
|
||||
)
|
369
vnpy/trader/database/database_sql.py
Normal file
369
vnpy/trader/database/database_sql.py
Normal file
@ -0,0 +1,369 @@
|
||||
""""""
|
||||
from datetime import datetime
|
||||
from typing import List, Sequence, Type
|
||||
|
||||
from peewee import (
|
||||
AutoField,
|
||||
CharField,
|
||||
Database,
|
||||
DateTimeField,
|
||||
FloatField,
|
||||
Model,
|
||||
MySQLDatabase,
|
||||
PostgresqlDatabase,
|
||||
SqliteDatabase,
|
||||
chunked,
|
||||
)
|
||||
|
||||
from vnpy.trader.constant import Exchange, Interval
|
||||
from vnpy.trader.object import BarData, TickData
|
||||
from vnpy.trader.utility import get_file_path
|
||||
from .database import BaseDatabaseManager, Driver
|
||||
|
||||
|
||||
def init(driver: Driver, settings: dict):
|
||||
init_funcs = {
|
||||
Driver.SQLITE: init_sqlite,
|
||||
Driver.MYSQL: init_mysql,
|
||||
Driver.POSTGRESQL: init_postgresql,
|
||||
}
|
||||
assert driver in init_funcs
|
||||
|
||||
db = init_funcs[driver](settings)
|
||||
bar, tick = init_models(db, driver)
|
||||
return SqlManager(bar, tick)
|
||||
|
||||
|
||||
def init_sqlite(settings: dict):
|
||||
database = settings["database"]
|
||||
path = str(get_file_path(database))
|
||||
db = SqliteDatabase(path)
|
||||
return db
|
||||
|
||||
|
||||
def init_mysql(settings: dict):
|
||||
keys = {"database", "user", "password", "host", "port"}
|
||||
settings = {k: v for k, v in settings.items() if k in keys}
|
||||
db = MySQLDatabase(**settings)
|
||||
return db
|
||||
|
||||
|
||||
def init_postgresql(settings: dict):
|
||||
keys = {"database", "user", "password", "host", "port"}
|
||||
settings = {k: v for k, v in settings.items() if k in keys}
|
||||
db = PostgresqlDatabase(**settings)
|
||||
return db
|
||||
|
||||
|
||||
class ModelBase(Model):
|
||||
def to_dict(self):
|
||||
return self.__data__
|
||||
|
||||
|
||||
def init_models(db: Database, driver: Driver):
|
||||
class DbBarData(ModelBase):
|
||||
"""
|
||||
Candlestick bar data for database storage.
|
||||
|
||||
Index is defined unique with datetime, interval, symbol
|
||||
"""
|
||||
|
||||
id = AutoField()
|
||||
symbol: str = CharField()
|
||||
exchange: str = CharField()
|
||||
datetime: datetime = DateTimeField()
|
||||
interval: str = CharField()
|
||||
|
||||
volume: float = FloatField()
|
||||
open_price: float = FloatField()
|
||||
high_price: float = FloatField()
|
||||
low_price: float = FloatField()
|
||||
close_price: float = FloatField()
|
||||
|
||||
class Meta:
|
||||
database = db
|
||||
indexes = ((("datetime", "interval", "symbol", "exchange"), True),)
|
||||
|
||||
@staticmethod
|
||||
def from_bar(bar: BarData):
|
||||
"""
|
||||
Generate DbBarData object from BarData.
|
||||
"""
|
||||
db_bar = DbBarData()
|
||||
|
||||
db_bar.symbol = bar.symbol
|
||||
db_bar.exchange = bar.exchange.value
|
||||
db_bar.datetime = bar.datetime
|
||||
db_bar.interval = bar.interval.value
|
||||
db_bar.volume = bar.volume
|
||||
db_bar.open_price = bar.open_price
|
||||
db_bar.high_price = bar.high_price
|
||||
db_bar.low_price = bar.low_price
|
||||
db_bar.close_price = bar.close_price
|
||||
|
||||
return db_bar
|
||||
|
||||
def to_bar(self):
|
||||
"""
|
||||
Generate BarData object from DbBarData.
|
||||
"""
|
||||
bar = BarData(
|
||||
symbol=self.symbol,
|
||||
exchange=Exchange(self.exchange),
|
||||
datetime=self.datetime,
|
||||
interval=Interval(self.interval),
|
||||
volume=self.volume,
|
||||
open_price=self.open_price,
|
||||
high_price=self.high_price,
|
||||
low_price=self.low_price,
|
||||
close_price=self.close_price,
|
||||
gateway_name="DB",
|
||||
)
|
||||
return bar
|
||||
|
||||
@staticmethod
|
||||
def save_all(objs: List["DbBarData"]):
|
||||
"""
|
||||
save a list of objects, update if exists.
|
||||
"""
|
||||
dicts = [i.to_dict() for i in objs]
|
||||
with db.atomic():
|
||||
if driver is Driver.POSTGRESQL:
|
||||
for bar in dicts:
|
||||
DbBarData.insert(bar).on_conflict(
|
||||
update=bar,
|
||||
conflict_target=(
|
||||
DbBarData.datetime,
|
||||
DbBarData.interval,
|
||||
DbBarData.symbol,
|
||||
DbBarData.exchange,
|
||||
),
|
||||
).execute()
|
||||
else:
|
||||
for c in chunked(dicts, 50):
|
||||
DbBarData.insert_many(c).on_conflict_replace().execute()
|
||||
|
||||
class DbTickData(ModelBase):
|
||||
"""
|
||||
Tick data for database storage.
|
||||
|
||||
Index is defined unique with (datetime, symbol)
|
||||
"""
|
||||
|
||||
id = AutoField()
|
||||
|
||||
symbol: str = CharField()
|
||||
exchange: str = CharField()
|
||||
datetime: datetime = DateTimeField()
|
||||
|
||||
name: str = CharField()
|
||||
volume: float = FloatField()
|
||||
last_price: float = FloatField()
|
||||
last_volume: float = FloatField()
|
||||
limit_up: float = FloatField()
|
||||
limit_down: float = FloatField()
|
||||
|
||||
open_price: float = FloatField()
|
||||
high_price: float = FloatField()
|
||||
low_price: float = FloatField()
|
||||
pre_close: float = FloatField()
|
||||
|
||||
bid_price_1: float = FloatField()
|
||||
bid_price_2: float = FloatField(null=True)
|
||||
bid_price_3: float = FloatField(null=True)
|
||||
bid_price_4: float = FloatField(null=True)
|
||||
bid_price_5: float = FloatField(null=True)
|
||||
|
||||
ask_price_1: float = FloatField()
|
||||
ask_price_2: float = FloatField(null=True)
|
||||
ask_price_3: float = FloatField(null=True)
|
||||
ask_price_4: float = FloatField(null=True)
|
||||
ask_price_5: float = FloatField(null=True)
|
||||
|
||||
bid_volume_1: float = FloatField()
|
||||
bid_volume_2: float = FloatField(null=True)
|
||||
bid_volume_3: float = FloatField(null=True)
|
||||
bid_volume_4: float = FloatField(null=True)
|
||||
bid_volume_5: float = FloatField(null=True)
|
||||
|
||||
ask_volume_1: float = FloatField()
|
||||
ask_volume_2: float = FloatField(null=True)
|
||||
ask_volume_3: float = FloatField(null=True)
|
||||
ask_volume_4: float = FloatField(null=True)
|
||||
ask_volume_5: float = FloatField(null=True)
|
||||
|
||||
class Meta:
|
||||
database = db
|
||||
indexes = ((("datetime", "symbol", "exchange"), True),)
|
||||
|
||||
@staticmethod
|
||||
def from_tick(tick: TickData):
|
||||
"""
|
||||
Generate DbTickData object from TickData.
|
||||
"""
|
||||
db_tick = DbTickData()
|
||||
|
||||
db_tick.symbol = tick.symbol
|
||||
db_tick.exchange = tick.exchange.value
|
||||
db_tick.datetime = tick.datetime
|
||||
db_tick.name = tick.name
|
||||
db_tick.volume = tick.volume
|
||||
db_tick.last_price = tick.last_price
|
||||
db_tick.last_volume = tick.last_volume
|
||||
db_tick.limit_up = tick.limit_up
|
||||
db_tick.limit_down = tick.limit_down
|
||||
db_tick.open_price = tick.open_price
|
||||
db_tick.high_price = tick.high_price
|
||||
db_tick.low_price = tick.low_price
|
||||
db_tick.pre_close = tick.pre_close
|
||||
|
||||
db_tick.bid_price_1 = tick.bid_price_1
|
||||
db_tick.ask_price_1 = tick.ask_price_1
|
||||
db_tick.bid_volume_1 = tick.bid_volume_1
|
||||
db_tick.ask_volume_1 = tick.ask_volume_1
|
||||
|
||||
if tick.bid_price_2:
|
||||
db_tick.bid_price_2 = tick.bid_price_2
|
||||
db_tick.bid_price_3 = tick.bid_price_3
|
||||
db_tick.bid_price_4 = tick.bid_price_4
|
||||
db_tick.bid_price_5 = tick.bid_price_5
|
||||
|
||||
db_tick.ask_price_2 = tick.ask_price_2
|
||||
db_tick.ask_price_3 = tick.ask_price_3
|
||||
db_tick.ask_price_4 = tick.ask_price_4
|
||||
db_tick.ask_price_5 = tick.ask_price_5
|
||||
|
||||
db_tick.bid_volume_2 = tick.bid_volume_2
|
||||
db_tick.bid_volume_3 = tick.bid_volume_3
|
||||
db_tick.bid_volume_4 = tick.bid_volume_4
|
||||
db_tick.bid_volume_5 = tick.bid_volume_5
|
||||
|
||||
db_tick.ask_volume_2 = tick.ask_volume_2
|
||||
db_tick.ask_volume_3 = tick.ask_volume_3
|
||||
db_tick.ask_volume_4 = tick.ask_volume_4
|
||||
db_tick.ask_volume_5 = tick.ask_volume_5
|
||||
|
||||
return db_tick
|
||||
|
||||
def to_tick(self):
|
||||
"""
|
||||
Generate TickData object from DbTickData.
|
||||
"""
|
||||
tick = TickData(
|
||||
symbol=self.symbol,
|
||||
exchange=Exchange(self.exchange),
|
||||
datetime=self.datetime,
|
||||
name=self.name,
|
||||
volume=self.volume,
|
||||
last_price=self.last_price,
|
||||
last_volume=self.last_volume,
|
||||
limit_up=self.limit_up,
|
||||
limit_down=self.limit_down,
|
||||
open_price=self.open_price,
|
||||
high_price=self.high_price,
|
||||
low_price=self.low_price,
|
||||
pre_close=self.pre_close,
|
||||
bid_price_1=self.bid_price_1,
|
||||
ask_price_1=self.ask_price_1,
|
||||
bid_volume_1=self.bid_volume_1,
|
||||
ask_volume_1=self.ask_volume_1,
|
||||
gateway_name="DB",
|
||||
)
|
||||
|
||||
if self.bid_price_2:
|
||||
tick.bid_price_2 = self.bid_price_2
|
||||
tick.bid_price_3 = self.bid_price_3
|
||||
tick.bid_price_4 = self.bid_price_4
|
||||
tick.bid_price_5 = self.bid_price_5
|
||||
|
||||
tick.ask_price_2 = self.ask_price_2
|
||||
tick.ask_price_3 = self.ask_price_3
|
||||
tick.ask_price_4 = self.ask_price_4
|
||||
tick.ask_price_5 = self.ask_price_5
|
||||
|
||||
tick.bid_volume_2 = self.bid_volume_2
|
||||
tick.bid_volume_3 = self.bid_volume_3
|
||||
tick.bid_volume_4 = self.bid_volume_4
|
||||
tick.bid_volume_5 = self.bid_volume_5
|
||||
|
||||
tick.ask_volume_2 = self.ask_volume_2
|
||||
tick.ask_volume_3 = self.ask_volume_3
|
||||
tick.ask_volume_4 = self.ask_volume_4
|
||||
tick.ask_volume_5 = self.ask_volume_5
|
||||
|
||||
return tick
|
||||
|
||||
@staticmethod
|
||||
def save_all(objs: List["DbTickData"]):
|
||||
dicts = [i.to_dict() for i in objs]
|
||||
with db.atomic():
|
||||
if driver is Driver.POSTGRESQL:
|
||||
for tick in dicts:
|
||||
DbTickData.insert(tick).on_conflict(
|
||||
update=tick,
|
||||
conflict_target=(
|
||||
DbTickData.datetime,
|
||||
DbTickData.symbol,
|
||||
DbTickData.exchange,
|
||||
),
|
||||
).execute()
|
||||
else:
|
||||
for c in chunked(dicts, 50):
|
||||
DbTickData.insert_many(c).on_conflict_replace().execute()
|
||||
|
||||
db.connect()
|
||||
db.create_tables([DbBarData, DbTickData])
|
||||
return DbBarData, DbTickData
|
||||
|
||||
|
||||
class SqlManager(BaseDatabaseManager):
|
||||
def __init__(self, class_bar: Type[Model], class_tick: Type[Model]):
|
||||
self.class_bar = class_bar
|
||||
self.class_tick = class_tick
|
||||
|
||||
def load_bar_data(
|
||||
self,
|
||||
symbol: str,
|
||||
exchange: Exchange,
|
||||
interval: Interval,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
) -> Sequence[BarData]:
|
||||
s = (
|
||||
self.class_bar.select()
|
||||
.where(
|
||||
(self.class_bar.symbol == symbol)
|
||||
& (self.class_bar.exchange == exchange.value)
|
||||
& (self.class_bar.interval == interval.value)
|
||||
& (self.class_bar.datetime >= start)
|
||||
& (self.class_bar.datetime <= end)
|
||||
)
|
||||
.order_by(self.class_bar.datetime)
|
||||
)
|
||||
data = [db_bar.to_bar() for db_bar in s]
|
||||
return data
|
||||
|
||||
def load_tick_data(
|
||||
self, symbol: str, exchange: Exchange, start: datetime, end: datetime
|
||||
) -> Sequence[TickData]:
|
||||
s = (
|
||||
self.class_tick.select()
|
||||
.where(
|
||||
(self.class_tick.symbol == symbol)
|
||||
& (self.class_tick.exchange == exchange.value)
|
||||
& (self.class_tick.datetime >= start)
|
||||
& (self.class_tick.datetime <= end)
|
||||
)
|
||||
.order_by(self.class_tick.datetime)
|
||||
)
|
||||
|
||||
data = [db_tick.to_tick() for db_tick in s]
|
||||
return data
|
||||
|
||||
def save_bar_data(self, datas: Sequence[BarData]):
|
||||
ds = [self.class_bar.from_bar(i) for i in datas]
|
||||
self.class_bar.save_all(ds)
|
||||
|
||||
def save_tick_data(self, datas: Sequence[TickData]):
|
||||
ds = [self.class_tick.from_tick(i) for i in datas]
|
||||
self.class_tick.save_all(ds)
|
24
vnpy/trader/database/initialize.py
Normal file
24
vnpy/trader/database/initialize.py
Normal file
@ -0,0 +1,24 @@
|
||||
""""""
|
||||
from .database import BaseDatabaseManager, Driver
|
||||
|
||||
|
||||
def init(settings: dict) -> BaseDatabaseManager:
|
||||
driver = Driver(settings["driver"])
|
||||
if driver is Driver.MONGODB:
|
||||
return init_nosql(driver=driver, settings=settings)
|
||||
else:
|
||||
return init_sql(driver=driver, settings=settings)
|
||||
|
||||
|
||||
def init_sql(driver: Driver, settings: dict):
|
||||
from .database_sql import init
|
||||
keys = {'database', "host", "port", "user", "password"}
|
||||
settings = {k: v for k, v in settings.items() if k in keys}
|
||||
_database_manager = init(driver, settings)
|
||||
return _database_manager
|
||||
|
||||
|
||||
def init_nosql(driver: Driver, settings: dict):
|
||||
from .database_mongo import init
|
||||
_database_manager = init(driver, settings=settings)
|
||||
return _database_manager
|
@ -24,16 +24,21 @@ SETTINGS = {
|
||||
|
||||
"rqdata.username": "",
|
||||
"rqdata.password": "",
|
||||
"database": {
|
||||
"driver": "sqlite", # sqlite, mysql, postgresql
|
||||
"database": "{VNPY_TEMP}/database.db", # for sqlite, use this as filepath
|
||||
"host": "localhost",
|
||||
"port": 3306,
|
||||
"user": "root",
|
||||
"password": ""
|
||||
}
|
||||
|
||||
"database.driver": "sqlite", # see database.Driver
|
||||
"database.database": "database.db", # for sqlite, use this as filepath
|
||||
"database.host": "localhost",
|
||||
"database.port": 3306,
|
||||
"database.user": "root",
|
||||
"database.password": "",
|
||||
"database.authentication_source": "admin", # for mongodb
|
||||
}
|
||||
|
||||
# Load global setting from json file.
|
||||
SETTING_FILENAME = "vt_setting.json"
|
||||
SETTINGS.update(load_json(SETTING_FILENAME))
|
||||
|
||||
|
||||
def get_settings(prefix: str = ""):
|
||||
prefix_length = len(prefix)
|
||||
return {k[prefix_length:]: v for k, v in SETTINGS.items() if k.startswith(prefix)}
|
||||
|
@ -3,20 +3,28 @@ General utility functions.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
from typing import Callable, TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import talib
|
||||
|
||||
from .object import BarData, TickData
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vnpy.trader.constant import Exchange
|
||||
|
||||
def resolve_path(pattern: str):
|
||||
env = dict(os.environ)
|
||||
env.update({"VNPY_TEMP": str(TEMP_DIR)})
|
||||
return pattern.format(**env)
|
||||
|
||||
def extract_vt_symbol(vt_symbol: str):
|
||||
"""
|
||||
:return: (symbol, exchange)
|
||||
"""
|
||||
symbol, exchange = vt_symbol.split('.')
|
||||
return symbol, exchange
|
||||
|
||||
|
||||
def generate_vt_symbol(symbol: str, exchange: "Exchange"):
|
||||
return f'{symbol}.{exchange.value}'
|
||||
|
||||
|
||||
def _get_trader_dir(temp_name: str):
|
||||
|
Loading…
Reference in New Issue
Block a user