diff --git a/vnpy/restful/RestfulClient.py b/vnpy/restful/RestfulClient.py index 2a4c585d..ab15ca17 100644 --- a/vnpy/restful/RestfulClient.py +++ b/vnpy/restful/RestfulClient.py @@ -7,15 +7,16 @@ from Queue import Empty, Queue from abc import abstractmethod from multiprocessing.dummy import Pool +import requests from enum import Enum - - ######################################################################## +from typing import Any, Callable + + class RequestStatus(Enum): ready = 0 # 刚刚构建 - success = 1 # 请求成功 code == 200 - failed = 2 # 请求失败 code != 200 - error = 3 # 发生错误 网络错误、json解析错误,等等 + finished = 1 # 请求成功 code == 200 + error = 2 # 发生错误 网络错误、json解析错误,等等 ######################################################################## @@ -27,10 +28,11 @@ class Request(object): _last_id = 0 #---------------------------------------------------------------------- - def __init__(self): + def __init__(self, extra=None): Request._last_id += 1 self._id = Request._last_id self._status = RequestStatus.ready + self.extra = extra #---------------------------------------------------------------------- @property @@ -46,11 +48,6 @@ class Request(object): @property def error(self): return self._status == RequestStatus.error - - #---------------------------------------------------------------------- - @property - def success(self): - return self._status == RequestStatus.success ######################################################################## @@ -90,7 +87,8 @@ class RestfulClient(object): self._active = False #---------------------------------------------------------------------- - def addReq(self, method, path, onSuccess, onFailed, params=None, postdict=None): + def addReq(self, method, path, callback, params=None, data=None, + extra=None): # type: (str, str, Callable[[int, dict, Request], Any], dict, dict, Any)->Any """ 发送一个请求 :param method: GET, POST, PUT, DELETE, QUERY @@ -98,12 +96,12 @@ class RestfulClient(object): :param onSuccess: callback for success action(status code == 200) type: (dict, Request) :param onFailed: callback for failed action(status code != 200) type: (code, dict, Request) :param params: dict for query string - :param postdict: dict for body + :param data: dict for body :return: """ - - req = Request() - self._queue.put((method, path, onSuccess, onFailed, params, postdict, Request())) + + req = Request(extra=extra) + self._queue.put((method, path, callback, params, data, req)) return req #---------------------------------------------------------------------- @@ -111,18 +109,18 @@ class RestfulClient(object): session = self.sessionProvider() while self._active: try: - req = self._queue.get(timeout=1) - self.processReq(req, session) + method, path, callback, params, postdict, req = self._queue.get(timeout=1) + self.processReq(method, path, callback, params, postdict, req, session) except Empty: pass #---------------------------------------------------------------------- @abstractmethod - def beforeRequest(self, method, path, params, postdict): # type: (str, str, dict, dict)->(str, dict, dict, dict) + def beforeRequest(self, method, path, params, data): # type: (str, str, dict, dict)->(str, dict, dict, dict) """ 所有请求在发送之前都会经过这个函数 签名之类的前奏可以在这里面实现 - @:return (path, params, body, headers) body可以是request中data参数能接收的任意类型,例如bytes,str,dict都可以。 + @:return (method, path, params, body, headers) body可以是request中data参数能接收的任意类型,例如bytes,str,dict都可以。 """ pass @@ -135,27 +133,25 @@ class RestfulClient(object): sys.excepthook(exceptionType, exceptionValue, tb) #---------------------------------------------------------------------- - def processReq(self, info, session): + def processReq(self, method, path, callback, params, data, req, + session): # type: (str, str, callable, dict, dict, Request, requests.Session)->None """处理请求""" - method, path, onSuccess, onFailed, params, postdict, req = info # type: str, str, callable, callable, dict, dict, Request - - path, params, body, headers = self.beforeRequest(method, path, params, postdict) - - url = self.urlBase + path - try: - # 使用长连接的session,比短连接的耗时缩短20% - resp = session.request(method, url, headers=headers, params=params, data=body) + res = self.beforeRequest(method, path, params, data) + if res is None: + headers = {} + else: + method, path, params, data, headers = res + + url = self.urlBase + path + + resp = session.request(method, url, headers=headers, params=params, data=data) code = resp.status_code jsonBody = resp.json() - - if code == 200: - req._status = RequestStatus.success - onSuccess(jsonBody, req) - else: - req._status = RequestStatus.failed - onFailed(code, jsonBody, req) + + req._status = RequestStatus.finished + callback(code, jsonBody, req) except: req._status = RequestStatus.error t, v, tb = sys.exc_info() @@ -164,5 +160,4 @@ class RestfulClient(object): ######################################################################## def requestsSessionProvider(): - import requests return requests.session()