Merge branch 'dev' of https://github.com/vnpy/vnpy into dev

This commit is contained in:
vn.py 2018-11-01 11:10:08 +08:00
commit db7e10d3d9
10 changed files with 342 additions and 34 deletions

View File

@ -4,6 +4,15 @@ cache: pip
python:
- 2.7
- 3.6
before_install:
- sudo apt install build-essential
- wget http://prdownloads.sourceforge.net/ta-lib/ta-lib-0.4.0-src.tar.gz
- tar -xzf ta-lib-0.4.0-src.tar.gz
- cd ta-lib/
- ./configure --prefix=/usr
- make
- sudo make install
- cd ..
install:
- pip install -r requirements.txt
- pip install flake8 # pytest # add another testing frameworks later

View File

@ -1,12 +1,14 @@
# encoding: UTF-8
import json
import sys
import unittest
import uuid
from simplejson import JSONDecodeError
from Promise import Promise
from vnpy.api.rest.RestClient import RestClient, Request
from vnpy.api.rest.RestClient import Request, RestClient
class FailedError(RuntimeError):
@ -84,3 +86,46 @@ class RestfulClientTest(unittest.TestCase):
with self.assertRaises(JSONDecodeError):
self.c.p.get(3)
class RestfulClientErrorHandleTest(unittest.TestCase):
def setUp(self):
self.c = TestRestClient()
self.c.start()
self.org_sys_hook = sys.excepthook
self.org_sys_stderr_write = sys.stderr.write
sys.excepthook = self.nop
sys.stderr.write = self.nop
def tearDown(self):
self.c.stop()
@staticmethod
def nop(*args, **kwargs):
pass
def test_onError(self):
"""这个测试保证onError内不会再抛异常"""
target = uuid.uuid4()
def callback(data, req):
pass
def onError(*args, **kwargs):
try:
super(TestRestClient, self.c).onError(*args, **kwargs)
self.c.p.set_result(target)
except:
self.c.p.set_result(False)
self.c.onError = onError
self.c.addRequest('GET', '/image/svg', callback)
res = self.c.p.get(3)
self.assertEqual(target, res)
if __name__ == '__main__':
unittest.main()

View File

@ -1,5 +1,7 @@
# encoding: UTF-8
import sys
import unittest
import uuid
from Promise import Promise
from vnpy.api.websocket import WebsocketClient
@ -12,12 +14,14 @@ class TestWebsocketClient(WebsocketClient):
super(TestWebsocketClient, self).__init__()
self.init(host)
self.p = Promise()
self.cp = Promise()
def onPacket(self, packet):
self.p.set_result(packet)
pass
def onConnected(self):
self.cp.set_result(True)
pass
def onError(self, exceptionType, exceptionValue, tb):
@ -30,6 +34,7 @@ class WebsocketClientTest(unittest.TestCase):
def setUp(self):
self.c = TestWebsocketClient()
self.c.start()
self.c.cp.get(3) # wait for connected
def tearDown(self):
self.c.stop()
@ -42,3 +47,73 @@ class WebsocketClientTest(unittest.TestCase):
res = self.c.p.get(3)
self.assertEqual(res, req)
def test_parseError(self):
class CustomException(Exception): pass
def onPacket(packet):
raise CustomException("Just a exception")
self.c.onPacket = onPacket
req = {
'name': 'val'
}
self.c.sendPacket(req)
with self.assertRaises(CustomException):
self.c.p.get(3)
class WebsocketClientErrorHandleTest(unittest.TestCase):
def setUp(self):
self.c = TestWebsocketClient()
self.c.start()
self.c.cp.get(3) # wait for connected
self.org_sys_hook = sys.excepthook
self.org_sys_stderr_write = sys.stderr.write
sys.excepthook = self.nop
sys.stderr.write = self.nop
@staticmethod
def nop(*args, **kwargs):
pass
def tearDown(self):
self.c.stop()
sys.excepthook = self.org_sys_hook
sys.stderr.write = self.org_sys_stderr_write
def test_onError(self):
target= uuid.uuid4()
"""这个测试保证onError内不会再抛Exception"""
class CustomException(Exception):
pass
def onPacket(packet):
raise CustomException("Just a exception")
def onError(*args, **kwargs):
try:
super(TestWebsocketClient, self.c).onError(*args, **kwargs)
self.c.p.set_result(target)
except:
self.c.p.set_result(False)
self.c.onPacket = onPacket
self.c.onError = onError
req = {
'name': 'val'
}
self.c.sendPacket(req)
res = self.c.p.get(3)
self.assertEqual(target, res)
if __name__ == '__main__':
unittest.main()

View File

@ -2,7 +2,9 @@
import sys
import traceback
from Queue import Empty, Queue
from datetime import datetime
from multiprocessing.dummy import Pool
import requests
@ -34,6 +36,7 @@ class Request(object):
self.headers = headers # type: dict
self.onFailed = None # type: callable
self.onError = None # type: callable
self.extra = None # type: Any
self.response = None # type: requests.Response
@ -95,7 +98,8 @@ class RestClient(object):
#----------------------------------------------------------------------
def start(self, n=3):
"""启动"""
assert not self._active
if self._active:
return
self._active = True
self._pool = Pool(n)
@ -128,6 +132,7 @@ class RestClient(object):
data=None, # type: dict
headers=None, # type: dict
onFailed=None, # type: Callable[[int, Request], Any]
onError=None, # type: Callable[[type, Exception, traceback, Request], Any]
extra=None # type: Any
): # type: (...)->Request
"""
@ -139,6 +144,7 @@ class RestClient(object):
:param data: dict for body
:param headers: dict for headers
:param onFailed: 请求失败后的回调(状态吗不为2xx时认为请求失败)如果指定该值默认的onFailed将不会被调用 type: (code, dict, Request)
:param onError: 请求出现Python错误后的回调如果指定该值默认的onError将不会被调用 type: (etype, evalue, tb, Request)
:param extra: 返回值的extra字段会被设置为这个值当然你也可以在函数调用之后再设置这个字段
:return: Request
"""
@ -146,6 +152,7 @@ class RestClient(object):
request = Request(method, path, params, data, headers, callback)
request.extra = extra
request.onFailed = onFailed
request.onError = onError
self._queue.put(request)
return request
@ -195,9 +202,29 @@ class RestClient(object):
Python内部错误处理默认行为是仍给excepthook
:param request 如果是在处理请求的时候出错它的值就是对应的Request否则为None
"""
print("error in request : {}\n".format(request))
sys.stderr.write(self.exceptionDetail(exceptionType, exceptionValue, tb, request))
sys.excepthook(exceptionType, exceptionValue, tb)
#----------------------------------------------------------------------
def exceptionDetail(self,
exceptionType, # type: type
exceptionValue, # type: Exception
tb,
request # type: Optional[Request]
):
text = "[{}]: Unhandled RestClient Error:{}\n".format(
datetime.now().isoformat(),
exceptionType
)
text += "request:{}\n".format(request)
text += "Exception trace: \n"
text += "".join(traceback.format_exception(
exceptionType,
exceptionValue,
tb,
))
return text
#----------------------------------------------------------------------
def _processRequest(self, request, session): # type: (Request, requests.Session)->None
"""
@ -231,6 +258,9 @@ class RestClient(object):
except:
request.status = RequestStatus.error
t, v, tb = sys.exc_info()
if request.onError:
request.onError(t, v, tb, request)
else:
self.onError(t, v, tb, request)
#----------------------------------------------------------------------
@ -243,3 +273,4 @@ class RestClient(object):
"""
url = self.urlBase + path
return url

View File

@ -3,13 +3,15 @@
########################################################################
import json
import sys
import ssl
import sys
import time
import websocket
import traceback
from datetime import datetime
from threading import Lock, Thread
import websocket
class WebsocketClient(object):
"""
@ -50,7 +52,11 @@ class WebsocketClient(object):
#----------------------------------------------------------------------
def start(self):
"""启动"""
"""
启动
:note 注意启动之后不能立即发包需要等待websocket连接成功
websocket连接成功之后会响应onConnected函数
"""
self._active = True
self._workerThread = Thread(target=self._run)
@ -59,6 +65,10 @@ class WebsocketClient(object):
self._pingThread = Thread(target=self._runPing)
self._pingThread.start()
# for debugging:
self._lastSentText = None
self._lastReceivedText = None
#----------------------------------------------------------------------
def stop(self):
"""
@ -80,7 +90,9 @@ class WebsocketClient(object):
#----------------------------------------------------------------------
def sendPacket(self, dictObj): # type: (dict)->None
"""发出请求:相当于sendText(json.dumps(dictObj))"""
return self._getWs().send(json.dumps(dictObj), opcode=websocket.ABNF.OPCODE_TEXT)
text = json.dumps(dictObj)
self._recordLastSentText(text)
return self._getWs().send(text, opcode=websocket.ABNF.OPCODE_TEXT)
#----------------------------------------------------------------------
def sendText(self, text): # type: (str)->None
@ -137,15 +149,15 @@ class WebsocketClient(object):
try:
ws = self._getWs()
if ws:
stream = ws.recv()
if not stream: # recv在阻塞的时候ws被关闭
text = ws.recv()
if not text: # recv在阻塞的时候ws被关闭
self._reconnect()
continue
self._recordLastReceivedText(text)
try:
data = json.loads(stream)
data = self.unpackData(text)
except ValueError as e:
print('websocket unable to parse data: ' + stream)
print('websocket unable to parse data: ' + text)
raise e
self.onPacket(data)
except websocket.WebSocketConnectionClosedException: # 在调用recv之前ws就被关闭了
@ -159,6 +171,17 @@ class WebsocketClient(object):
self.onError(et, ev, tb)
self._reconnect()
#----------------------------------------------------------------------
@staticmethod
def unpackData(data):
"""
解密数据默认使用json解密为dict
解密后的数据将会传入onPacket
如果需要使用不同的解密方式就重载这个函数
:param data 收到的数据可能是text frame也可能是binary frame, 目前并没有区分这两者
"""
return json.loads(data)
#----------------------------------------------------------------------
def _runPing(self):
while self._active:
@ -207,7 +230,43 @@ class WebsocketClient(object):
pass
#----------------------------------------------------------------------
@staticmethod
def onError(exceptionType, exceptionValue, tb):
"""Python错误回调"""
def onError(self, exceptionType, exceptionValue, tb):
"""
Python错误回调
todo: 以后详细的错误信息最好记录在文件里用uuid来联系/区分具体错误
"""
sys.stderr.write(self.exceptionDetail(exceptionType, exceptionValue, tb))
# 丢给默认的错误处理函数所以如果不重载onError一般的结果是程序会崩溃
return sys.excepthook(exceptionType, exceptionValue, tb)
#----------------------------------------------------------------------
def exceptionDetail(self, exceptionType, exceptionValue, tb):
"""打印详细的错误信息"""
text = "[{}]: Unhandled WebSocket Error:{}\n".format(
datetime.now().isoformat(),
exceptionType
)
text += "LastSentText:\n{}\n".format(self._lastSentText)
text += "LastReceivedText:\n{}\n".format(self._lastReceivedText)
text += "Exception trace: \n"
text += "".join(traceback.format_exception(
exceptionType,
exceptionValue,
tb,
))
return text
#----------------------------------------------------------------------
def _recordLastSentText(self, text):
"""
用于Debug 记录最后一次发送出去的text
"""
self._lastSentText = text[:200]
#----------------------------------------------------------------------
def _recordLastReceivedText(self, text):
"""
用于Debug 记录最后一次发送出去的text
"""
self._lastReceivedText = text[:200]

View File

@ -12,6 +12,7 @@ import traceback
# 用来保存策略类的字典
STRATEGY_CLASS = {}
#----------------------------------------------------------------------
def loadStrategyModule(moduleName):
"""使用importlib动态载入模块"""
@ -34,7 +35,7 @@ path = os.path.abspath(os.path.dirname(__file__))
for root, subdirs, files in os.walk(path):
for name in files:
# 只有文件名中包含strategy且以.py结尾的文件才是策略文件
if 'strategy' in name and name[-3:] == '.py':
if 'strategy' in name and name[-3:] == '.py' and '/' not in name and '\\' not in name:
# 模块名称需要模块路径前缀
moduleName = 'vnpy.trader.app.ctaStrategy.strategy.' + name.replace('.py', '')
loadStrategyModule(moduleName)
@ -45,7 +46,7 @@ workingPath = os.getcwd()
for root, subdirs, files in os.walk(workingPath):
for name in files:
# 只有文件名中包含strategy且以.py结尾的文件才是策略文件
if 'strategy' in name and name[-3:] == '.py':
if 'strategy' in name and name[-3:] == '.py' and '/' not in name and '\\' not in name:
# 模块名称无需前缀
moduleName = name.replace('.py', '')
loadStrategyModule(moduleName)

View File

@ -10,6 +10,7 @@ import os
import json
import hashlib
import hmac
import sys
import time
import traceback
from datetime import datetime, timedelta
@ -17,7 +18,9 @@ from copy import copy
from math import pow
from urllib import urlencode
from vnpy.api.rest import RestClient
from requests import ConnectionError
from vnpy.api.rest import RestClient, Request
from vnpy.api.websocket import WebsocketClient
from vnpy.trader.vtGateway import *
from vnpy.trader.vtFunction import getJsonPath, getTempPath
@ -68,6 +71,8 @@ class BitmexGateway(VtGateway):
self.fileName = self.gatewayName + '_connect.json'
self.filePath = getJsonPath(self.fileName, __file__)
self.exchange = constant.EXCHANGE_BITMEX
#----------------------------------------------------------------------
def connect(self):
"""连接"""
@ -172,7 +177,7 @@ class BitmexRestApi(RestClient):
"""Constructor"""
super(BitmexRestApi, self).__init__()
self.gateway = gateway # gateway对象
self.gateway = gateway # type: BitmexGateway # gateway对象
self.gatewayName = gateway.gatewayName # gateway对象名称
self.apiKey = ''
@ -240,11 +245,11 @@ class BitmexRestApi(RestClient):
self.gateway.onLog(log)
#----------------------------------------------------------------------
def sendOrder(self, orderReq):
def sendOrder(self, orderReq):# type: (VtOrderReq)->str
""""""
self.orderId += 1
orderId = self.loginTime + self.orderId
vtOrderID = '.'.join([self.gatewayName, str(orderId)])
orderId = str(self.loginTime + self.orderId)
vtOrderID = '.'.join([self.gatewayName, orderId])
data = {
'symbol': orderReq.symbol,
@ -259,7 +264,21 @@ class BitmexRestApi(RestClient):
if orderReq.priceType == PRICETYPE_LIMITPRICE:
data['price'] = orderReq.price
self.addRequest('POST', '/order', self.onSendOrder, data=data)
vtOrder = VtOrderData.createFromGateway(
self.gateway,
orderId=orderId,
symbol=orderReq.symbol,
exchange=self.gateway.exchange,
price=orderReq.price,
volume=orderReq.volume,
direction=orderReq.direction,
offset=orderReq.offset,
)
self.addRequest('POST', '/order', callback=self.onSendOrder, data=data, extra=vtOrder,
onFailed=self.onSendOrderFailed,
onError=self.onSendOrderError,
)
return vtOrderID
#----------------------------------------------------------------------
@ -272,13 +291,47 @@ class BitmexRestApi(RestClient):
else:
params = {'orderID': orderID}
self.addRequest('DELETE', '/order', self.onCancelOrder, params=params)
self.addRequest('DELETE', '/order', callback=self.onCancelOrder, params=params,
onError=self.onCancelOrderError,
)
#----------------------------------------------------------------------
def onSendOrderFailed(self, _, request):
"""
下单失败回调服务器明确告知下单失败
"""
vtOrder = request.extra # type: VtOrderData
vtOrder.status = constant.STATUS_REJECTED
self.gateway.onOrder(vtOrder)
pass
#----------------------------------------------------------------------
def onSendOrderError(self, exceptionType, exceptionValue, tb, request):
"""
下单失败回调连接错误
"""
vtOrder = request.extra # type: VtOrderData
vtOrder.status = constant.STATUS_REJECTED
self.gateway.onOrder(vtOrder)
# 意料之中的错误只有ConnectionError若还有其他错误最好还是记录一下用原来的onError记录即可
if not issubclass(exceptionType, ConnectionError):
self.onError(exceptionType, exceptionValue, tb, request)
#----------------------------------------------------------------------
def onSendOrder(self, data, request):
""""""
pass
#----------------------------------------------------------------------
def onCancelOrderError(self, exceptionType, exceptionValue, tb, request):
"""
撤单失败回调连接错误
"""
# 意料之中的错误只有ConnectionError若还有其他错误最好还是记录一下用原来的onError记录即可
if not issubclass(exceptionType, ConnectionError):
self.onError(exceptionType, exceptionValue, tb, request)
#----------------------------------------------------------------------
def onCancelOrder(self, data, request):
""""""
@ -308,7 +361,7 @@ class BitmexRestApi(RestClient):
e.errorMsg = exceptionValue
self.gateway.onError(e)
traceback.print_exc()
sys.stderr.write(self.exceptionDetail(exceptionType, exceptionValue, tb, request))
########################################################################
@ -517,10 +570,13 @@ class BitmexWebsocketApi(WebsocketClient):
trade.orderID = orderID
trade.vtOrderID = '.'.join([trade.gatewayName, trade.orderID])
trade.tradeID = tradeID
trade.vtTradeID = '.'.join([trade.gatewayName, trade.tradeID])
if 'side' not in d:
print('no side : \n', d)
return
trade.direction = directionMapReverse[d['side']]
trade.price = d['lastPx']
trade.volume = d['lastQty']

View File

@ -376,9 +376,33 @@ class CtpMdApi(MdApi):
tick.askVolume1 = data['AskVolume1']
# 大商所日期转换
if tick.exchange is EXCHANGE_DCE:
if tick.exchange == EXCHANGE_DCE:
tick.date = datetime.now().strftime('%Y%m%d')
# 上交所SEE股票期权相关
if tick.exchange == EXCHANGE_SSE:
tick.bidPrice2 = data['BidPrice2']
tick.bidVolume2 = data['BidVolume2']
tick.askPrice2 = data['AskPrice2']
tick.askVolume2 = data['AskVolume2']
tick.bidPrice3 = data['BidPrice3']
tick.bidVolume3 = data['BidVolume3']
tick.askPrice3 = data['AskPrice3']
tick.askVolume3 = data['AskVolume3']
tick.bidPrice4 = data['BidPrice4']
tick.bidVolume4 = data['BidVolume4']
tick.askPrice4 = data['AskPrice4']
tick.askVolume4 = data['AskVolume4']
tick.bidPrice5 = data['BidPrice5']
tick.bidVolume5 = data['BidVolume5']
tick.askPrice5 = data['AskPrice5']
tick.askVolume5 = data['AskVolume5']
tick.date = data['TradingDay']
self.gateway.onTick(tick)
#----------------------------------------------------------------------

View File

@ -121,7 +121,7 @@ class OkexFuturesWebSocketBase(WebsocketClient):
Okex期货websocket客户端
实例化后使用init设置apiKey和secretKeyapiSecret
"""
host = 'wss://real.okex.com:10440/websocket/okexapi'
host = 'wss://real.okex.com:10440/websocket/okexapi?compress=true'
def __init__(self):
super(OkexFuturesWebSocketBase, self).__init__()

View File

@ -5,6 +5,7 @@ from __future__ import print_function
import json
import sys
import traceback
import zlib
from collections import defaultdict
from enum import Enum
@ -141,6 +142,7 @@ class OkexFuturesGateway(VtGateway):
self.webSocket = OkexFuturesWebSocketBase()
self.webSocket.onPacket = self.onWebSocketPacket
self.webSocket.unpackData = self.webSocketUnpackData
self.leverRate = 1
self.symbols = []
@ -612,6 +614,12 @@ class OkexFuturesGateway(VtGateway):
)
self.onPosition(vtPos)
#----------------------------------------------------------------------
@staticmethod
def webSocketUnpackData(data):
"""重载websocket.unpackData"""
return json.loads(zlib.decompress(data, -zlib.MAX_WBITS))
#----------------------------------------------------------------------
def onWebSocketPacket(self, packets):
for packet in packets: