diff --git a/vnpy/gateway/coinbase/coinbase_gateway.py b/vnpy/gateway/coinbase/coinbase_gateway.py index 9c3a6f29..a2c32567 100644 --- a/vnpy/gateway/coinbase/coinbase_gateway.py +++ b/vnpy/gateway/coinbase/coinbase_gateway.py @@ -23,7 +23,6 @@ from vnpy.trader.constant import ( OrderType, Product, Status, - Offset, Interval ) from vnpy.trader.gateway import BaseGateway @@ -31,7 +30,6 @@ from vnpy.trader.object import ( TickData, OrderData, TradeData, - PositionData, AccountData, ContractData, BarData, @@ -49,8 +47,6 @@ SANDBOX_WEBSOCKET_HOST = "wss://ws-feed-public.sandbox.pro.coinbase.com" DIRECTION_VT2COINBASE = {Direction.LONG: "buy", Direction.SHORT: "sell"} DIRECTION_COINBASE2VT = {v: k for k, v in DIRECTION_VT2COINBASE.items()} -STOP_VT2COINBASE = {Direction.LONG: "entry", Direction.SHORT: "loss"} - ORDERTYPE_VT2COINBASE = { OrderType.LIMIT: "limit", OrderType.MARKET: "market", @@ -96,13 +92,6 @@ class CoinbaseGateway(BaseGateway): self.rest_api = CoinbaseRestApi(self) self.ws_api = CoinbaseWebsocketApi(self) - self.rest_api_inited = False - - self.product_id = [] - self.received_instrument = False - - event_engine.register(EVENT_TIMER, self.process_timer_event) - def connect(self, setting: dict): """""" key = setting["ID"] @@ -120,10 +109,7 @@ class CoinbaseGateway(BaseGateway): self.rest_api.connect(key, secret, passphrase, session_number, server, proxy_host, proxy_port) - while not self.received_instrument: - time.sleep(0.5) - self.write_log("合约查询成功") self.ws_api.connect( key, secret, @@ -146,7 +132,7 @@ class CoinbaseGateway(BaseGateway): def query_account(self): """""" - return self.rest_api.qry_account() + return self.rest_api.query_account() def query_position(self): """ @@ -168,12 +154,11 @@ class CoinbaseGateway(BaseGateway): def process_timer_event(self, event: Event): """""" self.rest_api.reset_rate_limit() - if self.rest_api_inited: - self.init_query() + self.init_query() def init_query(self): """""" - self.rest_api.qry_account() + self.rest_api.query_account() class CoinbaseWebsocketApi(WebsocketClient): @@ -226,6 +211,8 @@ class CoinbaseWebsocketApi(WebsocketClient): self.init(SANDBOX_WEBSOCKET_HOST, proxy_host, proxy_port) self.start() + self.gateway.event_engine.register(EVENT_TIMER, self.gateway.process_timer_event) + def subscribe(self, req: SubscribeRequest): """""" symbol = req.symbol @@ -269,10 +256,11 @@ class CoinbaseWebsocketApi(WebsocketClient): """ callback when data is received and unpacked """ - if packet['type'] == 'error': - self.gateway.write_log("Websocket API报错: %s" % packet['message']) - self.gateway.write_log("Websocket API报错原因是: %s" % packet['reason']) + self.gateway.write_log( + "Websocket API报错: %s" % packet['message']) + self.gateway.write_log( + "Websocket API报错原因是: %s" % packet['reason']) self.active = False else: @@ -385,6 +373,7 @@ class OrderBook(): self.gateway = gateway self.newest_tick = TickData( "COINBASE", symbol, exchange, datetime.now()) + self.first_update = False def on_message(self, d: dict): """ @@ -426,21 +415,19 @@ class OrderBook(): self.generate_tick(dt) + def on_ticker(self, d: dict): """ call back when type is ticker """ tick = self.newest_tick - tick.openPrice = float(d['open_24h']) - tick.highPrice = float(d['high_24h']) - tick.lowPrice = float(d['low_24h']) - tick.lastPrice = float(d['price']) + tick.open_price = float(d['open_24h']) + tick.high_price = float(d['high_24h']) + tick.low_price = float(d['low_24h']) + tick.last_price = float(d['price']) tick.volume = float(d['volume_24h']) - dt = datetime.strptime( - d['time'], "%Y-%m-%dT%H:%M:%S.%fZ") - self.gateway.on_tick(copy(tick)) def on_snapshot(self, asks: Sequence[List], bids: Sequence[List]): @@ -556,47 +543,44 @@ class CoinbaseRestApi(RestClient): self.init(SANDBOX_REST_HOST, proxy_host, proxy_port) self.start(session_number) - self.gateway.rest_api_inited = True - - self.qry_instrument() - # self.qry_orders() + self.query_instrument() + self.query_orders() self.gateway.write_log("REST API启动成功") - def qry_instrument(self): + def query_instrument(self): """ Get the instrument of Coinbase """ self.add_request( "GET", "/products", - callback=self.on_qry_instrument, + callback=self.on_query_instrument, params={}, - on_error=self.on_qry_instrument_error, + on_error=self.on_query_instrument_error, ) - def qry_orders(self): + def query_orders(self): """""" - params = {"status": "all"} self.add_request( "GET", - "/orders", - callback=self.on_qry_orders, - params=params, - on_error=self.on_qry_orders_error, + "/orders?status=all", + callback=self.on_query_orders, + params={}, + on_error=self.on_query_orders_error, ) - def qry_account(self): + def query_account(self): """""" self.add_request( "GET", "/accounts", - callback=self.on_qry_account, + callback=self.on_query_account, params={}, - on_error=self.on_qry_account_error, + on_error=self.on_query_account_error, ) - def on_qry_account(self, data, request): + def on_query_account(self, data, request): """""" for acc in data: account_id = str(acc['id']) @@ -612,7 +596,7 @@ class CoinbaseRestApi(RestClient): self.gateway.on_account(copy(account)) - def on_qry_account_error( + def on_query_account_error( self, exception_type: type, exception_value: Exception, @@ -622,7 +606,7 @@ class CoinbaseRestApi(RestClient): if not issubclass(exception_type, ConnectionError): self.on_error(exception_type, exception_value, tb, request) - def on_qry_orders_error( + def on_query_orders_error( self, exception_type: type, exception_value: Exception, @@ -632,12 +616,12 @@ class CoinbaseRestApi(RestClient): if not issubclass(exception_type, ConnectionError): self.on_error(exception_type, exception_value, tb, request) - def on_qry_orders(self, data, request): + def on_query_orders(self, data, request): """""" for d in data: date, time = d['created_at'].split('T') if d['status'] == 'open': - if not d['filled_size']: + if not float(d['filled_size']): status = Status.NOTTRADED else: status = Status.PARTTRADED @@ -661,11 +645,11 @@ class CoinbaseRestApi(RestClient): self.gateway.on_order(copy(order)) orderDict[order.orderid] = order - orderSysDict[order.orderid] = order.orderID + orderSysDict[order.orderid] = order.orderid - self.gateway.writeLog(u'委托信息查询成功') + self.gateway.write_log(u'委托信息查询成功') - def on_qry_instrument_error( + def on_query_instrument_error( self, exception_type: type, exception_value: Exception, @@ -678,7 +662,7 @@ class CoinbaseRestApi(RestClient): if not issubclass(exception_type, ConnectionError): self.on_error(exception_type, exception_value, tb, request) - def on_qry_instrument(self, data, request): + def on_query_instrument(self, data, request): """""" for d in data: contract = ContractData( @@ -688,7 +672,7 @@ class CoinbaseRestApi(RestClient): product=Product.SPOT, pricetick=d['quote_increment'], size=d['base_min_size'], - stop_supported=(not d['limit_only']), + stop_supported=False, net_position=True, history_data=False, gateway_name=self.gateway_name, @@ -696,8 +680,7 @@ class CoinbaseRestApi(RestClient): self.gateway.on_contract(contract) - self.gateway.product_id.append(d['id']) - self.gateway.received_instrument = True + self.gateway.write_log("") def send_order(self, req: OrderRequest): """""" @@ -716,9 +699,7 @@ class CoinbaseRestApi(RestClient): if req.type == OrderType.LIMIT: data['price'] = req.price - elif req.type == OrderType.STOP: - data['stop_price'] = req.price - data['stop'] = STOP_VT2COINBASE[req.Direction] + order = req.create_order_data(orderid, self.gateway_name) self.add_request(