[Mod] 重构RestClient, WebSocketClient
主要有: * beforeRequest改名为sign * onFailed, onError由函数指针改为函数(利用Python的特性,可以重载,也可以仍然当函数指针用) * sendData改名为sendBinary * 优化了WebSocketClient的循环逻辑
This commit is contained in:
parent
028d344fc0
commit
35518106a2
@ -22,7 +22,7 @@ class TestRestClient(RestClient):
|
||||
|
||||
self.p = Promise()
|
||||
|
||||
def beforeRequest(self, req): #type: (Request)->Request
|
||||
def sign(self, req): #type: (Request)->Request
|
||||
req.data = json.dumps(req.data)
|
||||
req.headers = {'Content-Type': 'application/json'}
|
||||
return req
|
||||
|
@ -10,9 +10,10 @@ from enum import Enum
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
########################################################################
|
||||
class RequestStatus(Enum):
|
||||
ready = 0 # 刚刚构建
|
||||
success = 1 # 请求成功 code == 200
|
||||
success = 1 # 请求成功 code == 2xx
|
||||
failed = 2
|
||||
error = 3 # 发生错误 网络错误、json解析错误,等等
|
||||
|
||||
@ -24,53 +25,26 @@ class Request(object):
|
||||
"""
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def __init__(self, method, path, callback, params, data, headers):
|
||||
def __init__(self, method, path, params, data, headers, callback):
|
||||
self.method = method # type: str
|
||||
self.path = path # type: str
|
||||
self.callback = callback # type: callable
|
||||
self.params = params # type: dict #, bytes, str
|
||||
self.data = data # type: dict #, bytes, str
|
||||
self.headers = headers # type: dict
|
||||
self.onFailed = None # type: callable
|
||||
self.skipDefaultOnFailed = None # type: callable
|
||||
self.extra = None # type: Any
|
||||
|
||||
self._response = None # type: requests.Response
|
||||
self._status = RequestStatus.ready
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
@property
|
||||
def success(self):
|
||||
assert self.finished, "'success' property is only available after request is finished"
|
||||
return self._status == RequestStatus.success
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
@property
|
||||
def failed(self):
|
||||
assert self.finished, "'failed' property is only available after request is finished"
|
||||
return self._status == RequestStatus.failed
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
@property
|
||||
def finished(self):
|
||||
return self._status != RequestStatus.ready
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
@property
|
||||
def error(self):
|
||||
return self._status == RequestStatus.error
|
||||
self.onFailed = None # type: callable
|
||||
self.extra = None # type: Any
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
@property
|
||||
def response(self): # type: ()->requests.Response
|
||||
return self._response
|
||||
self.response = None # type: requests.Response
|
||||
self.status = RequestStatus.ready
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def __str__(self):
|
||||
statusCode = 'not finished'
|
||||
if self._response:
|
||||
statusCode = self._response.status_code
|
||||
return "{} {} : {} {}\n".format(self.method, self.path, self._status, statusCode)
|
||||
if self.response:
|
||||
statusCode = self.response.status_code
|
||||
return "{} {} : {} {}\n".format(self.method, self.path, self.status, statusCode)
|
||||
|
||||
|
||||
########################################################################
|
||||
@ -79,8 +53,8 @@ class RestClient(object):
|
||||
HTTP 客户端。目前是为了对接各种RESTfulAPI而设计的。
|
||||
|
||||
如果需要给请求加上签名,请设置beforeRequest, 函数类型请参考defaultBeforeRequest。
|
||||
如果需要处理非200的请求,请设置onFailed,函数类型请参考defaultOnFailed。
|
||||
如果每一个请求的非200返回都需要单独处理,使用addReq函数的onFailed参数
|
||||
如果需要处理非2xx的请求,请设置onFailed,函数类型请参考defaultOnFailed。
|
||||
如果每一个请求的非2xx返回都需要单独处理,使用addReq函数的onFailed参数
|
||||
如果捕获Python内部错误,例如网络连接失败等等,请设置onError,函数类型请参考defaultOnError
|
||||
"""
|
||||
|
||||
@ -90,11 +64,6 @@ class RestClient(object):
|
||||
:param urlBase: 路径前缀。 例如'https://www.bitmex.com/api/v1/'
|
||||
"""
|
||||
self.urlBase = None # type: str
|
||||
self.sessionProvider = requestsSessionProvider
|
||||
self.beforeRequest = self.defaultBeforeRequest # 任何请求在发送之前都会经过这个函数,让其加工
|
||||
self.onError = self.defaultOnError # Python内部错误处理
|
||||
self.onFailed = self.defaultOnFailed # statusCode != 2xx 时触发
|
||||
|
||||
self._active = False
|
||||
|
||||
self._queue = Queue()
|
||||
@ -102,16 +71,13 @@ class RestClient(object):
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def init(self, urlBase):
|
||||
"""初始化"""
|
||||
self.urlBase = urlBase
|
||||
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def setSessionProvider(self, sessionProvider):
|
||||
"""
|
||||
设置sessionProvider可以使用自定义的requests实现。
|
||||
@:param sessionProvider: callable。调用后应该返回一个对象带request函数的对象,该request函数的用法应该和requests中的一致。 \
|
||||
每个工作线程会调用该函数一次以期获得一个独立的session实例。
|
||||
"""
|
||||
self.sessionProvider = sessionProvider
|
||||
def _genereateSession(self):
|
||||
""""""
|
||||
return requests.session()
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def start(self, n=3):
|
||||
@ -143,8 +109,8 @@ class RestClient(object):
|
||||
#----------------------------------------------------------------------
|
||||
def addRequest(self, method, path, callback,
|
||||
params=None, data=None, headers = None,
|
||||
onFailed=None, skipDefaultOnFailed=True,
|
||||
extra=None): # type: (str, str, Callable[[dict, Request], Any], dict, dict, dict, Callable[[dict, Request], Any], bool, Any)->Request
|
||||
onFailed=None,
|
||||
extra=None): # type: (str, str, Callable[[dict, Request], Any], dict, dict, dict, Callable[[dict, Request], Any], Any)->Request
|
||||
"""
|
||||
发送一个请求
|
||||
:param method: GET, POST, PUT, DELETE, QUERY
|
||||
@ -153,22 +119,20 @@ class RestClient(object):
|
||||
:param params: dict for query string
|
||||
:param data: dict for body
|
||||
:param headers: dict for headers
|
||||
:param onFailed: 请求失败后的回调(状态吗不为2xx时认为请求失败) type: (code, dict, Request)
|
||||
:param skipDefaultOnFailed: 仅当onFailed参数存在时有效:忽略对虚函数onFailed的调用
|
||||
:param onFailed: 请求失败后的回调(状态吗不为2xx时认为请求失败)(如果指定该值,默认的onFailed将不会被调用) type: (code, dict, Request)
|
||||
:param extra: 返回值的extra字段会被设置为这个值。当然,你也可以在函数调用之后再设置这个字段。
|
||||
:return: Request
|
||||
"""
|
||||
|
||||
request = Request(method, path, callback, params, data, headers)
|
||||
request.onFailed = onFailed
|
||||
request.skipDefaultOnFailed = skipDefaultOnFailed
|
||||
request = Request(method, path, params, data, headers, callback)
|
||||
request.extra = extra
|
||||
request.onFailed = onFailed
|
||||
self._queue.put(request)
|
||||
return request
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def _run(self):
|
||||
session = self.sessionProvider()
|
||||
session = self._genereateSession()
|
||||
while self._active:
|
||||
try:
|
||||
request = self._queue.get(timeout=1)
|
||||
@ -180,8 +144,7 @@ class RestClient(object):
|
||||
pass
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
@staticmethod
|
||||
def defaultBeforeRequest(request): # type: (Request)->Request
|
||||
def sign(self, request): # type: (Request)->Request
|
||||
"""
|
||||
所有请求在发送之前都会经过这个函数
|
||||
签名之类的前奏可以在这里面实现
|
||||
@ -191,10 +154,9 @@ class RestClient(object):
|
||||
return request
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
@staticmethod
|
||||
def defaultOnFailed(httpStatusCode, request): # type:(int, Request)->None
|
||||
def onFailed(self, httpStatusCode, request): # type:(int, Request)->None
|
||||
"""
|
||||
请求失败处理函数(HttpStatusCode!=200).
|
||||
请求失败处理函数(HttpStatusCode!=2xx).
|
||||
默认行为是打印到stderr
|
||||
"""
|
||||
print("reuqest : {} {} failed with {}: \n"
|
||||
@ -207,11 +169,10 @@ class RestClient(object):
|
||||
request.headers,
|
||||
request.params,
|
||||
request.data,
|
||||
request._response.raw))
|
||||
request.response.raw))
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
@staticmethod
|
||||
def defaultOnError(exceptionType, exceptionValue, tb, request):
|
||||
def onError(self, exceptionType, exceptionValue, tb, request):
|
||||
"""
|
||||
Python内部错误处理:默认行为是仍给excepthook
|
||||
"""
|
||||
@ -224,37 +185,37 @@ class RestClient(object):
|
||||
用于内部:将请求发送出去
|
||||
"""
|
||||
try:
|
||||
request = self.beforeRequest(request)
|
||||
request = self.sign(request)
|
||||
|
||||
url = self.makeFullUrl(request.path)
|
||||
|
||||
response = session.request(request.method, url, headers=request.headers, params=request.params, data=request.data)
|
||||
request._response = response
|
||||
request.response = response
|
||||
|
||||
httpStatusCode = response.status_code
|
||||
if httpStatusCode/100 == 2:
|
||||
if httpStatusCode / 100 == 2: # 2xx都算成功,尽管交易所都用200
|
||||
jsonBody = response.json()
|
||||
request.callback(jsonBody, request)
|
||||
request._status = RequestStatus.success
|
||||
request.status = RequestStatus.success
|
||||
else:
|
||||
request._status = RequestStatus.failed
|
||||
request.status = RequestStatus.failed
|
||||
|
||||
if request.onFailed:
|
||||
request.onFailed(httpStatusCode, response.raw, request)
|
||||
|
||||
# 若没有onFailed或者没设置skipDefaultOnFailed,则调用默认的处理函数
|
||||
if not request.onFailed or not request.skipDefaultOnFailed:
|
||||
request.onFailed(httpStatusCode, request)
|
||||
else:
|
||||
self.onFailed(httpStatusCode, request)
|
||||
except:
|
||||
request._status = RequestStatus.error
|
||||
request.status = RequestStatus.error
|
||||
t, v, tb = sys.exc_info()
|
||||
self.onError(t, v, tb, request)
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def makeFullUrl(self, path):
|
||||
"""
|
||||
将相对路径补充成绝对路径:
|
||||
eg: makeFullUrl('/get') == 'http://xxxxx/get'
|
||||
:param path:
|
||||
:return:
|
||||
"""
|
||||
url = self.urlBase + path
|
||||
return url
|
||||
|
||||
|
||||
########################################################################
|
||||
def requestsSessionProvider():
|
||||
return requests.session()
|
||||
|
@ -1 +1 @@
|
||||
from .RestClient import Request, RequestStatus, RestClient, requestsSessionProvider
|
||||
from .RestClient import Request, RequestStatus, RestClient
|
||||
|
@ -42,7 +42,7 @@ class WebSocketClient(object):
|
||||
self.onPacket = self.defaultOnPacket
|
||||
self.onError = self.defaultOnError
|
||||
|
||||
self.createConnection = websocket.create_connection
|
||||
self._createConnection = websocket.create_connection
|
||||
|
||||
self._ws_lock = Lock()
|
||||
self._ws = None # type: websocket.WebSocket
|
||||
@ -57,7 +57,7 @@ class WebSocketClient(object):
|
||||
for internal usage
|
||||
:param func: a function like websocket.create_connection
|
||||
"""
|
||||
self.createConnection = func
|
||||
self._createConnection = func
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def init(self, host):
|
||||
@ -87,28 +87,29 @@ class WebSocketClient(object):
|
||||
#----------------------------------------------------------------------
|
||||
def sendPacket(self, dictObj): # type: (dict)->None
|
||||
"""发出请求:相当于sendText(json.dumps(dictObj))"""
|
||||
return self._get_ws().send(json.dumps(dictObj), opcode=websocket.ABNF.OPCODE_TEXT)
|
||||
return self._getWs().send(json.dumps(dictObj), opcode=websocket.ABNF.OPCODE_TEXT)
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def sendText(self, text): # type: (str)->None
|
||||
"""发送文本数据"""
|
||||
return self._get_ws().send(text, opcode=websocket.ABNF.OPCODE_TEXT)
|
||||
return self._getWs().send(text, opcode=websocket.ABNF.OPCODE_TEXT)
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def sendData(self, data): # type: (bytes)->None
|
||||
def sendBinary(self, data): # type: (bytes)->None
|
||||
"""发送字节数据"""
|
||||
return self._get_ws().send_binary(data)
|
||||
return self._getWs().send_binary(data)
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def _reconnect(self):
|
||||
"""重连"""
|
||||
self._disconnect()
|
||||
self._connect()
|
||||
if self._active:
|
||||
self._disconnect()
|
||||
self._connect()
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def _connect(self):
|
||||
""""""
|
||||
self._ws = self.createConnection(self.host, sslopt={'cert_reqs': ssl.CERT_NONE})
|
||||
self._ws = self._createConnection(self.host, sslopt={'cert_reqs': ssl.CERT_NONE})
|
||||
self.onConnected()
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
@ -122,36 +123,43 @@ class WebSocketClient(object):
|
||||
self._ws = None
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def _get_ws(self):
|
||||
def _getWs(self):
|
||||
with self._ws_lock:
|
||||
return self._ws
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def _run(self):
|
||||
"""运行"""
|
||||
ws = self._get_ws()
|
||||
"""
|
||||
运行,直到stop()被调用
|
||||
"""
|
||||
|
||||
# todo: onDisconnect
|
||||
while self._active:
|
||||
try:
|
||||
stream = ws.recv()
|
||||
if not stream:
|
||||
self.onDisconnected()
|
||||
if self._active:
|
||||
ws = self._getWs()
|
||||
if ws:
|
||||
stream = ws.recv()
|
||||
if not stream: # recv在阻塞的时候ws被关闭
|
||||
self._reconnect()
|
||||
continue
|
||||
|
||||
data = json.loads(stream)
|
||||
self.onPacket(data)
|
||||
except websocket.WebSocketConnectionClosedException:
|
||||
if self._active:
|
||||
self._reconnect()
|
||||
except:
|
||||
|
||||
data = json.loads(stream)
|
||||
self.onPacket(data)
|
||||
except websocket.WebSocketConnectionClosedException: # 在调用recv之前ws就被关闭了
|
||||
self._reconnect()
|
||||
except: # Python内部错误(onPacket内出错)
|
||||
et, ev, tb = sys.exc_info()
|
||||
self.onError(et, ev, tb)
|
||||
|
||||
self._reconnect()
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def _runPing(self):
|
||||
while self._active:
|
||||
self._ping()
|
||||
try:
|
||||
self._ping()
|
||||
except:
|
||||
et, ev, tb = sys.exc_info()
|
||||
# todo: just log this, notifying user is not necessary
|
||||
self.onError(et, ev, tb)
|
||||
for i in range(60):
|
||||
if not self._active:
|
||||
break
|
||||
@ -159,7 +167,9 @@ class WebSocketClient(object):
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def _ping(self):
|
||||
return self._get_ws().send('ping', websocket.ABNF.OPCODE_PING)
|
||||
ws = self._getWs()
|
||||
if ws:
|
||||
ws.send('ping', websocket.ABNF.OPCODE_PING)
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
@staticmethod
|
||||
|
Loading…
Reference in New Issue
Block a user