diff --git a/.travis.yml b/.travis.yml index fd330019..5db367a4 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,6 +4,15 @@ cache: pip python: - 2.7 - 3.6 +before_install: + - sudo apt install build-essential + - wget http://prdownloads.sourceforge.net/ta-lib/ta-lib-0.4.0-src.tar.gz + - tar -xzf ta-lib-0.4.0-src.tar.gz + - cd ta-lib/ + - ./configure --prefix=/usr + - make + - sudo make install + - cd .. install: - pip install -r requirements.txt - pip install flake8 # pytest # add another testing frameworks later diff --git a/tests/api/base/RestClientTest.py b/tests/api/base/RestClientTest.py index 33d3d233..c6556202 100644 --- a/tests/api/base/RestClientTest.py +++ b/tests/api/base/RestClientTest.py @@ -1,12 +1,14 @@ # encoding: UTF-8 import json +import sys import unittest +import uuid from simplejson import JSONDecodeError from Promise import Promise -from vnpy.api.rest.RestClient import RestClient, Request +from vnpy.api.rest.RestClient import Request, RestClient class FailedError(RuntimeError): @@ -26,7 +28,7 @@ class TestRestClient(RestClient): req.data = json.dumps(req.data) req.headers = {'Content-Type': 'application/json'} return req - + def onError(self, exceptionType, exceptionValue, tb, req): self.p.set_exception(exceptionType, exceptionValue, tb) @@ -83,4 +85,47 @@ class RestfulClientTest(unittest.TestCase): self.c.addRequest('GET', '/image/svg', callback) with self.assertRaises(JSONDecodeError): self.c.p.get(3) + + +class RestfulClientErrorHandleTest(unittest.TestCase): + def setUp(self): + self.c = TestRestClient() + self.c.start() + + self.org_sys_hook = sys.excepthook + self.org_sys_stderr_write = sys.stderr.write + + sys.excepthook = self.nop + sys.stderr.write = self.nop + + def tearDown(self): + self.c.stop() + + @staticmethod + def nop(*args, **kwargs): + pass + + def test_onError(self): + """这个测试保证onError内不会再抛异常""" + target = uuid.uuid4() + + def callback(data, req): + pass + + def onError(*args, **kwargs): + try: + super(TestRestClient, self.c).onError(*args, **kwargs) + self.c.p.set_result(target) + except: + self.c.p.set_result(False) + self.c.onError = onError + + self.c.addRequest('GET', '/image/svg', callback) + + res = self.c.p.get(3) + self.assertEqual(target, res) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/api/base/WebSocketClientTest.py b/tests/api/base/WebSocketClientTest.py index dc9651dd..dbfb5745 100644 --- a/tests/api/base/WebSocketClientTest.py +++ b/tests/api/base/WebSocketClientTest.py @@ -1,5 +1,7 @@ # encoding: UTF-8 +import sys import unittest +import uuid from Promise import Promise from vnpy.api.websocket import WebsocketClient @@ -12,12 +14,14 @@ class TestWebsocketClient(WebsocketClient): super(TestWebsocketClient, self).__init__() self.init(host) self.p = Promise() + self.cp = Promise() def onPacket(self, packet): self.p.set_result(packet) pass def onConnected(self): + self.cp.set_result(True) pass def onError(self, exceptionType, exceptionValue, tb): @@ -30,6 +34,7 @@ class WebsocketClientTest(unittest.TestCase): def setUp(self): self.c = TestWebsocketClient() self.c.start() + self.c.cp.get(3) # wait for connected def tearDown(self): self.c.stop() @@ -42,3 +47,73 @@ class WebsocketClientTest(unittest.TestCase): res = self.c.p.get(3) self.assertEqual(res, req) + + def test_parseError(self): + class CustomException(Exception): pass + + def onPacket(packet): + raise CustomException("Just a exception") + + self.c.onPacket = onPacket + req = { + 'name': 'val' + } + self.c.sendPacket(req) + + with self.assertRaises(CustomException): + self.c.p.get(3) + + +class WebsocketClientErrorHandleTest(unittest.TestCase): + + def setUp(self): + + self.c = TestWebsocketClient() + self.c.start() + self.c.cp.get(3) # wait for connected + + self.org_sys_hook = sys.excepthook + self.org_sys_stderr_write = sys.stderr.write + + sys.excepthook = self.nop + sys.stderr.write = self.nop + + @staticmethod + def nop(*args, **kwargs): + pass + + def tearDown(self): + self.c.stop() + sys.excepthook = self.org_sys_hook + sys.stderr.write = self.org_sys_stderr_write + + def test_onError(self): + target= uuid.uuid4() + """这个测试保证onError内不会再抛Exception""" + class CustomException(Exception): + pass + + def onPacket(packet): + raise CustomException("Just a exception") + + def onError(*args, **kwargs): + try: + super(TestWebsocketClient, self.c).onError(*args, **kwargs) + self.c.p.set_result(target) + except: + self.c.p.set_result(False) + + self.c.onPacket = onPacket + self.c.onError = onError + + req = { + 'name': 'val' + } + self.c.sendPacket(req) + + res = self.c.p.get(3) + self.assertEqual(target, res) + + +if __name__ == '__main__': + unittest.main() diff --git a/vnpy/api/rest/RestClient.py b/vnpy/api/rest/RestClient.py index fa3ef0d4..32b96666 100644 --- a/vnpy/api/rest/RestClient.py +++ b/vnpy/api/rest/RestClient.py @@ -2,7 +2,9 @@ import sys +import traceback from Queue import Empty, Queue +from datetime import datetime from multiprocessing.dummy import Pool import requests @@ -34,6 +36,7 @@ class Request(object): self.headers = headers # type: dict self.onFailed = None # type: callable + self.onError = None # type: callable self.extra = None # type: Any self.response = None # type: requests.Response @@ -95,7 +98,8 @@ class RestClient(object): #---------------------------------------------------------------------- def start(self, n=3): """启动""" - assert not self._active + if self._active: + return self._active = True self._pool = Pool(n) @@ -128,6 +132,7 @@ class RestClient(object): data=None, # type: dict headers=None, # type: dict onFailed=None, # type: Callable[[int, Request], Any] + onError=None, # type: Callable[[type, Exception, traceback, Request], Any] extra=None # type: Any ): # type: (...)->Request """ @@ -139,6 +144,7 @@ class RestClient(object): :param data: dict for body :param headers: dict for headers :param onFailed: 请求失败后的回调(状态吗不为2xx时认为请求失败)(如果指定该值,默认的onFailed将不会被调用) type: (code, dict, Request) + :param onError: 请求出现Python错误后的回调(如果指定该值,默认的onError将不会被调用) type: (etype, evalue, tb, Request) :param extra: 返回值的extra字段会被设置为这个值。当然,你也可以在函数调用之后再设置这个字段。 :return: Request """ @@ -146,6 +152,7 @@ class RestClient(object): request = Request(method, path, params, data, headers, callback) request.extra = extra request.onFailed = onFailed + request.onError = onError self._queue.put(request) return request @@ -195,9 +202,29 @@ class RestClient(object): Python内部错误处理:默认行为是仍给excepthook :param request 如果是在处理请求的时候出错,它的值就是对应的Request,否则为None """ - print("error in request : {}\n".format(request)) + sys.stderr.write(self.exceptionDetail(exceptionType, exceptionValue, tb, request)) sys.excepthook(exceptionType, exceptionValue, tb) + #---------------------------------------------------------------------- + def exceptionDetail(self, + exceptionType, # type: type + exceptionValue, # type: Exception + tb, + request # type: Optional[Request] + ): + text = "[{}]: Unhandled RestClient Error:{}\n".format( + datetime.now().isoformat(), + exceptionType + ) + text += "request:{}\n".format(request) + text += "Exception trace: \n" + text += "".join(traceback.format_exception( + exceptionType, + exceptionValue, + tb, + )) + return text + #---------------------------------------------------------------------- def _processRequest(self, request, session): # type: (Request, requests.Session)->None """ @@ -231,7 +258,10 @@ class RestClient(object): except: request.status = RequestStatus.error t, v, tb = sys.exc_info() - self.onError(t, v, tb, request) + if request.onError: + request.onError(t, v, tb, request) + else: + self.onError(t, v, tb, request) #---------------------------------------------------------------------- def makeFullUrl(self, path): @@ -243,3 +273,4 @@ class RestClient(object): """ url = self.urlBase + path return url + diff --git a/vnpy/api/websocket/WebsocketClient.py b/vnpy/api/websocket/WebsocketClient.py index a029bda5..909e2ab2 100644 --- a/vnpy/api/websocket/WebsocketClient.py +++ b/vnpy/api/websocket/WebsocketClient.py @@ -3,13 +3,15 @@ ######################################################################## import json -import sys - import ssl +import sys import time -import websocket +import traceback +from datetime import datetime from threading import Lock, Thread +import websocket + class WebsocketClient(object): """ @@ -50,7 +52,11 @@ class WebsocketClient(object): #---------------------------------------------------------------------- def start(self): - """启动""" + """ + 启动 + :note 注意:启动之后不能立即发包,需要等待websocket连接成功。 + websocket连接成功之后会响应onConnected函数 + """ self._active = True self._workerThread = Thread(target=self._run) @@ -58,6 +64,10 @@ class WebsocketClient(object): self._pingThread = Thread(target=self._runPing) self._pingThread.start() + + # for debugging: + self._lastSentText = None + self._lastReceivedText = None #---------------------------------------------------------------------- def stop(self): @@ -80,7 +90,9 @@ class WebsocketClient(object): #---------------------------------------------------------------------- def sendPacket(self, dictObj): # type: (dict)->None """发出请求:相当于sendText(json.dumps(dictObj))""" - return self._getWs().send(json.dumps(dictObj), opcode=websocket.ABNF.OPCODE_TEXT) + text = json.dumps(dictObj) + self._recordLastSentText(text) + return self._getWs().send(text, opcode=websocket.ABNF.OPCODE_TEXT) #---------------------------------------------------------------------- def sendText(self, text): # type: (str)->None @@ -137,15 +149,15 @@ class WebsocketClient(object): try: ws = self._getWs() if ws: - stream = ws.recv() - if not stream: # recv在阻塞的时候ws被关闭 + text = ws.recv() + if not text: # recv在阻塞的时候ws被关闭 self._reconnect() continue - + self._recordLastReceivedText(text) try: - data = json.loads(stream) + data = self.unpackData(text) except ValueError as e: - print('websocket unable to parse data: ' + stream) + print('websocket unable to parse data: ' + text) raise e self.onPacket(data) except websocket.WebSocketConnectionClosedException: # 在调用recv之前ws就被关闭了 @@ -158,6 +170,17 @@ class WebsocketClient(object): et, ev, tb = sys.exc_info() self.onError(et, ev, tb) self._reconnect() + + #---------------------------------------------------------------------- + @staticmethod + def unpackData(data): + """ + 解密数据,默认使用json解密为dict + 解密后的数据将会传入onPacket + 如果需要使用不同的解密方式,就重载这个函数。 + :param data 收到的数据,可能是text frame,也可能是binary frame, 目前并没有区分这两者 + """ + return json.loads(data) #---------------------------------------------------------------------- def _runPing(self): @@ -207,7 +230,43 @@ class WebsocketClient(object): pass #---------------------------------------------------------------------- - @staticmethod - def onError(exceptionType, exceptionValue, tb): - """Python错误回调""" + def onError(self, exceptionType, exceptionValue, tb): + """ + Python错误回调 + todo: 以后详细的错误信息最好记录在文件里,用uuid来联系/区分具体错误 + """ + sys.stderr.write(self.exceptionDetail(exceptionType, exceptionValue, tb)) + + # 丢给默认的错误处理函数(所以如果不重载onError,一般的结果是程序会崩溃) return sys.excepthook(exceptionType, exceptionValue, tb) + + #---------------------------------------------------------------------- + def exceptionDetail(self, exceptionType, exceptionValue, tb): + """打印详细的错误信息""" + text = "[{}]: Unhandled WebSocket Error:{}\n".format( + datetime.now().isoformat(), + exceptionType + ) + text += "LastSentText:\n{}\n".format(self._lastSentText) + text += "LastReceivedText:\n{}\n".format(self._lastReceivedText) + text += "Exception trace: \n" + text += "".join(traceback.format_exception( + exceptionType, + exceptionValue, + tb, + )) + return text + + #---------------------------------------------------------------------- + def _recordLastSentText(self, text): + """ + 用于Debug: 记录最后一次发送出去的text + """ + self._lastSentText = text[:200] + + #---------------------------------------------------------------------- + def _recordLastReceivedText(self, text): + """ + 用于Debug: 记录最后一次发送出去的text + """ + self._lastReceivedText = text[:200] diff --git a/vnpy/trader/app/ctaStrategy/strategy/__init__.py b/vnpy/trader/app/ctaStrategy/strategy/__init__.py index 5b10bc02..2aea54b3 100644 --- a/vnpy/trader/app/ctaStrategy/strategy/__init__.py +++ b/vnpy/trader/app/ctaStrategy/strategy/__init__.py @@ -12,6 +12,7 @@ import traceback # 用来保存策略类的字典 STRATEGY_CLASS = {} + #---------------------------------------------------------------------- def loadStrategyModule(moduleName): """使用importlib动态载入模块""" @@ -34,7 +35,7 @@ path = os.path.abspath(os.path.dirname(__file__)) for root, subdirs, files in os.walk(path): for name in files: # 只有文件名中包含strategy且以.py结尾的文件,才是策略文件 - if 'strategy' in name and name[-3:] == '.py': + if 'strategy' in name and name[-3:] == '.py' and '/' not in name and '\\' not in name: # 模块名称需要模块路径前缀 moduleName = 'vnpy.trader.app.ctaStrategy.strategy.' + name.replace('.py', '') loadStrategyModule(moduleName) @@ -45,7 +46,7 @@ workingPath = os.getcwd() for root, subdirs, files in os.walk(workingPath): for name in files: # 只有文件名中包含strategy且以.py结尾的文件,才是策略文件 - if 'strategy' in name and name[-3:] == '.py': + if 'strategy' in name and name[-3:] == '.py' and '/' not in name and '\\' not in name: # 模块名称无需前缀 moduleName = name.replace('.py', '') loadStrategyModule(moduleName) diff --git a/vnpy/trader/gateway/bitmexGateway/bitmexGateway.py b/vnpy/trader/gateway/bitmexGateway/bitmexGateway.py index 0df8d238..95745adb 100644 --- a/vnpy/trader/gateway/bitmexGateway/bitmexGateway.py +++ b/vnpy/trader/gateway/bitmexGateway/bitmexGateway.py @@ -10,6 +10,7 @@ import os import json import hashlib import hmac +import sys import time import traceback from datetime import datetime, timedelta @@ -17,7 +18,9 @@ from copy import copy from math import pow from urllib import urlencode -from vnpy.api.rest import RestClient +from requests import ConnectionError + +from vnpy.api.rest import RestClient, Request from vnpy.api.websocket import WebsocketClient from vnpy.trader.vtGateway import * from vnpy.trader.vtFunction import getJsonPath, getTempPath @@ -67,6 +70,8 @@ class BitmexGateway(VtGateway): self.fileName = self.gatewayName + '_connect.json' self.filePath = getJsonPath(self.fileName, __file__) + + self.exchange = constant.EXCHANGE_BITMEX #---------------------------------------------------------------------- def connect(self): @@ -172,7 +177,7 @@ class BitmexRestApi(RestClient): """Constructor""" super(BitmexRestApi, self).__init__() - self.gateway = gateway # gateway对象 + self.gateway = gateway # type: BitmexGateway # gateway对象 self.gatewayName = gateway.gatewayName # gateway对象名称 self.apiKey = '' @@ -240,11 +245,11 @@ class BitmexRestApi(RestClient): self.gateway.onLog(log) #---------------------------------------------------------------------- - def sendOrder(self, orderReq): + def sendOrder(self, orderReq):# type: (VtOrderReq)->str """""" self.orderId += 1 - orderId = self.loginTime + self.orderId - vtOrderID = '.'.join([self.gatewayName, str(orderId)]) + orderId = str(self.loginTime + self.orderId) + vtOrderID = '.'.join([self.gatewayName, orderId]) data = { 'symbol': orderReq.symbol, @@ -258,8 +263,22 @@ class BitmexRestApi(RestClient): # 只有限价单才有price字段 if orderReq.priceType == PRICETYPE_LIMITPRICE: data['price'] = orderReq.price + + vtOrder = VtOrderData.createFromGateway( + self.gateway, + orderId=orderId, + symbol=orderReq.symbol, + exchange=self.gateway.exchange, + price=orderReq.price, + volume=orderReq.volume, + direction=orderReq.direction, + offset=orderReq.offset, + ) - self.addRequest('POST', '/order', self.onSendOrder, data=data) + self.addRequest('POST', '/order', callback=self.onSendOrder, data=data, extra=vtOrder, + onFailed=self.onSendOrderFailed, + onError=self.onSendOrderError, + ) return vtOrderID #---------------------------------------------------------------------- @@ -272,12 +291,46 @@ class BitmexRestApi(RestClient): else: params = {'orderID': orderID} - self.addRequest('DELETE', '/order', self.onCancelOrder, params=params) - + self.addRequest('DELETE', '/order', callback=self.onCancelOrder, params=params, + onError=self.onCancelOrderError, + ) + + #---------------------------------------------------------------------- + def onSendOrderFailed(self, _, request): + """ + 下单失败回调:服务器明确告知下单失败 + """ + vtOrder = request.extra # type: VtOrderData + vtOrder.status = constant.STATUS_REJECTED + self.gateway.onOrder(vtOrder) + pass + + #---------------------------------------------------------------------- + def onSendOrderError(self, exceptionType, exceptionValue, tb, request): + """ + 下单失败回调:连接错误 + """ + vtOrder = request.extra # type: VtOrderData + vtOrder.status = constant.STATUS_REJECTED + self.gateway.onOrder(vtOrder) + + # 意料之中的错误只有ConnectionError,若还有其他错误,最好还是记录一下,用原来的onError记录即可 + if not issubclass(exceptionType, ConnectionError): + self.onError(exceptionType, exceptionValue, tb, request) + #---------------------------------------------------------------------- def onSendOrder(self, data, request): """""" pass + + #---------------------------------------------------------------------- + def onCancelOrderError(self, exceptionType, exceptionValue, tb, request): + """ + 撤单失败回调:连接错误 + """ + # 意料之中的错误只有ConnectionError,若还有其他错误,最好还是记录一下,用原来的onError记录即可 + if not issubclass(exceptionType, ConnectionError): + self.onError(exceptionType, exceptionValue, tb, request) #---------------------------------------------------------------------- def onCancelOrder(self, data, request): @@ -307,8 +360,8 @@ class BitmexRestApi(RestClient): e.errorID = exceptionType e.errorMsg = exceptionValue self.gateway.onError(e) - - traceback.print_exc() + + sys.stderr.write(self.exceptionDetail(exceptionType, exceptionValue, tb, request)) ######################################################################## @@ -517,10 +570,13 @@ class BitmexWebsocketApi(WebsocketClient): trade.orderID = orderID trade.vtOrderID = '.'.join([trade.gatewayName, trade.orderID]) - trade.tradeID = tradeID trade.vtTradeID = '.'.join([trade.gatewayName, trade.tradeID]) + if 'side' not in d: + print('no side : \n', d) + return + trade.direction = directionMapReverse[d['side']] trade.price = d['lastPx'] trade.volume = d['lastQty'] diff --git a/vnpy/trader/gateway/ctpGateway/ctpGateway.py b/vnpy/trader/gateway/ctpGateway/ctpGateway.py index d6cfff21..3d726019 100644 --- a/vnpy/trader/gateway/ctpGateway/ctpGateway.py +++ b/vnpy/trader/gateway/ctpGateway/ctpGateway.py @@ -376,9 +376,33 @@ class CtpMdApi(MdApi): tick.askVolume1 = data['AskVolume1'] # 大商所日期转换 - if tick.exchange is EXCHANGE_DCE: + if tick.exchange == EXCHANGE_DCE: tick.date = datetime.now().strftime('%Y%m%d') - + + # 上交所,SEE,股票期权相关 + if tick.exchange == EXCHANGE_SSE: + tick.bidPrice2 = data['BidPrice2'] + tick.bidVolume2 = data['BidVolume2'] + tick.askPrice2 = data['AskPrice2'] + tick.askVolume2 = data['AskVolume2'] + + tick.bidPrice3 = data['BidPrice3'] + tick.bidVolume3 = data['BidVolume3'] + tick.askPrice3 = data['AskPrice3'] + tick.askVolume3 = data['AskVolume3'] + + tick.bidPrice4 = data['BidPrice4'] + tick.bidVolume4 = data['BidVolume4'] + tick.askPrice4 = data['AskPrice4'] + tick.askVolume4 = data['AskVolume4'] + + tick.bidPrice5 = data['BidPrice5'] + tick.bidVolume5 = data['BidVolume5'] + tick.askPrice5 = data['AskPrice5'] + tick.askVolume5 = data['AskVolume5'] + + tick.date = data['TradingDay'] + self.gateway.onTick(tick) #---------------------------------------------------------------------- diff --git a/vnpy/trader/gateway/okexFuturesGateway/OkexFuturesBase.py b/vnpy/trader/gateway/okexFuturesGateway/OkexFuturesBase.py index 9192cfd0..1e586f0f 100644 --- a/vnpy/trader/gateway/okexFuturesGateway/OkexFuturesBase.py +++ b/vnpy/trader/gateway/okexFuturesGateway/OkexFuturesBase.py @@ -121,7 +121,7 @@ class OkexFuturesWebSocketBase(WebsocketClient): Okex期货websocket客户端 实例化后使用init设置apiKey和secretKey(apiSecret) """ - host = 'wss://real.okex.com:10440/websocket/okexapi' + host = 'wss://real.okex.com:10440/websocket/okexapi?compress=true' def __init__(self): super(OkexFuturesWebSocketBase, self).__init__() diff --git a/vnpy/trader/gateway/okexFuturesGateway/okexFuturesGateway.py b/vnpy/trader/gateway/okexFuturesGateway/okexFuturesGateway.py index 1b6f14b5..8c621578 100644 --- a/vnpy/trader/gateway/okexFuturesGateway/okexFuturesGateway.py +++ b/vnpy/trader/gateway/okexFuturesGateway/okexFuturesGateway.py @@ -5,6 +5,7 @@ from __future__ import print_function import json import sys import traceback +import zlib from collections import defaultdict from enum import Enum @@ -141,6 +142,7 @@ class OkexFuturesGateway(VtGateway): self.webSocket = OkexFuturesWebSocketBase() self.webSocket.onPacket = self.onWebSocketPacket + self.webSocket.unpackData = self.webSocketUnpackData self.leverRate = 1 self.symbols = [] @@ -611,6 +613,12 @@ class OkexFuturesGateway(VtGateway): price=pos['short_avg_cost'], ) self.onPosition(vtPos) + + #---------------------------------------------------------------------- + @staticmethod + def webSocketUnpackData(data): + """重载websocket.unpackData""" + return json.loads(zlib.decompress(data, -zlib.MAX_WBITS)) #---------------------------------------------------------------------- def onWebSocketPacket(self, packets):