[Add] added support for PostgreSQL
This commit is contained in:
parent
5c5e6e9e7e
commit
1074a26b77
90
tests/app/test_csv_loader.py
Normal file
90
tests/app/test_csv_loader.py
Normal file
@ -0,0 +1,90 @@
|
||||
"""
|
||||
Test if csv loader works fine
|
||||
"""
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from vnpy.app.csv_loader import CsvLoaderEngine
|
||||
from vnpy.trader.constant import Exchange, Interval
|
||||
|
||||
|
||||
class TestCsvLoader(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.engine = CsvLoaderEngine(None, None) # no engine is necessary for CsvLoader
|
||||
|
||||
def test_load(self):
|
||||
data = """"Datetime","Open","High","Low","Close","Volume"
|
||||
2010-04-16 09:16:00,3450.0,3488.0,3450.0,3468.0,489
|
||||
2010-04-16 09:17:00,3468.0,3473.8,3467.0,3467.0,302
|
||||
2010-04-16 09:18:00,3467.0,3471.0,3466.0,3467.0,203
|
||||
2010-04-16 09:19:00,3467.0,3468.2,3448.0,3448.0,280
|
||||
2010-04-16 09:20:00,3448.0,3459.0,3448.0,3454.0,250
|
||||
2010-04-16 09:21:00,3454.0,3456.8,3454.0,3456.8,109
|
||||
"""
|
||||
with tempfile.TemporaryFile("w+t") as f:
|
||||
f.write(data)
|
||||
f.seek(0)
|
||||
|
||||
self.engine.load_by_handle(
|
||||
f,
|
||||
symbol="1",
|
||||
exchange=Exchange.BITMEX,
|
||||
interval=Interval.MINUTE,
|
||||
datetime_head="Datetime",
|
||||
open_head="Open",
|
||||
close_head="Close",
|
||||
low_head="Low",
|
||||
high_head="High",
|
||||
volume_head="Volume",
|
||||
datetime_format="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
def test_load_duplicated(self):
|
||||
data = """"Datetime","Open","High","Low","Close","Volume"
|
||||
2010-04-16 09:16:00,3450.0,3488.0,3450.0,3468.0,489
|
||||
2010-04-16 09:17:00,3468.0,3473.8,3467.0,3467.0,302
|
||||
2010-04-16 09:18:00,3467.0,3471.0,3466.0,3467.0,203
|
||||
2010-04-16 09:19:00,3467.0,3468.2,3448.0,3448.0,280
|
||||
2010-04-16 09:20:00,3448.0,3459.0,3448.0,3454.0,250
|
||||
2010-04-16 09:21:00,3454.0,3456.8,3454.0,3456.8,109
|
||||
"""
|
||||
with tempfile.TemporaryFile("w+t") as f:
|
||||
f.write(data)
|
||||
f.seek(0)
|
||||
|
||||
self.engine.load_by_handle(
|
||||
f,
|
||||
symbol="1",
|
||||
exchange=Exchange.BITMEX,
|
||||
interval=Interval.MINUTE,
|
||||
datetime_head="Datetime",
|
||||
open_head="Open",
|
||||
close_head="Close",
|
||||
low_head="Low",
|
||||
high_head="High",
|
||||
volume_head="Volume",
|
||||
datetime_format="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
with tempfile.TemporaryFile("w+t") as f:
|
||||
f.write(data)
|
||||
f.seek(0)
|
||||
|
||||
self.engine.load_by_handle(
|
||||
f,
|
||||
symbol="1",
|
||||
exchange=Exchange.BITMEX,
|
||||
interval=Interval.MINUTE,
|
||||
datetime_head="Datetime",
|
||||
open_head="Open",
|
||||
close_head="Close",
|
||||
low_head="Low",
|
||||
high_head="High",
|
||||
volume_head="Volume",
|
||||
datetime_format="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -2,7 +2,7 @@ from time import time
|
||||
|
||||
import rqdatac as rq
|
||||
|
||||
from vnpy.trader.database import DbBarData, DB
|
||||
from vnpy.trader.database import DbBarData
|
||||
|
||||
USERNAME = ""
|
||||
PASSWORD = ""
|
||||
@ -39,11 +39,11 @@ def download_minute_bar(vt_symbol):
|
||||
|
||||
df = rq.get_price(symbol, frequency="1m", fields=FIELDS)
|
||||
|
||||
with DB.atomic():
|
||||
for ix, row in df.iterrows():
|
||||
print(row.name)
|
||||
bar = generate_bar_from_row(row, symbol, exchange)
|
||||
DbBarData.replace(bar.__data__).execute()
|
||||
bars = []
|
||||
for ix, row in df.iterrows():
|
||||
bar = generate_bar_from_row(row, symbol, exchange)
|
||||
bars.append(bar)
|
||||
DbBarData.save_all(bars)
|
||||
|
||||
end = time()
|
||||
cost = (end - start) * 1000
|
||||
|
@ -22,15 +22,13 @@ Sample csv file:
|
||||
|
||||
import csv
|
||||
from datetime import datetime
|
||||
|
||||
from peewee import chunked
|
||||
from typing import TextIO
|
||||
|
||||
from vnpy.event import EventEngine
|
||||
from vnpy.trader.constant import Exchange, Interval
|
||||
from vnpy.trader.database import DbBarData, DB
|
||||
from vnpy.trader.database import DbBarData
|
||||
from vnpy.trader.engine import BaseEngine, MainEngine
|
||||
|
||||
|
||||
APP_NAME = "CsvLoader"
|
||||
|
||||
|
||||
@ -41,17 +39,70 @@ class CsvLoaderEngine(BaseEngine):
|
||||
""""""
|
||||
super().__init__(main_engine, event_engine, APP_NAME)
|
||||
|
||||
self.file_path: str = ""
|
||||
self.file_path: str = ''
|
||||
|
||||
self.symbol: str = ""
|
||||
self.exchange: Exchange = Exchange.SSE
|
||||
self.interval: Interval = Interval.MINUTE
|
||||
self.datetime_head: str = ""
|
||||
self.open_head: str = ""
|
||||
self.close_head: str = ""
|
||||
self.low_head: str = ""
|
||||
self.high_head: str = ""
|
||||
self.volume_head: str = ""
|
||||
self.datetime_head: str = ''
|
||||
self.open_head: str = ''
|
||||
self.close_head: str = ''
|
||||
self.low_head: str = ''
|
||||
self.high_head: str = ''
|
||||
self.volume_head: str = ''
|
||||
|
||||
def load_by_handle(
|
||||
self,
|
||||
f: TextIO,
|
||||
symbol: str,
|
||||
exchange: Exchange,
|
||||
interval: Interval,
|
||||
datetime_head: str,
|
||||
open_head: str,
|
||||
close_head: str,
|
||||
low_head: str,
|
||||
high_head: str,
|
||||
volume_head: str,
|
||||
datetime_format: str
|
||||
):
|
||||
"""
|
||||
load by text mode file handle
|
||||
"""
|
||||
reader = csv.DictReader(f)
|
||||
|
||||
db_bars = []
|
||||
start = None
|
||||
count = 0
|
||||
for item in reader:
|
||||
if datetime_format:
|
||||
dt = datetime.strptime(item[datetime_head], datetime_format)
|
||||
else:
|
||||
dt = datetime.fromisoformat(item[datetime_head])
|
||||
|
||||
db_bar = DbBarData(
|
||||
symbol=symbol,
|
||||
exchange=exchange.value,
|
||||
datetime=dt,
|
||||
interval=interval.value,
|
||||
volume=item[volume_head],
|
||||
open_price=item[open_head],
|
||||
high_price=item[high_head],
|
||||
low_price=item[low_head],
|
||||
close_price=item[close_head],
|
||||
)
|
||||
|
||||
db_bars.append(db_bar)
|
||||
|
||||
# do some statistics
|
||||
count += 1
|
||||
if not start:
|
||||
start = db_bar.datetime
|
||||
end = db_bar.datetime
|
||||
|
||||
# insert into database
|
||||
DbBarData.save_all(db_bars)
|
||||
|
||||
return start, end, count
|
||||
|
||||
def load(
|
||||
self,
|
||||
@ -67,47 +118,20 @@ class CsvLoaderEngine(BaseEngine):
|
||||
volume_head: str,
|
||||
datetime_format: str
|
||||
):
|
||||
""""""
|
||||
vt_symbol = f"{symbol}.{exchange.value}"
|
||||
|
||||
start = None
|
||||
end = None
|
||||
count = 0
|
||||
|
||||
with open(file_path, "rt") as f:
|
||||
reader = csv.DictReader(f)
|
||||
|
||||
db_bars = []
|
||||
|
||||
for item in reader:
|
||||
dt = datetime.strptime(item[datetime_head], datetime_format)
|
||||
|
||||
db_bar = {
|
||||
"symbol": symbol,
|
||||
"exchange": exchange.value,
|
||||
"datetime": dt,
|
||||
"interval": interval.value,
|
||||
"volume": item[volume_head],
|
||||
"open_price": item[open_head],
|
||||
"high_price": item[high_head],
|
||||
"low_price": item[low_head],
|
||||
"close_price": item[close_head],
|
||||
"vt_symbol": vt_symbol,
|
||||
"gateway_name": "DB"
|
||||
}
|
||||
|
||||
db_bars.append(db_bar)
|
||||
|
||||
# do some statistics
|
||||
count += 1
|
||||
if not start:
|
||||
start = db_bar["datetime"]
|
||||
|
||||
end = db_bar["datetime"]
|
||||
|
||||
# Insert into DB
|
||||
with DB.atomic():
|
||||
for batch in chunked(db_bars, 50):
|
||||
DbBarData.insert_many(batch).on_conflict_replace().execute()
|
||||
|
||||
return start, end, count
|
||||
"""
|
||||
load by filename
|
||||
"""
|
||||
with open(file_path, 'rt') as f:
|
||||
return self.load_by_handle(
|
||||
f,
|
||||
symbol=symbol,
|
||||
exchange=exchange,
|
||||
interval=interval,
|
||||
datetime_head=datetime_head,
|
||||
open_head=open_head,
|
||||
close_head=close_head,
|
||||
low_head=low_head,
|
||||
high_head=high_head,
|
||||
volume_head=volume_head,
|
||||
datetime_format=datetime_format,
|
||||
)
|
||||
|
@ -1,56 +1,85 @@
|
||||
""""""
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
from peewee import CharField, DateTimeField, FloatField, Model, MySQLDatabase, PostgresqlDatabase, \
|
||||
SqliteDatabase
|
||||
from peewee import (
|
||||
AutoField,
|
||||
CharField,
|
||||
Database,
|
||||
DateTimeField,
|
||||
FloatField,
|
||||
Model,
|
||||
MySQLDatabase,
|
||||
PostgresqlDatabase,
|
||||
SqliteDatabase,
|
||||
chunked,
|
||||
)
|
||||
|
||||
from .constant import Exchange, Interval
|
||||
from .object import BarData, TickData
|
||||
from .setting import SETTINGS
|
||||
from .utility import resolve_path
|
||||
from .utility import get_file_path
|
||||
|
||||
|
||||
class Driver(Enum):
|
||||
SQLITE = "sqlite"
|
||||
MYSQL = "mysql"
|
||||
POSTGRESQL = "postgresql"
|
||||
|
||||
|
||||
_db: Database
|
||||
_driver: Driver
|
||||
|
||||
|
||||
def init():
|
||||
db_settings = SETTINGS['database']
|
||||
driver = db_settings["driver"]
|
||||
global _driver
|
||||
db_settings = {k[9:]: v for k, v in SETTINGS.items() if k.startswith("database.")}
|
||||
_driver = Driver(db_settings["driver"])
|
||||
|
||||
init_funcs = {
|
||||
"sqlite": init_sqlite,
|
||||
"mysql": init_mysql,
|
||||
"postgresql": init_postgresql,
|
||||
Driver.SQLITE: init_sqlite,
|
||||
Driver.MYSQL: init_mysql,
|
||||
Driver.POSTGRESQL: init_postgresql,
|
||||
}
|
||||
|
||||
assert driver in init_funcs
|
||||
del db_settings['driver']
|
||||
return init_funcs[driver](db_settings)
|
||||
assert _driver in init_funcs
|
||||
del db_settings["driver"]
|
||||
return init_funcs[_driver](db_settings)
|
||||
|
||||
|
||||
def init_sqlite(settings: dict):
|
||||
global DB
|
||||
database = settings['database']
|
||||
global _db
|
||||
database = settings["database"]
|
||||
|
||||
DB = SqliteDatabase(str(resolve_path(database)))
|
||||
_db = SqliteDatabase(str(get_file_path(database)))
|
||||
|
||||
|
||||
def init_mysql(settings: dict):
|
||||
global DB
|
||||
DB = MySQLDatabase(**settings)
|
||||
global _db
|
||||
_db = MySQLDatabase(**settings)
|
||||
|
||||
|
||||
def init_postgresql(settings: dict):
|
||||
global DB
|
||||
DB = PostgresqlDatabase(**settings)
|
||||
global _db
|
||||
_db = PostgresqlDatabase(**settings)
|
||||
|
||||
|
||||
init()
|
||||
|
||||
|
||||
class DbBarData(Model):
|
||||
class ModelBase(Model):
|
||||
def to_dict(self):
|
||||
return self.__data__
|
||||
|
||||
|
||||
class DbBarData(ModelBase):
|
||||
"""
|
||||
Candlestick bar data for database storage.
|
||||
|
||||
Index is defined unique with vt_symbol, interval and datetime.
|
||||
Index is defined unique with datetime, interval, symbol
|
||||
"""
|
||||
|
||||
id = AutoField()
|
||||
symbol = CharField()
|
||||
exchange = CharField()
|
||||
datetime = DateTimeField()
|
||||
@ -62,12 +91,9 @@ class DbBarData(Model):
|
||||
low_price = FloatField()
|
||||
close_price = FloatField()
|
||||
|
||||
vt_symbol = CharField()
|
||||
gateway_name = CharField()
|
||||
|
||||
class Meta:
|
||||
database = DB
|
||||
indexes = ((("vt_symbol", "interval", "datetime"), True),)
|
||||
database = _db
|
||||
indexes = ((("datetime", "interval", "symbol"), True),)
|
||||
|
||||
@staticmethod
|
||||
def from_bar(bar: BarData):
|
||||
@ -85,8 +111,6 @@ class DbBarData(Model):
|
||||
db_bar.high_price = bar.high_price
|
||||
db_bar.low_price = bar.low_price
|
||||
db_bar.close_price = bar.close_price
|
||||
db_bar.vt_symbol = bar.vt_symbol
|
||||
db_bar.gateway_name = "DB"
|
||||
|
||||
return db_bar
|
||||
|
||||
@ -104,18 +128,40 @@ class DbBarData(Model):
|
||||
high_price=self.high_price,
|
||||
low_price=self.low_price,
|
||||
close_price=self.close_price,
|
||||
gateway_name=self.gateway_name,
|
||||
gateway_name="DB",
|
||||
)
|
||||
return bar
|
||||
|
||||
@staticmethod
|
||||
def save_all(objs: List["DbBarData"]):
|
||||
"""
|
||||
save a list of objects, update if exists.
|
||||
"""
|
||||
with _db.atomic():
|
||||
if _driver is Driver.POSTGRESQL:
|
||||
for bar in objs:
|
||||
DbBarData.insert(bar.to_dict()).on_conflict(
|
||||
update=bar.to_dict(),
|
||||
conflict_target=(
|
||||
DbBarData.datetime,
|
||||
DbBarData.interval,
|
||||
DbBarData.symbol,
|
||||
),
|
||||
).execute()
|
||||
else:
|
||||
for c in chunked(objs, 50):
|
||||
DbBarData.insert_many(c).on_conflict_replace()
|
||||
|
||||
class DbTickData(Model):
|
||||
|
||||
class DbTickData(ModelBase):
|
||||
"""
|
||||
Tick data for database storage.
|
||||
|
||||
Index is defined unique with vt_symbol, interval and datetime.
|
||||
Index is defined unique with (datetime, symbol)
|
||||
"""
|
||||
|
||||
id = AutoField()
|
||||
|
||||
symbol = CharField()
|
||||
exchange = CharField()
|
||||
datetime = DateTimeField()
|
||||
@ -131,6 +177,7 @@ class DbTickData(Model):
|
||||
high_price = FloatField()
|
||||
low_price = FloatField()
|
||||
close_price = FloatField()
|
||||
pre_close = FloatField()
|
||||
|
||||
bid_price_1 = FloatField()
|
||||
bid_price_2 = FloatField()
|
||||
@ -156,12 +203,9 @@ class DbTickData(Model):
|
||||
ask_volume_4 = FloatField()
|
||||
ask_volume_5 = FloatField()
|
||||
|
||||
vt_symbol = CharField()
|
||||
gateway_name = CharField()
|
||||
|
||||
class Meta:
|
||||
database = DB
|
||||
indexes = ((("vt_symbol", "datetime"), True),)
|
||||
database = _db
|
||||
indexes = ((("datetime", "symbol"), True),)
|
||||
|
||||
@staticmethod
|
||||
def from_tick(tick: TickData):
|
||||
@ -210,9 +254,6 @@ class DbTickData(Model):
|
||||
db_tick.ask_volume_4 = tick.ask_volume_4
|
||||
db_tick.ask_volume_5 = tick.ask_volume_5
|
||||
|
||||
db_tick.vt_symbol = tick.vt_symbol
|
||||
db_tick.gateway_name = "DB"
|
||||
|
||||
return tick
|
||||
|
||||
def to_tick(self):
|
||||
@ -237,7 +278,7 @@ class DbTickData(Model):
|
||||
ask_price_1=self.ask_price_1,
|
||||
bid_volume_1=self.bid_volume_1,
|
||||
ask_volume_1=self.ask_volume_1,
|
||||
gateway_name=self.gateway_name,
|
||||
gateway_name="DB",
|
||||
)
|
||||
|
||||
if self.bid_price_2:
|
||||
@ -263,6 +304,20 @@ class DbTickData(Model):
|
||||
|
||||
return tick
|
||||
|
||||
@staticmethod
|
||||
def save_all(objs: List["DbTickData"]):
|
||||
with _db.atomic():
|
||||
if _driver is Driver.POSTGRESQL:
|
||||
for bar in objs:
|
||||
DbTickData.insert(bar.to_dict()).on_conflict(
|
||||
update=bar.to_dict(),
|
||||
preserve=(DbTickData.id),
|
||||
conflict_target=(DbTickData.datetime, DbTickData.symbol),
|
||||
).execute()
|
||||
else:
|
||||
for c in chunked(objs, 50):
|
||||
DbBarData.insert_many(c).on_conflict_replace()
|
||||
|
||||
DB.connect()
|
||||
DB.create_tables([DbBarData, DbTickData])
|
||||
|
||||
_db.connect()
|
||||
_db.create_tables([DbBarData, DbTickData])
|
||||
|
@ -24,14 +24,13 @@ SETTINGS = {
|
||||
|
||||
"rqdata.username": "",
|
||||
"rqdata.password": "",
|
||||
"database": {
|
||||
"driver": "sqlite", # sqlite, mysql, postgresql
|
||||
"database": "{VNPY_TEMP}/database.db", # for sqlite, use this as filepath
|
||||
"host": "localhost",
|
||||
"port": 3306,
|
||||
"user": "root",
|
||||
"password": ""
|
||||
}
|
||||
|
||||
"database.driver": "sqlite", # see database.Driver
|
||||
"database.database": "database.db", # for sqlite, use this as filepath
|
||||
"database.host": "localhost",
|
||||
"database.port": 3306,
|
||||
"database.user": "root",
|
||||
"database.password": ""
|
||||
}
|
||||
|
||||
# Load global setting from json file.
|
||||
|
Loading…
Reference in New Issue
Block a user