From 6bf7837055b13990446a8012cd988c76b3d0ffe8 Mon Sep 17 00:00:00 2001 From: nanoric Date: Sun, 7 Oct 2018 03:33:33 -0400 Subject: [PATCH] =?UTF-8?q?[Add]=20=E5=A2=9E=E5=8A=A0RestfulClient?= =?UTF-8?q?=E7=9A=84=E5=8D=95=E5=85=83=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/Promise.py | 30 ++++++++++++++ tests/all_test.py | 7 ++++ tests/restful/RestfulClientTest.py | 65 ++++++++++++++++++++++++++++++ tests/restful/__init__.py | 0 4 files changed, 102 insertions(+) create mode 100644 tests/Promise.py create mode 100644 tests/all_test.py create mode 100644 tests/restful/RestfulClientTest.py create mode 100644 tests/restful/__init__.py diff --git a/tests/Promise.py b/tests/Promise.py new file mode 100644 index 00000000..a27046cf --- /dev/null +++ b/tests/Promise.py @@ -0,0 +1,30 @@ +# encoding: UTF-8 +from Queue import Queue + +from enum import Enum + + +class PromiseResultType(Enum): + Result = 1 + Exception = 2 + + +class Promise(object): + """ + 用队列实现的一个简单的Promise类型 + """ + + def __init__(self): + self._queue = Queue() + + def set_result(self, val): + self._queue.put((PromiseResultType.Result, val)) + + def get(self, timeout=None): + res = self._queue.get(timeout=timeout) + if res[0] == PromiseResultType.Result: + return res[1] + raise res[2] + + def set_exception(self, val): + self._queue.put((PromiseResultType.Exception, val)) diff --git a/tests/all_test.py b/tests/all_test.py new file mode 100644 index 00000000..cb451564 --- /dev/null +++ b/tests/all_test.py @@ -0,0 +1,7 @@ +import unittest + +# noinspection PyUnresolvedReferences +from restful.RestfulClientTest import * + +if __name__ == "__main__": + unittest.main() diff --git a/tests/restful/RestfulClientTest.py b/tests/restful/RestfulClientTest.py new file mode 100644 index 00000000..5d3421a7 --- /dev/null +++ b/tests/restful/RestfulClientTest.py @@ -0,0 +1,65 @@ +# encoding: UTF-8 + +import json +import traceback +import unittest + +from Promise import Promise +from vnpy.restful.RestfulClient import RestfulClient, requestsSessionProvider + + +class TestRestfulClient(RestfulClient): + + def __init__(self): + urlBase = 'https://httpbin.org' + super(TestRestfulClient, self).__init__(urlBase, requestsSessionProvider) + + self.p = Promise() + + def beforeRequest(self, method, path, params, data): + data = json.dumps(data) + return method, path, params, data, {'Content-Type': 'application/json'} + + def onError(self, exceptionType, exceptionValue, tb, req): + traceback.print_exception(exceptionType, exceptionValue, tb) + self.p.set_exception(exceptionValue) + + +class RestfulClientTest(unittest.TestCase): + + def setUp(self): + self.c = TestRestfulClient() + self.c.start() + + def tearDown(self): + self.c.stop() + + def test_addReq_get(self): + args = {'user': 'username', + 'pw': 'password'} + + def callback(code, data, req): + if code == 200: + self.c.p.set_result(data['args']) + return + self.c.p.set_result(False) + + self.c.addReq('GET', '/get', callback, params=args) + res = self.c.p.get(3) + + self.assertEqual(args, res) + + def test_addReq_post(self): + body = {'user': 'username', + 'pw': 'password'} + + def callback(code, data, req): + if code == 200: + self.c.p.set_result(data['json']) + return + self.c.p.set_result(False) + + self.c.addReq('POST', '/post', callback, data=body) + res = self.c.p.get(3) + + self.assertEqual(body, res) diff --git a/tests/restful/__init__.py b/tests/restful/__init__.py new file mode 100644 index 00000000..e69de29b