diff --git a/.travis.yml b/.travis.yml index 8622616d..0ef835b3 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,12 +2,27 @@ language: python dist: xenial # required for Python >= 3.7 (travis-ci/travis-ci#9069) +cache: pip + +git: + depth: 1 + python: - "3.7" +services: + - mongodb + - mysql + - postgresql + +before_script: + - psql -d postgresql://postgres:${VNPY_TEST_POSTGRESQL_PASSWORD}@localhost -c "create database vnpy;" + - mysql -u root --password=${VNPY_TEST_MYSQL_PASSWORD} -e 'CREATE DATABASE vnpy;' + script: - # todo: use python unittest - - mkdir run; cd run; python ../tests/load_all.py + - pip install psycopg2 mongoengine pymysql # we should support all database in test environment + - cd tests; source travis_env.sh; + - python test_all.py matrix: include: @@ -19,81 +34,36 @@ matrix: script: - flake8 - - name: "pip install under Windows" - os: "windows" - # language : cpp is necessary for windows - language: "cpp" - env: - - PATH=/c/Python37:/c/Python37/Scripts:$PATH - before_install: - - choco install python3 --version 3.7.2 - install: - - python -m pip install --upgrade pip wheel setuptools - - pip install https://pip.vnpy.com/colletion/TA_Lib-0.4.17-cp37-cp37m-win_amd64.whl - - pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl - - pip install -r requirements.txt - - pip install . - - name: "pip install under Ubuntu: gcc-8" + addons: + apt: + sources: + - ubuntu-toolchain-r-test + packages: + - g++-8 before_install: - # C++17 - - sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test - - sudo apt-get update -y - install: - # C++17 - - sudo apt-get install -y gcc-8 g++-8 - sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-8 90 - sudo update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++-8 90 - sudo update-alternatives --install /usr/bin/gcc cc /usr/bin/gcc-8 90 + install: # update pip & setuptools - python -m pip install --upgrade pip wheel setuptools # Linux install script - pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl - bash ./install.sh - - name: "pip install under Ubuntu: gcc-7" + - name: "sdist install under Ubuntu: gcc-7" + addons: + apt: + sources: + - ubuntu-toolchain-r-test + packages: + - g++-7 before_install: - # C++17 - - sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test - - sudo apt-get update -y - install: - # C++17 - - sudo apt-get install -y gcc-7 g++-7 - sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-7 90 - sudo update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++-7 90 - sudo update-alternatives --install /usr/bin/gcc cc /usr/bin/gcc-7 90 - # update pip & setuptools - - python -m pip install --upgrade pip wheel setuptools - # Linux install script - - pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl - - bash ./install.sh - - - name: "sdist install under Windows" - os: "windows" - # language : cpp is necessary for windows - language: "cpp" - env: - - PATH=/c/Python37:/c/Python37/Scripts:$PATH - before_install: - - choco install python3 --version 3.7.2 install: - - python -m pip install --upgrade pip wheel setuptools - - python setup.py sdist - - pip install https://pip.vnpy.com/colletion/TA_Lib-0.4.17-cp37-cp37m-win_amd64.whl - - pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl - - pip install dist/`ls dist` - - - name: "sdist install under Ubuntu: gcc-8" - before_install: - # C++17 - - sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test - - sudo apt-get update -y - install: - # C++17 - - sudo apt-get install -y gcc-8 g++-8 - - sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-8 90 - - sudo update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++-8 90 - - sudo update-alternatives --install /usr/bin/gcc cc /usr/bin/gcc-8 90 # Linux install script - python -m pip install --upgrade pip wheel setuptools - pushd /tmp @@ -108,3 +78,17 @@ matrix: - pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl - python setup.py sdist - pip install dist/`ls dist` + + - name: "pip install under osx" + os: osx + language: shell # osx supports only shell + services: [] + before_install: [] + install: + - pip3 install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl + - bash ./install_osx.sh + before_script: [] + script: + - pip3 install psycopg2 mongoengine pymysql # we should support all database in test environment + - cd tests; source travis_env.sh; + - VNPY_TEST_ONLY_SQLITE=1 python3 test_all.py diff --git a/appveyor.yml b/appveyor.yml new file mode 100644 index 00000000..437edefc --- /dev/null +++ b/appveyor.yml @@ -0,0 +1,68 @@ +clone_depth: 1 + +cache: + - '%LOCALAPPDATA%\pip\Cache' + +configuration: + - pip + - sdist + +image: + - Visual Studio 2017 + +services: + - mysql + - mongodb + - postgresql + +environment: + PATH: C:\Python37-x64;C:\Python37-x64\Scripts;C:\Program Files\PostgreSQL\9.6\bin\;%PATH% + VNPY_TEST_MYSQL_DATABASE: vnpy + VNPY_TEST_MYSQL_HOST: localhost + VNPY_TEST_MYSQL_PORT: 3306 + VNPY_TEST_MYSQL_USER: root + VNPY_TEST_MYSQL_PASSWORD: Password12! + + VNPY_TEST_POSTGRESQL_DATABASE: vnpy + VNPY_TEST_POSTGRESQL_HOST: localhost + VNPY_TEST_POSTGRESQL_PORT: 5432 + VNPY_TEST_POSTGRESQL_USER: postgres + VNPY_TEST_POSTGRESQL_PASSWORD: Password12! + + VNPY_TEST_MONGODB_DATABASE: vnpy + VNPY_TEST_MONGODB_HOST: localhost + VNPY_TEST_MONGODB_PORT: 27017 + + MYSQL_PWD: Password12! + +install: + - python -m pip install --upgrade pip wheel setuptools +before_build: + - ps: psql -d "postgresql://${ENV:VNPY_TEST_POSTGRESQL_USER}:${ENV:VNPY_TEST_POSTGRESQL_PASSWORD}@localhost" -c "create database vnpy;" + - ps: . "C:\Program Files\MySQL\MySQL Server 5.7\bin\mysql" -u $ENV:VNPY_TEST_MYSQL_USER -e 'CREATE DATABASE vnpy;' + +for: + - matrix: + only: + - configuration: pip + build_script: + - pip install psycopg2 mongoengine pymysql # we should support all database in test environment + - pip install https://pip.vnpy.com/colletion/TA_Lib-0.4.17-cp37-cp37m-win_amd64.whl + - pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl + - pip install -r requirements.txt + - pip install . + + - matrix: + only: + - configuration: sdist + build_script: + - python setup.py sdist + - pip install psycopg2 mongoengine pymysql # we should support all database in test environment + - pip install https://pip.vnpy.com/colletion/TA_Lib-0.4.17-cp37-cp37m-win_amd64.whl + - pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl + - ps: $name=(ls dist).name; pip install "dist/$name" + +test_script: + - cd tests + - python test_all.py + diff --git a/install.sh b/install.sh index 2cf2043c..dc748ee5 100644 --- a/install.sh +++ b/install.sh @@ -1,27 +1,35 @@ #!/usr/bin/env bash +python=$1 +pip=$2 +prefix=$3 + +[[ -z $python ]] && python=python +[[ -z $pip ]] && pip=pip +[[ -z $prefix ]] && prefix=/usr + # Get and build ta-lib pushd /tmp wget http://prdownloads.sourceforge.net/ta-lib/ta-lib-0.4.0-src.tar.gz tar -xf ta-lib-0.4.0-src.tar.gz cd ta-lib -./configure --prefix=/usr -make +./configure --prefix=$prefix +make -j sudo make install popd # old versions of ta-lib imports numpy in setup.py -pip install numpy +$pip install numpy # Install extra packages -pip install ta-lib -pip install https://vnpy-pip.oss-cn-shanghai.aliyuncs.com/colletion/ibapi-9.75.1-py3-none-any.whl +$pip install ta-lib +$pip install https://vnpy-pip.oss-cn-shanghai.aliyuncs.com/colletion/ibapi-9.75.1-py3-none-any.whl # Install Python Modules -pip install -r requirements.txt +$pip install -r requirements.txt # Install local Chinese language environment sudo locale-gen zh_CN.GB18030 # Install vn.py -pip install . \ No newline at end of file +$pip install . \ No newline at end of file diff --git a/install_osx.sh b/install_osx.sh new file mode 100644 index 00000000..eb96cc53 --- /dev/null +++ b/install_osx.sh @@ -0,0 +1,3 @@ +#!/usr/bin/env bash + +bash ./install.sh python3 pip3 /usr/local \ No newline at end of file diff --git a/setup.py b/setup.py index 28250618..3b0d1ca7 100644 --- a/setup.py +++ b/setup.py @@ -29,70 +29,79 @@ with open("vnpy/__init__.py", "rb") as f: version = str(ast.literal_eval(version_line)) if platform.uname().system == "Windows": - compiler_flags = ["/MP", "/std:c++17", # standard - "/O2", "/Ob2", "/Oi", "/Ot", "/Oy", "/GL", # Optimization - "/wd4819" # 936 code page - ] + compiler_flags = [ + "/MP", "/std:c++17", # standard + "/O2", "/Ob2", "/Oi", "/Ot", "/Oy", "/GL", # Optimization + "/wd4819" # 936 code page + ] extra_link_args = [] else: - compiler_flags = ["-std=c++17", - "-Wno-delete-incomplete", "-Wno-sign-compare", - ] + compiler_flags = [ + "-std=c++17", # standard + "-O3", # Optimization + "-Wno-delete-incomplete", "-Wno-sign-compare", + ] extra_link_args = ["-lstdc++"] -vnctpmd = Extension("vnpy.api.ctp.vnctpmd", - [ - "vnpy/api/ctp/vnctp/vnctpmd/vnctpmd.cpp", - ], - include_dirs=["vnpy/api/ctp/include", - "vnpy/api/ctp/vnctp", ], - define_macros=[], - undef_macros=[], - library_dirs=["vnpy/api/ctp/libs", "vnpy/api/ctp"], - libraries=["thostmduserapi", "thosttraderapi", ], - extra_compile_args=compiler_flags, - extra_link_args=extra_link_args, - depends=[], - runtime_library_dirs=["$ORIGIN"], - language="cpp", - ) -vnctptd = Extension("vnpy.api.ctp.vnctptd", - [ - "vnpy/api/ctp/vnctp/vnctptd/vnctptd.cpp", - ], - include_dirs=["vnpy/api/ctp/include", - "vnpy/api/ctp/vnctp", ], - define_macros=[], - undef_macros=[], - library_dirs=["vnpy/api/ctp/libs", "vnpy/api/ctp"], - libraries=["thostmduserapi", "thosttraderapi", ], - extra_compile_args=compiler_flags, - extra_link_args=extra_link_args, - runtime_library_dirs=["$ORIGIN"], - depends=[], - language="cpp", - ) -vnoes = Extension("vnpy.api.oes.vnoes", - [ - "vnpy/api/oes/vnoes/generated_files/classes_1.cpp", - "vnpy/api/oes/vnoes/generated_files/classes_2.cpp", - "vnpy/api/oes/vnoes/generated_files/module.cpp", - ], - include_dirs=["vnpy/api/oes/include", - "vnpy/api/oes/vnoes", ], - define_macros=[("BRIGAND_NO_BOOST_SUPPORT", "1")], - undef_macros=[], - library_dirs=["vnpy/api/oes/libs"], - libraries=["oes_api"], - extra_compile_args=compiler_flags, - extra_link_args=extra_link_args, - depends=[], - language="cpp", - ) +vnctpmd = Extension( + "vnpy.api.ctp.vnctpmd", + [ + "vnpy/api/ctp/vnctp/vnctpmd/vnctpmd.cpp", + ], + include_dirs=["vnpy/api/ctp/include", + "vnpy/api/ctp/vnctp", ], + define_macros=[], + undef_macros=[], + library_dirs=["vnpy/api/ctp/libs", "vnpy/api/ctp"], + libraries=["thostmduserapi", "thosttraderapi", ], + extra_compile_args=compiler_flags, + extra_link_args=extra_link_args, + depends=[], + runtime_library_dirs=["$ORIGIN"], + language="cpp", +) +vnctptd = Extension( + "vnpy.api.ctp.vnctptd", + [ + "vnpy/api/ctp/vnctp/vnctptd/vnctptd.cpp", + ], + include_dirs=["vnpy/api/ctp/include", + "vnpy/api/ctp/vnctp", ], + define_macros=[], + undef_macros=[], + library_dirs=["vnpy/api/ctp/libs", "vnpy/api/ctp"], + libraries=["thostmduserapi", "thosttraderapi", ], + extra_compile_args=compiler_flags, + extra_link_args=extra_link_args, + runtime_library_dirs=["$ORIGIN"], + depends=[], + language="cpp", +) +vnoes = Extension( + "vnpy.api.oes.vnoes", + [ + "vnpy/api/oes/vnoes/generated_files/classes_1.cpp", + "vnpy/api/oes/vnoes/generated_files/classes_2.cpp", + "vnpy/api/oes/vnoes/generated_files/module.cpp", + ], + include_dirs=["vnpy/api/oes/include", + "vnpy/api/oes/vnoes", ], + define_macros=[("BRIGAND_NO_BOOST_SUPPORT", "1")], + undef_macros=[], + library_dirs=["vnpy/api/oes/libs"], + libraries=["oes_api"], + extra_compile_args=compiler_flags, + extra_link_args=extra_link_args, + runtime_library_dirs=["$ORIGIN"], + depends=[], + language="cpp", +) -if platform.uname().system == "Windows": +if platform.system() == "Windows": # use pre-built pyd for windows ( support python 3.7 only ) ext_modules = [] +elif platform.system() == "Darwin": + ext_modules = [] else: ext_modules = [vnctptd, vnctpmd, vnoes] diff --git a/tests/app/__init__.py b/tests/app/__init__.py new file mode 100644 index 00000000..9aafa67d --- /dev/null +++ b/tests/app/__init__.py @@ -0,0 +1 @@ +from .test_csv_loader import * diff --git a/tests/app/test_csv_loader.py b/tests/app/test_csv_loader.py new file mode 100644 index 00000000..b52a2183 --- /dev/null +++ b/tests/app/test_csv_loader.py @@ -0,0 +1,90 @@ +""" +Test if csv loader works fine +""" +import tempfile +import unittest + +from vnpy.app.csv_loader import CsvLoaderEngine +from vnpy.trader.constant import Exchange, Interval + + +class TestCsvLoader(unittest.TestCase): + + def setUp(self) -> None: + self.engine = CsvLoaderEngine(None, None) # no engine is necessary for CsvLoader + + def test_load(self): + data = """"Datetime","Open","High","Low","Close","Volume" +2010-04-16 09:16:00,3450.0,3488.0,3450.0,3468.0,489 +2010-04-16 09:17:00,3468.0,3473.8,3467.0,3467.0,302 +2010-04-16 09:18:00,3467.0,3471.0,3466.0,3467.0,203 +2010-04-16 09:19:00,3467.0,3468.2,3448.0,3448.0,280 +2010-04-16 09:20:00,3448.0,3459.0,3448.0,3454.0,250 +2010-04-16 09:21:00,3454.0,3456.8,3454.0,3456.8,109 +""" + with tempfile.TemporaryFile("w+t") as f: + f.write(data) + f.seek(0) + + self.engine.load_by_handle( + f, + symbol="1", + exchange=Exchange.BITMEX, + interval=Interval.MINUTE, + datetime_head="Datetime", + open_head="Open", + close_head="Close", + low_head="Low", + high_head="High", + volume_head="Volume", + datetime_format="%Y-%m-%d %H:%M:%S", + ) + + def test_load_duplicated(self): + data = """"Datetime","Open","High","Low","Close","Volume" +2010-04-16 09:16:00,3450.0,3488.0,3450.0,3468.0,489 +2010-04-16 09:17:00,3468.0,3473.8,3467.0,3467.0,302 +2010-04-16 09:18:00,3467.0,3471.0,3466.0,3467.0,203 +2010-04-16 09:19:00,3467.0,3468.2,3448.0,3448.0,280 +2010-04-16 09:20:00,3448.0,3459.0,3448.0,3454.0,250 +2010-04-16 09:21:00,3454.0,3456.8,3454.0,3456.8,109 +""" + with tempfile.TemporaryFile("w+t") as f: + f.write(data) + f.seek(0) + + self.engine.load_by_handle( + f, + symbol="1", + exchange=Exchange.BITMEX, + interval=Interval.MINUTE, + datetime_head="Datetime", + open_head="Open", + close_head="Close", + low_head="Low", + high_head="High", + volume_head="Volume", + datetime_format="%Y-%m-%d %H:%M:%S", + ) + + with tempfile.TemporaryFile("w+t") as f: + f.write(data) + f.seek(0) + + self.engine.load_by_handle( + f, + symbol="1", + exchange=Exchange.BITMEX, + interval=Interval.MINUTE, + datetime_head="Datetime", + open_head="Open", + close_head="Close", + low_head="Low", + high_head="High", + volume_head="Volume", + datetime_format="%Y-%m-%d %H:%M:%S", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/backtesting/getdata.py b/tests/backtesting/getdata.py index 77436bcc..62d02465 100644 --- a/tests/backtesting/getdata.py +++ b/tests/backtesting/getdata.py @@ -2,7 +2,7 @@ from time import time import rqdatac as rq -from vnpy.trader.database import DbBarData, DB +from vnpy.trader.database import DbBarData USERNAME = "" PASSWORD = "" @@ -39,11 +39,11 @@ def download_minute_bar(vt_symbol): df = rq.get_price(symbol, frequency="1m", fields=FIELDS) - with DB.atomic(): - for ix, row in df.iterrows(): - print(row.name) - bar = generate_bar_from_row(row, symbol, exchange) - DbBarData.replace(bar.__data__).execute() + bars = [] + for ix, row in df.iterrows(): + bar = generate_bar_from_row(row, symbol, exchange) + bars.append(bar) + DbBarData.save_all(bars) end = time() cost = (end - start) * 1000 diff --git a/tests/test_all.py b/tests/test_all.py new file mode 100644 index 00000000..59b7683e --- /dev/null +++ b/tests/test_all.py @@ -0,0 +1,31 @@ +# tests/runner.py +import unittest + +import app +# import your test modules +import test_import_all +import trader + +# initialize the test suite +loader = unittest.TestLoader() +suite = unittest.TestSuite() + +# add tests to the test suite +suite.addTests(loader.loadTestsFromModule(test_import_all)) +suite.addTests(loader.loadTestsFromModule(trader)) +suite.addTests(loader.loadTestsFromModule(app)) + + +# initialize a runner, pass it your suite and run it +def main(): + runner = unittest.TextTestRunner(verbosity=3) + result = runner.run(suite) + return result + + +if __name__ == '__main__': + result = main() + if result.failures or result.errors: + exit(1) + else: + exit(0) diff --git a/tests/load_all.py b/tests/test_import_all.py similarity index 54% rename from tests/load_all.py rename to tests/test_import_all.py index feae9c12..38d240f8 100644 --- a/tests/load_all.py +++ b/tests/test_import_all.py @@ -1,24 +1,45 @@ # flake8: noqa import unittest +import platform +# noinspection PyUnresolvedReferences class ImportTest(unittest.TestCase): # noinspection PyUnresolvedReferences def test_import_all(self): from vnpy.event import EventEngine + def test_import_main_engine(self): from vnpy.trader.engine import MainEngine + + def test_import_ui(self): from vnpy.trader.ui import MainWindow, create_qapp + def test_import_bitmex_gateway(self): from vnpy.gateway.bitmex import BitmexGateway + + def test_import_futu_gateway(self): from vnpy.gateway.futu import FutuGateway + + def test_import_ib_gateway(self): from vnpy.gateway.ib import IbGateway + + @unittest.skipIf(platform.system() == "Darwin", "Not supported yet under osx") + def test_import_ctp_gateway(self): from vnpy.gateway.ctp import CtpGateway + + def test_import_tiger_gateway(self): from vnpy.gateway.tiger import TigerGateway + + @unittest.skipIf(platform.system() == "Darwin", "Not supported yet under osx") + def test_import_oes_gateway(self): from vnpy.gateway.oes import OesGateway + def test_import_cta_strategy_app(self): from vnpy.app.cta_strategy import CtaStrategyApp + + def test_import_csv_loader_app(self): from vnpy.app.csv_loader import CsvLoaderApp diff --git a/tests/trader/__init__.py b/tests/trader/__init__.py new file mode 100644 index 00000000..c5776fc1 --- /dev/null +++ b/tests/trader/__init__.py @@ -0,0 +1,2 @@ +from .test_database import * +from .test_settings import * diff --git a/tests/trader/test_database.py b/tests/trader/test_database.py new file mode 100644 index 00000000..df009baf --- /dev/null +++ b/tests/trader/test_database.py @@ -0,0 +1,128 @@ +""" +Test if database works fine +""" +import os +import unittest +from datetime import datetime, timedelta + +from vnpy.trader.constant import Exchange, Interval +from vnpy.trader.database.database import Driver +from vnpy.trader.object import BarData, TickData + +os.environ['VNPY_TESTING'] = '1' + +profiles = { + Driver.SQLITE: { + "driver": "sqlite", + "database": "test_db.db", + } +} +if 'VNPY_TEST_ONLY_SQLITE' not in os.environ: + profiles.update({ + Driver.MYSQL: { + "driver": "mysql", + "database": os.environ['VNPY_TEST_MYSQL_DATABASE'], + "host": os.environ['VNPY_TEST_MYSQL_HOST'], + "port": int(os.environ['VNPY_TEST_MYSQL_PORT']), + "user": os.environ["VNPY_TEST_MYSQL_USER"], + "password": os.environ['VNPY_TEST_MYSQL_PASSWORD'], + }, + Driver.POSTGRESQL: { + "driver": "postgresql", + "database": os.environ['VNPY_TEST_POSTGRESQL_DATABASE'], + "host": os.environ['VNPY_TEST_POSTGRESQL_HOST'], + "port": int(os.environ['VNPY_TEST_POSTGRESQL_PORT']), + "user": os.environ["VNPY_TEST_POSTGRESQL_USER"], + "password": os.environ['VNPY_TEST_POSTGRESQL_PASSWORD'], + }, + Driver.MONGODB: { + "driver": "mongodb", + "database": os.environ['VNPY_TEST_MONGODB_DATABASE'], + "host": os.environ['VNPY_TEST_MONGODB_HOST'], + "port": int(os.environ['VNPY_TEST_MONGODB_PORT']), + "user": "", + "password": "", + "authentication_source": "", + }, + }) + + +def now(): + return datetime.utcnow() + + +bar = BarData( + gateway_name="DB", + symbol="test_symbol", + exchange=Exchange.BITMEX, + datetime=now(), + interval=Interval.MINUTE, +) + +tick = TickData( + gateway_name="DB", + symbol="test_symbol", + exchange=Exchange.BITMEX, + datetime=now(), + name="DB_test_symbol", +) + + +class TestDatabase(unittest.TestCase): + + def connect(self, settings: dict): + from vnpy.trader.database.initialize import init # noqa + self.manager = init(settings) + + def test_upsert_bar(self): + for driver, settings in profiles.items(): + with self.subTest(driver=driver, settings=settings): + self.connect(settings) + self.manager.save_bar_data([bar]) + self.manager.save_bar_data([bar]) + + def test_save_load_bar(self): + for driver, settings in profiles.items(): + with self.subTest(driver=driver, settings=settings): + self.connect(settings) + # save first + self.manager.save_bar_data([bar]) + + # and load + results = self.manager.load_bar_data( + symbol=bar.symbol, + exchange=bar.exchange, + interval=bar.interval, + start=bar.datetime - timedelta(seconds=1), # time is not accuracy + end=now() + timedelta(seconds=1), # time is not accuracy + ) + count = len(results) + self.assertNotEqual(count, 0) + + def test_upsert_tick(self): + for driver, settings in profiles.items(): + with self.subTest(driver=driver, settings=settings): + self.connect(settings) + self.manager.save_tick_data([tick]) + self.manager.save_tick_data([tick]) + + def test_save_load_tick(self): + for driver, settings in profiles.items(): + with self.subTest(driver=driver, settings=settings): + self.connect(settings) + # save first + self.manager.save_tick_data([tick]) + + # and load + results = self.manager.load_tick_data( + symbol=bar.symbol, + exchange=bar.exchange, + start=bar.datetime - timedelta(seconds=1), # time is not accuracy + end=now() + timedelta(seconds=1), # time is not accuracy + ) + count = len(results) + self.assertNotEqual(count, 0) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trader/test_settings.py b/tests/trader/test_settings.py new file mode 100644 index 00000000..2d46e73e --- /dev/null +++ b/tests/trader/test_settings.py @@ -0,0 +1,25 @@ +""" +Test if database works fine +""" +import unittest + +from vnpy.trader.setting import SETTINGS, get_settings + + +class TestSettings(unittest.TestCase): + + def test_get_settings(self): + SETTINGS['a'] = 1 + got = get_settings() + self.assertIn('a', got) + self.assertEqual(got['a'], 1) + + def test_get_settings_with_prefix(self): + SETTINGS['a.a'] = 1 + got = get_settings() + self.assertIn('a', got) + self.assertEqual(got['a'], 1) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/travis_env.sh b/tests/travis_env.sh new file mode 100644 index 00000000..e2e59613 --- /dev/null +++ b/tests/travis_env.sh @@ -0,0 +1,30 @@ +#!/usr/bin/env bash + +[[ -z ${VNPY_TEST_MYSQL_DATABASE} ]] && VNPY_TEST_MYSQL_DATABASE=vnpy +[[ -z ${VNPY_TEST_MYSQL_HOST} ]] && VNPY_TEST_MYSQL_HOST=127.0.0.1 +[[ -z ${VNPY_TEST_MYSQL_PORT} ]] && VNPY_TEST_MYSQL_PORT=3306 +[[ -z ${VNPY_TEST_MYSQL_USER} ]] && VNPY_TEST_MYSQL_USER=root +[[ -z ${VNPY_TEST_MYSQL_PASSWORD} ]] && VNPY_TEST_MYSQL_PASSWORD= +[[ -z ${VNPY_TEST_POSTGRESQL_DATABASE} ]] && VNPY_TEST_POSTGRESQL_DATABASE=vnpy +[[ -z ${VNPY_TEST_POSTGRESQL_HOST} ]] && VNPY_TEST_POSTGRESQL_HOST=127.0.0.1 +[[ -z ${VNPY_TEST_POSTGRESQL_PORT} ]] && VNPY_TEST_POSTGRESQL_PORT=5432 +[[ -z ${VNPY_TEST_POSTGRESQL_USER} ]] && VNPY_TEST_POSTGRESQL_USER=postgres +[[ -z ${VNPY_TEST_POSTGRESQL_PASSWORD} ]] && VNPY_TEST_POSTGRESQL_PASSWORD= +[[ -z ${VNPY_TEST_MONGODB_DATABASE} ]] && VNPY_TEST_MONGODB_DATABASE=vnpy +[[ -z ${VNPY_TEST_MONGODB_HOST} ]] && VNPY_TEST_MONGODB_HOST=127.0.0.1 +[[ -z ${VNPY_TEST_MONGODB_PORT} ]] && VNPY_TEST_MONGODB_PORT=27017 + + +export VNPY_TEST_MYSQL_DATABASE +export VNPY_TEST_MYSQL_HOST +export VNPY_TEST_MYSQL_PORT +export VNPY_TEST_MYSQL_USER +export VNPY_TEST_MYSQL_PASSWORD +export VNPY_TEST_POSTGRESQL_DATABASE +export VNPY_TEST_POSTGRESQL_HOST +export VNPY_TEST_POSTGRESQL_PORT +export VNPY_TEST_POSTGRESQL_USER +export VNPY_TEST_POSTGRESQL_PASSWORD +export VNPY_TEST_MONGODB_DATABASE +export VNPY_TEST_MONGODB_HOST +export VNPY_TEST_MONGODB_PORT diff --git a/vnpy/app/csv_loader/engine.py b/vnpy/app/csv_loader/engine.py index fa88afe7..a9ff1936 100644 --- a/vnpy/app/csv_loader/engine.py +++ b/vnpy/app/csv_loader/engine.py @@ -22,14 +22,13 @@ Sample csv file: import csv from datetime import datetime - -from peewee import chunked +from typing import TextIO from vnpy.event import EventEngine from vnpy.trader.constant import Exchange, Interval -from vnpy.trader.database import DbBarData, DB +from vnpy.trader.database import database_manager from vnpy.trader.engine import BaseEngine, MainEngine - +from vnpy.trader.object import BarData APP_NAME = "CsvLoader" @@ -53,6 +52,59 @@ class CsvLoaderEngine(BaseEngine): self.high_head: str = "" self.volume_head: str = "" + def load_by_handle( + self, + f: TextIO, + symbol: str, + exchange: Exchange, + interval: Interval, + datetime_head: str, + open_head: str, + close_head: str, + low_head: str, + high_head: str, + volume_head: str, + datetime_format: str, + ): + """ + load by text mode file handle + """ + reader = csv.DictReader(f) + + bars = [] + start = None + count = 0 + for item in reader: + if datetime_format: + dt = datetime.strptime(item[datetime_head], datetime_format) + else: + dt = datetime.fromisoformat(item[datetime_head]) + + bar = BarData( + symbol=symbol, + exchange=exchange, + datetime=dt, + interval=interval, + volume=item[volume_head], + open_price=item[open_head], + high_price=item[high_head], + low_price=item[low_head], + close_price=item[close_head], + gateway_name="DB", + ) + + bars.append(bar) + + # do some statistics + count += 1 + if not start: + start = bar.datetime + end = bar.datetime + + # insert into database + database_manager.save_bar_data(bars) + return start, end, count + def load( self, file_path: str, @@ -65,49 +117,22 @@ class CsvLoaderEngine(BaseEngine): low_head: str, high_head: str, volume_head: str, - datetime_format: str + datetime_format: str, ): - """""" - vt_symbol = f"{symbol}.{exchange.value}" - - start = None - end = None - count = 0 - + """ + load by filename + """ with open(file_path, "rt") as f: - reader = csv.DictReader(f) - - db_bars = [] - - for item in reader: - dt = datetime.strptime(item[datetime_head], datetime_format) - - db_bar = { - "symbol": symbol, - "exchange": exchange.value, - "datetime": dt, - "interval": interval.value, - "volume": item[volume_head], - "open_price": item[open_head], - "high_price": item[high_head], - "low_price": item[low_head], - "close_price": item[close_head], - "vt_symbol": vt_symbol, - "gateway_name": "DB" - } - - db_bars.append(db_bar) - - # do some statistics - count += 1 - if not start: - start = db_bar["datetime"] - - end = db_bar["datetime"] - - # Insert into DB - with DB.atomic(): - for batch in chunked(db_bars, 50): - DbBarData.insert_many(batch).on_conflict_replace().execute() - - return start, end, count + return self.load_by_handle( + f, + symbol=symbol, + exchange=exchange, + interval=interval, + datetime_head=datetime_head, + open_head=open_head, + close_head=close_head, + low_head=low_head, + high_head=high_head, + volume_head=volume_head, + datetime_format=datetime_format, + ) diff --git a/vnpy/app/cta_strategy/backtesting.py b/vnpy/app/cta_strategy/backtesting.py index aeb5ddc1..01423bf5 100644 --- a/vnpy/app/cta_strategy/backtesting.py +++ b/vnpy/app/cta_strategy/backtesting.py @@ -12,8 +12,8 @@ from pandas import DataFrame from vnpy.trader.constant import (Direction, Offset, Exchange, Interval, Status) -from vnpy.trader.database import DbBarData, DbTickData -from vnpy.trader.object import OrderData, TradeData +from vnpy.trader.database import database_manager +from vnpy.trader.object import OrderData, TradeData, BarData, TickData from vnpy.trader.utility import round_to_pricetick from .base import ( @@ -103,8 +103,8 @@ class BacktestingEngine: self.strategy_class = None self.strategy = None - self.tick = None - self.bar = None + self.tick: TickData + self.bar: BarData self.datetime = None self.interval = None @@ -199,14 +199,16 @@ class BacktestingEngine: if self.mode == BacktestingMode.BAR: self.history_data = load_bar_data( - self.vt_symbol, + self.symbol, + self.exchange, self.interval, self.start, self.end ) else: self.history_data = load_tick_data( - self.vt_symbol, + self.symbol, + self.exchange, self.start, self.end ) @@ -520,7 +522,7 @@ class BacktestingEngine: else: self.daily_results[d] = DailyResult(d, price) - def new_bar(self, bar: DbBarData): + def new_bar(self, bar: BarData): """""" self.bar = bar self.datetime = bar.datetime @@ -531,7 +533,7 @@ class BacktestingEngine: self.update_daily_close(bar.close_price) - def new_tick(self, tick: DbTickData): + def new_tick(self, tick: TickData): """""" self.tick = tick self.datetime = tick.datetime @@ -966,41 +968,26 @@ def optimize( @lru_cache(maxsize=10) def load_bar_data( - vt_symbol: str, - interval: str, - start: datetime, + symbol: str, + exchange: Exchange, + interval: Interval, + start: datetime, end: datetime ): """""" - s = ( - DbBarData.select() - .where( - (DbBarData.vt_symbol == vt_symbol) - & (DbBarData.interval == interval) - & (DbBarData.datetime >= start) - & (DbBarData.datetime <= end) - ) - .order_by(DbBarData.datetime) + return database_manager.load_bar_data( + symbol, exchange, interval, start, end ) - data = [db_bar.to_bar() for db_bar in s] - return data @lru_cache(maxsize=10) def load_tick_data( - vt_symbol: str, - start: datetime, + symbol: str, + exchange: Exchange, + start: datetime, end: datetime ): """""" - s = ( - DbTickData.select() - .where( - (DbTickData.vt_symbol == vt_symbol) - & (DbTickData.datetime >= start) - & (DbTickData.datetime <= end) - ) - .order_by(DbTickData.datetime) + return database_manager.load_tick_data( + symbol, exchange, start, end ) - data = [db_tick.db_tick() for db_tick in s] - return data \ No newline at end of file diff --git a/vnpy/app/cta_strategy/engine.py b/vnpy/app/cta_strategy/engine.py index 37e5973a..7824ff2d 100644 --- a/vnpy/app/cta_strategy/engine.py +++ b/vnpy/app/cta_strategy/engine.py @@ -5,7 +5,7 @@ import os import traceback from collections import defaultdict from pathlib import Path -from typing import Any, Callable +from typing import Any, Callable, List from datetime import datetime, timedelta from threading import Thread from queue import Queue @@ -36,7 +36,7 @@ from vnpy.trader.constant import ( Status ) from vnpy.trader.utility import load_json, save_json -from vnpy.trader.database import DbTickData, DbBarData +from vnpy.trader.database import database_manager from vnpy.trader.setting import SETTINGS from .base import ( @@ -146,13 +146,12 @@ class CtaEngine(BaseEngine): self.write_log("RQData数据接口初始化成功") def query_bar_from_rq( - self, vt_symbol: str, interval: Interval, start: datetime, end: datetime + self, symbol: str, exchange: Exchange, interval: Interval, start: datetime, end: datetime ): """ Query bar data from RQData. """ - symbol, exchange_str = vt_symbol.split(".") - rq_symbol = to_rq_symbol(vt_symbol) + rq_symbol = to_rq_symbol(symbol, exchange) if rq_symbol not in self.rq_symbols: return None @@ -166,11 +165,11 @@ class CtaEngine(BaseEngine): end_date=end ) - data = [] + data: List[BarData] = [] for ix, row in df.iterrows(): bar = BarData( symbol=symbol, - exchange=Exchange(exchange_str), + exchange=exchange, interval=interval, datetime=row.name.to_pydatetime(), open_price=row["open"], @@ -529,46 +528,41 @@ class CtaEngine(BaseEngine): return self.engine_type def load_bar( - self, vt_symbol: str, days: int, interval: Interval, callback: Callable + self, symbol: str, exchange: Exchange, days: int, interval: Interval, + callback: Callable[[BarData], None] ): """""" end = datetime.now() start = end - timedelta(days) - # Query data from RQData by default, if not found, load from database. - data = self.query_bar_from_rq(vt_symbol, interval, start, end) - if not data: - s = ( - DbBarData.select() - .where( - (DbBarData.vt_symbol == vt_symbol) - & (DbBarData.interval == interval.value) - & (DbBarData.datetime >= start) - & (DbBarData.datetime <= end) - ) - .order_by(DbBarData.datetime) + # Query bars from RQData by default, if not found, load from database. + bars = self.query_bar_from_rq(symbol, exchange, interval, start, end) + if not bars: + bars = database_manager.load_bar_data( + symbol=symbol, + exchange=exchange, + interval=interval, + start=start, + end=end, ) - data = [db_bar.to_bar() for db_bar in s] - for bar in data: + for bar in bars: callback(bar) - def load_tick(self, vt_symbol: str, days: int, callback: Callable): + def load_tick(self, symbol: str, exchange: Exchange, days: int, + callback: Callable[[TickData], None]): """""" end = datetime.now() start = end - timedelta(days) - s = ( - DbTickData.select() - .where( - (DbBarData.vt_symbol == vt_symbol) - & (DbBarData.datetime >= start) - & (DbBarData.datetime <= end) - ) - .order_by(DbBarData.datetime) + ticks = database_manager.load_tick_data( + symbol=symbol, + exchange=exchange, + start=start, + end=end, ) - for tick in s: + for tick in ticks: callback(tick) def call_strategy_func( @@ -757,7 +751,7 @@ class CtaEngine(BaseEngine): """ Load strategy class from certain folder. """ - for dirpath, dirnames, filenames in os.walk(path): + for dirpath, dirnames, filenames in os.walk(str(path)): for filename in filenames: if filename.endswith(".py"): strategy_module_name = ".".join( @@ -914,19 +908,19 @@ class CtaEngine(BaseEngine): self.main_engine.send_email(subject, msg) -def to_rq_symbol(vt_symbol: str): +def to_rq_symbol(symbol: str, exchange: Exchange): """ CZCE product of RQData has symbol like "TA1905" while vt symbol is "TA905.CZCE" so need to add "1" in symbol. """ - symbol, exchange_str = vt_symbol.split(".") - if exchange_str != "CZCE": + if exchange is not Exchange.CZCE: return symbol.upper() for count, word in enumerate(symbol): if word.isdigit(): break - + + # noinspection PyUnboundLocalVariable product = symbol[:count] year = symbol[count] month = symbol[count + 1:] diff --git a/vnpy/trader/database.py b/vnpy/trader/database.py deleted file mode 100644 index 1e90dcbc..00000000 --- a/vnpy/trader/database.py +++ /dev/null @@ -1,268 +0,0 @@ -"""""" - -from peewee import CharField, DateTimeField, FloatField, Model, MySQLDatabase, PostgresqlDatabase, \ - SqliteDatabase - -from .constant import Exchange, Interval -from .object import BarData, TickData -from .setting import SETTINGS -from .utility import resolve_path - - -def init(): - db_settings = SETTINGS['database'] - driver = db_settings["driver"] - - init_funcs = { - "sqlite": init_sqlite, - "mysql": init_mysql, - "postgresql": init_postgresql, - } - - assert driver in init_funcs - del db_settings['driver'] - return init_funcs[driver](db_settings) - - -def init_sqlite(settings: dict): - global DB - database = settings['database'] - - DB = SqliteDatabase(str(resolve_path(database))) - - -def init_mysql(settings: dict): - global DB - DB = MySQLDatabase(**settings) - - -def init_postgresql(settings: dict): - global DB - DB = PostgresqlDatabase(**settings) - - -init() - - -class DbBarData(Model): - """ - Candlestick bar data for database storage. - - Index is defined unique with vt_symbol, interval and datetime. - """ - - symbol = CharField() - exchange = CharField() - datetime = DateTimeField() - interval = CharField() - - volume = FloatField() - open_price = FloatField() - high_price = FloatField() - low_price = FloatField() - close_price = FloatField() - - vt_symbol = CharField() - gateway_name = CharField() - - class Meta: - database = DB - indexes = ((("vt_symbol", "interval", "datetime"), True),) - - @staticmethod - def from_bar(bar: BarData): - """ - Generate DbBarData object from BarData. - """ - db_bar = DbBarData() - - db_bar.symbol = bar.symbol - db_bar.exchange = bar.exchange.value - db_bar.datetime = bar.datetime - db_bar.interval = bar.interval.value - db_bar.volume = bar.volume - db_bar.open_price = bar.open_price - db_bar.high_price = bar.high_price - db_bar.low_price = bar.low_price - db_bar.close_price = bar.close_price - db_bar.vt_symbol = bar.vt_symbol - db_bar.gateway_name = "DB" - - return db_bar - - def to_bar(self): - """ - Generate BarData object from DbBarData. - """ - bar = BarData( - symbol=self.symbol, - exchange=Exchange(self.exchange), - datetime=self.datetime, - interval=Interval(self.interval), - volume=self.volume, - open_price=self.open_price, - high_price=self.high_price, - low_price=self.low_price, - close_price=self.close_price, - gateway_name=self.gateway_name, - ) - return bar - - -class DbTickData(Model): - """ - Tick data for database storage. - - Index is defined unique with vt_symbol, interval and datetime. - """ - - symbol = CharField() - exchange = CharField() - datetime = DateTimeField() - - name = CharField() - volume = FloatField() - last_price = FloatField() - last_volume = FloatField() - limit_up = FloatField() - limit_down = FloatField() - - open_price = FloatField() - high_price = FloatField() - low_price = FloatField() - close_price = FloatField() - - bid_price_1 = FloatField() - bid_price_2 = FloatField() - bid_price_3 = FloatField() - bid_price_4 = FloatField() - bid_price_5 = FloatField() - - ask_price_1 = FloatField() - ask_price_2 = FloatField() - ask_price_3 = FloatField() - ask_price_4 = FloatField() - ask_price_5 = FloatField() - - bid_volume_1 = FloatField() - bid_volume_2 = FloatField() - bid_volume_3 = FloatField() - bid_volume_4 = FloatField() - bid_volume_5 = FloatField() - - ask_volume_1 = FloatField() - ask_volume_2 = FloatField() - ask_volume_3 = FloatField() - ask_volume_4 = FloatField() - ask_volume_5 = FloatField() - - vt_symbol = CharField() - gateway_name = CharField() - - class Meta: - database = DB - indexes = ((("vt_symbol", "datetime"), True),) - - @staticmethod - def from_tick(tick: TickData): - """ - Generate DbTickData object from TickData. - """ - db_tick = DbTickData() - - db_tick.symbol = tick.symbol - db_tick.exchange = tick.exchange.value - db_tick.datetime = tick.datetime - db_tick.name = tick.name - db_tick.volume = tick.volume - db_tick.last_price = tick.last_price - db_tick.last_volume = tick.last_volume - db_tick.limit_up = tick.limit_up - db_tick.limit_down = tick.limit_down - db_tick.open_price = tick.open_price - db_tick.high_price = tick.high_price - db_tick.low_price = tick.low_price - db_tick.pre_close = tick.pre_close - - db_tick.bid_price_1 = tick.bid_price_1 - db_tick.ask_price_1 = tick.ask_price_1 - db_tick.bid_volume_1 = tick.bid_volume_1 - db_tick.ask_volume_1 = tick.ask_volume_1 - - if tick.bid_price_2: - db_tick.bid_price_2 = tick.bid_price_2 - db_tick.bid_price_3 = tick.bid_price_3 - db_tick.bid_price_4 = tick.bid_price_4 - db_tick.bid_price_5 = tick.bid_price_5 - - db_tick.ask_price_2 = tick.ask_price_2 - db_tick.ask_price_3 = tick.ask_price_3 - db_tick.ask_price_4 = tick.ask_price_4 - db_tick.ask_price_5 = tick.ask_price_5 - - db_tick.bid_volume_2 = tick.bid_volume_2 - db_tick.bid_volume_3 = tick.bid_volume_3 - db_tick.bid_volume_4 = tick.bid_volume_4 - db_tick.bid_volume_5 = tick.bid_volume_5 - - db_tick.ask_volume_2 = tick.ask_volume_2 - db_tick.ask_volume_3 = tick.ask_volume_3 - db_tick.ask_volume_4 = tick.ask_volume_4 - db_tick.ask_volume_5 = tick.ask_volume_5 - - db_tick.vt_symbol = tick.vt_symbol - db_tick.gateway_name = "DB" - - return tick - - def to_tick(self): - """ - Generate TickData object from DbTickData. - """ - tick = TickData( - symbol=self.symbol, - exchange=Exchange(self.exchange), - datetime=self.datetime, - name=self.name, - volume=self.volume, - last_price=self.last_price, - last_volume=self.last_volume, - limit_up=self.limit_up, - limit_down=self.limit_down, - open_price=self.open_price, - high_price=self.high_price, - low_price=self.low_price, - pre_close=self.pre_close, - bid_price_1=self.bid_price_1, - ask_price_1=self.ask_price_1, - bid_volume_1=self.bid_volume_1, - ask_volume_1=self.ask_volume_1, - gateway_name=self.gateway_name, - ) - - if self.bid_price_2: - tick.bid_price_2 = self.bid_price_2 - tick.bid_price_3 = self.bid_price_3 - tick.bid_price_4 = self.bid_price_4 - tick.bid_price_5 = self.bid_price_5 - - tick.ask_price_2 = self.ask_price_2 - tick.ask_price_3 = self.ask_price_3 - tick.ask_price_4 = self.ask_price_4 - tick.ask_price_5 = self.ask_price_5 - - tick.bid_volume_2 = self.bid_volume_2 - tick.bid_volume_3 = self.bid_volume_3 - tick.bid_volume_4 = self.bid_volume_4 - tick.bid_volume_5 = self.bid_volume_5 - - tick.ask_volume_2 = self.ask_volume_2 - tick.ask_volume_3 = self.ask_volume_3 - tick.ask_volume_4 = self.ask_volume_4 - tick.ask_volume_5 = self.ask_volume_5 - - return tick - - -DB.connect() -DB.create_tables([DbBarData, DbTickData]) diff --git a/vnpy/trader/database/__init__.py b/vnpy/trader/database/__init__.py new file mode 100644 index 00000000..16cb96b8 --- /dev/null +++ b/vnpy/trader/database/__init__.py @@ -0,0 +1,12 @@ +import os +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from vnpy.trader.database.database import BaseDatabaseManager + +if "VNPY_TESTING" not in os.environ: + from vnpy.trader.setting import get_settings + from .initialize import init + + settings = get_settings("database.") + database_manager: "BaseDatabaseManager" = init(settings=settings) diff --git a/vnpy/trader/database/database.py b/vnpy/trader/database/database.py new file mode 100644 index 00000000..ec765e8d --- /dev/null +++ b/vnpy/trader/database/database.py @@ -0,0 +1,53 @@ +from abc import ABC, abstractmethod +from datetime import datetime +from enum import Enum +from typing import Sequence, TYPE_CHECKING + +if TYPE_CHECKING: + from vnpy.trader.constant import Interval, Exchange # noqa + from vnpy.trader.object import BarData, TickData # noqa + + +class Driver(Enum): + SQLITE = "sqlite" + MYSQL = "mysql" + POSTGRESQL = "postgresql" + MONGODB = "mongodb" + + +class BaseDatabaseManager(ABC): + + @abstractmethod + def load_bar_data( + self, + symbol: str, + exchange: "Exchange", + interval: "Interval", + start: datetime, + end: datetime + ) -> Sequence["BarData"]: + pass + + @abstractmethod + def load_tick_data( + self, + symbol: str, + exchange: "Exchange", + start: datetime, + end: datetime + ) -> Sequence["TickData"]: + pass + + @abstractmethod + def save_bar_data( + self, + datas: Sequence["BarData"], + ): + pass + + @abstractmethod + def save_tick_data( + self, + datas: Sequence["TickData"], + ): + pass diff --git a/vnpy/trader/database/database_mongo.py b/vnpy/trader/database/database_mongo.py new file mode 100644 index 00000000..ef5038e1 --- /dev/null +++ b/vnpy/trader/database/database_mongo.py @@ -0,0 +1,302 @@ +from datetime import datetime +from enum import Enum +from typing import Sequence + +from mongoengine import DateTimeField, Document, FloatField, StringField, connect + +from vnpy.trader.constant import Exchange, Interval +from vnpy.trader.object import BarData, TickData +from .database import BaseDatabaseManager, Driver + + +def init(_: Driver, settings: dict): + database = settings["database"] + host = settings["host"] + port = settings["port"] + username = settings["user"] + password = settings["password"] + authentication_source = settings["authentication_source"] + if not username: # if username == '' or None, skip username + username = None + password = None + authentication_source = None + connect( + db=database, + host=host, + port=port, + username=username, + password=password, + authentication_source=authentication_source, + ) + return MongoManager() + + +class DbBarData(Document): + """ + Candlestick bar data for database storage. + + Index is defined unique with datetime, interval, symbol + """ + + symbol: str = StringField() + exchange: str = StringField() + datetime: datetime = DateTimeField() + interval: str = StringField() + + volume: float = FloatField() + open_price: float = FloatField() + high_price: float = FloatField() + low_price: float = FloatField() + close_price: float = FloatField() + + meta = { + "indexes": [ + {"fields": ("datetime", "interval", "symbol", "exchange"), "unique": True} + ] + } + + @staticmethod + def from_bar(bar: BarData): + """ + Generate DbBarData object from BarData. + """ + db_bar = DbBarData() + + db_bar.symbol = bar.symbol + db_bar.exchange = bar.exchange.value + db_bar.datetime = bar.datetime + db_bar.interval = bar.interval.value + db_bar.volume = bar.volume + db_bar.open_price = bar.open_price + db_bar.high_price = bar.high_price + db_bar.low_price = bar.low_price + db_bar.close_price = bar.close_price + + return db_bar + + def to_bar(self): + """ + Generate BarData object from DbBarData. + """ + bar = BarData( + symbol=self.symbol, + exchange=Exchange(self.exchange), + datetime=self.datetime, + interval=Interval(self.interval), + volume=self.volume, + open_price=self.open_price, + high_price=self.high_price, + low_price=self.low_price, + close_price=self.close_price, + gateway_name="DB", + ) + return bar + + +class DbTickData(Document): + """ + Tick data for database storage. + + Index is defined unique with (datetime, symbol) + """ + + symbol: str = StringField() + exchange: str = StringField() + datetime: datetime = DateTimeField() + + name: str = StringField() + volume: float = FloatField() + last_price: float = FloatField() + last_volume: float = FloatField() + limit_up: float = FloatField() + limit_down: float = FloatField() + + open_price: float = FloatField() + high_price: float = FloatField() + low_price: float = FloatField() + close_price: float = FloatField() + pre_close: float = FloatField() + + bid_price_1: float = FloatField() + bid_price_2: float = FloatField() + bid_price_3: float = FloatField() + bid_price_4: float = FloatField() + bid_price_5: float = FloatField() + + ask_price_1: float = FloatField() + ask_price_2: float = FloatField() + ask_price_3: float = FloatField() + ask_price_4: float = FloatField() + ask_price_5: float = FloatField() + + bid_volume_1: float = FloatField() + bid_volume_2: float = FloatField() + bid_volume_3: float = FloatField() + bid_volume_4: float = FloatField() + bid_volume_5: float = FloatField() + + ask_volume_1: float = FloatField() + ask_volume_2: float = FloatField() + ask_volume_3: float = FloatField() + ask_volume_4: float = FloatField() + ask_volume_5: float = FloatField() + + meta = {"indexes": [{"fields": ("datetime", "symbol", "exchange"), "unique": True}]} + + @staticmethod + def from_tick(tick: TickData): + """ + Generate DbTickData object from TickData. + """ + db_tick = DbTickData() + + db_tick.symbol = tick.symbol + db_tick.exchange = tick.exchange.value + db_tick.datetime = tick.datetime + db_tick.name = tick.name + db_tick.volume = tick.volume + db_tick.last_price = tick.last_price + db_tick.last_volume = tick.last_volume + db_tick.limit_up = tick.limit_up + db_tick.limit_down = tick.limit_down + db_tick.open_price = tick.open_price + db_tick.high_price = tick.high_price + db_tick.low_price = tick.low_price + db_tick.pre_close = tick.pre_close + + db_tick.bid_price_1 = tick.bid_price_1 + db_tick.ask_price_1 = tick.ask_price_1 + db_tick.bid_volume_1 = tick.bid_volume_1 + db_tick.ask_volume_1 = tick.ask_volume_1 + + if tick.bid_price_2: + db_tick.bid_price_2 = tick.bid_price_2 + db_tick.bid_price_3 = tick.bid_price_3 + db_tick.bid_price_4 = tick.bid_price_4 + db_tick.bid_price_5 = tick.bid_price_5 + + db_tick.ask_price_2 = tick.ask_price_2 + db_tick.ask_price_3 = tick.ask_price_3 + db_tick.ask_price_4 = tick.ask_price_4 + db_tick.ask_price_5 = tick.ask_price_5 + + db_tick.bid_volume_2 = tick.bid_volume_2 + db_tick.bid_volume_3 = tick.bid_volume_3 + db_tick.bid_volume_4 = tick.bid_volume_4 + db_tick.bid_volume_5 = tick.bid_volume_5 + + db_tick.ask_volume_2 = tick.ask_volume_2 + db_tick.ask_volume_3 = tick.ask_volume_3 + db_tick.ask_volume_4 = tick.ask_volume_4 + db_tick.ask_volume_5 = tick.ask_volume_5 + + return db_tick + + def to_tick(self): + """ + Generate TickData object from DbTickData. + """ + tick = TickData( + symbol=self.symbol, + exchange=Exchange(self.exchange), + datetime=self.datetime, + name=self.name, + volume=self.volume, + last_price=self.last_price, + last_volume=self.last_volume, + limit_up=self.limit_up, + limit_down=self.limit_down, + open_price=self.open_price, + high_price=self.high_price, + low_price=self.low_price, + pre_close=self.pre_close, + bid_price_1=self.bid_price_1, + ask_price_1=self.ask_price_1, + bid_volume_1=self.bid_volume_1, + ask_volume_1=self.ask_volume_1, + gateway_name="DB", + ) + + if self.bid_price_2: + tick.bid_price_2 = self.bid_price_2 + tick.bid_price_3 = self.bid_price_3 + tick.bid_price_4 = self.bid_price_4 + tick.bid_price_5 = self.bid_price_5 + + tick.ask_price_2 = self.ask_price_2 + tick.ask_price_3 = self.ask_price_3 + tick.ask_price_4 = self.ask_price_4 + tick.ask_price_5 = self.ask_price_5 + + tick.bid_volume_2 = self.bid_volume_2 + tick.bid_volume_3 = self.bid_volume_3 + tick.bid_volume_4 = self.bid_volume_4 + tick.bid_volume_5 = self.bid_volume_5 + + tick.ask_volume_2 = self.ask_volume_2 + tick.ask_volume_3 = self.ask_volume_3 + tick.ask_volume_4 = self.ask_volume_4 + tick.ask_volume_5 = self.ask_volume_5 + + return tick + + +class MongoManager(BaseDatabaseManager): + def load_bar_data( + self, + symbol: str, + exchange: Exchange, + interval: Interval, + start: datetime, + end: datetime, + ) -> Sequence[BarData]: + s = DbBarData.objects( + symbol=symbol, + exchange=exchange.value, + interval=interval.value, + datetime__gte=start, + datetime__lte=end, + ) + data = [db_bar.to_bar() for db_bar in s] + return data + + def load_tick_data( + self, symbol: str, exchange: Exchange, start: datetime, end: datetime + ) -> Sequence[TickData]: + s = DbTickData.objects( + symbol=symbol, + exchange=exchange.value, + datetime__gte=start, + datetime__lte=end, + ) + data = [db_tick.to_tick() for db_tick in s] + return data + + @staticmethod + def to_update_param(d): + return { + "set__" + k: v.value if isinstance(v, Enum) else v + for k, v in d.__dict__.items() + } + + def save_bar_data(self, datas: Sequence[BarData]): + for d in datas: + updates = self.to_update_param(d) + updates.pop("set__gateway_name") + updates.pop("set__vt_symbol") + ( + DbBarData.objects( + symbol=d.symbol, interval=d.interval.value, datetime=d.datetime + ).update_one(upsert=True, **updates) + ) + + def save_tick_data(self, datas: Sequence[TickData]): + for d in datas: + updates = self.to_update_param(d) + updates.pop("set__gateway_name") + updates.pop("set__vt_symbol") + ( + DbTickData.objects( + symbol=d.symbol, exchange=d.exchange.value, datetime=d.datetime + ).update_one(upsert=True, **updates) + ) diff --git a/vnpy/trader/database/database_sql.py b/vnpy/trader/database/database_sql.py new file mode 100644 index 00000000..61d6e2ff --- /dev/null +++ b/vnpy/trader/database/database_sql.py @@ -0,0 +1,369 @@ +"""""" +from datetime import datetime +from typing import List, Sequence, Type + +from peewee import ( + AutoField, + CharField, + Database, + DateTimeField, + FloatField, + Model, + MySQLDatabase, + PostgresqlDatabase, + SqliteDatabase, + chunked, +) + +from vnpy.trader.constant import Exchange, Interval +from vnpy.trader.object import BarData, TickData +from vnpy.trader.utility import get_file_path +from .database import BaseDatabaseManager, Driver + + +def init(driver: Driver, settings: dict): + init_funcs = { + Driver.SQLITE: init_sqlite, + Driver.MYSQL: init_mysql, + Driver.POSTGRESQL: init_postgresql, + } + assert driver in init_funcs + + db = init_funcs[driver](settings) + bar, tick = init_models(db, driver) + return SqlManager(bar, tick) + + +def init_sqlite(settings: dict): + database = settings["database"] + path = str(get_file_path(database)) + db = SqliteDatabase(path) + return db + + +def init_mysql(settings: dict): + keys = {"database", "user", "password", "host", "port"} + settings = {k: v for k, v in settings.items() if k in keys} + db = MySQLDatabase(**settings) + return db + + +def init_postgresql(settings: dict): + keys = {"database", "user", "password", "host", "port"} + settings = {k: v for k, v in settings.items() if k in keys} + db = PostgresqlDatabase(**settings) + return db + + +class ModelBase(Model): + def to_dict(self): + return self.__data__ + + +def init_models(db: Database, driver: Driver): + class DbBarData(ModelBase): + """ + Candlestick bar data for database storage. + + Index is defined unique with datetime, interval, symbol + """ + + id = AutoField() + symbol: str = CharField() + exchange: str = CharField() + datetime: datetime = DateTimeField() + interval: str = CharField() + + volume: float = FloatField() + open_price: float = FloatField() + high_price: float = FloatField() + low_price: float = FloatField() + close_price: float = FloatField() + + class Meta: + database = db + indexes = ((("datetime", "interval", "symbol", "exchange"), True),) + + @staticmethod + def from_bar(bar: BarData): + """ + Generate DbBarData object from BarData. + """ + db_bar = DbBarData() + + db_bar.symbol = bar.symbol + db_bar.exchange = bar.exchange.value + db_bar.datetime = bar.datetime + db_bar.interval = bar.interval.value + db_bar.volume = bar.volume + db_bar.open_price = bar.open_price + db_bar.high_price = bar.high_price + db_bar.low_price = bar.low_price + db_bar.close_price = bar.close_price + + return db_bar + + def to_bar(self): + """ + Generate BarData object from DbBarData. + """ + bar = BarData( + symbol=self.symbol, + exchange=Exchange(self.exchange), + datetime=self.datetime, + interval=Interval(self.interval), + volume=self.volume, + open_price=self.open_price, + high_price=self.high_price, + low_price=self.low_price, + close_price=self.close_price, + gateway_name="DB", + ) + return bar + + @staticmethod + def save_all(objs: List["DbBarData"]): + """ + save a list of objects, update if exists. + """ + dicts = [i.to_dict() for i in objs] + with db.atomic(): + if driver is Driver.POSTGRESQL: + for bar in dicts: + DbBarData.insert(bar).on_conflict( + update=bar, + conflict_target=( + DbBarData.datetime, + DbBarData.interval, + DbBarData.symbol, + DbBarData.exchange, + ), + ).execute() + else: + for c in chunked(dicts, 50): + DbBarData.insert_many(c).on_conflict_replace().execute() + + class DbTickData(ModelBase): + """ + Tick data for database storage. + + Index is defined unique with (datetime, symbol) + """ + + id = AutoField() + + symbol: str = CharField() + exchange: str = CharField() + datetime: datetime = DateTimeField() + + name: str = CharField() + volume: float = FloatField() + last_price: float = FloatField() + last_volume: float = FloatField() + limit_up: float = FloatField() + limit_down: float = FloatField() + + open_price: float = FloatField() + high_price: float = FloatField() + low_price: float = FloatField() + pre_close: float = FloatField() + + bid_price_1: float = FloatField() + bid_price_2: float = FloatField(null=True) + bid_price_3: float = FloatField(null=True) + bid_price_4: float = FloatField(null=True) + bid_price_5: float = FloatField(null=True) + + ask_price_1: float = FloatField() + ask_price_2: float = FloatField(null=True) + ask_price_3: float = FloatField(null=True) + ask_price_4: float = FloatField(null=True) + ask_price_5: float = FloatField(null=True) + + bid_volume_1: float = FloatField() + bid_volume_2: float = FloatField(null=True) + bid_volume_3: float = FloatField(null=True) + bid_volume_4: float = FloatField(null=True) + bid_volume_5: float = FloatField(null=True) + + ask_volume_1: float = FloatField() + ask_volume_2: float = FloatField(null=True) + ask_volume_3: float = FloatField(null=True) + ask_volume_4: float = FloatField(null=True) + ask_volume_5: float = FloatField(null=True) + + class Meta: + database = db + indexes = ((("datetime", "symbol", "exchange"), True),) + + @staticmethod + def from_tick(tick: TickData): + """ + Generate DbTickData object from TickData. + """ + db_tick = DbTickData() + + db_tick.symbol = tick.symbol + db_tick.exchange = tick.exchange.value + db_tick.datetime = tick.datetime + db_tick.name = tick.name + db_tick.volume = tick.volume + db_tick.last_price = tick.last_price + db_tick.last_volume = tick.last_volume + db_tick.limit_up = tick.limit_up + db_tick.limit_down = tick.limit_down + db_tick.open_price = tick.open_price + db_tick.high_price = tick.high_price + db_tick.low_price = tick.low_price + db_tick.pre_close = tick.pre_close + + db_tick.bid_price_1 = tick.bid_price_1 + db_tick.ask_price_1 = tick.ask_price_1 + db_tick.bid_volume_1 = tick.bid_volume_1 + db_tick.ask_volume_1 = tick.ask_volume_1 + + if tick.bid_price_2: + db_tick.bid_price_2 = tick.bid_price_2 + db_tick.bid_price_3 = tick.bid_price_3 + db_tick.bid_price_4 = tick.bid_price_4 + db_tick.bid_price_5 = tick.bid_price_5 + + db_tick.ask_price_2 = tick.ask_price_2 + db_tick.ask_price_3 = tick.ask_price_3 + db_tick.ask_price_4 = tick.ask_price_4 + db_tick.ask_price_5 = tick.ask_price_5 + + db_tick.bid_volume_2 = tick.bid_volume_2 + db_tick.bid_volume_3 = tick.bid_volume_3 + db_tick.bid_volume_4 = tick.bid_volume_4 + db_tick.bid_volume_5 = tick.bid_volume_5 + + db_tick.ask_volume_2 = tick.ask_volume_2 + db_tick.ask_volume_3 = tick.ask_volume_3 + db_tick.ask_volume_4 = tick.ask_volume_4 + db_tick.ask_volume_5 = tick.ask_volume_5 + + return db_tick + + def to_tick(self): + """ + Generate TickData object from DbTickData. + """ + tick = TickData( + symbol=self.symbol, + exchange=Exchange(self.exchange), + datetime=self.datetime, + name=self.name, + volume=self.volume, + last_price=self.last_price, + last_volume=self.last_volume, + limit_up=self.limit_up, + limit_down=self.limit_down, + open_price=self.open_price, + high_price=self.high_price, + low_price=self.low_price, + pre_close=self.pre_close, + bid_price_1=self.bid_price_1, + ask_price_1=self.ask_price_1, + bid_volume_1=self.bid_volume_1, + ask_volume_1=self.ask_volume_1, + gateway_name="DB", + ) + + if self.bid_price_2: + tick.bid_price_2 = self.bid_price_2 + tick.bid_price_3 = self.bid_price_3 + tick.bid_price_4 = self.bid_price_4 + tick.bid_price_5 = self.bid_price_5 + + tick.ask_price_2 = self.ask_price_2 + tick.ask_price_3 = self.ask_price_3 + tick.ask_price_4 = self.ask_price_4 + tick.ask_price_5 = self.ask_price_5 + + tick.bid_volume_2 = self.bid_volume_2 + tick.bid_volume_3 = self.bid_volume_3 + tick.bid_volume_4 = self.bid_volume_4 + tick.bid_volume_5 = self.bid_volume_5 + + tick.ask_volume_2 = self.ask_volume_2 + tick.ask_volume_3 = self.ask_volume_3 + tick.ask_volume_4 = self.ask_volume_4 + tick.ask_volume_5 = self.ask_volume_5 + + return tick + + @staticmethod + def save_all(objs: List["DbTickData"]): + dicts = [i.to_dict() for i in objs] + with db.atomic(): + if driver is Driver.POSTGRESQL: + for tick in dicts: + DbTickData.insert(tick).on_conflict( + update=tick, + conflict_target=( + DbTickData.datetime, + DbTickData.symbol, + DbTickData.exchange, + ), + ).execute() + else: + for c in chunked(dicts, 50): + DbTickData.insert_many(c).on_conflict_replace().execute() + + db.connect() + db.create_tables([DbBarData, DbTickData]) + return DbBarData, DbTickData + + +class SqlManager(BaseDatabaseManager): + def __init__(self, class_bar: Type[Model], class_tick: Type[Model]): + self.class_bar = class_bar + self.class_tick = class_tick + + def load_bar_data( + self, + symbol: str, + exchange: Exchange, + interval: Interval, + start: datetime, + end: datetime, + ) -> Sequence[BarData]: + s = ( + self.class_bar.select() + .where( + (self.class_bar.symbol == symbol) + & (self.class_bar.exchange == exchange.value) + & (self.class_bar.interval == interval.value) + & (self.class_bar.datetime >= start) + & (self.class_bar.datetime <= end) + ) + .order_by(self.class_bar.datetime) + ) + data = [db_bar.to_bar() for db_bar in s] + return data + + def load_tick_data( + self, symbol: str, exchange: Exchange, start: datetime, end: datetime + ) -> Sequence[TickData]: + s = ( + self.class_tick.select() + .where( + (self.class_tick.symbol == symbol) + & (self.class_tick.exchange == exchange.value) + & (self.class_tick.datetime >= start) + & (self.class_tick.datetime <= end) + ) + .order_by(self.class_tick.datetime) + ) + + data = [db_tick.to_tick() for db_tick in s] + return data + + def save_bar_data(self, datas: Sequence[BarData]): + ds = [self.class_bar.from_bar(i) for i in datas] + self.class_bar.save_all(ds) + + def save_tick_data(self, datas: Sequence[TickData]): + ds = [self.class_tick.from_tick(i) for i in datas] + self.class_tick.save_all(ds) diff --git a/vnpy/trader/database/initialize.py b/vnpy/trader/database/initialize.py new file mode 100644 index 00000000..a15b484c --- /dev/null +++ b/vnpy/trader/database/initialize.py @@ -0,0 +1,24 @@ +"""""" +from .database import BaseDatabaseManager, Driver + + +def init(settings: dict) -> BaseDatabaseManager: + driver = Driver(settings["driver"]) + if driver is Driver.MONGODB: + return init_nosql(driver=driver, settings=settings) + else: + return init_sql(driver=driver, settings=settings) + + +def init_sql(driver: Driver, settings: dict): + from .database_sql import init + keys = {'database', "host", "port", "user", "password"} + settings = {k: v for k, v in settings.items() if k in keys} + _database_manager = init(driver, settings) + return _database_manager + + +def init_nosql(driver: Driver, settings: dict): + from .database_mongo import init + _database_manager = init(driver, settings=settings) + return _database_manager diff --git a/vnpy/trader/setting.py b/vnpy/trader/setting.py index b27dbf96..1dd8b0ff 100644 --- a/vnpy/trader/setting.py +++ b/vnpy/trader/setting.py @@ -24,16 +24,21 @@ SETTINGS = { "rqdata.username": "", "rqdata.password": "", - "database": { - "driver": "sqlite", # sqlite, mysql, postgresql - "database": "{VNPY_TEMP}/database.db", # for sqlite, use this as filepath - "host": "localhost", - "port": 3306, - "user": "root", - "password": "" - } + + "database.driver": "sqlite", # see database.Driver + "database.database": "database.db", # for sqlite, use this as filepath + "database.host": "localhost", + "database.port": 3306, + "database.user": "root", + "database.password": "", + "database.authentication_source": "admin", # for mongodb } # Load global setting from json file. SETTING_FILENAME = "vt_setting.json" SETTINGS.update(load_json(SETTING_FILENAME)) + + +def get_settings(prefix: str = ""): + prefix_length = len(prefix) + return {k[prefix_length:]: v for k, v in SETTINGS.items() if k.startswith(prefix)} diff --git a/vnpy/trader/utility.py b/vnpy/trader/utility.py index 216dcbae..69fa6991 100644 --- a/vnpy/trader/utility.py +++ b/vnpy/trader/utility.py @@ -3,20 +3,28 @@ General utility functions. """ import json -import os from pathlib import Path -from typing import Callable +from typing import Callable, TYPE_CHECKING import numpy as np import talib from .object import BarData, TickData +if TYPE_CHECKING: + from vnpy.trader.constant import Exchange -def resolve_path(pattern: str): - env = dict(os.environ) - env.update({"VNPY_TEMP": str(TEMP_DIR)}) - return pattern.format(**env) + +def extract_vt_symbol(vt_symbol: str): + """ + :return: (symbol, exchange) + """ + symbol, exchange = vt_symbol.split('.') + return symbol, exchange + + +def generate_vt_symbol(symbol: str, exchange: "Exchange"): + return f'{symbol}.{exchange.value}' def _get_trader_dir(temp_name: str):