commit
1d527bc36e
106
.travis.yml
106
.travis.yml
@ -2,12 +2,27 @@ language: python
|
|||||||
|
|
||||||
dist: xenial # required for Python >= 3.7 (travis-ci/travis-ci#9069)
|
dist: xenial # required for Python >= 3.7 (travis-ci/travis-ci#9069)
|
||||||
|
|
||||||
|
cache: pip
|
||||||
|
|
||||||
|
git:
|
||||||
|
depth: 1
|
||||||
|
|
||||||
python:
|
python:
|
||||||
- "3.7"
|
- "3.7"
|
||||||
|
|
||||||
|
services:
|
||||||
|
- mongodb
|
||||||
|
- mysql
|
||||||
|
- postgresql
|
||||||
|
|
||||||
|
before_script:
|
||||||
|
- psql -d postgresql://postgres:${VNPY_TEST_POSTGRESQL_PASSWORD}@localhost -c "create database vnpy;"
|
||||||
|
- mysql -u root --password=${VNPY_TEST_MYSQL_PASSWORD} -e 'CREATE DATABASE vnpy;'
|
||||||
|
|
||||||
script:
|
script:
|
||||||
# todo: use python unittest
|
- pip install psycopg2 mongoengine pymysql # we should support all database in test environment
|
||||||
- mkdir run; cd run; python ../tests/load_all.py
|
- cd tests; source travis_env.sh;
|
||||||
|
- python test_all.py
|
||||||
|
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
@ -19,81 +34,36 @@ matrix:
|
|||||||
script:
|
script:
|
||||||
- flake8
|
- flake8
|
||||||
|
|
||||||
- name: "pip install under Windows"
|
|
||||||
os: "windows"
|
|
||||||
# language : cpp is necessary for windows
|
|
||||||
language: "cpp"
|
|
||||||
env:
|
|
||||||
- PATH=/c/Python37:/c/Python37/Scripts:$PATH
|
|
||||||
before_install:
|
|
||||||
- choco install python3 --version 3.7.2
|
|
||||||
install:
|
|
||||||
- python -m pip install --upgrade pip wheel setuptools
|
|
||||||
- pip install https://pip.vnpy.com/colletion/TA_Lib-0.4.17-cp37-cp37m-win_amd64.whl
|
|
||||||
- pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
|
|
||||||
- pip install -r requirements.txt
|
|
||||||
- pip install .
|
|
||||||
|
|
||||||
- name: "pip install under Ubuntu: gcc-8"
|
- name: "pip install under Ubuntu: gcc-8"
|
||||||
|
addons:
|
||||||
|
apt:
|
||||||
|
sources:
|
||||||
|
- ubuntu-toolchain-r-test
|
||||||
|
packages:
|
||||||
|
- g++-8
|
||||||
before_install:
|
before_install:
|
||||||
# C++17
|
|
||||||
- sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test
|
|
||||||
- sudo apt-get update -y
|
|
||||||
install:
|
|
||||||
# C++17
|
|
||||||
- sudo apt-get install -y gcc-8 g++-8
|
|
||||||
- sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-8 90
|
- sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-8 90
|
||||||
- sudo update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++-8 90
|
- sudo update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++-8 90
|
||||||
- sudo update-alternatives --install /usr/bin/gcc cc /usr/bin/gcc-8 90
|
- sudo update-alternatives --install /usr/bin/gcc cc /usr/bin/gcc-8 90
|
||||||
|
install:
|
||||||
# update pip & setuptools
|
# update pip & setuptools
|
||||||
- python -m pip install --upgrade pip wheel setuptools
|
- python -m pip install --upgrade pip wheel setuptools
|
||||||
# Linux install script
|
# Linux install script
|
||||||
- pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
|
- pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
|
||||||
- bash ./install.sh
|
- bash ./install.sh
|
||||||
|
|
||||||
- name: "pip install under Ubuntu: gcc-7"
|
- name: "sdist install under Ubuntu: gcc-7"
|
||||||
|
addons:
|
||||||
|
apt:
|
||||||
|
sources:
|
||||||
|
- ubuntu-toolchain-r-test
|
||||||
|
packages:
|
||||||
|
- g++-7
|
||||||
before_install:
|
before_install:
|
||||||
# C++17
|
|
||||||
- sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test
|
|
||||||
- sudo apt-get update -y
|
|
||||||
install:
|
|
||||||
# C++17
|
|
||||||
- sudo apt-get install -y gcc-7 g++-7
|
|
||||||
- sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-7 90
|
- sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-7 90
|
||||||
- sudo update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++-7 90
|
- sudo update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++-7 90
|
||||||
- sudo update-alternatives --install /usr/bin/gcc cc /usr/bin/gcc-7 90
|
- sudo update-alternatives --install /usr/bin/gcc cc /usr/bin/gcc-7 90
|
||||||
# update pip & setuptools
|
|
||||||
- python -m pip install --upgrade pip wheel setuptools
|
|
||||||
# Linux install script
|
|
||||||
- pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
|
|
||||||
- bash ./install.sh
|
|
||||||
|
|
||||||
- name: "sdist install under Windows"
|
|
||||||
os: "windows"
|
|
||||||
# language : cpp is necessary for windows
|
|
||||||
language: "cpp"
|
|
||||||
env:
|
|
||||||
- PATH=/c/Python37:/c/Python37/Scripts:$PATH
|
|
||||||
before_install:
|
|
||||||
- choco install python3 --version 3.7.2
|
|
||||||
install:
|
install:
|
||||||
- python -m pip install --upgrade pip wheel setuptools
|
|
||||||
- python setup.py sdist
|
|
||||||
- pip install https://pip.vnpy.com/colletion/TA_Lib-0.4.17-cp37-cp37m-win_amd64.whl
|
|
||||||
- pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
|
|
||||||
- pip install dist/`ls dist`
|
|
||||||
|
|
||||||
- name: "sdist install under Ubuntu: gcc-8"
|
|
||||||
before_install:
|
|
||||||
# C++17
|
|
||||||
- sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test
|
|
||||||
- sudo apt-get update -y
|
|
||||||
install:
|
|
||||||
# C++17
|
|
||||||
- sudo apt-get install -y gcc-8 g++-8
|
|
||||||
- sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-8 90
|
|
||||||
- sudo update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++-8 90
|
|
||||||
- sudo update-alternatives --install /usr/bin/gcc cc /usr/bin/gcc-8 90
|
|
||||||
# Linux install script
|
# Linux install script
|
||||||
- python -m pip install --upgrade pip wheel setuptools
|
- python -m pip install --upgrade pip wheel setuptools
|
||||||
- pushd /tmp
|
- pushd /tmp
|
||||||
@ -108,3 +78,17 @@ matrix:
|
|||||||
- pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
|
- pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
|
||||||
- python setup.py sdist
|
- python setup.py sdist
|
||||||
- pip install dist/`ls dist`
|
- pip install dist/`ls dist`
|
||||||
|
|
||||||
|
- name: "pip install under osx"
|
||||||
|
os: osx
|
||||||
|
language: shell # osx supports only shell
|
||||||
|
services: []
|
||||||
|
before_install: []
|
||||||
|
install:
|
||||||
|
- pip3 install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
|
||||||
|
- bash ./install_osx.sh
|
||||||
|
before_script: []
|
||||||
|
script:
|
||||||
|
- pip3 install psycopg2 mongoengine pymysql # we should support all database in test environment
|
||||||
|
- cd tests; source travis_env.sh;
|
||||||
|
- VNPY_TEST_ONLY_SQLITE=1 python3 test_all.py
|
||||||
|
68
appveyor.yml
Normal file
68
appveyor.yml
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
clone_depth: 1
|
||||||
|
|
||||||
|
cache:
|
||||||
|
- '%LOCALAPPDATA%\pip\Cache'
|
||||||
|
|
||||||
|
configuration:
|
||||||
|
- pip
|
||||||
|
- sdist
|
||||||
|
|
||||||
|
image:
|
||||||
|
- Visual Studio 2017
|
||||||
|
|
||||||
|
services:
|
||||||
|
- mysql
|
||||||
|
- mongodb
|
||||||
|
- postgresql
|
||||||
|
|
||||||
|
environment:
|
||||||
|
PATH: C:\Python37-x64;C:\Python37-x64\Scripts;C:\Program Files\PostgreSQL\9.6\bin\;%PATH%
|
||||||
|
VNPY_TEST_MYSQL_DATABASE: vnpy
|
||||||
|
VNPY_TEST_MYSQL_HOST: localhost
|
||||||
|
VNPY_TEST_MYSQL_PORT: 3306
|
||||||
|
VNPY_TEST_MYSQL_USER: root
|
||||||
|
VNPY_TEST_MYSQL_PASSWORD: Password12!
|
||||||
|
|
||||||
|
VNPY_TEST_POSTGRESQL_DATABASE: vnpy
|
||||||
|
VNPY_TEST_POSTGRESQL_HOST: localhost
|
||||||
|
VNPY_TEST_POSTGRESQL_PORT: 5432
|
||||||
|
VNPY_TEST_POSTGRESQL_USER: postgres
|
||||||
|
VNPY_TEST_POSTGRESQL_PASSWORD: Password12!
|
||||||
|
|
||||||
|
VNPY_TEST_MONGODB_DATABASE: vnpy
|
||||||
|
VNPY_TEST_MONGODB_HOST: localhost
|
||||||
|
VNPY_TEST_MONGODB_PORT: 27017
|
||||||
|
|
||||||
|
MYSQL_PWD: Password12!
|
||||||
|
|
||||||
|
install:
|
||||||
|
- python -m pip install --upgrade pip wheel setuptools
|
||||||
|
before_build:
|
||||||
|
- ps: psql -d "postgresql://${ENV:VNPY_TEST_POSTGRESQL_USER}:${ENV:VNPY_TEST_POSTGRESQL_PASSWORD}@localhost" -c "create database vnpy;"
|
||||||
|
- ps: . "C:\Program Files\MySQL\MySQL Server 5.7\bin\mysql" -u $ENV:VNPY_TEST_MYSQL_USER -e 'CREATE DATABASE vnpy;'
|
||||||
|
|
||||||
|
for:
|
||||||
|
- matrix:
|
||||||
|
only:
|
||||||
|
- configuration: pip
|
||||||
|
build_script:
|
||||||
|
- pip install psycopg2 mongoengine pymysql # we should support all database in test environment
|
||||||
|
- pip install https://pip.vnpy.com/colletion/TA_Lib-0.4.17-cp37-cp37m-win_amd64.whl
|
||||||
|
- pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
|
||||||
|
- pip install -r requirements.txt
|
||||||
|
- pip install .
|
||||||
|
|
||||||
|
- matrix:
|
||||||
|
only:
|
||||||
|
- configuration: sdist
|
||||||
|
build_script:
|
||||||
|
- python setup.py sdist
|
||||||
|
- pip install psycopg2 mongoengine pymysql # we should support all database in test environment
|
||||||
|
- pip install https://pip.vnpy.com/colletion/TA_Lib-0.4.17-cp37-cp37m-win_amd64.whl
|
||||||
|
- pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
|
||||||
|
- ps: $name=(ls dist).name; pip install "dist/$name"
|
||||||
|
|
||||||
|
test_script:
|
||||||
|
- cd tests
|
||||||
|
- python test_all.py
|
||||||
|
|
22
install.sh
22
install.sh
@ -1,27 +1,35 @@
|
|||||||
#!/usr/bin/env bash
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
python=$1
|
||||||
|
pip=$2
|
||||||
|
prefix=$3
|
||||||
|
|
||||||
|
[[ -z $python ]] && python=python
|
||||||
|
[[ -z $pip ]] && pip=pip
|
||||||
|
[[ -z $prefix ]] && prefix=/usr
|
||||||
|
|
||||||
# Get and build ta-lib
|
# Get and build ta-lib
|
||||||
pushd /tmp
|
pushd /tmp
|
||||||
wget http://prdownloads.sourceforge.net/ta-lib/ta-lib-0.4.0-src.tar.gz
|
wget http://prdownloads.sourceforge.net/ta-lib/ta-lib-0.4.0-src.tar.gz
|
||||||
tar -xf ta-lib-0.4.0-src.tar.gz
|
tar -xf ta-lib-0.4.0-src.tar.gz
|
||||||
cd ta-lib
|
cd ta-lib
|
||||||
./configure --prefix=/usr
|
./configure --prefix=$prefix
|
||||||
make
|
make -j
|
||||||
sudo make install
|
sudo make install
|
||||||
popd
|
popd
|
||||||
|
|
||||||
# old versions of ta-lib imports numpy in setup.py
|
# old versions of ta-lib imports numpy in setup.py
|
||||||
pip install numpy
|
$pip install numpy
|
||||||
|
|
||||||
# Install extra packages
|
# Install extra packages
|
||||||
pip install ta-lib
|
$pip install ta-lib
|
||||||
pip install https://vnpy-pip.oss-cn-shanghai.aliyuncs.com/colletion/ibapi-9.75.1-py3-none-any.whl
|
$pip install https://vnpy-pip.oss-cn-shanghai.aliyuncs.com/colletion/ibapi-9.75.1-py3-none-any.whl
|
||||||
|
|
||||||
# Install Python Modules
|
# Install Python Modules
|
||||||
pip install -r requirements.txt
|
$pip install -r requirements.txt
|
||||||
|
|
||||||
# Install local Chinese language environment
|
# Install local Chinese language environment
|
||||||
sudo locale-gen zh_CN.GB18030
|
sudo locale-gen zh_CN.GB18030
|
||||||
|
|
||||||
# Install vn.py
|
# Install vn.py
|
||||||
pip install .
|
$pip install .
|
3
install_osx.sh
Normal file
3
install_osx.sh
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
bash ./install.sh python3 pip3 /usr/local
|
123
setup.py
123
setup.py
@ -29,70 +29,79 @@ with open("vnpy/__init__.py", "rb") as f:
|
|||||||
version = str(ast.literal_eval(version_line))
|
version = str(ast.literal_eval(version_line))
|
||||||
|
|
||||||
if platform.uname().system == "Windows":
|
if platform.uname().system == "Windows":
|
||||||
compiler_flags = ["/MP", "/std:c++17", # standard
|
compiler_flags = [
|
||||||
"/O2", "/Ob2", "/Oi", "/Ot", "/Oy", "/GL", # Optimization
|
"/MP", "/std:c++17", # standard
|
||||||
"/wd4819" # 936 code page
|
"/O2", "/Ob2", "/Oi", "/Ot", "/Oy", "/GL", # Optimization
|
||||||
]
|
"/wd4819" # 936 code page
|
||||||
|
]
|
||||||
extra_link_args = []
|
extra_link_args = []
|
||||||
else:
|
else:
|
||||||
compiler_flags = ["-std=c++17",
|
compiler_flags = [
|
||||||
"-Wno-delete-incomplete", "-Wno-sign-compare",
|
"-std=c++17", # standard
|
||||||
]
|
"-O3", # Optimization
|
||||||
|
"-Wno-delete-incomplete", "-Wno-sign-compare",
|
||||||
|
]
|
||||||
extra_link_args = ["-lstdc++"]
|
extra_link_args = ["-lstdc++"]
|
||||||
|
|
||||||
vnctpmd = Extension("vnpy.api.ctp.vnctpmd",
|
vnctpmd = Extension(
|
||||||
[
|
"vnpy.api.ctp.vnctpmd",
|
||||||
"vnpy/api/ctp/vnctp/vnctpmd/vnctpmd.cpp",
|
[
|
||||||
],
|
"vnpy/api/ctp/vnctp/vnctpmd/vnctpmd.cpp",
|
||||||
include_dirs=["vnpy/api/ctp/include",
|
],
|
||||||
"vnpy/api/ctp/vnctp", ],
|
include_dirs=["vnpy/api/ctp/include",
|
||||||
define_macros=[],
|
"vnpy/api/ctp/vnctp", ],
|
||||||
undef_macros=[],
|
define_macros=[],
|
||||||
library_dirs=["vnpy/api/ctp/libs", "vnpy/api/ctp"],
|
undef_macros=[],
|
||||||
libraries=["thostmduserapi", "thosttraderapi", ],
|
library_dirs=["vnpy/api/ctp/libs", "vnpy/api/ctp"],
|
||||||
extra_compile_args=compiler_flags,
|
libraries=["thostmduserapi", "thosttraderapi", ],
|
||||||
extra_link_args=extra_link_args,
|
extra_compile_args=compiler_flags,
|
||||||
depends=[],
|
extra_link_args=extra_link_args,
|
||||||
runtime_library_dirs=["$ORIGIN"],
|
depends=[],
|
||||||
language="cpp",
|
runtime_library_dirs=["$ORIGIN"],
|
||||||
)
|
language="cpp",
|
||||||
vnctptd = Extension("vnpy.api.ctp.vnctptd",
|
)
|
||||||
[
|
vnctptd = Extension(
|
||||||
"vnpy/api/ctp/vnctp/vnctptd/vnctptd.cpp",
|
"vnpy.api.ctp.vnctptd",
|
||||||
],
|
[
|
||||||
include_dirs=["vnpy/api/ctp/include",
|
"vnpy/api/ctp/vnctp/vnctptd/vnctptd.cpp",
|
||||||
"vnpy/api/ctp/vnctp", ],
|
],
|
||||||
define_macros=[],
|
include_dirs=["vnpy/api/ctp/include",
|
||||||
undef_macros=[],
|
"vnpy/api/ctp/vnctp", ],
|
||||||
library_dirs=["vnpy/api/ctp/libs", "vnpy/api/ctp"],
|
define_macros=[],
|
||||||
libraries=["thostmduserapi", "thosttraderapi", ],
|
undef_macros=[],
|
||||||
extra_compile_args=compiler_flags,
|
library_dirs=["vnpy/api/ctp/libs", "vnpy/api/ctp"],
|
||||||
extra_link_args=extra_link_args,
|
libraries=["thostmduserapi", "thosttraderapi", ],
|
||||||
runtime_library_dirs=["$ORIGIN"],
|
extra_compile_args=compiler_flags,
|
||||||
depends=[],
|
extra_link_args=extra_link_args,
|
||||||
language="cpp",
|
runtime_library_dirs=["$ORIGIN"],
|
||||||
)
|
depends=[],
|
||||||
vnoes = Extension("vnpy.api.oes.vnoes",
|
language="cpp",
|
||||||
[
|
)
|
||||||
"vnpy/api/oes/vnoes/generated_files/classes_1.cpp",
|
vnoes = Extension(
|
||||||
"vnpy/api/oes/vnoes/generated_files/classes_2.cpp",
|
"vnpy.api.oes.vnoes",
|
||||||
"vnpy/api/oes/vnoes/generated_files/module.cpp",
|
[
|
||||||
],
|
"vnpy/api/oes/vnoes/generated_files/classes_1.cpp",
|
||||||
include_dirs=["vnpy/api/oes/include",
|
"vnpy/api/oes/vnoes/generated_files/classes_2.cpp",
|
||||||
"vnpy/api/oes/vnoes", ],
|
"vnpy/api/oes/vnoes/generated_files/module.cpp",
|
||||||
define_macros=[("BRIGAND_NO_BOOST_SUPPORT", "1")],
|
],
|
||||||
undef_macros=[],
|
include_dirs=["vnpy/api/oes/include",
|
||||||
library_dirs=["vnpy/api/oes/libs"],
|
"vnpy/api/oes/vnoes", ],
|
||||||
libraries=["oes_api"],
|
define_macros=[("BRIGAND_NO_BOOST_SUPPORT", "1")],
|
||||||
extra_compile_args=compiler_flags,
|
undef_macros=[],
|
||||||
extra_link_args=extra_link_args,
|
library_dirs=["vnpy/api/oes/libs"],
|
||||||
depends=[],
|
libraries=["oes_api"],
|
||||||
language="cpp",
|
extra_compile_args=compiler_flags,
|
||||||
)
|
extra_link_args=extra_link_args,
|
||||||
|
runtime_library_dirs=["$ORIGIN"],
|
||||||
|
depends=[],
|
||||||
|
language="cpp",
|
||||||
|
)
|
||||||
|
|
||||||
if platform.uname().system == "Windows":
|
if platform.system() == "Windows":
|
||||||
# use pre-built pyd for windows ( support python 3.7 only )
|
# use pre-built pyd for windows ( support python 3.7 only )
|
||||||
ext_modules = []
|
ext_modules = []
|
||||||
|
elif platform.system() == "Darwin":
|
||||||
|
ext_modules = []
|
||||||
else:
|
else:
|
||||||
ext_modules = [vnctptd, vnctpmd, vnoes]
|
ext_modules = [vnctptd, vnctpmd, vnoes]
|
||||||
|
|
||||||
|
1
tests/app/__init__.py
Normal file
1
tests/app/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from .test_csv_loader import *
|
90
tests/app/test_csv_loader.py
Normal file
90
tests/app/test_csv_loader.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
"""
|
||||||
|
Test if csv loader works fine
|
||||||
|
"""
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from vnpy.app.csv_loader import CsvLoaderEngine
|
||||||
|
from vnpy.trader.constant import Exchange, Interval
|
||||||
|
|
||||||
|
|
||||||
|
class TestCsvLoader(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.engine = CsvLoaderEngine(None, None) # no engine is necessary for CsvLoader
|
||||||
|
|
||||||
|
def test_load(self):
|
||||||
|
data = """"Datetime","Open","High","Low","Close","Volume"
|
||||||
|
2010-04-16 09:16:00,3450.0,3488.0,3450.0,3468.0,489
|
||||||
|
2010-04-16 09:17:00,3468.0,3473.8,3467.0,3467.0,302
|
||||||
|
2010-04-16 09:18:00,3467.0,3471.0,3466.0,3467.0,203
|
||||||
|
2010-04-16 09:19:00,3467.0,3468.2,3448.0,3448.0,280
|
||||||
|
2010-04-16 09:20:00,3448.0,3459.0,3448.0,3454.0,250
|
||||||
|
2010-04-16 09:21:00,3454.0,3456.8,3454.0,3456.8,109
|
||||||
|
"""
|
||||||
|
with tempfile.TemporaryFile("w+t") as f:
|
||||||
|
f.write(data)
|
||||||
|
f.seek(0)
|
||||||
|
|
||||||
|
self.engine.load_by_handle(
|
||||||
|
f,
|
||||||
|
symbol="1",
|
||||||
|
exchange=Exchange.BITMEX,
|
||||||
|
interval=Interval.MINUTE,
|
||||||
|
datetime_head="Datetime",
|
||||||
|
open_head="Open",
|
||||||
|
close_head="Close",
|
||||||
|
low_head="Low",
|
||||||
|
high_head="High",
|
||||||
|
volume_head="Volume",
|
||||||
|
datetime_format="%Y-%m-%d %H:%M:%S",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_load_duplicated(self):
|
||||||
|
data = """"Datetime","Open","High","Low","Close","Volume"
|
||||||
|
2010-04-16 09:16:00,3450.0,3488.0,3450.0,3468.0,489
|
||||||
|
2010-04-16 09:17:00,3468.0,3473.8,3467.0,3467.0,302
|
||||||
|
2010-04-16 09:18:00,3467.0,3471.0,3466.0,3467.0,203
|
||||||
|
2010-04-16 09:19:00,3467.0,3468.2,3448.0,3448.0,280
|
||||||
|
2010-04-16 09:20:00,3448.0,3459.0,3448.0,3454.0,250
|
||||||
|
2010-04-16 09:21:00,3454.0,3456.8,3454.0,3456.8,109
|
||||||
|
"""
|
||||||
|
with tempfile.TemporaryFile("w+t") as f:
|
||||||
|
f.write(data)
|
||||||
|
f.seek(0)
|
||||||
|
|
||||||
|
self.engine.load_by_handle(
|
||||||
|
f,
|
||||||
|
symbol="1",
|
||||||
|
exchange=Exchange.BITMEX,
|
||||||
|
interval=Interval.MINUTE,
|
||||||
|
datetime_head="Datetime",
|
||||||
|
open_head="Open",
|
||||||
|
close_head="Close",
|
||||||
|
low_head="Low",
|
||||||
|
high_head="High",
|
||||||
|
volume_head="Volume",
|
||||||
|
datetime_format="%Y-%m-%d %H:%M:%S",
|
||||||
|
)
|
||||||
|
|
||||||
|
with tempfile.TemporaryFile("w+t") as f:
|
||||||
|
f.write(data)
|
||||||
|
f.seek(0)
|
||||||
|
|
||||||
|
self.engine.load_by_handle(
|
||||||
|
f,
|
||||||
|
symbol="1",
|
||||||
|
exchange=Exchange.BITMEX,
|
||||||
|
interval=Interval.MINUTE,
|
||||||
|
datetime_head="Datetime",
|
||||||
|
open_head="Open",
|
||||||
|
close_head="Close",
|
||||||
|
low_head="Low",
|
||||||
|
high_head="High",
|
||||||
|
volume_head="Volume",
|
||||||
|
datetime_format="%Y-%m-%d %H:%M:%S",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
@ -2,7 +2,7 @@ from time import time
|
|||||||
|
|
||||||
import rqdatac as rq
|
import rqdatac as rq
|
||||||
|
|
||||||
from vnpy.trader.database import DbBarData, DB
|
from vnpy.trader.database import DbBarData
|
||||||
|
|
||||||
USERNAME = ""
|
USERNAME = ""
|
||||||
PASSWORD = ""
|
PASSWORD = ""
|
||||||
@ -39,11 +39,11 @@ def download_minute_bar(vt_symbol):
|
|||||||
|
|
||||||
df = rq.get_price(symbol, frequency="1m", fields=FIELDS)
|
df = rq.get_price(symbol, frequency="1m", fields=FIELDS)
|
||||||
|
|
||||||
with DB.atomic():
|
bars = []
|
||||||
for ix, row in df.iterrows():
|
for ix, row in df.iterrows():
|
||||||
print(row.name)
|
bar = generate_bar_from_row(row, symbol, exchange)
|
||||||
bar = generate_bar_from_row(row, symbol, exchange)
|
bars.append(bar)
|
||||||
DbBarData.replace(bar.__data__).execute()
|
DbBarData.save_all(bars)
|
||||||
|
|
||||||
end = time()
|
end = time()
|
||||||
cost = (end - start) * 1000
|
cost = (end - start) * 1000
|
||||||
|
31
tests/test_all.py
Normal file
31
tests/test_all.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
# tests/runner.py
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import app
|
||||||
|
# import your test modules
|
||||||
|
import test_import_all
|
||||||
|
import trader
|
||||||
|
|
||||||
|
# initialize the test suite
|
||||||
|
loader = unittest.TestLoader()
|
||||||
|
suite = unittest.TestSuite()
|
||||||
|
|
||||||
|
# add tests to the test suite
|
||||||
|
suite.addTests(loader.loadTestsFromModule(test_import_all))
|
||||||
|
suite.addTests(loader.loadTestsFromModule(trader))
|
||||||
|
suite.addTests(loader.loadTestsFromModule(app))
|
||||||
|
|
||||||
|
|
||||||
|
# initialize a runner, pass it your suite and run it
|
||||||
|
def main():
|
||||||
|
runner = unittest.TextTestRunner(verbosity=3)
|
||||||
|
result = runner.run(suite)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
result = main()
|
||||||
|
if result.failures or result.errors:
|
||||||
|
exit(1)
|
||||||
|
else:
|
||||||
|
exit(0)
|
@ -1,24 +1,45 @@
|
|||||||
# flake8: noqa
|
# flake8: noqa
|
||||||
import unittest
|
import unittest
|
||||||
|
import platform
|
||||||
|
|
||||||
|
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
class ImportTest(unittest.TestCase):
|
class ImportTest(unittest.TestCase):
|
||||||
|
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
def test_import_all(self):
|
def test_import_all(self):
|
||||||
from vnpy.event import EventEngine
|
from vnpy.event import EventEngine
|
||||||
|
|
||||||
|
def test_import_main_engine(self):
|
||||||
from vnpy.trader.engine import MainEngine
|
from vnpy.trader.engine import MainEngine
|
||||||
|
|
||||||
|
def test_import_ui(self):
|
||||||
from vnpy.trader.ui import MainWindow, create_qapp
|
from vnpy.trader.ui import MainWindow, create_qapp
|
||||||
|
|
||||||
|
def test_import_bitmex_gateway(self):
|
||||||
from vnpy.gateway.bitmex import BitmexGateway
|
from vnpy.gateway.bitmex import BitmexGateway
|
||||||
|
|
||||||
|
def test_import_futu_gateway(self):
|
||||||
from vnpy.gateway.futu import FutuGateway
|
from vnpy.gateway.futu import FutuGateway
|
||||||
|
|
||||||
|
def test_import_ib_gateway(self):
|
||||||
from vnpy.gateway.ib import IbGateway
|
from vnpy.gateway.ib import IbGateway
|
||||||
|
|
||||||
|
@unittest.skipIf(platform.system() == "Darwin", "Not supported yet under osx")
|
||||||
|
def test_import_ctp_gateway(self):
|
||||||
from vnpy.gateway.ctp import CtpGateway
|
from vnpy.gateway.ctp import CtpGateway
|
||||||
|
|
||||||
|
def test_import_tiger_gateway(self):
|
||||||
from vnpy.gateway.tiger import TigerGateway
|
from vnpy.gateway.tiger import TigerGateway
|
||||||
|
|
||||||
|
@unittest.skipIf(platform.system() == "Darwin", "Not supported yet under osx")
|
||||||
|
def test_import_oes_gateway(self):
|
||||||
from vnpy.gateway.oes import OesGateway
|
from vnpy.gateway.oes import OesGateway
|
||||||
|
|
||||||
|
def test_import_cta_strategy_app(self):
|
||||||
from vnpy.app.cta_strategy import CtaStrategyApp
|
from vnpy.app.cta_strategy import CtaStrategyApp
|
||||||
|
|
||||||
|
def test_import_csv_loader_app(self):
|
||||||
from vnpy.app.csv_loader import CsvLoaderApp
|
from vnpy.app.csv_loader import CsvLoaderApp
|
||||||
|
|
||||||
|
|
2
tests/trader/__init__.py
Normal file
2
tests/trader/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
from .test_database import *
|
||||||
|
from .test_settings import *
|
128
tests/trader/test_database.py
Normal file
128
tests/trader/test_database.py
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
"""
|
||||||
|
Test if database works fine
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
from vnpy.trader.constant import Exchange, Interval
|
||||||
|
from vnpy.trader.database.database import Driver
|
||||||
|
from vnpy.trader.object import BarData, TickData
|
||||||
|
|
||||||
|
os.environ['VNPY_TESTING'] = '1'
|
||||||
|
|
||||||
|
profiles = {
|
||||||
|
Driver.SQLITE: {
|
||||||
|
"driver": "sqlite",
|
||||||
|
"database": "test_db.db",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if 'VNPY_TEST_ONLY_SQLITE' not in os.environ:
|
||||||
|
profiles.update({
|
||||||
|
Driver.MYSQL: {
|
||||||
|
"driver": "mysql",
|
||||||
|
"database": os.environ['VNPY_TEST_MYSQL_DATABASE'],
|
||||||
|
"host": os.environ['VNPY_TEST_MYSQL_HOST'],
|
||||||
|
"port": int(os.environ['VNPY_TEST_MYSQL_PORT']),
|
||||||
|
"user": os.environ["VNPY_TEST_MYSQL_USER"],
|
||||||
|
"password": os.environ['VNPY_TEST_MYSQL_PASSWORD'],
|
||||||
|
},
|
||||||
|
Driver.POSTGRESQL: {
|
||||||
|
"driver": "postgresql",
|
||||||
|
"database": os.environ['VNPY_TEST_POSTGRESQL_DATABASE'],
|
||||||
|
"host": os.environ['VNPY_TEST_POSTGRESQL_HOST'],
|
||||||
|
"port": int(os.environ['VNPY_TEST_POSTGRESQL_PORT']),
|
||||||
|
"user": os.environ["VNPY_TEST_POSTGRESQL_USER"],
|
||||||
|
"password": os.environ['VNPY_TEST_POSTGRESQL_PASSWORD'],
|
||||||
|
},
|
||||||
|
Driver.MONGODB: {
|
||||||
|
"driver": "mongodb",
|
||||||
|
"database": os.environ['VNPY_TEST_MONGODB_DATABASE'],
|
||||||
|
"host": os.environ['VNPY_TEST_MONGODB_HOST'],
|
||||||
|
"port": int(os.environ['VNPY_TEST_MONGODB_PORT']),
|
||||||
|
"user": "",
|
||||||
|
"password": "",
|
||||||
|
"authentication_source": "",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def now():
|
||||||
|
return datetime.utcnow()
|
||||||
|
|
||||||
|
|
||||||
|
bar = BarData(
|
||||||
|
gateway_name="DB",
|
||||||
|
symbol="test_symbol",
|
||||||
|
exchange=Exchange.BITMEX,
|
||||||
|
datetime=now(),
|
||||||
|
interval=Interval.MINUTE,
|
||||||
|
)
|
||||||
|
|
||||||
|
tick = TickData(
|
||||||
|
gateway_name="DB",
|
||||||
|
symbol="test_symbol",
|
||||||
|
exchange=Exchange.BITMEX,
|
||||||
|
datetime=now(),
|
||||||
|
name="DB_test_symbol",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDatabase(unittest.TestCase):
|
||||||
|
|
||||||
|
def connect(self, settings: dict):
|
||||||
|
from vnpy.trader.database.initialize import init # noqa
|
||||||
|
self.manager = init(settings)
|
||||||
|
|
||||||
|
def test_upsert_bar(self):
|
||||||
|
for driver, settings in profiles.items():
|
||||||
|
with self.subTest(driver=driver, settings=settings):
|
||||||
|
self.connect(settings)
|
||||||
|
self.manager.save_bar_data([bar])
|
||||||
|
self.manager.save_bar_data([bar])
|
||||||
|
|
||||||
|
def test_save_load_bar(self):
|
||||||
|
for driver, settings in profiles.items():
|
||||||
|
with self.subTest(driver=driver, settings=settings):
|
||||||
|
self.connect(settings)
|
||||||
|
# save first
|
||||||
|
self.manager.save_bar_data([bar])
|
||||||
|
|
||||||
|
# and load
|
||||||
|
results = self.manager.load_bar_data(
|
||||||
|
symbol=bar.symbol,
|
||||||
|
exchange=bar.exchange,
|
||||||
|
interval=bar.interval,
|
||||||
|
start=bar.datetime - timedelta(seconds=1), # time is not accuracy
|
||||||
|
end=now() + timedelta(seconds=1), # time is not accuracy
|
||||||
|
)
|
||||||
|
count = len(results)
|
||||||
|
self.assertNotEqual(count, 0)
|
||||||
|
|
||||||
|
def test_upsert_tick(self):
|
||||||
|
for driver, settings in profiles.items():
|
||||||
|
with self.subTest(driver=driver, settings=settings):
|
||||||
|
self.connect(settings)
|
||||||
|
self.manager.save_tick_data([tick])
|
||||||
|
self.manager.save_tick_data([tick])
|
||||||
|
|
||||||
|
def test_save_load_tick(self):
|
||||||
|
for driver, settings in profiles.items():
|
||||||
|
with self.subTest(driver=driver, settings=settings):
|
||||||
|
self.connect(settings)
|
||||||
|
# save first
|
||||||
|
self.manager.save_tick_data([tick])
|
||||||
|
|
||||||
|
# and load
|
||||||
|
results = self.manager.load_tick_data(
|
||||||
|
symbol=bar.symbol,
|
||||||
|
exchange=bar.exchange,
|
||||||
|
start=bar.datetime - timedelta(seconds=1), # time is not accuracy
|
||||||
|
end=now() + timedelta(seconds=1), # time is not accuracy
|
||||||
|
)
|
||||||
|
count = len(results)
|
||||||
|
self.assertNotEqual(count, 0)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
25
tests/trader/test_settings.py
Normal file
25
tests/trader/test_settings.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
"""
|
||||||
|
Test if database works fine
|
||||||
|
"""
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from vnpy.trader.setting import SETTINGS, get_settings
|
||||||
|
|
||||||
|
|
||||||
|
class TestSettings(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_get_settings(self):
|
||||||
|
SETTINGS['a'] = 1
|
||||||
|
got = get_settings()
|
||||||
|
self.assertIn('a', got)
|
||||||
|
self.assertEqual(got['a'], 1)
|
||||||
|
|
||||||
|
def test_get_settings_with_prefix(self):
|
||||||
|
SETTINGS['a.a'] = 1
|
||||||
|
got = get_settings()
|
||||||
|
self.assertIn('a', got)
|
||||||
|
self.assertEqual(got['a'], 1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
30
tests/travis_env.sh
Normal file
30
tests/travis_env.sh
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
[[ -z ${VNPY_TEST_MYSQL_DATABASE} ]] && VNPY_TEST_MYSQL_DATABASE=vnpy
|
||||||
|
[[ -z ${VNPY_TEST_MYSQL_HOST} ]] && VNPY_TEST_MYSQL_HOST=127.0.0.1
|
||||||
|
[[ -z ${VNPY_TEST_MYSQL_PORT} ]] && VNPY_TEST_MYSQL_PORT=3306
|
||||||
|
[[ -z ${VNPY_TEST_MYSQL_USER} ]] && VNPY_TEST_MYSQL_USER=root
|
||||||
|
[[ -z ${VNPY_TEST_MYSQL_PASSWORD} ]] && VNPY_TEST_MYSQL_PASSWORD=
|
||||||
|
[[ -z ${VNPY_TEST_POSTGRESQL_DATABASE} ]] && VNPY_TEST_POSTGRESQL_DATABASE=vnpy
|
||||||
|
[[ -z ${VNPY_TEST_POSTGRESQL_HOST} ]] && VNPY_TEST_POSTGRESQL_HOST=127.0.0.1
|
||||||
|
[[ -z ${VNPY_TEST_POSTGRESQL_PORT} ]] && VNPY_TEST_POSTGRESQL_PORT=5432
|
||||||
|
[[ -z ${VNPY_TEST_POSTGRESQL_USER} ]] && VNPY_TEST_POSTGRESQL_USER=postgres
|
||||||
|
[[ -z ${VNPY_TEST_POSTGRESQL_PASSWORD} ]] && VNPY_TEST_POSTGRESQL_PASSWORD=
|
||||||
|
[[ -z ${VNPY_TEST_MONGODB_DATABASE} ]] && VNPY_TEST_MONGODB_DATABASE=vnpy
|
||||||
|
[[ -z ${VNPY_TEST_MONGODB_HOST} ]] && VNPY_TEST_MONGODB_HOST=127.0.0.1
|
||||||
|
[[ -z ${VNPY_TEST_MONGODB_PORT} ]] && VNPY_TEST_MONGODB_PORT=27017
|
||||||
|
|
||||||
|
|
||||||
|
export VNPY_TEST_MYSQL_DATABASE
|
||||||
|
export VNPY_TEST_MYSQL_HOST
|
||||||
|
export VNPY_TEST_MYSQL_PORT
|
||||||
|
export VNPY_TEST_MYSQL_USER
|
||||||
|
export VNPY_TEST_MYSQL_PASSWORD
|
||||||
|
export VNPY_TEST_POSTGRESQL_DATABASE
|
||||||
|
export VNPY_TEST_POSTGRESQL_HOST
|
||||||
|
export VNPY_TEST_POSTGRESQL_PORT
|
||||||
|
export VNPY_TEST_POSTGRESQL_USER
|
||||||
|
export VNPY_TEST_POSTGRESQL_PASSWORD
|
||||||
|
export VNPY_TEST_MONGODB_DATABASE
|
||||||
|
export VNPY_TEST_MONGODB_HOST
|
||||||
|
export VNPY_TEST_MONGODB_PORT
|
@ -22,14 +22,13 @@ Sample csv file:
|
|||||||
|
|
||||||
import csv
|
import csv
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from typing import TextIO
|
||||||
from peewee import chunked
|
|
||||||
|
|
||||||
from vnpy.event import EventEngine
|
from vnpy.event import EventEngine
|
||||||
from vnpy.trader.constant import Exchange, Interval
|
from vnpy.trader.constant import Exchange, Interval
|
||||||
from vnpy.trader.database import DbBarData, DB
|
from vnpy.trader.database import database_manager
|
||||||
from vnpy.trader.engine import BaseEngine, MainEngine
|
from vnpy.trader.engine import BaseEngine, MainEngine
|
||||||
|
from vnpy.trader.object import BarData
|
||||||
|
|
||||||
APP_NAME = "CsvLoader"
|
APP_NAME = "CsvLoader"
|
||||||
|
|
||||||
@ -53,6 +52,59 @@ class CsvLoaderEngine(BaseEngine):
|
|||||||
self.high_head: str = ""
|
self.high_head: str = ""
|
||||||
self.volume_head: str = ""
|
self.volume_head: str = ""
|
||||||
|
|
||||||
|
def load_by_handle(
|
||||||
|
self,
|
||||||
|
f: TextIO,
|
||||||
|
symbol: str,
|
||||||
|
exchange: Exchange,
|
||||||
|
interval: Interval,
|
||||||
|
datetime_head: str,
|
||||||
|
open_head: str,
|
||||||
|
close_head: str,
|
||||||
|
low_head: str,
|
||||||
|
high_head: str,
|
||||||
|
volume_head: str,
|
||||||
|
datetime_format: str,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
load by text mode file handle
|
||||||
|
"""
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
|
||||||
|
bars = []
|
||||||
|
start = None
|
||||||
|
count = 0
|
||||||
|
for item in reader:
|
||||||
|
if datetime_format:
|
||||||
|
dt = datetime.strptime(item[datetime_head], datetime_format)
|
||||||
|
else:
|
||||||
|
dt = datetime.fromisoformat(item[datetime_head])
|
||||||
|
|
||||||
|
bar = BarData(
|
||||||
|
symbol=symbol,
|
||||||
|
exchange=exchange,
|
||||||
|
datetime=dt,
|
||||||
|
interval=interval,
|
||||||
|
volume=item[volume_head],
|
||||||
|
open_price=item[open_head],
|
||||||
|
high_price=item[high_head],
|
||||||
|
low_price=item[low_head],
|
||||||
|
close_price=item[close_head],
|
||||||
|
gateway_name="DB",
|
||||||
|
)
|
||||||
|
|
||||||
|
bars.append(bar)
|
||||||
|
|
||||||
|
# do some statistics
|
||||||
|
count += 1
|
||||||
|
if not start:
|
||||||
|
start = bar.datetime
|
||||||
|
end = bar.datetime
|
||||||
|
|
||||||
|
# insert into database
|
||||||
|
database_manager.save_bar_data(bars)
|
||||||
|
return start, end, count
|
||||||
|
|
||||||
def load(
|
def load(
|
||||||
self,
|
self,
|
||||||
file_path: str,
|
file_path: str,
|
||||||
@ -65,49 +117,22 @@ class CsvLoaderEngine(BaseEngine):
|
|||||||
low_head: str,
|
low_head: str,
|
||||||
high_head: str,
|
high_head: str,
|
||||||
volume_head: str,
|
volume_head: str,
|
||||||
datetime_format: str
|
datetime_format: str,
|
||||||
):
|
):
|
||||||
""""""
|
"""
|
||||||
vt_symbol = f"{symbol}.{exchange.value}"
|
load by filename
|
||||||
|
"""
|
||||||
start = None
|
|
||||||
end = None
|
|
||||||
count = 0
|
|
||||||
|
|
||||||
with open(file_path, "rt") as f:
|
with open(file_path, "rt") as f:
|
||||||
reader = csv.DictReader(f)
|
return self.load_by_handle(
|
||||||
|
f,
|
||||||
db_bars = []
|
symbol=symbol,
|
||||||
|
exchange=exchange,
|
||||||
for item in reader:
|
interval=interval,
|
||||||
dt = datetime.strptime(item[datetime_head], datetime_format)
|
datetime_head=datetime_head,
|
||||||
|
open_head=open_head,
|
||||||
db_bar = {
|
close_head=close_head,
|
||||||
"symbol": symbol,
|
low_head=low_head,
|
||||||
"exchange": exchange.value,
|
high_head=high_head,
|
||||||
"datetime": dt,
|
volume_head=volume_head,
|
||||||
"interval": interval.value,
|
datetime_format=datetime_format,
|
||||||
"volume": item[volume_head],
|
)
|
||||||
"open_price": item[open_head],
|
|
||||||
"high_price": item[high_head],
|
|
||||||
"low_price": item[low_head],
|
|
||||||
"close_price": item[close_head],
|
|
||||||
"vt_symbol": vt_symbol,
|
|
||||||
"gateway_name": "DB"
|
|
||||||
}
|
|
||||||
|
|
||||||
db_bars.append(db_bar)
|
|
||||||
|
|
||||||
# do some statistics
|
|
||||||
count += 1
|
|
||||||
if not start:
|
|
||||||
start = db_bar["datetime"]
|
|
||||||
|
|
||||||
end = db_bar["datetime"]
|
|
||||||
|
|
||||||
# Insert into DB
|
|
||||||
with DB.atomic():
|
|
||||||
for batch in chunked(db_bars, 50):
|
|
||||||
DbBarData.insert_many(batch).on_conflict_replace().execute()
|
|
||||||
|
|
||||||
return start, end, count
|
|
||||||
|
@ -12,8 +12,8 @@ from pandas import DataFrame
|
|||||||
|
|
||||||
from vnpy.trader.constant import (Direction, Offset, Exchange,
|
from vnpy.trader.constant import (Direction, Offset, Exchange,
|
||||||
Interval, Status)
|
Interval, Status)
|
||||||
from vnpy.trader.database import DbBarData, DbTickData
|
from vnpy.trader.database import database_manager
|
||||||
from vnpy.trader.object import OrderData, TradeData
|
from vnpy.trader.object import OrderData, TradeData, BarData, TickData
|
||||||
from vnpy.trader.utility import round_to_pricetick
|
from vnpy.trader.utility import round_to_pricetick
|
||||||
|
|
||||||
from .base import (
|
from .base import (
|
||||||
@ -103,8 +103,8 @@ class BacktestingEngine:
|
|||||||
|
|
||||||
self.strategy_class = None
|
self.strategy_class = None
|
||||||
self.strategy = None
|
self.strategy = None
|
||||||
self.tick = None
|
self.tick: TickData
|
||||||
self.bar = None
|
self.bar: BarData
|
||||||
self.datetime = None
|
self.datetime = None
|
||||||
|
|
||||||
self.interval = None
|
self.interval = None
|
||||||
@ -199,14 +199,16 @@ class BacktestingEngine:
|
|||||||
|
|
||||||
if self.mode == BacktestingMode.BAR:
|
if self.mode == BacktestingMode.BAR:
|
||||||
self.history_data = load_bar_data(
|
self.history_data = load_bar_data(
|
||||||
self.vt_symbol,
|
self.symbol,
|
||||||
|
self.exchange,
|
||||||
self.interval,
|
self.interval,
|
||||||
self.start,
|
self.start,
|
||||||
self.end
|
self.end
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.history_data = load_tick_data(
|
self.history_data = load_tick_data(
|
||||||
self.vt_symbol,
|
self.symbol,
|
||||||
|
self.exchange,
|
||||||
self.start,
|
self.start,
|
||||||
self.end
|
self.end
|
||||||
)
|
)
|
||||||
@ -520,7 +522,7 @@ class BacktestingEngine:
|
|||||||
else:
|
else:
|
||||||
self.daily_results[d] = DailyResult(d, price)
|
self.daily_results[d] = DailyResult(d, price)
|
||||||
|
|
||||||
def new_bar(self, bar: DbBarData):
|
def new_bar(self, bar: BarData):
|
||||||
""""""
|
""""""
|
||||||
self.bar = bar
|
self.bar = bar
|
||||||
self.datetime = bar.datetime
|
self.datetime = bar.datetime
|
||||||
@ -531,7 +533,7 @@ class BacktestingEngine:
|
|||||||
|
|
||||||
self.update_daily_close(bar.close_price)
|
self.update_daily_close(bar.close_price)
|
||||||
|
|
||||||
def new_tick(self, tick: DbTickData):
|
def new_tick(self, tick: TickData):
|
||||||
""""""
|
""""""
|
||||||
self.tick = tick
|
self.tick = tick
|
||||||
self.datetime = tick.datetime
|
self.datetime = tick.datetime
|
||||||
@ -966,41 +968,26 @@ def optimize(
|
|||||||
|
|
||||||
@lru_cache(maxsize=10)
|
@lru_cache(maxsize=10)
|
||||||
def load_bar_data(
|
def load_bar_data(
|
||||||
vt_symbol: str,
|
symbol: str,
|
||||||
interval: str,
|
exchange: Exchange,
|
||||||
start: datetime,
|
interval: Interval,
|
||||||
|
start: datetime,
|
||||||
end: datetime
|
end: datetime
|
||||||
):
|
):
|
||||||
""""""
|
""""""
|
||||||
s = (
|
return database_manager.load_bar_data(
|
||||||
DbBarData.select()
|
symbol, exchange, interval, start, end
|
||||||
.where(
|
|
||||||
(DbBarData.vt_symbol == vt_symbol)
|
|
||||||
& (DbBarData.interval == interval)
|
|
||||||
& (DbBarData.datetime >= start)
|
|
||||||
& (DbBarData.datetime <= end)
|
|
||||||
)
|
|
||||||
.order_by(DbBarData.datetime)
|
|
||||||
)
|
)
|
||||||
data = [db_bar.to_bar() for db_bar in s]
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=10)
|
@lru_cache(maxsize=10)
|
||||||
def load_tick_data(
|
def load_tick_data(
|
||||||
vt_symbol: str,
|
symbol: str,
|
||||||
start: datetime,
|
exchange: Exchange,
|
||||||
|
start: datetime,
|
||||||
end: datetime
|
end: datetime
|
||||||
):
|
):
|
||||||
""""""
|
""""""
|
||||||
s = (
|
return database_manager.load_tick_data(
|
||||||
DbTickData.select()
|
symbol, exchange, start, end
|
||||||
.where(
|
|
||||||
(DbTickData.vt_symbol == vt_symbol)
|
|
||||||
& (DbTickData.datetime >= start)
|
|
||||||
& (DbTickData.datetime <= end)
|
|
||||||
)
|
|
||||||
.order_by(DbTickData.datetime)
|
|
||||||
)
|
)
|
||||||
data = [db_tick.db_tick() for db_tick in s]
|
|
||||||
return data
|
|
@ -5,7 +5,7 @@ import os
|
|||||||
import traceback
|
import traceback
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable
|
from typing import Any, Callable, List
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
@ -36,7 +36,7 @@ from vnpy.trader.constant import (
|
|||||||
Status
|
Status
|
||||||
)
|
)
|
||||||
from vnpy.trader.utility import load_json, save_json
|
from vnpy.trader.utility import load_json, save_json
|
||||||
from vnpy.trader.database import DbTickData, DbBarData
|
from vnpy.trader.database import database_manager
|
||||||
from vnpy.trader.setting import SETTINGS
|
from vnpy.trader.setting import SETTINGS
|
||||||
|
|
||||||
from .base import (
|
from .base import (
|
||||||
@ -146,13 +146,12 @@ class CtaEngine(BaseEngine):
|
|||||||
self.write_log("RQData数据接口初始化成功")
|
self.write_log("RQData数据接口初始化成功")
|
||||||
|
|
||||||
def query_bar_from_rq(
|
def query_bar_from_rq(
|
||||||
self, vt_symbol: str, interval: Interval, start: datetime, end: datetime
|
self, symbol: str, exchange: Exchange, interval: Interval, start: datetime, end: datetime
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Query bar data from RQData.
|
Query bar data from RQData.
|
||||||
"""
|
"""
|
||||||
symbol, exchange_str = vt_symbol.split(".")
|
rq_symbol = to_rq_symbol(symbol, exchange)
|
||||||
rq_symbol = to_rq_symbol(vt_symbol)
|
|
||||||
if rq_symbol not in self.rq_symbols:
|
if rq_symbol not in self.rq_symbols:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -166,11 +165,11 @@ class CtaEngine(BaseEngine):
|
|||||||
end_date=end
|
end_date=end
|
||||||
)
|
)
|
||||||
|
|
||||||
data = []
|
data: List[BarData] = []
|
||||||
for ix, row in df.iterrows():
|
for ix, row in df.iterrows():
|
||||||
bar = BarData(
|
bar = BarData(
|
||||||
symbol=symbol,
|
symbol=symbol,
|
||||||
exchange=Exchange(exchange_str),
|
exchange=exchange,
|
||||||
interval=interval,
|
interval=interval,
|
||||||
datetime=row.name.to_pydatetime(),
|
datetime=row.name.to_pydatetime(),
|
||||||
open_price=row["open"],
|
open_price=row["open"],
|
||||||
@ -529,46 +528,41 @@ class CtaEngine(BaseEngine):
|
|||||||
return self.engine_type
|
return self.engine_type
|
||||||
|
|
||||||
def load_bar(
|
def load_bar(
|
||||||
self, vt_symbol: str, days: int, interval: Interval, callback: Callable
|
self, symbol: str, exchange: Exchange, days: int, interval: Interval,
|
||||||
|
callback: Callable[[BarData], None]
|
||||||
):
|
):
|
||||||
""""""
|
""""""
|
||||||
end = datetime.now()
|
end = datetime.now()
|
||||||
start = end - timedelta(days)
|
start = end - timedelta(days)
|
||||||
|
|
||||||
# Query data from RQData by default, if not found, load from database.
|
# Query bars from RQData by default, if not found, load from database.
|
||||||
data = self.query_bar_from_rq(vt_symbol, interval, start, end)
|
bars = self.query_bar_from_rq(symbol, exchange, interval, start, end)
|
||||||
if not data:
|
if not bars:
|
||||||
s = (
|
bars = database_manager.load_bar_data(
|
||||||
DbBarData.select()
|
symbol=symbol,
|
||||||
.where(
|
exchange=exchange,
|
||||||
(DbBarData.vt_symbol == vt_symbol)
|
interval=interval,
|
||||||
& (DbBarData.interval == interval.value)
|
start=start,
|
||||||
& (DbBarData.datetime >= start)
|
end=end,
|
||||||
& (DbBarData.datetime <= end)
|
|
||||||
)
|
|
||||||
.order_by(DbBarData.datetime)
|
|
||||||
)
|
)
|
||||||
data = [db_bar.to_bar() for db_bar in s]
|
|
||||||
|
|
||||||
for bar in data:
|
for bar in bars:
|
||||||
callback(bar)
|
callback(bar)
|
||||||
|
|
||||||
def load_tick(self, vt_symbol: str, days: int, callback: Callable):
|
def load_tick(self, symbol: str, exchange: Exchange, days: int,
|
||||||
|
callback: Callable[[TickData], None]):
|
||||||
""""""
|
""""""
|
||||||
end = datetime.now()
|
end = datetime.now()
|
||||||
start = end - timedelta(days)
|
start = end - timedelta(days)
|
||||||
|
|
||||||
s = (
|
ticks = database_manager.load_tick_data(
|
||||||
DbTickData.select()
|
symbol=symbol,
|
||||||
.where(
|
exchange=exchange,
|
||||||
(DbBarData.vt_symbol == vt_symbol)
|
start=start,
|
||||||
& (DbBarData.datetime >= start)
|
end=end,
|
||||||
& (DbBarData.datetime <= end)
|
|
||||||
)
|
|
||||||
.order_by(DbBarData.datetime)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for tick in s:
|
for tick in ticks:
|
||||||
callback(tick)
|
callback(tick)
|
||||||
|
|
||||||
def call_strategy_func(
|
def call_strategy_func(
|
||||||
@ -757,7 +751,7 @@ class CtaEngine(BaseEngine):
|
|||||||
"""
|
"""
|
||||||
Load strategy class from certain folder.
|
Load strategy class from certain folder.
|
||||||
"""
|
"""
|
||||||
for dirpath, dirnames, filenames in os.walk(path):
|
for dirpath, dirnames, filenames in os.walk(str(path)):
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
if filename.endswith(".py"):
|
if filename.endswith(".py"):
|
||||||
strategy_module_name = ".".join(
|
strategy_module_name = ".".join(
|
||||||
@ -914,19 +908,19 @@ class CtaEngine(BaseEngine):
|
|||||||
self.main_engine.send_email(subject, msg)
|
self.main_engine.send_email(subject, msg)
|
||||||
|
|
||||||
|
|
||||||
def to_rq_symbol(vt_symbol: str):
|
def to_rq_symbol(symbol: str, exchange: Exchange):
|
||||||
"""
|
"""
|
||||||
CZCE product of RQData has symbol like "TA1905" while
|
CZCE product of RQData has symbol like "TA1905" while
|
||||||
vt symbol is "TA905.CZCE" so need to add "1" in symbol.
|
vt symbol is "TA905.CZCE" so need to add "1" in symbol.
|
||||||
"""
|
"""
|
||||||
symbol, exchange_str = vt_symbol.split(".")
|
if exchange is not Exchange.CZCE:
|
||||||
if exchange_str != "CZCE":
|
|
||||||
return symbol.upper()
|
return symbol.upper()
|
||||||
|
|
||||||
for count, word in enumerate(symbol):
|
for count, word in enumerate(symbol):
|
||||||
if word.isdigit():
|
if word.isdigit():
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# noinspection PyUnboundLocalVariable
|
||||||
product = symbol[:count]
|
product = symbol[:count]
|
||||||
year = symbol[count]
|
year = symbol[count]
|
||||||
month = symbol[count + 1:]
|
month = symbol[count + 1:]
|
||||||
|
@ -1,268 +0,0 @@
|
|||||||
""""""
|
|
||||||
|
|
||||||
from peewee import CharField, DateTimeField, FloatField, Model, MySQLDatabase, PostgresqlDatabase, \
|
|
||||||
SqliteDatabase
|
|
||||||
|
|
||||||
from .constant import Exchange, Interval
|
|
||||||
from .object import BarData, TickData
|
|
||||||
from .setting import SETTINGS
|
|
||||||
from .utility import resolve_path
|
|
||||||
|
|
||||||
|
|
||||||
def init():
|
|
||||||
db_settings = SETTINGS['database']
|
|
||||||
driver = db_settings["driver"]
|
|
||||||
|
|
||||||
init_funcs = {
|
|
||||||
"sqlite": init_sqlite,
|
|
||||||
"mysql": init_mysql,
|
|
||||||
"postgresql": init_postgresql,
|
|
||||||
}
|
|
||||||
|
|
||||||
assert driver in init_funcs
|
|
||||||
del db_settings['driver']
|
|
||||||
return init_funcs[driver](db_settings)
|
|
||||||
|
|
||||||
|
|
||||||
def init_sqlite(settings: dict):
|
|
||||||
global DB
|
|
||||||
database = settings['database']
|
|
||||||
|
|
||||||
DB = SqliteDatabase(str(resolve_path(database)))
|
|
||||||
|
|
||||||
|
|
||||||
def init_mysql(settings: dict):
|
|
||||||
global DB
|
|
||||||
DB = MySQLDatabase(**settings)
|
|
||||||
|
|
||||||
|
|
||||||
def init_postgresql(settings: dict):
|
|
||||||
global DB
|
|
||||||
DB = PostgresqlDatabase(**settings)
|
|
||||||
|
|
||||||
|
|
||||||
init()
|
|
||||||
|
|
||||||
|
|
||||||
class DbBarData(Model):
|
|
||||||
"""
|
|
||||||
Candlestick bar data for database storage.
|
|
||||||
|
|
||||||
Index is defined unique with vt_symbol, interval and datetime.
|
|
||||||
"""
|
|
||||||
|
|
||||||
symbol = CharField()
|
|
||||||
exchange = CharField()
|
|
||||||
datetime = DateTimeField()
|
|
||||||
interval = CharField()
|
|
||||||
|
|
||||||
volume = FloatField()
|
|
||||||
open_price = FloatField()
|
|
||||||
high_price = FloatField()
|
|
||||||
low_price = FloatField()
|
|
||||||
close_price = FloatField()
|
|
||||||
|
|
||||||
vt_symbol = CharField()
|
|
||||||
gateway_name = CharField()
|
|
||||||
|
|
||||||
class Meta:
|
|
||||||
database = DB
|
|
||||||
indexes = ((("vt_symbol", "interval", "datetime"), True),)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_bar(bar: BarData):
|
|
||||||
"""
|
|
||||||
Generate DbBarData object from BarData.
|
|
||||||
"""
|
|
||||||
db_bar = DbBarData()
|
|
||||||
|
|
||||||
db_bar.symbol = bar.symbol
|
|
||||||
db_bar.exchange = bar.exchange.value
|
|
||||||
db_bar.datetime = bar.datetime
|
|
||||||
db_bar.interval = bar.interval.value
|
|
||||||
db_bar.volume = bar.volume
|
|
||||||
db_bar.open_price = bar.open_price
|
|
||||||
db_bar.high_price = bar.high_price
|
|
||||||
db_bar.low_price = bar.low_price
|
|
||||||
db_bar.close_price = bar.close_price
|
|
||||||
db_bar.vt_symbol = bar.vt_symbol
|
|
||||||
db_bar.gateway_name = "DB"
|
|
||||||
|
|
||||||
return db_bar
|
|
||||||
|
|
||||||
def to_bar(self):
|
|
||||||
"""
|
|
||||||
Generate BarData object from DbBarData.
|
|
||||||
"""
|
|
||||||
bar = BarData(
|
|
||||||
symbol=self.symbol,
|
|
||||||
exchange=Exchange(self.exchange),
|
|
||||||
datetime=self.datetime,
|
|
||||||
interval=Interval(self.interval),
|
|
||||||
volume=self.volume,
|
|
||||||
open_price=self.open_price,
|
|
||||||
high_price=self.high_price,
|
|
||||||
low_price=self.low_price,
|
|
||||||
close_price=self.close_price,
|
|
||||||
gateway_name=self.gateway_name,
|
|
||||||
)
|
|
||||||
return bar
|
|
||||||
|
|
||||||
|
|
||||||
class DbTickData(Model):
|
|
||||||
"""
|
|
||||||
Tick data for database storage.
|
|
||||||
|
|
||||||
Index is defined unique with vt_symbol, interval and datetime.
|
|
||||||
"""
|
|
||||||
|
|
||||||
symbol = CharField()
|
|
||||||
exchange = CharField()
|
|
||||||
datetime = DateTimeField()
|
|
||||||
|
|
||||||
name = CharField()
|
|
||||||
volume = FloatField()
|
|
||||||
last_price = FloatField()
|
|
||||||
last_volume = FloatField()
|
|
||||||
limit_up = FloatField()
|
|
||||||
limit_down = FloatField()
|
|
||||||
|
|
||||||
open_price = FloatField()
|
|
||||||
high_price = FloatField()
|
|
||||||
low_price = FloatField()
|
|
||||||
close_price = FloatField()
|
|
||||||
|
|
||||||
bid_price_1 = FloatField()
|
|
||||||
bid_price_2 = FloatField()
|
|
||||||
bid_price_3 = FloatField()
|
|
||||||
bid_price_4 = FloatField()
|
|
||||||
bid_price_5 = FloatField()
|
|
||||||
|
|
||||||
ask_price_1 = FloatField()
|
|
||||||
ask_price_2 = FloatField()
|
|
||||||
ask_price_3 = FloatField()
|
|
||||||
ask_price_4 = FloatField()
|
|
||||||
ask_price_5 = FloatField()
|
|
||||||
|
|
||||||
bid_volume_1 = FloatField()
|
|
||||||
bid_volume_2 = FloatField()
|
|
||||||
bid_volume_3 = FloatField()
|
|
||||||
bid_volume_4 = FloatField()
|
|
||||||
bid_volume_5 = FloatField()
|
|
||||||
|
|
||||||
ask_volume_1 = FloatField()
|
|
||||||
ask_volume_2 = FloatField()
|
|
||||||
ask_volume_3 = FloatField()
|
|
||||||
ask_volume_4 = FloatField()
|
|
||||||
ask_volume_5 = FloatField()
|
|
||||||
|
|
||||||
vt_symbol = CharField()
|
|
||||||
gateway_name = CharField()
|
|
||||||
|
|
||||||
class Meta:
|
|
||||||
database = DB
|
|
||||||
indexes = ((("vt_symbol", "datetime"), True),)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_tick(tick: TickData):
|
|
||||||
"""
|
|
||||||
Generate DbTickData object from TickData.
|
|
||||||
"""
|
|
||||||
db_tick = DbTickData()
|
|
||||||
|
|
||||||
db_tick.symbol = tick.symbol
|
|
||||||
db_tick.exchange = tick.exchange.value
|
|
||||||
db_tick.datetime = tick.datetime
|
|
||||||
db_tick.name = tick.name
|
|
||||||
db_tick.volume = tick.volume
|
|
||||||
db_tick.last_price = tick.last_price
|
|
||||||
db_tick.last_volume = tick.last_volume
|
|
||||||
db_tick.limit_up = tick.limit_up
|
|
||||||
db_tick.limit_down = tick.limit_down
|
|
||||||
db_tick.open_price = tick.open_price
|
|
||||||
db_tick.high_price = tick.high_price
|
|
||||||
db_tick.low_price = tick.low_price
|
|
||||||
db_tick.pre_close = tick.pre_close
|
|
||||||
|
|
||||||
db_tick.bid_price_1 = tick.bid_price_1
|
|
||||||
db_tick.ask_price_1 = tick.ask_price_1
|
|
||||||
db_tick.bid_volume_1 = tick.bid_volume_1
|
|
||||||
db_tick.ask_volume_1 = tick.ask_volume_1
|
|
||||||
|
|
||||||
if tick.bid_price_2:
|
|
||||||
db_tick.bid_price_2 = tick.bid_price_2
|
|
||||||
db_tick.bid_price_3 = tick.bid_price_3
|
|
||||||
db_tick.bid_price_4 = tick.bid_price_4
|
|
||||||
db_tick.bid_price_5 = tick.bid_price_5
|
|
||||||
|
|
||||||
db_tick.ask_price_2 = tick.ask_price_2
|
|
||||||
db_tick.ask_price_3 = tick.ask_price_3
|
|
||||||
db_tick.ask_price_4 = tick.ask_price_4
|
|
||||||
db_tick.ask_price_5 = tick.ask_price_5
|
|
||||||
|
|
||||||
db_tick.bid_volume_2 = tick.bid_volume_2
|
|
||||||
db_tick.bid_volume_3 = tick.bid_volume_3
|
|
||||||
db_tick.bid_volume_4 = tick.bid_volume_4
|
|
||||||
db_tick.bid_volume_5 = tick.bid_volume_5
|
|
||||||
|
|
||||||
db_tick.ask_volume_2 = tick.ask_volume_2
|
|
||||||
db_tick.ask_volume_3 = tick.ask_volume_3
|
|
||||||
db_tick.ask_volume_4 = tick.ask_volume_4
|
|
||||||
db_tick.ask_volume_5 = tick.ask_volume_5
|
|
||||||
|
|
||||||
db_tick.vt_symbol = tick.vt_symbol
|
|
||||||
db_tick.gateway_name = "DB"
|
|
||||||
|
|
||||||
return tick
|
|
||||||
|
|
||||||
def to_tick(self):
|
|
||||||
"""
|
|
||||||
Generate TickData object from DbTickData.
|
|
||||||
"""
|
|
||||||
tick = TickData(
|
|
||||||
symbol=self.symbol,
|
|
||||||
exchange=Exchange(self.exchange),
|
|
||||||
datetime=self.datetime,
|
|
||||||
name=self.name,
|
|
||||||
volume=self.volume,
|
|
||||||
last_price=self.last_price,
|
|
||||||
last_volume=self.last_volume,
|
|
||||||
limit_up=self.limit_up,
|
|
||||||
limit_down=self.limit_down,
|
|
||||||
open_price=self.open_price,
|
|
||||||
high_price=self.high_price,
|
|
||||||
low_price=self.low_price,
|
|
||||||
pre_close=self.pre_close,
|
|
||||||
bid_price_1=self.bid_price_1,
|
|
||||||
ask_price_1=self.ask_price_1,
|
|
||||||
bid_volume_1=self.bid_volume_1,
|
|
||||||
ask_volume_1=self.ask_volume_1,
|
|
||||||
gateway_name=self.gateway_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.bid_price_2:
|
|
||||||
tick.bid_price_2 = self.bid_price_2
|
|
||||||
tick.bid_price_3 = self.bid_price_3
|
|
||||||
tick.bid_price_4 = self.bid_price_4
|
|
||||||
tick.bid_price_5 = self.bid_price_5
|
|
||||||
|
|
||||||
tick.ask_price_2 = self.ask_price_2
|
|
||||||
tick.ask_price_3 = self.ask_price_3
|
|
||||||
tick.ask_price_4 = self.ask_price_4
|
|
||||||
tick.ask_price_5 = self.ask_price_5
|
|
||||||
|
|
||||||
tick.bid_volume_2 = self.bid_volume_2
|
|
||||||
tick.bid_volume_3 = self.bid_volume_3
|
|
||||||
tick.bid_volume_4 = self.bid_volume_4
|
|
||||||
tick.bid_volume_5 = self.bid_volume_5
|
|
||||||
|
|
||||||
tick.ask_volume_2 = self.ask_volume_2
|
|
||||||
tick.ask_volume_3 = self.ask_volume_3
|
|
||||||
tick.ask_volume_4 = self.ask_volume_4
|
|
||||||
tick.ask_volume_5 = self.ask_volume_5
|
|
||||||
|
|
||||||
return tick
|
|
||||||
|
|
||||||
|
|
||||||
DB.connect()
|
|
||||||
DB.create_tables([DbBarData, DbTickData])
|
|
12
vnpy/trader/database/__init__.py
Normal file
12
vnpy/trader/database/__init__.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
import os
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vnpy.trader.database.database import BaseDatabaseManager
|
||||||
|
|
||||||
|
if "VNPY_TESTING" not in os.environ:
|
||||||
|
from vnpy.trader.setting import get_settings
|
||||||
|
from .initialize import init
|
||||||
|
|
||||||
|
settings = get_settings("database.")
|
||||||
|
database_manager: "BaseDatabaseManager" = init(settings=settings)
|
53
vnpy/trader/database/database.py
Normal file
53
vnpy/trader/database/database.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Sequence, TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vnpy.trader.constant import Interval, Exchange # noqa
|
||||||
|
from vnpy.trader.object import BarData, TickData # noqa
|
||||||
|
|
||||||
|
|
||||||
|
class Driver(Enum):
|
||||||
|
SQLITE = "sqlite"
|
||||||
|
MYSQL = "mysql"
|
||||||
|
POSTGRESQL = "postgresql"
|
||||||
|
MONGODB = "mongodb"
|
||||||
|
|
||||||
|
|
||||||
|
class BaseDatabaseManager(ABC):
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load_bar_data(
|
||||||
|
self,
|
||||||
|
symbol: str,
|
||||||
|
exchange: "Exchange",
|
||||||
|
interval: "Interval",
|
||||||
|
start: datetime,
|
||||||
|
end: datetime
|
||||||
|
) -> Sequence["BarData"]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load_tick_data(
|
||||||
|
self,
|
||||||
|
symbol: str,
|
||||||
|
exchange: "Exchange",
|
||||||
|
start: datetime,
|
||||||
|
end: datetime
|
||||||
|
) -> Sequence["TickData"]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save_bar_data(
|
||||||
|
self,
|
||||||
|
datas: Sequence["BarData"],
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save_tick_data(
|
||||||
|
self,
|
||||||
|
datas: Sequence["TickData"],
|
||||||
|
):
|
||||||
|
pass
|
302
vnpy/trader/database/database_mongo.py
Normal file
302
vnpy/trader/database/database_mongo.py
Normal file
@ -0,0 +1,302 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Sequence
|
||||||
|
|
||||||
|
from mongoengine import DateTimeField, Document, FloatField, StringField, connect
|
||||||
|
|
||||||
|
from vnpy.trader.constant import Exchange, Interval
|
||||||
|
from vnpy.trader.object import BarData, TickData
|
||||||
|
from .database import BaseDatabaseManager, Driver
|
||||||
|
|
||||||
|
|
||||||
|
def init(_: Driver, settings: dict):
|
||||||
|
database = settings["database"]
|
||||||
|
host = settings["host"]
|
||||||
|
port = settings["port"]
|
||||||
|
username = settings["user"]
|
||||||
|
password = settings["password"]
|
||||||
|
authentication_source = settings["authentication_source"]
|
||||||
|
if not username: # if username == '' or None, skip username
|
||||||
|
username = None
|
||||||
|
password = None
|
||||||
|
authentication_source = None
|
||||||
|
connect(
|
||||||
|
db=database,
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
username=username,
|
||||||
|
password=password,
|
||||||
|
authentication_source=authentication_source,
|
||||||
|
)
|
||||||
|
return MongoManager()
|
||||||
|
|
||||||
|
|
||||||
|
class DbBarData(Document):
|
||||||
|
"""
|
||||||
|
Candlestick bar data for database storage.
|
||||||
|
|
||||||
|
Index is defined unique with datetime, interval, symbol
|
||||||
|
"""
|
||||||
|
|
||||||
|
symbol: str = StringField()
|
||||||
|
exchange: str = StringField()
|
||||||
|
datetime: datetime = DateTimeField()
|
||||||
|
interval: str = StringField()
|
||||||
|
|
||||||
|
volume: float = FloatField()
|
||||||
|
open_price: float = FloatField()
|
||||||
|
high_price: float = FloatField()
|
||||||
|
low_price: float = FloatField()
|
||||||
|
close_price: float = FloatField()
|
||||||
|
|
||||||
|
meta = {
|
||||||
|
"indexes": [
|
||||||
|
{"fields": ("datetime", "interval", "symbol", "exchange"), "unique": True}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_bar(bar: BarData):
|
||||||
|
"""
|
||||||
|
Generate DbBarData object from BarData.
|
||||||
|
"""
|
||||||
|
db_bar = DbBarData()
|
||||||
|
|
||||||
|
db_bar.symbol = bar.symbol
|
||||||
|
db_bar.exchange = bar.exchange.value
|
||||||
|
db_bar.datetime = bar.datetime
|
||||||
|
db_bar.interval = bar.interval.value
|
||||||
|
db_bar.volume = bar.volume
|
||||||
|
db_bar.open_price = bar.open_price
|
||||||
|
db_bar.high_price = bar.high_price
|
||||||
|
db_bar.low_price = bar.low_price
|
||||||
|
db_bar.close_price = bar.close_price
|
||||||
|
|
||||||
|
return db_bar
|
||||||
|
|
||||||
|
def to_bar(self):
|
||||||
|
"""
|
||||||
|
Generate BarData object from DbBarData.
|
||||||
|
"""
|
||||||
|
bar = BarData(
|
||||||
|
symbol=self.symbol,
|
||||||
|
exchange=Exchange(self.exchange),
|
||||||
|
datetime=self.datetime,
|
||||||
|
interval=Interval(self.interval),
|
||||||
|
volume=self.volume,
|
||||||
|
open_price=self.open_price,
|
||||||
|
high_price=self.high_price,
|
||||||
|
low_price=self.low_price,
|
||||||
|
close_price=self.close_price,
|
||||||
|
gateway_name="DB",
|
||||||
|
)
|
||||||
|
return bar
|
||||||
|
|
||||||
|
|
||||||
|
class DbTickData(Document):
|
||||||
|
"""
|
||||||
|
Tick data for database storage.
|
||||||
|
|
||||||
|
Index is defined unique with (datetime, symbol)
|
||||||
|
"""
|
||||||
|
|
||||||
|
symbol: str = StringField()
|
||||||
|
exchange: str = StringField()
|
||||||
|
datetime: datetime = DateTimeField()
|
||||||
|
|
||||||
|
name: str = StringField()
|
||||||
|
volume: float = FloatField()
|
||||||
|
last_price: float = FloatField()
|
||||||
|
last_volume: float = FloatField()
|
||||||
|
limit_up: float = FloatField()
|
||||||
|
limit_down: float = FloatField()
|
||||||
|
|
||||||
|
open_price: float = FloatField()
|
||||||
|
high_price: float = FloatField()
|
||||||
|
low_price: float = FloatField()
|
||||||
|
close_price: float = FloatField()
|
||||||
|
pre_close: float = FloatField()
|
||||||
|
|
||||||
|
bid_price_1: float = FloatField()
|
||||||
|
bid_price_2: float = FloatField()
|
||||||
|
bid_price_3: float = FloatField()
|
||||||
|
bid_price_4: float = FloatField()
|
||||||
|
bid_price_5: float = FloatField()
|
||||||
|
|
||||||
|
ask_price_1: float = FloatField()
|
||||||
|
ask_price_2: float = FloatField()
|
||||||
|
ask_price_3: float = FloatField()
|
||||||
|
ask_price_4: float = FloatField()
|
||||||
|
ask_price_5: float = FloatField()
|
||||||
|
|
||||||
|
bid_volume_1: float = FloatField()
|
||||||
|
bid_volume_2: float = FloatField()
|
||||||
|
bid_volume_3: float = FloatField()
|
||||||
|
bid_volume_4: float = FloatField()
|
||||||
|
bid_volume_5: float = FloatField()
|
||||||
|
|
||||||
|
ask_volume_1: float = FloatField()
|
||||||
|
ask_volume_2: float = FloatField()
|
||||||
|
ask_volume_3: float = FloatField()
|
||||||
|
ask_volume_4: float = FloatField()
|
||||||
|
ask_volume_5: float = FloatField()
|
||||||
|
|
||||||
|
meta = {"indexes": [{"fields": ("datetime", "symbol", "exchange"), "unique": True}]}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_tick(tick: TickData):
|
||||||
|
"""
|
||||||
|
Generate DbTickData object from TickData.
|
||||||
|
"""
|
||||||
|
db_tick = DbTickData()
|
||||||
|
|
||||||
|
db_tick.symbol = tick.symbol
|
||||||
|
db_tick.exchange = tick.exchange.value
|
||||||
|
db_tick.datetime = tick.datetime
|
||||||
|
db_tick.name = tick.name
|
||||||
|
db_tick.volume = tick.volume
|
||||||
|
db_tick.last_price = tick.last_price
|
||||||
|
db_tick.last_volume = tick.last_volume
|
||||||
|
db_tick.limit_up = tick.limit_up
|
||||||
|
db_tick.limit_down = tick.limit_down
|
||||||
|
db_tick.open_price = tick.open_price
|
||||||
|
db_tick.high_price = tick.high_price
|
||||||
|
db_tick.low_price = tick.low_price
|
||||||
|
db_tick.pre_close = tick.pre_close
|
||||||
|
|
||||||
|
db_tick.bid_price_1 = tick.bid_price_1
|
||||||
|
db_tick.ask_price_1 = tick.ask_price_1
|
||||||
|
db_tick.bid_volume_1 = tick.bid_volume_1
|
||||||
|
db_tick.ask_volume_1 = tick.ask_volume_1
|
||||||
|
|
||||||
|
if tick.bid_price_2:
|
||||||
|
db_tick.bid_price_2 = tick.bid_price_2
|
||||||
|
db_tick.bid_price_3 = tick.bid_price_3
|
||||||
|
db_tick.bid_price_4 = tick.bid_price_4
|
||||||
|
db_tick.bid_price_5 = tick.bid_price_5
|
||||||
|
|
||||||
|
db_tick.ask_price_2 = tick.ask_price_2
|
||||||
|
db_tick.ask_price_3 = tick.ask_price_3
|
||||||
|
db_tick.ask_price_4 = tick.ask_price_4
|
||||||
|
db_tick.ask_price_5 = tick.ask_price_5
|
||||||
|
|
||||||
|
db_tick.bid_volume_2 = tick.bid_volume_2
|
||||||
|
db_tick.bid_volume_3 = tick.bid_volume_3
|
||||||
|
db_tick.bid_volume_4 = tick.bid_volume_4
|
||||||
|
db_tick.bid_volume_5 = tick.bid_volume_5
|
||||||
|
|
||||||
|
db_tick.ask_volume_2 = tick.ask_volume_2
|
||||||
|
db_tick.ask_volume_3 = tick.ask_volume_3
|
||||||
|
db_tick.ask_volume_4 = tick.ask_volume_4
|
||||||
|
db_tick.ask_volume_5 = tick.ask_volume_5
|
||||||
|
|
||||||
|
return db_tick
|
||||||
|
|
||||||
|
def to_tick(self):
|
||||||
|
"""
|
||||||
|
Generate TickData object from DbTickData.
|
||||||
|
"""
|
||||||
|
tick = TickData(
|
||||||
|
symbol=self.symbol,
|
||||||
|
exchange=Exchange(self.exchange),
|
||||||
|
datetime=self.datetime,
|
||||||
|
name=self.name,
|
||||||
|
volume=self.volume,
|
||||||
|
last_price=self.last_price,
|
||||||
|
last_volume=self.last_volume,
|
||||||
|
limit_up=self.limit_up,
|
||||||
|
limit_down=self.limit_down,
|
||||||
|
open_price=self.open_price,
|
||||||
|
high_price=self.high_price,
|
||||||
|
low_price=self.low_price,
|
||||||
|
pre_close=self.pre_close,
|
||||||
|
bid_price_1=self.bid_price_1,
|
||||||
|
ask_price_1=self.ask_price_1,
|
||||||
|
bid_volume_1=self.bid_volume_1,
|
||||||
|
ask_volume_1=self.ask_volume_1,
|
||||||
|
gateway_name="DB",
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.bid_price_2:
|
||||||
|
tick.bid_price_2 = self.bid_price_2
|
||||||
|
tick.bid_price_3 = self.bid_price_3
|
||||||
|
tick.bid_price_4 = self.bid_price_4
|
||||||
|
tick.bid_price_5 = self.bid_price_5
|
||||||
|
|
||||||
|
tick.ask_price_2 = self.ask_price_2
|
||||||
|
tick.ask_price_3 = self.ask_price_3
|
||||||
|
tick.ask_price_4 = self.ask_price_4
|
||||||
|
tick.ask_price_5 = self.ask_price_5
|
||||||
|
|
||||||
|
tick.bid_volume_2 = self.bid_volume_2
|
||||||
|
tick.bid_volume_3 = self.bid_volume_3
|
||||||
|
tick.bid_volume_4 = self.bid_volume_4
|
||||||
|
tick.bid_volume_5 = self.bid_volume_5
|
||||||
|
|
||||||
|
tick.ask_volume_2 = self.ask_volume_2
|
||||||
|
tick.ask_volume_3 = self.ask_volume_3
|
||||||
|
tick.ask_volume_4 = self.ask_volume_4
|
||||||
|
tick.ask_volume_5 = self.ask_volume_5
|
||||||
|
|
||||||
|
return tick
|
||||||
|
|
||||||
|
|
||||||
|
class MongoManager(BaseDatabaseManager):
|
||||||
|
def load_bar_data(
|
||||||
|
self,
|
||||||
|
symbol: str,
|
||||||
|
exchange: Exchange,
|
||||||
|
interval: Interval,
|
||||||
|
start: datetime,
|
||||||
|
end: datetime,
|
||||||
|
) -> Sequence[BarData]:
|
||||||
|
s = DbBarData.objects(
|
||||||
|
symbol=symbol,
|
||||||
|
exchange=exchange.value,
|
||||||
|
interval=interval.value,
|
||||||
|
datetime__gte=start,
|
||||||
|
datetime__lte=end,
|
||||||
|
)
|
||||||
|
data = [db_bar.to_bar() for db_bar in s]
|
||||||
|
return data
|
||||||
|
|
||||||
|
def load_tick_data(
|
||||||
|
self, symbol: str, exchange: Exchange, start: datetime, end: datetime
|
||||||
|
) -> Sequence[TickData]:
|
||||||
|
s = DbTickData.objects(
|
||||||
|
symbol=symbol,
|
||||||
|
exchange=exchange.value,
|
||||||
|
datetime__gte=start,
|
||||||
|
datetime__lte=end,
|
||||||
|
)
|
||||||
|
data = [db_tick.to_tick() for db_tick in s]
|
||||||
|
return data
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def to_update_param(d):
|
||||||
|
return {
|
||||||
|
"set__" + k: v.value if isinstance(v, Enum) else v
|
||||||
|
for k, v in d.__dict__.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
def save_bar_data(self, datas: Sequence[BarData]):
|
||||||
|
for d in datas:
|
||||||
|
updates = self.to_update_param(d)
|
||||||
|
updates.pop("set__gateway_name")
|
||||||
|
updates.pop("set__vt_symbol")
|
||||||
|
(
|
||||||
|
DbBarData.objects(
|
||||||
|
symbol=d.symbol, interval=d.interval.value, datetime=d.datetime
|
||||||
|
).update_one(upsert=True, **updates)
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_tick_data(self, datas: Sequence[TickData]):
|
||||||
|
for d in datas:
|
||||||
|
updates = self.to_update_param(d)
|
||||||
|
updates.pop("set__gateway_name")
|
||||||
|
updates.pop("set__vt_symbol")
|
||||||
|
(
|
||||||
|
DbTickData.objects(
|
||||||
|
symbol=d.symbol, exchange=d.exchange.value, datetime=d.datetime
|
||||||
|
).update_one(upsert=True, **updates)
|
||||||
|
)
|
369
vnpy/trader/database/database_sql.py
Normal file
369
vnpy/trader/database/database_sql.py
Normal file
@ -0,0 +1,369 @@
|
|||||||
|
""""""
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Sequence, Type
|
||||||
|
|
||||||
|
from peewee import (
|
||||||
|
AutoField,
|
||||||
|
CharField,
|
||||||
|
Database,
|
||||||
|
DateTimeField,
|
||||||
|
FloatField,
|
||||||
|
Model,
|
||||||
|
MySQLDatabase,
|
||||||
|
PostgresqlDatabase,
|
||||||
|
SqliteDatabase,
|
||||||
|
chunked,
|
||||||
|
)
|
||||||
|
|
||||||
|
from vnpy.trader.constant import Exchange, Interval
|
||||||
|
from vnpy.trader.object import BarData, TickData
|
||||||
|
from vnpy.trader.utility import get_file_path
|
||||||
|
from .database import BaseDatabaseManager, Driver
|
||||||
|
|
||||||
|
|
||||||
|
def init(driver: Driver, settings: dict):
|
||||||
|
init_funcs = {
|
||||||
|
Driver.SQLITE: init_sqlite,
|
||||||
|
Driver.MYSQL: init_mysql,
|
||||||
|
Driver.POSTGRESQL: init_postgresql,
|
||||||
|
}
|
||||||
|
assert driver in init_funcs
|
||||||
|
|
||||||
|
db = init_funcs[driver](settings)
|
||||||
|
bar, tick = init_models(db, driver)
|
||||||
|
return SqlManager(bar, tick)
|
||||||
|
|
||||||
|
|
||||||
|
def init_sqlite(settings: dict):
|
||||||
|
database = settings["database"]
|
||||||
|
path = str(get_file_path(database))
|
||||||
|
db = SqliteDatabase(path)
|
||||||
|
return db
|
||||||
|
|
||||||
|
|
||||||
|
def init_mysql(settings: dict):
|
||||||
|
keys = {"database", "user", "password", "host", "port"}
|
||||||
|
settings = {k: v for k, v in settings.items() if k in keys}
|
||||||
|
db = MySQLDatabase(**settings)
|
||||||
|
return db
|
||||||
|
|
||||||
|
|
||||||
|
def init_postgresql(settings: dict):
|
||||||
|
keys = {"database", "user", "password", "host", "port"}
|
||||||
|
settings = {k: v for k, v in settings.items() if k in keys}
|
||||||
|
db = PostgresqlDatabase(**settings)
|
||||||
|
return db
|
||||||
|
|
||||||
|
|
||||||
|
class ModelBase(Model):
|
||||||
|
def to_dict(self):
|
||||||
|
return self.__data__
|
||||||
|
|
||||||
|
|
||||||
|
def init_models(db: Database, driver: Driver):
|
||||||
|
class DbBarData(ModelBase):
|
||||||
|
"""
|
||||||
|
Candlestick bar data for database storage.
|
||||||
|
|
||||||
|
Index is defined unique with datetime, interval, symbol
|
||||||
|
"""
|
||||||
|
|
||||||
|
id = AutoField()
|
||||||
|
symbol: str = CharField()
|
||||||
|
exchange: str = CharField()
|
||||||
|
datetime: datetime = DateTimeField()
|
||||||
|
interval: str = CharField()
|
||||||
|
|
||||||
|
volume: float = FloatField()
|
||||||
|
open_price: float = FloatField()
|
||||||
|
high_price: float = FloatField()
|
||||||
|
low_price: float = FloatField()
|
||||||
|
close_price: float = FloatField()
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
database = db
|
||||||
|
indexes = ((("datetime", "interval", "symbol", "exchange"), True),)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_bar(bar: BarData):
|
||||||
|
"""
|
||||||
|
Generate DbBarData object from BarData.
|
||||||
|
"""
|
||||||
|
db_bar = DbBarData()
|
||||||
|
|
||||||
|
db_bar.symbol = bar.symbol
|
||||||
|
db_bar.exchange = bar.exchange.value
|
||||||
|
db_bar.datetime = bar.datetime
|
||||||
|
db_bar.interval = bar.interval.value
|
||||||
|
db_bar.volume = bar.volume
|
||||||
|
db_bar.open_price = bar.open_price
|
||||||
|
db_bar.high_price = bar.high_price
|
||||||
|
db_bar.low_price = bar.low_price
|
||||||
|
db_bar.close_price = bar.close_price
|
||||||
|
|
||||||
|
return db_bar
|
||||||
|
|
||||||
|
def to_bar(self):
|
||||||
|
"""
|
||||||
|
Generate BarData object from DbBarData.
|
||||||
|
"""
|
||||||
|
bar = BarData(
|
||||||
|
symbol=self.symbol,
|
||||||
|
exchange=Exchange(self.exchange),
|
||||||
|
datetime=self.datetime,
|
||||||
|
interval=Interval(self.interval),
|
||||||
|
volume=self.volume,
|
||||||
|
open_price=self.open_price,
|
||||||
|
high_price=self.high_price,
|
||||||
|
low_price=self.low_price,
|
||||||
|
close_price=self.close_price,
|
||||||
|
gateway_name="DB",
|
||||||
|
)
|
||||||
|
return bar
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def save_all(objs: List["DbBarData"]):
|
||||||
|
"""
|
||||||
|
save a list of objects, update if exists.
|
||||||
|
"""
|
||||||
|
dicts = [i.to_dict() for i in objs]
|
||||||
|
with db.atomic():
|
||||||
|
if driver is Driver.POSTGRESQL:
|
||||||
|
for bar in dicts:
|
||||||
|
DbBarData.insert(bar).on_conflict(
|
||||||
|
update=bar,
|
||||||
|
conflict_target=(
|
||||||
|
DbBarData.datetime,
|
||||||
|
DbBarData.interval,
|
||||||
|
DbBarData.symbol,
|
||||||
|
DbBarData.exchange,
|
||||||
|
),
|
||||||
|
).execute()
|
||||||
|
else:
|
||||||
|
for c in chunked(dicts, 50):
|
||||||
|
DbBarData.insert_many(c).on_conflict_replace().execute()
|
||||||
|
|
||||||
|
class DbTickData(ModelBase):
|
||||||
|
"""
|
||||||
|
Tick data for database storage.
|
||||||
|
|
||||||
|
Index is defined unique with (datetime, symbol)
|
||||||
|
"""
|
||||||
|
|
||||||
|
id = AutoField()
|
||||||
|
|
||||||
|
symbol: str = CharField()
|
||||||
|
exchange: str = CharField()
|
||||||
|
datetime: datetime = DateTimeField()
|
||||||
|
|
||||||
|
name: str = CharField()
|
||||||
|
volume: float = FloatField()
|
||||||
|
last_price: float = FloatField()
|
||||||
|
last_volume: float = FloatField()
|
||||||
|
limit_up: float = FloatField()
|
||||||
|
limit_down: float = FloatField()
|
||||||
|
|
||||||
|
open_price: float = FloatField()
|
||||||
|
high_price: float = FloatField()
|
||||||
|
low_price: float = FloatField()
|
||||||
|
pre_close: float = FloatField()
|
||||||
|
|
||||||
|
bid_price_1: float = FloatField()
|
||||||
|
bid_price_2: float = FloatField(null=True)
|
||||||
|
bid_price_3: float = FloatField(null=True)
|
||||||
|
bid_price_4: float = FloatField(null=True)
|
||||||
|
bid_price_5: float = FloatField(null=True)
|
||||||
|
|
||||||
|
ask_price_1: float = FloatField()
|
||||||
|
ask_price_2: float = FloatField(null=True)
|
||||||
|
ask_price_3: float = FloatField(null=True)
|
||||||
|
ask_price_4: float = FloatField(null=True)
|
||||||
|
ask_price_5: float = FloatField(null=True)
|
||||||
|
|
||||||
|
bid_volume_1: float = FloatField()
|
||||||
|
bid_volume_2: float = FloatField(null=True)
|
||||||
|
bid_volume_3: float = FloatField(null=True)
|
||||||
|
bid_volume_4: float = FloatField(null=True)
|
||||||
|
bid_volume_5: float = FloatField(null=True)
|
||||||
|
|
||||||
|
ask_volume_1: float = FloatField()
|
||||||
|
ask_volume_2: float = FloatField(null=True)
|
||||||
|
ask_volume_3: float = FloatField(null=True)
|
||||||
|
ask_volume_4: float = FloatField(null=True)
|
||||||
|
ask_volume_5: float = FloatField(null=True)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
database = db
|
||||||
|
indexes = ((("datetime", "symbol", "exchange"), True),)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_tick(tick: TickData):
|
||||||
|
"""
|
||||||
|
Generate DbTickData object from TickData.
|
||||||
|
"""
|
||||||
|
db_tick = DbTickData()
|
||||||
|
|
||||||
|
db_tick.symbol = tick.symbol
|
||||||
|
db_tick.exchange = tick.exchange.value
|
||||||
|
db_tick.datetime = tick.datetime
|
||||||
|
db_tick.name = tick.name
|
||||||
|
db_tick.volume = tick.volume
|
||||||
|
db_tick.last_price = tick.last_price
|
||||||
|
db_tick.last_volume = tick.last_volume
|
||||||
|
db_tick.limit_up = tick.limit_up
|
||||||
|
db_tick.limit_down = tick.limit_down
|
||||||
|
db_tick.open_price = tick.open_price
|
||||||
|
db_tick.high_price = tick.high_price
|
||||||
|
db_tick.low_price = tick.low_price
|
||||||
|
db_tick.pre_close = tick.pre_close
|
||||||
|
|
||||||
|
db_tick.bid_price_1 = tick.bid_price_1
|
||||||
|
db_tick.ask_price_1 = tick.ask_price_1
|
||||||
|
db_tick.bid_volume_1 = tick.bid_volume_1
|
||||||
|
db_tick.ask_volume_1 = tick.ask_volume_1
|
||||||
|
|
||||||
|
if tick.bid_price_2:
|
||||||
|
db_tick.bid_price_2 = tick.bid_price_2
|
||||||
|
db_tick.bid_price_3 = tick.bid_price_3
|
||||||
|
db_tick.bid_price_4 = tick.bid_price_4
|
||||||
|
db_tick.bid_price_5 = tick.bid_price_5
|
||||||
|
|
||||||
|
db_tick.ask_price_2 = tick.ask_price_2
|
||||||
|
db_tick.ask_price_3 = tick.ask_price_3
|
||||||
|
db_tick.ask_price_4 = tick.ask_price_4
|
||||||
|
db_tick.ask_price_5 = tick.ask_price_5
|
||||||
|
|
||||||
|
db_tick.bid_volume_2 = tick.bid_volume_2
|
||||||
|
db_tick.bid_volume_3 = tick.bid_volume_3
|
||||||
|
db_tick.bid_volume_4 = tick.bid_volume_4
|
||||||
|
db_tick.bid_volume_5 = tick.bid_volume_5
|
||||||
|
|
||||||
|
db_tick.ask_volume_2 = tick.ask_volume_2
|
||||||
|
db_tick.ask_volume_3 = tick.ask_volume_3
|
||||||
|
db_tick.ask_volume_4 = tick.ask_volume_4
|
||||||
|
db_tick.ask_volume_5 = tick.ask_volume_5
|
||||||
|
|
||||||
|
return db_tick
|
||||||
|
|
||||||
|
def to_tick(self):
|
||||||
|
"""
|
||||||
|
Generate TickData object from DbTickData.
|
||||||
|
"""
|
||||||
|
tick = TickData(
|
||||||
|
symbol=self.symbol,
|
||||||
|
exchange=Exchange(self.exchange),
|
||||||
|
datetime=self.datetime,
|
||||||
|
name=self.name,
|
||||||
|
volume=self.volume,
|
||||||
|
last_price=self.last_price,
|
||||||
|
last_volume=self.last_volume,
|
||||||
|
limit_up=self.limit_up,
|
||||||
|
limit_down=self.limit_down,
|
||||||
|
open_price=self.open_price,
|
||||||
|
high_price=self.high_price,
|
||||||
|
low_price=self.low_price,
|
||||||
|
pre_close=self.pre_close,
|
||||||
|
bid_price_1=self.bid_price_1,
|
||||||
|
ask_price_1=self.ask_price_1,
|
||||||
|
bid_volume_1=self.bid_volume_1,
|
||||||
|
ask_volume_1=self.ask_volume_1,
|
||||||
|
gateway_name="DB",
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.bid_price_2:
|
||||||
|
tick.bid_price_2 = self.bid_price_2
|
||||||
|
tick.bid_price_3 = self.bid_price_3
|
||||||
|
tick.bid_price_4 = self.bid_price_4
|
||||||
|
tick.bid_price_5 = self.bid_price_5
|
||||||
|
|
||||||
|
tick.ask_price_2 = self.ask_price_2
|
||||||
|
tick.ask_price_3 = self.ask_price_3
|
||||||
|
tick.ask_price_4 = self.ask_price_4
|
||||||
|
tick.ask_price_5 = self.ask_price_5
|
||||||
|
|
||||||
|
tick.bid_volume_2 = self.bid_volume_2
|
||||||
|
tick.bid_volume_3 = self.bid_volume_3
|
||||||
|
tick.bid_volume_4 = self.bid_volume_4
|
||||||
|
tick.bid_volume_5 = self.bid_volume_5
|
||||||
|
|
||||||
|
tick.ask_volume_2 = self.ask_volume_2
|
||||||
|
tick.ask_volume_3 = self.ask_volume_3
|
||||||
|
tick.ask_volume_4 = self.ask_volume_4
|
||||||
|
tick.ask_volume_5 = self.ask_volume_5
|
||||||
|
|
||||||
|
return tick
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def save_all(objs: List["DbTickData"]):
|
||||||
|
dicts = [i.to_dict() for i in objs]
|
||||||
|
with db.atomic():
|
||||||
|
if driver is Driver.POSTGRESQL:
|
||||||
|
for tick in dicts:
|
||||||
|
DbTickData.insert(tick).on_conflict(
|
||||||
|
update=tick,
|
||||||
|
conflict_target=(
|
||||||
|
DbTickData.datetime,
|
||||||
|
DbTickData.symbol,
|
||||||
|
DbTickData.exchange,
|
||||||
|
),
|
||||||
|
).execute()
|
||||||
|
else:
|
||||||
|
for c in chunked(dicts, 50):
|
||||||
|
DbTickData.insert_many(c).on_conflict_replace().execute()
|
||||||
|
|
||||||
|
db.connect()
|
||||||
|
db.create_tables([DbBarData, DbTickData])
|
||||||
|
return DbBarData, DbTickData
|
||||||
|
|
||||||
|
|
||||||
|
class SqlManager(BaseDatabaseManager):
|
||||||
|
def __init__(self, class_bar: Type[Model], class_tick: Type[Model]):
|
||||||
|
self.class_bar = class_bar
|
||||||
|
self.class_tick = class_tick
|
||||||
|
|
||||||
|
def load_bar_data(
|
||||||
|
self,
|
||||||
|
symbol: str,
|
||||||
|
exchange: Exchange,
|
||||||
|
interval: Interval,
|
||||||
|
start: datetime,
|
||||||
|
end: datetime,
|
||||||
|
) -> Sequence[BarData]:
|
||||||
|
s = (
|
||||||
|
self.class_bar.select()
|
||||||
|
.where(
|
||||||
|
(self.class_bar.symbol == symbol)
|
||||||
|
& (self.class_bar.exchange == exchange.value)
|
||||||
|
& (self.class_bar.interval == interval.value)
|
||||||
|
& (self.class_bar.datetime >= start)
|
||||||
|
& (self.class_bar.datetime <= end)
|
||||||
|
)
|
||||||
|
.order_by(self.class_bar.datetime)
|
||||||
|
)
|
||||||
|
data = [db_bar.to_bar() for db_bar in s]
|
||||||
|
return data
|
||||||
|
|
||||||
|
def load_tick_data(
|
||||||
|
self, symbol: str, exchange: Exchange, start: datetime, end: datetime
|
||||||
|
) -> Sequence[TickData]:
|
||||||
|
s = (
|
||||||
|
self.class_tick.select()
|
||||||
|
.where(
|
||||||
|
(self.class_tick.symbol == symbol)
|
||||||
|
& (self.class_tick.exchange == exchange.value)
|
||||||
|
& (self.class_tick.datetime >= start)
|
||||||
|
& (self.class_tick.datetime <= end)
|
||||||
|
)
|
||||||
|
.order_by(self.class_tick.datetime)
|
||||||
|
)
|
||||||
|
|
||||||
|
data = [db_tick.to_tick() for db_tick in s]
|
||||||
|
return data
|
||||||
|
|
||||||
|
def save_bar_data(self, datas: Sequence[BarData]):
|
||||||
|
ds = [self.class_bar.from_bar(i) for i in datas]
|
||||||
|
self.class_bar.save_all(ds)
|
||||||
|
|
||||||
|
def save_tick_data(self, datas: Sequence[TickData]):
|
||||||
|
ds = [self.class_tick.from_tick(i) for i in datas]
|
||||||
|
self.class_tick.save_all(ds)
|
24
vnpy/trader/database/initialize.py
Normal file
24
vnpy/trader/database/initialize.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
""""""
|
||||||
|
from .database import BaseDatabaseManager, Driver
|
||||||
|
|
||||||
|
|
||||||
|
def init(settings: dict) -> BaseDatabaseManager:
|
||||||
|
driver = Driver(settings["driver"])
|
||||||
|
if driver is Driver.MONGODB:
|
||||||
|
return init_nosql(driver=driver, settings=settings)
|
||||||
|
else:
|
||||||
|
return init_sql(driver=driver, settings=settings)
|
||||||
|
|
||||||
|
|
||||||
|
def init_sql(driver: Driver, settings: dict):
|
||||||
|
from .database_sql import init
|
||||||
|
keys = {'database', "host", "port", "user", "password"}
|
||||||
|
settings = {k: v for k, v in settings.items() if k in keys}
|
||||||
|
_database_manager = init(driver, settings)
|
||||||
|
return _database_manager
|
||||||
|
|
||||||
|
|
||||||
|
def init_nosql(driver: Driver, settings: dict):
|
||||||
|
from .database_mongo import init
|
||||||
|
_database_manager = init(driver, settings=settings)
|
||||||
|
return _database_manager
|
@ -24,16 +24,21 @@ SETTINGS = {
|
|||||||
|
|
||||||
"rqdata.username": "",
|
"rqdata.username": "",
|
||||||
"rqdata.password": "",
|
"rqdata.password": "",
|
||||||
"database": {
|
|
||||||
"driver": "sqlite", # sqlite, mysql, postgresql
|
"database.driver": "sqlite", # see database.Driver
|
||||||
"database": "{VNPY_TEMP}/database.db", # for sqlite, use this as filepath
|
"database.database": "database.db", # for sqlite, use this as filepath
|
||||||
"host": "localhost",
|
"database.host": "localhost",
|
||||||
"port": 3306,
|
"database.port": 3306,
|
||||||
"user": "root",
|
"database.user": "root",
|
||||||
"password": ""
|
"database.password": "",
|
||||||
}
|
"database.authentication_source": "admin", # for mongodb
|
||||||
}
|
}
|
||||||
|
|
||||||
# Load global setting from json file.
|
# Load global setting from json file.
|
||||||
SETTING_FILENAME = "vt_setting.json"
|
SETTING_FILENAME = "vt_setting.json"
|
||||||
SETTINGS.update(load_json(SETTING_FILENAME))
|
SETTINGS.update(load_json(SETTING_FILENAME))
|
||||||
|
|
||||||
|
|
||||||
|
def get_settings(prefix: str = ""):
|
||||||
|
prefix_length = len(prefix)
|
||||||
|
return {k[prefix_length:]: v for k, v in SETTINGS.items() if k.startswith(prefix)}
|
||||||
|
@ -3,20 +3,28 @@ General utility functions.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable
|
from typing import Callable, TYPE_CHECKING
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import talib
|
import talib
|
||||||
|
|
||||||
from .object import BarData, TickData
|
from .object import BarData, TickData
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vnpy.trader.constant import Exchange
|
||||||
|
|
||||||
def resolve_path(pattern: str):
|
|
||||||
env = dict(os.environ)
|
def extract_vt_symbol(vt_symbol: str):
|
||||||
env.update({"VNPY_TEMP": str(TEMP_DIR)})
|
"""
|
||||||
return pattern.format(**env)
|
:return: (symbol, exchange)
|
||||||
|
"""
|
||||||
|
symbol, exchange = vt_symbol.split('.')
|
||||||
|
return symbol, exchange
|
||||||
|
|
||||||
|
|
||||||
|
def generate_vt_symbol(symbol: str, exchange: "Exchange"):
|
||||||
|
return f'{symbol}.{exchange.value}'
|
||||||
|
|
||||||
|
|
||||||
def _get_trader_dir(temp_name: str):
|
def _get_trader_dir(temp_name: str):
|
||||||
|
Loading…
Reference in New Issue
Block a user