diff --git a/tests/Promise.py b/tests/Promise.py new file mode 100644 index 00000000..a27046cf --- /dev/null +++ b/tests/Promise.py @@ -0,0 +1,30 @@ +# encoding: UTF-8 +from Queue import Queue + +from enum import Enum + + +class PromiseResultType(Enum): + Result = 1 + Exception = 2 + + +class Promise(object): + """ + 用队列实现的一个简单的Promise类型 + """ + + def __init__(self): + self._queue = Queue() + + def set_result(self, val): + self._queue.put((PromiseResultType.Result, val)) + + def get(self, timeout=None): + res = self._queue.get(timeout=timeout) + if res[0] == PromiseResultType.Result: + return res[1] + raise res[2] + + def set_exception(self, val): + self._queue.put((PromiseResultType.Exception, val)) diff --git a/tests/all_test.py b/tests/all_test.py new file mode 100644 index 00000000..cb451564 --- /dev/null +++ b/tests/all_test.py @@ -0,0 +1,7 @@ +import unittest + +# noinspection PyUnresolvedReferences +from restful.RestfulClientTest import * + +if __name__ == "__main__": + unittest.main() diff --git a/tests/restful/RestfulClientTest.py b/tests/restful/RestfulClientTest.py new file mode 100644 index 00000000..5d3421a7 --- /dev/null +++ b/tests/restful/RestfulClientTest.py @@ -0,0 +1,65 @@ +# encoding: UTF-8 + +import json +import traceback +import unittest + +from Promise import Promise +from vnpy.restful.RestfulClient import RestfulClient, requestsSessionProvider + + +class TestRestfulClient(RestfulClient): + + def __init__(self): + urlBase = 'https://httpbin.org' + super(TestRestfulClient, self).__init__(urlBase, requestsSessionProvider) + + self.p = Promise() + + def beforeRequest(self, method, path, params, data): + data = json.dumps(data) + return method, path, params, data, {'Content-Type': 'application/json'} + + def onError(self, exceptionType, exceptionValue, tb, req): + traceback.print_exception(exceptionType, exceptionValue, tb) + self.p.set_exception(exceptionValue) + + +class RestfulClientTest(unittest.TestCase): + + def setUp(self): + self.c = TestRestfulClient() + self.c.start() + + def tearDown(self): + self.c.stop() + + def test_addReq_get(self): + args = {'user': 'username', + 'pw': 'password'} + + def callback(code, data, req): + if code == 200: + self.c.p.set_result(data['args']) + return + self.c.p.set_result(False) + + self.c.addReq('GET', '/get', callback, params=args) + res = self.c.p.get(3) + + self.assertEqual(args, res) + + def test_addReq_post(self): + body = {'user': 'username', + 'pw': 'password'} + + def callback(code, data, req): + if code == 200: + self.c.p.set_result(data['json']) + return + self.c.p.set_result(False) + + self.c.addReq('POST', '/post', callback, data=body) + res = self.c.p.get(3) + + self.assertEqual(body, res) diff --git a/tests/restful/__init__.py b/tests/restful/__init__.py new file mode 100644 index 00000000..e69de29b