[Mod] RestClient: use shared thread pool.

This commit is contained in:
nanoric 2019-08-28 14:47:37 +08:00
parent bdfa2bd895
commit a87084ad2e

View File

@ -1,10 +1,12 @@
import multiprocessing
import os
import sys import sys
import traceback import traceback
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from multiprocessing.dummy import Pool from multiprocessing.dummy import Pool
from queue import Empty, Queue from threading import Lock
from typing import Any, Callable, Optional, Union from typing import Any, Callable, List, Optional, Union
import requests import requests
@ -16,6 +18,9 @@ class RequestStatus(Enum):
error = 3 # Exception raised error = 3 # Exception raised
pool: multiprocessing.pool.Pool = Pool(os.cpu_count() * 20)
class Request(object): class Request(object):
""" """
Request object for status check. Request object for status check.
@ -83,17 +88,32 @@ class RestClient(object):
* Reimplement on_error function to handle exception msg. * Reimplement on_error function to handle exception msg.
""" """
class Session:
def __init__(self, client: "RestClient", session: requests.Session):
self.client = client
self.session = session
def __enter__(self):
return self.session
def __exit__(self, exc_type, exc_val, exc_tb):
with self.client._sessions_lock:
self.client._sessions.append(self.session)
def __init__(self): def __init__(self):
""" """
""" """
self.url_base = '' # type: str self.url_base = '' # type: str
self._active = False self._active = False
self._queue = Queue()
self._pool = None # type: Pool
self.proxies = None self.proxies = None
self._tasks_lock = Lock()
self._tasks: List[multiprocessing.pool.AsyncResult] = []
self._sessions_lock = Lock()
self._sessions: List[requests.Session] = []
def init(self, url_base: str, proxy_host: str = "", proxy_port: int = 0): def init(self, url_base: str, proxy_host: str = "", proxy_port: int = 0):
""" """
Init rest client with url_base which is the API root address. Init rest client with url_base which is the API root address.
@ -115,10 +135,7 @@ class RestClient(object):
""" """
if self._active: if self._active:
return return
self._active = True self._active = True
self._pool = Pool(n)
self._pool.apply_async(self._run)
def stop(self): def stop(self):
""" """
@ -130,7 +147,8 @@ class RestClient(object):
""" """
Wait till all requests are processed. Wait till all requests are processed.
""" """
self._queue.join() for task in self._tasks:
task.wait()
def add_request( def add_request(
self, self,
@ -158,34 +176,40 @@ class RestClient(object):
:return: Request :return: Request
""" """
request = Request( request = Request(
method, method=method,
path, path=path,
params, params=params,
data, data=data,
headers, headers=headers,
callback, callback=callback,
on_failed, on_failed=on_failed,
on_error, on_error=on_error,
extra, extra=extra,
) )
self._queue.put(request) task = pool.apply_async(
self._process_request,
args=[request, ],
callback=self._clean_finished_tasks,
# error_callback=lambda e: self.on_error(type(e), e, e.__traceback__, request),
)
self._push_task(task)
return request return request
def _run(self): def _push_task(self, task):
try: with self._tasks_lock:
session = self._create_session() self._tasks.append(task)
while self._active:
try: def _clean_finished_tasks(self, result: None):
request = self._queue.get(timeout=1) with self._tasks_lock:
try: not_finished_tasks = [i for i in self._tasks if not i.ready()]
self._process_request(request, session) self._tasks = not_finished_tasks
finally:
self._queue.task_done() def _get_session(self):
except Empty: with self._sessions_lock:
pass if self._sessions:
except Exception: return self.Session(self, self._sessions.pop())
et, ev, tb = sys.exc_info() else:
self.on_error(et, ev, tb, None) return self.Session(self, self._create_session())
def sign(self, request: Request): def sign(self, request: Request):
""" """
@ -234,12 +258,13 @@ class RestClient(object):
return text return text
def _process_request( def _process_request(
self, request: Request, session: requests.Session self, request: Request
): ):
""" """
Sending request to server and get result. Sending request to server and get result.
""" """
try: try:
with self._get_session() as session:
request = self.sign(request) request = self.sign(request)
url = self.make_full_url(request.path) url = self.make_full_url(request.path)