From 8fcccd8b1c4d36f8b1579141a3c8092e623fd632 Mon Sep 17 00:00:00 2001 From: "vn.py" Date: Wed, 12 Jun 2019 13:46:44 +0800 Subject: [PATCH] [Add] support UserProductInfo in CtpGateway --- vnpy/gateway/ctp/ctp_gateway.py | 25 ++++++++++++++++++++--- vnpy/gateway/ctptest/ctptest_gateway.py | 27 +++++++++++++++++++++---- 2 files changed, 45 insertions(+), 7 deletions(-) diff --git a/vnpy/gateway/ctp/ctp_gateway.py b/vnpy/gateway/ctp/ctp_gateway.py index 94c35de9..78255313 100644 --- a/vnpy/gateway/ctp/ctp_gateway.py +++ b/vnpy/gateway/ctp/ctp_gateway.py @@ -131,7 +131,8 @@ class CtpGateway(BaseGateway): "交易服务器": "", "行情服务器": "", "产品名称": "", - "授权编码": "" + "授权编码": "", + "产品信息": "" } exchanges = list(EXCHANGE_CTP2VT.values()) @@ -152,13 +153,14 @@ class CtpGateway(BaseGateway): md_address = setting["行情服务器"] appid = setting["产品名称"] auth_code = setting["授权编码"] + product_info = setting["产品信息"] if not td_address.startswith("tcp://"): td_address = "tcp://" + td_address if not md_address.startswith("tcp://"): md_address = "tcp://" + md_address - self.td_api.connect(td_address, userid, password, brokerid, auth_code, appid) + self.td_api.connect(td_address, userid, password, brokerid, auth_code, appid, product_info) self.md_api.connect(md_address, userid, password, brokerid) self.init_query() @@ -378,6 +380,7 @@ class CtpTdApi(TdApi): self.brokerid = 0 self.auth_code = "" self.appid = "" + self.product_info = "" self.frontid = 0 self.sessionid = 0 @@ -635,7 +638,16 @@ class CtpTdApi(TdApi): ) self.gateway.on_trade(trade) - def connect(self, address: str, userid: str, password: str, brokerid: int, auth_code: str, appid: str): + def connect( + self, + address: str, + userid: str, + password: str, + brokerid: int, + auth_code: str, + appid: str, + product_info + ): """ Start connection to server. """ @@ -644,6 +656,7 @@ class CtpTdApi(TdApi): self.brokerid = brokerid self.auth_code = auth_code self.appid = appid + self.product_info = product_info if not self.connect_status: path = get_folder_path(self.gateway_name.lower()) @@ -667,6 +680,9 @@ class CtpTdApi(TdApi): "AuthCode": self.auth_code, "AppID": self.appid } + + if self.product_info: + req["UserProductInfo"] = self.product_info self.reqid += 1 self.reqAuthenticate(req, self.reqid) @@ -684,6 +700,9 @@ class CtpTdApi(TdApi): "BrokerID": self.brokerid, "AppID": self.appid } + + if self.product_info: + req["UserProductInfo"] = self.product_info self.reqid += 1 self.reqUserLogin(req, self.reqid) diff --git a/vnpy/gateway/ctptest/ctptest_gateway.py b/vnpy/gateway/ctptest/ctptest_gateway.py index 6a95dc97..ad3dd73f 100644 --- a/vnpy/gateway/ctptest/ctptest_gateway.py +++ b/vnpy/gateway/ctptest/ctptest_gateway.py @@ -131,7 +131,8 @@ class CtptestGateway(BaseGateway): "交易服务器": "", "行情服务器": "", "产品名称": "", - "授权编码": "" + "授权编码": "", + "产品信息": "" } exchanges = list(EXCHANGE_CTP2VT.values()) @@ -152,13 +153,14 @@ class CtptestGateway(BaseGateway): md_address = setting["行情服务器"] appid = setting["产品名称"] auth_code = setting["授权编码"] + product_info = setting["产品信息"] if not td_address.startswith("tcp://"): td_address = "tcp://" + td_address if not md_address.startswith("tcp://"): md_address = "tcp://" + md_address - self.td_api.connect(td_address, userid, password, brokerid, auth_code, appid) + self.td_api.connect(td_address, userid, password, brokerid, auth_code, appid, product_info) self.md_api.connect(md_address, userid, password, brokerid) self.init_query() @@ -378,6 +380,7 @@ class CtpTdApi(TdApi): self.brokerid = 0 self.auth_code = "" self.appid = "" + self.product_info = "" self.frontid = 0 self.sessionid = 0 @@ -406,7 +409,7 @@ class CtpTdApi(TdApi): def onRspAuthenticate(self, data: dict, error: dict, reqid: int, last: bool): """""" if not error['ErrorID']: - self.authStatus = True + self.auth_staus = True self.gateway.write_log("交易服务器授权验证成功") self.login() else: @@ -635,7 +638,16 @@ class CtpTdApi(TdApi): ) self.gateway.on_trade(trade) - def connect(self, address: str, userid: str, password: str, brokerid: int, auth_code: str, appid: str): + def connect( + self, + address: str, + userid: str, + password: str, + brokerid: int, + auth_code: str, + appid: str, + product_info + ): """ Start connection to server. """ @@ -644,6 +656,7 @@ class CtpTdApi(TdApi): self.brokerid = brokerid self.auth_code = auth_code self.appid = appid + self.product_info = product_info if not self.connect_status: path = get_folder_path(self.gateway_name.lower()) @@ -667,6 +680,9 @@ class CtpTdApi(TdApi): "AuthCode": self.auth_code, "AppID": self.appid } + + if self.product_info: + req["UserProductInfo"] = self.product_info self.reqid += 1 self.reqAuthenticate(req, self.reqid) @@ -684,6 +700,9 @@ class CtpTdApi(TdApi): "BrokerID": self.brokerid, "AppID": self.appid } + + if self.product_info: + req["UserProductInfo"] = self.product_info self.reqid += 1 self.reqUserLogin(req, self.reqid)