diff --git a/vnpy/gateway/oes/config_template.ini b/vnpy/gateway/oes/config_template.ini index c7db390f..2a3358c6 100644 --- a/vnpy/gateway/oes/config_template.ini +++ b/vnpy/gateway/oes/config_template.ini @@ -11,9 +11,6 @@ ordServer = 1 {td_ord_server} rptServer = 1 {td_rpt_server} qryServer = 1 {td_qry_server} -username = {username} -# 密码支持明文和MD5两种格式 (如 txt:XXX 或 md5:XXX..., 不带前缀则默认为明文) -password = {password} heartBtInt = 30 # 客户端环境号, 用于区分不同客户端实例上报的委托申报, 取值由客户端自行分配 @@ -33,7 +30,7 @@ rpt.subcribeEnvId = 0 # 比如想订阅所有委托、成交相关的回报消息,可以使用如下两种方式: # - rpt.subcribeRptTypes = 1,4,8 # - 或等价的: rpt.subcribeRptTypes = 0x0D -rpt.subcribeRptTypes = 0 +rpt.subcribeRptTypes = 1,2,4,8,0x10,0x20,0x40 # 服务器集群的集群类型 (1: 基于复制集的高可用集群, 2: 基于对等节点的服务器集群, 0: 默认为基于复制集的高可用集群) clusterType = 0 @@ -58,9 +55,6 @@ keepCnt = 9 tcpServer = {md_tcp_server} qryServer = {md_qry_server} -username = {username} -# 密码支持明文和MD5两种格式 (如 txt:XXX 或 md5:XXX..., 不带前缀则默认为明文) -password = {password} heartBtInt = 30 sse.stock.enable = false @@ -93,7 +87,7 @@ mktData.tickExpireType = 0 # 0x400:指数行情, 0x800:期权行情) # 要订阅多个数据种类, 可以用逗号或空格分隔, 或者设置为并集值, 如: # "mktData.dataTypes = 1,2,4" 或等价的 "mktData.dataTypes = 0x07" -mktData.dataTypes = 0 +mktData.dataTypes = 1,2,4,8,0x10 # 请求订阅的行情数据的起始时间 (格式: HHMMSS 或 HHMMSSsss) # (-1: 从头开始获取, 0: 从最新位置开始获取实时行情, 大于0: 从指定的起始时间开始获取) diff --git a/vnpy/gateway/oes/oes_gateway.py b/vnpy/gateway/oes/oes_gateway.py index 83f9b492..7a3f97a7 100644 --- a/vnpy/gateway/oes/oes_gateway.py +++ b/vnpy/gateway/oes/oes_gateway.py @@ -3,14 +3,15 @@ """ import hashlib import os -from threading import Thread +from gettext import gettext as _ +from threading import Thread, Lock from vnpy.trader.gateway import BaseGateway -from vnpy.trader.object import (AccountData, CancelRequest, ContractData, OrderData, OrderRequest, - PositionData, SubscribeRequest, TickData, TradeData) +from vnpy.trader.object import (CancelRequest, OrderRequest, + SubscribeRequest) from vnpy.trader.utility import get_file_path -from .md import OesMdApi -from .td import OesTdApi +from .oes_md import OesMdApi +from .oes_td import OesTdApi from .utils import config_template @@ -19,30 +20,12 @@ class OesGateway(BaseGateway): VN Trader Gateway for BitMEX connection. """ - def on_tick(self, tick: TickData): - super().on_tick(tick) - - def on_trade(self, trade: TradeData): - super().on_trade(trade) - - def on_order(self, order: OrderData): - super().on_order(order) - - def on_position(self, position: PositionData): - super().on_position(position) - - def on_account(self, account: AccountData): - super().on_account(account) - - def on_contract(self, contract: ContractData): - super().on_contract(contract) - default_setting = { - "td_ord_server": "tcp://106.15.58.119:6101", - "td_rpt_server": "tcp://106.15.58.119:6301", - "td_qry_server": "tcp://106.15.58.119:6401", - "md_tcp_server": "tcp://139.196.228.232:5103", - "md_qry_server": "tcp://139.196.228.232:5203", + "td_ord_server": "", + "td_rpt_server": "", + "td_qry_server": "", + "md_tcp_server": "", + "md_qry_server": "", "username": "", "password": "", } @@ -54,15 +37,21 @@ class OesGateway(BaseGateway): self.md_api = OesMdApi(self) self.td_api = OesTdApi(self) - def connect(self, setting: dict): - return self._connect_async(setting) + self._lock_subscribe = Lock() + self._lock_send_order = Lock() + self._lock_cancel_order = Lock() + self._lock_query_position = Lock() + self._lock_query_account = Lock() - def _connect_sync(self, setting: dict): + def connect(self, setting: dict): """""" if not setting['password'].startswith("md5:"): setting['password'] = "md5:" + hashlib.md5(setting['password'].encode()).hexdigest() - config_path = get_file_path("vnoes.ini") + username = setting['username'] + password = setting['password'] + + config_path = str(get_file_path("vnoes.ini")) with open(config_path, "wt") as f: if 'test' in setting: log_level = 'DEBUG' @@ -72,7 +61,7 @@ class OesGateway(BaseGateway): log_mode = 'file' log_dir = get_file_path('oes') log_path = os.path.join(log_dir, 'log.log') - if os.path.exists(log_dir): + if not os.path.exists(log_dir): os.mkdir(log_dir) content = config_template.format(**setting, log_level=log_level, @@ -80,41 +69,59 @@ class OesGateway(BaseGateway): log_path=log_path) f.write(content) - self.td_api.connect(str(config_path)) - self.td_api.query_account() - self.td_api.query_contracts() - self.td_api.query_position() - self.td_api.init_query_orders() - self.td_api.start() + Thread(target=self._connect_md_sync, args=(config_path, username, password)).start() + Thread(target=self._connect_td_sync, args=(config_path, username, password)).start() - self.md_api.connect(str(config_path)) - self.md_api.start() + def _connect_td_sync(self, config_path, username, password): + self.td_api.config_path = config_path + self.td_api.username = username + self.td_api.password = password + if self.td_api.connect(): + self.write_log(_("成功连接到交易服务器")) + self.td_api.query_account() + self.td_api.query_contracts() + self.write_log("合约信息查询成功") + self.td_api.query_position() + self.td_api.query_orders() + self.td_api.start() + else: + self.write_log(_("无法连接到交易服务器,请检查你的配置")) - def _connect_async(self, setting: dict): - Thread(target=self._connect_sync, args=(setting, )).start() + def _connect_md_sync(self, config_path, username, password): + self.md_api.config_path = config_path + self.md_api.username = username + self.md_api.password = password + if self.md_api.connect(): + self.md_api.start() + else: + self.write_log(_("无法连接到行情服务器,请检查你的配置")) def subscribe(self, req: SubscribeRequest): """""" - self.md_api.subscribe(req) + with self._lock_subscribe: + self.md_api.subscribe(req) def send_order(self, req: OrderRequest): """""" - return self.td_api.send_order(req) + with self._lock_send_order: + return self.td_api.send_order(req) def cancel_order(self, req: CancelRequest): """""" - return self.td_api.cancel_order(req) + with self._lock_cancel_order: + self.td_api.cancel_order(req) def query_account(self): """""" - return self.td_api.query_account() + with self._lock_query_account: + self.td_api.query_account() def query_position(self): """""" - return self.query_position() + with self._lock_query_position: + self.td_api.query_position() def close(self): """""" self.md_api.stop() self.td_api.stop() - pass diff --git a/vnpy/gateway/oes/md.py b/vnpy/gateway/oes/oes_md.py similarity index 71% rename from vnpy/gateway/oes/md.py rename to vnpy/gateway/oes/oes_md.py index 8fbdf8df..476799df 100644 --- a/vnpy/gateway/oes/md.py +++ b/vnpy/gateway/oes/oes_md.py @@ -1,14 +1,16 @@ +import time from datetime import datetime +from gettext import gettext as _ from threading import Thread # noinspection PyUnresolvedReferences from typing import Any, Callable, Dict from vnpy.api.oes.vnoes import MdsApiClientEnvT, MdsApi_DestoryAll, MdsApi_InitAllByConvention, \ - MdsApi_IsValidQryChannel, MdsApi_IsValidTcpChannel, MdsApi_LogoutAll, \ - MdsApi_SubscribeMarketData, MdsApi_WaitOnMsg, MdsL2StockSnapshotBodyT, MdsMktDataRequestEntryT, \ - MdsMktDataRequestReqT, MdsMktRspMsgBodyT, MdsStockSnapshotBodyT, SGeneralClientChannelT, \ - SMsgHeadT, SPlatform_IsNegEpipe, SPlatform_IsNegEtimeout, cast, eMdsExchangeIdT, \ - eMdsMktSubscribeFlagT, eMdsMsgTypeT, eMdsSecurityTypeT, eMdsSubscribeDataTypeT, \ + MdsApi_IsValidQryChannel, MdsApi_IsValidTcpChannel, MdsApi_LogoutAll, MdsApi_SetThreadPassword, \ + MdsApi_SetThreadUsername, MdsApi_SubscribeMarketData, MdsApi_WaitOnMsg, MdsL2StockSnapshotBodyT, \ + MdsMktDataRequestEntryT, MdsMktDataRequestReqT, MdsMktRspMsgBodyT, MdsStockSnapshotBodyT, \ + SGeneralClientChannelT, SMsgHeadT, SPlatform_IsNegEpipe, cast, \ + eMdsExchangeIdT, eMdsMktSubscribeFlagT, eMdsMsgTypeT, eMdsSecurityTypeT, eMdsSubscribeDataTypeT, \ eMdsSubscribeModeT, eMdsSubscribedTickExpireTypeT, eSMsgProtocolTypeT from vnpy.trader.constant import Exchange @@ -25,18 +27,23 @@ EXCHANGE_VT2MDS = {v: k for k, v in EXCHANGE_MDS2VT.items()} class OesMdMessageLoop: - def __init__(self, gateway: BaseGateway, env: MdsApiClientEnvT): + def __init__(self, gateway: BaseGateway, md: "OesMdApi", env: MdsApiClientEnvT): + """""" self.gateway = gateway self.env = env - self.alive = False - self.th = Thread(target=self.message_loop) + + self._alive = False + self._md = md + self._th = Thread(target=self._message_loop) self.message_handlers: Dict[int, Callable[[dict], None]] = { + # tick & orderbook eMdsMsgTypeT.MDS_MSGTYPE_MARKET_DATA_SNAPSHOT_FULL_REFRESH: self.on_market_full_refresh, eMdsMsgTypeT.MDS_MSGTYPE_L2_MARKET_DATA_SNAPSHOT: self.on_l2_market_data_snapshot, eMdsMsgTypeT.MDS_MSGTYPE_L2_ORDER: self.on_l2_order, eMdsMsgTypeT.MDS_MSGTYPE_L2_TRADE: self.on_l2_trade, + # others eMdsMsgTypeT.MDS_MSGTYPE_QRY_SECURITY_STATUS: self.on_security_status, eMdsMsgTypeT.MDS_MSGTYPE_L2_MARKET_DATA_INCREMENTAL: lambda x: 1, eMdsMsgTypeT.MDS_MSGTYPE_L2_BEST_ORDERS_SNAPSHOT: self.on_best_orders_snapshot, @@ -46,15 +53,36 @@ class OesMdMessageLoop: eMdsMsgTypeT.MDS_MSGTYPE_TRADING_SESSION_STATUS: self.on_trading_session_status, eMdsMsgTypeT.MDS_MSGTYPE_SECURITY_STATUS: self.on_security_status, eMdsMsgTypeT.MDS_MSGTYPE_MARKET_DATA_REQUEST: self.on_market_data_request, + eMdsMsgTypeT.MDS_MSGTYPE_HEARTBEAT: lambda x: 1, } self.last_tick: Dict[str, TickData] = {} self.symbol_to_exchange: Dict[str, Exchange] = {} def register_symbol(self, symbol: str, exchange: Exchange): + """""" self.symbol_to_exchange[symbol] = exchange - def get_last_tick(self, symbol): + def start(self): + """""" + self._alive = True + self._th.start() + + def stop(self): + """""" + self._alive = False + + def join(self): + """""" + self._th.join() + + def reconnect(self): + """""" + self.gateway.write_log(_("正在尝试重新连接到行情服务器。")) + return self._md.connect() + + def _get_last_tick(self, symbol): + """""" try: return self.last_tick[symbol] except KeyError: @@ -62,22 +90,15 @@ class OesMdMessageLoop: gateway_name=self.gateway.gateway_name, symbol=symbol, exchange=self.symbol_to_exchange[symbol], - # todo: use cache of something else to resolve exchange datetime=datetime.utcnow() ) self.last_tick[symbol] = tick return tick - def start(self): - self.alive = True - self.th.start() - - def join(self): - self.th.join() - - def on_message(self, session_info: SGeneralClientChannelT, - head: SMsgHeadT, - body: Any): + def _on_message(self, session_info: SGeneralClientChannelT, + head: SMsgHeadT, + body: Any): + """""" if session_info.protocolType == eSMsgProtocolTypeT.SMSG_PROTO_BINARY: b = cast.toMdsMktRspMsgBodyT(body) if head.msgId in self.message_handlers: @@ -91,31 +112,30 @@ class OesMdMessageLoop: self.gateway.write_log(f"unknown prototype : {session_info.protocolType}") return 1 - def message_loop(self): + def _message_loop(self): + """""" tcp_channel = self.env.tcpChannel timeout_ms = 1000 - is_timeout = SPlatform_IsNegEtimeout + # is_timeout = SPlatform_IsNegEtimeout is_disconnected = SPlatform_IsNegEpipe - while self.alive: + while self._alive: ret = MdsApi_WaitOnMsg(tcp_channel, timeout_ms, - self.on_message) + self._on_message) if ret < 0: - if is_timeout(ret): - pass + # if is_timeout(ret): + # pass # just no message if is_disconnected(ret): - # todo: handle disconnected - self.alive = False - break - pass + self.gateway.write_log(_("与行情服务器的连接已断开。")) + while self._alive and not self.reconnect(): + time.sleep(1) return def on_l2_market_data_snapshot(self, d: MdsMktRspMsgBodyT): + """""" data: MdsL2StockSnapshotBodyT = d.mktDataSnapshot.l2Stock symbol = str(data.SecurityID) - tick = self.get_last_tick(symbol) - tick.limit_up = data.HighPx / 10000 - tick.limit_down = data.LowPx / 10000 + tick = self._get_last_tick(symbol) tick.open_price = data.OpenPx / 10000 tick.pre_close = data.ClosePx / 10000 tick.high_price = data.HighPx / 10000 @@ -128,11 +148,10 @@ class OesMdMessageLoop: self.gateway.on_tick(tick) def on_market_full_refresh(self, d: MdsMktRspMsgBodyT): + """""" data: MdsStockSnapshotBodyT = d.mktDataSnapshot.stock symbol = data.SecurityID - tick = self.get_last_tick(symbol) - tick.limit_up = data.HighPx / 10000 - tick.limit_down = data.LowPx / 10000 + tick = self._get_last_tick(symbol) tick.open_price = data.OpenPx / 10000 tick.pre_close = data.ClosePx / 10000 tick.high_price = data.HighPx / 10000 @@ -143,76 +162,96 @@ class OesMdMessageLoop: for i in range(5): tick.__dict__['ask_price_' + str(i + 1)] = data.OfferLevels[i].Price / 10000 self.gateway.on_tick(tick) - pass def on_l2_trade(self, d: MdsMktRspMsgBodyT): + """""" data = d.trade symbol = data.SecurityID - tick = self.get_last_tick(symbol) + tick = self._get_last_tick(symbol) tick.datetime = datetime.utcnow() tick.volume = data.TradeQty tick.last_price = data.TradePrice / 10000 self.gateway.on_tick(tick) def on_market_data_request(self, d: MdsMktRspMsgBodyT): + """""" pass def on_trading_session_status(self, d: MdsMktRspMsgBodyT): + """""" pass def on_l2_market_overview(self, d: MdsMktRspMsgBodyT): + """""" pass def on_index_snapshot_full_refresh(self, d: MdsMktRspMsgBodyT): + """""" pass def on_option_snapshot_ful_refresh(self, d: MdsMktRspMsgBodyT): + """""" pass def on_best_orders_snapshot(self, d: MdsMktRspMsgBodyT): + """""" pass def on_l2_order(self, d: MdsMktRspMsgBodyT): + """""" pass def on_security_status(self, d: MdsMktRspMsgBodyT): + """""" pass - def stop(self): - self.alive = False - class OesMdApi: def __init__(self, gateway: BaseGateway): + """""" self.gateway = gateway - self.env = MdsApiClientEnvT() - self.message_loop = OesMdMessageLoop(gateway, self.env) + self.config_path: str = '' + self.username: str = '' + self.password: str = '' - def connect(self, config_path: str): - if not MdsApi_InitAllByConvention(self.env, config_path): - pass + self._env = MdsApiClientEnvT() + self._message_loop = OesMdMessageLoop(gateway, self, self._env) - if not MdsApi_IsValidTcpChannel(self.env.tcpChannel): - pass - if not MdsApi_IsValidQryChannel(self.env.qryChannel): - pass + def connect(self) -> bool: + """""" + """Connect to trading server. + :note set config_path before calling this function + """ + MdsApi_SetThreadUsername(self.username) + MdsApi_SetThreadPassword(self.password) + + config_path = self.config_path + if not MdsApi_InitAllByConvention(self._env, config_path): + return False + if not MdsApi_IsValidTcpChannel(self._env.tcpChannel): + return False + if not MdsApi_IsValidQryChannel(self._env.qryChannel): + return False + return True def start(self): - self.message_loop.start() + """""" + self._message_loop.start() def stop(self): - self.message_loop.stop() - if not MdsApi_LogoutAll(self.env, True): - pass # doc for this function is error - if not MdsApi_DestoryAll(self.env): - pass # doc for this function is error + """""" + self._message_loop.stop() + MdsApi_LogoutAll(self._env, True) + MdsApi_DestoryAll(self._env) def join(self): - self.message_loop.join() + """""" + self._message_loop.join() # why isn't arg a ContractData? def subscribe(self, req: SubscribeRequest): + """""" mds_req = MdsMktDataRequestReqT() entry = MdsMktDataRequestEntryT() mds_req.subMode = eMdsSubscribeModeT.MDS_SUB_MODE_APPEND @@ -240,12 +279,11 @@ class OesMdApi: entry.securityType = eMdsSecurityTypeT.MDS_SECURITY_TYPE_STOCK # todo: option and others entry.instrId = int(req.symbol) - self.message_loop.register_symbol(req.symbol, req.exchange) + self._message_loop.register_symbol(req.symbol, req.exchange) ret = MdsApi_SubscribeMarketData( - self.env.tcpChannel, + self._env.tcpChannel, mds_req, entry) if not ret: self.gateway.write_log( f"MdsApi_SubscribeByString failed with {ret}:{error_to_str(ret)}") - pass diff --git a/vnpy/gateway/oes/td.py b/vnpy/gateway/oes/oes_td.py similarity index 58% rename from vnpy/gateway/oes/td.py rename to vnpy/gateway/oes/oes_td.py index 37408d25..ea6ae7e7 100644 --- a/vnpy/gateway/oes/td.py +++ b/vnpy/gateway/oes/oes_td.py @@ -1,19 +1,20 @@ from dataclasses import dataclass -from datetime import datetime -from threading import Thread +from datetime import datetime, timedelta, timezone +from gettext import gettext as _ +from threading import Lock, Thread # noinspection PyUnresolvedReferences -from typing import Any, Callable, Dict, Tuple +from typing import Any, Callable, Dict from vnpy.api.oes.vnoes import OesApiClientEnvT, OesApi_DestoryAll, OesApi_InitAllByConvention, \ OesApi_IsValidOrdChannel, OesApi_IsValidQryChannel, OesApi_IsValidRptChannel, OesApi_LogoutAll, \ - OesApi_QueryCashAsset, OesApi_QueryEtf, OesApi_QueryIssue, OesApi_QueryOptHolding, \ - OesApi_QueryOption, OesApi_QueryOrder, OesApi_QueryStkHolding, OesApi_QueryStock, \ - OesApi_SendOrderCancelReq, OesApi_SendOrderReq, OesApi_WaitReportMsg, OesOrdCancelReqT, \ - OesOrdCnfmT, OesOrdRejectT, OesOrdReqT, OesQryCashAssetFilterT, OesQryCursorT, OesQryEtfFilterT, \ - OesQryIssueFilterT, OesQryOptionFilterT, OesQryOrdFilterT, OesQryStkHoldingFilterT, \ - OesQryStockFilterT, OesRspMsgBodyT, OesStockBaseInfoT, OesTrdCnfmT, SGeneralClientChannelT, \ - SMSG_PROTO_BINARY, SMsgHeadT, SPlatform_IsNegEpipe, SPlatform_IsNegEtimeout, cast, \ - eOesBuySellTypeT, eOesMarketIdT, eOesMsgTypeT, eOesOrdStatusT, eOesOrdTypeShT, eOesOrdTypeSzT + OesApi_QueryCashAsset, OesApi_QueryOptHolding, OesApi_QueryOption, OesApi_QueryOrder, \ + OesApi_QueryStkHolding, OesApi_QueryStock, OesApi_SendOrderCancelReq, OesApi_SendOrderReq, \ + OesApi_SetThreadPassword, OesApi_SetThreadUsername, OesApi_WaitReportMsg, OesOrdCancelReqT, \ + OesOrdCnfmT, OesOrdRejectT, OesOrdReqT, OesQryCashAssetFilterT, OesQryCursorT, \ + OesQryOptionFilterT, OesQryOrdFilterT, OesQryStkHoldingFilterT, OesQryStockFilterT, \ + OesRspMsgBodyT, OesStockBaseInfoT, OesTrdCnfmT, SGeneralClientChannelT, SMSG_PROTO_BINARY, \ + SMsgHeadT, SPlatform_IsNegEpipe, cast, eOesBuySellTypeT, eOesMarketIdT, \ + eOesMsgTypeT, eOesOrdStatusT, eOesOrdTypeShT, eOesOrdTypeSzT from vnpy.gateway.oes.error_code import error_to_str from vnpy.trader.constant import Direction, Exchange, Offset, PriceType, Product, Status @@ -84,87 +85,28 @@ STATUS_OES2VT = { eOesOrdStatusT.OES_ORD_STATUS_INVALID_SZ_TRY_AGAIN: Status.REJECTED, } +bjtz = timezone(timedelta(hours=8)) + @dataclass class InternalOrder: order_id: int = None vt_order: OrderData = None - req_data: OesOrdReqT = None - rpt_data: OesOrdCnfmT = None -class OrderManager: +def parse_oes_datetime(date: int, time: int): + """convert oes datetime to python datetime""" + # YYYYMMDD + year = int(date / 10000) + month = int((date % 10000) / 100) + day = int(date % 100) - def __init__(self): - self.last_order_id = 100000000 - self._orders: Dict[int, InternalOrder] = {} - - # key tuple: seqNo, ordId, envId, userInfo - self._remote_created_orders: Dict[Tuple[int, int, int, int], InternalOrder] = {} - - @staticmethod - def hash_remote_order(data): - key = (data.origClSeqNo, data.origClOrdId, data.origClEnvId, data.userInfo) - return key - - @staticmethod - def hash_remote_trade(data: OesTrdCnfmT): - key = (data.clSeqNo, data.clOrdId, data.clEnvId, data.userInfo) - return key - - def new_local_id(self): - id = self.last_order_id - self.last_order_id += 1 - return id - - def new_remote_id(self): - id = self.last_order_id - self.last_order_id += 1 - return id - - def save_local_created(self, order_id: int, order: OrderData, oes_req: OesOrdReqT): - self._orders[order_id] = InternalOrder( - order_id=order_id, - vt_order=order, - req_data=oes_req - ) - - def save_remote_created(self, order_id: int, vt_order: OrderData, data: OesOrdCnfmT): - internal_order = InternalOrder( - order_id=order_id, - vt_order=vt_order, - rpt_data=data - ) - self._orders[order_id] = internal_order - key = self.hash_remote_order(data) - self._remote_created_orders[key] = internal_order - - def get_from_order_id(self, id: int): - return self._orders[id] - - def get_remote_created_order_from_oes_data(self, data): - """ - :return: internal_order if succeed else None, will check only remote created order - """ - try: - key = self.hash_remote_order(data) - except AttributeError: - key = self.hash_remote_trade(data) - try: - return self._remote_created_orders[key] - except KeyError: - return None - - def get_from_oes_data(self, data): - try: - key = self.hash_remote_order(data) - except AttributeError: - key = self.hash_remote_trade(data) - try: - return self._remote_created_orders[key] - except KeyError: - order_id = key[3] - return self._orders[order_id] + # HHMMSSsss + hour = int(time / 10000000) + minute = int((time % 10000000) / 100000) + sec = int((time % 100000) / 1000) + mill = int(time % 1000) + return datetime(year, month, day, hour, minute, sec, mill * 1000, tzinfo=bjtz) class OesTdMessageLoop: @@ -172,38 +114,52 @@ class OesTdMessageLoop: def __init__(self, gateway: BaseGateway, env: OesApiClientEnvT, - order_manager: OrderManager, td: "OesTdApi" ): + """""" self.gateway = gateway - self.env = env - self.order_manager = order_manager - self.td = td - self.alive = False - self.th = Thread(target=self.message_loop) + self._env = env + self._td = td + + self._alive = False + self._th = Thread(target=self._message_loop) self.message_handlers: Dict[int, Callable[[dict], None]] = { - eOesMsgTypeT.OESMSG_RPT_BUSINESS_REJECT: self.on_reject, + eOesMsgTypeT.OESMSG_RPT_BUSINESS_REJECT: self.on_order_rejected, eOesMsgTypeT.OESMSG_RPT_ORDER_INSERT: self.on_order_inserted, eOesMsgTypeT.OESMSG_RPT_ORDER_REPORT: self.on_order_report, eOesMsgTypeT.OESMSG_RPT_TRADE_REPORT: self.on_trade_report, eOesMsgTypeT.OESMSG_RPT_STOCK_HOLDING_VARIATION: self.on_stock_holding, eOesMsgTypeT.OESMSG_RPT_OPTION_HOLDING_VARIATION: self.on_option_holding, eOesMsgTypeT.OESMSG_RPT_CASH_ASSET_VARIATION: self.on_cash, - eOesMsgTypeT.OESMSG_SESS_HEARTBEAT: lambda x: x, + + eOesMsgTypeT.OESMSG_RPT_REPORT_SYNCHRONIZATION: lambda x: 1, + eOesMsgTypeT.OESMSG_SESS_HEARTBEAT: lambda x: 1, } def start(self): - self.alive = True - self.th.start() + """""" + self._alive = True + self._th.start() + + def stop(self): + """""" + self._alive = False def join(self): - self.th.join() + """""" + self._th.join() - def on_message(self, session_info: SGeneralClientChannelT, - head: SMsgHeadT, - body: Any): + def reconnect(self): + """""" + self.gateway.write_log(_("正在尝试重新连接到交易服务器。")) + self._td.connect() + + def _on_message(self, session_info: SGeneralClientChannelT, + head: SMsgHeadT, + body: Any): + """""" if session_info.protocolType == SMSG_PROTO_BINARY: b = cast.toOesRspMsgBodyT(body) if head.msgId in self.message_handlers: @@ -215,67 +171,83 @@ class OesTdMessageLoop: self.gateway.write_log(f"unknown prototype : {session_info.protocolType}") return 1 - def message_loop(self): - rtp_channel = self.env.rptChannel + def _message_loop(self): + """""" + rpt_channel = self._env.rptChannel timeout_ms = 1000 - is_timeout = SPlatform_IsNegEtimeout is_disconnected = SPlatform_IsNegEpipe - while self.alive: - ret = OesApi_WaitReportMsg(rtp_channel, + while self._alive: + ret = OesApi_WaitReportMsg(rpt_channel, timeout_ms, - self.on_message) + self._on_message) if ret < 0: - if is_timeout(ret): - pass + # if is_timeout(ret): + # pass # just no message if is_disconnected(ret): - # todo: handle disconnected - self.alive = False - break - pass + self.gateway.write_log(_("与交易服务器的连接已断开。")) + while self._alive and not self.reconnect(): + pass return - def on_reject(self, d: OesRspMsgBodyT): + def on_order_rejected(self, d: OesRspMsgBodyT): + """""" error_code = d.rptMsg.rptHead.ordRejReason error_string = error_to_str(error_code) data: OesOrdRejectT = d.rptMsg.rptBody.ordRejectRsp - i = self.order_manager.get_from_oes_data(data) - vt_order = i.vt_order + if not data.origClSeqNo: + i = self._td.get_order(data.clSeqNo) + vt_order = i.vt_order - if vt_order == Status.ALLTRADED: - return + if vt_order == Status.ALLTRADED: + return - vt_order.status = Status.REJECTED + vt_order.status = Status.REJECTED - self.gateway.on_order(vt_order) - self.gateway.write_log( - f"Order: {vt_order.vt_symbol}-{vt_order.vt_orderid} Code: {error_code} Rejected: {error_string}") + self.gateway.on_order(vt_order) + self.gateway.write_log( + f"Order: {vt_order.vt_symbol}-{vt_order.vt_orderid} Code: {error_code} Rejected: {error_string}") + else: + self.gateway.write_log(f"撤单失败,订单号: {data.origClSeqNo}。原因:{error_string}") def on_order_inserted(self, d: OesRspMsgBodyT): + """""" data = d.rptMsg.rptBody.ordInsertRsp - i = self.order_manager.get_from_oes_data(data) + if not data.origClSeqNo: + i = self._td.get_order(data.clSeqNo) + else: + i = self._td.get_order(data.origClSeqNo) vt_order = i.vt_order vt_order.status = STATUS_OES2VT[data.ordStatus] - vt_order.volume = data.ordQty - data.canceledQty + vt_order.volume = data.ordQty vt_order.traded = data.cumQty + vt_order.time = parse_oes_datetime(data.ordDate, data.ordTime) self.gateway.on_order(vt_order) def on_order_report(self, d: OesRspMsgBodyT): + """""" data: OesOrdCnfmT = d.rptMsg.rptBody.ordCnfm - i = self.order_manager.get_from_oes_data(data) + if not data.origClSeqNo: + i = self._td.get_order(data.clSeqNo) + else: + i = self._td.get_order(data.origClSeqNo) vt_order = i.vt_order + vt_order.status = STATUS_OES2VT[data.ordStatus] - vt_order.volume = data.ordQty - data.canceledQty + vt_order.volume = data.ordQty vt_order.traded = data.cumQty + vt_order.time = parse_oes_datetime(data.ordDate, data.ordCnfmTime) + self.gateway.on_order(vt_order) def on_trade_report(self, d: OesRspMsgBodyT): + """""" data: OesTrdCnfmT = d.rptMsg.rptBody.trdCnfm - i = self.order_manager.get_from_oes_data(data) + i = self._td.get_order(data.clSeqNo) vt_order = i.vt_order # vt_order.status = STATUS_OES2VT[data.ordStatus] @@ -289,25 +261,20 @@ class OesTdMessageLoop: offset=vt_order.offset, price=data.trdPrice / 10000, volume=data.trdQty, - time=datetime.utcnow().isoformat() # strict + time=parse_oes_datetime(data.trdDate, data.trdTime) ) + vt_order.status = STATUS_OES2VT[data.ordStatus] + vt_order.traded = data.cumQty + vt_order.time = parse_oes_datetime(data.trdDate, data.trdTime) self.gateway.on_trade(trade) - - # hack : - # Sometimes order_report is not received after a trade is received. - # (only trade msg but no order msg) - # This cause a problem that vt_order.traded stay 0 after a trade, which is a error state. - # So we have to query new status of order for every receiving of trade. - # BUT - # Oes have no async call to query order only. - # And calling sync function here will slow down vnpy. - # So we queue it into another thread. - self.td.schedule_query_order(i) + self.gateway.on_order(vt_order) def on_option_holding(self, d: OesRspMsgBodyT): + """""" pass def on_stock_holding(self, d: OesRspMsgBodyT): + """""" data = d.rptMsg.rptBody.stkHoldingRpt position = PositionData( gateway_name=self.gateway.gateway_name, @@ -315,21 +282,22 @@ class OesTdMessageLoop: exchange=EXCHANGE_OES2VT[data.mktId], direction=Direction.NET, volume=data.sumHld, - frozen=data.lockHld, + frozen=data.lockHld, # todo: to verify price=data.costPrice / 10000, # pnl=data.costPrice - data.originalCostAmt, - pnl=0, # todo: oes只提供日初持仓价格信息,不提供最初持仓价格信息,所以pnl只有当日的 + pnl=0, yd_volume=data.originalHld, ) self.gateway.on_position(position) def on_cash(self, d: OesRspMsgBodyT): + """""" data = d.rptMsg.rptBody.cashAssetRpt balance = data.currentTotalBal availiable = data.currentAvailableBal # drawable = data.currentDrawableBal - account_id = data.custId + account_id = data.cashAcctId account = AccountData( gateway_name=self.gateway.gateway_name, accountid=account_id, @@ -339,59 +307,74 @@ class OesTdMessageLoop: self.gateway.on_account(account) return 1 - def stop(self): - self.alive = False - class OesTdApi: def __init__(self, gateway: BaseGateway): + """""" + self.config_path: str = None + self.username: str = '' + self.password: str = '' self.gateway = gateway - self.env = OesApiClientEnvT() - self.order_manager = OrderManager() - self.message_loop = OesTdMessageLoop(gateway, - self.env, - self.order_manager, - self) + self._env = OesApiClientEnvT() - self.account_id = None - self.last_seq_index = 1 # 0 has special manning for oes + self._message_loop = OesTdMessageLoop(gateway, + self._env, + self) - def connect(self, config_path: str): - if not OesApi_InitAllByConvention(self.env, config_path, -1, self.last_seq_index): - pass - self.last_seq_index = self.env.ordChannel.lastOutMsgSeq + 1 + self._last_seq_lock = Lock() + self._last_seq_index = 1000000 # 0 has special manning for oes - if not OesApi_IsValidOrdChannel(self.env.ordChannel): - pass - if not OesApi_IsValidQryChannel(self.env.qryChannel): - pass - if not OesApi_IsValidRptChannel(self.env.rptChannel): - pass + self._orders: Dict[int, InternalOrder] = {} + + def connect(self): + """Connect to trading server. + :note set config_path before calling this function + """ + OesApi_SetThreadUsername(self.username) + OesApi_SetThreadPassword(self.password) + + config_path = self.config_path + if not OesApi_InitAllByConvention(self._env, config_path, -1, self._last_seq_index): + return False + self._last_seq_index = max(self._last_seq_index, self._env.ordChannel.lastOutMsgSeq + 1) + + if not OesApi_IsValidOrdChannel(self._env.ordChannel): + return False + if not OesApi_IsValidQryChannel(self._env.qryChannel): + return False + if not OesApi_IsValidRptChannel(self._env.rptChannel): + return False + return True def start(self): - self.message_loop.start() + """""" + self._message_loop.start() def stop(self): - self.message_loop.stop() - if not OesApi_LogoutAll(self.env, True): - pass # doc for this function is error - if not OesApi_DestoryAll(self.env): - pass # doc for this function is error + """""" + self._message_loop.stop() + OesApi_LogoutAll(self._env, True) + OesApi_DestoryAll(self._env) def join(self): - self.message_loop.join() + """""" + self._message_loop.join() - def query_account(self) -> bool: - return self.query_cash_asset() + def _get_new_seq_index(self): + """""" + with self._last_seq_lock: + index = self._last_seq_index + self._last_seq_index += 1 + return index - def query_cash_asset(self) -> bool: - ret = OesApi_QueryCashAsset(self.env.qryChannel, - OesQryCashAssetFilterT(), - self.on_query_asset - ) - return ret >= 0 + def query_account(self): + """""" + OesApi_QueryCashAsset(self._env.qryChannel, + OesQryCashAssetFilterT(), + self.on_query_asset + ) def on_query_asset(self, session_info: SGeneralClientChannelT, @@ -399,28 +382,25 @@ class OesTdApi: body: Any, cursor: OesQryCursorT, ): + """""" data = cast.toOesCashAssetItemT(body) balance = data.currentTotalBal / 10000 availiable = data.currentAvailableBal / 10000 # drawable = data.currentDrawableBal - account_id = data.custId + account_id = data.cashAcctId account = AccountData( gateway_name=self.gateway.gateway_name, accountid=account_id, balance=balance, frozen=balance - availiable, ) - self.account_id = account_id self.gateway.on_account(account) return 1 def query_stock(self, ) -> bool: - # Thread(target=self._query_stock, ).start() - return self._query_stock() - - def _query_stock(self, ) -> bool: + """""" f = OesQryStockFilterT() - ret = OesApi_QueryStock(self.env.qryChannel, f, self.on_query_stock) + ret = OesApi_QueryStock(self._env.qryChannel, f, self.on_query_stock) return ret >= 0 def on_query_stock(self, @@ -429,6 +409,7 @@ class OesTdApi: body: Any, cursor: OesQryCursorT, ): + """""" data: OesStockBaseInfoT = cast.toOesStockItemT(body) contract = ContractData( gateway_name=self.gateway.gateway_name, @@ -443,8 +424,9 @@ class OesTdApi: return 1 def query_option(self) -> bool: + """""" f = OesQryOptionFilterT() - ret = OesApi_QueryOption(self.env.qryChannel, + ret = OesApi_QueryOption(self._env.qryChannel, f, self.on_query_option ) @@ -456,6 +438,7 @@ class OesTdApi: body: Any, cursor: OesQryCursorT, ): + """""" data = cast.toOesOptionItemT(body) contract = ContractData( gateway_name=self.gateway.gateway_name, @@ -469,63 +452,10 @@ class OesTdApi: self.gateway.on_contract(contract) return 1 - def query_issue(self) -> bool: - f = OesQryIssueFilterT() - ret = OesApi_QueryIssue(self.env.qryChannel, - f, - self.on_query_issue - ) - return ret >= 0 - - def on_query_issue(self, - session_info: SGeneralClientChannelT, - head: SMsgHeadT, - body: Any, - cursor: OesQryCursorT, - ): - data = cast.toOesIssueItemT(body) - contract = ContractData( - gateway_name=self.gateway.gateway_name, - symbol=data.securityId, - exchange=EXCHANGE_OES2VT[data.mktId], - name=data.securityName, - product=PRODUCT_OES2VT[data.mktId], - size=data.qtyUnit, - pricetick=1, - ) - self.gateway.on_contract(contract) - return 1 - - def query_etf(self) -> bool: - f = OesQryEtfFilterT() - ret = OesApi_QueryEtf(self.env.qryChannel, - f, - self.on_query_etf - ) - return ret >= 0 - - def on_query_etf(self, - session_info: SGeneralClientChannelT, - head: SMsgHeadT, - body: Any, - cursor: OesQryCursorT, - ): - data = cast.toOesEtfItemT(body) - contract = ContractData( - gateway_name=self.gateway.gateway_name, - symbol=data.securityId, - exchange=EXCHANGE_OES2VT[data.mktId], - name=data.securityId, - product=PRODUCT_OES2VT[data.mktId], - size=data.creRdmUnit, # todo: to verify! creRdmUnit : 每个篮子 (最小申购、赎回单位) 对应的ETF份数 - pricetick=1, - ) - self.gateway.on_contract(contract) - return 1 - def query_stock_holding(self) -> bool: + """""" f = OesQryStkHoldingFilterT() - ret = OesApi_QueryStkHolding(self.env.qryChannel, + ret = OesApi_QueryStkHolding(self._env.qryChannel, f, self.on_query_stock_holding ) @@ -537,6 +467,7 @@ class OesTdApi: body: Any, cursor: OesQryCursorT, ): + """""" data = cast.toOesStkHoldingItemT(body) position = PositionData( @@ -555,10 +486,11 @@ class OesTdApi: return 1 def query_option_holding(self) -> bool: + """""" f = OesQryStkHoldingFilterT() f.mktId = eOesMarketIdT.OES_MKT_ID_UNDEFINE f.userInfo = 0 - ret = OesApi_QueryOptHolding(self.env.qryChannel, + ret = OesApi_QueryOptHolding(self._env.qryChannel, f, self.on_query_holding ) @@ -570,6 +502,7 @@ class OesTdApi: body: Any, cursor: OesQryCursorT, ): + """""" data = cast.toOesOptHoldingItemT(body) # 权利 @@ -605,19 +538,20 @@ class OesTdApi: return 1 def query_contracts(self): + """""" self.query_stock() # self.query_option() # self.query_issue() - pass def query_position(self): + """""" self.query_stock_holding() self.query_option_holding() def send_order(self, vt_req: OrderRequest): - seq_id = self.last_seq_index - self.last_seq_index += 1 # note: thread un-safe here, conflict with on_query_order - order_id = self.order_manager.new_local_id() + """""" + seq_id = self._get_new_seq_index() + order_id = seq_id oes_req = OesOrdReqT() oes_req.clSeqNo = seq_id @@ -628,16 +562,17 @@ class OesTdApi: oes_req.securityId = vt_req.symbol oes_req.ordQty = int(vt_req.volume) oes_req.ordPrice = int(vt_req.price * 10000) - oes_req.userInfo = order_id - - ret = OesApi_SendOrderReq(self.env.ordChannel, - oes_req - ) + oes_req.origClOrdId = order_id order = vt_req.create_order_data(str(order_id), self.gateway.gateway_name) order.direction = Direction.NET # fix direction into NET: stock only + self.save_order(order_id, order) + + ret = OesApi_SendOrderReq(self._env.ordChannel, + oes_req + ) + if ret >= 0: - self.order_manager.save_local_created(order_id, order, oes_req) self.gateway.on_order(order) else: self.gateway.write_log("Failed to send_order!") @@ -645,58 +580,26 @@ class OesTdApi: return order.vt_orderid def cancel_order(self, vt_req: CancelRequest): - seq_id = self.last_seq_index - self.last_seq_index += 1 # note: thread un-safe here + """""" + seq_id = self._get_new_seq_index() oes_req = OesOrdCancelReqT() order_id = int(vt_req.orderid) - internal_order = self.order_manager.get_from_order_id(order_id) - if internal_order.rpt_data: - data = internal_order.rpt_data - # oes_req.origClSeqNo = self.local_id_to_sys_id[int(order_id)] - oes_req.origClOrdId = data.clOrdId - oes_req.origClSeqNo = data.clSeqNo - oes_req.origClEnvId = data.origClEnvId - oes_req.mktId = data.mktId - # oes_req.invAcctId = data.invAcctId - # oes_req.mktId = data.mktId - # oes_req.securityId = data.securityId - else: - data = internal_order.req_data - oes_req.origClSeqNo = data.clSeqNo - oes_req.mktId = internal_order.req_data.mktId + oes_req.mktId = EXCHANGE_VT2OES[vt_req.exchange] oes_req.clSeqNo = seq_id + oes_req.origClSeqNo = order_id oes_req.invAcctId = "" oes_req.securityId = vt_req.symbol - oes_req.userInfo = order_id - ret = OesApi_SendOrderCancelReq(self.env.ordChannel, - oes_req) - - if ret >= 0: - pass - else: - pass - return - - def schedule_query_order(self, internal_order: InternalOrder) -> Thread: - th = Thread(target=self.query_order, args=(internal_order,)) - th.start() - return th + OesApi_SendOrderCancelReq(self._env.ordChannel, + oes_req) def query_order(self, internal_order: InternalOrder) -> bool: + """""" f = OesQryOrdFilterT() - if internal_order.req_data: - f.clSeqNo = internal_order.req_data.clSeqNo - f.mktId = internal_order.req_data.mktId - f.invAcctId = internal_order.req_data.invAcctId - else: - f.clSeqNo = internal_order.rpt_data.origClSeqNo - f.clOrdId = internal_order.rpt_data.origClOrdId - f.clEnvId = internal_order.rpt_data.origClEnvId - f.mktId = internal_order.rpt_data.mktId - f.invAcctId = internal_order.rpt_data.invAcctId - ret = OesApi_QueryOrder(self.env.qryChannel, + f.mktId = EXCHANGE_VT2OES[internal_order.vt_order.exchange] + f.clSeqNo = internal_order.order_id + ret = OesApi_QueryOrder(self._env.qryChannel, f, self.on_query_order ) @@ -707,8 +610,10 @@ class OesTdApi: head: SMsgHeadT, body: Any, cursor: OesQryCursorT): + """""" data: OesOrdCnfmT = cast.toOesOrdItemT(body) - i = self.order_manager.get_from_oes_data(data) + + i = self.get_order(data.clSeqNo) vt_order = i.vt_order vt_order.status = STATUS_OES2VT[data.ordStatus] vt_order.volume = data.ordQty - data.canceledQty @@ -716,28 +621,33 @@ class OesTdApi: self.gateway.on_order(vt_order) return 1 - def init_query_orders(self) -> bool: - """ - :note: this function can be called only before calling send_order - :return: - """ + def query_orders(self) -> bool: + """""" f = OesQryOrdFilterT() - ret = OesApi_QueryOrder(self.env.qryChannel, + ret = OesApi_QueryOrder(self._env.qryChannel, f, - self.on_init_query_orders + self.on_query_orders ) return ret >= 0 - def on_init_query_orders(self, - session_info: SGeneralClientChannelT, - head: SMsgHeadT, - body: Any, - cursor: OesQryCursorT, - ): + def on_query_orders(self, + session_info: SGeneralClientChannelT, + head: SMsgHeadT, + body: Any, + cursor: OesQryCursorT, + ): + """""" data: OesOrdCnfmT = cast.toOesOrdItemT(body) - i = self.order_manager.get_remote_created_order_from_oes_data(data) - if not i: - order_id = self.order_manager.new_remote_id() + try: + i = self.get_order(data.clSeqNo) + vt_order = i.vt_order + vt_order.status = STATUS_OES2VT[data.ordStatus] + vt_order.volume = data.ordQty - data.canceledQty + vt_order.traded = data.cumQty + self.gateway.on_order(vt_order) + except KeyError: + # order_id = self.order_manager.new_remote_id() + order_id = data.clSeqNo if data.bsType == eOesBuySellTypeT.OES_BS_TYPE_BUY: offset = Offset.OPEN @@ -748,7 +658,7 @@ class OesTdApi: gateway_name=self.gateway.gateway_name, symbol=data.securityId, exchange=EXCHANGE_OES2VT[data.mktId], - orderid=order_id if order_id else data.userInfo, # generated id + orderid=order_id if order_id else data.origClSeqNo, # generated id direction=Direction.NET, offset=offset, price=data.ordPrice / 10000, @@ -759,13 +669,17 @@ class OesTdApi: # this time should be generated automatically or by a static function time=datetime.utcnow().isoformat(), ) - self.order_manager.save_remote_created(order_id, vt_order, data) + self.save_order(order_id, vt_order) self.gateway.on_order(vt_order) - return 1 - else: - vt_order = i.vt_order - vt_order.status = STATUS_OES2VT[data.ordStatus] - vt_order.volume = data.ordQty - data.canceledQty - vt_order.traded = data.cumQty - self.gateway.on_order(vt_order) - return 1 + return 1 + + def save_order(self, order_id: int, order: OrderData): + """""" + self._orders[order_id] = InternalOrder( + order_id=order_id, + vt_order=order, + ) + + def get_order(self, order_id: int): + """""" + return self._orders[order_id]