diff --git a/tests/api/base/RestClientTest.py b/tests/api/base/RestClientTest.py index 36da9289..6f350638 100644 --- a/tests/api/base/RestClientTest.py +++ b/tests/api/base/RestClientTest.py @@ -51,7 +51,7 @@ class RestfulClientTest(unittest.TestCase): def callback(data, req): self.c.p.set_result(data['args']) - self.c.addReq('GET', '/get', callback, params=args) + self.c.addRequest('GET', '/get', callback, params=args) res = self.c.p.get(3) self.assertEqual(args, res) @@ -63,7 +63,7 @@ class RestfulClientTest(unittest.TestCase): def callback(data, req): self.c.p.set_result(data['json']) - self.c.addReq('POST', '/post', callback, data=body) + self.c.addRequest('POST', '/post', callback, data=body) res = self.c.p.get(3) self.assertEqual(body, res) @@ -72,7 +72,7 @@ class RestfulClientTest(unittest.TestCase): def callback(data, req): pass - self.c.addReq('POST', '/status/401', callback) + self.c.addRequest('POST', '/status/401', callback) with self.assertRaises(FailedError): self.c.p.get(3) @@ -80,7 +80,7 @@ class RestfulClientTest(unittest.TestCase): def callback(data, req): pass - self.c.addReq('GET', '/image/svg', callback) + self.c.addRequest('GET', '/image/svg', callback) with self.assertRaises(JSONDecodeError): self.c.p.get(3) diff --git a/tests/api/base/WebSocketClientTest.py b/tests/api/base/WebSocketClientTest.py index 8361f647..8d931ecf 100644 --- a/tests/api/base/WebSocketClientTest.py +++ b/tests/api/base/WebSocketClientTest.py @@ -2,10 +2,10 @@ import unittest from Promise import Promise -from vnpy.api.websocket import WebsocketClient +from vnpy.api.websocket import WebSocketClient -class TestWebsocketClient(WebsocketClient): +class TestWebsocketClient(WebSocketClient): def __init__(self): host = 'wss://echo.websocket.org' @@ -13,11 +13,11 @@ class TestWebsocketClient(WebsocketClient): self.init(host) self.p = Promise() - def onMessage(self, packet): + def onPacket(self, packet): self.p.set_result(packet) pass - def onConnect(self): + def onConnected(self): pass def onError(self, exceptionType, exceptionValue, tb): @@ -38,7 +38,7 @@ class WebsocketClientTest(unittest.TestCase): req = { 'name': 'val' } - self.c.sendReq(req) + self.c.sendPacket(req) res = self.c.p.get(3) self.assertEqual(res, req) diff --git a/vnpy/api/okexfuture/OkexFutureApi.py b/vnpy/api/okexfuture/OkexFutureApi.py index 2849c8c9..b6d801c9 100644 --- a/vnpy/api/okexfuture/OkexFutureApi.py +++ b/vnpy/api/okexfuture/OkexFutureApi.py @@ -75,6 +75,7 @@ class OkexFutureUserInfo(object): #---------------------------------------------------------------------- def __init__(self): + self.easySymbol = None # 'etc', 'btc', 'eth', etc. self.accountRights = None self.keepDeposit = None self.profitReal = None @@ -142,7 +143,7 @@ class OkexFutureRestClient(OkexFutureRestBase): def sendOrder(self, symbol, contractType, orderType, volume, onSuccess, onFailed=None, price=None, useMarketPrice=False, leverRate=None, - extra=None): # type:(str, OkexFutureContractType, OkexFutureOrderType, float, Callable[[int, Any], Any], Callable[[int, Any], Any], float, bool, Union[int, None], Any)->Request + extra=None): # type:(str, OkexFutureContractType, OkexFutureOrderType, float, Callable[[str, Any], Any], Callable[[int, Any], Any], float, bool, Union[int, None], Any)->Request """ :param symbol: str :param contractType: OkexFutureContractType @@ -171,11 +172,11 @@ class OkexFutureRestClient(OkexFutureRestBase): if leverRate: data['lever_rate'] = leverRate # 杠杆倍数 - request = self.addReq('POST', + request = self.addRequest('POST', '/future_trade.do', - callback=self.onOrderSent, - data=data, - extra=_OkexFutureCustomExtra(onSuccess, onFailed, extra)) + callback=self.onOrderSent, + data=data, + extra=_OkexFutureCustomExtra(onSuccess, onFailed, extra)) return request #---------------------------------------------------------------------- @@ -195,20 +196,22 @@ class OkexFutureRestClient(OkexFutureRestBase): 'contractType': contractType, 'order_id': orderId } - return self.addReq('POST', + return self.addRequest('POST', '/future_cancel.do', - callback=self.onOrderCanceled, - data=data, - extra=_OkexFutureCustomExtra(onSuccess, onFailed, extra)) + callback=self.onOrderCanceled, + data=data, + extra=_OkexFutureCustomExtra(onSuccess, onFailed, extra)) #---------------------------------------------------------------------- def queryOrder(self, symbol, contractType, orderId, onSuccess, onFailed=None, - extra=None): # type: (str, OkexFutureContractType, str, Callable[[OkexFutureOrder, Any], Any], Callable[[int, Any], Any], Any)->Request + extra=None): # type: (str, OkexFutureContractType, str, Callable[[List[OkexFutureOrder], Any], Any], Callable[[int, Any], Any], Any)->Request """ + @note onSuccess接收的第一个参数是列表,并且有可能为空 + :param symbol: str :param contractType: OkexFutureContractType :param orderId: str - :param onSuccess: (OkexFutureOrder, extra:Any)->Any + :param onSuccess: (orders: List[OkexFutureOrder], extra:Any)->Any :param onFailed: (extra: Any)->Any :param extra: Any :return: Request @@ -218,41 +221,43 @@ class OkexFutureRestClient(OkexFutureRestBase): 'contractType': contractType, 'order_id': orderId } - return self.addReq('POST', + return self.addRequest('POST', '/future_order_info.do', - callback=self.onOrder, - data=data, - extra=_OkexFutureCustomExtra(onSuccess, onFailed, extra)) + callback=self.onOrder, + data=data, + extra=_OkexFutureCustomExtra(onSuccess, onFailed, extra)) #---------------------------------------------------------------------- def queryOrders(self, symbol, contractType, status, onSuccess, onFailed=None, - pageIndex=None, pageLength=50, - extra=None): # type: (str, OkexFutureContractType, OkexFutureOrderStatus, Callable[[OkexFutureOrder, Any], Any], Callable[[int, Any], Any], int, int, Any)->Request + pageIndex=0, pageLength=50, + extra=None): # type: (str, OkexFutureContractType, OkexFutureOrderStatus, Callable[[List[OkexFutureOrder], Any], Any], Callable[[int, Any], Any], int, int, Any)->Request """ + @note onSuccess接收的第一个参数是列表,并且有可能为空 + :param symbol: str :param contractType: OkexFutureContractType - :param orderId: str - :param onSuccess: (OkexFutureOrder, extra:Any)->Any + :param onSuccess: (List[OkexFutureOrder], extra:Any)->Any :param onFailed: (extra: Any)->Any + :param pageIndex: 页码 + :param pageLength: 最大显示数量(最大值50) :param extra: Any :return: Request """ data = { 'symbol': symbol, - 'contractType': contractType, + 'contract_type': contractType, 'status': status, - 'order_id': '-1', - 'pageLength': 50 + 'order_id': -1, + 'current_page': pageIndex, + 'page_length': pageLength } - if pageIndex: - data['page_index'] = pageIndex - - return self.addReq('POST', + + return self.addRequest('POST', '/future_order_info.do', - callback=self.onOrder, - data=data, - extra=_OkexFutureCustomExtra(onSuccess, onFailed, extra)) + callback=self.onOrder, + data=data, + extra=_OkexFutureCustomExtra(onSuccess, onFailed, extra)) #---------------------------------------------------------------------- def queryUserInfo(self, onSuccess, onFailed=None, @@ -264,24 +269,32 @@ class OkexFutureRestClient(OkexFutureRestBase): :param extra: Any :return: Request """ - return self.addReq('POST', + return self.addRequest('POST', '/future_userinfo.do', - callback=self.onOrder, - extra=_OkexFutureCustomExtra(onSuccess, onFailed, extra)) + callback=self.onOrder, + extra=_OkexFutureCustomExtra(onSuccess, onFailed, extra)) #---------------------------------------------------------------------- def queryPosition(self, symbol, contractType, onSuccess, onFailed=None, extra=None): # type: (str, OkexFutureContractType, Callable[[OkexFuturePosition, Any], Any], Callable[[int, Any], Any], Any)->Request + """ + :param symbol: OkexFutureSymbol + :param contractType: OkexFutureContractType + :param onSuccess: (pos:OkexFuturePosition, extra: any)->Any + :param onFailed: (errorCode: int, extra: any)->Any + :param extra: + :return: + """ data = { 'symbol': symbol, 'contractType': contractType } - return self.addReq('POST', + return self.addRequest('POST', '/future_position.do', - data=data, - callback=self.onPosition, - extra=_OkexFutureCustomExtra(onSuccess, onFailed, extra)) + data=data, + callback=self.onPosition, + extra=_OkexFutureCustomExtra(onSuccess, onFailed, extra)) #---------------------------------------------------------------------- @staticmethod @@ -323,24 +336,25 @@ class OkexFutureRestClient(OkexFutureRestBase): success = data['result'] extra = req.extra # type: _OkexFutureCustomExtra if success: - order = data['orders'][0] - okexOrder = OkexFutureOrder() - - okexOrder.volume = order['amount'] - okexOrder.contractName = order['contract_name'] - okexOrder.createDate = order['create_date'] - okexOrder.tradedVolume = order['deal_amount'] - okexOrder.fee = order['fee'] - okexOrder.leverRate = order['lever_rate'] - okexOrder.remoteId = order['order_id'] - okexOrder.price = order['price'] - okexOrder.priceAvg = order['price_avg'] - okexOrder.status = order['status'] - okexOrder.orderType = order['type'] - okexOrder.unitAmount = order['unit_amount'] - okexOrder.symbol = order['symbol'] - - extra.onSuccess(okexOrder, extra.extra) + orders = [] + for order in data['orders']: + okexOrder = OkexFutureOrder() + + okexOrder.volume = order['amount'] + okexOrder.contractName = order['contract_name'] + okexOrder.createDate = order['create_date'] + okexOrder.tradedVolume = order['deal_amount'] + okexOrder.fee = order['fee'] + okexOrder.leverRate = order['lever_rate'] + okexOrder.remoteId = order['order_id'] + okexOrder.price = order['price'] + okexOrder.priceAvg = order['price_avg'] + okexOrder.status = order['status'] + okexOrder.orderType = order['type'] + okexOrder.unitAmount = order['unit_amount'] + okexOrder.symbol = order['symbol'] + orders.append(okexOrder) + extra.onSuccess(orders, extra.extra) else: if extra.onFailed: code = 0 @@ -356,8 +370,9 @@ class OkexFutureRestClient(OkexFutureRestBase): if success: infos = data['info'] uis = [] - for symbol, info in infos.items(): # type: str, dict + for easySymbol, info in infos.items(): # type: str, dict ui = OkexFutureUserInfo() + ui.easySymbol = easySymbol ui.accountRights = info['account_rights'] ui.keepDeposit = info['keep_deposit'] ui.profitReal = info['profit_real'] @@ -406,9 +421,9 @@ class OkexFutureRestClient(OkexFutureRestBase): code = data['error_code'] extra.onFailed(code, extra.extra) - #---------------------------------------------------------------------- + #--------------------------------------------------------------------- @staticmethod - def errorCode2String(code): + def errorCodeToString(code): assert code in errorCodeMap return errorCodeMap[code] diff --git a/vnpy/api/okexfuture/vnokexFuture.py b/vnpy/api/okexfuture/vnokexFuture.py index 535ea7af..0b3a7e2e 100644 --- a/vnpy/api/okexfuture/vnokexFuture.py +++ b/vnpy/api/okexfuture/vnokexFuture.py @@ -2,13 +2,22 @@ import hashlib import urllib +from vnpy.api.rest import Request, RestClient -######################################################################## -from vnpy.api.rest import RestClient, Request + +#---------------------------------------------------------------------- +def sign(dataWithApiKey, apiSecret): + """ + :param dataWithApiKey: sorted urlencoded args with apiKey + :return: param 'sign' for okex api + """ + dataWithSecret = dataWithApiKey + "&secret_key=" + apiSecret + return hashlib.md5(dataWithSecret.encode()).hexdigest().upper() ######################################################################## class OkexFutureRestBase(RestClient): + host = 'https://www.okex.com/api/v1' #---------------------------------------------------------------------- def __init__(self): @@ -20,11 +29,11 @@ class OkexFutureRestBase(RestClient): # noinspection PyMethodOverriding def init(self, apiKey, apiSecret): # type: (str, str) -> any - super(OkexFutureRestBase, self).init('https://www.okex.com/api/v1') + super(OkexFutureRestBase, self).init(self.host) self.apiKey = apiKey self.apiSecret = apiSecret -#---------------------------------------------------------------------- + #---------------------------------------------------------------------- def beforeRequest(self, req): # type: (Request)->Request args = req.params or {} args.update(req.data or {}) @@ -33,11 +42,9 @@ class OkexFutureRestBase(RestClient): if 'apiKey' not in args: args['api_key'] = self.apiKey data = urllib.urlencode(sorted(args.items())) - data += "&secret_key=" + self.apiSecret - - sign = hashlib.md5(data.encode()).hexdigest().upper() - data += "&sign=" + sign + signature = sign(data, self.apiSecret) + data += "&sign=" + signature req.headers = {'Content-Type': 'application/x-www-form-urlencoded'} + req.data = data return req - diff --git a/vnpy/api/rest/RestClient.py b/vnpy/api/rest/RestClient.py index a1f5c06e..420637b1 100644 --- a/vnpy/api/rest/RestClient.py +++ b/vnpy/api/rest/RestClient.py @@ -139,10 +139,10 @@ class RestClient(object): self._queue.join() #---------------------------------------------------------------------- - def addReq(self, method, path, callback, - params=None, data=None, headers = None, - onFailed=None, skipDefaultOnFailed=True, - extra=None): # type: (str, str, Callable[[dict, Request], Any], dict, dict, dict, Callable[[dict, Request], Any], bool, Any)->Request + def addRequest(self, method, path, callback, + params=None, data=None, headers = None, + onFailed=None, skipDefaultOnFailed=True, + extra=None): # type: (str, str, Callable[[dict, Request], Any], dict, dict, dict, Callable[[dict, Request], Any], bool, Any)->Request """ 发送一个请求 :param method: GET, POST, PUT, DELETE, QUERY @@ -171,7 +171,7 @@ class RestClient(object): try: req = self._queue.get(timeout=1) try: - self._processReq(req, session) + self._processRequest(req, session) finally: self._queue.task_done() except Empty: @@ -215,8 +215,10 @@ class RestClient(object): sys.excepthook(exceptionType, exceptionValue, tb) #---------------------------------------------------------------------- - def _processReq(self, req, session): # type: (Request, requests.Session)->None - """处理请求""" + def _processRequest(self, req, session): # type: (Request, requests.Session)->None + """ + 用于内部:将请求发送出去 + """ try: req = self.beforeRequest(req) diff --git a/vnpy/api/websocket/WebSocketClient.py b/vnpy/api/websocket/WebSocketClient.py index 12cbfca5..1ae22dae 100644 --- a/vnpy/api/websocket/WebSocketClient.py +++ b/vnpy/api/websocket/WebSocketClient.py @@ -9,11 +9,27 @@ import time from abc import abstractmethod from threading import Thread, Lock -import vnpy.api.websocket +import websocket -class WebsocketClient(object): - """Websocket API""" +class WebSocketClient(object): + """ + Websocket API + + 继承使用该类。 + 实例化之后,应调用start开始后台线程。调用start()函数会自动连接websocket。 + 若要终止后台线程,请调用stop()。 stop()函数会顺便断开websocket。 + + 可以重写以下函数: + onConnected + onDisconnected + onPacket + onError + + 当然,为了不让用户随意自定义,用自己的init函数覆盖掉原本的init(host)也是个不错的选择。 + + @note 继承使用该类 + """ #---------------------------------------------------------------------- def __init__(self): @@ -26,6 +42,15 @@ class WebsocketClient(object): self._workerThread = None # type: Thread self._pingThread = None # type: Thread self._active = False + + self.createConnection = websocket.create_connection + + #---------------------------------------------------------------------- + def setCreateConnection(self, func): + """ + for internal usage + """ + self.createConnection = func #---------------------------------------------------------------------- def init(self, host): @@ -43,7 +68,7 @@ class WebsocketClient(object): self._pingThread = Thread(target=self._runPing) self._pingThread.start() - self.onConnect() + self.onConnected() #---------------------------------------------------------------------- def stop(self): @@ -55,18 +80,18 @@ class WebsocketClient(object): self._disconnect() #---------------------------------------------------------------------- - def sendReq(self, req): # type: (dict)->None - """发出请求""" - return self._get_ws().send(json.dumps(req), opcode=vnpy.api.websocket.ABNF.OPCODE_TEXT) + def sendPacket(self, dictObj): # type: (dict)->None + """发出请求:相当于sendText(json.dumps(dictObj))""" + return self._get_ws().send(json.dumps(dictObj), opcode=websocket.ABNF.OPCODE_TEXT) #---------------------------------------------------------------------- def sendText(self, text): # type: (str)->None - """发出请求""" - return self._get_ws().send(text, opcode=vnpy.api.websocket.ABNF.OPCODE_TEXT) + """发送文本数据""" + return self._get_ws().send(text, opcode=websocket.ABNF.OPCODE_TEXT) #---------------------------------------------------------------------- def sendData(self, data): # type: (bytes)->None - """发出请求""" + """发送字节数据""" return self._get_ws().send_binary(data) #---------------------------------------------------------------------- @@ -78,8 +103,8 @@ class WebsocketClient(object): #---------------------------------------------------------------------- def _connect(self): """""" - self._ws = vnpy.api.websocket.create_connection(self.host, sslopt={'cert_reqs': ssl.CERT_NONE}) - self.onConnect() + self._ws = self.createConnection(self.host, sslopt={'cert_reqs': ssl.CERT_NONE}) + self.onConnected() #---------------------------------------------------------------------- def _disconnect(self): @@ -104,12 +129,13 @@ class WebsocketClient(object): try: stream = ws.recv() if not stream: + self.onDisconnected() if self._active: self._reconnect() continue data = json.loads(stream) - self.onMessage(data) + self.onPacket(data) except: et, ev, tb = sys.exc_info() self.onError(et, ev, tb) @@ -125,17 +151,25 @@ class WebsocketClient(object): #---------------------------------------------------------------------- def _ping(self): - return self._get_ws().send('ping', vnpy.api.websocket.ABNF.OPCODE_PING) + return self._get_ws().send('ping', websocket.ABNF.OPCODE_PING) #---------------------------------------------------------------------- - @abstractmethod - def onConnect(self): - """连接回调""" + def onConnected(self): + """ + 连接成功回调 + """ + pass + + #---------------------------------------------------------------------- + def onDisconnected(self): + """ + 连接断开回调 + """ pass #---------------------------------------------------------------------- @abstractmethod - def onMessage(self, packet): + def onPacket(self, packet): """ 数据回调。 只有在数据为json包的时候才会触发这个回调 @@ -145,7 +179,6 @@ class WebsocketClient(object): pass #---------------------------------------------------------------------- - @abstractmethod def onError(self, exceptionType, exceptionValue, tb): """Python错误回调""" - pass + return sys.excepthook(exceptionType, exceptionValue, tb) diff --git a/vnpy/api/websocket/__init__.py b/vnpy/api/websocket/__init__.py index 3e58eeb4..5c44c43f 100644 --- a/vnpy/api/websocket/__init__.py +++ b/vnpy/api/websocket/__init__.py @@ -1 +1 @@ -from .WebSocketClient import WebsocketClient +from .WebSocketClient import WebSocketClient diff --git a/vnpy/trader/gateway/okexFutureGateway/okexFutureGateway.py b/vnpy/trader/gateway/okexFutureGateway/okexFutureGateway.py index 3b5287c3..23bc9bed 100644 --- a/vnpy/trader/gateway/okexFutureGateway/okexFutureGateway.py +++ b/vnpy/trader/gateway/okexFutureGateway/okexFutureGateway.py @@ -4,6 +4,7 @@ from __future__ import print_function import json from abc import abstractmethod, abstractproperty +from datetime import datetime from typing import Dict @@ -11,32 +12,74 @@ from vnpy.api.okexfuture.OkexFutureApi import * from vnpy.trader.vtFunction import getJsonPath from vnpy.trader.vtGateway import * -orderTypeMap = { +_orderTypeMap = { (constant.DIRECTION_LONG, constant.OFFSET_OPEN): OkexFutureOrderType.OpenLong, (constant.DIRECTION_SHORT, constant.OFFSET_OPEN): OkexFutureOrderType.OpenShort, (constant.DIRECTION_LONG, constant.OFFSET_CLOSE): OkexFutureOrderType.CloseLong, (constant.DIRECTION_SHORT, constant.OFFSET_CLOSE): OkexFutureOrderType.CloseShort, } -orderTypeMapReverse = {v: k for k, v in orderTypeMap.items()} +_orderTypeMapReverse = {v: k for k, v in _orderTypeMap.items()} -contracts = ( - 'btc_usd', 'ltc_usd', 'eth_usd', 'etc_usd', 'bch_usd', -) - -contractTypeMap = { +_contractTypeMap = { 'THISWEEK': OkexFutureContractType.ThisWeek, 'NEXTWEEK': OkexFutureContractType.NextWeek, 'QUARTER': OkexFutureContractType.Quarter, } +_contractTypeMapReverse = {v: k for k, v in _contractTypeMap.items()} + +_remoteSymbols = { + OkexFutureSymbol.BTC, + OkexFutureSymbol.LTC, + OkexFutureSymbol.ETH, + OkexFutureSymbol.ETC, + OkexFutureSymbol.BCH, +} # symbols for ui, -# keys:给用户看的symbols +# keys:给用户看的symbols : f"{internalSymbol}_{contractType}" # values: API接口使用的symbol和contractType字段 -symbolsForUi = {} # type: dict[str, [str, str]] -for s in contracts: - for vtContractType, contractType_ in contractTypeMap.items(): - vtSymbol = s + '_' + vtContractType - symbolsForUi[vtSymbol] = (s, contractType_) +_symbolsForUi = {(remoteSymbol.upper() + '_' + upperContractType.upper()): (remoteSymbol, remoteContractType) + for remoteSymbol in _remoteSymbols + for upperContractType, remoteContractType in + _contractTypeMap.items()} # type: Dict[str, List[str, str]] +_symbolsForUiReverse = {v: k for k, v in _symbolsForUi.items()} + + +#---------------------------------------------------------------------- +def localOrderTypeToRemote(direction, offset): # type: (str, str)->str + return _orderTypeMap[(direction, offset)] + + +#---------------------------------------------------------------------- +def remoteOrderTypeToLocal(orderType): # type: (str)->(str, str) + """ + :param orderType: + :return: direction, offset + """ + return _orderTypeMapReverse[orderType] + + +#---------------------------------------------------------------------- +def localContractTypeToRemote(localContractType): + return _contractTypeMap[localContractType] + + +#---------------------------------------------------------------------- +def remoteContractTypeToLocal(remoteContractType): + return _contractTypeMapReverse[remoteContractType] + + +#---------------------------------------------------------------------- +def localSymbolToRemote(symbol): # type: (str)->(OkexFutureSymbol, OkexFutureContractType) + """ + :return: remoteSymbol, remoteContractType + """ + return _symbolsForUi[symbol] + + +#---------------------------------------------------------------------- +def remoteSymbolToLocal(remoteSymbol, localContractType): + return remoteSymbol.upper() + localContractType ######################################################################## @@ -55,11 +98,6 @@ class VnpyGateway(VtGateway): def gatewayName(self): # type: ()->str return 'VnpyGateway' - #---------------------------------------------------------------------- - @abstractproperty - def exchange(self): # type: ()->str - return constant.EXCHANGE_UNKNOWN - #---------------------------------------------------------------------- def readConfig(self): """ @@ -88,9 +126,13 @@ class VnpyGateway(VtGateway): """ pass + +######################################################################## class _Order(object): _lastLocalId = 0 - def __init__(self): + + #---------------------------------------------------------------------- + def __ini__(self): _Order._lastLocalId += 1 self.localId = str(_Order._lastLocalId) self.remoteId = None @@ -102,7 +144,7 @@ class OkexFutureGateway(VnpyGateway): """OKEX期货交易接口""" #---------------------------------------------------------------------- - def __init__(self, eventEngine, *args, **kwargs): # args, kwargs is needed for compatibility + def __init__(self, eventEngine, *_, **__): # args, kwargs is needed for compatibility """Constructor""" super(OkexFutureGateway, self).__init__(eventEngine) self.apiKey = None # type: str @@ -111,14 +153,17 @@ class OkexFutureGateway(VnpyGateway): self.leverRate = 1 self.symbols = [] + self.tradeID = 0 self._orders = {} # type: Dict[str, _Order] + self._remoteIds = {} # type: Dict[str, _Order] + #---------------------------------------------------------------------- @property def gatewayName(self): return 'OkexFutureGateway' #---------------------------------------------------------------------- - @abstractproperty + @property def exchange(self): # type: ()->str return constant.EXCHANGE_OKEXFUTURE @@ -153,48 +198,57 @@ class OkexFutureGateway(VnpyGateway): pass #---------------------------------------------------------------------- - @staticmethod - def _contractTypeFromSymbol(symbol): - return symbolsForUi[symbol] + def _getOrderByLocalId(self, localId): + return self._orders[localId] #---------------------------------------------------------------------- - def _getOrder(self, localId): - return self._orders[localId] - + def _getOrderByRemoteId(self, remoteId): + return self._remoteIds[remoteId] + + #---------------------------------------------------------------------- + def _saveRemoteId(self, remoteId, myorder): + myorder.remoteId = remoteId + self._remoteIds[remoteId] = myorder + + #---------------------------------------------------------------------- + def _genereteLocalOrder(self, symbol, price, volume, direction, offset): + myorder = _Order() + localId = myorder.localId + self._orders[localId] = myorder + myorder.vtOrder = VtOrderData.createFromGateway(self, + self.exchange, + localId, + symbol, + price, + volume, + direction, + offset) + return myorder + #---------------------------------------------------------------------- def sendOrder(self, vtRequest): # type: (VtOrderReq)->str """发单""" - myorder = _Order() - localId = myorder.localId - - vtOrder = VtOrderData() - vtOrder.orderID = localId - vtOrder.vtOrderID = ".".join([self.gatewayName, localId]) - vtOrder.exchange = self.exchange - - vtOrder.symbol = vtRequest.symbol - vtOrder.vtSymbol = '.'.join([vtOrder.symbol, vtOrder.exchange]) - vtOrder.price = vtRequest.price - vtOrder.totalVolume = vtRequest.volume - vtOrder.direction = vtRequest.direction - - myorder.vtOrder = vtOrder + myorder = self._genereteLocalOrder(vtRequest.symbol, + vtRequest.price, + vtRequest.volume, + vtRequest.direction, + vtRequest.offset) - symbol, contractType = self._contractTypeFromSymbol(vtRequest.symbol) - orderType = orderTypeMap[(vtRequest.priceType, vtRequest.offset)] # 开多、开空、平多、平空 + remoteSymbol, remoteContractType = localSymbolToRemote(vtRequest.symbol) + orderType = _orderTypeMap[(vtRequest.priceType, vtRequest.offset)] # 开多、开空、平多、平空 userMarketPrice = False if vtRequest.priceType == constant.PRICETYPE_MARKETPRICE: userMarketPrice = True - - self.api.sendOrder(symbol=symbol, - contractType=contractType, + + self.api.sendOrder(symbol=remoteSymbol, + contractType=remoteContractType, orderType=orderType, volume=vtRequest.volume, price=vtRequest.price, useMarketPrice=userMarketPrice, leverRate=self.leverRate, - onSuccess=self.onOrderSent, + onSuccess=self._onOrderSent, extra=None, ) @@ -203,44 +257,160 @@ class OkexFutureGateway(VnpyGateway): #---------------------------------------------------------------------- def cancelOrder(self, vtCancel): # type: (VtCancelOrderReq)->None """撤单""" - myorder = self._getOrder(vtCancel.orderID) - symbol, contractType = self._contractTypeFromSymbol(vtCancel.symbol) + myorder = self._getOrderByLocalId(vtCancel.orderID) + symbol, contractType = localSymbolToRemote(vtCancel.symbol) self.api.cancelOrder(symbol=symbol, contractType=contractType, orderId=myorder.remoteId, - onSuccess=self.onOrderCanceled, + onSuccess=self._onOrderCanceled, extra=myorder, ) # cancelDict: 不存在的,没有localId就没有remoteId,没有remoteId何来cancel #---------------------------------------------------------------------- - def queryOrder(self): + def queryOrders(self, symbol, contractType, + status): # type: (str, str, OkexFutureOrderStatus)->None + """ + :param symbol: + :param contractType: 这个参数可以传'THISWEEK', 'NEXTWEEK', 'QUARTER',也可以传OkexFutureContractType + :param status: OkexFutureOrderStatus + :return: + """ + + if contractType in _contractTypeMap: + localContractType = contractType + remoteContractType = localContractTypeToRemote(localContractType) + else: + remoteContractType = contractType + localContractType = remoteContractTypeToLocal(remoteContractType) + + self.api.queryOrders(symbol=symbol, + contractType=remoteContractType, + status=status, + onSuccess=self._onQueryOrders, + extra=localContractType) #---------------------------------------------------------------------- def qryAccount(self): + self.api.queryUserInfo(onSuccess=self._onQueryAccount) """查询账户资金""" pass #---------------------------------------------------------------------- def qryPosition(self): """查询持仓""" - self.api.spotUserInfo() + for remoteSymbol in _remoteSymbols: + for localContractType, remoteContractType in _contractTypeMap.items(): + self.api.queryPosition(remoteSymbol, + remoteContractType, + onSuccess=self._onQueryPosition, + extra=localContractType + ) #---------------------------------------------------------------------- def close(self): """关闭""" - self.api.close() + self.api.stop() #---------------------------------------------------------------------- - def onOrderSent(self, remoteId, myorder): #type: (int, _Order)->None + def _onOrderSent(self, remoteId, myorder): #type: (str, _Order)->None myorder.remoteId = remoteId myorder.vtOrder.status = constant.STATUS_NOTTRADED + self._saveRemoteId(remoteId, myorder) self.onOrder(myorder.vtOrder) - + + # #---------------------------------------------------------------------- + # def _pushOrderAsTraded(self, order): + # trade = VtTradeData() + # trade.gatewayName = order.gatewayName + # trade.symbol = order.symbol + # trade.vtSymbol = order.vtSymbol + # trade.orderID = order.orderID + # trade.vtOrderID = order.vtOrderID + # self.tradeID += 1 + # trade.tradeID = str(self.tradeID) + # trade.vtTradeID = '.'.join([self.gatewayName, trade.tradeID]) + # trade.direction = order.direction + # trade.price = order.price + # trade.volume = order.tradedVolume + # trade.tradeTime = datetime.now().strftime('%H:%M:%S') + # self.onTrade(trade) + #---------------------------------------------------------------------- @staticmethod - def onOrderCanceled(myorder): #type: (_Order)->None + def _onOrderCanceled(myorder): #type: (_Order)->None myorder.vtOrder.status = constant.STATUS_CANCELLED + + #---------------------------------------------------------------------- + def _onQueryOrders(self, orders, extra): # type: (List[OkexFutureOrder], Any)->None + localContractType = extra + for order in orders: + remoteId = order.remoteId + if remoteId in self._remoteIds: + # 如果订单已经缓存在本地,则尝试更新订单状态 + myorder = self._getOrderByRemoteId(remoteId) + + # 有新交易才推送更新 + if order.tradedVolume != myorder.vtOrder.tradedVolume: + myorder.vtOrder.tradedVolume = order.tradedVolume + myorder.vtOrder.status = constant.STATUS_PARTTRADED + self.onOrder(myorder.vtOrder) + else: + # 本地无此订单的缓存(例如,用其他工具的下单) + # 缓存该订单,并推送 + symbol = remoteSymbolToLocal(order.symbol, localContractType) + direction, offset = remoteOrderTypeToLocal(order.orderType) + myorder = self._genereteLocalOrder(symbol, order.price, order.volume, direction, offset) + myorder.vtOrder.tradedVolume = order.tradedVolume + myorder.remoteId = order.remoteId + self._saveRemoteId(myorder.remoteId, myorder) + self.onOrder(myorder.vtOrder) - + # # 如果该订单已经交易完成,推送交易完成消息 + # # todo: 这样写会导致同一个订单产生多次交易完成消息 + # if order.status == OkexFutureOrderStatus.Finished: + # myorder.vtOrder.status = constant.STATUS_ALLTRADED + # self._pushOrderAsTraded(myorder.vtOrder) + + #---------------------------------------------------------------------- + def _onQueryAccount(self, infos, _): # type: (List[OkexFutureUserInfo], Any)->None + for info in infos: + vtAccount = VtAccountData() + vtAccount.accountID = info.easySymbol + vtAccount.vtAccountID = self.gatewayName + '.' + vtAccount.accountID + vtAccount.balance = info.accountRights + vtAccount.margin = info.keepDeposit + vtAccount.closeProfit = info.profitReal + vtAccount.positionProfit = info.profitUnreal + self.onAccount(vtAccount) + + #---------------------------------------------------------------------- + def _onQueryPosition(self, posinfo, extra): # type: (OkexFuturePosition, Any)->None + localContractType = extra + for info in posinfo.holding: + # 先生成多头持仓 + pos = VtPositionData() + pos.gatewayName = self.gatewayName + pos.symbol = remoteSymbolToLocal(pos.symbol, localContractType) + pos.exchange = self.exchange + pos.vtSymbol = '.'.join([pos.symbol, pos.exchange]) + + pos.direction = constant.DIRECTION_NET + pos.vtPositionName = '.'.join([pos.vtSymbol, pos.direction]) + pos.position = float(info.buyAmount) + + self.onPosition(pos) + + # 再生存空头持仓 + pos = VtPositionData() + pos.gatewayName = self.gatewayName + pos.symbol = remoteSymbolToLocal(pos.symbol, localContractType) + pos.exchange = self.exchange + pos.vtSymbol = '.'.join([pos.symbol, pos.exchange]) + + pos.direction = constant.DIRECTION_SHORT + pos.vtPositionName = '.'.join([pos.vtSymbol, pos.direction]) + pos.position = float(info.sellAmount) + + self.onPosition(pos) diff --git a/vnpy/trader/vtObject.py b/vnpy/trader/vtObject.py index 9aa62507..1a844183 100644 --- a/vnpy/trader/vtObject.py +++ b/vnpy/trader/vtObject.py @@ -1,10 +1,11 @@ # encoding: UTF-8 import time +from datetime import datetime from logging import INFO -from vnpy.trader.vtConstant import (EMPTY_STRING, EMPTY_UNICODE, - EMPTY_FLOAT, EMPTY_INT) +from vnpy.trader.language import constant +from vnpy.trader.vtConstant import (EMPTY_FLOAT, EMPTY_INT, EMPTY_STRING, EMPTY_UNICODE) ######################################################################## @@ -105,7 +106,10 @@ class VtBarData(VtBaseData): ######################################################################## class VtTradeData(VtBaseData): - """成交数据类""" + """ + 成交数据类 + 一般来说,一个VtOrderData可能对应多个VtTradeData:一个订单可能多次部分成交 + """ #---------------------------------------------------------------------- def __init__(self): @@ -116,8 +120,8 @@ class VtTradeData(VtBaseData): self.symbol = EMPTY_STRING # 合约代码 self.exchange = EMPTY_STRING # 交易所代码 self.vtSymbol = EMPTY_STRING # 合约在vt系统中的唯一代码,通常是 合约代码.交易所代码 - - self.tradeID = EMPTY_STRING # 成交编号 + + self.tradeID = EMPTY_STRING # 成交编号 gateway内部自己生成的编号 self.vtTradeID = EMPTY_STRING # 成交在vt系统中的唯一编号,通常是 Gateway名.成交编号 self.orderID = EMPTY_STRING # 订单编号 @@ -129,7 +133,27 @@ class VtTradeData(VtBaseData): self.price = EMPTY_FLOAT # 成交价格 self.volume = EMPTY_INT # 成交数量 self.tradeTime = EMPTY_STRING # 成交时间 - + + #---------------------------------------------------------------------- + @staticmethod + def createFromOrderData(order, + tradeID, + tradePrice, + tradeVolume): # type: (VtOrderData, str, float, float)->VtTradeData + trade = VtTradeData() + trade.gatewayName = order.gatewayName + trade.symbol = order.symbol + trade.vtSymbol = order.vtSymbol + trade.orderID = order.orderID + trade.vtOrderID = order.vtOrderID + trade.tradeID = tradeID + trade.vtTradeID = trade.gatewayName + '.' + tradeID + trade.direction = order.direction + trade.price = tradePrice + trade.volume = tradeVolume + trade.tradeTime = datetime.now().strftime('%H:%M:%S') + return trade + ######################################################################## class VtOrderData(VtBaseData): @@ -143,10 +167,10 @@ class VtOrderData(VtBaseData): # 代码编号相关 self.symbol = EMPTY_STRING # 合约代码 self.exchange = EMPTY_STRING # 交易所代码 - self.vtSymbol = EMPTY_STRING # 统一格式:f"{symbol}.{exchange}" + self.vtSymbol = EMPTY_STRING # 索引,统一格式:f"{symbol}.{exchange}" self.orderID = EMPTY_STRING # 订单编号 gateway内部自己生成的编号 - self.vtOrderID = EMPTY_STRING # 统一格式:f"{gatewayName}.{orderId}" + self.vtOrderID = EMPTY_STRING # 索引,统一格式:f"{gatewayName}.{orderId}" # 报单相关 self.direction = EMPTY_UNICODE # 报单方向 @@ -163,6 +187,33 @@ class VtOrderData(VtBaseData): self.frontID = EMPTY_INT # 前置机编号 self.sessionID = EMPTY_INT # 连接编号 + #---------------------------------------------------------------------- + @staticmethod + def createFromGateway(gateway, orderId, symbol, exchange, price, volume, direction, + offset=EMPTY_UNICODE, + tradedVolume=EMPTY_INT, + status=constant.STATUS_UNKNOWN, + orderTime=EMPTY_UNICODE, + cancelTime=EMPTY_UNICODE, + ): # type: (VtGateway, str, str, str, float, float, str, str, int, str, str, str)->VtOrderData + vtOrder = VtOrderData() + vtOrder.gatewayName = gateway.gatewayName + vtOrder.symbol = symbol + vtOrder.exchange = exchange + vtOrder.vtSymbol = symbol + '.' + exchange + vtOrder.orderID = orderId + vtOrder.vtOrderID = gateway.gatewayName + '.' + orderId + + vtOrder.direction = direction + vtOrder.offset = offset + vtOrder.price = price + vtOrder.totalVolume = volume + vtOrder.tradedVolume = tradedVolume + vtOrder.status = status + vtOrder.orderTime = orderTime + vtOrder.cancelTime = cancelTime + return vtOrder + ######################################################################## class VtPositionData(VtBaseData): @@ -187,6 +238,29 @@ class VtPositionData(VtBaseData): self.ydPosition = EMPTY_INT # 昨持仓 self.positionProfit = EMPTY_FLOAT # 持仓盈亏 + #---------------------------------------------------------------------- + @staticmethod + def createFromGateway(gateway, exchange, symbol, direction, position, + frozen=EMPTY_INT, + price=EMPTY_FLOAT, + yestordayPosition=EMPTY_INT, + profit=EMPTY_FLOAT + ): # type: (VtGateway, str, str, str, float, int, float, int, float)->VtPositionData + vtPosition = VtPositionData() + vtPosition.gatewayName = gateway.gatewayName + vtPosition.symbol = symbol + vtPosition.exchange = exchange + vtPosition.vtSymbol = symbol + '.' + exchange + + vtPosition.direction = direction + vtPosition.position = position + vtPosition.frozen = frozen + vtPosition.price = price + vtPosition.vtPositionName = vtPosition.vtSymbol + '.' + direction + vtPosition.ydPosition = yestordayPosition + vtPosition.positionProfit = profit + return vtPosition + ######################################################################## class VtAccountData(VtBaseData):