Merge branch 'dev' of https://github.com/vnpy/vnpy into dev

This commit is contained in:
vn.py 2018-10-17 14:22:18 +08:00
commit 9d124ed196
9 changed files with 470 additions and 169 deletions

View File

@ -51,7 +51,7 @@ class RestfulClientTest(unittest.TestCase):
def callback(data, req): def callback(data, req):
self.c.p.set_result(data['args']) self.c.p.set_result(data['args'])
self.c.addReq('GET', '/get', callback, params=args) self.c.addRequest('GET', '/get', callback, params=args)
res = self.c.p.get(3) res = self.c.p.get(3)
self.assertEqual(args, res) self.assertEqual(args, res)
@ -63,7 +63,7 @@ class RestfulClientTest(unittest.TestCase):
def callback(data, req): def callback(data, req):
self.c.p.set_result(data['json']) self.c.p.set_result(data['json'])
self.c.addReq('POST', '/post', callback, data=body) self.c.addRequest('POST', '/post', callback, data=body)
res = self.c.p.get(3) res = self.c.p.get(3)
self.assertEqual(body, res) self.assertEqual(body, res)
@ -72,7 +72,7 @@ class RestfulClientTest(unittest.TestCase):
def callback(data, req): def callback(data, req):
pass pass
self.c.addReq('POST', '/status/401', callback) self.c.addRequest('POST', '/status/401', callback)
with self.assertRaises(FailedError): with self.assertRaises(FailedError):
self.c.p.get(3) self.c.p.get(3)
@ -80,7 +80,7 @@ class RestfulClientTest(unittest.TestCase):
def callback(data, req): def callback(data, req):
pass pass
self.c.addReq('GET', '/image/svg', callback) self.c.addRequest('GET', '/image/svg', callback)
with self.assertRaises(JSONDecodeError): with self.assertRaises(JSONDecodeError):
self.c.p.get(3) self.c.p.get(3)

View File

@ -2,10 +2,10 @@
import unittest import unittest
from Promise import Promise from Promise import Promise
from vnpy.api.websocket import WebsocketClient from vnpy.api.websocket import WebSocketClient
class TestWebsocketClient(WebsocketClient): class TestWebsocketClient(WebSocketClient):
def __init__(self): def __init__(self):
host = 'wss://echo.websocket.org' host = 'wss://echo.websocket.org'
@ -13,11 +13,11 @@ class TestWebsocketClient(WebsocketClient):
self.init(host) self.init(host)
self.p = Promise() self.p = Promise()
def onMessage(self, packet): def onPacket(self, packet):
self.p.set_result(packet) self.p.set_result(packet)
pass pass
def onConnect(self): def onConnected(self):
pass pass
def onError(self, exceptionType, exceptionValue, tb): def onError(self, exceptionType, exceptionValue, tb):
@ -38,7 +38,7 @@ class WebsocketClientTest(unittest.TestCase):
req = { req = {
'name': 'val' 'name': 'val'
} }
self.c.sendReq(req) self.c.sendPacket(req)
res = self.c.p.get(3) res = self.c.p.get(3)
self.assertEqual(res, req) self.assertEqual(res, req)

View File

@ -75,6 +75,7 @@ class OkexFutureUserInfo(object):
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def __init__(self): def __init__(self):
self.easySymbol = None # 'etc', 'btc', 'eth', etc.
self.accountRights = None self.accountRights = None
self.keepDeposit = None self.keepDeposit = None
self.profitReal = None self.profitReal = None
@ -142,7 +143,7 @@ class OkexFutureRestClient(OkexFutureRestBase):
def sendOrder(self, symbol, contractType, orderType, volume, def sendOrder(self, symbol, contractType, orderType, volume,
onSuccess, onFailed=None, onSuccess, onFailed=None,
price=None, useMarketPrice=False, leverRate=None, price=None, useMarketPrice=False, leverRate=None,
extra=None): # type:(str, OkexFutureContractType, OkexFutureOrderType, float, Callable[[int, Any], Any], Callable[[int, Any], Any], float, bool, Union[int, None], Any)->Request extra=None): # type:(str, OkexFutureContractType, OkexFutureOrderType, float, Callable[[str, Any], Any], Callable[[int, Any], Any], float, bool, Union[int, None], Any)->Request
""" """
:param symbol: str :param symbol: str
:param contractType: OkexFutureContractType :param contractType: OkexFutureContractType
@ -171,7 +172,7 @@ class OkexFutureRestClient(OkexFutureRestBase):
if leverRate: if leverRate:
data['lever_rate'] = leverRate # 杠杆倍数 data['lever_rate'] = leverRate # 杠杆倍数
request = self.addReq('POST', request = self.addRequest('POST',
'/future_trade.do', '/future_trade.do',
callback=self.onOrderSent, callback=self.onOrderSent,
data=data, data=data,
@ -195,7 +196,7 @@ class OkexFutureRestClient(OkexFutureRestBase):
'contractType': contractType, 'contractType': contractType,
'order_id': orderId 'order_id': orderId
} }
return self.addReq('POST', return self.addRequest('POST',
'/future_cancel.do', '/future_cancel.do',
callback=self.onOrderCanceled, callback=self.onOrderCanceled,
data=data, data=data,
@ -203,12 +204,14 @@ class OkexFutureRestClient(OkexFutureRestBase):
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def queryOrder(self, symbol, contractType, orderId, onSuccess, onFailed=None, def queryOrder(self, symbol, contractType, orderId, onSuccess, onFailed=None,
extra=None): # type: (str, OkexFutureContractType, str, Callable[[OkexFutureOrder, Any], Any], Callable[[int, Any], Any], Any)->Request extra=None): # type: (str, OkexFutureContractType, str, Callable[[List[OkexFutureOrder], Any], Any], Callable[[int, Any], Any], Any)->Request
""" """
@note onSuccess接收的第一个参数是列表并且有可能为空
:param symbol: str :param symbol: str
:param contractType: OkexFutureContractType :param contractType: OkexFutureContractType
:param orderId: str :param orderId: str
:param onSuccess: (OkexFutureOrder, extra:Any)->Any :param onSuccess: (orders: List[OkexFutureOrder], extra:Any)->Any
:param onFailed: (extra: Any)->Any :param onFailed: (extra: Any)->Any
:param extra: Any :param extra: Any
:return: Request :return: Request
@ -218,7 +221,7 @@ class OkexFutureRestClient(OkexFutureRestBase):
'contractType': contractType, 'contractType': contractType,
'order_id': orderId 'order_id': orderId
} }
return self.addReq('POST', return self.addRequest('POST',
'/future_order_info.do', '/future_order_info.do',
callback=self.onOrder, callback=self.onOrder,
data=data, data=data,
@ -227,28 +230,30 @@ class OkexFutureRestClient(OkexFutureRestBase):
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def queryOrders(self, symbol, contractType, status, def queryOrders(self, symbol, contractType, status,
onSuccess, onFailed=None, onSuccess, onFailed=None,
pageIndex=None, pageLength=50, pageIndex=0, pageLength=50,
extra=None): # type: (str, OkexFutureContractType, OkexFutureOrderStatus, Callable[[OkexFutureOrder, Any], Any], Callable[[int, Any], Any], int, int, Any)->Request extra=None): # type: (str, OkexFutureContractType, OkexFutureOrderStatus, Callable[[List[OkexFutureOrder], Any], Any], Callable[[int, Any], Any], int, int, Any)->Request
""" """
@note onSuccess接收的第一个参数是列表并且有可能为空
:param symbol: str :param symbol: str
:param contractType: OkexFutureContractType :param contractType: OkexFutureContractType
:param orderId: str :param onSuccess: (List[OkexFutureOrder], extra:Any)->Any
:param onSuccess: (OkexFutureOrder, extra:Any)->Any
:param onFailed: (extra: Any)->Any :param onFailed: (extra: Any)->Any
:param pageIndex: 页码
:param pageLength: 最大显示数量最大值50
:param extra: Any :param extra: Any
:return: Request :return: Request
""" """
data = { data = {
'symbol': symbol, 'symbol': symbol,
'contractType': contractType, 'contract_type': contractType,
'status': status, 'status': status,
'order_id': '-1', 'order_id': -1,
'pageLength': 50 'current_page': pageIndex,
'page_length': pageLength
} }
if pageIndex:
data['page_index'] = pageIndex
return self.addReq('POST', return self.addRequest('POST',
'/future_order_info.do', '/future_order_info.do',
callback=self.onOrder, callback=self.onOrder,
data=data, data=data,
@ -264,7 +269,7 @@ class OkexFutureRestClient(OkexFutureRestBase):
:param extra: Any :param extra: Any
:return: Request :return: Request
""" """
return self.addReq('POST', return self.addRequest('POST',
'/future_userinfo.do', '/future_userinfo.do',
callback=self.onOrder, callback=self.onOrder,
extra=_OkexFutureCustomExtra(onSuccess, onFailed, extra)) extra=_OkexFutureCustomExtra(onSuccess, onFailed, extra))
@ -273,11 +278,19 @@ class OkexFutureRestClient(OkexFutureRestBase):
def queryPosition(self, symbol, contractType, def queryPosition(self, symbol, contractType,
onSuccess, onFailed=None, onSuccess, onFailed=None,
extra=None): # type: (str, OkexFutureContractType, Callable[[OkexFuturePosition, Any], Any], Callable[[int, Any], Any], Any)->Request extra=None): # type: (str, OkexFutureContractType, Callable[[OkexFuturePosition, Any], Any], Callable[[int, Any], Any], Any)->Request
"""
:param symbol: OkexFutureSymbol
:param contractType: OkexFutureContractType
:param onSuccess: (pos:OkexFuturePosition, extra: any)->Any
:param onFailed: (errorCode: int, extra: any)->Any
:param extra:
:return:
"""
data = { data = {
'symbol': symbol, 'symbol': symbol,
'contractType': contractType 'contractType': contractType
} }
return self.addReq('POST', return self.addRequest('POST',
'/future_position.do', '/future_position.do',
data=data, data=data,
callback=self.onPosition, callback=self.onPosition,
@ -323,7 +336,8 @@ class OkexFutureRestClient(OkexFutureRestBase):
success = data['result'] success = data['result']
extra = req.extra # type: _OkexFutureCustomExtra extra = req.extra # type: _OkexFutureCustomExtra
if success: if success:
order = data['orders'][0] orders = []
for order in data['orders']:
okexOrder = OkexFutureOrder() okexOrder = OkexFutureOrder()
okexOrder.volume = order['amount'] okexOrder.volume = order['amount']
@ -339,8 +353,8 @@ class OkexFutureRestClient(OkexFutureRestBase):
okexOrder.orderType = order['type'] okexOrder.orderType = order['type']
okexOrder.unitAmount = order['unit_amount'] okexOrder.unitAmount = order['unit_amount']
okexOrder.symbol = order['symbol'] okexOrder.symbol = order['symbol']
orders.append(okexOrder)
extra.onSuccess(okexOrder, extra.extra) extra.onSuccess(orders, extra.extra)
else: else:
if extra.onFailed: if extra.onFailed:
code = 0 code = 0
@ -356,8 +370,9 @@ class OkexFutureRestClient(OkexFutureRestBase):
if success: if success:
infos = data['info'] infos = data['info']
uis = [] uis = []
for symbol, info in infos.items(): # type: str, dict for easySymbol, info in infos.items(): # type: str, dict
ui = OkexFutureUserInfo() ui = OkexFutureUserInfo()
ui.easySymbol = easySymbol
ui.accountRights = info['account_rights'] ui.accountRights = info['account_rights']
ui.keepDeposit = info['keep_deposit'] ui.keepDeposit = info['keep_deposit']
ui.profitReal = info['profit_real'] ui.profitReal = info['profit_real']
@ -406,9 +421,9 @@ class OkexFutureRestClient(OkexFutureRestBase):
code = data['error_code'] code = data['error_code']
extra.onFailed(code, extra.extra) extra.onFailed(code, extra.extra)
#---------------------------------------------------------------------- #---------------------------------------------------------------------
@staticmethod @staticmethod
def errorCode2String(code): def errorCodeToString(code):
assert code in errorCodeMap assert code in errorCodeMap
return errorCodeMap[code] return errorCodeMap[code]

View File

@ -2,13 +2,22 @@
import hashlib import hashlib
import urllib import urllib
from vnpy.api.rest import Request, RestClient
########################################################################
from vnpy.api.rest import RestClient, Request #----------------------------------------------------------------------
def sign(dataWithApiKey, apiSecret):
"""
:param dataWithApiKey: sorted urlencoded args with apiKey
:return: param 'sign' for okex api
"""
dataWithSecret = dataWithApiKey + "&secret_key=" + apiSecret
return hashlib.md5(dataWithSecret.encode()).hexdigest().upper()
######################################################################## ########################################################################
class OkexFutureRestBase(RestClient): class OkexFutureRestBase(RestClient):
host = 'https://www.okex.com/api/v1'
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def __init__(self): def __init__(self):
@ -20,7 +29,7 @@ class OkexFutureRestBase(RestClient):
# noinspection PyMethodOverriding # noinspection PyMethodOverriding
def init(self, apiKey, apiSecret): def init(self, apiKey, apiSecret):
# type: (str, str) -> any # type: (str, str) -> any
super(OkexFutureRestBase, self).init('https://www.okex.com/api/v1') super(OkexFutureRestBase, self).init(self.host)
self.apiKey = apiKey self.apiKey = apiKey
self.apiSecret = apiSecret self.apiSecret = apiSecret
@ -33,11 +42,9 @@ class OkexFutureRestBase(RestClient):
if 'apiKey' not in args: if 'apiKey' not in args:
args['api_key'] = self.apiKey args['api_key'] = self.apiKey
data = urllib.urlencode(sorted(args.items())) data = urllib.urlencode(sorted(args.items()))
data += "&secret_key=" + self.apiSecret signature = sign(data, self.apiSecret)
data += "&sign=" + signature
sign = hashlib.md5(data.encode()).hexdigest().upper()
data += "&sign=" + sign
req.headers = {'Content-Type': 'application/x-www-form-urlencoded'} req.headers = {'Content-Type': 'application/x-www-form-urlencoded'}
req.data = data
return req return req

View File

@ -139,7 +139,7 @@ class RestClient(object):
self._queue.join() self._queue.join()
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def addReq(self, method, path, callback, def addRequest(self, method, path, callback,
params=None, data=None, headers = None, params=None, data=None, headers = None,
onFailed=None, skipDefaultOnFailed=True, onFailed=None, skipDefaultOnFailed=True,
extra=None): # type: (str, str, Callable[[dict, Request], Any], dict, dict, dict, Callable[[dict, Request], Any], bool, Any)->Request extra=None): # type: (str, str, Callable[[dict, Request], Any], dict, dict, dict, Callable[[dict, Request], Any], bool, Any)->Request
@ -171,7 +171,7 @@ class RestClient(object):
try: try:
req = self._queue.get(timeout=1) req = self._queue.get(timeout=1)
try: try:
self._processReq(req, session) self._processRequest(req, session)
finally: finally:
self._queue.task_done() self._queue.task_done()
except Empty: except Empty:
@ -215,8 +215,10 @@ class RestClient(object):
sys.excepthook(exceptionType, exceptionValue, tb) sys.excepthook(exceptionType, exceptionValue, tb)
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def _processReq(self, req, session): # type: (Request, requests.Session)->None def _processRequest(self, req, session): # type: (Request, requests.Session)->None
"""处理请求""" """
用于内部将请求发送出去
"""
try: try:
req = self.beforeRequest(req) req = self.beforeRequest(req)

View File

@ -9,11 +9,27 @@ import time
from abc import abstractmethod from abc import abstractmethod
from threading import Thread, Lock from threading import Thread, Lock
import vnpy.api.websocket import websocket
class WebsocketClient(object): class WebSocketClient(object):
"""Websocket API""" """
Websocket API
继承使用该类
实例化之后应调用start开始后台线程调用start()函数会自动连接websocket
若要终止后台线程请调用stop() stop()函数会顺便断开websocket
可以重写以下函数
onConnected
onDisconnected
onPacket
onError
当然为了不让用户随意自定义用自己的init函数覆盖掉原本的init(host)也是个不错的选择
@note 继承使用该类
"""
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def __init__(self): def __init__(self):
@ -27,6 +43,15 @@ class WebsocketClient(object):
self._pingThread = None # type: Thread self._pingThread = None # type: Thread
self._active = False self._active = False
self.createConnection = websocket.create_connection
#----------------------------------------------------------------------
def setCreateConnection(self, func):
"""
for internal usage
"""
self.createConnection = func
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def init(self, host): def init(self, host):
self.host = host self.host = host
@ -43,7 +68,7 @@ class WebsocketClient(object):
self._pingThread = Thread(target=self._runPing) self._pingThread = Thread(target=self._runPing)
self._pingThread.start() self._pingThread.start()
self.onConnect() self.onConnected()
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def stop(self): def stop(self):
@ -55,18 +80,18 @@ class WebsocketClient(object):
self._disconnect() self._disconnect()
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def sendReq(self, req): # type: (dict)->None def sendPacket(self, dictObj): # type: (dict)->None
"""发出请求""" """发出请求:相当于sendText(json.dumps(dictObj))"""
return self._get_ws().send(json.dumps(req), opcode=vnpy.api.websocket.ABNF.OPCODE_TEXT) return self._get_ws().send(json.dumps(dictObj), opcode=websocket.ABNF.OPCODE_TEXT)
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def sendText(self, text): # type: (str)->None def sendText(self, text): # type: (str)->None
"""出请求""" """送文本数据"""
return self._get_ws().send(text, opcode=vnpy.api.websocket.ABNF.OPCODE_TEXT) return self._get_ws().send(text, opcode=websocket.ABNF.OPCODE_TEXT)
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def sendData(self, data): # type: (bytes)->None def sendData(self, data): # type: (bytes)->None
"""出请求""" """送字节数据"""
return self._get_ws().send_binary(data) return self._get_ws().send_binary(data)
#---------------------------------------------------------------------- #----------------------------------------------------------------------
@ -78,8 +103,8 @@ class WebsocketClient(object):
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def _connect(self): def _connect(self):
"""""" """"""
self._ws = vnpy.api.websocket.create_connection(self.host, sslopt={'cert_reqs': ssl.CERT_NONE}) self._ws = self.createConnection(self.host, sslopt={'cert_reqs': ssl.CERT_NONE})
self.onConnect() self.onConnected()
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def _disconnect(self): def _disconnect(self):
@ -104,12 +129,13 @@ class WebsocketClient(object):
try: try:
stream = ws.recv() stream = ws.recv()
if not stream: if not stream:
self.onDisconnected()
if self._active: if self._active:
self._reconnect() self._reconnect()
continue continue
data = json.loads(stream) data = json.loads(stream)
self.onMessage(data) self.onPacket(data)
except: except:
et, ev, tb = sys.exc_info() et, ev, tb = sys.exc_info()
self.onError(et, ev, tb) self.onError(et, ev, tb)
@ -125,17 +151,25 @@ class WebsocketClient(object):
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def _ping(self): def _ping(self):
return self._get_ws().send('ping', vnpy.api.websocket.ABNF.OPCODE_PING) return self._get_ws().send('ping', websocket.ABNF.OPCODE_PING)
#---------------------------------------------------------------------- #----------------------------------------------------------------------
@abstractmethod def onConnected(self):
def onConnect(self): """
"""连接回调""" 连接成功回调
"""
pass
#----------------------------------------------------------------------
def onDisconnected(self):
"""
连接断开回调
"""
pass pass
#---------------------------------------------------------------------- #----------------------------------------------------------------------
@abstractmethod @abstractmethod
def onMessage(self, packet): def onPacket(self, packet):
""" """
数据回调 数据回调
只有在数据为json包的时候才会触发这个回调 只有在数据为json包的时候才会触发这个回调
@ -145,7 +179,6 @@ class WebsocketClient(object):
pass pass
#---------------------------------------------------------------------- #----------------------------------------------------------------------
@abstractmethod
def onError(self, exceptionType, exceptionValue, tb): def onError(self, exceptionType, exceptionValue, tb):
"""Python错误回调""" """Python错误回调"""
pass return sys.excepthook(exceptionType, exceptionValue, tb)

View File

@ -1 +1 @@
from .WebSocketClient import WebsocketClient from .WebSocketClient import WebSocketClient

View File

@ -4,6 +4,7 @@ from __future__ import print_function
import json import json
from abc import abstractmethod, abstractproperty from abc import abstractmethod, abstractproperty
from datetime import datetime
from typing import Dict from typing import Dict
@ -11,32 +12,74 @@ from vnpy.api.okexfuture.OkexFutureApi import *
from vnpy.trader.vtFunction import getJsonPath from vnpy.trader.vtFunction import getJsonPath
from vnpy.trader.vtGateway import * from vnpy.trader.vtGateway import *
orderTypeMap = { _orderTypeMap = {
(constant.DIRECTION_LONG, constant.OFFSET_OPEN): OkexFutureOrderType.OpenLong, (constant.DIRECTION_LONG, constant.OFFSET_OPEN): OkexFutureOrderType.OpenLong,
(constant.DIRECTION_SHORT, constant.OFFSET_OPEN): OkexFutureOrderType.OpenShort, (constant.DIRECTION_SHORT, constant.OFFSET_OPEN): OkexFutureOrderType.OpenShort,
(constant.DIRECTION_LONG, constant.OFFSET_CLOSE): OkexFutureOrderType.CloseLong, (constant.DIRECTION_LONG, constant.OFFSET_CLOSE): OkexFutureOrderType.CloseLong,
(constant.DIRECTION_SHORT, constant.OFFSET_CLOSE): OkexFutureOrderType.CloseShort, (constant.DIRECTION_SHORT, constant.OFFSET_CLOSE): OkexFutureOrderType.CloseShort,
} }
orderTypeMapReverse = {v: k for k, v in orderTypeMap.items()} _orderTypeMapReverse = {v: k for k, v in _orderTypeMap.items()}
contracts = ( _contractTypeMap = {
'btc_usd', 'ltc_usd', 'eth_usd', 'etc_usd', 'bch_usd',
)
contractTypeMap = {
'THISWEEK': OkexFutureContractType.ThisWeek, 'THISWEEK': OkexFutureContractType.ThisWeek,
'NEXTWEEK': OkexFutureContractType.NextWeek, 'NEXTWEEK': OkexFutureContractType.NextWeek,
'QUARTER': OkexFutureContractType.Quarter, 'QUARTER': OkexFutureContractType.Quarter,
} }
_contractTypeMapReverse = {v: k for k, v in _contractTypeMap.items()}
_remoteSymbols = {
OkexFutureSymbol.BTC,
OkexFutureSymbol.LTC,
OkexFutureSymbol.ETH,
OkexFutureSymbol.ETC,
OkexFutureSymbol.BCH,
}
# symbols for ui, # symbols for ui,
# keys:给用户看的symbols # keys:给用户看的symbols : f"{internalSymbol}_{contractType}"
# values: API接口使用的symbol和contractType字段 # values: API接口使用的symbol和contractType字段
symbolsForUi = {} # type: dict[str, [str, str]] _symbolsForUi = {(remoteSymbol.upper() + '_' + upperContractType.upper()): (remoteSymbol, remoteContractType)
for s in contracts: for remoteSymbol in _remoteSymbols
for vtContractType, contractType_ in contractTypeMap.items(): for upperContractType, remoteContractType in
vtSymbol = s + '_' + vtContractType _contractTypeMap.items()} # type: Dict[str, List[str, str]]
symbolsForUi[vtSymbol] = (s, contractType_) _symbolsForUiReverse = {v: k for k, v in _symbolsForUi.items()}
#----------------------------------------------------------------------
def localOrderTypeToRemote(direction, offset): # type: (str, str)->str
return _orderTypeMap[(direction, offset)]
#----------------------------------------------------------------------
def remoteOrderTypeToLocal(orderType): # type: (str)->(str, str)
"""
:param orderType:
:return: direction, offset
"""
return _orderTypeMapReverse[orderType]
#----------------------------------------------------------------------
def localContractTypeToRemote(localContractType):
return _contractTypeMap[localContractType]
#----------------------------------------------------------------------
def remoteContractTypeToLocal(remoteContractType):
return _contractTypeMapReverse[remoteContractType]
#----------------------------------------------------------------------
def localSymbolToRemote(symbol): # type: (str)->(OkexFutureSymbol, OkexFutureContractType)
"""
:return: remoteSymbol, remoteContractType
"""
return _symbolsForUi[symbol]
#----------------------------------------------------------------------
def remoteSymbolToLocal(remoteSymbol, localContractType):
return remoteSymbol.upper() + localContractType
######################################################################## ########################################################################
@ -55,11 +98,6 @@ class VnpyGateway(VtGateway):
def gatewayName(self): # type: ()->str def gatewayName(self): # type: ()->str
return 'VnpyGateway' return 'VnpyGateway'
#----------------------------------------------------------------------
@abstractproperty
def exchange(self): # type: ()->str
return constant.EXCHANGE_UNKNOWN
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def readConfig(self): def readConfig(self):
""" """
@ -88,9 +126,13 @@ class VnpyGateway(VtGateway):
""" """
pass pass
########################################################################
class _Order(object): class _Order(object):
_lastLocalId = 0 _lastLocalId = 0
def __init__(self):
#----------------------------------------------------------------------
def __ini__(self):
_Order._lastLocalId += 1 _Order._lastLocalId += 1
self.localId = str(_Order._lastLocalId) self.localId = str(_Order._lastLocalId)
self.remoteId = None self.remoteId = None
@ -102,7 +144,7 @@ class OkexFutureGateway(VnpyGateway):
"""OKEX期货交易接口""" """OKEX期货交易接口"""
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def __init__(self, eventEngine, *args, **kwargs): # args, kwargs is needed for compatibility def __init__(self, eventEngine, *_, **__): # args, kwargs is needed for compatibility
"""Constructor""" """Constructor"""
super(OkexFutureGateway, self).__init__(eventEngine) super(OkexFutureGateway, self).__init__(eventEngine)
self.apiKey = None # type: str self.apiKey = None # type: str
@ -111,14 +153,17 @@ class OkexFutureGateway(VnpyGateway):
self.leverRate = 1 self.leverRate = 1
self.symbols = [] self.symbols = []
self.tradeID = 0
self._orders = {} # type: Dict[str, _Order] self._orders = {} # type: Dict[str, _Order]
self._remoteIds = {} # type: Dict[str, _Order]
#---------------------------------------------------------------------- #----------------------------------------------------------------------
@property @property
def gatewayName(self): def gatewayName(self):
return 'OkexFutureGateway' return 'OkexFutureGateway'
#---------------------------------------------------------------------- #----------------------------------------------------------------------
@abstractproperty @property
def exchange(self): # type: ()->str def exchange(self): # type: ()->str
return constant.EXCHANGE_OKEXFUTURE return constant.EXCHANGE_OKEXFUTURE
@ -153,48 +198,57 @@ class OkexFutureGateway(VnpyGateway):
pass pass
#---------------------------------------------------------------------- #----------------------------------------------------------------------
@staticmethod def _getOrderByLocalId(self, localId):
def _contractTypeFromSymbol(symbol): return self._orders[localId]
return symbolsForUi[symbol]
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def _getOrder(self, localId): def _getOrderByRemoteId(self, remoteId):
return self._orders[localId] return self._remoteIds[remoteId]
#----------------------------------------------------------------------
def _saveRemoteId(self, remoteId, myorder):
myorder.remoteId = remoteId
self._remoteIds[remoteId] = myorder
#----------------------------------------------------------------------
def _genereteLocalOrder(self, symbol, price, volume, direction, offset):
myorder = _Order()
localId = myorder.localId
self._orders[localId] = myorder
myorder.vtOrder = VtOrderData.createFromGateway(self,
self.exchange,
localId,
symbol,
price,
volume,
direction,
offset)
return myorder
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def sendOrder(self, vtRequest): # type: (VtOrderReq)->str def sendOrder(self, vtRequest): # type: (VtOrderReq)->str
"""发单""" """发单"""
myorder = _Order() myorder = self._genereteLocalOrder(vtRequest.symbol,
localId = myorder.localId vtRequest.price,
vtRequest.volume,
vtRequest.direction,
vtRequest.offset)
vtOrder = VtOrderData() remoteSymbol, remoteContractType = localSymbolToRemote(vtRequest.symbol)
vtOrder.orderID = localId orderType = _orderTypeMap[(vtRequest.priceType, vtRequest.offset)] # 开多、开空、平多、平空
vtOrder.vtOrderID = ".".join([self.gatewayName, localId])
vtOrder.exchange = self.exchange
vtOrder.symbol = vtRequest.symbol
vtOrder.vtSymbol = '.'.join([vtOrder.symbol, vtOrder.exchange])
vtOrder.price = vtRequest.price
vtOrder.totalVolume = vtRequest.volume
vtOrder.direction = vtRequest.direction
myorder.vtOrder = vtOrder
symbol, contractType = self._contractTypeFromSymbol(vtRequest.symbol)
orderType = orderTypeMap[(vtRequest.priceType, vtRequest.offset)] # 开多、开空、平多、平空
userMarketPrice = False userMarketPrice = False
if vtRequest.priceType == constant.PRICETYPE_MARKETPRICE: if vtRequest.priceType == constant.PRICETYPE_MARKETPRICE:
userMarketPrice = True userMarketPrice = True
self.api.sendOrder(symbol=symbol, self.api.sendOrder(symbol=remoteSymbol,
contractType=contractType, contractType=remoteContractType,
orderType=orderType, orderType=orderType,
volume=vtRequest.volume, volume=vtRequest.volume,
price=vtRequest.price, price=vtRequest.price,
useMarketPrice=userMarketPrice, useMarketPrice=userMarketPrice,
leverRate=self.leverRate, leverRate=self.leverRate,
onSuccess=self.onOrderSent, onSuccess=self._onOrderSent,
extra=None, extra=None,
) )
@ -203,44 +257,160 @@ class OkexFutureGateway(VnpyGateway):
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def cancelOrder(self, vtCancel): # type: (VtCancelOrderReq)->None def cancelOrder(self, vtCancel): # type: (VtCancelOrderReq)->None
"""撤单""" """撤单"""
myorder = self._getOrder(vtCancel.orderID) myorder = self._getOrderByLocalId(vtCancel.orderID)
symbol, contractType = self._contractTypeFromSymbol(vtCancel.symbol) symbol, contractType = localSymbolToRemote(vtCancel.symbol)
self.api.cancelOrder(symbol=symbol, self.api.cancelOrder(symbol=symbol,
contractType=contractType, contractType=contractType,
orderId=myorder.remoteId, orderId=myorder.remoteId,
onSuccess=self.onOrderCanceled, onSuccess=self._onOrderCanceled,
extra=myorder, extra=myorder,
) )
# cancelDict: 不存在的没有localId就没有remoteId没有remoteId何来cancel # cancelDict: 不存在的没有localId就没有remoteId没有remoteId何来cancel
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def queryOrder(self): def queryOrders(self, symbol, contractType,
status): # type: (str, str, OkexFutureOrderStatus)->None
"""
:param symbol:
:param contractType: 这个参数可以传'THISWEEK', 'NEXTWEEK', 'QUARTER'也可以传OkexFutureContractType
:param status: OkexFutureOrderStatus
:return:
"""
if contractType in _contractTypeMap:
localContractType = contractType
remoteContractType = localContractTypeToRemote(localContractType)
else:
remoteContractType = contractType
localContractType = remoteContractTypeToLocal(remoteContractType)
self.api.queryOrders(symbol=symbol,
contractType=remoteContractType,
status=status,
onSuccess=self._onQueryOrders,
extra=localContractType)
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def qryAccount(self): def qryAccount(self):
self.api.queryUserInfo(onSuccess=self._onQueryAccount)
"""查询账户资金""" """查询账户资金"""
pass pass
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def qryPosition(self): def qryPosition(self):
"""查询持仓""" """查询持仓"""
self.api.spotUserInfo() for remoteSymbol in _remoteSymbols:
for localContractType, remoteContractType in _contractTypeMap.items():
self.api.queryPosition(remoteSymbol,
remoteContractType,
onSuccess=self._onQueryPosition,
extra=localContractType
)
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def close(self): def close(self):
"""关闭""" """关闭"""
self.api.close() self.api.stop()
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def onOrderSent(self, remoteId, myorder): #type: (int, _Order)->None def _onOrderSent(self, remoteId, myorder): #type: (str, _Order)->None
myorder.remoteId = remoteId myorder.remoteId = remoteId
myorder.vtOrder.status = constant.STATUS_NOTTRADED myorder.vtOrder.status = constant.STATUS_NOTTRADED
self._saveRemoteId(remoteId, myorder)
self.onOrder(myorder.vtOrder) self.onOrder(myorder.vtOrder)
# #----------------------------------------------------------------------
# def _pushOrderAsTraded(self, order):
# trade = VtTradeData()
# trade.gatewayName = order.gatewayName
# trade.symbol = order.symbol
# trade.vtSymbol = order.vtSymbol
# trade.orderID = order.orderID
# trade.vtOrderID = order.vtOrderID
# self.tradeID += 1
# trade.tradeID = str(self.tradeID)
# trade.vtTradeID = '.'.join([self.gatewayName, trade.tradeID])
# trade.direction = order.direction
# trade.price = order.price
# trade.volume = order.tradedVolume
# trade.tradeTime = datetime.now().strftime('%H:%M:%S')
# self.onTrade(trade)
#---------------------------------------------------------------------- #----------------------------------------------------------------------
@staticmethod @staticmethod
def onOrderCanceled(myorder): #type: (_Order)->None def _onOrderCanceled(myorder): #type: (_Order)->None
myorder.vtOrder.status = constant.STATUS_CANCELLED myorder.vtOrder.status = constant.STATUS_CANCELLED
#----------------------------------------------------------------------
def _onQueryOrders(self, orders, extra): # type: (List[OkexFutureOrder], Any)->None
localContractType = extra
for order in orders:
remoteId = order.remoteId
if remoteId in self._remoteIds:
# 如果订单已经缓存在本地,则尝试更新订单状态
myorder = self._getOrderByRemoteId(remoteId)
# 有新交易才推送更新
if order.tradedVolume != myorder.vtOrder.tradedVolume:
myorder.vtOrder.tradedVolume = order.tradedVolume
myorder.vtOrder.status = constant.STATUS_PARTTRADED
self.onOrder(myorder.vtOrder)
else:
# 本地无此订单的缓存(例如,用其他工具的下单)
# 缓存该订单,并推送
symbol = remoteSymbolToLocal(order.symbol, localContractType)
direction, offset = remoteOrderTypeToLocal(order.orderType)
myorder = self._genereteLocalOrder(symbol, order.price, order.volume, direction, offset)
myorder.vtOrder.tradedVolume = order.tradedVolume
myorder.remoteId = order.remoteId
self._saveRemoteId(myorder.remoteId, myorder)
self.onOrder(myorder.vtOrder)
# # 如果该订单已经交易完成,推送交易完成消息
# # todo: 这样写会导致同一个订单产生多次交易完成消息
# if order.status == OkexFutureOrderStatus.Finished:
# myorder.vtOrder.status = constant.STATUS_ALLTRADED
# self._pushOrderAsTraded(myorder.vtOrder)
#----------------------------------------------------------------------
def _onQueryAccount(self, infos, _): # type: (List[OkexFutureUserInfo], Any)->None
for info in infos:
vtAccount = VtAccountData()
vtAccount.accountID = info.easySymbol
vtAccount.vtAccountID = self.gatewayName + '.' + vtAccount.accountID
vtAccount.balance = info.accountRights
vtAccount.margin = info.keepDeposit
vtAccount.closeProfit = info.profitReal
vtAccount.positionProfit = info.profitUnreal
self.onAccount(vtAccount)
#----------------------------------------------------------------------
def _onQueryPosition(self, posinfo, extra): # type: (OkexFuturePosition, Any)->None
localContractType = extra
for info in posinfo.holding:
# 先生成多头持仓
pos = VtPositionData()
pos.gatewayName = self.gatewayName
pos.symbol = remoteSymbolToLocal(pos.symbol, localContractType)
pos.exchange = self.exchange
pos.vtSymbol = '.'.join([pos.symbol, pos.exchange])
pos.direction = constant.DIRECTION_NET
pos.vtPositionName = '.'.join([pos.vtSymbol, pos.direction])
pos.position = float(info.buyAmount)
self.onPosition(pos)
# 再生存空头持仓
pos = VtPositionData()
pos.gatewayName = self.gatewayName
pos.symbol = remoteSymbolToLocal(pos.symbol, localContractType)
pos.exchange = self.exchange
pos.vtSymbol = '.'.join([pos.symbol, pos.exchange])
pos.direction = constant.DIRECTION_SHORT
pos.vtPositionName = '.'.join([pos.vtSymbol, pos.direction])
pos.position = float(info.sellAmount)
self.onPosition(pos)

View File

@ -1,10 +1,11 @@
# encoding: UTF-8 # encoding: UTF-8
import time import time
from datetime import datetime
from logging import INFO from logging import INFO
from vnpy.trader.vtConstant import (EMPTY_STRING, EMPTY_UNICODE, from vnpy.trader.language import constant
EMPTY_FLOAT, EMPTY_INT) from vnpy.trader.vtConstant import (EMPTY_FLOAT, EMPTY_INT, EMPTY_STRING, EMPTY_UNICODE)
######################################################################## ########################################################################
@ -105,7 +106,10 @@ class VtBarData(VtBaseData):
######################################################################## ########################################################################
class VtTradeData(VtBaseData): class VtTradeData(VtBaseData):
"""成交数据类""" """
成交数据类
一般来说一个VtOrderData可能对应多个VtTradeData一个订单可能多次部分成交
"""
#---------------------------------------------------------------------- #----------------------------------------------------------------------
def __init__(self): def __init__(self):
@ -117,7 +121,7 @@ class VtTradeData(VtBaseData):
self.exchange = EMPTY_STRING # 交易所代码 self.exchange = EMPTY_STRING # 交易所代码
self.vtSymbol = EMPTY_STRING # 合约在vt系统中的唯一代码通常是 合约代码.交易所代码 self.vtSymbol = EMPTY_STRING # 合约在vt系统中的唯一代码通常是 合约代码.交易所代码
self.tradeID = EMPTY_STRING # 成交编号 self.tradeID = EMPTY_STRING # 成交编号 gateway内部自己生成的编号
self.vtTradeID = EMPTY_STRING # 成交在vt系统中的唯一编号通常是 Gateway名.成交编号 self.vtTradeID = EMPTY_STRING # 成交在vt系统中的唯一编号通常是 Gateway名.成交编号
self.orderID = EMPTY_STRING # 订单编号 self.orderID = EMPTY_STRING # 订单编号
@ -130,6 +134,26 @@ class VtTradeData(VtBaseData):
self.volume = EMPTY_INT # 成交数量 self.volume = EMPTY_INT # 成交数量
self.tradeTime = EMPTY_STRING # 成交时间 self.tradeTime = EMPTY_STRING # 成交时间
#----------------------------------------------------------------------
@staticmethod
def createFromOrderData(order,
tradeID,
tradePrice,
tradeVolume): # type: (VtOrderData, str, float, float)->VtTradeData
trade = VtTradeData()
trade.gatewayName = order.gatewayName
trade.symbol = order.symbol
trade.vtSymbol = order.vtSymbol
trade.orderID = order.orderID
trade.vtOrderID = order.vtOrderID
trade.tradeID = tradeID
trade.vtTradeID = trade.gatewayName + '.' + tradeID
trade.direction = order.direction
trade.price = tradePrice
trade.volume = tradeVolume
trade.tradeTime = datetime.now().strftime('%H:%M:%S')
return trade
######################################################################## ########################################################################
class VtOrderData(VtBaseData): class VtOrderData(VtBaseData):
@ -143,10 +167,10 @@ class VtOrderData(VtBaseData):
# 代码编号相关 # 代码编号相关
self.symbol = EMPTY_STRING # 合约代码 self.symbol = EMPTY_STRING # 合约代码
self.exchange = EMPTY_STRING # 交易所代码 self.exchange = EMPTY_STRING # 交易所代码
self.vtSymbol = EMPTY_STRING # 统一格式f"{symbol}.{exchange}" self.vtSymbol = EMPTY_STRING # 索引,统一格式f"{symbol}.{exchange}"
self.orderID = EMPTY_STRING # 订单编号 gateway内部自己生成的编号 self.orderID = EMPTY_STRING # 订单编号 gateway内部自己生成的编号
self.vtOrderID = EMPTY_STRING # 统一格式f"{gatewayName}.{orderId}" self.vtOrderID = EMPTY_STRING # 索引,统一格式f"{gatewayName}.{orderId}"
# 报单相关 # 报单相关
self.direction = EMPTY_UNICODE # 报单方向 self.direction = EMPTY_UNICODE # 报单方向
@ -163,6 +187,33 @@ class VtOrderData(VtBaseData):
self.frontID = EMPTY_INT # 前置机编号 self.frontID = EMPTY_INT # 前置机编号
self.sessionID = EMPTY_INT # 连接编号 self.sessionID = EMPTY_INT # 连接编号
#----------------------------------------------------------------------
@staticmethod
def createFromGateway(gateway, orderId, symbol, exchange, price, volume, direction,
offset=EMPTY_UNICODE,
tradedVolume=EMPTY_INT,
status=constant.STATUS_UNKNOWN,
orderTime=EMPTY_UNICODE,
cancelTime=EMPTY_UNICODE,
): # type: (VtGateway, str, str, str, float, float, str, str, int, str, str, str)->VtOrderData
vtOrder = VtOrderData()
vtOrder.gatewayName = gateway.gatewayName
vtOrder.symbol = symbol
vtOrder.exchange = exchange
vtOrder.vtSymbol = symbol + '.' + exchange
vtOrder.orderID = orderId
vtOrder.vtOrderID = gateway.gatewayName + '.' + orderId
vtOrder.direction = direction
vtOrder.offset = offset
vtOrder.price = price
vtOrder.totalVolume = volume
vtOrder.tradedVolume = tradedVolume
vtOrder.status = status
vtOrder.orderTime = orderTime
vtOrder.cancelTime = cancelTime
return vtOrder
######################################################################## ########################################################################
class VtPositionData(VtBaseData): class VtPositionData(VtBaseData):
@ -187,6 +238,29 @@ class VtPositionData(VtBaseData):
self.ydPosition = EMPTY_INT # 昨持仓 self.ydPosition = EMPTY_INT # 昨持仓
self.positionProfit = EMPTY_FLOAT # 持仓盈亏 self.positionProfit = EMPTY_FLOAT # 持仓盈亏
#----------------------------------------------------------------------
@staticmethod
def createFromGateway(gateway, exchange, symbol, direction, position,
frozen=EMPTY_INT,
price=EMPTY_FLOAT,
yestordayPosition=EMPTY_INT,
profit=EMPTY_FLOAT
): # type: (VtGateway, str, str, str, float, int, float, int, float)->VtPositionData
vtPosition = VtPositionData()
vtPosition.gatewayName = gateway.gatewayName
vtPosition.symbol = symbol
vtPosition.exchange = exchange
vtPosition.vtSymbol = symbol + '.' + exchange
vtPosition.direction = direction
vtPosition.position = position
vtPosition.frozen = frozen
vtPosition.price = price
vtPosition.vtPositionName = vtPosition.vtSymbol + '.' + direction
vtPosition.ydPosition = yestordayPosition
vtPosition.positionProfit = profit
return vtPosition
######################################################################## ########################################################################
class VtAccountData(VtBaseData): class VtAccountData(VtBaseData):