[Mod] complete test of ToraGateway

This commit is contained in:
vn.py 2019-06-26 21:43:06 +08:00
parent a9ed153f60
commit d0bb68fc47
5 changed files with 99 additions and 33 deletions

View File

@ -9,7 +9,7 @@ from vnpy.gateway.futu import FutuGateway
from vnpy.gateway.ib import IbGateway
from vnpy.gateway.ctp import CtpGateway
# from vnpy.gateway.ctptest import CtptestGateway
from vnpy.gateway.femas import FemasGateway
# from vnpy.gateway.femas import FemasGateway
from vnpy.gateway.tiger import TigerGateway
# from vnpy.gateway.oes import OesGateway
from vnpy.gateway.okex import OkexGateway
@ -19,7 +19,8 @@ from vnpy.gateway.onetoken import OnetokenGateway
from vnpy.gateway.okexf import OkexfGateway
# from vnpy.gateway.xtp import XtpGateway
from vnpy.gateway.hbdm import HbdmGateway
from vnpy.gateway.tap import TapGateway
# from vnpy.gateway.tap import TapGateway
from vnpy.gateway.tora import ToraGateway
from vnpy.app.cta_strategy import CtaStrategyApp
from vnpy.app.csv_loader import CsvLoaderApp
@ -40,7 +41,7 @@ def main():
main_engine.add_gateway(BinanceGateway)
main_engine.add_gateway(CtpGateway)
# main_engine.add_gateway(CtptestGateway)
main_engine.add_gateway(FemasGateway)
# main_engine.add_gateway(FemasGateway)
main_engine.add_gateway(IbGateway)
main_engine.add_gateway(FutuGateway)
main_engine.add_gateway(BitmexGateway)
@ -53,7 +54,8 @@ def main():
main_engine.add_gateway(OkexfGateway)
main_engine.add_gateway(HbdmGateway)
# main_engine.add_gateway(XtpGateway)
main_engine.add_gateway(TapGateway)
# main_engine.add_gateway(TapGateway)
main_engine.add_gateway(ToraGateway)
main_engine.add_app(CtaStrategyApp)
main_engine.add_app(CtaBacktesterApp)

View File

@ -23,22 +23,28 @@ def parse_datetime(date: str, time: str):
class ToraMdSpi(CTORATstpMdSpi):
""""""
def __init__(self, api: "ToraMdApi", gateway: "BaseGateway"):
""""""
super().__init__()
self.gateway = gateway
self._api = api
def OnFrontConnected(self) -> Any:
""""""
self.gateway.write_log("行情服务器连接成功")
def OnFrontDisconnected(self, error_code: int) -> Any:
self.gateway.write_log(f"行情服务器连接断开({error_code}):{get_error_msg(error_code)}")
""""""
self.gateway.write_log(
f"行情服务器连接断开({error_code}):{get_error_msg(error_code)}")
def OnRspError(
self, error_info: CTORATstpRspInfoField, request_id: int, is_last: bool
) -> Any:
""""""
error_id = error_info.ErrorID
error_msg = error_info.ErrorMsg
self.gateway.write_log(f"行情服务收到错误消息({error_id}){error_msg}")
@ -50,6 +56,7 @@ class ToraMdSpi(CTORATstpMdSpi):
request_id: int,
is_last: bool,
) -> Any:
""""""
error_id = error_info.ErrorID
if error_id != 0:
error_msg = error_info.ErrorMsg
@ -64,6 +71,7 @@ class ToraMdSpi(CTORATstpMdSpi):
request_id: int,
is_last: bool,
) -> Any:
""""""
error_id = error_info.ErrorID
if error_id != 0:
error_msg = error_info.ErrorMsg
@ -72,6 +80,7 @@ class ToraMdSpi(CTORATstpMdSpi):
self.gateway.write_log("行情服务器登出成功")
def OnRtnDepthMarketData(self, data: CTORATstpMarketDataField) -> Any:
""""""
if data.ExchangeID not in EXCHANGE_TORA2VT:
return
tick_data = TickData(
@ -114,8 +123,10 @@ class ToraMdSpi(CTORATstpMdSpi):
class ToraMdApi:
""""""
def __init__(self, gateway: BaseGateway):
""""""
self.gateway = gateway
self.md_address = ""
@ -151,10 +162,13 @@ class ToraMdApi:
return True
def subscribe(self, symbols: List[str], exchange: Exchange):
err = self._native_api.SubscribeMarketData(symbols, EXCHANGE_VT2TORA[exchange])
""""""
err = self._native_api.SubscribeMarketData(
symbols, EXCHANGE_VT2TORA[exchange])
self._if_error_write_log(err, "subscribe")
def _if_error_write_log(self, error_code: int, function_name: str):
""""""
if error_code != 0:
error_msg = get_error_msg(error_code)
msg = f'在执行 {function_name} 时发生错误({error_code}): {error_msg}'

View File

@ -47,7 +47,8 @@ def _check_error(none_return: bool = True,
def wrapped(self, info, error_info, *args):
function_name = func.__name__
if print_function_name:
print(function_name, "info" if info else "None", error_info.ErrorID)
print(function_name, "info" if info else "None",
error_info.ErrorID)
# print if errors
error_code = error_info.ErrorID
@ -72,8 +73,10 @@ def _check_error(none_return: bool = True,
class QueryLoop:
""""""
def __init__(self, gateway: "BaseGateway"):
""""""
self.event_engine = gateway.event_engine
self._seconds_left = 0
@ -84,6 +87,7 @@ class QueryLoop:
self.event_engine.register(EVENT_TIMER, self._process_timer_event)
def stop(self):
""""""
self.event_engine.unregister(EVENT_TIMER, self._process_timer_event)
def _process_timer_event(self, event):
@ -96,7 +100,8 @@ class QueryLoop:
self._seconds_left = 2
# get the last one and re-queue it
func = self._query_functions.pop(0) # works fine if there is no so much items
# works fine if there is no so much items
func = self._query_functions.pop(0)
self._query_functions.append(func)
# call it
@ -107,11 +112,13 @@ OrdersType = Dict[str, "OrderInfo"]
class ToraTdSpi(CTORATstpTraderSpi):
""""""
def __init__(self, session_info: "SessionInfo",
api: "ToraTdApi",
gateway: "BaseGateway",
orders: OrdersType):
""""""
super().__init__()
self.session_info = session_info
self.gateway = gateway
@ -120,6 +127,7 @@ class ToraTdSpi(CTORATstpTraderSpi):
self._api: "ToraTdApi" = api
def OnRtnTrade(self, info: CTORATstpTradeField) -> None:
""""""
try:
trade_data = TradeData(
gateway_name=self.gateway.gateway_name,
@ -138,6 +146,7 @@ class ToraTdSpi(CTORATstpTraderSpi):
return
def OnRtnOrder(self, info: CTORATstpOrderField) -> None:
""""""
self._api.update_last_local_order_id(int(info.OrderRef))
try:
@ -155,6 +164,7 @@ class ToraTdSpi(CTORATstpTraderSpi):
@_check_error(error_return=False, write_log=False, print_function_name=False)
def OnErrRtnOrderInsert(self, info: CTORATstpInputOrderField,
error_info: CTORATstpRspInfoField) -> None:
""""""
try:
self._api.update_last_local_order_id(int(info.OrderRef))
except ValueError:
@ -173,24 +183,27 @@ class ToraTdSpi(CTORATstpTraderSpi):
@_check_error(error_return=False, write_log=False, print_function_name=False)
def OnErrRtnOrderAction(self, info: CTORATstpOrderActionField,
error_info: CTORATstpRspInfoField) -> None:
""""""
pass
@_check_error()
def OnRtnCondOrder(self, info: CTORATstpConditionOrderField) -> None:
""""""
pass
@_check_error()
def OnRspOrderAction(self, info: CTORATstpInputOrderActionField,
error_info: CTORATstpRspInfoField, request_id: int, is_last: bool) -> None:
print("order action succeed!")
pass
@_check_error()
def OnRspOrderInsert(self, info: CTORATstpInputOrderField,
error_info: CTORATstpRspInfoField, request_id: int, is_last: bool) -> None:
""""""
try:
order_data = self.parse_order_field(info)
except KeyError:
self.gateway.write_log(f"收到无法识别的下单回执({info.OrderRef})")
self.gateway.write_log(f"收到无法识别的下单回执({info.OrderRef})")
return
self.gateway.on_order(order_data)
@ -208,14 +221,17 @@ class ToraTdSpi(CTORATstpTraderSpi):
@_check_error(print_function_name=False)
def OnRspQryPosition(self, info: CTORATstpPositionField, error_info: CTORATstpRspInfoField,
request_id: int, is_last: bool) -> None:
""""""
if info.InvestorID != self.session_info.investor_id:
self.gateway.write_log("OnRspQryPosition:收到其他账户的仓位信息")
self.gateway.write_log("OnRspQryPosition:收到其他账户的仓位信息")
return
if info.ExchangeID not in EXCHANGE_TORA2VT:
self.gateway.write_log(f"OnRspQryPosition:忽略不支持的交易所:{info.ExchangeID}")
self.gateway.write_log(
f"OnRspQryPosition:忽略不支持的交易所:{info.ExchangeID}")
return
volume = info.CurrentPosition
frozen = info.HistoryPosFrozen + info.TodayBSFrozen + info.TodayPRFrozen + info.TodaySMPosFrozen
frozen = info.HistoryPosFrozen + info.TodayBSFrozen + \
info.TodayPRFrozen + info.TodaySMPosFrozen
position_data = PositionData(
gateway_name=self.gateway.gateway_name,
symbol=info.SecurityID,
@ -224,7 +240,8 @@ class ToraTdSpi(CTORATstpTraderSpi):
volume=volume, # verify this: which one should vnpy use?
frozen=frozen, # verify this: which one should i use?
price=info.TotalPosCost / volume,
pnl=info.LastPrice * volume - info.TotalPosCost, # verify this: is this formula correct
# verify this: is this formula correct
pnl=info.LastPrice * volume - info.TotalPosCost,
yd_volume=info.HistoryPos,
)
self.gateway.on_position(position_data)
@ -233,7 +250,7 @@ class ToraTdSpi(CTORATstpTraderSpi):
def OnRspQryTradingAccount(self, info: CTORATstpTradingAccountField,
error_info: CTORATstpRspInfoField, request_id: int,
is_last: bool) -> None:
""""""
self.session_info.account_id = info.AccountID
account_data = AccountData(
gateway_name=self.gateway.gateway_name,
@ -247,19 +264,22 @@ class ToraTdSpi(CTORATstpTraderSpi):
def OnRspQryShareholderAccount(self, info: CTORATstpShareholderAccountField,
error_info: CTORATstpRspInfoField, request_id: int,
is_last: bool) -> None:
""""""
exchange = EXCHANGE_TORA2VT[info.ExchangeID]
self.session_info.shareholder_ids[exchange] = info.ShareholderID
@_check_error(print_function_name=False)
def OnRspQryInvestor(self, info: CTORATstpInvestorField, error_info: CTORATstpRspInfoField,
request_id: int, is_last: bool) -> None:
""""""
self.session_info.investor_id = info.InvestorID
@_check_error(none_return=False, print_function_name=False)
def OnRspQrySecurity(self, info: CTORATstpSecurityField, error_info: CTORATstpRspInfoField,
request_id: int, is_last: bool) -> None:
""""""
if is_last:
self.gateway.write_log("合约信息查询成功!")
self.gateway.write_log("合约信息查询成功")
if not info:
return
@ -283,12 +303,14 @@ class ToraTdSpi(CTORATstpTraderSpi):
self.gateway.on_contract(contract_data)
def OnFrontConnected(self) -> None:
""""""
self.gateway.write_log("交易服务器连接成功")
self._api.login()
@_check_error(print_function_name=False)
def OnRspUserLogin(self, info: CTORATstpRspUserLoginField,
error_info: CTORATstpRspInfoField, request_id: int, is_last: bool) -> None:
""""""
self._api.update_last_local_order_id(int(info.MaxOrderRef))
self.session_info.front_id = info.FrontID
self.session_info.session_id = info.SessionID
@ -298,7 +320,9 @@ class ToraTdSpi(CTORATstpTraderSpi):
self._api.start_query_loop() # stop at ToraTdApi.stop()
def OnFrontDisconnected(self, error_code: int) -> None:
self.gateway.write_log(f"交易服务器连接断开({error_code}):{get_error_msg(error_code)}")
""""""
self.gateway.write_log(
f"交易服务器连接断开({error_code}):{get_error_msg(error_code)}")
def parse_order_field(self, info):
"""
@ -330,6 +354,7 @@ class ToraTdSpi(CTORATstpTraderSpi):
class ToraTdApi:
def __init__(self, gateway: BaseGateway):
""""""
self.gateway = gateway
self.username = ""
@ -347,13 +372,16 @@ class ToraTdApi:
self._next_local_order_id = int(1e5)
def get_shareholder_id(self, exchange: Exchange):
""""""
return self.session_info.shareholder_ids[exchange]
def update_last_local_order_id(self, new_val: int):
""""""
cur = self._next_local_order_id
self._next_local_order_id = max(cur, new_val + 1)
def _if_error_write_log(self, error_code: int, function_name: str):
""""""
if error_code != 0:
error_msg = get_error_msg(error_code)
msg = f'在执行 {function_name} 时发生错误({error_code}): {error_msg}'
@ -361,21 +389,25 @@ class ToraTdApi:
return True
def _get_new_req_id(self):
""""""
req_id = self._last_req_id
self._last_req_id += 1
return req_id
def _get_new_order_id(self) -> str:
""""""
order_id = self._next_local_order_id
self._next_local_order_id += 1
return str(order_id)
def query_contracts(self):
""""""
info = CTORATstpQrySecurityField()
err = self._native_api.ReqQrySecurity(info, self._get_new_req_id())
self._if_error_write_log(err, "query_contracts")
def query_exchange(self, exchange: Exchange):
""""""
info = CTORATstpQryExchangeField()
info.ExchangeID = EXCHANGE_VT2TORA[exchange]
err = self._native_api.ReqQryExchange(info, self._get_new_req_id())
@ -383,6 +415,7 @@ class ToraTdApi:
self._if_error_write_log(err, "query_exchange")
def query_market_data(self, symbol: str, exchange: Exchange):
""""""
info = CTORATstpQryMarketDataField()
info.ExchangeID = EXCHANGE_VT2TORA[exchange]
info.SecurityID = symbol
@ -390,6 +423,7 @@ class ToraTdApi:
self._if_error_write_log(err, "query_market_data")
def stop(self):
""""""
self.stop_query_loop()
if self._native_api:
@ -419,16 +453,21 @@ class ToraTdApi:
:return:
"""
flow_path = str(get_folder_path(self.gateway.gateway_name.lower()))
self._native_api = CTORATstpTraderApi.CreateTstpTraderApi(flow_path, True)
self._spi = ToraTdSpi(self.session_info, self, self.gateway, self.orders)
self._native_api = CTORATstpTraderApi.CreateTstpTraderApi(
flow_path, True)
self._spi = ToraTdSpi(self.session_info, self,
self.gateway, self.orders)
self._native_api.RegisterSpi(self._spi)
self._native_api.RegisterFront(self.td_address)
self._native_api.SubscribePublicTopic(TORA_TE_RESUME_TYPE.TORA_TERT_RESTART)
self._native_api.SubscribePrivateTopic(TORA_TE_RESUME_TYPE.TORA_TERT_RESTART)
self._native_api.SubscribePublicTopic(
TORA_TE_RESUME_TYPE.TORA_TERT_RESTART)
self._native_api.SubscribePrivateTopic(
TORA_TE_RESUME_TYPE.TORA_TERT_RESTART)
self._native_api.Init()
return True
def send_order(self, req: OrderRequest):
""""""
if req.type is OrderType.STOP:
raise NotImplementedError()
if req.type is OrderType.FAK or req.type is OrderType.FOK:
@ -465,16 +504,19 @@ class ToraTdApi:
self.session_info.session_id,
self.session_info.front_id,
)
self.gateway.on_order(req.create_order_data(order_id, self.gateway.gateway_name))
self.gateway.on_order(req.create_order_data(
order_id, self.gateway.gateway_name))
# err = self._native_api.ReqCondOrderInsert(info, self._get_new_req_id())
err = self._native_api.ReqOrderInsert(info, self._get_new_req_id())
self._if_error_write_log(err, "send_order:ReqOrderInsert")
def cancel_order(self, req: CancelRequest):
""""""
info = CTORATstpInputOrderActionField()
info.InvestorID = self.session_info.investor_id
info.ExchangeID = EXCHANGE_VT2TORA[req.exchange] # 没有的话:(16608)VIP:未知的交易所代码
# 没有的话:(16608)VIP:未知的交易所代码
info.ExchangeID = EXCHANGE_VT2TORA[req.exchange]
info.SecurityID = req.symbol
# info.OrderActionRef = str(self._get_new_req_id())
@ -491,6 +533,7 @@ class ToraTdApi:
self._if_error_write_log(err, "cancel_order:ReqOrderAction")
def query_initialize_status(self):
""""""
self.query_contracts()
self.query_investors()
self.query_shareholder_ids()
@ -500,41 +543,51 @@ class ToraTdApi:
self.query_trades()
def query_accounts(self):
""""""
info = CTORATstpQryTradingAccountField()
err = self._native_api.ReqQryTradingAccount(info, self._get_new_req_id())
err = self._native_api.ReqQryTradingAccount(
info, self._get_new_req_id())
self._if_error_write_log(err, "query_accounts")
def query_shareholder_ids(self):
""""""
info = CTORATstpQryShareholderAccountField()
err = self._native_api.ReqQryShareholderAccount(info, self._get_new_req_id())
err = self._native_api.ReqQryShareholderAccount(
info, self._get_new_req_id())
self._if_error_write_log(err, "query_shareholder_ids")
def query_investors(self):
""""""
info = CTORATstpQryInvestorField()
err = self._native_api.ReqQryInvestor(info, self._get_new_req_id())
self._if_error_write_log(err, "query_investors")
def query_positions(self):
""""""
info = CTORATstpQryPositionField()
err = self._native_api.ReqQryPosition(info, self._get_new_req_id())
self._if_error_write_log(err, "query_positions")
def query_orders(self):
""""""
info = CTORATstpQryOrderField()
err = self._native_api.ReqQryOrder(info, self._get_new_req_id())
self._if_error_write_log(err, "query_orders")
def query_trades(self):
""""""
info = CTORATstpQryTradeField()
err = self._native_api.ReqQryTrade(info, self._get_new_req_id())
self._if_error_write_log(err, "query_trades")
def start_query_loop(self):
""""""
if not self._query_loop:
self._query_loop = QueryLoop(self.gateway)
self._query_loop.start()
def stop_query_loop(self):
""""""
if self._query_loop:
self._query_loop.stop()
self._query_loop = None

View File

@ -5,7 +5,8 @@ TODO:
* Linux support
"""
from vnpy.api.tora.vntora import (AsyncDispatchException, set_async_callback_exception_handler)
from vnpy.api.tora.vntora import (
AsyncDispatchException, set_async_callback_exception_handler)
from vnpy.event import EventEngine
from vnpy.trader.gateway import BaseGateway
@ -20,6 +21,8 @@ def is_valid_front_address(address: str):
class ToraGateway(BaseGateway):
""""""
default_setting = {
"账号": "",
"密码": "",
@ -86,13 +89,6 @@ class ToraGateway(BaseGateway):
""""""
self._td_api.query_positions()
def write_log(self, msg: str):
"""
for easier test
"""
print(msg)
super().write_log(msg)
def _async_callback_exception_handler(self, e: AsyncDispatchException):
error_str = f"发生内部错误:\n" f"位置:{e.instance}.{e.function_name}" f"详细信息:{e.what}"
self.write_log(error_str)

View File

@ -163,6 +163,7 @@ class MainWindow(QtWidgets.QMainWindow):
def init_toolbar(self):
""""""
self.toolbar = QtWidgets.QToolBar(self)
self.toolbar.setObjectName("工具栏")
self.toolbar.setFloatable(False)
self.toolbar.setMovable(False)