From 6359fac19647f47a5975e559fe15d4c6836ae0d5 Mon Sep 17 00:00:00 2001 From: 1122455801 Date: Tue, 26 Mar 2019 15:53:15 +0800 Subject: [PATCH 01/49] Create genetic_algorithm.ipynb --- tests/backtesting/genetic_algorithm.ipynb | 189 ++++++++++++++++++++++ 1 file changed, 189 insertions(+) create mode 100644 tests/backtesting/genetic_algorithm.ipynb diff --git a/tests/backtesting/genetic_algorithm.ipynb b/tests/backtesting/genetic_algorithm.ipynb new file mode 100644 index 00000000..c9e92e12 --- /dev/null +++ b/tests/backtesting/genetic_algorithm.ipynb @@ -0,0 +1,189 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import random\n", + "import multiprocessing\n", + "import numpy as np\n", + "from deap import creator, base, tools, algorithms\n", + "from vnpy.app.cta_strategy.backtesting import BacktestingEngine\n", + "from boll_channel_strategy import BollChannelStrategy\n", + "from datetime import datetime\n", + "import multiprocessing #多进程\n", + "from scoop import futures #多进程" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def parameter_generate():\n", + " '''\n", + " 根据设置的起始值,终止值和步进,随机生成待优化的策略参数\n", + " '''\n", + " parameter_list = []\n", + " p1 = random.randrange(4,50,2) #布林带窗口\n", + " p2 = random.randrange(4,50,2) #布林带通道阈值\n", + " p3 = random.randrange(4,50,2) #CCI窗口\n", + " p4 = random.randrange(18,40,2) #ATR窗口 \n", + "\n", + " parameter_list.append(p1)\n", + " parameter_list.append(p2)\n", + " parameter_list.append(p3)\n", + " parameter_list.append(p4)\n", + "\n", + " return parameter_list" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def object_func(strategy_avg):\n", + " \"\"\"\n", + " 本函数为优化目标函数,根据随机生成的策略参数,运行回测后自动返回2个结果指标:收益回撤比和夏普比率\n", + " \"\"\"\n", + " # 创建回测引擎对象\n", + " engine = BacktestingEngine()\n", + " engine.set_parameters(\n", + " vt_symbol=\"IF88.CFFEX\",\n", + " interval=\"1m\",\n", + " start=datetime(2018, 9, 1),\n", + " end=datetime(2019, 1,1),\n", + " rate=0,\n", + " slippage=0,\n", + " size=300,\n", + " pricetick=0.2,\n", + " capital=1_000_000,\n", + " )\n", + "\n", + " setting = {'boll_window': strategy_avg[0], #布林带窗口\n", + " 'boll_dev': strategy_avg[1], #布林带通道阈值\n", + " 'cci_window': strategy_avg[2], #CCI窗口\n", + " 'atr_window': strategy_avg[3],} #ATR窗口 \n", + "\n", + " #加载策略 \n", + " #engine.initStrategy(TurtleTradingStrategy, setting)\n", + " engine.add_strategy(BollChannelStrategy, setting)\n", + " engine.load_data()\n", + " engine.run_backtesting()\n", + " engine.calculate_result()\n", + " result = engine.calculate_statistics(Output=False)\n", + "\n", + " return_drawdown_ratio = round(result['return_drawdown_ratio'],2) #收益回撤比\n", + " sharpe_ratio= round(result['sharpe_ratio'],2) #夏普比率\n", + " return return_drawdown_ratio , sharpe_ratio" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "#设置优化方向:最大化收益回撤比,最大化夏普比率\n", + "creator.create(\"FitnessMulti\", base.Fitness, weights=(1.0, 1.0)) # 1.0 求最大值;-1.0 求最小值\n", + "creator.create(\"Individual\", list, fitness=creator.FitnessMulti)\n", + "\n", + "def optimize():\n", + " \"\"\"\"\"\" \n", + " toolbox = base.Toolbox() #Toolbox是deap库内置的工具箱,里面包含遗传算法中所用到的各种函数\n", + "\n", + " # 初始化 \n", + " toolbox.register(\"individual\", tools.initIterate, creator.Individual,parameter_generate) # 注册个体:随机生成的策略参数parameter_generate() \n", + " toolbox.register(\"population\", tools.initRepeat, list, toolbox.individual) #注册种群:个体形成种群 \n", + " toolbox.register(\"mate\", tools.cxTwoPoint) #注册交叉:两点交叉 \n", + " toolbox.register(\"mutate\", tools.mutUniformInt,low = 4,up = 40,indpb=0.6) #注册变异:随机生成一定区间内的整数\n", + " toolbox.register(\"evaluate\", object_func) #注册评估:优化目标函数object_func() \n", + " toolbox.register(\"select\", tools.selNSGA2) #注册选择:NSGA-II(带精英策略的非支配排序的遗传算法)\n", + " #pool = multiprocessing.Pool()\n", + " #toolbox.register(\"map\", pool.map)\n", + " #toolbox.register(\"map\", futures.map)\n", + "\n", + " #遗传算法参数设置\n", + " MU = 40 #设置每一代选择的个体数\n", + " LAMBDA = 160 #设置每一代产生的子女数\n", + " pop = toolbox.population(400) #设置族群里面的个体数量\n", + " CXPB, MUTPB, NGEN = 0.5, 0.35,40 #分别为种群内部个体的交叉概率、变异概率、产生种群代数\n", + " hof = tools.ParetoFront() #解的集合:帕累托前沿(非占优最优集)\n", + "\n", + " #解的集合的描述统计信息\n", + " #集合内平均值,标准差,最小值,最大值可以体现集合的收敛程度\n", + " #收敛程度低可以增加算法的迭代次数\n", + " stats = tools.Statistics(lambda ind: ind.fitness.values)\n", + " np.set_printoptions(suppress=True) #对numpy默认输出的科学计数法转换\n", + " stats.register(\"mean\", np.mean, axis=0) #统计目标优化函数结果的平均值\n", + " stats.register(\"std\", np.std, axis=0) #统计目标优化函数结果的标准差\n", + " stats.register(\"min\", np.min, axis=0) #统计目标优化函数结果的最小值\n", + " stats.register(\"max\", np.max, axis=0) #统计目标优化函数结果的最大值\n", + "\n", + " #运行算法\n", + " algorithms.eaMuPlusLambda(pop, toolbox, MU, LAMBDA, CXPB, MUTPB, NGEN, stats,\n", + " halloffame=hof) #esMuPlusLambda是一种基于(μ+λ)选择策略的多目标优化分段遗传算法\n", + "\n", + " return pop" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "optimize()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.1" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From fc1c76ce487dcc3c22602b1bc448fbe6ba41d068 Mon Sep 17 00:00:00 2001 From: 1122455801 Date: Tue, 26 Mar 2019 15:53:22 +0800 Subject: [PATCH 02/49] Update backtesting.py --- vnpy/app/cta_strategy/backtesting.py | 58 +++++++++++++++------------- 1 file changed, 32 insertions(+), 26 deletions(-) diff --git a/vnpy/app/cta_strategy/backtesting.py b/vnpy/app/cta_strategy/backtesting.py index cc00cad2..3a99c3b3 100644 --- a/vnpy/app/cta_strategy/backtesting.py +++ b/vnpy/app/cta_strategy/backtesting.py @@ -293,7 +293,7 @@ class BacktestingEngine: self.output("逐日盯市盈亏计算完成") return self.daily_df - def calculate_statistics(self, df: DataFrame = None): + def calculate_statistics(self, df: DataFrame = None, Output=True): """""" self.output("开始计算策略统计指标") @@ -325,6 +325,7 @@ class BacktestingEngine: daily_return = 0 return_std = 0 sharpe_ratio = 0 + return_drawdown_ratio = 0 else: # Calculate balance related time series data df["balance"] = df["net_pnl"].cumsum() + self.capital @@ -373,38 +374,42 @@ class BacktestingEngine: else: sharpe_ratio = 0 + return_drawdown_ratio = -total_return / max_ddpercent + # Output - self.output("-" * 30) - self.output(f"首个交易日:\t{start_date}") - self.output(f"最后交易日:\t{end_date}") + if Output: + self.output("-" * 30) + self.output(f"首个交易日:\t{start_date}") + self.output(f"最后交易日:\t{end_date}") - self.output(f"总交易日:\t{total_days}") - self.output(f"盈利交易日:\t{profit_days}") - self.output(f"亏损交易日:\t{loss_days}") + self.output(f"总交易日:\t{total_days}") + self.output(f"盈利交易日:\t{profit_days}") + self.output(f"亏损交易日:\t{loss_days}") - self.output(f"起始资金:\t{self.capital:,.2f}") - self.output(f"结束资金:\t{end_balance:,.2f}") + self.output(f"起始资金:\t{self.capital:,.2f}") + self.output(f"结束资金:\t{end_balance:,.2f}") - self.output(f"总收益率:\t{total_return:,.2f}%") - self.output(f"年化收益:\t{annual_return:,.2f}%") - self.output(f"最大回撤: \t{max_drawdown:,.2f}") - self.output(f"百分比最大回撤: {max_ddpercent:,.2f}%") + self.output(f"总收益率:\t{total_return:,.2f}%") + self.output(f"年化收益:\t{annual_return:,.2f}%") + self.output(f"最大回撤: \t{max_drawdown:,.2f}") + self.output(f"百分比最大回撤: {max_ddpercent:,.2f}%") - self.output(f"总盈亏:\t{total_net_pnl:,.2f}") - self.output(f"总手续费:\t{total_commission:,.2f}") - self.output(f"总滑点:\t{total_slippage:,.2f}") - self.output(f"总成交金额:\t{total_turnover:,.2f}") - self.output(f"总成交笔数:\t{total_trade_count}") + self.output(f"总盈亏:\t{total_net_pnl:,.2f}") + self.output(f"总手续费:\t{total_commission:,.2f}") + self.output(f"总滑点:\t{total_slippage:,.2f}") + self.output(f"总成交金额:\t{total_turnover:,.2f}") + self.output(f"总成交笔数:\t{total_trade_count}") - self.output(f"日均盈亏:\t{daily_net_pnl:,.2f}") - self.output(f"日均手续费:\t{daily_commission:,.2f}") - self.output(f"日均滑点:\t{daily_slippage:,.2f}") - self.output(f"日均成交金额:\t{daily_turnover:,.2f}") - self.output(f"日均成交笔数:\t{daily_trade_count}") + self.output(f"日均盈亏:\t{daily_net_pnl:,.2f}") + self.output(f"日均手续费:\t{daily_commission:,.2f}") + self.output(f"日均滑点:\t{daily_slippage:,.2f}") + self.output(f"日均成交金额:\t{daily_turnover:,.2f}") + self.output(f"日均成交笔数:\t{daily_trade_count}") - self.output(f"日均收益率:\t{daily_return:,.2f}%") - self.output(f"收益标准差:\t{return_std:,.2f}%") - self.output(f"Sharpe Ratio:\t{sharpe_ratio:,.2f}") + self.output(f"日均收益率:\t{daily_return:,.2f}%") + self.output(f"收益标准差:\t{return_std:,.2f}%") + self.output(f"Sharpe Ratio:\t{sharpe_ratio:,.2f}") + self.output(f"收益回撤比:\t{return_drawdown_ratio:,.2f}") statistics = { "start_date": start_date, @@ -430,6 +435,7 @@ class BacktestingEngine: "daily_return": daily_return, "return_std": return_std, "sharpe_ratio": sharpe_ratio, + "return_drawdown_ratio": return_drawdown_ratio, } return statistics From ce64ede969f5f06dbc3a4a73c74af6d6e04bb806 Mon Sep 17 00:00:00 2001 From: 1122455801 Date: Tue, 26 Mar 2019 15:53:25 +0800 Subject: [PATCH 03/49] Update tiger_gateway.py --- vnpy/gateway/tiger/tiger_gateway.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vnpy/gateway/tiger/tiger_gateway.py b/vnpy/gateway/tiger/tiger_gateway.py index eecd9305..f391d7b2 100644 --- a/vnpy/gateway/tiger/tiger_gateway.py +++ b/vnpy/gateway/tiger/tiger_gateway.py @@ -1,5 +1,6 @@ # encoding: UTF-8 """ +Author: KeKe Please install tiger-api before use. pip install tigeropen """ From 3d99ccb5f8b71c9774bdc35845684345a476c236 Mon Sep 17 00:00:00 2001 From: 1122455801 Date: Tue, 2 Apr 2019 17:07:04 +0800 Subject: [PATCH 04/49] Create __init__.py --- vnpy/rpc/__init__.py | 1 + 1 file changed, 1 insertion(+) create mode 100644 vnpy/rpc/__init__.py diff --git a/vnpy/rpc/__init__.py b/vnpy/rpc/__init__.py new file mode 100644 index 00000000..b903d2bb --- /dev/null +++ b/vnpy/rpc/__init__.py @@ -0,0 +1 @@ +from .vnrpc import RpcServer, RpcClient, RemoteException \ No newline at end of file From e4a19972960b2b2618a380239d8e9e0687296d0f Mon Sep 17 00:00:00 2001 From: 1122455801 Date: Tue, 2 Apr 2019 17:07:07 +0800 Subject: [PATCH 05/49] Create test_client.py --- vnpy/rpc/test_client.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 vnpy/rpc/test_client.py diff --git a/vnpy/rpc/test_client.py b/vnpy/rpc/test_client.py new file mode 100644 index 00000000..17f42558 --- /dev/null +++ b/vnpy/rpc/test_client.py @@ -0,0 +1,35 @@ +from __future__ import print_function +from __future__ import absolute_import +from time import sleep + +from .vnrpc import RpcClient + + +class TestClient(RpcClient): + """ + Test RpcClient + """ + def __init__(self, req_address, sub_address): + """ + Constructor + """ + super(TestClient, self).__init__(req_address, sub_address) + + def callback(self, topic, data): + """ + Realize callable function + """ + print('client received topic:', topic, ', data:', data) + + +if __name__ == '__main__': + req_address = 'tcp://localhost:2014' + sub_address = 'tcp://localhost:0602' + + tc = TestClient(req_address, sub_address) + tc.subscribeTopic('') + tc.start() + + while 1: + print(tc.add(1, 3)) + sleep(2) From b3a72c1283ac678b6e7d3b4c3123791fc978d0d9 Mon Sep 17 00:00:00 2001 From: 1122455801 Date: Tue, 2 Apr 2019 17:07:10 +0800 Subject: [PATCH 06/49] Create test_server.py --- vnpy/rpc/test_server.py | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 vnpy/rpc/test_server.py diff --git a/vnpy/rpc/test_server.py b/vnpy/rpc/test_server.py new file mode 100644 index 00000000..ee4ca6e4 --- /dev/null +++ b/vnpy/rpc/test_server.py @@ -0,0 +1,40 @@ +from __future__ import print_function +from __future__ import absolute_import +from time import sleep, time + +from .vnrpc import RpcServer + + +class TestServer(RpcServer): + """ + Test RpcServer + """ + + def __init__(self, rep_address, pub_address): + """ + Constructor + """ + super(TestServer, self).__init__(rep_address, pub_address) + + self.register(self.add) + + def add(self, a, b): + """ + Test function + """ + print('receiving: %s, %s' % (a, b)) + return a + b + + +if __name__ == '__main__': + rep_address = 'tcp://*:2014' + pub_address = 'tcp://*:0602' + + ts = TestServer(rep_address, pub_address) + ts.start() + + while 1: + content = 'current server time is %s' % time() + print(content) + ts.publish('test', content) + sleep(2) From d5585ff269868ed5407bc573a851860fdb35a5ec Mon Sep 17 00:00:00 2001 From: 1122455801 Date: Tue, 2 Apr 2019 17:07:14 +0800 Subject: [PATCH 07/49] Create vnrpc.py --- vnpy/rpc/vnrpc.py | 317 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 317 insertions(+) create mode 100644 vnpy/rpc/vnrpc.py diff --git a/vnpy/rpc/vnrpc.py b/vnpy/rpc/vnrpc.py new file mode 100644 index 00000000..3debfa65 --- /dev/null +++ b/vnpy/rpc/vnrpc.py @@ -0,0 +1,317 @@ +import threading +import traceback +import signal + +import zmq +from msgpack import packb, unpackb +from json import dumps, loads + +import cPickle +p_dumps = cPickle.dumps +p_loads = cPickle.loads + + +# Achieve Ctrl-c interrupt recv +signal.signal(signal.SIGINT, signal.SIG_DFL) + + +class RpcObject(object): + """ + Referred to serialization of packing and unpacking, we offer 3 tools: + 1) maspack: higher performance, but usually requires the installation of msgpack related tools; + 2) jason: Slightly lower performance but versatility is better, most programming languages have built-in libraries; + 3) cPickle: Lower performance and only can be used in Python, but it is very convenient to transfer Python objects directly. + + Therefore, it is recommended to use msgpack. + Use json, if you want to communicate with some languages without providing msgpack. + Use cPickle, when the data being transferred contains many custom Python objects. + """ + + def __init__(self): + """ + Constructor + Use msgpack as default serialization tool + """ + self.use_msgpack() + + def pack(self, data): + """""" + pass + + def unpack(self, data): + """""" + pass + + def __json_pack(self, data): + """ + Pack with json + """ + return dumps(data) + + def __json_unpack(self, data): + """ + Unpack with json + """ + return loads(data) + + def __msgpack_pack(self, data): + """ + Pack with msgpack + """ + return packb(data) + + def __msgpack_unpack(self, data): + """ + Unpack with msgpack + """ + return unpackb(data) + + def __pickle_pack(self, data): + """ + Pack with cPickle + """ + return p_dumps(data) + + def __pickle_unpack(self, data): + """ + Unpack with cPickle + """ + return p_loads(data) + + def use_json(self): + """ + Use json as serialization tool + """ + self.pack = self.__json_pack + self.unpack = self.__json_unpack + + def use_msgpack(self): + """ + Use msgpack as serialization tool + """ + self.pack = self.__msgpack_pack + self.unpack = self.__msgpack_unpack + + def use_pickle(self): + """ + Use cPickle as serialization tool + """ + self.pack = self.__pickle_pack + self.unpack = self.__pickle_unpack + + +class RpcServer(RpcObject): + """""" + + def __init__(self, rep_address, pub_address): + """ + Constructor + """ + super(RpcServer, self).__init__() + + # Save functions dict: key is fuction name, value is fuction object + self.__functions = {} + + # Zmq port related + self.__context = zmq.Context() + + self.__socket_rep = self.__context.socket(zmq.REP) # Reply socket (Request–reply pattern) + self.__socket_rep.bind(rep_address) + + self.__socket_pub = self.__context.socket(zmq.PUB) # Publish socket (Publish–subscribe pattern) + self.__socket_pub.bind(pub_address) + + # Woker thread related + self.__active = False # RpcServer status + self.__thread = threading.Thread(target=self.run) # RpcServer thread + + def start(self): + """ + Start RpcServer + """ + # Start RpcServer status + self.__active = True + + # Start RpcServer thread + if not self.__thread.isAlive(): + self.__thread.start() + + def stop(self, join=False): + """ + Stop RpcServer + """ + # Stop RpcServer status + self.__active = False + + # Wait for RpcServer thread to exit + if join and self.__thread.isAlive(): + self.__thread.join() + + def run(self): + """ + Run RpcServer functions + """ + while self.__active: + # Use poll to wait event arrival, waiting time is 1 second (1000 milliseconds) + if not self.__socket_rep.poll(1000): + continue + + # Receive request data from Reply socket + reqb = self.__socket_rep.recv() + + # Unpack request by deserialization + req = self.unpack(reqb) + + # Get function name and parameters + name, args, kwargs = req + + # Try to get and execute callable function object; capture exception information if it fails + try: + func = self.__functions[name] + r = func(*args, **kwargs) + rep = [True, r] + except Exception as e: + rep = [False, traceback.format_exc()] + + # Pack response by serialization + repb = self.pack(rep) + + # send callable response by Reply socket + self.__socket_rep.send(repb) + + def publish(self, topic, data): + """ + Publish data + """ + # Serialized data + datab = self.pack(data) + + # Send data by Publish socket + self.__socket_pub.send_multipart([topic, datab]) # topci must be ascii encoding + + def register(self, func): + """ + Register function + """ + self.__functions[func.__name__] = func + + +class RpcClient(RpcObject): + """""" + + def __init__(self, req_address, sub_address): + """Constructor""" + super(RpcClient, self).__init__() + + # zmq port related + self.__req_address = req_address + self.__sub_address = sub_address + + self.__context = zmq.Context() + self.__socket_req = self.__context.socket(zmq.REQ) # Request socket (Request–reply pattern) + self.__socket_sub = self.__context.socket(zmq.SUB) # Subscribe socket (Publish–subscribe pattern) + + # Woker thread relate, used to process data pushed from server + self.__active = False # RpcClient status + self.__thread = threading.Thread(target=self.run) # RpcClient thread + + def __getattr__(self, name): + """ + Realize remote call function + """ + # Perform remote call task + def dorpc(*args, **kwargs): + # Generate request + req = [name, args, kwargs] + + # Pack request by serialization + reqb = self.pack(req) + + # Send request and wait for response + self.__socket_req.send(reqb) + repb = self.__socket_req.recv() + + # Unpack response by deserialization + rep = self.unpack(repb) + + # Return response if successed; Trigger exception if failed + if rep[0]: + return rep[1] + else: + raise RemoteException(rep[1]) + + return dorpc + + def start(self): + """ + Start RpcClient + """ + # Connect zmq port + self.__socket_req.connect(self.__req_address) + self.__socket_sub.connect(self.__sub_address) + + # Start RpcClient status + self.__active = True + + # Start RpcClient thread + if not self.__thread.isAlive(): + self.__thread.start() + + def stop(self): + """ + Stop RpcClient + """ + # Stop RpcClient status + self.__active = False + + # Wait for RpcClient thread to exit + if self.__thread.isAlive(): + self.__thread.join() + + def run(self): + """ + Run RpcClient function + """ + while self.__active: + # Use poll to wait event arrival, waiting time is 1 second (1000 milliseconds) + if not self.__socket_sub.poll(1000): + continue + + # Receive data from subscribe socket + topic, datab = self.__socket_sub.recv_multipart() + + # Unpack data by deserialization + data = self.unpack(datab) + + # Process data by callable function + self.callback(topic, data) + + def callback(self, topic, data): + """ + Callable function + """ + raise NotImplementedError + + def subscribeTopic(self, topic): + """ + Subscribe data + """ + self.__socket_sub.setsockopt(zmq.SUBSCRIBE, topic) + + +class RemoteException(Exception): + """ + RPC remote exception + """ + + def __init__(self, value): + """ + Constructor + """ + self.__value = value + + def __str__(self): + """ + Output error message + """ + return self.__value From aceb46c3fd3abbd1af540f719318af48e2eae618 Mon Sep 17 00:00:00 2001 From: "vn.py" Date: Wed, 3 Apr 2019 15:56:06 +0800 Subject: [PATCH 08/49] [Add]okex gateway --- vnpy/gateway/okex/__init__.py | 1 + vnpy/gateway/okex/okex_gateway.py | 629 ++++++++++++++++++++++++++++++ vnpy/trader/constant.py | 1 + 3 files changed, 631 insertions(+) create mode 100644 vnpy/gateway/okex/__init__.py create mode 100644 vnpy/gateway/okex/okex_gateway.py diff --git a/vnpy/gateway/okex/__init__.py b/vnpy/gateway/okex/__init__.py new file mode 100644 index 00000000..8cb9fb3b --- /dev/null +++ b/vnpy/gateway/okex/__init__.py @@ -0,0 +1 @@ +from .okex_gateway import OkexGateway diff --git a/vnpy/gateway/okex/okex_gateway.py b/vnpy/gateway/okex/okex_gateway.py new file mode 100644 index 00000000..3edba175 --- /dev/null +++ b/vnpy/gateway/okex/okex_gateway.py @@ -0,0 +1,629 @@ +# encoding: UTF-8 +""" +""" + +import hashlib +import hmac +import sys +import time +import json +import base64 +from copy import copy +from datetime import datetime +from threading import Lock +from urllib.parse import urlencode + +from requests import ConnectionError + +from vnpy.api.rest import Request, RestClient +from vnpy.api.websocket import WebsocketClient +from vnpy.trader.constant import ( + Direction, + Exchange, + OrderType, + Product, + Status, + Offset +) +from vnpy.trader.gateway import BaseGateway +from vnpy.trader.object import ( + TickData, + OrderData, + TradeData, + PositionData, + AccountData, + ContractData, + OrderRequest, + CancelRequest, + SubscribeRequest, +) + +REST_HOST = "https://www.okex.com" +WEBSOCKET_HOST = "wss://real.okex.com:10440/websocket/okexapi?compress=true" + +STATUS_OKEX2VT = { + "ordering": Status.SUBMITTING, + "open": Status.NOTTRADED, + "part_filled": Status.PARTTRADED, + "filled": Status.ALLTRADED, + "cancelled": Status.CANCELLED, + "cancelling": Status.CANCELLED, + "failure": Status.REJECTED, +} + +DIRECTION_VT2OKEX = {Direction.LONG: "buy", Direction.SHORT: "sell"} +DIRECTION_OKEX2VT = {v: k for k, v in DIRECTION_VT2OKEX.items()} + +ORDERTYPE_VT2OKEX = { + OrderType.LIMIT: "limit", + OrderType.MARKET: "market" +} +ORDERTYPE_OKEX2VT = {v: k for k, v in ORDERTYPE_VT2OKEX.items()} + + +instruments = set() +currencies = set() + + +class OkexGateway(BaseGateway): + """ + VN Trader Gateway for OKEX connection. + """ + + default_setting = { + "API Key": "", + "Secret Key": "", + "Passphrase": "", + "会话数": 3, + "代理地址": "127.0.0.1", + "代理端口": 1080, + } + + def __init__(self, event_engine): + """Constructor""" + super(OkexGateway, self).__init__(event_engine, "OKEX") + + self.rest_api = OkexRestApi(self) + self.ws_api = OkexWebsocketApi(self) + + def connect(self, setting: dict): + """""" + key = setting["API KEY"] + secret = setting["Secret Key"] + passphrase = setting["Passphrase"] + session_number = setting["会话数"] + proxy_host = setting["代理地址"] + proxy_port = setting["代理端口"] + + self.rest_api.connect(key, secret, passphrase, + session_number, proxy_host, proxy_port) + + self.ws_api.connect(key, secret, passphrase, proxy_host, proxy_port) + + def subscribe(self, req: SubscribeRequest): + """""" + self.ws_api.subscribe(req) + + def send_order(self, req: OrderRequest): + """""" + return self.rest_api.send_order(req) + + def cancel_order(self, req: CancelRequest): + """""" + self.rest_api.cancel_order(req) + + def query_account(self): + """""" + pass + + def query_position(self): + """""" + pass + + def close(self): + """""" + self.rest_api.stop() + self.ws_api.stop() + + +class OkexRestApi(RestClient): + """ + OKEX REST API + """ + + def __init__(self, gateway: BaseGateway): + """""" + super(OkexRestApi, self).__init__() + + self.gateway = gateway + self.gateway_name = gateway.gateway_name + + self.key = "" + self.secret = "" + self.passphrase = "" + + self.order_count = 1_000_000 + self.order_count_lock = Lock() + + self.connect_time = 0 + + def sign(self, request): + """ + Generate OKEX signature. + """ + # Sign + timestamp = str(time.time()) + request.data = json.dumps(request.data) + + if request.params: + path = request.path + '?' + urlencode(request.params) + else: + path = request.path + + msg = timestamp + request.method + path + request.data + signature = generate_signature(msg, self.secret) + + # Add headers + request.headers = { + 'OK-ACCESS-KEY': self.key, + 'OK-ACCESS-SIGN': signature, + 'OK-ACCESS-TIMESTAMP': timestamp, + 'OK-ACCESS-PASSPHRASE': self.passphrase, + 'Content-Type': 'application/json' + } + return request + + def connect( + self, + key: str, + secret: str, + passphrase: str + session_number: int, + proxy_host: str, + proxy_port: int, + ): + """ + Initialize connection to REST server. + """ + self.key = key.encode() + self.secret = secret.encode() + self.passphrase = passphrase + + self.connect_time = ( + int(datetime.now().strftime("%y%m%d%H%M%S")) * self.order_count + ) + + self.init(REST_HOST, proxy_host, proxy_port) + self.start(session_number) + self.gateway.write_log("REST API启动成功") + + def _new_order_id(self): + with self.order_count_lock: + self.order_count += 1 + return self.order_count + + def send_order(self, req: OrderRequest): + """""" + orderid = str(self.connect_time + self._new_order_id()) + + data = { + "client_oid": orderid, + "type": ORDERTYPE_VT2OKEX[req.type], + "side": DIRECTION_VT2OKEX[req.direction], + "instrument_id": req.symbol + } + + if req.type == OrderType.MARKET: + if req.direction == Direction.LONG: + data["notional"] = req.volume + else: + data["size"] = req.volume + else: + data["price"] = req.price + data["size"] = req.volume + + order = req.create_order_data(orderid, self.gateway_name) + + self.add_request( + "POST", + "/api/spot/v3/orders", + callback=self.on_send_order, + data=data, + extra=order, + on_failed=self.on_send_order_failed, + on_error=self.on_send_order_error, + ) + + self.gateway.on_order(order) + return order.vt_orderid + + def cancel_order(self, req: CancelRequest): + """""" + data = { + "instrument_id": req.symbol, + "client_oid": req.orderid + } + + path = "/api/spot/v3/cancel_orders/" + req.orderid + self.add_request( + "POST", + path, + callback=self.on_cancel_order, + data=data, + on_error=self.on_cancel_order_error, + ) + + def query_contract(self): + """""" + data = { + "instrument_id": req.symbol, + "client_oid": req.orderid + } + + path = "/api/spot/v3/cancel_orders/" + req.orderid + self.add_request( + "POST", + path, + callback=self.on_cancel_order, + data=data, + on_error=self.on_cancel_order_error, + ) + + def on_query_contract(self, data, request): + """""" + for instrument_data in data: + symbol = instrument_data["instrument_id"] + contract = ContractData( + symbol=symbol, + exchange=Exchange.OKEX, + name=symbol, + product=Product.SPOT, + size=1, + pricetick=instrument_data["tick_size"] + + ) + self.gateway.on_contract(contract) + + instruments.add(instrument_data["instrument_id"]) + currencies.add(instrument_data["base_currency"]) + currencies.add(instrument_data["quote_currency"]) + + self.gateway.write_log("合约信息查询成功") + + def on_send_order_failed(self, status_code: str, request: Request): + """ + Callback when sending order failed on server. + """ + order = request.extra + order.status = Status.REJECTED + self.gateway.on_order(order) + + msg = f"委托失败,状态码:{status_code},信息:{request.response.text}" + self.gateway.write_log(msg) + + def on_send_order_error( + self, exception_type: type, exception_value: Exception, tb, request: Request + ): + """ + Callback when sending order caused exception. + """ + order = request.extra + order.status = Status.REJECTED + self.gateway.on_order(order) + + # Record exception if not ConnectionError + if not issubclass(exception_type, ConnectionError): + self.on_error(exception_type, exception_value, tb, request) + + def on_send_order(self, data, request): + """Websocket will push a new order status""" + pass + + def on_cancel_order_error( + self, exception_type: type, exception_value: Exception, tb, request: Request + ): + """ + Callback when cancelling order failed on server. + """ + # Record exception if not ConnectionError + if not issubclass(exception_type, ConnectionError): + self.on_error(exception_type, exception_value, tb, request) + + def on_cancel_order(self, data, request): + """Websocket will push a new order status""" + pass + + def on_failed(self, status_code: int, request: Request): + """ + Callback to handle request failed. + """ + msg = f"请求失败,状态码:{status_code},信息:{request.response.text}" + self.gateway.write_log(msg) + + def on_error( + self, exception_type: type, exception_value: Exception, tb, request: Request + ): + """ + Callback to handler request exception. + """ + msg = f"触发异常,状态码:{exception_type},信息:{exception_value}" + self.gateway.write_log(msg) + + sys.stderr.write( + self.exception_detail(exception_type, exception_value, tb, request) + ) + + +class OkexWebsocketApi(WebsocketClient): + """""" + + def __init__(self, gateway): + """""" + super(OkexWebsocketApi, self).__init__() + + self.gateway = gateway + self.gateway_name = gateway.gateway_name + + self.key = "" + self.secret = "" + self.passphrase = "" + + self.callbacks = {} + + self.ticks = {} + self.accounts = {} + self.orders = {} + self.trades = set() + + def connect( + self, key: str, secret: str, server: str, proxy_host: str, proxy_port: int + ): + """""" + self.key = key.encode() + self.secret = secret.encode() + + self.init(WEBSOCKET_HOST, proxy_host, proxy_port) + self.start() + + def subscribe(self, req: SubscribeRequest): + """ + Subscribe to tick data upate. + """ + tick = TickData( + symbol=req.symbol, + exchange=req.exchange, + name=req.symbol, + datetime=datetime.now(), + gateway_name=self.gateway_name, + ) + self.ticks[req.symbol] = tick + + def on_connected(self): + """""" + self.gateway.write_log("Websocket API连接成功") + self.authenticate() + + def on_disconnected(self): + """""" + self.gateway.write_log("Websocket API连接断开") + + def on_packet(self, packet: dict): + """""" + if "error" in packet: + self.gateway.write_log("Websocket API报错:%s" % packet["error"]) + + if "not valid" in packet["error"]: + self.active = False + + elif "request" in packet: + req = packet["request"] + success = packet["success"] + + if success: + if req["op"] == "authKey": + self.gateway.write_log("Websocket API验证授权成功") + self.subscribe_topic() + + elif "table" in packet: + name = packet["table"] + callback = self.callbacks[name] + + if isinstance(packet["data"], list): + for d in packet["data"]: + callback(d) + else: + callback(packet["data"]) + + def on_error(self, exception_type: type, exception_value: Exception, tb): + """""" + msg = f"触发异常,状态码:{exception_type},信息:{exception_value}" + self.gateway.write_log(msg) + + sys.stderr.write(self.exception_detail( + exception_type, exception_value, tb)) + + def login(self): + """ + Need to login befores subscribe to websocket topic. + """ + timestamp = str(time.time()) + + msg = timestamp + 'GET' + '/users/self/verify' + signature = generate_signature(msg, self.secret) + + req = { + "op": "login", + "args": [ + self.key, + self.passphrase, + timestamp, + signature + ] + } + self.send_packet(req) + + self.callbacks['login'] = self.on_login + + def subscribe_topic(self): + """ + Subscribe to all private topics. + """ + for instrument_id in instruments: + channel = f"spot/order:{instrument_id}" + req = {"op": "subscribe", "args": [channel]} + self.send_packet(req) + self.callbacks[channel] = self.on_trade + + for currency in currencies: + channel = f"spot/account:{currency}" + req = {"op": "subscribe", "args": [channel]} + self.send_packet(req) + self.callbacks[channel] = self.on_account + + def on_login(self, d: dict): + """""" + data = d['data'] + + if data['success']: + self.gateway.write_log("Websocket接口登录成功") + self.subscribe_topic() + else: + self.gateway.write_log("Websocket接口登录失败") + + def on_tick(self, d): + """""" + symbol = d["symbol"] + tick = self.ticks.get(symbol, None) + if not tick: + return + + tick.last_price = d["price"] + tick.datetime = datetime.strptime( + d["timestamp"], "%Y-%m-%dT%H:%M:%S.%fZ") + self.gateway.on_tick(copy(tick)) + + def on_depth(self, d): + """""" + symbol = d["symbol"] + tick = self.ticks.get(symbol, None) + if not tick: + return + + for n, buf in enumerate(d["bids"][:5]): + price, volume = buf + tick.__setattr__("bid_price_%s" % (n + 1), price) + tick.__setattr__("bid_volume_%s" % (n + 1), volume) + + for n, buf in enumerate(d["asks"][:5]): + price, volume = buf + tick.__setattr__("ask_price_%s" % (n + 1), price) + tick.__setattr__("ask_volume_%s" % (n + 1), volume) + + tick.datetime = datetime.strptime( + d["timestamp"], "%Y-%m-%dT%H:%M:%S.%fZ") + self.gateway.on_tick(copy(tick)) + + def on_trade(self, d): + """""" + # Filter trade update with no trade volume and side (funding) + if not d["lastQty"] or not d["side"]: + return + + tradeid = d["execID"] + if tradeid in self.trades: + return + self.trades.add(tradeid) + + if d["clOrdID"]: + orderid = d["clOrdID"] + else: + orderid = d["orderID"] + + trade = TradeData( + symbol=d["symbol"], + exchange=Exchange.OKEX, + orderid=orderid, + tradeid=tradeid, + direction=DIRECTION_OKEX2VT[d["side"]], + price=d["lastPx"], + volume=d["lastQty"], + time=d["timestamp"][11:19], + gateway_name=self.gateway_name, + ) + + self.gateway.on_trade(trade) + + def on_order(self, d): + """""" + if "ordStatus" not in d: + return + + sysid = d["orderID"] + order = self.orders.get(sysid, None) + if not order: + if d["clOrdID"]: + orderid = d["clOrdID"] + else: + orderid = sysid + + # time = d["timestamp"][11:19] + + order = OrderData( + symbol=d["symbol"], + exchange=Exchange.OKEX, + type=ORDERTYPE_OKEX2VT[d["ordType"]], + orderid=orderid, + direction=DIRECTION_OKEX2VT[d["side"]], + price=d["price"], + volume=d["orderQty"], + time=d["timestamp"][11:19], + gateway_name=self.gateway_name, + ) + self.orders[sysid] = order + + order.traded = d.get("cumQty", order.traded) + order.status = STATUS_OKEX2VT.get(d["ordStatus"], order.status) + + self.gateway.on_order(copy(order)) + + def on_account(self, d): + """""" + accountid = str(d["account"]) + account = self.accounts.get(accountid, None) + if not account: + account = AccountData(accountid=accountid, + gateway_name=self.gateway_name) + self.accounts[accountid] = account + + account.balance = d.get("marginBalance", account.balance) + account.available = d.get("availableMargin", account.available) + account.frozen = account.balance - account.available + + self.gateway.on_account(copy(account)) + + def on_contract(self, d): + """""" + if "tickSize" not in d: + return + + if not d["lotSize"]: + return + + contract = ContractData( + symbol=d["symbol"], + exchange=Exchange.OKEX, + name=d["symbol"], + product=Product.FUTURES, + pricetick=d["tickSize"], + size=d["lotSize"], + stop_supported=True, + net_position=True, + gateway_name=self.gateway_name, + ) + + self.gateway.on_contract(contract) + + +def generate_signature(msg: str, secret_key: str): + """OKEX V3 signature""" + return base64.b64encode(hmac.new(secret_key, msg.encode(), hashlib.sha256).digest()) diff --git a/vnpy/trader/constant.py b/vnpy/trader/constant.py index 92492f5b..65cdf051 100644 --- a/vnpy/trader/constant.py +++ b/vnpy/trader/constant.py @@ -99,6 +99,7 @@ class Exchange(Enum): # CryptoCurrency BITMEX = "BITMEX" + OKEX = "OKEX" class Currency(Enum): From 455a1851a7e10766e3ab78a73a59f1616ea1ac92 Mon Sep 17 00:00:00 2001 From: qqqlyx Date: Wed, 3 Apr 2019 16:31:53 +0800 Subject: [PATCH 09/49] [Add] ping_interval MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 增加了一个变量,用来控制websocket_client中ping的时间间隔。 --- vnpy/api/websocket/websocket_client.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vnpy/api/websocket/websocket_client.py b/vnpy/api/websocket/websocket_client.py index 61541594..563cc64f 100644 --- a/vnpy/api/websocket/websocket_client.py +++ b/vnpy/api/websocket/websocket_client.py @@ -50,9 +50,10 @@ class WebsocketClient(object): self._last_sent_text = None self._last_received_text = None - def init(self, host: str, proxy_host: str = "", proxy_port: int = 0): + def init(self, host: str, proxy_host: str = "", proxy_port: int = 0, ping_interval: int = 60): """""" self.host = host + self.ping_interval = ping_interval # seconds if proxy_host and proxy_port: self.proxy_host = proxy_host @@ -202,7 +203,7 @@ class WebsocketClient(object): et, ev, tb = sys.exc_info() self.on_error(et, ev, tb) self._reconnect() - for i in range(60): + for i in range(self.ping_interval): if not self._active: break sleep(1) From 7024a7d2ca2939635436e23e32db41f044da826f Mon Sep 17 00:00:00 2001 From: "vn.py" Date: Wed, 3 Apr 2019 17:00:52 +0800 Subject: [PATCH 10/49] [Mod]move TigerGateway subsribe into on_push_connected callback --- vnpy/gateway/tiger/tiger_gateway.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/vnpy/gateway/tiger/tiger_gateway.py b/vnpy/gateway/tiger/tiger_gateway.py index eecd9305..521979f3 100644 --- a/vnpy/gateway/tiger/tiger_gateway.py +++ b/vnpy/gateway/tiger/tiger_gateway.py @@ -123,6 +123,9 @@ class TigerGateway(BaseGateway): self.contracts = {} self.symbol_names = {} + self.push_connected = False + self.subscribed_symbols = set() + def run(self): """""" while self.active: @@ -203,23 +206,33 @@ class TigerGateway(BaseGateway): """ protocol, host, port = self.client_config.socket_host_port self.push_client = PushClient(host, port, (protocol == 'ssl')) - self.push_client.connect( - self.client_config.tiger_id, self.client_config.private_key) self.push_client.quote_changed = self.on_quote_change self.push_client.asset_changed = self.on_asset_change self.push_client.position_changed = self.on_position_change self.push_client.order_changed = self.on_order_change - self.write_log("推送接口连接成功") + self.push_client.connect( + self.client_config.tiger_id, self.client_config.private_key) def subscribe(self, req: SubscribeRequest): """""" - self.push_client.subscribe_quote([req.symbol]) + self.subscribed_symbols.add(req.symbol) + + if self.push_connected: + self.push_client.subscribe_quote([req.symbol]) + + def on_push_connected(self): + """""" + self.push_connected = True + self.write_log("推送接口连接成功") + self.push_client.subscribe_asset() self.push_client.subscribe_position() self.push_client.subscribe_order() + self.push_client.subscribe_quote(list(self.subscribed_symbols)) + def on_quote_change(self, tiger_symbol: str, data: list, trading: bool): """""" data = dict(data) From b8c648b432f373df772594ac9550d8920a7a68d2 Mon Sep 17 00:00:00 2001 From: qqqlyx Date: Wed, 3 Apr 2019 17:46:48 +0800 Subject: [PATCH 11/49] update --- vnpy/api/websocket/websocket_client.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vnpy/api/websocket/websocket_client.py b/vnpy/api/websocket/websocket_client.py index 563cc64f..185ed046 100644 --- a/vnpy/api/websocket/websocket_client.py +++ b/vnpy/api/websocket/websocket_client.py @@ -45,13 +45,16 @@ class WebsocketClient(object): self.proxy_host = None self.proxy_port = None + self.ping_interval = 60 # seconds # For debugging self._last_sent_text = None self._last_received_text = None def init(self, host: str, proxy_host: str = "", proxy_port: int = 0, ping_interval: int = 60): - """""" + """ + :param ping_interval: unit: seconds, type: int + """ self.host = host self.ping_interval = ping_interval # seconds From 7d86efce398c65f2b4f918db426c639f7351ab23 Mon Sep 17 00:00:00 2001 From: "vn.py" Date: Thu, 4 Apr 2019 07:40:44 +0800 Subject: [PATCH 12/49] [Mod]complete trading test of okex gateway --- tests/trader/run.py | 14 +- vnpy/gateway/okex/okex_gateway.py | 431 +++++++++++++++++------------- 2 files changed, 254 insertions(+), 191 deletions(-) diff --git a/tests/trader/run.py b/tests/trader/run.py index 1a402f82..04e2aecc 100644 --- a/tests/trader/run.py +++ b/tests/trader/run.py @@ -10,6 +10,7 @@ from vnpy.gateway.ib import IbGateway from vnpy.gateway.ctp import CtpGateway from vnpy.gateway.tiger import TigerGateway from vnpy.gateway.oes import OesGateway +from vnpy.gateway.okex import OkexGateway from vnpy.app.cta_strategy import CtaStrategyApp from vnpy.app.csv_loader import CsvLoaderApp @@ -22,12 +23,13 @@ def main(): event_engine = EventEngine() main_engine = MainEngine(event_engine) - main_engine.add_gateway(CtpGateway) - main_engine.add_gateway(IbGateway) - main_engine.add_gateway(FutuGateway) - main_engine.add_gateway(BitmexGateway) - main_engine.add_gateway(TigerGateway) - main_engine.add_gateway(OesGateway) + # main_engine.add_gateway(CtpGateway) + # main_engine.add_gateway(IbGateway) + # main_engine.add_gateway(FutuGateway) + # main_engine.add_gateway(BitmexGateway) + # main_engine.add_gateway(TigerGateway) + # main_engine.add_gateway(OesGateway) + main_engine.add_gateway(OkexGateway) main_engine.add_app(CtaStrategyApp) main_engine.add_app(CsvLoaderApp) diff --git a/vnpy/gateway/okex/okex_gateway.py b/vnpy/gateway/okex/okex_gateway.py index 3edba175..d59888b3 100644 --- a/vnpy/gateway/okex/okex_gateway.py +++ b/vnpy/gateway/okex/okex_gateway.py @@ -8,6 +8,7 @@ import sys import time import json import base64 +import zlib from copy import copy from datetime import datetime from threading import Lock @@ -39,7 +40,7 @@ from vnpy.trader.object import ( ) REST_HOST = "https://www.okex.com" -WEBSOCKET_HOST = "wss://real.okex.com:10440/websocket/okexapi?compress=true" +WEBSOCKET_HOST = "wss://real.okex.com:10442/ws/v3" STATUS_OKEX2VT = { "ordering": Status.SUBMITTING, @@ -88,7 +89,7 @@ class OkexGateway(BaseGateway): def connect(self, setting: dict): """""" - key = setting["API KEY"] + key = setting["API Key"] secret = setting["Secret Key"] passphrase = setting["Passphrase"] session_number = setting["会话数"] @@ -142,7 +143,7 @@ class OkexRestApi(RestClient): self.secret = "" self.passphrase = "" - self.order_count = 1_000_000 + self.order_count = 10000 self.order_count_lock = Lock() self.connect_time = 0 @@ -152,7 +153,8 @@ class OkexRestApi(RestClient): Generate OKEX signature. """ # Sign - timestamp = str(time.time()) + # timestamp = str(time.time()) + timestamp = get_timestamp() request.data = json.dumps(request.data) if request.params: @@ -177,7 +179,7 @@ class OkexRestApi(RestClient): self, key: str, secret: str, - passphrase: str + passphrase: str, session_number: int, proxy_host: str, proxy_port: int, @@ -185,18 +187,21 @@ class OkexRestApi(RestClient): """ Initialize connection to REST server. """ - self.key = key.encode() + self.key = key self.secret = secret.encode() self.passphrase = passphrase - self.connect_time = ( - int(datetime.now().strftime("%y%m%d%H%M%S")) * self.order_count - ) - + self.connect_time = int(datetime.now().strftime("%y%m%d%H%M%S")) + self.init(REST_HOST, proxy_host, proxy_port) self.start(session_number) self.gateway.write_log("REST API启动成功") + self.query_time() + self.query_contract() + self.query_account() + self.query_order() + def _new_order_id(self): with self.order_count_lock: self.order_count += 1 @@ -204,8 +209,8 @@ class OkexRestApi(RestClient): def send_order(self, req: OrderRequest): """""" - orderid = str(self.connect_time + self._new_order_id()) - + orderid = f"a{self.connect_time}{self._new_order_id()}" + data = { "client_oid": orderid, "type": ORDERTYPE_VT2OKEX[req.type], @@ -227,11 +232,11 @@ class OkexRestApi(RestClient): self.add_request( "POST", "/api/spot/v3/orders", - callback=self.on_send_order, - data=data, - extra=order, - on_failed=self.on_send_order_failed, - on_error=self.on_send_order_error, + callback = self.on_send_order, + data = data, + extra = order, + on_failed = self.on_send_order_failed, + on_error = self.on_send_order_error, ) self.gateway.on_order(order) @@ -239,7 +244,7 @@ class OkexRestApi(RestClient): def cancel_order(self, req: CancelRequest): """""" - data = { + data={ "instrument_id": req.symbol, "client_oid": req.orderid } @@ -248,25 +253,41 @@ class OkexRestApi(RestClient): self.add_request( "POST", path, - callback=self.on_cancel_order, - data=data, - on_error=self.on_cancel_order_error, + callback = self.on_cancel_order, + data = data, + on_error = self.on_cancel_order_error, ) def query_contract(self): """""" - data = { - "instrument_id": req.symbol, - "client_oid": req.orderid - } - - path = "/api/spot/v3/cancel_orders/" + req.orderid self.add_request( - "POST", - path, - callback=self.on_cancel_order, - data=data, - on_error=self.on_cancel_order_error, + "GET", + "/api/spot/v3/instruments", + callback = self.on_query_contract + ) + + def query_account(self): + """""" + self.add_request( + "GET", + "/api/spot/v3/accounts", + callback = self.on_query_account + ) + + def query_order(self): + """""" + self.add_request( + "GET", + "/api/spot/v3/orders_pending", + callback = self.on_query_order + ) + + def query_time(self): + """""" + self.add_request( + "GET", + "/api/general/v3/time", + callback=self.on_query_time ) def on_query_contract(self, data, request): @@ -279,8 +300,8 @@ class OkexRestApi(RestClient): name=symbol, product=Product.SPOT, size=1, - pricetick=instrument_data["tick_size"] - + pricetick = instrument_data["tick_size"], + gateway_name = self.gateway_name ) self.gateway.on_contract(contract) @@ -290,6 +311,48 @@ class OkexRestApi(RestClient): self.gateway.write_log("合约信息查询成功") + # Start websocket api after instruments data collected + self.gateway.ws_api.start() + + def on_query_account(self, data, request): + """""" + for account_data in data: + account = AccountData( + accountid=account_data["currency"], + balance=float(account_data["balance"]), + frozen=float(account_data["hold"]), + gateway_name=self.gateway_name + ) + self.gateway.on_account(account) + + self.gateway.write_log("账户资金查询成功") + + def on_query_order(self, data, request): + """""" + for order_data in data: + order = OrderData( + symbol=order_data["instrument_id"], + exchange=Exchange.OKEX, + type=ORDERTYPE_OKEX2VT[order_data["type"]], + orderid=order_data["client_oid"], + direction=DIRECTION_OKEX2VT[order_data["side"]], + price=float(order_data["price"]), + volume=float(order_data["size"]), + time=order_data["timestamp"][11:19], + status=STATUS_OKEX2VT[order_data["status"]], + gateway_name=self.gateway_name, + ) + self.gateway.on_order(order) + + self.gateway.write_log("委托信息查询成功") + + def on_query_time(self, data, request): + """""" + server_time = data["iso"] + local_time = datetime.utcnow().isoformat() + msg = f"服务器时间:{server_time},本机时间:{local_time}" + self.gateway.write_log(msg) + def on_send_order_failed(self, status_code: str, request: Request): """ Callback when sending order failed on server. @@ -368,22 +431,33 @@ class OkexWebsocketApi(WebsocketClient): self.secret = "" self.passphrase = "" - self.callbacks = {} + self.trade_count = 10000 + self.connect_time = 0 + self.callbacks = {} self.ticks = {} - self.accounts = {} - self.orders = {} - self.trades = set() def connect( - self, key: str, secret: str, server: str, proxy_host: str, proxy_port: int + self, + key: str, + secret: str, + passphrase: str, + proxy_host: str, + proxy_port: int ): """""" - self.key = key.encode() + self.key = key self.secret = secret.encode() + self.passphrase = passphrase + + self.connect_time = int(datetime.now().strftime("%y%m%d%H%M%S")) self.init(WEBSOCKET_HOST, proxy_host, proxy_port) - self.start() + # self.start() + + def unpack_data(self, data): + """""" + return json.loads(zlib.decompress(data, -zlib.MAX_WBITS)) def subscribe(self, req: SubscribeRequest): """ @@ -398,10 +472,22 @@ class OkexWebsocketApi(WebsocketClient): ) self.ticks[req.symbol] = tick + channel_ticker = f"spot/ticker:{req.symbol}" + channel_depth = f"spot/depth5:{req.symbol}" + + self.callbacks[channel_ticker] = self.on_ticker + self.callbacks[channel_depth] = self.on_depth + + req = { + "op": "subscribe", + "args": [channel_ticker, channel_depth] + } + self.send_packet(req) + def on_connected(self): """""" self.gateway.write_log("Websocket API连接成功") - self.authenticate() + self.login() def on_disconnected(self): """""" @@ -409,30 +495,27 @@ class OkexWebsocketApi(WebsocketClient): def on_packet(self, packet: dict): """""" - if "error" in packet: - self.gateway.write_log("Websocket API报错:%s" % packet["error"]) + if "event" in packet: + event = packet["event"] + if event == "subscribe": + return + elif event == "error": + msg = packet["message"] + self.gateway.write_log(f"Websocket API请求异常:{msg}") + elif event == "login": + self.on_login(packet) + else: + channel = packet["table"] + data = packet["data"] + callback = self.callbacks[channel] - if "not valid" in packet["error"]: - self.active = False - - elif "request" in packet: - req = packet["request"] - success = packet["success"] - - if success: - if req["op"] == "authKey": - self.gateway.write_log("Websocket API验证授权成功") - self.subscribe_topic() - - elif "table" in packet: - name = packet["table"] - callback = self.callbacks[name] - - if isinstance(packet["data"], list): - for d in packet["data"]: + try: + for d in data: callback(d) - else: - callback(packet["data"]) + except: + import traceback + traceback.print_exc() + print(packet) def on_error(self, exception_type: type, exception_value: Exception, tb): """""" @@ -457,173 +540,151 @@ class OkexWebsocketApi(WebsocketClient): self.key, self.passphrase, timestamp, - signature + signature.decode("utf-8") ] } self.send_packet(req) - self.callbacks['login'] = self.on_login def subscribe_topic(self): """ Subscribe to all private topics. """ + self.callbacks["spot/ticker"] = self.on_ticker + self.callbacks["spot/depth5"] = self.on_depth + self.callbacks["spot/account"] = self.on_account + self.callbacks["spot/order"] = self.on_order + + # Subscribe to order update + channels = [] for instrument_id in instruments: channel = f"spot/order:{instrument_id}" - req = {"op": "subscribe", "args": [channel]} - self.send_packet(req) - self.callbacks[channel] = self.on_trade + channels.append(channel) + req = { + "op": "subscribe", + "args": channels + } + self.send_packet(req) + + # Subscribe to account update + channels = [] for currency in currencies: channel = f"spot/account:{currency}" - req = {"op": "subscribe", "args": [channel]} - self.send_packet(req) - self.callbacks[channel] = self.on_account + channels.append(channel) - def on_login(self, d: dict): + req = { + "op": "subscribe", + "args": channels + } + self.send_packet(req) + + def on_login(self, data: dict): """""" - data = d['data'] + success = data.get("success", False) - if data['success']: - self.gateway.write_log("Websocket接口登录成功") + if success: + self.gateway.write_log("Websocket API登录成功") self.subscribe_topic() else: - self.gateway.write_log("Websocket接口登录失败") + self.gateway.write_log("Websocket API登录失败") - def on_tick(self, d): + def on_ticker(self, d): """""" - symbol = d["symbol"] + symbol = d["instrument_id"] tick = self.ticks.get(symbol, None) if not tick: return - tick.last_price = d["price"] + tick.last_price = d["last"] + tick.open = d["open_24h"] + tick.high = d["high_24h"] + tick.low = d["low_24h"] + tick.volume = d["base_volume_24h"] tick.datetime = datetime.strptime( d["timestamp"], "%Y-%m-%dT%H:%M:%S.%fZ") self.gateway.on_tick(copy(tick)) def on_depth(self, d): """""" - symbol = d["symbol"] - tick = self.ticks.get(symbol, None) - if not tick: - return + for tick_data in d: + symbol = d["instrument_id"] + tick = self.ticks.get(symbol, None) + if not tick: + return - for n, buf in enumerate(d["bids"][:5]): - price, volume = buf - tick.__setattr__("bid_price_%s" % (n + 1), price) - tick.__setattr__("bid_volume_%s" % (n + 1), volume) + bids = d["bids"] + asks = d["asks"] + for n, buf in enumerate(bids): + price, volume, _ = buf + tick.__setattr__("bid_price_%s" % (n + 1), price) + tick.__setattr__("bid_volume_%s" % (n + 1), volume) - for n, buf in enumerate(d["asks"][:5]): - price, volume = buf - tick.__setattr__("ask_price_%s" % (n + 1), price) - tick.__setattr__("ask_volume_%s" % (n + 1), volume) + for n, buf in enumerate(asks): + price, volume, _ = buf + tick.__setattr__("ask_price_%s" % (n + 1), price) + tick.__setattr__("ask_volume_%s" % (n + 1), volume) - tick.datetime = datetime.strptime( - d["timestamp"], "%Y-%m-%dT%H:%M:%S.%fZ") - self.gateway.on_tick(copy(tick)) - - def on_trade(self, d): - """""" - # Filter trade update with no trade volume and side (funding) - if not d["lastQty"] or not d["side"]: - return - - tradeid = d["execID"] - if tradeid in self.trades: - return - self.trades.add(tradeid) - - if d["clOrdID"]: - orderid = d["clOrdID"] - else: - orderid = d["orderID"] - - trade = TradeData( - symbol=d["symbol"], - exchange=Exchange.OKEX, - orderid=orderid, - tradeid=tradeid, - direction=DIRECTION_OKEX2VT[d["side"]], - price=d["lastPx"], - volume=d["lastQty"], - time=d["timestamp"][11:19], - gateway_name=self.gateway_name, - ) - - self.gateway.on_trade(trade) + tick.datetime = datetime.strptime( + d["timestamp"], "%Y-%m-%dT%H:%M:%S.%fZ") + self.gateway.on_tick(copy(tick)) def on_order(self, d): """""" - if "ordStatus" not in d: + order = OrderData( + symbol=d["instrument_id"], + exchange=Exchange.OKEX, + type=ORDERTYPE_OKEX2VT[d["type"]], + orderid=d["client_oid"], + direction=DIRECTION_OKEX2VT[d["side"]], + price=d["price"], + volume=d["size"], + traded=d["filled_size"], + time=d["timestamp"][11:19], + status=STATUS_OKEX2VT[d["status"]], + gateway_name=self.gateway_name, + ) + self.gateway.on_order(copy(order)) + + trade_volume = float(d.get("last_fill_qty", 0)) + if not trade_volume: return - sysid = d["orderID"] - order = self.orders.get(sysid, None) - if not order: - if d["clOrdID"]: - orderid = d["clOrdID"] - else: - orderid = sysid + self.trade_count += 1 + tradeid = f"{self.connect_time}{self.trade_count}" - # time = d["timestamp"][11:19] - - order = OrderData( - symbol=d["symbol"], - exchange=Exchange.OKEX, - type=ORDERTYPE_OKEX2VT[d["ordType"]], - orderid=orderid, - direction=DIRECTION_OKEX2VT[d["side"]], - price=d["price"], - volume=d["orderQty"], - time=d["timestamp"][11:19], - gateway_name=self.gateway_name, - ) - self.orders[sysid] = order - - order.traded = d.get("cumQty", order.traded) - order.status = STATUS_OKEX2VT.get(d["ordStatus"], order.status) - - self.gateway.on_order(copy(order)) + trade = TradeData( + symbol=order.symbol, + exchange=order.exchange, + orderid=order.orderid, + tradeid=tradeid, + direction=order.direction, + price=float(d["last_fill_px"]), + volume=float(trade_volume), + time=d["last_fill_time"][11:19], + gateway_name=self.gateway_name + ) + self.gateway.on_trade(trade) def on_account(self, d): """""" - accountid = str(d["account"]) - account = self.accounts.get(accountid, None) - if not account: - account = AccountData(accountid=accountid, - gateway_name=self.gateway_name) - self.accounts[accountid] = account - - account.balance = d.get("marginBalance", account.balance) - account.available = d.get("availableMargin", account.available) - account.frozen = account.balance - account.available - - self.gateway.on_account(copy(account)) - - def on_contract(self, d): - """""" - if "tickSize" not in d: - return - - if not d["lotSize"]: - return - - contract = ContractData( - symbol=d["symbol"], - exchange=Exchange.OKEX, - name=d["symbol"], - product=Product.FUTURES, - pricetick=d["tickSize"], - size=d["lotSize"], - stop_supported=True, - net_position=True, - gateway_name=self.gateway_name, + account = AccountData( + accountid=d["currency"], + balance=float(d["balance"]), + frozen=float(d["hold"]), + gateway_name=self.gateway_name ) - - self.gateway.on_contract(contract) + + self.gateway.on_account(copy(account)) def generate_signature(msg: str, secret_key: str): """OKEX V3 signature""" return base64.b64encode(hmac.new(secret_key, msg.encode(), hashlib.sha256).digest()) + + +def get_timestamp(): + """""" + now = datetime.utcnow() + timestamp = now.isoformat("T", "milliseconds") + return timestamp + "Z" From 8e19d218687702b0f0dfdfa401614f3e2f6f5db1 Mon Sep 17 00:00:00 2001 From: "vn.py" Date: Thu, 4 Apr 2019 07:44:23 +0800 Subject: [PATCH 13/49] [Mod]add support for order rejected --- vnpy/gateway/okex/okex_gateway.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/vnpy/gateway/okex/okex_gateway.py b/vnpy/gateway/okex/okex_gateway.py index d59888b3..9e13dc99 100644 --- a/vnpy/gateway/okex/okex_gateway.py +++ b/vnpy/gateway/okex/okex_gateway.py @@ -380,7 +380,14 @@ class OkexRestApi(RestClient): def on_send_order(self, data, request): """Websocket will push a new order status""" - pass + order = request.extra + + error_msg = data["error_message"] + if error_msg: + order.status = Status.REJECTED + self.gateway.on_order(order) + + self.gateway.write_log(f"委托失败:{error_msg}") def on_cancel_order_error( self, exception_type: type, exception_value: Exception, tb, request: Request From a2da6d7ec2570f8701d66bc66425166894d2be4e Mon Sep 17 00:00:00 2001 From: CHEN Jie Date: Thu, 4 Apr 2019 09:32:47 +0800 Subject: [PATCH 14/49] fix a bug in CTP login using UserProductInfo and AuthCode --- vnpy/gateway/ctp/ctp_gateway.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vnpy/gateway/ctp/ctp_gateway.py b/vnpy/gateway/ctp/ctp_gateway.py index fe2784a7..cb3014e0 100644 --- a/vnpy/gateway/ctp/ctp_gateway.py +++ b/vnpy/gateway/ctp/ctp_gateway.py @@ -405,7 +405,7 @@ class CtpTdApi(TdApi): """""" if not error['ErrorID']: self.authStatus = True - self.writeLog("交易授权验证成功") + self.gateway.write_log("交易授权验证成功") self.login() else: self.gateway.write_error("交易授权验证失败", error) @@ -418,7 +418,7 @@ class CtpTdApi(TdApi): self.login_status = True self.gateway.write_log("交易登录成功") - # Confirm settelment + # Confirm settlement req = { "BrokerID": self.brokerid, "InvestorID": self.userid @@ -662,7 +662,7 @@ class CtpTdApi(TdApi): "UserID": self.userid, "BrokerID": self.brokerid, "AuthCode": self.auth_code, - "ProductInfo": self.product_info + "UserProductInfo": self.product_info } self.reqid += 1 @@ -678,7 +678,8 @@ class CtpTdApi(TdApi): req = { "UserID": self.userid, "Password": self.password, - "BrokerID": self.brokerid + "BrokerID": self.brokerid, + "UserProductInfo": self.product_info } self.reqid += 1 From 0ebe533d4ffeb3129d4a39bbe0df40f6d7a7dd95 Mon Sep 17 00:00:00 2001 From: "vn.py" Date: Thu, 4 Apr 2019 15:57:19 +0800 Subject: [PATCH 15/49] [Mod]improve code quality of okex gateway --- tests/trader/run.py | 12 +++--- vnpy/api/websocket/websocket_client.py | 2 +- vnpy/gateway/okex/okex_gateway.py | 58 +++++++++++++------------- 3 files changed, 37 insertions(+), 35 deletions(-) diff --git a/tests/trader/run.py b/tests/trader/run.py index 04e2aecc..d77acee9 100644 --- a/tests/trader/run.py +++ b/tests/trader/run.py @@ -23,12 +23,12 @@ def main(): event_engine = EventEngine() main_engine = MainEngine(event_engine) - # main_engine.add_gateway(CtpGateway) - # main_engine.add_gateway(IbGateway) - # main_engine.add_gateway(FutuGateway) - # main_engine.add_gateway(BitmexGateway) - # main_engine.add_gateway(TigerGateway) - # main_engine.add_gateway(OesGateway) + main_engine.add_gateway(CtpGateway) + main_engine.add_gateway(IbGateway) + main_engine.add_gateway(FutuGateway) + main_engine.add_gateway(BitmexGateway) + main_engine.add_gateway(TigerGateway) + main_engine.add_gateway(OesGateway) main_engine.add_gateway(OkexGateway) main_engine.add_app(CtaStrategyApp) diff --git a/vnpy/api/websocket/websocket_client.py b/vnpy/api/websocket/websocket_client.py index 111bb3ab..0fe28ee7 100644 --- a/vnpy/api/websocket/websocket_client.py +++ b/vnpy/api/websocket/websocket_client.py @@ -48,7 +48,7 @@ class WebsocketClient(object): self.proxy_host = None self.proxy_port = None - self.ping_interval = 60 # seconds + self.ping_interval = 60 # seconds # For debugging self._last_sent_text = None diff --git a/vnpy/gateway/okex/okex_gateway.py b/vnpy/gateway/okex/okex_gateway.py index 9e13dc99..cc1e124c 100644 --- a/vnpy/gateway/okex/okex_gateway.py +++ b/vnpy/gateway/okex/okex_gateway.py @@ -23,15 +23,13 @@ from vnpy.trader.constant import ( Exchange, OrderType, Product, - Status, - Offset + Status ) from vnpy.trader.gateway import BaseGateway from vnpy.trader.object import ( TickData, OrderData, TradeData, - PositionData, AccountData, ContractData, OrderRequest, @@ -192,7 +190,7 @@ class OkexRestApi(RestClient): self.passphrase = passphrase self.connect_time = int(datetime.now().strftime("%y%m%d%H%M%S")) - + self.init(REST_HOST, proxy_host, proxy_port) self.start(session_number) self.gateway.write_log("REST API启动成功") @@ -210,7 +208,7 @@ class OkexRestApi(RestClient): def send_order(self, req: OrderRequest): """""" orderid = f"a{self.connect_time}{self._new_order_id()}" - + data = { "client_oid": orderid, "type": ORDERTYPE_VT2OKEX[req.type], @@ -232,11 +230,11 @@ class OkexRestApi(RestClient): self.add_request( "POST", "/api/spot/v3/orders", - callback = self.on_send_order, - data = data, - extra = order, - on_failed = self.on_send_order_failed, - on_error = self.on_send_order_error, + callback=self.on_send_order, + data=data, + extra=order, + on_failed=self.on_send_order_failed, + on_error=self.on_send_order_error, ) self.gateway.on_order(order) @@ -244,7 +242,7 @@ class OkexRestApi(RestClient): def cancel_order(self, req: CancelRequest): """""" - data={ + data = { "instrument_id": req.symbol, "client_oid": req.orderid } @@ -253,9 +251,9 @@ class OkexRestApi(RestClient): self.add_request( "POST", path, - callback = self.on_cancel_order, - data = data, - on_error = self.on_cancel_order_error, + callback=self.on_cancel_order, + data=data, + on_error=self.on_cancel_order_error, ) def query_contract(self): @@ -263,7 +261,7 @@ class OkexRestApi(RestClient): self.add_request( "GET", "/api/spot/v3/instruments", - callback = self.on_query_contract + callback=self.on_query_contract ) def query_account(self): @@ -271,7 +269,7 @@ class OkexRestApi(RestClient): self.add_request( "GET", "/api/spot/v3/accounts", - callback = self.on_query_account + callback=self.on_query_account ) def query_order(self): @@ -279,7 +277,7 @@ class OkexRestApi(RestClient): self.add_request( "GET", "/api/spot/v3/orders_pending", - callback = self.on_query_order + callback=self.on_query_order ) def query_time(self): @@ -300,8 +298,8 @@ class OkexRestApi(RestClient): name=symbol, product=Product.SPOT, size=1, - pricetick = instrument_data["tick_size"], - gateway_name = self.gateway_name + pricetick=instrument_data["tick_size"], + gateway_name=self.gateway_name ) self.gateway.on_contract(contract) @@ -386,8 +384,8 @@ class OkexRestApi(RestClient): if error_msg: order.status = Status.REJECTED self.gateway.on_order(order) - - self.gateway.write_log(f"委托失败:{error_msg}") + + self.gateway.write_log(f"委托失败:{error_msg}") def on_cancel_order_error( self, exception_type: type, exception_value: Exception, tb, request: Request @@ -430,6 +428,7 @@ class OkexWebsocketApi(WebsocketClient): def __init__(self, gateway): """""" super(OkexWebsocketApi, self).__init__() + self.ping_interval = 20 # OKEX use 30 seconds for ping self.gateway = gateway self.gateway_name = gateway.gateway_name @@ -514,15 +513,11 @@ class OkexWebsocketApi(WebsocketClient): else: channel = packet["table"] data = packet["data"] - callback = self.callbacks[channel] + callback = self.callbacks.get(channel, None) - try: + if callback: for d in data: callback(d) - except: - import traceback - traceback.print_exc() - print(packet) def on_error(self, exception_type: type, exception_value: Exception, tb): """""" @@ -586,6 +581,13 @@ class OkexWebsocketApi(WebsocketClient): } self.send_packet(req) + # Subscribe to BTC/USDT trade for keep connection alive + req = { + "op": "subscribe", + "args": ["spot/trade:BTC-USDT"] + } + self.send_packet(req) + def on_login(self, data: dict): """""" success = data.get("success", False) @@ -681,7 +683,7 @@ class OkexWebsocketApi(WebsocketClient): frozen=float(d["hold"]), gateway_name=self.gateway_name ) - + self.gateway.on_account(copy(account)) From 1429856d5c9d39d1c2c00a8b75a7742327513b57 Mon Sep 17 00:00:00 2001 From: "vn.py" Date: Thu, 4 Apr 2019 15:57:49 +0800 Subject: [PATCH 16/49] [Mod]add unicode to bytes transfer in rpc --- vnpy/rpc/test_client.py | 5 +++-- vnpy/rpc/test_server.py | 4 ++-- vnpy/rpc/vnrpc.py | 36 ++++++++++++++++++++++++------------ 3 files changed, 29 insertions(+), 16 deletions(-) diff --git a/vnpy/rpc/test_client.py b/vnpy/rpc/test_client.py index 17f42558..7a693a52 100644 --- a/vnpy/rpc/test_client.py +++ b/vnpy/rpc/test_client.py @@ -2,13 +2,14 @@ from __future__ import print_function from __future__ import absolute_import from time import sleep -from .vnrpc import RpcClient +from vnpy.rpc import RpcClient class TestClient(RpcClient): """ Test RpcClient - """ + """ + def __init__(self, req_address, sub_address): """ Constructor diff --git a/vnpy/rpc/test_server.py b/vnpy/rpc/test_server.py index ee4ca6e4..660168fc 100644 --- a/vnpy/rpc/test_server.py +++ b/vnpy/rpc/test_server.py @@ -2,14 +2,14 @@ from __future__ import print_function from __future__ import absolute_import from time import sleep, time -from .vnrpc import RpcServer +from vnpy.rpc import RpcServer class TestServer(RpcServer): """ Test RpcServer """ - + def __init__(self, rep_address, pub_address): """ Constructor diff --git a/vnpy/rpc/vnrpc.py b/vnpy/rpc/vnrpc.py index 3debfa65..1cdd835a 100644 --- a/vnpy/rpc/vnrpc.py +++ b/vnpy/rpc/vnrpc.py @@ -6,9 +6,9 @@ import zmq from msgpack import packb, unpackb from json import dumps, loads -import cPickle -p_dumps = cPickle.dumps -p_loads = cPickle.loads +import pickle +p_dumps = pickle.dumps +p_loads = pickle.loads # Achieve Ctrl-c interrupt recv @@ -21,7 +21,7 @@ class RpcObject(object): 1) maspack: higher performance, but usually requires the installation of msgpack related tools; 2) jason: Slightly lower performance but versatility is better, most programming languages have built-in libraries; 3) cPickle: Lower performance and only can be used in Python, but it is very convenient to transfer Python objects directly. - + Therefore, it is recommended to use msgpack. Use json, if you want to communicate with some languages without providing msgpack. Use cPickle, when the data being transferred contains many custom Python objects. @@ -115,10 +115,12 @@ class RpcServer(RpcObject): # Zmq port related self.__context = zmq.Context() - self.__socket_rep = self.__context.socket(zmq.REP) # Reply socket (Request–reply pattern) + self.__socket_rep = self.__context.socket( + zmq.REP) # Reply socket (Request–reply pattern) self.__socket_rep.bind(rep_address) - self.__socket_pub = self.__context.socket(zmq.PUB) # Publish socket (Publish–subscribe pattern) + # Publish socket (Publish–subscribe pattern) + self.__socket_pub = self.__context.socket(zmq.PUB) self.__socket_pub.bind(pub_address) # Woker thread related @@ -166,11 +168,13 @@ class RpcServer(RpcObject): name, args, kwargs = req # Try to get and execute callable function object; capture exception information if it fails + name = name.decode("UTF-8") + try: func = self.__functions[name] r = func(*args, **kwargs) rep = [True, r] - except Exception as e: + except Exception as e: # noqa rep = [False, traceback.format_exc()] # Pack response by serialization @@ -184,10 +188,12 @@ class RpcServer(RpcObject): Publish data """ # Serialized data + topic = bytes(topic, "UTF-8") datab = self.pack(data) # Send data by Publish socket - self.__socket_pub.send_multipart([topic, datab]) # topci must be ascii encoding + # topci must be ascii encoding + self.__socket_pub.send_multipart([topic, datab]) def register(self, func): """ @@ -208,12 +214,15 @@ class RpcClient(RpcObject): self.__sub_address = sub_address self.__context = zmq.Context() - self.__socket_req = self.__context.socket(zmq.REQ) # Request socket (Request–reply pattern) - self.__socket_sub = self.__context.socket(zmq.SUB) # Subscribe socket (Publish–subscribe pattern) + # Request socket (Request–reply pattern) + self.__socket_req = self.__context.socket(zmq.REQ) + # Subscribe socket (Publish–subscribe pattern) + self.__socket_sub = self.__context.socket(zmq.SUB) # Woker thread relate, used to process data pushed from server self.__active = False # RpcClient status - self.__thread = threading.Thread(target=self.run) # RpcClient thread + self.__thread = threading.Thread( + target=self.run) # RpcClient thread def __getattr__(self, name): """ @@ -238,7 +247,7 @@ class RpcClient(RpcObject): if rep[0]: return rep[1] else: - raise RemoteException(rep[1]) + raise RemoteException(rep[1].decode("UTF-8")) return dorpc @@ -284,6 +293,8 @@ class RpcClient(RpcObject): data = self.unpack(datab) # Process data by callable function + topic = topic.decode("UTF-8") + self.callback(topic, data) def callback(self, topic, data): @@ -296,6 +307,7 @@ class RpcClient(RpcObject): """ Subscribe data """ + topic = bytes(topic, "UTF-8") self.__socket_sub.setsockopt(zmq.SUBSCRIBE, topic) From 9164bda6b72a7bf77f61101aee4caf09221e4ebc Mon Sep 17 00:00:00 2001 From: 1122455801 Date: Thu, 4 Apr 2019 17:43:14 +0800 Subject: [PATCH 17/49] Update constant.py --- vnpy/trader/constant.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vnpy/trader/constant.py b/vnpy/trader/constant.py index 65cdf051..f4727ba3 100644 --- a/vnpy/trader/constant.py +++ b/vnpy/trader/constant.py @@ -100,6 +100,7 @@ class Exchange(Enum): # CryptoCurrency BITMEX = "BITMEX" OKEX = "OKEX" + HUOBI = "HUOBI" class Currency(Enum): From 42d0cd555caa18675209b0ed0d32a2086158f961 Mon Sep 17 00:00:00 2001 From: 1122455801 Date: Thu, 4 Apr 2019 17:43:21 +0800 Subject: [PATCH 18/49] Create __init__.py --- vnpy/gateway/huobi/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 vnpy/gateway/huobi/__init__.py diff --git a/vnpy/gateway/huobi/__init__.py b/vnpy/gateway/huobi/__init__.py new file mode 100644 index 00000000..9e3d74e8 --- /dev/null +++ b/vnpy/gateway/huobi/__init__.py @@ -0,0 +1,4 @@ +from .huobi_gateway import HuobiGateway + + + From a54c000ca8804399969d42447850c4550627d214 Mon Sep 17 00:00:00 2001 From: 1122455801 Date: Thu, 4 Apr 2019 17:43:32 +0800 Subject: [PATCH 19/49] Create huobi_gateway.py --- vnpy/gateway/huobi/huobi_gateway.py | 838 ++++++++++++++++++++++++++++ 1 file changed, 838 insertions(+) create mode 100644 vnpy/gateway/huobi/huobi_gateway.py diff --git a/vnpy/gateway/huobi/huobi_gateway.py b/vnpy/gateway/huobi/huobi_gateway.py new file mode 100644 index 00000000..bf03f958 --- /dev/null +++ b/vnpy/gateway/huobi/huobi_gateway.py @@ -0,0 +1,838 @@ +# encoding: UTF-8 + +""" +火币交易接口 +""" + +import hashlib +import hmac +from copy import copy +from datetime import datetime + +from vnpy.api.rest import Request, RestClient +from vnpy.api.websocket import WebsocketClient +from vnpy.trader.constant import ( + Direction, + Exchange, + Product, + Status, +) +from vnpy.trader.gateway import BaseGateway +from vnpy.trader.object import ( + TickData, + OrderData, + TradeData, + AccountData, + ContractData, + OrderRequest, + CancelRequest, + SubscribeRequest, + LogData, +) + +import re +import urllib +import base64 +import json +import zlib +from vnpy.event.engine import Event, EVENT_TIMER +from vnpy.trader.event import EVENT_LOG + + +REST_HOST = "https://api.huobi.pro" +WEBSOCKET_MARKET_HOST = "wss://api.huobi.pro/ws" # 行情 +WEBSOCKET_TRADE_HOST = "wss://api.huobi.pro/ws/v1" # 资金和委托 + +STATUS_HUOBI2VT = { + "submitted": Status.SUBMITTING, + "partial-filled": Status.PARTTRADED, + "filled": Status.ALLTRADED, + "cancelling": Status.CANCELLED, + "partial-canceled": Status.CANCELLED, + "canceled": Status.CANCELLED, +} + + +class HuobiGateway(BaseGateway): + """ + VN Trader Gateway for Huobi connection. + """ + + default_setting = { + "ID": "", + "Secret": "", + "Symbols": "", + } + + def __init__(self, event_engine): + """Constructor""" + super(HuobiGateway, self).__init__(event_engine, "HUOBI") + + self.local_id = 10000 + + self.accountDict = {} + self.orderDict = {} + self.localOrderDict = {} + self.orderLocalDict = {} + + self.qry_enabled = False + + self.rest_api = HuobiRestApi(self) + self.trade_ws_api = HuobiTradeWebsocketApi(self) + self.market_ws_api = HuobiMarketWebsocketApi(self) + + def connect(self, setting: dict): + """""" + key = setting["ID"] + secret = setting["Secret"] + symbols = setting["Symbols"] + + self.rest_api.connect(symbols, secret, key) + self.trade_ws_api.connect(symbols, secret, key) + self.market_ws_api.connect(symbols, secret, key) + # websocket will push all account status on connected, including asset, position and orders. + + def subscribe(self, req: SubscribeRequest): + """""" + self.ws_api.subscribe(req) + + def send_order(self, req: OrderRequest): + """""" + return self.rest_api.send_order(req) + + def cancel_order(self, req: CancelRequest): + """""" + self.rest_api.cancel_order(req) + + def query_account(self): + """""" + self.rest_api.query_account() + + def query_position(self): + """""" + pass + + def close(self): + """""" + self.rest_api.stop() + self.trade_ws_api.stop() + self.market_ws_api.stop() + + def init_query(self): + """初始化连续查询""" + if self.qry_enabled: + # 需要循环的查询函数列表 + self.qry_functionList = [self.qry_info] + + self.qry_count = 0 # 查询触发倒计时 + self.qry_trigger = 1 # 查询触发点 + self.qry_next_function = 0 # 上次运行的查询函数索引 + + self.start_query() + + def query(self, event): + """注册到事件处理引擎上的查询函数""" + self.qry_count += 1 + + if self.qry_count > self.qry_trigger: + # 清空倒计时 + self.qry_count = 0 + + # 执行查询函数 + function = self.qry_functionList[self.qry_next_function] + function() + + # 计算下次查询函数的索引,如果超过了列表长度,则重新设为0 + self.qry_next_function += 1 + if self.qry_next_function == len(self.qry_functionList): + self.qry_next_function = 0 + + def start_query(self): + """启动连续查询""" + self.event_engine.register(EVENT_TIMER, self.query) + + def set_qry_enabled(self, qry_enabled): + """设置是否要启动循环查询""" + self.qry_enabled = qry_enabled + + def write_log(self, msg): + """""" + log = LogData() + log.log_content = msg + log.gateway_name = self.gateway_name + + event = Event(EVENT_LOG) + event.dict_["data"] = log + self.event_engine.put(event) + + +class HuobiRestApi(RestClient): + """ + HUOBI REST API + """ + + def __init__(self, gateway: BaseGateway): + """""" + super(HuobiRestApi, self).__init__() + + self.gateway = gateway + self.gateway_name = gateway.gateway_name + + self.symbols = [] + self.key = "" + self.secret = "" + self.sign_host = "" + + self.account_id = "" + self.cancelReqDict = {} + self.orderBufDict = {} + + self.accountDict = gateway.accountDict + self.orderDict = gateway.orderDict + self.orderLocalDict = gateway.orderLocalDict + self.localOrderDict = gateway.localOrderDict + + self.account_id = "" + self.cancelReqDict = {} + self.orderBufDict = {} + + def sign(self, request): + """ + Generate HUOBI signature. + """ + + request.headers = { + "User-Agent": "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.71 Safari/537.36" + } + params_with_signature = create_signature(self.key, request.method, self.sign_host, request.path, self.secret, request.params) + request.params = params_with_signature + + if request.method == "POST": + request.headers["Content-Type"] = "application/json" + + if request.data: + request.data = json.dumps(request.data) + return request + + def connect( + self, + key: str, + secret: str, + symbols, + session_number=3, + ): + """ + Initialize connection to REST server. + """ + self.key = key + self.secret = secret + self.symbols = symbols + + host, path = _split_url(REST_HOST) + self.init(REST_HOST) + + self.sign_host = host + self.start(session_number) + + self.gateway.write_log("REST API启动成功") + + def query_account(self): + """""" + self.add_request( + method="GET", + path="/v1/account/accounts", + callable=self.on_query_account + ) + + def query_account_balance(self): + """""" + # path = "/v1/account/accounts/%s/balance" %self.account_id + path = f"/v1/account/accounts/{self.account_id}/balance" + self.add_request( + method="GET", + path=path, + callable=self.on_query_account_balance + ) + + def query_order(self): + """""" + path = "/v1/order/orders" + + today_date = datetime.now().strftime("%Y-%m-%d") + states_active = "submitted, partial-filled" + + for symbol in self.symbols: + params = { + "symbol": symbol, + "states": states_active, + "end_date": today_date + } + self.add_request( + method="GET", + path=path, + callable=self.on_query_order, + params=params + ) + + def query_contract(self): + """""" + self.add_request( + method="GET", + path="/v1/common/symbols", + callable=self.on_query_contract + ) + + def send_order(self, req: OrderRequest): + """""" + self.gateway.local_id += 1 + + local_id = str(self.gateway.local_id) + + if req.direction == Direction.LONG: + type_ = "buy-limit" + else: + type_ = "sell-limit" + + data = { + "account-id": self.account_id, + "amount": str(req.volume), + "symbol": req.symbol, + "type": type_, + "price": str(req.price), + "source": "api" + } + path = "/v1/order/orders/place" + + self.add_request( + method="POST", + path=path, + callable=self.on_send_order, + data=data, + extra=local_id, + ) + + order = OrderData( + symbol=req.symbol, + exchange=Exchange.HUOBI, + orderid=local_id, + direction=req.direction, + price=req.price, + volume=req.volume, + time=datetime.now(), + gateway_name=self.gateway_name, + ) + + self.orderBufDict[local_id] = order + + self.gateway.on_order(order) + return order.vt_orderid + + def cancel_order(self, req: CancelRequest): + """""" + local_id = req.orderid + order_id = self.localOrderDict.get(local_id, None) + + if order_id: + path = f"/v1/order/orders/{order_id}/submitcancel" + self.add_request( + method="POST", + path=path, + callable=self.on_cancel_order, + ) + + if local_id in self.cancelReqDict: + del self.cancelReqDict[local_id] + else: + self.cancelReqDict[local_id] = req + + def on_query_account(self, data, request): # type: (dict, Request)->None + """""" + for d in data["data"]: + if str(d["type"]) == "spot": + self.account_id = str(d["id"]) + self.gateway.write_log(f"账户代码{self.account_id}查询成功") + + self.query_account_balance() + + def on_query_account_balance(self, data, request): # type: (dict, Request)->None + """""" + status = data.get("status", None) + if status == "error": + msg = "错误代码:%s, 错误信息:%s" % (data["err-code"], data["err-msg"]) + self.gateway.write_log(msg) + return + + self.gateway.write_log(u"资金信息查询成功") + + for d in data["data"]["list"]: + currency = d["currency"] + account = self.accountDict.get(currency, None) + + if not account: + account = AccountData( + account_id=d["currency"], + gateway_name=self.gateway_name, + available=float(d["balance"]) if d["type"] == "trade" else 0.0, + margin=float(d["balance"]) if d["type"] == "frozen" else 0.0, + balance=account.margin + account.available, + ) + self.accountDict[currency] = account + + for account in self.accountDict.values(): + self.gateway.on_account(account) + + self.query_order() + + def on_query_order(self, data, request): # type: (dict, Request)->None + """""" + status = data.get("status", None) + if status == "error": + msg = "错误代码:%s, 错误信息:%s" % (data["err-code"], data["err-msg"]) + self.gateway.write_log(msg) + return + + symbol = request.params["symbol"] + self.gateway.write_log(f"{symbol}委托信息查询成功") + + data["data"].reverse() + for d in data["data"]: + order_id = str(d["id"]) + self.gateway.local_id += 1 + local_id = str(self.gateway.local_id) + + self.orderLocalDict[order_id] = local_id + self.localOrderDict[local_id] = order_id + + if "buy" in d["type"]: + direction = Direction.LONG + else: + direction = Direction.SHORT + + if d["canceled-at"]: + time = datetime.fromtimestamp(d["canceled-at"] / 1000).strftime("%H:%M:%S") + else: + time = datetime.fromtimestamp(d["created-at"] / 1000).strftime("%H:%M:%S") + + order = OrderData( + orderid=local_id, + symbol=d["symbol"], + exchange=Exchange.HUOBI, + price=float(d["price"]), + volume=float(d["amount"]), + direction=direction, + traded=float(d["field-amount"]), + status=STATUS_HUOBI2VT.get(d["state"], None), + time=time, + gateway_name=self.gateway_name, + + ) + + self.orderDict[order_id] = order + self.gateway.on_order(order) + + def on_query_contract(self, data, request): # type: (dict, Request)->None + """""" + status = data.get("status", None) + if status == "error": + msg = "错误代码:%s, 错误信息:%s" % (data["err-code"], data["err-msg"]) + self.gateway.write_log(msg) + return + + self.gateway.write_log("合约信息查询成功") + + for d in data["data"]: + contract = ContractData( + symbol=d["base-currency"] + d["quote-currency"], + exchange=Exchange.HUOBI, + name="/".join([d["base-currency"].upper(), d["quote-currency"].upper()]), + pricetick=1 / pow(10, d["price-precision"]), + size=1 / pow(10, d["amount-precision"]), + product=Product.SPOT, + gateway_name=self.gateway_name, + ) + self.gateway.on_contract(contract) + + self.query_account() + + def on_send_order(self, data, request): # type: (dict, Request)->None + """""" + local_id = request.extra + order = self.orderBufDict[local_id] + + status = data.get("status", None) + + if status == "error": + msg = "错误代码:%s, 错误信息:%s" % (data["err-code"], data["err-msg"]) + self.gateway.write_log(msg) + + order.status = Status.REJECTED + self.gateway.on_order(order) + return + + order_id = str(data["data"]) + + self.localOrderDict[local_id] = order_id + self.orderDict[order_id] = order + + req = self.cancelReqDict.get(local_id, None) + if req: + self.cancel_order(req) + + def on_cancel_order(self, data, request): # type: (dict, Request)->None + """""" + status = data.get("status", None) + if status == "error": + msg = "错误代码:%s, 错误信息:%s" % (data["err-code"], data["err-msg"]) + self.gateway.write_log(msg) + return + + self.gateway.write_log(f"委托撤单成功:{data}") + + +class HuobiWebsocketApiBase(WebsocketClient): + """""" + + def __init__(self, gateway): + """""" + super(HuobiWebsocketApiBase, self).__init__() + + self.gateway = gateway + self.gateway_name = gateway.gateway_name + + self.key = "" + self.secret = "" + self.sign_host = "" + self.path = "" + + def connect(self, key: str, secret: str, url: str): + """""" + self.key = key + self.secret = secret + host, path = _split_url(url) + + self.init(url) + self.sign_host = host + self.path = path + self.start() + + def login(self): + """""" + params = {"op": "auth", } + params.update(create_signature(self.key, "GET", self.sign_host, self.path, self.secret)) + return self.send_packet(params) + + def on_login(self, packet): + """""" + pass + + @staticmethod + def unpack_data(data): + """""" + return json.loads(zlib.decompress(data, 31)) + + def on_packet(self, packet): + """""" + if "ping" in packet: + self.send_packet({"pong": packet["ping"]}) + return + + if "err-msg" in packet: + return self.on_error_msg(packet) + + if "op" in packet and packet["op"] == "auth": + return self.on_login() + + self.on_data(packet) + + def on_data(self, packet): + """""" + print("data : {}".format(packet)) + + def on_error_msg(self, packet): + """""" + msg = packet["err-msg"] + if msg == "invalid pong": + return + + self.gateway.write_log(packet["err-msg"]) + + +class HuobiTradeWebsocketApi(HuobiWebsocketApiBase): + """""" + def __init__(self, gateway): + """""" + super(HuobiTradeWebsocketApi, self).__init__(gateway) + + self.req_id = 10000 + + self.accountDict = gateway.accountDict + self.orderDict = gateway.orderDict + self.orderLocalDict = gateway.orderLocalDict + self.localOrderDict = gateway.localOrderDict + + def connect(self, symbols, key, secret): + """""" + self.symbols = symbols + super(HuobiTradeWebsocketApi, self).connect(key, secret, WEBSOCKET_TRADE_HOST) + + def subscribe_topic(self): + """""" + # 订阅资金变动 + self.req_id += 1 + req = { + "op": "sub", + "cid": str(self.req_id), + "topic": "accounts", + } + self.send_packet(req) + + # 订阅委托变动 + for symbol in self.symbols: + self.req_id += 1 + req = { + "op": "sub", + "cid": str(self.req_id), + "topic": f"orders.{symbol}" + } + self.send_packet(req) + + def on_connected(self): + """""" + self.login() + + def on_login(self): + """""" + self.gateway.write_log("交易Websocket服务器登录成功") + + self.subscribe_topic() + + def on_data(self, packet): # type: (dict)->None + """""" + op = packet.get("op", None) + if op != "notify": + return + + topic = packet["topic"] + if topic == "accounts": + self.on_account(packet["data"]) + elif "orders" in topic: + self.on_order(packet["data"]) + + def on_account(self, data): + """""" + for d in data["list"]: + account = self.accountDict.get(d["currency"], None) + if not account: + continue + + if d["type"] == "trade": + account.available = float(d["balance"]) + elif d["type"] == "frozen": + account.margin = float(d["balance"]) + + account.balance = account.margin + account.available + self.gateway.on_account(account) + + def on_order(self, data: list): + """""" + order_id = str(data["order-id"]) + order = self.orderDict.get(order_id, None) + + if not order: + local_id = self._new_order_id() + local_id = str(local_id) + + self.orderLocalDict[order_id] = local_id + self.localOrderDict[local_id] = order_id + + if "buy" in data["order-type"]: + direction = Direction.LONG + else: + direction = Direction.SHORT + + order = OrderData( + orderid=local_id, + symbol=data["symbol"], + exchange=Exchange.HUOBI, + price=float(data["order-price"]), + volume=float(data["order-amount"]), + direction=direction, + status=STATUS_HUOBI2VT.get(data["order-state"], None), + time=datetime.fromtimestamp(data["created-at"] / 1000).strftime("%H:%M:%S"), + gateway_name=self.gateway_name, + ) + order.traded += float(data['filled-amount']) + self.orderDict[order_id] = order + self.gateway.onOrder(order) + + if float(data["filled-amount"]): + trade = TradeData( + orderid=order.orderid, + tradeid=str(data["seq-id"]), + symbol=data["symbol"], + exchange=Exchange.HUOBI, + direction=order.direction, + price=float(data["price"]), + volume=float(data["filled-amount"]), + time=datetime.now().strftime("%H:%M:%S"), + gateway_name=self.gateway_name, + ) + self.gateway.onTrade(trade) + + +class HuobiMarketWebsocketApi(HuobiWebsocketApiBase): + """""" + def __init__(self, gateway): + """""" + super(HuobiMarketWebsocketApi, self).__init__(gateway) + + self.req_id = 10000 + self.tickDict = {} + + def connect(self, symbols, key, secret): + """""" + self.symbols = symbols + super(HuobiMarketWebsocketApi, self).connect(key, secret, WEBSOCKET_MARKET_HOST) + + def on_connected(self): + """""" + self.subscribe_topic() + + def subscribe_topic(self): # type:()->None + """""" + for symbol in self.symbols: + # 创建Tick对象 + tick = TickData( + symbol=symbol, + exchange=Exchange.HUOBI, + gateway_name=self.gateway_name, + ) + + self.tickDict[symbol] = tick + + # 订阅深度和成交 + self.req_id += 1 + req = { + "sub": "market.{symbol}.depth.step0", + "id": str(self.req_id) + } + self.send_packet(req) + + self.req_id += 1 + req = { + "sub": "market.{symbol}.detail", + "id": str(self.req_id) + } + self.send_packet(req) + + def on_data(self, packet): # type: (dict)->None + """""" + if "ch" in packet: + if "depth.step" in packet["ch"]: + self.on_market_depth(packet) + elif "detail" in packet["ch"]: + self.on_market_detail(packet) + elif "err-code" in packet: + self.gateway.write_log("错误代码:%s, 错误信息:%s" % (data["err-code"], data["err-msg"])) + + def on_market_depth(self, data): + """行情深度推送 """ + symbol = data["ch"].split(".")[1] + + tick = self.tickDict.get(symbol, None) + if not tick: + return + + tick.datetime = datetime.fromtimestamp(data["ts"] / 1000) + tick.date = tick.datetime.strftime("%Y%m%d") + tick.time = tick.datetime.strftime("%H:%M:%S.%f") + + bids = data["tick"]["bids"] + for n in range(5): + l = bids[n] + tick.__setattr__("bid_price_" + str(n + 1), float(l[0])) + tick.__setattr__("bid_volume_" + str(n + 1), float(l[1])) + + asks = data["tick"]["asks"] + for n in range(5): + l = asks[n] + tick.__setattr__("ask_price_" + str(n + 1), float(l[0])) + tick.__setattr__("ask_volume_" + str(n + 1), float(l[1])) + + if tick.last_price: + newtick = copy(tick) + self.gateway.on_tick(newtick) + + def on_market_detail(self, data): + """市场细节推送""" + symbol = data["ch"].split(".")[1] + + tick = self.tickDict.get(symbol, None) + if not tick: + return + + tick.datetime = datetime.fromtimestamp(data["ts"] / 1000) + tick.date = tick.datetime.strftime("%Y%m%d") + tick.time = tick.datetime.strftime("%H:%M:%S.%f") + + t = data["tick"] + tick.open_price = float(t["open"]) + tick.high_price = float(t["high"]) + tick.low_price = float(t["low"]) + tick.last_price = float(t["close"]) + tick.volume = float(t["vol"]) + tick.pre_close = float(tick.open_price) + + if tick.bid_price_1: + newtick = copy(tick) + self.gateway.on_tick(newtick) + + +def print_dict(d): + """""" + print("-" * 30) + l = d.keys() + l.sort() + for k in l: + print(type(k), k, d[k]) + + +def _split_url(url): + """ + 将url拆分为host和path + :return: host, path + """ + m = re.match("\w+://([^/]*)(.*)", url) + if m: + return m.group(1), m.group(2) + + +def create_signature(api_key, method, host, path, secret_key, get_params=None): + """ + 创建签名 + :param get_params: dict 使用GET方法时附带的额外参数(urlparams) + :return: + """ + sortedParams = [ + ("AccessKeyId", api_key), + ("SignatureMethod", "HmacSHA256"), + ("SignatureVersion", "2"), + ("Timestamp", datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S")) + ] + if get_params: + sortedParams.extend(get_params.items()) + sortedParams = list(sorted(sortedParams)) + encodeParams = urllib.urlencode(sortedParams) + + payload = [method, host, path, encodeParams] + payload = "\n".join(payload) + payload = payload.encode(encoding="UTF8") + + secret_key = secret_key.encode(encoding="UTF8") + + digest = hmac.new(secret_key, payload, digestmod=hashlib.sha256).digest() + signature = base64.b64encode(digest) + + params = dict(sortedParams) + params["Signature"] = signature + return params From 3964142b2ed7f9dfec9611b0b64262a29524d53c Mon Sep 17 00:00:00 2001 From: "vn.py" Date: Thu, 4 Apr 2019 23:52:35 +0800 Subject: [PATCH 20/49] [Mod]complete test for huobi gateway --- tests/trader/run.py | 2 + vnpy/gateway/huobi/huobi_gateway.py | 801 +++++++++++++--------------- 2 files changed, 365 insertions(+), 438 deletions(-) diff --git a/tests/trader/run.py b/tests/trader/run.py index d77acee9..0b7fb501 100644 --- a/tests/trader/run.py +++ b/tests/trader/run.py @@ -11,6 +11,7 @@ from vnpy.gateway.ctp import CtpGateway from vnpy.gateway.tiger import TigerGateway from vnpy.gateway.oes import OesGateway from vnpy.gateway.okex import OkexGateway +from vnpy.gateway.huobi import HuobiGateway from vnpy.app.cta_strategy import CtaStrategyApp from vnpy.app.csv_loader import CsvLoaderApp @@ -30,6 +31,7 @@ def main(): main_engine.add_gateway(TigerGateway) main_engine.add_gateway(OesGateway) main_engine.add_gateway(OkexGateway) + main_engine.add_gateway(HuobiGateway) main_engine.add_app(CtaStrategyApp) main_engine.add_app(CsvLoaderApp) diff --git a/vnpy/gateway/huobi/huobi_gateway.py b/vnpy/gateway/huobi/huobi_gateway.py index bf03f958..59b8a908 100644 --- a/vnpy/gateway/huobi/huobi_gateway.py +++ b/vnpy/gateway/huobi/huobi_gateway.py @@ -4,18 +4,25 @@ 火币交易接口 """ +import re +import urllib +import base64 +import json +import zlib import hashlib import hmac from copy import copy from datetime import datetime -from vnpy.api.rest import Request, RestClient +from vnpy.event import Event +from vnpy.api.rest import RestClient from vnpy.api.websocket import WebsocketClient from vnpy.trader.constant import ( Direction, Exchange, Product, Status, + OrderType ) from vnpy.trader.gateway import BaseGateway from vnpy.trader.object import ( @@ -26,25 +33,17 @@ from vnpy.trader.object import ( ContractData, OrderRequest, CancelRequest, - SubscribeRequest, - LogData, + SubscribeRequest ) - -import re -import urllib -import base64 -import json -import zlib -from vnpy.event.engine import Event, EVENT_TIMER -from vnpy.trader.event import EVENT_LOG +from vnpy.trader.event import EVENT_TIMER -REST_HOST = "https://api.huobi.pro" -WEBSOCKET_MARKET_HOST = "wss://api.huobi.pro/ws" # 行情 -WEBSOCKET_TRADE_HOST = "wss://api.huobi.pro/ws/v1" # 资金和委托 +REST_HOST = "https://api.huobipro.com" +WEBSOCKET_DATA_HOST = "wss://api.huobi.pro/ws" # Market Data +WEBSOCKET_TRADE_HOST = "wss://api.huobi.pro/ws/v1" # Account and Order STATUS_HUOBI2VT = { - "submitted": Status.SUBMITTING, + "submitted": Status.NOTTRADED, "partial-filled": Status.PARTTRADED, "filled": Status.ALLTRADED, "cancelling": Status.CANCELLED, @@ -52,6 +51,17 @@ STATUS_HUOBI2VT = { "canceled": Status.CANCELLED, } +ORDERTYPE_VT2HUOBI = { + (Direction.LONG, OrderType.MARKET): "buy-market", + (Direction.SHORT, OrderType.MARKET): "sell-market", + (Direction.LONG, OrderType.LIMIT): "buy-limit", + (Direction.SHORT, OrderType.LIMIT): "sell-limit", +} +ORDERTYPE_HUOBI2VT = {v: k for k, v in ORDERTYPE_VT2HUOBI.items()} + + +huobi_symbols = set() + class HuobiGateway(BaseGateway): """ @@ -59,42 +69,47 @@ class HuobiGateway(BaseGateway): """ default_setting = { - "ID": "", - "Secret": "", - "Symbols": "", + "API Key": "", + "Secret Key": "", + "会话数": 3, + "代理地址": "127.0.0.1", + "代理端口": 1080, } def __init__(self, event_engine): """Constructor""" super(HuobiGateway, self).__init__(event_engine, "HUOBI") - self.local_id = 10000 - - self.accountDict = {} - self.orderDict = {} - self.localOrderDict = {} - self.orderLocalDict = {} - - self.qry_enabled = False + self.order_count = 100000 + + self.local_huobi_map = {} # local orderid: huobi orderid + self.huobi_local_map = {} # huobi orderid: local orderid + self.local_order_map = {} # local orderid: order + self.huobi_order_data = {} # huobi orderid: data self.rest_api = HuobiRestApi(self) self.trade_ws_api = HuobiTradeWebsocketApi(self) - self.market_ws_api = HuobiMarketWebsocketApi(self) + self.market_ws_api = HuobiDataWebsocketApi(self) def connect(self, setting: dict): """""" - key = setting["ID"] - secret = setting["Secret"] - symbols = setting["Symbols"] - - self.rest_api.connect(symbols, secret, key) - self.trade_ws_api.connect(symbols, secret, key) - self.market_ws_api.connect(symbols, secret, key) - # websocket will push all account status on connected, including asset, position and orders. + key = setting["API Key"] + secret = setting["Secret Key"] + session_number = setting["会话数"] + proxy_host = setting["代理地址"] + proxy_port = setting["代理端口"] + + self.rest_api.connect(key, secret, session_number, + proxy_host, proxy_port) + self.trade_ws_api.connect(key, secret, proxy_host, proxy_port) + self.market_ws_api.connect(key, secret, proxy_host, proxy_port) + + self.init_query() def subscribe(self, req: SubscribeRequest): """""" - self.ws_api.subscribe(req) + self.market_ws_api.subscribe(req) + self.trade_ws_api.subscribe(req) def send_order(self, req: OrderRequest): """""" @@ -106,7 +121,7 @@ class HuobiGateway(BaseGateway): def query_account(self): """""" - self.rest_api.query_account() + self.rest_api.query_account_balance() def query_position(self): """""" @@ -118,53 +133,62 @@ class HuobiGateway(BaseGateway): self.trade_ws_api.stop() self.market_ws_api.stop() - def init_query(self): - """初始化连续查询""" - if self.qry_enabled: - # 需要循环的查询函数列表 - self.qry_functionList = [self.qry_info] - - self.qry_count = 0 # 查询触发倒计时 - self.qry_trigger = 1 # 查询触发点 - self.qry_next_function = 0 # 上次运行的查询函数索引 - - self.start_query() - - def query(self, event): - """注册到事件处理引擎上的查询函数""" - self.qry_count += 1 - - if self.qry_count > self.qry_trigger: - # 清空倒计时 - self.qry_count = 0 - - # 执行查询函数 - function = self.qry_functionList[self.qry_next_function] - function() - - # 计算下次查询函数的索引,如果超过了列表长度,则重新设为0 - self.qry_next_function += 1 - if self.qry_next_function == len(self.qry_functionList): - self.qry_next_function = 0 - - def start_query(self): - """启动连续查询""" - self.event_engine.register(EVENT_TIMER, self.query) - - def set_qry_enabled(self, qry_enabled): - """设置是否要启动循环查询""" - self.qry_enabled = qry_enabled - - def write_log(self, msg): + def process_timer_event(self, event: Event): """""" - log = LogData() - log.log_content = msg - log.gateway_name = self.gateway_name - - event = Event(EVENT_LOG) - event.dict_["data"] = log - self.event_engine.put(event) + self.count += 1 + if self.count < 3: + return + self.query_account() + + def init_query(self): + """""" + self.count = 0 + self.event_engine.register(EVENT_TIMER, self.process_timer_event) + + def get_local_orderid(self, huobi_orderid: str): + """""" + local_orderid = self.huobi_local_map.get(huobi_orderid, None) + + if not local_orderid: + local_orderid = self.new_local_orderid() + self.update_orderid_map(local_orderid, huobi_orderid) + + return local_orderid + + def get_huobi_orderid(self, local_orderid: str): + """""" + huobi_orderid = self.local_huobi_map.get(local_orderid, "") + return huobi_orderid + + def new_local_orderid(self): + """""" + self.order_count += 1 + return str(self.order_count) + + def update_orderid_map(self, local_orderid: str, huobi_orderid: str): + """""" + self.huobi_local_map[huobi_orderid] = local_orderid + self.local_huobi_map[local_orderid] = huobi_orderid + + if huobi_orderid in self.huobi_order_data: + data = self.huobi_order_data.pop(huobi_orderid) + self.trade_ws_api.on_order(data) + + def on_order(self, order: OrderData): + """""" + self.local_order_map[order.orderid] = order + + super().on_order(copy(order)) + + def get_order(self, huobi_orderid: str): + """""" + local_orderid = self.huobi_local_map.get(huobi_orderid, None) + if not local_orderid: + return None + else: + return self.local_order_map[local_orderid] + class HuobiRestApi(RestClient): """ @@ -178,315 +202,283 @@ class HuobiRestApi(RestClient): self.gateway = gateway self.gateway_name = gateway.gateway_name - self.symbols = [] + self.host = "" self.key = "" self.secret = "" - self.sign_host = "" - self.account_id = "" - self.cancelReqDict = {} - self.orderBufDict = {} - self.accountDict = gateway.accountDict - self.orderDict = gateway.orderDict - self.orderLocalDict = gateway.orderLocalDict - self.localOrderDict = gateway.localOrderDict - - self.account_id = "" - self.cancelReqDict = {} - self.orderBufDict = {} + self.cancel_requests = {} + self.orders = {} def sign(self, request): """ Generate HUOBI signature. """ - request.headers = { "User-Agent": "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.71 Safari/537.36" } - params_with_signature = create_signature(self.key, request.method, self.sign_host, request.path, self.secret, request.params) + params_with_signature = create_signature( + self.key, + request.method, + self.host, + request.path, + self.secret, + request.params + ) request.params = params_with_signature - + if request.method == "POST": request.headers["Content-Type"] = "application/json" - + if request.data: - request.data = json.dumps(request.data) + request.data = json.dumps(request.data) + return request def connect( self, key: str, secret: str, - symbols, - session_number=3, + session_number: int, + proxy_host: str, + proxy_port: int ): """ Initialize connection to REST server. """ self.key = key self.secret = secret - self.symbols = symbols - host, path = _split_url(REST_HOST) - self.init(REST_HOST) - - self.sign_host = host + self.host, _ = _split_url(REST_HOST) + + self.init(REST_HOST, proxy_host, proxy_port) self.start(session_number) self.gateway.write_log("REST API启动成功") + self.query_contract() + self.query_account() + self.query_order() + def query_account(self): """""" self.add_request( method="GET", path="/v1/account/accounts", - callable=self.on_query_account + callback=self.on_query_account ) def query_account_balance(self): """""" - # path = "/v1/account/accounts/%s/balance" %self.account_id path = f"/v1/account/accounts/{self.account_id}/balance" self.add_request( - method="GET", - path=path, - callable=self.on_query_account_balance + method="GET", + path=path, + callback=self.on_query_account_balance ) def query_order(self): """""" - path = "/v1/order/orders" - - today_date = datetime.now().strftime("%Y-%m-%d") - states_active = "submitted, partial-filled" - - for symbol in self.symbols: - params = { - "symbol": symbol, - "states": states_active, - "end_date": today_date - } - self.add_request( - method="GET", - path=path, - callable=self.on_query_order, - params=params - ) - + self.add_request( + method="GET", + path="/v1/order/openOrders", + callback=self.on_query_order + ) + def query_contract(self): """""" self.add_request( - method="GET", - path="/v1/common/symbols", - callable=self.on_query_contract + method="GET", + path="/v1/common/symbols", + callback=self.on_query_contract ) def send_order(self, req: OrderRequest): """""" - self.gateway.local_id += 1 + huobi_type = ORDERTYPE_VT2HUOBI.get( + (req.direction, req.type), "" + ) - local_id = str(self.gateway.local_id) - - if req.direction == Direction.LONG: - type_ = "buy-limit" - else: - type_ = "sell-limit" + local_orderid = self.gateway.new_local_orderid() + order = req.create_order_data( + local_orderid, + self.gateway_name + ) + order.time = datetime.now().strftime("%H:%M:%S") data = { "account-id": self.account_id, "amount": str(req.volume), "symbol": req.symbol, - "type": type_, + "type": huobi_type, "price": str(req.price), "source": "api" } - path = "/v1/order/orders/place" - + self.add_request( - method="POST", - path=path, - callable=self.on_send_order, - data=data, - extra=local_id, + method="POST", + path="/v1/order/orders/place", + callback=self.on_send_order, + data=data, + extra=order, ) - order = OrderData( - symbol=req.symbol, - exchange=Exchange.HUOBI, - orderid=local_id, - direction=req.direction, - price=req.price, - volume=req.volume, - time=datetime.now(), - gateway_name=self.gateway_name, - ) - - self.orderBufDict[local_id] = order - self.gateway.on_order(order) return order.vt_orderid def cancel_order(self, req: CancelRequest): """""" local_id = req.orderid - order_id = self.localOrderDict.get(local_id, None) + huobi_orderid = self.gateway.get_huobi_orderid(local_id) - if order_id: - path = f"/v1/order/orders/{order_id}/submitcancel" - self.add_request( - method="POST", - path=path, - callable=self.on_cancel_order, - ) - - if local_id in self.cancelReqDict: - del self.cancelReqDict[local_id] - else: - self.cancelReqDict[local_id] = req + if not huobi_orderid: + self.cancel_requests[local_id] = req + return - def on_query_account(self, data, request): # type: (dict, Request)->None + path = f"/v1/order/orders/{huobi_orderid}/submitcancel" + self.add_request( + method="POST", + path=path, + callback=self.on_cancel_order, + extra=req + ) + + if local_id in self.cancel_requests: + self.cancel_requests.pop(local_id) + + def on_query_account(self, data, request): """""" + if self.check_error(data, "查询账户"): + return + for d in data["data"]: - if str(d["type"]) == "spot": - self.account_id = str(d["id"]) + if d["type"] == "spot": + self.account_id = d["id"] self.gateway.write_log(f"账户代码{self.account_id}查询成功") self.query_account_balance() - def on_query_account_balance(self, data, request): # type: (dict, Request)->None + def on_query_account_balance(self, data, request): """""" - status = data.get("status", None) - if status == "error": - msg = "错误代码:%s, 错误信息:%s" % (data["err-code"], data["err-msg"]) - self.gateway.write_log(msg) + if self.check_error(data, "查询账户资金"): return - - self.gateway.write_log(u"资金信息查询成功") - + + buf = {} for d in data["data"]["list"]: currency = d["currency"] - account = self.accountDict.get(currency, None) + currency_data = buf.setdefault(currency, {}) + currency_data[d["type"]] = float(d["balance"]) - if not account: - account = AccountData( - account_id=d["currency"], - gateway_name=self.gateway_name, - available=float(d["balance"]) if d["type"] == "trade" else 0.0, - margin=float(d["balance"]) if d["type"] == "frozen" else 0.0, - balance=account.margin + account.available, - ) - self.accountDict[currency] = account + for currency, currency_data in buf.items(): + account = AccountData( + accountid=currency, + balance=currency_data["trade"] + currency_data["frozen"], + frozen=currency_data["frozen"], + gateway_name=self.gateway_name, + ) - for account in self.accountDict.values(): - self.gateway.on_account(account) - - self.query_order() + if account.balance: + self.gateway.on_account(account) - def on_query_order(self, data, request): # type: (dict, Request)->None - """""" - status = data.get("status", None) - if status == "error": - msg = "错误代码:%s, 错误信息:%s" % (data["err-code"], data["err-msg"]) - self.gateway.write_log(msg) + def on_query_order(self, data, request): + """""" + if self.check_error(data, "查询委托"): return - - symbol = request.params["symbol"] - self.gateway.write_log(f"{symbol}委托信息查询成功") - - data["data"].reverse() + for d in data["data"]: - order_id = str(d["id"]) - self.gateway.local_id += 1 - local_id = str(self.gateway.local_id) + huobi_orderid = d["id"] + local_orderid = self.gateway.get_local_orderid(huobi_orderid) - self.orderLocalDict[order_id] = local_id - self.localOrderDict[local_id] = order_id - - if "buy" in d["type"]: - direction = Direction.LONG - else: - direction = Direction.SHORT - - if d["canceled-at"]: - time = datetime.fromtimestamp(d["canceled-at"] / 1000).strftime("%H:%M:%S") - else: - time = datetime.fromtimestamp(d["created-at"] / 1000).strftime("%H:%M:%S") + direction, order_type = ORDERTYPE_HUOBI2VT[d["type"]] + dt = datetime.fromtimestamp(d["created-at"] / 1000) + time = dt.strftime("%H:%M:%S") order = OrderData( - orderid=local_id, + orderid=local_orderid, symbol=d["symbol"], exchange=Exchange.HUOBI, price=float(d["price"]), volume=float(d["amount"]), + type=order_type, direction=direction, - traded=float(d["field-amount"]), + traded=float(d["filled-amount"]), status=STATUS_HUOBI2VT.get(d["state"], None), time=time, gateway_name=self.gateway_name, - ) - self.orderDict[order_id] = order self.gateway.on_order(order) + + self.gateway.write_log("委托信息查询成功") def on_query_contract(self, data, request): # type: (dict, Request)->None """""" - status = data.get("status", None) - if status == "error": - msg = "错误代码:%s, 错误信息:%s" % (data["err-code"], data["err-msg"]) - self.gateway.write_log(msg) + if self.check_error(data, "查询合约"): return - - self.gateway.write_log("合约信息查询成功") - + for d in data["data"]: + base_currency = d["base-currency"] + quote_currency = d["quote-currency"] + name = f"{base_currency.upper()}/{quote_currency.upper()}" + pricetick = 1 / pow(10, d["price-precision"]) + size = 1 / pow(10, d["amount-precision"]) + contract = ContractData( - symbol=d["base-currency"] + d["quote-currency"], + symbol=d["symbol"], exchange=Exchange.HUOBI, - name="/".join([d["base-currency"].upper(), d["quote-currency"].upper()]), - pricetick=1 / pow(10, d["price-precision"]), - size=1 / pow(10, d["amount-precision"]), + name=name, + pricetick=pricetick, + size=size, product=Product.SPOT, gateway_name=self.gateway_name, ) self.gateway.on_contract(contract) - self.query_account() + huobi_symbols.add(contract.symbol) - def on_send_order(self, data, request): # type: (dict, Request)->None + self.gateway.write_log("合约信息查询成功") + + def on_send_order(self, data, request): """""" - local_id = request.extra - order = self.orderBufDict[local_id] + order = request.extra - status = data.get("status", None) - - if status == "error": - msg = "错误代码:%s, 错误信息:%s" % (data["err-code"], data["err-msg"]) - self.gateway.write_log(msg) - + if self.check_error(data, "委托"): order.status = Status.REJECTED self.gateway.on_order(order) return + + huobi_orderid = str(data["data"]) + self.gateway.update_orderid_map(order.orderid, huobi_orderid) - order_id = str(data["data"]) - - self.localOrderDict[local_id] = order_id - self.orderDict[order_id] = order - - req = self.cancelReqDict.get(local_id, None) + req = self.cancel_requests.get(order.orderid, None) if req: self.cancel_order(req) - def on_cancel_order(self, data, request): # type: (dict, Request)->None + def on_cancel_order(self, data, request): """""" - status = data.get("status", None) - if status == "error": - msg = "错误代码:%s, 错误信息:%s" % (data["err-code"], data["err-msg"]) - self.gateway.write_log(msg) + if self.check_error(data, "撤单"): return + + cancel_request = request.extra + local_orderid = cancel_request.orderid + huobi_orderid = self.gateway.get_huobi_orderid(local_orderid) + + order = self.gateway.get_order(huobi_orderid) + order.status = Status.CANCELLED - self.gateway.write_log(f"委托撤单成功:{data}") + self.gateway.on_order(copy(order)) + self.gateway.write_log(f"委托撤单成功:{order.orderid}") + + def check_error(self, data: dict, func: str = ""): + """""" + if data["status"] != "error": + return False + + error_code = data["err-code"] + error_msg = data["err-msg"] + + self.gateway.write_log(f"{func}请求出错,代码:{error_code},信息:{error_msg}") + return True class HuobiWebsocketApiBase(WebsocketClient): @@ -504,20 +496,28 @@ class HuobiWebsocketApiBase(WebsocketClient): self.sign_host = "" self.path = "" - def connect(self, key: str, secret: str, url: str): + def connect( + self, + key: str, + secret: str, + url: str, + proxy_host: str, + proxy_port: int + ): """""" self.key = key self.secret = secret - host, path = _split_url(url) - self.init(url) + host, path = _split_url(url) self.sign_host = host self.path = path + + self.init(url, proxy_host, proxy_port) self.start() def login(self): """""" - params = {"op": "auth", } + params = {"op": "auth"} params.update(create_signature(self.key, "GET", self.sign_host, self.path, self.secret)) return self.send_packet(params) @@ -534,15 +534,12 @@ class HuobiWebsocketApiBase(WebsocketClient): """""" if "ping" in packet: self.send_packet({"pong": packet["ping"]}) - return - - if "err-msg" in packet: + elif "err-msg" in packet: return self.on_error_msg(packet) - - if "op" in packet and packet["op"] == "auth": + elif "op" in packet and packet["op"] == "auth": return self.on_login() - - self.on_data(packet) + else: + self.on_data(packet) def on_data(self, packet): """""" @@ -561,50 +558,32 @@ class HuobiTradeWebsocketApi(HuobiWebsocketApiBase): """""" def __init__(self, gateway): """""" - super(HuobiTradeWebsocketApi, self).__init__(gateway) + super().__init__(gateway) - self.req_id = 10000 - - self.accountDict = gateway.accountDict - self.orderDict = gateway.orderDict - self.orderLocalDict = gateway.orderLocalDict - self.localOrderDict = gateway.localOrderDict + self.req_id = 0 - def connect(self, symbols, key, secret): + def connect(self, key, secret, proxy_host, proxy_port): """""" - self.symbols = symbols - super(HuobiTradeWebsocketApi, self).connect(key, secret, WEBSOCKET_TRADE_HOST) + super().connect(key, secret, WEBSOCKET_TRADE_HOST, proxy_host, proxy_port) - def subscribe_topic(self): + def subscribe(self, req: SubscribeRequest): """""" - # 订阅资金变动 self.req_id += 1 req = { "op": "sub", "cid": str(self.req_id), - "topic": "accounts", + "topic": f"orders.{req.symbol}" } self.send_packet(req) - - # 订阅委托变动 - for symbol in self.symbols: - self.req_id += 1 - req = { - "op": "sub", - "cid": str(self.req_id), - "topic": f"orders.{symbol}" - } - self.send_packet(req) def on_connected(self): """""" + self.gateway.write_log("交易Websocket API连接成功") self.login() def on_login(self): """""" - self.gateway.write_log("交易Websocket服务器登录成功") - - self.subscribe_topic() + self.gateway.write_log("交易Websocket API登录成功") def on_data(self, packet): # type: (dict)->None """""" @@ -613,188 +592,133 @@ class HuobiTradeWebsocketApi(HuobiWebsocketApiBase): return topic = packet["topic"] - if topic == "accounts": - self.on_account(packet["data"]) - elif "orders" in topic: + if "orders" in topic: self.on_order(packet["data"]) - def on_account(self, data): + def on_order(self, data: dict): """""" - for d in data["list"]: - account = self.accountDict.get(d["currency"], None) - if not account: - continue - - if d["type"] == "trade": - account.available = float(d["balance"]) - elif d["type"] == "frozen": - account.margin = float(d["balance"]) - - account.balance = account.margin + account.available - self.gateway.on_account(account) - - def on_order(self, data: list): - """""" - order_id = str(data["order-id"]) - order = self.orderDict.get(order_id, None) - + huobi_orderid = str(data["order-id"]) + order = self.gateway.get_order(huobi_orderid) if not order: - local_id = self._new_order_id() - local_id = str(local_id) - - self.orderLocalDict[order_id] = local_id - self.localOrderDict[local_id] = order_id - - if "buy" in data["order-type"]: - direction = Direction.LONG - else: - direction = Direction.SHORT - - order = OrderData( - orderid=local_id, - symbol=data["symbol"], - exchange=Exchange.HUOBI, - price=float(data["order-price"]), - volume=float(data["order-amount"]), - direction=direction, - status=STATUS_HUOBI2VT.get(data["order-state"], None), - time=datetime.fromtimestamp(data["created-at"] / 1000).strftime("%H:%M:%S"), - gateway_name=self.gateway_name, - ) - order.traded += float(data['filled-amount']) - self.orderDict[order_id] = order - self.gateway.onOrder(order) + self.gateway.huobi_order_data[huobi_orderid] = data + return - if float(data["filled-amount"]): + traded_volume = float(data["filled-amount"]) + order.traded += traded_volume + order.status = STATUS_HUOBI2VT.get(data["order-state"], None) + self.gateway.on_order(order) + + if traded_volume: trade = TradeData( + symbol=order.symbol, + exchange=Exchange.HUOBI, orderid=order.orderid, tradeid=str(data["seq-id"]), - symbol=data["symbol"], - exchange=Exchange.HUOBI, direction=order.direction, price=float(data["price"]), volume=float(data["filled-amount"]), time=datetime.now().strftime("%H:%M:%S"), gateway_name=self.gateway_name, ) - self.gateway.onTrade(trade) + self.gateway.on_trade(trade) -class HuobiMarketWebsocketApi(HuobiWebsocketApiBase): +class HuobiDataWebsocketApi(HuobiWebsocketApiBase): """""" + def __init__(self, gateway): """""" - super(HuobiMarketWebsocketApi, self).__init__(gateway) + super().__init__(gateway) - self.req_id = 10000 - self.tickDict = {} + self.req_id = 0 + self.ticks = {} - def connect(self, symbols, key, secret): + def connect(self, key: str, secret: str, proxy_host: str, proxy_port: int): """""" - self.symbols = symbols - super(HuobiMarketWebsocketApi, self).connect(key, secret, WEBSOCKET_MARKET_HOST) + super().connect(key, secret, WEBSOCKET_DATA_HOST, proxy_host, proxy_port) def on_connected(self): """""" - self.subscribe_topic() - - def subscribe_topic(self): # type:()->None + self.gateway.write_log("行情Websocket API连接成功") + + def subscribe(self, req: SubscribeRequest): """""" - for symbol in self.symbols: - # 创建Tick对象 - tick = TickData( - symbol=symbol, - exchange=Exchange.HUOBI, - gateway_name=self.gateway_name, - ) + symbol = req.symbol - self.tickDict[symbol] = tick + # Create tick data buffer + tick = TickData( + symbol=symbol, + exchange=Exchange.HUOBI, + datetime=datetime.now(), + gateway_name=self.gateway_name, + ) + self.ticks[symbol] = tick - # 订阅深度和成交 - self.req_id += 1 - req = { - "sub": "market.{symbol}.depth.step0", - "id": str(self.req_id) - } - self.send_packet(req) - - self.req_id += 1 - req = { - "sub": "market.{symbol}.detail", - "id": str(self.req_id) - } - self.send_packet(req) + # Subscribe to market depth update + self.req_id += 1 + req = { + "sub": f"market.{symbol}.depth.step0", + "id": str(self.req_id) + } + self.send_packet(req) + + # Subscribe to market detail update + self.req_id += 1 + req = { + "sub": f"market.{symbol}.detail", + "id": str(self.req_id) + } + self.send_packet(req) def on_data(self, packet): # type: (dict)->None """""" - if "ch" in packet: - if "depth.step" in packet["ch"]: + channel = packet.get("ch", None) + if channel: + if "depth.step" in channel: self.on_market_depth(packet) - elif "detail" in packet["ch"]: + elif "detail" in channel: self.on_market_detail(packet) elif "err-code" in packet: - self.gateway.write_log("错误代码:%s, 错误信息:%s" % (data["err-code"], data["err-msg"])) + code = packet["err-code"] + msg = packet["err-msg"] + self.gateway.write_log(f"错误代码:{code}, 错误信息:{msg}") def on_market_depth(self, data): """行情深度推送 """ symbol = data["ch"].split(".")[1] - - tick = self.tickDict.get(symbol, None) - if not tick: - return - + tick = self.ticks[symbol] tick.datetime = datetime.fromtimestamp(data["ts"] / 1000) - tick.date = tick.datetime.strftime("%Y%m%d") - tick.time = tick.datetime.strftime("%H:%M:%S.%f") - + bids = data["tick"]["bids"] for n in range(5): - l = bids[n] - tick.__setattr__("bid_price_" + str(n + 1), float(l[0])) - tick.__setattr__("bid_volume_" + str(n + 1), float(l[1])) + price, volume = bids[n] + tick.__setattr__("bid_price_" + str(n + 1), float(price)) + tick.__setattr__("bid_volume_" + str(n + 1), float(volume)) asks = data["tick"]["asks"] for n in range(5): - l = asks[n] - tick.__setattr__("ask_price_" + str(n + 1), float(l[0])) - tick.__setattr__("ask_volume_" + str(n + 1), float(l[1])) + price, volume = asks[n] + tick.__setattr__("ask_price_" + str(n + 1), float(price)) + tick.__setattr__("ask_volume_" + str(n + 1), float(volume)) if tick.last_price: - newtick = copy(tick) - self.gateway.on_tick(newtick) + self.gateway.on_tick(copy(tick)) def on_market_detail(self, data): """市场细节推送""" symbol = data["ch"].split(".")[1] - - tick = self.tickDict.get(symbol, None) - if not tick: - return - + tick = self.ticks[symbol] tick.datetime = datetime.fromtimestamp(data["ts"] / 1000) - tick.date = tick.datetime.strftime("%Y%m%d") - tick.time = tick.datetime.strftime("%H:%M:%S.%f") - - t = data["tick"] - tick.open_price = float(t["open"]) - tick.high_price = float(t["high"]) - tick.low_price = float(t["low"]) - tick.last_price = float(t["close"]) - tick.volume = float(t["vol"]) - tick.pre_close = float(tick.open_price) + + tick_data = data["tick"] + tick.open_price = float(tick_data["open"]) + tick.high_price = float(tick_data["high"]) + tick.low_price = float(tick_data["low"]) + tick.last_price = float(tick_data["close"]) + tick.volume = float(tick_data["vol"]) if tick.bid_price_1: - newtick = copy(tick) - self.gateway.on_tick(newtick) - - -def print_dict(d): - """""" - print("-" * 30) - l = d.keys() - l.sort() - for k in l: - print(type(k), k, d[k]) + self.gateway.on_tick(copy(tick)) def _split_url(url): @@ -802,9 +726,9 @@ def _split_url(url): 将url拆分为host和path :return: host, path """ - m = re.match("\w+://([^/]*)(.*)", url) - if m: - return m.group(1), m.group(2) + result = re.match("\w+://([^/]*)(.*)", url) # noqa + if result: + return result.group(1), result.group(2) def create_signature(api_key, method, host, path, secret_key, get_params=None): @@ -813,18 +737,19 @@ def create_signature(api_key, method, host, path, secret_key, get_params=None): :param get_params: dict 使用GET方法时附带的额外参数(urlparams) :return: """ - sortedParams = [ + sorted_params = [ ("AccessKeyId", api_key), ("SignatureMethod", "HmacSHA256"), ("SignatureVersion", "2"), ("Timestamp", datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S")) ] + if get_params: - sortedParams.extend(get_params.items()) - sortedParams = list(sorted(sortedParams)) - encodeParams = urllib.urlencode(sortedParams) + sorted_params.extend(list(get_params.items())) + sorted_params = list(sorted(sorted_params)) + encode_params = urllib.parse.urlencode(sorted_params) - payload = [method, host, path, encodeParams] + payload = [method, host, path, encode_params] payload = "\n".join(payload) payload = payload.encode(encoding="UTF8") @@ -833,6 +758,6 @@ def create_signature(api_key, method, host, path, secret_key, get_params=None): digest = hmac.new(secret_key, payload, digestmod=hashlib.sha256).digest() signature = base64.b64encode(digest) - params = dict(sortedParams) - params["Signature"] = signature + params = dict(sorted_params) + params["Signature"] = signature.decode("UTF8") return params From a90433141f723138d11b3082b809f41731ac8732 Mon Sep 17 00:00:00 2001 From: "vn.py" Date: Fri, 5 Apr 2019 10:31:48 +0800 Subject: [PATCH 21/49] [Add]ssleay32.dll problem in docs --- docs/install.md | 9 +++++++++ vnpy/gateway/huobi/huobi_gateway.py | 3 +++ 2 files changed, 12 insertions(+) diff --git a/docs/install.md b/docs/install.md index 48a78528..f09fdb75 100644 --- a/docs/install.md +++ b/docs/install.md @@ -5,6 +5,15 @@ ### 使用VNConda +**ssleay32.dll问题** + +如果电脑上之前安装过其他的Python环境或者应用软件,有可能会存在SSL相关动态链接库路径被修改的问题,在运行VN Station时弹出如下图所示的错误: + +![ssleay32.dll](https://user-images.githubusercontent.com/7112268/55474371-8bd06a00-5643-11e9-8b35-f064a45edfd1.png) + +解决方法: +1. 找到你的VNConda目录 +2. 将VNConda\Lib\site-packages\PyQt5\Qt\bin目录的两个动态库libeay32.dll和ssleay32.dll拷贝到VNConda\下即可 ### 手动安装 diff --git a/vnpy/gateway/huobi/huobi_gateway.py b/vnpy/gateway/huobi/huobi_gateway.py index 59b8a908..b6dfd4bb 100644 --- a/vnpy/gateway/huobi/huobi_gateway.py +++ b/vnpy/gateway/huobi/huobi_gateway.py @@ -61,6 +61,7 @@ ORDERTYPE_HUOBI2VT = {v: k for k, v in ORDERTYPE_VT2HUOBI.items()} huobi_symbols = set() +symbol_name_map = {} class HuobiGateway(BaseGateway): @@ -435,6 +436,7 @@ class HuobiRestApi(RestClient): self.gateway.on_contract(contract) huobi_symbols.add(contract.symbol) + symbol_name_map[contract.symbol] = contract.name self.gateway.write_log("合约信息查询成功") @@ -648,6 +650,7 @@ class HuobiDataWebsocketApi(HuobiWebsocketApiBase): # Create tick data buffer tick = TickData( symbol=symbol, + name=symbol_name_map.get(symbol, ""), exchange=Exchange.HUOBI, datetime=datetime.now(), gateway_name=self.gateway_name, From 20f80327079b588c3127038b4f09008ae000a0a6 Mon Sep 17 00:00:00 2001 From: "vn.py" Date: Fri, 5 Apr 2019 11:06:50 +0800 Subject: [PATCH 22/49] [Fix]Close #1559 --- vnpy/app/csv_loader/engine.py | 45 +++++++++++++++++++------------- vnpy/app/csv_loader/ui/widget.py | 7 ++--- 2 files changed, 31 insertions(+), 21 deletions(-) diff --git a/vnpy/app/csv_loader/engine.py b/vnpy/app/csv_loader/engine.py index af4a49a8..2c384d02 100644 --- a/vnpy/app/csv_loader/engine.py +++ b/vnpy/app/csv_loader/engine.py @@ -23,9 +23,11 @@ Sample csv file: import csv from datetime import datetime +from peewee import chunked + from vnpy.event import EventEngine from vnpy.trader.constant import Exchange, Interval -from vnpy.trader.database import DbBarData +from vnpy.trader.database import DbBarData, DB from vnpy.trader.engine import BaseEngine, MainEngine @@ -75,30 +77,37 @@ class CsvLoaderEngine(BaseEngine): with open(file_path, 'rt') as f: reader = csv.DictReader(f) + db_bars = [] + for item in reader: - db_bar = DbBarData() + dt = datetime.strptime(item[datetime_head], datetime_format) - db_bar.symbol = symbol - db_bar.exchange = exchange.value - db_bar.datetime = datetime.strptime( - item[datetime_head], datetime_format - ) - db_bar.interval = interval.value - db_bar.volume = item[volume_head] - db_bar.open_price = item[open_head] - db_bar.high_price = item[high_head] - db_bar.low_price = item[low_head] - db_bar.close_price = item[close_head] - db_bar.vt_symbol = vt_symbol - db_bar.gateway_name = "DB" + db_bar = { + "symbol": symbol, + "exchange": exchange.value, + "datetime": dt, + "interval": interval.value, + "volume": item[volume_head], + "open_price": item[open_head], + "high_price": item[high_head], + "low_price": item[low_head], + "close_price": item[close_head], + "vt_symbol": vt_symbol, + "gateway_name": "DB" + } - db_bar.replace() + db_bars.append(db_bar) # do some statistics count += 1 if not start: - start = db_bar.datetime + start = db_bar["datetime"] - end = db_bar.datetime + end = db_bar["datetime"] + + # Insert into DB + with DB.atomic(): + for batch in chunked(db_bars, 500): + DbBarData.insert_many(batch).on_conflict_replace().execute() return start, end, count diff --git a/vnpy/app/csv_loader/ui/widget.py b/vnpy/app/csv_loader/ui/widget.py index 6e89741c..30160d77 100644 --- a/vnpy/app/csv_loader/ui/widget.py +++ b/vnpy/app/csv_loader/ui/widget.py @@ -27,8 +27,8 @@ class CsvLoaderWidget(QtWidgets.QWidget): self.setFixedWidth(300) self.setWindowFlags( - (self.windowFlags() | QtCore.Qt.CustomizeWindowHint) - & ~QtCore.Qt.WindowMaximizeButtonHint) + (self.windowFlags() | QtCore.Qt.CustomizeWindowHint) & + ~QtCore.Qt.WindowMaximizeButtonHint) file_button = QtWidgets.QPushButton("选择文件") file_button.clicked.connect(self.select_file) @@ -90,7 +90,8 @@ class CsvLoaderWidget(QtWidgets.QWidget): def select_file(self): """""" - result: str = QtWidgets.QFileDialog.getOpenFileName(self) + result: str = QtWidgets.QFileDialog.getOpenFileName( + self, filter="CSV (*.csv)") filename = result[0] if filename: self.file_edit.setText(filename) From 81454e7dfb86c73d2270f5eeaee373fb4a57ec14 Mon Sep 17 00:00:00 2001 From: "vn.py" Date: Fri, 5 Apr 2019 11:08:31 +0800 Subject: [PATCH 23/49] [Fix]Close #1551 --- vnpy/app/cta_strategy/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vnpy/app/cta_strategy/engine.py b/vnpy/app/cta_strategy/engine.py index cb29a652..9bff635b 100644 --- a/vnpy/app/cta_strategy/engine.py +++ b/vnpy/app/cta_strategy/engine.py @@ -541,7 +541,7 @@ class CtaEngine(BaseEngine): DbBarData.select() .where( (DbBarData.vt_symbol == vt_symbol) - & (DbBarData.interval == interval) + & (DbBarData.interval == interval.value) & (DbBarData.datetime >= start) & (DbBarData.datetime <= end) ) From f16a87990569ad37b5e8155e71d72bb3b1d4c455 Mon Sep 17 00:00:00 2001 From: "vn.py" Date: Fri, 5 Apr 2019 12:32:48 +0800 Subject: [PATCH 24/49] [Add]LocalOrderManager and use it for async order id map in HuobiGateway --- vnpy/app/csv_loader/ui/widget.py | 4 +- vnpy/gateway/huobi/huobi_gateway.py | 133 +++++++++------------------- vnpy/trader/gateway.py | 122 +++++++++++++++++++++++++ 3 files changed, 165 insertions(+), 94 deletions(-) diff --git a/vnpy/app/csv_loader/ui/widget.py b/vnpy/app/csv_loader/ui/widget.py index 30160d77..79ef9e0c 100644 --- a/vnpy/app/csv_loader/ui/widget.py +++ b/vnpy/app/csv_loader/ui/widget.py @@ -27,8 +27,8 @@ class CsvLoaderWidget(QtWidgets.QWidget): self.setFixedWidth(300) self.setWindowFlags( - (self.windowFlags() | QtCore.Qt.CustomizeWindowHint) & - ~QtCore.Qt.WindowMaximizeButtonHint) + (self.windowFlags() | QtCore.Qt.CustomizeWindowHint) + & ~QtCore.Qt.WindowMaximizeButtonHint) file_button = QtWidgets.QPushButton("选择文件") file_button.clicked.connect(self.select_file) diff --git a/vnpy/gateway/huobi/huobi_gateway.py b/vnpy/gateway/huobi/huobi_gateway.py index b6dfd4bb..5572f5fe 100644 --- a/vnpy/gateway/huobi/huobi_gateway.py +++ b/vnpy/gateway/huobi/huobi_gateway.py @@ -24,7 +24,7 @@ from vnpy.trader.constant import ( Status, OrderType ) -from vnpy.trader.gateway import BaseGateway +from vnpy.trader.gateway import BaseGateway, LocalOrderManager from vnpy.trader.object import ( TickData, OrderData, @@ -81,12 +81,7 @@ class HuobiGateway(BaseGateway): """Constructor""" super(HuobiGateway, self).__init__(event_engine, "HUOBI") - self.order_count = 100000 - - self.local_huobi_map = {} # local orderid: huobi orderid - self.huobi_local_map = {} # huobi orderid: local orderid - self.local_order_map = {} # local orderid: order - self.huobi_order_data = {} # huobi orderid: data + self.order_manager = LocalOrderManager(self) self.rest_api = HuobiRestApi(self) self.trade_ws_api = HuobiTradeWebsocketApi(self) @@ -147,49 +142,6 @@ class HuobiGateway(BaseGateway): self.count = 0 self.event_engine.register(EVENT_TIMER, self.process_timer_event) - def get_local_orderid(self, huobi_orderid: str): - """""" - local_orderid = self.huobi_local_map.get(huobi_orderid, None) - - if not local_orderid: - local_orderid = self.new_local_orderid() - self.update_orderid_map(local_orderid, huobi_orderid) - - return local_orderid - - def get_huobi_orderid(self, local_orderid: str): - """""" - huobi_orderid = self.local_huobi_map.get(local_orderid, "") - return huobi_orderid - - def new_local_orderid(self): - """""" - self.order_count += 1 - return str(self.order_count) - - def update_orderid_map(self, local_orderid: str, huobi_orderid: str): - """""" - self.huobi_local_map[huobi_orderid] = local_orderid - self.local_huobi_map[local_orderid] = huobi_orderid - - if huobi_orderid in self.huobi_order_data: - data = self.huobi_order_data.pop(huobi_orderid) - self.trade_ws_api.on_order(data) - - def on_order(self, order: OrderData): - """""" - self.local_order_map[order.orderid] = order - - super().on_order(copy(order)) - - def get_order(self, huobi_orderid: str): - """""" - local_orderid = self.huobi_local_map.get(huobi_orderid, None) - if not local_orderid: - return None - else: - return self.local_order_map[local_orderid] - class HuobiRestApi(RestClient): """ @@ -202,6 +154,7 @@ class HuobiRestApi(RestClient): self.gateway = gateway self.gateway_name = gateway.gateway_name + self.order_manager = gateway.order_manager self.host = "" self.key = "" @@ -300,7 +253,7 @@ class HuobiRestApi(RestClient): (req.direction, req.type), "" ) - local_orderid = self.gateway.new_local_orderid() + local_orderid = self.order_manager.new_local_orderid() order = req.create_order_data( local_orderid, self.gateway_name @@ -324,19 +277,14 @@ class HuobiRestApi(RestClient): extra=order, ) - self.gateway.on_order(order) + self.order_manager.on_order(order) return order.vt_orderid def cancel_order(self, req: CancelRequest): """""" - local_id = req.orderid - huobi_orderid = self.gateway.get_huobi_orderid(local_id) + sys_orderid = self.order_manager.get_sys_orderid(req.orderid) - if not huobi_orderid: - self.cancel_requests[local_id] = req - return - - path = f"/v1/order/orders/{huobi_orderid}/submitcancel" + path = f"/v1/order/orders/{sys_orderid}/submitcancel" self.add_request( method="POST", path=path, @@ -344,9 +292,6 @@ class HuobiRestApi(RestClient): extra=req ) - if local_id in self.cancel_requests: - self.cancel_requests.pop(local_id) - def on_query_account(self, data, request): """""" if self.check_error(data, "查询账户"): @@ -387,8 +332,8 @@ class HuobiRestApi(RestClient): return for d in data["data"]: - huobi_orderid = d["id"] - local_orderid = self.gateway.get_local_orderid(huobi_orderid) + sys_orderid = d["id"] + local_orderid = self.order_manager.get_local_orderid(sys_orderid) direction, order_type = ORDERTYPE_HUOBI2VT[d["type"]] dt = datetime.fromtimestamp(d["created-at"] / 1000) @@ -408,7 +353,7 @@ class HuobiRestApi(RestClient): gateway_name=self.gateway_name, ) - self.gateway.on_order(order) + self.order_manager.on_order(order) self.gateway.write_log("委托信息查询成功") @@ -446,15 +391,11 @@ class HuobiRestApi(RestClient): if self.check_error(data, "委托"): order.status = Status.REJECTED - self.gateway.on_order(order) + self.order_manager.on_order(order) return - huobi_orderid = str(data["data"]) - self.gateway.update_orderid_map(order.orderid, huobi_orderid) - - req = self.cancel_requests.get(order.orderid, None) - if req: - self.cancel_order(req) + sys_orderid = data["data"] + self.order_manager.update_orderid_map(order.orderid, sys_orderid) def on_cancel_order(self, data, request): """""" @@ -463,12 +404,11 @@ class HuobiRestApi(RestClient): cancel_request = request.extra local_orderid = cancel_request.orderid - huobi_orderid = self.gateway.get_huobi_orderid(local_orderid) - order = self.gateway.get_order(huobi_orderid) + order = self.order_manager.get_order_with_local_orderid(local_orderid) order.status = Status.CANCELLED - self.gateway.on_order(copy(order)) + self.order_manager.on_order(order) self.gateway.write_log(f"委托撤单成功:{order.orderid}") def check_error(self, data: dict, func: str = ""): @@ -562,6 +502,9 @@ class HuobiTradeWebsocketApi(HuobiWebsocketApiBase): """""" super().__init__(gateway) + self.order_manager = gateway.order_manager + self.order_manager.push_data_callback = self.on_data + self.req_id = 0 def connect(self, key, secret, proxy_host, proxy_port): @@ -599,30 +542,36 @@ class HuobiTradeWebsocketApi(HuobiWebsocketApiBase): def on_order(self, data: dict): """""" - huobi_orderid = str(data["order-id"]) - order = self.gateway.get_order(huobi_orderid) + sys_orderid = str(data["order-id"]) + + order = self.order_manager.get_order_with_sys_orderid(sys_orderid) if not order: - self.gateway.huobi_order_data[huobi_orderid] = data + self.order_manager.add_push_data(sys_orderid, data) return traded_volume = float(data["filled-amount"]) + + # Push order event order.traded += traded_volume order.status = STATUS_HUOBI2VT.get(data["order-state"], None) - self.gateway.on_order(order) + self.order_manager.on_order(order) - if traded_volume: - trade = TradeData( - symbol=order.symbol, - exchange=Exchange.HUOBI, - orderid=order.orderid, - tradeid=str(data["seq-id"]), - direction=order.direction, - price=float(data["price"]), - volume=float(data["filled-amount"]), - time=datetime.now().strftime("%H:%M:%S"), - gateway_name=self.gateway_name, - ) - self.gateway.on_trade(trade) + # Push trade event + if not traded_volume: + return + + trade = TradeData( + symbol=order.symbol, + exchange=Exchange.HUOBI, + orderid=order.orderid, + tradeid=str(data["seq-id"]), + direction=order.direction, + price=float(data["price"]), + volume=float(data["filled-amount"]), + time=datetime.now().strftime("%H:%M:%S"), + gateway_name=self.gateway_name, + ) + self.gateway.on_trade(trade) class HuobiDataWebsocketApi(HuobiWebsocketApiBase): diff --git a/vnpy/trader/gateway.py b/vnpy/trader/gateway.py index d1a6248f..ce7aa591 100644 --- a/vnpy/trader/gateway.py +++ b/vnpy/trader/gateway.py @@ -4,6 +4,7 @@ from abc import ABC, abstractmethod from typing import Any +from copy import copy from vnpy.event import Event, EventEngine from .event import ( @@ -227,3 +228,124 @@ class BaseGateway(ABC): Return default setting dict. """ return self.default_setting + + +class LocalOrderManager: + """ + Management tool to support use local order id for trading. + """ + + def __init__(self, gateway: BaseGateway): + """""" + self.gateway = gateway + + # For generating local orderid + self.order_prefix = "" + self.order_count = 0 + self.orders = {} # local_orderid:order + + # Map between local and system orderid + self.local_sys_orderid_map = {} + self.sys_local_orderid_map = {} + + # Push order data buf + self.push_data_buf = {} # sys_orderid:data + + # Callback for processing push order data + self.push_data_callback = None + + # Cancel request buf + self.cancel_request_buf = {} # local_orderid:req + + def new_local_orderid(self): + """ + Generate a new local orderid. + """ + self.order_count += 1 + local_orderid = str(self.order_count).rjust(8, "0") + return local_orderid + + def get_local_orderid(self, sys_orderid: str): + """ + Get local orderid with sys orderid. + """ + local_orderid = self.sys_local_orderid_map.get(sys_orderid, "") + + if not local_orderid: + local_orderid = self.new_local_orderid() + self.update_orderid_map(local_orderid, sys_orderid) + + return local_orderid + + def get_sys_orderid(self, local_orderid: str): + """ + Get sys orderid with local orderid. + """ + sys_orderid = self.local_sys_orderid_map.get(local_orderid, "") + return sys_orderid + + def update_orderid_map(self, local_orderid: str, sys_orderid: str): + """ + Update orderid map. + """ + self.sys_local_orderid_map[sys_orderid] = local_orderid + self.local_sys_orderid_map[local_orderid] = sys_orderid + + self.check_cancel_request(local_orderid) + self.check_push_data(sys_orderid) + + def check_push_data(self, sys_orderid: str): + """ + Check if any order push data waiting. + """ + if sys_orderid not in self.push_data_buf: + return + + data = self.push_data_buf.pop(sys_orderid) + if self.push_data_callback: + self.push_data_callback(data) + + def add_push_data(self, sys_orderid: str, data: dict): + """ + Add push data into buf. + """ + self.push_data_buf[sys_orderid] = data + + def get_order_with_sys_orderid(self, sys_orderid: str): + """""" + local_orderid = self.sys_local_orderid_map.get(sys_orderid, None) + if not local_orderid: + return None + else: + return self.get_order_with_local_orderid(local_orderid) + + def get_order_with_local_orderid(self, local_orderid: str): + """""" + order = self.orders[local_orderid] + return copy(order) + + def on_order(self, order: OrderData): + """ + Keep an order buf before pushing it to gateway. + """ + self.orders[order.orderid] = copy(order) + self.gateway.on_order(order) + + def cancel_order(self, req: CancelRequest): + """ + """ + sys_orderid = self.get_sys_orderid(req.orderid) + if not sys_orderid: + self.cancel_request_buf[req.orderid] = req + return + + self.gateway.cancel_order(req) + + def check_cancel_request(self, local_orderid: str): + """ + """ + if local_orderid not in self.cancel_request_buf: + return + + req = self.cancel_request_buf.pop(local_orderid) + self.gateway.cancel_order(req) From 36295b96c56b71d8c9bb1edb6178b88a3226f88c Mon Sep 17 00:00:00 2001 From: "vn.py" Date: Sat, 6 Apr 2019 13:32:56 +0800 Subject: [PATCH 25/49] [Add]return engine/gateway object when adding related class to MainEngine --- vnpy/trader/engine.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vnpy/trader/engine.py b/vnpy/trader/engine.py index 251d2853..a2e098ee 100644 --- a/vnpy/trader/engine.py +++ b/vnpy/trader/engine.py @@ -52,6 +52,7 @@ class MainEngine: """ engine = engine_class(self, self.event_engine) self.engines[engine.engine_name] = engine + return engine def add_gateway(self, gateway_class: BaseGateway): """ @@ -59,6 +60,7 @@ class MainEngine: """ gateway = gateway_class(self.event_engine) self.gateways[gateway.gateway_name] = gateway + return gateway def add_app(self, app_class: BaseApp): """ @@ -67,7 +69,8 @@ class MainEngine: app = app_class() self.apps[app.app_name] = app - self.add_engine(app.engine_class) + engine = self.add_engine(app.engine_class) + return engine def init_engines(self): """ From 566638426cfef26874240275f91dbb6923388fa1 Mon Sep 17 00:00:00 2001 From: "vn.py" Date: Sat, 6 Apr 2019 14:37:15 +0800 Subject: [PATCH 26/49] [Fix]close #1498 --- vnpy/gateway/ctp/ctp_gateway.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/vnpy/gateway/ctp/ctp_gateway.py b/vnpy/gateway/ctp/ctp_gateway.py index cb3014e0..f1a25811 100644 --- a/vnpy/gateway/ctp/ctp_gateway.py +++ b/vnpy/gateway/ctp/ctp_gateway.py @@ -549,13 +549,16 @@ class CtpTdApi(TdApi): product=product, size=data["VolumeMultiple"], pricetick=data["PriceTick"], - option_underlying=data["UnderlyingInstrID"], - option_type=OPTIONTYPE_CTP2VT.get(data["OptionsType"], None), - option_strike=data["StrikePrice"], - option_expiry=datetime.strptime(data["ExpireDate"], "%Y%m%d"), gateway_name=self.gateway_name ) + # For option only + if data["OptionsType"]: + contract.option_underlying = data["UnderlyingInstrID"], + contract.option_type = OPTIONTYPE_CTP2VT.get(data["OptionsType"], None), + contract.option_strike = data["StrikePrice"], + contract.option_expiry = datetime.strptime(data["ExpireDate"], "%Y%m%d"), + self.gateway.on_contract(contract) symbol_exchange_map[contract.symbol] = contract.exchange From be62354c1f4abe00c2a577ce1679777e67314b30 Mon Sep 17 00:00:00 2001 From: "vn.py" Date: Sat, 6 Apr 2019 20:36:04 +0800 Subject: [PATCH 27/49] [Add]algo trading app --- vnpy/app/algo_trading/__init__.py | 17 +++ vnpy/app/algo_trading/algos/__init__.py | 59 ++++++++++ vnpy/app/algo_trading/algos/iceberg_algo.py | 0 vnpy/app/algo_trading/algos/sniper_algo.py | 0 vnpy/app/algo_trading/algos/twap_algo.py | 0 vnpy/app/algo_trading/engine.py | 65 +++++++++++ vnpy/app/algo_trading/template.py | 123 ++++++++++++++++++++ vnpy/app/algo_trading/ui/algo.ico | Bin 0 -> 67646 bytes vnpy/app/algo_trading/ui/widget.py | 0 9 files changed, 264 insertions(+) create mode 100644 vnpy/app/algo_trading/__init__.py create mode 100644 vnpy/app/algo_trading/algos/__init__.py create mode 100644 vnpy/app/algo_trading/algos/iceberg_algo.py create mode 100644 vnpy/app/algo_trading/algos/sniper_algo.py create mode 100644 vnpy/app/algo_trading/algos/twap_algo.py create mode 100644 vnpy/app/algo_trading/engine.py create mode 100644 vnpy/app/algo_trading/template.py create mode 100644 vnpy/app/algo_trading/ui/algo.ico create mode 100644 vnpy/app/algo_trading/ui/widget.py diff --git a/vnpy/app/algo_trading/__init__.py b/vnpy/app/algo_trading/__init__.py new file mode 100644 index 00000000..e5be3d37 --- /dev/null +++ b/vnpy/app/algo_trading/__init__.py @@ -0,0 +1,17 @@ +from pathlib import Path + +from vnpy.trader.app import BaseApp + +from .engine import AlgoEngine, APP_NAME + + +class CtaStrategyApp(BaseApp): + """""" + + app_name = APP_NAME + app_module = __module__ + app_path = Path(__file__).parent + display_name = "算法交易" + engine_class = AlgoEngine + widget_name = "AlgoManager" + icon_name = "algo.ico" diff --git a/vnpy/app/algo_trading/algos/__init__.py b/vnpy/app/algo_trading/algos/__init__.py new file mode 100644 index 00000000..d6fa977a --- /dev/null +++ b/vnpy/app/algo_trading/algos/__init__.py @@ -0,0 +1,59 @@ +# encoding: UTF-8 + +''' +动态载入所有的策略类 +''' +from __future__ import print_function + +import os +import importlib +import traceback + + +# 用来保存算法类和控件类的字典 +ALGO_DICT = {} +WIDGET_DICT = {} + + +#---------------------------------------------------------------------- +def loadAlgoModule(path, prefix): + """使用importlib动态载入算法""" + for root, subdirs, files in os.walk(path): + for name in files: + # 只有文件名以Algo.py结尾的才是算法文件 + if len(name)>7 and name[-7:] == 'Algo.py': + try: + # 模块名称需要模块路径前缀 + moduleName = prefix + name.replace('.py', '') + module = importlib.import_module(moduleName) + + # 获取算法类和控件类 + algo = None + widget = None + + for k in dir(module): + # 以Algo结尾的类,是算法 + if k[-4:] == 'Algo': + algo = module.__getattribute__(k) + + # 以Widget结尾的类,是控件 + if k[-6:] == 'Widget': + widget = module.__getattribute__(k) + + # 保存到字典中 + if algo and widget: + ALGO_DICT[algo.templateName] = algo + WIDGET_DICT[algo.templateName] = widget + except: + print ('-' * 20) + print ('Failed to import strategy file %s:' %moduleName) + traceback.print_exc() + + +# 遍历algo目录下的文件 +path1 = os.path.abspath(os.path.dirname(__file__)) +loadAlgoModule(path1, 'vnpy.trader.app.algoTrading.algo.') + +# 遍历工作目录下的文件 +path2 = os.getcwd() +loadAlgoModule(path2, '') \ No newline at end of file diff --git a/vnpy/app/algo_trading/algos/iceberg_algo.py b/vnpy/app/algo_trading/algos/iceberg_algo.py new file mode 100644 index 00000000..e69de29b diff --git a/vnpy/app/algo_trading/algos/sniper_algo.py b/vnpy/app/algo_trading/algos/sniper_algo.py new file mode 100644 index 00000000..e69de29b diff --git a/vnpy/app/algo_trading/algos/twap_algo.py b/vnpy/app/algo_trading/algos/twap_algo.py new file mode 100644 index 00000000..e69de29b diff --git a/vnpy/app/algo_trading/engine.py b/vnpy/app/algo_trading/engine.py new file mode 100644 index 00000000..3c5224dd --- /dev/null +++ b/vnpy/app/algo_trading/engine.py @@ -0,0 +1,65 @@ + +from vnpy.event import EventEngine +from vnpy.trader.engine import BaseEngine, MainEngine +from vnpy.trader.event import (EVENT_TICK, EVENT_TIMER, EVENT_ORDER, EVENT_TRADE) + + +class AlgoEngine(BaseEngine): + """""" + + def __init__(self, main_engine: MainEngine, event_engine: EventEngine): + """Constructor""" + super().__init__(main_engine, event_engine) + + self.algos = {} + self.symbol_algo_map = {} + self.orderid_algo_map = {} + + self.register_event() + + def register_event(self): + """""" + self.event_engine.register(EVENT_TICK, self.process_tick_event) + self.event_engine.register(EVENT_TIMER, self.process_timer_event) + self.event_engine.register(EVENT_ORDER, self.process_order_event) + self.event_engine.register(EVENT_TRADE, self.process_trade_event) + + def process_tick_event(self): + """""" + pass + + def process_timer_event(self): + """""" + pass + + def process_trade_event(self): + """""" + pass + + def process_order_event(self): + """""" + pass + + def start_algo(self, setting: dict): + """""" + pass + + def stop_algo(self, algo_name: dict): + """""" + pass + + def stop_all(self): + """""" + pass + + def subscribe(self, algo, vt_symbol): + """""" + pass + + def send_order( + self, + algo, + vt_symbol + ): + """""" + pass diff --git a/vnpy/app/algo_trading/template.py b/vnpy/app/algo_trading/template.py new file mode 100644 index 00000000..1ba0c5a4 --- /dev/null +++ b/vnpy/app/algo_trading/template.py @@ -0,0 +1,123 @@ +from vnpy.trader.engine import BaseEngine +from vnpy.trader.object import TickData, OrderData, TradeData +from vnpy.trader.constant import OrderType, Offset + +class AlgoTemplate: + """""" + count = 0 + + def __init__( + self, + algo_engine: BaseEngine, + algo_name: str, + setting: dict + ): + """Constructor""" + self.algo_engine = algo_engine + self.algo_name = algo_name + + self.active = False + self.active_orders = {} # vt_orderid:order + + @staticmethod + def new(cls, algo_engine:BaseEngine, setting: dict): + """Create new algo instance""" + cls.count += 1 + algo_name = f"{cls.__name__}_{cls.count}" + algo = cls(algo_engine, algo_name, setting) + + def update_tick(self, tick: TickData): + """""" + if self.active: + self.on_tick(tick) + + def update_order(self, order: OrderData): + """""" + if self.active: + if order.is_active(): + self.active_orders[order.vt_orderid] = order + elif order.vt_orderid in self.active_orders: + self.active_orders.pop(order.vt_orderid) + + self.on_order(order) + + def update_trade(self, trade: TradeData): + """""" + if self.active: + self.on_trade(trade) + + def update_timer(self): + """""" + if self.active: + self.on_timer() + + def on_start(self): + """""" + pass + + def on_stop(self): + """""" + pass + + def on_tick(self, tick: TickData): + """""" + pass + + def on_order(self, order: OrderData): + """""" + pass + + def on_trade(self, trade: TradeData): + """""" + pass + + def on_timer(self): + """""" + pass + + def start(self): + """""" + pass + + def stop(self): + """""" + pass + + def buy( + self, + vt_symbol, + price, + volume, + order_type: OrderType = OrderType.LIMIT, + offset: Offset = Offset.NONE + ): + """""" + return self.algo_engine.buy( + vt_symbol, + price, + volume, + order_type, + offset + ) + + def sell( + self, + vt_symbol, + price, + volume, + order_type: OrderType = OrderType.LIMIT, + offset: Offset = Offset.NONE + ): + """""" + return self.algo_engine.buy( + vt_symbol, + price, + volume, + order_type, + offset + ) + + + + + \ No newline at end of file diff --git a/vnpy/app/algo_trading/ui/algo.ico b/vnpy/app/algo_trading/ui/algo.ico new file mode 100644 index 0000000000000000000000000000000000000000..83114df81f5fca19c9462e166dd97f5f032d20b2 GIT binary patch literal 67646 zcmeI53zQYbna6v0M2IGsY(!T*4h98cVssN1cawDo?hG(5n8U{0Ka z?wUl+W>*958XsBL=m_zVBa2EB#?|!@v%1F(C@8`NjCv3}5TlMTOz-dO>AF+5=DvDv z-|m@vyW!7Q)m>Fz)%XAE(bat$%PQedbF;<2Bdjq)O040QWsL!xl9SGP2})??S(byY zc=T^%fye@p1tJSX7KkhmS)lMOu=U23RxT)f3l(k)4c@;R)BTjCGa6Q2KIvcz-llBh=0WWAc(pq)1i1mR)Yhex4f%-?@`7K;EVyu zbQZ!nojW&tI+@1}p&lA<0|-w5Pl2N# z6ZiV{mq|Yxd;&Or5Tw0?A20x{0LMVCysP}3l=W>e5Cqwzz;r3TkTTxC19V=!geC7{ zAZvy|eL%PYgTb$X;?6HAW55s)7$-p~Qam9~1zUg@zLTZhMHycMVetQHIQo4a_}@*L zpDkei2f^1(;A4=CQ?KV*DCe^vFdqHq0Dh0-^~cLe3zROkoSV#@|-H)cFL4}qCn;V;RXx`uK>SzSLN)b%riis z{N!7i%S|cLk9D8v2OX4qalrh`zTXRZ;9u61@+=|WlmuG7%S{ZqADH~*8W;X+Tpf9D zmv95*9{mn*9H_pVtCaiydw|?q=Di(3DsZ`xwcooSSH8{s!v9oPM#mMKnk5_o>8C5fQJ^^Nt?ZzT?*f@u0%=2# zN?dMELOuX;<=f0JbIN(HjzF3v9J!FS+9uFj-c>%?OB~^r8Auy~RHC>+-VRLua*OY( zdvN2fy3dJ@#WTtSX_n&2mE({*f#TAue1p7|Zds0B1=5HBm1z9WCoIqLa^u|1|L~4F z``%|}-nRCZ+97LySYyqrjs?~#jjut3F9aLG`ygBX<=u~8f{TE&=0oPK{==Vs*D{ee zWk$go|1t;s+zEeFtHp|m|K<4WI%f$y6@9hU(yv$y6eU&}8GG@d1 zelQD^2>*hh4xK;3@?7{laP5Puu6XwD6$?5ZT6KjjAbkAi*Dcxhn;TcI;=1hdJI{HM zc^UH!u1_W&16P67U@^D^oCut8Aqac9J|KJbG7t3n8BojiS@z98nQSkZ>~Jsq3-*C} z;a*1&?Ghb|N0BexdcDZdqG4gh}- z9tV4X?czA{HP8b73Y37n`+>{9?9&bf*MPmi+gjk&x+AyVed*pMmsK2xXT`me;sY`c z#C;$i_RUl#`ehti4cg~cjkg7akAh|}AXENat}7mq!oBqOH-O?jncU9xIUv41$VGp; zyl}0a1BsM(fGz@$f_9L}4`iR`xedGOrYv7Han!u3*pt|@6DaPLe3NsRfzW&a-Esc4 z`UTCujPI{!;9mK|i=?|i`9L}jbG-(Kx(`TLp5LzovMziE=;$hm|5=5lV#1?_k%Uvt1);15`T+m)J#sJpaST)i$$v5@OXd^F(rNsQFPxpm2V^b? zbUlE+P{x179dbFyrq7#cE>mSBvw2@Thn@9tAl-T(&^2Kn;J*PLk9zch4@f_uKk+aA zX!wAv2U4vG16vP-!N1sH`n-kKUv7HSI&Je?7oYmnfgK+Cf1dQC`y2mBd_cy5P}c)t z@h?7b;q8xEvu94L|MI#g4xF;(kU74mi{A|QgM|P7cu3nd$a>%@km}wruj>Ky20H(1 z?G_tMo?kV7&&-BK``r6~YoGF~_tJ5$u00L!rTww*>%MB(LGBAu-5ch0J%HZ8_&3^o zDe=9)*$b;$w%Q(j-?Mgxao|aKFIU`)?9Yc`k5JYF)FJQu8|{>SKNai&Y9G{Q9I$6z z-f)~b;g#V}AKV7F9|7s}e+JKjwcsp}_r3oheaf&+;OhbO2hG3q`!{s%Rd{}7+19%H zx{H_xM#1wqPz&Ul&`99S_sAgW2}Bp|4{bew?!5AEv~4o+_?S;P-v0@=f>9GzHmiNX zlrj)f0#-uXbsq?9J%HZ4@o%(A`h6SFIalFdbAF6U-#DeJ2loNf5sEt1epjYj56C`Y z9T4u~e$$&uyJa}fe+er-a(e!!*U1g{nC~@Rzao5|o~JkMD3m<99SF;MU^7T|Zzy|2 z;sbI1CG7PekMI58sM8eUN#}dcZv_$e25E-<&L*C>{Q=7Mm4Bn|GT$ex^NGtCBmUe9 zMuE8h8fGZ!sBRa+p9gu_AD~>{_&4e-+_wQWr|2?azL$K_(bEEkJw2@ldX|>_3_5(_ z->7Re-zPV>=d*@g#Rq&{5BR~qQO9)i{clJc&G#O?z_72c>w#SPH|iG6_a3>;Ho0ME z@qtw1fT!mLx$tk)DbsvET2G&CyQ6g77SQcSIMaGS-Yu88-C19Emy#HA)2v3e1GPrW zc>k};b)}5?$=3N&UUYw-X!4w5hTT1_2ku%@Wj(yCCIOrqw5T?=W?@aN9m#J$%5B0& z7uLpZyrpT=8Nd4>VN9blwNS>*-(| zxZ(b*t9RXVMYZkV$}aiuS5?2kwY8ueoB`s#FsJ*S=2Y{&^!rv2jrqwb%=s*~)%#(_ z^}v=@ldZ?DsT}~17l7-)E8qig6xbUK$T@Hfd@|>w$esr33+ARC^P~AbJ#D^TP2vMR_6e_^ zQf~ibVdbCTPS6|P^*rydUsQEuDP^KtM<;RF*Ng7&ePuXvVTKPRtOu4(t+4N0ROO8C zK5*~&fy@>EwwT`vPpi=RkLLT#ROKy&;R9#St%_}?-yfr|+RWK{|Mev-W&SHZz`W9d zF6sBte4jUVvnyD9U|RV&>*j_@*2)NI6oLDVOYX;Ci>j@t;IrBS?j>G7h&)RaJ=%)Pe8Xsv^_xNYRgWTWct?wIU z?)+o2+0b?t5(?^DSLc=?DKf_8`(>~R@I|;nFy!l%I?VffmiPQ zO8=;;M>tj}UaiXR$c-WL|0KF){aL(3gw%4eIgs_Y^g+M+v)6K-Mz`~vI3zO^ty=6Z z{7c_^1bF4muk=r$Tiz`wS{}k`xY*o9*yF!bkFe}eta>&6?;z|~fA(6=J*pnX%0pP~ zR(40u1P6gv?s83k7@dtE>i;1*fouD-kgtPW`ShB9KRVBmIz&e}77*J5S${taI)PUX zb4mXM`i6pV8Xsv@>pzj@;3JR=pI-9HJEXNB8v8@`Lzjb-kwM-EZbIKF5&t2RkNQVe z)L6WmXFa&I=I(GLl*yXR?G|kZ+z5f-C^NHOxINw_0cU;1cEj;PG z`4#?Q(jwVci~Sv==UAy*5Ch^u>9UI=WkLK7DM@KjlsQ)`N@8R z-N|pYTA8E22>q0EfGIaT$`p9PAyWMx}dX8N|`=xb+G zSPy)oTJFz>fWHNIf&T`70IAjo@*dgq;2v-Z7|K0`^{tr`WS+_Rn^V1x#mkicE!XWz zIO}gOJbF#7zDFeC)4)Wq2K)fr3T|`5-$hn|(?LQ$O@F%IoXUD2t8Y;FZ=u30QM0>R z>!-b7i0S_JUC-Sr4T8zid2xlPa|pit~zL=d9;@$(QfL zs9Uri@X_FgZPngRvTrSg`t1dyp5g;NzCpn`wLegAr1#Ncg;kVcCtv3Kd|Ak->qz3! z{y+}KHtd;czV8#?(KhOwZaok^FG%AXh8?Bf?*;mrT7~=Ex0Z~$%l}^>|Hq}OuP&qg z0mBCj+x6M`-l%qusL_sSJvQMb?wE9|8@=~Ab zIj>qenrbG;+m9{M&G%cm-Y4e!Ooq`pC0z9D0iExj;T=}(>=(obGR^m2(CwOugRrmb zbtNq80l7ES*8wVA&}%}2@6QpR0lp9J20yNz^IFsFDJ!nVHoKG9NABS=ha+)4#AZC{APuNvz=LAVS&2i^g;hC`FyiZ^a({V!qF0Puii61F{}SwMUdN7c{L+=qkc8tXR@c^19D%GYCX{P%xG53 zYIxY;^ONu@T;C41fe*lO(Ba~34>#{Efz*JtKl`M@rFa{V)ZsGquK zV%@a0@jeqq8}gyk@PTyCjP}i|iq$ndQa%a}{|N}^b{5z{n)%>F5ch+8v?X1c8vo}J zemQ&o#SeD=v1-oi)ky2|OBbh8(=V?XK9GDp!29b>-;BN0@M!tV@F@K}8}3z_L!{pb zMgzwOf;N=Ke~fU4H~yWm;G3$9pz9ulKEnqR*8^2Cd+x%R-SChn&K38%+zkIGhv5T? zJ7i@qxRHiuO zJpV$9_`v+H#H^<4V^-rW<+EAyA8dHsH~z(U<^|2av`g`aTn3JTY<=KP(hdevOi>&Q z&g5QC)@r{jzYh5^$feJld7DD!|1)rTUzYsK{@~fdy`WeQhx@Lz`W-%VFTM~8|H7T( z5qUmX2hxoLk8yn*C@Sv7PA>nlzxO~P@o)IRFt}|5EkO2yWu12j$ogP8I0FJ zD@6W_ky9^icKH`SD5&+H7aJ6FT8;l@gtL9tW8UYbiifuTFJ?}?v|Zz0p7*@~QgPXv z^K+#9=;+%PaD71b>ZQMa2ztY{ndf=R5ufVYjgM-l)_@{qAK*WL$)(r$dz9gf{m96| z{lEgklZ!7Pe+#^FuF|)Y_ckE+6i%7_p!r2#Ma30T=7ICU1|VE2UbB%0NPjK(3lMSN z&%8=;=E`A6d4|6oWa9oPY2-aXc_(xrh`8@(+^gy;?p!HjfxOrGBk)tO3G4wMf!~1^ zupY>>!{I>9Arb$|p#4;uToNAyP6yJ*X8@Ta;_rA7jyUh9T&JqG;xt>Cs+{P2WP!*6 zkp&_PL>7oF5Lv+X=QwN)ODL$#I>C*1Iv1L)5;xxJ1TegcTQ?Buj$2ll9`CTMQX@`* zout2k)3(iZG$*CI^AU#cS;vG@}f^BZRQxp Date: Sun, 7 Apr 2019 16:31:20 +0800 Subject: [PATCH 28/49] [Mod]complete twap algo and app test --- tests/trader/run.py | 2 + vnpy/app/algo_trading/__init__.py | 3 +- vnpy/app/algo_trading/algos/__init__.py | 59 --- vnpy/app/algo_trading/algos/twap_algo.py | 112 +++++ vnpy/app/algo_trading/engine.py | 258 ++++++++-- vnpy/app/algo_trading/template.py | 140 ++++-- vnpy/app/algo_trading/ui/__init__.py | 1 + vnpy/app/algo_trading/ui/display.py | 15 + vnpy/app/algo_trading/ui/widget.py | 571 +++++++++++++++++++++++ vnpy/app/cta_strategy/engine.py | 3 +- 10 files changed, 1028 insertions(+), 136 deletions(-) create mode 100644 vnpy/app/algo_trading/ui/__init__.py create mode 100644 vnpy/app/algo_trading/ui/display.py diff --git a/tests/trader/run.py b/tests/trader/run.py index 0b7fb501..1ab073c5 100644 --- a/tests/trader/run.py +++ b/tests/trader/run.py @@ -15,6 +15,7 @@ from vnpy.gateway.huobi import HuobiGateway from vnpy.app.cta_strategy import CtaStrategyApp from vnpy.app.csv_loader import CsvLoaderApp +from vnpy.app.algo_trading import AlgoTradingApp def main(): @@ -35,6 +36,7 @@ def main(): main_engine.add_app(CtaStrategyApp) main_engine.add_app(CsvLoaderApp) + main_engine.add_app(AlgoTradingApp) main_window = MainWindow(main_engine, event_engine) main_window.showMaximized() diff --git a/vnpy/app/algo_trading/__init__.py b/vnpy/app/algo_trading/__init__.py index e5be3d37..dba58a66 100644 --- a/vnpy/app/algo_trading/__init__.py +++ b/vnpy/app/algo_trading/__init__.py @@ -3,9 +3,10 @@ from pathlib import Path from vnpy.trader.app import BaseApp from .engine import AlgoEngine, APP_NAME +from .template import AlgoTemplate -class CtaStrategyApp(BaseApp): +class AlgoTradingApp(BaseApp): """""" app_name = APP_NAME diff --git a/vnpy/app/algo_trading/algos/__init__.py b/vnpy/app/algo_trading/algos/__init__.py index d6fa977a..e69de29b 100644 --- a/vnpy/app/algo_trading/algos/__init__.py +++ b/vnpy/app/algo_trading/algos/__init__.py @@ -1,59 +0,0 @@ -# encoding: UTF-8 - -''' -动态载入所有的策略类 -''' -from __future__ import print_function - -import os -import importlib -import traceback - - -# 用来保存算法类和控件类的字典 -ALGO_DICT = {} -WIDGET_DICT = {} - - -#---------------------------------------------------------------------- -def loadAlgoModule(path, prefix): - """使用importlib动态载入算法""" - for root, subdirs, files in os.walk(path): - for name in files: - # 只有文件名以Algo.py结尾的才是算法文件 - if len(name)>7 and name[-7:] == 'Algo.py': - try: - # 模块名称需要模块路径前缀 - moduleName = prefix + name.replace('.py', '') - module = importlib.import_module(moduleName) - - # 获取算法类和控件类 - algo = None - widget = None - - for k in dir(module): - # 以Algo结尾的类,是算法 - if k[-4:] == 'Algo': - algo = module.__getattribute__(k) - - # 以Widget结尾的类,是控件 - if k[-6:] == 'Widget': - widget = module.__getattribute__(k) - - # 保存到字典中 - if algo and widget: - ALGO_DICT[algo.templateName] = algo - WIDGET_DICT[algo.templateName] = widget - except: - print ('-' * 20) - print ('Failed to import strategy file %s:' %moduleName) - traceback.print_exc() - - -# 遍历algo目录下的文件 -path1 = os.path.abspath(os.path.dirname(__file__)) -loadAlgoModule(path1, 'vnpy.trader.app.algoTrading.algo.') - -# 遍历工作目录下的文件 -path2 = os.getcwd() -loadAlgoModule(path2, '') \ No newline at end of file diff --git a/vnpy/app/algo_trading/algos/twap_algo.py b/vnpy/app/algo_trading/algos/twap_algo.py index e69de29b..32fc6a65 100644 --- a/vnpy/app/algo_trading/algos/twap_algo.py +++ b/vnpy/app/algo_trading/algos/twap_algo.py @@ -0,0 +1,112 @@ +from vnpy.trader.constant import Offset, Direction +from vnpy.trader.object import TradeData +from vnpy.trader.engine import BaseEngine + +from vnpy.app.algo_trading import AlgoTemplate + + +class TwapAlgo(AlgoTemplate): + """""" + + display_name = "TWAP 时间加权平均" + + default_setting = { + "vt_symbol": "", + "direction": [Direction.LONG.value, Direction.SHORT.value], + "price": 0.0, + "volume": 0.0, + "time": 600, + "interval": 60, + "offset": [ + Offset.NONE.value, + Offset.OPEN.value, + Offset.CLOSE.value, + Offset.CLOSETODAY.value, + Offset.CLOSEYESTERDAY.value + ] + } + + def __init__( + self, + algo_engine: BaseEngine, + algo_name: str, + setting: dict + ): + """""" + super().__init__(algo_engine, algo_name, setting) + + # Parameters + self.vt_symbol = setting["vt_symbol"] + self.direction = Direction(setting["direction"]) + self.price = setting["price"] + self.volume = setting["volume"] + self.time = setting["time"] + self.interval = setting["interval"] + self.offset = Offset(setting["offset"]) + + # Variables + self.order_volume = self.volume / (self.time / self.interval) + self.timer_count = 0 + self.total_count = 0 + self.traded = 0 + + self.variables.extend([ + "traded", + "order_volume", + "timer_count", + "total_count" + ]) + + self.subscribe(self.vt_symbol) + self.put_parameters_event() + self.put_variables_event() + + def on_trade(self, trade: TradeData): + """""" + self.traded += trade.volume + + if self.traded >= self.volume: + self.stop() + else: + self.put_variables_event() + + def on_timer(self): + """""" + self.timer_count += 1 + self.total_count += 1 + self.put_variables_event() + + if self.total_count >= self.time: + self.write_log("执行时间已结束,停止算法") + self.stop() + return + + if self.timer_count < self.interval: + return + self.timer_count = 0 + + tick = self.get_tick(self.vt_symbol) + if not tick: + return + + self.cancel_all() + + left_volume = self.volume - self.traded + order_volume = min(self.order_volume, left_volume) + + if self.direction == Direction.LONG: + if tick.ask_price_1 <= self.price: + self.buy(self.vt_symbol, self.price, + order_volume, offset=self.offset) + self.write_log( + f"委托买入{self.vt_symbol}:{order_volume}@{self.price}") + else: + if tick.bid_price_1 >= self.price: + self.sell(self.vt_symbol, self.price, + order_volume, offset=self.offset) + self.write_log( + f"委托卖出{self.vt_symbol}:{order_volume}@{self.price}") + + def get_default_setting(self): + """""" + return self.default_setting diff --git a/vnpy/app/algo_trading/engine.py b/vnpy/app/algo_trading/engine.py index 3c5224dd..fe551043 100644 --- a/vnpy/app/algo_trading/engine.py +++ b/vnpy/app/algo_trading/engine.py @@ -1,65 +1,257 @@ -from vnpy.event import EventEngine +from vnpy.event import EventEngine, Event from vnpy.trader.engine import BaseEngine, MainEngine -from vnpy.trader.event import (EVENT_TICK, EVENT_TIMER, EVENT_ORDER, EVENT_TRADE) +from vnpy.trader.event import ( + EVENT_TICK, EVENT_TIMER, EVENT_ORDER, EVENT_TRADE) +from vnpy.trader.constant import (Direction, Offset, OrderType) +from vnpy.trader.object import (SubscribeRequest, OrderRequest) +from vnpy.trader.utility import load_json, save_json + +from .template import AlgoTemplate + + +APP_NAME = "AlgoTrading" + +EVENT_ALGO_LOG = "eAlgoLog" +EVENT_ALGO_SETTING = "eAlgoSetting" +EVENT_ALGO_VARIABLES = "eAlgoVariables" +EVENT_ALGO_PARAMETERS = "eAlgoParameters" class AlgoEngine(BaseEngine): """""" + setting_filename = "algo_trading_setting.json" def __init__(self, main_engine: MainEngine, event_engine: EventEngine): """Constructor""" - super().__init__(main_engine, event_engine) - + super().__init__(main_engine, event_engine, APP_NAME) + self.algos = {} self.symbol_algo_map = {} self.orderid_algo_map = {} - + + self.algo_templates = {} + self.algo_settings = {} + + self.load_algo_template() self.register_event() - + + def init_engine(self): + """""" + self.write_log("算法交易引擎启动") + self.load_algo_setting() + + def load_algo_template(self): + """""" + from .algos.twap_algo import TwapAlgo + + self.algo_templates[TwapAlgo.__name__] = TwapAlgo + + def load_algo_setting(self): + """""" + self.algo_settings = load_json(self.setting_filename) + + for setting_name, setting in self.algo_settings.items(): + self.put_setting_event(setting_name, setting) + + self.write_log("算法配置载入成功") + + def save_algo_setting(self): + """""" + save_json(self.setting_filename, self.algo_settings) + def register_event(self): """""" self.event_engine.register(EVENT_TICK, self.process_tick_event) self.event_engine.register(EVENT_TIMER, self.process_timer_event) self.event_engine.register(EVENT_ORDER, self.process_order_event) self.event_engine.register(EVENT_TRADE, self.process_trade_event) - - def process_tick_event(self): + + def process_tick_event(self, event: Event): """""" - pass - - def process_timer_event(self): + tick = event.data + + algos = self.symbol_algo_map.get(tick.vt_symbol, None) + if algos: + for algo in algos: + algo.update_tick(tick) + + def process_timer_event(self, event: Event): """""" - pass - - def process_trade_event(self): + for algo in self.algos.values(): + algo.update_timer() + + def process_trade_event(self, event: Event): """""" - pass - - def process_order_event(self): + trade = event.data + + algo = self.orderid_algo_map.get(trade.vt_orderid, None) + if algo: + algo.update_trade(trade) + + def process_order_event(self, event: Event): """""" - pass - + order = event.data + + algo = self.orderid_algo_map.get(order.vt_orderid, None) + if algo: + algo.update_order(order) + def start_algo(self, setting: dict): """""" - pass - - def stop_algo(self, algo_name: dict): + template_name = setting["template_name"] + algo_template = self.algo_templates[template_name] + + algo = algo_template.new(self, setting) + algo.start() + + self.algos[algo.algo_name] = algo + return algo.algo_name + + def stop_algo(self, algo_name: str): """""" - pass - + algo = self.algos.get(algo_name, None) + if algo: + algo.stop() + self.algos.pop(algo_name) + def stop_all(self): """""" - pass - - def subscribe(self, algo, vt_symbol): + for algo_name in list(self.algos.keys()): + self.stop_algo(algo_name) + + def subscribe(self, algo: AlgoTemplate, vt_symbol: str): """""" - pass - + contract = self.main_engine.get_contract(vt_symbol) + if not contract: + self.write_log(f'订阅行情失败,找不到合约:{vt_symbol}', algo) + return + + algos = self.symbol_algo_map.setdefault(vt_symbol, set()) + + if not algos: + req = SubscribeRequest( + symbol=contract.symbol, + exchange=contract.exchange + ) + self.main_engine.subscribe(req, contract.gateway_name) + + algos.add(algo) + def send_order( - self, - algo, - vt_symbol + self, + algo: AlgoTemplate, + vt_symbol: str, + direction: Direction, + price: float, + volume: float, + order_type: OrderType, + offset: Offset ): """""" - pass + contract = self.main_engine.get_contract(vt_symbol) + if not contract: + self.write_log(f'委托下单失败,找不到合约:{vt_symbol}', algo) + return + + req = OrderRequest( + symbol=contract.symbol, + exchange=contract.exchange, + direction=direction, + type=order_type, + volume=volume, + price=price, + offset=offset + ) + vt_orderid = self.main_engine.send_order(req, contract.gateway_name) + + self.orderid_algo_map[vt_orderid] = algo + return vt_orderid + + def cancel_order(self, algo: AlgoTemplate, vt_orderid: str): + """""" + order = self.main_engine.get_order(vt_orderid) + + if not order: + self.write_log(f"委托撤单失败,找不到委托:{vt_orderid}", algo) + return + + req = order.create_cancel_request() + self.main_engine.cancel_order(req, order.gateway_name) + + def get_tick(self, algo: AlgoTemplate, vt_symbol: str): + """""" + tick = self.main_engine.get_tick(vt_symbol) + + if not tick: + self.write_log(f"查询行情失败,找不到行情:{vt_symbol}", algo) + + return tick + + def get_contract(self, algo: AlgoTemplate, vt_symbol: str): + """""" + contract = self.main_engine.get_contract(vt_symbol) + + if not contract: + self.write_log(f"查询合约失败,找不到合约:{vt_symbol}", algo) + + return contract + + def write_log(self, msg: str, algo: AlgoTemplate = None): + """""" + if algo: + msg = f"{algo.algo_name}:{msg}" + + event = Event(EVENT_ALGO_LOG) + event.data = msg + self.event_engine.put(event) + + def put_setting_event(self, setting_name: str, setting: dict): + """""" + event = Event(EVENT_ALGO_SETTING) + event.data = { + "setting_name": setting_name, + "setting": setting + } + self.event_engine.put(event) + + def update_algo_setting(self, setting_name: str, setting: dict): + """""" + self.algo_settings[setting_name] = setting + + self.save_algo_setting() + + self.put_setting_event(setting_name, setting) + + def remove_algo_setting(self, setting_name: str): + """""" + if setting_name not in self.algo_settings: + return + self.algo_settings.pop(setting_name) + + event = Event(EVENT_ALGO_SETTING) + event.data = { + "setting_name": setting_name, + "setting": None + } + self.event_engine.put(event) + + self.save_algo_setting() + + def put_parameters_event(self, algo: AlgoTemplate, parameters: dict): + """""" + event = Event(EVENT_ALGO_PARAMETERS) + event.data = { + "algo_name": algo.algo_name, + "parameters": parameters + } + self.event_engine.put(event) + + def put_variables_event(self, algo: AlgoTemplate, variables: dict): + """""" + event = Event(EVENT_ALGO_VARIABLES) + event.data = { + "algo_name": algo.algo_name, + "variables": variables + } + self.event_engine.put(event) diff --git a/vnpy/app/algo_trading/template.py b/vnpy/app/algo_trading/template.py index 1ba0c5a4..758e526f 100644 --- a/vnpy/app/algo_trading/template.py +++ b/vnpy/app/algo_trading/template.py @@ -1,30 +1,38 @@ from vnpy.trader.engine import BaseEngine from vnpy.trader.object import TickData, OrderData, TradeData -from vnpy.trader.constant import OrderType, Offset +from vnpy.trader.constant import OrderType, Offset, Direction + class AlgoTemplate: """""" - count = 0 + + _count = 0 + display_name = "" + default_setting = {} + variables = [] def __init__( - self, - algo_engine: BaseEngine, + self, + algo_engine: BaseEngine, algo_name: str, setting: dict ): """Constructor""" self.algo_engine = algo_engine self.algo_name = algo_name - - self.active = False - self.active_orders = {} # vt_orderid:order - @staticmethod - def new(cls, algo_engine:BaseEngine, setting: dict): + self.active = False + self.active_orders = {} # vt_orderid:order + + self.variables.insert(0, "active") + + @classmethod + def new(cls, algo_engine: BaseEngine, setting: dict): """Create new algo instance""" - cls.count += 1 - algo_name = f"{cls.__name__}_{cls.count}" + cls._count += 1 + algo_name = f"{cls.__name__}_{cls._count}" algo = cls(algo_engine, algo_name, setting) + return algo def update_tick(self, tick: TickData): """""" @@ -38,27 +46,27 @@ class AlgoTemplate: self.active_orders[order.vt_orderid] = order elif order.vt_orderid in self.active_orders: self.active_orders.pop(order.vt_orderid) - + self.on_order(order) - + def update_trade(self, trade: TradeData): """""" if self.active: self.on_trade(trade) - + def update_timer(self): """""" if self.active: self.on_timer() - + def on_start(self): """""" pass - + def on_stop(self): """""" pass - + def on_tick(self, tick: TickData): """""" pass @@ -66,58 +74,106 @@ class AlgoTemplate: def on_order(self, order: OrderData): """""" pass - + def on_trade(self, trade: TradeData): """""" pass - + def on_timer(self): """""" - pass - + pass + def start(self): """""" - pass - + self.active = True + self.on_start() + self.put_variables_event() + def stop(self): """""" - pass - + self.active = False + self.cancel_all() + self.on_stop() + self.put_variables_event() + + def subscribe(self, vt_symbol): + """""" + self.algo_engine.subscribe(self, vt_symbol) + def buy( - self, - vt_symbol, - price, - volume, + self, + vt_symbol, + price, + volume, order_type: OrderType = OrderType.LIMIT, offset: Offset = Offset.NONE ): """""" - return self.algo_engine.buy( + return self.algo_engine.send_order( + self, vt_symbol, + Direction.LONG, price, volume, order_type, offset ) - + def sell( - self, - vt_symbol, - price, - volume, + self, + vt_symbol, + price, + volume, order_type: OrderType = OrderType.LIMIT, offset: Offset = Offset.NONE ): """""" - return self.algo_engine.buy( + return self.algo_engine.send_order( + self, vt_symbol, + Direction.SHORT, price, volume, order_type, offset - ) - - - - - \ No newline at end of file + ) + + def cancel_order(self, vt_orderid: str): + """""" + self.algo_engine.cancel_order(self, vt_orderid) + + def cancel_all(self): + """""" + if not self.active_orders: + return + + for vt_orderid in self.active_orders.keys(): + self.cancel_order(vt_orderid) + + def get_tick(self, vt_symbol: str): + """""" + return self.algo_engine.get_tick(self, vt_symbol) + + def get_contract(self, vt_symbol: str): + """""" + return self.algo_engine.get_contract(self, vt_symbol) + + def write_log(self, msg: str): + """""" + self.algo_engine.write_log(msg, self) + + def put_parameters_event(self): + """""" + parameters = {} + for name in self.default_setting.keys(): + parameters[name] = getattr(self, name) + + self.algo_engine.put_parameters_event(self, parameters) + + def put_variables_event(self): + """""" + variables = {} + for name in self.variables: + variables[name] = getattr(self, name) + + self.algo_engine.put_variables_event(self, variables) diff --git a/vnpy/app/algo_trading/ui/__init__.py b/vnpy/app/algo_trading/ui/__init__.py new file mode 100644 index 00000000..9ac801bb --- /dev/null +++ b/vnpy/app/algo_trading/ui/__init__.py @@ -0,0 +1 @@ +from .widget import AlgoManager diff --git a/vnpy/app/algo_trading/ui/display.py b/vnpy/app/algo_trading/ui/display.py new file mode 100644 index 00000000..8d6b991b --- /dev/null +++ b/vnpy/app/algo_trading/ui/display.py @@ -0,0 +1,15 @@ +NAME_DISPLAY_MAP = { + "vt_symbol": "本地代码", + "direction": "方向", + "price": "价格", + "volume": "数量", + "time": "执行时间(秒)", + "interval": "每轮间隔(秒)", + "offset": "开平", + "active": "算法状态", + "traded": "成交数量", + "order_volume": "单笔委托", + "timer_count": "本轮读秒", + "total_count": "累计读秒", + "template_name": "算法模板" +} diff --git a/vnpy/app/algo_trading/ui/widget.py b/vnpy/app/algo_trading/ui/widget.py index e69de29b..57976d92 100644 --- a/vnpy/app/algo_trading/ui/widget.py +++ b/vnpy/app/algo_trading/ui/widget.py @@ -0,0 +1,571 @@ +""" +Widget for algo trading. +""" + +from functools import partial +from datetime import datetime + +from vnpy.event import EventEngine, Event +from vnpy.trader.engine import MainEngine +from vnpy.trader.ui import QtWidgets, QtCore + +from ..engine import ( + AlgoEngine, + AlgoTemplate, + APP_NAME, + EVENT_ALGO_LOG, + EVENT_ALGO_PARAMETERS, + EVENT_ALGO_VARIABLES, + EVENT_ALGO_SETTING +) +from .display import NAME_DISPLAY_MAP + + +class AlgoWidget(QtWidgets.QWidget): + """ + Start connection of a certain gateway. + """ + + def __init__( + self, + algo_engine: AlgoEngine, + algo_template: AlgoTemplate + ): + """""" + super().__init__() + + self.algo_engine = algo_engine + self.template_name = algo_template.__name__ + self.default_setting = algo_template.default_setting + + self.widgets = {} + + self.init_ui() + + def init_ui(self): + """ + Initialize line edits and form layout based on setting. + """ + self.setMaximumWidth(400) + + form = QtWidgets.QFormLayout() + + for field_name, field_value in self.default_setting.items(): + field_type = type(field_value) + + if field_type == list: + widget = QtWidgets.QComboBox() + widget.addItems(field_value) + else: + widget = QtWidgets.QLineEdit() + + display_name = NAME_DISPLAY_MAP.get(field_name, field_name) + + form.addRow(display_name, widget) + self.widgets[field_name] = (widget, field_type) + + start_algo_button = QtWidgets.QPushButton("启动算法") + start_algo_button.clicked.connect(self.start_algo) + form.addRow(start_algo_button) + + form.addRow(QtWidgets.QLabel("")) + + self.setting_name_line = QtWidgets.QLineEdit() + form.addRow("配置名称", self.setting_name_line) + + save_setting_button = QtWidgets.QPushButton("保存配置") + save_setting_button.clicked.connect(self.save_setting) + form.addRow(save_setting_button) + + self.setLayout(form) + + def get_setting(self): + """ + Get setting value from line edits. + """ + setting = {"template_name": self.template_name} + + for field_name, tp in self.widgets.items(): + widget, field_type = tp + if field_type == list: + field_value = str(widget.currentText()) + else: + try: + field_value = field_type(widget.text()) + except ValueError: + display_name = NAME_DISPLAY_MAP.get(field_name, field_name) + QtWidgets.QMessageBox.warning( + self, + "参数错误", + f"{display_name}参数类型应为{field_type},请检查!" + ) + return None + + setting[field_name] = field_value + + return setting + + def start_algo(self): + """ + Start algo trading. + """ + setting = self.get_setting() + if setting: + self.algo_engine.start_algo(setting) + + def update_setting(self, setting_name: str, setting: dict): + """ + Update setting into widgets. + """ + self.setting_name_line.setText(setting_name) + + for name, tp in self.widgets.items(): + widget, _ = tp + value = setting[name] + + if isinstance(widget, QtWidgets.QLineEdit): + widget.setText(str(value)) + elif isinstance(widget, QtWidgets.QComboBox): + ix = widget.findText(value) + widget.setCurrentIndex(ix) + + def save_setting(self): + """ + Save algo setting + """ + setting_name = self.setting_name_line.text() + if not setting_name: + return + + setting = self.get_setting() + if setting: + self.algo_engine.update_algo_setting(setting_name, setting) + + +class AlgoMonitor(QtWidgets.QTableWidget): + """""" + parameters_signal = QtCore.pyqtSignal(Event) + variables_signal = QtCore.pyqtSignal(Event) + + def __init__( + self, + algo_engine: AlgoEngine, + event_engine: EventEngine, + mode_active: bool + ): + """""" + super().__init__() + + self.algo_engine = algo_engine + self.event_engine = event_engine + self.mode_active = mode_active + + self.algo_cells = {} + + self.init_ui() + self.register_event() + + def init_ui(self): + """""" + labels = [ + "", + "算法", + "参数", + "状态" + ] + self.setColumnCount(len(labels)) + self.setHorizontalHeaderLabels(labels) + self.verticalHeader().setVisible(False) + self.setEditTriggers(self.NoEditTriggers) + + self.verticalHeader().setSectionResizeMode( + QtWidgets.QHeaderView.ResizeToContents + ) + + for column in range(2, 4): + self.horizontalHeader().setSectionResizeMode( + column, + QtWidgets.QHeaderView.Stretch + ) + self.setWordWrap(True) + + if not self.mode_active: + self.hideColumn(0) + + def register_event(self): + """""" + self.parameters_signal.connect(self.process_parameters_event) + self.variables_signal.connect(self.process_variables_event) + + self.event_engine.register( + EVENT_ALGO_PARAMETERS, self.parameters_signal.emit) + self.event_engine.register( + EVENT_ALGO_VARIABLES, self.variables_signal.emit) + + def process_parameters_event(self, event): + """""" + data = event.data + algo_name = data["algo_name"] + parameters = data["parameters"] + + cells = self.get_algo_cells(algo_name) + text = to_text(parameters) + cells["parameters"].setText(text) + + def process_variables_event(self, event): + """""" + data = event.data + algo_name = data["algo_name"] + variables = data["variables"] + + cells = self.get_algo_cells(algo_name) + variables_cell = cells["variables"] + text = to_text(variables) + variables_cell.setText(text) + + row = self.row(variables_cell) + active = variables["active"] + + if self.mode_active: + if active: + self.showRow(row) + else: + self.hideRow(row) + else: + if active: + self.hideRow(row) + else: + self.showRow(row) + + def stop_algo(self, algo_name: str): + """""" + self.algo_engine.stop_algo(algo_name) + + def get_algo_cells(self, algo_name: str): + """""" + cells = self.algo_cells.get(algo_name, None) + + if not cells: + stop_func = partial(self.stop_algo, algo_name=algo_name) + stop_button = QtWidgets.QPushButton("停止") + stop_button.clicked.connect(stop_func) + + name_cell = QtWidgets.QTableWidgetItem(algo_name) + parameters_cell = QtWidgets.QTableWidgetItem() + variables_cell = QtWidgets.QTableWidgetItem() + + self.insertRow(0) + self.setCellWidget(0, 0, stop_button) + self.setItem(0, 1, name_cell) + self.setItem(0, 2, parameters_cell) + self.setItem(0, 3, variables_cell) + + cells = { + "name": name_cell, + "parameters": parameters_cell, + "variables": variables_cell + } + self.algo_cells[algo_name] = cells + + return cells + + +class ActiveAlgoMonitor(AlgoMonitor): + """ + Monitor for active algos. + """ + + def __init__(self, algo_engine: AlgoEngine, event_engine: EventEngine): + """""" + super().__init__(algo_engine, event_engine, True) + + +class InactiveAlgoMonitor(AlgoMonitor): + """ + Monitor for inactive algos. + """ + + def __init__(self, algo_engine: AlgoEngine, event_engine: EventEngine): + """""" + super().__init__(algo_engine, event_engine, False) + + +class SettingMonitor(QtWidgets.QTableWidget): + """""" + setting_signal = QtCore.pyqtSignal(Event) + use_signal = QtCore.pyqtSignal(dict) + + def __init__(self, algo_engine: AlgoEngine, event_engine: EventEngine): + """""" + super().__init__() + + self.algo_engine = algo_engine + self.event_engine = event_engine + + self.settings = {} + self.setting_cells = {} + + self.init_ui() + self.register_event() + + def init_ui(self): + """""" + labels = [ + "", + "", + "名称", + "配置" + ] + self.setColumnCount(len(labels)) + self.setHorizontalHeaderLabels(labels) + self.verticalHeader().setVisible(False) + self.setEditTriggers(self.NoEditTriggers) + + self.verticalHeader().setSectionResizeMode( + QtWidgets.QHeaderView.ResizeToContents + ) + + self.horizontalHeader().setSectionResizeMode( + 3, + QtWidgets.QHeaderView.Stretch + ) + self.setWordWrap(True) + + def register_event(self): + """""" + self.setting_signal.connect(self.process_setting_event) + + self.event_engine.register( + EVENT_ALGO_SETTING, self.setting_signal.emit) + + def process_setting_event(self, event): + """""" + data = event.data + setting_name = data["setting_name"] + setting = data["setting"] + cells = self.get_setting_cells(setting_name) + + if setting: + self.settings[setting_name] = setting + + cells["setting"].setText(to_text(setting)) + else: + if setting_name in self.settings: + self.settings.pop(setting_name) + + row = self.row(cells["setting"]) + self.removeRow(row) + + self.setting_cells.pop(setting_name) + + def get_setting_cells(self, setting_name: str): + """""" + cells = self.setting_cells.get(setting_name, None) + + if not cells: + use_func = partial(self.use_setting, setting_name=setting_name) + use_button = QtWidgets.QPushButton("使用") + use_button.clicked.connect(use_func) + + remove_func = partial(self.remove_setting, + setting_name=setting_name) + remove_button = QtWidgets.QPushButton("移除") + remove_button.clicked.connect(remove_func) + + name_cell = QtWidgets.QTableWidgetItem(setting_name) + setting_cell = QtWidgets.QTableWidgetItem() + + self.insertRow(0) + self.setCellWidget(0, 0, use_button) + self.setCellWidget(0, 1, remove_button) + self.setItem(0, 2, name_cell) + self.setItem(0, 3, setting_cell) + + cells = { + "name": name_cell, + "setting": setting_cell + } + self.setting_cells[setting_name] = cells + + return cells + + def use_setting(self, setting_name: str): + """""" + setting = self.settings[setting_name] + setting["setting_name"] = setting_name + self.use_signal.emit(setting) + + def remove_setting(self, setting_name: str): + """""" + self.algo_engine.remove_algo_setting(setting_name) + + +class LogMonitor(QtWidgets.QTableWidget): + """""" + signal = QtCore.pyqtSignal(Event) + + def __init__(self, event_engine: EventEngine): + """""" + super().__init__() + + self.event_engine = event_engine + + self.init_ui() + self.register_event() + + def init_ui(self): + """""" + labels = [ + "时间", + "信息" + ] + self.setColumnCount(len(labels)) + self.setHorizontalHeaderLabels(labels) + self.verticalHeader().setVisible(False) + self.setEditTriggers(self.NoEditTriggers) + + self.verticalHeader().setSectionResizeMode( + QtWidgets.QHeaderView.ResizeToContents + ) + + self.horizontalHeader().setSectionResizeMode( + 1, + QtWidgets.QHeaderView.Stretch + ) + self.setWordWrap(True) + + def register_event(self): + """""" + self.signal.connect(self.process_log_event) + + self.event_engine.register(EVENT_ALGO_LOG, self.signal.emit) + + def process_log_event(self, event): + """""" + msg = event.data + timestamp = datetime.now().strftime("%H:%M:%S") + + timestamp_cell = QtWidgets.QTableWidgetItem(timestamp) + msg_cell = QtWidgets.QTableWidgetItem(msg) + + self.insertRow(0) + self.setItem(0, 0, timestamp_cell) + self.setItem(0, 1, msg_cell) + + +class AlgoManager(QtWidgets.QWidget): + """""" + + def __init__(self, main_engine: MainEngine, event_engine: EventEngine): + """""" + super().__init__() + + self.main_engine = main_engine + self.event_engine = event_engine + self.algo_engine = main_engine.get_engine(APP_NAME) + + self.algo_widgets = {} + + self.init_ui() + self.algo_engine.init_engine() + + def init_ui(self): + """""" + self.setWindowTitle("算法交易") + + # Left side control widgets + self.template_combo = QtWidgets.QComboBox() + self.template_combo.currentIndexChanged.connect(self.show_algo_widget) + + form = QtWidgets.QFormLayout() + form.addRow("算法", self.template_combo) + widget = QtWidgets.QWidget() + widget.setLayout(form) + + vbox = QtWidgets.QVBoxLayout() + vbox.addWidget(widget) + + for algo_template in self.algo_engine.algo_templates.values(): + widget = AlgoWidget(self.algo_engine, algo_template) + vbox.addWidget(widget) + + template_name = algo_template.__name__ + display_name = algo_template.display_name + + self.algo_widgets[template_name] = widget + self.template_combo.addItem(display_name, template_name) + + vbox.addStretch() + + stop_all_button = QtWidgets.QPushButton("全部停止") + stop_all_button.setFixedHeight(stop_all_button.sizeHint().height() * 2) + stop_all_button.clicked.connect(self.algo_engine.stop_all) + + vbox.addWidget(stop_all_button) + + # Right side monitor widgets + active_algo_monitor = ActiveAlgoMonitor( + self.algo_engine, self.event_engine + ) + inactive_algo_monitor = InactiveAlgoMonitor( + self.algo_engine, self.event_engine + ) + tab1 = QtWidgets.QTabWidget() + tab1.addTab(active_algo_monitor, "执行中") + tab1.addTab(inactive_algo_monitor, "已结束") + + log_monitor = LogMonitor(self.event_engine) + tab2 = QtWidgets.QTabWidget() + tab2.addTab(log_monitor, "日志") + + setting_monitor = SettingMonitor(self.algo_engine, self.event_engine) + setting_monitor.use_signal.connect(self.use_setting) + tab3 = QtWidgets.QTabWidget() + tab3.addTab(setting_monitor, "配置") + + grid = QtWidgets.QGridLayout() + grid.addWidget(tab1, 0, 0, 1, 2) + grid.addWidget(tab2, 1, 0) + grid.addWidget(tab3, 1, 1) + + hbox2 = QtWidgets.QHBoxLayout() + hbox2.addLayout(vbox) + hbox2.addLayout(grid) + self.setLayout(hbox2) + + self.show_algo_widget() + + def show_algo_widget(self): + """""" + ix = self.template_combo.currentIndex() + current_name = self.template_combo.itemData(ix) + + for template_name, widget in self.algo_widgets.items(): + if template_name == current_name: + widget.show() + else: + widget.hide() + + def use_setting(self, setting: dict): + """""" + setting_name = setting["setting_name"] + template_name = setting["template_name"] + + widget = self.algo_widgets[template_name] + widget.update_setting(setting_name, setting) + + def show(self): + """""" + self.showMaximized() + + +def to_text(data: dict): + """ + Convert dict data into string. + """ + buf = [] + for key, value in data.items(): + key = NAME_DISPLAY_MAP.get(key, key) + buf.append(f"{key}:{value}") + text = ",".join(buf) + return text diff --git a/vnpy/app/cta_strategy/engine.py b/vnpy/app/cta_strategy/engine.py index 9bff635b..37e5973a 100644 --- a/vnpy/app/cta_strategy/engine.py +++ b/vnpy/app/cta_strategy/engine.py @@ -40,6 +40,7 @@ from vnpy.trader.database import DbTickData, DbBarData from vnpy.trader.setting import SETTINGS from .base import ( + APP_NAME, EVENT_CTA_LOG, EVENT_CTA_STRATEGY, EVENT_CTA_STOPORDER, @@ -73,7 +74,7 @@ class CtaEngine(BaseEngine): def __init__(self, main_engine: MainEngine, event_engine: EventEngine): """""" super(CtaEngine, self).__init__( - main_engine, event_engine, "CtaStrategy") + main_engine, event_engine, APP_NAME) self.strategy_setting = {} # strategy_name: dict self.strategy_data = {} # strategy_name: dict From 4472e026c93da9c7b623a7c35f014a4bcff7957d Mon Sep 17 00:00:00 2001 From: "vn.py" Date: Sun, 7 Apr 2019 16:41:30 +0800 Subject: [PATCH 29/49] [Add]on_send_order_fail callback for HuobiGateway --- vnpy/gateway/huobi/huobi_gateway.py | 69 ++++++++++++++++++++--------- 1 file changed, 48 insertions(+), 21 deletions(-) diff --git a/vnpy/gateway/huobi/huobi_gateway.py b/vnpy/gateway/huobi/huobi_gateway.py index 5572f5fe..2123fc2b 100644 --- a/vnpy/gateway/huobi/huobi_gateway.py +++ b/vnpy/gateway/huobi/huobi_gateway.py @@ -15,7 +15,7 @@ from copy import copy from datetime import datetime from vnpy.event import Event -from vnpy.api.rest import RestClient +from vnpy.api.rest import RestClient, Request from vnpy.api.websocket import WebsocketClient from vnpy.trader.constant import ( Direction, @@ -268,13 +268,15 @@ class HuobiRestApi(RestClient): "price": str(req.price), "source": "api" } - + self.add_request( method="POST", path="/v1/order/orders/place", callback=self.on_send_order, data=data, extra=order, + on_error=self.on_send_order_error, + on_failed=self.on_send_order_failed ) self.order_manager.on_order(order) @@ -283,15 +285,15 @@ class HuobiRestApi(RestClient): def cancel_order(self, req: CancelRequest): """""" sys_orderid = self.order_manager.get_sys_orderid(req.orderid) - - path = f"/v1/order/orders/{sys_orderid}/submitcancel" + + path = f"/v1/order/orders/{sys_orderid}/submitcancel" self.add_request( - method="POST", - path=path, + method="POST", + path=path, callback=self.on_cancel_order, extra=req ) - + def on_query_account(self, data, request): """""" if self.check_error(data, "查询账户"): @@ -300,8 +302,8 @@ class HuobiRestApi(RestClient): for d in data["data"]: if d["type"] == "spot": self.account_id = d["id"] - self.gateway.write_log(f"账户代码{self.account_id}查询成功") - + self.gateway.write_log(f"账户代码{self.account_id}查询成功") + self.query_account_balance() def on_query_account_balance(self, data, request): @@ -327,7 +329,7 @@ class HuobiRestApi(RestClient): self.gateway.on_account(account) def on_query_order(self, data, request): - """""" + """""" if self.check_error(data, "查询委托"): return @@ -354,7 +356,7 @@ class HuobiRestApi(RestClient): ) self.order_manager.on_order(order) - + self.gateway.write_log("委托信息查询成功") def on_query_contract(self, data, request): # type: (dict, Request)->None @@ -379,7 +381,7 @@ class HuobiRestApi(RestClient): gateway_name=self.gateway_name, ) self.gateway.on_contract(contract) - + huobi_symbols.add(contract.symbol) symbol_name_map[contract.symbol] = contract.name @@ -388,7 +390,7 @@ class HuobiRestApi(RestClient): def on_send_order(self, data, request): """""" order = request.extra - + if self.check_error(data, "委托"): order.status = Status.REJECTED self.order_manager.on_order(order) @@ -396,21 +398,46 @@ class HuobiRestApi(RestClient): sys_orderid = data["data"] self.order_manager.update_orderid_map(order.orderid, sys_orderid) - + + def on_send_order_failed(self, status_code: str, request: Request): + """ + Callback when sending order failed on server. + """ + order = request.extra + order.status = Status.REJECTED + self.gateway.on_order(order) + + msg = f"委托失败,状态码:{status_code},信息:{request.response.text}" + self.gateway.write_log(msg) + + def on_send_order_error( + self, exception_type: type, exception_value: Exception, tb, request: Request + ): + """ + Callback when sending order caused exception. + """ + order = request.extra + order.status = Status.REJECTED + self.gateway.on_order(order) + + # Record exception if not ConnectionError + if not issubclass(exception_type, ConnectionError): + self.on_error(exception_type, exception_value, tb, request) + def on_cancel_order(self, data, request): """""" - if self.check_error(data, "撤单"): - return - cancel_request = request.extra local_orderid = cancel_request.orderid - order = self.order_manager.get_order_with_local_orderid(local_orderid) - order.status = Status.CANCELLED + + if self.check_error(data, "撤单"): + order.status = Status.REJECTED + else: + order.status = Status.CANCELLED + self.gateway.write_log(f"委托撤单成功:{order.orderid}") self.order_manager.on_order(order) - self.gateway.write_log(f"委托撤单成功:{order.orderid}") - + def check_error(self, data: dict, func: str = ""): """""" if data["status"] != "error": From 2f674f2019bf6356880f882d00aa18f634bac3ab Mon Sep 17 00:00:00 2001 From: "vn.py" Date: Sun, 7 Apr 2019 22:14:04 +0800 Subject: [PATCH 30/49] [Add]SniperAlgo and IcebergAlgo --- vnpy/app/algo_trading/algos/iceberg_algo.py | 137 ++++++++++++++++++++ vnpy/app/algo_trading/algos/sniper_algo.py | 101 +++++++++++++++ vnpy/app/algo_trading/algos/twap_algo.py | 23 ++-- vnpy/app/algo_trading/engine.py | 10 +- vnpy/app/algo_trading/template.py | 8 ++ vnpy/app/algo_trading/ui/display.py | 3 +- vnpy/app/algo_trading/ui/widget.py | 4 + 7 files changed, 269 insertions(+), 17 deletions(-) diff --git a/vnpy/app/algo_trading/algos/iceberg_algo.py b/vnpy/app/algo_trading/algos/iceberg_algo.py index e69de29b..74ab7f44 100644 --- a/vnpy/app/algo_trading/algos/iceberg_algo.py +++ b/vnpy/app/algo_trading/algos/iceberg_algo.py @@ -0,0 +1,137 @@ +from vnpy.trader.constant import Offset, Direction +from vnpy.trader.object import TradeData, OrderData, TickData +from vnpy.trader.engine import BaseEngine + +from vnpy.app.algo_trading import AlgoTemplate + + +class IcebergAlgo(AlgoTemplate): + """""" + + display_name = "Iceberg 冰山" + + default_setting = { + "vt_symbol": "", + "direction": [Direction.LONG.value, Direction.SHORT.value], + "price": 0.0, + "volume": 0.0, + "display_volume": 0.0, + "interval": 0, + "offset": [ + Offset.NONE.value, + Offset.OPEN.value, + Offset.CLOSE.value, + Offset.CLOSETODAY.value, + Offset.CLOSEYESTERDAY.value + ] + } + + variables = [ + "traded", + "timer_count", + "vt_orderid" + ] + + def __init__( + self, + algo_engine: BaseEngine, + algo_name: str, + setting: dict + ): + """""" + super().__init__(algo_engine, algo_name, setting) + + # Parameters + self.vt_symbol = setting["vt_symbol"] + self.direction = Direction(setting["direction"]) + self.price = setting["price"] + self.volume = setting["volume"] + self.display_volume = setting["display_volume"] + self.interval = setting["interval"] + self.offset = Offset(setting["offset"]) + + # Variables + self.timer_count = 0 + self.vt_orderid = "" + self.traded = 0 + + self.last_tick = None + + self.subscribe(self.vt_symbol) + self.put_parameters_event() + self.put_variables_event() + + def on_stop(self): + """""" + self.write_log("停止算法") + + def on_tick(self, tick: TickData): + """""" + self.last_tick = tick + + def on_order(self, order: OrderData): + """""" + msg = f"委托号:{order.vt_orderid},委托状态:{order.status.value}" + self.write_log(msg) + + if not order.is_active(): + self.vt_orderid = "" + self.put_variables_event() + + def on_trade(self, trade: TradeData): + """""" + self.traded += trade.volume + + if self.traded >= self.volume: + self.write_log(f"已交易数量:{self.traded},总数量:{self.volume}") + self.stop() + else: + self.put_variables_event() + + def on_timer(self): + """""" + self.timer_count += 1 + + if self.timer_count < self.interval: + self.put_variables_event() + return + + self.timer_count = 0 + + contract = self.get_contract(self.vt_symbol) + if not contract: + return + + # If order already finished, just send new order + if not self.vt_orderid: + order_volume = self.volume - self.traded + order_volume = min(order_volume, self.display_volume) + + if self.direction == Direction.LONG: + self.vt_orderid = self.buy( + self.vt_symbol, + self.price, + order_volume, + offset=self.offset + ) + else: + self.vt_orderid = self.sell( + self.vt_symbol, + self.price, + order_volume, + offset=self.offset + ) + # Otherwise check for cancel + else: + if self.direction == Direction.LONG: + if self.last_tick.ask_price_1 <= self.price: + self.cancel_order(self.vt_orderid) + self.vt_orderid = "" + self.write_log(u"最新Tick卖一价,低于买入委托价格,之前委托可能丢失,强制撤单") + else: + if self.last_tick.bid_price_1 >= self.price: + self.cancel_order(self.vt_orderid) + self.vt_orderid = "" + self.write_log(u"最新Tick买一价,高于卖出委托价格,之前委托可能丢失,强制撤单") + + self.put_variables_event() diff --git a/vnpy/app/algo_trading/algos/sniper_algo.py b/vnpy/app/algo_trading/algos/sniper_algo.py index e69de29b..5f4b1c56 100644 --- a/vnpy/app/algo_trading/algos/sniper_algo.py +++ b/vnpy/app/algo_trading/algos/sniper_algo.py @@ -0,0 +1,101 @@ +from vnpy.trader.constant import Offset, Direction +from vnpy.trader.object import TradeData, OrderData, TickData +from vnpy.trader.engine import BaseEngine + +from vnpy.app.algo_trading import AlgoTemplate + + +class SniperAlgo(AlgoTemplate): + """""" + + display_name = "Sniper 狙击手" + + default_setting = { + "vt_symbol": "", + "direction": [Direction.LONG.value, Direction.SHORT.value], + "price": 0.0, + "volume": 0.0, + "offset": [ + Offset.NONE.value, + Offset.OPEN.value, + Offset.CLOSE.value, + Offset.CLOSETODAY.value, + Offset.CLOSEYESTERDAY.value + ] + } + + variables = [ + "traded", + "vt_orderid" + ] + + def __init__( + self, + algo_engine: BaseEngine, + algo_name: str, + setting: dict + ): + """""" + super().__init__(algo_engine, algo_name, setting) + + # Parameters + self.vt_symbol = setting["vt_symbol"] + self.direction = Direction(setting["direction"]) + self.price = setting["price"] + self.volume = setting["volume"] + self.offset = Offset(setting["offset"]) + + # Variables + self.vt_orderid = "" + self.traded = 0 + + self.subscribe(self.vt_symbol) + self.put_parameters_event() + self.put_variables_event() + + def on_tick(self, tick: TickData): + """""" + if self.vt_orderid: + self.cancel_all() + return + + if self.direction == Direction.LONG: + if tick.ask_price_1 <= self.price: + order_volume = self.volume - self.traded + order_volume = min(order_volume, tick.ask_volume_1) + + self.vt_orderid = self.buy( + self.vt_symbol, + self.price, + order_volume, + offset=self.offset + ) + else: + if tick.bid_price_1 >= self.price: + order_volume = self.volume - self.traded + order_volume = min(order_volume, tick.bid_volume_1) + + self.vt_orderid = self.sell( + self.vt_symbol, + self.price, + order_volume, + offset=self.offset + ) + + self.put_variables_event() + + def on_order(self, order: OrderData): + """""" + if not order.is_active(): + self.vt_orderid = "" + self.put_variables_event() + + def on_trade(self, trade: TradeData): + """""" + self.traded += trade.volume + + if self.traded >= self.volume: + self.write_log(f"已交易数量:{self.traded},总数量:{self.volume}") + self.stop() + else: + self.put_variables_event() diff --git a/vnpy/app/algo_trading/algos/twap_algo.py b/vnpy/app/algo_trading/algos/twap_algo.py index 32fc6a65..ebd24833 100644 --- a/vnpy/app/algo_trading/algos/twap_algo.py +++ b/vnpy/app/algo_trading/algos/twap_algo.py @@ -26,6 +26,13 @@ class TwapAlgo(AlgoTemplate): ] } + variables = [ + "traded", + "order_volume", + "timer_count", + "total_count" + ] + def __init__( self, algo_engine: BaseEngine, @@ -50,13 +57,6 @@ class TwapAlgo(AlgoTemplate): self.total_count = 0 self.traded = 0 - self.variables.extend([ - "traded", - "order_volume", - "timer_count", - "total_count" - ]) - self.subscribe(self.vt_symbol) self.put_parameters_event() self.put_variables_event() @@ -66,6 +66,7 @@ class TwapAlgo(AlgoTemplate): self.traded += trade.volume if self.traded >= self.volume: + self.write_log(f"已交易数量:{self.traded},总数量:{self.volume}") self.stop() else: self.put_variables_event() @@ -98,15 +99,7 @@ class TwapAlgo(AlgoTemplate): if tick.ask_price_1 <= self.price: self.buy(self.vt_symbol, self.price, order_volume, offset=self.offset) - self.write_log( - f"委托买入{self.vt_symbol}:{order_volume}@{self.price}") else: if tick.bid_price_1 >= self.price: self.sell(self.vt_symbol, self.price, order_volume, offset=self.offset) - self.write_log( - f"委托卖出{self.vt_symbol}:{order_volume}@{self.price}") - - def get_default_setting(self): - """""" - return self.default_setting diff --git a/vnpy/app/algo_trading/engine.py b/vnpy/app/algo_trading/engine.py index fe551043..8607cc71 100644 --- a/vnpy/app/algo_trading/engine.py +++ b/vnpy/app/algo_trading/engine.py @@ -44,8 +44,16 @@ class AlgoEngine(BaseEngine): def load_algo_template(self): """""" from .algos.twap_algo import TwapAlgo + from .algos.iceberg_algo import IcebergAlgo + from .algos.sniper_algo import SniperAlgo - self.algo_templates[TwapAlgo.__name__] = TwapAlgo + self.add_algo_template(TwapAlgo) + self.add_algo_template(IcebergAlgo) + self.add_algo_template(SniperAlgo) + + def add_algo_template(self, template: AlgoTemplate): + """""" + self.algo_templates[template.__name__] = template def load_algo_setting(self): """""" diff --git a/vnpy/app/algo_trading/template.py b/vnpy/app/algo_trading/template.py index 758e526f..609641af 100644 --- a/vnpy/app/algo_trading/template.py +++ b/vnpy/app/algo_trading/template.py @@ -96,6 +96,8 @@ class AlgoTemplate: self.on_stop() self.put_variables_event() + self.write_log("停止算法") + def subscribe(self, vt_symbol): """""" self.algo_engine.subscribe(self, vt_symbol) @@ -109,6 +111,9 @@ class AlgoTemplate: offset: Offset = Offset.NONE ): """""" + msg = f"委托买入{vt_symbol}:{volume}@{price}" + self.write_log(msg) + return self.algo_engine.send_order( self, vt_symbol, @@ -128,6 +133,9 @@ class AlgoTemplate: offset: Offset = Offset.NONE ): """""" + msg = f"委托卖出{vt_symbol}:{volume}@{price}" + self.write_log(msg) + return self.algo_engine.send_order( self, vt_symbol, diff --git a/vnpy/app/algo_trading/ui/display.py b/vnpy/app/algo_trading/ui/display.py index 8d6b991b..4f7b9502 100644 --- a/vnpy/app/algo_trading/ui/display.py +++ b/vnpy/app/algo_trading/ui/display.py @@ -11,5 +11,6 @@ NAME_DISPLAY_MAP = { "order_volume": "单笔委托", "timer_count": "本轮读秒", "total_count": "累计读秒", - "template_name": "算法模板" + "template_name": "算法模板", + "display_volume": "挂出数量" } diff --git a/vnpy/app/algo_trading/ui/widget.py b/vnpy/app/algo_trading/ui/widget.py index 57976d92..111d599a 100644 --- a/vnpy/app/algo_trading/ui/widget.py +++ b/vnpy/app/algo_trading/ui/widget.py @@ -554,6 +554,10 @@ class AlgoManager(QtWidgets.QWidget): widget = self.algo_widgets[template_name] widget.update_setting(setting_name, setting) + ix = self.template_combo.findData(template_name) + self.template_combo.setCurrentIndex(ix) + self.show_algo_widget() + def show(self): """""" self.showMaximized() From 0e67e75d6c77957425f86d623ac085882910636b Mon Sep 17 00:00:00 2001 From: nanoric Date: Mon, 8 Apr 2019 04:56:42 -0400 Subject: [PATCH 31/49] [Add] virtual decorator --- vnpy/app/algo_trading/template.py | 6 ++++++ vnpy/app/cta_strategy/template.py | 14 ++++++++++++++ vnpy/trader/utility.py | 11 +++++++++++ 3 files changed, 31 insertions(+) diff --git a/vnpy/app/algo_trading/template.py b/vnpy/app/algo_trading/template.py index 609641af..3b6620b7 100644 --- a/vnpy/app/algo_trading/template.py +++ b/vnpy/app/algo_trading/template.py @@ -1,6 +1,7 @@ from vnpy.trader.engine import BaseEngine from vnpy.trader.object import TickData, OrderData, TradeData from vnpy.trader.constant import OrderType, Offset, Direction +from vnpy.trader.utility import virtual class AlgoTemplate: @@ -63,22 +64,27 @@ class AlgoTemplate: """""" pass + @virtual def on_stop(self): """""" pass + @virtual def on_tick(self, tick: TickData): """""" pass + @virtual def on_order(self, order: OrderData): """""" pass + @virtual def on_trade(self, trade: TradeData): """""" pass + @virtual def on_timer(self): """""" pass diff --git a/vnpy/app/cta_strategy/template.py b/vnpy/app/cta_strategy/template.py index 20cf88d2..0fde32c0 100644 --- a/vnpy/app/cta_strategy/template.py +++ b/vnpy/app/cta_strategy/template.py @@ -4,6 +4,7 @@ from typing import Any, Callable from vnpy.trader.constant import Interval, Direction, Offset from vnpy.trader.object import BarData, TickData, OrderData, TradeData +from vnpy.trader.utility import virtual from .base import StopOrder, EngineType @@ -87,48 +88,56 @@ class CtaTemplate(ABC): } return strategy_data + @virtual def on_init(self): """ Callback when strategy is inited. """ pass + @virtual def on_start(self): """ Callback when strategy is started. """ pass + @virtual def on_stop(self): """ Callback when strategy is stopped. """ pass + @virtual def on_tick(self, tick: TickData): """ Callback of new tick data update. """ pass + @virtual def on_bar(self, bar: BarData): """ Callback of new bar data update. """ pass + @virtual def on_trade(self, trade: TradeData): """ Callback of new trade data update. """ pass + @virtual def on_order(self, order: OrderData): """ Callback of new order data update. """ pass + @virtual def on_stop_order(self, stop_order: StopOrder): """ Callback of stop order update. @@ -255,12 +264,14 @@ class CtaSignal(ABC): """""" self.signal_pos = 0 + @virtual def on_tick(self, tick: TickData): """ Callback of new tick data update. """ pass + @virtual def on_bar(self, bar: BarData): """ Callback of new bar data update. @@ -292,6 +303,7 @@ class TargetPosTemplate(CtaTemplate): ) self.variables.append("target_pos") + @virtual def on_tick(self, tick: TickData): """ Callback of new tick data update. @@ -301,12 +313,14 @@ class TargetPosTemplate(CtaTemplate): if self.trading: self.trade() + @virtual def on_bar(self, bar: BarData): """ Callback of new bar data update. """ self.last_bar = bar + @virtual def on_order(self, order: OrderData): """ Callback of new order data update. diff --git a/vnpy/trader/utility.py b/vnpy/trader/utility.py index ca788227..9951e1f1 100644 --- a/vnpy/trader/utility.py +++ b/vnpy/trader/utility.py @@ -385,3 +385,14 @@ class ArrayManager(object): if array: return up, down return up[-1], down[-1] + + +def virtual(func: "callable"): + """ + mark a function as "virtual", which means that this function can be override. + any base class should use this or @abstractmethod to decorate all functions + that can be (re)implemented by subclasses. + """ + return func + + From 0260aa078403932a054b3a795771d4ff9ac0c842 Mon Sep 17 00:00:00 2001 From: "vn.py" Date: Mon, 8 Apr 2019 18:32:59 +0800 Subject: [PATCH 32/49] [Fix]convert str to float for price and volume data --- vnpy/gateway/okex/okex_gateway.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/vnpy/gateway/okex/okex_gateway.py b/vnpy/gateway/okex/okex_gateway.py index cc1e124c..4a1b2a93 100644 --- a/vnpy/gateway/okex/okex_gateway.py +++ b/vnpy/gateway/okex/okex_gateway.py @@ -298,7 +298,7 @@ class OkexRestApi(RestClient): name=symbol, product=Product.SPOT, size=1, - pricetick=instrument_data["tick_size"], + pricetick=float(instrument_data["tick_size"]), gateway_name=self.gateway_name ) self.gateway.on_contract(contract) @@ -336,6 +336,7 @@ class OkexRestApi(RestClient): direction=DIRECTION_OKEX2VT[order_data["side"]], price=float(order_data["price"]), volume=float(order_data["size"]), + traded=float(order_data["filled_size"]), time=order_data["timestamp"][11:19], status=STATUS_OKEX2VT[order_data["status"]], gateway_name=self.gateway_name, @@ -605,11 +606,11 @@ class OkexWebsocketApi(WebsocketClient): if not tick: return - tick.last_price = d["last"] - tick.open = d["open_24h"] - tick.high = d["high_24h"] - tick.low = d["low_24h"] - tick.volume = d["base_volume_24h"] + tick.last_price = float(d["last"]) + tick.open = float(d["open_24h"]) + tick.high = float(d["high_24h"]) + tick.low = float(d["low_24h"]) + tick.volume = float(d["base_volume_24h"]) tick.datetime = datetime.strptime( d["timestamp"], "%Y-%m-%dT%H:%M:%S.%fZ") self.gateway.on_tick(copy(tick)) @@ -626,13 +627,13 @@ class OkexWebsocketApi(WebsocketClient): asks = d["asks"] for n, buf in enumerate(bids): price, volume, _ = buf - tick.__setattr__("bid_price_%s" % (n + 1), price) - tick.__setattr__("bid_volume_%s" % (n + 1), volume) + tick.__setattr__("bid_price_%s" % (n + 1), float(price)) + tick.__setattr__("bid_volume_%s" % (n + 1), float(volume)) for n, buf in enumerate(asks): price, volume, _ = buf - tick.__setattr__("ask_price_%s" % (n + 1), price) - tick.__setattr__("ask_volume_%s" % (n + 1), volume) + tick.__setattr__("ask_price_%s" % (n + 1), float(price)) + tick.__setattr__("ask_volume_%s" % (n + 1), float(volume)) tick.datetime = datetime.strptime( d["timestamp"], "%Y-%m-%dT%H:%M:%S.%fZ") @@ -646,16 +647,16 @@ class OkexWebsocketApi(WebsocketClient): type=ORDERTYPE_OKEX2VT[d["type"]], orderid=d["client_oid"], direction=DIRECTION_OKEX2VT[d["side"]], - price=d["price"], - volume=d["size"], - traded=d["filled_size"], + price=float(d["price"]), + volume=float(d["size"]), + traded=float(d["filled_size"]), time=d["timestamp"][11:19], status=STATUS_OKEX2VT[d["status"]], gateway_name=self.gateway_name, ) self.gateway.on_order(copy(order)) - trade_volume = float(d.get("last_fill_qty", 0)) + trade_volume = d.get("last_fill_qty", 0) if not trade_volume: return From f53f5ea7dd8ecc13700398dd94164dece511dcbe Mon Sep 17 00:00:00 2001 From: "vn.py" Date: Tue, 9 Apr 2019 11:27:12 +0800 Subject: [PATCH 33/49] [Mod]remove default proxy setting --- vnpy/gateway/bitmex/bitmex_gateway.py | 9 +++++++-- vnpy/gateway/huobi/huobi_gateway.py | 9 +++++++-- vnpy/gateway/okex/okex_gateway.py | 10 +++++++--- 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/vnpy/gateway/bitmex/bitmex_gateway.py b/vnpy/gateway/bitmex/bitmex_gateway.py index 3c3e6879..11579c4b 100644 --- a/vnpy/gateway/bitmex/bitmex_gateway.py +++ b/vnpy/gateway/bitmex/bitmex_gateway.py @@ -71,8 +71,8 @@ class BitmexGateway(BaseGateway): "Secret": "", "会话数": 3, "服务器": ["REAL", "TESTNET"], - "代理地址": "127.0.0.1", - "代理端口": 1080, + "代理地址": "", + "代理端口": "", } def __init__(self, event_engine): @@ -91,6 +91,11 @@ class BitmexGateway(BaseGateway): proxy_host = setting["代理地址"] proxy_port = setting["代理端口"] + if proxy_port.isdigit(): + proxy_port = int(proxy_port) + else: + proxy_port = 0 + self.rest_api.connect(key, secret, session_number, server, proxy_host, proxy_port) diff --git a/vnpy/gateway/huobi/huobi_gateway.py b/vnpy/gateway/huobi/huobi_gateway.py index 2123fc2b..474e6f69 100644 --- a/vnpy/gateway/huobi/huobi_gateway.py +++ b/vnpy/gateway/huobi/huobi_gateway.py @@ -73,8 +73,8 @@ class HuobiGateway(BaseGateway): "API Key": "", "Secret Key": "", "会话数": 3, - "代理地址": "127.0.0.1", - "代理端口": 1080, + "代理地址": "", + "代理端口": "", } def __init__(self, event_engine): @@ -95,6 +95,11 @@ class HuobiGateway(BaseGateway): proxy_host = setting["代理地址"] proxy_port = setting["代理端口"] + if proxy_port.isdigit(): + proxy_port = int(proxy_port) + else: + proxy_port = 0 + self.rest_api.connect(key, secret, session_number, proxy_host, proxy_port) self.trade_ws_api.connect(key, secret, proxy_host, proxy_port) diff --git a/vnpy/gateway/okex/okex_gateway.py b/vnpy/gateway/okex/okex_gateway.py index 4a1b2a93..ded30c13 100644 --- a/vnpy/gateway/okex/okex_gateway.py +++ b/vnpy/gateway/okex/okex_gateway.py @@ -74,8 +74,8 @@ class OkexGateway(BaseGateway): "Secret Key": "", "Passphrase": "", "会话数": 3, - "代理地址": "127.0.0.1", - "代理端口": 1080, + "代理地址": "", + "代理端口": "", } def __init__(self, event_engine): @@ -94,9 +94,13 @@ class OkexGateway(BaseGateway): proxy_host = setting["代理地址"] proxy_port = setting["代理端口"] + if proxy_port.isdigit(): + proxy_port = int(proxy_port) + else: + proxy_port = 0 + self.rest_api.connect(key, secret, passphrase, session_number, proxy_host, proxy_port) - self.ws_api.connect(key, secret, passphrase, proxy_host, proxy_port) def subscribe(self, req: SubscribeRequest): From 1a79ba37db118b570c94960dce239c19267239fb Mon Sep 17 00:00:00 2001 From: "vn.py" Date: Tue, 9 Apr 2019 18:21:28 +0800 Subject: [PATCH 34/49] [Mod]add on_cancel_order_failed callback for OkexGateway --- vnpy/gateway/okex/okex_gateway.py | 21 +++++++++++++++++++++ vnpy/trader/utility.py | 2 -- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/vnpy/gateway/okex/okex_gateway.py b/vnpy/gateway/okex/okex_gateway.py index ded30c13..9ca18217 100644 --- a/vnpy/gateway/okex/okex_gateway.py +++ b/vnpy/gateway/okex/okex_gateway.py @@ -85,6 +85,8 @@ class OkexGateway(BaseGateway): self.rest_api = OkexRestApi(self) self.ws_api = OkexWebsocketApi(self) + self.orders = {} + def connect(self, setting: dict): """""" key = setting["API Key"] @@ -128,6 +130,15 @@ class OkexGateway(BaseGateway): self.rest_api.stop() self.ws_api.stop() + def on_order(self, order: OrderData): + """""" + self.orders[order.vt_orderid] = order + super().on_order(order) + + def get_order(self, vt_orderid: str): + """""" + return self.orders.get(vt_orderid, None) + class OkexRestApi(RestClient): """ @@ -258,6 +269,8 @@ class OkexRestApi(RestClient): callback=self.on_cancel_order, data=data, on_error=self.on_cancel_order_error, + on_failed=self.on_cancel_order_failed, + extra=req ) def query_contract(self): @@ -406,6 +419,14 @@ class OkexRestApi(RestClient): """Websocket will push a new order status""" pass + def on_cancel_order_failed(self, status_code: int, request: Request): + """If cancel failed, mark order status to be rejected.""" + req = request.extra + order = self.gateway.get_order(req.vt_orderid) + if order: + order.status = Status.REJECTED + self.gateway.on_order(order) + def on_failed(self, status_code: int, request: Request): """ Callback to handle request failed. diff --git a/vnpy/trader/utility.py b/vnpy/trader/utility.py index 9951e1f1..89c412a1 100644 --- a/vnpy/trader/utility.py +++ b/vnpy/trader/utility.py @@ -394,5 +394,3 @@ def virtual(func: "callable"): that can be (re)implemented by subclasses. """ return func - - From fd60b4db4bf4104bbf280fd81fb379021639ec6c Mon Sep 17 00:00:00 2001 From: qqqlyx Date: Wed, 10 Apr 2019 10:16:37 +0800 Subject: [PATCH 35/49] up --- vnpy/gateway/okex/okex_gateway.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vnpy/gateway/okex/okex_gateway.py b/vnpy/gateway/okex/okex_gateway.py index 9ca18217..bc5f2dfb 100644 --- a/vnpy/gateway/okex/okex_gateway.py +++ b/vnpy/gateway/okex/okex_gateway.py @@ -682,7 +682,7 @@ class OkexWebsocketApi(WebsocketClient): self.gateway.on_order(copy(order)) trade_volume = d.get("last_fill_qty", 0) - if not trade_volume: + if float(trade_volume) == 0: return self.trade_count += 1 From ee93bf709b847ef3a5531669a1355d9c4193ef5e Mon Sep 17 00:00:00 2001 From: qqqlyx Date: Wed, 10 Apr 2019 10:20:45 +0800 Subject: [PATCH 36/49] up --- vnpy/gateway/okex/okex_gateway.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vnpy/gateway/okex/okex_gateway.py b/vnpy/gateway/okex/okex_gateway.py index bc5f2dfb..10087748 100644 --- a/vnpy/gateway/okex/okex_gateway.py +++ b/vnpy/gateway/okex/okex_gateway.py @@ -682,7 +682,7 @@ class OkexWebsocketApi(WebsocketClient): self.gateway.on_order(copy(order)) trade_volume = d.get("last_fill_qty", 0) - if float(trade_volume) == 0: + if not trade_volume or float(trade_volume) == 0: return self.trade_count += 1 From 219fbe2b5f95948875d3c9be13da13f179cdfd2d Mon Sep 17 00:00:00 2001 From: "vn.py" Date: Thu, 11 Apr 2019 09:18:41 +0800 Subject: [PATCH 37/49] [Fix]close #1569 --- vnpy/app/csv_loader/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vnpy/app/csv_loader/engine.py b/vnpy/app/csv_loader/engine.py index 2c384d02..c0547036 100644 --- a/vnpy/app/csv_loader/engine.py +++ b/vnpy/app/csv_loader/engine.py @@ -107,7 +107,7 @@ class CsvLoaderEngine(BaseEngine): # Insert into DB with DB.atomic(): - for batch in chunked(db_bars, 500): + for batch in chunked(db_bars, 50): DbBarData.insert_many(batch).on_conflict_replace().execute() return start, end, count From f7db834e8b565c3880126f04aca9c5d3ac828526 Mon Sep 17 00:00:00 2001 From: "vn.py" Date: Thu, 11 Apr 2019 09:19:12 +0800 Subject: [PATCH 38/49] [Mod]change key used for get_order --- vnpy/gateway/okex/okex_gateway.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vnpy/gateway/okex/okex_gateway.py b/vnpy/gateway/okex/okex_gateway.py index 9ca18217..e4376f65 100644 --- a/vnpy/gateway/okex/okex_gateway.py +++ b/vnpy/gateway/okex/okex_gateway.py @@ -132,12 +132,12 @@ class OkexGateway(BaseGateway): def on_order(self, order: OrderData): """""" - self.orders[order.vt_orderid] = order + self.orders[order.orderid] = order super().on_order(order) - def get_order(self, vt_orderid: str): + def get_order(self, orderid: str): """""" - return self.orders.get(vt_orderid, None) + return self.orders.get(orderid, None) class OkexRestApi(RestClient): @@ -422,7 +422,7 @@ class OkexRestApi(RestClient): def on_cancel_order_failed(self, status_code: int, request: Request): """If cancel failed, mark order status to be rejected.""" req = request.extra - order = self.gateway.get_order(req.vt_orderid) + order = self.gateway.get_order(req.orderid) if order: order.status = Status.REJECTED self.gateway.on_order(order) From 70adba7637b051a59a83e31a2503f9f8c7dbd628 Mon Sep 17 00:00:00 2001 From: nanoric Date: Thu, 11 Apr 2019 00:43:46 -0400 Subject: [PATCH 39/49] =?UTF-8?q?[Mod]=20=E6=94=B9=E5=8F=98Singleton?= =?UTF-8?q?=E7=9A=84=E4=BD=BF=E7=94=A8=E8=AF=B4=E6=98=8E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- vnpy/trader/engine.py | 4 +--- vnpy/trader/utility.py | 7 ++++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/vnpy/trader/engine.py b/vnpy/trader/engine.py index a2e098ee..af136dab 100644 --- a/vnpy/trader/engine.py +++ b/vnpy/trader/engine.py @@ -197,13 +197,11 @@ class BaseEngine(ABC): pass -class LogEngine(BaseEngine): +class LogEngine(BaseEngine, metaclass=Singleton): """ Processes log event and output with logging module. """ - __metaclass__ = Singleton - def __init__(self, main_engine: MainEngine, event_engine: EventEngine): """""" super(LogEngine, self).__init__(main_engine, event_engine, "log") diff --git a/vnpy/trader/utility.py b/vnpy/trader/utility.py index 9951e1f1..c221277a 100644 --- a/vnpy/trader/utility.py +++ b/vnpy/trader/utility.py @@ -14,10 +14,11 @@ from .object import BarData, TickData class Singleton(type): """ - Singleton metaclass, + Singleton metaclass, - class A: - __metaclass__ = Singleton + usage: + class A(metaclass=Singleton): + ... """ _instances = {} From 95df49c54c7aefee9c49dee9ae234000c24345f0 Mon Sep 17 00:00:00 2001 From: nanoric Date: Thu, 11 Apr 2019 00:45:15 -0400 Subject: [PATCH 40/49] =?UTF-8?q?[Mod]=20=E6=94=B9=E5=8F=98=E8=B7=AF?= =?UTF-8?q?=E5=BE=84=E7=9B=B8=E5=85=B3=E5=87=BD=E6=95=B0=E7=9A=84=E5=91=BD?= =?UTF-8?q?=E5=90=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- vnpy/trader/ui/mainwindow.py | 4 ++-- vnpy/trader/utility.py | 10 ++++------ 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/vnpy/trader/ui/mainwindow.py b/vnpy/trader/ui/mainwindow.py index d8c6280b..b1972f42 100644 --- a/vnpy/trader/ui/mainwindow.py +++ b/vnpy/trader/ui/mainwindow.py @@ -24,7 +24,7 @@ from .widget import ( AboutDialog, ) from ..engine import MainEngine -from ..utility import get_icon_path, TRADER_PATH +from ..utility import get_icon_path, TRADER_DIR class MainWindow(QtWidgets.QMainWindow): @@ -38,7 +38,7 @@ class MainWindow(QtWidgets.QMainWindow): self.main_engine = main_engine self.event_engine = event_engine - self.window_title = f"VN Trader [{TRADER_PATH}]" + self.window_title = f"VN Trader [{TRADER_DIR}]" self.connect_dialogs = {} self.widgets = {} diff --git a/vnpy/trader/utility.py b/vnpy/trader/utility.py index c221277a..6f023259 100644 --- a/vnpy/trader/utility.py +++ b/vnpy/trader/utility.py @@ -31,7 +31,7 @@ class Singleton(type): return cls._instances[cls] -def get_path(temp_name: str): +def _get_trader_dir(temp_name: str): """ Get path where trader is running in. """ @@ -54,21 +54,21 @@ def get_path(temp_name: str): return home_path, temp_path -TRADER_PATH, TEMP_PATH = get_path(".vntrader") +TRADER_DIR, TEMP_DIR = _get_trader_dir(".vntrader") def get_file_path(filename: str): """ Get path for temp file with filename. """ - return TEMP_PATH.joinpath(filename) + return TEMP_DIR.joinpath(filename) def get_folder_path(folder_name: str): """ Get path for temp folder with folder name. """ - folder_path = TEMP_PATH.joinpath(folder_name) + folder_path = TEMP_DIR.joinpath(folder_name) if not folder_path.exists(): folder_path.mkdir() return folder_path @@ -395,5 +395,3 @@ def virtual(func: "callable"): that can be (re)implemented by subclasses. """ return func - - From bee61d79b0fb841df2c555767793528ed73b95b4 Mon Sep 17 00:00:00 2001 From: "vn.py" Date: Thu, 11 Apr 2019 14:01:29 +0800 Subject: [PATCH 41/49] [Add]CtaBackteserApp for GUI based backtesting --- tests/trader/run.py | 2 + vnpy/app/cta_backtester/__init__.py | 17 + vnpy/app/cta_backtester/engine.py | 199 +++++++++ vnpy/app/cta_backtester/ui/__init__.py | 1 + vnpy/app/cta_backtester/ui/backtester.ico | Bin 0 -> 64478 bytes vnpy/app/cta_backtester/ui/widget.py | 476 ++++++++++++++++++++++ vnpy/app/cta_strategy/__init__.py | 1 + vnpy/app/cta_strategy/backtesting.py | 5 +- 8 files changed, 699 insertions(+), 2 deletions(-) create mode 100644 vnpy/app/cta_backtester/__init__.py create mode 100644 vnpy/app/cta_backtester/engine.py create mode 100644 vnpy/app/cta_backtester/ui/__init__.py create mode 100644 vnpy/app/cta_backtester/ui/backtester.ico create mode 100644 vnpy/app/cta_backtester/ui/widget.py diff --git a/tests/trader/run.py b/tests/trader/run.py index 1ab073c5..2cda6ae2 100644 --- a/tests/trader/run.py +++ b/tests/trader/run.py @@ -16,6 +16,7 @@ from vnpy.gateway.huobi import HuobiGateway from vnpy.app.cta_strategy import CtaStrategyApp from vnpy.app.csv_loader import CsvLoaderApp from vnpy.app.algo_trading import AlgoTradingApp +from vnpy.app.cta_backtester import CtaBacktesterApp def main(): @@ -35,6 +36,7 @@ def main(): main_engine.add_gateway(HuobiGateway) main_engine.add_app(CtaStrategyApp) + main_engine.add_app(CtaBacktesterApp) main_engine.add_app(CsvLoaderApp) main_engine.add_app(AlgoTradingApp) diff --git a/vnpy/app/cta_backtester/__init__.py b/vnpy/app/cta_backtester/__init__.py new file mode 100644 index 00000000..c589e02f --- /dev/null +++ b/vnpy/app/cta_backtester/__init__.py @@ -0,0 +1,17 @@ +from pathlib import Path + +from vnpy.trader.app import BaseApp + +from .engine import BacktesterEngine, APP_NAME + + +class CtaBacktesterApp(BaseApp): + """""" + + app_name = APP_NAME + app_module = __module__ + app_path = Path(__file__).parent + display_name = "CTA回测" + engine_class = BacktesterEngine + widget_name = "BacktesterManager" + icon_name = "backtester.ico" diff --git a/vnpy/app/cta_backtester/engine.py b/vnpy/app/cta_backtester/engine.py new file mode 100644 index 00000000..621330d2 --- /dev/null +++ b/vnpy/app/cta_backtester/engine.py @@ -0,0 +1,199 @@ +import os +import importlib +from datetime import datetime +from threading import Thread +from pathlib import Path + +from vnpy.event import Event, EventEngine +from vnpy.trader.engine import BaseEngine, MainEngine +from vnpy.app.cta_strategy import ( + CtaTemplate, + BacktestingEngine, + OptimizationSetting +) + + +APP_NAME = "CtaBacktester" + +EVENT_BACKTESTER_LOG = "eBacktesterLog" +EVENT_BACKTESTER_BACKTESTING_FINISHED = "eBacktesterBacktestingFinished" +EVENT_BACKTESTER_OPTIMIZATION_FINISHED = "eBacktesterOptimizationFinished" + + +class BacktesterEngine(BaseEngine): + """ + For running CTA strategy backtesting. + """ + + def __init__(self, main_engine: MainEngine, event_engine: EventEngine): + """""" + super().__init__(main_engine, event_engine, APP_NAME) + + self.classes = {} + self.backtesting_engine = None + self.thread = None + + self.result_df = None + self.result_statistics = None + + self.load_strategy_class() + + def init_engine(self): + """""" + self.write_log("初始化CTA回测引擎") + + self.backtesting_engine = BacktestingEngine() + # Redirect log from backtesting engine outside. + self.backtesting_engine.output = self.write_log + + self.write_log("策略文件加载完成") + + def write_log(self, msg: str): + """""" + event = Event(EVENT_BACKTESTER_LOG) + event.data = msg + self.event_engine.put(event) + + def load_strategy_class(self): + """ + Load strategy class from source code. + """ + app_path = Path(__file__).parent.parent + path1 = app_path.joinpath("cta_strategy", "strategies") + self.load_strategy_class_from_folder( + path1, "vnpy.app.cta_strategy.strategies") + + path2 = Path.cwd().joinpath("strategies") + self.load_strategy_class_from_folder(path2, "strategies") + + def load_strategy_class_from_folder(self, path: Path, module_name: str = ""): + """ + Load strategy class from certain folder. + """ + for dirpath, dirnames, filenames in os.walk(path): + for filename in filenames: + if filename.endswith(".py"): + strategy_module_name = ".".join( + [module_name, filename.replace(".py", "")]) + self.load_strategy_class_from_module(strategy_module_name) + + def load_strategy_class_from_module(self, module_name: str): + """ + Load strategy class from module file. + """ + try: + module = importlib.import_module(module_name) + + for name in dir(module): + value = getattr(module, name) + if (isinstance(value, type) and issubclass(value, CtaTemplate) and value is not CtaTemplate): + self.classes[value.__name__] = value + except: # noqa + msg = f"策略文件{module_name}加载失败,触发异常:\n{traceback.format_exc()}" + self.write_log(msg) + + def get_strategy_class_names(self): + """""" + return list(self.classes.keys()) + + def run_backtesting( + self, + class_name: str, + vt_symbol: str, + interval: str, + start: datetime, + end: datetime, + rate: float, + slippage: float, + size: int, + pricetick: float, + capital: int, + setting: dict + ): + """""" + self.result_df = None + self.result_statistics = None + + engine = self.backtesting_engine + engine.clear_data() + + engine.set_parameters( + vt_symbol=vt_symbol, + interval=interval, + start=start, + end=end, + rate=rate, + slippage=slippage, + size=size, + pricetick=pricetick, + capital=capital + ) + + strategy_class = self.classes[class_name] + engine.add_strategy( + strategy_class, + setting + ) + + engine.load_data() + engine.run_backtesting() + self.result_df = engine.calculate_result() + self.result_statistics = engine.calculate_statistics(output=False) + + # Clear thread object handler. + self.thread = None + + # Put backtesting done event + event = Event(EVENT_BACKTESTER_BACKTESTING_FINISHED) + self.event_engine.put(event) + + def start_backtesting( + self, + class_name: str, + vt_symbol: str, + interval: str, + start: datetime, + end: datetime, + rate: float, + slippage: float, + size: int, + pricetick: float, + capital: int, + setting: dict + ): + if self.thread: + self.write_log("已有回测在运行中,请等待完成") + return False + + self.thread = Thread( + target=self.run_backtesting, + args=( + class_name, + vt_symbol, + interval, + start, + end, + rate, + slippage, + size, + pricetick, + capital, + setting + ) + ) + self.thread.start() + + return True + + def get_result_df(self): + """""" + return self.result_df + + def get_result_statistics(self): + """""" + return self.result_statistics + + def get_default_setting(self, class_name: str): + """""" + strategy_class = self.classes[class_name] + return strategy_class.get_class_parameters() diff --git a/vnpy/app/cta_backtester/ui/__init__.py b/vnpy/app/cta_backtester/ui/__init__.py new file mode 100644 index 00000000..a02dafb3 --- /dev/null +++ b/vnpy/app/cta_backtester/ui/__init__.py @@ -0,0 +1 @@ +from .widget import BacktesterManager diff --git a/vnpy/app/cta_backtester/ui/backtester.ico b/vnpy/app/cta_backtester/ui/backtester.ico new file mode 100644 index 0000000000000000000000000000000000000000..647b8ca91d4d0658d1ea00133f8f1c1e49bbd007 GIT binary patch literal 64478 zcmeI5d#oH)9mluah6`9IZ%cXHLcw6e8zU$}D~J`O;VBBH)L2Bs#%QX58s#>Cny5kf z!^B5Id?fh7$RGTHG4v)xgGNN9D1t%-wSuUCtu3^4{d{NV%$+@VchBrRc4vEkgP)!` zGjo3DobT_P-^`rdy*3(i_;2zP4gOu!cwl>@ad4y2I2I;hGw67uOb8&m-kRAq? zz{LewLHT{a*;Xq13!n`97h-3>*ruyP_78%KplJJLL$+iyZC`b&?xOakfwV~5S3S~9 z+Ld8nzKrZ=^#5G0Uj@w)cnKS)!fxR7ll@Df4EwV2CKv-}U;E3UEc>$M+gF`s*_Rg5 z#I;{qNi%6zMjflL6WQ;>|0_$Xdoy-U0pI@`*Jbog8lB{wD=#^&*FbrBR3!Y)Z%|wJB$Bejr?)t@)rN| zYbOmY{)hFm_#f8D%|}1A_#f8J;(u5pe_Xk|#sB=;NkfbOVf`%rhc$BZ(N8V@hqbf# zAJ)hpS1xbyKfiX;(BgksKa2lijof_nQ;Yv$?dll+%l!T4nZ%j&@AK^Uf9+$UZvNNr zCy#*};QE5BrTn3g{{6K1zux|LzjLlCHUHPt|DoTX7d=lL(3mp+oB!?i|I=fjHuwLc z^E*`D=6}b(7h(RNo(ul_V<>OyKl!zjwr%S_q4`_%yv=``|MvVhV+>gQ&mY4q{^!@u z;(s>nX3hG4(ffO-yxsq2#D2}eJn{bu%58-g3-S-jp9s<4{|ouA=y~d~^`DtBr(geT zEqx*A?_R#OAZJrvf6Fw}-yb#qN8?v>Wte~eQFBQC(mHBU)R*M5nWFv;m8UM7{{`lM zNZ-1jH~$yA|JU8Wq4t^o&Hs#nL4E$Wow`En$63zT=KU`;zv_N|fb;(}>Us(`733Mp zpAWV8``>lW9fIEd#zJ-2J4NN2>Hu5+%c*p%pVsFOfb&3~trS5vWGng(OU`CxE=E1G*8lfhytI*pUSh`j zRkbePtLjqD%|`!wKHrV}9&i}E1tNV86({UuWAf*tFO%Cv(5i0f6ih5c(?;L z!jrHG{vPiu|7AD|bZs;a%GACNwq}d}o&Vx=Z_?{PzTO7SF8q(<8XK3xUJ#50b?|9F zZP?ZQpY~b2uh{-2*b3eJ?jL`VYd;P9!Y(xiM#2AS&C~mx#(DaDhkecAtKk*!eH_Yb zyj%ctVVAiV)Ybn^V=vY7SaL7+&L#6QbzHuh2Mgd>P|W`)gnZrY{O>4x44D6WuXW^J z|99HPsr;>b^+GrrPKQ%K&t$KLwtF&p`P*H;gniA|&%@)O`~1t$&9_;O@1XoA;Z?9p zjDb4Mf2V1j>iz3=unsiuUIOXA9Tbb#!-qhAJ_3$|^I#Gl0(ZS){Z7c@-_SL(GX)2N z`M-M%=zIPhau#XN4Q4*)|ilg%JOx$dKE_whsgg*M|tz+{}%FD)~7?) zT!o$J`{h}+N$Gh6``?Cc`VafK?jg%T_ok7L0V)6MH`7mmem`A==O}+UWcI()kyJ7N zV=!zV_o`p<-TN1HD5h?|#FN3=IwBPpV-S94W zBa}V{(l+D$lS#h_--Idn4{U*-!FNEPb1j9WE!wZsmTKD|*qb+6GaBU*hSpFk~`jh9pKIi`d+ze~rBG7vA z+u&CC9y|(pYQ0w=4jUhZJ)m#hT^TQZ zJ9*hn*OO&`Gi6r6IP~2X=c_8xxILODXG7kYtaZ6x!h7M(un0Q&JL)^q;Pe;ym(0FBKHK=CG-hwLZqjQ8i0o(?y_ z-LM|?8K}lm);hag_c?56JSBC8v>C+tRC^Y&|6S$`+{4*4@Y2zo~U3LFRWhu+yV_V0n8!y`~e{PElOEOs>x_kv^$Xn&Ynsx6wo zhr$~8J^1bS^Xs|h5|D<;+%MV}Y&mT-Pxbup7x3#Yn%{~|^-<$N&%RaoyHgwgN?p2F zUI}`aaD6Q*TcPslcCfE|fnwliKyzUSD8`GI;C@&O$HEwt-iMUUx$Ntm{t4)$=cpdj zy-n|j?*hfn(&kUzfj}D)KeE3Y=w3JhOW;gUoRrTMBihC;`MlM3Ehn2XYRmC(A87oI zs?R&M`x*N4ThKVr`=$CC^}DMu6`yeRd#T!$jM4PDDw~t4M|B?ocYxlfI%zq|$8`VE zbEd|D=3QG{C9nRs`JhVOoo(uLbFeLS=LeB`Z8(tp&7kM{QTBJIw(I`-0%#n(6Z9UW zdxO?zP6FMN7sB44XQL!mus;Y_46P3JFNNad2VqdpY@K{nJ>1E@INPxkc=r8%PeXO+Gfwr!0CzYRLJ9Uh1GgVwi_d-DKX zF|;b^XEN5=-yN2N=83P-Y{+XIh_A!p5NR_s=M1o7*H==Ffs3HiJedt+!1eLNXfO`5 zaSXhces2S(MOCTp!B2wLM08))8d6nt^=nsaSj%8G`d_i{R~7g|&s*~ILofv&2d#nZ z3tAKT8fXn?13U#<+jDD*53>I>yaN6DGVk?zCOjAh7>S2lmHG;LA6N%@b#vvkuApbC zU%-7Jn`goP5Y4L&sn$vt!^NQIQLT^Z*;HHxcZ0^oAns3!OL`}mjWIy|mhYS|`%NFC z>}_x^=zY7Rzujo>cD}>$JvnwI{1`M2`fXRg*Q@=r(f@i5yaM#z+^=RTBOm!|ZiRHj zUV02^9cLB%2>u5BwlVAVibI9k?Z)`1q^`fDnzvtoe%=Rlt;!^GzU+QuJ3S6`uUZX1 zgnrtY)&K9nZnQ2~c0cNDk3Kt%wHERL&~sTxGhL%=Ri^SezK-dfu5sgLV48V#9!P!-=>4teGvZ^|I2i`k zmzs|Q#leMe21sL}eI1KpB3B) z8RG3iy5v<#&x%$*Kt8H=W1L^g0f`?sB9cwbc9rPBwu#8dv<1qvMAf02$g#y| zo#UF5?!eOMv>M;qCR%w{VeGi10WP0zfJ!L;9$ARTH~F!sP@Io?DH-B?OI$xcm7H%m ugI2`tX*mn4F0V>eYg<05B|6jk$Ir(uoxZ{iY= 0: + profit_pnl_height.append(pnl) + profit_pnl_x.append(count) + else: + loss_pnl_height.append(pnl) + loss_pnl_x.append(count) + + self.profit_pnl_bar.setOpts(x=profit_pnl_x, height=profit_pnl_height) + self.loss_pnl_bar.setOpts(x=loss_pnl_x, height=loss_pnl_height) + + # Set data for pnl distribution + hist, x = np.histogram(df["net_pnl"], bins="auto") + x = x[:-1] + self.distribution_curve.setData(x, hist) + + +class DateAxis(pg.AxisItem): + """Axis for showing date data""" + + def __init__(self, dates: dict, *args, **kwargs): + """""" + super().__init__(*args, **kwargs) + self.dates = dates + + def tickStrings(self, values, scale, spacing): + """""" + strings = [] + for v in values: + dt = self.dates.get(v, "") + strings.append(str(dt)) + return strings diff --git a/vnpy/app/cta_strategy/__init__.py b/vnpy/app/cta_strategy/__init__.py index 9d079f9b..e753746b 100644 --- a/vnpy/app/cta_strategy/__init__.py +++ b/vnpy/app/cta_strategy/__init__.py @@ -7,6 +7,7 @@ from vnpy.trader.utility import BarGenerator, ArrayManager from .base import APP_NAME, StopOrder from .engine import CtaEngine +from .backtesting import BacktestingEngine, OptimizationSetting from .template import CtaTemplate, CtaSignal, TargetPosTemplate diff --git a/vnpy/app/cta_strategy/backtesting.py b/vnpy/app/cta_strategy/backtesting.py index 39768844..5712f798 100644 --- a/vnpy/app/cta_strategy/backtesting.py +++ b/vnpy/app/cta_strategy/backtesting.py @@ -293,7 +293,7 @@ class BacktestingEngine: self.output("逐日盯市盈亏计算完成") return self.daily_df - def calculate_statistics(self, df: DataFrame = None, Output=True): + def calculate_statistics(self, df: DataFrame = None, output=True): """""" self.output("开始计算策略统计指标") @@ -377,7 +377,7 @@ class BacktestingEngine: return_drawdown_ratio = -total_return / max_ddpercent # Output - if Output: + if output: self.output("-" * 30) self.output(f"首个交易日:\t{start_date}") self.output(f"最后交易日:\t{end_date}") @@ -417,6 +417,7 @@ class BacktestingEngine: "total_days": total_days, "profit_days": profit_days, "loss_days": loss_days, + "capital": self.capital, "end_balance": end_balance, "max_drawdown": max_drawdown, "max_ddpercent": max_ddpercent, From 8684e4c15ce9269e1a867a03eb0c3811fe7dcf4a Mon Sep 17 00:00:00 2001 From: "vn.py" Date: Thu, 11 Apr 2019 14:04:35 +0800 Subject: [PATCH 42/49] [Fix]add traceback import --- vnpy/app/cta_backtester/engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vnpy/app/cta_backtester/engine.py b/vnpy/app/cta_backtester/engine.py index 621330d2..adcab59d 100644 --- a/vnpy/app/cta_backtester/engine.py +++ b/vnpy/app/cta_backtester/engine.py @@ -1,5 +1,6 @@ import os import importlib +import traceback from datetime import datetime from threading import Thread from pathlib import Path @@ -8,8 +9,7 @@ from vnpy.event import Event, EventEngine from vnpy.trader.engine import BaseEngine, MainEngine from vnpy.app.cta_strategy import ( CtaTemplate, - BacktestingEngine, - OptimizationSetting + BacktestingEngine ) From b83599bc54a0485deeee02d78551da6714beeff2 Mon Sep 17 00:00:00 2001 From: "vn.py" Date: Thu, 11 Apr 2019 15:25:54 +0800 Subject: [PATCH 43/49] [Fix]close #1580 --- vnpy/gateway/ctp/ctp_gateway.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vnpy/gateway/ctp/ctp_gateway.py b/vnpy/gateway/ctp/ctp_gateway.py index f1a25811..057e6db8 100644 --- a/vnpy/gateway/ctp/ctp_gateway.py +++ b/vnpy/gateway/ctp/ctp_gateway.py @@ -553,7 +553,7 @@ class CtpTdApi(TdApi): ) # For option only - if data["OptionsType"]: + if contract.product == Product.OPTION: contract.option_underlying = data["UnderlyingInstrID"], contract.option_type = OPTIONTYPE_CTP2VT.get(data["OptionsType"], None), contract.option_strike = data["StrikePrice"], From 1d6506c5f3a57b37e30351c7b7e414351a54c867 Mon Sep 17 00:00:00 2001 From: "vn.py" Date: Thu, 11 Apr 2019 16:54:43 +0800 Subject: [PATCH 44/49] [Del]remove Singleton --- vnpy/trader/engine.py | 4 ++-- vnpy/trader/utility.py | 19 ------------------- 2 files changed, 2 insertions(+), 21 deletions(-) diff --git a/vnpy/trader/engine.py b/vnpy/trader/engine.py index af136dab..2d4a98e0 100644 --- a/vnpy/trader/engine.py +++ b/vnpy/trader/engine.py @@ -24,7 +24,7 @@ from .event import ( from .gateway import BaseGateway from .object import CancelRequest, LogData, OrderRequest, SubscribeRequest from .setting import SETTINGS -from .utility import Singleton, get_folder_path +from .utility import get_folder_path class MainEngine: @@ -197,7 +197,7 @@ class BaseEngine(ABC): pass -class LogEngine(BaseEngine, metaclass=Singleton): +class LogEngine(BaseEngine): """ Processes log event and output with logging module. """ diff --git a/vnpy/trader/utility.py b/vnpy/trader/utility.py index 6f023259..1f40aa77 100644 --- a/vnpy/trader/utility.py +++ b/vnpy/trader/utility.py @@ -12,25 +12,6 @@ import talib from .object import BarData, TickData -class Singleton(type): - """ - Singleton metaclass, - - usage: - class A(metaclass=Singleton): - ... - """ - _instances = {} - - def __call__(cls, *args, **kwargs): - """""" - if cls not in cls._instances: - cls._instances[cls] = super(Singleton, cls).__call__( - *args, **kwargs - ) - return cls._instances[cls] - - def _get_trader_dir(temp_name: str): """ Get path where trader is running in. From b3961dbb84cbeb1c75c6dc63f55e1e0b37c811b8 Mon Sep 17 00:00:00 2001 From: "vn.py" Date: Thu, 11 Apr 2019 16:56:19 +0800 Subject: [PATCH 45/49] [Add]history data cache in cta backtesting to improve speed --- vnpy/app/cta_backtester/engine.py | 1 + vnpy/app/cta_strategy/backtesting.py | 71 ++++++++++++++++++++-------- 2 files changed, 53 insertions(+), 19 deletions(-) diff --git a/vnpy/app/cta_backtester/engine.py b/vnpy/app/cta_backtester/engine.py index adcab59d..4dc80589 100644 --- a/vnpy/app/cta_backtester/engine.py +++ b/vnpy/app/cta_backtester/engine.py @@ -165,6 +165,7 @@ class BacktesterEngine(BaseEngine): self.write_log("已有回测在运行中,请等待完成") return False + self.write_log("-" * 40) self.thread = Thread( target=self.run_backtesting, args=( diff --git a/vnpy/app/cta_strategy/backtesting.py b/vnpy/app/cta_strategy/backtesting.py index 5712f798..c1ba8e67 100644 --- a/vnpy/app/cta_strategy/backtesting.py +++ b/vnpy/app/cta_strategy/backtesting.py @@ -2,6 +2,7 @@ from collections import defaultdict from datetime import date, datetime from typing import Callable from itertools import product +from functools import lru_cache import multiprocessing import numpy as np @@ -197,28 +198,18 @@ class BacktestingEngine: self.output("开始加载历史数据") if self.mode == BacktestingMode.BAR: - s = ( - DbBarData.select() - .where( - (DbBarData.vt_symbol == self.vt_symbol) - & (DbBarData.interval == self.interval) - & (DbBarData.datetime >= self.start) - & (DbBarData.datetime <= self.end) - ) - .order_by(DbBarData.datetime) + self.history_data = load_bar_data( + self.vt_symbol, + self.interval, + self.start, + self.end ) - self.history_data = [db_bar.to_bar() for db_bar in s] else: - s = ( - DbTickData.select() - .where( - (DbTickData.vt_symbol == self.vt_symbol) - & (DbTickData.datetime >= self.start) - & (DbTickData.datetime <= self.end) - ) - .order_by(DbTickData.datetime) + self.history_data = load_tick_data( + self.vt_symbol, + self.start, + self.end ) - self.history_data = [db_tick.to_tick() for db_tick in s] self.output(f"历史数据加载完成,数据量:{len(self.history_data)}") @@ -970,3 +961,45 @@ def optimize( target_value = statistics[target_name] return (str(setting), target_value, statistics) + + +@lru_cache(maxsize=10) +def load_bar_data( + vt_symbol: str, + interval: str, + start: datetime, + end: datetime +): + """""" + s = ( + DbBarData.select() + .where( + (DbBarData.vt_symbol == vt_symbol) + & (DbBarData.interval == interval) + & (DbBarData.datetime >= start) + & (DbBarData.datetime <= end) + ) + .order_by(DbBarData.datetime) + ) + data = [db_bar.to_bar() for db_bar in s] + return data + + +@lru_cache(maxsize=10) +def load_tick_data( + vt_symbol: str, + start: datetime, + end: datetime +): + """""" + s = ( + DbTickData.select() + .where( + (DbTickData.vt_symbol == vt_symbol) + & (DbTickData.datetime >= start) + & (DbTickData.datetime <= end) + ) + .order_by(DbTickData.datetime) + ) + data = [db_tick.db_tick() for db_tick in s] + return data \ No newline at end of file From 7490a5c0ad5a56aec66f4488d6cddd51647976d0 Mon Sep 17 00:00:00 2001 From: nanoric Date: Thu, 11 Apr 2019 05:26:51 -0400 Subject: [PATCH 46/49] [Add] add support for mysql --- requirements.txt | 1 + vnpy/trader/database.py | 42 +++++++++++++++++++++++++++++++++++++---- vnpy/trader/engine.py | 2 +- vnpy/trader/setting.py | 11 +++++++++-- vnpy/trader/utility.py | 7 +++++++ 5 files changed, 56 insertions(+), 7 deletions(-) diff --git a/requirements.txt b/requirements.txt index 2b64ee1d..ee935ce6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ PyQt5<5.12 +pyqtgraph dataclasses; python_version<="3.6" qdarkstyle requests diff --git a/vnpy/trader/database.py b/vnpy/trader/database.py index 87fb3eae..1e90dcbc 100644 --- a/vnpy/trader/database.py +++ b/vnpy/trader/database.py @@ -1,13 +1,47 @@ """""" -from peewee import SqliteDatabase, Model, CharField, DateTimeField, FloatField +from peewee import CharField, DateTimeField, FloatField, Model, MySQLDatabase, PostgresqlDatabase, \ + SqliteDatabase from .constant import Exchange, Interval from .object import BarData, TickData -from .utility import get_file_path +from .setting import SETTINGS +from .utility import resolve_path -DB_NAME = "database.db" -DB = SqliteDatabase(str(get_file_path(DB_NAME))) + +def init(): + db_settings = SETTINGS['database'] + driver = db_settings["driver"] + + init_funcs = { + "sqlite": init_sqlite, + "mysql": init_mysql, + "postgresql": init_postgresql, + } + + assert driver in init_funcs + del db_settings['driver'] + return init_funcs[driver](db_settings) + + +def init_sqlite(settings: dict): + global DB + database = settings['database'] + + DB = SqliteDatabase(str(resolve_path(database))) + + +def init_mysql(settings: dict): + global DB + DB = MySQLDatabase(**settings) + + +def init_postgresql(settings: dict): + global DB + DB = PostgresqlDatabase(**settings) + + +init() class DbBarData(Model): diff --git a/vnpy/trader/engine.py b/vnpy/trader/engine.py index af136dab..f7cf95df 100644 --- a/vnpy/trader/engine.py +++ b/vnpy/trader/engine.py @@ -197,7 +197,7 @@ class BaseEngine(ABC): pass -class LogEngine(BaseEngine, metaclass=Singleton): +class LogEngine(BaseEngine): """ Processes log event and output with logging module. """ diff --git a/vnpy/trader/setting.py b/vnpy/trader/setting.py index ae2011f9..dc82fad7 100644 --- a/vnpy/trader/setting.py +++ b/vnpy/trader/setting.py @@ -23,10 +23,17 @@ SETTINGS = { "email.receiver": "", "rqdata.username": "", - "rqdata.password": "" + "rqdata.password": "", + "database": { + "driver": "sqlite", + "database": "{VNPY_TEMP}/database.db", # for sqlite, use this as filepath + "host": "localhost", + "port": 3306, + "user": "root", + "password": "" + } } - # Load global setting from json file. SETTING_FILENAME = "vt_setting.json" SETTINGS.update(load_json(SETTING_FILENAME)) diff --git a/vnpy/trader/utility.py b/vnpy/trader/utility.py index 6f023259..90dc6b2e 100644 --- a/vnpy/trader/utility.py +++ b/vnpy/trader/utility.py @@ -3,6 +3,7 @@ General utility functions. """ import json +import os from pathlib import Path from typing import Callable @@ -31,6 +32,12 @@ class Singleton(type): return cls._instances[cls] +def resolve_path(pattern: str): + env = dict(os.environ) + env.update({"VNPY_TEMP": str(TEMP_DIR)}) + return pattern.format(**env) + + def _get_trader_dir(temp_name: str): """ Get path where trader is running in. From 3134b692a6f290e6aba822cc546b42684099c3d6 Mon Sep 17 00:00:00 2001 From: nanoric Date: Thu, 11 Apr 2019 05:28:05 -0400 Subject: [PATCH 47/49] [Add] added comment --- vnpy/trader/setting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vnpy/trader/setting.py b/vnpy/trader/setting.py index dc82fad7..b27dbf96 100644 --- a/vnpy/trader/setting.py +++ b/vnpy/trader/setting.py @@ -25,7 +25,7 @@ SETTINGS = { "rqdata.username": "", "rqdata.password": "", "database": { - "driver": "sqlite", + "driver": "sqlite", # sqlite, mysql, postgresql "database": "{VNPY_TEMP}/database.db", # for sqlite, use this as filepath "host": "localhost", "port": 3306, From 4c83b08315c76ef3ae2e17d4f3a8009976358718 Mon Sep 17 00:00:00 2001 From: "vn.py" Date: Sat, 13 Apr 2019 08:44:43 +0800 Subject: [PATCH 48/49] [Fix]spread contract with no size info will cause no position update --- vnpy/gateway/ctp/ctp_gateway.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/vnpy/gateway/ctp/ctp_gateway.py b/vnpy/gateway/ctp/ctp_gateway.py index 057e6db8..6dd92450 100644 --- a/vnpy/gateway/ctp/ctp_gateway.py +++ b/vnpy/gateway/ctp/ctp_gateway.py @@ -487,11 +487,6 @@ class CtpTdApi(TdApi): ) self.positions[key] = position - # Get contract size, return if size value not collected - size = symbol_size_map.get(position.symbol, None) - if not size: - return - # For SHFE position data update if position.exchange == Exchange.SHFE: if data["YdPosition"] and not data["TodayPosition"]: @@ -500,6 +495,9 @@ class CtpTdApi(TdApi): else: position.yd_volume = data["Position"] - data["TodayPosition"] + # Get contract size (spread contract has no size value) + size = symbol_size_map.get(position.symbol, 0) + # Calculate previous position cost cost = position.price * position.volume * size @@ -508,7 +506,7 @@ class CtpTdApi(TdApi): position.pnl += data["PositionProfit"] # Calculate average position price - if position.volume: + if position.volume and size: cost += data["PositionCost"] position.price = cost / (position.volume * size) From c765a8f123c7b2769491e1440ad1f16a398214f7 Mon Sep 17 00:00:00 2001 From: 1122455801 Date: Sat, 13 Apr 2019 18:37:09 +0800 Subject: [PATCH 49/49] Update engine.py --- vnpy/app/csv_loader/engine.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/vnpy/app/csv_loader/engine.py b/vnpy/app/csv_loader/engine.py index c0547036..fa88afe7 100644 --- a/vnpy/app/csv_loader/engine.py +++ b/vnpy/app/csv_loader/engine.py @@ -41,17 +41,17 @@ class CsvLoaderEngine(BaseEngine): """""" super().__init__(main_engine, event_engine, APP_NAME) - self.file_path: str = '' + self.file_path: str = "" self.symbol: str = "" self.exchange: Exchange = Exchange.SSE self.interval: Interval = Interval.MINUTE - self.datetime_head: str = '' - self.open_head: str = '' - self.close_head: str = '' - self.low_head: str = '' - self.high_head: str = '' - self.volume_head: str = '' + self.datetime_head: str = "" + self.open_head: str = "" + self.close_head: str = "" + self.low_head: str = "" + self.high_head: str = "" + self.volume_head: str = "" def load( self, @@ -74,7 +74,7 @@ class CsvLoaderEngine(BaseEngine): end = None count = 0 - with open(file_path, 'rt') as f: + with open(file_path, "rt") as f: reader = csv.DictReader(f) db_bars = []