[Add] added support for PostgreSQL

This commit is contained in:
nanoric 2019-04-12 04:41:56 -04:00
parent 5c5e6e9e7e
commit 1074a26b77
5 changed files with 278 additions and 110 deletions

View File

@ -0,0 +1,90 @@
"""
Test if csv loader works fine
"""
import tempfile
import unittest
from vnpy.app.csv_loader import CsvLoaderEngine
from vnpy.trader.constant import Exchange, Interval
class TestCsvLoader(unittest.TestCase):
def setUp(self) -> None:
self.engine = CsvLoaderEngine(None, None) # no engine is necessary for CsvLoader
def test_load(self):
data = """"Datetime","Open","High","Low","Close","Volume"
2010-04-16 09:16:00,3450.0,3488.0,3450.0,3468.0,489
2010-04-16 09:17:00,3468.0,3473.8,3467.0,3467.0,302
2010-04-16 09:18:00,3467.0,3471.0,3466.0,3467.0,203
2010-04-16 09:19:00,3467.0,3468.2,3448.0,3448.0,280
2010-04-16 09:20:00,3448.0,3459.0,3448.0,3454.0,250
2010-04-16 09:21:00,3454.0,3456.8,3454.0,3456.8,109
"""
with tempfile.TemporaryFile("w+t") as f:
f.write(data)
f.seek(0)
self.engine.load_by_handle(
f,
symbol="1",
exchange=Exchange.BITMEX,
interval=Interval.MINUTE,
datetime_head="Datetime",
open_head="Open",
close_head="Close",
low_head="Low",
high_head="High",
volume_head="Volume",
datetime_format="%Y-%m-%d %H:%M:%S",
)
def test_load_duplicated(self):
data = """"Datetime","Open","High","Low","Close","Volume"
2010-04-16 09:16:00,3450.0,3488.0,3450.0,3468.0,489
2010-04-16 09:17:00,3468.0,3473.8,3467.0,3467.0,302
2010-04-16 09:18:00,3467.0,3471.0,3466.0,3467.0,203
2010-04-16 09:19:00,3467.0,3468.2,3448.0,3448.0,280
2010-04-16 09:20:00,3448.0,3459.0,3448.0,3454.0,250
2010-04-16 09:21:00,3454.0,3456.8,3454.0,3456.8,109
"""
with tempfile.TemporaryFile("w+t") as f:
f.write(data)
f.seek(0)
self.engine.load_by_handle(
f,
symbol="1",
exchange=Exchange.BITMEX,
interval=Interval.MINUTE,
datetime_head="Datetime",
open_head="Open",
close_head="Close",
low_head="Low",
high_head="High",
volume_head="Volume",
datetime_format="%Y-%m-%d %H:%M:%S",
)
with tempfile.TemporaryFile("w+t") as f:
f.write(data)
f.seek(0)
self.engine.load_by_handle(
f,
symbol="1",
exchange=Exchange.BITMEX,
interval=Interval.MINUTE,
datetime_head="Datetime",
open_head="Open",
close_head="Close",
low_head="Low",
high_head="High",
volume_head="Volume",
datetime_format="%Y-%m-%d %H:%M:%S",
)
if __name__ == "__main__":
unittest.main()

View File

@ -2,7 +2,7 @@ from time import time
import rqdatac as rq
from vnpy.trader.database import DbBarData, DB
from vnpy.trader.database import DbBarData
USERNAME = ""
PASSWORD = ""
@ -39,11 +39,11 @@ def download_minute_bar(vt_symbol):
df = rq.get_price(symbol, frequency="1m", fields=FIELDS)
with DB.atomic():
for ix, row in df.iterrows():
print(row.name)
bar = generate_bar_from_row(row, symbol, exchange)
DbBarData.replace(bar.__data__).execute()
bars = []
for ix, row in df.iterrows():
bar = generate_bar_from_row(row, symbol, exchange)
bars.append(bar)
DbBarData.save_all(bars)
end = time()
cost = (end - start) * 1000

View File

@ -22,15 +22,13 @@ Sample csv file:
import csv
from datetime import datetime
from peewee import chunked
from typing import TextIO
from vnpy.event import EventEngine
from vnpy.trader.constant import Exchange, Interval
from vnpy.trader.database import DbBarData, DB
from vnpy.trader.database import DbBarData
from vnpy.trader.engine import BaseEngine, MainEngine
APP_NAME = "CsvLoader"
@ -41,17 +39,70 @@ class CsvLoaderEngine(BaseEngine):
""""""
super().__init__(main_engine, event_engine, APP_NAME)
self.file_path: str = ""
self.file_path: str = ''
self.symbol: str = ""
self.exchange: Exchange = Exchange.SSE
self.interval: Interval = Interval.MINUTE
self.datetime_head: str = ""
self.open_head: str = ""
self.close_head: str = ""
self.low_head: str = ""
self.high_head: str = ""
self.volume_head: str = ""
self.datetime_head: str = ''
self.open_head: str = ''
self.close_head: str = ''
self.low_head: str = ''
self.high_head: str = ''
self.volume_head: str = ''
def load_by_handle(
self,
f: TextIO,
symbol: str,
exchange: Exchange,
interval: Interval,
datetime_head: str,
open_head: str,
close_head: str,
low_head: str,
high_head: str,
volume_head: str,
datetime_format: str
):
"""
load by text mode file handle
"""
reader = csv.DictReader(f)
db_bars = []
start = None
count = 0
for item in reader:
if datetime_format:
dt = datetime.strptime(item[datetime_head], datetime_format)
else:
dt = datetime.fromisoformat(item[datetime_head])
db_bar = DbBarData(
symbol=symbol,
exchange=exchange.value,
datetime=dt,
interval=interval.value,
volume=item[volume_head],
open_price=item[open_head],
high_price=item[high_head],
low_price=item[low_head],
close_price=item[close_head],
)
db_bars.append(db_bar)
# do some statistics
count += 1
if not start:
start = db_bar.datetime
end = db_bar.datetime
# insert into database
DbBarData.save_all(db_bars)
return start, end, count
def load(
self,
@ -67,47 +118,20 @@ class CsvLoaderEngine(BaseEngine):
volume_head: str,
datetime_format: str
):
""""""
vt_symbol = f"{symbol}.{exchange.value}"
start = None
end = None
count = 0
with open(file_path, "rt") as f:
reader = csv.DictReader(f)
db_bars = []
for item in reader:
dt = datetime.strptime(item[datetime_head], datetime_format)
db_bar = {
"symbol": symbol,
"exchange": exchange.value,
"datetime": dt,
"interval": interval.value,
"volume": item[volume_head],
"open_price": item[open_head],
"high_price": item[high_head],
"low_price": item[low_head],
"close_price": item[close_head],
"vt_symbol": vt_symbol,
"gateway_name": "DB"
}
db_bars.append(db_bar)
# do some statistics
count += 1
if not start:
start = db_bar["datetime"]
end = db_bar["datetime"]
# Insert into DB
with DB.atomic():
for batch in chunked(db_bars, 50):
DbBarData.insert_many(batch).on_conflict_replace().execute()
return start, end, count
"""
load by filename
"""
with open(file_path, 'rt') as f:
return self.load_by_handle(
f,
symbol=symbol,
exchange=exchange,
interval=interval,
datetime_head=datetime_head,
open_head=open_head,
close_head=close_head,
low_head=low_head,
high_head=high_head,
volume_head=volume_head,
datetime_format=datetime_format,
)

View File

@ -1,56 +1,85 @@
""""""
from enum import Enum
from typing import List
from peewee import CharField, DateTimeField, FloatField, Model, MySQLDatabase, PostgresqlDatabase, \
SqliteDatabase
from peewee import (
AutoField,
CharField,
Database,
DateTimeField,
FloatField,
Model,
MySQLDatabase,
PostgresqlDatabase,
SqliteDatabase,
chunked,
)
from .constant import Exchange, Interval
from .object import BarData, TickData
from .setting import SETTINGS
from .utility import resolve_path
from .utility import get_file_path
class Driver(Enum):
SQLITE = "sqlite"
MYSQL = "mysql"
POSTGRESQL = "postgresql"
_db: Database
_driver: Driver
def init():
db_settings = SETTINGS['database']
driver = db_settings["driver"]
global _driver
db_settings = {k[9:]: v for k, v in SETTINGS.items() if k.startswith("database.")}
_driver = Driver(db_settings["driver"])
init_funcs = {
"sqlite": init_sqlite,
"mysql": init_mysql,
"postgresql": init_postgresql,
Driver.SQLITE: init_sqlite,
Driver.MYSQL: init_mysql,
Driver.POSTGRESQL: init_postgresql,
}
assert driver in init_funcs
del db_settings['driver']
return init_funcs[driver](db_settings)
assert _driver in init_funcs
del db_settings["driver"]
return init_funcs[_driver](db_settings)
def init_sqlite(settings: dict):
global DB
database = settings['database']
global _db
database = settings["database"]
DB = SqliteDatabase(str(resolve_path(database)))
_db = SqliteDatabase(str(get_file_path(database)))
def init_mysql(settings: dict):
global DB
DB = MySQLDatabase(**settings)
global _db
_db = MySQLDatabase(**settings)
def init_postgresql(settings: dict):
global DB
DB = PostgresqlDatabase(**settings)
global _db
_db = PostgresqlDatabase(**settings)
init()
class DbBarData(Model):
class ModelBase(Model):
def to_dict(self):
return self.__data__
class DbBarData(ModelBase):
"""
Candlestick bar data for database storage.
Index is defined unique with vt_symbol, interval and datetime.
Index is defined unique with datetime, interval, symbol
"""
id = AutoField()
symbol = CharField()
exchange = CharField()
datetime = DateTimeField()
@ -62,12 +91,9 @@ class DbBarData(Model):
low_price = FloatField()
close_price = FloatField()
vt_symbol = CharField()
gateway_name = CharField()
class Meta:
database = DB
indexes = ((("vt_symbol", "interval", "datetime"), True),)
database = _db
indexes = ((("datetime", "interval", "symbol"), True),)
@staticmethod
def from_bar(bar: BarData):
@ -85,8 +111,6 @@ class DbBarData(Model):
db_bar.high_price = bar.high_price
db_bar.low_price = bar.low_price
db_bar.close_price = bar.close_price
db_bar.vt_symbol = bar.vt_symbol
db_bar.gateway_name = "DB"
return db_bar
@ -104,18 +128,40 @@ class DbBarData(Model):
high_price=self.high_price,
low_price=self.low_price,
close_price=self.close_price,
gateway_name=self.gateway_name,
gateway_name="DB",
)
return bar
@staticmethod
def save_all(objs: List["DbBarData"]):
"""
save a list of objects, update if exists.
"""
with _db.atomic():
if _driver is Driver.POSTGRESQL:
for bar in objs:
DbBarData.insert(bar.to_dict()).on_conflict(
update=bar.to_dict(),
conflict_target=(
DbBarData.datetime,
DbBarData.interval,
DbBarData.symbol,
),
).execute()
else:
for c in chunked(objs, 50):
DbBarData.insert_many(c).on_conflict_replace()
class DbTickData(Model):
class DbTickData(ModelBase):
"""
Tick data for database storage.
Index is defined unique with vt_symbol, interval and datetime.
Index is defined unique with (datetime, symbol)
"""
id = AutoField()
symbol = CharField()
exchange = CharField()
datetime = DateTimeField()
@ -131,6 +177,7 @@ class DbTickData(Model):
high_price = FloatField()
low_price = FloatField()
close_price = FloatField()
pre_close = FloatField()
bid_price_1 = FloatField()
bid_price_2 = FloatField()
@ -156,12 +203,9 @@ class DbTickData(Model):
ask_volume_4 = FloatField()
ask_volume_5 = FloatField()
vt_symbol = CharField()
gateway_name = CharField()
class Meta:
database = DB
indexes = ((("vt_symbol", "datetime"), True),)
database = _db
indexes = ((("datetime", "symbol"), True),)
@staticmethod
def from_tick(tick: TickData):
@ -210,9 +254,6 @@ class DbTickData(Model):
db_tick.ask_volume_4 = tick.ask_volume_4
db_tick.ask_volume_5 = tick.ask_volume_5
db_tick.vt_symbol = tick.vt_symbol
db_tick.gateway_name = "DB"
return tick
def to_tick(self):
@ -237,7 +278,7 @@ class DbTickData(Model):
ask_price_1=self.ask_price_1,
bid_volume_1=self.bid_volume_1,
ask_volume_1=self.ask_volume_1,
gateway_name=self.gateway_name,
gateway_name="DB",
)
if self.bid_price_2:
@ -263,6 +304,20 @@ class DbTickData(Model):
return tick
@staticmethod
def save_all(objs: List["DbTickData"]):
with _db.atomic():
if _driver is Driver.POSTGRESQL:
for bar in objs:
DbTickData.insert(bar.to_dict()).on_conflict(
update=bar.to_dict(),
preserve=(DbTickData.id),
conflict_target=(DbTickData.datetime, DbTickData.symbol),
).execute()
else:
for c in chunked(objs, 50):
DbBarData.insert_many(c).on_conflict_replace()
DB.connect()
DB.create_tables([DbBarData, DbTickData])
_db.connect()
_db.create_tables([DbBarData, DbTickData])

View File

@ -24,14 +24,13 @@ SETTINGS = {
"rqdata.username": "",
"rqdata.password": "",
"database": {
"driver": "sqlite", # sqlite, mysql, postgresql
"database": "{VNPY_TEMP}/database.db", # for sqlite, use this as filepath
"host": "localhost",
"port": 3306,
"user": "root",
"password": ""
}
"database.driver": "sqlite", # see database.Driver
"database.database": "database.db", # for sqlite, use this as filepath
"database.host": "localhost",
"database.port": 3306,
"database.user": "root",
"database.password": ""
}
# Load global setting from json file.