diff --git a/vnpy/api/rest/rest_client.py b/vnpy/api/rest/rest_client.py index b0806278..0d0260f6 100644 --- a/vnpy/api/rest/rest_client.py +++ b/vnpy/api/rest/rest_client.py @@ -1,7 +1,9 @@ +import logging import multiprocessing import os import sys import traceback +import uuid from datetime import datetime from enum import Enum from multiprocessing.dummy import Pool @@ -10,6 +12,8 @@ from typing import Any, Callable, List, Optional, Union import requests +from vnpy.trader.utility import get_file_logger + class RequestStatus(Enum): ready = 0 # Request created @@ -105,21 +109,35 @@ class RestClient(object): """ """ self.url_base = '' # type: str - self._active = False + self.logger: Optional[logging.Logger] = None self.proxies = None + self._active = False + self._tasks_lock = Lock() self._tasks: List[multiprocessing.pool.AsyncResult] = [] self._sessions_lock = Lock() self._sessions: List[requests.Session] = [] - def init(self, url_base: str, proxy_host: str = "", proxy_port: int = 0): + def init(self, + url_base: str, + proxy_host: str = "", + proxy_port: int = 0, + log_path: Optional[str] = None, + ): """ Init rest client with url_base which is the API root address. e.g. 'https://www.bitmex.com/api/v1/' + :param url_base: + :param proxy_host: + :param proxy_port: + :param log_path: optional. file to save log. """ self.url_base = url_base + if log_path is not None: + self.logger = get_file_logger(log_path) + self.logger.setLevel(logging.DEBUG) if proxy_host and proxy_port: proxy = f"{proxy_host}:{proxy_port}" @@ -257,6 +275,11 @@ class RestClient(object): ) return text + def _log(self, msg, *args): + logger = self.logger + if logger: + logger.debug(msg, *args) + def _process_request( self, request: Request ): @@ -265,19 +288,32 @@ class RestClient(object): """ try: with self._get_session() as session: + # sign request = self.sign(request) + # send request url = self.make_full_url(request.path) + uid = uuid.uuid4() + method = request.method + headers = request.headers + params = request.params + data = request.data + self._log("[%s] sending request %s %s, headers:%s, params:%s, data:%s", + uid, method, url, + headers, params, data) response = session.request( - request.method, + method, url, - headers=request.headers, - params=request.params, - data=request.data, + headers=headers, + params=params, + data=data, proxies=self.proxies, ) request.response = response + self._log("[%s] received response from %s:%s", uid, method, url) + + # check result & call corresponding callbacks status_code = response.status_code if status_code // 100 == 2: # 2xx codes are all successful if status_code == 204: diff --git a/vnpy/api/websocket/websocket_client.py b/vnpy/api/websocket/websocket_client.py index e7c767a0..d89529a4 100644 --- a/vnpy/api/websocket/websocket_client.py +++ b/vnpy/api/websocket/websocket_client.py @@ -1,14 +1,18 @@ import json +import logging +import socket import ssl import sys import traceback -import socket from datetime import datetime from threading import Lock, Thread from time import sleep +from typing import Optional import websocket +from vnpy.trader.utility import get_file_logger + class WebsocketClient(object): """ @@ -47,19 +51,36 @@ class WebsocketClient(object): self.proxy_host = None self.proxy_port = None - self.ping_interval = 60 # seconds + self.ping_interval = 60 # seconds self.header = {} + self.logger: Optional[logging.Logger] = None + # For debugging self._last_sent_text = None self._last_received_text = None - def init(self, host: str, proxy_host: str = "", proxy_port: int = 0, ping_interval: int = 60, header: dict = None): + def init(self, + host: str, + proxy_host: str = "", + proxy_port: int = 0, + ping_interval: int = 60, + header: dict = None, + log_path: Optional[str] = None, + ): """ + :param host: + :param proxy_host: + :param proxy_port: + :param header: :param ping_interval: unit: seconds, type: int + :param log_path: optional. file to save log. """ self.host = host self.ping_interval = ping_interval # seconds + if log_path is not None: + self.logger = get_file_logger(log_path) + self.logger.setLevel(logging.DEBUG) if header: self.header = header @@ -109,6 +130,11 @@ class WebsocketClient(object): self._record_last_sent_text(text) return self._send_text(text) + def _log(self, msg, *args): + logger = self.logger + if logger: + logger.debug(msg, *args) + def _send_text(self, text: str): """ Send a text string to server. @@ -116,6 +142,7 @@ class WebsocketClient(object): ws = self._ws if ws: ws.send(text, opcode=websocket.ABNF.OPCODE_TEXT) + self._log('sent text: %s', text) def _send_binary(self, data: bytes): """ @@ -124,6 +151,7 @@ class WebsocketClient(object): ws = self._ws if ws: ws._send_binary(data) + self._log('sent binary: %s', data) def _create_connection(self, *args, **kwargs): """""" @@ -184,6 +212,7 @@ class WebsocketClient(object): print("websocket unable to parse data: " + text) raise e + self._log('recv data: %s', data) self.on_packet(data) # ws is closed before recv function is called # For socket.error, see Issue #1608 diff --git a/vnpy/trader/utility.py b/vnpy/trader/utility.py index 38a6e7db..51471b70 100644 --- a/vnpy/trader/utility.py +++ b/vnpy/trader/utility.py @@ -3,8 +3,9 @@ General utility functions. """ import json +import logging from pathlib import Path -from typing import Callable +from typing import Callable, Dict from decimal import Decimal import numpy as np @@ -14,6 +15,9 @@ from .object import BarData, TickData from .constant import Exchange, Interval +log_formatter = logging.Formatter('[%(asctime)s] %(message)s') + + def extract_vt_symbol(vt_symbol: str): """ :return: (symbol, exchange) @@ -461,3 +465,25 @@ def virtual(func: "callable"): that can be (re)implemented by subclasses. """ return func + + +file_handlers: Dict[str, logging.FileHandler] = {} + + +def _get_file_logger_handler(filename: str): + handler = file_handlers.get(filename, None) + if handler is None: + handler = logging.FileHandler(filename) + file_handlers[filename] = handler # Am i need a lock? + return handler + + +def get_file_logger(filename: str): + """ + return a logger that writes records into a file. + """ + logger = logging.getLogger(filename) + handler = _get_file_logger_handler(filename) # get singleton handler. + handler.setFormatter(log_formatter) + logger.addHandler(handler) # each handler will be added only once. + return logger