[Mod] improve code quality of coinbase gateway

This commit is contained in:
vn.py 2019-09-05 13:13:39 +08:00
parent 77c2b52842
commit eec5d0a769

View File

@ -5,7 +5,6 @@ import sys
import time import time
from copy import copy from copy import copy
from datetime import datetime, timedelta from datetime import datetime, timedelta
from threading import Lock
import base64 import base64
import uuid import uuid
@ -127,7 +126,10 @@ class CoinbaseGateway(BaseGateway):
passphrase, passphrase,
server, server,
proxy_host, proxy_host,
proxy_port) proxy_port
)
self.init_query()
def subscribe(self, req: SubscribeRequest): def subscribe(self, req: SubscribeRequest):
"""""" """"""
@ -164,12 +166,11 @@ class CoinbaseGateway(BaseGateway):
def process_timer_event(self, event: Event): def process_timer_event(self, event: Event):
"""""" """"""
self.rest_api.reset_rate_limit() self.rest_api.query_account()
self.init_query()
def init_query(self): def init_query(self):
"""""" """"""
self.rest_api.query_account() self.event_engine.register(EVENT_TIMER, self.process_timer_event)
class CoinbaseWebsocketApi(WebsocketClient): class CoinbaseWebsocketApi(WebsocketClient):
@ -198,8 +199,6 @@ class CoinbaseWebsocketApi(WebsocketClient):
"match": self.on_order_match, "match": self.on_order_match,
} }
self.ticks = {}
self.accounts = {}
self.orderbooks = {} self.orderbooks = {}
def connect( def connect(
@ -209,7 +208,8 @@ class CoinbaseWebsocketApi(WebsocketClient):
passphrase: str, passphrase: str,
server: str, server: str,
proxy_host: str, proxy_host: str,
proxy_port: int): proxy_port: int
):
"""""" """"""
self.gateway.write_log("开始连接ws接口") self.gateway.write_log("开始连接ws接口")
self.key = key self.key = key
@ -220,9 +220,8 @@ class CoinbaseWebsocketApi(WebsocketClient):
self.init(WEBSOCKET_HOST, proxy_host, proxy_port) self.init(WEBSOCKET_HOST, proxy_host, proxy_port)
else: else:
self.init(SANDBOX_WEBSOCKET_HOST, proxy_host, proxy_port) self.init(SANDBOX_WEBSOCKET_HOST, proxy_host, proxy_port)
self.start()
self.gateway.event_engine.register(EVENT_TIMER, self.gateway.process_timer_event) self.start()
def subscribe(self, req: SubscribeRequest): def subscribe(self, req: SubscribeRequest):
"""""" """"""
@ -240,12 +239,15 @@ class CoinbaseWebsocketApi(WebsocketClient):
timestamp = str(time.time()) timestamp = str(time.time())
message = timestamp + 'GET' + '/users/self/verify' message = timestamp + 'GET' + '/users/self/verify'
auth_headers = get_auth_header( auth_headers = get_auth_header(
timestamp, timestamp,
message, message,
self.key, self.key,
self.secret, self.secret,
self.passphrase) self.passphrase
)
sub_req['signature'] = auth_headers['CB-ACCESS-SIGN'] sub_req['signature'] = auth_headers['CB-ACCESS-SIGN']
sub_req['key'] = auth_headers['CB-ACCESS-KEY'] sub_req['key'] = auth_headers['CB-ACCESS-KEY']
sub_req['passphrase'] = auth_headers['CB-ACCESS-PASSPHRASE'] sub_req['passphrase'] = auth_headers['CB-ACCESS-PASSPHRASE']
@ -340,11 +342,14 @@ class CoinbaseWebsocketApi(WebsocketClient):
order = orderDict.get(packet['order_id'], None) order = orderDict.get(packet['order_id'], None)
if not order: if not order:
return return
order.traded = order.volume - float(packet['remaining_size']) order.traded = order.volume - float(packet['remaining_size'])
if packet['reason'] == 'filled': if packet['reason'] == 'filled':
order.status = Status.ALLTRADED order.status = Status.ALLTRADED
else: else:
order.status = Status.CANCELLED order.status = Status.CANCELLED
self.gateway.on_order(copy(order)) self.gateway.on_order(copy(order))
def on_order_match(self, packet: dict): def on_order_match(self, packet: dict):
@ -381,13 +386,15 @@ class OrderBook():
self.asks = dict() self.asks = dict()
self.bids = dict() self.bids = dict()
self.gateway = gateway self.gateway = gateway
self.newest_tick = TickData(
self.tick = TickData(
symbol=symbol, symbol=symbol,
exchange=exchange, exchange=exchange,
name=symbol_name_map.get(symbol, ""), name=symbol_name_map.get(symbol, ""),
datetime=datetime.now(), datetime=datetime.now(),
gateway_name=gateway.gateway_name, gateway_name=gateway.gateway_name,
) )
self.first_update = False self.first_update = False
def on_message(self, d: dict): def on_message(self, d: dict):
@ -430,12 +437,11 @@ class OrderBook():
self.generate_tick(dt) self.generate_tick(dt)
def on_ticker(self, d: dict): def on_ticker(self, d: dict):
""" """
call back when type is ticker call back when type is ticker
""" """
tick = self.newest_tick tick = self.tick
tick.open_price = float(d['open_24h']) tick.open_price = float(d['open_24h'])
tick.high_price = float(d['high_24h']) tick.high_price = float(d['high_24h'])
@ -457,7 +463,7 @@ class OrderBook():
def generate_tick(self, dt: datetime): def generate_tick(self, dt: datetime):
"""""" """"""
tick = self.newest_tick tick = self.tick
bids_keys = self.bids.keys() bids_keys = self.bids.keys()
bids_keys = sorted(bids_keys, reverse=True) bids_keys = sorted(bids_keys, reverse=True)
@ -509,16 +515,8 @@ class CoinbaseRestApi(RestClient):
self.secret = "" self.secret = ""
self.passphrase = "" self.passphrase = ""
self.order_count = 1_000_000
self.order_count_lock = Lock()
self.connect_time = 0
self.accounts = {} self.accounts = {}
self.rate_limit = 5
self.rate_limit_remaining = 5
def sign(self, request): def sign(self, request):
""" """
Generate Coinbase signature Generate Coinbase signature
@ -549,17 +547,15 @@ class CoinbaseRestApi(RestClient):
self.secret = secret.encode() self.secret = secret.encode()
self.passphrase = passphrase self.passphrase = passphrase
self.connect_time = (
int(datetime.now().strftime("%y%m%d%H%M%S")) * self.order_count
)
if server == "REAL": if server == "REAL":
self.init(REST_HOST, proxy_host, proxy_port) self.init(REST_HOST, proxy_host, proxy_port)
else: else:
self.init(SANDBOX_REST_HOST, proxy_host, proxy_port) self.init(SANDBOX_REST_HOST, proxy_host, proxy_port)
self.start(session_number) self.start(session_number)
self.query_instrument() self.query_instrument()
self.query_orders() self.query_order()
self.gateway.write_log("REST API启动成功") self.gateway.write_log("REST API启动成功")
@ -570,19 +566,15 @@ class CoinbaseRestApi(RestClient):
self.add_request( self.add_request(
"GET", "GET",
"/products", "/products",
callback=self.on_query_instrument, callback=self.on_query_instrument
params={},
on_error=self.on_query_instrument_error,
) )
def query_orders(self): def query_order(self):
"""""" """"""
self.add_request( self.add_request(
"GET", "GET",
"/orders?status=all", "/orders?status=open",
callback=self.on_query_orders, callback=self.on_query_order
params={},
on_error=self.on_query_orders_error,
) )
def query_account(self): def query_account(self):
@ -591,8 +583,6 @@ class CoinbaseRestApi(RestClient):
"GET", "GET",
"/accounts", "/accounts",
callback=self.on_query_account, callback=self.on_query_account,
params={},
on_error=self.on_query_account_error,
) )
def on_query_account(self, data, request): def on_query_account(self, data, request):
@ -614,30 +604,11 @@ class CoinbaseRestApi(RestClient):
self.gateway.on_account(copy(account)) self.gateway.on_account(copy(account))
def on_query_account_error( def on_query_order(self, data, request):
self,
exception_type: type,
exception_value: Exception,
tb,
request):
""""""
if not issubclass(exception_type, ConnectionError):
self.on_error(exception_type, exception_value, tb, request)
def on_query_orders_error(
self,
exception_type: type,
exception_value: Exception,
tb,
request):
""""""
if not issubclass(exception_type, ConnectionError):
self.on_error(exception_type, exception_value, tb, request)
def on_query_orders(self, data, request):
"""""" """"""
for d in data: for d in data:
date, time = d['created_at'].split('T') date, time = d['created_at'].split('T')
if d['status'] == 'open': if d['status'] == 'open':
if not float(d['filled_size']): if not float(d['filled_size']):
status = Status.NOTTRADED status = Status.NOTTRADED
@ -648,6 +619,7 @@ class CoinbaseRestApi(RestClient):
status = Status.ALLTRADED status = Status.ALLTRADED
else: else:
status = Status.CANCELLED status = Status.CANCELLED
order = OrderData( order = OrderData(
symbol=d['product_id'], symbol=d['product_id'],
gateway_name=self.gateway_name, gateway_name=self.gateway_name,
@ -667,19 +639,6 @@ class CoinbaseRestApi(RestClient):
self.gateway.write_log(u'委托信息查询成功') self.gateway.write_log(u'委托信息查询成功')
def on_query_instrument_error(
self,
exception_type: type,
exception_value: Exception,
tb,
request: Request):
"""
Callback when sending order caused exception.
"""
# Record exception if not ConnectionError
if not issubclass(exception_type, ConnectionError):
self.on_error(exception_type, exception_value, tb, request)
def on_query_instrument(self, data, request): def on_query_instrument(self, data, request):
"""""" """"""
for d in data: for d in data:
@ -703,9 +662,6 @@ class CoinbaseRestApi(RestClient):
def send_order(self, req: OrderRequest): def send_order(self, req: OrderRequest):
"""""" """"""
if not self.check_rate_limit():
return
orderid = str(uuid.uuid1()) orderid = str(uuid.uuid1())
data = { data = {
@ -719,7 +675,6 @@ class CoinbaseRestApi(RestClient):
if req.type == OrderType.LIMIT: if req.type == OrderType.LIMIT:
data['price'] = req.price data['price'] = req.price
order = req.create_order_data(orderid, self.gateway_name) order = req.create_order_data(orderid, self.gateway_name)
self.add_request( self.add_request(
"POST", "POST",
@ -732,6 +687,7 @@ class CoinbaseRestApi(RestClient):
on_error=self.on_send_order_error, on_error=self.on_send_order_error,
) )
self.gateway.on_order(order)
return order.vt_orderid return order.vt_orderid
def on_send_order_failed(self, status_code: str, request: Request): def on_send_order_failed(self, status_code: str, request: Request):
@ -756,7 +712,8 @@ class CoinbaseRestApi(RestClient):
exception_type: type, exception_type: type,
exception_value: Exception, exception_value: Exception,
tb, tb,
request: Request): request: Request
):
""" """
Callback when sending order caused exception. Callback when sending order caused exception.
""" """
@ -775,36 +732,15 @@ class CoinbaseRestApi(RestClient):
def cancel_order(self, req: CancelRequest): def cancel_order(self, req: CancelRequest):
"""""" """"""
if not self.check_rate_limit():
return
orderid = req.orderid orderid = req.orderid
if orderid not in orderSysDict:
cancelDict[orderid] = req
self.add_request( self.add_request(
"DELETE", "DELETE",
"/orders/" + orderid, f"/orders/client:{orderid}",
callback=self.on_cancel_order, callback=self.on_cancel_order,
params={},
on_error=self.on_cancel_order_error,
on_failed=self.on_cancel_order_failed, on_failed=self.on_cancel_order_failed,
) )
def on_cancel_order_error(
self,
exception_type: type,
exception_value: Exception,
tb,
request: Request):
"""
Callback when cancelling order failed on server.
"""
# Record exception if not ConnectionError
if not issubclass(exception_type, ConnectionError):
self.on_error(exception_type, exception_value, tb, request)
def on_cancel_order(self, data, request): def on_cancel_order(self, data, request):
"""Websocket will push a new order status""" """Websocket will push a new order status"""
pass pass
@ -816,9 +752,9 @@ class CoinbaseRestApi(RestClient):
if request.response.text: if request.response.text:
data = request.response.json() data = request.response.json()
error = data["message"] error = data["message"]
msg = f"委托失败,状态码:{status_code},信息:{error}" msg = f"撤单失败,状态码:{status_code},信息:{error}"
else: else:
msg = f"委托失败,状态码:{status_code}" msg = f"撤单失败,状态码:{status_code}"
self.gateway.write_log(msg) self.gateway.write_log(msg)
@ -836,7 +772,8 @@ class CoinbaseRestApi(RestClient):
exception_type: type, exception_type: type,
exception_value: Exception, exception_value: Exception,
tb, tb,
request: Request): request: Request
):
""" """
Callback to handler request exception. Callback to handler request exception.
""" """
@ -847,35 +784,20 @@ class CoinbaseRestApi(RestClient):
self.exception_detail(exception_type, exception_value, tb, request) self.exception_detail(exception_type, exception_value, tb, request)
) )
def reset_rate_limit(self):
"""
reset the rate limit every 1 sec
"""
self.rate_limit_remaining = 5
def check_rate_limit(self):
"""
Called before send requests
"""
if self.rate_limit_remaining:
self.rate_limit_remaining -= 1
return True
else:
self.gateway.write_log("已超出请求速率上限,请稍后重试")
return False
def get_auth_header( def get_auth_header(
timestamp, timestamp,
message, message,
api_key, api_key,
secret_key, secret_key,
passphrase): passphrase
):
"""""" """"""
message = message.encode("ascii") message = message.encode("ascii")
hmac_key = base64.b64decode(secret_key) hmac_key = base64.b64decode(secret_key)
signature = hmac.new(hmac_key, message, hashlib.sha256) signature = hmac.new(hmac_key, message, hashlib.sha256)
signature_b64 = base64.b64encode(signature.digest()).decode('utf-8') signature_b64 = base64.b64encode(signature.digest()).decode('utf-8')
return{ return{
'Content-Type': 'Application/JSON', 'Content-Type': 'Application/JSON',
'CB-ACCESS-SIGN': signature_b64, 'CB-ACCESS-SIGN': signature_b64,