From 212b864c6933edbbc99e3212595a7acfe0d314db Mon Sep 17 00:00:00 2001 From: nanoric Date: Fri, 8 Mar 2019 08:56:05 -0400 Subject: [PATCH] [Add] Add lock to make public methods of gateway thread-safe --- vnpy/gateway/oes/oes_gateway.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/vnpy/gateway/oes/oes_gateway.py b/vnpy/gateway/oes/oes_gateway.py index 9901e755..7a3f97a7 100644 --- a/vnpy/gateway/oes/oes_gateway.py +++ b/vnpy/gateway/oes/oes_gateway.py @@ -4,7 +4,7 @@ import hashlib import os from gettext import gettext as _ -from threading import Thread +from threading import Thread, Lock from vnpy.trader.gateway import BaseGateway from vnpy.trader.object import (CancelRequest, OrderRequest, @@ -37,6 +37,12 @@ class OesGateway(BaseGateway): self.md_api = OesMdApi(self) self.td_api = OesTdApi(self) + self._lock_subscribe = Lock() + self._lock_send_order = Lock() + self._lock_cancel_order = Lock() + self._lock_query_position = Lock() + self._lock_query_account = Lock() + def connect(self, setting: dict): """""" if not setting['password'].startswith("md5:"): @@ -76,7 +82,7 @@ class OesGateway(BaseGateway): self.td_api.query_contracts() self.write_log("合约信息查询成功") self.td_api.query_position() - self.td_api.init_query_orders() + self.td_api.query_orders() self.td_api.start() else: self.write_log(_("无法连接到交易服务器,请检查你的配置")) @@ -92,23 +98,28 @@ class OesGateway(BaseGateway): def subscribe(self, req: SubscribeRequest): """""" - self.md_api.subscribe(req) + with self._lock_subscribe: + self.md_api.subscribe(req) def send_order(self, req: OrderRequest): """""" - return self.td_api.send_order(req) + with self._lock_send_order: + return self.td_api.send_order(req) def cancel_order(self, req: CancelRequest): """""" - self.td_api.cancel_order(req) + with self._lock_cancel_order: + self.td_api.cancel_order(req) def query_account(self): """""" - self.td_api.query_account() + with self._lock_query_account: + self.td_api.query_account() def query_position(self): """""" - self.td_api.query_position() + with self._lock_query_position: + self.td_api.query_position() def close(self): """"""