diff --git a/vnpy/trader/database/database_mongo.py b/vnpy/trader/database/database_mongo.py index 19292e66..333f54ca 100644 --- a/vnpy/trader/database/database_mongo.py +++ b/vnpy/trader/database/database_mongo.py @@ -1,6 +1,6 @@ from datetime import datetime from enum import Enum -from typing import Sequence, Optional +from typing import Optional, Sequence from mongoengine import DateTimeField, Document, FloatField, StringField, connect @@ -47,6 +47,7 @@ class DbBarData(Document): interval: str = StringField() volume: float = FloatField() + open_interest: float = FloatField() open_price: float = FloatField() high_price: float = FloatField() low_price: float = FloatField() @@ -73,6 +74,7 @@ class DbBarData(Document): db_bar.datetime = bar.datetime db_bar.interval = bar.interval.value db_bar.volume = bar.volume + db_bar.open_interest = bar.open_interest db_bar.open_price = bar.open_price db_bar.high_price = bar.high_price db_bar.low_price = bar.low_price @@ -90,6 +92,7 @@ class DbBarData(Document): datetime=self.datetime, interval=Interval(self.interval), volume=self.volume, + open_interest=self.open_interest, open_price=self.open_price, high_price=self.high_price, low_price=self.low_price, @@ -112,6 +115,7 @@ class DbTickData(Document): name: str = StringField() volume: float = FloatField() + open_interest: float = FloatField() last_price: float = FloatField() last_volume: float = FloatField() limit_up: float = FloatField() @@ -168,6 +172,7 @@ class DbTickData(Document): db_tick.datetime = tick.datetime db_tick.name = tick.name db_tick.volume = tick.volume + db_tick.open_interest = tick.open_interest db_tick.last_price = tick.last_price db_tick.last_volume = tick.last_volume db_tick.limit_up = tick.limit_up @@ -215,6 +220,7 @@ class DbTickData(Document): datetime=self.datetime, name=self.name, volume=self.volume, + open_interest=self.open_interest, last_price=self.last_price, last_volume=self.last_volume, limit_up=self.limit_up, @@ -255,6 +261,7 @@ class DbTickData(Document): class MongoManager(BaseDatabaseManager): + def load_bar_data( self, symbol: str, @@ -319,8 +326,8 @@ class MongoManager(BaseDatabaseManager): ) -> Optional["BarData"]: s = ( DbBarData.objects(symbol=symbol, exchange=exchange.value) - .order_by("-datetime") - .first() + .order_by("-datetime") + .first() ) if s: return s.to_bar() @@ -331,8 +338,8 @@ class MongoManager(BaseDatabaseManager): ) -> Optional["TickData"]: s = ( DbTickData.objects(symbol=symbol, exchange=exchange.value) - .order_by("-datetime") - .first() + .order_by("-datetime") + .first() ) if s: return s.to_tick() diff --git a/vnpy/trader/database/database_sql.py b/vnpy/trader/database/database_sql.py index 6111cfae..d3ad3c12 100644 --- a/vnpy/trader/database/database_sql.py +++ b/vnpy/trader/database/database_sql.py @@ -56,6 +56,7 @@ def init_postgresql(settings: dict): class ModelBase(Model): + def to_dict(self): return self.__data__ @@ -75,6 +76,7 @@ def init_models(db: Database, driver: Driver): interval: str = CharField() volume: float = FloatField() + open_interest: float = FloatField() open_price: float = FloatField() high_price: float = FloatField() low_price: float = FloatField() @@ -96,6 +98,7 @@ def init_models(db: Database, driver: Driver): db_bar.datetime = bar.datetime db_bar.interval = bar.interval.value db_bar.volume = bar.volume + db_bar.open_interest = bar.open_interest db_bar.open_price = bar.open_price db_bar.high_price = bar.high_price db_bar.low_price = bar.low_price @@ -133,10 +136,10 @@ def init_models(db: Database, driver: Driver): DbBarData.insert(bar).on_conflict( update=bar, conflict_target=( - DbBarData.datetime, - DbBarData.interval, DbBarData.symbol, DbBarData.exchange, + DbBarData.interval, + DbBarData.datetime, ), ).execute() else: @@ -159,6 +162,7 @@ def init_models(db: Database, driver: Driver): name: str = CharField() volume: float = FloatField() + open_interest: float = FloatField() last_price: float = FloatField() last_volume: float = FloatField() limit_up: float = FloatField() @@ -209,6 +213,7 @@ def init_models(db: Database, driver: Driver): db_tick.datetime = tick.datetime db_tick.name = tick.name db_tick.volume = tick.volume + db_tick.open_interest = tick.open_interest db_tick.last_price = tick.last_price db_tick.last_volume = tick.last_volume db_tick.limit_up = tick.limit_up @@ -256,6 +261,7 @@ def init_models(db: Database, driver: Driver): datetime=self.datetime, name=self.name, volume=self.volume, + open_interest=self.open_interest, last_price=self.last_price, last_volume=self.last_volume, limit_up=self.limit_up, @@ -303,9 +309,9 @@ def init_models(db: Database, driver: Driver): DbTickData.insert(tick).on_conflict( update=tick, conflict_target=( - DbTickData.datetime, DbTickData.symbol, DbTickData.exchange, + DbTickData.datetime, ), ).execute() else: @@ -318,6 +324,7 @@ def init_models(db: Database, driver: Driver): class SqlManager(BaseDatabaseManager): + def __init__(self, class_bar: Type[Model], class_tick: Type[Model]): self.class_bar = class_bar self.class_tick = class_tick @@ -332,14 +339,14 @@ class SqlManager(BaseDatabaseManager): ) -> Sequence[BarData]: s = ( self.class_bar.select() - .where( - (self.class_bar.symbol == symbol) - & (self.class_bar.exchange == exchange.value) - & (self.class_bar.interval == interval.value) - & (self.class_bar.datetime >= start) + .where( + (self.class_bar.symbol == symbol) + & (self.class_bar.exchange == exchange.value) + & (self.class_bar.interval == interval.value) + & (self.class_bar.datetime >= start) & (self.class_bar.datetime <= end) ) - .order_by(self.class_bar.datetime) + .order_by(self.class_bar.datetime) ) data = [db_bar.to_bar() for db_bar in s] return data @@ -349,13 +356,13 @@ class SqlManager(BaseDatabaseManager): ) -> Sequence[TickData]: s = ( self.class_tick.select() - .where( - (self.class_tick.symbol == symbol) - & (self.class_tick.exchange == exchange.value) - & (self.class_tick.datetime >= start) + .where( + (self.class_tick.symbol == symbol) + & (self.class_tick.exchange == exchange.value) + & (self.class_tick.datetime >= start) & (self.class_tick.datetime <= end) ) - .order_by(self.class_tick.datetime) + .order_by(self.class_tick.datetime) ) data = [db_tick.to_tick() for db_tick in s] @@ -374,13 +381,13 @@ class SqlManager(BaseDatabaseManager): ) -> Optional["BarData"]: s = ( self.class_bar.select() - .where( - (self.class_bar.symbol == symbol) - & (self.class_bar.exchange == exchange.value) + .where( + (self.class_bar.symbol == symbol) + & (self.class_bar.exchange == exchange.value) & (self.class_bar.interval == interval.value) ) - .order_by(self.class_bar.datetime.desc()) - .first() + .order_by(self.class_bar.datetime.desc()) + .first() ) if s: return s.to_bar() @@ -391,12 +398,12 @@ class SqlManager(BaseDatabaseManager): ) -> Optional["TickData"]: s = ( self.class_tick.select() - .where( - (self.class_tick.symbol == symbol) + .where( + (self.class_tick.symbol == symbol) & (self.class_tick.exchange == exchange.value) ) - .order_by(self.class_tick.datetime.desc()) - .first() + .order_by(self.class_tick.datetime.desc()) + .first() ) if s: return s.to_tick()