Merge pull request #1830 from nanoric/database_open_interest

[Add] database: open_interest
This commit is contained in:
vn.py 2019-06-14 15:44:23 +08:00 committed by GitHub
commit bc900dfc82
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 42 additions and 28 deletions

View File

@ -1,6 +1,6 @@
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import Sequence, Optional from typing import Optional, Sequence
from mongoengine import DateTimeField, Document, FloatField, StringField, connect from mongoengine import DateTimeField, Document, FloatField, StringField, connect
@ -47,6 +47,7 @@ class DbBarData(Document):
interval: str = StringField() interval: str = StringField()
volume: float = FloatField() volume: float = FloatField()
open_interest: float = FloatField()
open_price: float = FloatField() open_price: float = FloatField()
high_price: float = FloatField() high_price: float = FloatField()
low_price: float = FloatField() low_price: float = FloatField()
@ -73,6 +74,7 @@ class DbBarData(Document):
db_bar.datetime = bar.datetime db_bar.datetime = bar.datetime
db_bar.interval = bar.interval.value db_bar.interval = bar.interval.value
db_bar.volume = bar.volume db_bar.volume = bar.volume
db_bar.open_interest = bar.open_interest
db_bar.open_price = bar.open_price db_bar.open_price = bar.open_price
db_bar.high_price = bar.high_price db_bar.high_price = bar.high_price
db_bar.low_price = bar.low_price db_bar.low_price = bar.low_price
@ -90,6 +92,7 @@ class DbBarData(Document):
datetime=self.datetime, datetime=self.datetime,
interval=Interval(self.interval), interval=Interval(self.interval),
volume=self.volume, volume=self.volume,
open_interest=self.open_interest,
open_price=self.open_price, open_price=self.open_price,
high_price=self.high_price, high_price=self.high_price,
low_price=self.low_price, low_price=self.low_price,
@ -112,6 +115,7 @@ class DbTickData(Document):
name: str = StringField() name: str = StringField()
volume: float = FloatField() volume: float = FloatField()
open_interest: float = FloatField()
last_price: float = FloatField() last_price: float = FloatField()
last_volume: float = FloatField() last_volume: float = FloatField()
limit_up: float = FloatField() limit_up: float = FloatField()
@ -168,6 +172,7 @@ class DbTickData(Document):
db_tick.datetime = tick.datetime db_tick.datetime = tick.datetime
db_tick.name = tick.name db_tick.name = tick.name
db_tick.volume = tick.volume db_tick.volume = tick.volume
db_tick.open_interest = tick.open_interest
db_tick.last_price = tick.last_price db_tick.last_price = tick.last_price
db_tick.last_volume = tick.last_volume db_tick.last_volume = tick.last_volume
db_tick.limit_up = tick.limit_up db_tick.limit_up = tick.limit_up
@ -215,6 +220,7 @@ class DbTickData(Document):
datetime=self.datetime, datetime=self.datetime,
name=self.name, name=self.name,
volume=self.volume, volume=self.volume,
open_interest=self.open_interest,
last_price=self.last_price, last_price=self.last_price,
last_volume=self.last_volume, last_volume=self.last_volume,
limit_up=self.limit_up, limit_up=self.limit_up,
@ -255,6 +261,7 @@ class DbTickData(Document):
class MongoManager(BaseDatabaseManager): class MongoManager(BaseDatabaseManager):
def load_bar_data( def load_bar_data(
self, self,
symbol: str, symbol: str,
@ -319,8 +326,8 @@ class MongoManager(BaseDatabaseManager):
) -> Optional["BarData"]: ) -> Optional["BarData"]:
s = ( s = (
DbBarData.objects(symbol=symbol, exchange=exchange.value) DbBarData.objects(symbol=symbol, exchange=exchange.value)
.order_by("-datetime") .order_by("-datetime")
.first() .first()
) )
if s: if s:
return s.to_bar() return s.to_bar()
@ -331,8 +338,8 @@ class MongoManager(BaseDatabaseManager):
) -> Optional["TickData"]: ) -> Optional["TickData"]:
s = ( s = (
DbTickData.objects(symbol=symbol, exchange=exchange.value) DbTickData.objects(symbol=symbol, exchange=exchange.value)
.order_by("-datetime") .order_by("-datetime")
.first() .first()
) )
if s: if s:
return s.to_tick() return s.to_tick()

View File

@ -56,6 +56,7 @@ def init_postgresql(settings: dict):
class ModelBase(Model): class ModelBase(Model):
def to_dict(self): def to_dict(self):
return self.__data__ return self.__data__
@ -75,6 +76,7 @@ def init_models(db: Database, driver: Driver):
interval: str = CharField() interval: str = CharField()
volume: float = FloatField() volume: float = FloatField()
open_interest: float = FloatField()
open_price: float = FloatField() open_price: float = FloatField()
high_price: float = FloatField() high_price: float = FloatField()
low_price: float = FloatField() low_price: float = FloatField()
@ -96,6 +98,7 @@ def init_models(db: Database, driver: Driver):
db_bar.datetime = bar.datetime db_bar.datetime = bar.datetime
db_bar.interval = bar.interval.value db_bar.interval = bar.interval.value
db_bar.volume = bar.volume db_bar.volume = bar.volume
db_bar.open_interest = bar.open_interest
db_bar.open_price = bar.open_price db_bar.open_price = bar.open_price
db_bar.high_price = bar.high_price db_bar.high_price = bar.high_price
db_bar.low_price = bar.low_price db_bar.low_price = bar.low_price
@ -133,10 +136,10 @@ def init_models(db: Database, driver: Driver):
DbBarData.insert(bar).on_conflict( DbBarData.insert(bar).on_conflict(
update=bar, update=bar,
conflict_target=( conflict_target=(
DbBarData.datetime,
DbBarData.interval,
DbBarData.symbol, DbBarData.symbol,
DbBarData.exchange, DbBarData.exchange,
DbBarData.interval,
DbBarData.datetime,
), ),
).execute() ).execute()
else: else:
@ -159,6 +162,7 @@ def init_models(db: Database, driver: Driver):
name: str = CharField() name: str = CharField()
volume: float = FloatField() volume: float = FloatField()
open_interest: float = FloatField()
last_price: float = FloatField() last_price: float = FloatField()
last_volume: float = FloatField() last_volume: float = FloatField()
limit_up: float = FloatField() limit_up: float = FloatField()
@ -209,6 +213,7 @@ def init_models(db: Database, driver: Driver):
db_tick.datetime = tick.datetime db_tick.datetime = tick.datetime
db_tick.name = tick.name db_tick.name = tick.name
db_tick.volume = tick.volume db_tick.volume = tick.volume
db_tick.open_interest = tick.open_interest
db_tick.last_price = tick.last_price db_tick.last_price = tick.last_price
db_tick.last_volume = tick.last_volume db_tick.last_volume = tick.last_volume
db_tick.limit_up = tick.limit_up db_tick.limit_up = tick.limit_up
@ -256,6 +261,7 @@ def init_models(db: Database, driver: Driver):
datetime=self.datetime, datetime=self.datetime,
name=self.name, name=self.name,
volume=self.volume, volume=self.volume,
open_interest=self.open_interest,
last_price=self.last_price, last_price=self.last_price,
last_volume=self.last_volume, last_volume=self.last_volume,
limit_up=self.limit_up, limit_up=self.limit_up,
@ -303,9 +309,9 @@ def init_models(db: Database, driver: Driver):
DbTickData.insert(tick).on_conflict( DbTickData.insert(tick).on_conflict(
update=tick, update=tick,
conflict_target=( conflict_target=(
DbTickData.datetime,
DbTickData.symbol, DbTickData.symbol,
DbTickData.exchange, DbTickData.exchange,
DbTickData.datetime,
), ),
).execute() ).execute()
else: else:
@ -318,6 +324,7 @@ def init_models(db: Database, driver: Driver):
class SqlManager(BaseDatabaseManager): class SqlManager(BaseDatabaseManager):
def __init__(self, class_bar: Type[Model], class_tick: Type[Model]): def __init__(self, class_bar: Type[Model], class_tick: Type[Model]):
self.class_bar = class_bar self.class_bar = class_bar
self.class_tick = class_tick self.class_tick = class_tick
@ -332,14 +339,14 @@ class SqlManager(BaseDatabaseManager):
) -> Sequence[BarData]: ) -> Sequence[BarData]:
s = ( s = (
self.class_bar.select() self.class_bar.select()
.where( .where(
(self.class_bar.symbol == symbol) (self.class_bar.symbol == symbol)
& (self.class_bar.exchange == exchange.value) & (self.class_bar.exchange == exchange.value)
& (self.class_bar.interval == interval.value) & (self.class_bar.interval == interval.value)
& (self.class_bar.datetime >= start) & (self.class_bar.datetime >= start)
& (self.class_bar.datetime <= end) & (self.class_bar.datetime <= end)
) )
.order_by(self.class_bar.datetime) .order_by(self.class_bar.datetime)
) )
data = [db_bar.to_bar() for db_bar in s] data = [db_bar.to_bar() for db_bar in s]
return data return data
@ -349,13 +356,13 @@ class SqlManager(BaseDatabaseManager):
) -> Sequence[TickData]: ) -> Sequence[TickData]:
s = ( s = (
self.class_tick.select() self.class_tick.select()
.where( .where(
(self.class_tick.symbol == symbol) (self.class_tick.symbol == symbol)
& (self.class_tick.exchange == exchange.value) & (self.class_tick.exchange == exchange.value)
& (self.class_tick.datetime >= start) & (self.class_tick.datetime >= start)
& (self.class_tick.datetime <= end) & (self.class_tick.datetime <= end)
) )
.order_by(self.class_tick.datetime) .order_by(self.class_tick.datetime)
) )
data = [db_tick.to_tick() for db_tick in s] data = [db_tick.to_tick() for db_tick in s]
@ -374,13 +381,13 @@ class SqlManager(BaseDatabaseManager):
) -> Optional["BarData"]: ) -> Optional["BarData"]:
s = ( s = (
self.class_bar.select() self.class_bar.select()
.where( .where(
(self.class_bar.symbol == symbol) (self.class_bar.symbol == symbol)
& (self.class_bar.exchange == exchange.value) & (self.class_bar.exchange == exchange.value)
& (self.class_bar.interval == interval.value) & (self.class_bar.interval == interval.value)
) )
.order_by(self.class_bar.datetime.desc()) .order_by(self.class_bar.datetime.desc())
.first() .first()
) )
if s: if s:
return s.to_bar() return s.to_bar()
@ -391,12 +398,12 @@ class SqlManager(BaseDatabaseManager):
) -> Optional["TickData"]: ) -> Optional["TickData"]:
s = ( s = (
self.class_tick.select() self.class_tick.select()
.where( .where(
(self.class_tick.symbol == symbol) (self.class_tick.symbol == symbol)
& (self.class_tick.exchange == exchange.value) & (self.class_tick.exchange == exchange.value)
) )
.order_by(self.class_tick.datetime.desc()) .order_by(self.class_tick.datetime.desc())
.first() .first()
) )
if s: if s:
return s.to_tick() return s.to_tick()