From 35518106a2a5886fa3bb7148236f628346abf5a5 Mon Sep 17 00:00:00 2001 From: nanoric Date: Thu, 18 Oct 2018 05:22:38 -0400 Subject: [PATCH] =?UTF-8?q?[Mod]=20=E9=87=8D=E6=9E=84RestClient,=20WebSock?= =?UTF-8?q?etClient?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 主要有: * beforeRequest改名为sign * onFailed, onError由函数指针改为函数(利用Python的特性,可以重载,也可以仍然当函数指针用) * sendData改名为sendBinary * 优化了WebSocketClient的循环逻辑 --- tests/api/base/RestClientTest.py | 2 +- vnpy/api/rest/RestClient.py | 125 +++++++++----------------- vnpy/api/rest/__init__.py | 2 +- vnpy/api/websocket/WebSocketClient.py | 64 +++++++------ 4 files changed, 82 insertions(+), 111 deletions(-) diff --git a/tests/api/base/RestClientTest.py b/tests/api/base/RestClientTest.py index 6f350638..33d3d233 100644 --- a/tests/api/base/RestClientTest.py +++ b/tests/api/base/RestClientTest.py @@ -22,7 +22,7 @@ class TestRestClient(RestClient): self.p = Promise() - def beforeRequest(self, req): #type: (Request)->Request + def sign(self, req): #type: (Request)->Request req.data = json.dumps(req.data) req.headers = {'Content-Type': 'application/json'} return req diff --git a/vnpy/api/rest/RestClient.py b/vnpy/api/rest/RestClient.py index 4560671d..e4bbec41 100644 --- a/vnpy/api/rest/RestClient.py +++ b/vnpy/api/rest/RestClient.py @@ -10,9 +10,10 @@ from enum import Enum from typing import Any, Callable +######################################################################## class RequestStatus(Enum): ready = 0 # 刚刚构建 - success = 1 # 请求成功 code == 200 + success = 1 # 请求成功 code == 2xx failed = 2 error = 3 # 发生错误 网络错误、json解析错误,等等 @@ -24,53 +25,26 @@ class Request(object): """ #---------------------------------------------------------------------- - def __init__(self, method, path, callback, params, data, headers): + def __init__(self, method, path, params, data, headers, callback): self.method = method # type: str self.path = path # type: str self.callback = callback # type: callable self.params = params # type: dict #, bytes, str self.data = data # type: dict #, bytes, str self.headers = headers # type: dict - self.onFailed = None # type: callable - self.skipDefaultOnFailed = None # type: callable - self.extra = None # type: Any - self._response = None # type: requests.Response - self._status = RequestStatus.ready - - #---------------------------------------------------------------------- - @property - def success(self): - assert self.finished, "'success' property is only available after request is finished" - return self._status == RequestStatus.success - - #---------------------------------------------------------------------- - @property - def failed(self): - assert self.finished, "'failed' property is only available after request is finished" - return self._status == RequestStatus.failed - - #---------------------------------------------------------------------- - @property - def finished(self): - return self._status != RequestStatus.ready - - #---------------------------------------------------------------------- - @property - def error(self): - return self._status == RequestStatus.error + self.onFailed = None # type: callable + self.extra = None # type: Any - #---------------------------------------------------------------------- - @property - def response(self): # type: ()->requests.Response - return self._response + self.response = None # type: requests.Response + self.status = RequestStatus.ready #---------------------------------------------------------------------- def __str__(self): statusCode = 'not finished' - if self._response: - statusCode = self._response.status_code - return "{} {} : {} {}\n".format(self.method, self.path, self._status, statusCode) + if self.response: + statusCode = self.response.status_code + return "{} {} : {} {}\n".format(self.method, self.path, self.status, statusCode) ######################################################################## @@ -79,8 +53,8 @@ class RestClient(object): HTTP 客户端。目前是为了对接各种RESTfulAPI而设计的。 如果需要给请求加上签名,请设置beforeRequest, 函数类型请参考defaultBeforeRequest。 - 如果需要处理非200的请求,请设置onFailed,函数类型请参考defaultOnFailed。 - 如果每一个请求的非200返回都需要单独处理,使用addReq函数的onFailed参数 + 如果需要处理非2xx的请求,请设置onFailed,函数类型请参考defaultOnFailed。 + 如果每一个请求的非2xx返回都需要单独处理,使用addReq函数的onFailed参数 如果捕获Python内部错误,例如网络连接失败等等,请设置onError,函数类型请参考defaultOnError """ @@ -90,11 +64,6 @@ class RestClient(object): :param urlBase: 路径前缀。 例如'https://www.bitmex.com/api/v1/' """ self.urlBase = None # type: str - self.sessionProvider = requestsSessionProvider - self.beforeRequest = self.defaultBeforeRequest # 任何请求在发送之前都会经过这个函数,让其加工 - self.onError = self.defaultOnError # Python内部错误处理 - self.onFailed = self.defaultOnFailed # statusCode != 2xx 时触发 - self._active = False self._queue = Queue() @@ -102,16 +71,13 @@ class RestClient(object): #---------------------------------------------------------------------- def init(self, urlBase): + """初始化""" self.urlBase = urlBase - + #---------------------------------------------------------------------- - def setSessionProvider(self, sessionProvider): - """ - 设置sessionProvider可以使用自定义的requests实现。 - @:param sessionProvider: callable。调用后应该返回一个对象带request函数的对象,该request函数的用法应该和requests中的一致。 \ - 每个工作线程会调用该函数一次以期获得一个独立的session实例。 - """ - self.sessionProvider = sessionProvider + def _genereateSession(self): + """""" + return requests.session() #---------------------------------------------------------------------- def start(self, n=3): @@ -143,8 +109,8 @@ class RestClient(object): #---------------------------------------------------------------------- def addRequest(self, method, path, callback, params=None, data=None, headers = None, - onFailed=None, skipDefaultOnFailed=True, - extra=None): # type: (str, str, Callable[[dict, Request], Any], dict, dict, dict, Callable[[dict, Request], Any], bool, Any)->Request + onFailed=None, + extra=None): # type: (str, str, Callable[[dict, Request], Any], dict, dict, dict, Callable[[dict, Request], Any], Any)->Request """ 发送一个请求 :param method: GET, POST, PUT, DELETE, QUERY @@ -153,22 +119,20 @@ class RestClient(object): :param params: dict for query string :param data: dict for body :param headers: dict for headers - :param onFailed: 请求失败后的回调(状态吗不为2xx时认为请求失败) type: (code, dict, Request) - :param skipDefaultOnFailed: 仅当onFailed参数存在时有效:忽略对虚函数onFailed的调用 + :param onFailed: 请求失败后的回调(状态吗不为2xx时认为请求失败)(如果指定该值,默认的onFailed将不会被调用) type: (code, dict, Request) :param extra: 返回值的extra字段会被设置为这个值。当然,你也可以在函数调用之后再设置这个字段。 :return: Request """ - request = Request(method, path, callback, params, data, headers) - request.onFailed = onFailed - request.skipDefaultOnFailed = skipDefaultOnFailed + request = Request(method, path, params, data, headers, callback) request.extra = extra + request.onFailed = onFailed self._queue.put(request) return request #---------------------------------------------------------------------- def _run(self): - session = self.sessionProvider() + session = self._genereateSession() while self._active: try: request = self._queue.get(timeout=1) @@ -180,8 +144,7 @@ class RestClient(object): pass #---------------------------------------------------------------------- - @staticmethod - def defaultBeforeRequest(request): # type: (Request)->Request + def sign(self, request): # type: (Request)->Request """ 所有请求在发送之前都会经过这个函数 签名之类的前奏可以在这里面实现 @@ -191,10 +154,9 @@ class RestClient(object): return request #---------------------------------------------------------------------- - @staticmethod - def defaultOnFailed(httpStatusCode, request): # type:(int, Request)->None + def onFailed(self, httpStatusCode, request): # type:(int, Request)->None """ - 请求失败处理函数(HttpStatusCode!=200). + 请求失败处理函数(HttpStatusCode!=2xx). 默认行为是打印到stderr """ print("reuqest : {} {} failed with {}: \n" @@ -207,11 +169,10 @@ class RestClient(object): request.headers, request.params, request.data, - request._response.raw)) + request.response.raw)) #---------------------------------------------------------------------- - @staticmethod - def defaultOnError(exceptionType, exceptionValue, tb, request): + def onError(self, exceptionType, exceptionValue, tb, request): """ Python内部错误处理:默认行为是仍给excepthook """ @@ -224,37 +185,37 @@ class RestClient(object): 用于内部:将请求发送出去 """ try: - request = self.beforeRequest(request) + request = self.sign(request) url = self.makeFullUrl(request.path) response = session.request(request.method, url, headers=request.headers, params=request.params, data=request.data) - request._response = response + request.response = response httpStatusCode = response.status_code - if httpStatusCode/100 == 2: + if httpStatusCode / 100 == 2: # 2xx都算成功,尽管交易所都用200 jsonBody = response.json() request.callback(jsonBody, request) - request._status = RequestStatus.success + request.status = RequestStatus.success else: - request._status = RequestStatus.failed + request.status = RequestStatus.failed if request.onFailed: - request.onFailed(httpStatusCode, response.raw, request) - - # 若没有onFailed或者没设置skipDefaultOnFailed,则调用默认的处理函数 - if not request.onFailed or not request.skipDefaultOnFailed: + request.onFailed(httpStatusCode, request) + else: self.onFailed(httpStatusCode, request) except: - request._status = RequestStatus.error + request.status = RequestStatus.error t, v, tb = sys.exc_info() self.onError(t, v, tb, request) + #---------------------------------------------------------------------- def makeFullUrl(self, path): + """ + 将相对路径补充成绝对路径: + eg: makeFullUrl('/get') == 'http://xxxxx/get' + :param path: + :return: + """ url = self.urlBase + path return url - - -######################################################################## -def requestsSessionProvider(): - return requests.session() diff --git a/vnpy/api/rest/__init__.py b/vnpy/api/rest/__init__.py index f0cbea27..f1e7410c 100644 --- a/vnpy/api/rest/__init__.py +++ b/vnpy/api/rest/__init__.py @@ -1 +1 @@ -from .RestClient import Request, RequestStatus, RestClient, requestsSessionProvider +from .RestClient import Request, RequestStatus, RestClient diff --git a/vnpy/api/websocket/WebSocketClient.py b/vnpy/api/websocket/WebSocketClient.py index c98bb339..f2a92b55 100644 --- a/vnpy/api/websocket/WebSocketClient.py +++ b/vnpy/api/websocket/WebSocketClient.py @@ -42,7 +42,7 @@ class WebSocketClient(object): self.onPacket = self.defaultOnPacket self.onError = self.defaultOnError - self.createConnection = websocket.create_connection + self._createConnection = websocket.create_connection self._ws_lock = Lock() self._ws = None # type: websocket.WebSocket @@ -57,7 +57,7 @@ class WebSocketClient(object): for internal usage :param func: a function like websocket.create_connection """ - self.createConnection = func + self._createConnection = func #---------------------------------------------------------------------- def init(self, host): @@ -87,28 +87,29 @@ class WebSocketClient(object): #---------------------------------------------------------------------- def sendPacket(self, dictObj): # type: (dict)->None """发出请求:相当于sendText(json.dumps(dictObj))""" - return self._get_ws().send(json.dumps(dictObj), opcode=websocket.ABNF.OPCODE_TEXT) + return self._getWs().send(json.dumps(dictObj), opcode=websocket.ABNF.OPCODE_TEXT) #---------------------------------------------------------------------- def sendText(self, text): # type: (str)->None """发送文本数据""" - return self._get_ws().send(text, opcode=websocket.ABNF.OPCODE_TEXT) + return self._getWs().send(text, opcode=websocket.ABNF.OPCODE_TEXT) #---------------------------------------------------------------------- - def sendData(self, data): # type: (bytes)->None + def sendBinary(self, data): # type: (bytes)->None """发送字节数据""" - return self._get_ws().send_binary(data) + return self._getWs().send_binary(data) #---------------------------------------------------------------------- def _reconnect(self): """重连""" - self._disconnect() - self._connect() + if self._active: + self._disconnect() + self._connect() #---------------------------------------------------------------------- def _connect(self): """""" - self._ws = self.createConnection(self.host, sslopt={'cert_reqs': ssl.CERT_NONE}) + self._ws = self._createConnection(self.host, sslopt={'cert_reqs': ssl.CERT_NONE}) self.onConnected() #---------------------------------------------------------------------- @@ -122,36 +123,43 @@ class WebSocketClient(object): self._ws = None #---------------------------------------------------------------------- - def _get_ws(self): + def _getWs(self): with self._ws_lock: return self._ws #---------------------------------------------------------------------- def _run(self): - """运行""" - ws = self._get_ws() + """ + 运行,直到stop()被调用 + """ + + # todo: onDisconnect while self._active: try: - stream = ws.recv() - if not stream: - self.onDisconnected() - if self._active: + ws = self._getWs() + if ws: + stream = ws.recv() + if not stream: # recv在阻塞的时候ws被关闭 self._reconnect() - continue - - data = json.loads(stream) - self.onPacket(data) - except websocket.WebSocketConnectionClosedException: - if self._active: - self._reconnect() - except: + + data = json.loads(stream) + self.onPacket(data) + except websocket.WebSocketConnectionClosedException: # 在调用recv之前ws就被关闭了 + self._reconnect() + except: # Python内部错误(onPacket内出错) et, ev, tb = sys.exc_info() self.onError(et, ev, tb) - + self._reconnect() + #---------------------------------------------------------------------- def _runPing(self): while self._active: - self._ping() + try: + self._ping() + except: + et, ev, tb = sys.exc_info() + # todo: just log this, notifying user is not necessary + self.onError(et, ev, tb) for i in range(60): if not self._active: break @@ -159,7 +167,9 @@ class WebSocketClient(object): #---------------------------------------------------------------------- def _ping(self): - return self._get_ws().send('ping', websocket.ABNF.OPCODE_PING) + ws = self._getWs() + if ws: + ws.send('ping', websocket.ABNF.OPCODE_PING) #---------------------------------------------------------------------- @staticmethod