Merge branch 'dev' of https://github.com/vnpy/vnpy into dev
This commit is contained in:
commit
db7e10d3d9
@ -4,6 +4,15 @@ cache: pip
|
|||||||
python:
|
python:
|
||||||
- 2.7
|
- 2.7
|
||||||
- 3.6
|
- 3.6
|
||||||
|
before_install:
|
||||||
|
- sudo apt install build-essential
|
||||||
|
- wget http://prdownloads.sourceforge.net/ta-lib/ta-lib-0.4.0-src.tar.gz
|
||||||
|
- tar -xzf ta-lib-0.4.0-src.tar.gz
|
||||||
|
- cd ta-lib/
|
||||||
|
- ./configure --prefix=/usr
|
||||||
|
- make
|
||||||
|
- sudo make install
|
||||||
|
- cd ..
|
||||||
install:
|
install:
|
||||||
- pip install -r requirements.txt
|
- pip install -r requirements.txt
|
||||||
- pip install flake8 # pytest # add another testing frameworks later
|
- pip install flake8 # pytest # add another testing frameworks later
|
||||||
|
@ -1,12 +1,14 @@
|
|||||||
# encoding: UTF-8
|
# encoding: UTF-8
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
|
import uuid
|
||||||
|
|
||||||
from simplejson import JSONDecodeError
|
from simplejson import JSONDecodeError
|
||||||
|
|
||||||
from Promise import Promise
|
from Promise import Promise
|
||||||
from vnpy.api.rest.RestClient import RestClient, Request
|
from vnpy.api.rest.RestClient import Request, RestClient
|
||||||
|
|
||||||
|
|
||||||
class FailedError(RuntimeError):
|
class FailedError(RuntimeError):
|
||||||
@ -26,7 +28,7 @@ class TestRestClient(RestClient):
|
|||||||
req.data = json.dumps(req.data)
|
req.data = json.dumps(req.data)
|
||||||
req.headers = {'Content-Type': 'application/json'}
|
req.headers = {'Content-Type': 'application/json'}
|
||||||
return req
|
return req
|
||||||
|
|
||||||
def onError(self, exceptionType, exceptionValue, tb, req):
|
def onError(self, exceptionType, exceptionValue, tb, req):
|
||||||
self.p.set_exception(exceptionType, exceptionValue, tb)
|
self.p.set_exception(exceptionType, exceptionValue, tb)
|
||||||
|
|
||||||
@ -83,4 +85,47 @@ class RestfulClientTest(unittest.TestCase):
|
|||||||
self.c.addRequest('GET', '/image/svg', callback)
|
self.c.addRequest('GET', '/image/svg', callback)
|
||||||
with self.assertRaises(JSONDecodeError):
|
with self.assertRaises(JSONDecodeError):
|
||||||
self.c.p.get(3)
|
self.c.p.get(3)
|
||||||
|
|
||||||
|
|
||||||
|
class RestfulClientErrorHandleTest(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.c = TestRestClient()
|
||||||
|
self.c.start()
|
||||||
|
|
||||||
|
self.org_sys_hook = sys.excepthook
|
||||||
|
self.org_sys_stderr_write = sys.stderr.write
|
||||||
|
|
||||||
|
sys.excepthook = self.nop
|
||||||
|
sys.stderr.write = self.nop
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.c.stop()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def nop(*args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_onError(self):
|
||||||
|
"""这个测试保证onError内不会再抛异常"""
|
||||||
|
target = uuid.uuid4()
|
||||||
|
|
||||||
|
def callback(data, req):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def onError(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
super(TestRestClient, self.c).onError(*args, **kwargs)
|
||||||
|
self.c.p.set_result(target)
|
||||||
|
except:
|
||||||
|
self.c.p.set_result(False)
|
||||||
|
self.c.onError = onError
|
||||||
|
|
||||||
|
self.c.addRequest('GET', '/image/svg', callback)
|
||||||
|
|
||||||
|
res = self.c.p.get(3)
|
||||||
|
self.assertEqual(target, res)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
# encoding: UTF-8
|
# encoding: UTF-8
|
||||||
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
|
import uuid
|
||||||
|
|
||||||
from Promise import Promise
|
from Promise import Promise
|
||||||
from vnpy.api.websocket import WebsocketClient
|
from vnpy.api.websocket import WebsocketClient
|
||||||
@ -12,12 +14,14 @@ class TestWebsocketClient(WebsocketClient):
|
|||||||
super(TestWebsocketClient, self).__init__()
|
super(TestWebsocketClient, self).__init__()
|
||||||
self.init(host)
|
self.init(host)
|
||||||
self.p = Promise()
|
self.p = Promise()
|
||||||
|
self.cp = Promise()
|
||||||
|
|
||||||
def onPacket(self, packet):
|
def onPacket(self, packet):
|
||||||
self.p.set_result(packet)
|
self.p.set_result(packet)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def onConnected(self):
|
def onConnected(self):
|
||||||
|
self.cp.set_result(True)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def onError(self, exceptionType, exceptionValue, tb):
|
def onError(self, exceptionType, exceptionValue, tb):
|
||||||
@ -30,6 +34,7 @@ class WebsocketClientTest(unittest.TestCase):
|
|||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.c = TestWebsocketClient()
|
self.c = TestWebsocketClient()
|
||||||
self.c.start()
|
self.c.start()
|
||||||
|
self.c.cp.get(3) # wait for connected
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
self.c.stop()
|
self.c.stop()
|
||||||
@ -42,3 +47,73 @@ class WebsocketClientTest(unittest.TestCase):
|
|||||||
res = self.c.p.get(3)
|
res = self.c.p.get(3)
|
||||||
|
|
||||||
self.assertEqual(res, req)
|
self.assertEqual(res, req)
|
||||||
|
|
||||||
|
def test_parseError(self):
|
||||||
|
class CustomException(Exception): pass
|
||||||
|
|
||||||
|
def onPacket(packet):
|
||||||
|
raise CustomException("Just a exception")
|
||||||
|
|
||||||
|
self.c.onPacket = onPacket
|
||||||
|
req = {
|
||||||
|
'name': 'val'
|
||||||
|
}
|
||||||
|
self.c.sendPacket(req)
|
||||||
|
|
||||||
|
with self.assertRaises(CustomException):
|
||||||
|
self.c.p.get(3)
|
||||||
|
|
||||||
|
|
||||||
|
class WebsocketClientErrorHandleTest(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
|
||||||
|
self.c = TestWebsocketClient()
|
||||||
|
self.c.start()
|
||||||
|
self.c.cp.get(3) # wait for connected
|
||||||
|
|
||||||
|
self.org_sys_hook = sys.excepthook
|
||||||
|
self.org_sys_stderr_write = sys.stderr.write
|
||||||
|
|
||||||
|
sys.excepthook = self.nop
|
||||||
|
sys.stderr.write = self.nop
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def nop(*args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.c.stop()
|
||||||
|
sys.excepthook = self.org_sys_hook
|
||||||
|
sys.stderr.write = self.org_sys_stderr_write
|
||||||
|
|
||||||
|
def test_onError(self):
|
||||||
|
target= uuid.uuid4()
|
||||||
|
"""这个测试保证onError内不会再抛Exception"""
|
||||||
|
class CustomException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def onPacket(packet):
|
||||||
|
raise CustomException("Just a exception")
|
||||||
|
|
||||||
|
def onError(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
super(TestWebsocketClient, self.c).onError(*args, **kwargs)
|
||||||
|
self.c.p.set_result(target)
|
||||||
|
except:
|
||||||
|
self.c.p.set_result(False)
|
||||||
|
|
||||||
|
self.c.onPacket = onPacket
|
||||||
|
self.c.onError = onError
|
||||||
|
|
||||||
|
req = {
|
||||||
|
'name': 'val'
|
||||||
|
}
|
||||||
|
self.c.sendPacket(req)
|
||||||
|
|
||||||
|
res = self.c.p.get(3)
|
||||||
|
self.assertEqual(target, res)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
|
@ -2,7 +2,9 @@
|
|||||||
|
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
import traceback
|
||||||
from Queue import Empty, Queue
|
from Queue import Empty, Queue
|
||||||
|
from datetime import datetime
|
||||||
from multiprocessing.dummy import Pool
|
from multiprocessing.dummy import Pool
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
@ -34,6 +36,7 @@ class Request(object):
|
|||||||
self.headers = headers # type: dict
|
self.headers = headers # type: dict
|
||||||
|
|
||||||
self.onFailed = None # type: callable
|
self.onFailed = None # type: callable
|
||||||
|
self.onError = None # type: callable
|
||||||
self.extra = None # type: Any
|
self.extra = None # type: Any
|
||||||
|
|
||||||
self.response = None # type: requests.Response
|
self.response = None # type: requests.Response
|
||||||
@ -95,7 +98,8 @@ class RestClient(object):
|
|||||||
#----------------------------------------------------------------------
|
#----------------------------------------------------------------------
|
||||||
def start(self, n=3):
|
def start(self, n=3):
|
||||||
"""启动"""
|
"""启动"""
|
||||||
assert not self._active
|
if self._active:
|
||||||
|
return
|
||||||
|
|
||||||
self._active = True
|
self._active = True
|
||||||
self._pool = Pool(n)
|
self._pool = Pool(n)
|
||||||
@ -128,6 +132,7 @@ class RestClient(object):
|
|||||||
data=None, # type: dict
|
data=None, # type: dict
|
||||||
headers=None, # type: dict
|
headers=None, # type: dict
|
||||||
onFailed=None, # type: Callable[[int, Request], Any]
|
onFailed=None, # type: Callable[[int, Request], Any]
|
||||||
|
onError=None, # type: Callable[[type, Exception, traceback, Request], Any]
|
||||||
extra=None # type: Any
|
extra=None # type: Any
|
||||||
): # type: (...)->Request
|
): # type: (...)->Request
|
||||||
"""
|
"""
|
||||||
@ -139,6 +144,7 @@ class RestClient(object):
|
|||||||
:param data: dict for body
|
:param data: dict for body
|
||||||
:param headers: dict for headers
|
:param headers: dict for headers
|
||||||
:param onFailed: 请求失败后的回调(状态吗不为2xx时认为请求失败)(如果指定该值,默认的onFailed将不会被调用) type: (code, dict, Request)
|
:param onFailed: 请求失败后的回调(状态吗不为2xx时认为请求失败)(如果指定该值,默认的onFailed将不会被调用) type: (code, dict, Request)
|
||||||
|
:param onError: 请求出现Python错误后的回调(如果指定该值,默认的onError将不会被调用) type: (etype, evalue, tb, Request)
|
||||||
:param extra: 返回值的extra字段会被设置为这个值。当然,你也可以在函数调用之后再设置这个字段。
|
:param extra: 返回值的extra字段会被设置为这个值。当然,你也可以在函数调用之后再设置这个字段。
|
||||||
:return: Request
|
:return: Request
|
||||||
"""
|
"""
|
||||||
@ -146,6 +152,7 @@ class RestClient(object):
|
|||||||
request = Request(method, path, params, data, headers, callback)
|
request = Request(method, path, params, data, headers, callback)
|
||||||
request.extra = extra
|
request.extra = extra
|
||||||
request.onFailed = onFailed
|
request.onFailed = onFailed
|
||||||
|
request.onError = onError
|
||||||
self._queue.put(request)
|
self._queue.put(request)
|
||||||
return request
|
return request
|
||||||
|
|
||||||
@ -195,9 +202,29 @@ class RestClient(object):
|
|||||||
Python内部错误处理:默认行为是仍给excepthook
|
Python内部错误处理:默认行为是仍给excepthook
|
||||||
:param request 如果是在处理请求的时候出错,它的值就是对应的Request,否则为None
|
:param request 如果是在处理请求的时候出错,它的值就是对应的Request,否则为None
|
||||||
"""
|
"""
|
||||||
print("error in request : {}\n".format(request))
|
sys.stderr.write(self.exceptionDetail(exceptionType, exceptionValue, tb, request))
|
||||||
sys.excepthook(exceptionType, exceptionValue, tb)
|
sys.excepthook(exceptionType, exceptionValue, tb)
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
def exceptionDetail(self,
|
||||||
|
exceptionType, # type: type
|
||||||
|
exceptionValue, # type: Exception
|
||||||
|
tb,
|
||||||
|
request # type: Optional[Request]
|
||||||
|
):
|
||||||
|
text = "[{}]: Unhandled RestClient Error:{}\n".format(
|
||||||
|
datetime.now().isoformat(),
|
||||||
|
exceptionType
|
||||||
|
)
|
||||||
|
text += "request:{}\n".format(request)
|
||||||
|
text += "Exception trace: \n"
|
||||||
|
text += "".join(traceback.format_exception(
|
||||||
|
exceptionType,
|
||||||
|
exceptionValue,
|
||||||
|
tb,
|
||||||
|
))
|
||||||
|
return text
|
||||||
|
|
||||||
#----------------------------------------------------------------------
|
#----------------------------------------------------------------------
|
||||||
def _processRequest(self, request, session): # type: (Request, requests.Session)->None
|
def _processRequest(self, request, session): # type: (Request, requests.Session)->None
|
||||||
"""
|
"""
|
||||||
@ -231,7 +258,10 @@ class RestClient(object):
|
|||||||
except:
|
except:
|
||||||
request.status = RequestStatus.error
|
request.status = RequestStatus.error
|
||||||
t, v, tb = sys.exc_info()
|
t, v, tb = sys.exc_info()
|
||||||
self.onError(t, v, tb, request)
|
if request.onError:
|
||||||
|
request.onError(t, v, tb, request)
|
||||||
|
else:
|
||||||
|
self.onError(t, v, tb, request)
|
||||||
|
|
||||||
#----------------------------------------------------------------------
|
#----------------------------------------------------------------------
|
||||||
def makeFullUrl(self, path):
|
def makeFullUrl(self, path):
|
||||||
@ -243,3 +273,4 @@ class RestClient(object):
|
|||||||
"""
|
"""
|
||||||
url = self.urlBase + path
|
url = self.urlBase + path
|
||||||
return url
|
return url
|
||||||
|
|
||||||
|
@ -3,13 +3,15 @@
|
|||||||
|
|
||||||
########################################################################
|
########################################################################
|
||||||
import json
|
import json
|
||||||
import sys
|
|
||||||
|
|
||||||
import ssl
|
import ssl
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
import websocket
|
import traceback
|
||||||
|
from datetime import datetime
|
||||||
from threading import Lock, Thread
|
from threading import Lock, Thread
|
||||||
|
|
||||||
|
import websocket
|
||||||
|
|
||||||
|
|
||||||
class WebsocketClient(object):
|
class WebsocketClient(object):
|
||||||
"""
|
"""
|
||||||
@ -50,7 +52,11 @@ class WebsocketClient(object):
|
|||||||
|
|
||||||
#----------------------------------------------------------------------
|
#----------------------------------------------------------------------
|
||||||
def start(self):
|
def start(self):
|
||||||
"""启动"""
|
"""
|
||||||
|
启动
|
||||||
|
:note 注意:启动之后不能立即发包,需要等待websocket连接成功。
|
||||||
|
websocket连接成功之后会响应onConnected函数
|
||||||
|
"""
|
||||||
|
|
||||||
self._active = True
|
self._active = True
|
||||||
self._workerThread = Thread(target=self._run)
|
self._workerThread = Thread(target=self._run)
|
||||||
@ -58,6 +64,10 @@ class WebsocketClient(object):
|
|||||||
|
|
||||||
self._pingThread = Thread(target=self._runPing)
|
self._pingThread = Thread(target=self._runPing)
|
||||||
self._pingThread.start()
|
self._pingThread.start()
|
||||||
|
|
||||||
|
# for debugging:
|
||||||
|
self._lastSentText = None
|
||||||
|
self._lastReceivedText = None
|
||||||
|
|
||||||
#----------------------------------------------------------------------
|
#----------------------------------------------------------------------
|
||||||
def stop(self):
|
def stop(self):
|
||||||
@ -80,7 +90,9 @@ class WebsocketClient(object):
|
|||||||
#----------------------------------------------------------------------
|
#----------------------------------------------------------------------
|
||||||
def sendPacket(self, dictObj): # type: (dict)->None
|
def sendPacket(self, dictObj): # type: (dict)->None
|
||||||
"""发出请求:相当于sendText(json.dumps(dictObj))"""
|
"""发出请求:相当于sendText(json.dumps(dictObj))"""
|
||||||
return self._getWs().send(json.dumps(dictObj), opcode=websocket.ABNF.OPCODE_TEXT)
|
text = json.dumps(dictObj)
|
||||||
|
self._recordLastSentText(text)
|
||||||
|
return self._getWs().send(text, opcode=websocket.ABNF.OPCODE_TEXT)
|
||||||
|
|
||||||
#----------------------------------------------------------------------
|
#----------------------------------------------------------------------
|
||||||
def sendText(self, text): # type: (str)->None
|
def sendText(self, text): # type: (str)->None
|
||||||
@ -137,15 +149,15 @@ class WebsocketClient(object):
|
|||||||
try:
|
try:
|
||||||
ws = self._getWs()
|
ws = self._getWs()
|
||||||
if ws:
|
if ws:
|
||||||
stream = ws.recv()
|
text = ws.recv()
|
||||||
if not stream: # recv在阻塞的时候ws被关闭
|
if not text: # recv在阻塞的时候ws被关闭
|
||||||
self._reconnect()
|
self._reconnect()
|
||||||
continue
|
continue
|
||||||
|
self._recordLastReceivedText(text)
|
||||||
try:
|
try:
|
||||||
data = json.loads(stream)
|
data = self.unpackData(text)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
print('websocket unable to parse data: ' + stream)
|
print('websocket unable to parse data: ' + text)
|
||||||
raise e
|
raise e
|
||||||
self.onPacket(data)
|
self.onPacket(data)
|
||||||
except websocket.WebSocketConnectionClosedException: # 在调用recv之前ws就被关闭了
|
except websocket.WebSocketConnectionClosedException: # 在调用recv之前ws就被关闭了
|
||||||
@ -158,6 +170,17 @@ class WebsocketClient(object):
|
|||||||
et, ev, tb = sys.exc_info()
|
et, ev, tb = sys.exc_info()
|
||||||
self.onError(et, ev, tb)
|
self.onError(et, ev, tb)
|
||||||
self._reconnect()
|
self._reconnect()
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
@staticmethod
|
||||||
|
def unpackData(data):
|
||||||
|
"""
|
||||||
|
解密数据,默认使用json解密为dict
|
||||||
|
解密后的数据将会传入onPacket
|
||||||
|
如果需要使用不同的解密方式,就重载这个函数。
|
||||||
|
:param data 收到的数据,可能是text frame,也可能是binary frame, 目前并没有区分这两者
|
||||||
|
"""
|
||||||
|
return json.loads(data)
|
||||||
|
|
||||||
#----------------------------------------------------------------------
|
#----------------------------------------------------------------------
|
||||||
def _runPing(self):
|
def _runPing(self):
|
||||||
@ -207,7 +230,43 @@ class WebsocketClient(object):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
#----------------------------------------------------------------------
|
#----------------------------------------------------------------------
|
||||||
@staticmethod
|
def onError(self, exceptionType, exceptionValue, tb):
|
||||||
def onError(exceptionType, exceptionValue, tb):
|
"""
|
||||||
"""Python错误回调"""
|
Python错误回调
|
||||||
|
todo: 以后详细的错误信息最好记录在文件里,用uuid来联系/区分具体错误
|
||||||
|
"""
|
||||||
|
sys.stderr.write(self.exceptionDetail(exceptionType, exceptionValue, tb))
|
||||||
|
|
||||||
|
# 丢给默认的错误处理函数(所以如果不重载onError,一般的结果是程序会崩溃)
|
||||||
return sys.excepthook(exceptionType, exceptionValue, tb)
|
return sys.excepthook(exceptionType, exceptionValue, tb)
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
def exceptionDetail(self, exceptionType, exceptionValue, tb):
|
||||||
|
"""打印详细的错误信息"""
|
||||||
|
text = "[{}]: Unhandled WebSocket Error:{}\n".format(
|
||||||
|
datetime.now().isoformat(),
|
||||||
|
exceptionType
|
||||||
|
)
|
||||||
|
text += "LastSentText:\n{}\n".format(self._lastSentText)
|
||||||
|
text += "LastReceivedText:\n{}\n".format(self._lastReceivedText)
|
||||||
|
text += "Exception trace: \n"
|
||||||
|
text += "".join(traceback.format_exception(
|
||||||
|
exceptionType,
|
||||||
|
exceptionValue,
|
||||||
|
tb,
|
||||||
|
))
|
||||||
|
return text
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
def _recordLastSentText(self, text):
|
||||||
|
"""
|
||||||
|
用于Debug: 记录最后一次发送出去的text
|
||||||
|
"""
|
||||||
|
self._lastSentText = text[:200]
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
def _recordLastReceivedText(self, text):
|
||||||
|
"""
|
||||||
|
用于Debug: 记录最后一次发送出去的text
|
||||||
|
"""
|
||||||
|
self._lastReceivedText = text[:200]
|
||||||
|
@ -12,6 +12,7 @@ import traceback
|
|||||||
# 用来保存策略类的字典
|
# 用来保存策略类的字典
|
||||||
STRATEGY_CLASS = {}
|
STRATEGY_CLASS = {}
|
||||||
|
|
||||||
|
|
||||||
#----------------------------------------------------------------------
|
#----------------------------------------------------------------------
|
||||||
def loadStrategyModule(moduleName):
|
def loadStrategyModule(moduleName):
|
||||||
"""使用importlib动态载入模块"""
|
"""使用importlib动态载入模块"""
|
||||||
@ -34,7 +35,7 @@ path = os.path.abspath(os.path.dirname(__file__))
|
|||||||
for root, subdirs, files in os.walk(path):
|
for root, subdirs, files in os.walk(path):
|
||||||
for name in files:
|
for name in files:
|
||||||
# 只有文件名中包含strategy且以.py结尾的文件,才是策略文件
|
# 只有文件名中包含strategy且以.py结尾的文件,才是策略文件
|
||||||
if 'strategy' in name and name[-3:] == '.py':
|
if 'strategy' in name and name[-3:] == '.py' and '/' not in name and '\\' not in name:
|
||||||
# 模块名称需要模块路径前缀
|
# 模块名称需要模块路径前缀
|
||||||
moduleName = 'vnpy.trader.app.ctaStrategy.strategy.' + name.replace('.py', '')
|
moduleName = 'vnpy.trader.app.ctaStrategy.strategy.' + name.replace('.py', '')
|
||||||
loadStrategyModule(moduleName)
|
loadStrategyModule(moduleName)
|
||||||
@ -45,7 +46,7 @@ workingPath = os.getcwd()
|
|||||||
for root, subdirs, files in os.walk(workingPath):
|
for root, subdirs, files in os.walk(workingPath):
|
||||||
for name in files:
|
for name in files:
|
||||||
# 只有文件名中包含strategy且以.py结尾的文件,才是策略文件
|
# 只有文件名中包含strategy且以.py结尾的文件,才是策略文件
|
||||||
if 'strategy' in name and name[-3:] == '.py':
|
if 'strategy' in name and name[-3:] == '.py' and '/' not in name and '\\' not in name:
|
||||||
# 模块名称无需前缀
|
# 模块名称无需前缀
|
||||||
moduleName = name.replace('.py', '')
|
moduleName = name.replace('.py', '')
|
||||||
loadStrategyModule(moduleName)
|
loadStrategyModule(moduleName)
|
||||||
|
@ -10,6 +10,7 @@ import os
|
|||||||
import json
|
import json
|
||||||
import hashlib
|
import hashlib
|
||||||
import hmac
|
import hmac
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
@ -17,7 +18,9 @@ from copy import copy
|
|||||||
from math import pow
|
from math import pow
|
||||||
from urllib import urlencode
|
from urllib import urlencode
|
||||||
|
|
||||||
from vnpy.api.rest import RestClient
|
from requests import ConnectionError
|
||||||
|
|
||||||
|
from vnpy.api.rest import RestClient, Request
|
||||||
from vnpy.api.websocket import WebsocketClient
|
from vnpy.api.websocket import WebsocketClient
|
||||||
from vnpy.trader.vtGateway import *
|
from vnpy.trader.vtGateway import *
|
||||||
from vnpy.trader.vtFunction import getJsonPath, getTempPath
|
from vnpy.trader.vtFunction import getJsonPath, getTempPath
|
||||||
@ -67,6 +70,8 @@ class BitmexGateway(VtGateway):
|
|||||||
|
|
||||||
self.fileName = self.gatewayName + '_connect.json'
|
self.fileName = self.gatewayName + '_connect.json'
|
||||||
self.filePath = getJsonPath(self.fileName, __file__)
|
self.filePath = getJsonPath(self.fileName, __file__)
|
||||||
|
|
||||||
|
self.exchange = constant.EXCHANGE_BITMEX
|
||||||
|
|
||||||
#----------------------------------------------------------------------
|
#----------------------------------------------------------------------
|
||||||
def connect(self):
|
def connect(self):
|
||||||
@ -172,7 +177,7 @@ class BitmexRestApi(RestClient):
|
|||||||
"""Constructor"""
|
"""Constructor"""
|
||||||
super(BitmexRestApi, self).__init__()
|
super(BitmexRestApi, self).__init__()
|
||||||
|
|
||||||
self.gateway = gateway # gateway对象
|
self.gateway = gateway # type: BitmexGateway # gateway对象
|
||||||
self.gatewayName = gateway.gatewayName # gateway对象名称
|
self.gatewayName = gateway.gatewayName # gateway对象名称
|
||||||
|
|
||||||
self.apiKey = ''
|
self.apiKey = ''
|
||||||
@ -240,11 +245,11 @@ class BitmexRestApi(RestClient):
|
|||||||
self.gateway.onLog(log)
|
self.gateway.onLog(log)
|
||||||
|
|
||||||
#----------------------------------------------------------------------
|
#----------------------------------------------------------------------
|
||||||
def sendOrder(self, orderReq):
|
def sendOrder(self, orderReq):# type: (VtOrderReq)->str
|
||||||
""""""
|
""""""
|
||||||
self.orderId += 1
|
self.orderId += 1
|
||||||
orderId = self.loginTime + self.orderId
|
orderId = str(self.loginTime + self.orderId)
|
||||||
vtOrderID = '.'.join([self.gatewayName, str(orderId)])
|
vtOrderID = '.'.join([self.gatewayName, orderId])
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
'symbol': orderReq.symbol,
|
'symbol': orderReq.symbol,
|
||||||
@ -258,8 +263,22 @@ class BitmexRestApi(RestClient):
|
|||||||
# 只有限价单才有price字段
|
# 只有限价单才有price字段
|
||||||
if orderReq.priceType == PRICETYPE_LIMITPRICE:
|
if orderReq.priceType == PRICETYPE_LIMITPRICE:
|
||||||
data['price'] = orderReq.price
|
data['price'] = orderReq.price
|
||||||
|
|
||||||
|
vtOrder = VtOrderData.createFromGateway(
|
||||||
|
self.gateway,
|
||||||
|
orderId=orderId,
|
||||||
|
symbol=orderReq.symbol,
|
||||||
|
exchange=self.gateway.exchange,
|
||||||
|
price=orderReq.price,
|
||||||
|
volume=orderReq.volume,
|
||||||
|
direction=orderReq.direction,
|
||||||
|
offset=orderReq.offset,
|
||||||
|
)
|
||||||
|
|
||||||
self.addRequest('POST', '/order', self.onSendOrder, data=data)
|
self.addRequest('POST', '/order', callback=self.onSendOrder, data=data, extra=vtOrder,
|
||||||
|
onFailed=self.onSendOrderFailed,
|
||||||
|
onError=self.onSendOrderError,
|
||||||
|
)
|
||||||
return vtOrderID
|
return vtOrderID
|
||||||
|
|
||||||
#----------------------------------------------------------------------
|
#----------------------------------------------------------------------
|
||||||
@ -272,12 +291,46 @@ class BitmexRestApi(RestClient):
|
|||||||
else:
|
else:
|
||||||
params = {'orderID': orderID}
|
params = {'orderID': orderID}
|
||||||
|
|
||||||
self.addRequest('DELETE', '/order', self.onCancelOrder, params=params)
|
self.addRequest('DELETE', '/order', callback=self.onCancelOrder, params=params,
|
||||||
|
onError=self.onCancelOrderError,
|
||||||
|
)
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
def onSendOrderFailed(self, _, request):
|
||||||
|
"""
|
||||||
|
下单失败回调:服务器明确告知下单失败
|
||||||
|
"""
|
||||||
|
vtOrder = request.extra # type: VtOrderData
|
||||||
|
vtOrder.status = constant.STATUS_REJECTED
|
||||||
|
self.gateway.onOrder(vtOrder)
|
||||||
|
pass
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
def onSendOrderError(self, exceptionType, exceptionValue, tb, request):
|
||||||
|
"""
|
||||||
|
下单失败回调:连接错误
|
||||||
|
"""
|
||||||
|
vtOrder = request.extra # type: VtOrderData
|
||||||
|
vtOrder.status = constant.STATUS_REJECTED
|
||||||
|
self.gateway.onOrder(vtOrder)
|
||||||
|
|
||||||
|
# 意料之中的错误只有ConnectionError,若还有其他错误,最好还是记录一下,用原来的onError记录即可
|
||||||
|
if not issubclass(exceptionType, ConnectionError):
|
||||||
|
self.onError(exceptionType, exceptionValue, tb, request)
|
||||||
|
|
||||||
#----------------------------------------------------------------------
|
#----------------------------------------------------------------------
|
||||||
def onSendOrder(self, data, request):
|
def onSendOrder(self, data, request):
|
||||||
""""""
|
""""""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
def onCancelOrderError(self, exceptionType, exceptionValue, tb, request):
|
||||||
|
"""
|
||||||
|
撤单失败回调:连接错误
|
||||||
|
"""
|
||||||
|
# 意料之中的错误只有ConnectionError,若还有其他错误,最好还是记录一下,用原来的onError记录即可
|
||||||
|
if not issubclass(exceptionType, ConnectionError):
|
||||||
|
self.onError(exceptionType, exceptionValue, tb, request)
|
||||||
|
|
||||||
#----------------------------------------------------------------------
|
#----------------------------------------------------------------------
|
||||||
def onCancelOrder(self, data, request):
|
def onCancelOrder(self, data, request):
|
||||||
@ -307,8 +360,8 @@ class BitmexRestApi(RestClient):
|
|||||||
e.errorID = exceptionType
|
e.errorID = exceptionType
|
||||||
e.errorMsg = exceptionValue
|
e.errorMsg = exceptionValue
|
||||||
self.gateway.onError(e)
|
self.gateway.onError(e)
|
||||||
|
|
||||||
traceback.print_exc()
|
sys.stderr.write(self.exceptionDetail(exceptionType, exceptionValue, tb, request))
|
||||||
|
|
||||||
|
|
||||||
########################################################################
|
########################################################################
|
||||||
@ -517,10 +570,13 @@ class BitmexWebsocketApi(WebsocketClient):
|
|||||||
trade.orderID = orderID
|
trade.orderID = orderID
|
||||||
trade.vtOrderID = '.'.join([trade.gatewayName, trade.orderID])
|
trade.vtOrderID = '.'.join([trade.gatewayName, trade.orderID])
|
||||||
|
|
||||||
|
|
||||||
trade.tradeID = tradeID
|
trade.tradeID = tradeID
|
||||||
trade.vtTradeID = '.'.join([trade.gatewayName, trade.tradeID])
|
trade.vtTradeID = '.'.join([trade.gatewayName, trade.tradeID])
|
||||||
|
|
||||||
|
if 'side' not in d:
|
||||||
|
print('no side : \n', d)
|
||||||
|
return
|
||||||
|
|
||||||
trade.direction = directionMapReverse[d['side']]
|
trade.direction = directionMapReverse[d['side']]
|
||||||
trade.price = d['lastPx']
|
trade.price = d['lastPx']
|
||||||
trade.volume = d['lastQty']
|
trade.volume = d['lastQty']
|
||||||
|
@ -376,9 +376,33 @@ class CtpMdApi(MdApi):
|
|||||||
tick.askVolume1 = data['AskVolume1']
|
tick.askVolume1 = data['AskVolume1']
|
||||||
|
|
||||||
# 大商所日期转换
|
# 大商所日期转换
|
||||||
if tick.exchange is EXCHANGE_DCE:
|
if tick.exchange == EXCHANGE_DCE:
|
||||||
tick.date = datetime.now().strftime('%Y%m%d')
|
tick.date = datetime.now().strftime('%Y%m%d')
|
||||||
|
|
||||||
|
# 上交所,SEE,股票期权相关
|
||||||
|
if tick.exchange == EXCHANGE_SSE:
|
||||||
|
tick.bidPrice2 = data['BidPrice2']
|
||||||
|
tick.bidVolume2 = data['BidVolume2']
|
||||||
|
tick.askPrice2 = data['AskPrice2']
|
||||||
|
tick.askVolume2 = data['AskVolume2']
|
||||||
|
|
||||||
|
tick.bidPrice3 = data['BidPrice3']
|
||||||
|
tick.bidVolume3 = data['BidVolume3']
|
||||||
|
tick.askPrice3 = data['AskPrice3']
|
||||||
|
tick.askVolume3 = data['AskVolume3']
|
||||||
|
|
||||||
|
tick.bidPrice4 = data['BidPrice4']
|
||||||
|
tick.bidVolume4 = data['BidVolume4']
|
||||||
|
tick.askPrice4 = data['AskPrice4']
|
||||||
|
tick.askVolume4 = data['AskVolume4']
|
||||||
|
|
||||||
|
tick.bidPrice5 = data['BidPrice5']
|
||||||
|
tick.bidVolume5 = data['BidVolume5']
|
||||||
|
tick.askPrice5 = data['AskPrice5']
|
||||||
|
tick.askVolume5 = data['AskVolume5']
|
||||||
|
|
||||||
|
tick.date = data['TradingDay']
|
||||||
|
|
||||||
self.gateway.onTick(tick)
|
self.gateway.onTick(tick)
|
||||||
|
|
||||||
#----------------------------------------------------------------------
|
#----------------------------------------------------------------------
|
||||||
|
@ -121,7 +121,7 @@ class OkexFuturesWebSocketBase(WebsocketClient):
|
|||||||
Okex期货websocket客户端
|
Okex期货websocket客户端
|
||||||
实例化后使用init设置apiKey和secretKey(apiSecret)
|
实例化后使用init设置apiKey和secretKey(apiSecret)
|
||||||
"""
|
"""
|
||||||
host = 'wss://real.okex.com:10440/websocket/okexapi'
|
host = 'wss://real.okex.com:10440/websocket/okexapi?compress=true'
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(OkexFuturesWebSocketBase, self).__init__()
|
super(OkexFuturesWebSocketBase, self).__init__()
|
||||||
|
@ -5,6 +5,7 @@ from __future__ import print_function
|
|||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
import zlib
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@ -141,6 +142,7 @@ class OkexFuturesGateway(VtGateway):
|
|||||||
|
|
||||||
self.webSocket = OkexFuturesWebSocketBase()
|
self.webSocket = OkexFuturesWebSocketBase()
|
||||||
self.webSocket.onPacket = self.onWebSocketPacket
|
self.webSocket.onPacket = self.onWebSocketPacket
|
||||||
|
self.webSocket.unpackData = self.webSocketUnpackData
|
||||||
|
|
||||||
self.leverRate = 1
|
self.leverRate = 1
|
||||||
self.symbols = []
|
self.symbols = []
|
||||||
@ -611,6 +613,12 @@ class OkexFuturesGateway(VtGateway):
|
|||||||
price=pos['short_avg_cost'],
|
price=pos['short_avg_cost'],
|
||||||
)
|
)
|
||||||
self.onPosition(vtPos)
|
self.onPosition(vtPos)
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------
|
||||||
|
@staticmethod
|
||||||
|
def webSocketUnpackData(data):
|
||||||
|
"""重载websocket.unpackData"""
|
||||||
|
return json.loads(zlib.decompress(data, -zlib.MAX_WBITS))
|
||||||
|
|
||||||
#----------------------------------------------------------------------
|
#----------------------------------------------------------------------
|
||||||
def onWebSocketPacket(self, packets):
|
def onWebSocketPacket(self, packets):
|
||||||
|
Loading…
Reference in New Issue
Block a user