Merge pull request #2048 from nanoric/shared_pool

[Mod] RestClient: use shared thread pool.
This commit is contained in:
vn.py 2019-08-29 09:15:51 +08:00 committed by GitHub
commit 9f0cd1f472
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,10 +1,12 @@
import multiprocessing
import os
import sys import sys
import traceback import traceback
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from multiprocessing.dummy import Pool from multiprocessing.dummy import Pool
from queue import Empty, Queue from threading import Lock
from typing import Any, Callable, Optional, Union from typing import Any, Callable, List, Optional, Union
import requests import requests
@ -16,6 +18,9 @@ class RequestStatus(Enum):
error = 3 # Exception raised error = 3 # Exception raised
pool: multiprocessing.pool.Pool = Pool(os.cpu_count() * 20)
class Request(object): class Request(object):
""" """
Request object for status check. Request object for status check.
@ -83,17 +88,32 @@ class RestClient(object):
* Reimplement on_error function to handle exception msg. * Reimplement on_error function to handle exception msg.
""" """
class Session:
def __init__(self, client: "RestClient", session: requests.Session):
self.client = client
self.session = session
def __enter__(self):
return self.session
def __exit__(self, exc_type, exc_val, exc_tb):
with self.client._sessions_lock:
self.client._sessions.append(self.session)
def __init__(self): def __init__(self):
""" """
""" """
self.url_base = '' # type: str self.url_base = '' # type: str
self._active = False self._active = False
self._queue = Queue()
self._pool = None # type: Pool
self.proxies = None self.proxies = None
self._tasks_lock = Lock()
self._tasks: List[multiprocessing.pool.AsyncResult] = []
self._sessions_lock = Lock()
self._sessions: List[requests.Session] = []
def init(self, url_base: str, proxy_host: str = "", proxy_port: int = 0): def init(self, url_base: str, proxy_host: str = "", proxy_port: int = 0):
""" """
Init rest client with url_base which is the API root address. Init rest client with url_base which is the API root address.
@ -115,10 +135,7 @@ class RestClient(object):
""" """
if self._active: if self._active:
return return
self._active = True self._active = True
self._pool = Pool(n)
self._pool.apply_async(self._run)
def stop(self): def stop(self):
""" """
@ -130,7 +147,8 @@ class RestClient(object):
""" """
Wait till all requests are processed. Wait till all requests are processed.
""" """
self._queue.join() for task in self._tasks:
task.wait()
def add_request( def add_request(
self, self,
@ -158,34 +176,40 @@ class RestClient(object):
:return: Request :return: Request
""" """
request = Request( request = Request(
method, method=method,
path, path=path,
params, params=params,
data, data=data,
headers, headers=headers,
callback, callback=callback,
on_failed, on_failed=on_failed,
on_error, on_error=on_error,
extra, extra=extra,
) )
self._queue.put(request) task = pool.apply_async(
self._process_request,
args=[request, ],
callback=self._clean_finished_tasks,
# error_callback=lambda e: self.on_error(type(e), e, e.__traceback__, request),
)
self._push_task(task)
return request return request
def _run(self): def _push_task(self, task):
try: with self._tasks_lock:
session = self._create_session() self._tasks.append(task)
while self._active:
try: def _clean_finished_tasks(self, result: None):
request = self._queue.get(timeout=1) with self._tasks_lock:
try: not_finished_tasks = [i for i in self._tasks if not i.ready()]
self._process_request(request, session) self._tasks = not_finished_tasks
finally:
self._queue.task_done() def _get_session(self):
except Empty: with self._sessions_lock:
pass if self._sessions:
except Exception: return self.Session(self, self._sessions.pop())
et, ev, tb = sys.exc_info() else:
self.on_error(et, ev, tb, None) return self.Session(self, self._create_session())
def sign(self, request: Request): def sign(self, request: Request):
""" """
@ -234,41 +258,42 @@ class RestClient(object):
return text return text
def _process_request( def _process_request(
self, request: Request, session: requests.Session self, request: Request
): ):
""" """
Sending request to server and get result. Sending request to server and get result.
""" """
try: try:
request = self.sign(request) with self._get_session() as session:
request = self.sign(request)
url = self.make_full_url(request.path) url = self.make_full_url(request.path)
response = session.request( response = session.request(
request.method, request.method,
url, url,
headers=request.headers, headers=request.headers,
params=request.params, params=request.params,
data=request.data, data=request.data,
proxies=self.proxies, proxies=self.proxies,
) )
request.response = response request.response = response
status_code = response.status_code status_code = response.status_code
if status_code // 100 == 2: # 2xx codes are all successful if status_code // 100 == 2: # 2xx codes are all successful
if status_code == 204: if status_code == 204:
json_body = None json_body = None
else:
json_body = response.json()
request.callback(json_body, request)
request.status = RequestStatus.success
else: else:
json_body = response.json() request.status = RequestStatus.failed
request.callback(json_body, request) if request.on_failed:
request.status = RequestStatus.success request.on_failed(status_code, request)
else: else:
request.status = RequestStatus.failed self.on_failed(status_code, request)
if request.on_failed:
request.on_failed(status_code, request)
else:
self.on_failed(status_code, request)
except Exception: except Exception:
request.status = RequestStatus.error request.status = RequestStatus.error
t, v, tb = sys.exc_info() t, v, tb = sys.exc_info()