From 7d86efce398c65f2b4f918db426c639f7351ab23 Mon Sep 17 00:00:00 2001 From: "vn.py" Date: Thu, 4 Apr 2019 07:40:44 +0800 Subject: [PATCH] [Mod]complete trading test of okex gateway --- tests/trader/run.py | 14 +- vnpy/gateway/okex/okex_gateway.py | 431 +++++++++++++++++------------- 2 files changed, 254 insertions(+), 191 deletions(-) diff --git a/tests/trader/run.py b/tests/trader/run.py index 1a402f82..04e2aecc 100644 --- a/tests/trader/run.py +++ b/tests/trader/run.py @@ -10,6 +10,7 @@ from vnpy.gateway.ib import IbGateway from vnpy.gateway.ctp import CtpGateway from vnpy.gateway.tiger import TigerGateway from vnpy.gateway.oes import OesGateway +from vnpy.gateway.okex import OkexGateway from vnpy.app.cta_strategy import CtaStrategyApp from vnpy.app.csv_loader import CsvLoaderApp @@ -22,12 +23,13 @@ def main(): event_engine = EventEngine() main_engine = MainEngine(event_engine) - main_engine.add_gateway(CtpGateway) - main_engine.add_gateway(IbGateway) - main_engine.add_gateway(FutuGateway) - main_engine.add_gateway(BitmexGateway) - main_engine.add_gateway(TigerGateway) - main_engine.add_gateway(OesGateway) + # main_engine.add_gateway(CtpGateway) + # main_engine.add_gateway(IbGateway) + # main_engine.add_gateway(FutuGateway) + # main_engine.add_gateway(BitmexGateway) + # main_engine.add_gateway(TigerGateway) + # main_engine.add_gateway(OesGateway) + main_engine.add_gateway(OkexGateway) main_engine.add_app(CtaStrategyApp) main_engine.add_app(CsvLoaderApp) diff --git a/vnpy/gateway/okex/okex_gateway.py b/vnpy/gateway/okex/okex_gateway.py index 3edba175..d59888b3 100644 --- a/vnpy/gateway/okex/okex_gateway.py +++ b/vnpy/gateway/okex/okex_gateway.py @@ -8,6 +8,7 @@ import sys import time import json import base64 +import zlib from copy import copy from datetime import datetime from threading import Lock @@ -39,7 +40,7 @@ from vnpy.trader.object import ( ) REST_HOST = "https://www.okex.com" -WEBSOCKET_HOST = "wss://real.okex.com:10440/websocket/okexapi?compress=true" +WEBSOCKET_HOST = "wss://real.okex.com:10442/ws/v3" STATUS_OKEX2VT = { "ordering": Status.SUBMITTING, @@ -88,7 +89,7 @@ class OkexGateway(BaseGateway): def connect(self, setting: dict): """""" - key = setting["API KEY"] + key = setting["API Key"] secret = setting["Secret Key"] passphrase = setting["Passphrase"] session_number = setting["会话数"] @@ -142,7 +143,7 @@ class OkexRestApi(RestClient): self.secret = "" self.passphrase = "" - self.order_count = 1_000_000 + self.order_count = 10000 self.order_count_lock = Lock() self.connect_time = 0 @@ -152,7 +153,8 @@ class OkexRestApi(RestClient): Generate OKEX signature. """ # Sign - timestamp = str(time.time()) + # timestamp = str(time.time()) + timestamp = get_timestamp() request.data = json.dumps(request.data) if request.params: @@ -177,7 +179,7 @@ class OkexRestApi(RestClient): self, key: str, secret: str, - passphrase: str + passphrase: str, session_number: int, proxy_host: str, proxy_port: int, @@ -185,18 +187,21 @@ class OkexRestApi(RestClient): """ Initialize connection to REST server. """ - self.key = key.encode() + self.key = key self.secret = secret.encode() self.passphrase = passphrase - self.connect_time = ( - int(datetime.now().strftime("%y%m%d%H%M%S")) * self.order_count - ) - + self.connect_time = int(datetime.now().strftime("%y%m%d%H%M%S")) + self.init(REST_HOST, proxy_host, proxy_port) self.start(session_number) self.gateway.write_log("REST API启动成功") + self.query_time() + self.query_contract() + self.query_account() + self.query_order() + def _new_order_id(self): with self.order_count_lock: self.order_count += 1 @@ -204,8 +209,8 @@ class OkexRestApi(RestClient): def send_order(self, req: OrderRequest): """""" - orderid = str(self.connect_time + self._new_order_id()) - + orderid = f"a{self.connect_time}{self._new_order_id()}" + data = { "client_oid": orderid, "type": ORDERTYPE_VT2OKEX[req.type], @@ -227,11 +232,11 @@ class OkexRestApi(RestClient): self.add_request( "POST", "/api/spot/v3/orders", - callback=self.on_send_order, - data=data, - extra=order, - on_failed=self.on_send_order_failed, - on_error=self.on_send_order_error, + callback = self.on_send_order, + data = data, + extra = order, + on_failed = self.on_send_order_failed, + on_error = self.on_send_order_error, ) self.gateway.on_order(order) @@ -239,7 +244,7 @@ class OkexRestApi(RestClient): def cancel_order(self, req: CancelRequest): """""" - data = { + data={ "instrument_id": req.symbol, "client_oid": req.orderid } @@ -248,25 +253,41 @@ class OkexRestApi(RestClient): self.add_request( "POST", path, - callback=self.on_cancel_order, - data=data, - on_error=self.on_cancel_order_error, + callback = self.on_cancel_order, + data = data, + on_error = self.on_cancel_order_error, ) def query_contract(self): """""" - data = { - "instrument_id": req.symbol, - "client_oid": req.orderid - } - - path = "/api/spot/v3/cancel_orders/" + req.orderid self.add_request( - "POST", - path, - callback=self.on_cancel_order, - data=data, - on_error=self.on_cancel_order_error, + "GET", + "/api/spot/v3/instruments", + callback = self.on_query_contract + ) + + def query_account(self): + """""" + self.add_request( + "GET", + "/api/spot/v3/accounts", + callback = self.on_query_account + ) + + def query_order(self): + """""" + self.add_request( + "GET", + "/api/spot/v3/orders_pending", + callback = self.on_query_order + ) + + def query_time(self): + """""" + self.add_request( + "GET", + "/api/general/v3/time", + callback=self.on_query_time ) def on_query_contract(self, data, request): @@ -279,8 +300,8 @@ class OkexRestApi(RestClient): name=symbol, product=Product.SPOT, size=1, - pricetick=instrument_data["tick_size"] - + pricetick = instrument_data["tick_size"], + gateway_name = self.gateway_name ) self.gateway.on_contract(contract) @@ -290,6 +311,48 @@ class OkexRestApi(RestClient): self.gateway.write_log("合约信息查询成功") + # Start websocket api after instruments data collected + self.gateway.ws_api.start() + + def on_query_account(self, data, request): + """""" + for account_data in data: + account = AccountData( + accountid=account_data["currency"], + balance=float(account_data["balance"]), + frozen=float(account_data["hold"]), + gateway_name=self.gateway_name + ) + self.gateway.on_account(account) + + self.gateway.write_log("账户资金查询成功") + + def on_query_order(self, data, request): + """""" + for order_data in data: + order = OrderData( + symbol=order_data["instrument_id"], + exchange=Exchange.OKEX, + type=ORDERTYPE_OKEX2VT[order_data["type"]], + orderid=order_data["client_oid"], + direction=DIRECTION_OKEX2VT[order_data["side"]], + price=float(order_data["price"]), + volume=float(order_data["size"]), + time=order_data["timestamp"][11:19], + status=STATUS_OKEX2VT[order_data["status"]], + gateway_name=self.gateway_name, + ) + self.gateway.on_order(order) + + self.gateway.write_log("委托信息查询成功") + + def on_query_time(self, data, request): + """""" + server_time = data["iso"] + local_time = datetime.utcnow().isoformat() + msg = f"服务器时间:{server_time},本机时间:{local_time}" + self.gateway.write_log(msg) + def on_send_order_failed(self, status_code: str, request: Request): """ Callback when sending order failed on server. @@ -368,22 +431,33 @@ class OkexWebsocketApi(WebsocketClient): self.secret = "" self.passphrase = "" - self.callbacks = {} + self.trade_count = 10000 + self.connect_time = 0 + self.callbacks = {} self.ticks = {} - self.accounts = {} - self.orders = {} - self.trades = set() def connect( - self, key: str, secret: str, server: str, proxy_host: str, proxy_port: int + self, + key: str, + secret: str, + passphrase: str, + proxy_host: str, + proxy_port: int ): """""" - self.key = key.encode() + self.key = key self.secret = secret.encode() + self.passphrase = passphrase + + self.connect_time = int(datetime.now().strftime("%y%m%d%H%M%S")) self.init(WEBSOCKET_HOST, proxy_host, proxy_port) - self.start() + # self.start() + + def unpack_data(self, data): + """""" + return json.loads(zlib.decompress(data, -zlib.MAX_WBITS)) def subscribe(self, req: SubscribeRequest): """ @@ -398,10 +472,22 @@ class OkexWebsocketApi(WebsocketClient): ) self.ticks[req.symbol] = tick + channel_ticker = f"spot/ticker:{req.symbol}" + channel_depth = f"spot/depth5:{req.symbol}" + + self.callbacks[channel_ticker] = self.on_ticker + self.callbacks[channel_depth] = self.on_depth + + req = { + "op": "subscribe", + "args": [channel_ticker, channel_depth] + } + self.send_packet(req) + def on_connected(self): """""" self.gateway.write_log("Websocket API连接成功") - self.authenticate() + self.login() def on_disconnected(self): """""" @@ -409,30 +495,27 @@ class OkexWebsocketApi(WebsocketClient): def on_packet(self, packet: dict): """""" - if "error" in packet: - self.gateway.write_log("Websocket API报错:%s" % packet["error"]) + if "event" in packet: + event = packet["event"] + if event == "subscribe": + return + elif event == "error": + msg = packet["message"] + self.gateway.write_log(f"Websocket API请求异常:{msg}") + elif event == "login": + self.on_login(packet) + else: + channel = packet["table"] + data = packet["data"] + callback = self.callbacks[channel] - if "not valid" in packet["error"]: - self.active = False - - elif "request" in packet: - req = packet["request"] - success = packet["success"] - - if success: - if req["op"] == "authKey": - self.gateway.write_log("Websocket API验证授权成功") - self.subscribe_topic() - - elif "table" in packet: - name = packet["table"] - callback = self.callbacks[name] - - if isinstance(packet["data"], list): - for d in packet["data"]: + try: + for d in data: callback(d) - else: - callback(packet["data"]) + except: + import traceback + traceback.print_exc() + print(packet) def on_error(self, exception_type: type, exception_value: Exception, tb): """""" @@ -457,173 +540,151 @@ class OkexWebsocketApi(WebsocketClient): self.key, self.passphrase, timestamp, - signature + signature.decode("utf-8") ] } self.send_packet(req) - self.callbacks['login'] = self.on_login def subscribe_topic(self): """ Subscribe to all private topics. """ + self.callbacks["spot/ticker"] = self.on_ticker + self.callbacks["spot/depth5"] = self.on_depth + self.callbacks["spot/account"] = self.on_account + self.callbacks["spot/order"] = self.on_order + + # Subscribe to order update + channels = [] for instrument_id in instruments: channel = f"spot/order:{instrument_id}" - req = {"op": "subscribe", "args": [channel]} - self.send_packet(req) - self.callbacks[channel] = self.on_trade + channels.append(channel) + req = { + "op": "subscribe", + "args": channels + } + self.send_packet(req) + + # Subscribe to account update + channels = [] for currency in currencies: channel = f"spot/account:{currency}" - req = {"op": "subscribe", "args": [channel]} - self.send_packet(req) - self.callbacks[channel] = self.on_account + channels.append(channel) - def on_login(self, d: dict): + req = { + "op": "subscribe", + "args": channels + } + self.send_packet(req) + + def on_login(self, data: dict): """""" - data = d['data'] + success = data.get("success", False) - if data['success']: - self.gateway.write_log("Websocket接口登录成功") + if success: + self.gateway.write_log("Websocket API登录成功") self.subscribe_topic() else: - self.gateway.write_log("Websocket接口登录失败") + self.gateway.write_log("Websocket API登录失败") - def on_tick(self, d): + def on_ticker(self, d): """""" - symbol = d["symbol"] + symbol = d["instrument_id"] tick = self.ticks.get(symbol, None) if not tick: return - tick.last_price = d["price"] + tick.last_price = d["last"] + tick.open = d["open_24h"] + tick.high = d["high_24h"] + tick.low = d["low_24h"] + tick.volume = d["base_volume_24h"] tick.datetime = datetime.strptime( d["timestamp"], "%Y-%m-%dT%H:%M:%S.%fZ") self.gateway.on_tick(copy(tick)) def on_depth(self, d): """""" - symbol = d["symbol"] - tick = self.ticks.get(symbol, None) - if not tick: - return + for tick_data in d: + symbol = d["instrument_id"] + tick = self.ticks.get(symbol, None) + if not tick: + return - for n, buf in enumerate(d["bids"][:5]): - price, volume = buf - tick.__setattr__("bid_price_%s" % (n + 1), price) - tick.__setattr__("bid_volume_%s" % (n + 1), volume) + bids = d["bids"] + asks = d["asks"] + for n, buf in enumerate(bids): + price, volume, _ = buf + tick.__setattr__("bid_price_%s" % (n + 1), price) + tick.__setattr__("bid_volume_%s" % (n + 1), volume) - for n, buf in enumerate(d["asks"][:5]): - price, volume = buf - tick.__setattr__("ask_price_%s" % (n + 1), price) - tick.__setattr__("ask_volume_%s" % (n + 1), volume) + for n, buf in enumerate(asks): + price, volume, _ = buf + tick.__setattr__("ask_price_%s" % (n + 1), price) + tick.__setattr__("ask_volume_%s" % (n + 1), volume) - tick.datetime = datetime.strptime( - d["timestamp"], "%Y-%m-%dT%H:%M:%S.%fZ") - self.gateway.on_tick(copy(tick)) - - def on_trade(self, d): - """""" - # Filter trade update with no trade volume and side (funding) - if not d["lastQty"] or not d["side"]: - return - - tradeid = d["execID"] - if tradeid in self.trades: - return - self.trades.add(tradeid) - - if d["clOrdID"]: - orderid = d["clOrdID"] - else: - orderid = d["orderID"] - - trade = TradeData( - symbol=d["symbol"], - exchange=Exchange.OKEX, - orderid=orderid, - tradeid=tradeid, - direction=DIRECTION_OKEX2VT[d["side"]], - price=d["lastPx"], - volume=d["lastQty"], - time=d["timestamp"][11:19], - gateway_name=self.gateway_name, - ) - - self.gateway.on_trade(trade) + tick.datetime = datetime.strptime( + d["timestamp"], "%Y-%m-%dT%H:%M:%S.%fZ") + self.gateway.on_tick(copy(tick)) def on_order(self, d): """""" - if "ordStatus" not in d: + order = OrderData( + symbol=d["instrument_id"], + exchange=Exchange.OKEX, + type=ORDERTYPE_OKEX2VT[d["type"]], + orderid=d["client_oid"], + direction=DIRECTION_OKEX2VT[d["side"]], + price=d["price"], + volume=d["size"], + traded=d["filled_size"], + time=d["timestamp"][11:19], + status=STATUS_OKEX2VT[d["status"]], + gateway_name=self.gateway_name, + ) + self.gateway.on_order(copy(order)) + + trade_volume = float(d.get("last_fill_qty", 0)) + if not trade_volume: return - sysid = d["orderID"] - order = self.orders.get(sysid, None) - if not order: - if d["clOrdID"]: - orderid = d["clOrdID"] - else: - orderid = sysid + self.trade_count += 1 + tradeid = f"{self.connect_time}{self.trade_count}" - # time = d["timestamp"][11:19] - - order = OrderData( - symbol=d["symbol"], - exchange=Exchange.OKEX, - type=ORDERTYPE_OKEX2VT[d["ordType"]], - orderid=orderid, - direction=DIRECTION_OKEX2VT[d["side"]], - price=d["price"], - volume=d["orderQty"], - time=d["timestamp"][11:19], - gateway_name=self.gateway_name, - ) - self.orders[sysid] = order - - order.traded = d.get("cumQty", order.traded) - order.status = STATUS_OKEX2VT.get(d["ordStatus"], order.status) - - self.gateway.on_order(copy(order)) + trade = TradeData( + symbol=order.symbol, + exchange=order.exchange, + orderid=order.orderid, + tradeid=tradeid, + direction=order.direction, + price=float(d["last_fill_px"]), + volume=float(trade_volume), + time=d["last_fill_time"][11:19], + gateway_name=self.gateway_name + ) + self.gateway.on_trade(trade) def on_account(self, d): """""" - accountid = str(d["account"]) - account = self.accounts.get(accountid, None) - if not account: - account = AccountData(accountid=accountid, - gateway_name=self.gateway_name) - self.accounts[accountid] = account - - account.balance = d.get("marginBalance", account.balance) - account.available = d.get("availableMargin", account.available) - account.frozen = account.balance - account.available - - self.gateway.on_account(copy(account)) - - def on_contract(self, d): - """""" - if "tickSize" not in d: - return - - if not d["lotSize"]: - return - - contract = ContractData( - symbol=d["symbol"], - exchange=Exchange.OKEX, - name=d["symbol"], - product=Product.FUTURES, - pricetick=d["tickSize"], - size=d["lotSize"], - stop_supported=True, - net_position=True, - gateway_name=self.gateway_name, + account = AccountData( + accountid=d["currency"], + balance=float(d["balance"]), + frozen=float(d["hold"]), + gateway_name=self.gateway_name ) - - self.gateway.on_contract(contract) + + self.gateway.on_account(copy(account)) def generate_signature(msg: str, secret_key: str): """OKEX V3 signature""" return base64.b64encode(hmac.new(secret_key, msg.encode(), hashlib.sha256).digest()) + + +def get_timestamp(): + """""" + now = datetime.utcnow() + timestamp = now.isoformat("T", "milliseconds") + return timestamp + "Z"