from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals import json import random import time from builtins import * import zmq try: import queue except ImportError: import queue as queue import threading import msgpack import snappy import copy qEmpty = copy.copy(queue.Empty) def _unpack_msgpack_snappy(str): if str.startswith(b'S'): tmp = snappy.uncompress(str[1:]) # print "SNAPPY: ", len(str), len(tmp) obj = msgpack.loads(tmp, encoding='utf-8') elif str.startswith(b'\0'): obj = msgpack.loads(str[1:], encoding='utf-8') else: return None return obj def _pack_msgpack_snappy(obj): # print "pack", obj tmp = msgpack.dumps(obj, encoding='utf-8') if len(tmp) > 1000: return b'S' + snappy.compress(tmp) else: return b'\0' + tmp def _unpack_msgpack(str): return msgpack.loads(str, encoding='utf-8') def _pack_msgpack(obj): return msgpack.dumps(obj, encoding='utf-8') def _unpack_json(str): return json.loads(str, encoding='utf-8') def _pack_json(obj): return json.dumps(obj, encoding='utf-8') class JRpcClient(object): def __init__(self, data_format="msgpack_snappy"): self._waiter_lock = threading.Lock() self._waiter_map = {} self._should_close = False self._next_callid = 0 self._send_lock = threading.Lock() self._callid_lock = threading.Lock() self._last_heartbeat_rsp_time = 0 self._connected = False self.on_disconnected = None self.on_rpc_callback = None self._callback_queue = queue.Queue() self._call_wait_queue = queue.Queue() self._ctx = zmq.Context() self._pull_sock = self._ctx.socket(zmq.PULL) self._pull_sock.bind("inproc://pull_sock") self._push_sock = self._ctx.socket(zmq.PUSH) self._push_sock.connect("inproc://pull_sock") self._heartbeat_interval = 1 self._heartbeat_timeout = 3 self._addr = None if data_format == "msgpack_snappy": self._pack = _pack_msgpack_snappy self._unpack = _unpack_msgpack_snappy elif data_format == "msgpack": self._pack = _pack_msgpack self._unpack = _unpack_msgpack elif data_format == "json": self._pack = _pack_json self._unpack = _unpack_json else: assert False, "unknown data_format " + data_format t = threading.Thread(target=self._recv_run) t.setDaemon(True) t.start() self._recv_thread = t t = threading.Thread(target=self._callback_run) t.setDaemon(True) t.start() self._callback_thread = t def __del__(self): self.close() def next_callid(self): self._callid_lock.acquire() self._next_callid += 1 callid = self._next_callid self._callid_lock.release() return callid def set_heartbeat_options(self, interval, timeout): self._heartbeat_interval = interval self._heartbeat_timeout = timeout def _recv_run(self): heartbeat_time = 0 poller = zmq.Poller() poller.register(self._pull_sock, zmq.POLLIN) remote_sock = None while not self._should_close: try: if self._connected and time.time() - self._last_heartbeat_rsp_time > self._heartbeat_timeout: self._connected = False if self.on_disconnected: self._async_call(self.on_disconnected) if remote_sock and time.time() - heartbeat_time > self._heartbeat_interval: self._send_hearbeat() heartbeat_time = time.time() socks = dict(poller.poll(500)) if self._pull_sock in socks and socks[self._pull_sock] == zmq.POLLIN: cmd = self._pull_sock.recv() if cmd == b"CONNECT": # print time.ctime(), "CONNECT " + self._addr if remote_sock: poller.unregister(remote_sock) remote_sock.close() remote_sock = None remote_sock = self._do_connect() if remote_sock: poller.register(remote_sock, zmq.POLLIN) elif cmd.startswith(b"SEND:") and remote_sock: # print time.ctime(), "SEND " + cmd[5:] remote_sock.send(cmd[5:]) if remote_sock and remote_sock in socks and socks[remote_sock] == zmq.POLLIN: data = remote_sock.recv() if data: # if not data.find("heartbeat"): # print time.ctime(), "RECV", data self._on_data_arrived(data) except zmq.error.Again as e: # print "RECV timeout: ", e pass except Exception as e: print("_recv_run:", e) def _callback_run(self): while not self._should_close: try: r = self._callback_queue.get(timeout=1) if r: r() except qEmpty as e: pass except TypeError as e: if str(e) == "'NoneType' object is not callable": pass else: print("_callback_run {}".format(r), type(e), e) except Exception as e: print("_callback_run {}".format(r), type(e), e) def _async_call(self, func): self._callback_queue.put(func) def _send_request(self, json): try: self._send_lock.acquire() self._push_sock.send(b"SEND:" + json) finally: self._send_lock.release() def connect(self, addr): self._addr = addr self._push_sock.send_string('CONNECT', encoding='utf-8') def _do_connect(self): client_id = str(random.randint(1000000, 100000000)) socket = self._ctx.socket(zmq.DEALER) identity = (client_id) + '$' + str(random.randint(1000000, 1000000000)) identity = identity.encode('utf-8') socket.setsockopt(zmq.IDENTITY, identity) socket.setsockopt(zmq.RCVTIMEO, 500) socket.setsockopt(zmq.SNDTIMEO, 500) socket.setsockopt(zmq.LINGER, 0) socket.connect(self._addr) return socket def close(self): self._should_close = True self._callback_thread.join() self._recv_thread.join() def _on_data_arrived(self, str): try: msg = self._unpack(str) # print "RECV", msg if not msg: print("wrong message format") return if 'method' in msg and msg['method'] == '.sys.heartbeat': self._last_heartbeat_rsp_time = time.time() if not self._connected: self._connected = True if self.on_connected: self._async_call(self.on_connected) # Let user has a chance to check message in .sys.heartbeat if 'result' in msg and self.on_rpc_callback: self._async_call(lambda: self.on_rpc_callback(msg['method'], msg['result'])) elif 'id' in msg and msg['id']: # Call result id = int(msg['id']) if self._waiter_lock.acquire(): if id in self._waiter_map: q = self._waiter_map[id] if q: q.put(msg) self._waiter_lock.release() else: # Notification message if 'method' in msg and 'result' in msg and self.on_rpc_callback: self._async_call(lambda: self.on_rpc_callback(msg['method'], msg['result'])) except Exception as e: print("_on_data_arrived:", e) pass def _send_hearbeat(self): msg = {'jsonrpc': '2.0', 'method': '.sys.heartbeat', 'params': {'time': time.time()}, 'id': str(self.next_callid())} json_str = self._pack(msg) self._send_request(json_str) def _alloc_wait_queue(self): self._waiter_lock.acquire() if self._call_wait_queue: q = self._call_wait_queue self._call_wait_queue = None else: q = queue.Queue() self._waiter_lock.release() return q def _free_wait_queue(self, q): self._waiter_lock.acquire() if not self._call_wait_queue: self._call_wait_queue = q else: del q self._waiter_lock.release() def call(self, method, params, timeout=6): # print "call", method, params, timeout callid = self.next_callid() if timeout: q = self._alloc_wait_queue() self._waiter_lock.acquire() self._waiter_map[callid] = q self._waiter_lock.release() msg = {'jsonrpc': '2.0', 'method': method, 'params': params, 'id': str(callid)} # print "SEND", msg json_str = self._pack(msg) self._send_request(json_str) if timeout: ret = {} try: r = q.get(timeout=timeout) q.task_done() except qEmpty: r = None self._waiter_lock.acquire() self._waiter_map[callid] = None self._waiter_lock.release() self._free_wait_queue(q) if r: if 'result' in r: ret['result'] = r['result'] if 'error' in r: ret['error'] = r['error'] return ret if ret else {'error': {'error': -1, 'message': "timeout"}} else: return {'result': True}