[Mod] RestClient: use shared thread pool.
This commit is contained in:
parent
bdfa2bd895
commit
a87084ad2e
@ -1,10 +1,12 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from multiprocessing.dummy import Pool
|
||||
from queue import Empty, Queue
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from threading import Lock
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
import requests
|
||||
|
||||
@ -16,6 +18,9 @@ class RequestStatus(Enum):
|
||||
error = 3 # Exception raised
|
||||
|
||||
|
||||
pool: multiprocessing.pool.Pool = Pool(os.cpu_count() * 20)
|
||||
|
||||
|
||||
class Request(object):
|
||||
"""
|
||||
Request object for status check.
|
||||
@ -83,17 +88,32 @@ class RestClient(object):
|
||||
* Reimplement on_error function to handle exception msg.
|
||||
"""
|
||||
|
||||
class Session:
|
||||
|
||||
def __init__(self, client: "RestClient", session: requests.Session):
|
||||
self.client = client
|
||||
self.session = session
|
||||
|
||||
def __enter__(self):
|
||||
return self.session
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
with self.client._sessions_lock:
|
||||
self.client._sessions.append(self.session)
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
"""
|
||||
self.url_base = '' # type: str
|
||||
self._active = False
|
||||
|
||||
self._queue = Queue()
|
||||
self._pool = None # type: Pool
|
||||
|
||||
self.proxies = None
|
||||
|
||||
self._tasks_lock = Lock()
|
||||
self._tasks: List[multiprocessing.pool.AsyncResult] = []
|
||||
self._sessions_lock = Lock()
|
||||
self._sessions: List[requests.Session] = []
|
||||
|
||||
def init(self, url_base: str, proxy_host: str = "", proxy_port: int = 0):
|
||||
"""
|
||||
Init rest client with url_base which is the API root address.
|
||||
@ -115,10 +135,7 @@ class RestClient(object):
|
||||
"""
|
||||
if self._active:
|
||||
return
|
||||
|
||||
self._active = True
|
||||
self._pool = Pool(n)
|
||||
self._pool.apply_async(self._run)
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
@ -130,7 +147,8 @@ class RestClient(object):
|
||||
"""
|
||||
Wait till all requests are processed.
|
||||
"""
|
||||
self._queue.join()
|
||||
for task in self._tasks:
|
||||
task.wait()
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
@ -158,34 +176,40 @@ class RestClient(object):
|
||||
:return: Request
|
||||
"""
|
||||
request = Request(
|
||||
method,
|
||||
path,
|
||||
params,
|
||||
data,
|
||||
headers,
|
||||
callback,
|
||||
on_failed,
|
||||
on_error,
|
||||
extra,
|
||||
method=method,
|
||||
path=path,
|
||||
params=params,
|
||||
data=data,
|
||||
headers=headers,
|
||||
callback=callback,
|
||||
on_failed=on_failed,
|
||||
on_error=on_error,
|
||||
extra=extra,
|
||||
)
|
||||
self._queue.put(request)
|
||||
task = pool.apply_async(
|
||||
self._process_request,
|
||||
args=[request, ],
|
||||
callback=self._clean_finished_tasks,
|
||||
# error_callback=lambda e: self.on_error(type(e), e, e.__traceback__, request),
|
||||
)
|
||||
self._push_task(task)
|
||||
return request
|
||||
|
||||
def _run(self):
|
||||
try:
|
||||
session = self._create_session()
|
||||
while self._active:
|
||||
try:
|
||||
request = self._queue.get(timeout=1)
|
||||
try:
|
||||
self._process_request(request, session)
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
except Empty:
|
||||
pass
|
||||
except Exception:
|
||||
et, ev, tb = sys.exc_info()
|
||||
self.on_error(et, ev, tb, None)
|
||||
def _push_task(self, task):
|
||||
with self._tasks_lock:
|
||||
self._tasks.append(task)
|
||||
|
||||
def _clean_finished_tasks(self, result: None):
|
||||
with self._tasks_lock:
|
||||
not_finished_tasks = [i for i in self._tasks if not i.ready()]
|
||||
self._tasks = not_finished_tasks
|
||||
|
||||
def _get_session(self):
|
||||
with self._sessions_lock:
|
||||
if self._sessions:
|
||||
return self.Session(self, self._sessions.pop())
|
||||
else:
|
||||
return self.Session(self, self._create_session())
|
||||
|
||||
def sign(self, request: Request):
|
||||
"""
|
||||
@ -234,41 +258,42 @@ class RestClient(object):
|
||||
return text
|
||||
|
||||
def _process_request(
|
||||
self, request: Request, session: requests.Session
|
||||
self, request: Request
|
||||
):
|
||||
"""
|
||||
Sending request to server and get result.
|
||||
"""
|
||||
try:
|
||||
request = self.sign(request)
|
||||
with self._get_session() as session:
|
||||
request = self.sign(request)
|
||||
|
||||
url = self.make_full_url(request.path)
|
||||
url = self.make_full_url(request.path)
|
||||
|
||||
response = session.request(
|
||||
request.method,
|
||||
url,
|
||||
headers=request.headers,
|
||||
params=request.params,
|
||||
data=request.data,
|
||||
proxies=self.proxies,
|
||||
)
|
||||
request.response = response
|
||||
status_code = response.status_code
|
||||
if status_code // 100 == 2: # 2xx codes are all successful
|
||||
if status_code == 204:
|
||||
json_body = None
|
||||
response = session.request(
|
||||
request.method,
|
||||
url,
|
||||
headers=request.headers,
|
||||
params=request.params,
|
||||
data=request.data,
|
||||
proxies=self.proxies,
|
||||
)
|
||||
request.response = response
|
||||
status_code = response.status_code
|
||||
if status_code // 100 == 2: # 2xx codes are all successful
|
||||
if status_code == 204:
|
||||
json_body = None
|
||||
else:
|
||||
json_body = response.json()
|
||||
|
||||
request.callback(json_body, request)
|
||||
request.status = RequestStatus.success
|
||||
else:
|
||||
json_body = response.json()
|
||||
request.status = RequestStatus.failed
|
||||
|
||||
request.callback(json_body, request)
|
||||
request.status = RequestStatus.success
|
||||
else:
|
||||
request.status = RequestStatus.failed
|
||||
|
||||
if request.on_failed:
|
||||
request.on_failed(status_code, request)
|
||||
else:
|
||||
self.on_failed(status_code, request)
|
||||
if request.on_failed:
|
||||
request.on_failed(status_code, request)
|
||||
else:
|
||||
self.on_failed(status_code, request)
|
||||
except Exception:
|
||||
request.status = RequestStatus.error
|
||||
t, v, tb = sys.exc_info()
|
||||
|
Loading…
Reference in New Issue
Block a user