From cad96b95fcee78478e18743ad99cc4d2b795c849 Mon Sep 17 00:00:00 2001 From: nanoric Date: Fri, 8 Mar 2019 05:04:46 -0400 Subject: [PATCH] [Add] gateway.oes: reconnect --- vnpy/gateway/oes/config_template.ini | 6 -- vnpy/gateway/oes/oes_gateway.py | 48 +++++---- vnpy/gateway/oes/{md.py => oes_md.py} | 70 +++++++------ vnpy/gateway/oes/{td.py => oes_td.py} | 140 ++++++++++++++------------ 4 files changed, 149 insertions(+), 115 deletions(-) rename vnpy/gateway/oes/{md.py => oes_md.py} (84%) rename vnpy/gateway/oes/{td.py => oes_td.py} (86%) diff --git a/vnpy/gateway/oes/config_template.ini b/vnpy/gateway/oes/config_template.ini index c7db390f..367a169c 100644 --- a/vnpy/gateway/oes/config_template.ini +++ b/vnpy/gateway/oes/config_template.ini @@ -11,9 +11,6 @@ ordServer = 1 {td_ord_server} rptServer = 1 {td_rpt_server} qryServer = 1 {td_qry_server} -username = {username} -# 密码支持明文和MD5两种格式 (如 txt:XXX 或 md5:XXX..., 不带前缀则默认为明文) -password = {password} heartBtInt = 30 # 客户端环境号, 用于区分不同客户端实例上报的委托申报, 取值由客户端自行分配 @@ -58,9 +55,6 @@ keepCnt = 9 tcpServer = {md_tcp_server} qryServer = {md_qry_server} -username = {username} -# 密码支持明文和MD5两种格式 (如 txt:XXX 或 md5:XXX..., 不带前缀则默认为明文) -password = {password} heartBtInt = 30 sse.stock.enable = false diff --git a/vnpy/gateway/oes/oes_gateway.py b/vnpy/gateway/oes/oes_gateway.py index 9ef4d41b..3f7ad72e 100644 --- a/vnpy/gateway/oes/oes_gateway.py +++ b/vnpy/gateway/oes/oes_gateway.py @@ -3,14 +3,15 @@ """ import hashlib import os +from gettext import gettext as _ from threading import Thread from vnpy.trader.gateway import BaseGateway -from vnpy.trader.object import (AccountData, CancelRequest, ContractData, OrderData, OrderRequest, - PositionData, SubscribeRequest, TickData, TradeData) +from vnpy.trader.object import (CancelRequest, OrderRequest, + SubscribeRequest) from vnpy.trader.utility import get_file_path -from .md import OesMdApi -from .td import OesTdApi +from .oes_md import OesMdApi +from .oes_td import OesTdApi from .utils import config_template @@ -44,7 +45,10 @@ class OesGateway(BaseGateway): if not setting['password'].startswith("md5:"): setting['password'] = "md5:" + hashlib.md5(setting['password'].encode()).hexdigest() - config_path = get_file_path("vnoes.ini") + username = setting['username'] + password = setting['password'] + + config_path = str(get_file_path("vnoes.ini")) with open(config_path, "wt") as f: if 'test' in setting: log_level = 'DEBUG' @@ -54,7 +58,7 @@ class OesGateway(BaseGateway): log_mode = 'file' log_dir = get_file_path('oes') log_path = os.path.join(log_dir, 'log.log') - if os.path.exists(log_dir): + if not os.path.exists(log_dir): os.mkdir(log_dir) content = config_template.format(**setting, log_level=log_level, @@ -62,20 +66,30 @@ class OesGateway(BaseGateway): log_path=log_path) f.write(content) - self.td_api.connect(str(config_path)) + self.md_api.config_path = config_path + self.md_api.username = username + self.md_api.password = password + if self.md_api.connect(): + self.md_api.start() + else: + self.write_log(_("无法连接到行情服务器,请检查你的配置")) - self.td_api.query_account() - self.td_api.query_contracts() - self.td_api.query_position() - self.td_api.init_query_orders() - - self.td_api.start() - - self.md_api.connect(str(config_path)) - self.md_api.start() + self.td_api.config_path = config_path + self.td_api.username = username + self.td_api.password = password + if self.td_api.connect(): + self.write_log(_("成功连接到交易服务器")) + self.td_api.query_account() + self.td_api.query_contracts() + self.write_log("合约信息查询成功") + self.td_api.query_position() + self.td_api.init_query_orders() + self.td_api.start() + else: + self.write_log(_("无法连接到交易服务器,请检查你的配置")) def _connect_async(self, setting: dict): - Thread(target=self._connect_sync, args=(setting, )).start() + Thread(target=self._connect_sync, args=(setting,)).start() def subscribe(self, req: SubscribeRequest): """""" diff --git a/vnpy/gateway/oes/md.py b/vnpy/gateway/oes/oes_md.py similarity index 84% rename from vnpy/gateway/oes/md.py rename to vnpy/gateway/oes/oes_md.py index e256f97d..9d76e1ae 100644 --- a/vnpy/gateway/oes/md.py +++ b/vnpy/gateway/oes/oes_md.py @@ -1,4 +1,6 @@ +import time from datetime import datetime +from gettext import gettext as _ from threading import Thread # noinspection PyUnresolvedReferences from typing import Any, Callable, Dict @@ -9,7 +11,8 @@ from vnpy.api.oes.vnoes import MdsApiClientEnvT, MdsApi_DestoryAll, MdsApi_InitA MdsMktDataRequestReqT, MdsMktRspMsgBodyT, MdsStockSnapshotBodyT, SGeneralClientChannelT, \ SMsgHeadT, SPlatform_IsNegEpipe, SPlatform_IsNegEtimeout, cast, eMdsExchangeIdT, \ eMdsMktSubscribeFlagT, eMdsMsgTypeT, eMdsSecurityTypeT, eMdsSubscribeDataTypeT, \ - eMdsSubscribeModeT, eMdsSubscribedTickExpireTypeT, eSMsgProtocolTypeT + eMdsSubscribeModeT, eMdsSubscribedTickExpireTypeT, eSMsgProtocolTypeT, MdsApi_SetThreadUsername, \ + MdsApi_SetThreadPassword from vnpy.trader.constant import Exchange from vnpy.trader.gateway import BaseGateway @@ -25,10 +28,11 @@ EXCHANGE_VT2MDS = {v: k for k, v in EXCHANGE_MDS2VT.items()} class OesMdMessageLoop: - def __init__(self, gateway: BaseGateway, env: MdsApiClientEnvT): + def __init__(self, gateway: BaseGateway, md: "OesMdApi", env: MdsApiClientEnvT): self.gateway = gateway self.env = env self.alive = False + self.md = md self.th = Thread(target=self.message_loop) self.message_handlers: Dict[int, Callable[[dict], None]] = { @@ -91,6 +95,10 @@ class OesMdMessageLoop: self.gateway.write_log(f"unknown prototype : {session_info.protocolType}") return 1 + def reconnect(self): + self.gateway.write_log(_("正在尝试重新连接到行情服务器。")) + return self.md.connect() + def message_loop(self): tcp_channel = self.env.tcpChannel timeout_ms = 1000 @@ -101,13 +109,12 @@ class OesMdMessageLoop: timeout_ms, self.on_message) if ret < 0: - if is_timeout(ret): - pass + # if is_timeout(ret): + # pass # just no message if is_disconnected(ret): - # todo: handle disconnected - self.alive = False - break - pass + self.gateway.write_log(_("与行情服务器的连接已断开。")) + while not self.reconnect(): + time.sleep(1) return def on_l2_market_data_snapshot(self, d: MdsMktRspMsgBodyT): @@ -139,7 +146,6 @@ class OesMdMessageLoop: for i in range(5): tick.__dict__['ask_price_' + str(i + 1)] = data.OfferLevels[i].Price / 10000 self.gateway.on_tick(tick) - pass def on_l2_trade(self, d: MdsMktRspMsgBodyT): data = d.trade @@ -182,30 +188,39 @@ class OesMdApi: def __init__(self, gateway: BaseGateway): self.gateway = gateway - self.env = MdsApiClientEnvT() - self.message_loop = OesMdMessageLoop(gateway, self.env) + self.config_path: str = '' + self.username: str = '' + self.password: str = '' - def connect(self, config_path: str): - if not MdsApi_InitAllByConvention(self.env, config_path): - pass + self._env = MdsApiClientEnvT() + self._message_loop = OesMdMessageLoop(gateway, self, self._env) - if not MdsApi_IsValidTcpChannel(self.env.tcpChannel): - pass - if not MdsApi_IsValidQryChannel(self.env.qryChannel): - pass + def connect(self) -> bool: + """Connect to trading server. + :note set config_path before calling this function + """ + MdsApi_SetThreadUsername(self.username) + MdsApi_SetThreadPassword(self.password) + + config_path = self.config_path + if not MdsApi_InitAllByConvention(self._env, config_path): + return False + if not MdsApi_IsValidTcpChannel(self._env.tcpChannel): + return False + if not MdsApi_IsValidQryChannel(self._env.qryChannel): + return False + return True def start(self): - self.message_loop.start() + self._message_loop.start() def stop(self): - self.message_loop.stop() - if not MdsApi_LogoutAll(self.env, True): - pass # doc for this function is error - if not MdsApi_DestoryAll(self.env): - pass # doc for this function is error + self._message_loop.stop() + MdsApi_LogoutAll(self._env, True) + MdsApi_DestoryAll(self._env) def join(self): - self.message_loop.join() + self._message_loop.join() # why isn't arg a ContractData? def subscribe(self, req: SubscribeRequest): @@ -236,12 +251,11 @@ class OesMdApi: entry.securityType = eMdsSecurityTypeT.MDS_SECURITY_TYPE_STOCK # todo: option and others entry.instrId = int(req.symbol) - self.message_loop.register_symbol(req.symbol, req.exchange) + self._message_loop.register_symbol(req.symbol, req.exchange) ret = MdsApi_SubscribeMarketData( - self.env.tcpChannel, + self._env.tcpChannel, mds_req, entry) if not ret: self.gateway.write_log( f"MdsApi_SubscribeByString failed with {ret}:{error_to_str(ret)}") - pass diff --git a/vnpy/gateway/oes/td.py b/vnpy/gateway/oes/oes_td.py similarity index 86% rename from vnpy/gateway/oes/td.py rename to vnpy/gateway/oes/oes_td.py index a74e7897..649aef7f 100644 --- a/vnpy/gateway/oes/td.py +++ b/vnpy/gateway/oes/oes_td.py @@ -1,19 +1,21 @@ from dataclasses import dataclass from datetime import datetime -from threading import Thread +from gettext import gettext as _ +from threading import Lock, Thread # noinspection PyUnresolvedReferences from typing import Any, Callable, Dict, Tuple from vnpy.api.oes.vnoes import OesApiClientEnvT, OesApi_DestoryAll, OesApi_InitAllByConvention, \ OesApi_IsValidOrdChannel, OesApi_IsValidQryChannel, OesApi_IsValidRptChannel, OesApi_LogoutAll, \ - OesApi_QueryCashAsset, OesApi_QueryEtf, OesApi_QueryIssue, OesApi_QueryOptHolding, \ + OesApi_QueryCashAsset, OesApi_QueryOptHolding, \ OesApi_QueryOption, OesApi_QueryOrder, OesApi_QueryStkHolding, OesApi_QueryStock, \ OesApi_SendOrderCancelReq, OesApi_SendOrderReq, OesApi_WaitReportMsg, OesOrdCancelReqT, \ - OesOrdCnfmT, OesOrdRejectT, OesOrdReqT, OesQryCashAssetFilterT, OesQryCursorT, OesQryEtfFilterT, \ - OesQryIssueFilterT, OesQryOptionFilterT, OesQryOrdFilterT, OesQryStkHoldingFilterT, \ + OesOrdCnfmT, OesOrdRejectT, OesOrdReqT, OesQryCashAssetFilterT, OesQryCursorT, \ + OesQryOptionFilterT, OesQryOrdFilterT, OesQryStkHoldingFilterT, \ OesQryStockFilterT, OesRspMsgBodyT, OesStockBaseInfoT, OesTrdCnfmT, SGeneralClientChannelT, \ SMSG_PROTO_BINARY, SMsgHeadT, SPlatform_IsNegEpipe, SPlatform_IsNegEtimeout, cast, \ - eOesBuySellTypeT, eOesMarketIdT, eOesMsgTypeT, eOesOrdStatusT, eOesOrdTypeShT, eOesOrdTypeSzT + eOesBuySellTypeT, eOesMarketIdT, eOesMsgTypeT, eOesOrdStatusT, eOesOrdTypeShT, eOesOrdTypeSzT, \ + OesApi_SetThreadUsername, OesApi_SetThreadPassword from vnpy.gateway.oes.error_code import error_to_str from vnpy.trader.constant import Direction, Exchange, Offset, PriceType, Product, Status @@ -175,6 +177,10 @@ class OesTdMessageLoop: self.gateway.write_log(f"unknown prototype : {session_info.protocolType}") return 1 + def reconnect(self): + self.gateway.write_log(_("正在尝试重新连接到交易服务器。")) + self.td.connect() + def message_loop(self): rpt_channel = self.env.rptChannel timeout_ms = 1000 @@ -186,12 +192,10 @@ class OesTdMessageLoop: timeout_ms, self.on_message) if ret < 0: - if is_timeout(ret): - pass + # if is_timeout(ret): + # pass # just no message if is_disconnected(ret): - # todo: handle disconnected - self.alive = False - break + self.reconnect() return def on_order_rejected(self, d: OesRspMsgBodyT): @@ -241,7 +245,6 @@ class OesTdMessageLoop: vt_order = i.vt_order # vt_order.status = STATUS_OES2VT[data.ordStatus] - trade = TradeData( gateway_name=self.gateway.gateway_name, symbol=data.securityId, @@ -309,55 +312,65 @@ class OesTdMessageLoop: class OesTdApi: def __init__(self, gateway: BaseGateway): + self.config_path: str = None + self.username: str = '' + self.password: str = '' self.gateway = gateway - self.env = OesApiClientEnvT() - self.order_manager = OrderManager() - self.message_loop = OesTdMessageLoop(gateway, - self.env, - self.order_manager, - self) + self._env = OesApiClientEnvT() - self.account_id = None - self.last_seq_index = 1000000 # 0 has special manning for oes + self._order_manager = OrderManager() + self._message_loop = OesTdMessageLoop(gateway, + self._env, + self._order_manager, + self) - def get_new_seq_index(self): - """note: not thread safe currently""" - # todo: add lock - index = self.last_seq_index - self.last_seq_index += 1 - return index + self._last_seq_lock = Lock() + self._last_seq_index = 1000000 # 0 has special manning for oes - def connect(self, config_path: str): - if not OesApi_InitAllByConvention(self.env, config_path, -1, self.last_seq_index): - pass - self.last_seq_index = max(self.last_seq_index, self.env.ordChannel.lastOutMsgSeq + 1) + def connect(self): + """Connect to trading server. + :note set config_path before calling this function + """ + OesApi_SetThreadUsername(self.username) + OesApi_SetThreadPassword(self.password) - if not OesApi_IsValidOrdChannel(self.env.ordChannel): - pass - if not OesApi_IsValidQryChannel(self.env.qryChannel): - pass - if not OesApi_IsValidRptChannel(self.env.rptChannel): - pass + config_path = self.config_path + if not OesApi_InitAllByConvention(self._env, config_path, -1, self._last_seq_index): + return False + self._last_seq_index = max(self._last_seq_index, self._env.ordChannel.lastOutMsgSeq + 1) + + if not OesApi_IsValidOrdChannel(self._env.ordChannel): + return False + if not OesApi_IsValidQryChannel(self._env.qryChannel): + return False + if not OesApi_IsValidRptChannel(self._env.rptChannel): + return False + return True def start(self): - self.message_loop.start() + self._message_loop.start() def stop(self): - self.message_loop.stop() - if not OesApi_LogoutAll(self.env, True): - pass # doc for this function is error - if not OesApi_DestoryAll(self.env): - pass # doc for this function is error + self._message_loop.stop() + OesApi_LogoutAll(self._env, True) + OesApi_DestoryAll(self._env) def join(self): - self.message_loop.join() + self._message_loop.join() + + def _get_new_seq_index(self): + """""" + with self._last_seq_lock: + index = self._last_seq_index + self._last_seq_index += 1 + return index def query_account(self): - OesApi_QueryCashAsset(self.env.qryChannel, - OesQryCashAssetFilterT(), - self.on_query_asset - ) + OesApi_QueryCashAsset(self._env.qryChannel, + OesQryCashAssetFilterT(), + self.on_query_asset + ) def on_query_asset(self, session_info: SGeneralClientChannelT, @@ -376,7 +389,6 @@ class OesTdApi: balance=balance, frozen=balance - availiable, ) - self.account_id = account_id self.gateway.on_account(account) return 1 @@ -386,7 +398,7 @@ class OesTdApi: def _query_stock(self, ) -> bool: f = OesQryStockFilterT() - ret = OesApi_QueryStock(self.env.qryChannel, f, self.on_query_stock) + ret = OesApi_QueryStock(self._env.qryChannel, f, self.on_query_stock) return ret >= 0 def on_query_stock(self, @@ -410,7 +422,7 @@ class OesTdApi: def query_option(self) -> bool: f = OesQryOptionFilterT() - ret = OesApi_QueryOption(self.env.qryChannel, + ret = OesApi_QueryOption(self._env.qryChannel, f, self.on_query_option ) @@ -437,7 +449,7 @@ class OesTdApi: def query_stock_holding(self) -> bool: f = OesQryStkHoldingFilterT() - ret = OesApi_QueryStkHolding(self.env.qryChannel, + ret = OesApi_QueryStkHolding(self._env.qryChannel, f, self.on_query_stock_holding ) @@ -470,7 +482,7 @@ class OesTdApi: f = OesQryStkHoldingFilterT() f.mktId = eOesMarketIdT.OES_MKT_ID_UNDEFINE f.userInfo = 0 - ret = OesApi_QueryOptHolding(self.env.qryChannel, + ret = OesApi_QueryOptHolding(self._env.qryChannel, f, self.on_query_holding ) @@ -526,7 +538,7 @@ class OesTdApi: self.query_option_holding() def send_order(self, vt_req: OrderRequest): - seq_id = self.get_new_seq_index() + seq_id = self._get_new_seq_index() order_id = seq_id oes_req = OesOrdReqT() @@ -542,8 +554,8 @@ class OesTdApi: order = vt_req.create_order_data(str(order_id), self.gateway.gateway_name) order.direction = Direction.NET # fix direction into NET: stock only - self.order_manager.save_local_created(order_id, order) - ret = OesApi_SendOrderReq(self.env.ordChannel, + self._order_manager.save_local_created(order_id, order) + ret = OesApi_SendOrderReq(self._env.ordChannel, oes_req ) @@ -555,11 +567,11 @@ class OesTdApi: return order.vt_orderid def cancel_order(self, vt_req: CancelRequest): - seq_id = self.get_new_seq_index() + seq_id = self._get_new_seq_index() oes_req = OesOrdCancelReqT() order_id = int(vt_req.orderid) - internal_order = self.order_manager.get_from_order_id(order_id) + internal_order = self._order_manager.get_from_order_id(order_id) oes_req.origClOrdId = internal_order.order_id oes_req.mktId = EXCHANGE_VT2OES[vt_req.exchange] @@ -567,8 +579,8 @@ class OesTdApi: oes_req.origClSeqNo = order_id oes_req.invAcctId = "" oes_req.securityId = vt_req.symbol - OesApi_SendOrderCancelReq(self.env.ordChannel, - oes_req) + OesApi_SendOrderCancelReq(self._env.ordChannel, + oes_req) def schedule_query_order(self, internal_order: InternalOrder) -> Thread: th = Thread(target=self.query_order, args=(internal_order,)) @@ -579,7 +591,7 @@ class OesTdApi: f = OesQryOrdFilterT() f.mktId = EXCHANGE_VT2OES[internal_order.vt_order.exchange] f.clSeqNo = internal_order.order_id - ret = OesApi_QueryOrder(self.env.qryChannel, + ret = OesApi_QueryOrder(self._env.qryChannel, f, self.on_query_order ) @@ -591,7 +603,7 @@ class OesTdApi: body: Any, cursor: OesQryCursorT): data: OesOrdCnfmT = cast.toOesOrdItemT(body) - i = self.order_manager.get_from_oes_data(data) + i = self._order_manager.get_from_oes_data(data) vt_order = i.vt_order vt_order.status = STATUS_OES2VT[data.ordStatus] vt_order.volume = data.ordQty - data.canceledQty @@ -605,7 +617,7 @@ class OesTdApi: :return: """ f = OesQryOrdFilterT() - ret = OesApi_QueryOrder(self.env.qryChannel, + ret = OesApi_QueryOrder(self._env.qryChannel, f, self.on_init_query_orders ) @@ -619,7 +631,7 @@ class OesTdApi: ): data: OesOrdCnfmT = cast.toOesOrdItemT(body) try: - i = self.order_manager.get_from_oes_data(data) + i = self._order_manager.get_from_oes_data(data) vt_order = i.vt_order vt_order.status = STATUS_OES2VT[data.ordStatus] vt_order.volume = data.ordQty - data.canceledQty @@ -627,7 +639,7 @@ class OesTdApi: self.gateway.on_order(vt_order) except KeyError: # order_id = self.order_manager.new_remote_id() - order_id = self.order_manager.get_order_id_from_data(data) + order_id = self._order_manager.get_order_id_from_data(data) if data.bsType == eOesBuySellTypeT.OES_BS_TYPE_BUY: offset = Offset.OPEN @@ -649,6 +661,6 @@ class OesTdApi: # this time should be generated automatically or by a static function time=datetime.utcnow().isoformat(), ) - self.order_manager.save_remote_created(order_id, vt_order) + self._order_manager.save_remote_created(order_id, vt_order) self.gateway.on_order(vt_order) return 1