[Mod]complete trading test of okex gateway

This commit is contained in:
vn.py 2019-04-04 07:40:44 +08:00
parent 7024a7d2ca
commit 7d86efce39
2 changed files with 254 additions and 191 deletions

View File

@ -10,6 +10,7 @@ from vnpy.gateway.ib import IbGateway
from vnpy.gateway.ctp import CtpGateway
from vnpy.gateway.tiger import TigerGateway
from vnpy.gateway.oes import OesGateway
from vnpy.gateway.okex import OkexGateway
from vnpy.app.cta_strategy import CtaStrategyApp
from vnpy.app.csv_loader import CsvLoaderApp
@ -22,12 +23,13 @@ def main():
event_engine = EventEngine()
main_engine = MainEngine(event_engine)
main_engine.add_gateway(CtpGateway)
main_engine.add_gateway(IbGateway)
main_engine.add_gateway(FutuGateway)
main_engine.add_gateway(BitmexGateway)
main_engine.add_gateway(TigerGateway)
main_engine.add_gateway(OesGateway)
# main_engine.add_gateway(CtpGateway)
# main_engine.add_gateway(IbGateway)
# main_engine.add_gateway(FutuGateway)
# main_engine.add_gateway(BitmexGateway)
# main_engine.add_gateway(TigerGateway)
# main_engine.add_gateway(OesGateway)
main_engine.add_gateway(OkexGateway)
main_engine.add_app(CtaStrategyApp)
main_engine.add_app(CsvLoaderApp)

View File

@ -8,6 +8,7 @@ import sys
import time
import json
import base64
import zlib
from copy import copy
from datetime import datetime
from threading import Lock
@ -39,7 +40,7 @@ from vnpy.trader.object import (
)
REST_HOST = "https://www.okex.com"
WEBSOCKET_HOST = "wss://real.okex.com:10440/websocket/okexapi?compress=true"
WEBSOCKET_HOST = "wss://real.okex.com:10442/ws/v3"
STATUS_OKEX2VT = {
"ordering": Status.SUBMITTING,
@ -88,7 +89,7 @@ class OkexGateway(BaseGateway):
def connect(self, setting: dict):
""""""
key = setting["API KEY"]
key = setting["API Key"]
secret = setting["Secret Key"]
passphrase = setting["Passphrase"]
session_number = setting["会话数"]
@ -142,7 +143,7 @@ class OkexRestApi(RestClient):
self.secret = ""
self.passphrase = ""
self.order_count = 1_000_000
self.order_count = 10000
self.order_count_lock = Lock()
self.connect_time = 0
@ -152,7 +153,8 @@ class OkexRestApi(RestClient):
Generate OKEX signature.
"""
# Sign
timestamp = str(time.time())
# timestamp = str(time.time())
timestamp = get_timestamp()
request.data = json.dumps(request.data)
if request.params:
@ -177,7 +179,7 @@ class OkexRestApi(RestClient):
self,
key: str,
secret: str,
passphrase: str
passphrase: str,
session_number: int,
proxy_host: str,
proxy_port: int,
@ -185,18 +187,21 @@ class OkexRestApi(RestClient):
"""
Initialize connection to REST server.
"""
self.key = key.encode()
self.key = key
self.secret = secret.encode()
self.passphrase = passphrase
self.connect_time = (
int(datetime.now().strftime("%y%m%d%H%M%S")) * self.order_count
)
self.connect_time = int(datetime.now().strftime("%y%m%d%H%M%S"))
self.init(REST_HOST, proxy_host, proxy_port)
self.start(session_number)
self.gateway.write_log("REST API启动成功")
self.query_time()
self.query_contract()
self.query_account()
self.query_order()
def _new_order_id(self):
with self.order_count_lock:
self.order_count += 1
@ -204,8 +209,8 @@ class OkexRestApi(RestClient):
def send_order(self, req: OrderRequest):
""""""
orderid = str(self.connect_time + self._new_order_id())
orderid = f"a{self.connect_time}{self._new_order_id()}"
data = {
"client_oid": orderid,
"type": ORDERTYPE_VT2OKEX[req.type],
@ -227,11 +232,11 @@ class OkexRestApi(RestClient):
self.add_request(
"POST",
"/api/spot/v3/orders",
callback=self.on_send_order,
data=data,
extra=order,
on_failed=self.on_send_order_failed,
on_error=self.on_send_order_error,
callback = self.on_send_order,
data = data,
extra = order,
on_failed = self.on_send_order_failed,
on_error = self.on_send_order_error,
)
self.gateway.on_order(order)
@ -239,7 +244,7 @@ class OkexRestApi(RestClient):
def cancel_order(self, req: CancelRequest):
""""""
data = {
data={
"instrument_id": req.symbol,
"client_oid": req.orderid
}
@ -248,25 +253,41 @@ class OkexRestApi(RestClient):
self.add_request(
"POST",
path,
callback=self.on_cancel_order,
data=data,
on_error=self.on_cancel_order_error,
callback = self.on_cancel_order,
data = data,
on_error = self.on_cancel_order_error,
)
def query_contract(self):
""""""
data = {
"instrument_id": req.symbol,
"client_oid": req.orderid
}
path = "/api/spot/v3/cancel_orders/" + req.orderid
self.add_request(
"POST",
path,
callback=self.on_cancel_order,
data=data,
on_error=self.on_cancel_order_error,
"GET",
"/api/spot/v3/instruments",
callback = self.on_query_contract
)
def query_account(self):
""""""
self.add_request(
"GET",
"/api/spot/v3/accounts",
callback = self.on_query_account
)
def query_order(self):
""""""
self.add_request(
"GET",
"/api/spot/v3/orders_pending",
callback = self.on_query_order
)
def query_time(self):
""""""
self.add_request(
"GET",
"/api/general/v3/time",
callback=self.on_query_time
)
def on_query_contract(self, data, request):
@ -279,8 +300,8 @@ class OkexRestApi(RestClient):
name=symbol,
product=Product.SPOT,
size=1,
pricetick=instrument_data["tick_size"]
pricetick = instrument_data["tick_size"],
gateway_name = self.gateway_name
)
self.gateway.on_contract(contract)
@ -290,6 +311,48 @@ class OkexRestApi(RestClient):
self.gateway.write_log("合约信息查询成功")
# Start websocket api after instruments data collected
self.gateway.ws_api.start()
def on_query_account(self, data, request):
""""""
for account_data in data:
account = AccountData(
accountid=account_data["currency"],
balance=float(account_data["balance"]),
frozen=float(account_data["hold"]),
gateway_name=self.gateway_name
)
self.gateway.on_account(account)
self.gateway.write_log("账户资金查询成功")
def on_query_order(self, data, request):
""""""
for order_data in data:
order = OrderData(
symbol=order_data["instrument_id"],
exchange=Exchange.OKEX,
type=ORDERTYPE_OKEX2VT[order_data["type"]],
orderid=order_data["client_oid"],
direction=DIRECTION_OKEX2VT[order_data["side"]],
price=float(order_data["price"]),
volume=float(order_data["size"]),
time=order_data["timestamp"][11:19],
status=STATUS_OKEX2VT[order_data["status"]],
gateway_name=self.gateway_name,
)
self.gateway.on_order(order)
self.gateway.write_log("委托信息查询成功")
def on_query_time(self, data, request):
""""""
server_time = data["iso"]
local_time = datetime.utcnow().isoformat()
msg = f"服务器时间:{server_time},本机时间:{local_time}"
self.gateway.write_log(msg)
def on_send_order_failed(self, status_code: str, request: Request):
"""
Callback when sending order failed on server.
@ -368,22 +431,33 @@ class OkexWebsocketApi(WebsocketClient):
self.secret = ""
self.passphrase = ""
self.callbacks = {}
self.trade_count = 10000
self.connect_time = 0
self.callbacks = {}
self.ticks = {}
self.accounts = {}
self.orders = {}
self.trades = set()
def connect(
self, key: str, secret: str, server: str, proxy_host: str, proxy_port: int
self,
key: str,
secret: str,
passphrase: str,
proxy_host: str,
proxy_port: int
):
""""""
self.key = key.encode()
self.key = key
self.secret = secret.encode()
self.passphrase = passphrase
self.connect_time = int(datetime.now().strftime("%y%m%d%H%M%S"))
self.init(WEBSOCKET_HOST, proxy_host, proxy_port)
self.start()
# self.start()
def unpack_data(self, data):
""""""
return json.loads(zlib.decompress(data, -zlib.MAX_WBITS))
def subscribe(self, req: SubscribeRequest):
"""
@ -398,10 +472,22 @@ class OkexWebsocketApi(WebsocketClient):
)
self.ticks[req.symbol] = tick
channel_ticker = f"spot/ticker:{req.symbol}"
channel_depth = f"spot/depth5:{req.symbol}"
self.callbacks[channel_ticker] = self.on_ticker
self.callbacks[channel_depth] = self.on_depth
req = {
"op": "subscribe",
"args": [channel_ticker, channel_depth]
}
self.send_packet(req)
def on_connected(self):
""""""
self.gateway.write_log("Websocket API连接成功")
self.authenticate()
self.login()
def on_disconnected(self):
""""""
@ -409,30 +495,27 @@ class OkexWebsocketApi(WebsocketClient):
def on_packet(self, packet: dict):
""""""
if "error" in packet:
self.gateway.write_log("Websocket API报错%s" % packet["error"])
if "event" in packet:
event = packet["event"]
if event == "subscribe":
return
elif event == "error":
msg = packet["message"]
self.gateway.write_log(f"Websocket API请求异常{msg}")
elif event == "login":
self.on_login(packet)
else:
channel = packet["table"]
data = packet["data"]
callback = self.callbacks[channel]
if "not valid" in packet["error"]:
self.active = False
elif "request" in packet:
req = packet["request"]
success = packet["success"]
if success:
if req["op"] == "authKey":
self.gateway.write_log("Websocket API验证授权成功")
self.subscribe_topic()
elif "table" in packet:
name = packet["table"]
callback = self.callbacks[name]
if isinstance(packet["data"], list):
for d in packet["data"]:
try:
for d in data:
callback(d)
else:
callback(packet["data"])
except:
import traceback
traceback.print_exc()
print(packet)
def on_error(self, exception_type: type, exception_value: Exception, tb):
""""""
@ -457,173 +540,151 @@ class OkexWebsocketApi(WebsocketClient):
self.key,
self.passphrase,
timestamp,
signature
signature.decode("utf-8")
]
}
self.send_packet(req)
self.callbacks['login'] = self.on_login
def subscribe_topic(self):
"""
Subscribe to all private topics.
"""
self.callbacks["spot/ticker"] = self.on_ticker
self.callbacks["spot/depth5"] = self.on_depth
self.callbacks["spot/account"] = self.on_account
self.callbacks["spot/order"] = self.on_order
# Subscribe to order update
channels = []
for instrument_id in instruments:
channel = f"spot/order:{instrument_id}"
req = {"op": "subscribe", "args": [channel]}
self.send_packet(req)
self.callbacks[channel] = self.on_trade
channels.append(channel)
req = {
"op": "subscribe",
"args": channels
}
self.send_packet(req)
# Subscribe to account update
channels = []
for currency in currencies:
channel = f"spot/account:{currency}"
req = {"op": "subscribe", "args": [channel]}
self.send_packet(req)
self.callbacks[channel] = self.on_account
channels.append(channel)
def on_login(self, d: dict):
req = {
"op": "subscribe",
"args": channels
}
self.send_packet(req)
def on_login(self, data: dict):
""""""
data = d['data']
success = data.get("success", False)
if data['success']:
self.gateway.write_log("Websocket接口登录成功")
if success:
self.gateway.write_log("Websocket API登录成功")
self.subscribe_topic()
else:
self.gateway.write_log("Websocket接口登录失败")
self.gateway.write_log("Websocket API登录失败")
def on_tick(self, d):
def on_ticker(self, d):
""""""
symbol = d["symbol"]
symbol = d["instrument_id"]
tick = self.ticks.get(symbol, None)
if not tick:
return
tick.last_price = d["price"]
tick.last_price = d["last"]
tick.open = d["open_24h"]
tick.high = d["high_24h"]
tick.low = d["low_24h"]
tick.volume = d["base_volume_24h"]
tick.datetime = datetime.strptime(
d["timestamp"], "%Y-%m-%dT%H:%M:%S.%fZ")
self.gateway.on_tick(copy(tick))
def on_depth(self, d):
""""""
symbol = d["symbol"]
tick = self.ticks.get(symbol, None)
if not tick:
return
for tick_data in d:
symbol = d["instrument_id"]
tick = self.ticks.get(symbol, None)
if not tick:
return
for n, buf in enumerate(d["bids"][:5]):
price, volume = buf
tick.__setattr__("bid_price_%s" % (n + 1), price)
tick.__setattr__("bid_volume_%s" % (n + 1), volume)
bids = d["bids"]
asks = d["asks"]
for n, buf in enumerate(bids):
price, volume, _ = buf
tick.__setattr__("bid_price_%s" % (n + 1), price)
tick.__setattr__("bid_volume_%s" % (n + 1), volume)
for n, buf in enumerate(d["asks"][:5]):
price, volume = buf
tick.__setattr__("ask_price_%s" % (n + 1), price)
tick.__setattr__("ask_volume_%s" % (n + 1), volume)
for n, buf in enumerate(asks):
price, volume, _ = buf
tick.__setattr__("ask_price_%s" % (n + 1), price)
tick.__setattr__("ask_volume_%s" % (n + 1), volume)
tick.datetime = datetime.strptime(
d["timestamp"], "%Y-%m-%dT%H:%M:%S.%fZ")
self.gateway.on_tick(copy(tick))
def on_trade(self, d):
""""""
# Filter trade update with no trade volume and side (funding)
if not d["lastQty"] or not d["side"]:
return
tradeid = d["execID"]
if tradeid in self.trades:
return
self.trades.add(tradeid)
if d["clOrdID"]:
orderid = d["clOrdID"]
else:
orderid = d["orderID"]
trade = TradeData(
symbol=d["symbol"],
exchange=Exchange.OKEX,
orderid=orderid,
tradeid=tradeid,
direction=DIRECTION_OKEX2VT[d["side"]],
price=d["lastPx"],
volume=d["lastQty"],
time=d["timestamp"][11:19],
gateway_name=self.gateway_name,
)
self.gateway.on_trade(trade)
tick.datetime = datetime.strptime(
d["timestamp"], "%Y-%m-%dT%H:%M:%S.%fZ")
self.gateway.on_tick(copy(tick))
def on_order(self, d):
""""""
if "ordStatus" not in d:
order = OrderData(
symbol=d["instrument_id"],
exchange=Exchange.OKEX,
type=ORDERTYPE_OKEX2VT[d["type"]],
orderid=d["client_oid"],
direction=DIRECTION_OKEX2VT[d["side"]],
price=d["price"],
volume=d["size"],
traded=d["filled_size"],
time=d["timestamp"][11:19],
status=STATUS_OKEX2VT[d["status"]],
gateway_name=self.gateway_name,
)
self.gateway.on_order(copy(order))
trade_volume = float(d.get("last_fill_qty", 0))
if not trade_volume:
return
sysid = d["orderID"]
order = self.orders.get(sysid, None)
if not order:
if d["clOrdID"]:
orderid = d["clOrdID"]
else:
orderid = sysid
self.trade_count += 1
tradeid = f"{self.connect_time}{self.trade_count}"
# time = d["timestamp"][11:19]
order = OrderData(
symbol=d["symbol"],
exchange=Exchange.OKEX,
type=ORDERTYPE_OKEX2VT[d["ordType"]],
orderid=orderid,
direction=DIRECTION_OKEX2VT[d["side"]],
price=d["price"],
volume=d["orderQty"],
time=d["timestamp"][11:19],
gateway_name=self.gateway_name,
)
self.orders[sysid] = order
order.traded = d.get("cumQty", order.traded)
order.status = STATUS_OKEX2VT.get(d["ordStatus"], order.status)
self.gateway.on_order(copy(order))
trade = TradeData(
symbol=order.symbol,
exchange=order.exchange,
orderid=order.orderid,
tradeid=tradeid,
direction=order.direction,
price=float(d["last_fill_px"]),
volume=float(trade_volume),
time=d["last_fill_time"][11:19],
gateway_name=self.gateway_name
)
self.gateway.on_trade(trade)
def on_account(self, d):
""""""
accountid = str(d["account"])
account = self.accounts.get(accountid, None)
if not account:
account = AccountData(accountid=accountid,
gateway_name=self.gateway_name)
self.accounts[accountid] = account
account.balance = d.get("marginBalance", account.balance)
account.available = d.get("availableMargin", account.available)
account.frozen = account.balance - account.available
self.gateway.on_account(copy(account))
def on_contract(self, d):
""""""
if "tickSize" not in d:
return
if not d["lotSize"]:
return
contract = ContractData(
symbol=d["symbol"],
exchange=Exchange.OKEX,
name=d["symbol"],
product=Product.FUTURES,
pricetick=d["tickSize"],
size=d["lotSize"],
stop_supported=True,
net_position=True,
gateway_name=self.gateway_name,
account = AccountData(
accountid=d["currency"],
balance=float(d["balance"]),
frozen=float(d["hold"]),
gateway_name=self.gateway_name
)
self.gateway.on_contract(contract)
self.gateway.on_account(copy(account))
def generate_signature(msg: str, secret_key: str):
"""OKEX V3 signature"""
return base64.b64encode(hmac.new(secret_key, msg.encode(), hashlib.sha256).digest())
def get_timestamp():
""""""
now = datetime.utcnow()
timestamp = now.isoformat("T", "milliseconds")
return timestamp + "Z"