diff --git a/tests/api/base/RestClientTest.py b/tests/api/base/RestClientTest.py index 33d3d233..c6556202 100644 --- a/tests/api/base/RestClientTest.py +++ b/tests/api/base/RestClientTest.py @@ -1,12 +1,14 @@ # encoding: UTF-8 import json +import sys import unittest +import uuid from simplejson import JSONDecodeError from Promise import Promise -from vnpy.api.rest.RestClient import RestClient, Request +from vnpy.api.rest.RestClient import Request, RestClient class FailedError(RuntimeError): @@ -26,7 +28,7 @@ class TestRestClient(RestClient): req.data = json.dumps(req.data) req.headers = {'Content-Type': 'application/json'} return req - + def onError(self, exceptionType, exceptionValue, tb, req): self.p.set_exception(exceptionType, exceptionValue, tb) @@ -83,4 +85,47 @@ class RestfulClientTest(unittest.TestCase): self.c.addRequest('GET', '/image/svg', callback) with self.assertRaises(JSONDecodeError): self.c.p.get(3) + + +class RestfulClientErrorHandleTest(unittest.TestCase): + def setUp(self): + self.c = TestRestClient() + self.c.start() + + self.org_sys_hook = sys.excepthook + self.org_sys_stderr_write = sys.stderr.write + + sys.excepthook = self.nop + sys.stderr.write = self.nop + + def tearDown(self): + self.c.stop() + + @staticmethod + def nop(*args, **kwargs): + pass + + def test_onError(self): + """这个测试保证onError内不会再抛异常""" + target = uuid.uuid4() + + def callback(data, req): + pass + + def onError(*args, **kwargs): + try: + super(TestRestClient, self.c).onError(*args, **kwargs) + self.c.p.set_result(target) + except: + self.c.p.set_result(False) + self.c.onError = onError + + self.c.addRequest('GET', '/image/svg', callback) + + res = self.c.p.get(3) + self.assertEqual(target, res) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/api/base/WebSocketClientTest.py b/tests/api/base/WebSocketClientTest.py index dc9651dd..dbfb5745 100644 --- a/tests/api/base/WebSocketClientTest.py +++ b/tests/api/base/WebSocketClientTest.py @@ -1,5 +1,7 @@ # encoding: UTF-8 +import sys import unittest +import uuid from Promise import Promise from vnpy.api.websocket import WebsocketClient @@ -12,12 +14,14 @@ class TestWebsocketClient(WebsocketClient): super(TestWebsocketClient, self).__init__() self.init(host) self.p = Promise() + self.cp = Promise() def onPacket(self, packet): self.p.set_result(packet) pass def onConnected(self): + self.cp.set_result(True) pass def onError(self, exceptionType, exceptionValue, tb): @@ -30,6 +34,7 @@ class WebsocketClientTest(unittest.TestCase): def setUp(self): self.c = TestWebsocketClient() self.c.start() + self.c.cp.get(3) # wait for connected def tearDown(self): self.c.stop() @@ -42,3 +47,73 @@ class WebsocketClientTest(unittest.TestCase): res = self.c.p.get(3) self.assertEqual(res, req) + + def test_parseError(self): + class CustomException(Exception): pass + + def onPacket(packet): + raise CustomException("Just a exception") + + self.c.onPacket = onPacket + req = { + 'name': 'val' + } + self.c.sendPacket(req) + + with self.assertRaises(CustomException): + self.c.p.get(3) + + +class WebsocketClientErrorHandleTest(unittest.TestCase): + + def setUp(self): + + self.c = TestWebsocketClient() + self.c.start() + self.c.cp.get(3) # wait for connected + + self.org_sys_hook = sys.excepthook + self.org_sys_stderr_write = sys.stderr.write + + sys.excepthook = self.nop + sys.stderr.write = self.nop + + @staticmethod + def nop(*args, **kwargs): + pass + + def tearDown(self): + self.c.stop() + sys.excepthook = self.org_sys_hook + sys.stderr.write = self.org_sys_stderr_write + + def test_onError(self): + target= uuid.uuid4() + """这个测试保证onError内不会再抛Exception""" + class CustomException(Exception): + pass + + def onPacket(packet): + raise CustomException("Just a exception") + + def onError(*args, **kwargs): + try: + super(TestWebsocketClient, self.c).onError(*args, **kwargs) + self.c.p.set_result(target) + except: + self.c.p.set_result(False) + + self.c.onPacket = onPacket + self.c.onError = onError + + req = { + 'name': 'val' + } + self.c.sendPacket(req) + + res = self.c.p.get(3) + self.assertEqual(target, res) + + +if __name__ == '__main__': + unittest.main() diff --git a/vnpy/api/rest/RestClient.py b/vnpy/api/rest/RestClient.py index fa3ef0d4..32b96666 100644 --- a/vnpy/api/rest/RestClient.py +++ b/vnpy/api/rest/RestClient.py @@ -2,7 +2,9 @@ import sys +import traceback from Queue import Empty, Queue +from datetime import datetime from multiprocessing.dummy import Pool import requests @@ -34,6 +36,7 @@ class Request(object): self.headers = headers # type: dict self.onFailed = None # type: callable + self.onError = None # type: callable self.extra = None # type: Any self.response = None # type: requests.Response @@ -95,7 +98,8 @@ class RestClient(object): #---------------------------------------------------------------------- def start(self, n=3): """启动""" - assert not self._active + if self._active: + return self._active = True self._pool = Pool(n) @@ -128,6 +132,7 @@ class RestClient(object): data=None, # type: dict headers=None, # type: dict onFailed=None, # type: Callable[[int, Request], Any] + onError=None, # type: Callable[[type, Exception, traceback, Request], Any] extra=None # type: Any ): # type: (...)->Request """ @@ -139,6 +144,7 @@ class RestClient(object): :param data: dict for body :param headers: dict for headers :param onFailed: 请求失败后的回调(状态吗不为2xx时认为请求失败)(如果指定该值,默认的onFailed将不会被调用) type: (code, dict, Request) + :param onError: 请求出现Python错误后的回调(如果指定该值,默认的onError将不会被调用) type: (etype, evalue, tb, Request) :param extra: 返回值的extra字段会被设置为这个值。当然,你也可以在函数调用之后再设置这个字段。 :return: Request """ @@ -146,6 +152,7 @@ class RestClient(object): request = Request(method, path, params, data, headers, callback) request.extra = extra request.onFailed = onFailed + request.onError = onError self._queue.put(request) return request @@ -195,9 +202,29 @@ class RestClient(object): Python内部错误处理:默认行为是仍给excepthook :param request 如果是在处理请求的时候出错,它的值就是对应的Request,否则为None """ - print("error in request : {}\n".format(request)) + sys.stderr.write(self.exceptionDetail(exceptionType, exceptionValue, tb, request)) sys.excepthook(exceptionType, exceptionValue, tb) + #---------------------------------------------------------------------- + def exceptionDetail(self, + exceptionType, # type: type + exceptionValue, # type: Exception + tb, + request # type: Optional[Request] + ): + text = "[{}]: Unhandled RestClient Error:{}\n".format( + datetime.now().isoformat(), + exceptionType + ) + text += "request:{}\n".format(request) + text += "Exception trace: \n" + text += "".join(traceback.format_exception( + exceptionType, + exceptionValue, + tb, + )) + return text + #---------------------------------------------------------------------- def _processRequest(self, request, session): # type: (Request, requests.Session)->None """ @@ -231,7 +258,10 @@ class RestClient(object): except: request.status = RequestStatus.error t, v, tb = sys.exc_info() - self.onError(t, v, tb, request) + if request.onError: + request.onError(t, v, tb, request) + else: + self.onError(t, v, tb, request) #---------------------------------------------------------------------- def makeFullUrl(self, path): @@ -243,3 +273,4 @@ class RestClient(object): """ url = self.urlBase + path return url + diff --git a/vnpy/api/websocket/WebsocketClient.py b/vnpy/api/websocket/WebsocketClient.py index a029bda5..a8b9d436 100644 --- a/vnpy/api/websocket/WebsocketClient.py +++ b/vnpy/api/websocket/WebsocketClient.py @@ -3,13 +3,15 @@ ######################################################################## import json -import sys - import ssl +import sys import time -import websocket +import traceback +from datetime import datetime from threading import Lock, Thread +import websocket + class WebsocketClient(object): """ @@ -50,7 +52,11 @@ class WebsocketClient(object): #---------------------------------------------------------------------- def start(self): - """启动""" + """ + 启动 + :note 注意:启动之后不能立即发包,需要等待websocket连接成功。 + websocket连接成功之后会响应onConnected函数 + """ self._active = True self._workerThread = Thread(target=self._run) @@ -58,6 +64,10 @@ class WebsocketClient(object): self._pingThread = Thread(target=self._runPing) self._pingThread.start() + + # for debugging: + self._lastSentText = None + self._lastReceivedText = None #---------------------------------------------------------------------- def stop(self): @@ -80,7 +90,9 @@ class WebsocketClient(object): #---------------------------------------------------------------------- def sendPacket(self, dictObj): # type: (dict)->None """发出请求:相当于sendText(json.dumps(dictObj))""" - return self._getWs().send(json.dumps(dictObj), opcode=websocket.ABNF.OPCODE_TEXT) + text = json.dumps(dictObj) + self._recordLastSentText(text) + return self._getWs().send(text, opcode=websocket.ABNF.OPCODE_TEXT) #---------------------------------------------------------------------- def sendText(self, text): # type: (str)->None @@ -137,15 +149,15 @@ class WebsocketClient(object): try: ws = self._getWs() if ws: - stream = ws.recv() - if not stream: # recv在阻塞的时候ws被关闭 + text = ws.recv() + if not text: # recv在阻塞的时候ws被关闭 self._reconnect() continue - + self._recordLastReceivedText(text) try: - data = json.loads(stream) + data = json.loads(text) except ValueError as e: - print('websocket unable to parse data: ' + stream) + print('websocket unable to parse data: ' + text) raise e self.onPacket(data) except websocket.WebSocketConnectionClosedException: # 在调用recv之前ws就被关闭了 @@ -207,7 +219,43 @@ class WebsocketClient(object): pass #---------------------------------------------------------------------- - @staticmethod - def onError(exceptionType, exceptionValue, tb): - """Python错误回调""" + def onError(self, exceptionType, exceptionValue, tb): + """ + Python错误回调 + todo: 以后详细的错误信息最好记录在文件里,用uuid来联系/区分具体错误 + """ + sys.stderr.write(self.exceptionDetail(exceptionType, exceptionValue, tb)) + + # 丢给默认的错误处理函数(所以如果不重载onError,一般的结果是程序会崩溃) return sys.excepthook(exceptionType, exceptionValue, tb) + + #---------------------------------------------------------------------- + def exceptionDetail(self, exceptionType, exceptionValue, tb): + """打印详细的错误信息""" + text = "[{}]: Unhandled WebSocket Error:{}\n".format( + datetime.now().isoformat(), + exceptionType + ) + text += "LastSentText:\n{}\n".format(self._lastSentText) + text += "LastReceivedText:\n{}\n".format(self._lastReceivedText) + text += "Exception trace: \n" + text += "".join(traceback.format_exception( + exceptionType, + exceptionValue, + tb, + )) + return text + + #---------------------------------------------------------------------- + def _recordLastSentText(self, text): + """ + 用于Debug: 记录最后一次发送出去的text + """ + self._lastSentText = text[:200] + + #---------------------------------------------------------------------- + def _recordLastReceivedText(self, text): + """ + 用于Debug: 记录最后一次发送出去的text + """ + self._lastReceivedText = text[:200]