diff --git a/vnpy/api/rest/RestClient.py b/vnpy/api/rest/RestClient.py index e4bbec41..a013668d 100644 --- a/vnpy/api/rest/RestClient.py +++ b/vnpy/api/rest/RestClient.py @@ -37,14 +37,25 @@ class Request(object): self.extra = None # type: Any self.response = None # type: requests.Response - self.status = RequestStatus.ready + self.status = RequestStatus.ready # type: RequestStatus #---------------------------------------------------------------------- def __str__(self): - statusCode = 'not finished' - if self.response: + if self.response is None: + statusCode = 'terminated' + else: statusCode = self.response.status_code - return "{} {} : {} {}\n".format(self.method, self.path, self.status, statusCode) + return ("reuqest : {} {} {} because {}: \n" + "headers: {}\n" + "params: {}\n" + "data: {}\n" + "response:" + "{}\n" + .format(self.method, self.path, self.status.name, statusCode, + self.headers, + self.params, + self.data, + '' if self.response is None else self.response.text)) ######################################################################## @@ -61,7 +72,6 @@ class RestClient(object): #---------------------------------------------------------------------- def __init__(self): """ - :param urlBase: 路径前缀。 例如'https://www.bitmex.com/api/v1/' """ self.urlBase = None # type: str self._active = False @@ -71,11 +81,14 @@ class RestClient(object): #---------------------------------------------------------------------- def init(self, urlBase): - """初始化""" + """ + 初始化 + :param urlBase: 路径前缀。 例如'https://www.bitmex.com/api/v1/' + """ self.urlBase = urlBase #---------------------------------------------------------------------- - def _genereateSession(self): + def _generateSession(self): """""" return requests.session() @@ -107,10 +120,16 @@ class RestClient(object): self._queue.join() #---------------------------------------------------------------------- - def addRequest(self, method, path, callback, - params=None, data=None, headers = None, - onFailed=None, - extra=None): # type: (str, str, Callable[[dict, Request], Any], dict, dict, dict, Callable[[dict, Request], Any], Any)->Request + def addRequest(self, + method, # type: str + path, # type: str + callback, # type: Callable[[dict, Request], Any] + params=None, # type: dict + data=None, # type: dict + headers=None, # type: dict + onFailed=None, # type: Callable[[dict, Request], Any] + extra=None # type: Any + ): # type: (...)->Request """ 发送一个请求 :param method: GET, POST, PUT, DELETE, QUERY @@ -132,7 +151,7 @@ class RestClient(object): #---------------------------------------------------------------------- def _run(self): - session = self._genereateSession() + session = self._generateSession() while self._active: try: request = self._queue.get(timeout=1) @@ -159,17 +178,7 @@ class RestClient(object): 请求失败处理函数(HttpStatusCode!=2xx). 默认行为是打印到stderr """ - print("reuqest : {} {} failed with {}: \n" - "headers: {}\n" - "params: {}\n" - "data: {}\n" - "response:" - "{}\n" - .format(request.method, request.path, httpStatusCode, - request.headers, - request.params, - request.data, - request.response.raw)) + sys.stderr.write(str(request)) #---------------------------------------------------------------------- def onError(self, exceptionType, exceptionValue, tb, request): @@ -184,12 +193,17 @@ class RestClient(object): """ 用于内部:将请求发送出去 """ + # noinspection PyBroadException try: request = self.sign(request) url = self.makeFullUrl(request.path) - response = session.request(request.method, url, headers=request.headers, params=request.params, data=request.data) + response = session.request(request.method, + url, + headers=request.headers, + params=request.params, + data=request.data) request.response = response httpStatusCode = response.status_code diff --git a/vnpy/api/websocket/WebSocketClient.py b/vnpy/api/websocket/WebSocketClient.py index f2a92b55..7576d5c0 100644 --- a/vnpy/api/websocket/WebSocketClient.py +++ b/vnpy/api/websocket/WebSocketClient.py @@ -37,11 +37,6 @@ class WebSocketClient(object): """Constructor""" self.host = None # type: str - self.onConnected = self.defaultOnConnected - self.onDisconnected = self.defaultOnDisconnected - self.onPacket = self.defaultOnPacket - self.onError = self.defaultOnError - self._createConnection = websocket.create_connection self._ws_lock = Lock() @@ -51,14 +46,6 @@ class WebSocketClient(object): self._pingThread = None # type: Thread self._active = False - #---------------------------------------------------------------------- - def setCreateConnection(self, func): - """ - for internal usage - :param func: a function like websocket.create_connection - """ - self._createConnection = func - #---------------------------------------------------------------------- def init(self, host): self.host = host @@ -83,6 +70,14 @@ class WebSocketClient(object): """ self._active = False self._disconnect() + + def join(self): + """ + 等待所有工作线程退出 + 正确调用方式:先stop()后join() + """ + self._pingThread.join() + self._workerThread.join() #---------------------------------------------------------------------- def sendPacket(self, dictObj): # type: (dict)->None @@ -141,8 +136,13 @@ class WebSocketClient(object): stream = ws.recv() if not stream: # recv在阻塞的时候ws被关闭 self._reconnect() + continue - data = json.loads(stream) + try: + data = json.loads(stream) + except ValueError as e: + print('websocket unable to parse data: ' + stream) + raise e self.onPacket(data) except websocket.WebSocketConnectionClosedException: # 在调用recv之前ws就被关闭了 self._reconnect() @@ -173,7 +173,7 @@ class WebSocketClient(object): #---------------------------------------------------------------------- @staticmethod - def defaultOnConnected(): + def onConnected(): """ 连接成功回调 """ @@ -181,7 +181,7 @@ class WebSocketClient(object): #---------------------------------------------------------------------- @staticmethod - def defaultOnDisconnected(): + def onDisconnected(): """ 连接断开回调 """ @@ -189,7 +189,7 @@ class WebSocketClient(object): #---------------------------------------------------------------------- @staticmethod - def defaultOnPacket(packet): + def onPacket(packet): """ 数据回调。 只有在数据为json包的时候才会触发这个回调 @@ -200,6 +200,6 @@ class WebSocketClient(object): #---------------------------------------------------------------------- @staticmethod - def defaultOnError(exceptionType, exceptionValue, tb): + def onError(exceptionType, exceptionValue, tb): """Python错误回调""" return sys.excepthook(exceptionType, exceptionValue, tb)