update coinbase

This commit is contained in:
LimingFang 2019-09-04 15:51:35 +08:00
parent 851c22528f
commit 2ab2e49121

View File

@ -23,7 +23,6 @@ from vnpy.trader.constant import (
OrderType, OrderType,
Product, Product,
Status, Status,
Offset,
Interval Interval
) )
from vnpy.trader.gateway import BaseGateway from vnpy.trader.gateway import BaseGateway
@ -31,7 +30,6 @@ from vnpy.trader.object import (
TickData, TickData,
OrderData, OrderData,
TradeData, TradeData,
PositionData,
AccountData, AccountData,
ContractData, ContractData,
BarData, BarData,
@ -49,8 +47,6 @@ SANDBOX_WEBSOCKET_HOST = "wss://ws-feed-public.sandbox.pro.coinbase.com"
DIRECTION_VT2COINBASE = {Direction.LONG: "buy", Direction.SHORT: "sell"} DIRECTION_VT2COINBASE = {Direction.LONG: "buy", Direction.SHORT: "sell"}
DIRECTION_COINBASE2VT = {v: k for k, v in DIRECTION_VT2COINBASE.items()} DIRECTION_COINBASE2VT = {v: k for k, v in DIRECTION_VT2COINBASE.items()}
STOP_VT2COINBASE = {Direction.LONG: "entry", Direction.SHORT: "loss"}
ORDERTYPE_VT2COINBASE = { ORDERTYPE_VT2COINBASE = {
OrderType.LIMIT: "limit", OrderType.LIMIT: "limit",
OrderType.MARKET: "market", OrderType.MARKET: "market",
@ -96,13 +92,6 @@ class CoinbaseGateway(BaseGateway):
self.rest_api = CoinbaseRestApi(self) self.rest_api = CoinbaseRestApi(self)
self.ws_api = CoinbaseWebsocketApi(self) self.ws_api = CoinbaseWebsocketApi(self)
self.rest_api_inited = False
self.product_id = []
self.received_instrument = False
event_engine.register(EVENT_TIMER, self.process_timer_event)
def connect(self, setting: dict): def connect(self, setting: dict):
"""""" """"""
key = setting["ID"] key = setting["ID"]
@ -120,10 +109,7 @@ class CoinbaseGateway(BaseGateway):
self.rest_api.connect(key, secret, passphrase, session_number, server, self.rest_api.connect(key, secret, passphrase, session_number, server,
proxy_host, proxy_port) proxy_host, proxy_port)
while not self.received_instrument:
time.sleep(0.5)
self.write_log("合约查询成功")
self.ws_api.connect( self.ws_api.connect(
key, key,
secret, secret,
@ -146,7 +132,7 @@ class CoinbaseGateway(BaseGateway):
def query_account(self): def query_account(self):
"""""" """"""
return self.rest_api.qry_account() return self.rest_api.query_account()
def query_position(self): def query_position(self):
""" """
@ -168,12 +154,11 @@ class CoinbaseGateway(BaseGateway):
def process_timer_event(self, event: Event): def process_timer_event(self, event: Event):
"""""" """"""
self.rest_api.reset_rate_limit() self.rest_api.reset_rate_limit()
if self.rest_api_inited:
self.init_query() self.init_query()
def init_query(self): def init_query(self):
"""""" """"""
self.rest_api.qry_account() self.rest_api.query_account()
class CoinbaseWebsocketApi(WebsocketClient): class CoinbaseWebsocketApi(WebsocketClient):
@ -226,6 +211,8 @@ class CoinbaseWebsocketApi(WebsocketClient):
self.init(SANDBOX_WEBSOCKET_HOST, proxy_host, proxy_port) self.init(SANDBOX_WEBSOCKET_HOST, proxy_host, proxy_port)
self.start() self.start()
self.gateway.event_engine.register(EVENT_TIMER, self.gateway.process_timer_event)
def subscribe(self, req: SubscribeRequest): def subscribe(self, req: SubscribeRequest):
"""""" """"""
symbol = req.symbol symbol = req.symbol
@ -269,10 +256,11 @@ class CoinbaseWebsocketApi(WebsocketClient):
""" """
callback when data is received and unpacked callback when data is received and unpacked
""" """
if packet['type'] == 'error': if packet['type'] == 'error':
self.gateway.write_log("Websocket API报错 %s" % packet['message']) self.gateway.write_log(
self.gateway.write_log("Websocket API报错原因是 %s" % packet['reason']) "Websocket API报错 %s" % packet['message'])
self.gateway.write_log(
"Websocket API报错原因是 %s" % packet['reason'])
self.active = False self.active = False
else: else:
@ -385,6 +373,7 @@ class OrderBook():
self.gateway = gateway self.gateway = gateway
self.newest_tick = TickData( self.newest_tick = TickData(
"COINBASE", symbol, exchange, datetime.now()) "COINBASE", symbol, exchange, datetime.now())
self.first_update = False
def on_message(self, d: dict): def on_message(self, d: dict):
""" """
@ -426,21 +415,19 @@ class OrderBook():
self.generate_tick(dt) self.generate_tick(dt)
def on_ticker(self, d: dict): def on_ticker(self, d: dict):
""" """
call back when type is ticker call back when type is ticker
""" """
tick = self.newest_tick tick = self.newest_tick
tick.openPrice = float(d['open_24h']) tick.open_price = float(d['open_24h'])
tick.highPrice = float(d['high_24h']) tick.high_price = float(d['high_24h'])
tick.lowPrice = float(d['low_24h']) tick.low_price = float(d['low_24h'])
tick.lastPrice = float(d['price']) tick.last_price = float(d['price'])
tick.volume = float(d['volume_24h']) tick.volume = float(d['volume_24h'])
dt = datetime.strptime(
d['time'], "%Y-%m-%dT%H:%M:%S.%fZ")
self.gateway.on_tick(copy(tick)) self.gateway.on_tick(copy(tick))
def on_snapshot(self, asks: Sequence[List], bids: Sequence[List]): def on_snapshot(self, asks: Sequence[List], bids: Sequence[List]):
@ -556,47 +543,44 @@ class CoinbaseRestApi(RestClient):
self.init(SANDBOX_REST_HOST, proxy_host, proxy_port) self.init(SANDBOX_REST_HOST, proxy_host, proxy_port)
self.start(session_number) self.start(session_number)
self.gateway.rest_api_inited = True self.query_instrument()
self.query_orders()
self.qry_instrument()
# self.qry_orders()
self.gateway.write_log("REST API启动成功") self.gateway.write_log("REST API启动成功")
def qry_instrument(self): def query_instrument(self):
""" """
Get the instrument of Coinbase Get the instrument of Coinbase
""" """
self.add_request( self.add_request(
"GET", "GET",
"/products", "/products",
callback=self.on_qry_instrument, callback=self.on_query_instrument,
params={}, params={},
on_error=self.on_qry_instrument_error, on_error=self.on_query_instrument_error,
) )
def qry_orders(self): def query_orders(self):
"""""" """"""
params = {"status": "all"}
self.add_request( self.add_request(
"GET", "GET",
"/orders", "/orders?status=all",
callback=self.on_qry_orders, callback=self.on_query_orders,
params=params, params={},
on_error=self.on_qry_orders_error, on_error=self.on_query_orders_error,
) )
def qry_account(self): def query_account(self):
"""""" """"""
self.add_request( self.add_request(
"GET", "GET",
"/accounts", "/accounts",
callback=self.on_qry_account, callback=self.on_query_account,
params={}, params={},
on_error=self.on_qry_account_error, on_error=self.on_query_account_error,
) )
def on_qry_account(self, data, request): def on_query_account(self, data, request):
"""""" """"""
for acc in data: for acc in data:
account_id = str(acc['id']) account_id = str(acc['id'])
@ -612,7 +596,7 @@ class CoinbaseRestApi(RestClient):
self.gateway.on_account(copy(account)) self.gateway.on_account(copy(account))
def on_qry_account_error( def on_query_account_error(
self, self,
exception_type: type, exception_type: type,
exception_value: Exception, exception_value: Exception,
@ -622,7 +606,7 @@ class CoinbaseRestApi(RestClient):
if not issubclass(exception_type, ConnectionError): if not issubclass(exception_type, ConnectionError):
self.on_error(exception_type, exception_value, tb, request) self.on_error(exception_type, exception_value, tb, request)
def on_qry_orders_error( def on_query_orders_error(
self, self,
exception_type: type, exception_type: type,
exception_value: Exception, exception_value: Exception,
@ -632,12 +616,12 @@ class CoinbaseRestApi(RestClient):
if not issubclass(exception_type, ConnectionError): if not issubclass(exception_type, ConnectionError):
self.on_error(exception_type, exception_value, tb, request) self.on_error(exception_type, exception_value, tb, request)
def on_qry_orders(self, data, request): def on_query_orders(self, data, request):
"""""" """"""
for d in data: for d in data:
date, time = d['created_at'].split('T') date, time = d['created_at'].split('T')
if d['status'] == 'open': if d['status'] == 'open':
if not d['filled_size']: if not float(d['filled_size']):
status = Status.NOTTRADED status = Status.NOTTRADED
else: else:
status = Status.PARTTRADED status = Status.PARTTRADED
@ -661,11 +645,11 @@ class CoinbaseRestApi(RestClient):
self.gateway.on_order(copy(order)) self.gateway.on_order(copy(order))
orderDict[order.orderid] = order orderDict[order.orderid] = order
orderSysDict[order.orderid] = order.orderID orderSysDict[order.orderid] = order.orderid
self.gateway.writeLog(u'委托信息查询成功') self.gateway.write_log(u'委托信息查询成功')
def on_qry_instrument_error( def on_query_instrument_error(
self, self,
exception_type: type, exception_type: type,
exception_value: Exception, exception_value: Exception,
@ -678,7 +662,7 @@ class CoinbaseRestApi(RestClient):
if not issubclass(exception_type, ConnectionError): if not issubclass(exception_type, ConnectionError):
self.on_error(exception_type, exception_value, tb, request) self.on_error(exception_type, exception_value, tb, request)
def on_qry_instrument(self, data, request): def on_query_instrument(self, data, request):
"""""" """"""
for d in data: for d in data:
contract = ContractData( contract = ContractData(
@ -688,7 +672,7 @@ class CoinbaseRestApi(RestClient):
product=Product.SPOT, product=Product.SPOT,
pricetick=d['quote_increment'], pricetick=d['quote_increment'],
size=d['base_min_size'], size=d['base_min_size'],
stop_supported=(not d['limit_only']), stop_supported=False,
net_position=True, net_position=True,
history_data=False, history_data=False,
gateway_name=self.gateway_name, gateway_name=self.gateway_name,
@ -696,8 +680,7 @@ class CoinbaseRestApi(RestClient):
self.gateway.on_contract(contract) self.gateway.on_contract(contract)
self.gateway.product_id.append(d['id']) self.gateway.write_log("")
self.gateway.received_instrument = True
def send_order(self, req: OrderRequest): def send_order(self, req: OrderRequest):
"""""" """"""
@ -716,9 +699,7 @@ class CoinbaseRestApi(RestClient):
if req.type == OrderType.LIMIT: if req.type == OrderType.LIMIT:
data['price'] = req.price data['price'] = req.price
elif req.type == OrderType.STOP:
data['stop_price'] = req.price
data['stop'] = STOP_VT2COINBASE[req.Direction]
order = req.create_order_data(orderid, self.gateway_name) order = req.create_order_data(orderid, self.gateway_name)
self.add_request( self.add_request(