diff --git a/tests/trader/run.py b/tests/trader/run.py index d77acee9..0b7fb501 100644 --- a/tests/trader/run.py +++ b/tests/trader/run.py @@ -11,6 +11,7 @@ from vnpy.gateway.ctp import CtpGateway from vnpy.gateway.tiger import TigerGateway from vnpy.gateway.oes import OesGateway from vnpy.gateway.okex import OkexGateway +from vnpy.gateway.huobi import HuobiGateway from vnpy.app.cta_strategy import CtaStrategyApp from vnpy.app.csv_loader import CsvLoaderApp @@ -30,6 +31,7 @@ def main(): main_engine.add_gateway(TigerGateway) main_engine.add_gateway(OesGateway) main_engine.add_gateway(OkexGateway) + main_engine.add_gateway(HuobiGateway) main_engine.add_app(CtaStrategyApp) main_engine.add_app(CsvLoaderApp) diff --git a/vnpy/gateway/huobi/huobi_gateway.py b/vnpy/gateway/huobi/huobi_gateway.py index bf03f958..59b8a908 100644 --- a/vnpy/gateway/huobi/huobi_gateway.py +++ b/vnpy/gateway/huobi/huobi_gateway.py @@ -4,18 +4,25 @@ 火币交易接口 """ +import re +import urllib +import base64 +import json +import zlib import hashlib import hmac from copy import copy from datetime import datetime -from vnpy.api.rest import Request, RestClient +from vnpy.event import Event +from vnpy.api.rest import RestClient from vnpy.api.websocket import WebsocketClient from vnpy.trader.constant import ( Direction, Exchange, Product, Status, + OrderType ) from vnpy.trader.gateway import BaseGateway from vnpy.trader.object import ( @@ -26,25 +33,17 @@ from vnpy.trader.object import ( ContractData, OrderRequest, CancelRequest, - SubscribeRequest, - LogData, + SubscribeRequest ) - -import re -import urllib -import base64 -import json -import zlib -from vnpy.event.engine import Event, EVENT_TIMER -from vnpy.trader.event import EVENT_LOG +from vnpy.trader.event import EVENT_TIMER -REST_HOST = "https://api.huobi.pro" -WEBSOCKET_MARKET_HOST = "wss://api.huobi.pro/ws" # 行情 -WEBSOCKET_TRADE_HOST = "wss://api.huobi.pro/ws/v1" # 资金和委托 +REST_HOST = "https://api.huobipro.com" +WEBSOCKET_DATA_HOST = "wss://api.huobi.pro/ws" # Market Data +WEBSOCKET_TRADE_HOST = "wss://api.huobi.pro/ws/v1" # Account and Order STATUS_HUOBI2VT = { - "submitted": Status.SUBMITTING, + "submitted": Status.NOTTRADED, "partial-filled": Status.PARTTRADED, "filled": Status.ALLTRADED, "cancelling": Status.CANCELLED, @@ -52,6 +51,17 @@ STATUS_HUOBI2VT = { "canceled": Status.CANCELLED, } +ORDERTYPE_VT2HUOBI = { + (Direction.LONG, OrderType.MARKET): "buy-market", + (Direction.SHORT, OrderType.MARKET): "sell-market", + (Direction.LONG, OrderType.LIMIT): "buy-limit", + (Direction.SHORT, OrderType.LIMIT): "sell-limit", +} +ORDERTYPE_HUOBI2VT = {v: k for k, v in ORDERTYPE_VT2HUOBI.items()} + + +huobi_symbols = set() + class HuobiGateway(BaseGateway): """ @@ -59,42 +69,47 @@ class HuobiGateway(BaseGateway): """ default_setting = { - "ID": "", - "Secret": "", - "Symbols": "", + "API Key": "", + "Secret Key": "", + "会话数": 3, + "代理地址": "127.0.0.1", + "代理端口": 1080, } def __init__(self, event_engine): """Constructor""" super(HuobiGateway, self).__init__(event_engine, "HUOBI") - self.local_id = 10000 - - self.accountDict = {} - self.orderDict = {} - self.localOrderDict = {} - self.orderLocalDict = {} - - self.qry_enabled = False + self.order_count = 100000 + + self.local_huobi_map = {} # local orderid: huobi orderid + self.huobi_local_map = {} # huobi orderid: local orderid + self.local_order_map = {} # local orderid: order + self.huobi_order_data = {} # huobi orderid: data self.rest_api = HuobiRestApi(self) self.trade_ws_api = HuobiTradeWebsocketApi(self) - self.market_ws_api = HuobiMarketWebsocketApi(self) + self.market_ws_api = HuobiDataWebsocketApi(self) def connect(self, setting: dict): """""" - key = setting["ID"] - secret = setting["Secret"] - symbols = setting["Symbols"] - - self.rest_api.connect(symbols, secret, key) - self.trade_ws_api.connect(symbols, secret, key) - self.market_ws_api.connect(symbols, secret, key) - # websocket will push all account status on connected, including asset, position and orders. + key = setting["API Key"] + secret = setting["Secret Key"] + session_number = setting["会话数"] + proxy_host = setting["代理地址"] + proxy_port = setting["代理端口"] + + self.rest_api.connect(key, secret, session_number, + proxy_host, proxy_port) + self.trade_ws_api.connect(key, secret, proxy_host, proxy_port) + self.market_ws_api.connect(key, secret, proxy_host, proxy_port) + + self.init_query() def subscribe(self, req: SubscribeRequest): """""" - self.ws_api.subscribe(req) + self.market_ws_api.subscribe(req) + self.trade_ws_api.subscribe(req) def send_order(self, req: OrderRequest): """""" @@ -106,7 +121,7 @@ class HuobiGateway(BaseGateway): def query_account(self): """""" - self.rest_api.query_account() + self.rest_api.query_account_balance() def query_position(self): """""" @@ -118,53 +133,62 @@ class HuobiGateway(BaseGateway): self.trade_ws_api.stop() self.market_ws_api.stop() - def init_query(self): - """初始化连续查询""" - if self.qry_enabled: - # 需要循环的查询函数列表 - self.qry_functionList = [self.qry_info] - - self.qry_count = 0 # 查询触发倒计时 - self.qry_trigger = 1 # 查询触发点 - self.qry_next_function = 0 # 上次运行的查询函数索引 - - self.start_query() - - def query(self, event): - """注册到事件处理引擎上的查询函数""" - self.qry_count += 1 - - if self.qry_count > self.qry_trigger: - # 清空倒计时 - self.qry_count = 0 - - # 执行查询函数 - function = self.qry_functionList[self.qry_next_function] - function() - - # 计算下次查询函数的索引,如果超过了列表长度,则重新设为0 - self.qry_next_function += 1 - if self.qry_next_function == len(self.qry_functionList): - self.qry_next_function = 0 - - def start_query(self): - """启动连续查询""" - self.event_engine.register(EVENT_TIMER, self.query) - - def set_qry_enabled(self, qry_enabled): - """设置是否要启动循环查询""" - self.qry_enabled = qry_enabled - - def write_log(self, msg): + def process_timer_event(self, event: Event): """""" - log = LogData() - log.log_content = msg - log.gateway_name = self.gateway_name - - event = Event(EVENT_LOG) - event.dict_["data"] = log - self.event_engine.put(event) + self.count += 1 + if self.count < 3: + return + self.query_account() + + def init_query(self): + """""" + self.count = 0 + self.event_engine.register(EVENT_TIMER, self.process_timer_event) + + def get_local_orderid(self, huobi_orderid: str): + """""" + local_orderid = self.huobi_local_map.get(huobi_orderid, None) + + if not local_orderid: + local_orderid = self.new_local_orderid() + self.update_orderid_map(local_orderid, huobi_orderid) + + return local_orderid + + def get_huobi_orderid(self, local_orderid: str): + """""" + huobi_orderid = self.local_huobi_map.get(local_orderid, "") + return huobi_orderid + + def new_local_orderid(self): + """""" + self.order_count += 1 + return str(self.order_count) + + def update_orderid_map(self, local_orderid: str, huobi_orderid: str): + """""" + self.huobi_local_map[huobi_orderid] = local_orderid + self.local_huobi_map[local_orderid] = huobi_orderid + + if huobi_orderid in self.huobi_order_data: + data = self.huobi_order_data.pop(huobi_orderid) + self.trade_ws_api.on_order(data) + + def on_order(self, order: OrderData): + """""" + self.local_order_map[order.orderid] = order + + super().on_order(copy(order)) + + def get_order(self, huobi_orderid: str): + """""" + local_orderid = self.huobi_local_map.get(huobi_orderid, None) + if not local_orderid: + return None + else: + return self.local_order_map[local_orderid] + class HuobiRestApi(RestClient): """ @@ -178,315 +202,283 @@ class HuobiRestApi(RestClient): self.gateway = gateway self.gateway_name = gateway.gateway_name - self.symbols = [] + self.host = "" self.key = "" self.secret = "" - self.sign_host = "" - self.account_id = "" - self.cancelReqDict = {} - self.orderBufDict = {} - self.accountDict = gateway.accountDict - self.orderDict = gateway.orderDict - self.orderLocalDict = gateway.orderLocalDict - self.localOrderDict = gateway.localOrderDict - - self.account_id = "" - self.cancelReqDict = {} - self.orderBufDict = {} + self.cancel_requests = {} + self.orders = {} def sign(self, request): """ Generate HUOBI signature. """ - request.headers = { "User-Agent": "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.71 Safari/537.36" } - params_with_signature = create_signature(self.key, request.method, self.sign_host, request.path, self.secret, request.params) + params_with_signature = create_signature( + self.key, + request.method, + self.host, + request.path, + self.secret, + request.params + ) request.params = params_with_signature - + if request.method == "POST": request.headers["Content-Type"] = "application/json" - + if request.data: - request.data = json.dumps(request.data) + request.data = json.dumps(request.data) + return request def connect( self, key: str, secret: str, - symbols, - session_number=3, + session_number: int, + proxy_host: str, + proxy_port: int ): """ Initialize connection to REST server. """ self.key = key self.secret = secret - self.symbols = symbols - host, path = _split_url(REST_HOST) - self.init(REST_HOST) - - self.sign_host = host + self.host, _ = _split_url(REST_HOST) + + self.init(REST_HOST, proxy_host, proxy_port) self.start(session_number) self.gateway.write_log("REST API启动成功") + self.query_contract() + self.query_account() + self.query_order() + def query_account(self): """""" self.add_request( method="GET", path="/v1/account/accounts", - callable=self.on_query_account + callback=self.on_query_account ) def query_account_balance(self): """""" - # path = "/v1/account/accounts/%s/balance" %self.account_id path = f"/v1/account/accounts/{self.account_id}/balance" self.add_request( - method="GET", - path=path, - callable=self.on_query_account_balance + method="GET", + path=path, + callback=self.on_query_account_balance ) def query_order(self): """""" - path = "/v1/order/orders" - - today_date = datetime.now().strftime("%Y-%m-%d") - states_active = "submitted, partial-filled" - - for symbol in self.symbols: - params = { - "symbol": symbol, - "states": states_active, - "end_date": today_date - } - self.add_request( - method="GET", - path=path, - callable=self.on_query_order, - params=params - ) - + self.add_request( + method="GET", + path="/v1/order/openOrders", + callback=self.on_query_order + ) + def query_contract(self): """""" self.add_request( - method="GET", - path="/v1/common/symbols", - callable=self.on_query_contract + method="GET", + path="/v1/common/symbols", + callback=self.on_query_contract ) def send_order(self, req: OrderRequest): """""" - self.gateway.local_id += 1 + huobi_type = ORDERTYPE_VT2HUOBI.get( + (req.direction, req.type), "" + ) - local_id = str(self.gateway.local_id) - - if req.direction == Direction.LONG: - type_ = "buy-limit" - else: - type_ = "sell-limit" + local_orderid = self.gateway.new_local_orderid() + order = req.create_order_data( + local_orderid, + self.gateway_name + ) + order.time = datetime.now().strftime("%H:%M:%S") data = { "account-id": self.account_id, "amount": str(req.volume), "symbol": req.symbol, - "type": type_, + "type": huobi_type, "price": str(req.price), "source": "api" } - path = "/v1/order/orders/place" - + self.add_request( - method="POST", - path=path, - callable=self.on_send_order, - data=data, - extra=local_id, + method="POST", + path="/v1/order/orders/place", + callback=self.on_send_order, + data=data, + extra=order, ) - order = OrderData( - symbol=req.symbol, - exchange=Exchange.HUOBI, - orderid=local_id, - direction=req.direction, - price=req.price, - volume=req.volume, - time=datetime.now(), - gateway_name=self.gateway_name, - ) - - self.orderBufDict[local_id] = order - self.gateway.on_order(order) return order.vt_orderid def cancel_order(self, req: CancelRequest): """""" local_id = req.orderid - order_id = self.localOrderDict.get(local_id, None) + huobi_orderid = self.gateway.get_huobi_orderid(local_id) - if order_id: - path = f"/v1/order/orders/{order_id}/submitcancel" - self.add_request( - method="POST", - path=path, - callable=self.on_cancel_order, - ) - - if local_id in self.cancelReqDict: - del self.cancelReqDict[local_id] - else: - self.cancelReqDict[local_id] = req + if not huobi_orderid: + self.cancel_requests[local_id] = req + return - def on_query_account(self, data, request): # type: (dict, Request)->None + path = f"/v1/order/orders/{huobi_orderid}/submitcancel" + self.add_request( + method="POST", + path=path, + callback=self.on_cancel_order, + extra=req + ) + + if local_id in self.cancel_requests: + self.cancel_requests.pop(local_id) + + def on_query_account(self, data, request): """""" + if self.check_error(data, "查询账户"): + return + for d in data["data"]: - if str(d["type"]) == "spot": - self.account_id = str(d["id"]) + if d["type"] == "spot": + self.account_id = d["id"] self.gateway.write_log(f"账户代码{self.account_id}查询成功") self.query_account_balance() - def on_query_account_balance(self, data, request): # type: (dict, Request)->None + def on_query_account_balance(self, data, request): """""" - status = data.get("status", None) - if status == "error": - msg = "错误代码:%s, 错误信息:%s" % (data["err-code"], data["err-msg"]) - self.gateway.write_log(msg) + if self.check_error(data, "查询账户资金"): return - - self.gateway.write_log(u"资金信息查询成功") - + + buf = {} for d in data["data"]["list"]: currency = d["currency"] - account = self.accountDict.get(currency, None) + currency_data = buf.setdefault(currency, {}) + currency_data[d["type"]] = float(d["balance"]) - if not account: - account = AccountData( - account_id=d["currency"], - gateway_name=self.gateway_name, - available=float(d["balance"]) if d["type"] == "trade" else 0.0, - margin=float(d["balance"]) if d["type"] == "frozen" else 0.0, - balance=account.margin + account.available, - ) - self.accountDict[currency] = account + for currency, currency_data in buf.items(): + account = AccountData( + accountid=currency, + balance=currency_data["trade"] + currency_data["frozen"], + frozen=currency_data["frozen"], + gateway_name=self.gateway_name, + ) - for account in self.accountDict.values(): - self.gateway.on_account(account) - - self.query_order() + if account.balance: + self.gateway.on_account(account) - def on_query_order(self, data, request): # type: (dict, Request)->None - """""" - status = data.get("status", None) - if status == "error": - msg = "错误代码:%s, 错误信息:%s" % (data["err-code"], data["err-msg"]) - self.gateway.write_log(msg) + def on_query_order(self, data, request): + """""" + if self.check_error(data, "查询委托"): return - - symbol = request.params["symbol"] - self.gateway.write_log(f"{symbol}委托信息查询成功") - - data["data"].reverse() + for d in data["data"]: - order_id = str(d["id"]) - self.gateway.local_id += 1 - local_id = str(self.gateway.local_id) + huobi_orderid = d["id"] + local_orderid = self.gateway.get_local_orderid(huobi_orderid) - self.orderLocalDict[order_id] = local_id - self.localOrderDict[local_id] = order_id - - if "buy" in d["type"]: - direction = Direction.LONG - else: - direction = Direction.SHORT - - if d["canceled-at"]: - time = datetime.fromtimestamp(d["canceled-at"] / 1000).strftime("%H:%M:%S") - else: - time = datetime.fromtimestamp(d["created-at"] / 1000).strftime("%H:%M:%S") + direction, order_type = ORDERTYPE_HUOBI2VT[d["type"]] + dt = datetime.fromtimestamp(d["created-at"] / 1000) + time = dt.strftime("%H:%M:%S") order = OrderData( - orderid=local_id, + orderid=local_orderid, symbol=d["symbol"], exchange=Exchange.HUOBI, price=float(d["price"]), volume=float(d["amount"]), + type=order_type, direction=direction, - traded=float(d["field-amount"]), + traded=float(d["filled-amount"]), status=STATUS_HUOBI2VT.get(d["state"], None), time=time, gateway_name=self.gateway_name, - ) - self.orderDict[order_id] = order self.gateway.on_order(order) + + self.gateway.write_log("委托信息查询成功") def on_query_contract(self, data, request): # type: (dict, Request)->None """""" - status = data.get("status", None) - if status == "error": - msg = "错误代码:%s, 错误信息:%s" % (data["err-code"], data["err-msg"]) - self.gateway.write_log(msg) + if self.check_error(data, "查询合约"): return - - self.gateway.write_log("合约信息查询成功") - + for d in data["data"]: + base_currency = d["base-currency"] + quote_currency = d["quote-currency"] + name = f"{base_currency.upper()}/{quote_currency.upper()}" + pricetick = 1 / pow(10, d["price-precision"]) + size = 1 / pow(10, d["amount-precision"]) + contract = ContractData( - symbol=d["base-currency"] + d["quote-currency"], + symbol=d["symbol"], exchange=Exchange.HUOBI, - name="/".join([d["base-currency"].upper(), d["quote-currency"].upper()]), - pricetick=1 / pow(10, d["price-precision"]), - size=1 / pow(10, d["amount-precision"]), + name=name, + pricetick=pricetick, + size=size, product=Product.SPOT, gateway_name=self.gateway_name, ) self.gateway.on_contract(contract) - self.query_account() + huobi_symbols.add(contract.symbol) - def on_send_order(self, data, request): # type: (dict, Request)->None + self.gateway.write_log("合约信息查询成功") + + def on_send_order(self, data, request): """""" - local_id = request.extra - order = self.orderBufDict[local_id] + order = request.extra - status = data.get("status", None) - - if status == "error": - msg = "错误代码:%s, 错误信息:%s" % (data["err-code"], data["err-msg"]) - self.gateway.write_log(msg) - + if self.check_error(data, "委托"): order.status = Status.REJECTED self.gateway.on_order(order) return + + huobi_orderid = str(data["data"]) + self.gateway.update_orderid_map(order.orderid, huobi_orderid) - order_id = str(data["data"]) - - self.localOrderDict[local_id] = order_id - self.orderDict[order_id] = order - - req = self.cancelReqDict.get(local_id, None) + req = self.cancel_requests.get(order.orderid, None) if req: self.cancel_order(req) - def on_cancel_order(self, data, request): # type: (dict, Request)->None + def on_cancel_order(self, data, request): """""" - status = data.get("status", None) - if status == "error": - msg = "错误代码:%s, 错误信息:%s" % (data["err-code"], data["err-msg"]) - self.gateway.write_log(msg) + if self.check_error(data, "撤单"): return + + cancel_request = request.extra + local_orderid = cancel_request.orderid + huobi_orderid = self.gateway.get_huobi_orderid(local_orderid) + + order = self.gateway.get_order(huobi_orderid) + order.status = Status.CANCELLED - self.gateway.write_log(f"委托撤单成功:{data}") + self.gateway.on_order(copy(order)) + self.gateway.write_log(f"委托撤单成功:{order.orderid}") + + def check_error(self, data: dict, func: str = ""): + """""" + if data["status"] != "error": + return False + + error_code = data["err-code"] + error_msg = data["err-msg"] + + self.gateway.write_log(f"{func}请求出错,代码:{error_code},信息:{error_msg}") + return True class HuobiWebsocketApiBase(WebsocketClient): @@ -504,20 +496,28 @@ class HuobiWebsocketApiBase(WebsocketClient): self.sign_host = "" self.path = "" - def connect(self, key: str, secret: str, url: str): + def connect( + self, + key: str, + secret: str, + url: str, + proxy_host: str, + proxy_port: int + ): """""" self.key = key self.secret = secret - host, path = _split_url(url) - self.init(url) + host, path = _split_url(url) self.sign_host = host self.path = path + + self.init(url, proxy_host, proxy_port) self.start() def login(self): """""" - params = {"op": "auth", } + params = {"op": "auth"} params.update(create_signature(self.key, "GET", self.sign_host, self.path, self.secret)) return self.send_packet(params) @@ -534,15 +534,12 @@ class HuobiWebsocketApiBase(WebsocketClient): """""" if "ping" in packet: self.send_packet({"pong": packet["ping"]}) - return - - if "err-msg" in packet: + elif "err-msg" in packet: return self.on_error_msg(packet) - - if "op" in packet and packet["op"] == "auth": + elif "op" in packet and packet["op"] == "auth": return self.on_login() - - self.on_data(packet) + else: + self.on_data(packet) def on_data(self, packet): """""" @@ -561,50 +558,32 @@ class HuobiTradeWebsocketApi(HuobiWebsocketApiBase): """""" def __init__(self, gateway): """""" - super(HuobiTradeWebsocketApi, self).__init__(gateway) + super().__init__(gateway) - self.req_id = 10000 - - self.accountDict = gateway.accountDict - self.orderDict = gateway.orderDict - self.orderLocalDict = gateway.orderLocalDict - self.localOrderDict = gateway.localOrderDict + self.req_id = 0 - def connect(self, symbols, key, secret): + def connect(self, key, secret, proxy_host, proxy_port): """""" - self.symbols = symbols - super(HuobiTradeWebsocketApi, self).connect(key, secret, WEBSOCKET_TRADE_HOST) + super().connect(key, secret, WEBSOCKET_TRADE_HOST, proxy_host, proxy_port) - def subscribe_topic(self): + def subscribe(self, req: SubscribeRequest): """""" - # 订阅资金变动 self.req_id += 1 req = { "op": "sub", "cid": str(self.req_id), - "topic": "accounts", + "topic": f"orders.{req.symbol}" } self.send_packet(req) - - # 订阅委托变动 - for symbol in self.symbols: - self.req_id += 1 - req = { - "op": "sub", - "cid": str(self.req_id), - "topic": f"orders.{symbol}" - } - self.send_packet(req) def on_connected(self): """""" + self.gateway.write_log("交易Websocket API连接成功") self.login() def on_login(self): """""" - self.gateway.write_log("交易Websocket服务器登录成功") - - self.subscribe_topic() + self.gateway.write_log("交易Websocket API登录成功") def on_data(self, packet): # type: (dict)->None """""" @@ -613,188 +592,133 @@ class HuobiTradeWebsocketApi(HuobiWebsocketApiBase): return topic = packet["topic"] - if topic == "accounts": - self.on_account(packet["data"]) - elif "orders" in topic: + if "orders" in topic: self.on_order(packet["data"]) - def on_account(self, data): + def on_order(self, data: dict): """""" - for d in data["list"]: - account = self.accountDict.get(d["currency"], None) - if not account: - continue - - if d["type"] == "trade": - account.available = float(d["balance"]) - elif d["type"] == "frozen": - account.margin = float(d["balance"]) - - account.balance = account.margin + account.available - self.gateway.on_account(account) - - def on_order(self, data: list): - """""" - order_id = str(data["order-id"]) - order = self.orderDict.get(order_id, None) - + huobi_orderid = str(data["order-id"]) + order = self.gateway.get_order(huobi_orderid) if not order: - local_id = self._new_order_id() - local_id = str(local_id) - - self.orderLocalDict[order_id] = local_id - self.localOrderDict[local_id] = order_id - - if "buy" in data["order-type"]: - direction = Direction.LONG - else: - direction = Direction.SHORT - - order = OrderData( - orderid=local_id, - symbol=data["symbol"], - exchange=Exchange.HUOBI, - price=float(data["order-price"]), - volume=float(data["order-amount"]), - direction=direction, - status=STATUS_HUOBI2VT.get(data["order-state"], None), - time=datetime.fromtimestamp(data["created-at"] / 1000).strftime("%H:%M:%S"), - gateway_name=self.gateway_name, - ) - order.traded += float(data['filled-amount']) - self.orderDict[order_id] = order - self.gateway.onOrder(order) + self.gateway.huobi_order_data[huobi_orderid] = data + return - if float(data["filled-amount"]): + traded_volume = float(data["filled-amount"]) + order.traded += traded_volume + order.status = STATUS_HUOBI2VT.get(data["order-state"], None) + self.gateway.on_order(order) + + if traded_volume: trade = TradeData( + symbol=order.symbol, + exchange=Exchange.HUOBI, orderid=order.orderid, tradeid=str(data["seq-id"]), - symbol=data["symbol"], - exchange=Exchange.HUOBI, direction=order.direction, price=float(data["price"]), volume=float(data["filled-amount"]), time=datetime.now().strftime("%H:%M:%S"), gateway_name=self.gateway_name, ) - self.gateway.onTrade(trade) + self.gateway.on_trade(trade) -class HuobiMarketWebsocketApi(HuobiWebsocketApiBase): +class HuobiDataWebsocketApi(HuobiWebsocketApiBase): """""" + def __init__(self, gateway): """""" - super(HuobiMarketWebsocketApi, self).__init__(gateway) + super().__init__(gateway) - self.req_id = 10000 - self.tickDict = {} + self.req_id = 0 + self.ticks = {} - def connect(self, symbols, key, secret): + def connect(self, key: str, secret: str, proxy_host: str, proxy_port: int): """""" - self.symbols = symbols - super(HuobiMarketWebsocketApi, self).connect(key, secret, WEBSOCKET_MARKET_HOST) + super().connect(key, secret, WEBSOCKET_DATA_HOST, proxy_host, proxy_port) def on_connected(self): """""" - self.subscribe_topic() - - def subscribe_topic(self): # type:()->None + self.gateway.write_log("行情Websocket API连接成功") + + def subscribe(self, req: SubscribeRequest): """""" - for symbol in self.symbols: - # 创建Tick对象 - tick = TickData( - symbol=symbol, - exchange=Exchange.HUOBI, - gateway_name=self.gateway_name, - ) + symbol = req.symbol - self.tickDict[symbol] = tick + # Create tick data buffer + tick = TickData( + symbol=symbol, + exchange=Exchange.HUOBI, + datetime=datetime.now(), + gateway_name=self.gateway_name, + ) + self.ticks[symbol] = tick - # 订阅深度和成交 - self.req_id += 1 - req = { - "sub": "market.{symbol}.depth.step0", - "id": str(self.req_id) - } - self.send_packet(req) - - self.req_id += 1 - req = { - "sub": "market.{symbol}.detail", - "id": str(self.req_id) - } - self.send_packet(req) + # Subscribe to market depth update + self.req_id += 1 + req = { + "sub": f"market.{symbol}.depth.step0", + "id": str(self.req_id) + } + self.send_packet(req) + + # Subscribe to market detail update + self.req_id += 1 + req = { + "sub": f"market.{symbol}.detail", + "id": str(self.req_id) + } + self.send_packet(req) def on_data(self, packet): # type: (dict)->None """""" - if "ch" in packet: - if "depth.step" in packet["ch"]: + channel = packet.get("ch", None) + if channel: + if "depth.step" in channel: self.on_market_depth(packet) - elif "detail" in packet["ch"]: + elif "detail" in channel: self.on_market_detail(packet) elif "err-code" in packet: - self.gateway.write_log("错误代码:%s, 错误信息:%s" % (data["err-code"], data["err-msg"])) + code = packet["err-code"] + msg = packet["err-msg"] + self.gateway.write_log(f"错误代码:{code}, 错误信息:{msg}") def on_market_depth(self, data): """行情深度推送 """ symbol = data["ch"].split(".")[1] - - tick = self.tickDict.get(symbol, None) - if not tick: - return - + tick = self.ticks[symbol] tick.datetime = datetime.fromtimestamp(data["ts"] / 1000) - tick.date = tick.datetime.strftime("%Y%m%d") - tick.time = tick.datetime.strftime("%H:%M:%S.%f") - + bids = data["tick"]["bids"] for n in range(5): - l = bids[n] - tick.__setattr__("bid_price_" + str(n + 1), float(l[0])) - tick.__setattr__("bid_volume_" + str(n + 1), float(l[1])) + price, volume = bids[n] + tick.__setattr__("bid_price_" + str(n + 1), float(price)) + tick.__setattr__("bid_volume_" + str(n + 1), float(volume)) asks = data["tick"]["asks"] for n in range(5): - l = asks[n] - tick.__setattr__("ask_price_" + str(n + 1), float(l[0])) - tick.__setattr__("ask_volume_" + str(n + 1), float(l[1])) + price, volume = asks[n] + tick.__setattr__("ask_price_" + str(n + 1), float(price)) + tick.__setattr__("ask_volume_" + str(n + 1), float(volume)) if tick.last_price: - newtick = copy(tick) - self.gateway.on_tick(newtick) + self.gateway.on_tick(copy(tick)) def on_market_detail(self, data): """市场细节推送""" symbol = data["ch"].split(".")[1] - - tick = self.tickDict.get(symbol, None) - if not tick: - return - + tick = self.ticks[symbol] tick.datetime = datetime.fromtimestamp(data["ts"] / 1000) - tick.date = tick.datetime.strftime("%Y%m%d") - tick.time = tick.datetime.strftime("%H:%M:%S.%f") - - t = data["tick"] - tick.open_price = float(t["open"]) - tick.high_price = float(t["high"]) - tick.low_price = float(t["low"]) - tick.last_price = float(t["close"]) - tick.volume = float(t["vol"]) - tick.pre_close = float(tick.open_price) + + tick_data = data["tick"] + tick.open_price = float(tick_data["open"]) + tick.high_price = float(tick_data["high"]) + tick.low_price = float(tick_data["low"]) + tick.last_price = float(tick_data["close"]) + tick.volume = float(tick_data["vol"]) if tick.bid_price_1: - newtick = copy(tick) - self.gateway.on_tick(newtick) - - -def print_dict(d): - """""" - print("-" * 30) - l = d.keys() - l.sort() - for k in l: - print(type(k), k, d[k]) + self.gateway.on_tick(copy(tick)) def _split_url(url): @@ -802,9 +726,9 @@ def _split_url(url): 将url拆分为host和path :return: host, path """ - m = re.match("\w+://([^/]*)(.*)", url) - if m: - return m.group(1), m.group(2) + result = re.match("\w+://([^/]*)(.*)", url) # noqa + if result: + return result.group(1), result.group(2) def create_signature(api_key, method, host, path, secret_key, get_params=None): @@ -813,18 +737,19 @@ def create_signature(api_key, method, host, path, secret_key, get_params=None): :param get_params: dict 使用GET方法时附带的额外参数(urlparams) :return: """ - sortedParams = [ + sorted_params = [ ("AccessKeyId", api_key), ("SignatureMethod", "HmacSHA256"), ("SignatureVersion", "2"), ("Timestamp", datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S")) ] + if get_params: - sortedParams.extend(get_params.items()) - sortedParams = list(sorted(sortedParams)) - encodeParams = urllib.urlencode(sortedParams) + sorted_params.extend(list(get_params.items())) + sorted_params = list(sorted(sorted_params)) + encode_params = urllib.parse.urlencode(sorted_params) - payload = [method, host, path, encodeParams] + payload = [method, host, path, encode_params] payload = "\n".join(payload) payload = payload.encode(encoding="UTF8") @@ -833,6 +758,6 @@ def create_signature(api_key, method, host, path, secret_key, get_params=None): digest = hmac.new(secret_key, payload, digestmod=hashlib.sha256).digest() signature = base64.b64encode(digest) - params = dict(sortedParams) - params["Signature"] = signature + params = dict(sorted_params) + params["Signature"] = signature.decode("UTF8") return params