Merge pull request #3 from vnpy/DEV

Dev
This commit is contained in:
RobinLiu 2019-04-07 14:15:56 +08:00 committed by GitHub
commit b075e1d653
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
48 changed files with 2754 additions and 150 deletions

View File

@ -29,6 +29,8 @@ matrix:
- choco install python3 --version 3.7.2 - choco install python3 --version 3.7.2
install: install:
- python -m pip install --upgrade pip wheel setuptools - python -m pip install --upgrade pip wheel setuptools
- pip install https://pip.vnpy.com/colletion/TA_Lib-0.4.17-cp37-cp37m-win_amd64.whl
- pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
- pip install -r requirements.txt - pip install -r requirements.txt
- pip install . - pip install .
@ -43,8 +45,10 @@ matrix:
- sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-8 90 - sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-8 90
- sudo update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++-8 90 - sudo update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++-8 90
- sudo update-alternatives --install /usr/bin/gcc cc /usr/bin/gcc-8 90 - sudo update-alternatives --install /usr/bin/gcc cc /usr/bin/gcc-8 90
# Linux install script # update pip & setuptools
- python -m pip install --upgrade pip wheel setuptools - python -m pip install --upgrade pip wheel setuptools
# Linux install script
- pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
- bash ./install.sh - bash ./install.sh
- name: "pip install under Ubuntu: gcc-7" - name: "pip install under Ubuntu: gcc-7"
@ -58,8 +62,10 @@ matrix:
- sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-7 90 - sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-7 90
- sudo update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++-7 90 - sudo update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++-7 90
- sudo update-alternatives --install /usr/bin/gcc cc /usr/bin/gcc-7 90 - sudo update-alternatives --install /usr/bin/gcc cc /usr/bin/gcc-7 90
# Linux install script # update pip & setuptools
- python -m pip install --upgrade pip wheel setuptools - python -m pip install --upgrade pip wheel setuptools
# Linux install script
- pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
- bash ./install.sh - bash ./install.sh
- name: "sdist install under Windows" - name: "sdist install under Windows"

View File

@ -1,10 +1,18 @@
# By Traders, For Traders. # By Traders, For Traders.
<p align="center"> <p align="center">
<img src ="https://vnpy.oss-cn-shanghai.aliyuncs.com/vnpy-logo.png" /> <img src ="https://vnpy.oss-cn-shanghai.aliyuncs.com/vnpy-logo.png"/>
</p> </p>
vn.py是一套基于Python的开源量化交易系统开发框架自2015年1月正式发布以来在开源社区5年持续不断的贡献下一步步成长为全功能量化交易平台目前国内外金融机构用户已经超过300家包括私募基金、证券自营和资管、期货资管和子公司、高校研究机构、自营交易公司、交易所、Token Fund等。 <p align="center">
<img src ="https://img.shields.io/badge/version-2.0.1-blueviolet.svg"/>
<img src ="https://img.shields.io/badge/platform-windows|linux|macos-yellow.svg"/>
<img src ="https://img.shields.io/badge/python-3.7-blue.svg" />
<img src ="https://img.shields.io/travis/com/vnpy/vnpy/master.svg"/>
<img src ="https://img.shields.io/github/license/vnpy/vnpy.svg?color=orange"/>
</p>
vn.py是一套基于Python的开源量化交易系统开发框架于2015年1月正式发布在开源社区5年持续不断的贡献下一步步成长为全功能量化交易平台目前国内外金融机构用户已经超过300家包括私募基金、证券自营和资管、期货资管和子公司、高校研究机构、自营交易公司、交易所、Token Fund等。
2.0版本基于Python 3.7全新重构开发目前功能还在逐步完善中。如需Python 2上的版本请点击[长期支持版本v1.9.2 LTS](https://github.com/vnpy/vnpy/tree/v1.9.2-LTS)。 2.0版本基于Python 3.7全新重构开发目前功能还在逐步完善中。如需Python 2上的版本请点击[长期支持版本v1.9.2 LTS](https://github.com/vnpy/vnpy/tree/v1.9.2-LTS)。
@ -14,17 +22,23 @@ vn.py是一套基于Python的开源量化交易系统开发框架自2015年1
2. 覆盖国内外所有交易品种的交易接口vnpy.gateway 2. 覆盖国内外所有交易品种的交易接口vnpy.gateway
* CTP(ctpGateway):国内期货、期权 * CTP(ctp):国内期货、期权
* 富途证券(futuGateway):港股、美 * 宽睿(oes)A
* Interactive Brokers(ibGateway):全球证券、期货、期权、外汇等 * 富途证券(futu):港股、美股
* BitMEX (bitmexGateway):数字货币期货、期权、永续合约 * 老虎证券(tiger):全球证券、期货、期权、外汇等
* Interactive Brokers(ib):全球证券、期货、期权、外汇等
* BitMEX (bitmex):数字货币期货、期权、永续合约
3. 开箱即用的各类量化策略交易应用vnpy.app 3. 开箱即用的各类量化策略交易应用vnpy.app
* CtaStrategyCTA策略引擎模块在保持易用性的同时允许用户针对CTA类策略运行过程中委托的报撤行为进行细粒度控制降低交易滑点、实现高频策略 * cta_strategyCTA策略引擎模块在保持易用性的同时允许用户针对CTA类策略运行过程中委托的报撤行为进行细粒度控制降低交易滑点、实现高频策略
* csv_loaderCSV历史数据加载器用于加载CSV格式文件中的历史数据到平台数据库中用于策略的回测研究以及实盘初始化等功能支持自定义数据表头格式
4. Python交易API接口封装vnpy.api提供上述交易接口的底层对接实现。 4. Python交易API接口封装vnpy.api提供上述交易接口的底层对接实现。
@ -36,11 +50,9 @@ vn.py是一套基于Python的开源量化交易系统开发框架自2015年1
## 环境准备 ## 环境准备
* 推荐使用vn.py团队为量化交易专门打造的Python发行版[VNConda-2.0-Windows-x86_64](https://conda.vnpy.com/VNConda-2.0-Windows-x86_64.exe)内置了最新版的vn.py无需手动安装 * 推荐使用vn.py团队为量化交易专门打造的Python发行版[VNConda-2.0.1-Windows-x86_64](https://conda.vnpy.com/VNConda-2.0.1-Windows-x86_64.exe)内置了最新版的vn.py框架以及VN Station量化管理平台,无需手动安装
* 支持的系统版本Windows 7以上/Windows Server 2008以上/Ubuntu 18.04 LTS * 支持的系统版本Windows 7以上/Windows Server 2008以上/Ubuntu 18.04 LTS
* 支持的Python版本Python 3.7 64位**注意必须是Python 3.7 64位版本** * 支持的Python版本Python 3.7 64位**注意必须是Python 3.7 64位版本**
* 如需使用IB API请在[Interactive Brokers Github](https://interactivebrokers.github.io/#)页面下载安装**IB API Latest**
## 安装步骤 ## 安装步骤
@ -59,15 +71,20 @@ vn.py是一套基于Python的开源量化交易系统开发框架自2015年1
1. 在[SimNow](http://www.simnow.com.cn/)注册CTP仿真账号并在[该页面](http://www.simnow.com.cn/product.action)获取经纪商代码以及交易行情服务器地址。 1. 在[SimNow](http://www.simnow.com.cn/)注册CTP仿真账号并在[该页面](http://www.simnow.com.cn/product.action)获取经纪商代码以及交易行情服务器地址。
2. 在[vn.py社区论坛](https://www.vnpy.com/forum/)注册获得VN Station账号密码,论坛最新的注册邀请码为**El86Pa1p** 2. 在[vn.py社区论坛](https://www.vnpy.com/forum/)注册获得VN Station账号密码(论坛账号密码即是)
3. 启动VN Station安装VNConda后会在桌面自动创建快捷方式输入上一步的账号密码登录 3. 启动VN Station安装VNConda后会在桌面自动创建快捷方式输入上一步的账号密码登录
4. 点击底部的**VN Trader**按钮选择运行目录默认在系统用户目录即可在对话框中勾选CTP接口以及CtaStrategy应用点击右下方的**启动**按钮,开始你的交易!!! 4. 点击底部的**VN Trader Lite**按钮,开始你的交易!!!
5. 在VN Trader的运行过程中请勿关闭VN Station会自动退出 注意:
* 在VN Trader的运行过程中请勿关闭VN Station会自动退出
* 如需要灵活配置量化交易应用组件,请使用**VN Trader Pro**
6. 如选择了VNConda以外的安装方式不推荐新手可以在任意目录下创建run.py写入以下示例代码后运行
## 脚本运行
除了基于VN Station的图形化启动方式外也可以在任意目录下创建run.py写入以下示例代码
```Python ```Python
from vnpy.event import EventEngine from vnpy.event import EventEngine
@ -77,7 +94,7 @@ from vnpy.gateway.ctp import CtpGateway
from vnpy.app.cta_strategy import CtaStrategyApp from vnpy.app.cta_strategy import CtaStrategyApp
def main(): def main():
"""启动VN Trader""" """Start VN Trader"""
qapp = create_qapp() qapp = create_qapp()
event_engine = EventEngine() event_engine = EventEngine()
@ -95,11 +112,13 @@ if __name__ == "__main__":
main() main()
``` ```
在该目录下打开CMD按住Shift->点击鼠标右键->在此处打开命令窗口/PowerShell后运行下列命令启动VN Trader
python run.py
## 贡献代码 ## 贡献代码
vn.py使用github托管其源代码如果希望贡献代码请使用github的PR(Pull Request)的流程: vn.py使用Github托管其源代码如果希望贡献代码请使用github的PR(Pull Request)的流程:
1. [创建 Issue](https://github.com/vnpy/vnpy/issues/new) - 对于较大的改动(如新功能,大型重构等)最好先开issue讨论一下较小的improvement(如文档改进bugfix等)直接发PR即可 1. [创建 Issue](https://github.com/vnpy/vnpy/issues/new) - 对于较大的改动(如新功能,大型重构等)最好先开issue讨论一下较小的improvement(如文档改进bugfix等)直接发PR即可

View File

@ -28,7 +28,6 @@ version = '2.0'
# The full version, including alpha/beta/rc tags # The full version, including alpha/beta/rc tags
release = '2.0-DEV' release = '2.0-DEV'
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here. # If your documentation needs a minimal Sphinx version, state it here.
@ -75,19 +74,20 @@ exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
# The name of the Pygments (syntax highlighting) style to use. # The name of the Pygments (syntax highlighting) style to use.
pygments_style = None pygments_style = None
# -- Options for HTML output ------------------------------------------------- # -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for # The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes. # a list of builtin themes.
# #
html_theme = 'alabaster' html_theme = 'alabaster'
# Theme options are theme-specific and customize the look and feel of a theme # Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the # further. For a list of options available for each theme, see the
# documentation. # documentation.
# #
# html_theme_options = {} html_theme_options = {
"base_bg": "inherit",
"narrow_sidebar_bg": "inherit",
}
# Add any paths that contain custom static files (such as style sheets) here, # Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files, # relative to this directory. They are copied after the builtin static files,
@ -110,7 +110,6 @@ html_static_path = ['_static']
# Output file base name for HTML help builder. # Output file base name for HTML help builder.
htmlhelp_basename = 'vnpydoc' htmlhelp_basename = 'vnpydoc'
# -- Options for LaTeX output ------------------------------------------------ # -- Options for LaTeX output ------------------------------------------------
latex_elements = { latex_elements = {
@ -139,7 +138,6 @@ latex_documents = [
'vn.py Team', 'manual'), 'vn.py Team', 'manual'),
] ]
# -- Options for manual page output ------------------------------------------ # -- Options for manual page output ------------------------------------------
# One entry per manual page. List of tuples # One entry per manual page. List of tuples
@ -149,7 +147,6 @@ man_pages = [
[author], 1) [author], 1)
] ]
# -- Options for Texinfo output ---------------------------------------------- # -- Options for Texinfo output ----------------------------------------------
# Grouping the document tree into Texinfo files. List of tuples # Grouping the document tree into Texinfo files. List of tuples
@ -161,7 +158,6 @@ texinfo_documents = [
'Miscellaneous'), 'Miscellaneous'),
] ]
# -- Options for Epub output ------------------------------------------------- # -- Options for Epub output -------------------------------------------------
# Bibliographic Dublin Core info. # Bibliographic Dublin Core info.

0
docs/csv_loader.md Normal file
View File

View File

@ -1 +1,23 @@
# Introduction # CTA策略模块
## 模块构成
## 历史数据
## 策略开发
## 回测研究
## 参数优化
## 实盘运行

0
docs/gateway.md Normal file
View File

View File

@ -1,6 +1,15 @@
# vn.py文档 # vn.py文档
* [vn.py简介](introduction.md) * 快速入门
* [项目安装](install.md) * [项目简介](introduction.md)
* [基本使用](quickstart.md) * [环境安装](install.md)
* [CTA策略模块](cta_strategy.md) * [基本使用](quickstart.md)
* 应用模块
* [CTA策略](cta_strategy.md)
* [CSV载入](csv_loader.md)
* [交易接口](gateway.md)
* [RPC应用](rpc.md)
* [贡献代码](contribution.md)

View File

@ -1,8 +1,33 @@
# 安装指南 # 安装指南
## Windows
### 使用VNConda
**ssleay32.dll问题**
如果电脑上之前安装过其他的Python环境或者应用软件有可能会存在SSL相关动态链接库路径被修改的问题在运行VN Station时弹出如下图所示的错误
![ssleay32.dll](https://user-images.githubusercontent.com/7112268/55474371-8bd06a00-5643-11e9-8b35-f064a45edfd1.png)
解决方法:
1. 找到你的VNConda目录
2. 将VNConda\Lib\site-packages\PyQt5\Qt\bin目录的两个动态库libeay32.dll和ssleay32.dll拷贝到VNConda\下即可
### 手动安装
## Ubuntu ## Ubuntu
如果是英文系统,请先运行下列命令安装中文编码:
### 安装脚本
### TA-Lib
### 中文编码
如果是英文系统(如阿里云),请先运行下列命令安装中文编码:
``` ```
sudo locale-gen zh_CN.GB18030 sudo locale-gen zh_CN.GB18030

View File

@ -1 +1,15 @@
# Introduction # Introduction
## 目标用户
## 应用场景
## 支持的接口
## 支持的应用

View File

@ -1 +1,20 @@
# Introduction # 基本使用
## 启动VN Trader
## 连接接口
## 订阅行情
## 委托交易
## 数据监控
## 应用模块

0
docs/rpc.md Normal file
View File

View File

@ -1,6 +1,6 @@
::Install talib and ibapi ::Install talib and ibapi
pip install https://vnpy-pip.oss-cn-shanghai.aliyuncs.com/colletion/TA_Lib-0.4.17-cp37-cp37m-win_amd64.whl pip install https://pip.vnpy.com/colletion/TA_Lib-0.4.17-cp37-cp37m-win_amd64.whl
pip install https://vnpy-pip.oss-cn-shanghai.aliyuncs.com/colletion/ibapi-9.75.1-py3-none-any.whl pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
::Install Python Modules ::Install Python Modules
pip install -r requirements.txt pip install -r requirements.txt

View File

@ -13,6 +13,10 @@ popd
# old versions of ta-lib imports numpy in setup.py # old versions of ta-lib imports numpy in setup.py
pip install numpy pip install numpy
# Install extra packages
pip install ta-lib
pip install https://vnpy-pip.oss-cn-shanghai.aliyuncs.com/colletion/ibapi-9.75.1-py3-none-any.whl
# Install Python Modules # Install Python Modules
pip install -r requirements.txt pip install -r requirements.txt

View File

@ -10,11 +10,5 @@ matplotlib
seaborn seaborn
futu-api futu-api
tigeropen tigeropen
ta-lib
# ta-lib ibapi
ta-lib; platform_system=="Linux"
https://pip.vnpy.com/colletion/TA_Lib-0.4.17-cp37-cp37m-win_amd64.whl; platform_system=="Windows"
# ibapi
https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl

View File

@ -0,0 +1,189 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"import multiprocessing\n",
"import numpy as np\n",
"from deap import creator, base, tools, algorithms\n",
"from vnpy.app.cta_strategy.backtesting import BacktestingEngine\n",
"from boll_channel_strategy import BollChannelStrategy\n",
"from datetime import datetime\n",
"import multiprocessing #多进程\n",
"from scoop import futures #多进程"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def parameter_generate():\n",
" '''\n",
" 根据设置的起始值,终止值和步进,随机生成待优化的策略参数\n",
" '''\n",
" parameter_list = []\n",
" p1 = random.randrange(4,50,2) #布林带窗口\n",
" p2 = random.randrange(4,50,2) #布林带通道阈值\n",
" p3 = random.randrange(4,50,2) #CCI窗口\n",
" p4 = random.randrange(18,40,2) #ATR窗口 \n",
"\n",
" parameter_list.append(p1)\n",
" parameter_list.append(p2)\n",
" parameter_list.append(p3)\n",
" parameter_list.append(p4)\n",
"\n",
" return parameter_list"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def object_func(strategy_avg):\n",
" \"\"\"\n",
" 本函数为优化目标函数根据随机生成的策略参数运行回测后自动返回2个结果指标收益回撤比和夏普比率\n",
" \"\"\"\n",
" # 创建回测引擎对象\n",
" engine = BacktestingEngine()\n",
" engine.set_parameters(\n",
" vt_symbol=\"IF88.CFFEX\",\n",
" interval=\"1m\",\n",
" start=datetime(2018, 9, 1),\n",
" end=datetime(2019, 1,1),\n",
" rate=0,\n",
" slippage=0,\n",
" size=300,\n",
" pricetick=0.2,\n",
" capital=1_000_000,\n",
" )\n",
"\n",
" setting = {'boll_window': strategy_avg[0], #布林带窗口\n",
" 'boll_dev': strategy_avg[1], #布林带通道阈值\n",
" 'cci_window': strategy_avg[2], #CCI窗口\n",
" 'atr_window': strategy_avg[3],} #ATR窗口 \n",
"\n",
" #加载策略 \n",
" #engine.initStrategy(TurtleTradingStrategy, setting)\n",
" engine.add_strategy(BollChannelStrategy, setting)\n",
" engine.load_data()\n",
" engine.run_backtesting()\n",
" engine.calculate_result()\n",
" result = engine.calculate_statistics(Output=False)\n",
"\n",
" return_drawdown_ratio = round(result['return_drawdown_ratio'],2) #收益回撤比\n",
" sharpe_ratio= round(result['sharpe_ratio'],2) #夏普比率\n",
" return return_drawdown_ratio , sharpe_ratio"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"#设置优化方向:最大化收益回撤比,最大化夏普比率\n",
"creator.create(\"FitnessMulti\", base.Fitness, weights=(1.0, 1.0)) # 1.0 求最大值;-1.0 求最小值\n",
"creator.create(\"Individual\", list, fitness=creator.FitnessMulti)\n",
"\n",
"def optimize():\n",
" \"\"\"\"\"\" \n",
" toolbox = base.Toolbox() #Toolbox是deap库内置的工具箱里面包含遗传算法中所用到的各种函数\n",
"\n",
" # 初始化 \n",
" toolbox.register(\"individual\", tools.initIterate, creator.Individual,parameter_generate) # 注册个体随机生成的策略参数parameter_generate() \n",
" toolbox.register(\"population\", tools.initRepeat, list, toolbox.individual) #注册种群:个体形成种群 \n",
" toolbox.register(\"mate\", tools.cxTwoPoint) #注册交叉:两点交叉 \n",
" toolbox.register(\"mutate\", tools.mutUniformInt,low = 4,up = 40,indpb=0.6) #注册变异:随机生成一定区间内的整数\n",
" toolbox.register(\"evaluate\", object_func) #注册评估优化目标函数object_func() \n",
" toolbox.register(\"select\", tools.selNSGA2) #注册选择:NSGA-II(带精英策略的非支配排序的遗传算法)\n",
" #pool = multiprocessing.Pool()\n",
" #toolbox.register(\"map\", pool.map)\n",
" #toolbox.register(\"map\", futures.map)\n",
"\n",
" #遗传算法参数设置\n",
" MU = 40 #设置每一代选择的个体数\n",
" LAMBDA = 160 #设置每一代产生的子女数\n",
" pop = toolbox.population(400) #设置族群里面的个体数量\n",
" CXPB, MUTPB, NGEN = 0.5, 0.35,40 #分别为种群内部个体的交叉概率、变异概率、产生种群代数\n",
" hof = tools.ParetoFront() #解的集合:帕累托前沿(非占优最优集)\n",
"\n",
" #解的集合的描述统计信息\n",
" #集合内平均值,标准差,最小值,最大值可以体现集合的收敛程度\n",
" #收敛程度低可以增加算法的迭代次数\n",
" stats = tools.Statistics(lambda ind: ind.fitness.values)\n",
" np.set_printoptions(suppress=True) #对numpy默认输出的科学计数法转换\n",
" stats.register(\"mean\", np.mean, axis=0) #统计目标优化函数结果的平均值\n",
" stats.register(\"std\", np.std, axis=0) #统计目标优化函数结果的标准差\n",
" stats.register(\"min\", np.min, axis=0) #统计目标优化函数结果的最小值\n",
" stats.register(\"max\", np.max, axis=0) #统计目标优化函数结果的最大值\n",
"\n",
" #运行算法\n",
" algorithms.eaMuPlusLambda(pop, toolbox, MU, LAMBDA, CXPB, MUTPB, NGEN, stats,\n",
" halloffame=hof) #esMuPlusLambda是一种基于(μ+λ)选择策略的多目标优化分段遗传算法\n",
"\n",
" return pop"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"optimize()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.1"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -10,6 +10,8 @@ from vnpy.gateway.ib import IbGateway
from vnpy.gateway.ctp import CtpGateway from vnpy.gateway.ctp import CtpGateway
from vnpy.gateway.tiger import TigerGateway from vnpy.gateway.tiger import TigerGateway
from vnpy.gateway.oes import OesGateway from vnpy.gateway.oes import OesGateway
from vnpy.gateway.okex import OkexGateway
from vnpy.gateway.huobi import HuobiGateway
from vnpy.app.cta_strategy import CtaStrategyApp from vnpy.app.cta_strategy import CtaStrategyApp
from vnpy.app.csv_loader import CsvLoaderApp from vnpy.app.csv_loader import CsvLoaderApp
@ -28,6 +30,8 @@ def main():
main_engine.add_gateway(BitmexGateway) main_engine.add_gateway(BitmexGateway)
main_engine.add_gateway(TigerGateway) main_engine.add_gateway(TigerGateway)
main_engine.add_gateway(OesGateway) main_engine.add_gateway(OesGateway)
main_engine.add_gateway(OkexGateway)
main_engine.add_gateway(HuobiGateway)
main_engine.add_app(CtaStrategyApp) main_engine.add_app(CtaStrategyApp)
main_engine.add_app(CsvLoaderApp) main_engine.add_app(CsvLoaderApp)

View File

@ -1 +1 @@
__version__ = "2.0.1b0" __version__ = "2.0.1"

Binary file not shown.

View File

@ -1,2 +1,2 @@
from .vnapex import ApexApi from .vnapex import *
from .fiddef import * from .fiddef import *

View File

@ -31,7 +31,7 @@ class ApexApi:
def set_app_info(self, name: str, version: str): def set_app_info(self, name: str, version: str):
"""设置应用信息""" """设置应用信息"""
n = APEX.Fix_SetAppInfo(c_char_p(name), c_char_p(version)) n = APEX.Fix_SetAppInfo(to_bytes(name), to_bytes(version))
return bool(n) return bool(n)
def uninitialize(self): def uninitialize(self):
@ -42,19 +42,19 @@ class ApexApi:
def set_default_info(self, user: str, wtfs: str, fbdm: str, dest: str): def set_default_info(self, user: str, wtfs: str, fbdm: str, dest: str):
"""设置默认信息""" """设置默认信息"""
n = APEX.Fix_SetDefaultInfo( n = APEX.Fix_SetDefaultInfo(
c_char_p(user), to_bytes(user),
c_char_p(wtfs), to_bytes(wtfs),
c_char_p(fbdm), to_bytes(fbdm),
c_char_p(dest) to_bytes(dest)
) )
return bool(n) return bool(n)
def connect(self, address: str, khh: str, pwd: str, timeout: int): def connect(self, address: str, khh: str, pwd: str, timeout: int):
"""连接交易""" """连接交易"""
conn = APEX.Fix_Connect( conn = APEX.Fix_Connect(
c_char_p(address), to_bytes(address),
c_char_p(khh), to_bytes(khh),
c_char_p(pwd), to_bytes(pwd),
timeout timeout
) )
return conn return conn
@ -66,13 +66,13 @@ class ApexApi:
): ):
"""连接交易""" """连接交易"""
conn = APEX.Fix_ConnectEx( conn = APEX.Fix_ConnectEx(
c_char_p(address), to_bytes(address),
c_char_p(khh), to_bytes(khh),
c_char_p(pwd), to_bytes(pwd),
c_char_p(file_cert), to_bytes(file_cert),
c_char_p(cert_pwd), to_bytes(cert_pwd),
c_char_p(file_ca), to_bytes(file_ca),
c_char_p(procotol), to_bytes(procotol),
verify, verify,
timeout timeout
) )
@ -100,27 +100,27 @@ class ApexApi:
def set_wtfs(self, sess: int, wtfs: str): def set_wtfs(self, sess: int, wtfs: str):
"""设置委托方式""" """设置委托方式"""
n = APEX.Fix_SetWTFS(sess, c_char_p(wtfs)) n = APEX.Fix_SetWTFS(sess, to_bytes(wtfs))
return bool(n) return bool(n)
def set_fbdm(self, sess: int, fbdm: str): def set_fbdm(self, sess: int, fbdm: str):
"""设置来源营业部""" """设置来源营业部"""
n = APEX.Fix_SetFBDM(sess, c_char_p(fbdm)) n = APEX.Fix_SetFBDM(sess, to_bytes(fbdm))
return bool(n) return bool(n)
def set_dest_fbdm(self, sess: int, fbdm: str): def set_dest_fbdm(self, sess: int, fbdm: str):
"""设置目标营业部""" """设置目标营业部"""
n = APEX.Fix_SetDestFBDM(sess, c_char_p(fbdm)) n = APEX.Fix_SetDestFBDM(sess, to_bytes(fbdm))
return bool(n) return bool(n)
def set_node(self, sess: int, node: str): def set_node(self, sess: int, node: str):
"""设置业务站点""" """设置业务站点"""
n = APEX.Fix_SetNode(sess, c_char_p(node)) n = APEX.Fix_SetNode(sess, to_bytes(node))
return bool(n) return bool(n)
def set_gydm(self, sess: int, gydm: str): def set_gydm(self, sess: int, gydm: str):
"""设置柜员号""" """设置柜员号"""
n = APEX.Fix_SetGYDM(sess, c_char_p(gydm)) n = APEX.Fix_SetGYDM(sess, to_bytes(gydm))
return bool(n) return bool(n)
def create_head(self, sess: int, func: int): def create_head(self, sess: int, func: int):
@ -170,7 +170,7 @@ class ApexApi:
def get_err_msg(self, sess: int): def get_err_msg(self, sess: int):
"""获取错误信息""" """获取错误信息"""
size = 256 size = 256
out = create_string_buffer("", size) out = create_string_buffer(b"", size)
APEX.Fix_GetErrMsg(sess, out, size) APEX.Fix_GetErrMsg(sess, out, size)
return out.value return out.value
@ -182,7 +182,7 @@ class ApexApi:
def get_item(self, sess: int, fid: int, row: int): def get_item(self, sess: int, fid: int, row: int):
"""获取字符串内容""" """获取字符串内容"""
size = 256 size = 256
out = create_string_buffer("", size) out = create_string_buffer(b"", size)
APEX.Fix_GetItem(sess, fid, out, size, row) APEX.Fix_GetItem(sess, fid, out, size, row)
return out.value return out.value
@ -210,19 +210,21 @@ class ApexApi:
def get_token(self, sess: int): def get_token(self, sess: int):
"""获取业务令牌""" """获取业务令牌"""
size = 256 size = 256
out = create_string_buffer("", size) out = create_string_buffer(b"", size)
APEX.Fix_GetToken(sess, out, size) APEX.Fix_GetToken(sess, out, size)
return out.value return out.value
def encode(self, data): def encode(self, data: str):
"""加密""" """加密"""
data = to_bytes(data)
buf = create_string_buffer(data, 512) buf = create_string_buffer(data, 512)
APEX.Fix_Encode(buf) APEX.Fix_Encode(buf)
return buf.value return to_unicode(buf.value)
def add_backup_svc_addr(self, address: str): def add_backup_svc_addr(self, address: str):
"""设置业务令牌""" """设置业务令牌"""
address = to_bytes(address)
n = APEX.Fix_AddBackupSvrAddr(address) n = APEX.Fix_AddBackupSvrAddr(address)
return bool(n) return bool(n)
@ -238,9 +240,11 @@ class ApexApi:
def subscribe_by_customer(self, conn: int, svc: int, khh: str, pwd: str): def subscribe_by_customer(self, conn: int, svc: int, khh: str, pwd: str):
"""订阅数据""" """订阅数据"""
func = APEX[93] func = APEX[108]
n = func(conn, svc, self.push_call_func, c_char_p(""), khh, pwd) n = func(conn, svc, self.push_call_func,
return bool(n) to_bytes(""), to_bytes(khh), to_bytes(pwd))
return n
def unsubscribe_by_handle(self, handle: int): def unsubscribe_by_handle(self, handle: int):
"""退订推送""" """退订推送"""
@ -254,22 +258,22 @@ class ApexApi:
def get_val_with_id_by_index(self, sess: int, row: int, col: int): def get_val_with_id_by_index(self, sess: int, row: int, col: int):
"""根据行列获取数据""" """根据行列获取数据"""
s = 256 s = 256
buf = create_string_buffer("", s) buf = create_string_buffer(b"", s)
fid = c_long(0) fid = c_long(0)
size = c_int(s) size = c_int(s)
APEX.Fix_GetValWithIdByIndex( APEX.Fix_GetValWithIdByIndex(
sess, row, col, byref(fid), buf, byref(size)) sess, row, col, byref(fid), buf, byref(size))
return fid.value, buf.value return fid.value, to_unicode(buf.value)
def set_system_no(self, sess: int, val: str): def set_system_no(self, sess: int, val: str):
"""设置系统编号""" """设置系统编号"""
n = APEX.Fix_SetSystemNo(sess, c_char_p(val)) n = APEX.Fix_SetSystemNo(sess, to_bytes(val))
return bool(n) return bool(n)
def set_default_system_no(self, val: str): def set_default_system_no(self, val: str):
"""设置默认系统编号""" """设置默认系统编号"""
n = APEX.Fix_SetDefaultSystemNo(c_char_p(val)) n = APEX.Fix_SetDefaultSystemNo(to_bytes(val))
return bool(n) return bool(n)
def set_auto_reconnect(self, conn: int, reconnect: int): def set_auto_reconnect(self, conn: int, reconnect: int):
@ -291,23 +295,23 @@ class ApexApi:
"""获取缓存数据""" """获取缓存数据"""
size = 1024 size = 1024
outlen = c_int(size) outlen = c_int(size)
buf = create_string_buffer("", size) buf = create_string_buffer(b"", size)
APEX.Fix_GetItemBuf(sess, buf, byref(outlen), row) APEX.Fix_GetItemBuf(sess, buf, byref(outlen), row)
return buf return buf
def set_item(self, sess: int, fid: int, val: str): def set_item(self, sess: int, fid: int, val: str):
"""设置请求内容""" """设置请求内容"""
n = APEX.Fix_SetString(sess, fid, c_char_p(val)) n = APEX.Fix_SetString(sess, fid, to_bytes(val))
return bool(n) return bool(n)
def get_last_err_msg(self): def get_last_err_msg(self):
"""获取错误信息""" """获取错误信息"""
size = 256 size = 256
out = create_string_buffer("", size) out = create_string_buffer(b"", size)
APEX.Fix_GetLastErrMsg(out, size) APEX.Fix_GetLastErrMsg(out, size)
return out.value return to_unicode(out.value)
def reg_reply_call_func(self, sess: int = 0): def reg_reply_call_func(self, sess: int = 0):
"""注册回调函数""" """注册回调函数"""
@ -328,3 +332,21 @@ class ApexApi:
def on_conn(self, conn: int, event, recv): def on_conn(self, conn: int, event, recv):
"""连接回调(需要继承)""" """连接回调(需要继承)"""
return True return True
def to_bytes(data: str):
"""
将unicode字符串转换为bytes
"""
try:
bytes_data = data.encode("GBK")
return bytes_data
except AttributeError:
return data
def to_unicode(data: bytes):
"""
将bytes字符串转换为unicode
"""
return data.decode("GBK")

View File

@ -48,14 +48,18 @@ class WebsocketClient(object):
self.proxy_host = None self.proxy_host = None
self.proxy_port = None self.proxy_port = None
self.ping_interval = 60 # seconds
# For debugging # For debugging
self._last_sent_text = None self._last_sent_text = None
self._last_received_text = None self._last_received_text = None
def init(self, host: str, proxy_host: str = "", proxy_port: int = 0): def init(self, host: str, proxy_host: str = "", proxy_port: int = 0, ping_interval: int = 60):
"""""" """
:param ping_interval: unit: seconds, type: int
"""
self.host = host self.host = host
self.ping_interval = ping_interval # seconds
if proxy_host and proxy_port: if proxy_host and proxy_port:
self.proxy_host = proxy_host self.proxy_host = proxy_host
@ -206,7 +210,7 @@ class WebsocketClient(object):
et, ev, tb = sys.exc_info() et, ev, tb = sys.exc_info()
self.on_error(et, ev, tb) self.on_error(et, ev, tb)
self._reconnect() self._reconnect()
for i in range(60): for i in range(self.ping_interval):
if not self._active: if not self._active:
break break
sleep(1) sleep(1)

View File

@ -0,0 +1,17 @@
from pathlib import Path
from vnpy.trader.app import BaseApp
from .engine import AlgoEngine, APP_NAME
class CtaStrategyApp(BaseApp):
""""""
app_name = APP_NAME
app_module = __module__
app_path = Path(__file__).parent
display_name = "算法交易"
engine_class = AlgoEngine
widget_name = "AlgoManager"
icon_name = "algo.ico"

View File

@ -0,0 +1,59 @@
# encoding: UTF-8
'''
动态载入所有的策略类
'''
from __future__ import print_function
import os
import importlib
import traceback
# 用来保存算法类和控件类的字典
ALGO_DICT = {}
WIDGET_DICT = {}
#----------------------------------------------------------------------
def loadAlgoModule(path, prefix):
"""使用importlib动态载入算法"""
for root, subdirs, files in os.walk(path):
for name in files:
# 只有文件名以Algo.py结尾的才是算法文件
if len(name)>7 and name[-7:] == 'Algo.py':
try:
# 模块名称需要模块路径前缀
moduleName = prefix + name.replace('.py', '')
module = importlib.import_module(moduleName)
# 获取算法类和控件类
algo = None
widget = None
for k in dir(module):
# 以Algo结尾的类是算法
if k[-4:] == 'Algo':
algo = module.__getattribute__(k)
# 以Widget结尾的类是控件
if k[-6:] == 'Widget':
widget = module.__getattribute__(k)
# 保存到字典中
if algo and widget:
ALGO_DICT[algo.templateName] = algo
WIDGET_DICT[algo.templateName] = widget
except:
print ('-' * 20)
print ('Failed to import strategy file %s:' %moduleName)
traceback.print_exc()
# 遍历algo目录下的文件
path1 = os.path.abspath(os.path.dirname(__file__))
loadAlgoModule(path1, 'vnpy.trader.app.algoTrading.algo.')
# 遍历工作目录下的文件
path2 = os.getcwd()
loadAlgoModule(path2, '')

View File

View File

@ -0,0 +1,65 @@
from vnpy.event import EventEngine
from vnpy.trader.engine import BaseEngine, MainEngine
from vnpy.trader.event import (EVENT_TICK, EVENT_TIMER, EVENT_ORDER, EVENT_TRADE)
class AlgoEngine(BaseEngine):
""""""
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
"""Constructor"""
super().__init__(main_engine, event_engine)
self.algos = {}
self.symbol_algo_map = {}
self.orderid_algo_map = {}
self.register_event()
def register_event(self):
""""""
self.event_engine.register(EVENT_TICK, self.process_tick_event)
self.event_engine.register(EVENT_TIMER, self.process_timer_event)
self.event_engine.register(EVENT_ORDER, self.process_order_event)
self.event_engine.register(EVENT_TRADE, self.process_trade_event)
def process_tick_event(self):
""""""
pass
def process_timer_event(self):
""""""
pass
def process_trade_event(self):
""""""
pass
def process_order_event(self):
""""""
pass
def start_algo(self, setting: dict):
""""""
pass
def stop_algo(self, algo_name: dict):
""""""
pass
def stop_all(self):
""""""
pass
def subscribe(self, algo, vt_symbol):
""""""
pass
def send_order(
self,
algo,
vt_symbol
):
""""""
pass

View File

@ -0,0 +1,123 @@
from vnpy.trader.engine import BaseEngine
from vnpy.trader.object import TickData, OrderData, TradeData
from vnpy.trader.constant import OrderType, Offset
class AlgoTemplate:
""""""
count = 0
def __init__(
self,
algo_engine: BaseEngine,
algo_name: str,
setting: dict
):
"""Constructor"""
self.algo_engine = algo_engine
self.algo_name = algo_name
self.active = False
self.active_orders = {} # vt_orderid:order
@staticmethod
def new(cls, algo_engine:BaseEngine, setting: dict):
"""Create new algo instance"""
cls.count += 1
algo_name = f"{cls.__name__}_{cls.count}"
algo = cls(algo_engine, algo_name, setting)
def update_tick(self, tick: TickData):
""""""
if self.active:
self.on_tick(tick)
def update_order(self, order: OrderData):
""""""
if self.active:
if order.is_active():
self.active_orders[order.vt_orderid] = order
elif order.vt_orderid in self.active_orders:
self.active_orders.pop(order.vt_orderid)
self.on_order(order)
def update_trade(self, trade: TradeData):
""""""
if self.active:
self.on_trade(trade)
def update_timer(self):
""""""
if self.active:
self.on_timer()
def on_start(self):
""""""
pass
def on_stop(self):
""""""
pass
def on_tick(self, tick: TickData):
""""""
pass
def on_order(self, order: OrderData):
""""""
pass
def on_trade(self, trade: TradeData):
""""""
pass
def on_timer(self):
""""""
pass
def start(self):
""""""
pass
def stop(self):
""""""
pass
def buy(
self,
vt_symbol,
price,
volume,
order_type: OrderType = OrderType.LIMIT,
offset: Offset = Offset.NONE
):
""""""
return self.algo_engine.buy(
vt_symbol,
price,
volume,
order_type,
offset
)
def sell(
self,
vt_symbol,
price,
volume,
order_type: OrderType = OrderType.LIMIT,
offset: Offset = Offset.NONE
):
""""""
return self.algo_engine.buy(
vt_symbol,
price,
volume,
order_type,
offset
)

Binary file not shown.

After

Width:  |  Height:  |  Size: 66 KiB

View File

View File

@ -23,9 +23,11 @@ Sample csv file:
import csv import csv
from datetime import datetime from datetime import datetime
from peewee import chunked
from vnpy.event import EventEngine from vnpy.event import EventEngine
from vnpy.trader.constant import Exchange, Interval from vnpy.trader.constant import Exchange, Interval
from vnpy.trader.database import DbBarData from vnpy.trader.database import DbBarData, DB
from vnpy.trader.engine import BaseEngine, MainEngine from vnpy.trader.engine import BaseEngine, MainEngine
@ -75,30 +77,37 @@ class CsvLoaderEngine(BaseEngine):
with open(file_path, 'rt') as f: with open(file_path, 'rt') as f:
reader = csv.DictReader(f) reader = csv.DictReader(f)
db_bars = []
for item in reader: for item in reader:
db_bar = DbBarData() dt = datetime.strptime(item[datetime_head], datetime_format)
db_bar.symbol = symbol db_bar = {
db_bar.exchange = exchange.value "symbol": symbol,
db_bar.datetime = datetime.strptime( "exchange": exchange.value,
item[datetime_head], datetime_format "datetime": dt,
) "interval": interval.value,
db_bar.interval = interval.value "volume": item[volume_head],
db_bar.volume = item[volume_head] "open_price": item[open_head],
db_bar.open_price = item[open_head] "high_price": item[high_head],
db_bar.high_price = item[high_head] "low_price": item[low_head],
db_bar.low_price = item[low_head] "close_price": item[close_head],
db_bar.close_price = item[close_head] "vt_symbol": vt_symbol,
db_bar.vt_symbol = vt_symbol "gateway_name": "DB"
db_bar.gateway_name = "DB" }
db_bar.replace() db_bars.append(db_bar)
# do some statistics # do some statistics
count += 1 count += 1
if not start: if not start:
start = db_bar.datetime start = db_bar["datetime"]
end = db_bar.datetime end = db_bar["datetime"]
# Insert into DB
with DB.atomic():
for batch in chunked(db_bars, 500):
DbBarData.insert_many(batch).on_conflict_replace().execute()
return start, end, count return start, end, count

View File

@ -90,7 +90,8 @@ class CsvLoaderWidget(QtWidgets.QWidget):
def select_file(self): def select_file(self):
"""""" """"""
result: str = QtWidgets.QFileDialog.getOpenFileName(self) result: str = QtWidgets.QFileDialog.getOpenFileName(
self, filter="CSV (*.csv)")
filename = result[0] filename = result[0]
if filename: if filename:
self.file_edit.setText(filename) self.file_edit.setText(filename)

View File

@ -35,7 +35,7 @@ class OptimizationSetting:
def __init__(self): def __init__(self):
"""""" """"""
self.params = {} self.params = {}
self.target = "" self.target_name = ""
def add_parameter( def add_parameter(
self, name: str, start: float, end: float = None, step: float = None self, name: str, start: float, end: float = None, step: float = None
@ -62,9 +62,9 @@ class OptimizationSetting:
self.params[name] = value_list self.params[name] = value_list
def set_target(self, target: str): def set_target(self, target_name: str):
"""""" """"""
self.target = target self.target_name = target_name
def generate_setting(self): def generate_setting(self):
"""""" """"""
@ -293,7 +293,7 @@ class BacktestingEngine:
self.output("逐日盯市盈亏计算完成") self.output("逐日盯市盈亏计算完成")
return self.daily_df return self.daily_df
def calculate_statistics(self, df: DataFrame = None): def calculate_statistics(self, df: DataFrame = None, Output=True):
"""""" """"""
self.output("开始计算策略统计指标") self.output("开始计算策略统计指标")
@ -325,6 +325,7 @@ class BacktestingEngine:
daily_return = 0 daily_return = 0
return_std = 0 return_std = 0
sharpe_ratio = 0 sharpe_ratio = 0
return_drawdown_ratio = 0
else: else:
# Calculate balance related time series data # Calculate balance related time series data
df["balance"] = df["net_pnl"].cumsum() + self.capital df["balance"] = df["net_pnl"].cumsum() + self.capital
@ -373,38 +374,42 @@ class BacktestingEngine:
else: else:
sharpe_ratio = 0 sharpe_ratio = 0
return_drawdown_ratio = -total_return / max_ddpercent
# Output # Output
self.output("-" * 30) if Output:
self.output(f"首个交易日:\t{start_date}") self.output("-" * 30)
self.output(f"最后交易日:\t{end_date}") self.output(f"首个交易日:\t{start_date}")
self.output(f"最后交易日:\t{end_date}")
self.output(f"总交易日:\t{total_days}") self.output(f"总交易日:\t{total_days}")
self.output(f"盈利交易日:\t{profit_days}") self.output(f"盈利交易日:\t{profit_days}")
self.output(f"亏损交易日:\t{loss_days}") self.output(f"亏损交易日:\t{loss_days}")
self.output(f"起始资金:\t{self.capital:,.2f}") self.output(f"起始资金:\t{self.capital:,.2f}")
self.output(f"结束资金:\t{end_balance:,.2f}") self.output(f"结束资金:\t{end_balance:,.2f}")
self.output(f"总收益率:\t{total_return:,.2f}%") self.output(f"总收益率:\t{total_return:,.2f}%")
self.output(f"年化收益:\t{annual_return:,.2f}%") self.output(f"年化收益:\t{annual_return:,.2f}%")
self.output(f"最大回撤: \t{max_drawdown:,.2f}") self.output(f"最大回撤: \t{max_drawdown:,.2f}")
self.output(f"百分比最大回撤: {max_ddpercent:,.2f}%") self.output(f"百分比最大回撤: {max_ddpercent:,.2f}%")
self.output(f"总盈亏:\t{total_net_pnl:,.2f}") self.output(f"总盈亏:\t{total_net_pnl:,.2f}")
self.output(f"总手续费:\t{total_commission:,.2f}") self.output(f"总手续费:\t{total_commission:,.2f}")
self.output(f"总滑点:\t{total_slippage:,.2f}") self.output(f"总滑点:\t{total_slippage:,.2f}")
self.output(f"总成交金额:\t{total_turnover:,.2f}") self.output(f"总成交金额:\t{total_turnover:,.2f}")
self.output(f"总成交笔数:\t{total_trade_count}") self.output(f"总成交笔数:\t{total_trade_count}")
self.output(f"日均盈亏:\t{daily_net_pnl:,.2f}") self.output(f"日均盈亏:\t{daily_net_pnl:,.2f}")
self.output(f"日均手续费:\t{daily_commission:,.2f}") self.output(f"日均手续费:\t{daily_commission:,.2f}")
self.output(f"日均滑点:\t{daily_slippage:,.2f}") self.output(f"日均滑点:\t{daily_slippage:,.2f}")
self.output(f"日均成交金额:\t{daily_turnover:,.2f}") self.output(f"日均成交金额:\t{daily_turnover:,.2f}")
self.output(f"日均成交笔数:\t{daily_trade_count}") self.output(f"日均成交笔数:\t{daily_trade_count}")
self.output(f"日均收益率:\t{daily_return:,.2f}%") self.output(f"日均收益率:\t{daily_return:,.2f}%")
self.output(f"收益标准差:\t{return_std:,.2f}%") self.output(f"收益标准差:\t{return_std:,.2f}%")
self.output(f"Sharpe Ratio\t{sharpe_ratio:,.2f}") self.output(f"Sharpe Ratio\t{sharpe_ratio:,.2f}")
self.output(f"收益回撤比:\t{return_drawdown_ratio:,.2f}")
statistics = { statistics = {
"start_date": start_date, "start_date": start_date,
@ -430,6 +435,7 @@ class BacktestingEngine:
"daily_return": daily_return, "daily_return": daily_return,
"return_std": return_std, "return_std": return_std,
"sharpe_ratio": sharpe_ratio, "sharpe_ratio": sharpe_ratio,
"return_drawdown_ratio": return_drawdown_ratio,
} }
return statistics return statistics
@ -473,7 +479,7 @@ class BacktestingEngine:
return return
if not target_name: if not target_name:
self.output("优化目标设置,请检查") self.output("优化目标设置,请检查")
return return
# Use multiprocessing pool for running backtesting with different setting # Use multiprocessing pool for running backtesting with different setting

View File

@ -151,13 +151,14 @@ class CtaEngine(BaseEngine):
Query bar data from RQData. Query bar data from RQData.
""" """
symbol, exchange_str = vt_symbol.split(".") symbol, exchange_str = vt_symbol.split(".")
if symbol.upper() not in self.rq_symbols: rq_symbol = to_rq_symbol(vt_symbol)
if rq_symbol not in self.rq_symbols:
return None return None
end += timedelta(1) # For querying night trading period data end += timedelta(1) # For querying night trading period data
df = self.rq_client.get_price( df = self.rq_client.get_price(
symbol.upper(), rq_symbol,
frequency=interval.value, frequency=interval.value,
fields=["open", "high", "low", "close", "volume"], fields=["open", "high", "low", "close", "volume"],
start_date=start, start_date=start,
@ -540,7 +541,7 @@ class CtaEngine(BaseEngine):
DbBarData.select() DbBarData.select()
.where( .where(
(DbBarData.vt_symbol == vt_symbol) (DbBarData.vt_symbol == vt_symbol)
& (DbBarData.interval == interval) & (DbBarData.interval == interval.value)
& (DbBarData.datetime >= start) & (DbBarData.datetime >= start)
& (DbBarData.datetime <= end) & (DbBarData.datetime <= end)
) )
@ -910,3 +911,29 @@ class CtaEngine(BaseEngine):
subject = "CTA策略引擎" subject = "CTA策略引擎"
self.main_engine.send_email(subject, msg) self.main_engine.send_email(subject, msg)
def to_rq_symbol(vt_symbol: str):
"""
CZCE product of RQData has symbol like "TA1905" while
vt symbol is "TA905.CZCE" so need to add "1" in symbol.
"""
symbol, exchange_str = vt_symbol.split(".")
if exchange_str != "CZCE":
return symbol.upper()
for count, word in enumerate(symbol):
if word.isdigit():
break
product = symbol[:count]
year = symbol[count]
month = symbol[count + 1:]
if year == "9":
year = "1" + year
else:
year = "2" + year
rq_symbol = f"{product}{year}{month}".upper()
return rq_symbol

View File

@ -2,7 +2,7 @@
from abc import ABC from abc import ABC
from typing import Any, Callable from typing import Any, Callable
from vnpy.trader.constant import Interval, Status, Direction, Offset from vnpy.trader.constant import Interval, Direction, Offset
from vnpy.trader.object import BarData, TickData, OrderData, TradeData from vnpy.trader.object import BarData, TickData, OrderData, TradeData
from .base import StopOrder, EngineType from .base import StopOrder, EngineType

View File

@ -405,7 +405,7 @@ class CtpTdApi(TdApi):
"""""" """"""
if not error['ErrorID']: if not error['ErrorID']:
self.authStatus = True self.authStatus = True
self.writeLog("交易授权验证成功") self.gateway.write_log("交易授权验证成功")
self.login() self.login()
else: else:
self.gateway.write_error("交易授权验证失败", error) self.gateway.write_error("交易授权验证失败", error)
@ -418,7 +418,7 @@ class CtpTdApi(TdApi):
self.login_status = True self.login_status = True
self.gateway.write_log("交易登录成功") self.gateway.write_log("交易登录成功")
# Confirm settelment # Confirm settlement
req = { req = {
"BrokerID": self.brokerid, "BrokerID": self.brokerid,
"InvestorID": self.userid "InvestorID": self.userid
@ -549,13 +549,16 @@ class CtpTdApi(TdApi):
product=product, product=product,
size=data["VolumeMultiple"], size=data["VolumeMultiple"],
pricetick=data["PriceTick"], pricetick=data["PriceTick"],
option_underlying=data["UnderlyingInstrID"],
option_type=OPTIONTYPE_CTP2VT.get(data["OptionsType"], None),
option_strike=data["StrikePrice"],
option_expiry=datetime.strptime(data["ExpireDate"], "%Y%m%d"),
gateway_name=self.gateway_name gateway_name=self.gateway_name
) )
# For option only
if data["OptionsType"]:
contract.option_underlying = data["UnderlyingInstrID"],
contract.option_type = OPTIONTYPE_CTP2VT.get(data["OptionsType"], None),
contract.option_strike = data["StrikePrice"],
contract.option_expiry = datetime.strptime(data["ExpireDate"], "%Y%m%d"),
self.gateway.on_contract(contract) self.gateway.on_contract(contract)
symbol_exchange_map[contract.symbol] = contract.exchange symbol_exchange_map[contract.symbol] = contract.exchange
@ -662,7 +665,7 @@ class CtpTdApi(TdApi):
"UserID": self.userid, "UserID": self.userid,
"BrokerID": self.brokerid, "BrokerID": self.brokerid,
"AuthCode": self.auth_code, "AuthCode": self.auth_code,
"ProductInfo": self.product_info "UserProductInfo": self.product_info
} }
self.reqid += 1 self.reqid += 1
@ -678,7 +681,8 @@ class CtpTdApi(TdApi):
req = { req = {
"UserID": self.userid, "UserID": self.userid,
"Password": self.password, "Password": self.password,
"BrokerID": self.brokerid "BrokerID": self.brokerid,
"UserProductInfo": self.product_info
} }
self.reqid += 1 self.reqid += 1

View File

@ -0,0 +1,4 @@
from .huobi_gateway import HuobiGateway

View File

@ -0,0 +1,715 @@
# encoding: UTF-8
"""
火币交易接口
"""
import re
import urllib
import base64
import json
import zlib
import hashlib
import hmac
from copy import copy
from datetime import datetime
from vnpy.event import Event
from vnpy.api.rest import RestClient
from vnpy.api.websocket import WebsocketClient
from vnpy.trader.constant import (
Direction,
Exchange,
Product,
Status,
OrderType
)
from vnpy.trader.gateway import BaseGateway, LocalOrderManager
from vnpy.trader.object import (
TickData,
OrderData,
TradeData,
AccountData,
ContractData,
OrderRequest,
CancelRequest,
SubscribeRequest
)
from vnpy.trader.event import EVENT_TIMER
REST_HOST = "https://api.huobipro.com"
WEBSOCKET_DATA_HOST = "wss://api.huobi.pro/ws" # Market Data
WEBSOCKET_TRADE_HOST = "wss://api.huobi.pro/ws/v1" # Account and Order
STATUS_HUOBI2VT = {
"submitted": Status.NOTTRADED,
"partial-filled": Status.PARTTRADED,
"filled": Status.ALLTRADED,
"cancelling": Status.CANCELLED,
"partial-canceled": Status.CANCELLED,
"canceled": Status.CANCELLED,
}
ORDERTYPE_VT2HUOBI = {
(Direction.LONG, OrderType.MARKET): "buy-market",
(Direction.SHORT, OrderType.MARKET): "sell-market",
(Direction.LONG, OrderType.LIMIT): "buy-limit",
(Direction.SHORT, OrderType.LIMIT): "sell-limit",
}
ORDERTYPE_HUOBI2VT = {v: k for k, v in ORDERTYPE_VT2HUOBI.items()}
huobi_symbols = set()
symbol_name_map = {}
class HuobiGateway(BaseGateway):
"""
VN Trader Gateway for Huobi connection.
"""
default_setting = {
"API Key": "",
"Secret Key": "",
"会话数": 3,
"代理地址": "127.0.0.1",
"代理端口": 1080,
}
def __init__(self, event_engine):
"""Constructor"""
super(HuobiGateway, self).__init__(event_engine, "HUOBI")
self.order_manager = LocalOrderManager(self)
self.rest_api = HuobiRestApi(self)
self.trade_ws_api = HuobiTradeWebsocketApi(self)
self.market_ws_api = HuobiDataWebsocketApi(self)
def connect(self, setting: dict):
""""""
key = setting["API Key"]
secret = setting["Secret Key"]
session_number = setting["会话数"]
proxy_host = setting["代理地址"]
proxy_port = setting["代理端口"]
self.rest_api.connect(key, secret, session_number,
proxy_host, proxy_port)
self.trade_ws_api.connect(key, secret, proxy_host, proxy_port)
self.market_ws_api.connect(key, secret, proxy_host, proxy_port)
self.init_query()
def subscribe(self, req: SubscribeRequest):
""""""
self.market_ws_api.subscribe(req)
self.trade_ws_api.subscribe(req)
def send_order(self, req: OrderRequest):
""""""
return self.rest_api.send_order(req)
def cancel_order(self, req: CancelRequest):
""""""
self.rest_api.cancel_order(req)
def query_account(self):
""""""
self.rest_api.query_account_balance()
def query_position(self):
""""""
pass
def close(self):
""""""
self.rest_api.stop()
self.trade_ws_api.stop()
self.market_ws_api.stop()
def process_timer_event(self, event: Event):
""""""
self.count += 1
if self.count < 3:
return
self.query_account()
def init_query(self):
""""""
self.count = 0
self.event_engine.register(EVENT_TIMER, self.process_timer_event)
class HuobiRestApi(RestClient):
"""
HUOBI REST API
"""
def __init__(self, gateway: BaseGateway):
""""""
super(HuobiRestApi, self).__init__()
self.gateway = gateway
self.gateway_name = gateway.gateway_name
self.order_manager = gateway.order_manager
self.host = ""
self.key = ""
self.secret = ""
self.account_id = ""
self.cancel_requests = {}
self.orders = {}
def sign(self, request):
"""
Generate HUOBI signature.
"""
request.headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.71 Safari/537.36"
}
params_with_signature = create_signature(
self.key,
request.method,
self.host,
request.path,
self.secret,
request.params
)
request.params = params_with_signature
if request.method == "POST":
request.headers["Content-Type"] = "application/json"
if request.data:
request.data = json.dumps(request.data)
return request
def connect(
self,
key: str,
secret: str,
session_number: int,
proxy_host: str,
proxy_port: int
):
"""
Initialize connection to REST server.
"""
self.key = key
self.secret = secret
self.host, _ = _split_url(REST_HOST)
self.init(REST_HOST, proxy_host, proxy_port)
self.start(session_number)
self.gateway.write_log("REST API启动成功")
self.query_contract()
self.query_account()
self.query_order()
def query_account(self):
""""""
self.add_request(
method="GET",
path="/v1/account/accounts",
callback=self.on_query_account
)
def query_account_balance(self):
""""""
path = f"/v1/account/accounts/{self.account_id}/balance"
self.add_request(
method="GET",
path=path,
callback=self.on_query_account_balance
)
def query_order(self):
""""""
self.add_request(
method="GET",
path="/v1/order/openOrders",
callback=self.on_query_order
)
def query_contract(self):
""""""
self.add_request(
method="GET",
path="/v1/common/symbols",
callback=self.on_query_contract
)
def send_order(self, req: OrderRequest):
""""""
huobi_type = ORDERTYPE_VT2HUOBI.get(
(req.direction, req.type), ""
)
local_orderid = self.order_manager.new_local_orderid()
order = req.create_order_data(
local_orderid,
self.gateway_name
)
order.time = datetime.now().strftime("%H:%M:%S")
data = {
"account-id": self.account_id,
"amount": str(req.volume),
"symbol": req.symbol,
"type": huobi_type,
"price": str(req.price),
"source": "api"
}
self.add_request(
method="POST",
path="/v1/order/orders/place",
callback=self.on_send_order,
data=data,
extra=order,
)
self.order_manager.on_order(order)
return order.vt_orderid
def cancel_order(self, req: CancelRequest):
""""""
sys_orderid = self.order_manager.get_sys_orderid(req.orderid)
path = f"/v1/order/orders/{sys_orderid}/submitcancel"
self.add_request(
method="POST",
path=path,
callback=self.on_cancel_order,
extra=req
)
def on_query_account(self, data, request):
""""""
if self.check_error(data, "查询账户"):
return
for d in data["data"]:
if d["type"] == "spot":
self.account_id = d["id"]
self.gateway.write_log(f"账户代码{self.account_id}查询成功")
self.query_account_balance()
def on_query_account_balance(self, data, request):
""""""
if self.check_error(data, "查询账户资金"):
return
buf = {}
for d in data["data"]["list"]:
currency = d["currency"]
currency_data = buf.setdefault(currency, {})
currency_data[d["type"]] = float(d["balance"])
for currency, currency_data in buf.items():
account = AccountData(
accountid=currency,
balance=currency_data["trade"] + currency_data["frozen"],
frozen=currency_data["frozen"],
gateway_name=self.gateway_name,
)
if account.balance:
self.gateway.on_account(account)
def on_query_order(self, data, request):
""""""
if self.check_error(data, "查询委托"):
return
for d in data["data"]:
sys_orderid = d["id"]
local_orderid = self.order_manager.get_local_orderid(sys_orderid)
direction, order_type = ORDERTYPE_HUOBI2VT[d["type"]]
dt = datetime.fromtimestamp(d["created-at"] / 1000)
time = dt.strftime("%H:%M:%S")
order = OrderData(
orderid=local_orderid,
symbol=d["symbol"],
exchange=Exchange.HUOBI,
price=float(d["price"]),
volume=float(d["amount"]),
type=order_type,
direction=direction,
traded=float(d["filled-amount"]),
status=STATUS_HUOBI2VT.get(d["state"], None),
time=time,
gateway_name=self.gateway_name,
)
self.order_manager.on_order(order)
self.gateway.write_log("委托信息查询成功")
def on_query_contract(self, data, request): # type: (dict, Request)->None
""""""
if self.check_error(data, "查询合约"):
return
for d in data["data"]:
base_currency = d["base-currency"]
quote_currency = d["quote-currency"]
name = f"{base_currency.upper()}/{quote_currency.upper()}"
pricetick = 1 / pow(10, d["price-precision"])
size = 1 / pow(10, d["amount-precision"])
contract = ContractData(
symbol=d["symbol"],
exchange=Exchange.HUOBI,
name=name,
pricetick=pricetick,
size=size,
product=Product.SPOT,
gateway_name=self.gateway_name,
)
self.gateway.on_contract(contract)
huobi_symbols.add(contract.symbol)
symbol_name_map[contract.symbol] = contract.name
self.gateway.write_log("合约信息查询成功")
def on_send_order(self, data, request):
""""""
order = request.extra
if self.check_error(data, "委托"):
order.status = Status.REJECTED
self.order_manager.on_order(order)
return
sys_orderid = data["data"]
self.order_manager.update_orderid_map(order.orderid, sys_orderid)
def on_cancel_order(self, data, request):
""""""
if self.check_error(data, "撤单"):
return
cancel_request = request.extra
local_orderid = cancel_request.orderid
order = self.order_manager.get_order_with_local_orderid(local_orderid)
order.status = Status.CANCELLED
self.order_manager.on_order(order)
self.gateway.write_log(f"委托撤单成功:{order.orderid}")
def check_error(self, data: dict, func: str = ""):
""""""
if data["status"] != "error":
return False
error_code = data["err-code"]
error_msg = data["err-msg"]
self.gateway.write_log(f"{func}请求出错,代码:{error_code},信息:{error_msg}")
return True
class HuobiWebsocketApiBase(WebsocketClient):
""""""
def __init__(self, gateway):
""""""
super(HuobiWebsocketApiBase, self).__init__()
self.gateway = gateway
self.gateway_name = gateway.gateway_name
self.key = ""
self.secret = ""
self.sign_host = ""
self.path = ""
def connect(
self,
key: str,
secret: str,
url: str,
proxy_host: str,
proxy_port: int
):
""""""
self.key = key
self.secret = secret
host, path = _split_url(url)
self.sign_host = host
self.path = path
self.init(url, proxy_host, proxy_port)
self.start()
def login(self):
""""""
params = {"op": "auth"}
params.update(create_signature(self.key, "GET", self.sign_host, self.path, self.secret))
return self.send_packet(params)
def on_login(self, packet):
""""""
pass
@staticmethod
def unpack_data(data):
""""""
return json.loads(zlib.decompress(data, 31))
def on_packet(self, packet):
""""""
if "ping" in packet:
self.send_packet({"pong": packet["ping"]})
elif "err-msg" in packet:
return self.on_error_msg(packet)
elif "op" in packet and packet["op"] == "auth":
return self.on_login()
else:
self.on_data(packet)
def on_data(self, packet):
""""""
print("data : {}".format(packet))
def on_error_msg(self, packet):
""""""
msg = packet["err-msg"]
if msg == "invalid pong":
return
self.gateway.write_log(packet["err-msg"])
class HuobiTradeWebsocketApi(HuobiWebsocketApiBase):
""""""
def __init__(self, gateway):
""""""
super().__init__(gateway)
self.order_manager = gateway.order_manager
self.order_manager.push_data_callback = self.on_data
self.req_id = 0
def connect(self, key, secret, proxy_host, proxy_port):
""""""
super().connect(key, secret, WEBSOCKET_TRADE_HOST, proxy_host, proxy_port)
def subscribe(self, req: SubscribeRequest):
""""""
self.req_id += 1
req = {
"op": "sub",
"cid": str(self.req_id),
"topic": f"orders.{req.symbol}"
}
self.send_packet(req)
def on_connected(self):
""""""
self.gateway.write_log("交易Websocket API连接成功")
self.login()
def on_login(self):
""""""
self.gateway.write_log("交易Websocket API登录成功")
def on_data(self, packet): # type: (dict)->None
""""""
op = packet.get("op", None)
if op != "notify":
return
topic = packet["topic"]
if "orders" in topic:
self.on_order(packet["data"])
def on_order(self, data: dict):
""""""
sys_orderid = str(data["order-id"])
order = self.order_manager.get_order_with_sys_orderid(sys_orderid)
if not order:
self.order_manager.add_push_data(sys_orderid, data)
return
traded_volume = float(data["filled-amount"])
# Push order event
order.traded += traded_volume
order.status = STATUS_HUOBI2VT.get(data["order-state"], None)
self.order_manager.on_order(order)
# Push trade event
if not traded_volume:
return
trade = TradeData(
symbol=order.symbol,
exchange=Exchange.HUOBI,
orderid=order.orderid,
tradeid=str(data["seq-id"]),
direction=order.direction,
price=float(data["price"]),
volume=float(data["filled-amount"]),
time=datetime.now().strftime("%H:%M:%S"),
gateway_name=self.gateway_name,
)
self.gateway.on_trade(trade)
class HuobiDataWebsocketApi(HuobiWebsocketApiBase):
""""""
def __init__(self, gateway):
""""""
super().__init__(gateway)
self.req_id = 0
self.ticks = {}
def connect(self, key: str, secret: str, proxy_host: str, proxy_port: int):
""""""
super().connect(key, secret, WEBSOCKET_DATA_HOST, proxy_host, proxy_port)
def on_connected(self):
""""""
self.gateway.write_log("行情Websocket API连接成功")
def subscribe(self, req: SubscribeRequest):
""""""
symbol = req.symbol
# Create tick data buffer
tick = TickData(
symbol=symbol,
name=symbol_name_map.get(symbol, ""),
exchange=Exchange.HUOBI,
datetime=datetime.now(),
gateway_name=self.gateway_name,
)
self.ticks[symbol] = tick
# Subscribe to market depth update
self.req_id += 1
req = {
"sub": f"market.{symbol}.depth.step0",
"id": str(self.req_id)
}
self.send_packet(req)
# Subscribe to market detail update
self.req_id += 1
req = {
"sub": f"market.{symbol}.detail",
"id": str(self.req_id)
}
self.send_packet(req)
def on_data(self, packet): # type: (dict)->None
""""""
channel = packet.get("ch", None)
if channel:
if "depth.step" in channel:
self.on_market_depth(packet)
elif "detail" in channel:
self.on_market_detail(packet)
elif "err-code" in packet:
code = packet["err-code"]
msg = packet["err-msg"]
self.gateway.write_log(f"错误代码:{code}, 错误信息:{msg}")
def on_market_depth(self, data):
"""行情深度推送 """
symbol = data["ch"].split(".")[1]
tick = self.ticks[symbol]
tick.datetime = datetime.fromtimestamp(data["ts"] / 1000)
bids = data["tick"]["bids"]
for n in range(5):
price, volume = bids[n]
tick.__setattr__("bid_price_" + str(n + 1), float(price))
tick.__setattr__("bid_volume_" + str(n + 1), float(volume))
asks = data["tick"]["asks"]
for n in range(5):
price, volume = asks[n]
tick.__setattr__("ask_price_" + str(n + 1), float(price))
tick.__setattr__("ask_volume_" + str(n + 1), float(volume))
if tick.last_price:
self.gateway.on_tick(copy(tick))
def on_market_detail(self, data):
"""市场细节推送"""
symbol = data["ch"].split(".")[1]
tick = self.ticks[symbol]
tick.datetime = datetime.fromtimestamp(data["ts"] / 1000)
tick_data = data["tick"]
tick.open_price = float(tick_data["open"])
tick.high_price = float(tick_data["high"])
tick.low_price = float(tick_data["low"])
tick.last_price = float(tick_data["close"])
tick.volume = float(tick_data["vol"])
if tick.bid_price_1:
self.gateway.on_tick(copy(tick))
def _split_url(url):
"""
将url拆分为host和path
:return: host, path
"""
result = re.match("\w+://([^/]*)(.*)", url) # noqa
if result:
return result.group(1), result.group(2)
def create_signature(api_key, method, host, path, secret_key, get_params=None):
"""
创建签名
:param get_params: dict 使用GET方法时附带的额外参数(urlparams)
:return:
"""
sorted_params = [
("AccessKeyId", api_key),
("SignatureMethod", "HmacSHA256"),
("SignatureVersion", "2"),
("Timestamp", datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S"))
]
if get_params:
sorted_params.extend(list(get_params.items()))
sorted_params = list(sorted(sorted_params))
encode_params = urllib.parse.urlencode(sorted_params)
payload = [method, host, path, encode_params]
payload = "\n".join(payload)
payload = payload.encode(encoding="UTF8")
secret_key = secret_key.encode(encoding="UTF8")
digest = hmac.new(secret_key, payload, digestmod=hashlib.sha256).digest()
signature = base64.b64encode(digest)
params = dict(sorted_params)
params["Signature"] = signature.decode("UTF8")
return params

View File

@ -0,0 +1 @@
from .okex_gateway import OkexGateway

View File

@ -0,0 +1,699 @@
# encoding: UTF-8
"""
"""
import hashlib
import hmac
import sys
import time
import json
import base64
import zlib
from copy import copy
from datetime import datetime
from threading import Lock
from urllib.parse import urlencode
from requests import ConnectionError
from vnpy.api.rest import Request, RestClient
from vnpy.api.websocket import WebsocketClient
from vnpy.trader.constant import (
Direction,
Exchange,
OrderType,
Product,
Status
)
from vnpy.trader.gateway import BaseGateway
from vnpy.trader.object import (
TickData,
OrderData,
TradeData,
AccountData,
ContractData,
OrderRequest,
CancelRequest,
SubscribeRequest,
)
REST_HOST = "https://www.okex.com"
WEBSOCKET_HOST = "wss://real.okex.com:10442/ws/v3"
STATUS_OKEX2VT = {
"ordering": Status.SUBMITTING,
"open": Status.NOTTRADED,
"part_filled": Status.PARTTRADED,
"filled": Status.ALLTRADED,
"cancelled": Status.CANCELLED,
"cancelling": Status.CANCELLED,
"failure": Status.REJECTED,
}
DIRECTION_VT2OKEX = {Direction.LONG: "buy", Direction.SHORT: "sell"}
DIRECTION_OKEX2VT = {v: k for k, v in DIRECTION_VT2OKEX.items()}
ORDERTYPE_VT2OKEX = {
OrderType.LIMIT: "limit",
OrderType.MARKET: "market"
}
ORDERTYPE_OKEX2VT = {v: k for k, v in ORDERTYPE_VT2OKEX.items()}
instruments = set()
currencies = set()
class OkexGateway(BaseGateway):
"""
VN Trader Gateway for OKEX connection.
"""
default_setting = {
"API Key": "",
"Secret Key": "",
"Passphrase": "",
"会话数": 3,
"代理地址": "127.0.0.1",
"代理端口": 1080,
}
def __init__(self, event_engine):
"""Constructor"""
super(OkexGateway, self).__init__(event_engine, "OKEX")
self.rest_api = OkexRestApi(self)
self.ws_api = OkexWebsocketApi(self)
def connect(self, setting: dict):
""""""
key = setting["API Key"]
secret = setting["Secret Key"]
passphrase = setting["Passphrase"]
session_number = setting["会话数"]
proxy_host = setting["代理地址"]
proxy_port = setting["代理端口"]
self.rest_api.connect(key, secret, passphrase,
session_number, proxy_host, proxy_port)
self.ws_api.connect(key, secret, passphrase, proxy_host, proxy_port)
def subscribe(self, req: SubscribeRequest):
""""""
self.ws_api.subscribe(req)
def send_order(self, req: OrderRequest):
""""""
return self.rest_api.send_order(req)
def cancel_order(self, req: CancelRequest):
""""""
self.rest_api.cancel_order(req)
def query_account(self):
""""""
pass
def query_position(self):
""""""
pass
def close(self):
""""""
self.rest_api.stop()
self.ws_api.stop()
class OkexRestApi(RestClient):
"""
OKEX REST API
"""
def __init__(self, gateway: BaseGateway):
""""""
super(OkexRestApi, self).__init__()
self.gateway = gateway
self.gateway_name = gateway.gateway_name
self.key = ""
self.secret = ""
self.passphrase = ""
self.order_count = 10000
self.order_count_lock = Lock()
self.connect_time = 0
def sign(self, request):
"""
Generate OKEX signature.
"""
# Sign
# timestamp = str(time.time())
timestamp = get_timestamp()
request.data = json.dumps(request.data)
if request.params:
path = request.path + '?' + urlencode(request.params)
else:
path = request.path
msg = timestamp + request.method + path + request.data
signature = generate_signature(msg, self.secret)
# Add headers
request.headers = {
'OK-ACCESS-KEY': self.key,
'OK-ACCESS-SIGN': signature,
'OK-ACCESS-TIMESTAMP': timestamp,
'OK-ACCESS-PASSPHRASE': self.passphrase,
'Content-Type': 'application/json'
}
return request
def connect(
self,
key: str,
secret: str,
passphrase: str,
session_number: int,
proxy_host: str,
proxy_port: int,
):
"""
Initialize connection to REST server.
"""
self.key = key
self.secret = secret.encode()
self.passphrase = passphrase
self.connect_time = int(datetime.now().strftime("%y%m%d%H%M%S"))
self.init(REST_HOST, proxy_host, proxy_port)
self.start(session_number)
self.gateway.write_log("REST API启动成功")
self.query_time()
self.query_contract()
self.query_account()
self.query_order()
def _new_order_id(self):
with self.order_count_lock:
self.order_count += 1
return self.order_count
def send_order(self, req: OrderRequest):
""""""
orderid = f"a{self.connect_time}{self._new_order_id()}"
data = {
"client_oid": orderid,
"type": ORDERTYPE_VT2OKEX[req.type],
"side": DIRECTION_VT2OKEX[req.direction],
"instrument_id": req.symbol
}
if req.type == OrderType.MARKET:
if req.direction == Direction.LONG:
data["notional"] = req.volume
else:
data["size"] = req.volume
else:
data["price"] = req.price
data["size"] = req.volume
order = req.create_order_data(orderid, self.gateway_name)
self.add_request(
"POST",
"/api/spot/v3/orders",
callback=self.on_send_order,
data=data,
extra=order,
on_failed=self.on_send_order_failed,
on_error=self.on_send_order_error,
)
self.gateway.on_order(order)
return order.vt_orderid
def cancel_order(self, req: CancelRequest):
""""""
data = {
"instrument_id": req.symbol,
"client_oid": req.orderid
}
path = "/api/spot/v3/cancel_orders/" + req.orderid
self.add_request(
"POST",
path,
callback=self.on_cancel_order,
data=data,
on_error=self.on_cancel_order_error,
)
def query_contract(self):
""""""
self.add_request(
"GET",
"/api/spot/v3/instruments",
callback=self.on_query_contract
)
def query_account(self):
""""""
self.add_request(
"GET",
"/api/spot/v3/accounts",
callback=self.on_query_account
)
def query_order(self):
""""""
self.add_request(
"GET",
"/api/spot/v3/orders_pending",
callback=self.on_query_order
)
def query_time(self):
""""""
self.add_request(
"GET",
"/api/general/v3/time",
callback=self.on_query_time
)
def on_query_contract(self, data, request):
""""""
for instrument_data in data:
symbol = instrument_data["instrument_id"]
contract = ContractData(
symbol=symbol,
exchange=Exchange.OKEX,
name=symbol,
product=Product.SPOT,
size=1,
pricetick=instrument_data["tick_size"],
gateway_name=self.gateway_name
)
self.gateway.on_contract(contract)
instruments.add(instrument_data["instrument_id"])
currencies.add(instrument_data["base_currency"])
currencies.add(instrument_data["quote_currency"])
self.gateway.write_log("合约信息查询成功")
# Start websocket api after instruments data collected
self.gateway.ws_api.start()
def on_query_account(self, data, request):
""""""
for account_data in data:
account = AccountData(
accountid=account_data["currency"],
balance=float(account_data["balance"]),
frozen=float(account_data["hold"]),
gateway_name=self.gateway_name
)
self.gateway.on_account(account)
self.gateway.write_log("账户资金查询成功")
def on_query_order(self, data, request):
""""""
for order_data in data:
order = OrderData(
symbol=order_data["instrument_id"],
exchange=Exchange.OKEX,
type=ORDERTYPE_OKEX2VT[order_data["type"]],
orderid=order_data["client_oid"],
direction=DIRECTION_OKEX2VT[order_data["side"]],
price=float(order_data["price"]),
volume=float(order_data["size"]),
time=order_data["timestamp"][11:19],
status=STATUS_OKEX2VT[order_data["status"]],
gateway_name=self.gateway_name,
)
self.gateway.on_order(order)
self.gateway.write_log("委托信息查询成功")
def on_query_time(self, data, request):
""""""
server_time = data["iso"]
local_time = datetime.utcnow().isoformat()
msg = f"服务器时间:{server_time},本机时间:{local_time}"
self.gateway.write_log(msg)
def on_send_order_failed(self, status_code: str, request: Request):
"""
Callback when sending order failed on server.
"""
order = request.extra
order.status = Status.REJECTED
self.gateway.on_order(order)
msg = f"委托失败,状态码:{status_code},信息:{request.response.text}"
self.gateway.write_log(msg)
def on_send_order_error(
self, exception_type: type, exception_value: Exception, tb, request: Request
):
"""
Callback when sending order caused exception.
"""
order = request.extra
order.status = Status.REJECTED
self.gateway.on_order(order)
# Record exception if not ConnectionError
if not issubclass(exception_type, ConnectionError):
self.on_error(exception_type, exception_value, tb, request)
def on_send_order(self, data, request):
"""Websocket will push a new order status"""
order = request.extra
error_msg = data["error_message"]
if error_msg:
order.status = Status.REJECTED
self.gateway.on_order(order)
self.gateway.write_log(f"委托失败:{error_msg}")
def on_cancel_order_error(
self, exception_type: type, exception_value: Exception, tb, request: Request
):
"""
Callback when cancelling order failed on server.
"""
# Record exception if not ConnectionError
if not issubclass(exception_type, ConnectionError):
self.on_error(exception_type, exception_value, tb, request)
def on_cancel_order(self, data, request):
"""Websocket will push a new order status"""
pass
def on_failed(self, status_code: int, request: Request):
"""
Callback to handle request failed.
"""
msg = f"请求失败,状态码:{status_code},信息:{request.response.text}"
self.gateway.write_log(msg)
def on_error(
self, exception_type: type, exception_value: Exception, tb, request: Request
):
"""
Callback to handler request exception.
"""
msg = f"触发异常,状态码:{exception_type},信息:{exception_value}"
self.gateway.write_log(msg)
sys.stderr.write(
self.exception_detail(exception_type, exception_value, tb, request)
)
class OkexWebsocketApi(WebsocketClient):
""""""
def __init__(self, gateway):
""""""
super(OkexWebsocketApi, self).__init__()
self.ping_interval = 20 # OKEX use 30 seconds for ping
self.gateway = gateway
self.gateway_name = gateway.gateway_name
self.key = ""
self.secret = ""
self.passphrase = ""
self.trade_count = 10000
self.connect_time = 0
self.callbacks = {}
self.ticks = {}
def connect(
self,
key: str,
secret: str,
passphrase: str,
proxy_host: str,
proxy_port: int
):
""""""
self.key = key
self.secret = secret.encode()
self.passphrase = passphrase
self.connect_time = int(datetime.now().strftime("%y%m%d%H%M%S"))
self.init(WEBSOCKET_HOST, proxy_host, proxy_port)
# self.start()
def unpack_data(self, data):
""""""
return json.loads(zlib.decompress(data, -zlib.MAX_WBITS))
def subscribe(self, req: SubscribeRequest):
"""
Subscribe to tick data upate.
"""
tick = TickData(
symbol=req.symbol,
exchange=req.exchange,
name=req.symbol,
datetime=datetime.now(),
gateway_name=self.gateway_name,
)
self.ticks[req.symbol] = tick
channel_ticker = f"spot/ticker:{req.symbol}"
channel_depth = f"spot/depth5:{req.symbol}"
self.callbacks[channel_ticker] = self.on_ticker
self.callbacks[channel_depth] = self.on_depth
req = {
"op": "subscribe",
"args": [channel_ticker, channel_depth]
}
self.send_packet(req)
def on_connected(self):
""""""
self.gateway.write_log("Websocket API连接成功")
self.login()
def on_disconnected(self):
""""""
self.gateway.write_log("Websocket API连接断开")
def on_packet(self, packet: dict):
""""""
if "event" in packet:
event = packet["event"]
if event == "subscribe":
return
elif event == "error":
msg = packet["message"]
self.gateway.write_log(f"Websocket API请求异常{msg}")
elif event == "login":
self.on_login(packet)
else:
channel = packet["table"]
data = packet["data"]
callback = self.callbacks.get(channel, None)
if callback:
for d in data:
callback(d)
def on_error(self, exception_type: type, exception_value: Exception, tb):
""""""
msg = f"触发异常,状态码:{exception_type},信息:{exception_value}"
self.gateway.write_log(msg)
sys.stderr.write(self.exception_detail(
exception_type, exception_value, tb))
def login(self):
"""
Need to login befores subscribe to websocket topic.
"""
timestamp = str(time.time())
msg = timestamp + 'GET' + '/users/self/verify'
signature = generate_signature(msg, self.secret)
req = {
"op": "login",
"args": [
self.key,
self.passphrase,
timestamp,
signature.decode("utf-8")
]
}
self.send_packet(req)
self.callbacks['login'] = self.on_login
def subscribe_topic(self):
"""
Subscribe to all private topics.
"""
self.callbacks["spot/ticker"] = self.on_ticker
self.callbacks["spot/depth5"] = self.on_depth
self.callbacks["spot/account"] = self.on_account
self.callbacks["spot/order"] = self.on_order
# Subscribe to order update
channels = []
for instrument_id in instruments:
channel = f"spot/order:{instrument_id}"
channels.append(channel)
req = {
"op": "subscribe",
"args": channels
}
self.send_packet(req)
# Subscribe to account update
channels = []
for currency in currencies:
channel = f"spot/account:{currency}"
channels.append(channel)
req = {
"op": "subscribe",
"args": channels
}
self.send_packet(req)
# Subscribe to BTC/USDT trade for keep connection alive
req = {
"op": "subscribe",
"args": ["spot/trade:BTC-USDT"]
}
self.send_packet(req)
def on_login(self, data: dict):
""""""
success = data.get("success", False)
if success:
self.gateway.write_log("Websocket API登录成功")
self.subscribe_topic()
else:
self.gateway.write_log("Websocket API登录失败")
def on_ticker(self, d):
""""""
symbol = d["instrument_id"]
tick = self.ticks.get(symbol, None)
if not tick:
return
tick.last_price = d["last"]
tick.open = d["open_24h"]
tick.high = d["high_24h"]
tick.low = d["low_24h"]
tick.volume = d["base_volume_24h"]
tick.datetime = datetime.strptime(
d["timestamp"], "%Y-%m-%dT%H:%M:%S.%fZ")
self.gateway.on_tick(copy(tick))
def on_depth(self, d):
""""""
for tick_data in d:
symbol = d["instrument_id"]
tick = self.ticks.get(symbol, None)
if not tick:
return
bids = d["bids"]
asks = d["asks"]
for n, buf in enumerate(bids):
price, volume, _ = buf
tick.__setattr__("bid_price_%s" % (n + 1), price)
tick.__setattr__("bid_volume_%s" % (n + 1), volume)
for n, buf in enumerate(asks):
price, volume, _ = buf
tick.__setattr__("ask_price_%s" % (n + 1), price)
tick.__setattr__("ask_volume_%s" % (n + 1), volume)
tick.datetime = datetime.strptime(
d["timestamp"], "%Y-%m-%dT%H:%M:%S.%fZ")
self.gateway.on_tick(copy(tick))
def on_order(self, d):
""""""
order = OrderData(
symbol=d["instrument_id"],
exchange=Exchange.OKEX,
type=ORDERTYPE_OKEX2VT[d["type"]],
orderid=d["client_oid"],
direction=DIRECTION_OKEX2VT[d["side"]],
price=d["price"],
volume=d["size"],
traded=d["filled_size"],
time=d["timestamp"][11:19],
status=STATUS_OKEX2VT[d["status"]],
gateway_name=self.gateway_name,
)
self.gateway.on_order(copy(order))
trade_volume = float(d.get("last_fill_qty", 0))
if not trade_volume:
return
self.trade_count += 1
tradeid = f"{self.connect_time}{self.trade_count}"
trade = TradeData(
symbol=order.symbol,
exchange=order.exchange,
orderid=order.orderid,
tradeid=tradeid,
direction=order.direction,
price=float(d["last_fill_px"]),
volume=float(trade_volume),
time=d["last_fill_time"][11:19],
gateway_name=self.gateway_name
)
self.gateway.on_trade(trade)
def on_account(self, d):
""""""
account = AccountData(
accountid=d["currency"],
balance=float(d["balance"]),
frozen=float(d["hold"]),
gateway_name=self.gateway_name
)
self.gateway.on_account(copy(account))
def generate_signature(msg: str, secret_key: str):
"""OKEX V3 signature"""
return base64.b64encode(hmac.new(secret_key, msg.encode(), hashlib.sha256).digest())
def get_timestamp():
""""""
now = datetime.utcnow()
timestamp = now.isoformat("T", "milliseconds")
return timestamp + "Z"

View File

@ -1,5 +1,6 @@
# encoding: UTF-8 # encoding: UTF-8
""" """
Author: KeKe
Please install tiger-api before use. Please install tiger-api before use.
pip install tigeropen pip install tigeropen
""" """
@ -123,6 +124,9 @@ class TigerGateway(BaseGateway):
self.contracts = {} self.contracts = {}
self.symbol_names = {} self.symbol_names = {}
self.push_connected = False
self.subscribed_symbols = set()
def run(self): def run(self):
"""""" """"""
while self.active: while self.active:
@ -203,23 +207,33 @@ class TigerGateway(BaseGateway):
""" """
protocol, host, port = self.client_config.socket_host_port protocol, host, port = self.client_config.socket_host_port
self.push_client = PushClient(host, port, (protocol == 'ssl')) self.push_client = PushClient(host, port, (protocol == 'ssl'))
self.push_client.connect(
self.client_config.tiger_id, self.client_config.private_key)
self.push_client.quote_changed = self.on_quote_change self.push_client.quote_changed = self.on_quote_change
self.push_client.asset_changed = self.on_asset_change self.push_client.asset_changed = self.on_asset_change
self.push_client.position_changed = self.on_position_change self.push_client.position_changed = self.on_position_change
self.push_client.order_changed = self.on_order_change self.push_client.order_changed = self.on_order_change
self.write_log("推送接口连接成功") self.push_client.connect(
self.client_config.tiger_id, self.client_config.private_key)
def subscribe(self, req: SubscribeRequest): def subscribe(self, req: SubscribeRequest):
"""""" """"""
self.push_client.subscribe_quote([req.symbol]) self.subscribed_symbols.add(req.symbol)
if self.push_connected:
self.push_client.subscribe_quote([req.symbol])
def on_push_connected(self):
""""""
self.push_connected = True
self.write_log("推送接口连接成功")
self.push_client.subscribe_asset() self.push_client.subscribe_asset()
self.push_client.subscribe_position() self.push_client.subscribe_position()
self.push_client.subscribe_order() self.push_client.subscribe_order()
self.push_client.subscribe_quote(list(self.subscribed_symbols))
def on_quote_change(self, tiger_symbol: str, data: list, trading: bool): def on_quote_change(self, tiger_symbol: str, data: list, trading: bool):
"""""" """"""
data = dict(data) data = dict(data)

1
vnpy/rpc/__init__.py Normal file
View File

@ -0,0 +1 @@
from .vnrpc import RpcServer, RpcClient, RemoteException

36
vnpy/rpc/test_client.py Normal file
View File

@ -0,0 +1,36 @@
from __future__ import print_function
from __future__ import absolute_import
from time import sleep
from vnpy.rpc import RpcClient
class TestClient(RpcClient):
"""
Test RpcClient
"""
def __init__(self, req_address, sub_address):
"""
Constructor
"""
super(TestClient, self).__init__(req_address, sub_address)
def callback(self, topic, data):
"""
Realize callable function
"""
print('client received topic:', topic, ', data:', data)
if __name__ == '__main__':
req_address = 'tcp://localhost:2014'
sub_address = 'tcp://localhost:0602'
tc = TestClient(req_address, sub_address)
tc.subscribeTopic('')
tc.start()
while 1:
print(tc.add(1, 3))
sleep(2)

40
vnpy/rpc/test_server.py Normal file
View File

@ -0,0 +1,40 @@
from __future__ import print_function
from __future__ import absolute_import
from time import sleep, time
from vnpy.rpc import RpcServer
class TestServer(RpcServer):
"""
Test RpcServer
"""
def __init__(self, rep_address, pub_address):
"""
Constructor
"""
super(TestServer, self).__init__(rep_address, pub_address)
self.register(self.add)
def add(self, a, b):
"""
Test function
"""
print('receiving: %s, %s' % (a, b))
return a + b
if __name__ == '__main__':
rep_address = 'tcp://*:2014'
pub_address = 'tcp://*:0602'
ts = TestServer(rep_address, pub_address)
ts.start()
while 1:
content = 'current server time is %s' % time()
print(content)
ts.publish('test', content)
sleep(2)

329
vnpy/rpc/vnrpc.py Normal file
View File

@ -0,0 +1,329 @@
import threading
import traceback
import signal
import zmq
from msgpack import packb, unpackb
from json import dumps, loads
import pickle
p_dumps = pickle.dumps
p_loads = pickle.loads
# Achieve Ctrl-c interrupt recv
signal.signal(signal.SIGINT, signal.SIG_DFL)
class RpcObject(object):
"""
Referred to serialization of packing and unpacking, we offer 3 tools:
1) maspack: higher performance, but usually requires the installation of msgpack related tools;
2) jason: Slightly lower performance but versatility is better, most programming languages have built-in libraries;
3) cPickle: Lower performance and only can be used in Python, but it is very convenient to transfer Python objects directly.
Therefore, it is recommended to use msgpack.
Use json, if you want to communicate with some languages without providing msgpack.
Use cPickle, when the data being transferred contains many custom Python objects.
"""
def __init__(self):
"""
Constructor
Use msgpack as default serialization tool
"""
self.use_msgpack()
def pack(self, data):
""""""
pass
def unpack(self, data):
""""""
pass
def __json_pack(self, data):
"""
Pack with json
"""
return dumps(data)
def __json_unpack(self, data):
"""
Unpack with json
"""
return loads(data)
def __msgpack_pack(self, data):
"""
Pack with msgpack
"""
return packb(data)
def __msgpack_unpack(self, data):
"""
Unpack with msgpack
"""
return unpackb(data)
def __pickle_pack(self, data):
"""
Pack with cPickle
"""
return p_dumps(data)
def __pickle_unpack(self, data):
"""
Unpack with cPickle
"""
return p_loads(data)
def use_json(self):
"""
Use json as serialization tool
"""
self.pack = self.__json_pack
self.unpack = self.__json_unpack
def use_msgpack(self):
"""
Use msgpack as serialization tool
"""
self.pack = self.__msgpack_pack
self.unpack = self.__msgpack_unpack
def use_pickle(self):
"""
Use cPickle as serialization tool
"""
self.pack = self.__pickle_pack
self.unpack = self.__pickle_unpack
class RpcServer(RpcObject):
""""""
def __init__(self, rep_address, pub_address):
"""
Constructor
"""
super(RpcServer, self).__init__()
# Save functions dict: key is fuction name, value is fuction object
self.__functions = {}
# Zmq port related
self.__context = zmq.Context()
self.__socket_rep = self.__context.socket(
zmq.REP) # Reply socket (Requestreply pattern)
self.__socket_rep.bind(rep_address)
# Publish socket (Publishsubscribe pattern)
self.__socket_pub = self.__context.socket(zmq.PUB)
self.__socket_pub.bind(pub_address)
# Woker thread related
self.__active = False # RpcServer status
self.__thread = threading.Thread(target=self.run) # RpcServer thread
def start(self):
"""
Start RpcServer
"""
# Start RpcServer status
self.__active = True
# Start RpcServer thread
if not self.__thread.isAlive():
self.__thread.start()
def stop(self, join=False):
"""
Stop RpcServer
"""
# Stop RpcServer status
self.__active = False
# Wait for RpcServer thread to exit
if join and self.__thread.isAlive():
self.__thread.join()
def run(self):
"""
Run RpcServer functions
"""
while self.__active:
# Use poll to wait event arrival, waiting time is 1 second (1000 milliseconds)
if not self.__socket_rep.poll(1000):
continue
# Receive request data from Reply socket
reqb = self.__socket_rep.recv()
# Unpack request by deserialization
req = self.unpack(reqb)
# Get function name and parameters
name, args, kwargs = req
# Try to get and execute callable function object; capture exception information if it fails
name = name.decode("UTF-8")
try:
func = self.__functions[name]
r = func(*args, **kwargs)
rep = [True, r]
except Exception as e: # noqa
rep = [False, traceback.format_exc()]
# Pack response by serialization
repb = self.pack(rep)
# send callable response by Reply socket
self.__socket_rep.send(repb)
def publish(self, topic, data):
"""
Publish data
"""
# Serialized data
topic = bytes(topic, "UTF-8")
datab = self.pack(data)
# Send data by Publish socket
# topci must be ascii encoding
self.__socket_pub.send_multipart([topic, datab])
def register(self, func):
"""
Register function
"""
self.__functions[func.__name__] = func
class RpcClient(RpcObject):
""""""
def __init__(self, req_address, sub_address):
"""Constructor"""
super(RpcClient, self).__init__()
# zmq port related
self.__req_address = req_address
self.__sub_address = sub_address
self.__context = zmq.Context()
# Request socket (Requestreply pattern)
self.__socket_req = self.__context.socket(zmq.REQ)
# Subscribe socket (Publishsubscribe pattern)
self.__socket_sub = self.__context.socket(zmq.SUB)
# Woker thread relate, used to process data pushed from server
self.__active = False # RpcClient status
self.__thread = threading.Thread(
target=self.run) # RpcClient thread
def __getattr__(self, name):
"""
Realize remote call function
"""
# Perform remote call task
def dorpc(*args, **kwargs):
# Generate request
req = [name, args, kwargs]
# Pack request by serialization
reqb = self.pack(req)
# Send request and wait for response
self.__socket_req.send(reqb)
repb = self.__socket_req.recv()
# Unpack response by deserialization
rep = self.unpack(repb)
# Return response if successed; Trigger exception if failed
if rep[0]:
return rep[1]
else:
raise RemoteException(rep[1].decode("UTF-8"))
return dorpc
def start(self):
"""
Start RpcClient
"""
# Connect zmq port
self.__socket_req.connect(self.__req_address)
self.__socket_sub.connect(self.__sub_address)
# Start RpcClient status
self.__active = True
# Start RpcClient thread
if not self.__thread.isAlive():
self.__thread.start()
def stop(self):
"""
Stop RpcClient
"""
# Stop RpcClient status
self.__active = False
# Wait for RpcClient thread to exit
if self.__thread.isAlive():
self.__thread.join()
def run(self):
"""
Run RpcClient function
"""
while self.__active:
# Use poll to wait event arrival, waiting time is 1 second (1000 milliseconds)
if not self.__socket_sub.poll(1000):
continue
# Receive data from subscribe socket
topic, datab = self.__socket_sub.recv_multipart()
# Unpack data by deserialization
data = self.unpack(datab)
# Process data by callable function
topic = topic.decode("UTF-8")
self.callback(topic, data)
def callback(self, topic, data):
"""
Callable function
"""
raise NotImplementedError
def subscribeTopic(self, topic):
"""
Subscribe data
"""
topic = bytes(topic, "UTF-8")
self.__socket_sub.setsockopt(zmq.SUBSCRIBE, topic)
class RemoteException(Exception):
"""
RPC remote exception
"""
def __init__(self, value):
"""
Constructor
"""
self.__value = value
def __str__(self):
"""
Output error message
"""
return self.__value

View File

@ -99,6 +99,8 @@ class Exchange(Enum):
# CryptoCurrency # CryptoCurrency
BITMEX = "BITMEX" BITMEX = "BITMEX"
OKEX = "OKEX"
HUOBI = "HUOBI"
class Currency(Enum): class Currency(Enum):

View File

@ -52,6 +52,7 @@ class MainEngine:
""" """
engine = engine_class(self, self.event_engine) engine = engine_class(self, self.event_engine)
self.engines[engine.engine_name] = engine self.engines[engine.engine_name] = engine
return engine
def add_gateway(self, gateway_class: BaseGateway): def add_gateway(self, gateway_class: BaseGateway):
""" """
@ -59,6 +60,7 @@ class MainEngine:
""" """
gateway = gateway_class(self.event_engine) gateway = gateway_class(self.event_engine)
self.gateways[gateway.gateway_name] = gateway self.gateways[gateway.gateway_name] = gateway
return gateway
def add_app(self, app_class: BaseApp): def add_app(self, app_class: BaseApp):
""" """
@ -67,7 +69,8 @@ class MainEngine:
app = app_class() app = app_class()
self.apps[app.app_name] = app self.apps[app.app_name] = app
self.add_engine(app.engine_class) engine = self.add_engine(app.engine_class)
return engine
def init_engines(self): def init_engines(self):
""" """

View File

@ -4,6 +4,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any from typing import Any
from copy import copy
from vnpy.event import Event, EventEngine from vnpy.event import Event, EventEngine
from .event import ( from .event import (
@ -227,3 +228,124 @@ class BaseGateway(ABC):
Return default setting dict. Return default setting dict.
""" """
return self.default_setting return self.default_setting
class LocalOrderManager:
"""
Management tool to support use local order id for trading.
"""
def __init__(self, gateway: BaseGateway):
""""""
self.gateway = gateway
# For generating local orderid
self.order_prefix = ""
self.order_count = 0
self.orders = {} # local_orderid:order
# Map between local and system orderid
self.local_sys_orderid_map = {}
self.sys_local_orderid_map = {}
# Push order data buf
self.push_data_buf = {} # sys_orderid:data
# Callback for processing push order data
self.push_data_callback = None
# Cancel request buf
self.cancel_request_buf = {} # local_orderid:req
def new_local_orderid(self):
"""
Generate a new local orderid.
"""
self.order_count += 1
local_orderid = str(self.order_count).rjust(8, "0")
return local_orderid
def get_local_orderid(self, sys_orderid: str):
"""
Get local orderid with sys orderid.
"""
local_orderid = self.sys_local_orderid_map.get(sys_orderid, "")
if not local_orderid:
local_orderid = self.new_local_orderid()
self.update_orderid_map(local_orderid, sys_orderid)
return local_orderid
def get_sys_orderid(self, local_orderid: str):
"""
Get sys orderid with local orderid.
"""
sys_orderid = self.local_sys_orderid_map.get(local_orderid, "")
return sys_orderid
def update_orderid_map(self, local_orderid: str, sys_orderid: str):
"""
Update orderid map.
"""
self.sys_local_orderid_map[sys_orderid] = local_orderid
self.local_sys_orderid_map[local_orderid] = sys_orderid
self.check_cancel_request(local_orderid)
self.check_push_data(sys_orderid)
def check_push_data(self, sys_orderid: str):
"""
Check if any order push data waiting.
"""
if sys_orderid not in self.push_data_buf:
return
data = self.push_data_buf.pop(sys_orderid)
if self.push_data_callback:
self.push_data_callback(data)
def add_push_data(self, sys_orderid: str, data: dict):
"""
Add push data into buf.
"""
self.push_data_buf[sys_orderid] = data
def get_order_with_sys_orderid(self, sys_orderid: str):
""""""
local_orderid = self.sys_local_orderid_map.get(sys_orderid, None)
if not local_orderid:
return None
else:
return self.get_order_with_local_orderid(local_orderid)
def get_order_with_local_orderid(self, local_orderid: str):
""""""
order = self.orders[local_orderid]
return copy(order)
def on_order(self, order: OrderData):
"""
Keep an order buf before pushing it to gateway.
"""
self.orders[order.orderid] = copy(order)
self.gateway.on_order(order)
def cancel_order(self, req: CancelRequest):
"""
"""
sys_orderid = self.get_sys_orderid(req.orderid)
if not sys_orderid:
self.cancel_request_buf[req.orderid] = req
return
self.gateway.cancel_order(req)
def check_cancel_request(self, local_orderid: str):
"""
"""
if local_orderid not in self.cancel_request_buf:
return
req = self.cancel_request_buf.pop(local_orderid)
self.gateway.cancel_order(req)