[Mod] improve code quality of coinbase gateway
This commit is contained in:
parent
77c2b52842
commit
eec5d0a769
@ -5,7 +5,6 @@ import sys
|
|||||||
import time
|
import time
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from threading import Lock
|
|
||||||
import base64
|
import base64
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
@ -127,7 +126,10 @@ class CoinbaseGateway(BaseGateway):
|
|||||||
passphrase,
|
passphrase,
|
||||||
server,
|
server,
|
||||||
proxy_host,
|
proxy_host,
|
||||||
proxy_port)
|
proxy_port
|
||||||
|
)
|
||||||
|
|
||||||
|
self.init_query()
|
||||||
|
|
||||||
def subscribe(self, req: SubscribeRequest):
|
def subscribe(self, req: SubscribeRequest):
|
||||||
""""""
|
""""""
|
||||||
@ -164,12 +166,11 @@ class CoinbaseGateway(BaseGateway):
|
|||||||
|
|
||||||
def process_timer_event(self, event: Event):
|
def process_timer_event(self, event: Event):
|
||||||
""""""
|
""""""
|
||||||
self.rest_api.reset_rate_limit()
|
self.rest_api.query_account()
|
||||||
self.init_query()
|
|
||||||
|
|
||||||
def init_query(self):
|
def init_query(self):
|
||||||
""""""
|
""""""
|
||||||
self.rest_api.query_account()
|
self.event_engine.register(EVENT_TIMER, self.process_timer_event)
|
||||||
|
|
||||||
|
|
||||||
class CoinbaseWebsocketApi(WebsocketClient):
|
class CoinbaseWebsocketApi(WebsocketClient):
|
||||||
@ -198,18 +199,17 @@ class CoinbaseWebsocketApi(WebsocketClient):
|
|||||||
"match": self.on_order_match,
|
"match": self.on_order_match,
|
||||||
}
|
}
|
||||||
|
|
||||||
self.ticks = {}
|
|
||||||
self.accounts = {}
|
|
||||||
self.orderbooks = {}
|
self.orderbooks = {}
|
||||||
|
|
||||||
def connect(
|
def connect(
|
||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
secret: str,
|
secret: str,
|
||||||
passphrase: str,
|
passphrase: str,
|
||||||
server: str,
|
server: str,
|
||||||
proxy_host: str,
|
proxy_host: str,
|
||||||
proxy_port: int):
|
proxy_port: int
|
||||||
|
):
|
||||||
""""""
|
""""""
|
||||||
self.gateway.write_log("开始连接ws接口")
|
self.gateway.write_log("开始连接ws接口")
|
||||||
self.key = key
|
self.key = key
|
||||||
@ -220,9 +220,8 @@ class CoinbaseWebsocketApi(WebsocketClient):
|
|||||||
self.init(WEBSOCKET_HOST, proxy_host, proxy_port)
|
self.init(WEBSOCKET_HOST, proxy_host, proxy_port)
|
||||||
else:
|
else:
|
||||||
self.init(SANDBOX_WEBSOCKET_HOST, proxy_host, proxy_port)
|
self.init(SANDBOX_WEBSOCKET_HOST, proxy_host, proxy_port)
|
||||||
self.start()
|
|
||||||
|
|
||||||
self.gateway.event_engine.register(EVENT_TIMER, self.gateway.process_timer_event)
|
self.start()
|
||||||
|
|
||||||
def subscribe(self, req: SubscribeRequest):
|
def subscribe(self, req: SubscribeRequest):
|
||||||
""""""
|
""""""
|
||||||
@ -240,12 +239,15 @@ class CoinbaseWebsocketApi(WebsocketClient):
|
|||||||
|
|
||||||
timestamp = str(time.time())
|
timestamp = str(time.time())
|
||||||
message = timestamp + 'GET' + '/users/self/verify'
|
message = timestamp + 'GET' + '/users/self/verify'
|
||||||
|
|
||||||
auth_headers = get_auth_header(
|
auth_headers = get_auth_header(
|
||||||
timestamp,
|
timestamp,
|
||||||
message,
|
message,
|
||||||
self.key,
|
self.key,
|
||||||
self.secret,
|
self.secret,
|
||||||
self.passphrase)
|
self.passphrase
|
||||||
|
)
|
||||||
|
|
||||||
sub_req['signature'] = auth_headers['CB-ACCESS-SIGN']
|
sub_req['signature'] = auth_headers['CB-ACCESS-SIGN']
|
||||||
sub_req['key'] = auth_headers['CB-ACCESS-KEY']
|
sub_req['key'] = auth_headers['CB-ACCESS-KEY']
|
||||||
sub_req['passphrase'] = auth_headers['CB-ACCESS-PASSPHRASE']
|
sub_req['passphrase'] = auth_headers['CB-ACCESS-PASSPHRASE']
|
||||||
@ -340,11 +342,14 @@ class CoinbaseWebsocketApi(WebsocketClient):
|
|||||||
order = orderDict.get(packet['order_id'], None)
|
order = orderDict.get(packet['order_id'], None)
|
||||||
if not order:
|
if not order:
|
||||||
return
|
return
|
||||||
|
|
||||||
order.traded = order.volume - float(packet['remaining_size'])
|
order.traded = order.volume - float(packet['remaining_size'])
|
||||||
|
|
||||||
if packet['reason'] == 'filled':
|
if packet['reason'] == 'filled':
|
||||||
order.status = Status.ALLTRADED
|
order.status = Status.ALLTRADED
|
||||||
else:
|
else:
|
||||||
order.status = Status.CANCELLED
|
order.status = Status.CANCELLED
|
||||||
|
|
||||||
self.gateway.on_order(copy(order))
|
self.gateway.on_order(copy(order))
|
||||||
|
|
||||||
def on_order_match(self, packet: dict):
|
def on_order_match(self, packet: dict):
|
||||||
@ -381,13 +386,15 @@ class OrderBook():
|
|||||||
self.asks = dict()
|
self.asks = dict()
|
||||||
self.bids = dict()
|
self.bids = dict()
|
||||||
self.gateway = gateway
|
self.gateway = gateway
|
||||||
self.newest_tick = TickData(
|
|
||||||
|
self.tick = TickData(
|
||||||
symbol=symbol,
|
symbol=symbol,
|
||||||
exchange=exchange,
|
exchange=exchange,
|
||||||
name=symbol_name_map.get(symbol, ""),
|
name=symbol_name_map.get(symbol, ""),
|
||||||
datetime=datetime.now(),
|
datetime=datetime.now(),
|
||||||
gateway_name=gateway.gateway_name,
|
gateway_name=gateway.gateway_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.first_update = False
|
self.first_update = False
|
||||||
|
|
||||||
def on_message(self, d: dict):
|
def on_message(self, d: dict):
|
||||||
@ -430,12 +437,11 @@ class OrderBook():
|
|||||||
|
|
||||||
self.generate_tick(dt)
|
self.generate_tick(dt)
|
||||||
|
|
||||||
|
|
||||||
def on_ticker(self, d: dict):
|
def on_ticker(self, d: dict):
|
||||||
"""
|
"""
|
||||||
call back when type is ticker
|
call back when type is ticker
|
||||||
"""
|
"""
|
||||||
tick = self.newest_tick
|
tick = self.tick
|
||||||
|
|
||||||
tick.open_price = float(d['open_24h'])
|
tick.open_price = float(d['open_24h'])
|
||||||
tick.high_price = float(d['high_24h'])
|
tick.high_price = float(d['high_24h'])
|
||||||
@ -457,7 +463,7 @@ class OrderBook():
|
|||||||
|
|
||||||
def generate_tick(self, dt: datetime):
|
def generate_tick(self, dt: datetime):
|
||||||
""""""
|
""""""
|
||||||
tick = self.newest_tick
|
tick = self.tick
|
||||||
|
|
||||||
bids_keys = self.bids.keys()
|
bids_keys = self.bids.keys()
|
||||||
bids_keys = sorted(bids_keys, reverse=True)
|
bids_keys = sorted(bids_keys, reverse=True)
|
||||||
@ -509,16 +515,8 @@ class CoinbaseRestApi(RestClient):
|
|||||||
self.secret = ""
|
self.secret = ""
|
||||||
self.passphrase = ""
|
self.passphrase = ""
|
||||||
|
|
||||||
self.order_count = 1_000_000
|
|
||||||
self.order_count_lock = Lock()
|
|
||||||
|
|
||||||
self.connect_time = 0
|
|
||||||
|
|
||||||
self.accounts = {}
|
self.accounts = {}
|
||||||
|
|
||||||
self.rate_limit = 5
|
|
||||||
self.rate_limit_remaining = 5
|
|
||||||
|
|
||||||
def sign(self, request):
|
def sign(self, request):
|
||||||
"""
|
"""
|
||||||
Generate Coinbase signature
|
Generate Coinbase signature
|
||||||
@ -549,17 +547,15 @@ class CoinbaseRestApi(RestClient):
|
|||||||
self.secret = secret.encode()
|
self.secret = secret.encode()
|
||||||
self.passphrase = passphrase
|
self.passphrase = passphrase
|
||||||
|
|
||||||
self.connect_time = (
|
|
||||||
int(datetime.now().strftime("%y%m%d%H%M%S")) * self.order_count
|
|
||||||
)
|
|
||||||
if server == "REAL":
|
if server == "REAL":
|
||||||
self.init(REST_HOST, proxy_host, proxy_port)
|
self.init(REST_HOST, proxy_host, proxy_port)
|
||||||
else:
|
else:
|
||||||
self.init(SANDBOX_REST_HOST, proxy_host, proxy_port)
|
self.init(SANDBOX_REST_HOST, proxy_host, proxy_port)
|
||||||
|
|
||||||
self.start(session_number)
|
self.start(session_number)
|
||||||
|
|
||||||
self.query_instrument()
|
self.query_instrument()
|
||||||
self.query_orders()
|
self.query_order()
|
||||||
|
|
||||||
self.gateway.write_log("REST API启动成功")
|
self.gateway.write_log("REST API启动成功")
|
||||||
|
|
||||||
@ -570,19 +566,15 @@ class CoinbaseRestApi(RestClient):
|
|||||||
self.add_request(
|
self.add_request(
|
||||||
"GET",
|
"GET",
|
||||||
"/products",
|
"/products",
|
||||||
callback=self.on_query_instrument,
|
callback=self.on_query_instrument
|
||||||
params={},
|
|
||||||
on_error=self.on_query_instrument_error,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def query_orders(self):
|
def query_order(self):
|
||||||
""""""
|
""""""
|
||||||
self.add_request(
|
self.add_request(
|
||||||
"GET",
|
"GET",
|
||||||
"/orders?status=all",
|
"/orders?status=open",
|
||||||
callback=self.on_query_orders,
|
callback=self.on_query_order
|
||||||
params={},
|
|
||||||
on_error=self.on_query_orders_error,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def query_account(self):
|
def query_account(self):
|
||||||
@ -591,8 +583,6 @@ class CoinbaseRestApi(RestClient):
|
|||||||
"GET",
|
"GET",
|
||||||
"/accounts",
|
"/accounts",
|
||||||
callback=self.on_query_account,
|
callback=self.on_query_account,
|
||||||
params={},
|
|
||||||
on_error=self.on_query_account_error,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_query_account(self, data, request):
|
def on_query_account(self, data, request):
|
||||||
@ -614,30 +604,11 @@ class CoinbaseRestApi(RestClient):
|
|||||||
|
|
||||||
self.gateway.on_account(copy(account))
|
self.gateway.on_account(copy(account))
|
||||||
|
|
||||||
def on_query_account_error(
|
def on_query_order(self, data, request):
|
||||||
self,
|
|
||||||
exception_type: type,
|
|
||||||
exception_value: Exception,
|
|
||||||
tb,
|
|
||||||
request):
|
|
||||||
""""""
|
|
||||||
if not issubclass(exception_type, ConnectionError):
|
|
||||||
self.on_error(exception_type, exception_value, tb, request)
|
|
||||||
|
|
||||||
def on_query_orders_error(
|
|
||||||
self,
|
|
||||||
exception_type: type,
|
|
||||||
exception_value: Exception,
|
|
||||||
tb,
|
|
||||||
request):
|
|
||||||
""""""
|
|
||||||
if not issubclass(exception_type, ConnectionError):
|
|
||||||
self.on_error(exception_type, exception_value, tb, request)
|
|
||||||
|
|
||||||
def on_query_orders(self, data, request):
|
|
||||||
""""""
|
""""""
|
||||||
for d in data:
|
for d in data:
|
||||||
date, time = d['created_at'].split('T')
|
date, time = d['created_at'].split('T')
|
||||||
|
|
||||||
if d['status'] == 'open':
|
if d['status'] == 'open':
|
||||||
if not float(d['filled_size']):
|
if not float(d['filled_size']):
|
||||||
status = Status.NOTTRADED
|
status = Status.NOTTRADED
|
||||||
@ -648,6 +619,7 @@ class CoinbaseRestApi(RestClient):
|
|||||||
status = Status.ALLTRADED
|
status = Status.ALLTRADED
|
||||||
else:
|
else:
|
||||||
status = Status.CANCELLED
|
status = Status.CANCELLED
|
||||||
|
|
||||||
order = OrderData(
|
order = OrderData(
|
||||||
symbol=d['product_id'],
|
symbol=d['product_id'],
|
||||||
gateway_name=self.gateway_name,
|
gateway_name=self.gateway_name,
|
||||||
@ -667,19 +639,6 @@ class CoinbaseRestApi(RestClient):
|
|||||||
|
|
||||||
self.gateway.write_log(u'委托信息查询成功')
|
self.gateway.write_log(u'委托信息查询成功')
|
||||||
|
|
||||||
def on_query_instrument_error(
|
|
||||||
self,
|
|
||||||
exception_type: type,
|
|
||||||
exception_value: Exception,
|
|
||||||
tb,
|
|
||||||
request: Request):
|
|
||||||
"""
|
|
||||||
Callback when sending order caused exception.
|
|
||||||
"""
|
|
||||||
# Record exception if not ConnectionError
|
|
||||||
if not issubclass(exception_type, ConnectionError):
|
|
||||||
self.on_error(exception_type, exception_value, tb, request)
|
|
||||||
|
|
||||||
def on_query_instrument(self, data, request):
|
def on_query_instrument(self, data, request):
|
||||||
""""""
|
""""""
|
||||||
for d in data:
|
for d in data:
|
||||||
@ -703,9 +662,6 @@ class CoinbaseRestApi(RestClient):
|
|||||||
|
|
||||||
def send_order(self, req: OrderRequest):
|
def send_order(self, req: OrderRequest):
|
||||||
""""""
|
""""""
|
||||||
if not self.check_rate_limit():
|
|
||||||
return
|
|
||||||
|
|
||||||
orderid = str(uuid.uuid1())
|
orderid = str(uuid.uuid1())
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
@ -719,7 +675,6 @@ class CoinbaseRestApi(RestClient):
|
|||||||
if req.type == OrderType.LIMIT:
|
if req.type == OrderType.LIMIT:
|
||||||
data['price'] = req.price
|
data['price'] = req.price
|
||||||
|
|
||||||
|
|
||||||
order = req.create_order_data(orderid, self.gateway_name)
|
order = req.create_order_data(orderid, self.gateway_name)
|
||||||
self.add_request(
|
self.add_request(
|
||||||
"POST",
|
"POST",
|
||||||
@ -732,6 +687,7 @@ class CoinbaseRestApi(RestClient):
|
|||||||
on_error=self.on_send_order_error,
|
on_error=self.on_send_order_error,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.gateway.on_order(order)
|
||||||
return order.vt_orderid
|
return order.vt_orderid
|
||||||
|
|
||||||
def on_send_order_failed(self, status_code: str, request: Request):
|
def on_send_order_failed(self, status_code: str, request: Request):
|
||||||
@ -752,11 +708,12 @@ class CoinbaseRestApi(RestClient):
|
|||||||
self.gateway.write_log(msg)
|
self.gateway.write_log(msg)
|
||||||
|
|
||||||
def on_send_order_error(
|
def on_send_order_error(
|
||||||
self,
|
self,
|
||||||
exception_type: type,
|
exception_type: type,
|
||||||
exception_value: Exception,
|
exception_value: Exception,
|
||||||
tb,
|
tb,
|
||||||
request: Request):
|
request: Request
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Callback when sending order caused exception.
|
Callback when sending order caused exception.
|
||||||
"""
|
"""
|
||||||
@ -775,36 +732,15 @@ class CoinbaseRestApi(RestClient):
|
|||||||
|
|
||||||
def cancel_order(self, req: CancelRequest):
|
def cancel_order(self, req: CancelRequest):
|
||||||
""""""
|
""""""
|
||||||
if not self.check_rate_limit():
|
|
||||||
return
|
|
||||||
|
|
||||||
orderid = req.orderid
|
orderid = req.orderid
|
||||||
|
|
||||||
if orderid not in orderSysDict:
|
|
||||||
cancelDict[orderid] = req
|
|
||||||
|
|
||||||
self.add_request(
|
self.add_request(
|
||||||
"DELETE",
|
"DELETE",
|
||||||
"/orders/" + orderid,
|
f"/orders/client:{orderid}",
|
||||||
callback=self.on_cancel_order,
|
callback=self.on_cancel_order,
|
||||||
params={},
|
|
||||||
on_error=self.on_cancel_order_error,
|
|
||||||
on_failed=self.on_cancel_order_failed,
|
on_failed=self.on_cancel_order_failed,
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_cancel_order_error(
|
|
||||||
self,
|
|
||||||
exception_type: type,
|
|
||||||
exception_value: Exception,
|
|
||||||
tb,
|
|
||||||
request: Request):
|
|
||||||
"""
|
|
||||||
Callback when cancelling order failed on server.
|
|
||||||
"""
|
|
||||||
# Record exception if not ConnectionError
|
|
||||||
if not issubclass(exception_type, ConnectionError):
|
|
||||||
self.on_error(exception_type, exception_value, tb, request)
|
|
||||||
|
|
||||||
def on_cancel_order(self, data, request):
|
def on_cancel_order(self, data, request):
|
||||||
"""Websocket will push a new order status"""
|
"""Websocket will push a new order status"""
|
||||||
pass
|
pass
|
||||||
@ -816,9 +752,9 @@ class CoinbaseRestApi(RestClient):
|
|||||||
if request.response.text:
|
if request.response.text:
|
||||||
data = request.response.json()
|
data = request.response.json()
|
||||||
error = data["message"]
|
error = data["message"]
|
||||||
msg = f"委托失败,状态码:{status_code},信息:{error}"
|
msg = f"撤单失败,状态码:{status_code},信息:{error}"
|
||||||
else:
|
else:
|
||||||
msg = f"委托失败,状态码:{status_code}"
|
msg = f"撤单失败,状态码:{status_code}"
|
||||||
|
|
||||||
self.gateway.write_log(msg)
|
self.gateway.write_log(msg)
|
||||||
|
|
||||||
@ -832,11 +768,12 @@ class CoinbaseRestApi(RestClient):
|
|||||||
self.gateway.write_log(msg)
|
self.gateway.write_log(msg)
|
||||||
|
|
||||||
def on_error(
|
def on_error(
|
||||||
self,
|
self,
|
||||||
exception_type: type,
|
exception_type: type,
|
||||||
exception_value: Exception,
|
exception_value: Exception,
|
||||||
tb,
|
tb,
|
||||||
request: Request):
|
request: Request
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Callback to handler request exception.
|
Callback to handler request exception.
|
||||||
"""
|
"""
|
||||||
@ -847,35 +784,20 @@ class CoinbaseRestApi(RestClient):
|
|||||||
self.exception_detail(exception_type, exception_value, tb, request)
|
self.exception_detail(exception_type, exception_value, tb, request)
|
||||||
)
|
)
|
||||||
|
|
||||||
def reset_rate_limit(self):
|
|
||||||
"""
|
|
||||||
reset the rate limit every 1 sec
|
|
||||||
"""
|
|
||||||
self.rate_limit_remaining = 5
|
|
||||||
|
|
||||||
def check_rate_limit(self):
|
|
||||||
"""
|
|
||||||
Called before send requests
|
|
||||||
"""
|
|
||||||
if self.rate_limit_remaining:
|
|
||||||
self.rate_limit_remaining -= 1
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
self.gateway.write_log("已超出请求速率上限,请稍后重试")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def get_auth_header(
|
def get_auth_header(
|
||||||
timestamp,
|
timestamp,
|
||||||
message,
|
message,
|
||||||
api_key,
|
api_key,
|
||||||
secret_key,
|
secret_key,
|
||||||
passphrase):
|
passphrase
|
||||||
|
):
|
||||||
""""""
|
""""""
|
||||||
message = message.encode("ascii")
|
message = message.encode("ascii")
|
||||||
hmac_key = base64.b64decode(secret_key)
|
hmac_key = base64.b64decode(secret_key)
|
||||||
signature = hmac.new(hmac_key, message, hashlib.sha256)
|
signature = hmac.new(hmac_key, message, hashlib.sha256)
|
||||||
signature_b64 = base64.b64encode(signature.digest()).decode('utf-8')
|
signature_b64 = base64.b64encode(signature.digest()).decode('utf-8')
|
||||||
|
|
||||||
return{
|
return{
|
||||||
'Content-Type': 'Application/JSON',
|
'Content-Type': 'Application/JSON',
|
||||||
'CB-ACCESS-SIGN': signature_b64,
|
'CB-ACCESS-SIGN': signature_b64,
|
||||||
|
Loading…
Reference in New Issue
Block a user