[Mod] ScriptEngine add support to return DataFrame
This commit is contained in:
parent
1ed35377e6
commit
2c0362b3c5
File diff suppressed because it is too large
Load Diff
@ -3,11 +3,13 @@
|
|||||||
import sys
|
import sys
|
||||||
import importlib
|
import importlib
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Sequence
|
from typing import Sequence, Callable, Any
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
|
from pandas import DataFrame
|
||||||
|
|
||||||
from vnpy.event import Event, EventEngine
|
from vnpy.event import Event, EventEngine
|
||||||
from vnpy.trader.engine import BaseEngine, MainEngine
|
from vnpy.trader.engine import BaseEngine, MainEngine
|
||||||
from vnpy.trader.constant import Direction, Offset, OrderType, Interval
|
from vnpy.trader.constant import Direction, Offset, OrderType, Interval
|
||||||
@ -158,31 +160,39 @@ class ScriptEngine(BaseEngine):
|
|||||||
req = order.create_cancel_request()
|
req = order.create_cancel_request()
|
||||||
self.main_engine.cancel_order(req, order.gateway_name)
|
self.main_engine.cancel_order(req, order.gateway_name)
|
||||||
|
|
||||||
def get_tick(self, vt_symbol: str) -> TickData:
|
def get_tick(self, vt_symbol: str, use_df: bool = False) -> TickData:
|
||||||
""""""
|
""""""
|
||||||
return self.main_engine.get_tick(vt_symbol)
|
return get_data(self.main_engine.get_tick, arg=vt_symbol, use_df=use_df)
|
||||||
|
|
||||||
def get_ticks(self, vt_symbols: Sequence[str]) -> Sequence[TickData]:
|
def get_ticks(self, vt_symbols: Sequence[str], use_df: bool = False) -> Sequence[TickData]:
|
||||||
""""""
|
""""""
|
||||||
ticks = []
|
ticks = []
|
||||||
for vt_symbol in vt_symbols:
|
for vt_symbol in vt_symbols:
|
||||||
tick = self.main_engine.get_tick(vt_symbol)
|
tick = self.main_engine.get_tick(vt_symbol)
|
||||||
ticks.append(tick)
|
ticks.append(tick)
|
||||||
|
|
||||||
|
if not use_df:
|
||||||
return ticks
|
return ticks
|
||||||
|
else:
|
||||||
|
return to_df(ticks)
|
||||||
|
|
||||||
def get_order(self, vt_orderid: str) -> OrderData:
|
def get_order(self, vt_orderid: str, use_df: bool = False) -> OrderData:
|
||||||
""""""
|
""""""
|
||||||
return self.main_engine.get_order(vt_orderid)
|
return get_data(self.main_engine.get_order, arg=vt_orderid, use_df=use_df)
|
||||||
|
|
||||||
def get_orders(self, vt_orderids: Sequence[str]) -> Sequence[OrderData]:
|
def get_orders(self, vt_orderids: Sequence[str], use_df: bool = False) -> Sequence[OrderData]:
|
||||||
""""""
|
""""""
|
||||||
orders = []
|
orders = []
|
||||||
for vt_orderid in vt_orderids:
|
for vt_orderid in vt_orderids:
|
||||||
order = self.main_engine.get_order(vt_orderid)
|
order = self.main_engine.get_order(vt_orderid)
|
||||||
orders.append(order)
|
orders.append(order)
|
||||||
return orders
|
|
||||||
|
|
||||||
def get_trades(self, vt_orderid: str) -> Sequence[TradeData]:
|
if not use_df:
|
||||||
|
return orders
|
||||||
|
else:
|
||||||
|
return to_df(orders)
|
||||||
|
|
||||||
|
def get_trades(self, vt_orderid: str, use_df: bool = False) -> Sequence[TradeData]:
|
||||||
""""""
|
""""""
|
||||||
trades = []
|
trades = []
|
||||||
all_trades = self.main_engine.get_all_trades()
|
all_trades = self.main_engine.get_all_trades()
|
||||||
@ -191,37 +201,40 @@ class ScriptEngine(BaseEngine):
|
|||||||
if trade.vt_orderid == vt_orderid:
|
if trade.vt_orderid == vt_orderid:
|
||||||
trades.append(trade)
|
trades.append(trade)
|
||||||
|
|
||||||
|
if not use_df:
|
||||||
return trades
|
return trades
|
||||||
|
else:
|
||||||
|
return to_df(trades)
|
||||||
|
|
||||||
def get_all_active_orders(self) -> Sequence[OrderData]:
|
def get_all_active_orders(self, use_df: bool = False) -> Sequence[OrderData]:
|
||||||
""""""
|
""""""
|
||||||
return self.main_engine.get_all_active_orders()
|
return get_data(self.main_engine.get_all_active_orders, use_df=use_df)
|
||||||
|
|
||||||
def get_contract(self, vt_symbol) -> ContractData:
|
def get_contract(self, vt_symbol, use_df: bool = False) -> ContractData:
|
||||||
""""""
|
""""""
|
||||||
return self.main_engine.get_contract(vt_symbol)
|
return get_data(self.main_engine.get_contract, arg=vt_symbol, use_df=use_df)
|
||||||
|
|
||||||
def get_all_contracts(self) -> Sequence[ContractData]:
|
def get_all_contracts(self, use_df: bool = False) -> Sequence[ContractData]:
|
||||||
""""""
|
""""""
|
||||||
return self.main_engine.get_all_contracts()
|
return get_data(self.main_engine.get_all_contracts, use_df=use_df)
|
||||||
|
|
||||||
def get_account(self, vt_accountid: str) -> AccountData:
|
def get_account(self, vt_accountid: str, use_df: bool = False) -> AccountData:
|
||||||
""""""
|
""""""
|
||||||
return self.main_engine.get_account(vt_accountid)
|
return get_data(self.main_engine.get_account, arg=vt_accountid, use_df=use_df)
|
||||||
|
|
||||||
def get_all_accounts(self) -> Sequence[AccountData]:
|
def get_all_accounts(self, use_df: bool = False) -> Sequence[AccountData]:
|
||||||
""""""
|
""""""
|
||||||
return self.main_engine.get_all_accounts()
|
return get_data(self.main_engine.get_all_accounts, use_df=use_df)
|
||||||
|
|
||||||
def get_position(self, vt_positionid: str) -> PositionData:
|
def get_position(self, vt_positionid: str, use_df: bool = False) -> PositionData:
|
||||||
""""""
|
""""""
|
||||||
return self.main_engine.get_position(vt_positionid)
|
return get_data(self.main_engine.get_position, arg=vt_positionid, use_df=use_df)
|
||||||
|
|
||||||
def get_all_positions(self) -> Sequence[AccountData]:
|
def get_all_positions(self, use_df: bool = False) -> Sequence[AccountData]:
|
||||||
""""""
|
""""""
|
||||||
return self.main_engine.get_all_positions()
|
return get_data(self.main_engine.get_all_positions, use_df=use_df)
|
||||||
|
|
||||||
def get_bars(self, vt_symbol: str, start_date: str, interval: Interval) -> Sequence[BarData]:
|
def get_bars(self, vt_symbol: str, start_date: str, interval: Interval, use_df: bool = False) -> Sequence[BarData]:
|
||||||
""""""
|
""""""
|
||||||
contract = self.main_engine.get_contract(vt_symbol)
|
contract = self.main_engine.get_contract(vt_symbol)
|
||||||
if not contract:
|
if not contract:
|
||||||
@ -236,10 +249,7 @@ class ScriptEngine(BaseEngine):
|
|||||||
interval=interval
|
interval=interval
|
||||||
)
|
)
|
||||||
|
|
||||||
bars = rqdata_client.query_history(req)
|
return get_data(rqdata_client.query_history, arg=req, use_df=use_df)
|
||||||
if not bars:
|
|
||||||
return []
|
|
||||||
return bars
|
|
||||||
|
|
||||||
def write_log(self, msg: str) -> None:
|
def write_log(self, msg: str) -> None:
|
||||||
""""""
|
""""""
|
||||||
@ -248,3 +258,29 @@ class ScriptEngine(BaseEngine):
|
|||||||
|
|
||||||
event = Event(EVENT_SCRIPT_LOG, log)
|
event = Event(EVENT_SCRIPT_LOG, log)
|
||||||
self.event_engine.put(event)
|
self.event_engine.put(event)
|
||||||
|
|
||||||
|
|
||||||
|
def to_df(data_list: Sequence):
|
||||||
|
""""""
|
||||||
|
if not data_list:
|
||||||
|
return None
|
||||||
|
|
||||||
|
dict_list = [data.__dict__ for data in data_list]
|
||||||
|
return DataFrame(dict_list)
|
||||||
|
|
||||||
|
|
||||||
|
def get_data(func: callable, arg: Any = None, use_df: bool = False):
|
||||||
|
""""""
|
||||||
|
if not arg:
|
||||||
|
data = func()
|
||||||
|
else:
|
||||||
|
data = func(arg)
|
||||||
|
|
||||||
|
if not use_df:
|
||||||
|
return data
|
||||||
|
elif data is None:
|
||||||
|
return data
|
||||||
|
else:
|
||||||
|
if not isinstance(data, list):
|
||||||
|
data = [data]
|
||||||
|
return to_df(data)
|
||||||
|
Loading…
Reference in New Issue
Block a user