diff --git a/requirements.txt b/requirements.txt index 2b64ee1d..ee935ce6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ PyQt5<5.12 +pyqtgraph dataclasses; python_version<="3.6" qdarkstyle requests diff --git a/vnpy/trader/database.py b/vnpy/trader/database.py index 87fb3eae..1e90dcbc 100644 --- a/vnpy/trader/database.py +++ b/vnpy/trader/database.py @@ -1,13 +1,47 @@ """""" -from peewee import SqliteDatabase, Model, CharField, DateTimeField, FloatField +from peewee import CharField, DateTimeField, FloatField, Model, MySQLDatabase, PostgresqlDatabase, \ + SqliteDatabase from .constant import Exchange, Interval from .object import BarData, TickData -from .utility import get_file_path +from .setting import SETTINGS +from .utility import resolve_path -DB_NAME = "database.db" -DB = SqliteDatabase(str(get_file_path(DB_NAME))) + +def init(): + db_settings = SETTINGS['database'] + driver = db_settings["driver"] + + init_funcs = { + "sqlite": init_sqlite, + "mysql": init_mysql, + "postgresql": init_postgresql, + } + + assert driver in init_funcs + del db_settings['driver'] + return init_funcs[driver](db_settings) + + +def init_sqlite(settings: dict): + global DB + database = settings['database'] + + DB = SqliteDatabase(str(resolve_path(database))) + + +def init_mysql(settings: dict): + global DB + DB = MySQLDatabase(**settings) + + +def init_postgresql(settings: dict): + global DB + DB = PostgresqlDatabase(**settings) + + +init() class DbBarData(Model): diff --git a/vnpy/trader/setting.py b/vnpy/trader/setting.py index ae2011f9..b27dbf96 100644 --- a/vnpy/trader/setting.py +++ b/vnpy/trader/setting.py @@ -23,10 +23,17 @@ SETTINGS = { "email.receiver": "", "rqdata.username": "", - "rqdata.password": "" + "rqdata.password": "", + "database": { + "driver": "sqlite", # sqlite, mysql, postgresql + "database": "{VNPY_TEMP}/database.db", # for sqlite, use this as filepath + "host": "localhost", + "port": 3306, + "user": "root", + "password": "" + } } - # Load global setting from json file. SETTING_FILENAME = "vt_setting.json" SETTINGS.update(load_json(SETTING_FILENAME)) diff --git a/vnpy/trader/utility.py b/vnpy/trader/utility.py index 1f40aa77..216dcbae 100644 --- a/vnpy/trader/utility.py +++ b/vnpy/trader/utility.py @@ -3,6 +3,7 @@ General utility functions. """ import json +import os from pathlib import Path from typing import Callable @@ -12,6 +13,12 @@ import talib from .object import BarData, TickData +def resolve_path(pattern: str): + env = dict(os.environ) + env.update({"VNPY_TEMP": str(TEMP_DIR)}) + return pattern.format(**env) + + def _get_trader_dir(temp_name: str): """ Get path where trader is running in.