commit
b075e1d653
10
.travis.yml
10
.travis.yml
@ -29,6 +29,8 @@ matrix:
|
||||
- choco install python3 --version 3.7.2
|
||||
install:
|
||||
- python -m pip install --upgrade pip wheel setuptools
|
||||
- pip install https://pip.vnpy.com/colletion/TA_Lib-0.4.17-cp37-cp37m-win_amd64.whl
|
||||
- pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
|
||||
- pip install -r requirements.txt
|
||||
- pip install .
|
||||
|
||||
@ -43,8 +45,10 @@ matrix:
|
||||
- sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-8 90
|
||||
- sudo update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++-8 90
|
||||
- sudo update-alternatives --install /usr/bin/gcc cc /usr/bin/gcc-8 90
|
||||
# Linux install script
|
||||
# update pip & setuptools
|
||||
- python -m pip install --upgrade pip wheel setuptools
|
||||
# Linux install script
|
||||
- pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
|
||||
- bash ./install.sh
|
||||
|
||||
- name: "pip install under Ubuntu: gcc-7"
|
||||
@ -58,8 +62,10 @@ matrix:
|
||||
- sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-7 90
|
||||
- sudo update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++-7 90
|
||||
- sudo update-alternatives --install /usr/bin/gcc cc /usr/bin/gcc-7 90
|
||||
# Linux install script
|
||||
# update pip & setuptools
|
||||
- python -m pip install --upgrade pip wheel setuptools
|
||||
# Linux install script
|
||||
- pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
|
||||
- bash ./install.sh
|
||||
|
||||
- name: "sdist install under Windows"
|
||||
|
51
README.md
51
README.md
@ -1,10 +1,18 @@
|
||||
# By Traders, For Traders.
|
||||
|
||||
<p align="center">
|
||||
<img src ="https://vnpy.oss-cn-shanghai.aliyuncs.com/vnpy-logo.png" />
|
||||
<img src ="https://vnpy.oss-cn-shanghai.aliyuncs.com/vnpy-logo.png"/>
|
||||
</p>
|
||||
|
||||
vn.py是一套基于Python的开源量化交易系统开发框架,自2015年1月正式发布以来,在开源社区5年持续不断的贡献下一步步成长为全功能量化交易平台,目前国内外金融机构用户已经超过300家,包括:私募基金、证券自营和资管、期货资管和子公司、高校研究机构、自营交易公司、交易所、Token Fund等。
|
||||
<p align="center">
|
||||
<img src ="https://img.shields.io/badge/version-2.0.1-blueviolet.svg"/>
|
||||
<img src ="https://img.shields.io/badge/platform-windows|linux|macos-yellow.svg"/>
|
||||
<img src ="https://img.shields.io/badge/python-3.7-blue.svg" />
|
||||
<img src ="https://img.shields.io/travis/com/vnpy/vnpy/master.svg"/>
|
||||
<img src ="https://img.shields.io/github/license/vnpy/vnpy.svg?color=orange"/>
|
||||
</p>
|
||||
|
||||
vn.py是一套基于Python的开源量化交易系统开发框架,于2015年1月正式发布,在开源社区5年持续不断的贡献下一步步成长为全功能量化交易平台,目前国内外金融机构用户已经超过300家,包括:私募基金、证券自营和资管、期货资管和子公司、高校研究机构、自营交易公司、交易所、Token Fund等。
|
||||
|
||||
2.0版本基于Python 3.7全新重构开发,目前功能还在逐步完善中。如需Python 2上的版本请点击:[长期支持版本v1.9.2 LTS](https://github.com/vnpy/vnpy/tree/v1.9.2-LTS)。
|
||||
|
||||
@ -14,17 +22,23 @@ vn.py是一套基于Python的开源量化交易系统开发框架,自2015年1
|
||||
|
||||
2. 覆盖国内外所有交易品种的交易接口(vnpy.gateway):
|
||||
|
||||
* CTP(ctpGateway):国内期货、期权
|
||||
* CTP(ctp):国内期货、期权
|
||||
|
||||
* 富途证券(futuGateway):港股、美股
|
||||
* 宽睿(oes):A股
|
||||
|
||||
* Interactive Brokers(ibGateway):全球证券、期货、期权、外汇等
|
||||
* 富途证券(futu):港股、美股
|
||||
|
||||
* BitMEX (bitmexGateway):数字货币期货、期权、永续合约
|
||||
* 老虎证券(tiger):全球证券、期货、期权、外汇等
|
||||
|
||||
* Interactive Brokers(ib):全球证券、期货、期权、外汇等
|
||||
|
||||
* BitMEX (bitmex):数字货币期货、期权、永续合约
|
||||
|
||||
3. 开箱即用的各类量化策略交易应用(vnpy.app):
|
||||
|
||||
* CtaStrategy:CTA策略引擎模块,在保持易用性的同时,允许用户针对CTA类策略运行过程中委托的报撤行为进行细粒度控制(降低交易滑点、实现高频策略)
|
||||
* cta_strategy:CTA策略引擎模块,在保持易用性的同时,允许用户针对CTA类策略运行过程中委托的报撤行为进行细粒度控制(降低交易滑点、实现高频策略)
|
||||
|
||||
* csv_loader:CSV历史数据加载器,用于加载CSV格式文件中的历史数据到平台数据库中,用于策略的回测研究以及实盘初始化等功能,支持自定义数据表头格式
|
||||
|
||||
4. Python交易API接口封装(vnpy.api),提供上述交易接口的底层对接实现。
|
||||
|
||||
@ -36,11 +50,9 @@ vn.py是一套基于Python的开源量化交易系统开发框架,自2015年1
|
||||
|
||||
## 环境准备
|
||||
|
||||
* 推荐使用vn.py团队为量化交易专门打造的Python发行版[VNConda-2.0-Windows-x86_64](https://conda.vnpy.com/VNConda-2.0-Windows-x86_64.exe),内置了最新版的vn.py,无需手动安装
|
||||
* 推荐使用vn.py团队为量化交易专门打造的Python发行版[VNConda-2.0.1-Windows-x86_64](https://conda.vnpy.com/VNConda-2.0.1-Windows-x86_64.exe),内置了最新版的vn.py框架以及VN Station量化管理平台,无需手动安装
|
||||
* 支持的系统版本:Windows 7以上/Windows Server 2008以上/Ubuntu 18.04 LTS
|
||||
* 支持的Python版本:Python 3.7 64位(**注意必须是Python 3.7 64位版本**)
|
||||
* 如需使用IB API,请在[Interactive Brokers Github](https://interactivebrokers.github.io/#)页面下载安装**IB API Latest**
|
||||
|
||||
|
||||
## 安装步骤
|
||||
|
||||
@ -59,15 +71,20 @@ vn.py是一套基于Python的开源量化交易系统开发框架,自2015年1
|
||||
|
||||
1. 在[SimNow](http://www.simnow.com.cn/)注册CTP仿真账号,并在[该页面](http://www.simnow.com.cn/product.action)获取经纪商代码以及交易行情服务器地址。
|
||||
|
||||
2. 在[vn.py社区论坛](https://www.vnpy.com/forum/)注册获得VN Station账号密码,论坛最新的注册邀请码为**El86Pa1p**
|
||||
2. 在[vn.py社区论坛](https://www.vnpy.com/forum/)注册获得VN Station账号密码(论坛账号密码即是)
|
||||
|
||||
3. 启动VN Station(安装VNConda后会在桌面自动创建快捷方式),输入上一步的账号密码登录
|
||||
|
||||
4. 点击底部的**VN Trader**按钮,选择运行目录(默认在系统用户目录即可)后,在对话框中勾选CTP接口以及CtaStrategy应用,点击右下方的**启动**按钮,开始你的交易!!!
|
||||
4. 点击底部的**VN Trader Lite**按钮,开始你的交易!!!
|
||||
|
||||
5. 在VN Trader的运行过程中请勿关闭VN Station(会自动退出)
|
||||
注意:
|
||||
* 在VN Trader的运行过程中请勿关闭VN Station(会自动退出)
|
||||
* 如需要灵活配置量化交易应用组件,请使用**VN Trader Pro**
|
||||
|
||||
6. 如选择了VNConda以外的安装方式(不推荐新手),可以在任意目录下创建run.py,写入以下示例代码后运行:
|
||||
|
||||
## 脚本运行
|
||||
|
||||
除了基于VN Station的图形化启动方式外,也可以在任意目录下创建run.py,写入以下示例代码:
|
||||
|
||||
```Python
|
||||
from vnpy.event import EventEngine
|
||||
@ -77,7 +94,7 @@ from vnpy.gateway.ctp import CtpGateway
|
||||
from vnpy.app.cta_strategy import CtaStrategyApp
|
||||
|
||||
def main():
|
||||
"""启动VN Trader"""
|
||||
"""Start VN Trader"""
|
||||
qapp = create_qapp()
|
||||
|
||||
event_engine = EventEngine()
|
||||
@ -95,11 +112,13 @@ if __name__ == "__main__":
|
||||
main()
|
||||
```
|
||||
|
||||
在该目录下打开CMD(按住Shift->点击鼠标右键->在此处打开命令窗口/PowerShell)后运行下列命令启动VN Trader:
|
||||
|
||||
python run.py
|
||||
|
||||
## 贡献代码
|
||||
|
||||
vn.py使用github托管其源代码,如果希望贡献代码请使用github的PR(Pull Request)的流程:
|
||||
vn.py使用Github托管其源代码,如果希望贡献代码请使用github的PR(Pull Request)的流程:
|
||||
|
||||
1. [创建 Issue](https://github.com/vnpy/vnpy/issues/new) - 对于较大的改动(如新功能,大型重构等)最好先开issue讨论一下,较小的improvement(如文档改进,bugfix等)直接发PR即可
|
||||
|
||||
|
12
docs/conf.py
12
docs/conf.py
@ -28,7 +28,6 @@ version = '2.0'
|
||||
# The full version, including alpha/beta/rc tags
|
||||
release = '2.0-DEV'
|
||||
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
|
||||
# If your documentation needs a minimal Sphinx version, state it here.
|
||||
@ -75,19 +74,20 @@ exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
|
||||
# The name of the Pygments (syntax highlighting) style to use.
|
||||
pygments_style = None
|
||||
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
|
||||
# The theme to use for HTML and HTML Help pages. See the documentation for
|
||||
# a list of builtin themes.
|
||||
#
|
||||
html_theme = 'alabaster'
|
||||
|
||||
# Theme options are theme-specific and customize the look and feel of a theme
|
||||
# further. For a list of options available for each theme, see the
|
||||
# documentation.
|
||||
#
|
||||
# html_theme_options = {}
|
||||
html_theme_options = {
|
||||
"base_bg": "inherit",
|
||||
"narrow_sidebar_bg": "inherit",
|
||||
}
|
||||
|
||||
# Add any paths that contain custom static files (such as style sheets) here,
|
||||
# relative to this directory. They are copied after the builtin static files,
|
||||
@ -110,7 +110,6 @@ html_static_path = ['_static']
|
||||
# Output file base name for HTML help builder.
|
||||
htmlhelp_basename = 'vnpydoc'
|
||||
|
||||
|
||||
# -- Options for LaTeX output ------------------------------------------------
|
||||
|
||||
latex_elements = {
|
||||
@ -139,7 +138,6 @@ latex_documents = [
|
||||
'vn.py Team', 'manual'),
|
||||
]
|
||||
|
||||
|
||||
# -- Options for manual page output ------------------------------------------
|
||||
|
||||
# One entry per manual page. List of tuples
|
||||
@ -149,7 +147,6 @@ man_pages = [
|
||||
[author], 1)
|
||||
]
|
||||
|
||||
|
||||
# -- Options for Texinfo output ----------------------------------------------
|
||||
|
||||
# Grouping the document tree into Texinfo files. List of tuples
|
||||
@ -161,7 +158,6 @@ texinfo_documents = [
|
||||
'Miscellaneous'),
|
||||
]
|
||||
|
||||
|
||||
# -- Options for Epub output -------------------------------------------------
|
||||
|
||||
# Bibliographic Dublin Core info.
|
||||
|
0
docs/csv_loader.md
Normal file
0
docs/csv_loader.md
Normal file
@ -1 +1,23 @@
|
||||
# Introduction
|
||||
# CTA策略模块
|
||||
|
||||
|
||||
## 模块构成
|
||||
|
||||
|
||||
## 历史数据
|
||||
|
||||
|
||||
|
||||
## 策略开发
|
||||
|
||||
|
||||
## 回测研究
|
||||
|
||||
|
||||
|
||||
## 参数优化
|
||||
|
||||
|
||||
|
||||
## 实盘运行
|
||||
|
||||
|
0
docs/gateway.md
Normal file
0
docs/gateway.md
Normal file
@ -1,6 +1,15 @@
|
||||
# vn.py文档
|
||||
|
||||
* [vn.py简介](introduction.md)
|
||||
* [项目安装](install.md)
|
||||
* [基本使用](quickstart.md)
|
||||
* [CTA策略模块](cta_strategy.md)
|
||||
* 快速入门
|
||||
* [项目简介](introduction.md)
|
||||
* [环境安装](install.md)
|
||||
* [基本使用](quickstart.md)
|
||||
|
||||
* 应用模块
|
||||
* [CTA策略](cta_strategy.md)
|
||||
* [CSV载入](csv_loader.md)
|
||||
|
||||
* [交易接口](gateway.md)
|
||||
|
||||
* [RPC应用](rpc.md)
|
||||
* [贡献代码](contribution.md)
|
@ -1,8 +1,33 @@
|
||||
# 安装指南
|
||||
|
||||
|
||||
## Windows
|
||||
|
||||
### 使用VNConda
|
||||
|
||||
**ssleay32.dll问题**
|
||||
|
||||
如果电脑上之前安装过其他的Python环境或者应用软件,有可能会存在SSL相关动态链接库路径被修改的问题,在运行VN Station时弹出如下图所示的错误:
|
||||
|
||||
![ssleay32.dll](https://user-images.githubusercontent.com/7112268/55474371-8bd06a00-5643-11e9-8b35-f064a45edfd1.png)
|
||||
|
||||
解决方法:
|
||||
1. 找到你的VNConda目录
|
||||
2. 将VNConda\Lib\site-packages\PyQt5\Qt\bin目录的两个动态库libeay32.dll和ssleay32.dll拷贝到VNConda\下即可
|
||||
|
||||
### 手动安装
|
||||
|
||||
|
||||
|
||||
## Ubuntu
|
||||
如果是英文系统,请先运行下列命令安装中文编码:
|
||||
|
||||
### 安装脚本
|
||||
|
||||
### TA-Lib
|
||||
|
||||
### 中文编码
|
||||
|
||||
如果是英文系统(如阿里云),请先运行下列命令安装中文编码:
|
||||
|
||||
```
|
||||
sudo locale-gen zh_CN.GB18030
|
||||
|
@ -1 +1,15 @@
|
||||
# Introduction
|
||||
|
||||
|
||||
## 目标用户
|
||||
|
||||
|
||||
## 应用场景
|
||||
|
||||
|
||||
|
||||
## 支持的接口
|
||||
|
||||
|
||||
|
||||
## 支持的应用
|
@ -1 +1,20 @@
|
||||
# Introduction
|
||||
# 基本使用
|
||||
|
||||
|
||||
## 启动VN Trader
|
||||
|
||||
|
||||
## 连接接口
|
||||
|
||||
|
||||
## 订阅行情
|
||||
|
||||
|
||||
## 委托交易
|
||||
|
||||
|
||||
## 数据监控
|
||||
|
||||
|
||||
## 应用模块
|
||||
|
||||
|
0
docs/rpc.md
Normal file
0
docs/rpc.md
Normal file
@ -1,6 +1,6 @@
|
||||
::Install talib and ibapi
|
||||
pip install https://vnpy-pip.oss-cn-shanghai.aliyuncs.com/colletion/TA_Lib-0.4.17-cp37-cp37m-win_amd64.whl
|
||||
pip install https://vnpy-pip.oss-cn-shanghai.aliyuncs.com/colletion/ibapi-9.75.1-py3-none-any.whl
|
||||
pip install https://pip.vnpy.com/colletion/TA_Lib-0.4.17-cp37-cp37m-win_amd64.whl
|
||||
pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
|
||||
|
||||
::Install Python Modules
|
||||
pip install -r requirements.txt
|
||||
|
@ -13,6 +13,10 @@ popd
|
||||
# old versions of ta-lib imports numpy in setup.py
|
||||
pip install numpy
|
||||
|
||||
# Install extra packages
|
||||
pip install ta-lib
|
||||
pip install https://vnpy-pip.oss-cn-shanghai.aliyuncs.com/colletion/ibapi-9.75.1-py3-none-any.whl
|
||||
|
||||
# Install Python Modules
|
||||
pip install -r requirements.txt
|
||||
|
||||
|
@ -10,11 +10,5 @@ matplotlib
|
||||
seaborn
|
||||
futu-api
|
||||
tigeropen
|
||||
|
||||
# ta-lib
|
||||
ta-lib; platform_system=="Linux"
|
||||
https://pip.vnpy.com/colletion/TA_Lib-0.4.17-cp37-cp37m-win_amd64.whl; platform_system=="Windows"
|
||||
|
||||
# ibapi
|
||||
https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
|
||||
|
||||
ta-lib
|
||||
ibapi
|
||||
|
189
tests/backtesting/genetic_algorithm.ipynb
Normal file
189
tests/backtesting/genetic_algorithm.ipynb
Normal file
@ -0,0 +1,189 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import random\n",
|
||||
"import multiprocessing\n",
|
||||
"import numpy as np\n",
|
||||
"from deap import creator, base, tools, algorithms\n",
|
||||
"from vnpy.app.cta_strategy.backtesting import BacktestingEngine\n",
|
||||
"from boll_channel_strategy import BollChannelStrategy\n",
|
||||
"from datetime import datetime\n",
|
||||
"import multiprocessing #多进程\n",
|
||||
"from scoop import futures #多进程"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def parameter_generate():\n",
|
||||
" '''\n",
|
||||
" 根据设置的起始值,终止值和步进,随机生成待优化的策略参数\n",
|
||||
" '''\n",
|
||||
" parameter_list = []\n",
|
||||
" p1 = random.randrange(4,50,2) #布林带窗口\n",
|
||||
" p2 = random.randrange(4,50,2) #布林带通道阈值\n",
|
||||
" p3 = random.randrange(4,50,2) #CCI窗口\n",
|
||||
" p4 = random.randrange(18,40,2) #ATR窗口 \n",
|
||||
"\n",
|
||||
" parameter_list.append(p1)\n",
|
||||
" parameter_list.append(p2)\n",
|
||||
" parameter_list.append(p3)\n",
|
||||
" parameter_list.append(p4)\n",
|
||||
"\n",
|
||||
" return parameter_list"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def object_func(strategy_avg):\n",
|
||||
" \"\"\"\n",
|
||||
" 本函数为优化目标函数,根据随机生成的策略参数,运行回测后自动返回2个结果指标:收益回撤比和夏普比率\n",
|
||||
" \"\"\"\n",
|
||||
" # 创建回测引擎对象\n",
|
||||
" engine = BacktestingEngine()\n",
|
||||
" engine.set_parameters(\n",
|
||||
" vt_symbol=\"IF88.CFFEX\",\n",
|
||||
" interval=\"1m\",\n",
|
||||
" start=datetime(2018, 9, 1),\n",
|
||||
" end=datetime(2019, 1,1),\n",
|
||||
" rate=0,\n",
|
||||
" slippage=0,\n",
|
||||
" size=300,\n",
|
||||
" pricetick=0.2,\n",
|
||||
" capital=1_000_000,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" setting = {'boll_window': strategy_avg[0], #布林带窗口\n",
|
||||
" 'boll_dev': strategy_avg[1], #布林带通道阈值\n",
|
||||
" 'cci_window': strategy_avg[2], #CCI窗口\n",
|
||||
" 'atr_window': strategy_avg[3],} #ATR窗口 \n",
|
||||
"\n",
|
||||
" #加载策略 \n",
|
||||
" #engine.initStrategy(TurtleTradingStrategy, setting)\n",
|
||||
" engine.add_strategy(BollChannelStrategy, setting)\n",
|
||||
" engine.load_data()\n",
|
||||
" engine.run_backtesting()\n",
|
||||
" engine.calculate_result()\n",
|
||||
" result = engine.calculate_statistics(Output=False)\n",
|
||||
"\n",
|
||||
" return_drawdown_ratio = round(result['return_drawdown_ratio'],2) #收益回撤比\n",
|
||||
" sharpe_ratio= round(result['sharpe_ratio'],2) #夏普比率\n",
|
||||
" return return_drawdown_ratio , sharpe_ratio"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"#设置优化方向:最大化收益回撤比,最大化夏普比率\n",
|
||||
"creator.create(\"FitnessMulti\", base.Fitness, weights=(1.0, 1.0)) # 1.0 求最大值;-1.0 求最小值\n",
|
||||
"creator.create(\"Individual\", list, fitness=creator.FitnessMulti)\n",
|
||||
"\n",
|
||||
"def optimize():\n",
|
||||
" \"\"\"\"\"\" \n",
|
||||
" toolbox = base.Toolbox() #Toolbox是deap库内置的工具箱,里面包含遗传算法中所用到的各种函数\n",
|
||||
"\n",
|
||||
" # 初始化 \n",
|
||||
" toolbox.register(\"individual\", tools.initIterate, creator.Individual,parameter_generate) # 注册个体:随机生成的策略参数parameter_generate() \n",
|
||||
" toolbox.register(\"population\", tools.initRepeat, list, toolbox.individual) #注册种群:个体形成种群 \n",
|
||||
" toolbox.register(\"mate\", tools.cxTwoPoint) #注册交叉:两点交叉 \n",
|
||||
" toolbox.register(\"mutate\", tools.mutUniformInt,low = 4,up = 40,indpb=0.6) #注册变异:随机生成一定区间内的整数\n",
|
||||
" toolbox.register(\"evaluate\", object_func) #注册评估:优化目标函数object_func() \n",
|
||||
" toolbox.register(\"select\", tools.selNSGA2) #注册选择:NSGA-II(带精英策略的非支配排序的遗传算法)\n",
|
||||
" #pool = multiprocessing.Pool()\n",
|
||||
" #toolbox.register(\"map\", pool.map)\n",
|
||||
" #toolbox.register(\"map\", futures.map)\n",
|
||||
"\n",
|
||||
" #遗传算法参数设置\n",
|
||||
" MU = 40 #设置每一代选择的个体数\n",
|
||||
" LAMBDA = 160 #设置每一代产生的子女数\n",
|
||||
" pop = toolbox.population(400) #设置族群里面的个体数量\n",
|
||||
" CXPB, MUTPB, NGEN = 0.5, 0.35,40 #分别为种群内部个体的交叉概率、变异概率、产生种群代数\n",
|
||||
" hof = tools.ParetoFront() #解的集合:帕累托前沿(非占优最优集)\n",
|
||||
"\n",
|
||||
" #解的集合的描述统计信息\n",
|
||||
" #集合内平均值,标准差,最小值,最大值可以体现集合的收敛程度\n",
|
||||
" #收敛程度低可以增加算法的迭代次数\n",
|
||||
" stats = tools.Statistics(lambda ind: ind.fitness.values)\n",
|
||||
" np.set_printoptions(suppress=True) #对numpy默认输出的科学计数法转换\n",
|
||||
" stats.register(\"mean\", np.mean, axis=0) #统计目标优化函数结果的平均值\n",
|
||||
" stats.register(\"std\", np.std, axis=0) #统计目标优化函数结果的标准差\n",
|
||||
" stats.register(\"min\", np.min, axis=0) #统计目标优化函数结果的最小值\n",
|
||||
" stats.register(\"max\", np.max, axis=0) #统计目标优化函数结果的最大值\n",
|
||||
"\n",
|
||||
" #运行算法\n",
|
||||
" algorithms.eaMuPlusLambda(pop, toolbox, MU, LAMBDA, CXPB, MUTPB, NGEN, stats,\n",
|
||||
" halloffame=hof) #esMuPlusLambda是一种基于(μ+λ)选择策略的多目标优化分段遗传算法\n",
|
||||
"\n",
|
||||
" return pop"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"optimize()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.1"
|
||||
},
|
||||
"toc": {
|
||||
"base_numbering": 1,
|
||||
"nav_menu": {},
|
||||
"number_sections": true,
|
||||
"sideBar": true,
|
||||
"skip_h1_title": false,
|
||||
"title_cell": "Table of Contents",
|
||||
"title_sidebar": "Contents",
|
||||
"toc_cell": false,
|
||||
"toc_position": {},
|
||||
"toc_section_display": true,
|
||||
"toc_window_display": false
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -10,6 +10,8 @@ from vnpy.gateway.ib import IbGateway
|
||||
from vnpy.gateway.ctp import CtpGateway
|
||||
from vnpy.gateway.tiger import TigerGateway
|
||||
from vnpy.gateway.oes import OesGateway
|
||||
from vnpy.gateway.okex import OkexGateway
|
||||
from vnpy.gateway.huobi import HuobiGateway
|
||||
|
||||
from vnpy.app.cta_strategy import CtaStrategyApp
|
||||
from vnpy.app.csv_loader import CsvLoaderApp
|
||||
@ -28,6 +30,8 @@ def main():
|
||||
main_engine.add_gateway(BitmexGateway)
|
||||
main_engine.add_gateway(TigerGateway)
|
||||
main_engine.add_gateway(OesGateway)
|
||||
main_engine.add_gateway(OkexGateway)
|
||||
main_engine.add_gateway(HuobiGateway)
|
||||
|
||||
main_engine.add_app(CtaStrategyApp)
|
||||
main_engine.add_app(CsvLoaderApp)
|
||||
|
@ -1 +1 @@
|
||||
__version__ = "2.0.1b0"
|
||||
__version__ = "2.0.1"
|
||||
|
Binary file not shown.
@ -1,2 +1,2 @@
|
||||
from .vnapex import ApexApi
|
||||
from .vnapex import *
|
||||
from .fiddef import *
|
||||
|
@ -31,7 +31,7 @@ class ApexApi:
|
||||
|
||||
def set_app_info(self, name: str, version: str):
|
||||
"""设置应用信息"""
|
||||
n = APEX.Fix_SetAppInfo(c_char_p(name), c_char_p(version))
|
||||
n = APEX.Fix_SetAppInfo(to_bytes(name), to_bytes(version))
|
||||
return bool(n)
|
||||
|
||||
def uninitialize(self):
|
||||
@ -42,19 +42,19 @@ class ApexApi:
|
||||
def set_default_info(self, user: str, wtfs: str, fbdm: str, dest: str):
|
||||
"""设置默认信息"""
|
||||
n = APEX.Fix_SetDefaultInfo(
|
||||
c_char_p(user),
|
||||
c_char_p(wtfs),
|
||||
c_char_p(fbdm),
|
||||
c_char_p(dest)
|
||||
to_bytes(user),
|
||||
to_bytes(wtfs),
|
||||
to_bytes(fbdm),
|
||||
to_bytes(dest)
|
||||
)
|
||||
return bool(n)
|
||||
|
||||
def connect(self, address: str, khh: str, pwd: str, timeout: int):
|
||||
"""连接交易"""
|
||||
conn = APEX.Fix_Connect(
|
||||
c_char_p(address),
|
||||
c_char_p(khh),
|
||||
c_char_p(pwd),
|
||||
to_bytes(address),
|
||||
to_bytes(khh),
|
||||
to_bytes(pwd),
|
||||
timeout
|
||||
)
|
||||
return conn
|
||||
@ -66,13 +66,13 @@ class ApexApi:
|
||||
):
|
||||
"""连接交易"""
|
||||
conn = APEX.Fix_ConnectEx(
|
||||
c_char_p(address),
|
||||
c_char_p(khh),
|
||||
c_char_p(pwd),
|
||||
c_char_p(file_cert),
|
||||
c_char_p(cert_pwd),
|
||||
c_char_p(file_ca),
|
||||
c_char_p(procotol),
|
||||
to_bytes(address),
|
||||
to_bytes(khh),
|
||||
to_bytes(pwd),
|
||||
to_bytes(file_cert),
|
||||
to_bytes(cert_pwd),
|
||||
to_bytes(file_ca),
|
||||
to_bytes(procotol),
|
||||
verify,
|
||||
timeout
|
||||
)
|
||||
@ -100,27 +100,27 @@ class ApexApi:
|
||||
|
||||
def set_wtfs(self, sess: int, wtfs: str):
|
||||
"""设置委托方式"""
|
||||
n = APEX.Fix_SetWTFS(sess, c_char_p(wtfs))
|
||||
n = APEX.Fix_SetWTFS(sess, to_bytes(wtfs))
|
||||
return bool(n)
|
||||
|
||||
def set_fbdm(self, sess: int, fbdm: str):
|
||||
"""设置来源营业部"""
|
||||
n = APEX.Fix_SetFBDM(sess, c_char_p(fbdm))
|
||||
n = APEX.Fix_SetFBDM(sess, to_bytes(fbdm))
|
||||
return bool(n)
|
||||
|
||||
def set_dest_fbdm(self, sess: int, fbdm: str):
|
||||
"""设置目标营业部"""
|
||||
n = APEX.Fix_SetDestFBDM(sess, c_char_p(fbdm))
|
||||
n = APEX.Fix_SetDestFBDM(sess, to_bytes(fbdm))
|
||||
return bool(n)
|
||||
|
||||
def set_node(self, sess: int, node: str):
|
||||
"""设置业务站点"""
|
||||
n = APEX.Fix_SetNode(sess, c_char_p(node))
|
||||
n = APEX.Fix_SetNode(sess, to_bytes(node))
|
||||
return bool(n)
|
||||
|
||||
def set_gydm(self, sess: int, gydm: str):
|
||||
"""设置柜员号"""
|
||||
n = APEX.Fix_SetGYDM(sess, c_char_p(gydm))
|
||||
n = APEX.Fix_SetGYDM(sess, to_bytes(gydm))
|
||||
return bool(n)
|
||||
|
||||
def create_head(self, sess: int, func: int):
|
||||
@ -170,7 +170,7 @@ class ApexApi:
|
||||
def get_err_msg(self, sess: int):
|
||||
"""获取错误信息"""
|
||||
size = 256
|
||||
out = create_string_buffer("", size)
|
||||
out = create_string_buffer(b"", size)
|
||||
|
||||
APEX.Fix_GetErrMsg(sess, out, size)
|
||||
return out.value
|
||||
@ -182,7 +182,7 @@ class ApexApi:
|
||||
def get_item(self, sess: int, fid: int, row: int):
|
||||
"""获取字符串内容"""
|
||||
size = 256
|
||||
out = create_string_buffer("", size)
|
||||
out = create_string_buffer(b"", size)
|
||||
|
||||
APEX.Fix_GetItem(sess, fid, out, size, row)
|
||||
return out.value
|
||||
@ -210,19 +210,21 @@ class ApexApi:
|
||||
def get_token(self, sess: int):
|
||||
"""获取业务令牌"""
|
||||
size = 256
|
||||
out = create_string_buffer("", size)
|
||||
out = create_string_buffer(b"", size)
|
||||
|
||||
APEX.Fix_GetToken(sess, out, size)
|
||||
return out.value
|
||||
|
||||
def encode(self, data):
|
||||
def encode(self, data: str):
|
||||
"""加密"""
|
||||
data = to_bytes(data)
|
||||
buf = create_string_buffer(data, 512)
|
||||
APEX.Fix_Encode(buf)
|
||||
return buf.value
|
||||
return to_unicode(buf.value)
|
||||
|
||||
def add_backup_svc_addr(self, address: str):
|
||||
"""设置业务令牌"""
|
||||
address = to_bytes(address)
|
||||
n = APEX.Fix_AddBackupSvrAddr(address)
|
||||
return bool(n)
|
||||
|
||||
@ -238,9 +240,11 @@ class ApexApi:
|
||||
|
||||
def subscribe_by_customer(self, conn: int, svc: int, khh: str, pwd: str):
|
||||
"""订阅数据"""
|
||||
func = APEX[93]
|
||||
n = func(conn, svc, self.push_call_func, c_char_p(""), khh, pwd)
|
||||
return bool(n)
|
||||
func = APEX[108]
|
||||
n = func(conn, svc, self.push_call_func,
|
||||
to_bytes(""), to_bytes(khh), to_bytes(pwd))
|
||||
|
||||
return n
|
||||
|
||||
def unsubscribe_by_handle(self, handle: int):
|
||||
"""退订推送"""
|
||||
@ -254,22 +258,22 @@ class ApexApi:
|
||||
def get_val_with_id_by_index(self, sess: int, row: int, col: int):
|
||||
"""根据行列获取数据"""
|
||||
s = 256
|
||||
buf = create_string_buffer("", s)
|
||||
buf = create_string_buffer(b"", s)
|
||||
fid = c_long(0)
|
||||
size = c_int(s)
|
||||
|
||||
APEX.Fix_GetValWithIdByIndex(
|
||||
sess, row, col, byref(fid), buf, byref(size))
|
||||
return fid.value, buf.value
|
||||
return fid.value, to_unicode(buf.value)
|
||||
|
||||
def set_system_no(self, sess: int, val: str):
|
||||
"""设置系统编号"""
|
||||
n = APEX.Fix_SetSystemNo(sess, c_char_p(val))
|
||||
n = APEX.Fix_SetSystemNo(sess, to_bytes(val))
|
||||
return bool(n)
|
||||
|
||||
def set_default_system_no(self, val: str):
|
||||
"""设置默认系统编号"""
|
||||
n = APEX.Fix_SetDefaultSystemNo(c_char_p(val))
|
||||
n = APEX.Fix_SetDefaultSystemNo(to_bytes(val))
|
||||
return bool(n)
|
||||
|
||||
def set_auto_reconnect(self, conn: int, reconnect: int):
|
||||
@ -291,23 +295,23 @@ class ApexApi:
|
||||
"""获取缓存数据"""
|
||||
size = 1024
|
||||
outlen = c_int(size)
|
||||
buf = create_string_buffer("", size)
|
||||
buf = create_string_buffer(b"", size)
|
||||
|
||||
APEX.Fix_GetItemBuf(sess, buf, byref(outlen), row)
|
||||
return buf
|
||||
|
||||
def set_item(self, sess: int, fid: int, val: str):
|
||||
"""设置请求内容"""
|
||||
n = APEX.Fix_SetString(sess, fid, c_char_p(val))
|
||||
n = APEX.Fix_SetString(sess, fid, to_bytes(val))
|
||||
return bool(n)
|
||||
|
||||
def get_last_err_msg(self):
|
||||
"""获取错误信息"""
|
||||
size = 256
|
||||
out = create_string_buffer("", size)
|
||||
out = create_string_buffer(b"", size)
|
||||
|
||||
APEX.Fix_GetLastErrMsg(out, size)
|
||||
return out.value
|
||||
return to_unicode(out.value)
|
||||
|
||||
def reg_reply_call_func(self, sess: int = 0):
|
||||
"""注册回调函数"""
|
||||
@ -328,3 +332,21 @@ class ApexApi:
|
||||
def on_conn(self, conn: int, event, recv):
|
||||
"""连接回调(需要继承)"""
|
||||
return True
|
||||
|
||||
|
||||
def to_bytes(data: str):
|
||||
"""
|
||||
将unicode字符串转换为bytes
|
||||
"""
|
||||
try:
|
||||
bytes_data = data.encode("GBK")
|
||||
return bytes_data
|
||||
except AttributeError:
|
||||
return data
|
||||
|
||||
|
||||
def to_unicode(data: bytes):
|
||||
"""
|
||||
将bytes字符串转换为unicode
|
||||
"""
|
||||
return data.decode("GBK")
|
||||
|
@ -48,14 +48,18 @@ class WebsocketClient(object):
|
||||
|
||||
self.proxy_host = None
|
||||
self.proxy_port = None
|
||||
self.ping_interval = 60 # seconds
|
||||
|
||||
# For debugging
|
||||
self._last_sent_text = None
|
||||
self._last_received_text = None
|
||||
|
||||
def init(self, host: str, proxy_host: str = "", proxy_port: int = 0):
|
||||
""""""
|
||||
def init(self, host: str, proxy_host: str = "", proxy_port: int = 0, ping_interval: int = 60):
|
||||
"""
|
||||
:param ping_interval: unit: seconds, type: int
|
||||
"""
|
||||
self.host = host
|
||||
self.ping_interval = ping_interval # seconds
|
||||
|
||||
if proxy_host and proxy_port:
|
||||
self.proxy_host = proxy_host
|
||||
@ -206,7 +210,7 @@ class WebsocketClient(object):
|
||||
et, ev, tb = sys.exc_info()
|
||||
self.on_error(et, ev, tb)
|
||||
self._reconnect()
|
||||
for i in range(60):
|
||||
for i in range(self.ping_interval):
|
||||
if not self._active:
|
||||
break
|
||||
sleep(1)
|
||||
|
17
vnpy/app/algo_trading/__init__.py
Normal file
17
vnpy/app/algo_trading/__init__.py
Normal file
@ -0,0 +1,17 @@
|
||||
from pathlib import Path
|
||||
|
||||
from vnpy.trader.app import BaseApp
|
||||
|
||||
from .engine import AlgoEngine, APP_NAME
|
||||
|
||||
|
||||
class CtaStrategyApp(BaseApp):
|
||||
""""""
|
||||
|
||||
app_name = APP_NAME
|
||||
app_module = __module__
|
||||
app_path = Path(__file__).parent
|
||||
display_name = "算法交易"
|
||||
engine_class = AlgoEngine
|
||||
widget_name = "AlgoManager"
|
||||
icon_name = "algo.ico"
|
59
vnpy/app/algo_trading/algos/__init__.py
Normal file
59
vnpy/app/algo_trading/algos/__init__.py
Normal file
@ -0,0 +1,59 @@
|
||||
# encoding: UTF-8
|
||||
|
||||
'''
|
||||
动态载入所有的策略类
|
||||
'''
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import importlib
|
||||
import traceback
|
||||
|
||||
|
||||
# 用来保存算法类和控件类的字典
|
||||
ALGO_DICT = {}
|
||||
WIDGET_DICT = {}
|
||||
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
def loadAlgoModule(path, prefix):
|
||||
"""使用importlib动态载入算法"""
|
||||
for root, subdirs, files in os.walk(path):
|
||||
for name in files:
|
||||
# 只有文件名以Algo.py结尾的才是算法文件
|
||||
if len(name)>7 and name[-7:] == 'Algo.py':
|
||||
try:
|
||||
# 模块名称需要模块路径前缀
|
||||
moduleName = prefix + name.replace('.py', '')
|
||||
module = importlib.import_module(moduleName)
|
||||
|
||||
# 获取算法类和控件类
|
||||
algo = None
|
||||
widget = None
|
||||
|
||||
for k in dir(module):
|
||||
# 以Algo结尾的类,是算法
|
||||
if k[-4:] == 'Algo':
|
||||
algo = module.__getattribute__(k)
|
||||
|
||||
# 以Widget结尾的类,是控件
|
||||
if k[-6:] == 'Widget':
|
||||
widget = module.__getattribute__(k)
|
||||
|
||||
# 保存到字典中
|
||||
if algo and widget:
|
||||
ALGO_DICT[algo.templateName] = algo
|
||||
WIDGET_DICT[algo.templateName] = widget
|
||||
except:
|
||||
print ('-' * 20)
|
||||
print ('Failed to import strategy file %s:' %moduleName)
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
# 遍历algo目录下的文件
|
||||
path1 = os.path.abspath(os.path.dirname(__file__))
|
||||
loadAlgoModule(path1, 'vnpy.trader.app.algoTrading.algo.')
|
||||
|
||||
# 遍历工作目录下的文件
|
||||
path2 = os.getcwd()
|
||||
loadAlgoModule(path2, '')
|
0
vnpy/app/algo_trading/algos/iceberg_algo.py
Normal file
0
vnpy/app/algo_trading/algos/iceberg_algo.py
Normal file
0
vnpy/app/algo_trading/algos/sniper_algo.py
Normal file
0
vnpy/app/algo_trading/algos/sniper_algo.py
Normal file
0
vnpy/app/algo_trading/algos/twap_algo.py
Normal file
0
vnpy/app/algo_trading/algos/twap_algo.py
Normal file
65
vnpy/app/algo_trading/engine.py
Normal file
65
vnpy/app/algo_trading/engine.py
Normal file
@ -0,0 +1,65 @@
|
||||
|
||||
from vnpy.event import EventEngine
|
||||
from vnpy.trader.engine import BaseEngine, MainEngine
|
||||
from vnpy.trader.event import (EVENT_TICK, EVENT_TIMER, EVENT_ORDER, EVENT_TRADE)
|
||||
|
||||
|
||||
class AlgoEngine(BaseEngine):
|
||||
""""""
|
||||
|
||||
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
|
||||
"""Constructor"""
|
||||
super().__init__(main_engine, event_engine)
|
||||
|
||||
self.algos = {}
|
||||
self.symbol_algo_map = {}
|
||||
self.orderid_algo_map = {}
|
||||
|
||||
self.register_event()
|
||||
|
||||
def register_event(self):
|
||||
""""""
|
||||
self.event_engine.register(EVENT_TICK, self.process_tick_event)
|
||||
self.event_engine.register(EVENT_TIMER, self.process_timer_event)
|
||||
self.event_engine.register(EVENT_ORDER, self.process_order_event)
|
||||
self.event_engine.register(EVENT_TRADE, self.process_trade_event)
|
||||
|
||||
def process_tick_event(self):
|
||||
""""""
|
||||
pass
|
||||
|
||||
def process_timer_event(self):
|
||||
""""""
|
||||
pass
|
||||
|
||||
def process_trade_event(self):
|
||||
""""""
|
||||
pass
|
||||
|
||||
def process_order_event(self):
|
||||
""""""
|
||||
pass
|
||||
|
||||
def start_algo(self, setting: dict):
|
||||
""""""
|
||||
pass
|
||||
|
||||
def stop_algo(self, algo_name: dict):
|
||||
""""""
|
||||
pass
|
||||
|
||||
def stop_all(self):
|
||||
""""""
|
||||
pass
|
||||
|
||||
def subscribe(self, algo, vt_symbol):
|
||||
""""""
|
||||
pass
|
||||
|
||||
def send_order(
|
||||
self,
|
||||
algo,
|
||||
vt_symbol
|
||||
):
|
||||
""""""
|
||||
pass
|
123
vnpy/app/algo_trading/template.py
Normal file
123
vnpy/app/algo_trading/template.py
Normal file
@ -0,0 +1,123 @@
|
||||
from vnpy.trader.engine import BaseEngine
|
||||
from vnpy.trader.object import TickData, OrderData, TradeData
|
||||
from vnpy.trader.constant import OrderType, Offset
|
||||
|
||||
class AlgoTemplate:
|
||||
""""""
|
||||
count = 0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
algo_engine: BaseEngine,
|
||||
algo_name: str,
|
||||
setting: dict
|
||||
):
|
||||
"""Constructor"""
|
||||
self.algo_engine = algo_engine
|
||||
self.algo_name = algo_name
|
||||
|
||||
self.active = False
|
||||
self.active_orders = {} # vt_orderid:order
|
||||
|
||||
@staticmethod
|
||||
def new(cls, algo_engine:BaseEngine, setting: dict):
|
||||
"""Create new algo instance"""
|
||||
cls.count += 1
|
||||
algo_name = f"{cls.__name__}_{cls.count}"
|
||||
algo = cls(algo_engine, algo_name, setting)
|
||||
|
||||
def update_tick(self, tick: TickData):
|
||||
""""""
|
||||
if self.active:
|
||||
self.on_tick(tick)
|
||||
|
||||
def update_order(self, order: OrderData):
|
||||
""""""
|
||||
if self.active:
|
||||
if order.is_active():
|
||||
self.active_orders[order.vt_orderid] = order
|
||||
elif order.vt_orderid in self.active_orders:
|
||||
self.active_orders.pop(order.vt_orderid)
|
||||
|
||||
self.on_order(order)
|
||||
|
||||
def update_trade(self, trade: TradeData):
|
||||
""""""
|
||||
if self.active:
|
||||
self.on_trade(trade)
|
||||
|
||||
def update_timer(self):
|
||||
""""""
|
||||
if self.active:
|
||||
self.on_timer()
|
||||
|
||||
def on_start(self):
|
||||
""""""
|
||||
pass
|
||||
|
||||
def on_stop(self):
|
||||
""""""
|
||||
pass
|
||||
|
||||
def on_tick(self, tick: TickData):
|
||||
""""""
|
||||
pass
|
||||
|
||||
def on_order(self, order: OrderData):
|
||||
""""""
|
||||
pass
|
||||
|
||||
def on_trade(self, trade: TradeData):
|
||||
""""""
|
||||
pass
|
||||
|
||||
def on_timer(self):
|
||||
""""""
|
||||
pass
|
||||
|
||||
def start(self):
|
||||
""""""
|
||||
pass
|
||||
|
||||
def stop(self):
|
||||
""""""
|
||||
pass
|
||||
|
||||
def buy(
|
||||
self,
|
||||
vt_symbol,
|
||||
price,
|
||||
volume,
|
||||
order_type: OrderType = OrderType.LIMIT,
|
||||
offset: Offset = Offset.NONE
|
||||
):
|
||||
""""""
|
||||
return self.algo_engine.buy(
|
||||
vt_symbol,
|
||||
price,
|
||||
volume,
|
||||
order_type,
|
||||
offset
|
||||
)
|
||||
|
||||
def sell(
|
||||
self,
|
||||
vt_symbol,
|
||||
price,
|
||||
volume,
|
||||
order_type: OrderType = OrderType.LIMIT,
|
||||
offset: Offset = Offset.NONE
|
||||
):
|
||||
""""""
|
||||
return self.algo_engine.buy(
|
||||
vt_symbol,
|
||||
price,
|
||||
volume,
|
||||
order_type,
|
||||
offset
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
|
BIN
vnpy/app/algo_trading/ui/algo.ico
Normal file
BIN
vnpy/app/algo_trading/ui/algo.ico
Normal file
Binary file not shown.
After Width: | Height: | Size: 66 KiB |
0
vnpy/app/algo_trading/ui/widget.py
Normal file
0
vnpy/app/algo_trading/ui/widget.py
Normal file
@ -23,9 +23,11 @@ Sample csv file:
|
||||
import csv
|
||||
from datetime import datetime
|
||||
|
||||
from peewee import chunked
|
||||
|
||||
from vnpy.event import EventEngine
|
||||
from vnpy.trader.constant import Exchange, Interval
|
||||
from vnpy.trader.database import DbBarData
|
||||
from vnpy.trader.database import DbBarData, DB
|
||||
from vnpy.trader.engine import BaseEngine, MainEngine
|
||||
|
||||
|
||||
@ -75,30 +77,37 @@ class CsvLoaderEngine(BaseEngine):
|
||||
with open(file_path, 'rt') as f:
|
||||
reader = csv.DictReader(f)
|
||||
|
||||
db_bars = []
|
||||
|
||||
for item in reader:
|
||||
db_bar = DbBarData()
|
||||
dt = datetime.strptime(item[datetime_head], datetime_format)
|
||||
|
||||
db_bar.symbol = symbol
|
||||
db_bar.exchange = exchange.value
|
||||
db_bar.datetime = datetime.strptime(
|
||||
item[datetime_head], datetime_format
|
||||
)
|
||||
db_bar.interval = interval.value
|
||||
db_bar.volume = item[volume_head]
|
||||
db_bar.open_price = item[open_head]
|
||||
db_bar.high_price = item[high_head]
|
||||
db_bar.low_price = item[low_head]
|
||||
db_bar.close_price = item[close_head]
|
||||
db_bar.vt_symbol = vt_symbol
|
||||
db_bar.gateway_name = "DB"
|
||||
db_bar = {
|
||||
"symbol": symbol,
|
||||
"exchange": exchange.value,
|
||||
"datetime": dt,
|
||||
"interval": interval.value,
|
||||
"volume": item[volume_head],
|
||||
"open_price": item[open_head],
|
||||
"high_price": item[high_head],
|
||||
"low_price": item[low_head],
|
||||
"close_price": item[close_head],
|
||||
"vt_symbol": vt_symbol,
|
||||
"gateway_name": "DB"
|
||||
}
|
||||
|
||||
db_bar.replace()
|
||||
db_bars.append(db_bar)
|
||||
|
||||
# do some statistics
|
||||
count += 1
|
||||
if not start:
|
||||
start = db_bar.datetime
|
||||
start = db_bar["datetime"]
|
||||
|
||||
end = db_bar.datetime
|
||||
end = db_bar["datetime"]
|
||||
|
||||
# Insert into DB
|
||||
with DB.atomic():
|
||||
for batch in chunked(db_bars, 500):
|
||||
DbBarData.insert_many(batch).on_conflict_replace().execute()
|
||||
|
||||
return start, end, count
|
||||
|
@ -90,7 +90,8 @@ class CsvLoaderWidget(QtWidgets.QWidget):
|
||||
|
||||
def select_file(self):
|
||||
""""""
|
||||
result: str = QtWidgets.QFileDialog.getOpenFileName(self)
|
||||
result: str = QtWidgets.QFileDialog.getOpenFileName(
|
||||
self, filter="CSV (*.csv)")
|
||||
filename = result[0]
|
||||
if filename:
|
||||
self.file_edit.setText(filename)
|
||||
|
@ -35,7 +35,7 @@ class OptimizationSetting:
|
||||
def __init__(self):
|
||||
""""""
|
||||
self.params = {}
|
||||
self.target = ""
|
||||
self.target_name = ""
|
||||
|
||||
def add_parameter(
|
||||
self, name: str, start: float, end: float = None, step: float = None
|
||||
@ -62,9 +62,9 @@ class OptimizationSetting:
|
||||
|
||||
self.params[name] = value_list
|
||||
|
||||
def set_target(self, target: str):
|
||||
def set_target(self, target_name: str):
|
||||
""""""
|
||||
self.target = target
|
||||
self.target_name = target_name
|
||||
|
||||
def generate_setting(self):
|
||||
""""""
|
||||
@ -293,7 +293,7 @@ class BacktestingEngine:
|
||||
self.output("逐日盯市盈亏计算完成")
|
||||
return self.daily_df
|
||||
|
||||
def calculate_statistics(self, df: DataFrame = None):
|
||||
def calculate_statistics(self, df: DataFrame = None, Output=True):
|
||||
""""""
|
||||
self.output("开始计算策略统计指标")
|
||||
|
||||
@ -325,6 +325,7 @@ class BacktestingEngine:
|
||||
daily_return = 0
|
||||
return_std = 0
|
||||
sharpe_ratio = 0
|
||||
return_drawdown_ratio = 0
|
||||
else:
|
||||
# Calculate balance related time series data
|
||||
df["balance"] = df["net_pnl"].cumsum() + self.capital
|
||||
@ -373,38 +374,42 @@ class BacktestingEngine:
|
||||
else:
|
||||
sharpe_ratio = 0
|
||||
|
||||
return_drawdown_ratio = -total_return / max_ddpercent
|
||||
|
||||
# Output
|
||||
self.output("-" * 30)
|
||||
self.output(f"首个交易日:\t{start_date}")
|
||||
self.output(f"最后交易日:\t{end_date}")
|
||||
if Output:
|
||||
self.output("-" * 30)
|
||||
self.output(f"首个交易日:\t{start_date}")
|
||||
self.output(f"最后交易日:\t{end_date}")
|
||||
|
||||
self.output(f"总交易日:\t{total_days}")
|
||||
self.output(f"盈利交易日:\t{profit_days}")
|
||||
self.output(f"亏损交易日:\t{loss_days}")
|
||||
self.output(f"总交易日:\t{total_days}")
|
||||
self.output(f"盈利交易日:\t{profit_days}")
|
||||
self.output(f"亏损交易日:\t{loss_days}")
|
||||
|
||||
self.output(f"起始资金:\t{self.capital:,.2f}")
|
||||
self.output(f"结束资金:\t{end_balance:,.2f}")
|
||||
self.output(f"起始资金:\t{self.capital:,.2f}")
|
||||
self.output(f"结束资金:\t{end_balance:,.2f}")
|
||||
|
||||
self.output(f"总收益率:\t{total_return:,.2f}%")
|
||||
self.output(f"年化收益:\t{annual_return:,.2f}%")
|
||||
self.output(f"最大回撤: \t{max_drawdown:,.2f}")
|
||||
self.output(f"百分比最大回撤: {max_ddpercent:,.2f}%")
|
||||
self.output(f"总收益率:\t{total_return:,.2f}%")
|
||||
self.output(f"年化收益:\t{annual_return:,.2f}%")
|
||||
self.output(f"最大回撤: \t{max_drawdown:,.2f}")
|
||||
self.output(f"百分比最大回撤: {max_ddpercent:,.2f}%")
|
||||
|
||||
self.output(f"总盈亏:\t{total_net_pnl:,.2f}")
|
||||
self.output(f"总手续费:\t{total_commission:,.2f}")
|
||||
self.output(f"总滑点:\t{total_slippage:,.2f}")
|
||||
self.output(f"总成交金额:\t{total_turnover:,.2f}")
|
||||
self.output(f"总成交笔数:\t{total_trade_count}")
|
||||
self.output(f"总盈亏:\t{total_net_pnl:,.2f}")
|
||||
self.output(f"总手续费:\t{total_commission:,.2f}")
|
||||
self.output(f"总滑点:\t{total_slippage:,.2f}")
|
||||
self.output(f"总成交金额:\t{total_turnover:,.2f}")
|
||||
self.output(f"总成交笔数:\t{total_trade_count}")
|
||||
|
||||
self.output(f"日均盈亏:\t{daily_net_pnl:,.2f}")
|
||||
self.output(f"日均手续费:\t{daily_commission:,.2f}")
|
||||
self.output(f"日均滑点:\t{daily_slippage:,.2f}")
|
||||
self.output(f"日均成交金额:\t{daily_turnover:,.2f}")
|
||||
self.output(f"日均成交笔数:\t{daily_trade_count}")
|
||||
self.output(f"日均盈亏:\t{daily_net_pnl:,.2f}")
|
||||
self.output(f"日均手续费:\t{daily_commission:,.2f}")
|
||||
self.output(f"日均滑点:\t{daily_slippage:,.2f}")
|
||||
self.output(f"日均成交金额:\t{daily_turnover:,.2f}")
|
||||
self.output(f"日均成交笔数:\t{daily_trade_count}")
|
||||
|
||||
self.output(f"日均收益率:\t{daily_return:,.2f}%")
|
||||
self.output(f"收益标准差:\t{return_std:,.2f}%")
|
||||
self.output(f"Sharpe Ratio:\t{sharpe_ratio:,.2f}")
|
||||
self.output(f"日均收益率:\t{daily_return:,.2f}%")
|
||||
self.output(f"收益标准差:\t{return_std:,.2f}%")
|
||||
self.output(f"Sharpe Ratio:\t{sharpe_ratio:,.2f}")
|
||||
self.output(f"收益回撤比:\t{return_drawdown_ratio:,.2f}")
|
||||
|
||||
statistics = {
|
||||
"start_date": start_date,
|
||||
@ -430,6 +435,7 @@ class BacktestingEngine:
|
||||
"daily_return": daily_return,
|
||||
"return_std": return_std,
|
||||
"sharpe_ratio": sharpe_ratio,
|
||||
"return_drawdown_ratio": return_drawdown_ratio,
|
||||
}
|
||||
|
||||
return statistics
|
||||
@ -473,7 +479,7 @@ class BacktestingEngine:
|
||||
return
|
||||
|
||||
if not target_name:
|
||||
self.output("优化目标为设置,请检查")
|
||||
self.output("优化目标未设置,请检查")
|
||||
return
|
||||
|
||||
# Use multiprocessing pool for running backtesting with different setting
|
||||
|
@ -151,13 +151,14 @@ class CtaEngine(BaseEngine):
|
||||
Query bar data from RQData.
|
||||
"""
|
||||
symbol, exchange_str = vt_symbol.split(".")
|
||||
if symbol.upper() not in self.rq_symbols:
|
||||
rq_symbol = to_rq_symbol(vt_symbol)
|
||||
if rq_symbol not in self.rq_symbols:
|
||||
return None
|
||||
|
||||
end += timedelta(1) # For querying night trading period data
|
||||
|
||||
df = self.rq_client.get_price(
|
||||
symbol.upper(),
|
||||
rq_symbol,
|
||||
frequency=interval.value,
|
||||
fields=["open", "high", "low", "close", "volume"],
|
||||
start_date=start,
|
||||
@ -540,7 +541,7 @@ class CtaEngine(BaseEngine):
|
||||
DbBarData.select()
|
||||
.where(
|
||||
(DbBarData.vt_symbol == vt_symbol)
|
||||
& (DbBarData.interval == interval)
|
||||
& (DbBarData.interval == interval.value)
|
||||
& (DbBarData.datetime >= start)
|
||||
& (DbBarData.datetime <= end)
|
||||
)
|
||||
@ -910,3 +911,29 @@ class CtaEngine(BaseEngine):
|
||||
subject = "CTA策略引擎"
|
||||
|
||||
self.main_engine.send_email(subject, msg)
|
||||
|
||||
|
||||
def to_rq_symbol(vt_symbol: str):
|
||||
"""
|
||||
CZCE product of RQData has symbol like "TA1905" while
|
||||
vt symbol is "TA905.CZCE" so need to add "1" in symbol.
|
||||
"""
|
||||
symbol, exchange_str = vt_symbol.split(".")
|
||||
if exchange_str != "CZCE":
|
||||
return symbol.upper()
|
||||
|
||||
for count, word in enumerate(symbol):
|
||||
if word.isdigit():
|
||||
break
|
||||
|
||||
product = symbol[:count]
|
||||
year = symbol[count]
|
||||
month = symbol[count + 1:]
|
||||
|
||||
if year == "9":
|
||||
year = "1" + year
|
||||
else:
|
||||
year = "2" + year
|
||||
|
||||
rq_symbol = f"{product}{year}{month}".upper()
|
||||
return rq_symbol
|
||||
|
@ -2,7 +2,7 @@
|
||||
from abc import ABC
|
||||
from typing import Any, Callable
|
||||
|
||||
from vnpy.trader.constant import Interval, Status, Direction, Offset
|
||||
from vnpy.trader.constant import Interval, Direction, Offset
|
||||
from vnpy.trader.object import BarData, TickData, OrderData, TradeData
|
||||
|
||||
from .base import StopOrder, EngineType
|
||||
|
@ -405,7 +405,7 @@ class CtpTdApi(TdApi):
|
||||
""""""
|
||||
if not error['ErrorID']:
|
||||
self.authStatus = True
|
||||
self.writeLog("交易授权验证成功")
|
||||
self.gateway.write_log("交易授权验证成功")
|
||||
self.login()
|
||||
else:
|
||||
self.gateway.write_error("交易授权验证失败", error)
|
||||
@ -418,7 +418,7 @@ class CtpTdApi(TdApi):
|
||||
self.login_status = True
|
||||
self.gateway.write_log("交易登录成功")
|
||||
|
||||
# Confirm settelment
|
||||
# Confirm settlement
|
||||
req = {
|
||||
"BrokerID": self.brokerid,
|
||||
"InvestorID": self.userid
|
||||
@ -549,13 +549,16 @@ class CtpTdApi(TdApi):
|
||||
product=product,
|
||||
size=data["VolumeMultiple"],
|
||||
pricetick=data["PriceTick"],
|
||||
option_underlying=data["UnderlyingInstrID"],
|
||||
option_type=OPTIONTYPE_CTP2VT.get(data["OptionsType"], None),
|
||||
option_strike=data["StrikePrice"],
|
||||
option_expiry=datetime.strptime(data["ExpireDate"], "%Y%m%d"),
|
||||
gateway_name=self.gateway_name
|
||||
)
|
||||
|
||||
# For option only
|
||||
if data["OptionsType"]:
|
||||
contract.option_underlying = data["UnderlyingInstrID"],
|
||||
contract.option_type = OPTIONTYPE_CTP2VT.get(data["OptionsType"], None),
|
||||
contract.option_strike = data["StrikePrice"],
|
||||
contract.option_expiry = datetime.strptime(data["ExpireDate"], "%Y%m%d"),
|
||||
|
||||
self.gateway.on_contract(contract)
|
||||
|
||||
symbol_exchange_map[contract.symbol] = contract.exchange
|
||||
@ -662,7 +665,7 @@ class CtpTdApi(TdApi):
|
||||
"UserID": self.userid,
|
||||
"BrokerID": self.brokerid,
|
||||
"AuthCode": self.auth_code,
|
||||
"ProductInfo": self.product_info
|
||||
"UserProductInfo": self.product_info
|
||||
}
|
||||
|
||||
self.reqid += 1
|
||||
@ -678,7 +681,8 @@ class CtpTdApi(TdApi):
|
||||
req = {
|
||||
"UserID": self.userid,
|
||||
"Password": self.password,
|
||||
"BrokerID": self.brokerid
|
||||
"BrokerID": self.brokerid,
|
||||
"UserProductInfo": self.product_info
|
||||
}
|
||||
|
||||
self.reqid += 1
|
||||
|
4
vnpy/gateway/huobi/__init__.py
Normal file
4
vnpy/gateway/huobi/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
from .huobi_gateway import HuobiGateway
|
||||
|
||||
|
||||
|
715
vnpy/gateway/huobi/huobi_gateway.py
Normal file
715
vnpy/gateway/huobi/huobi_gateway.py
Normal file
@ -0,0 +1,715 @@
|
||||
# encoding: UTF-8
|
||||
|
||||
"""
|
||||
火币交易接口
|
||||
"""
|
||||
|
||||
import re
|
||||
import urllib
|
||||
import base64
|
||||
import json
|
||||
import zlib
|
||||
import hashlib
|
||||
import hmac
|
||||
from copy import copy
|
||||
from datetime import datetime
|
||||
|
||||
from vnpy.event import Event
|
||||
from vnpy.api.rest import RestClient
|
||||
from vnpy.api.websocket import WebsocketClient
|
||||
from vnpy.trader.constant import (
|
||||
Direction,
|
||||
Exchange,
|
||||
Product,
|
||||
Status,
|
||||
OrderType
|
||||
)
|
||||
from vnpy.trader.gateway import BaseGateway, LocalOrderManager
|
||||
from vnpy.trader.object import (
|
||||
TickData,
|
||||
OrderData,
|
||||
TradeData,
|
||||
AccountData,
|
||||
ContractData,
|
||||
OrderRequest,
|
||||
CancelRequest,
|
||||
SubscribeRequest
|
||||
)
|
||||
from vnpy.trader.event import EVENT_TIMER
|
||||
|
||||
|
||||
REST_HOST = "https://api.huobipro.com"
|
||||
WEBSOCKET_DATA_HOST = "wss://api.huobi.pro/ws" # Market Data
|
||||
WEBSOCKET_TRADE_HOST = "wss://api.huobi.pro/ws/v1" # Account and Order
|
||||
|
||||
STATUS_HUOBI2VT = {
|
||||
"submitted": Status.NOTTRADED,
|
||||
"partial-filled": Status.PARTTRADED,
|
||||
"filled": Status.ALLTRADED,
|
||||
"cancelling": Status.CANCELLED,
|
||||
"partial-canceled": Status.CANCELLED,
|
||||
"canceled": Status.CANCELLED,
|
||||
}
|
||||
|
||||
ORDERTYPE_VT2HUOBI = {
|
||||
(Direction.LONG, OrderType.MARKET): "buy-market",
|
||||
(Direction.SHORT, OrderType.MARKET): "sell-market",
|
||||
(Direction.LONG, OrderType.LIMIT): "buy-limit",
|
||||
(Direction.SHORT, OrderType.LIMIT): "sell-limit",
|
||||
}
|
||||
ORDERTYPE_HUOBI2VT = {v: k for k, v in ORDERTYPE_VT2HUOBI.items()}
|
||||
|
||||
|
||||
huobi_symbols = set()
|
||||
symbol_name_map = {}
|
||||
|
||||
|
||||
class HuobiGateway(BaseGateway):
|
||||
"""
|
||||
VN Trader Gateway for Huobi connection.
|
||||
"""
|
||||
|
||||
default_setting = {
|
||||
"API Key": "",
|
||||
"Secret Key": "",
|
||||
"会话数": 3,
|
||||
"代理地址": "127.0.0.1",
|
||||
"代理端口": 1080,
|
||||
}
|
||||
|
||||
def __init__(self, event_engine):
|
||||
"""Constructor"""
|
||||
super(HuobiGateway, self).__init__(event_engine, "HUOBI")
|
||||
|
||||
self.order_manager = LocalOrderManager(self)
|
||||
|
||||
self.rest_api = HuobiRestApi(self)
|
||||
self.trade_ws_api = HuobiTradeWebsocketApi(self)
|
||||
self.market_ws_api = HuobiDataWebsocketApi(self)
|
||||
|
||||
def connect(self, setting: dict):
|
||||
""""""
|
||||
key = setting["API Key"]
|
||||
secret = setting["Secret Key"]
|
||||
session_number = setting["会话数"]
|
||||
proxy_host = setting["代理地址"]
|
||||
proxy_port = setting["代理端口"]
|
||||
|
||||
self.rest_api.connect(key, secret, session_number,
|
||||
proxy_host, proxy_port)
|
||||
self.trade_ws_api.connect(key, secret, proxy_host, proxy_port)
|
||||
self.market_ws_api.connect(key, secret, proxy_host, proxy_port)
|
||||
|
||||
self.init_query()
|
||||
|
||||
def subscribe(self, req: SubscribeRequest):
|
||||
""""""
|
||||
self.market_ws_api.subscribe(req)
|
||||
self.trade_ws_api.subscribe(req)
|
||||
|
||||
def send_order(self, req: OrderRequest):
|
||||
""""""
|
||||
return self.rest_api.send_order(req)
|
||||
|
||||
def cancel_order(self, req: CancelRequest):
|
||||
""""""
|
||||
self.rest_api.cancel_order(req)
|
||||
|
||||
def query_account(self):
|
||||
""""""
|
||||
self.rest_api.query_account_balance()
|
||||
|
||||
def query_position(self):
|
||||
""""""
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
""""""
|
||||
self.rest_api.stop()
|
||||
self.trade_ws_api.stop()
|
||||
self.market_ws_api.stop()
|
||||
|
||||
def process_timer_event(self, event: Event):
|
||||
""""""
|
||||
self.count += 1
|
||||
if self.count < 3:
|
||||
return
|
||||
|
||||
self.query_account()
|
||||
|
||||
def init_query(self):
|
||||
""""""
|
||||
self.count = 0
|
||||
self.event_engine.register(EVENT_TIMER, self.process_timer_event)
|
||||
|
||||
|
||||
class HuobiRestApi(RestClient):
|
||||
"""
|
||||
HUOBI REST API
|
||||
"""
|
||||
|
||||
def __init__(self, gateway: BaseGateway):
|
||||
""""""
|
||||
super(HuobiRestApi, self).__init__()
|
||||
|
||||
self.gateway = gateway
|
||||
self.gateway_name = gateway.gateway_name
|
||||
self.order_manager = gateway.order_manager
|
||||
|
||||
self.host = ""
|
||||
self.key = ""
|
||||
self.secret = ""
|
||||
self.account_id = ""
|
||||
|
||||
self.cancel_requests = {}
|
||||
self.orders = {}
|
||||
|
||||
def sign(self, request):
|
||||
"""
|
||||
Generate HUOBI signature.
|
||||
"""
|
||||
request.headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/39.0.2171.71 Safari/537.36"
|
||||
}
|
||||
params_with_signature = create_signature(
|
||||
self.key,
|
||||
request.method,
|
||||
self.host,
|
||||
request.path,
|
||||
self.secret,
|
||||
request.params
|
||||
)
|
||||
request.params = params_with_signature
|
||||
|
||||
if request.method == "POST":
|
||||
request.headers["Content-Type"] = "application/json"
|
||||
|
||||
if request.data:
|
||||
request.data = json.dumps(request.data)
|
||||
|
||||
return request
|
||||
|
||||
def connect(
|
||||
self,
|
||||
key: str,
|
||||
secret: str,
|
||||
session_number: int,
|
||||
proxy_host: str,
|
||||
proxy_port: int
|
||||
):
|
||||
"""
|
||||
Initialize connection to REST server.
|
||||
"""
|
||||
self.key = key
|
||||
self.secret = secret
|
||||
|
||||
self.host, _ = _split_url(REST_HOST)
|
||||
|
||||
self.init(REST_HOST, proxy_host, proxy_port)
|
||||
self.start(session_number)
|
||||
|
||||
self.gateway.write_log("REST API启动成功")
|
||||
|
||||
self.query_contract()
|
||||
self.query_account()
|
||||
self.query_order()
|
||||
|
||||
def query_account(self):
|
||||
""""""
|
||||
self.add_request(
|
||||
method="GET",
|
||||
path="/v1/account/accounts",
|
||||
callback=self.on_query_account
|
||||
)
|
||||
|
||||
def query_account_balance(self):
|
||||
""""""
|
||||
path = f"/v1/account/accounts/{self.account_id}/balance"
|
||||
self.add_request(
|
||||
method="GET",
|
||||
path=path,
|
||||
callback=self.on_query_account_balance
|
||||
)
|
||||
|
||||
def query_order(self):
|
||||
""""""
|
||||
self.add_request(
|
||||
method="GET",
|
||||
path="/v1/order/openOrders",
|
||||
callback=self.on_query_order
|
||||
)
|
||||
|
||||
def query_contract(self):
|
||||
""""""
|
||||
self.add_request(
|
||||
method="GET",
|
||||
path="/v1/common/symbols",
|
||||
callback=self.on_query_contract
|
||||
)
|
||||
|
||||
def send_order(self, req: OrderRequest):
|
||||
""""""
|
||||
huobi_type = ORDERTYPE_VT2HUOBI.get(
|
||||
(req.direction, req.type), ""
|
||||
)
|
||||
|
||||
local_orderid = self.order_manager.new_local_orderid()
|
||||
order = req.create_order_data(
|
||||
local_orderid,
|
||||
self.gateway_name
|
||||
)
|
||||
order.time = datetime.now().strftime("%H:%M:%S")
|
||||
|
||||
data = {
|
||||
"account-id": self.account_id,
|
||||
"amount": str(req.volume),
|
||||
"symbol": req.symbol,
|
||||
"type": huobi_type,
|
||||
"price": str(req.price),
|
||||
"source": "api"
|
||||
}
|
||||
|
||||
self.add_request(
|
||||
method="POST",
|
||||
path="/v1/order/orders/place",
|
||||
callback=self.on_send_order,
|
||||
data=data,
|
||||
extra=order,
|
||||
)
|
||||
|
||||
self.order_manager.on_order(order)
|
||||
return order.vt_orderid
|
||||
|
||||
def cancel_order(self, req: CancelRequest):
|
||||
""""""
|
||||
sys_orderid = self.order_manager.get_sys_orderid(req.orderid)
|
||||
|
||||
path = f"/v1/order/orders/{sys_orderid}/submitcancel"
|
||||
self.add_request(
|
||||
method="POST",
|
||||
path=path,
|
||||
callback=self.on_cancel_order,
|
||||
extra=req
|
||||
)
|
||||
|
||||
def on_query_account(self, data, request):
|
||||
""""""
|
||||
if self.check_error(data, "查询账户"):
|
||||
return
|
||||
|
||||
for d in data["data"]:
|
||||
if d["type"] == "spot":
|
||||
self.account_id = d["id"]
|
||||
self.gateway.write_log(f"账户代码{self.account_id}查询成功")
|
||||
|
||||
self.query_account_balance()
|
||||
|
||||
def on_query_account_balance(self, data, request):
|
||||
""""""
|
||||
if self.check_error(data, "查询账户资金"):
|
||||
return
|
||||
|
||||
buf = {}
|
||||
for d in data["data"]["list"]:
|
||||
currency = d["currency"]
|
||||
currency_data = buf.setdefault(currency, {})
|
||||
currency_data[d["type"]] = float(d["balance"])
|
||||
|
||||
for currency, currency_data in buf.items():
|
||||
account = AccountData(
|
||||
accountid=currency,
|
||||
balance=currency_data["trade"] + currency_data["frozen"],
|
||||
frozen=currency_data["frozen"],
|
||||
gateway_name=self.gateway_name,
|
||||
)
|
||||
|
||||
if account.balance:
|
||||
self.gateway.on_account(account)
|
||||
|
||||
def on_query_order(self, data, request):
|
||||
""""""
|
||||
if self.check_error(data, "查询委托"):
|
||||
return
|
||||
|
||||
for d in data["data"]:
|
||||
sys_orderid = d["id"]
|
||||
local_orderid = self.order_manager.get_local_orderid(sys_orderid)
|
||||
|
||||
direction, order_type = ORDERTYPE_HUOBI2VT[d["type"]]
|
||||
dt = datetime.fromtimestamp(d["created-at"] / 1000)
|
||||
time = dt.strftime("%H:%M:%S")
|
||||
|
||||
order = OrderData(
|
||||
orderid=local_orderid,
|
||||
symbol=d["symbol"],
|
||||
exchange=Exchange.HUOBI,
|
||||
price=float(d["price"]),
|
||||
volume=float(d["amount"]),
|
||||
type=order_type,
|
||||
direction=direction,
|
||||
traded=float(d["filled-amount"]),
|
||||
status=STATUS_HUOBI2VT.get(d["state"], None),
|
||||
time=time,
|
||||
gateway_name=self.gateway_name,
|
||||
)
|
||||
|
||||
self.order_manager.on_order(order)
|
||||
|
||||
self.gateway.write_log("委托信息查询成功")
|
||||
|
||||
def on_query_contract(self, data, request): # type: (dict, Request)->None
|
||||
""""""
|
||||
if self.check_error(data, "查询合约"):
|
||||
return
|
||||
|
||||
for d in data["data"]:
|
||||
base_currency = d["base-currency"]
|
||||
quote_currency = d["quote-currency"]
|
||||
name = f"{base_currency.upper()}/{quote_currency.upper()}"
|
||||
pricetick = 1 / pow(10, d["price-precision"])
|
||||
size = 1 / pow(10, d["amount-precision"])
|
||||
|
||||
contract = ContractData(
|
||||
symbol=d["symbol"],
|
||||
exchange=Exchange.HUOBI,
|
||||
name=name,
|
||||
pricetick=pricetick,
|
||||
size=size,
|
||||
product=Product.SPOT,
|
||||
gateway_name=self.gateway_name,
|
||||
)
|
||||
self.gateway.on_contract(contract)
|
||||
|
||||
huobi_symbols.add(contract.symbol)
|
||||
symbol_name_map[contract.symbol] = contract.name
|
||||
|
||||
self.gateway.write_log("合约信息查询成功")
|
||||
|
||||
def on_send_order(self, data, request):
|
||||
""""""
|
||||
order = request.extra
|
||||
|
||||
if self.check_error(data, "委托"):
|
||||
order.status = Status.REJECTED
|
||||
self.order_manager.on_order(order)
|
||||
return
|
||||
|
||||
sys_orderid = data["data"]
|
||||
self.order_manager.update_orderid_map(order.orderid, sys_orderid)
|
||||
|
||||
def on_cancel_order(self, data, request):
|
||||
""""""
|
||||
if self.check_error(data, "撤单"):
|
||||
return
|
||||
|
||||
cancel_request = request.extra
|
||||
local_orderid = cancel_request.orderid
|
||||
|
||||
order = self.order_manager.get_order_with_local_orderid(local_orderid)
|
||||
order.status = Status.CANCELLED
|
||||
|
||||
self.order_manager.on_order(order)
|
||||
self.gateway.write_log(f"委托撤单成功:{order.orderid}")
|
||||
|
||||
def check_error(self, data: dict, func: str = ""):
|
||||
""""""
|
||||
if data["status"] != "error":
|
||||
return False
|
||||
|
||||
error_code = data["err-code"]
|
||||
error_msg = data["err-msg"]
|
||||
|
||||
self.gateway.write_log(f"{func}请求出错,代码:{error_code},信息:{error_msg}")
|
||||
return True
|
||||
|
||||
|
||||
class HuobiWebsocketApiBase(WebsocketClient):
|
||||
""""""
|
||||
|
||||
def __init__(self, gateway):
|
||||
""""""
|
||||
super(HuobiWebsocketApiBase, self).__init__()
|
||||
|
||||
self.gateway = gateway
|
||||
self.gateway_name = gateway.gateway_name
|
||||
|
||||
self.key = ""
|
||||
self.secret = ""
|
||||
self.sign_host = ""
|
||||
self.path = ""
|
||||
|
||||
def connect(
|
||||
self,
|
||||
key: str,
|
||||
secret: str,
|
||||
url: str,
|
||||
proxy_host: str,
|
||||
proxy_port: int
|
||||
):
|
||||
""""""
|
||||
self.key = key
|
||||
self.secret = secret
|
||||
|
||||
host, path = _split_url(url)
|
||||
self.sign_host = host
|
||||
self.path = path
|
||||
|
||||
self.init(url, proxy_host, proxy_port)
|
||||
self.start()
|
||||
|
||||
def login(self):
|
||||
""""""
|
||||
params = {"op": "auth"}
|
||||
params.update(create_signature(self.key, "GET", self.sign_host, self.path, self.secret))
|
||||
return self.send_packet(params)
|
||||
|
||||
def on_login(self, packet):
|
||||
""""""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def unpack_data(data):
|
||||
""""""
|
||||
return json.loads(zlib.decompress(data, 31))
|
||||
|
||||
def on_packet(self, packet):
|
||||
""""""
|
||||
if "ping" in packet:
|
||||
self.send_packet({"pong": packet["ping"]})
|
||||
elif "err-msg" in packet:
|
||||
return self.on_error_msg(packet)
|
||||
elif "op" in packet and packet["op"] == "auth":
|
||||
return self.on_login()
|
||||
else:
|
||||
self.on_data(packet)
|
||||
|
||||
def on_data(self, packet):
|
||||
""""""
|
||||
print("data : {}".format(packet))
|
||||
|
||||
def on_error_msg(self, packet):
|
||||
""""""
|
||||
msg = packet["err-msg"]
|
||||
if msg == "invalid pong":
|
||||
return
|
||||
|
||||
self.gateway.write_log(packet["err-msg"])
|
||||
|
||||
|
||||
class HuobiTradeWebsocketApi(HuobiWebsocketApiBase):
|
||||
""""""
|
||||
def __init__(self, gateway):
|
||||
""""""
|
||||
super().__init__(gateway)
|
||||
|
||||
self.order_manager = gateway.order_manager
|
||||
self.order_manager.push_data_callback = self.on_data
|
||||
|
||||
self.req_id = 0
|
||||
|
||||
def connect(self, key, secret, proxy_host, proxy_port):
|
||||
""""""
|
||||
super().connect(key, secret, WEBSOCKET_TRADE_HOST, proxy_host, proxy_port)
|
||||
|
||||
def subscribe(self, req: SubscribeRequest):
|
||||
""""""
|
||||
self.req_id += 1
|
||||
req = {
|
||||
"op": "sub",
|
||||
"cid": str(self.req_id),
|
||||
"topic": f"orders.{req.symbol}"
|
||||
}
|
||||
self.send_packet(req)
|
||||
|
||||
def on_connected(self):
|
||||
""""""
|
||||
self.gateway.write_log("交易Websocket API连接成功")
|
||||
self.login()
|
||||
|
||||
def on_login(self):
|
||||
""""""
|
||||
self.gateway.write_log("交易Websocket API登录成功")
|
||||
|
||||
def on_data(self, packet): # type: (dict)->None
|
||||
""""""
|
||||
op = packet.get("op", None)
|
||||
if op != "notify":
|
||||
return
|
||||
|
||||
topic = packet["topic"]
|
||||
if "orders" in topic:
|
||||
self.on_order(packet["data"])
|
||||
|
||||
def on_order(self, data: dict):
|
||||
""""""
|
||||
sys_orderid = str(data["order-id"])
|
||||
|
||||
order = self.order_manager.get_order_with_sys_orderid(sys_orderid)
|
||||
if not order:
|
||||
self.order_manager.add_push_data(sys_orderid, data)
|
||||
return
|
||||
|
||||
traded_volume = float(data["filled-amount"])
|
||||
|
||||
# Push order event
|
||||
order.traded += traded_volume
|
||||
order.status = STATUS_HUOBI2VT.get(data["order-state"], None)
|
||||
self.order_manager.on_order(order)
|
||||
|
||||
# Push trade event
|
||||
if not traded_volume:
|
||||
return
|
||||
|
||||
trade = TradeData(
|
||||
symbol=order.symbol,
|
||||
exchange=Exchange.HUOBI,
|
||||
orderid=order.orderid,
|
||||
tradeid=str(data["seq-id"]),
|
||||
direction=order.direction,
|
||||
price=float(data["price"]),
|
||||
volume=float(data["filled-amount"]),
|
||||
time=datetime.now().strftime("%H:%M:%S"),
|
||||
gateway_name=self.gateway_name,
|
||||
)
|
||||
self.gateway.on_trade(trade)
|
||||
|
||||
|
||||
class HuobiDataWebsocketApi(HuobiWebsocketApiBase):
|
||||
""""""
|
||||
|
||||
def __init__(self, gateway):
|
||||
""""""
|
||||
super().__init__(gateway)
|
||||
|
||||
self.req_id = 0
|
||||
self.ticks = {}
|
||||
|
||||
def connect(self, key: str, secret: str, proxy_host: str, proxy_port: int):
|
||||
""""""
|
||||
super().connect(key, secret, WEBSOCKET_DATA_HOST, proxy_host, proxy_port)
|
||||
|
||||
def on_connected(self):
|
||||
""""""
|
||||
self.gateway.write_log("行情Websocket API连接成功")
|
||||
|
||||
def subscribe(self, req: SubscribeRequest):
|
||||
""""""
|
||||
symbol = req.symbol
|
||||
|
||||
# Create tick data buffer
|
||||
tick = TickData(
|
||||
symbol=symbol,
|
||||
name=symbol_name_map.get(symbol, ""),
|
||||
exchange=Exchange.HUOBI,
|
||||
datetime=datetime.now(),
|
||||
gateway_name=self.gateway_name,
|
||||
)
|
||||
self.ticks[symbol] = tick
|
||||
|
||||
# Subscribe to market depth update
|
||||
self.req_id += 1
|
||||
req = {
|
||||
"sub": f"market.{symbol}.depth.step0",
|
||||
"id": str(self.req_id)
|
||||
}
|
||||
self.send_packet(req)
|
||||
|
||||
# Subscribe to market detail update
|
||||
self.req_id += 1
|
||||
req = {
|
||||
"sub": f"market.{symbol}.detail",
|
||||
"id": str(self.req_id)
|
||||
}
|
||||
self.send_packet(req)
|
||||
|
||||
def on_data(self, packet): # type: (dict)->None
|
||||
""""""
|
||||
channel = packet.get("ch", None)
|
||||
if channel:
|
||||
if "depth.step" in channel:
|
||||
self.on_market_depth(packet)
|
||||
elif "detail" in channel:
|
||||
self.on_market_detail(packet)
|
||||
elif "err-code" in packet:
|
||||
code = packet["err-code"]
|
||||
msg = packet["err-msg"]
|
||||
self.gateway.write_log(f"错误代码:{code}, 错误信息:{msg}")
|
||||
|
||||
def on_market_depth(self, data):
|
||||
"""行情深度推送 """
|
||||
symbol = data["ch"].split(".")[1]
|
||||
tick = self.ticks[symbol]
|
||||
tick.datetime = datetime.fromtimestamp(data["ts"] / 1000)
|
||||
|
||||
bids = data["tick"]["bids"]
|
||||
for n in range(5):
|
||||
price, volume = bids[n]
|
||||
tick.__setattr__("bid_price_" + str(n + 1), float(price))
|
||||
tick.__setattr__("bid_volume_" + str(n + 1), float(volume))
|
||||
|
||||
asks = data["tick"]["asks"]
|
||||
for n in range(5):
|
||||
price, volume = asks[n]
|
||||
tick.__setattr__("ask_price_" + str(n + 1), float(price))
|
||||
tick.__setattr__("ask_volume_" + str(n + 1), float(volume))
|
||||
|
||||
if tick.last_price:
|
||||
self.gateway.on_tick(copy(tick))
|
||||
|
||||
def on_market_detail(self, data):
|
||||
"""市场细节推送"""
|
||||
symbol = data["ch"].split(".")[1]
|
||||
tick = self.ticks[symbol]
|
||||
tick.datetime = datetime.fromtimestamp(data["ts"] / 1000)
|
||||
|
||||
tick_data = data["tick"]
|
||||
tick.open_price = float(tick_data["open"])
|
||||
tick.high_price = float(tick_data["high"])
|
||||
tick.low_price = float(tick_data["low"])
|
||||
tick.last_price = float(tick_data["close"])
|
||||
tick.volume = float(tick_data["vol"])
|
||||
|
||||
if tick.bid_price_1:
|
||||
self.gateway.on_tick(copy(tick))
|
||||
|
||||
|
||||
def _split_url(url):
|
||||
"""
|
||||
将url拆分为host和path
|
||||
:return: host, path
|
||||
"""
|
||||
result = re.match("\w+://([^/]*)(.*)", url) # noqa
|
||||
if result:
|
||||
return result.group(1), result.group(2)
|
||||
|
||||
|
||||
def create_signature(api_key, method, host, path, secret_key, get_params=None):
|
||||
"""
|
||||
创建签名
|
||||
:param get_params: dict 使用GET方法时附带的额外参数(urlparams)
|
||||
:return:
|
||||
"""
|
||||
sorted_params = [
|
||||
("AccessKeyId", api_key),
|
||||
("SignatureMethod", "HmacSHA256"),
|
||||
("SignatureVersion", "2"),
|
||||
("Timestamp", datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S"))
|
||||
]
|
||||
|
||||
if get_params:
|
||||
sorted_params.extend(list(get_params.items()))
|
||||
sorted_params = list(sorted(sorted_params))
|
||||
encode_params = urllib.parse.urlencode(sorted_params)
|
||||
|
||||
payload = [method, host, path, encode_params]
|
||||
payload = "\n".join(payload)
|
||||
payload = payload.encode(encoding="UTF8")
|
||||
|
||||
secret_key = secret_key.encode(encoding="UTF8")
|
||||
|
||||
digest = hmac.new(secret_key, payload, digestmod=hashlib.sha256).digest()
|
||||
signature = base64.b64encode(digest)
|
||||
|
||||
params = dict(sorted_params)
|
||||
params["Signature"] = signature.decode("UTF8")
|
||||
return params
|
1
vnpy/gateway/okex/__init__.py
Normal file
1
vnpy/gateway/okex/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .okex_gateway import OkexGateway
|
699
vnpy/gateway/okex/okex_gateway.py
Normal file
699
vnpy/gateway/okex/okex_gateway.py
Normal file
@ -0,0 +1,699 @@
|
||||
# encoding: UTF-8
|
||||
"""
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import sys
|
||||
import time
|
||||
import json
|
||||
import base64
|
||||
import zlib
|
||||
from copy import copy
|
||||
from datetime import datetime
|
||||
from threading import Lock
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from requests import ConnectionError
|
||||
|
||||
from vnpy.api.rest import Request, RestClient
|
||||
from vnpy.api.websocket import WebsocketClient
|
||||
from vnpy.trader.constant import (
|
||||
Direction,
|
||||
Exchange,
|
||||
OrderType,
|
||||
Product,
|
||||
Status
|
||||
)
|
||||
from vnpy.trader.gateway import BaseGateway
|
||||
from vnpy.trader.object import (
|
||||
TickData,
|
||||
OrderData,
|
||||
TradeData,
|
||||
AccountData,
|
||||
ContractData,
|
||||
OrderRequest,
|
||||
CancelRequest,
|
||||
SubscribeRequest,
|
||||
)
|
||||
|
||||
REST_HOST = "https://www.okex.com"
|
||||
WEBSOCKET_HOST = "wss://real.okex.com:10442/ws/v3"
|
||||
|
||||
STATUS_OKEX2VT = {
|
||||
"ordering": Status.SUBMITTING,
|
||||
"open": Status.NOTTRADED,
|
||||
"part_filled": Status.PARTTRADED,
|
||||
"filled": Status.ALLTRADED,
|
||||
"cancelled": Status.CANCELLED,
|
||||
"cancelling": Status.CANCELLED,
|
||||
"failure": Status.REJECTED,
|
||||
}
|
||||
|
||||
DIRECTION_VT2OKEX = {Direction.LONG: "buy", Direction.SHORT: "sell"}
|
||||
DIRECTION_OKEX2VT = {v: k for k, v in DIRECTION_VT2OKEX.items()}
|
||||
|
||||
ORDERTYPE_VT2OKEX = {
|
||||
OrderType.LIMIT: "limit",
|
||||
OrderType.MARKET: "market"
|
||||
}
|
||||
ORDERTYPE_OKEX2VT = {v: k for k, v in ORDERTYPE_VT2OKEX.items()}
|
||||
|
||||
|
||||
instruments = set()
|
||||
currencies = set()
|
||||
|
||||
|
||||
class OkexGateway(BaseGateway):
|
||||
"""
|
||||
VN Trader Gateway for OKEX connection.
|
||||
"""
|
||||
|
||||
default_setting = {
|
||||
"API Key": "",
|
||||
"Secret Key": "",
|
||||
"Passphrase": "",
|
||||
"会话数": 3,
|
||||
"代理地址": "127.0.0.1",
|
||||
"代理端口": 1080,
|
||||
}
|
||||
|
||||
def __init__(self, event_engine):
|
||||
"""Constructor"""
|
||||
super(OkexGateway, self).__init__(event_engine, "OKEX")
|
||||
|
||||
self.rest_api = OkexRestApi(self)
|
||||
self.ws_api = OkexWebsocketApi(self)
|
||||
|
||||
def connect(self, setting: dict):
|
||||
""""""
|
||||
key = setting["API Key"]
|
||||
secret = setting["Secret Key"]
|
||||
passphrase = setting["Passphrase"]
|
||||
session_number = setting["会话数"]
|
||||
proxy_host = setting["代理地址"]
|
||||
proxy_port = setting["代理端口"]
|
||||
|
||||
self.rest_api.connect(key, secret, passphrase,
|
||||
session_number, proxy_host, proxy_port)
|
||||
|
||||
self.ws_api.connect(key, secret, passphrase, proxy_host, proxy_port)
|
||||
|
||||
def subscribe(self, req: SubscribeRequest):
|
||||
""""""
|
||||
self.ws_api.subscribe(req)
|
||||
|
||||
def send_order(self, req: OrderRequest):
|
||||
""""""
|
||||
return self.rest_api.send_order(req)
|
||||
|
||||
def cancel_order(self, req: CancelRequest):
|
||||
""""""
|
||||
self.rest_api.cancel_order(req)
|
||||
|
||||
def query_account(self):
|
||||
""""""
|
||||
pass
|
||||
|
||||
def query_position(self):
|
||||
""""""
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
""""""
|
||||
self.rest_api.stop()
|
||||
self.ws_api.stop()
|
||||
|
||||
|
||||
class OkexRestApi(RestClient):
|
||||
"""
|
||||
OKEX REST API
|
||||
"""
|
||||
|
||||
def __init__(self, gateway: BaseGateway):
|
||||
""""""
|
||||
super(OkexRestApi, self).__init__()
|
||||
|
||||
self.gateway = gateway
|
||||
self.gateway_name = gateway.gateway_name
|
||||
|
||||
self.key = ""
|
||||
self.secret = ""
|
||||
self.passphrase = ""
|
||||
|
||||
self.order_count = 10000
|
||||
self.order_count_lock = Lock()
|
||||
|
||||
self.connect_time = 0
|
||||
|
||||
def sign(self, request):
|
||||
"""
|
||||
Generate OKEX signature.
|
||||
"""
|
||||
# Sign
|
||||
# timestamp = str(time.time())
|
||||
timestamp = get_timestamp()
|
||||
request.data = json.dumps(request.data)
|
||||
|
||||
if request.params:
|
||||
path = request.path + '?' + urlencode(request.params)
|
||||
else:
|
||||
path = request.path
|
||||
|
||||
msg = timestamp + request.method + path + request.data
|
||||
signature = generate_signature(msg, self.secret)
|
||||
|
||||
# Add headers
|
||||
request.headers = {
|
||||
'OK-ACCESS-KEY': self.key,
|
||||
'OK-ACCESS-SIGN': signature,
|
||||
'OK-ACCESS-TIMESTAMP': timestamp,
|
||||
'OK-ACCESS-PASSPHRASE': self.passphrase,
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
return request
|
||||
|
||||
def connect(
|
||||
self,
|
||||
key: str,
|
||||
secret: str,
|
||||
passphrase: str,
|
||||
session_number: int,
|
||||
proxy_host: str,
|
||||
proxy_port: int,
|
||||
):
|
||||
"""
|
||||
Initialize connection to REST server.
|
||||
"""
|
||||
self.key = key
|
||||
self.secret = secret.encode()
|
||||
self.passphrase = passphrase
|
||||
|
||||
self.connect_time = int(datetime.now().strftime("%y%m%d%H%M%S"))
|
||||
|
||||
self.init(REST_HOST, proxy_host, proxy_port)
|
||||
self.start(session_number)
|
||||
self.gateway.write_log("REST API启动成功")
|
||||
|
||||
self.query_time()
|
||||
self.query_contract()
|
||||
self.query_account()
|
||||
self.query_order()
|
||||
|
||||
def _new_order_id(self):
|
||||
with self.order_count_lock:
|
||||
self.order_count += 1
|
||||
return self.order_count
|
||||
|
||||
def send_order(self, req: OrderRequest):
|
||||
""""""
|
||||
orderid = f"a{self.connect_time}{self._new_order_id()}"
|
||||
|
||||
data = {
|
||||
"client_oid": orderid,
|
||||
"type": ORDERTYPE_VT2OKEX[req.type],
|
||||
"side": DIRECTION_VT2OKEX[req.direction],
|
||||
"instrument_id": req.symbol
|
||||
}
|
||||
|
||||
if req.type == OrderType.MARKET:
|
||||
if req.direction == Direction.LONG:
|
||||
data["notional"] = req.volume
|
||||
else:
|
||||
data["size"] = req.volume
|
||||
else:
|
||||
data["price"] = req.price
|
||||
data["size"] = req.volume
|
||||
|
||||
order = req.create_order_data(orderid, self.gateway_name)
|
||||
|
||||
self.add_request(
|
||||
"POST",
|
||||
"/api/spot/v3/orders",
|
||||
callback=self.on_send_order,
|
||||
data=data,
|
||||
extra=order,
|
||||
on_failed=self.on_send_order_failed,
|
||||
on_error=self.on_send_order_error,
|
||||
)
|
||||
|
||||
self.gateway.on_order(order)
|
||||
return order.vt_orderid
|
||||
|
||||
def cancel_order(self, req: CancelRequest):
|
||||
""""""
|
||||
data = {
|
||||
"instrument_id": req.symbol,
|
||||
"client_oid": req.orderid
|
||||
}
|
||||
|
||||
path = "/api/spot/v3/cancel_orders/" + req.orderid
|
||||
self.add_request(
|
||||
"POST",
|
||||
path,
|
||||
callback=self.on_cancel_order,
|
||||
data=data,
|
||||
on_error=self.on_cancel_order_error,
|
||||
)
|
||||
|
||||
def query_contract(self):
|
||||
""""""
|
||||
self.add_request(
|
||||
"GET",
|
||||
"/api/spot/v3/instruments",
|
||||
callback=self.on_query_contract
|
||||
)
|
||||
|
||||
def query_account(self):
|
||||
""""""
|
||||
self.add_request(
|
||||
"GET",
|
||||
"/api/spot/v3/accounts",
|
||||
callback=self.on_query_account
|
||||
)
|
||||
|
||||
def query_order(self):
|
||||
""""""
|
||||
self.add_request(
|
||||
"GET",
|
||||
"/api/spot/v3/orders_pending",
|
||||
callback=self.on_query_order
|
||||
)
|
||||
|
||||
def query_time(self):
|
||||
""""""
|
||||
self.add_request(
|
||||
"GET",
|
||||
"/api/general/v3/time",
|
||||
callback=self.on_query_time
|
||||
)
|
||||
|
||||
def on_query_contract(self, data, request):
|
||||
""""""
|
||||
for instrument_data in data:
|
||||
symbol = instrument_data["instrument_id"]
|
||||
contract = ContractData(
|
||||
symbol=symbol,
|
||||
exchange=Exchange.OKEX,
|
||||
name=symbol,
|
||||
product=Product.SPOT,
|
||||
size=1,
|
||||
pricetick=instrument_data["tick_size"],
|
||||
gateway_name=self.gateway_name
|
||||
)
|
||||
self.gateway.on_contract(contract)
|
||||
|
||||
instruments.add(instrument_data["instrument_id"])
|
||||
currencies.add(instrument_data["base_currency"])
|
||||
currencies.add(instrument_data["quote_currency"])
|
||||
|
||||
self.gateway.write_log("合约信息查询成功")
|
||||
|
||||
# Start websocket api after instruments data collected
|
||||
self.gateway.ws_api.start()
|
||||
|
||||
def on_query_account(self, data, request):
|
||||
""""""
|
||||
for account_data in data:
|
||||
account = AccountData(
|
||||
accountid=account_data["currency"],
|
||||
balance=float(account_data["balance"]),
|
||||
frozen=float(account_data["hold"]),
|
||||
gateway_name=self.gateway_name
|
||||
)
|
||||
self.gateway.on_account(account)
|
||||
|
||||
self.gateway.write_log("账户资金查询成功")
|
||||
|
||||
def on_query_order(self, data, request):
|
||||
""""""
|
||||
for order_data in data:
|
||||
order = OrderData(
|
||||
symbol=order_data["instrument_id"],
|
||||
exchange=Exchange.OKEX,
|
||||
type=ORDERTYPE_OKEX2VT[order_data["type"]],
|
||||
orderid=order_data["client_oid"],
|
||||
direction=DIRECTION_OKEX2VT[order_data["side"]],
|
||||
price=float(order_data["price"]),
|
||||
volume=float(order_data["size"]),
|
||||
time=order_data["timestamp"][11:19],
|
||||
status=STATUS_OKEX2VT[order_data["status"]],
|
||||
gateway_name=self.gateway_name,
|
||||
)
|
||||
self.gateway.on_order(order)
|
||||
|
||||
self.gateway.write_log("委托信息查询成功")
|
||||
|
||||
def on_query_time(self, data, request):
|
||||
""""""
|
||||
server_time = data["iso"]
|
||||
local_time = datetime.utcnow().isoformat()
|
||||
msg = f"服务器时间:{server_time},本机时间:{local_time}"
|
||||
self.gateway.write_log(msg)
|
||||
|
||||
def on_send_order_failed(self, status_code: str, request: Request):
|
||||
"""
|
||||
Callback when sending order failed on server.
|
||||
"""
|
||||
order = request.extra
|
||||
order.status = Status.REJECTED
|
||||
self.gateway.on_order(order)
|
||||
|
||||
msg = f"委托失败,状态码:{status_code},信息:{request.response.text}"
|
||||
self.gateway.write_log(msg)
|
||||
|
||||
def on_send_order_error(
|
||||
self, exception_type: type, exception_value: Exception, tb, request: Request
|
||||
):
|
||||
"""
|
||||
Callback when sending order caused exception.
|
||||
"""
|
||||
order = request.extra
|
||||
order.status = Status.REJECTED
|
||||
self.gateway.on_order(order)
|
||||
|
||||
# Record exception if not ConnectionError
|
||||
if not issubclass(exception_type, ConnectionError):
|
||||
self.on_error(exception_type, exception_value, tb, request)
|
||||
|
||||
def on_send_order(self, data, request):
|
||||
"""Websocket will push a new order status"""
|
||||
order = request.extra
|
||||
|
||||
error_msg = data["error_message"]
|
||||
if error_msg:
|
||||
order.status = Status.REJECTED
|
||||
self.gateway.on_order(order)
|
||||
|
||||
self.gateway.write_log(f"委托失败:{error_msg}")
|
||||
|
||||
def on_cancel_order_error(
|
||||
self, exception_type: type, exception_value: Exception, tb, request: Request
|
||||
):
|
||||
"""
|
||||
Callback when cancelling order failed on server.
|
||||
"""
|
||||
# Record exception if not ConnectionError
|
||||
if not issubclass(exception_type, ConnectionError):
|
||||
self.on_error(exception_type, exception_value, tb, request)
|
||||
|
||||
def on_cancel_order(self, data, request):
|
||||
"""Websocket will push a new order status"""
|
||||
pass
|
||||
|
||||
def on_failed(self, status_code: int, request: Request):
|
||||
"""
|
||||
Callback to handle request failed.
|
||||
"""
|
||||
msg = f"请求失败,状态码:{status_code},信息:{request.response.text}"
|
||||
self.gateway.write_log(msg)
|
||||
|
||||
def on_error(
|
||||
self, exception_type: type, exception_value: Exception, tb, request: Request
|
||||
):
|
||||
"""
|
||||
Callback to handler request exception.
|
||||
"""
|
||||
msg = f"触发异常,状态码:{exception_type},信息:{exception_value}"
|
||||
self.gateway.write_log(msg)
|
||||
|
||||
sys.stderr.write(
|
||||
self.exception_detail(exception_type, exception_value, tb, request)
|
||||
)
|
||||
|
||||
|
||||
class OkexWebsocketApi(WebsocketClient):
|
||||
""""""
|
||||
|
||||
def __init__(self, gateway):
|
||||
""""""
|
||||
super(OkexWebsocketApi, self).__init__()
|
||||
self.ping_interval = 20 # OKEX use 30 seconds for ping
|
||||
|
||||
self.gateway = gateway
|
||||
self.gateway_name = gateway.gateway_name
|
||||
|
||||
self.key = ""
|
||||
self.secret = ""
|
||||
self.passphrase = ""
|
||||
|
||||
self.trade_count = 10000
|
||||
self.connect_time = 0
|
||||
|
||||
self.callbacks = {}
|
||||
self.ticks = {}
|
||||
|
||||
def connect(
|
||||
self,
|
||||
key: str,
|
||||
secret: str,
|
||||
passphrase: str,
|
||||
proxy_host: str,
|
||||
proxy_port: int
|
||||
):
|
||||
""""""
|
||||
self.key = key
|
||||
self.secret = secret.encode()
|
||||
self.passphrase = passphrase
|
||||
|
||||
self.connect_time = int(datetime.now().strftime("%y%m%d%H%M%S"))
|
||||
|
||||
self.init(WEBSOCKET_HOST, proxy_host, proxy_port)
|
||||
# self.start()
|
||||
|
||||
def unpack_data(self, data):
|
||||
""""""
|
||||
return json.loads(zlib.decompress(data, -zlib.MAX_WBITS))
|
||||
|
||||
def subscribe(self, req: SubscribeRequest):
|
||||
"""
|
||||
Subscribe to tick data upate.
|
||||
"""
|
||||
tick = TickData(
|
||||
symbol=req.symbol,
|
||||
exchange=req.exchange,
|
||||
name=req.symbol,
|
||||
datetime=datetime.now(),
|
||||
gateway_name=self.gateway_name,
|
||||
)
|
||||
self.ticks[req.symbol] = tick
|
||||
|
||||
channel_ticker = f"spot/ticker:{req.symbol}"
|
||||
channel_depth = f"spot/depth5:{req.symbol}"
|
||||
|
||||
self.callbacks[channel_ticker] = self.on_ticker
|
||||
self.callbacks[channel_depth] = self.on_depth
|
||||
|
||||
req = {
|
||||
"op": "subscribe",
|
||||
"args": [channel_ticker, channel_depth]
|
||||
}
|
||||
self.send_packet(req)
|
||||
|
||||
def on_connected(self):
|
||||
""""""
|
||||
self.gateway.write_log("Websocket API连接成功")
|
||||
self.login()
|
||||
|
||||
def on_disconnected(self):
|
||||
""""""
|
||||
self.gateway.write_log("Websocket API连接断开")
|
||||
|
||||
def on_packet(self, packet: dict):
|
||||
""""""
|
||||
if "event" in packet:
|
||||
event = packet["event"]
|
||||
if event == "subscribe":
|
||||
return
|
||||
elif event == "error":
|
||||
msg = packet["message"]
|
||||
self.gateway.write_log(f"Websocket API请求异常:{msg}")
|
||||
elif event == "login":
|
||||
self.on_login(packet)
|
||||
else:
|
||||
channel = packet["table"]
|
||||
data = packet["data"]
|
||||
callback = self.callbacks.get(channel, None)
|
||||
|
||||
if callback:
|
||||
for d in data:
|
||||
callback(d)
|
||||
|
||||
def on_error(self, exception_type: type, exception_value: Exception, tb):
|
||||
""""""
|
||||
msg = f"触发异常,状态码:{exception_type},信息:{exception_value}"
|
||||
self.gateway.write_log(msg)
|
||||
|
||||
sys.stderr.write(self.exception_detail(
|
||||
exception_type, exception_value, tb))
|
||||
|
||||
def login(self):
|
||||
"""
|
||||
Need to login befores subscribe to websocket topic.
|
||||
"""
|
||||
timestamp = str(time.time())
|
||||
|
||||
msg = timestamp + 'GET' + '/users/self/verify'
|
||||
signature = generate_signature(msg, self.secret)
|
||||
|
||||
req = {
|
||||
"op": "login",
|
||||
"args": [
|
||||
self.key,
|
||||
self.passphrase,
|
||||
timestamp,
|
||||
signature.decode("utf-8")
|
||||
]
|
||||
}
|
||||
self.send_packet(req)
|
||||
self.callbacks['login'] = self.on_login
|
||||
|
||||
def subscribe_topic(self):
|
||||
"""
|
||||
Subscribe to all private topics.
|
||||
"""
|
||||
self.callbacks["spot/ticker"] = self.on_ticker
|
||||
self.callbacks["spot/depth5"] = self.on_depth
|
||||
self.callbacks["spot/account"] = self.on_account
|
||||
self.callbacks["spot/order"] = self.on_order
|
||||
|
||||
# Subscribe to order update
|
||||
channels = []
|
||||
for instrument_id in instruments:
|
||||
channel = f"spot/order:{instrument_id}"
|
||||
channels.append(channel)
|
||||
|
||||
req = {
|
||||
"op": "subscribe",
|
||||
"args": channels
|
||||
}
|
||||
self.send_packet(req)
|
||||
|
||||
# Subscribe to account update
|
||||
channels = []
|
||||
for currency in currencies:
|
||||
channel = f"spot/account:{currency}"
|
||||
channels.append(channel)
|
||||
|
||||
req = {
|
||||
"op": "subscribe",
|
||||
"args": channels
|
||||
}
|
||||
self.send_packet(req)
|
||||
|
||||
# Subscribe to BTC/USDT trade for keep connection alive
|
||||
req = {
|
||||
"op": "subscribe",
|
||||
"args": ["spot/trade:BTC-USDT"]
|
||||
}
|
||||
self.send_packet(req)
|
||||
|
||||
def on_login(self, data: dict):
|
||||
""""""
|
||||
success = data.get("success", False)
|
||||
|
||||
if success:
|
||||
self.gateway.write_log("Websocket API登录成功")
|
||||
self.subscribe_topic()
|
||||
else:
|
||||
self.gateway.write_log("Websocket API登录失败")
|
||||
|
||||
def on_ticker(self, d):
|
||||
""""""
|
||||
symbol = d["instrument_id"]
|
||||
tick = self.ticks.get(symbol, None)
|
||||
if not tick:
|
||||
return
|
||||
|
||||
tick.last_price = d["last"]
|
||||
tick.open = d["open_24h"]
|
||||
tick.high = d["high_24h"]
|
||||
tick.low = d["low_24h"]
|
||||
tick.volume = d["base_volume_24h"]
|
||||
tick.datetime = datetime.strptime(
|
||||
d["timestamp"], "%Y-%m-%dT%H:%M:%S.%fZ")
|
||||
self.gateway.on_tick(copy(tick))
|
||||
|
||||
def on_depth(self, d):
|
||||
""""""
|
||||
for tick_data in d:
|
||||
symbol = d["instrument_id"]
|
||||
tick = self.ticks.get(symbol, None)
|
||||
if not tick:
|
||||
return
|
||||
|
||||
bids = d["bids"]
|
||||
asks = d["asks"]
|
||||
for n, buf in enumerate(bids):
|
||||
price, volume, _ = buf
|
||||
tick.__setattr__("bid_price_%s" % (n + 1), price)
|
||||
tick.__setattr__("bid_volume_%s" % (n + 1), volume)
|
||||
|
||||
for n, buf in enumerate(asks):
|
||||
price, volume, _ = buf
|
||||
tick.__setattr__("ask_price_%s" % (n + 1), price)
|
||||
tick.__setattr__("ask_volume_%s" % (n + 1), volume)
|
||||
|
||||
tick.datetime = datetime.strptime(
|
||||
d["timestamp"], "%Y-%m-%dT%H:%M:%S.%fZ")
|
||||
self.gateway.on_tick(copy(tick))
|
||||
|
||||
def on_order(self, d):
|
||||
""""""
|
||||
order = OrderData(
|
||||
symbol=d["instrument_id"],
|
||||
exchange=Exchange.OKEX,
|
||||
type=ORDERTYPE_OKEX2VT[d["type"]],
|
||||
orderid=d["client_oid"],
|
||||
direction=DIRECTION_OKEX2VT[d["side"]],
|
||||
price=d["price"],
|
||||
volume=d["size"],
|
||||
traded=d["filled_size"],
|
||||
time=d["timestamp"][11:19],
|
||||
status=STATUS_OKEX2VT[d["status"]],
|
||||
gateway_name=self.gateway_name,
|
||||
)
|
||||
self.gateway.on_order(copy(order))
|
||||
|
||||
trade_volume = float(d.get("last_fill_qty", 0))
|
||||
if not trade_volume:
|
||||
return
|
||||
|
||||
self.trade_count += 1
|
||||
tradeid = f"{self.connect_time}{self.trade_count}"
|
||||
|
||||
trade = TradeData(
|
||||
symbol=order.symbol,
|
||||
exchange=order.exchange,
|
||||
orderid=order.orderid,
|
||||
tradeid=tradeid,
|
||||
direction=order.direction,
|
||||
price=float(d["last_fill_px"]),
|
||||
volume=float(trade_volume),
|
||||
time=d["last_fill_time"][11:19],
|
||||
gateway_name=self.gateway_name
|
||||
)
|
||||
self.gateway.on_trade(trade)
|
||||
|
||||
def on_account(self, d):
|
||||
""""""
|
||||
account = AccountData(
|
||||
accountid=d["currency"],
|
||||
balance=float(d["balance"]),
|
||||
frozen=float(d["hold"]),
|
||||
gateway_name=self.gateway_name
|
||||
)
|
||||
|
||||
self.gateway.on_account(copy(account))
|
||||
|
||||
|
||||
def generate_signature(msg: str, secret_key: str):
|
||||
"""OKEX V3 signature"""
|
||||
return base64.b64encode(hmac.new(secret_key, msg.encode(), hashlib.sha256).digest())
|
||||
|
||||
|
||||
def get_timestamp():
|
||||
""""""
|
||||
now = datetime.utcnow()
|
||||
timestamp = now.isoformat("T", "milliseconds")
|
||||
return timestamp + "Z"
|
@ -1,5 +1,6 @@
|
||||
# encoding: UTF-8
|
||||
"""
|
||||
Author: KeKe
|
||||
Please install tiger-api before use.
|
||||
pip install tigeropen
|
||||
"""
|
||||
@ -123,6 +124,9 @@ class TigerGateway(BaseGateway):
|
||||
self.contracts = {}
|
||||
self.symbol_names = {}
|
||||
|
||||
self.push_connected = False
|
||||
self.subscribed_symbols = set()
|
||||
|
||||
def run(self):
|
||||
""""""
|
||||
while self.active:
|
||||
@ -203,23 +207,33 @@ class TigerGateway(BaseGateway):
|
||||
"""
|
||||
protocol, host, port = self.client_config.socket_host_port
|
||||
self.push_client = PushClient(host, port, (protocol == 'ssl'))
|
||||
self.push_client.connect(
|
||||
self.client_config.tiger_id, self.client_config.private_key)
|
||||
|
||||
self.push_client.quote_changed = self.on_quote_change
|
||||
self.push_client.asset_changed = self.on_asset_change
|
||||
self.push_client.position_changed = self.on_position_change
|
||||
self.push_client.order_changed = self.on_order_change
|
||||
|
||||
self.write_log("推送接口连接成功")
|
||||
self.push_client.connect(
|
||||
self.client_config.tiger_id, self.client_config.private_key)
|
||||
|
||||
def subscribe(self, req: SubscribeRequest):
|
||||
""""""
|
||||
self.push_client.subscribe_quote([req.symbol])
|
||||
self.subscribed_symbols.add(req.symbol)
|
||||
|
||||
if self.push_connected:
|
||||
self.push_client.subscribe_quote([req.symbol])
|
||||
|
||||
def on_push_connected(self):
|
||||
""""""
|
||||
self.push_connected = True
|
||||
self.write_log("推送接口连接成功")
|
||||
|
||||
self.push_client.subscribe_asset()
|
||||
self.push_client.subscribe_position()
|
||||
self.push_client.subscribe_order()
|
||||
|
||||
self.push_client.subscribe_quote(list(self.subscribed_symbols))
|
||||
|
||||
def on_quote_change(self, tiger_symbol: str, data: list, trading: bool):
|
||||
""""""
|
||||
data = dict(data)
|
||||
|
1
vnpy/rpc/__init__.py
Normal file
1
vnpy/rpc/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .vnrpc import RpcServer, RpcClient, RemoteException
|
36
vnpy/rpc/test_client.py
Normal file
36
vnpy/rpc/test_client.py
Normal file
@ -0,0 +1,36 @@
|
||||
from __future__ import print_function
|
||||
from __future__ import absolute_import
|
||||
from time import sleep
|
||||
|
||||
from vnpy.rpc import RpcClient
|
||||
|
||||
|
||||
class TestClient(RpcClient):
|
||||
"""
|
||||
Test RpcClient
|
||||
"""
|
||||
|
||||
def __init__(self, req_address, sub_address):
|
||||
"""
|
||||
Constructor
|
||||
"""
|
||||
super(TestClient, self).__init__(req_address, sub_address)
|
||||
|
||||
def callback(self, topic, data):
|
||||
"""
|
||||
Realize callable function
|
||||
"""
|
||||
print('client received topic:', topic, ', data:', data)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
req_address = 'tcp://localhost:2014'
|
||||
sub_address = 'tcp://localhost:0602'
|
||||
|
||||
tc = TestClient(req_address, sub_address)
|
||||
tc.subscribeTopic('')
|
||||
tc.start()
|
||||
|
||||
while 1:
|
||||
print(tc.add(1, 3))
|
||||
sleep(2)
|
40
vnpy/rpc/test_server.py
Normal file
40
vnpy/rpc/test_server.py
Normal file
@ -0,0 +1,40 @@
|
||||
from __future__ import print_function
|
||||
from __future__ import absolute_import
|
||||
from time import sleep, time
|
||||
|
||||
from vnpy.rpc import RpcServer
|
||||
|
||||
|
||||
class TestServer(RpcServer):
|
||||
"""
|
||||
Test RpcServer
|
||||
"""
|
||||
|
||||
def __init__(self, rep_address, pub_address):
|
||||
"""
|
||||
Constructor
|
||||
"""
|
||||
super(TestServer, self).__init__(rep_address, pub_address)
|
||||
|
||||
self.register(self.add)
|
||||
|
||||
def add(self, a, b):
|
||||
"""
|
||||
Test function
|
||||
"""
|
||||
print('receiving: %s, %s' % (a, b))
|
||||
return a + b
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
rep_address = 'tcp://*:2014'
|
||||
pub_address = 'tcp://*:0602'
|
||||
|
||||
ts = TestServer(rep_address, pub_address)
|
||||
ts.start()
|
||||
|
||||
while 1:
|
||||
content = 'current server time is %s' % time()
|
||||
print(content)
|
||||
ts.publish('test', content)
|
||||
sleep(2)
|
329
vnpy/rpc/vnrpc.py
Normal file
329
vnpy/rpc/vnrpc.py
Normal file
@ -0,0 +1,329 @@
|
||||
import threading
|
||||
import traceback
|
||||
import signal
|
||||
|
||||
import zmq
|
||||
from msgpack import packb, unpackb
|
||||
from json import dumps, loads
|
||||
|
||||
import pickle
|
||||
p_dumps = pickle.dumps
|
||||
p_loads = pickle.loads
|
||||
|
||||
|
||||
# Achieve Ctrl-c interrupt recv
|
||||
signal.signal(signal.SIGINT, signal.SIG_DFL)
|
||||
|
||||
|
||||
class RpcObject(object):
|
||||
"""
|
||||
Referred to serialization of packing and unpacking, we offer 3 tools:
|
||||
1) maspack: higher performance, but usually requires the installation of msgpack related tools;
|
||||
2) jason: Slightly lower performance but versatility is better, most programming languages have built-in libraries;
|
||||
3) cPickle: Lower performance and only can be used in Python, but it is very convenient to transfer Python objects directly.
|
||||
|
||||
Therefore, it is recommended to use msgpack.
|
||||
Use json, if you want to communicate with some languages without providing msgpack.
|
||||
Use cPickle, when the data being transferred contains many custom Python objects.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Constructor
|
||||
Use msgpack as default serialization tool
|
||||
"""
|
||||
self.use_msgpack()
|
||||
|
||||
def pack(self, data):
|
||||
""""""
|
||||
pass
|
||||
|
||||
def unpack(self, data):
|
||||
""""""
|
||||
pass
|
||||
|
||||
def __json_pack(self, data):
|
||||
"""
|
||||
Pack with json
|
||||
"""
|
||||
return dumps(data)
|
||||
|
||||
def __json_unpack(self, data):
|
||||
"""
|
||||
Unpack with json
|
||||
"""
|
||||
return loads(data)
|
||||
|
||||
def __msgpack_pack(self, data):
|
||||
"""
|
||||
Pack with msgpack
|
||||
"""
|
||||
return packb(data)
|
||||
|
||||
def __msgpack_unpack(self, data):
|
||||
"""
|
||||
Unpack with msgpack
|
||||
"""
|
||||
return unpackb(data)
|
||||
|
||||
def __pickle_pack(self, data):
|
||||
"""
|
||||
Pack with cPickle
|
||||
"""
|
||||
return p_dumps(data)
|
||||
|
||||
def __pickle_unpack(self, data):
|
||||
"""
|
||||
Unpack with cPickle
|
||||
"""
|
||||
return p_loads(data)
|
||||
|
||||
def use_json(self):
|
||||
"""
|
||||
Use json as serialization tool
|
||||
"""
|
||||
self.pack = self.__json_pack
|
||||
self.unpack = self.__json_unpack
|
||||
|
||||
def use_msgpack(self):
|
||||
"""
|
||||
Use msgpack as serialization tool
|
||||
"""
|
||||
self.pack = self.__msgpack_pack
|
||||
self.unpack = self.__msgpack_unpack
|
||||
|
||||
def use_pickle(self):
|
||||
"""
|
||||
Use cPickle as serialization tool
|
||||
"""
|
||||
self.pack = self.__pickle_pack
|
||||
self.unpack = self.__pickle_unpack
|
||||
|
||||
|
||||
class RpcServer(RpcObject):
|
||||
""""""
|
||||
|
||||
def __init__(self, rep_address, pub_address):
|
||||
"""
|
||||
Constructor
|
||||
"""
|
||||
super(RpcServer, self).__init__()
|
||||
|
||||
# Save functions dict: key is fuction name, value is fuction object
|
||||
self.__functions = {}
|
||||
|
||||
# Zmq port related
|
||||
self.__context = zmq.Context()
|
||||
|
||||
self.__socket_rep = self.__context.socket(
|
||||
zmq.REP) # Reply socket (Request–reply pattern)
|
||||
self.__socket_rep.bind(rep_address)
|
||||
|
||||
# Publish socket (Publish–subscribe pattern)
|
||||
self.__socket_pub = self.__context.socket(zmq.PUB)
|
||||
self.__socket_pub.bind(pub_address)
|
||||
|
||||
# Woker thread related
|
||||
self.__active = False # RpcServer status
|
||||
self.__thread = threading.Thread(target=self.run) # RpcServer thread
|
||||
|
||||
def start(self):
|
||||
"""
|
||||
Start RpcServer
|
||||
"""
|
||||
# Start RpcServer status
|
||||
self.__active = True
|
||||
|
||||
# Start RpcServer thread
|
||||
if not self.__thread.isAlive():
|
||||
self.__thread.start()
|
||||
|
||||
def stop(self, join=False):
|
||||
"""
|
||||
Stop RpcServer
|
||||
"""
|
||||
# Stop RpcServer status
|
||||
self.__active = False
|
||||
|
||||
# Wait for RpcServer thread to exit
|
||||
if join and self.__thread.isAlive():
|
||||
self.__thread.join()
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
Run RpcServer functions
|
||||
"""
|
||||
while self.__active:
|
||||
# Use poll to wait event arrival, waiting time is 1 second (1000 milliseconds)
|
||||
if not self.__socket_rep.poll(1000):
|
||||
continue
|
||||
|
||||
# Receive request data from Reply socket
|
||||
reqb = self.__socket_rep.recv()
|
||||
|
||||
# Unpack request by deserialization
|
||||
req = self.unpack(reqb)
|
||||
|
||||
# Get function name and parameters
|
||||
name, args, kwargs = req
|
||||
|
||||
# Try to get and execute callable function object; capture exception information if it fails
|
||||
name = name.decode("UTF-8")
|
||||
|
||||
try:
|
||||
func = self.__functions[name]
|
||||
r = func(*args, **kwargs)
|
||||
rep = [True, r]
|
||||
except Exception as e: # noqa
|
||||
rep = [False, traceback.format_exc()]
|
||||
|
||||
# Pack response by serialization
|
||||
repb = self.pack(rep)
|
||||
|
||||
# send callable response by Reply socket
|
||||
self.__socket_rep.send(repb)
|
||||
|
||||
def publish(self, topic, data):
|
||||
"""
|
||||
Publish data
|
||||
"""
|
||||
# Serialized data
|
||||
topic = bytes(topic, "UTF-8")
|
||||
datab = self.pack(data)
|
||||
|
||||
# Send data by Publish socket
|
||||
# topci must be ascii encoding
|
||||
self.__socket_pub.send_multipart([topic, datab])
|
||||
|
||||
def register(self, func):
|
||||
"""
|
||||
Register function
|
||||
"""
|
||||
self.__functions[func.__name__] = func
|
||||
|
||||
|
||||
class RpcClient(RpcObject):
|
||||
""""""
|
||||
|
||||
def __init__(self, req_address, sub_address):
|
||||
"""Constructor"""
|
||||
super(RpcClient, self).__init__()
|
||||
|
||||
# zmq port related
|
||||
self.__req_address = req_address
|
||||
self.__sub_address = sub_address
|
||||
|
||||
self.__context = zmq.Context()
|
||||
# Request socket (Request–reply pattern)
|
||||
self.__socket_req = self.__context.socket(zmq.REQ)
|
||||
# Subscribe socket (Publish–subscribe pattern)
|
||||
self.__socket_sub = self.__context.socket(zmq.SUB)
|
||||
|
||||
# Woker thread relate, used to process data pushed from server
|
||||
self.__active = False # RpcClient status
|
||||
self.__thread = threading.Thread(
|
||||
target=self.run) # RpcClient thread
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""
|
||||
Realize remote call function
|
||||
"""
|
||||
# Perform remote call task
|
||||
def dorpc(*args, **kwargs):
|
||||
# Generate request
|
||||
req = [name, args, kwargs]
|
||||
|
||||
# Pack request by serialization
|
||||
reqb = self.pack(req)
|
||||
|
||||
# Send request and wait for response
|
||||
self.__socket_req.send(reqb)
|
||||
repb = self.__socket_req.recv()
|
||||
|
||||
# Unpack response by deserialization
|
||||
rep = self.unpack(repb)
|
||||
|
||||
# Return response if successed; Trigger exception if failed
|
||||
if rep[0]:
|
||||
return rep[1]
|
||||
else:
|
||||
raise RemoteException(rep[1].decode("UTF-8"))
|
||||
|
||||
return dorpc
|
||||
|
||||
def start(self):
|
||||
"""
|
||||
Start RpcClient
|
||||
"""
|
||||
# Connect zmq port
|
||||
self.__socket_req.connect(self.__req_address)
|
||||
self.__socket_sub.connect(self.__sub_address)
|
||||
|
||||
# Start RpcClient status
|
||||
self.__active = True
|
||||
|
||||
# Start RpcClient thread
|
||||
if not self.__thread.isAlive():
|
||||
self.__thread.start()
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
Stop RpcClient
|
||||
"""
|
||||
# Stop RpcClient status
|
||||
self.__active = False
|
||||
|
||||
# Wait for RpcClient thread to exit
|
||||
if self.__thread.isAlive():
|
||||
self.__thread.join()
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
Run RpcClient function
|
||||
"""
|
||||
while self.__active:
|
||||
# Use poll to wait event arrival, waiting time is 1 second (1000 milliseconds)
|
||||
if not self.__socket_sub.poll(1000):
|
||||
continue
|
||||
|
||||
# Receive data from subscribe socket
|
||||
topic, datab = self.__socket_sub.recv_multipart()
|
||||
|
||||
# Unpack data by deserialization
|
||||
data = self.unpack(datab)
|
||||
|
||||
# Process data by callable function
|
||||
topic = topic.decode("UTF-8")
|
||||
|
||||
self.callback(topic, data)
|
||||
|
||||
def callback(self, topic, data):
|
||||
"""
|
||||
Callable function
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def subscribeTopic(self, topic):
|
||||
"""
|
||||
Subscribe data
|
||||
"""
|
||||
topic = bytes(topic, "UTF-8")
|
||||
self.__socket_sub.setsockopt(zmq.SUBSCRIBE, topic)
|
||||
|
||||
|
||||
class RemoteException(Exception):
|
||||
"""
|
||||
RPC remote exception
|
||||
"""
|
||||
|
||||
def __init__(self, value):
|
||||
"""
|
||||
Constructor
|
||||
"""
|
||||
self.__value = value
|
||||
|
||||
def __str__(self):
|
||||
"""
|
||||
Output error message
|
||||
"""
|
||||
return self.__value
|
@ -99,6 +99,8 @@ class Exchange(Enum):
|
||||
|
||||
# CryptoCurrency
|
||||
BITMEX = "BITMEX"
|
||||
OKEX = "OKEX"
|
||||
HUOBI = "HUOBI"
|
||||
|
||||
|
||||
class Currency(Enum):
|
||||
|
@ -52,6 +52,7 @@ class MainEngine:
|
||||
"""
|
||||
engine = engine_class(self, self.event_engine)
|
||||
self.engines[engine.engine_name] = engine
|
||||
return engine
|
||||
|
||||
def add_gateway(self, gateway_class: BaseGateway):
|
||||
"""
|
||||
@ -59,6 +60,7 @@ class MainEngine:
|
||||
"""
|
||||
gateway = gateway_class(self.event_engine)
|
||||
self.gateways[gateway.gateway_name] = gateway
|
||||
return gateway
|
||||
|
||||
def add_app(self, app_class: BaseApp):
|
||||
"""
|
||||
@ -67,7 +69,8 @@ class MainEngine:
|
||||
app = app_class()
|
||||
self.apps[app.app_name] = app
|
||||
|
||||
self.add_engine(app.engine_class)
|
||||
engine = self.add_engine(app.engine_class)
|
||||
return engine
|
||||
|
||||
def init_engines(self):
|
||||
"""
|
||||
|
@ -4,6 +4,7 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
from copy import copy
|
||||
|
||||
from vnpy.event import Event, EventEngine
|
||||
from .event import (
|
||||
@ -227,3 +228,124 @@ class BaseGateway(ABC):
|
||||
Return default setting dict.
|
||||
"""
|
||||
return self.default_setting
|
||||
|
||||
|
||||
class LocalOrderManager:
|
||||
"""
|
||||
Management tool to support use local order id for trading.
|
||||
"""
|
||||
|
||||
def __init__(self, gateway: BaseGateway):
|
||||
""""""
|
||||
self.gateway = gateway
|
||||
|
||||
# For generating local orderid
|
||||
self.order_prefix = ""
|
||||
self.order_count = 0
|
||||
self.orders = {} # local_orderid:order
|
||||
|
||||
# Map between local and system orderid
|
||||
self.local_sys_orderid_map = {}
|
||||
self.sys_local_orderid_map = {}
|
||||
|
||||
# Push order data buf
|
||||
self.push_data_buf = {} # sys_orderid:data
|
||||
|
||||
# Callback for processing push order data
|
||||
self.push_data_callback = None
|
||||
|
||||
# Cancel request buf
|
||||
self.cancel_request_buf = {} # local_orderid:req
|
||||
|
||||
def new_local_orderid(self):
|
||||
"""
|
||||
Generate a new local orderid.
|
||||
"""
|
||||
self.order_count += 1
|
||||
local_orderid = str(self.order_count).rjust(8, "0")
|
||||
return local_orderid
|
||||
|
||||
def get_local_orderid(self, sys_orderid: str):
|
||||
"""
|
||||
Get local orderid with sys orderid.
|
||||
"""
|
||||
local_orderid = self.sys_local_orderid_map.get(sys_orderid, "")
|
||||
|
||||
if not local_orderid:
|
||||
local_orderid = self.new_local_orderid()
|
||||
self.update_orderid_map(local_orderid, sys_orderid)
|
||||
|
||||
return local_orderid
|
||||
|
||||
def get_sys_orderid(self, local_orderid: str):
|
||||
"""
|
||||
Get sys orderid with local orderid.
|
||||
"""
|
||||
sys_orderid = self.local_sys_orderid_map.get(local_orderid, "")
|
||||
return sys_orderid
|
||||
|
||||
def update_orderid_map(self, local_orderid: str, sys_orderid: str):
|
||||
"""
|
||||
Update orderid map.
|
||||
"""
|
||||
self.sys_local_orderid_map[sys_orderid] = local_orderid
|
||||
self.local_sys_orderid_map[local_orderid] = sys_orderid
|
||||
|
||||
self.check_cancel_request(local_orderid)
|
||||
self.check_push_data(sys_orderid)
|
||||
|
||||
def check_push_data(self, sys_orderid: str):
|
||||
"""
|
||||
Check if any order push data waiting.
|
||||
"""
|
||||
if sys_orderid not in self.push_data_buf:
|
||||
return
|
||||
|
||||
data = self.push_data_buf.pop(sys_orderid)
|
||||
if self.push_data_callback:
|
||||
self.push_data_callback(data)
|
||||
|
||||
def add_push_data(self, sys_orderid: str, data: dict):
|
||||
"""
|
||||
Add push data into buf.
|
||||
"""
|
||||
self.push_data_buf[sys_orderid] = data
|
||||
|
||||
def get_order_with_sys_orderid(self, sys_orderid: str):
|
||||
""""""
|
||||
local_orderid = self.sys_local_orderid_map.get(sys_orderid, None)
|
||||
if not local_orderid:
|
||||
return None
|
||||
else:
|
||||
return self.get_order_with_local_orderid(local_orderid)
|
||||
|
||||
def get_order_with_local_orderid(self, local_orderid: str):
|
||||
""""""
|
||||
order = self.orders[local_orderid]
|
||||
return copy(order)
|
||||
|
||||
def on_order(self, order: OrderData):
|
||||
"""
|
||||
Keep an order buf before pushing it to gateway.
|
||||
"""
|
||||
self.orders[order.orderid] = copy(order)
|
||||
self.gateway.on_order(order)
|
||||
|
||||
def cancel_order(self, req: CancelRequest):
|
||||
"""
|
||||
"""
|
||||
sys_orderid = self.get_sys_orderid(req.orderid)
|
||||
if not sys_orderid:
|
||||
self.cancel_request_buf[req.orderid] = req
|
||||
return
|
||||
|
||||
self.gateway.cancel_order(req)
|
||||
|
||||
def check_cancel_request(self, local_orderid: str):
|
||||
"""
|
||||
"""
|
||||
if local_orderid not in self.cancel_request_buf:
|
||||
return
|
||||
|
||||
req = self.cancel_request_buf.pop(local_orderid)
|
||||
self.gateway.cancel_order(req)
|
||||
|
Loading…
Reference in New Issue
Block a user