diff --git a/vnpy/api/rest/rest_client.py b/vnpy/api/rest/rest_client.py index 85cdd675..b0806278 100644 --- a/vnpy/api/rest/rest_client.py +++ b/vnpy/api/rest/rest_client.py @@ -1,10 +1,12 @@ +import multiprocessing +import os import sys import traceback from datetime import datetime from enum import Enum from multiprocessing.dummy import Pool -from queue import Empty, Queue -from typing import Any, Callable, Optional, Union +from threading import Lock +from typing import Any, Callable, List, Optional, Union import requests @@ -16,6 +18,9 @@ class RequestStatus(Enum): error = 3 # Exception raised +pool: multiprocessing.pool.Pool = Pool(os.cpu_count() * 20) + + class Request(object): """ Request object for status check. @@ -83,17 +88,32 @@ class RestClient(object): * Reimplement on_error function to handle exception msg. """ + class Session: + + def __init__(self, client: "RestClient", session: requests.Session): + self.client = client + self.session = session + + def __enter__(self): + return self.session + + def __exit__(self, exc_type, exc_val, exc_tb): + with self.client._sessions_lock: + self.client._sessions.append(self.session) + def __init__(self): """ """ self.url_base = '' # type: str self._active = False - self._queue = Queue() - self._pool = None # type: Pool - self.proxies = None + self._tasks_lock = Lock() + self._tasks: List[multiprocessing.pool.AsyncResult] = [] + self._sessions_lock = Lock() + self._sessions: List[requests.Session] = [] + def init(self, url_base: str, proxy_host: str = "", proxy_port: int = 0): """ Init rest client with url_base which is the API root address. @@ -115,10 +135,7 @@ class RestClient(object): """ if self._active: return - self._active = True - self._pool = Pool(n) - self._pool.apply_async(self._run) def stop(self): """ @@ -130,7 +147,8 @@ class RestClient(object): """ Wait till all requests are processed. """ - self._queue.join() + for task in self._tasks: + task.wait() def add_request( self, @@ -158,34 +176,40 @@ class RestClient(object): :return: Request """ request = Request( - method, - path, - params, - data, - headers, - callback, - on_failed, - on_error, - extra, + method=method, + path=path, + params=params, + data=data, + headers=headers, + callback=callback, + on_failed=on_failed, + on_error=on_error, + extra=extra, ) - self._queue.put(request) + task = pool.apply_async( + self._process_request, + args=[request, ], + callback=self._clean_finished_tasks, + # error_callback=lambda e: self.on_error(type(e), e, e.__traceback__, request), + ) + self._push_task(task) return request - def _run(self): - try: - session = self._create_session() - while self._active: - try: - request = self._queue.get(timeout=1) - try: - self._process_request(request, session) - finally: - self._queue.task_done() - except Empty: - pass - except Exception: - et, ev, tb = sys.exc_info() - self.on_error(et, ev, tb, None) + def _push_task(self, task): + with self._tasks_lock: + self._tasks.append(task) + + def _clean_finished_tasks(self, result: None): + with self._tasks_lock: + not_finished_tasks = [i for i in self._tasks if not i.ready()] + self._tasks = not_finished_tasks + + def _get_session(self): + with self._sessions_lock: + if self._sessions: + return self.Session(self, self._sessions.pop()) + else: + return self.Session(self, self._create_session()) def sign(self, request: Request): """ @@ -234,41 +258,42 @@ class RestClient(object): return text def _process_request( - self, request: Request, session: requests.Session + self, request: Request ): """ Sending request to server and get result. """ try: - request = self.sign(request) + with self._get_session() as session: + request = self.sign(request) - url = self.make_full_url(request.path) + url = self.make_full_url(request.path) - response = session.request( - request.method, - url, - headers=request.headers, - params=request.params, - data=request.data, - proxies=self.proxies, - ) - request.response = response - status_code = response.status_code - if status_code // 100 == 2: # 2xx codes are all successful - if status_code == 204: - json_body = None + response = session.request( + request.method, + url, + headers=request.headers, + params=request.params, + data=request.data, + proxies=self.proxies, + ) + request.response = response + status_code = response.status_code + if status_code // 100 == 2: # 2xx codes are all successful + if status_code == 204: + json_body = None + else: + json_body = response.json() + + request.callback(json_body, request) + request.status = RequestStatus.success else: - json_body = response.json() + request.status = RequestStatus.failed - request.callback(json_body, request) - request.status = RequestStatus.success - else: - request.status = RequestStatus.failed - - if request.on_failed: - request.on_failed(status_code, request) - else: - self.on_failed(status_code, request) + if request.on_failed: + request.on_failed(status_code, request) + else: + self.on_failed(status_code, request) except Exception: request.status = RequestStatus.error t, v, tb = sys.exc_info()