Merge pull request #1459 from nanoric/oes_fix

Oes fix
This commit is contained in:
vn.py 2019-03-08 22:32:52 +08:00 committed by GitHub
commit cc129b3e2c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 388 additions and 435 deletions

View File

@ -11,9 +11,6 @@ ordServer = 1 {td_ord_server}
rptServer = 1 {td_rpt_server}
qryServer = 1 {td_qry_server}
username = {username}
# 密码支持明文和MD5两种格式 (如 txt:XXX 或 md5:XXX..., 不带前缀则默认为明文)
password = {password}
heartBtInt = 30
# 客户端环境号, 用于区分不同客户端实例上报的委托申报, 取值由客户端自行分配
@ -33,7 +30,7 @@ rpt.subcribeEnvId = 0
# 比如想订阅所有委托、成交相关的回报消息,可以使用如下两种方式:
# - rpt.subcribeRptTypes = 1,4,8
# - 或等价的: rpt.subcribeRptTypes = 0x0D
rpt.subcribeRptTypes = 0
rpt.subcribeRptTypes = 1,2,4,8,0x10,0x20,0x40
# 服务器集群的集群类型 (1: 基于复制集的高可用集群, 2: 基于对等节点的服务器集群, 0: 默认为基于复制集的高可用集群)
clusterType = 0
@ -58,9 +55,6 @@ keepCnt = 9
tcpServer = {md_tcp_server}
qryServer = {md_qry_server}
username = {username}
# 密码支持明文和MD5两种格式 (如 txt:XXX 或 md5:XXX..., 不带前缀则默认为明文)
password = {password}
heartBtInt = 30
sse.stock.enable = false
@ -93,7 +87,7 @@ mktData.tickExpireType = 0
# 0x400:指数行情, 0x800:期权行情)
# 要订阅多个数据种类, 可以用逗号或空格分隔, 或者设置为并集值, 如:
# "mktData.dataTypes = 1,2,4" 或等价的 "mktData.dataTypes = 0x07"
mktData.dataTypes = 0
mktData.dataTypes = 1,2,4,8,0x10
# 请求订阅的行情数据的起始时间 (格式: HHMMSS 或 HHMMSSsss)
# (-1: 从头开始获取, 0: 从最新位置开始获取实时行情, 大于0: 从指定的起始时间开始获取)

View File

@ -3,14 +3,15 @@
"""
import hashlib
import os
from threading import Thread
from gettext import gettext as _
from threading import Thread, Lock
from vnpy.trader.gateway import BaseGateway
from vnpy.trader.object import (AccountData, CancelRequest, ContractData, OrderData, OrderRequest,
PositionData, SubscribeRequest, TickData, TradeData)
from vnpy.trader.object import (CancelRequest, OrderRequest,
SubscribeRequest)
from vnpy.trader.utility import get_file_path
from .md import OesMdApi
from .td import OesTdApi
from .oes_md import OesMdApi
from .oes_td import OesTdApi
from .utils import config_template
@ -19,30 +20,12 @@ class OesGateway(BaseGateway):
VN Trader Gateway for BitMEX connection.
"""
def on_tick(self, tick: TickData):
super().on_tick(tick)
def on_trade(self, trade: TradeData):
super().on_trade(trade)
def on_order(self, order: OrderData):
super().on_order(order)
def on_position(self, position: PositionData):
super().on_position(position)
def on_account(self, account: AccountData):
super().on_account(account)
def on_contract(self, contract: ContractData):
super().on_contract(contract)
default_setting = {
"td_ord_server": "tcp://106.15.58.119:6101",
"td_rpt_server": "tcp://106.15.58.119:6301",
"td_qry_server": "tcp://106.15.58.119:6401",
"md_tcp_server": "tcp://139.196.228.232:5103",
"md_qry_server": "tcp://139.196.228.232:5203",
"td_ord_server": "",
"td_rpt_server": "",
"td_qry_server": "",
"md_tcp_server": "",
"md_qry_server": "",
"username": "",
"password": "",
}
@ -54,15 +37,21 @@ class OesGateway(BaseGateway):
self.md_api = OesMdApi(self)
self.td_api = OesTdApi(self)
def connect(self, setting: dict):
return self._connect_async(setting)
self._lock_subscribe = Lock()
self._lock_send_order = Lock()
self._lock_cancel_order = Lock()
self._lock_query_position = Lock()
self._lock_query_account = Lock()
def _connect_sync(self, setting: dict):
def connect(self, setting: dict):
""""""
if not setting['password'].startswith("md5:"):
setting['password'] = "md5:" + hashlib.md5(setting['password'].encode()).hexdigest()
config_path = get_file_path("vnoes.ini")
username = setting['username']
password = setting['password']
config_path = str(get_file_path("vnoes.ini"))
with open(config_path, "wt") as f:
if 'test' in setting:
log_level = 'DEBUG'
@ -72,7 +61,7 @@ class OesGateway(BaseGateway):
log_mode = 'file'
log_dir = get_file_path('oes')
log_path = os.path.join(log_dir, 'log.log')
if os.path.exists(log_dir):
if not os.path.exists(log_dir):
os.mkdir(log_dir)
content = config_template.format(**setting,
log_level=log_level,
@ -80,41 +69,59 @@ class OesGateway(BaseGateway):
log_path=log_path)
f.write(content)
self.td_api.connect(str(config_path))
self.td_api.query_account()
self.td_api.query_contracts()
self.td_api.query_position()
self.td_api.init_query_orders()
self.td_api.start()
Thread(target=self._connect_md_sync, args=(config_path, username, password)).start()
Thread(target=self._connect_td_sync, args=(config_path, username, password)).start()
self.md_api.connect(str(config_path))
self.md_api.start()
def _connect_td_sync(self, config_path, username, password):
self.td_api.config_path = config_path
self.td_api.username = username
self.td_api.password = password
if self.td_api.connect():
self.write_log(_("成功连接到交易服务器"))
self.td_api.query_account()
self.td_api.query_contracts()
self.write_log("合约信息查询成功")
self.td_api.query_position()
self.td_api.query_orders()
self.td_api.start()
else:
self.write_log(_("无法连接到交易服务器,请检查你的配置"))
def _connect_async(self, setting: dict):
Thread(target=self._connect_sync, args=(setting, )).start()
def _connect_md_sync(self, config_path, username, password):
self.md_api.config_path = config_path
self.md_api.username = username
self.md_api.password = password
if self.md_api.connect():
self.md_api.start()
else:
self.write_log(_("无法连接到行情服务器,请检查你的配置"))
def subscribe(self, req: SubscribeRequest):
""""""
self.md_api.subscribe(req)
with self._lock_subscribe:
self.md_api.subscribe(req)
def send_order(self, req: OrderRequest):
""""""
return self.td_api.send_order(req)
with self._lock_send_order:
return self.td_api.send_order(req)
def cancel_order(self, req: CancelRequest):
""""""
return self.td_api.cancel_order(req)
with self._lock_cancel_order:
self.td_api.cancel_order(req)
def query_account(self):
""""""
return self.td_api.query_account()
with self._lock_query_account:
self.td_api.query_account()
def query_position(self):
""""""
return self.query_position()
with self._lock_query_position:
self.td_api.query_position()
def close(self):
""""""
self.md_api.stop()
self.td_api.stop()
pass

View File

@ -1,14 +1,16 @@
import time
from datetime import datetime
from gettext import gettext as _
from threading import Thread
# noinspection PyUnresolvedReferences
from typing import Any, Callable, Dict
from vnpy.api.oes.vnoes import MdsApiClientEnvT, MdsApi_DestoryAll, MdsApi_InitAllByConvention, \
MdsApi_IsValidQryChannel, MdsApi_IsValidTcpChannel, MdsApi_LogoutAll, \
MdsApi_SubscribeMarketData, MdsApi_WaitOnMsg, MdsL2StockSnapshotBodyT, MdsMktDataRequestEntryT, \
MdsMktDataRequestReqT, MdsMktRspMsgBodyT, MdsStockSnapshotBodyT, SGeneralClientChannelT, \
SMsgHeadT, SPlatform_IsNegEpipe, SPlatform_IsNegEtimeout, cast, eMdsExchangeIdT, \
eMdsMktSubscribeFlagT, eMdsMsgTypeT, eMdsSecurityTypeT, eMdsSubscribeDataTypeT, \
MdsApi_IsValidQryChannel, MdsApi_IsValidTcpChannel, MdsApi_LogoutAll, MdsApi_SetThreadPassword, \
MdsApi_SetThreadUsername, MdsApi_SubscribeMarketData, MdsApi_WaitOnMsg, MdsL2StockSnapshotBodyT, \
MdsMktDataRequestEntryT, MdsMktDataRequestReqT, MdsMktRspMsgBodyT, MdsStockSnapshotBodyT, \
SGeneralClientChannelT, SMsgHeadT, SPlatform_IsNegEpipe, cast, \
eMdsExchangeIdT, eMdsMktSubscribeFlagT, eMdsMsgTypeT, eMdsSecurityTypeT, eMdsSubscribeDataTypeT, \
eMdsSubscribeModeT, eMdsSubscribedTickExpireTypeT, eSMsgProtocolTypeT
from vnpy.trader.constant import Exchange
@ -25,18 +27,23 @@ EXCHANGE_VT2MDS = {v: k for k, v in EXCHANGE_MDS2VT.items()}
class OesMdMessageLoop:
def __init__(self, gateway: BaseGateway, env: MdsApiClientEnvT):
def __init__(self, gateway: BaseGateway, md: "OesMdApi", env: MdsApiClientEnvT):
""""""
self.gateway = gateway
self.env = env
self.alive = False
self.th = Thread(target=self.message_loop)
self._alive = False
self._md = md
self._th = Thread(target=self._message_loop)
self.message_handlers: Dict[int, Callable[[dict], None]] = {
# tick & orderbook
eMdsMsgTypeT.MDS_MSGTYPE_MARKET_DATA_SNAPSHOT_FULL_REFRESH: self.on_market_full_refresh,
eMdsMsgTypeT.MDS_MSGTYPE_L2_MARKET_DATA_SNAPSHOT: self.on_l2_market_data_snapshot,
eMdsMsgTypeT.MDS_MSGTYPE_L2_ORDER: self.on_l2_order,
eMdsMsgTypeT.MDS_MSGTYPE_L2_TRADE: self.on_l2_trade,
# others
eMdsMsgTypeT.MDS_MSGTYPE_QRY_SECURITY_STATUS: self.on_security_status,
eMdsMsgTypeT.MDS_MSGTYPE_L2_MARKET_DATA_INCREMENTAL: lambda x: 1,
eMdsMsgTypeT.MDS_MSGTYPE_L2_BEST_ORDERS_SNAPSHOT: self.on_best_orders_snapshot,
@ -46,15 +53,36 @@ class OesMdMessageLoop:
eMdsMsgTypeT.MDS_MSGTYPE_TRADING_SESSION_STATUS: self.on_trading_session_status,
eMdsMsgTypeT.MDS_MSGTYPE_SECURITY_STATUS: self.on_security_status,
eMdsMsgTypeT.MDS_MSGTYPE_MARKET_DATA_REQUEST: self.on_market_data_request,
eMdsMsgTypeT.MDS_MSGTYPE_HEARTBEAT: lambda x: 1,
}
self.last_tick: Dict[str, TickData] = {}
self.symbol_to_exchange: Dict[str, Exchange] = {}
def register_symbol(self, symbol: str, exchange: Exchange):
""""""
self.symbol_to_exchange[symbol] = exchange
def get_last_tick(self, symbol):
def start(self):
""""""
self._alive = True
self._th.start()
def stop(self):
""""""
self._alive = False
def join(self):
""""""
self._th.join()
def reconnect(self):
""""""
self.gateway.write_log(_("正在尝试重新连接到行情服务器。"))
return self._md.connect()
def _get_last_tick(self, symbol):
""""""
try:
return self.last_tick[symbol]
except KeyError:
@ -62,22 +90,15 @@ class OesMdMessageLoop:
gateway_name=self.gateway.gateway_name,
symbol=symbol,
exchange=self.symbol_to_exchange[symbol],
# todo: use cache of something else to resolve exchange
datetime=datetime.utcnow()
)
self.last_tick[symbol] = tick
return tick
def start(self):
self.alive = True
self.th.start()
def join(self):
self.th.join()
def on_message(self, session_info: SGeneralClientChannelT,
head: SMsgHeadT,
body: Any):
def _on_message(self, session_info: SGeneralClientChannelT,
head: SMsgHeadT,
body: Any):
""""""
if session_info.protocolType == eSMsgProtocolTypeT.SMSG_PROTO_BINARY:
b = cast.toMdsMktRspMsgBodyT(body)
if head.msgId in self.message_handlers:
@ -91,31 +112,30 @@ class OesMdMessageLoop:
self.gateway.write_log(f"unknown prototype : {session_info.protocolType}")
return 1
def message_loop(self):
def _message_loop(self):
""""""
tcp_channel = self.env.tcpChannel
timeout_ms = 1000
is_timeout = SPlatform_IsNegEtimeout
# is_timeout = SPlatform_IsNegEtimeout
is_disconnected = SPlatform_IsNegEpipe
while self.alive:
while self._alive:
ret = MdsApi_WaitOnMsg(tcp_channel,
timeout_ms,
self.on_message)
self._on_message)
if ret < 0:
if is_timeout(ret):
pass
# if is_timeout(ret):
# pass # just no message
if is_disconnected(ret):
# todo: handle disconnected
self.alive = False
break
pass
self.gateway.write_log(_("与行情服务器的连接已断开。"))
while self._alive and not self.reconnect():
time.sleep(1)
return
def on_l2_market_data_snapshot(self, d: MdsMktRspMsgBodyT):
""""""
data: MdsL2StockSnapshotBodyT = d.mktDataSnapshot.l2Stock
symbol = str(data.SecurityID)
tick = self.get_last_tick(symbol)
tick.limit_up = data.HighPx / 10000
tick.limit_down = data.LowPx / 10000
tick = self._get_last_tick(symbol)
tick.open_price = data.OpenPx / 10000
tick.pre_close = data.ClosePx / 10000
tick.high_price = data.HighPx / 10000
@ -128,11 +148,10 @@ class OesMdMessageLoop:
self.gateway.on_tick(tick)
def on_market_full_refresh(self, d: MdsMktRspMsgBodyT):
""""""
data: MdsStockSnapshotBodyT = d.mktDataSnapshot.stock
symbol = data.SecurityID
tick = self.get_last_tick(symbol)
tick.limit_up = data.HighPx / 10000
tick.limit_down = data.LowPx / 10000
tick = self._get_last_tick(symbol)
tick.open_price = data.OpenPx / 10000
tick.pre_close = data.ClosePx / 10000
tick.high_price = data.HighPx / 10000
@ -143,76 +162,96 @@ class OesMdMessageLoop:
for i in range(5):
tick.__dict__['ask_price_' + str(i + 1)] = data.OfferLevels[i].Price / 10000
self.gateway.on_tick(tick)
pass
def on_l2_trade(self, d: MdsMktRspMsgBodyT):
""""""
data = d.trade
symbol = data.SecurityID
tick = self.get_last_tick(symbol)
tick = self._get_last_tick(symbol)
tick.datetime = datetime.utcnow()
tick.volume = data.TradeQty
tick.last_price = data.TradePrice / 10000
self.gateway.on_tick(tick)
def on_market_data_request(self, d: MdsMktRspMsgBodyT):
""""""
pass
def on_trading_session_status(self, d: MdsMktRspMsgBodyT):
""""""
pass
def on_l2_market_overview(self, d: MdsMktRspMsgBodyT):
""""""
pass
def on_index_snapshot_full_refresh(self, d: MdsMktRspMsgBodyT):
""""""
pass
def on_option_snapshot_ful_refresh(self, d: MdsMktRspMsgBodyT):
""""""
pass
def on_best_orders_snapshot(self, d: MdsMktRspMsgBodyT):
""""""
pass
def on_l2_order(self, d: MdsMktRspMsgBodyT):
""""""
pass
def on_security_status(self, d: MdsMktRspMsgBodyT):
""""""
pass
def stop(self):
self.alive = False
class OesMdApi:
def __init__(self, gateway: BaseGateway):
""""""
self.gateway = gateway
self.env = MdsApiClientEnvT()
self.message_loop = OesMdMessageLoop(gateway, self.env)
self.config_path: str = ''
self.username: str = ''
self.password: str = ''
def connect(self, config_path: str):
if not MdsApi_InitAllByConvention(self.env, config_path):
pass
self._env = MdsApiClientEnvT()
self._message_loop = OesMdMessageLoop(gateway, self, self._env)
if not MdsApi_IsValidTcpChannel(self.env.tcpChannel):
pass
if not MdsApi_IsValidQryChannel(self.env.qryChannel):
pass
def connect(self) -> bool:
""""""
"""Connect to trading server.
:note set config_path before calling this function
"""
MdsApi_SetThreadUsername(self.username)
MdsApi_SetThreadPassword(self.password)
config_path = self.config_path
if not MdsApi_InitAllByConvention(self._env, config_path):
return False
if not MdsApi_IsValidTcpChannel(self._env.tcpChannel):
return False
if not MdsApi_IsValidQryChannel(self._env.qryChannel):
return False
return True
def start(self):
self.message_loop.start()
""""""
self._message_loop.start()
def stop(self):
self.message_loop.stop()
if not MdsApi_LogoutAll(self.env, True):
pass # doc for this function is error
if not MdsApi_DestoryAll(self.env):
pass # doc for this function is error
""""""
self._message_loop.stop()
MdsApi_LogoutAll(self._env, True)
MdsApi_DestoryAll(self._env)
def join(self):
self.message_loop.join()
""""""
self._message_loop.join()
# why isn't arg a ContractData?
def subscribe(self, req: SubscribeRequest):
""""""
mds_req = MdsMktDataRequestReqT()
entry = MdsMktDataRequestEntryT()
mds_req.subMode = eMdsSubscribeModeT.MDS_SUB_MODE_APPEND
@ -240,12 +279,11 @@ class OesMdApi:
entry.securityType = eMdsSecurityTypeT.MDS_SECURITY_TYPE_STOCK # todo: option and others
entry.instrId = int(req.symbol)
self.message_loop.register_symbol(req.symbol, req.exchange)
self._message_loop.register_symbol(req.symbol, req.exchange)
ret = MdsApi_SubscribeMarketData(
self.env.tcpChannel,
self._env.tcpChannel,
mds_req,
entry)
if not ret:
self.gateway.write_log(
f"MdsApi_SubscribeByString failed with {ret}:{error_to_str(ret)}")
pass

View File

@ -1,19 +1,20 @@
from dataclasses import dataclass
from datetime import datetime
from threading import Thread
from datetime import datetime, timedelta, timezone
from gettext import gettext as _
from threading import Lock, Thread
# noinspection PyUnresolvedReferences
from typing import Any, Callable, Dict, Tuple
from typing import Any, Callable, Dict
from vnpy.api.oes.vnoes import OesApiClientEnvT, OesApi_DestoryAll, OesApi_InitAllByConvention, \
OesApi_IsValidOrdChannel, OesApi_IsValidQryChannel, OesApi_IsValidRptChannel, OesApi_LogoutAll, \
OesApi_QueryCashAsset, OesApi_QueryEtf, OesApi_QueryIssue, OesApi_QueryOptHolding, \
OesApi_QueryOption, OesApi_QueryOrder, OesApi_QueryStkHolding, OesApi_QueryStock, \
OesApi_SendOrderCancelReq, OesApi_SendOrderReq, OesApi_WaitReportMsg, OesOrdCancelReqT, \
OesOrdCnfmT, OesOrdRejectT, OesOrdReqT, OesQryCashAssetFilterT, OesQryCursorT, OesQryEtfFilterT, \
OesQryIssueFilterT, OesQryOptionFilterT, OesQryOrdFilterT, OesQryStkHoldingFilterT, \
OesQryStockFilterT, OesRspMsgBodyT, OesStockBaseInfoT, OesTrdCnfmT, SGeneralClientChannelT, \
SMSG_PROTO_BINARY, SMsgHeadT, SPlatform_IsNegEpipe, SPlatform_IsNegEtimeout, cast, \
eOesBuySellTypeT, eOesMarketIdT, eOesMsgTypeT, eOesOrdStatusT, eOesOrdTypeShT, eOesOrdTypeSzT
OesApi_QueryCashAsset, OesApi_QueryOptHolding, OesApi_QueryOption, OesApi_QueryOrder, \
OesApi_QueryStkHolding, OesApi_QueryStock, OesApi_SendOrderCancelReq, OesApi_SendOrderReq, \
OesApi_SetThreadPassword, OesApi_SetThreadUsername, OesApi_WaitReportMsg, OesOrdCancelReqT, \
OesOrdCnfmT, OesOrdRejectT, OesOrdReqT, OesQryCashAssetFilterT, OesQryCursorT, \
OesQryOptionFilterT, OesQryOrdFilterT, OesQryStkHoldingFilterT, OesQryStockFilterT, \
OesRspMsgBodyT, OesStockBaseInfoT, OesTrdCnfmT, SGeneralClientChannelT, SMSG_PROTO_BINARY, \
SMsgHeadT, SPlatform_IsNegEpipe, cast, eOesBuySellTypeT, eOesMarketIdT, \
eOesMsgTypeT, eOesOrdStatusT, eOesOrdTypeShT, eOesOrdTypeSzT
from vnpy.gateway.oes.error_code import error_to_str
from vnpy.trader.constant import Direction, Exchange, Offset, PriceType, Product, Status
@ -84,87 +85,28 @@ STATUS_OES2VT = {
eOesOrdStatusT.OES_ORD_STATUS_INVALID_SZ_TRY_AGAIN: Status.REJECTED,
}
bjtz = timezone(timedelta(hours=8))
@dataclass
class InternalOrder:
order_id: int = None
vt_order: OrderData = None
req_data: OesOrdReqT = None
rpt_data: OesOrdCnfmT = None
class OrderManager:
def parse_oes_datetime(date: int, time: int):
"""convert oes datetime to python datetime"""
# YYYYMMDD
year = int(date / 10000)
month = int((date % 10000) / 100)
day = int(date % 100)
def __init__(self):
self.last_order_id = 100000000
self._orders: Dict[int, InternalOrder] = {}
# key tuple: seqNo, ordId, envId, userInfo
self._remote_created_orders: Dict[Tuple[int, int, int, int], InternalOrder] = {}
@staticmethod
def hash_remote_order(data):
key = (data.origClSeqNo, data.origClOrdId, data.origClEnvId, data.userInfo)
return key
@staticmethod
def hash_remote_trade(data: OesTrdCnfmT):
key = (data.clSeqNo, data.clOrdId, data.clEnvId, data.userInfo)
return key
def new_local_id(self):
id = self.last_order_id
self.last_order_id += 1
return id
def new_remote_id(self):
id = self.last_order_id
self.last_order_id += 1
return id
def save_local_created(self, order_id: int, order: OrderData, oes_req: OesOrdReqT):
self._orders[order_id] = InternalOrder(
order_id=order_id,
vt_order=order,
req_data=oes_req
)
def save_remote_created(self, order_id: int, vt_order: OrderData, data: OesOrdCnfmT):
internal_order = InternalOrder(
order_id=order_id,
vt_order=vt_order,
rpt_data=data
)
self._orders[order_id] = internal_order
key = self.hash_remote_order(data)
self._remote_created_orders[key] = internal_order
def get_from_order_id(self, id: int):
return self._orders[id]
def get_remote_created_order_from_oes_data(self, data):
"""
:return: internal_order if succeed else None, will check only remote created order
"""
try:
key = self.hash_remote_order(data)
except AttributeError:
key = self.hash_remote_trade(data)
try:
return self._remote_created_orders[key]
except KeyError:
return None
def get_from_oes_data(self, data):
try:
key = self.hash_remote_order(data)
except AttributeError:
key = self.hash_remote_trade(data)
try:
return self._remote_created_orders[key]
except KeyError:
order_id = key[3]
return self._orders[order_id]
# HHMMSSsss
hour = int(time / 10000000)
minute = int((time % 10000000) / 100000)
sec = int((time % 100000) / 1000)
mill = int(time % 1000)
return datetime(year, month, day, hour, minute, sec, mill * 1000, tzinfo=bjtz)
class OesTdMessageLoop:
@ -172,38 +114,52 @@ class OesTdMessageLoop:
def __init__(self,
gateway: BaseGateway,
env: OesApiClientEnvT,
order_manager: OrderManager,
td: "OesTdApi"
):
""""""
self.gateway = gateway
self.env = env
self.order_manager = order_manager
self.td = td
self.alive = False
self.th = Thread(target=self.message_loop)
self._env = env
self._td = td
self._alive = False
self._th = Thread(target=self._message_loop)
self.message_handlers: Dict[int, Callable[[dict], None]] = {
eOesMsgTypeT.OESMSG_RPT_BUSINESS_REJECT: self.on_reject,
eOesMsgTypeT.OESMSG_RPT_BUSINESS_REJECT: self.on_order_rejected,
eOesMsgTypeT.OESMSG_RPT_ORDER_INSERT: self.on_order_inserted,
eOesMsgTypeT.OESMSG_RPT_ORDER_REPORT: self.on_order_report,
eOesMsgTypeT.OESMSG_RPT_TRADE_REPORT: self.on_trade_report,
eOesMsgTypeT.OESMSG_RPT_STOCK_HOLDING_VARIATION: self.on_stock_holding,
eOesMsgTypeT.OESMSG_RPT_OPTION_HOLDING_VARIATION: self.on_option_holding,
eOesMsgTypeT.OESMSG_RPT_CASH_ASSET_VARIATION: self.on_cash,
eOesMsgTypeT.OESMSG_SESS_HEARTBEAT: lambda x: x,
eOesMsgTypeT.OESMSG_RPT_REPORT_SYNCHRONIZATION: lambda x: 1,
eOesMsgTypeT.OESMSG_SESS_HEARTBEAT: lambda x: 1,
}
def start(self):
self.alive = True
self.th.start()
""""""
self._alive = True
self._th.start()
def stop(self):
""""""
self._alive = False
def join(self):
self.th.join()
""""""
self._th.join()
def on_message(self, session_info: SGeneralClientChannelT,
head: SMsgHeadT,
body: Any):
def reconnect(self):
""""""
self.gateway.write_log(_("正在尝试重新连接到交易服务器。"))
self._td.connect()
def _on_message(self, session_info: SGeneralClientChannelT,
head: SMsgHeadT,
body: Any):
""""""
if session_info.protocolType == SMSG_PROTO_BINARY:
b = cast.toOesRspMsgBodyT(body)
if head.msgId in self.message_handlers:
@ -215,67 +171,83 @@ class OesTdMessageLoop:
self.gateway.write_log(f"unknown prototype : {session_info.protocolType}")
return 1
def message_loop(self):
rtp_channel = self.env.rptChannel
def _message_loop(self):
""""""
rpt_channel = self._env.rptChannel
timeout_ms = 1000
is_timeout = SPlatform_IsNegEtimeout
is_disconnected = SPlatform_IsNegEpipe
while self.alive:
ret = OesApi_WaitReportMsg(rtp_channel,
while self._alive:
ret = OesApi_WaitReportMsg(rpt_channel,
timeout_ms,
self.on_message)
self._on_message)
if ret < 0:
if is_timeout(ret):
pass
# if is_timeout(ret):
# pass # just no message
if is_disconnected(ret):
# todo: handle disconnected
self.alive = False
break
pass
self.gateway.write_log(_("与交易服务器的连接已断开。"))
while self._alive and not self.reconnect():
pass
return
def on_reject(self, d: OesRspMsgBodyT):
def on_order_rejected(self, d: OesRspMsgBodyT):
""""""
error_code = d.rptMsg.rptHead.ordRejReason
error_string = error_to_str(error_code)
data: OesOrdRejectT = d.rptMsg.rptBody.ordRejectRsp
i = self.order_manager.get_from_oes_data(data)
vt_order = i.vt_order
if not data.origClSeqNo:
i = self._td.get_order(data.clSeqNo)
vt_order = i.vt_order
if vt_order == Status.ALLTRADED:
return
if vt_order == Status.ALLTRADED:
return
vt_order.status = Status.REJECTED
vt_order.status = Status.REJECTED
self.gateway.on_order(vt_order)
self.gateway.write_log(
f"Order: {vt_order.vt_symbol}-{vt_order.vt_orderid} Code: {error_code} Rejected: {error_string}")
self.gateway.on_order(vt_order)
self.gateway.write_log(
f"Order: {vt_order.vt_symbol}-{vt_order.vt_orderid} Code: {error_code} Rejected: {error_string}")
else:
self.gateway.write_log(f"撤单失败,订单号: {data.origClSeqNo}。原因:{error_string}")
def on_order_inserted(self, d: OesRspMsgBodyT):
""""""
data = d.rptMsg.rptBody.ordInsertRsp
i = self.order_manager.get_from_oes_data(data)
if not data.origClSeqNo:
i = self._td.get_order(data.clSeqNo)
else:
i = self._td.get_order(data.origClSeqNo)
vt_order = i.vt_order
vt_order.status = STATUS_OES2VT[data.ordStatus]
vt_order.volume = data.ordQty - data.canceledQty
vt_order.volume = data.ordQty
vt_order.traded = data.cumQty
vt_order.time = parse_oes_datetime(data.ordDate, data.ordTime)
self.gateway.on_order(vt_order)
def on_order_report(self, d: OesRspMsgBodyT):
""""""
data: OesOrdCnfmT = d.rptMsg.rptBody.ordCnfm
i = self.order_manager.get_from_oes_data(data)
if not data.origClSeqNo:
i = self._td.get_order(data.clSeqNo)
else:
i = self._td.get_order(data.origClSeqNo)
vt_order = i.vt_order
vt_order.status = STATUS_OES2VT[data.ordStatus]
vt_order.volume = data.ordQty - data.canceledQty
vt_order.volume = data.ordQty
vt_order.traded = data.cumQty
vt_order.time = parse_oes_datetime(data.ordDate, data.ordCnfmTime)
self.gateway.on_order(vt_order)
def on_trade_report(self, d: OesRspMsgBodyT):
""""""
data: OesTrdCnfmT = d.rptMsg.rptBody.trdCnfm
i = self.order_manager.get_from_oes_data(data)
i = self._td.get_order(data.clSeqNo)
vt_order = i.vt_order
# vt_order.status = STATUS_OES2VT[data.ordStatus]
@ -289,25 +261,20 @@ class OesTdMessageLoop:
offset=vt_order.offset,
price=data.trdPrice / 10000,
volume=data.trdQty,
time=datetime.utcnow().isoformat() # strict
time=parse_oes_datetime(data.trdDate, data.trdTime)
)
vt_order.status = STATUS_OES2VT[data.ordStatus]
vt_order.traded = data.cumQty
vt_order.time = parse_oes_datetime(data.trdDate, data.trdTime)
self.gateway.on_trade(trade)
# hack :
# Sometimes order_report is not received after a trade is received.
# (only trade msg but no order msg)
# This cause a problem that vt_order.traded stay 0 after a trade, which is a error state.
# So we have to query new status of order for every receiving of trade.
# BUT
# Oes have no async call to query order only.
# And calling sync function here will slow down vnpy.
# So we queue it into another thread.
self.td.schedule_query_order(i)
self.gateway.on_order(vt_order)
def on_option_holding(self, d: OesRspMsgBodyT):
""""""
pass
def on_stock_holding(self, d: OesRspMsgBodyT):
""""""
data = d.rptMsg.rptBody.stkHoldingRpt
position = PositionData(
gateway_name=self.gateway.gateway_name,
@ -315,21 +282,22 @@ class OesTdMessageLoop:
exchange=EXCHANGE_OES2VT[data.mktId],
direction=Direction.NET,
volume=data.sumHld,
frozen=data.lockHld,
frozen=data.lockHld, # todo: to verify
price=data.costPrice / 10000,
# pnl=data.costPrice - data.originalCostAmt,
pnl=0, # todo: oes只提供日初持仓价格信息不提供最初持仓价格信息所以pnl只有当日的
pnl=0,
yd_volume=data.originalHld,
)
self.gateway.on_position(position)
def on_cash(self, d: OesRspMsgBodyT):
""""""
data = d.rptMsg.rptBody.cashAssetRpt
balance = data.currentTotalBal
availiable = data.currentAvailableBal
# drawable = data.currentDrawableBal
account_id = data.custId
account_id = data.cashAcctId
account = AccountData(
gateway_name=self.gateway.gateway_name,
accountid=account_id,
@ -339,59 +307,74 @@ class OesTdMessageLoop:
self.gateway.on_account(account)
return 1
def stop(self):
self.alive = False
class OesTdApi:
def __init__(self, gateway: BaseGateway):
""""""
self.config_path: str = None
self.username: str = ''
self.password: str = ''
self.gateway = gateway
self.env = OesApiClientEnvT()
self.order_manager = OrderManager()
self.message_loop = OesTdMessageLoop(gateway,
self.env,
self.order_manager,
self)
self._env = OesApiClientEnvT()
self.account_id = None
self.last_seq_index = 1 # 0 has special manning for oes
self._message_loop = OesTdMessageLoop(gateway,
self._env,
self)
def connect(self, config_path: str):
if not OesApi_InitAllByConvention(self.env, config_path, -1, self.last_seq_index):
pass
self.last_seq_index = self.env.ordChannel.lastOutMsgSeq + 1
self._last_seq_lock = Lock()
self._last_seq_index = 1000000 # 0 has special manning for oes
if not OesApi_IsValidOrdChannel(self.env.ordChannel):
pass
if not OesApi_IsValidQryChannel(self.env.qryChannel):
pass
if not OesApi_IsValidRptChannel(self.env.rptChannel):
pass
self._orders: Dict[int, InternalOrder] = {}
def connect(self):
"""Connect to trading server.
:note set config_path before calling this function
"""
OesApi_SetThreadUsername(self.username)
OesApi_SetThreadPassword(self.password)
config_path = self.config_path
if not OesApi_InitAllByConvention(self._env, config_path, -1, self._last_seq_index):
return False
self._last_seq_index = max(self._last_seq_index, self._env.ordChannel.lastOutMsgSeq + 1)
if not OesApi_IsValidOrdChannel(self._env.ordChannel):
return False
if not OesApi_IsValidQryChannel(self._env.qryChannel):
return False
if not OesApi_IsValidRptChannel(self._env.rptChannel):
return False
return True
def start(self):
self.message_loop.start()
""""""
self._message_loop.start()
def stop(self):
self.message_loop.stop()
if not OesApi_LogoutAll(self.env, True):
pass # doc for this function is error
if not OesApi_DestoryAll(self.env):
pass # doc for this function is error
""""""
self._message_loop.stop()
OesApi_LogoutAll(self._env, True)
OesApi_DestoryAll(self._env)
def join(self):
self.message_loop.join()
""""""
self._message_loop.join()
def query_account(self) -> bool:
return self.query_cash_asset()
def _get_new_seq_index(self):
""""""
with self._last_seq_lock:
index = self._last_seq_index
self._last_seq_index += 1
return index
def query_cash_asset(self) -> bool:
ret = OesApi_QueryCashAsset(self.env.qryChannel,
OesQryCashAssetFilterT(),
self.on_query_asset
)
return ret >= 0
def query_account(self):
""""""
OesApi_QueryCashAsset(self._env.qryChannel,
OesQryCashAssetFilterT(),
self.on_query_asset
)
def on_query_asset(self,
session_info: SGeneralClientChannelT,
@ -399,28 +382,25 @@ class OesTdApi:
body: Any,
cursor: OesQryCursorT,
):
""""""
data = cast.toOesCashAssetItemT(body)
balance = data.currentTotalBal / 10000
availiable = data.currentAvailableBal / 10000
# drawable = data.currentDrawableBal
account_id = data.custId
account_id = data.cashAcctId
account = AccountData(
gateway_name=self.gateway.gateway_name,
accountid=account_id,
balance=balance,
frozen=balance - availiable,
)
self.account_id = account_id
self.gateway.on_account(account)
return 1
def query_stock(self, ) -> bool:
# Thread(target=self._query_stock, ).start()
return self._query_stock()
def _query_stock(self, ) -> bool:
""""""
f = OesQryStockFilterT()
ret = OesApi_QueryStock(self.env.qryChannel, f, self.on_query_stock)
ret = OesApi_QueryStock(self._env.qryChannel, f, self.on_query_stock)
return ret >= 0
def on_query_stock(self,
@ -429,6 +409,7 @@ class OesTdApi:
body: Any,
cursor: OesQryCursorT,
):
""""""
data: OesStockBaseInfoT = cast.toOesStockItemT(body)
contract = ContractData(
gateway_name=self.gateway.gateway_name,
@ -443,8 +424,9 @@ class OesTdApi:
return 1
def query_option(self) -> bool:
""""""
f = OesQryOptionFilterT()
ret = OesApi_QueryOption(self.env.qryChannel,
ret = OesApi_QueryOption(self._env.qryChannel,
f,
self.on_query_option
)
@ -456,6 +438,7 @@ class OesTdApi:
body: Any,
cursor: OesQryCursorT,
):
""""""
data = cast.toOesOptionItemT(body)
contract = ContractData(
gateway_name=self.gateway.gateway_name,
@ -469,63 +452,10 @@ class OesTdApi:
self.gateway.on_contract(contract)
return 1
def query_issue(self) -> bool:
f = OesQryIssueFilterT()
ret = OesApi_QueryIssue(self.env.qryChannel,
f,
self.on_query_issue
)
return ret >= 0
def on_query_issue(self,
session_info: SGeneralClientChannelT,
head: SMsgHeadT,
body: Any,
cursor: OesQryCursorT,
):
data = cast.toOesIssueItemT(body)
contract = ContractData(
gateway_name=self.gateway.gateway_name,
symbol=data.securityId,
exchange=EXCHANGE_OES2VT[data.mktId],
name=data.securityName,
product=PRODUCT_OES2VT[data.mktId],
size=data.qtyUnit,
pricetick=1,
)
self.gateway.on_contract(contract)
return 1
def query_etf(self) -> bool:
f = OesQryEtfFilterT()
ret = OesApi_QueryEtf(self.env.qryChannel,
f,
self.on_query_etf
)
return ret >= 0
def on_query_etf(self,
session_info: SGeneralClientChannelT,
head: SMsgHeadT,
body: Any,
cursor: OesQryCursorT,
):
data = cast.toOesEtfItemT(body)
contract = ContractData(
gateway_name=self.gateway.gateway_name,
symbol=data.securityId,
exchange=EXCHANGE_OES2VT[data.mktId],
name=data.securityId,
product=PRODUCT_OES2VT[data.mktId],
size=data.creRdmUnit, # todo: to verify! creRdmUnit : 每个篮子 (最小申购、赎回单位) 对应的ETF份数
pricetick=1,
)
self.gateway.on_contract(contract)
return 1
def query_stock_holding(self) -> bool:
""""""
f = OesQryStkHoldingFilterT()
ret = OesApi_QueryStkHolding(self.env.qryChannel,
ret = OesApi_QueryStkHolding(self._env.qryChannel,
f,
self.on_query_stock_holding
)
@ -537,6 +467,7 @@ class OesTdApi:
body: Any,
cursor: OesQryCursorT,
):
""""""
data = cast.toOesStkHoldingItemT(body)
position = PositionData(
@ -555,10 +486,11 @@ class OesTdApi:
return 1
def query_option_holding(self) -> bool:
""""""
f = OesQryStkHoldingFilterT()
f.mktId = eOesMarketIdT.OES_MKT_ID_UNDEFINE
f.userInfo = 0
ret = OesApi_QueryOptHolding(self.env.qryChannel,
ret = OesApi_QueryOptHolding(self._env.qryChannel,
f,
self.on_query_holding
)
@ -570,6 +502,7 @@ class OesTdApi:
body: Any,
cursor: OesQryCursorT,
):
""""""
data = cast.toOesOptHoldingItemT(body)
# 权利
@ -605,19 +538,20 @@ class OesTdApi:
return 1
def query_contracts(self):
""""""
self.query_stock()
# self.query_option()
# self.query_issue()
pass
def query_position(self):
""""""
self.query_stock_holding()
self.query_option_holding()
def send_order(self, vt_req: OrderRequest):
seq_id = self.last_seq_index
self.last_seq_index += 1 # note: thread un-safe here, conflict with on_query_order
order_id = self.order_manager.new_local_id()
""""""
seq_id = self._get_new_seq_index()
order_id = seq_id
oes_req = OesOrdReqT()
oes_req.clSeqNo = seq_id
@ -628,16 +562,17 @@ class OesTdApi:
oes_req.securityId = vt_req.symbol
oes_req.ordQty = int(vt_req.volume)
oes_req.ordPrice = int(vt_req.price * 10000)
oes_req.userInfo = order_id
ret = OesApi_SendOrderReq(self.env.ordChannel,
oes_req
)
oes_req.origClOrdId = order_id
order = vt_req.create_order_data(str(order_id), self.gateway.gateway_name)
order.direction = Direction.NET # fix direction into NET: stock only
self.save_order(order_id, order)
ret = OesApi_SendOrderReq(self._env.ordChannel,
oes_req
)
if ret >= 0:
self.order_manager.save_local_created(order_id, order, oes_req)
self.gateway.on_order(order)
else:
self.gateway.write_log("Failed to send_order!")
@ -645,58 +580,26 @@ class OesTdApi:
return order.vt_orderid
def cancel_order(self, vt_req: CancelRequest):
seq_id = self.last_seq_index
self.last_seq_index += 1 # note: thread un-safe here
""""""
seq_id = self._get_new_seq_index()
oes_req = OesOrdCancelReqT()
order_id = int(vt_req.orderid)
internal_order = self.order_manager.get_from_order_id(order_id)
if internal_order.rpt_data:
data = internal_order.rpt_data
# oes_req.origClSeqNo = self.local_id_to_sys_id[int(order_id)]
oes_req.origClOrdId = data.clOrdId
oes_req.origClSeqNo = data.clSeqNo
oes_req.origClEnvId = data.origClEnvId
oes_req.mktId = data.mktId
# oes_req.invAcctId = data.invAcctId
# oes_req.mktId = data.mktId
# oes_req.securityId = data.securityId
else:
data = internal_order.req_data
oes_req.origClSeqNo = data.clSeqNo
oes_req.mktId = internal_order.req_data.mktId
oes_req.mktId = EXCHANGE_VT2OES[vt_req.exchange]
oes_req.clSeqNo = seq_id
oes_req.origClSeqNo = order_id
oes_req.invAcctId = ""
oes_req.securityId = vt_req.symbol
oes_req.userInfo = order_id
ret = OesApi_SendOrderCancelReq(self.env.ordChannel,
oes_req)
if ret >= 0:
pass
else:
pass
return
def schedule_query_order(self, internal_order: InternalOrder) -> Thread:
th = Thread(target=self.query_order, args=(internal_order,))
th.start()
return th
OesApi_SendOrderCancelReq(self._env.ordChannel,
oes_req)
def query_order(self, internal_order: InternalOrder) -> bool:
""""""
f = OesQryOrdFilterT()
if internal_order.req_data:
f.clSeqNo = internal_order.req_data.clSeqNo
f.mktId = internal_order.req_data.mktId
f.invAcctId = internal_order.req_data.invAcctId
else:
f.clSeqNo = internal_order.rpt_data.origClSeqNo
f.clOrdId = internal_order.rpt_data.origClOrdId
f.clEnvId = internal_order.rpt_data.origClEnvId
f.mktId = internal_order.rpt_data.mktId
f.invAcctId = internal_order.rpt_data.invAcctId
ret = OesApi_QueryOrder(self.env.qryChannel,
f.mktId = EXCHANGE_VT2OES[internal_order.vt_order.exchange]
f.clSeqNo = internal_order.order_id
ret = OesApi_QueryOrder(self._env.qryChannel,
f,
self.on_query_order
)
@ -707,8 +610,10 @@ class OesTdApi:
head: SMsgHeadT,
body: Any,
cursor: OesQryCursorT):
""""""
data: OesOrdCnfmT = cast.toOesOrdItemT(body)
i = self.order_manager.get_from_oes_data(data)
i = self.get_order(data.clSeqNo)
vt_order = i.vt_order
vt_order.status = STATUS_OES2VT[data.ordStatus]
vt_order.volume = data.ordQty - data.canceledQty
@ -716,28 +621,33 @@ class OesTdApi:
self.gateway.on_order(vt_order)
return 1
def init_query_orders(self) -> bool:
"""
:note: this function can be called only before calling send_order
:return:
"""
def query_orders(self) -> bool:
""""""
f = OesQryOrdFilterT()
ret = OesApi_QueryOrder(self.env.qryChannel,
ret = OesApi_QueryOrder(self._env.qryChannel,
f,
self.on_init_query_orders
self.on_query_orders
)
return ret >= 0
def on_init_query_orders(self,
session_info: SGeneralClientChannelT,
head: SMsgHeadT,
body: Any,
cursor: OesQryCursorT,
):
def on_query_orders(self,
session_info: SGeneralClientChannelT,
head: SMsgHeadT,
body: Any,
cursor: OesQryCursorT,
):
""""""
data: OesOrdCnfmT = cast.toOesOrdItemT(body)
i = self.order_manager.get_remote_created_order_from_oes_data(data)
if not i:
order_id = self.order_manager.new_remote_id()
try:
i = self.get_order(data.clSeqNo)
vt_order = i.vt_order
vt_order.status = STATUS_OES2VT[data.ordStatus]
vt_order.volume = data.ordQty - data.canceledQty
vt_order.traded = data.cumQty
self.gateway.on_order(vt_order)
except KeyError:
# order_id = self.order_manager.new_remote_id()
order_id = data.clSeqNo
if data.bsType == eOesBuySellTypeT.OES_BS_TYPE_BUY:
offset = Offset.OPEN
@ -748,7 +658,7 @@ class OesTdApi:
gateway_name=self.gateway.gateway_name,
symbol=data.securityId,
exchange=EXCHANGE_OES2VT[data.mktId],
orderid=order_id if order_id else data.userInfo, # generated id
orderid=order_id if order_id else data.origClSeqNo, # generated id
direction=Direction.NET,
offset=offset,
price=data.ordPrice / 10000,
@ -759,13 +669,17 @@ class OesTdApi:
# this time should be generated automatically or by a static function
time=datetime.utcnow().isoformat(),
)
self.order_manager.save_remote_created(order_id, vt_order, data)
self.save_order(order_id, vt_order)
self.gateway.on_order(vt_order)
return 1
else:
vt_order = i.vt_order
vt_order.status = STATUS_OES2VT[data.ordStatus]
vt_order.volume = data.ordQty - data.canceledQty
vt_order.traded = data.cumQty
self.gateway.on_order(vt_order)
return 1
return 1
def save_order(self, order_id: int, order: OrderData):
""""""
self._orders[order_id] = InternalOrder(
order_id=order_id,
vt_order=order,
)
def get_order(self, order_id: int):
""""""
return self._orders[order_id]