From eec5d0a7692b8ddac92199c0b87eec4588fae0d9 Mon Sep 17 00:00:00 2001 From: "vn.py" Date: Thu, 5 Sep 2019 13:13:39 +0800 Subject: [PATCH] [Mod] improve code quality of coinbase gateway --- vnpy/gateway/coinbase/coinbase_gateway.py | 210 +++++++--------------- 1 file changed, 66 insertions(+), 144 deletions(-) diff --git a/vnpy/gateway/coinbase/coinbase_gateway.py b/vnpy/gateway/coinbase/coinbase_gateway.py index 6438aba6..6c262cae 100644 --- a/vnpy/gateway/coinbase/coinbase_gateway.py +++ b/vnpy/gateway/coinbase/coinbase_gateway.py @@ -5,7 +5,6 @@ import sys import time from copy import copy from datetime import datetime, timedelta -from threading import Lock import base64 import uuid @@ -112,12 +111,12 @@ class CoinbaseGateway(BaseGateway): proxy_port = 0 self.rest_api.connect( - key, - secret, - passphrase, - session_number, + key, + secret, + passphrase, + session_number, server, - proxy_host, + proxy_host, proxy_port ) @@ -127,7 +126,10 @@ class CoinbaseGateway(BaseGateway): passphrase, server, proxy_host, - proxy_port) + proxy_port + ) + + self.init_query() def subscribe(self, req: SubscribeRequest): """""" @@ -164,12 +166,11 @@ class CoinbaseGateway(BaseGateway): def process_timer_event(self, event: Event): """""" - self.rest_api.reset_rate_limit() - self.init_query() + self.rest_api.query_account() def init_query(self): """""" - self.rest_api.query_account() + self.event_engine.register(EVENT_TIMER, self.process_timer_event) class CoinbaseWebsocketApi(WebsocketClient): @@ -198,18 +199,17 @@ class CoinbaseWebsocketApi(WebsocketClient): "match": self.on_order_match, } - self.ticks = {} - self.accounts = {} self.orderbooks = {} def connect( - self, - key: str, - secret: str, - passphrase: str, - server: str, - proxy_host: str, - proxy_port: int): + self, + key: str, + secret: str, + passphrase: str, + server: str, + proxy_host: str, + proxy_port: int + ): """""" self.gateway.write_log("开始连接ws接口") self.key = key @@ -220,9 +220,8 @@ class CoinbaseWebsocketApi(WebsocketClient): self.init(WEBSOCKET_HOST, proxy_host, proxy_port) else: self.init(SANDBOX_WEBSOCKET_HOST, proxy_host, proxy_port) - self.start() - self.gateway.event_engine.register(EVENT_TIMER, self.gateway.process_timer_event) + self.start() def subscribe(self, req: SubscribeRequest): """""" @@ -240,12 +239,15 @@ class CoinbaseWebsocketApi(WebsocketClient): timestamp = str(time.time()) message = timestamp + 'GET' + '/users/self/verify' + auth_headers = get_auth_header( timestamp, message, self.key, self.secret, - self.passphrase) + self.passphrase + ) + sub_req['signature'] = auth_headers['CB-ACCESS-SIGN'] sub_req['key'] = auth_headers['CB-ACCESS-KEY'] sub_req['passphrase'] = auth_headers['CB-ACCESS-PASSPHRASE'] @@ -340,11 +342,14 @@ class CoinbaseWebsocketApi(WebsocketClient): order = orderDict.get(packet['order_id'], None) if not order: return + order.traded = order.volume - float(packet['remaining_size']) + if packet['reason'] == 'filled': order.status = Status.ALLTRADED else: order.status = Status.CANCELLED + self.gateway.on_order(copy(order)) def on_order_match(self, packet: dict): @@ -381,13 +386,15 @@ class OrderBook(): self.asks = dict() self.bids = dict() self.gateway = gateway - self.newest_tick = TickData( - symbol=symbol, - exchange=exchange, + + self.tick = TickData( + symbol=symbol, + exchange=exchange, name=symbol_name_map.get(symbol, ""), datetime=datetime.now(), gateway_name=gateway.gateway_name, ) + self.first_update = False def on_message(self, d: dict): @@ -430,12 +437,11 @@ class OrderBook(): self.generate_tick(dt) - def on_ticker(self, d: dict): """ call back when type is ticker """ - tick = self.newest_tick + tick = self.tick tick.open_price = float(d['open_24h']) tick.high_price = float(d['high_24h']) @@ -457,7 +463,7 @@ class OrderBook(): def generate_tick(self, dt: datetime): """""" - tick = self.newest_tick + tick = self.tick bids_keys = self.bids.keys() bids_keys = sorted(bids_keys, reverse=True) @@ -509,16 +515,8 @@ class CoinbaseRestApi(RestClient): self.secret = "" self.passphrase = "" - self.order_count = 1_000_000 - self.order_count_lock = Lock() - - self.connect_time = 0 - self.accounts = {} - self.rate_limit = 5 - self.rate_limit_remaining = 5 - def sign(self, request): """ Generate Coinbase signature @@ -549,17 +547,15 @@ class CoinbaseRestApi(RestClient): self.secret = secret.encode() self.passphrase = passphrase - self.connect_time = ( - int(datetime.now().strftime("%y%m%d%H%M%S")) * self.order_count - ) if server == "REAL": self.init(REST_HOST, proxy_host, proxy_port) else: self.init(SANDBOX_REST_HOST, proxy_host, proxy_port) self.start(session_number) + self.query_instrument() - self.query_orders() + self.query_order() self.gateway.write_log("REST API启动成功") @@ -570,19 +566,15 @@ class CoinbaseRestApi(RestClient): self.add_request( "GET", "/products", - callback=self.on_query_instrument, - params={}, - on_error=self.on_query_instrument_error, + callback=self.on_query_instrument ) - def query_orders(self): + def query_order(self): """""" self.add_request( "GET", - "/orders?status=all", - callback=self.on_query_orders, - params={}, - on_error=self.on_query_orders_error, + "/orders?status=open", + callback=self.on_query_order ) def query_account(self): @@ -591,8 +583,6 @@ class CoinbaseRestApi(RestClient): "GET", "/accounts", callback=self.on_query_account, - params={}, - on_error=self.on_query_account_error, ) def on_query_account(self, data, request): @@ -614,30 +604,11 @@ class CoinbaseRestApi(RestClient): self.gateway.on_account(copy(account)) - def on_query_account_error( - self, - exception_type: type, - exception_value: Exception, - tb, - request): - """""" - if not issubclass(exception_type, ConnectionError): - self.on_error(exception_type, exception_value, tb, request) - - def on_query_orders_error( - self, - exception_type: type, - exception_value: Exception, - tb, - request): - """""" - if not issubclass(exception_type, ConnectionError): - self.on_error(exception_type, exception_value, tb, request) - - def on_query_orders(self, data, request): + def on_query_order(self, data, request): """""" for d in data: date, time = d['created_at'].split('T') + if d['status'] == 'open': if not float(d['filled_size']): status = Status.NOTTRADED @@ -648,6 +619,7 @@ class CoinbaseRestApi(RestClient): status = Status.ALLTRADED else: status = Status.CANCELLED + order = OrderData( symbol=d['product_id'], gateway_name=self.gateway_name, @@ -667,19 +639,6 @@ class CoinbaseRestApi(RestClient): self.gateway.write_log(u'委托信息查询成功') - def on_query_instrument_error( - self, - exception_type: type, - exception_value: Exception, - tb, - request: Request): - """ - Callback when sending order caused exception. - """ - # Record exception if not ConnectionError - if not issubclass(exception_type, ConnectionError): - self.on_error(exception_type, exception_value, tb, request) - def on_query_instrument(self, data, request): """""" for d in data: @@ -703,9 +662,6 @@ class CoinbaseRestApi(RestClient): def send_order(self, req: OrderRequest): """""" - if not self.check_rate_limit(): - return - orderid = str(uuid.uuid1()) data = { @@ -719,7 +675,6 @@ class CoinbaseRestApi(RestClient): if req.type == OrderType.LIMIT: data['price'] = req.price - order = req.create_order_data(orderid, self.gateway_name) self.add_request( "POST", @@ -732,6 +687,7 @@ class CoinbaseRestApi(RestClient): on_error=self.on_send_order_error, ) + self.gateway.on_order(order) return order.vt_orderid def on_send_order_failed(self, status_code: str, request: Request): @@ -752,11 +708,12 @@ class CoinbaseRestApi(RestClient): self.gateway.write_log(msg) def on_send_order_error( - self, - exception_type: type, - exception_value: Exception, - tb, - request: Request): + self, + exception_type: type, + exception_value: Exception, + tb, + request: Request + ): """ Callback when sending order caused exception. """ @@ -775,36 +732,15 @@ class CoinbaseRestApi(RestClient): def cancel_order(self, req: CancelRequest): """""" - if not self.check_rate_limit(): - return - orderid = req.orderid - if orderid not in orderSysDict: - cancelDict[orderid] = req - self.add_request( "DELETE", - "/orders/" + orderid, + f"/orders/client:{orderid}", callback=self.on_cancel_order, - params={}, - on_error=self.on_cancel_order_error, on_failed=self.on_cancel_order_failed, ) - def on_cancel_order_error( - self, - exception_type: type, - exception_value: Exception, - tb, - request: Request): - """ - Callback when cancelling order failed on server. - """ - # Record exception if not ConnectionError - if not issubclass(exception_type, ConnectionError): - self.on_error(exception_type, exception_value, tb, request) - def on_cancel_order(self, data, request): """Websocket will push a new order status""" pass @@ -816,9 +752,9 @@ class CoinbaseRestApi(RestClient): if request.response.text: data = request.response.json() error = data["message"] - msg = f"委托失败,状态码:{status_code},信息:{error}" + msg = f"撤单失败,状态码:{status_code},信息:{error}" else: - msg = f"委托失败,状态码:{status_code}" + msg = f"撤单失败,状态码:{status_code}" self.gateway.write_log(msg) @@ -832,11 +768,12 @@ class CoinbaseRestApi(RestClient): self.gateway.write_log(msg) def on_error( - self, - exception_type: type, - exception_value: Exception, - tb, - request: Request): + self, + exception_type: type, + exception_value: Exception, + tb, + request: Request + ): """ Callback to handler request exception. """ @@ -847,35 +784,20 @@ class CoinbaseRestApi(RestClient): self.exception_detail(exception_type, exception_value, tb, request) ) - def reset_rate_limit(self): - """ - reset the rate limit every 1 sec - """ - self.rate_limit_remaining = 5 - - def check_rate_limit(self): - """ - Called before send requests - """ - if self.rate_limit_remaining: - self.rate_limit_remaining -= 1 - return True - else: - self.gateway.write_log("已超出请求速率上限,请稍后重试") - return False - def get_auth_header( - timestamp, - message, - api_key, - secret_key, - passphrase): + timestamp, + message, + api_key, + secret_key, + passphrase +): """""" message = message.encode("ascii") hmac_key = base64.b64decode(secret_key) signature = hmac.new(hmac_key, message, hashlib.sha256) signature_b64 = base64.b64encode(signature.digest()).decode('utf-8') + return{ 'Content-Type': 'Application/JSON', 'CB-ACCESS-SIGN': signature_b64,