[Mod] ScriptEngine add support to return DataFrame

This commit is contained in:
vn.py 2019-07-02 17:33:12 +08:00
parent 1ed35377e6
commit 2c0362b3c5
2 changed files with 1765 additions and 1071 deletions

File diff suppressed because it is too large Load Diff

View File

@ -3,11 +3,13 @@
import sys import sys
import importlib import importlib
import traceback import traceback
from typing import Sequence from typing import Sequence, Callable, Any
from pathlib import Path from pathlib import Path
from datetime import datetime from datetime import datetime
from threading import Thread from threading import Thread
from pandas import DataFrame
from vnpy.event import Event, EventEngine from vnpy.event import Event, EventEngine
from vnpy.trader.engine import BaseEngine, MainEngine from vnpy.trader.engine import BaseEngine, MainEngine
from vnpy.trader.constant import Direction, Offset, OrderType, Interval from vnpy.trader.constant import Direction, Offset, OrderType, Interval
@ -158,31 +160,39 @@ class ScriptEngine(BaseEngine):
req = order.create_cancel_request() req = order.create_cancel_request()
self.main_engine.cancel_order(req, order.gateway_name) self.main_engine.cancel_order(req, order.gateway_name)
def get_tick(self, vt_symbol: str) -> TickData: def get_tick(self, vt_symbol: str, use_df: bool = False) -> TickData:
"""""" """"""
return self.main_engine.get_tick(vt_symbol) return get_data(self.main_engine.get_tick, arg=vt_symbol, use_df=use_df)
def get_ticks(self, vt_symbols: Sequence[str]) -> Sequence[TickData]: def get_ticks(self, vt_symbols: Sequence[str], use_df: bool = False) -> Sequence[TickData]:
"""""" """"""
ticks = [] ticks = []
for vt_symbol in vt_symbols: for vt_symbol in vt_symbols:
tick = self.main_engine.get_tick(vt_symbol) tick = self.main_engine.get_tick(vt_symbol)
ticks.append(tick) ticks.append(tick)
if not use_df:
return ticks return ticks
else:
return to_df(ticks)
def get_order(self, vt_orderid: str) -> OrderData: def get_order(self, vt_orderid: str, use_df: bool = False) -> OrderData:
"""""" """"""
return self.main_engine.get_order(vt_orderid) return get_data(self.main_engine.get_order, arg=vt_orderid, use_df=use_df)
def get_orders(self, vt_orderids: Sequence[str]) -> Sequence[OrderData]: def get_orders(self, vt_orderids: Sequence[str], use_df: bool = False) -> Sequence[OrderData]:
"""""" """"""
orders = [] orders = []
for vt_orderid in vt_orderids: for vt_orderid in vt_orderids:
order = self.main_engine.get_order(vt_orderid) order = self.main_engine.get_order(vt_orderid)
orders.append(order) orders.append(order)
return orders
def get_trades(self, vt_orderid: str) -> Sequence[TradeData]: if not use_df:
return orders
else:
return to_df(orders)
def get_trades(self, vt_orderid: str, use_df: bool = False) -> Sequence[TradeData]:
"""""" """"""
trades = [] trades = []
all_trades = self.main_engine.get_all_trades() all_trades = self.main_engine.get_all_trades()
@ -191,37 +201,40 @@ class ScriptEngine(BaseEngine):
if trade.vt_orderid == vt_orderid: if trade.vt_orderid == vt_orderid:
trades.append(trade) trades.append(trade)
if not use_df:
return trades return trades
else:
return to_df(trades)
def get_all_active_orders(self) -> Sequence[OrderData]: def get_all_active_orders(self, use_df: bool = False) -> Sequence[OrderData]:
"""""" """"""
return self.main_engine.get_all_active_orders() return get_data(self.main_engine.get_all_active_orders, use_df=use_df)
def get_contract(self, vt_symbol) -> ContractData: def get_contract(self, vt_symbol, use_df: bool = False) -> ContractData:
"""""" """"""
return self.main_engine.get_contract(vt_symbol) return get_data(self.main_engine.get_contract, arg=vt_symbol, use_df=use_df)
def get_all_contracts(self) -> Sequence[ContractData]: def get_all_contracts(self, use_df: bool = False) -> Sequence[ContractData]:
"""""" """"""
return self.main_engine.get_all_contracts() return get_data(self.main_engine.get_all_contracts, use_df=use_df)
def get_account(self, vt_accountid: str) -> AccountData: def get_account(self, vt_accountid: str, use_df: bool = False) -> AccountData:
"""""" """"""
return self.main_engine.get_account(vt_accountid) return get_data(self.main_engine.get_account, arg=vt_accountid, use_df=use_df)
def get_all_accounts(self) -> Sequence[AccountData]: def get_all_accounts(self, use_df: bool = False) -> Sequence[AccountData]:
"""""" """"""
return self.main_engine.get_all_accounts() return get_data(self.main_engine.get_all_accounts, use_df=use_df)
def get_position(self, vt_positionid: str) -> PositionData: def get_position(self, vt_positionid: str, use_df: bool = False) -> PositionData:
"""""" """"""
return self.main_engine.get_position(vt_positionid) return get_data(self.main_engine.get_position, arg=vt_positionid, use_df=use_df)
def get_all_positions(self) -> Sequence[AccountData]: def get_all_positions(self, use_df: bool = False) -> Sequence[AccountData]:
"""""" """"""
return self.main_engine.get_all_positions() return get_data(self.main_engine.get_all_positions, use_df=use_df)
def get_bars(self, vt_symbol: str, start_date: str, interval: Interval) -> Sequence[BarData]: def get_bars(self, vt_symbol: str, start_date: str, interval: Interval, use_df: bool = False) -> Sequence[BarData]:
"""""" """"""
contract = self.main_engine.get_contract(vt_symbol) contract = self.main_engine.get_contract(vt_symbol)
if not contract: if not contract:
@ -236,10 +249,7 @@ class ScriptEngine(BaseEngine):
interval=interval interval=interval
) )
bars = rqdata_client.query_history(req) return get_data(rqdata_client.query_history, arg=req, use_df=use_df)
if not bars:
return []
return bars
def write_log(self, msg: str) -> None: def write_log(self, msg: str) -> None:
"""""" """"""
@ -248,3 +258,29 @@ class ScriptEngine(BaseEngine):
event = Event(EVENT_SCRIPT_LOG, log) event = Event(EVENT_SCRIPT_LOG, log)
self.event_engine.put(event) self.event_engine.put(event)
def to_df(data_list: Sequence):
""""""
if not data_list:
return None
dict_list = [data.__dict__ for data in data_list]
return DataFrame(dict_list)
def get_data(func: callable, arg: Any = None, use_df: bool = False):
""""""
if not arg:
data = func()
else:
data = func(arg)
if not use_df:
return data
elif data is None:
return data
else:
if not isinstance(data, list):
data = [data]
return to_df(data)