[Mod] ScriptEngine add support to return DataFrame

This commit is contained in:
vn.py 2019-07-02 17:33:12 +08:00
parent 1ed35377e6
commit 2c0362b3c5
2 changed files with 1765 additions and 1071 deletions

File diff suppressed because it is too large Load Diff

View File

@ -3,11 +3,13 @@
import sys
import importlib
import traceback
from typing import Sequence
from typing import Sequence, Callable, Any
from pathlib import Path
from datetime import datetime
from threading import Thread
from pandas import DataFrame
from vnpy.event import Event, EventEngine
from vnpy.trader.engine import BaseEngine, MainEngine
from vnpy.trader.constant import Direction, Offset, OrderType, Interval
@ -158,31 +160,39 @@ class ScriptEngine(BaseEngine):
req = order.create_cancel_request()
self.main_engine.cancel_order(req, order.gateway_name)
def get_tick(self, vt_symbol: str) -> TickData:
def get_tick(self, vt_symbol: str, use_df: bool = False) -> TickData:
""""""
return self.main_engine.get_tick(vt_symbol)
return get_data(self.main_engine.get_tick, arg=vt_symbol, use_df=use_df)
def get_ticks(self, vt_symbols: Sequence[str]) -> Sequence[TickData]:
def get_ticks(self, vt_symbols: Sequence[str], use_df: bool = False) -> Sequence[TickData]:
""""""
ticks = []
for vt_symbol in vt_symbols:
tick = self.main_engine.get_tick(vt_symbol)
ticks.append(tick)
return ticks
def get_order(self, vt_orderid: str) -> OrderData:
if not use_df:
return ticks
else:
return to_df(ticks)
def get_order(self, vt_orderid: str, use_df: bool = False) -> OrderData:
""""""
return self.main_engine.get_order(vt_orderid)
return get_data(self.main_engine.get_order, arg=vt_orderid, use_df=use_df)
def get_orders(self, vt_orderids: Sequence[str]) -> Sequence[OrderData]:
def get_orders(self, vt_orderids: Sequence[str], use_df: bool = False) -> Sequence[OrderData]:
""""""
orders = []
for vt_orderid in vt_orderids:
order = self.main_engine.get_order(vt_orderid)
orders.append(order)
return orders
def get_trades(self, vt_orderid: str) -> Sequence[TradeData]:
if not use_df:
return orders
else:
return to_df(orders)
def get_trades(self, vt_orderid: str, use_df: bool = False) -> Sequence[TradeData]:
""""""
trades = []
all_trades = self.main_engine.get_all_trades()
@ -191,37 +201,40 @@ class ScriptEngine(BaseEngine):
if trade.vt_orderid == vt_orderid:
trades.append(trade)
return trades
if not use_df:
return trades
else:
return to_df(trades)
def get_all_active_orders(self) -> Sequence[OrderData]:
def get_all_active_orders(self, use_df: bool = False) -> Sequence[OrderData]:
""""""
return self.main_engine.get_all_active_orders()
return get_data(self.main_engine.get_all_active_orders, use_df=use_df)
def get_contract(self, vt_symbol) -> ContractData:
def get_contract(self, vt_symbol, use_df: bool = False) -> ContractData:
""""""
return self.main_engine.get_contract(vt_symbol)
return get_data(self.main_engine.get_contract, arg=vt_symbol, use_df=use_df)
def get_all_contracts(self) -> Sequence[ContractData]:
def get_all_contracts(self, use_df: bool = False) -> Sequence[ContractData]:
""""""
return self.main_engine.get_all_contracts()
return get_data(self.main_engine.get_all_contracts, use_df=use_df)
def get_account(self, vt_accountid: str) -> AccountData:
def get_account(self, vt_accountid: str, use_df: bool = False) -> AccountData:
""""""
return self.main_engine.get_account(vt_accountid)
return get_data(self.main_engine.get_account, arg=vt_accountid, use_df=use_df)
def get_all_accounts(self) -> Sequence[AccountData]:
def get_all_accounts(self, use_df: bool = False) -> Sequence[AccountData]:
""""""
return self.main_engine.get_all_accounts()
return get_data(self.main_engine.get_all_accounts, use_df=use_df)
def get_position(self, vt_positionid: str) -> PositionData:
def get_position(self, vt_positionid: str, use_df: bool = False) -> PositionData:
""""""
return self.main_engine.get_position(vt_positionid)
return get_data(self.main_engine.get_position, arg=vt_positionid, use_df=use_df)
def get_all_positions(self) -> Sequence[AccountData]:
def get_all_positions(self, use_df: bool = False) -> Sequence[AccountData]:
""""""
return self.main_engine.get_all_positions()
return get_data(self.main_engine.get_all_positions, use_df=use_df)
def get_bars(self, vt_symbol: str, start_date: str, interval: Interval) -> Sequence[BarData]:
def get_bars(self, vt_symbol: str, start_date: str, interval: Interval, use_df: bool = False) -> Sequence[BarData]:
""""""
contract = self.main_engine.get_contract(vt_symbol)
if not contract:
@ -236,10 +249,7 @@ class ScriptEngine(BaseEngine):
interval=interval
)
bars = rqdata_client.query_history(req)
if not bars:
return []
return bars
return get_data(rqdata_client.query_history, arg=req, use_df=use_df)
def write_log(self, msg: str) -> None:
""""""
@ -248,3 +258,29 @@ class ScriptEngine(BaseEngine):
event = Event(EVENT_SCRIPT_LOG, log)
self.event_engine.put(event)
def to_df(data_list: Sequence):
""""""
if not data_list:
return None
dict_list = [data.__dict__ for data in data_list]
return DataFrame(dict_list)
def get_data(func: callable, arg: Any = None, use_df: bool = False):
""""""
if not arg:
data = func()
else:
data = func(arg)
if not use_df:
return data
elif data is None:
return data
else:
if not isinstance(data, list):
data = [data]
return to_df(data)