Merge pull request #2048 from nanoric/shared_pool

[Mod] RestClient: use shared thread pool.
This commit is contained in:
vn.py 2019-08-29 09:15:51 +08:00 committed by GitHub
commit 9f0cd1f472
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,10 +1,12 @@
import multiprocessing
import os
import sys
import traceback
from datetime import datetime
from enum import Enum
from multiprocessing.dummy import Pool
from queue import Empty, Queue
from typing import Any, Callable, Optional, Union
from threading import Lock
from typing import Any, Callable, List, Optional, Union
import requests
@ -16,6 +18,9 @@ class RequestStatus(Enum):
error = 3 # Exception raised
pool: multiprocessing.pool.Pool = Pool(os.cpu_count() * 20)
class Request(object):
"""
Request object for status check.
@ -83,17 +88,32 @@ class RestClient(object):
* Reimplement on_error function to handle exception msg.
"""
class Session:
def __init__(self, client: "RestClient", session: requests.Session):
self.client = client
self.session = session
def __enter__(self):
return self.session
def __exit__(self, exc_type, exc_val, exc_tb):
with self.client._sessions_lock:
self.client._sessions.append(self.session)
def __init__(self):
"""
"""
self.url_base = '' # type: str
self._active = False
self._queue = Queue()
self._pool = None # type: Pool
self.proxies = None
self._tasks_lock = Lock()
self._tasks: List[multiprocessing.pool.AsyncResult] = []
self._sessions_lock = Lock()
self._sessions: List[requests.Session] = []
def init(self, url_base: str, proxy_host: str = "", proxy_port: int = 0):
"""
Init rest client with url_base which is the API root address.
@ -115,10 +135,7 @@ class RestClient(object):
"""
if self._active:
return
self._active = True
self._pool = Pool(n)
self._pool.apply_async(self._run)
def stop(self):
"""
@ -130,7 +147,8 @@ class RestClient(object):
"""
Wait till all requests are processed.
"""
self._queue.join()
for task in self._tasks:
task.wait()
def add_request(
self,
@ -158,34 +176,40 @@ class RestClient(object):
:return: Request
"""
request = Request(
method,
path,
params,
data,
headers,
callback,
on_failed,
on_error,
extra,
method=method,
path=path,
params=params,
data=data,
headers=headers,
callback=callback,
on_failed=on_failed,
on_error=on_error,
extra=extra,
)
self._queue.put(request)
task = pool.apply_async(
self._process_request,
args=[request, ],
callback=self._clean_finished_tasks,
# error_callback=lambda e: self.on_error(type(e), e, e.__traceback__, request),
)
self._push_task(task)
return request
def _run(self):
try:
session = self._create_session()
while self._active:
try:
request = self._queue.get(timeout=1)
try:
self._process_request(request, session)
finally:
self._queue.task_done()
except Empty:
pass
except Exception:
et, ev, tb = sys.exc_info()
self.on_error(et, ev, tb, None)
def _push_task(self, task):
with self._tasks_lock:
self._tasks.append(task)
def _clean_finished_tasks(self, result: None):
with self._tasks_lock:
not_finished_tasks = [i for i in self._tasks if not i.ready()]
self._tasks = not_finished_tasks
def _get_session(self):
with self._sessions_lock:
if self._sessions:
return self.Session(self, self._sessions.pop())
else:
return self.Session(self, self._create_session())
def sign(self, request: Request):
"""
@ -234,41 +258,42 @@ class RestClient(object):
return text
def _process_request(
self, request: Request, session: requests.Session
self, request: Request
):
"""
Sending request to server and get result.
"""
try:
request = self.sign(request)
with self._get_session() as session:
request = self.sign(request)
url = self.make_full_url(request.path)
url = self.make_full_url(request.path)
response = session.request(
request.method,
url,
headers=request.headers,
params=request.params,
data=request.data,
proxies=self.proxies,
)
request.response = response
status_code = response.status_code
if status_code // 100 == 2: # 2xx codes are all successful
if status_code == 204:
json_body = None
response = session.request(
request.method,
url,
headers=request.headers,
params=request.params,
data=request.data,
proxies=self.proxies,
)
request.response = response
status_code = response.status_code
if status_code // 100 == 2: # 2xx codes are all successful
if status_code == 204:
json_body = None
else:
json_body = response.json()
request.callback(json_body, request)
request.status = RequestStatus.success
else:
json_body = response.json()
request.status = RequestStatus.failed
request.callback(json_body, request)
request.status = RequestStatus.success
else:
request.status = RequestStatus.failed
if request.on_failed:
request.on_failed(status_code, request)
else:
self.on_failed(status_code, request)
if request.on_failed:
request.on_failed(status_code, request)
else:
self.on_failed(status_code, request)
except Exception:
request.status = RequestStatus.error
t, v, tb = sys.exc_info()