Merge pull request #1830 from nanoric/database_open_interest

[Add] database: open_interest
This commit is contained in:
vn.py 2019-06-14 15:44:23 +08:00 committed by GitHub
commit bc900dfc82
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 42 additions and 28 deletions

View File

@ -1,6 +1,6 @@
from datetime import datetime
from enum import Enum
from typing import Sequence, Optional
from typing import Optional, Sequence
from mongoengine import DateTimeField, Document, FloatField, StringField, connect
@ -47,6 +47,7 @@ class DbBarData(Document):
interval: str = StringField()
volume: float = FloatField()
open_interest: float = FloatField()
open_price: float = FloatField()
high_price: float = FloatField()
low_price: float = FloatField()
@ -73,6 +74,7 @@ class DbBarData(Document):
db_bar.datetime = bar.datetime
db_bar.interval = bar.interval.value
db_bar.volume = bar.volume
db_bar.open_interest = bar.open_interest
db_bar.open_price = bar.open_price
db_bar.high_price = bar.high_price
db_bar.low_price = bar.low_price
@ -90,6 +92,7 @@ class DbBarData(Document):
datetime=self.datetime,
interval=Interval(self.interval),
volume=self.volume,
open_interest=self.open_interest,
open_price=self.open_price,
high_price=self.high_price,
low_price=self.low_price,
@ -112,6 +115,7 @@ class DbTickData(Document):
name: str = StringField()
volume: float = FloatField()
open_interest: float = FloatField()
last_price: float = FloatField()
last_volume: float = FloatField()
limit_up: float = FloatField()
@ -168,6 +172,7 @@ class DbTickData(Document):
db_tick.datetime = tick.datetime
db_tick.name = tick.name
db_tick.volume = tick.volume
db_tick.open_interest = tick.open_interest
db_tick.last_price = tick.last_price
db_tick.last_volume = tick.last_volume
db_tick.limit_up = tick.limit_up
@ -215,6 +220,7 @@ class DbTickData(Document):
datetime=self.datetime,
name=self.name,
volume=self.volume,
open_interest=self.open_interest,
last_price=self.last_price,
last_volume=self.last_volume,
limit_up=self.limit_up,
@ -255,6 +261,7 @@ class DbTickData(Document):
class MongoManager(BaseDatabaseManager):
def load_bar_data(
self,
symbol: str,
@ -319,8 +326,8 @@ class MongoManager(BaseDatabaseManager):
) -> Optional["BarData"]:
s = (
DbBarData.objects(symbol=symbol, exchange=exchange.value)
.order_by("-datetime")
.first()
.order_by("-datetime")
.first()
)
if s:
return s.to_bar()
@ -331,8 +338,8 @@ class MongoManager(BaseDatabaseManager):
) -> Optional["TickData"]:
s = (
DbTickData.objects(symbol=symbol, exchange=exchange.value)
.order_by("-datetime")
.first()
.order_by("-datetime")
.first()
)
if s:
return s.to_tick()

View File

@ -56,6 +56,7 @@ def init_postgresql(settings: dict):
class ModelBase(Model):
def to_dict(self):
return self.__data__
@ -75,6 +76,7 @@ def init_models(db: Database, driver: Driver):
interval: str = CharField()
volume: float = FloatField()
open_interest: float = FloatField()
open_price: float = FloatField()
high_price: float = FloatField()
low_price: float = FloatField()
@ -96,6 +98,7 @@ def init_models(db: Database, driver: Driver):
db_bar.datetime = bar.datetime
db_bar.interval = bar.interval.value
db_bar.volume = bar.volume
db_bar.open_interest = bar.open_interest
db_bar.open_price = bar.open_price
db_bar.high_price = bar.high_price
db_bar.low_price = bar.low_price
@ -133,10 +136,10 @@ def init_models(db: Database, driver: Driver):
DbBarData.insert(bar).on_conflict(
update=bar,
conflict_target=(
DbBarData.datetime,
DbBarData.interval,
DbBarData.symbol,
DbBarData.exchange,
DbBarData.interval,
DbBarData.datetime,
),
).execute()
else:
@ -159,6 +162,7 @@ def init_models(db: Database, driver: Driver):
name: str = CharField()
volume: float = FloatField()
open_interest: float = FloatField()
last_price: float = FloatField()
last_volume: float = FloatField()
limit_up: float = FloatField()
@ -209,6 +213,7 @@ def init_models(db: Database, driver: Driver):
db_tick.datetime = tick.datetime
db_tick.name = tick.name
db_tick.volume = tick.volume
db_tick.open_interest = tick.open_interest
db_tick.last_price = tick.last_price
db_tick.last_volume = tick.last_volume
db_tick.limit_up = tick.limit_up
@ -256,6 +261,7 @@ def init_models(db: Database, driver: Driver):
datetime=self.datetime,
name=self.name,
volume=self.volume,
open_interest=self.open_interest,
last_price=self.last_price,
last_volume=self.last_volume,
limit_up=self.limit_up,
@ -303,9 +309,9 @@ def init_models(db: Database, driver: Driver):
DbTickData.insert(tick).on_conflict(
update=tick,
conflict_target=(
DbTickData.datetime,
DbTickData.symbol,
DbTickData.exchange,
DbTickData.datetime,
),
).execute()
else:
@ -318,6 +324,7 @@ def init_models(db: Database, driver: Driver):
class SqlManager(BaseDatabaseManager):
def __init__(self, class_bar: Type[Model], class_tick: Type[Model]):
self.class_bar = class_bar
self.class_tick = class_tick
@ -332,14 +339,14 @@ class SqlManager(BaseDatabaseManager):
) -> Sequence[BarData]:
s = (
self.class_bar.select()
.where(
(self.class_bar.symbol == symbol)
& (self.class_bar.exchange == exchange.value)
& (self.class_bar.interval == interval.value)
& (self.class_bar.datetime >= start)
.where(
(self.class_bar.symbol == symbol)
& (self.class_bar.exchange == exchange.value)
& (self.class_bar.interval == interval.value)
& (self.class_bar.datetime >= start)
& (self.class_bar.datetime <= end)
)
.order_by(self.class_bar.datetime)
.order_by(self.class_bar.datetime)
)
data = [db_bar.to_bar() for db_bar in s]
return data
@ -349,13 +356,13 @@ class SqlManager(BaseDatabaseManager):
) -> Sequence[TickData]:
s = (
self.class_tick.select()
.where(
(self.class_tick.symbol == symbol)
& (self.class_tick.exchange == exchange.value)
& (self.class_tick.datetime >= start)
.where(
(self.class_tick.symbol == symbol)
& (self.class_tick.exchange == exchange.value)
& (self.class_tick.datetime >= start)
& (self.class_tick.datetime <= end)
)
.order_by(self.class_tick.datetime)
.order_by(self.class_tick.datetime)
)
data = [db_tick.to_tick() for db_tick in s]
@ -374,13 +381,13 @@ class SqlManager(BaseDatabaseManager):
) -> Optional["BarData"]:
s = (
self.class_bar.select()
.where(
(self.class_bar.symbol == symbol)
& (self.class_bar.exchange == exchange.value)
.where(
(self.class_bar.symbol == symbol)
& (self.class_bar.exchange == exchange.value)
& (self.class_bar.interval == interval.value)
)
.order_by(self.class_bar.datetime.desc())
.first()
.order_by(self.class_bar.datetime.desc())
.first()
)
if s:
return s.to_bar()
@ -391,12 +398,12 @@ class SqlManager(BaseDatabaseManager):
) -> Optional["TickData"]:
s = (
self.class_tick.select()
.where(
(self.class_tick.symbol == symbol)
.where(
(self.class_tick.symbol == symbol)
& (self.class_tick.exchange == exchange.value)
)
.order_by(self.class_tick.datetime.desc())
.first()
.order_by(self.class_tick.datetime.desc())
.first()
)
if s:
return s.to_tick()