[Mod] improve code quality of coinbase gateway
This commit is contained in:
parent
77c2b52842
commit
eec5d0a769
@ -5,7 +5,6 @@ import sys
|
||||
import time
|
||||
from copy import copy
|
||||
from datetime import datetime, timedelta
|
||||
from threading import Lock
|
||||
import base64
|
||||
import uuid
|
||||
|
||||
@ -127,7 +126,10 @@ class CoinbaseGateway(BaseGateway):
|
||||
passphrase,
|
||||
server,
|
||||
proxy_host,
|
||||
proxy_port)
|
||||
proxy_port
|
||||
)
|
||||
|
||||
self.init_query()
|
||||
|
||||
def subscribe(self, req: SubscribeRequest):
|
||||
""""""
|
||||
@ -164,12 +166,11 @@ class CoinbaseGateway(BaseGateway):
|
||||
|
||||
def process_timer_event(self, event: Event):
|
||||
""""""
|
||||
self.rest_api.reset_rate_limit()
|
||||
self.init_query()
|
||||
self.rest_api.query_account()
|
||||
|
||||
def init_query(self):
|
||||
""""""
|
||||
self.rest_api.query_account()
|
||||
self.event_engine.register(EVENT_TIMER, self.process_timer_event)
|
||||
|
||||
|
||||
class CoinbaseWebsocketApi(WebsocketClient):
|
||||
@ -198,8 +199,6 @@ class CoinbaseWebsocketApi(WebsocketClient):
|
||||
"match": self.on_order_match,
|
||||
}
|
||||
|
||||
self.ticks = {}
|
||||
self.accounts = {}
|
||||
self.orderbooks = {}
|
||||
|
||||
def connect(
|
||||
@ -209,7 +208,8 @@ class CoinbaseWebsocketApi(WebsocketClient):
|
||||
passphrase: str,
|
||||
server: str,
|
||||
proxy_host: str,
|
||||
proxy_port: int):
|
||||
proxy_port: int
|
||||
):
|
||||
""""""
|
||||
self.gateway.write_log("开始连接ws接口")
|
||||
self.key = key
|
||||
@ -220,9 +220,8 @@ class CoinbaseWebsocketApi(WebsocketClient):
|
||||
self.init(WEBSOCKET_HOST, proxy_host, proxy_port)
|
||||
else:
|
||||
self.init(SANDBOX_WEBSOCKET_HOST, proxy_host, proxy_port)
|
||||
self.start()
|
||||
|
||||
self.gateway.event_engine.register(EVENT_TIMER, self.gateway.process_timer_event)
|
||||
self.start()
|
||||
|
||||
def subscribe(self, req: SubscribeRequest):
|
||||
""""""
|
||||
@ -240,12 +239,15 @@ class CoinbaseWebsocketApi(WebsocketClient):
|
||||
|
||||
timestamp = str(time.time())
|
||||
message = timestamp + 'GET' + '/users/self/verify'
|
||||
|
||||
auth_headers = get_auth_header(
|
||||
timestamp,
|
||||
message,
|
||||
self.key,
|
||||
self.secret,
|
||||
self.passphrase)
|
||||
self.passphrase
|
||||
)
|
||||
|
||||
sub_req['signature'] = auth_headers['CB-ACCESS-SIGN']
|
||||
sub_req['key'] = auth_headers['CB-ACCESS-KEY']
|
||||
sub_req['passphrase'] = auth_headers['CB-ACCESS-PASSPHRASE']
|
||||
@ -340,11 +342,14 @@ class CoinbaseWebsocketApi(WebsocketClient):
|
||||
order = orderDict.get(packet['order_id'], None)
|
||||
if not order:
|
||||
return
|
||||
|
||||
order.traded = order.volume - float(packet['remaining_size'])
|
||||
|
||||
if packet['reason'] == 'filled':
|
||||
order.status = Status.ALLTRADED
|
||||
else:
|
||||
order.status = Status.CANCELLED
|
||||
|
||||
self.gateway.on_order(copy(order))
|
||||
|
||||
def on_order_match(self, packet: dict):
|
||||
@ -381,13 +386,15 @@ class OrderBook():
|
||||
self.asks = dict()
|
||||
self.bids = dict()
|
||||
self.gateway = gateway
|
||||
self.newest_tick = TickData(
|
||||
|
||||
self.tick = TickData(
|
||||
symbol=symbol,
|
||||
exchange=exchange,
|
||||
name=symbol_name_map.get(symbol, ""),
|
||||
datetime=datetime.now(),
|
||||
gateway_name=gateway.gateway_name,
|
||||
)
|
||||
|
||||
self.first_update = False
|
||||
|
||||
def on_message(self, d: dict):
|
||||
@ -430,12 +437,11 @@ class OrderBook():
|
||||
|
||||
self.generate_tick(dt)
|
||||
|
||||
|
||||
def on_ticker(self, d: dict):
|
||||
"""
|
||||
call back when type is ticker
|
||||
"""
|
||||
tick = self.newest_tick
|
||||
tick = self.tick
|
||||
|
||||
tick.open_price = float(d['open_24h'])
|
||||
tick.high_price = float(d['high_24h'])
|
||||
@ -457,7 +463,7 @@ class OrderBook():
|
||||
|
||||
def generate_tick(self, dt: datetime):
|
||||
""""""
|
||||
tick = self.newest_tick
|
||||
tick = self.tick
|
||||
|
||||
bids_keys = self.bids.keys()
|
||||
bids_keys = sorted(bids_keys, reverse=True)
|
||||
@ -509,16 +515,8 @@ class CoinbaseRestApi(RestClient):
|
||||
self.secret = ""
|
||||
self.passphrase = ""
|
||||
|
||||
self.order_count = 1_000_000
|
||||
self.order_count_lock = Lock()
|
||||
|
||||
self.connect_time = 0
|
||||
|
||||
self.accounts = {}
|
||||
|
||||
self.rate_limit = 5
|
||||
self.rate_limit_remaining = 5
|
||||
|
||||
def sign(self, request):
|
||||
"""
|
||||
Generate Coinbase signature
|
||||
@ -549,17 +547,15 @@ class CoinbaseRestApi(RestClient):
|
||||
self.secret = secret.encode()
|
||||
self.passphrase = passphrase
|
||||
|
||||
self.connect_time = (
|
||||
int(datetime.now().strftime("%y%m%d%H%M%S")) * self.order_count
|
||||
)
|
||||
if server == "REAL":
|
||||
self.init(REST_HOST, proxy_host, proxy_port)
|
||||
else:
|
||||
self.init(SANDBOX_REST_HOST, proxy_host, proxy_port)
|
||||
|
||||
self.start(session_number)
|
||||
|
||||
self.query_instrument()
|
||||
self.query_orders()
|
||||
self.query_order()
|
||||
|
||||
self.gateway.write_log("REST API启动成功")
|
||||
|
||||
@ -570,19 +566,15 @@ class CoinbaseRestApi(RestClient):
|
||||
self.add_request(
|
||||
"GET",
|
||||
"/products",
|
||||
callback=self.on_query_instrument,
|
||||
params={},
|
||||
on_error=self.on_query_instrument_error,
|
||||
callback=self.on_query_instrument
|
||||
)
|
||||
|
||||
def query_orders(self):
|
||||
def query_order(self):
|
||||
""""""
|
||||
self.add_request(
|
||||
"GET",
|
||||
"/orders?status=all",
|
||||
callback=self.on_query_orders,
|
||||
params={},
|
||||
on_error=self.on_query_orders_error,
|
||||
"/orders?status=open",
|
||||
callback=self.on_query_order
|
||||
)
|
||||
|
||||
def query_account(self):
|
||||
@ -591,8 +583,6 @@ class CoinbaseRestApi(RestClient):
|
||||
"GET",
|
||||
"/accounts",
|
||||
callback=self.on_query_account,
|
||||
params={},
|
||||
on_error=self.on_query_account_error,
|
||||
)
|
||||
|
||||
def on_query_account(self, data, request):
|
||||
@ -614,30 +604,11 @@ class CoinbaseRestApi(RestClient):
|
||||
|
||||
self.gateway.on_account(copy(account))
|
||||
|
||||
def on_query_account_error(
|
||||
self,
|
||||
exception_type: type,
|
||||
exception_value: Exception,
|
||||
tb,
|
||||
request):
|
||||
""""""
|
||||
if not issubclass(exception_type, ConnectionError):
|
||||
self.on_error(exception_type, exception_value, tb, request)
|
||||
|
||||
def on_query_orders_error(
|
||||
self,
|
||||
exception_type: type,
|
||||
exception_value: Exception,
|
||||
tb,
|
||||
request):
|
||||
""""""
|
||||
if not issubclass(exception_type, ConnectionError):
|
||||
self.on_error(exception_type, exception_value, tb, request)
|
||||
|
||||
def on_query_orders(self, data, request):
|
||||
def on_query_order(self, data, request):
|
||||
""""""
|
||||
for d in data:
|
||||
date, time = d['created_at'].split('T')
|
||||
|
||||
if d['status'] == 'open':
|
||||
if not float(d['filled_size']):
|
||||
status = Status.NOTTRADED
|
||||
@ -648,6 +619,7 @@ class CoinbaseRestApi(RestClient):
|
||||
status = Status.ALLTRADED
|
||||
else:
|
||||
status = Status.CANCELLED
|
||||
|
||||
order = OrderData(
|
||||
symbol=d['product_id'],
|
||||
gateway_name=self.gateway_name,
|
||||
@ -667,19 +639,6 @@ class CoinbaseRestApi(RestClient):
|
||||
|
||||
self.gateway.write_log(u'委托信息查询成功')
|
||||
|
||||
def on_query_instrument_error(
|
||||
self,
|
||||
exception_type: type,
|
||||
exception_value: Exception,
|
||||
tb,
|
||||
request: Request):
|
||||
"""
|
||||
Callback when sending order caused exception.
|
||||
"""
|
||||
# Record exception if not ConnectionError
|
||||
if not issubclass(exception_type, ConnectionError):
|
||||
self.on_error(exception_type, exception_value, tb, request)
|
||||
|
||||
def on_query_instrument(self, data, request):
|
||||
""""""
|
||||
for d in data:
|
||||
@ -703,9 +662,6 @@ class CoinbaseRestApi(RestClient):
|
||||
|
||||
def send_order(self, req: OrderRequest):
|
||||
""""""
|
||||
if not self.check_rate_limit():
|
||||
return
|
||||
|
||||
orderid = str(uuid.uuid1())
|
||||
|
||||
data = {
|
||||
@ -719,7 +675,6 @@ class CoinbaseRestApi(RestClient):
|
||||
if req.type == OrderType.LIMIT:
|
||||
data['price'] = req.price
|
||||
|
||||
|
||||
order = req.create_order_data(orderid, self.gateway_name)
|
||||
self.add_request(
|
||||
"POST",
|
||||
@ -732,6 +687,7 @@ class CoinbaseRestApi(RestClient):
|
||||
on_error=self.on_send_order_error,
|
||||
)
|
||||
|
||||
self.gateway.on_order(order)
|
||||
return order.vt_orderid
|
||||
|
||||
def on_send_order_failed(self, status_code: str, request: Request):
|
||||
@ -756,7 +712,8 @@ class CoinbaseRestApi(RestClient):
|
||||
exception_type: type,
|
||||
exception_value: Exception,
|
||||
tb,
|
||||
request: Request):
|
||||
request: Request
|
||||
):
|
||||
"""
|
||||
Callback when sending order caused exception.
|
||||
"""
|
||||
@ -775,36 +732,15 @@ class CoinbaseRestApi(RestClient):
|
||||
|
||||
def cancel_order(self, req: CancelRequest):
|
||||
""""""
|
||||
if not self.check_rate_limit():
|
||||
return
|
||||
|
||||
orderid = req.orderid
|
||||
|
||||
if orderid not in orderSysDict:
|
||||
cancelDict[orderid] = req
|
||||
|
||||
self.add_request(
|
||||
"DELETE",
|
||||
"/orders/" + orderid,
|
||||
f"/orders/client:{orderid}",
|
||||
callback=self.on_cancel_order,
|
||||
params={},
|
||||
on_error=self.on_cancel_order_error,
|
||||
on_failed=self.on_cancel_order_failed,
|
||||
)
|
||||
|
||||
def on_cancel_order_error(
|
||||
self,
|
||||
exception_type: type,
|
||||
exception_value: Exception,
|
||||
tb,
|
||||
request: Request):
|
||||
"""
|
||||
Callback when cancelling order failed on server.
|
||||
"""
|
||||
# Record exception if not ConnectionError
|
||||
if not issubclass(exception_type, ConnectionError):
|
||||
self.on_error(exception_type, exception_value, tb, request)
|
||||
|
||||
def on_cancel_order(self, data, request):
|
||||
"""Websocket will push a new order status"""
|
||||
pass
|
||||
@ -816,9 +752,9 @@ class CoinbaseRestApi(RestClient):
|
||||
if request.response.text:
|
||||
data = request.response.json()
|
||||
error = data["message"]
|
||||
msg = f"委托失败,状态码:{status_code},信息:{error}"
|
||||
msg = f"撤单失败,状态码:{status_code},信息:{error}"
|
||||
else:
|
||||
msg = f"委托失败,状态码:{status_code}"
|
||||
msg = f"撤单失败,状态码:{status_code}"
|
||||
|
||||
self.gateway.write_log(msg)
|
||||
|
||||
@ -836,7 +772,8 @@ class CoinbaseRestApi(RestClient):
|
||||
exception_type: type,
|
||||
exception_value: Exception,
|
||||
tb,
|
||||
request: Request):
|
||||
request: Request
|
||||
):
|
||||
"""
|
||||
Callback to handler request exception.
|
||||
"""
|
||||
@ -847,35 +784,20 @@ class CoinbaseRestApi(RestClient):
|
||||
self.exception_detail(exception_type, exception_value, tb, request)
|
||||
)
|
||||
|
||||
def reset_rate_limit(self):
|
||||
"""
|
||||
reset the rate limit every 1 sec
|
||||
"""
|
||||
self.rate_limit_remaining = 5
|
||||
|
||||
def check_rate_limit(self):
|
||||
"""
|
||||
Called before send requests
|
||||
"""
|
||||
if self.rate_limit_remaining:
|
||||
self.rate_limit_remaining -= 1
|
||||
return True
|
||||
else:
|
||||
self.gateway.write_log("已超出请求速率上限,请稍后重试")
|
||||
return False
|
||||
|
||||
|
||||
def get_auth_header(
|
||||
timestamp,
|
||||
message,
|
||||
api_key,
|
||||
secret_key,
|
||||
passphrase):
|
||||
passphrase
|
||||
):
|
||||
""""""
|
||||
message = message.encode("ascii")
|
||||
hmac_key = base64.b64decode(secret_key)
|
||||
signature = hmac.new(hmac_key, message, hashlib.sha256)
|
||||
signature_b64 = base64.b64encode(signature.digest()).decode('utf-8')
|
||||
|
||||
return{
|
||||
'Content-Type': 'Application/JSON',
|
||||
'CB-ACCESS-SIGN': signature_b64,
|
||||
|
Loading…
Reference in New Issue
Block a user