diff --git a/.travis.yml b/.travis.yml index d412e7a8..8622616d 100644 --- a/.travis.yml +++ b/.travis.yml @@ -29,6 +29,8 @@ matrix: - choco install python3 --version 3.7.2 install: - python -m pip install --upgrade pip wheel setuptools + - pip install https://pip.vnpy.com/colletion/TA_Lib-0.4.17-cp37-cp37m-win_amd64.whl + - pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl - pip install -r requirements.txt - pip install . @@ -43,8 +45,10 @@ matrix: - sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-8 90 - sudo update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++-8 90 - sudo update-alternatives --install /usr/bin/gcc cc /usr/bin/gcc-8 90 - # Linux install script + # update pip & setuptools - python -m pip install --upgrade pip wheel setuptools + # Linux install script + - pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl - bash ./install.sh - name: "pip install under Ubuntu: gcc-7" @@ -58,8 +62,10 @@ matrix: - sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-7 90 - sudo update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++-7 90 - sudo update-alternatives --install /usr/bin/gcc cc /usr/bin/gcc-7 90 - # Linux install script + # update pip & setuptools - python -m pip install --upgrade pip wheel setuptools + # Linux install script + - pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl - bash ./install.sh - name: "sdist install under Windows" diff --git a/README.md b/README.md index ef84a6f4..39a31811 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,18 @@ # By Traders, For Traders.

- +

-vn.py是一套基于Python的开源量化交易系统开发框架,自2015年1月正式发布以来,在开源社区5年持续不断的贡献下一步步成长为全功能量化交易平台,目前国内外金融机构用户已经超过300家,包括:私募基金、证券自营和资管、期货资管和子公司、高校研究机构、自营交易公司、交易所、Token Fund等。 +

+ + + + + +

+ +vn.py是一套基于Python的开源量化交易系统开发框架,于2015年1月正式发布,在开源社区5年持续不断的贡献下一步步成长为全功能量化交易平台,目前国内外金融机构用户已经超过300家,包括:私募基金、证券自营和资管、期货资管和子公司、高校研究机构、自营交易公司、交易所、Token Fund等。 2.0版本基于Python 3.7全新重构开发,目前功能还在逐步完善中。如需Python 2上的版本请点击:[长期支持版本v1.9.2 LTS](https://github.com/vnpy/vnpy/tree/v1.9.2-LTS)。 @@ -14,17 +22,23 @@ vn.py是一套基于Python的开源量化交易系统开发框架,自2015年1 2. 覆盖国内外所有交易品种的交易接口(vnpy.gateway): - * CTP(ctpGateway):国内期货、期权 + * CTP(ctp):国内期货、期权 - * 富途证券(futuGateway):港股、美股 + * 宽睿(oes):A股 - * Interactive Brokers(ibGateway):全球证券、期货、期权、外汇等 + * 富途证券(futu):港股、美股 - * BitMEX (bitmexGateway):数字货币期货、期权、永续合约 + * 老虎证券(tiger):全球证券、期货、期权、外汇等 + + * Interactive Brokers(ib):全球证券、期货、期权、外汇等 + + * BitMEX (bitmex):数字货币期货、期权、永续合约 3. 开箱即用的各类量化策略交易应用(vnpy.app): - * CtaStrategy:CTA策略引擎模块,在保持易用性的同时,允许用户针对CTA类策略运行过程中委托的报撤行为进行细粒度控制(降低交易滑点、实现高频策略) + * cta_strategy:CTA策略引擎模块,在保持易用性的同时,允许用户针对CTA类策略运行过程中委托的报撤行为进行细粒度控制(降低交易滑点、实现高频策略) + + * csv_loader:CSV历史数据加载器,用于加载CSV格式文件中的历史数据到平台数据库中,用于策略的回测研究以及实盘初始化等功能,支持自定义数据表头格式 4. Python交易API接口封装(vnpy.api),提供上述交易接口的底层对接实现。 @@ -36,11 +50,9 @@ vn.py是一套基于Python的开源量化交易系统开发框架,自2015年1 ## 环境准备 -* 推荐使用vn.py团队为量化交易专门打造的Python发行版[VNConda-2.0-Windows-x86_64](https://conda.vnpy.com/VNConda-2.0-Windows-x86_64.exe),内置了最新版的vn.py,无需手动安装 +* 推荐使用vn.py团队为量化交易专门打造的Python发行版[VNConda-2.0.1-Windows-x86_64](https://conda.vnpy.com/VNConda-2.0.1-Windows-x86_64.exe),内置了最新版的vn.py框架以及VN Station量化管理平台,无需手动安装 * 支持的系统版本:Windows 7以上/Windows Server 2008以上/Ubuntu 18.04 LTS * 支持的Python版本:Python 3.7 64位(**注意必须是Python 3.7 64位版本**) -* 如需使用IB API,请在[Interactive Brokers Github](https://interactivebrokers.github.io/#)页面下载安装**IB API Latest** - ## 安装步骤 @@ -59,15 +71,20 @@ vn.py是一套基于Python的开源量化交易系统开发框架,自2015年1 1. 在[SimNow](http://www.simnow.com.cn/)注册CTP仿真账号,并在[该页面](http://www.simnow.com.cn/product.action)获取经纪商代码以及交易行情服务器地址。 -2. 在[vn.py社区论坛](https://www.vnpy.com/forum/)注册获得VN Station账号密码,论坛最新的注册邀请码为**El86Pa1p** +2. 在[vn.py社区论坛](https://www.vnpy.com/forum/)注册获得VN Station账号密码(论坛账号密码即是) 3. 启动VN Station(安装VNConda后会在桌面自动创建快捷方式),输入上一步的账号密码登录 -4. 点击底部的**VN Trader**按钮,选择运行目录(默认在系统用户目录即可)后,在对话框中勾选CTP接口以及CtaStrategy应用,点击右下方的**启动**按钮,开始你的交易!!! +4. 点击底部的**VN Trader Lite**按钮,开始你的交易!!! -5. 在VN Trader的运行过程中请勿关闭VN Station(会自动退出) +注意: +* 在VN Trader的运行过程中请勿关闭VN Station(会自动退出) +* 如需要灵活配置量化交易应用组件,请使用**VN Trader Pro** -6. 如选择了VNConda以外的安装方式(不推荐新手),可以在任意目录下创建run.py,写入以下示例代码后运行: + +## 脚本运行 + +除了基于VN Station的图形化启动方式外,也可以在任意目录下创建run.py,写入以下示例代码: ```Python from vnpy.event import EventEngine @@ -77,7 +94,7 @@ from vnpy.gateway.ctp import CtpGateway from vnpy.app.cta_strategy import CtaStrategyApp def main(): - """启动VN Trader""" + """Start VN Trader""" qapp = create_qapp() event_engine = EventEngine() @@ -95,11 +112,13 @@ if __name__ == "__main__": main() ``` +在该目录下打开CMD(按住Shift->点击鼠标右键->在此处打开命令窗口/PowerShell)后运行下列命令启动VN Trader: + python run.py ## 贡献代码 -vn.py使用github托管其源代码,如果希望贡献代码请使用github的PR(Pull Request)的流程: +vn.py使用Github托管其源代码,如果希望贡献代码请使用github的PR(Pull Request)的流程: 1. [创建 Issue](https://github.com/vnpy/vnpy/issues/new) - 对于较大的改动(如新功能,大型重构等)最好先开issue讨论一下,较小的improvement(如文档改进,bugfix等)直接发PR即可 diff --git a/docs/conf.py b/docs/conf.py index 40ed53f2..7266f6b4 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -28,7 +28,6 @@ version = '2.0' # The full version, including alpha/beta/rc tags release = '2.0-DEV' - # -- General configuration --------------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here. @@ -75,19 +74,20 @@ exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] # The name of the Pygments (syntax highlighting) style to use. pygments_style = None - # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # html_theme = 'alabaster' - # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. # -# html_theme_options = {} +html_theme_options = { + "base_bg": "inherit", + "narrow_sidebar_bg": "inherit", +} # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, @@ -110,7 +110,6 @@ html_static_path = ['_static'] # Output file base name for HTML help builder. htmlhelp_basename = 'vnpydoc' - # -- Options for LaTeX output ------------------------------------------------ latex_elements = { @@ -139,7 +138,6 @@ latex_documents = [ 'vn.py Team', 'manual'), ] - # -- Options for manual page output ------------------------------------------ # One entry per manual page. List of tuples @@ -149,7 +147,6 @@ man_pages = [ [author], 1) ] - # -- Options for Texinfo output ---------------------------------------------- # Grouping the document tree into Texinfo files. List of tuples @@ -161,7 +158,6 @@ texinfo_documents = [ 'Miscellaneous'), ] - # -- Options for Epub output ------------------------------------------------- # Bibliographic Dublin Core info. diff --git a/docs/csv_loader.md b/docs/csv_loader.md new file mode 100644 index 00000000..e69de29b diff --git a/docs/cta_strategy.md b/docs/cta_strategy.md index e10b99d0..62d6106d 100644 --- a/docs/cta_strategy.md +++ b/docs/cta_strategy.md @@ -1 +1,23 @@ -# Introduction +# CTA策略模块 + + +## 模块构成 + + +## 历史数据 + + + +## 策略开发 + + +## 回测研究 + + + +## 参数优化 + + + +## 实盘运行 + diff --git a/docs/gateway.md b/docs/gateway.md new file mode 100644 index 00000000..e69de29b diff --git a/docs/index.md b/docs/index.md index b5eea238..c9ce72c3 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,6 +1,15 @@ # vn.py文档 -* [vn.py简介](introduction.md) -* [项目安装](install.md) -* [基本使用](quickstart.md) -* [CTA策略模块](cta_strategy.md) \ No newline at end of file +* 快速入门 + * [项目简介](introduction.md) + * [环境安装](install.md) + * [基本使用](quickstart.md) + +* 应用模块 + * [CTA策略](cta_strategy.md) + * [CSV载入](csv_loader.md) + +* [交易接口](gateway.md) + +* [RPC应用](rpc.md) +* [贡献代码](contribution.md) \ No newline at end of file diff --git a/docs/install.md b/docs/install.md index bf0de3b5..f09fdb75 100644 --- a/docs/install.md +++ b/docs/install.md @@ -1,8 +1,33 @@ # 安装指南 +## Windows + +### 使用VNConda + +**ssleay32.dll问题** + +如果电脑上之前安装过其他的Python环境或者应用软件,有可能会存在SSL相关动态链接库路径被修改的问题,在运行VN Station时弹出如下图所示的错误: + +![ssleay32.dll](https://user-images.githubusercontent.com/7112268/55474371-8bd06a00-5643-11e9-8b35-f064a45edfd1.png) + +解决方法: +1. 找到你的VNConda目录 +2. 将VNConda\Lib\site-packages\PyQt5\Qt\bin目录的两个动态库libeay32.dll和ssleay32.dll拷贝到VNConda\下即可 + +### 手动安装 + + + ## Ubuntu -如果是英文系统,请先运行下列命令安装中文编码: + +### 安装脚本 + +### TA-Lib + +### 中文编码 + +如果是英文系统(如阿里云),请先运行下列命令安装中文编码: ``` sudo locale-gen zh_CN.GB18030 diff --git a/docs/introduction.md b/docs/introduction.md index e10b99d0..02971c01 100644 --- a/docs/introduction.md +++ b/docs/introduction.md @@ -1 +1,15 @@ # Introduction + + +## 目标用户 + + +## 应用场景 + + + +## 支持的接口 + + + +## 支持的应用 \ No newline at end of file diff --git a/docs/quickstart.md b/docs/quickstart.md index e10b99d0..9536f20e 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -1 +1,20 @@ -# Introduction +# 基本使用 + + +## 启动VN Trader + + +## 连接接口 + + +## 订阅行情 + + +## 委托交易 + + +## 数据监控 + + +## 应用模块 + diff --git a/docs/rpc.md b/docs/rpc.md new file mode 100644 index 00000000..e69de29b diff --git a/install.bat b/install.bat index 4f06d206..e45576f0 100644 --- a/install.bat +++ b/install.bat @@ -1,6 +1,6 @@ ::Install talib and ibapi -pip install https://vnpy-pip.oss-cn-shanghai.aliyuncs.com/colletion/TA_Lib-0.4.17-cp37-cp37m-win_amd64.whl -pip install https://vnpy-pip.oss-cn-shanghai.aliyuncs.com/colletion/ibapi-9.75.1-py3-none-any.whl +pip install https://pip.vnpy.com/colletion/TA_Lib-0.4.17-cp37-cp37m-win_amd64.whl +pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl ::Install Python Modules pip install -r requirements.txt diff --git a/install.sh b/install.sh index e7801f91..2cf2043c 100644 --- a/install.sh +++ b/install.sh @@ -13,6 +13,10 @@ popd # old versions of ta-lib imports numpy in setup.py pip install numpy +# Install extra packages +pip install ta-lib +pip install https://vnpy-pip.oss-cn-shanghai.aliyuncs.com/colletion/ibapi-9.75.1-py3-none-any.whl + # Install Python Modules pip install -r requirements.txt diff --git a/requirements.txt b/requirements.txt index a1878cac..2b64ee1d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,11 +10,5 @@ matplotlib seaborn futu-api tigeropen - -# ta-lib -ta-lib; platform_system=="Linux" -https://pip.vnpy.com/colletion/TA_Lib-0.4.17-cp37-cp37m-win_amd64.whl; platform_system=="Windows" - -# ibapi -https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl - +ta-lib +ibapi diff --git a/tests/backtesting/genetic_algorithm.ipynb b/tests/backtesting/genetic_algorithm.ipynb new file mode 100644 index 00000000..c9e92e12 --- /dev/null +++ b/tests/backtesting/genetic_algorithm.ipynb @@ -0,0 +1,189 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import random\n", + "import multiprocessing\n", + "import numpy as np\n", + "from deap import creator, base, tools, algorithms\n", + "from vnpy.app.cta_strategy.backtesting import BacktestingEngine\n", + "from boll_channel_strategy import BollChannelStrategy\n", + "from datetime import datetime\n", + "import multiprocessing #多进程\n", + "from scoop import futures #多进程" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def parameter_generate():\n", + " '''\n", + " 根据设置的起始值,终止值和步进,随机生成待优化的策略参数\n", + " '''\n", + " parameter_list = []\n", + " p1 = random.randrange(4,50,2) #布林带窗口\n", + " p2 = random.randrange(4,50,2) #布林带通道阈值\n", + " p3 = random.randrange(4,50,2) #CCI窗口\n", + " p4 = random.randrange(18,40,2) #ATR窗口 \n", + "\n", + " parameter_list.append(p1)\n", + " parameter_list.append(p2)\n", + " parameter_list.append(p3)\n", + " parameter_list.append(p4)\n", + "\n", + " return parameter_list" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def object_func(strategy_avg):\n", + " \"\"\"\n", + " 本函数为优化目标函数,根据随机生成的策略参数,运行回测后自动返回2个结果指标:收益回撤比和夏普比率\n", + " \"\"\"\n", + " # 创建回测引擎对象\n", + " engine = BacktestingEngine()\n", + " engine.set_parameters(\n", + " vt_symbol=\"IF88.CFFEX\",\n", + " interval=\"1m\",\n", + " start=datetime(2018, 9, 1),\n", + " end=datetime(2019, 1,1),\n", + " rate=0,\n", + " slippage=0,\n", + " size=300,\n", + " pricetick=0.2,\n", + " capital=1_000_000,\n", + " )\n", + "\n", + " setting = {'boll_window': strategy_avg[0], #布林带窗口\n", + " 'boll_dev': strategy_avg[1], #布林带通道阈值\n", + " 'cci_window': strategy_avg[2], #CCI窗口\n", + " 'atr_window': strategy_avg[3],} #ATR窗口 \n", + "\n", + " #加载策略 \n", + " #engine.initStrategy(TurtleTradingStrategy, setting)\n", + " engine.add_strategy(BollChannelStrategy, setting)\n", + " engine.load_data()\n", + " engine.run_backtesting()\n", + " engine.calculate_result()\n", + " result = engine.calculate_statistics(Output=False)\n", + "\n", + " return_drawdown_ratio = round(result['return_drawdown_ratio'],2) #收益回撤比\n", + " sharpe_ratio= round(result['sharpe_ratio'],2) #夏普比率\n", + " return return_drawdown_ratio , sharpe_ratio" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "#设置优化方向:最大化收益回撤比,最大化夏普比率\n", + "creator.create(\"FitnessMulti\", base.Fitness, weights=(1.0, 1.0)) # 1.0 求最大值;-1.0 求最小值\n", + "creator.create(\"Individual\", list, fitness=creator.FitnessMulti)\n", + "\n", + "def optimize():\n", + " \"\"\"\"\"\" \n", + " toolbox = base.Toolbox() #Toolbox是deap库内置的工具箱,里面包含遗传算法中所用到的各种函数\n", + "\n", + " # 初始化 \n", + " toolbox.register(\"individual\", tools.initIterate, creator.Individual,parameter_generate) # 注册个体:随机生成的策略参数parameter_generate() \n", + " toolbox.register(\"population\", tools.initRepeat, list, toolbox.individual) #注册种群:个体形成种群 \n", + " toolbox.register(\"mate\", tools.cxTwoPoint) #注册交叉:两点交叉 \n", + " toolbox.register(\"mutate\", tools.mutUniformInt,low = 4,up = 40,indpb=0.6) #注册变异:随机生成一定区间内的整数\n", + " toolbox.register(\"evaluate\", object_func) #注册评估:优化目标函数object_func() \n", + " toolbox.register(\"select\", tools.selNSGA2) #注册选择:NSGA-II(带精英策略的非支配排序的遗传算法)\n", + " #pool = multiprocessing.Pool()\n", + " #toolbox.register(\"map\", pool.map)\n", + " #toolbox.register(\"map\", futures.map)\n", + "\n", + " #遗传算法参数设置\n", + " MU = 40 #设置每一代选择的个体数\n", + " LAMBDA = 160 #设置每一代产生的子女数\n", + " pop = toolbox.population(400) #设置族群里面的个体数量\n", + " CXPB, MUTPB, NGEN = 0.5, 0.35,40 #分别为种群内部个体的交叉概率、变异概率、产生种群代数\n", + " hof = tools.ParetoFront() #解的集合:帕累托前沿(非占优最优集)\n", + "\n", + " #解的集合的描述统计信息\n", + " #集合内平均值,标准差,最小值,最大值可以体现集合的收敛程度\n", + " #收敛程度低可以增加算法的迭代次数\n", + " stats = tools.Statistics(lambda ind: ind.fitness.values)\n", + " np.set_printoptions(suppress=True) #对numpy默认输出的科学计数法转换\n", + " stats.register(\"mean\", np.mean, axis=0) #统计目标优化函数结果的平均值\n", + " stats.register(\"std\", np.std, axis=0) #统计目标优化函数结果的标准差\n", + " stats.register(\"min\", np.min, axis=0) #统计目标优化函数结果的最小值\n", + " stats.register(\"max\", np.max, axis=0) #统计目标优化函数结果的最大值\n", + "\n", + " #运行算法\n", + " algorithms.eaMuPlusLambda(pop, toolbox, MU, LAMBDA, CXPB, MUTPB, NGEN, stats,\n", + " halloffame=hof) #esMuPlusLambda是一种基于(μ+λ)选择策略的多目标优化分段遗传算法\n", + "\n", + " return pop" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "optimize()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.1" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/trader/run.py b/tests/trader/run.py index 1a402f82..0b7fb501 100644 --- a/tests/trader/run.py +++ b/tests/trader/run.py @@ -10,6 +10,8 @@ from vnpy.gateway.ib import IbGateway from vnpy.gateway.ctp import CtpGateway from vnpy.gateway.tiger import TigerGateway from vnpy.gateway.oes import OesGateway +from vnpy.gateway.okex import OkexGateway +from vnpy.gateway.huobi import HuobiGateway from vnpy.app.cta_strategy import CtaStrategyApp from vnpy.app.csv_loader import CsvLoaderApp @@ -28,6 +30,8 @@ def main(): main_engine.add_gateway(BitmexGateway) main_engine.add_gateway(TigerGateway) main_engine.add_gateway(OesGateway) + main_engine.add_gateway(OkexGateway) + main_engine.add_gateway(HuobiGateway) main_engine.add_app(CtaStrategyApp) main_engine.add_app(CsvLoaderApp) diff --git a/vnpy/__init__.py b/vnpy/__init__.py index cc9bec87..159d48b8 100644 --- a/vnpy/__init__.py +++ b/vnpy/__init__.py @@ -1 +1 @@ -__version__ = "2.0.1b0" +__version__ = "2.0.1" diff --git a/vnpy/api/apex/FixApi.dll b/vnpy/api/apex/FixApi.dll index bd6dd2a9..9b9003a8 100644 Binary files a/vnpy/api/apex/FixApi.dll and b/vnpy/api/apex/FixApi.dll differ diff --git a/vnpy/api/apex/__init__.py b/vnpy/api/apex/__init__.py index b3d19ec5..69f6f329 100644 --- a/vnpy/api/apex/__init__.py +++ b/vnpy/api/apex/__init__.py @@ -1,2 +1,2 @@ -from .vnapex import ApexApi +from .vnapex import * from .fiddef import * diff --git a/vnpy/api/apex/vnapex.py b/vnpy/api/apex/vnapex.py index 2a425636..600fc63e 100644 --- a/vnpy/api/apex/vnapex.py +++ b/vnpy/api/apex/vnapex.py @@ -31,7 +31,7 @@ class ApexApi: def set_app_info(self, name: str, version: str): """设置应用信息""" - n = APEX.Fix_SetAppInfo(c_char_p(name), c_char_p(version)) + n = APEX.Fix_SetAppInfo(to_bytes(name), to_bytes(version)) return bool(n) def uninitialize(self): @@ -42,19 +42,19 @@ class ApexApi: def set_default_info(self, user: str, wtfs: str, fbdm: str, dest: str): """设置默认信息""" n = APEX.Fix_SetDefaultInfo( - c_char_p(user), - c_char_p(wtfs), - c_char_p(fbdm), - c_char_p(dest) + to_bytes(user), + to_bytes(wtfs), + to_bytes(fbdm), + to_bytes(dest) ) return bool(n) def connect(self, address: str, khh: str, pwd: str, timeout: int): """连接交易""" conn = APEX.Fix_Connect( - c_char_p(address), - c_char_p(khh), - c_char_p(pwd), + to_bytes(address), + to_bytes(khh), + to_bytes(pwd), timeout ) return conn @@ -66,13 +66,13 @@ class ApexApi: ): """连接交易""" conn = APEX.Fix_ConnectEx( - c_char_p(address), - c_char_p(khh), - c_char_p(pwd), - c_char_p(file_cert), - c_char_p(cert_pwd), - c_char_p(file_ca), - c_char_p(procotol), + to_bytes(address), + to_bytes(khh), + to_bytes(pwd), + to_bytes(file_cert), + to_bytes(cert_pwd), + to_bytes(file_ca), + to_bytes(procotol), verify, timeout ) @@ -100,27 +100,27 @@ class ApexApi: def set_wtfs(self, sess: int, wtfs: str): """设置委托方式""" - n = APEX.Fix_SetWTFS(sess, c_char_p(wtfs)) + n = APEX.Fix_SetWTFS(sess, to_bytes(wtfs)) return bool(n) def set_fbdm(self, sess: int, fbdm: str): """设置来源营业部""" - n = APEX.Fix_SetFBDM(sess, c_char_p(fbdm)) + n = APEX.Fix_SetFBDM(sess, to_bytes(fbdm)) return bool(n) def set_dest_fbdm(self, sess: int, fbdm: str): """设置目标营业部""" - n = APEX.Fix_SetDestFBDM(sess, c_char_p(fbdm)) + n = APEX.Fix_SetDestFBDM(sess, to_bytes(fbdm)) return bool(n) def set_node(self, sess: int, node: str): """设置业务站点""" - n = APEX.Fix_SetNode(sess, c_char_p(node)) + n = APEX.Fix_SetNode(sess, to_bytes(node)) return bool(n) def set_gydm(self, sess: int, gydm: str): """设置柜员号""" - n = APEX.Fix_SetGYDM(sess, c_char_p(gydm)) + n = APEX.Fix_SetGYDM(sess, to_bytes(gydm)) return bool(n) def create_head(self, sess: int, func: int): @@ -170,7 +170,7 @@ class ApexApi: def get_err_msg(self, sess: int): """获取错误信息""" size = 256 - out = create_string_buffer("", size) + out = create_string_buffer(b"", size) APEX.Fix_GetErrMsg(sess, out, size) return out.value @@ -182,7 +182,7 @@ class ApexApi: def get_item(self, sess: int, fid: int, row: int): """获取字符串内容""" size = 256 - out = create_string_buffer("", size) + out = create_string_buffer(b"", size) APEX.Fix_GetItem(sess, fid, out, size, row) return out.value @@ -210,19 +210,21 @@ class ApexApi: def get_token(self, sess: int): """获取业务令牌""" size = 256 - out = create_string_buffer("", size) + out = create_string_buffer(b"", size) APEX.Fix_GetToken(sess, out, size) return out.value - def encode(self, data): + def encode(self, data: str): """加密""" + data = to_bytes(data) buf = create_string_buffer(data, 512) APEX.Fix_Encode(buf) - return buf.value + return to_unicode(buf.value) def add_backup_svc_addr(self, address: str): """设置业务令牌""" + address = to_bytes(address) n = APEX.Fix_AddBackupSvrAddr(address) return bool(n) @@ -238,9 +240,11 @@ class ApexApi: def subscribe_by_customer(self, conn: int, svc: int, khh: str, pwd: str): """订阅数据""" - func = APEX[93] - n = func(conn, svc, self.push_call_func, c_char_p(""), khh, pwd) - return bool(n) + func = APEX[108] + n = func(conn, svc, self.push_call_func, + to_bytes(""), to_bytes(khh), to_bytes(pwd)) + + return n def unsubscribe_by_handle(self, handle: int): """退订推送""" @@ -254,22 +258,22 @@ class ApexApi: def get_val_with_id_by_index(self, sess: int, row: int, col: int): """根据行列获取数据""" s = 256 - buf = create_string_buffer("", s) + buf = create_string_buffer(b"", s) fid = c_long(0) size = c_int(s) APEX.Fix_GetValWithIdByIndex( sess, row, col, byref(fid), buf, byref(size)) - return fid.value, buf.value + return fid.value, to_unicode(buf.value) def set_system_no(self, sess: int, val: str): """设置系统编号""" - n = APEX.Fix_SetSystemNo(sess, c_char_p(val)) + n = APEX.Fix_SetSystemNo(sess, to_bytes(val)) return bool(n) def set_default_system_no(self, val: str): """设置默认系统编号""" - n = APEX.Fix_SetDefaultSystemNo(c_char_p(val)) + n = APEX.Fix_SetDefaultSystemNo(to_bytes(val)) return bool(n) def set_auto_reconnect(self, conn: int, reconnect: int): @@ -291,23 +295,23 @@ class ApexApi: """获取缓存数据""" size = 1024 outlen = c_int(size) - buf = create_string_buffer("", size) + buf = create_string_buffer(b"", size) APEX.Fix_GetItemBuf(sess, buf, byref(outlen), row) return buf def set_item(self, sess: int, fid: int, val: str): """设置请求内容""" - n = APEX.Fix_SetString(sess, fid, c_char_p(val)) + n = APEX.Fix_SetString(sess, fid, to_bytes(val)) return bool(n) def get_last_err_msg(self): """获取错误信息""" size = 256 - out = create_string_buffer("", size) + out = create_string_buffer(b"", size) APEX.Fix_GetLastErrMsg(out, size) - return out.value + return to_unicode(out.value) def reg_reply_call_func(self, sess: int = 0): """注册回调函数""" @@ -328,3 +332,21 @@ class ApexApi: def on_conn(self, conn: int, event, recv): """连接回调(需要继承)""" return True + + +def to_bytes(data: str): + """ + 将unicode字符串转换为bytes + """ + try: + bytes_data = data.encode("GBK") + return bytes_data + except AttributeError: + return data + + +def to_unicode(data: bytes): + """ + 将bytes字符串转换为unicode + """ + return data.decode("GBK") diff --git a/vnpy/api/websocket/websocket_client.py b/vnpy/api/websocket/websocket_client.py index 5fde1d82..0fe28ee7 100644 --- a/vnpy/api/websocket/websocket_client.py +++ b/vnpy/api/websocket/websocket_client.py @@ -48,14 +48,18 @@ class WebsocketClient(object): self.proxy_host = None self.proxy_port = None + self.ping_interval = 60 # seconds # For debugging self._last_sent_text = None self._last_received_text = None - def init(self, host: str, proxy_host: str = "", proxy_port: int = 0): - """""" + def init(self, host: str, proxy_host: str = "", proxy_port: int = 0, ping_interval: int = 60): + """ + :param ping_interval: unit: seconds, type: int + """ self.host = host + self.ping_interval = ping_interval # seconds if proxy_host and proxy_port: self.proxy_host = proxy_host @@ -206,7 +210,7 @@ class WebsocketClient(object): et, ev, tb = sys.exc_info() self.on_error(et, ev, tb) self._reconnect() - for i in range(60): + for i in range(self.ping_interval): if not self._active: break sleep(1) diff --git a/vnpy/app/algo_trading/__init__.py b/vnpy/app/algo_trading/__init__.py new file mode 100644 index 00000000..e5be3d37 --- /dev/null +++ b/vnpy/app/algo_trading/__init__.py @@ -0,0 +1,17 @@ +from pathlib import Path + +from vnpy.trader.app import BaseApp + +from .engine import AlgoEngine, APP_NAME + + +class CtaStrategyApp(BaseApp): + """""" + + app_name = APP_NAME + app_module = __module__ + app_path = Path(__file__).parent + display_name = "算法交易" + engine_class = AlgoEngine + widget_name = "AlgoManager" + icon_name = "algo.ico" diff --git a/vnpy/app/algo_trading/algos/__init__.py b/vnpy/app/algo_trading/algos/__init__.py new file mode 100644 index 00000000..d6fa977a --- /dev/null +++ b/vnpy/app/algo_trading/algos/__init__.py @@ -0,0 +1,59 @@ +# encoding: UTF-8 + +''' +动态载入所有的策略类 +''' +from __future__ import print_function + +import os +import importlib +import traceback + + +# 用来保存算法类和控件类的字典 +ALGO_DICT = {} +WIDGET_DICT = {} + + +#---------------------------------------------------------------------- +def loadAlgoModule(path, prefix): + """使用importlib动态载入算法""" + for root, subdirs, files in os.walk(path): + for name in files: + # 只有文件名以Algo.py结尾的才是算法文件 + if len(name)>7 and name[-7:] == 'Algo.py': + try: + # 模块名称需要模块路径前缀 + moduleName = prefix + name.replace('.py', '') + module = importlib.import_module(moduleName) + + # 获取算法类和控件类 + algo = None + widget = None + + for k in dir(module): + # 以Algo结尾的类,是算法 + if k[-4:] == 'Algo': + algo = module.__getattribute__(k) + + # 以Widget结尾的类,是控件 + if k[-6:] == 'Widget': + widget = module.__getattribute__(k) + + # 保存到字典中 + if algo and widget: + ALGO_DICT[algo.templateName] = algo + WIDGET_DICT[algo.templateName] = widget + except: + print ('-' * 20) + print ('Failed to import strategy file %s:' %moduleName) + traceback.print_exc() + + +# 遍历algo目录下的文件 +path1 = os.path.abspath(os.path.dirname(__file__)) +loadAlgoModule(path1, 'vnpy.trader.app.algoTrading.algo.') + +# 遍历工作目录下的文件 +path2 = os.getcwd() +loadAlgoModule(path2, '') \ No newline at end of file diff --git a/vnpy/app/algo_trading/algos/iceberg_algo.py b/vnpy/app/algo_trading/algos/iceberg_algo.py new file mode 100644 index 00000000..e69de29b diff --git a/vnpy/app/algo_trading/algos/sniper_algo.py b/vnpy/app/algo_trading/algos/sniper_algo.py new file mode 100644 index 00000000..e69de29b diff --git a/vnpy/app/algo_trading/algos/twap_algo.py b/vnpy/app/algo_trading/algos/twap_algo.py new file mode 100644 index 00000000..e69de29b diff --git a/vnpy/app/algo_trading/engine.py b/vnpy/app/algo_trading/engine.py new file mode 100644 index 00000000..3c5224dd --- /dev/null +++ b/vnpy/app/algo_trading/engine.py @@ -0,0 +1,65 @@ + +from vnpy.event import EventEngine +from vnpy.trader.engine import BaseEngine, MainEngine +from vnpy.trader.event import (EVENT_TICK, EVENT_TIMER, EVENT_ORDER, EVENT_TRADE) + + +class AlgoEngine(BaseEngine): + """""" + + def __init__(self, main_engine: MainEngine, event_engine: EventEngine): + """Constructor""" + super().__init__(main_engine, event_engine) + + self.algos = {} + self.symbol_algo_map = {} + self.orderid_algo_map = {} + + self.register_event() + + def register_event(self): + """""" + self.event_engine.register(EVENT_TICK, self.process_tick_event) + self.event_engine.register(EVENT_TIMER, self.process_timer_event) + self.event_engine.register(EVENT_ORDER, self.process_order_event) + self.event_engine.register(EVENT_TRADE, self.process_trade_event) + + def process_tick_event(self): + """""" + pass + + def process_timer_event(self): + """""" + pass + + def process_trade_event(self): + """""" + pass + + def process_order_event(self): + """""" + pass + + def start_algo(self, setting: dict): + """""" + pass + + def stop_algo(self, algo_name: dict): + """""" + pass + + def stop_all(self): + """""" + pass + + def subscribe(self, algo, vt_symbol): + """""" + pass + + def send_order( + self, + algo, + vt_symbol + ): + """""" + pass diff --git a/vnpy/app/algo_trading/template.py b/vnpy/app/algo_trading/template.py new file mode 100644 index 00000000..1ba0c5a4 --- /dev/null +++ b/vnpy/app/algo_trading/template.py @@ -0,0 +1,123 @@ +from vnpy.trader.engine import BaseEngine +from vnpy.trader.object import TickData, OrderData, TradeData +from vnpy.trader.constant import OrderType, Offset + +class AlgoTemplate: + """""" + count = 0 + + def __init__( + self, + algo_engine: BaseEngine, + algo_name: str, + setting: dict + ): + """Constructor""" + self.algo_engine = algo_engine + self.algo_name = algo_name + + self.active = False + self.active_orders = {} # vt_orderid:order + + @staticmethod + def new(cls, algo_engine:BaseEngine, setting: dict): + """Create new algo instance""" + cls.count += 1 + algo_name = f"{cls.__name__}_{cls.count}" + algo = cls(algo_engine, algo_name, setting) + + def update_tick(self, tick: TickData): + """""" + if self.active: + self.on_tick(tick) + + def update_order(self, order: OrderData): + """""" + if self.active: + if order.is_active(): + self.active_orders[order.vt_orderid] = order + elif order.vt_orderid in self.active_orders: + self.active_orders.pop(order.vt_orderid) + + self.on_order(order) + + def update_trade(self, trade: TradeData): + """""" + if self.active: + self.on_trade(trade) + + def update_timer(self): + """""" + if self.active: + self.on_timer() + + def on_start(self): + """""" + pass + + def on_stop(self): + """""" + pass + + def on_tick(self, tick: TickData): + """""" + pass + + def on_order(self, order: OrderData): + """""" + pass + + def on_trade(self, trade: TradeData): + """""" + pass + + def on_timer(self): + """""" + pass + + def start(self): + """""" + pass + + def stop(self): + """""" + pass + + def buy( + self, + vt_symbol, + price, + volume, + order_type: OrderType = OrderType.LIMIT, + offset: Offset = Offset.NONE + ): + """""" + return self.algo_engine.buy( + vt_symbol, + price, + volume, + order_type, + offset + ) + + def sell( + self, + vt_symbol, + price, + volume, + order_type: OrderType = OrderType.LIMIT, + offset: Offset = Offset.NONE + ): + """""" + return self.algo_engine.buy( + vt_symbol, + price, + volume, + order_type, + offset + ) + + + + + \ No newline at end of file diff --git a/vnpy/app/algo_trading/ui/algo.ico b/vnpy/app/algo_trading/ui/algo.ico new file mode 100644 index 00000000..83114df8 Binary files /dev/null and b/vnpy/app/algo_trading/ui/algo.ico differ diff --git a/vnpy/app/algo_trading/ui/widget.py b/vnpy/app/algo_trading/ui/widget.py new file mode 100644 index 00000000..e69de29b diff --git a/vnpy/app/csv_loader/engine.py b/vnpy/app/csv_loader/engine.py index af4a49a8..2c384d02 100644 --- a/vnpy/app/csv_loader/engine.py +++ b/vnpy/app/csv_loader/engine.py @@ -23,9 +23,11 @@ Sample csv file: import csv from datetime import datetime +from peewee import chunked + from vnpy.event import EventEngine from vnpy.trader.constant import Exchange, Interval -from vnpy.trader.database import DbBarData +from vnpy.trader.database import DbBarData, DB from vnpy.trader.engine import BaseEngine, MainEngine @@ -75,30 +77,37 @@ class CsvLoaderEngine(BaseEngine): with open(file_path, 'rt') as f: reader = csv.DictReader(f) + db_bars = [] + for item in reader: - db_bar = DbBarData() + dt = datetime.strptime(item[datetime_head], datetime_format) - db_bar.symbol = symbol - db_bar.exchange = exchange.value - db_bar.datetime = datetime.strptime( - item[datetime_head], datetime_format - ) - db_bar.interval = interval.value - db_bar.volume = item[volume_head] - db_bar.open_price = item[open_head] - db_bar.high_price = item[high_head] - db_bar.low_price = item[low_head] - db_bar.close_price = item[close_head] - db_bar.vt_symbol = vt_symbol - db_bar.gateway_name = "DB" + db_bar = { + "symbol": symbol, + "exchange": exchange.value, + "datetime": dt, + "interval": interval.value, + "volume": item[volume_head], + "open_price": item[open_head], + "high_price": item[high_head], + "low_price": item[low_head], + "close_price": item[close_head], + "vt_symbol": vt_symbol, + "gateway_name": "DB" + } - db_bar.replace() + db_bars.append(db_bar) # do some statistics count += 1 if not start: - start = db_bar.datetime + start = db_bar["datetime"] - end = db_bar.datetime + end = db_bar["datetime"] + + # Insert into DB + with DB.atomic(): + for batch in chunked(db_bars, 500): + DbBarData.insert_many(batch).on_conflict_replace().execute() return start, end, count diff --git a/vnpy/app/csv_loader/ui/widget.py b/vnpy/app/csv_loader/ui/widget.py index 6e89741c..79ef9e0c 100644 --- a/vnpy/app/csv_loader/ui/widget.py +++ b/vnpy/app/csv_loader/ui/widget.py @@ -90,7 +90,8 @@ class CsvLoaderWidget(QtWidgets.QWidget): def select_file(self): """""" - result: str = QtWidgets.QFileDialog.getOpenFileName(self) + result: str = QtWidgets.QFileDialog.getOpenFileName( + self, filter="CSV (*.csv)") filename = result[0] if filename: self.file_edit.setText(filename) diff --git a/vnpy/app/cta_strategy/backtesting.py b/vnpy/app/cta_strategy/backtesting.py index ab168b72..39768844 100644 --- a/vnpy/app/cta_strategy/backtesting.py +++ b/vnpy/app/cta_strategy/backtesting.py @@ -35,7 +35,7 @@ class OptimizationSetting: def __init__(self): """""" self.params = {} - self.target = "" + self.target_name = "" def add_parameter( self, name: str, start: float, end: float = None, step: float = None @@ -62,9 +62,9 @@ class OptimizationSetting: self.params[name] = value_list - def set_target(self, target: str): + def set_target(self, target_name: str): """""" - self.target = target + self.target_name = target_name def generate_setting(self): """""" @@ -293,7 +293,7 @@ class BacktestingEngine: self.output("逐日盯市盈亏计算完成") return self.daily_df - def calculate_statistics(self, df: DataFrame = None): + def calculate_statistics(self, df: DataFrame = None, Output=True): """""" self.output("开始计算策略统计指标") @@ -325,6 +325,7 @@ class BacktestingEngine: daily_return = 0 return_std = 0 sharpe_ratio = 0 + return_drawdown_ratio = 0 else: # Calculate balance related time series data df["balance"] = df["net_pnl"].cumsum() + self.capital @@ -373,38 +374,42 @@ class BacktestingEngine: else: sharpe_ratio = 0 + return_drawdown_ratio = -total_return / max_ddpercent + # Output - self.output("-" * 30) - self.output(f"首个交易日:\t{start_date}") - self.output(f"最后交易日:\t{end_date}") + if Output: + self.output("-" * 30) + self.output(f"首个交易日:\t{start_date}") + self.output(f"最后交易日:\t{end_date}") - self.output(f"总交易日:\t{total_days}") - self.output(f"盈利交易日:\t{profit_days}") - self.output(f"亏损交易日:\t{loss_days}") + self.output(f"总交易日:\t{total_days}") + self.output(f"盈利交易日:\t{profit_days}") + self.output(f"亏损交易日:\t{loss_days}") - self.output(f"起始资金:\t{self.capital:,.2f}") - self.output(f"结束资金:\t{end_balance:,.2f}") + self.output(f"起始资金:\t{self.capital:,.2f}") + self.output(f"结束资金:\t{end_balance:,.2f}") - self.output(f"总收益率:\t{total_return:,.2f}%") - self.output(f"年化收益:\t{annual_return:,.2f}%") - self.output(f"最大回撤: \t{max_drawdown:,.2f}") - self.output(f"百分比最大回撤: {max_ddpercent:,.2f}%") + self.output(f"总收益率:\t{total_return:,.2f}%") + self.output(f"年化收益:\t{annual_return:,.2f}%") + self.output(f"最大回撤: \t{max_drawdown:,.2f}") + self.output(f"百分比最大回撤: {max_ddpercent:,.2f}%") - self.output(f"总盈亏:\t{total_net_pnl:,.2f}") - self.output(f"总手续费:\t{total_commission:,.2f}") - self.output(f"总滑点:\t{total_slippage:,.2f}") - self.output(f"总成交金额:\t{total_turnover:,.2f}") - self.output(f"总成交笔数:\t{total_trade_count}") + self.output(f"总盈亏:\t{total_net_pnl:,.2f}") + self.output(f"总手续费:\t{total_commission:,.2f}") + self.output(f"总滑点:\t{total_slippage:,.2f}") + self.output(f"总成交金额:\t{total_turnover:,.2f}") + self.output(f"总成交笔数:\t{total_trade_count}") - self.output(f"日均盈亏:\t{daily_net_pnl:,.2f}") - self.output(f"日均手续费:\t{daily_commission:,.2f}") - self.output(f"日均滑点:\t{daily_slippage:,.2f}") - self.output(f"日均成交金额:\t{daily_turnover:,.2f}") - self.output(f"日均成交笔数:\t{daily_trade_count}") + self.output(f"日均盈亏:\t{daily_net_pnl:,.2f}") + self.output(f"日均手续费:\t{daily_commission:,.2f}") + self.output(f"日均滑点:\t{daily_slippage:,.2f}") + self.output(f"日均成交金额:\t{daily_turnover:,.2f}") + self.output(f"日均成交笔数:\t{daily_trade_count}") - self.output(f"日均收益率:\t{daily_return:,.2f}%") - self.output(f"收益标准差:\t{return_std:,.2f}%") - self.output(f"Sharpe Ratio:\t{sharpe_ratio:,.2f}") + self.output(f"日均收益率:\t{daily_return:,.2f}%") + self.output(f"收益标准差:\t{return_std:,.2f}%") + self.output(f"Sharpe Ratio:\t{sharpe_ratio:,.2f}") + self.output(f"收益回撤比:\t{return_drawdown_ratio:,.2f}") statistics = { "start_date": start_date, @@ -430,6 +435,7 @@ class BacktestingEngine: "daily_return": daily_return, "return_std": return_std, "sharpe_ratio": sharpe_ratio, + "return_drawdown_ratio": return_drawdown_ratio, } return statistics @@ -473,7 +479,7 @@ class BacktestingEngine: return if not target_name: - self.output("优化目标为设置,请检查") + self.output("优化目标未设置,请检查") return # Use multiprocessing pool for running backtesting with different setting diff --git a/vnpy/app/cta_strategy/engine.py b/vnpy/app/cta_strategy/engine.py index 311981f1..9bff635b 100644 --- a/vnpy/app/cta_strategy/engine.py +++ b/vnpy/app/cta_strategy/engine.py @@ -151,13 +151,14 @@ class CtaEngine(BaseEngine): Query bar data from RQData. """ symbol, exchange_str = vt_symbol.split(".") - if symbol.upper() not in self.rq_symbols: + rq_symbol = to_rq_symbol(vt_symbol) + if rq_symbol not in self.rq_symbols: return None end += timedelta(1) # For querying night trading period data df = self.rq_client.get_price( - symbol.upper(), + rq_symbol, frequency=interval.value, fields=["open", "high", "low", "close", "volume"], start_date=start, @@ -540,7 +541,7 @@ class CtaEngine(BaseEngine): DbBarData.select() .where( (DbBarData.vt_symbol == vt_symbol) - & (DbBarData.interval == interval) + & (DbBarData.interval == interval.value) & (DbBarData.datetime >= start) & (DbBarData.datetime <= end) ) @@ -910,3 +911,29 @@ class CtaEngine(BaseEngine): subject = "CTA策略引擎" self.main_engine.send_email(subject, msg) + + +def to_rq_symbol(vt_symbol: str): + """ + CZCE product of RQData has symbol like "TA1905" while + vt symbol is "TA905.CZCE" so need to add "1" in symbol. + """ + symbol, exchange_str = vt_symbol.split(".") + if exchange_str != "CZCE": + return symbol.upper() + + for count, word in enumerate(symbol): + if word.isdigit(): + break + + product = symbol[:count] + year = symbol[count] + month = symbol[count + 1:] + + if year == "9": + year = "1" + year + else: + year = "2" + year + + rq_symbol = f"{product}{year}{month}".upper() + return rq_symbol diff --git a/vnpy/app/cta_strategy/template.py b/vnpy/app/cta_strategy/template.py index e4f309ff..20cf88d2 100644 --- a/vnpy/app/cta_strategy/template.py +++ b/vnpy/app/cta_strategy/template.py @@ -2,7 +2,7 @@ from abc import ABC from typing import Any, Callable -from vnpy.trader.constant import Interval, Status, Direction, Offset +from vnpy.trader.constant import Interval, Direction, Offset from vnpy.trader.object import BarData, TickData, OrderData, TradeData from .base import StopOrder, EngineType diff --git a/vnpy/gateway/ctp/ctp_gateway.py b/vnpy/gateway/ctp/ctp_gateway.py index fe2784a7..f1a25811 100644 --- a/vnpy/gateway/ctp/ctp_gateway.py +++ b/vnpy/gateway/ctp/ctp_gateway.py @@ -405,7 +405,7 @@ class CtpTdApi(TdApi): """""" if not error['ErrorID']: self.authStatus = True - self.writeLog("交易授权验证成功") + self.gateway.write_log("交易授权验证成功") self.login() else: self.gateway.write_error("交易授权验证失败", error) @@ -418,7 +418,7 @@ class CtpTdApi(TdApi): self.login_status = True self.gateway.write_log("交易登录成功") - # Confirm settelment + # Confirm settlement req = { "BrokerID": self.brokerid, "InvestorID": self.userid @@ -549,13 +549,16 @@ class CtpTdApi(TdApi): product=product, size=data["VolumeMultiple"], pricetick=data["PriceTick"], - option_underlying=data["UnderlyingInstrID"], - option_type=OPTIONTYPE_CTP2VT.get(data["OptionsType"], None), - option_strike=data["StrikePrice"], - option_expiry=datetime.strptime(data["ExpireDate"], "%Y%m%d"), gateway_name=self.gateway_name ) + # For option only + if data["OptionsType"]: + contract.option_underlying = data["UnderlyingInstrID"], + contract.option_type = OPTIONTYPE_CTP2VT.get(data["OptionsType"], None), + contract.option_strike = data["StrikePrice"], + contract.option_expiry = datetime.strptime(data["ExpireDate"], "%Y%m%d"), + self.gateway.on_contract(contract) symbol_exchange_map[contract.symbol] = contract.exchange @@ -662,7 +665,7 @@ class CtpTdApi(TdApi): "UserID": self.userid, "BrokerID": self.brokerid, "AuthCode": self.auth_code, - "ProductInfo": self.product_info + "UserProductInfo": self.product_info } self.reqid += 1 @@ -678,7 +681,8 @@ class CtpTdApi(TdApi): req = { "UserID": self.userid, "Password": self.password, - "BrokerID": self.brokerid + "BrokerID": self.brokerid, + "UserProductInfo": self.product_info } self.reqid += 1 diff --git a/vnpy/gateway/huobi/__init__.py b/vnpy/gateway/huobi/__init__.py new file mode 100644 index 00000000..9e3d74e8 --- /dev/null +++ b/vnpy/gateway/huobi/__init__.py @@ -0,0 +1,4 @@ +from .huobi_gateway import HuobiGateway + + + diff --git a/vnpy/gateway/huobi/huobi_gateway.py b/vnpy/gateway/huobi/huobi_gateway.py new file mode 100644 index 00000000..5572f5fe --- /dev/null +++ b/vnpy/gateway/huobi/huobi_gateway.py @@ -0,0 +1,715 @@ +# encoding: UTF-8 + +""" +火币交易接口 +""" + +import re +import urllib +import base64 +import json +import zlib +import hashlib +import hmac +from copy import copy +from datetime import datetime + +from vnpy.event import Event +from vnpy.api.rest import RestClient +from vnpy.api.websocket import WebsocketClient +from vnpy.trader.constant import ( + Direction, + Exchange, + Product, + Status, + OrderType +) +from vnpy.trader.gateway import BaseGateway, LocalOrderManager +from vnpy.trader.object import ( + TickData, + OrderData, + TradeData, + AccountData, + ContractData, + OrderRequest, + CancelRequest, + SubscribeRequest +) +from vnpy.trader.event import EVENT_TIMER + + +REST_HOST = "https://api.huobipro.com" +WEBSOCKET_DATA_HOST = "wss://api.huobi.pro/ws" # Market Data +WEBSOCKET_TRADE_HOST = "wss://api.huobi.pro/ws/v1" # Account and Order + +STATUS_HUOBI2VT = { + "submitted": Status.NOTTRADED, + "partial-filled": Status.PARTTRADED, + "filled": Status.ALLTRADED, + "cancelling": Status.CANCELLED, + "partial-canceled": Status.CANCELLED, + "canceled": Status.CANCELLED, +} + +ORDERTYPE_VT2HUOBI = { + (Direction.LONG, OrderType.MARKET): "buy-market", + (Direction.SHORT, OrderType.MARKET): "sell-market", + (Direction.LONG, OrderType.LIMIT): "buy-limit", + (Direction.SHORT, OrderType.LIMIT): "sell-limit", +} +ORDERTYPE_HUOBI2VT = {v: k for k, v in ORDERTYPE_VT2HUOBI.items()} + + +huobi_symbols = set() +symbol_name_map = {} + + +class HuobiGateway(BaseGateway): + """ + VN Trader Gateway for Huobi connection. + """ + + default_setting = { + "API Key": "", + "Secret Key": "", + "会话数": 3, + "代理地址": "127.0.0.1", + "代理端口": 1080, + } + + def __init__(self, event_engine): + """Constructor""" + super(HuobiGateway, self).__init__(event_engine, "HUOBI") + + self.order_manager = LocalOrderManager(self) + + self.rest_api = HuobiRestApi(self) + self.trade_ws_api = HuobiTradeWebsocketApi(self) + self.market_ws_api = HuobiDataWebsocketApi(self) + + def connect(self, setting: dict): + """""" + key = setting["API Key"] + secret = setting["Secret Key"] + session_number = setting["会话数"] + proxy_host = setting["代理地址"] + proxy_port = setting["代理端口"] + + self.rest_api.connect(key, secret, session_number, + proxy_host, proxy_port) + self.trade_ws_api.connect(key, secret, proxy_host, proxy_port) + self.market_ws_api.connect(key, secret, proxy_host, proxy_port) + + self.init_query() + + def subscribe(self, req: SubscribeRequest): + """""" + self.market_ws_api.subscribe(req) + self.trade_ws_api.subscribe(req) + + def send_order(self, req: OrderRequest): + """""" + return self.rest_api.send_order(req) + + def cancel_order(self, req: CancelRequest): + """""" + self.rest_api.cancel_order(req) + + def query_account(self): + """""" + self.rest_api.query_account_balance() + + def query_position(self): + """""" + pass + + def close(self): + """""" + self.rest_api.stop() + self.trade_ws_api.stop() + self.market_ws_api.stop() + + def process_timer_event(self, event: Event): + """""" + self.count += 1 + if self.count < 3: + return + + self.query_account() + + def init_query(self): + """""" + self.count = 0 + self.event_engine.register(EVENT_TIMER, self.process_timer_event) + + +class HuobiRestApi(RestClient): + """ + HUOBI REST API + """ + + def __init__(self, gateway: BaseGateway): + """""" + super(HuobiRestApi, self).__init__() + + self.gateway = gateway + self.gateway_name = gateway.gateway_name + self.order_manager = gateway.order_manager + + self.host = "" + self.key = "" + self.secret = "" + self.account_id = "" + + self.cancel_requests = {} + self.orders = {} + + def sign(self, request): + """ + Generate HUOBI signature. + """ + request.headers = { + "User-Agent": "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.71 Safari/537.36" + } + params_with_signature = create_signature( + self.key, + request.method, + self.host, + request.path, + self.secret, + request.params + ) + request.params = params_with_signature + + if request.method == "POST": + request.headers["Content-Type"] = "application/json" + + if request.data: + request.data = json.dumps(request.data) + + return request + + def connect( + self, + key: str, + secret: str, + session_number: int, + proxy_host: str, + proxy_port: int + ): + """ + Initialize connection to REST server. + """ + self.key = key + self.secret = secret + + self.host, _ = _split_url(REST_HOST) + + self.init(REST_HOST, proxy_host, proxy_port) + self.start(session_number) + + self.gateway.write_log("REST API启动成功") + + self.query_contract() + self.query_account() + self.query_order() + + def query_account(self): + """""" + self.add_request( + method="GET", + path="/v1/account/accounts", + callback=self.on_query_account + ) + + def query_account_balance(self): + """""" + path = f"/v1/account/accounts/{self.account_id}/balance" + self.add_request( + method="GET", + path=path, + callback=self.on_query_account_balance + ) + + def query_order(self): + """""" + self.add_request( + method="GET", + path="/v1/order/openOrders", + callback=self.on_query_order + ) + + def query_contract(self): + """""" + self.add_request( + method="GET", + path="/v1/common/symbols", + callback=self.on_query_contract + ) + + def send_order(self, req: OrderRequest): + """""" + huobi_type = ORDERTYPE_VT2HUOBI.get( + (req.direction, req.type), "" + ) + + local_orderid = self.order_manager.new_local_orderid() + order = req.create_order_data( + local_orderid, + self.gateway_name + ) + order.time = datetime.now().strftime("%H:%M:%S") + + data = { + "account-id": self.account_id, + "amount": str(req.volume), + "symbol": req.symbol, + "type": huobi_type, + "price": str(req.price), + "source": "api" + } + + self.add_request( + method="POST", + path="/v1/order/orders/place", + callback=self.on_send_order, + data=data, + extra=order, + ) + + self.order_manager.on_order(order) + return order.vt_orderid + + def cancel_order(self, req: CancelRequest): + """""" + sys_orderid = self.order_manager.get_sys_orderid(req.orderid) + + path = f"/v1/order/orders/{sys_orderid}/submitcancel" + self.add_request( + method="POST", + path=path, + callback=self.on_cancel_order, + extra=req + ) + + def on_query_account(self, data, request): + """""" + if self.check_error(data, "查询账户"): + return + + for d in data["data"]: + if d["type"] == "spot": + self.account_id = d["id"] + self.gateway.write_log(f"账户代码{self.account_id}查询成功") + + self.query_account_balance() + + def on_query_account_balance(self, data, request): + """""" + if self.check_error(data, "查询账户资金"): + return + + buf = {} + for d in data["data"]["list"]: + currency = d["currency"] + currency_data = buf.setdefault(currency, {}) + currency_data[d["type"]] = float(d["balance"]) + + for currency, currency_data in buf.items(): + account = AccountData( + accountid=currency, + balance=currency_data["trade"] + currency_data["frozen"], + frozen=currency_data["frozen"], + gateway_name=self.gateway_name, + ) + + if account.balance: + self.gateway.on_account(account) + + def on_query_order(self, data, request): + """""" + if self.check_error(data, "查询委托"): + return + + for d in data["data"]: + sys_orderid = d["id"] + local_orderid = self.order_manager.get_local_orderid(sys_orderid) + + direction, order_type = ORDERTYPE_HUOBI2VT[d["type"]] + dt = datetime.fromtimestamp(d["created-at"] / 1000) + time = dt.strftime("%H:%M:%S") + + order = OrderData( + orderid=local_orderid, + symbol=d["symbol"], + exchange=Exchange.HUOBI, + price=float(d["price"]), + volume=float(d["amount"]), + type=order_type, + direction=direction, + traded=float(d["filled-amount"]), + status=STATUS_HUOBI2VT.get(d["state"], None), + time=time, + gateway_name=self.gateway_name, + ) + + self.order_manager.on_order(order) + + self.gateway.write_log("委托信息查询成功") + + def on_query_contract(self, data, request): # type: (dict, Request)->None + """""" + if self.check_error(data, "查询合约"): + return + + for d in data["data"]: + base_currency = d["base-currency"] + quote_currency = d["quote-currency"] + name = f"{base_currency.upper()}/{quote_currency.upper()}" + pricetick = 1 / pow(10, d["price-precision"]) + size = 1 / pow(10, d["amount-precision"]) + + contract = ContractData( + symbol=d["symbol"], + exchange=Exchange.HUOBI, + name=name, + pricetick=pricetick, + size=size, + product=Product.SPOT, + gateway_name=self.gateway_name, + ) + self.gateway.on_contract(contract) + + huobi_symbols.add(contract.symbol) + symbol_name_map[contract.symbol] = contract.name + + self.gateway.write_log("合约信息查询成功") + + def on_send_order(self, data, request): + """""" + order = request.extra + + if self.check_error(data, "委托"): + order.status = Status.REJECTED + self.order_manager.on_order(order) + return + + sys_orderid = data["data"] + self.order_manager.update_orderid_map(order.orderid, sys_orderid) + + def on_cancel_order(self, data, request): + """""" + if self.check_error(data, "撤单"): + return + + cancel_request = request.extra + local_orderid = cancel_request.orderid + + order = self.order_manager.get_order_with_local_orderid(local_orderid) + order.status = Status.CANCELLED + + self.order_manager.on_order(order) + self.gateway.write_log(f"委托撤单成功:{order.orderid}") + + def check_error(self, data: dict, func: str = ""): + """""" + if data["status"] != "error": + return False + + error_code = data["err-code"] + error_msg = data["err-msg"] + + self.gateway.write_log(f"{func}请求出错,代码:{error_code},信息:{error_msg}") + return True + + +class HuobiWebsocketApiBase(WebsocketClient): + """""" + + def __init__(self, gateway): + """""" + super(HuobiWebsocketApiBase, self).__init__() + + self.gateway = gateway + self.gateway_name = gateway.gateway_name + + self.key = "" + self.secret = "" + self.sign_host = "" + self.path = "" + + def connect( + self, + key: str, + secret: str, + url: str, + proxy_host: str, + proxy_port: int + ): + """""" + self.key = key + self.secret = secret + + host, path = _split_url(url) + self.sign_host = host + self.path = path + + self.init(url, proxy_host, proxy_port) + self.start() + + def login(self): + """""" + params = {"op": "auth"} + params.update(create_signature(self.key, "GET", self.sign_host, self.path, self.secret)) + return self.send_packet(params) + + def on_login(self, packet): + """""" + pass + + @staticmethod + def unpack_data(data): + """""" + return json.loads(zlib.decompress(data, 31)) + + def on_packet(self, packet): + """""" + if "ping" in packet: + self.send_packet({"pong": packet["ping"]}) + elif "err-msg" in packet: + return self.on_error_msg(packet) + elif "op" in packet and packet["op"] == "auth": + return self.on_login() + else: + self.on_data(packet) + + def on_data(self, packet): + """""" + print("data : {}".format(packet)) + + def on_error_msg(self, packet): + """""" + msg = packet["err-msg"] + if msg == "invalid pong": + return + + self.gateway.write_log(packet["err-msg"]) + + +class HuobiTradeWebsocketApi(HuobiWebsocketApiBase): + """""" + def __init__(self, gateway): + """""" + super().__init__(gateway) + + self.order_manager = gateway.order_manager + self.order_manager.push_data_callback = self.on_data + + self.req_id = 0 + + def connect(self, key, secret, proxy_host, proxy_port): + """""" + super().connect(key, secret, WEBSOCKET_TRADE_HOST, proxy_host, proxy_port) + + def subscribe(self, req: SubscribeRequest): + """""" + self.req_id += 1 + req = { + "op": "sub", + "cid": str(self.req_id), + "topic": f"orders.{req.symbol}" + } + self.send_packet(req) + + def on_connected(self): + """""" + self.gateway.write_log("交易Websocket API连接成功") + self.login() + + def on_login(self): + """""" + self.gateway.write_log("交易Websocket API登录成功") + + def on_data(self, packet): # type: (dict)->None + """""" + op = packet.get("op", None) + if op != "notify": + return + + topic = packet["topic"] + if "orders" in topic: + self.on_order(packet["data"]) + + def on_order(self, data: dict): + """""" + sys_orderid = str(data["order-id"]) + + order = self.order_manager.get_order_with_sys_orderid(sys_orderid) + if not order: + self.order_manager.add_push_data(sys_orderid, data) + return + + traded_volume = float(data["filled-amount"]) + + # Push order event + order.traded += traded_volume + order.status = STATUS_HUOBI2VT.get(data["order-state"], None) + self.order_manager.on_order(order) + + # Push trade event + if not traded_volume: + return + + trade = TradeData( + symbol=order.symbol, + exchange=Exchange.HUOBI, + orderid=order.orderid, + tradeid=str(data["seq-id"]), + direction=order.direction, + price=float(data["price"]), + volume=float(data["filled-amount"]), + time=datetime.now().strftime("%H:%M:%S"), + gateway_name=self.gateway_name, + ) + self.gateway.on_trade(trade) + + +class HuobiDataWebsocketApi(HuobiWebsocketApiBase): + """""" + + def __init__(self, gateway): + """""" + super().__init__(gateway) + + self.req_id = 0 + self.ticks = {} + + def connect(self, key: str, secret: str, proxy_host: str, proxy_port: int): + """""" + super().connect(key, secret, WEBSOCKET_DATA_HOST, proxy_host, proxy_port) + + def on_connected(self): + """""" + self.gateway.write_log("行情Websocket API连接成功") + + def subscribe(self, req: SubscribeRequest): + """""" + symbol = req.symbol + + # Create tick data buffer + tick = TickData( + symbol=symbol, + name=symbol_name_map.get(symbol, ""), + exchange=Exchange.HUOBI, + datetime=datetime.now(), + gateway_name=self.gateway_name, + ) + self.ticks[symbol] = tick + + # Subscribe to market depth update + self.req_id += 1 + req = { + "sub": f"market.{symbol}.depth.step0", + "id": str(self.req_id) + } + self.send_packet(req) + + # Subscribe to market detail update + self.req_id += 1 + req = { + "sub": f"market.{symbol}.detail", + "id": str(self.req_id) + } + self.send_packet(req) + + def on_data(self, packet): # type: (dict)->None + """""" + channel = packet.get("ch", None) + if channel: + if "depth.step" in channel: + self.on_market_depth(packet) + elif "detail" in channel: + self.on_market_detail(packet) + elif "err-code" in packet: + code = packet["err-code"] + msg = packet["err-msg"] + self.gateway.write_log(f"错误代码:{code}, 错误信息:{msg}") + + def on_market_depth(self, data): + """行情深度推送 """ + symbol = data["ch"].split(".")[1] + tick = self.ticks[symbol] + tick.datetime = datetime.fromtimestamp(data["ts"] / 1000) + + bids = data["tick"]["bids"] + for n in range(5): + price, volume = bids[n] + tick.__setattr__("bid_price_" + str(n + 1), float(price)) + tick.__setattr__("bid_volume_" + str(n + 1), float(volume)) + + asks = data["tick"]["asks"] + for n in range(5): + price, volume = asks[n] + tick.__setattr__("ask_price_" + str(n + 1), float(price)) + tick.__setattr__("ask_volume_" + str(n + 1), float(volume)) + + if tick.last_price: + self.gateway.on_tick(copy(tick)) + + def on_market_detail(self, data): + """市场细节推送""" + symbol = data["ch"].split(".")[1] + tick = self.ticks[symbol] + tick.datetime = datetime.fromtimestamp(data["ts"] / 1000) + + tick_data = data["tick"] + tick.open_price = float(tick_data["open"]) + tick.high_price = float(tick_data["high"]) + tick.low_price = float(tick_data["low"]) + tick.last_price = float(tick_data["close"]) + tick.volume = float(tick_data["vol"]) + + if tick.bid_price_1: + self.gateway.on_tick(copy(tick)) + + +def _split_url(url): + """ + 将url拆分为host和path + :return: host, path + """ + result = re.match("\w+://([^/]*)(.*)", url) # noqa + if result: + return result.group(1), result.group(2) + + +def create_signature(api_key, method, host, path, secret_key, get_params=None): + """ + 创建签名 + :param get_params: dict 使用GET方法时附带的额外参数(urlparams) + :return: + """ + sorted_params = [ + ("AccessKeyId", api_key), + ("SignatureMethod", "HmacSHA256"), + ("SignatureVersion", "2"), + ("Timestamp", datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S")) + ] + + if get_params: + sorted_params.extend(list(get_params.items())) + sorted_params = list(sorted(sorted_params)) + encode_params = urllib.parse.urlencode(sorted_params) + + payload = [method, host, path, encode_params] + payload = "\n".join(payload) + payload = payload.encode(encoding="UTF8") + + secret_key = secret_key.encode(encoding="UTF8") + + digest = hmac.new(secret_key, payload, digestmod=hashlib.sha256).digest() + signature = base64.b64encode(digest) + + params = dict(sorted_params) + params["Signature"] = signature.decode("UTF8") + return params diff --git a/vnpy/gateway/okex/__init__.py b/vnpy/gateway/okex/__init__.py new file mode 100644 index 00000000..8cb9fb3b --- /dev/null +++ b/vnpy/gateway/okex/__init__.py @@ -0,0 +1 @@ +from .okex_gateway import OkexGateway diff --git a/vnpy/gateway/okex/okex_gateway.py b/vnpy/gateway/okex/okex_gateway.py new file mode 100644 index 00000000..cc1e124c --- /dev/null +++ b/vnpy/gateway/okex/okex_gateway.py @@ -0,0 +1,699 @@ +# encoding: UTF-8 +""" +""" + +import hashlib +import hmac +import sys +import time +import json +import base64 +import zlib +from copy import copy +from datetime import datetime +from threading import Lock +from urllib.parse import urlencode + +from requests import ConnectionError + +from vnpy.api.rest import Request, RestClient +from vnpy.api.websocket import WebsocketClient +from vnpy.trader.constant import ( + Direction, + Exchange, + OrderType, + Product, + Status +) +from vnpy.trader.gateway import BaseGateway +from vnpy.trader.object import ( + TickData, + OrderData, + TradeData, + AccountData, + ContractData, + OrderRequest, + CancelRequest, + SubscribeRequest, +) + +REST_HOST = "https://www.okex.com" +WEBSOCKET_HOST = "wss://real.okex.com:10442/ws/v3" + +STATUS_OKEX2VT = { + "ordering": Status.SUBMITTING, + "open": Status.NOTTRADED, + "part_filled": Status.PARTTRADED, + "filled": Status.ALLTRADED, + "cancelled": Status.CANCELLED, + "cancelling": Status.CANCELLED, + "failure": Status.REJECTED, +} + +DIRECTION_VT2OKEX = {Direction.LONG: "buy", Direction.SHORT: "sell"} +DIRECTION_OKEX2VT = {v: k for k, v in DIRECTION_VT2OKEX.items()} + +ORDERTYPE_VT2OKEX = { + OrderType.LIMIT: "limit", + OrderType.MARKET: "market" +} +ORDERTYPE_OKEX2VT = {v: k for k, v in ORDERTYPE_VT2OKEX.items()} + + +instruments = set() +currencies = set() + + +class OkexGateway(BaseGateway): + """ + VN Trader Gateway for OKEX connection. + """ + + default_setting = { + "API Key": "", + "Secret Key": "", + "Passphrase": "", + "会话数": 3, + "代理地址": "127.0.0.1", + "代理端口": 1080, + } + + def __init__(self, event_engine): + """Constructor""" + super(OkexGateway, self).__init__(event_engine, "OKEX") + + self.rest_api = OkexRestApi(self) + self.ws_api = OkexWebsocketApi(self) + + def connect(self, setting: dict): + """""" + key = setting["API Key"] + secret = setting["Secret Key"] + passphrase = setting["Passphrase"] + session_number = setting["会话数"] + proxy_host = setting["代理地址"] + proxy_port = setting["代理端口"] + + self.rest_api.connect(key, secret, passphrase, + session_number, proxy_host, proxy_port) + + self.ws_api.connect(key, secret, passphrase, proxy_host, proxy_port) + + def subscribe(self, req: SubscribeRequest): + """""" + self.ws_api.subscribe(req) + + def send_order(self, req: OrderRequest): + """""" + return self.rest_api.send_order(req) + + def cancel_order(self, req: CancelRequest): + """""" + self.rest_api.cancel_order(req) + + def query_account(self): + """""" + pass + + def query_position(self): + """""" + pass + + def close(self): + """""" + self.rest_api.stop() + self.ws_api.stop() + + +class OkexRestApi(RestClient): + """ + OKEX REST API + """ + + def __init__(self, gateway: BaseGateway): + """""" + super(OkexRestApi, self).__init__() + + self.gateway = gateway + self.gateway_name = gateway.gateway_name + + self.key = "" + self.secret = "" + self.passphrase = "" + + self.order_count = 10000 + self.order_count_lock = Lock() + + self.connect_time = 0 + + def sign(self, request): + """ + Generate OKEX signature. + """ + # Sign + # timestamp = str(time.time()) + timestamp = get_timestamp() + request.data = json.dumps(request.data) + + if request.params: + path = request.path + '?' + urlencode(request.params) + else: + path = request.path + + msg = timestamp + request.method + path + request.data + signature = generate_signature(msg, self.secret) + + # Add headers + request.headers = { + 'OK-ACCESS-KEY': self.key, + 'OK-ACCESS-SIGN': signature, + 'OK-ACCESS-TIMESTAMP': timestamp, + 'OK-ACCESS-PASSPHRASE': self.passphrase, + 'Content-Type': 'application/json' + } + return request + + def connect( + self, + key: str, + secret: str, + passphrase: str, + session_number: int, + proxy_host: str, + proxy_port: int, + ): + """ + Initialize connection to REST server. + """ + self.key = key + self.secret = secret.encode() + self.passphrase = passphrase + + self.connect_time = int(datetime.now().strftime("%y%m%d%H%M%S")) + + self.init(REST_HOST, proxy_host, proxy_port) + self.start(session_number) + self.gateway.write_log("REST API启动成功") + + self.query_time() + self.query_contract() + self.query_account() + self.query_order() + + def _new_order_id(self): + with self.order_count_lock: + self.order_count += 1 + return self.order_count + + def send_order(self, req: OrderRequest): + """""" + orderid = f"a{self.connect_time}{self._new_order_id()}" + + data = { + "client_oid": orderid, + "type": ORDERTYPE_VT2OKEX[req.type], + "side": DIRECTION_VT2OKEX[req.direction], + "instrument_id": req.symbol + } + + if req.type == OrderType.MARKET: + if req.direction == Direction.LONG: + data["notional"] = req.volume + else: + data["size"] = req.volume + else: + data["price"] = req.price + data["size"] = req.volume + + order = req.create_order_data(orderid, self.gateway_name) + + self.add_request( + "POST", + "/api/spot/v3/orders", + callback=self.on_send_order, + data=data, + extra=order, + on_failed=self.on_send_order_failed, + on_error=self.on_send_order_error, + ) + + self.gateway.on_order(order) + return order.vt_orderid + + def cancel_order(self, req: CancelRequest): + """""" + data = { + "instrument_id": req.symbol, + "client_oid": req.orderid + } + + path = "/api/spot/v3/cancel_orders/" + req.orderid + self.add_request( + "POST", + path, + callback=self.on_cancel_order, + data=data, + on_error=self.on_cancel_order_error, + ) + + def query_contract(self): + """""" + self.add_request( + "GET", + "/api/spot/v3/instruments", + callback=self.on_query_contract + ) + + def query_account(self): + """""" + self.add_request( + "GET", + "/api/spot/v3/accounts", + callback=self.on_query_account + ) + + def query_order(self): + """""" + self.add_request( + "GET", + "/api/spot/v3/orders_pending", + callback=self.on_query_order + ) + + def query_time(self): + """""" + self.add_request( + "GET", + "/api/general/v3/time", + callback=self.on_query_time + ) + + def on_query_contract(self, data, request): + """""" + for instrument_data in data: + symbol = instrument_data["instrument_id"] + contract = ContractData( + symbol=symbol, + exchange=Exchange.OKEX, + name=symbol, + product=Product.SPOT, + size=1, + pricetick=instrument_data["tick_size"], + gateway_name=self.gateway_name + ) + self.gateway.on_contract(contract) + + instruments.add(instrument_data["instrument_id"]) + currencies.add(instrument_data["base_currency"]) + currencies.add(instrument_data["quote_currency"]) + + self.gateway.write_log("合约信息查询成功") + + # Start websocket api after instruments data collected + self.gateway.ws_api.start() + + def on_query_account(self, data, request): + """""" + for account_data in data: + account = AccountData( + accountid=account_data["currency"], + balance=float(account_data["balance"]), + frozen=float(account_data["hold"]), + gateway_name=self.gateway_name + ) + self.gateway.on_account(account) + + self.gateway.write_log("账户资金查询成功") + + def on_query_order(self, data, request): + """""" + for order_data in data: + order = OrderData( + symbol=order_data["instrument_id"], + exchange=Exchange.OKEX, + type=ORDERTYPE_OKEX2VT[order_data["type"]], + orderid=order_data["client_oid"], + direction=DIRECTION_OKEX2VT[order_data["side"]], + price=float(order_data["price"]), + volume=float(order_data["size"]), + time=order_data["timestamp"][11:19], + status=STATUS_OKEX2VT[order_data["status"]], + gateway_name=self.gateway_name, + ) + self.gateway.on_order(order) + + self.gateway.write_log("委托信息查询成功") + + def on_query_time(self, data, request): + """""" + server_time = data["iso"] + local_time = datetime.utcnow().isoformat() + msg = f"服务器时间:{server_time},本机时间:{local_time}" + self.gateway.write_log(msg) + + def on_send_order_failed(self, status_code: str, request: Request): + """ + Callback when sending order failed on server. + """ + order = request.extra + order.status = Status.REJECTED + self.gateway.on_order(order) + + msg = f"委托失败,状态码:{status_code},信息:{request.response.text}" + self.gateway.write_log(msg) + + def on_send_order_error( + self, exception_type: type, exception_value: Exception, tb, request: Request + ): + """ + Callback when sending order caused exception. + """ + order = request.extra + order.status = Status.REJECTED + self.gateway.on_order(order) + + # Record exception if not ConnectionError + if not issubclass(exception_type, ConnectionError): + self.on_error(exception_type, exception_value, tb, request) + + def on_send_order(self, data, request): + """Websocket will push a new order status""" + order = request.extra + + error_msg = data["error_message"] + if error_msg: + order.status = Status.REJECTED + self.gateway.on_order(order) + + self.gateway.write_log(f"委托失败:{error_msg}") + + def on_cancel_order_error( + self, exception_type: type, exception_value: Exception, tb, request: Request + ): + """ + Callback when cancelling order failed on server. + """ + # Record exception if not ConnectionError + if not issubclass(exception_type, ConnectionError): + self.on_error(exception_type, exception_value, tb, request) + + def on_cancel_order(self, data, request): + """Websocket will push a new order status""" + pass + + def on_failed(self, status_code: int, request: Request): + """ + Callback to handle request failed. + """ + msg = f"请求失败,状态码:{status_code},信息:{request.response.text}" + self.gateway.write_log(msg) + + def on_error( + self, exception_type: type, exception_value: Exception, tb, request: Request + ): + """ + Callback to handler request exception. + """ + msg = f"触发异常,状态码:{exception_type},信息:{exception_value}" + self.gateway.write_log(msg) + + sys.stderr.write( + self.exception_detail(exception_type, exception_value, tb, request) + ) + + +class OkexWebsocketApi(WebsocketClient): + """""" + + def __init__(self, gateway): + """""" + super(OkexWebsocketApi, self).__init__() + self.ping_interval = 20 # OKEX use 30 seconds for ping + + self.gateway = gateway + self.gateway_name = gateway.gateway_name + + self.key = "" + self.secret = "" + self.passphrase = "" + + self.trade_count = 10000 + self.connect_time = 0 + + self.callbacks = {} + self.ticks = {} + + def connect( + self, + key: str, + secret: str, + passphrase: str, + proxy_host: str, + proxy_port: int + ): + """""" + self.key = key + self.secret = secret.encode() + self.passphrase = passphrase + + self.connect_time = int(datetime.now().strftime("%y%m%d%H%M%S")) + + self.init(WEBSOCKET_HOST, proxy_host, proxy_port) + # self.start() + + def unpack_data(self, data): + """""" + return json.loads(zlib.decompress(data, -zlib.MAX_WBITS)) + + def subscribe(self, req: SubscribeRequest): + """ + Subscribe to tick data upate. + """ + tick = TickData( + symbol=req.symbol, + exchange=req.exchange, + name=req.symbol, + datetime=datetime.now(), + gateway_name=self.gateway_name, + ) + self.ticks[req.symbol] = tick + + channel_ticker = f"spot/ticker:{req.symbol}" + channel_depth = f"spot/depth5:{req.symbol}" + + self.callbacks[channel_ticker] = self.on_ticker + self.callbacks[channel_depth] = self.on_depth + + req = { + "op": "subscribe", + "args": [channel_ticker, channel_depth] + } + self.send_packet(req) + + def on_connected(self): + """""" + self.gateway.write_log("Websocket API连接成功") + self.login() + + def on_disconnected(self): + """""" + self.gateway.write_log("Websocket API连接断开") + + def on_packet(self, packet: dict): + """""" + if "event" in packet: + event = packet["event"] + if event == "subscribe": + return + elif event == "error": + msg = packet["message"] + self.gateway.write_log(f"Websocket API请求异常:{msg}") + elif event == "login": + self.on_login(packet) + else: + channel = packet["table"] + data = packet["data"] + callback = self.callbacks.get(channel, None) + + if callback: + for d in data: + callback(d) + + def on_error(self, exception_type: type, exception_value: Exception, tb): + """""" + msg = f"触发异常,状态码:{exception_type},信息:{exception_value}" + self.gateway.write_log(msg) + + sys.stderr.write(self.exception_detail( + exception_type, exception_value, tb)) + + def login(self): + """ + Need to login befores subscribe to websocket topic. + """ + timestamp = str(time.time()) + + msg = timestamp + 'GET' + '/users/self/verify' + signature = generate_signature(msg, self.secret) + + req = { + "op": "login", + "args": [ + self.key, + self.passphrase, + timestamp, + signature.decode("utf-8") + ] + } + self.send_packet(req) + self.callbacks['login'] = self.on_login + + def subscribe_topic(self): + """ + Subscribe to all private topics. + """ + self.callbacks["spot/ticker"] = self.on_ticker + self.callbacks["spot/depth5"] = self.on_depth + self.callbacks["spot/account"] = self.on_account + self.callbacks["spot/order"] = self.on_order + + # Subscribe to order update + channels = [] + for instrument_id in instruments: + channel = f"spot/order:{instrument_id}" + channels.append(channel) + + req = { + "op": "subscribe", + "args": channels + } + self.send_packet(req) + + # Subscribe to account update + channels = [] + for currency in currencies: + channel = f"spot/account:{currency}" + channels.append(channel) + + req = { + "op": "subscribe", + "args": channels + } + self.send_packet(req) + + # Subscribe to BTC/USDT trade for keep connection alive + req = { + "op": "subscribe", + "args": ["spot/trade:BTC-USDT"] + } + self.send_packet(req) + + def on_login(self, data: dict): + """""" + success = data.get("success", False) + + if success: + self.gateway.write_log("Websocket API登录成功") + self.subscribe_topic() + else: + self.gateway.write_log("Websocket API登录失败") + + def on_ticker(self, d): + """""" + symbol = d["instrument_id"] + tick = self.ticks.get(symbol, None) + if not tick: + return + + tick.last_price = d["last"] + tick.open = d["open_24h"] + tick.high = d["high_24h"] + tick.low = d["low_24h"] + tick.volume = d["base_volume_24h"] + tick.datetime = datetime.strptime( + d["timestamp"], "%Y-%m-%dT%H:%M:%S.%fZ") + self.gateway.on_tick(copy(tick)) + + def on_depth(self, d): + """""" + for tick_data in d: + symbol = d["instrument_id"] + tick = self.ticks.get(symbol, None) + if not tick: + return + + bids = d["bids"] + asks = d["asks"] + for n, buf in enumerate(bids): + price, volume, _ = buf + tick.__setattr__("bid_price_%s" % (n + 1), price) + tick.__setattr__("bid_volume_%s" % (n + 1), volume) + + for n, buf in enumerate(asks): + price, volume, _ = buf + tick.__setattr__("ask_price_%s" % (n + 1), price) + tick.__setattr__("ask_volume_%s" % (n + 1), volume) + + tick.datetime = datetime.strptime( + d["timestamp"], "%Y-%m-%dT%H:%M:%S.%fZ") + self.gateway.on_tick(copy(tick)) + + def on_order(self, d): + """""" + order = OrderData( + symbol=d["instrument_id"], + exchange=Exchange.OKEX, + type=ORDERTYPE_OKEX2VT[d["type"]], + orderid=d["client_oid"], + direction=DIRECTION_OKEX2VT[d["side"]], + price=d["price"], + volume=d["size"], + traded=d["filled_size"], + time=d["timestamp"][11:19], + status=STATUS_OKEX2VT[d["status"]], + gateway_name=self.gateway_name, + ) + self.gateway.on_order(copy(order)) + + trade_volume = float(d.get("last_fill_qty", 0)) + if not trade_volume: + return + + self.trade_count += 1 + tradeid = f"{self.connect_time}{self.trade_count}" + + trade = TradeData( + symbol=order.symbol, + exchange=order.exchange, + orderid=order.orderid, + tradeid=tradeid, + direction=order.direction, + price=float(d["last_fill_px"]), + volume=float(trade_volume), + time=d["last_fill_time"][11:19], + gateway_name=self.gateway_name + ) + self.gateway.on_trade(trade) + + def on_account(self, d): + """""" + account = AccountData( + accountid=d["currency"], + balance=float(d["balance"]), + frozen=float(d["hold"]), + gateway_name=self.gateway_name + ) + + self.gateway.on_account(copy(account)) + + +def generate_signature(msg: str, secret_key: str): + """OKEX V3 signature""" + return base64.b64encode(hmac.new(secret_key, msg.encode(), hashlib.sha256).digest()) + + +def get_timestamp(): + """""" + now = datetime.utcnow() + timestamp = now.isoformat("T", "milliseconds") + return timestamp + "Z" diff --git a/vnpy/gateway/tiger/tiger_gateway.py b/vnpy/gateway/tiger/tiger_gateway.py index eecd9305..997d3f02 100644 --- a/vnpy/gateway/tiger/tiger_gateway.py +++ b/vnpy/gateway/tiger/tiger_gateway.py @@ -1,5 +1,6 @@ # encoding: UTF-8 """ +Author: KeKe Please install tiger-api before use. pip install tigeropen """ @@ -123,6 +124,9 @@ class TigerGateway(BaseGateway): self.contracts = {} self.symbol_names = {} + self.push_connected = False + self.subscribed_symbols = set() + def run(self): """""" while self.active: @@ -203,23 +207,33 @@ class TigerGateway(BaseGateway): """ protocol, host, port = self.client_config.socket_host_port self.push_client = PushClient(host, port, (protocol == 'ssl')) - self.push_client.connect( - self.client_config.tiger_id, self.client_config.private_key) self.push_client.quote_changed = self.on_quote_change self.push_client.asset_changed = self.on_asset_change self.push_client.position_changed = self.on_position_change self.push_client.order_changed = self.on_order_change - self.write_log("推送接口连接成功") + self.push_client.connect( + self.client_config.tiger_id, self.client_config.private_key) def subscribe(self, req: SubscribeRequest): """""" - self.push_client.subscribe_quote([req.symbol]) + self.subscribed_symbols.add(req.symbol) + + if self.push_connected: + self.push_client.subscribe_quote([req.symbol]) + + def on_push_connected(self): + """""" + self.push_connected = True + self.write_log("推送接口连接成功") + self.push_client.subscribe_asset() self.push_client.subscribe_position() self.push_client.subscribe_order() + self.push_client.subscribe_quote(list(self.subscribed_symbols)) + def on_quote_change(self, tiger_symbol: str, data: list, trading: bool): """""" data = dict(data) diff --git a/vnpy/rpc/__init__.py b/vnpy/rpc/__init__.py new file mode 100644 index 00000000..b903d2bb --- /dev/null +++ b/vnpy/rpc/__init__.py @@ -0,0 +1 @@ +from .vnrpc import RpcServer, RpcClient, RemoteException \ No newline at end of file diff --git a/vnpy/rpc/test_client.py b/vnpy/rpc/test_client.py new file mode 100644 index 00000000..7a693a52 --- /dev/null +++ b/vnpy/rpc/test_client.py @@ -0,0 +1,36 @@ +from __future__ import print_function +from __future__ import absolute_import +from time import sleep + +from vnpy.rpc import RpcClient + + +class TestClient(RpcClient): + """ + Test RpcClient + """ + + def __init__(self, req_address, sub_address): + """ + Constructor + """ + super(TestClient, self).__init__(req_address, sub_address) + + def callback(self, topic, data): + """ + Realize callable function + """ + print('client received topic:', topic, ', data:', data) + + +if __name__ == '__main__': + req_address = 'tcp://localhost:2014' + sub_address = 'tcp://localhost:0602' + + tc = TestClient(req_address, sub_address) + tc.subscribeTopic('') + tc.start() + + while 1: + print(tc.add(1, 3)) + sleep(2) diff --git a/vnpy/rpc/test_server.py b/vnpy/rpc/test_server.py new file mode 100644 index 00000000..660168fc --- /dev/null +++ b/vnpy/rpc/test_server.py @@ -0,0 +1,40 @@ +from __future__ import print_function +from __future__ import absolute_import +from time import sleep, time + +from vnpy.rpc import RpcServer + + +class TestServer(RpcServer): + """ + Test RpcServer + """ + + def __init__(self, rep_address, pub_address): + """ + Constructor + """ + super(TestServer, self).__init__(rep_address, pub_address) + + self.register(self.add) + + def add(self, a, b): + """ + Test function + """ + print('receiving: %s, %s' % (a, b)) + return a + b + + +if __name__ == '__main__': + rep_address = 'tcp://*:2014' + pub_address = 'tcp://*:0602' + + ts = TestServer(rep_address, pub_address) + ts.start() + + while 1: + content = 'current server time is %s' % time() + print(content) + ts.publish('test', content) + sleep(2) diff --git a/vnpy/rpc/vnrpc.py b/vnpy/rpc/vnrpc.py new file mode 100644 index 00000000..1cdd835a --- /dev/null +++ b/vnpy/rpc/vnrpc.py @@ -0,0 +1,329 @@ +import threading +import traceback +import signal + +import zmq +from msgpack import packb, unpackb +from json import dumps, loads + +import pickle +p_dumps = pickle.dumps +p_loads = pickle.loads + + +# Achieve Ctrl-c interrupt recv +signal.signal(signal.SIGINT, signal.SIG_DFL) + + +class RpcObject(object): + """ + Referred to serialization of packing and unpacking, we offer 3 tools: + 1) maspack: higher performance, but usually requires the installation of msgpack related tools; + 2) jason: Slightly lower performance but versatility is better, most programming languages have built-in libraries; + 3) cPickle: Lower performance and only can be used in Python, but it is very convenient to transfer Python objects directly. + + Therefore, it is recommended to use msgpack. + Use json, if you want to communicate with some languages without providing msgpack. + Use cPickle, when the data being transferred contains many custom Python objects. + """ + + def __init__(self): + """ + Constructor + Use msgpack as default serialization tool + """ + self.use_msgpack() + + def pack(self, data): + """""" + pass + + def unpack(self, data): + """""" + pass + + def __json_pack(self, data): + """ + Pack with json + """ + return dumps(data) + + def __json_unpack(self, data): + """ + Unpack with json + """ + return loads(data) + + def __msgpack_pack(self, data): + """ + Pack with msgpack + """ + return packb(data) + + def __msgpack_unpack(self, data): + """ + Unpack with msgpack + """ + return unpackb(data) + + def __pickle_pack(self, data): + """ + Pack with cPickle + """ + return p_dumps(data) + + def __pickle_unpack(self, data): + """ + Unpack with cPickle + """ + return p_loads(data) + + def use_json(self): + """ + Use json as serialization tool + """ + self.pack = self.__json_pack + self.unpack = self.__json_unpack + + def use_msgpack(self): + """ + Use msgpack as serialization tool + """ + self.pack = self.__msgpack_pack + self.unpack = self.__msgpack_unpack + + def use_pickle(self): + """ + Use cPickle as serialization tool + """ + self.pack = self.__pickle_pack + self.unpack = self.__pickle_unpack + + +class RpcServer(RpcObject): + """""" + + def __init__(self, rep_address, pub_address): + """ + Constructor + """ + super(RpcServer, self).__init__() + + # Save functions dict: key is fuction name, value is fuction object + self.__functions = {} + + # Zmq port related + self.__context = zmq.Context() + + self.__socket_rep = self.__context.socket( + zmq.REP) # Reply socket (Request–reply pattern) + self.__socket_rep.bind(rep_address) + + # Publish socket (Publish–subscribe pattern) + self.__socket_pub = self.__context.socket(zmq.PUB) + self.__socket_pub.bind(pub_address) + + # Woker thread related + self.__active = False # RpcServer status + self.__thread = threading.Thread(target=self.run) # RpcServer thread + + def start(self): + """ + Start RpcServer + """ + # Start RpcServer status + self.__active = True + + # Start RpcServer thread + if not self.__thread.isAlive(): + self.__thread.start() + + def stop(self, join=False): + """ + Stop RpcServer + """ + # Stop RpcServer status + self.__active = False + + # Wait for RpcServer thread to exit + if join and self.__thread.isAlive(): + self.__thread.join() + + def run(self): + """ + Run RpcServer functions + """ + while self.__active: + # Use poll to wait event arrival, waiting time is 1 second (1000 milliseconds) + if not self.__socket_rep.poll(1000): + continue + + # Receive request data from Reply socket + reqb = self.__socket_rep.recv() + + # Unpack request by deserialization + req = self.unpack(reqb) + + # Get function name and parameters + name, args, kwargs = req + + # Try to get and execute callable function object; capture exception information if it fails + name = name.decode("UTF-8") + + try: + func = self.__functions[name] + r = func(*args, **kwargs) + rep = [True, r] + except Exception as e: # noqa + rep = [False, traceback.format_exc()] + + # Pack response by serialization + repb = self.pack(rep) + + # send callable response by Reply socket + self.__socket_rep.send(repb) + + def publish(self, topic, data): + """ + Publish data + """ + # Serialized data + topic = bytes(topic, "UTF-8") + datab = self.pack(data) + + # Send data by Publish socket + # topci must be ascii encoding + self.__socket_pub.send_multipart([topic, datab]) + + def register(self, func): + """ + Register function + """ + self.__functions[func.__name__] = func + + +class RpcClient(RpcObject): + """""" + + def __init__(self, req_address, sub_address): + """Constructor""" + super(RpcClient, self).__init__() + + # zmq port related + self.__req_address = req_address + self.__sub_address = sub_address + + self.__context = zmq.Context() + # Request socket (Request–reply pattern) + self.__socket_req = self.__context.socket(zmq.REQ) + # Subscribe socket (Publish–subscribe pattern) + self.__socket_sub = self.__context.socket(zmq.SUB) + + # Woker thread relate, used to process data pushed from server + self.__active = False # RpcClient status + self.__thread = threading.Thread( + target=self.run) # RpcClient thread + + def __getattr__(self, name): + """ + Realize remote call function + """ + # Perform remote call task + def dorpc(*args, **kwargs): + # Generate request + req = [name, args, kwargs] + + # Pack request by serialization + reqb = self.pack(req) + + # Send request and wait for response + self.__socket_req.send(reqb) + repb = self.__socket_req.recv() + + # Unpack response by deserialization + rep = self.unpack(repb) + + # Return response if successed; Trigger exception if failed + if rep[0]: + return rep[1] + else: + raise RemoteException(rep[1].decode("UTF-8")) + + return dorpc + + def start(self): + """ + Start RpcClient + """ + # Connect zmq port + self.__socket_req.connect(self.__req_address) + self.__socket_sub.connect(self.__sub_address) + + # Start RpcClient status + self.__active = True + + # Start RpcClient thread + if not self.__thread.isAlive(): + self.__thread.start() + + def stop(self): + """ + Stop RpcClient + """ + # Stop RpcClient status + self.__active = False + + # Wait for RpcClient thread to exit + if self.__thread.isAlive(): + self.__thread.join() + + def run(self): + """ + Run RpcClient function + """ + while self.__active: + # Use poll to wait event arrival, waiting time is 1 second (1000 milliseconds) + if not self.__socket_sub.poll(1000): + continue + + # Receive data from subscribe socket + topic, datab = self.__socket_sub.recv_multipart() + + # Unpack data by deserialization + data = self.unpack(datab) + + # Process data by callable function + topic = topic.decode("UTF-8") + + self.callback(topic, data) + + def callback(self, topic, data): + """ + Callable function + """ + raise NotImplementedError + + def subscribeTopic(self, topic): + """ + Subscribe data + """ + topic = bytes(topic, "UTF-8") + self.__socket_sub.setsockopt(zmq.SUBSCRIBE, topic) + + +class RemoteException(Exception): + """ + RPC remote exception + """ + + def __init__(self, value): + """ + Constructor + """ + self.__value = value + + def __str__(self): + """ + Output error message + """ + return self.__value diff --git a/vnpy/trader/constant.py b/vnpy/trader/constant.py index 92492f5b..f4727ba3 100644 --- a/vnpy/trader/constant.py +++ b/vnpy/trader/constant.py @@ -99,6 +99,8 @@ class Exchange(Enum): # CryptoCurrency BITMEX = "BITMEX" + OKEX = "OKEX" + HUOBI = "HUOBI" class Currency(Enum): diff --git a/vnpy/trader/engine.py b/vnpy/trader/engine.py index 251d2853..a2e098ee 100644 --- a/vnpy/trader/engine.py +++ b/vnpy/trader/engine.py @@ -52,6 +52,7 @@ class MainEngine: """ engine = engine_class(self, self.event_engine) self.engines[engine.engine_name] = engine + return engine def add_gateway(self, gateway_class: BaseGateway): """ @@ -59,6 +60,7 @@ class MainEngine: """ gateway = gateway_class(self.event_engine) self.gateways[gateway.gateway_name] = gateway + return gateway def add_app(self, app_class: BaseApp): """ @@ -67,7 +69,8 @@ class MainEngine: app = app_class() self.apps[app.app_name] = app - self.add_engine(app.engine_class) + engine = self.add_engine(app.engine_class) + return engine def init_engines(self): """ diff --git a/vnpy/trader/gateway.py b/vnpy/trader/gateway.py index d1a6248f..ce7aa591 100644 --- a/vnpy/trader/gateway.py +++ b/vnpy/trader/gateway.py @@ -4,6 +4,7 @@ from abc import ABC, abstractmethod from typing import Any +from copy import copy from vnpy.event import Event, EventEngine from .event import ( @@ -227,3 +228,124 @@ class BaseGateway(ABC): Return default setting dict. """ return self.default_setting + + +class LocalOrderManager: + """ + Management tool to support use local order id for trading. + """ + + def __init__(self, gateway: BaseGateway): + """""" + self.gateway = gateway + + # For generating local orderid + self.order_prefix = "" + self.order_count = 0 + self.orders = {} # local_orderid:order + + # Map between local and system orderid + self.local_sys_orderid_map = {} + self.sys_local_orderid_map = {} + + # Push order data buf + self.push_data_buf = {} # sys_orderid:data + + # Callback for processing push order data + self.push_data_callback = None + + # Cancel request buf + self.cancel_request_buf = {} # local_orderid:req + + def new_local_orderid(self): + """ + Generate a new local orderid. + """ + self.order_count += 1 + local_orderid = str(self.order_count).rjust(8, "0") + return local_orderid + + def get_local_orderid(self, sys_orderid: str): + """ + Get local orderid with sys orderid. + """ + local_orderid = self.sys_local_orderid_map.get(sys_orderid, "") + + if not local_orderid: + local_orderid = self.new_local_orderid() + self.update_orderid_map(local_orderid, sys_orderid) + + return local_orderid + + def get_sys_orderid(self, local_orderid: str): + """ + Get sys orderid with local orderid. + """ + sys_orderid = self.local_sys_orderid_map.get(local_orderid, "") + return sys_orderid + + def update_orderid_map(self, local_orderid: str, sys_orderid: str): + """ + Update orderid map. + """ + self.sys_local_orderid_map[sys_orderid] = local_orderid + self.local_sys_orderid_map[local_orderid] = sys_orderid + + self.check_cancel_request(local_orderid) + self.check_push_data(sys_orderid) + + def check_push_data(self, sys_orderid: str): + """ + Check if any order push data waiting. + """ + if sys_orderid not in self.push_data_buf: + return + + data = self.push_data_buf.pop(sys_orderid) + if self.push_data_callback: + self.push_data_callback(data) + + def add_push_data(self, sys_orderid: str, data: dict): + """ + Add push data into buf. + """ + self.push_data_buf[sys_orderid] = data + + def get_order_with_sys_orderid(self, sys_orderid: str): + """""" + local_orderid = self.sys_local_orderid_map.get(sys_orderid, None) + if not local_orderid: + return None + else: + return self.get_order_with_local_orderid(local_orderid) + + def get_order_with_local_orderid(self, local_orderid: str): + """""" + order = self.orders[local_orderid] + return copy(order) + + def on_order(self, order: OrderData): + """ + Keep an order buf before pushing it to gateway. + """ + self.orders[order.orderid] = copy(order) + self.gateway.on_order(order) + + def cancel_order(self, req: CancelRequest): + """ + """ + sys_orderid = self.get_sys_orderid(req.orderid) + if not sys_orderid: + self.cancel_request_buf[req.orderid] = req + return + + self.gateway.cancel_order(req) + + def check_cancel_request(self, local_orderid: str): + """ + """ + if local_orderid not in self.cancel_request_buf: + return + + req = self.cancel_request_buf.pop(local_orderid) + self.gateway.cancel_order(req)