diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md
index d0d205ee..f37c5876 100644
--- a/.github/ISSUE_TEMPLATE.md
+++ b/.github/ISSUE_TEMPLATE.md
@@ -1,8 +1,8 @@
## 环境
* 操作系统: 如Windows 10或者Ubuntu 18.04
-* Anaconda版本: 如Anaconda 18.12 Python 3.7 64位
-* vn.py版本: 如v2.0发行版或者dev branch 20190101(下载日期)
+* Python版本: 如VNStudio-2.0.6
+* vn.py版本: 如v2.0.5发行版或者dev branch 20190101(下载日期)
## Issue类型
三选一:Bug/Enhancement/Question
diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
deleted file mode 100644
index d12c40ca..00000000
--- a/.gitlab-ci.yml
+++ /dev/null
@@ -1,142 +0,0 @@
-# This file is a template, and might need editing before it works on your project.
-# Official language image. Look for the different tagged releases at:
-# https://hub.docker.com/r/library/python/tags/
-image: registry.cn-shanghai.aliyuncs.com/vnpy-ci/gcc-7-python-3.7:latest
-
-.services:
- services: &services
- - postgres:latest
- - mysql:latest
- - mongo:latest
-
-# Change pip's cache directory to be inside the project directory since we can
-# only cache local items.
-variables: &variables
- GIT_DEPTH: "1"
- PIP_CACHE_DIR: "$CI_PROJECT_DIR/.cache/pip"
- POSTGRES_DB: &db_name "vnpy"
- POSTGRES_USER: "postgres"
- POSTGRES_PASSWORD: &db_password "1234"
- VNPY_TEST_POSTGRESQL_PASSWORD: *db_password
- MYSQL_DATABASE: *db_name
- MYSQL_ROOT_PASSWORD: *db_password
- VNPY_TEST_MYSQL_PASSWORD: *db_password
- VNPY_BUILD_PARALLEL: "auto"
-
-# Pip's cache doesn't store the python packages
-# https://pip.pypa.io/en/stable/reference/pip_install/#caching
-#
-# If you want to also cache the installed packages, you have to install
-# them in a virtualenv and cache it as well.
-.default_cache:
- cache:
- <<: &cache
- key: "pip_and_venv"
- untracked: false
- policy: pull
- paths:
- - .cache/pip
- - venv/
-
-
-
-before_script:
- - echo $PWD
- - python -V
- - gcc --version
- - free
- - date
-
- # venv
- - pip install virtualenv
- - virtualenv venv
- - source venv/bin/activate
-
- # some envs
- - source ci/env.sh
-
-.scripts:
- script:
- - &install_scripts |
- date
- python -m pip --version
- python -m pip install --upgrade pip wheel setuptools
- python -m pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
- bash ci/gitlab_pre_install.sh
-
- date
- bash ./install.sh
- date
-
- - &test_scripts |
- date
- cd tests
- python test_all.py
- date
-##################################
-# stages
-
-stages: # I use anchors for IDE hints only
- - &single_module single_module
- - &build_all build_all
-
-
-###########################################
-## jobs:
-flake8:
- stage: *single_module
- image: python:3.7
- cache:
- key: 'flake8'
- paths:
- - .cache/pip
- - venv/
- script:
- - pip install flake8
- - flake8
-
-ctp:
- <<: &test_single_module
- stage: *single_module
- image: registry.cn-shanghai.aliyuncs.com/vnpy-ci/gcc-8-python-3.7:latest
- services: *services
- cache:
- <<: *cache
- script:
- - *install_scripts
- - *test_scripts
- variables:
- <<: *variables
- VNPY_BUILD_CTP: 1
-
-oes:
- <<: *test_single_module
- variables:
- <<: *variables
- VNPY_BUILD_OES: 1
-
-no_building:
- <<: *test_single_module
- cache:
- <<: *cache
- policy: pull-push
- variables:
- <<: *variables
- VNPY_BUILD_OES: 0
- VNPY_BUILD_CTP: 0
-
-build-all-gcc8:
- stage: *build_all
- variables:
- <<: *variables
- image: registry.cn-shanghai.aliyuncs.com/vnpy-ci/gcc-8-python-3.7:latest
- services: *services
- cache:
- key: "build-all"
- paths: []
- script:
- - unset VNPY_BUILD_CTP
- - unset VNPY_BUILD_OES
- - *install_scripts
- - *test_scripts
-
diff --git a/.travis.yml b/.travis.yml
deleted file mode 100644
index 10ca263c..00000000
--- a/.travis.yml
+++ /dev/null
@@ -1,101 +0,0 @@
-language: python
-
-dist: xenial # required for Python >= 3.7 (travis-ci/travis-ci#9069)
-
-cache: pip
-
-git:
- depth: 1
-
-env:
- - >
- VNPY_BUILD_PARALLEL=1
- VNPY_BUILD_CTP=1
- VNPY_BUILD_OES=1
-
-python:
- - "3.7"
-
-services:
- - mongodb
- - mysql
- - postgresql
-
-before_script:
- - psql -d postgresql://postgres:${VNPY_TEST_POSTGRESQL_PASSWORD}@localhost -c "create database vnpy;"
- - mysql -u root --password=${VNPY_TEST_MYSQL_PASSWORD} -e 'CREATE DATABASE vnpy;'
- - source ci/env.sh;
-
-script:
- - cd tests;
- - python test_all.py
-
-matrix:
- include:
- - name: "code quality analysis: flake8"
- before_install:
- - python -m pip install flake8
- install:
- - "" # prevent running "pip install -r requirements.txt"
- script:
- - flake8
-
- - name: "pip install under Ubuntu: gcc-8"
- addons:
- apt:
- sources:
- - ubuntu-toolchain-r-test
- packages:
- - g++-8
- before_install:
- - sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-8 90
- - sudo update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++-8 90
- - sudo update-alternatives --install /usr/bin/gcc cc /usr/bin/gcc-8 90
- install:
- # update pip & setuptools
- - python -m pip install --upgrade pip wheel setuptools
- # Linux install script
- - python -m pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
- - bash ./install.sh
-
- - name: "sdist install under Ubuntu: gcc-7"
- addons:
- apt:
- sources:
- - ubuntu-toolchain-r-test
- packages:
- - g++-7
- before_install:
- - sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-7 90
- - sudo update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++-7 90
- - sudo update-alternatives --install /usr/bin/gcc cc /usr/bin/gcc-7 90
- install:
- # Linux install script
- - python -m pip install --upgrade pip wheel setuptools
- - pushd /tmp
- - wget http://prdownloads.sourceforge.net/ta-lib/ta-lib-0.4.0-src.tar.gz
- - tar -xf ta-lib-0.4.0-src.tar.gz
- - cd ta-lib
- - ./configure --prefix=/usr
- - make # -j under gcc-7 failed!!?
- - sudo make install
- - popd
- - python -m pip install numpy
- - python -m pip install --pre --extra-index-url https://rquser:ricequant99@py.ricequant.com/simple/ rqdatac
- - python -m pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
- - python setup.py sdist
- - python -m pip install dist/`ls dist`
-
- - name: "pip install under osx"
- os: osx
- language: shell # osx supports only shell
- services: []
- before_install: []
- install:
- - python3 -m pip install https://pip.vnpy.com/colletion/ibapi-9.75.1-001-py3-none-any.whl
- - bash ./install_osx.sh
- before_script: []
- script:
- - source ci/env.sh;
- - cd tests;
- - VNPY_TEST_ONLY_SQLITE=1 python3 test_all.py
diff --git a/README.md b/README.md
index 4ed0dc18..25959d96 100644
--- a/README.md
+++ b/README.md
@@ -5,10 +5,10 @@
-
+
-
+
@@ -32,49 +32,61 @@ vn.py是一套基于Python的开源量化交易系统开发框架,于2015年1
* 国内市场
- * CTP(ctp):国内期货、期权
+ * CTP(ctp):国内期货、期权
- * CTP Mini(mini):国内期货、期权
+ * CTP Mini(mini):国内期货、期权
- * 飞马(femas):国内期货
+ * CTP证券(sopt):ETF期权
- * 宽睿(oes):国内证券(A股)
+ * 飞马(femas):国内期货
- * 中泰XTP(xtp):国内证券(A股)
+ * 宽睿(oes):国内证券(A股)
- * 华鑫奇点(tora):国内证券(A股)
+ * 中泰XTP(xtp):国内证券(A股)
+
+ * 华鑫奇点(tora):国内证券(A股)
+
+ * 鑫管家(xgj):期货资管
+
+ * 融航(rohon):期货资管
* 海外市场
-
- * 富途证券(futu):港股、美股
- * 老虎证券(tiger):全球证券、期货、期权、外汇等
+ * 富途证券(futu):港股、美股
- * Interactive Brokers(ib):全球证券、期货、期权、外汇等
+ * 老虎证券(tiger):全球证券、期货、期权、外汇等
- * 易盛9.0外盘(tap):全球期货
+ * Interactive Brokers(ib):全球证券、期货、期权、外汇等
+
+ * 易盛9.0外盘(tap):全球期货
+
+ * 直达期货(da):全球期货
* 数字货币
- * BitMEX(bitmex):数字货币期货、期权、永续合约
+ * BitMEX(bitmex):数字货币期货、期权、永续合约
- * OKEX合约(okexf):数字货币期货
+ * OKEX永续(okexs):数字货币永续合约
- * 火币合约(hbdm):数字货币期货
+ * OKEX合约(okexf):数字货币期货
- * 币安(binance):数字货币现货
+ * 火币合约(hbdm):数字货币期货
- * OKEX(okex):数字货币现货
+ * 币安(binance):数字货币现货
- * 火币(huobi):数字货币现货
+ * OKEX(okex):数字货币现货
- * Bitfinex(bitfinex):数字货币现货
+ * 火币(huobi):数字货币现货
- * 1Token(onetoken):数字货币券商(现货、期货)
+ * Bitfinex(bitfinex):数字货币现货
+
+ * Coinbase(coinbase):数字货币现货
+
+ * 1Token(onetoken):数字货币券商(现货、期货)
* 特殊应用
- * RPC服务(rpc):跨进程通讯接口,用于分布式架构
+ * RPC服务(rpc):跨进程通讯接口,用于分布式架构
3. 开箱即用的各类量化策略交易应用(vnpy.app):
@@ -82,6 +94,8 @@ vn.py是一套基于Python的开源量化交易系统开发框架,于2015年1
* cta_backtester:CTA策略回测模块,无需使用Jupyter Notebook,直接使用图形界面直接进行策略回测分析、参数优化等相关工作
+ * spread_trading:价差交易模块,支持自定义价差,实时计算价差行情和持仓,支持半自动价差算法交易以及全自动价差策略交易两种模式
+
* algo_trading:算法交易模块,提供多种常用的智能交易算法:TWAP、Sniper、Iceberg、BestLimit等等,支持常用算法配置保存
* script_trader:脚本策略模块,针对多标的组合类交易策略设计,同时也可以直接在命令行中实现REPL指令形式的交易,不支持回测功能
@@ -108,7 +122,7 @@ vn.py是一套基于Python的开源量化交易系统开发框架,于2015年1
## 环境准备
-* 推荐使用vn.py团队为量化交易专门打造的Python发行版[VNStudio-2.0.4](https://download.vnpy.com/vnstudio-2.0.4-r.exe),内置了最新版的vn.py框架以及VN Station量化管理平台,无需手动安装
+* 推荐使用vn.py团队为量化交易专门打造的Python发行版[VNStudio-2.0.7](https://download.vnpy.com/vnstudio-2.0.7.exe),内置了最新版的vn.py框架以及VN Station量化管理平台,无需手动安装
* 支持的系统版本:Windows 7以上/Windows Server 2008以上/Ubuntu 18.04 LTS
* 支持的Python版本:Python 3.7 64位(**注意必须是Python 3.7 64位版本**)
@@ -153,24 +167,24 @@ from vnpy.gateway.ctp import CtpGateway
from vnpy.app.cta_strategy import CtaStrategyApp
from vnpy.app.cta_backtester import CtaBacktesterApp
-def main():
+def main():
"""Start VN Trader"""
- qapp = create_qapp()
+ qapp = create_qapp()
- event_engine = EventEngine()
- main_engine = MainEngine(event_engine)
+ event_engine = EventEngine()
+ main_engine = MainEngine(event_engine)
- main_engine.add_gateway(CtpGateway)
- main_engine.add_app(CtaStrategyApp)
- main_engine.add_app(CtaBacktesterApp)
+ main_engine.add_gateway(CtpGateway)
+ main_engine.add_app(CtaStrategyApp)
+ main_engine.add_app(CtaBacktesterApp)
- main_window = MainWindow(main_engine, event_engine)
- main_window.showMaximized()
+ main_window = MainWindow(main_engine, event_engine)
+ main_window.showMaximized()
- qapp.exec()
+ qapp.exec()
if __name__ == "__main__":
- main()
+ main()
```
在该目录下打开CMD(按住Shift->点击鼠标右键->在此处打开命令窗口/PowerShell)后运行下列命令启动VN Trader:
@@ -179,14 +193,14 @@ if __name__ == "__main__":
## 贡献代码
-vn.py使用Github托管其源代码,如果希望贡献代码请使用github的PR(Pull Request)的流程:
+vn.py使用Github托管其源代码,如果希望贡献代码请使用github的PR(Pull Request)的流程:
-1. [创建 Issue](https://github.com/vnpy/vnpy/issues/new) - 对于较大的改动(如新功能,大型重构等)最好先开issue讨论一下,较小的improvement(如文档改进,bugfix等)直接发PR即可
+1. [创建 Issue](https://github.com/vnpy/vnpy/issues/new) - 对于较大的改动(如新功能,大型重构等)最好先开issue讨论一下,较小的improvement(如文档改进,bugfix等)直接发PR即可
2. Fork [vn.py](https://github.com/vnpy/vnpy) - 点击右上角**Fork**按钮
3. Clone你自己的fork: ```git clone https://github.com/$userid/vnpy.git```
- * 如果你的fork已经过时,需要手动sync:[https://help.github.com/articles/syncing-a-fork/](https://help.github.com/articles/syncing-a-fork/)
+ * 如果你的fork已经过时,需要手动sync:[同步方法](https://help.github.com/articles/syncing-a-fork/)
4. 从**dev**创建你自己的feature branch: ```git checkout -b $my_feature_branch dev```
diff --git a/examples/vn_trader/run.py b/examples/vn_trader/run.py
index 74e3ca1f..f7c23125 100644
--- a/examples/vn_trader/run.py
+++ b/examples/vn_trader/run.py
@@ -5,10 +5,10 @@ from vnpy.trader.engine import MainEngine
from vnpy.trader.ui import MainWindow, create_qapp
# from vnpy.gateway.binance import BinanceGateway
-# from vnpy.gateway.bitmex import BitmexGateway
+from vnpy.gateway.bitmex import BitmexGateway
# from vnpy.gateway.futu import FutuGateway
# from vnpy.gateway.ib import IbGateway
-# from vnpy.gateway.ctp import CtpGateway
+from vnpy.gateway.ctp import CtpGateway
# from vnpy.gateway.ctptest import CtptestGateway
# from vnpy.gateway.mini import MiniGateway
# from vnpy.gateway.sopt import SoptGateway
@@ -18,7 +18,7 @@ from vnpy.trader.ui import MainWindow, create_qapp
# from vnpy.gateway.oes import OesGateway
# from vnpy.gateway.okex import OkexGateway
# from vnpy.gateway.huobi import HuobiGateway
-# from vnpy.gateway.bitfinex import BitfinexGateway
+from vnpy.gateway.bitfinex import BitfinexGateway
# from vnpy.gateway.onetoken import OnetokenGateway
from vnpy.gateway.okexf import OkexfGateway
from vnpy.gateway.okexs import OkexsGateway
@@ -38,6 +38,7 @@ from vnpy.app.cta_backtester import CtaBacktesterApp
# from vnpy.app.risk_manager import RiskManagerApp
from vnpy.app.script_trader import ScriptTraderApp
from vnpy.app.rpc_service import RpcServiceApp
+from vnpy.app.spread_trading import SpreadTradingApp
def main():
@@ -49,7 +50,7 @@ def main():
main_engine = MainEngine(event_engine)
# main_engine.add_gateway(BinanceGateway)
- # main_engine.add_gateway(CtpGateway)
+ main_engine.add_gateway(CtpGateway)
# main_engine.add_gateway(CtptestGateway)
# main_engine.add_gateway(MiniGateway)
# main_engine.add_gateway(SoptGateway)
@@ -57,12 +58,12 @@ def main():
# main_engine.add_gateway(FemasGateway)
# main_engine.add_gateway(IbGateway)
# main_engine.add_gateway(FutuGateway)
- # main_engine.add_gateway(BitmexGateway)
+ main_engine.add_gateway(BitmexGateway)
# main_engine.add_gateway(TigerGateway)
# main_engine.add_gateway(OesGateway)
# main_engine.add_gateway(OkexGateway)
# main_engine.add_gateway(HuobiGateway)
- # main_engine.add_gateway(BitfinexGateway)
+ main_engine.add_gateway(BitfinexGateway)
# main_engine.add_gateway(OnetokenGateway)
# main_engine.add_gateway(OkexfGateway)
# main_engine.add_gateway(HbdmGateway)
@@ -80,8 +81,9 @@ def main():
# main_engine.add_app(AlgoTradingApp)
# main_engine.add_app(DataRecorderApp)
# main_engine.add_app(RiskManagerApp)
- main_engine.add_app(ScriptTraderApp)
- main_engine.add_app(RpcServiceApp)
+ # main_engine.add_app(ScriptTraderApp)
+ # main_engine.add_app(RpcServiceApp)
+ main_engine.add_app(SpreadTradingApp)
main_window = MainWindow(main_engine, event_engine)
main_window.showMaximized()
diff --git a/requirements.txt b/requirements.txt
index 72df05c9..26a5e53b 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -18,4 +18,5 @@ ta-lib
ibapi
deap
pyzmq
-sortedcontainers
\ No newline at end of file
+sortedcontainers
+wmi
diff --git a/vnpy/__init__.py b/vnpy/__init__.py
index ff6ef86d..962c851b 100644
--- a/vnpy/__init__.py
+++ b/vnpy/__init__.py
@@ -1 +1 @@
-__version__ = "2.0.6"
+__version__ = "2.0.7"
diff --git a/vnpy/api/rest/rest_client.py b/vnpy/api/rest/rest_client.py
index 2ab3d85e..ee8e98ba 100644
--- a/vnpy/api/rest/rest_client.py
+++ b/vnpy/api/rest/rest_client.py
@@ -1,7 +1,9 @@
+import logging
import multiprocessing
import os
import sys
import traceback
+import uuid
from datetime import datetime
from enum import Enum
from multiprocessing.dummy import Pool
@@ -11,6 +13,8 @@ from typing import Any, Callable, List, Optional, Type, Union
import requests
+from vnpy.trader.utility import get_file_logger
+
class RequestStatus(Enum):
ready = 0 # Request created
@@ -110,10 +114,12 @@ class RestClient(object):
"""
"""
self.url_base = '' # type: str
- self._active = False
+ self.logger: Optional[logging.Logger] = None
self.proxies = None
+ self._active = False
+
self._tasks_lock = Lock()
self._tasks: List[multiprocessing.pool.AsyncResult] = []
self._sessions_lock = Lock()
@@ -123,12 +129,24 @@ class RestClient(object):
def alive(self):
return self._active
- def init(self, url_base: str, proxy_host: str = "", proxy_port: int = 0):
+ def init(self,
+ url_base: str,
+ proxy_host: str = "",
+ proxy_port: int = 0,
+ log_path: Optional[str] = None,
+ ):
"""
Init rest client with url_base which is the API root address.
e.g. 'https://www.bitmex.com/api/v1/'
+ :param url_base:
+ :param proxy_host:
+ :param proxy_port:
+ :param log_path: optional. file to save log.
"""
self.url_base = url_base
+ if log_path is not None:
+ self.logger = get_file_logger(log_path)
+ self.logger.setLevel(logging.DEBUG)
if proxy_host and proxy_port:
proxy = f"{proxy_host}:{proxy_port}"
@@ -275,6 +293,11 @@ class RestClient(object):
"""
return True
+ def _log(self, msg, *args):
+ logger = self.logger
+ if logger:
+ logger.debug(msg, *args)
+
def _process_request(
self, request: Request
):
@@ -283,19 +306,32 @@ class RestClient(object):
"""
try:
with self._get_session() as session:
+ # sign
request = self.sign(request)
+ # send request
url = self.make_full_url(request.path)
+ uid = uuid.uuid4()
+ method = request.method
+ headers = request.headers
+ params = request.params
+ data = request.data
+ self._log("[%s] sending request %s %s, headers:%s, params:%s, data:%s",
+ uid, method, url,
+ headers, params, data)
response = session.request(
- request.method,
+ method,
url,
- headers=request.headers,
- params=request.params,
- data=request.data,
+ headers=headers,
+ params=params,
+ data=data,
proxies=self.proxies,
)
request.response = response
+ self._log("[%s] received response from %s:%s", uid, method, url)
+
+ # check result & call corresponding callbacks
status_code = response.status_code
success = False
diff --git a/vnpy/api/tap/vntap.pyd b/vnpy/api/tap/vntap.pyd
index 699dff73..ce2c9f27 100644
Binary files a/vnpy/api/tap/vntap.pyd and b/vnpy/api/tap/vntap.pyd differ
diff --git a/vnpy/api/tap/vntap/custom/custom_wrappers.hpp b/vnpy/api/tap/vntap/custom/custom_wrappers.hpp
index c0fd67c4..70484824 100644
--- a/vnpy/api/tap/vntap/custom/custom_wrappers.hpp
+++ b/vnpy/api/tap/vntap/custom/custom_wrappers.hpp
@@ -6,168 +6,169 @@
namespace c2py
{
- // TapAPIOrderInfoNotice
- struct FixedTapAPIOrderInfoNotice : ITapTrade::TapAPIOrderInfoNotice
+ // // TapAPIOrderInfoNotice
+ // struct FixedTapAPIOrderInfoNotice : ITapTrade::TapAPIOrderInfoNotice
+ // {
+ // ITapTrade::TapAPIOrderInfo order_info;
+ //
+ // // copy from original structure
+ // FixedTapAPIOrderInfoNotice(const ITapTrade::TapAPIOrderInfoNotice* info)
+ // : TapAPIOrderInfoNotice(*info), order_info(info->OrderInfo != nullptr ? *info->OrderInfo : ITapTrade::TapAPIOrderInfo{})
+ // {
+ // // fix pointer if there is one
+ // this->OrderInfo = info->OrderInfo != nullptr ? &this->order_info : nullptr;
+ // }
+ //
+ // // copy constructor
+ // FixedTapAPIOrderInfoNotice(const FixedTapAPIOrderInfoNotice& fixed)
+ // : TapAPIOrderInfoNotice(fixed), order_info(fixed.order_info)
+ // {
+ // // fix pointer if there is one
+ // this->OrderInfo = this->OrderInfo != nullptr ? &this->order_info : nullptr;
+ // }
+ // };
+ // // TapAPIPositionProfit
+ // struct FixedTapAPIPositionProfitNotice : ITapTrade::TapAPIPositionProfitNotice
+ // {
+ // ITapTrade::TapAPIPositionProfit data;
+ //
+ // // copy from original structure
+ // FixedTapAPIPositionProfitNotice(const ITapTrade::TapAPIPositionProfitNotice* info)
+ // : TapAPIPositionProfitNotice(*info), data(info->Data != nullptr ? *info->Data : ITapTrade::TapAPIPositionProfit{})
+ // {
+ // // fix pointer if there is one
+ // this->Data = info->Data != nullptr ? &this->data : nullptr;
+ // }
+ //
+ // // copy constructor
+ // FixedTapAPIPositionProfitNotice(const FixedTapAPIPositionProfitNotice& fixed)
+ // : TapAPIPositionProfitNotice(fixed), data(fixed.data)
+ // {
+ // // fix pointer if there is one
+ // this->Data = this->Data != nullptr ? &this->data : nullptr;
+ // }
+ // };
+ //
+ // namespace arg_helper
+ // {
+ // inline auto save(const ITapTrade::TapAPIOrderInfoNotice* info)
+ // { // match char []
+ // return FixedTapAPIOrderInfoNotice(info);
+ // }
+ //
+ // template <>
+ // struct loader
+ // {
+ // inline FixedTapAPIOrderInfoNotice operator ()(FixedTapAPIOrderInfoNotice& val)
+ // {
+ // return val;
+ // }
+ // };
+ //
+ // inline auto save(const ITapTrade::TapAPIPositionProfitNotice* info)
+ // { // match char []
+ // return FixedTapAPIPositionProfitNotice(info);
+ // }
+ //
+ // template <>
+ // struct loader
+ // {
+ // inline FixedTapAPIPositionProfitNotice operator ()(FixedTapAPIPositionProfitNotice& val)
+ // {
+ // return val;
+ // }
+ // };
+ // }
+
+ template<>
+ struct callback_wrapper<&ITapTrade::ITapTradeAPINotify::OnRspOrderAction>
{
- ITapTrade::TapAPIOrderInfo order_info;
-
- // copy from original structure
- FixedTapAPIOrderInfoNotice(const ITapTrade::TapAPIOrderInfoNotice* info)
- : TapAPIOrderInfoNotice(*info), order_info(info->OrderInfo != nullptr ? *info->OrderInfo : ITapTrade::TapAPIOrderInfo{})
+ inline static void call(ITapTrade::ITapTradeAPINotify* instance, const char* py_func_name, ITapTrade::TAPIUINT32 sessionID, ITapTrade::TAPIINT32 errorCode, const ITapTrade::TapAPIOrderActionRsp* info)
{
- // fix pointer if there is one
- this->OrderInfo = info->OrderInfo != nullptr ? &this->order_info : nullptr;
- }
+ ITapTrade::TapAPIOrderInfo orderInfo;
+ if (info->OrderInfo != nullptr)
+ {
+ orderInfo = *info->OrderInfo;
+ }
+ ITapTrade::TapAPIOrderActionRsp copied_info = *info;
+ auto task = [=]() mutable
+ {
+ if (copied_info.OrderInfo != nullptr)
+ {
+ copied_info.OrderInfo = &orderInfo; // ensure pointer is pointer to the correct address(address changes after constructed lambda)
+ }
+ try
+ {
+ return default_callback_wrapper<&ITapTrade::ITapTradeAPINotify::OnRspOrderAction>::sync(instance, py_func_name, sessionID, errorCode, &copied_info);
+ }
+ catch (const async_dispatch_exception& e)
+ {
+ async_callback_exception_handler::handle_excepiton(e);
+ }
- // copy constructor
- FixedTapAPIOrderInfoNotice(const FixedTapAPIOrderInfoNotice& fixed)
- : TapAPIOrderInfoNotice(fixed), order_info(fixed.order_info)
- {
- // fix pointer if there is one
- this->OrderInfo = this->OrderInfo != nullptr ? &this->order_info : nullptr;
+ };
+ dispatcher::instance().add(std::move(task));
}
};
- // TapAPIPositionProfit
- struct FixedTapAPIPositionProfitNotice : ITapTrade::TapAPIPositionProfitNotice
+
+ template<>
+ struct callback_wrapper<&ITapTrade::ITapTradeAPINotify::OnRtnOrder>
{
- ITapTrade::TapAPIPositionProfit data;
-
- // copy from original structure
- FixedTapAPIPositionProfitNotice(const ITapTrade::TapAPIPositionProfitNotice* info)
- : TapAPIPositionProfitNotice(*info), data(info->Data != nullptr ? *info->Data : ITapTrade::TapAPIPositionProfit{})
+ inline static void call(ITapTrade::ITapTradeAPINotify* instance, const char* py_func_name, const ITapTrade::TapAPIOrderInfoNotice* info)
{
- // fix pointer if there is one
- this->Data = info->Data != nullptr ? &this->data : nullptr;
- }
-
- // copy constructor
- FixedTapAPIPositionProfitNotice(const FixedTapAPIPositionProfitNotice& fixed)
- : TapAPIPositionProfitNotice(fixed), data(fixed.data)
- {
- // fix pointer if there is one
- this->Data = this->Data != nullptr ? &this->data : nullptr;
+ ITapTrade::TapAPIOrderInfo orderInfo;
+ if (info->OrderInfo != nullptr)
+ {
+ orderInfo = *info->OrderInfo;
+ }
+ ITapTrade::TapAPIOrderInfoNotice copied_info = *info;
+ auto task = [=]() mutable
+ {
+ if (copied_info.OrderInfo != nullptr)
+ {
+ copied_info.OrderInfo = &orderInfo; // ensure pointer is pointer to the correct address(address changes after constructed lambda)
+ }
+ try
+ {
+ return default_callback_wrapper<&ITapTrade::ITapTradeAPINotify::OnRtnOrder>::sync(instance, py_func_name, &copied_info);
+ }
+ catch (const async_dispatch_exception& e)
+ {
+ async_callback_exception_handler::handle_excepiton(e);
+ }
+ };
+ dispatcher::instance().add(std::move(task));
}
};
- namespace arg_helper
+ template<>
+ struct callback_wrapper<&ITapTrade::ITapTradeAPINotify::OnRtnPositionProfit>
{
- inline auto save(const ITapTrade::TapAPIOrderInfoNotice* info)
- { // match char []
- return FixedTapAPIOrderInfoNotice(info);
- }
-
- template <>
- struct loader
+ inline static void call(ITapTrade::ITapTradeAPINotify* instance, const char* py_func_name, const ITapTrade::TapAPIPositionProfitNotice* info)
{
- inline FixedTapAPIOrderInfoNotice operator ()(FixedTapAPIOrderInfoNotice& val)
- {
- return val;
- }
- };
- inline auto save(const ITapTrade::TapAPIPositionProfitNotice* info)
- { // match char []
- return FixedTapAPIPositionProfitNotice(info);
+ ITapTrade::TapAPIPositionProfit profit;
+ if (info->Data != nullptr)
+ {
+ profit = *info->Data;
+ }
+ ITapTrade::TapAPIPositionProfitNotice copied_info = *info;
+ auto task = [=]() mutable
+ {
+ if (copied_info.Data != nullptr)
+ {
+ copied_info.Data = &profit; // ensure pointer is pointer to the correct address(address changes after constructed lambda)
+ }
+ try
+ {
+ return default_callback_wrapper<&ITapTrade::ITapTradeAPINotify::OnRtnPositionProfit>::sync(instance, py_func_name, &copied_info);
+ }
+ catch (const async_dispatch_exception& e)
+ {
+ async_callback_exception_handler::handle_excepiton(e);
+ }
+ };
+ dispatcher::instance().add(std::move(task));
}
-
- template <>
- struct loader
- {
- inline FixedTapAPIPositionProfitNotice operator ()(FixedTapAPIPositionProfitNotice& val)
- {
- return val;
- }
- };
- }
-
- //template<>
- //struct callback_wrapper<&ITapTrade::ITapTradeAPINotify::OnRspOrderAction>
- //{
- // inline static void call(ITapTrade::ITapTradeAPINotify* instance, const char* py_func_name, ITapTrade::TAPIUINT32 sessionID, ITapTrade::TAPIINT32 errorCode, const ITapTrade::TapAPIOrderActionRsp* info)
- // {
- // ITapTrade::TapAPIOrderInfo orderInfo;
- // if (info->OrderInfo != nullptr)
- // {
- // orderInfo = *info->OrderInfo;
- // }
- // ITapTrade::TapAPIOrderActionRsp copied_info = *info;
- // auto task = [=]() mutable
- // {
- // if (copied_info.OrderInfo != nullptr)
- // {
- // copied_info.OrderInfo = &orderInfo; // ensure pointer is pointer to the correct address(address changes after constructed lambda)
- // }
- // try
- // {
- // return default_callback_wrapper<&ITapTrade::ITapTradeAPINotify::OnRspOrderAction>::sync(instance, py_func_name, sessionID, errorCode, &copied_info);
- // }
- // catch (const async_dispatch_exception& e)
- // {
- // async_callback_exception_handler::handle_excepiton(e);
- // }
-
- // };
- // dispatcher::instance().add(std::move(task));
- // }
- //};
- //template<>
- //struct callback_wrapper<&ITapTrade::ITapTradeAPINotify::OnRtnOrder>
- //{
- // inline static void call(ITapTrade::ITapTradeAPINotify* instance, const char* py_func_name, const ITapTrade::TapAPIOrderInfoNotice* info)
- // {
- // ITapTrade::TapAPIOrderInfo orderInfo;
- // if (info->OrderInfo != nullptr)
- // {
- // orderInfo = *info->OrderInfo;
- // }
- // ITapTrade::TapAPIOrderInfoNotice copied_info = *info;
- // auto task = [=]() mutable
- // {
- // if (copied_info.OrderInfo != nullptr)
- // {
- // copied_info.OrderInfo = &orderInfo; // ensure pointer is pointer to the correct address(address changes after constructed lambda)
- // }
- // try
- // {
- // return default_callback_wrapper<&ITapTrade::ITapTradeAPINotify::OnRtnOrder>::sync(instance, py_func_name, &copied_info);
- // }
- // catch (const async_dispatch_exception& e)
- // {
- // async_callback_exception_handler::handle_excepiton(e);
- // }
- // };
- // dispatcher::instance().add(std::move(task));
- // }
- //};
-
- //template<>
- //struct callback_wrapper<&ITapTrade::ITapTradeAPINotify::OnRtnPositionProfit>
- //{
- // inline static void call(ITapTrade::ITapTradeAPINotify* instance, const char* py_func_name, const ITapTrade::TapAPIPositionProfitNotice* info)
- // {
-
- // ITapTrade::TapAPIPositionProfit profit;
- // if (info->Data != nullptr)
- // {
- // profit = *info->Data;
- // }
- // ITapTrade::TapAPIPositionProfitNotice copied_info = *info;
- // auto task = [=]() mutable
- // {
- // if (copied_info.Data != nullptr)
- // {
- // copied_info.Data = &profit; // ensure pointer is pointer to the correct address(address changes after constructed lambda)
- // }
- // try
- // {
- // return default_callback_wrapper<&ITapTrade::ITapTradeAPINotify::OnRtnPositionProfit>::sync(instance, py_func_name, &copied_info);
- // }
- // catch (const async_dispatch_exception& e)
- // {
- // async_callback_exception_handler::handle_excepiton(e);
- // }
- // };
- // dispatcher::instance().add(std::move(task));
- // }
- //};
+ };
}
diff --git a/vnpy/api/tap/vntap/vntap.vcxproj b/vnpy/api/tap/vntap/vntap.vcxproj
index f2013708..af7ed955 100644
--- a/vnpy/api/tap/vntap/vntap.vcxproj
+++ b/vnpy/api/tap/vntap/vntap.vcxproj
@@ -115,30 +115,30 @@
.pyd
$(SolutionDir)
$(Platform)\$(Configuration)\
- $(ProjectDir);$(ProjectDir)include;C:\Python37\include;$(IncludePath)
- $(ProjectDir)\libs;C:\Python37\libs;$(LibraryPath)
+ $(ProjectDir);$(ProjectDir)include;C:\Python37\include;C:\Python373\include;$(IncludePath)
+ $(ProjectDir)\libs;C:\Python37\libs;C:\Python373\libs;$(LibraryPath)
true
.pyd
$(SolutionDir)
- $(ProjectDir);$(ProjectDir)include;C:\Python37\include;$(IncludePath)
- $(ProjectDir)\libs;C:\Python37\libs;$(LibraryPath)
+ $(ProjectDir);$(ProjectDir)include;C:\Python37\include;C:\Python373\include;$(IncludePath)
+ $(ProjectDir)\libs;C:\Python37\libs;C:\Python373\libs;$(LibraryPath)
false
.pyd
$(SolutionDir)
$(Platform)\$(Configuration)\
- $(ProjectDir);$(ProjectDir)include;C:\Python37\include;$(IncludePath)
- $(ProjectDir)\libs;C:\Python37\libs;$(LibraryPath)
+ $(ProjectDir);$(ProjectDir)include;C:\Python37\include;C:\Python373\include;$(IncludePath)
+ $(ProjectDir)\libs;C:\Python37\libs;C:\Python373\libs;$(LibraryPath)
false
.pyd
$(SolutionDir)
- $(ProjectDir);$(ProjectDir)include;C:\Python37\include;$(IncludePath)
- $(ProjectDir)\libs;C:\Python37\libs;$(LibraryPath)
+ $(ProjectDir);$(ProjectDir)include;C:\Python37\include;C:\Python373\include;$(IncludePath)
+ $(ProjectDir)\libs;C:\Python37\libs;C:\Python373\libs;$(LibraryPath)
diff --git a/vnpy/api/tap/vntap/vntap.vcxproj.filters b/vnpy/api/tap/vntap/vntap.vcxproj.filters
index 5740dd73..9611ea32 100644
--- a/vnpy/api/tap/vntap/vntap.vcxproj.filters
+++ b/vnpy/api/tap/vntap/vntap.vcxproj.filters
@@ -98,8 +98,5 @@
Source Files
-
- Source Files
-
\ No newline at end of file
diff --git a/vnpy/api/websocket/websocket_client.py b/vnpy/api/websocket/websocket_client.py
index e7c767a0..d89529a4 100644
--- a/vnpy/api/websocket/websocket_client.py
+++ b/vnpy/api/websocket/websocket_client.py
@@ -1,14 +1,18 @@
import json
+import logging
+import socket
import ssl
import sys
import traceback
-import socket
from datetime import datetime
from threading import Lock, Thread
from time import sleep
+from typing import Optional
import websocket
+from vnpy.trader.utility import get_file_logger
+
class WebsocketClient(object):
"""
@@ -47,19 +51,36 @@ class WebsocketClient(object):
self.proxy_host = None
self.proxy_port = None
- self.ping_interval = 60 # seconds
+ self.ping_interval = 60 # seconds
self.header = {}
+ self.logger: Optional[logging.Logger] = None
+
# For debugging
self._last_sent_text = None
self._last_received_text = None
- def init(self, host: str, proxy_host: str = "", proxy_port: int = 0, ping_interval: int = 60, header: dict = None):
+ def init(self,
+ host: str,
+ proxy_host: str = "",
+ proxy_port: int = 0,
+ ping_interval: int = 60,
+ header: dict = None,
+ log_path: Optional[str] = None,
+ ):
"""
+ :param host:
+ :param proxy_host:
+ :param proxy_port:
+ :param header:
:param ping_interval: unit: seconds, type: int
+ :param log_path: optional. file to save log.
"""
self.host = host
self.ping_interval = ping_interval # seconds
+ if log_path is not None:
+ self.logger = get_file_logger(log_path)
+ self.logger.setLevel(logging.DEBUG)
if header:
self.header = header
@@ -109,6 +130,11 @@ class WebsocketClient(object):
self._record_last_sent_text(text)
return self._send_text(text)
+ def _log(self, msg, *args):
+ logger = self.logger
+ if logger:
+ logger.debug(msg, *args)
+
def _send_text(self, text: str):
"""
Send a text string to server.
@@ -116,6 +142,7 @@ class WebsocketClient(object):
ws = self._ws
if ws:
ws.send(text, opcode=websocket.ABNF.OPCODE_TEXT)
+ self._log('sent text: %s', text)
def _send_binary(self, data: bytes):
"""
@@ -124,6 +151,7 @@ class WebsocketClient(object):
ws = self._ws
if ws:
ws._send_binary(data)
+ self._log('sent binary: %s', data)
def _create_connection(self, *args, **kwargs):
""""""
@@ -184,6 +212,7 @@ class WebsocketClient(object):
print("websocket unable to parse data: " + text)
raise e
+ self._log('recv data: %s', data)
self.on_packet(data)
# ws is closed before recv function is called
# For socket.error, see Issue #1608
diff --git a/vnpy/app/cta_backtester/engine.py b/vnpy/app/cta_backtester/engine.py
index b64045a2..5b677c9a 100644
--- a/vnpy/app/cta_backtester/engine.py
+++ b/vnpy/app/cta_backtester/engine.py
@@ -91,14 +91,16 @@ class BacktesterEngine(BaseEngine):
"""
for dirpath, dirnames, filenames in os.walk(path):
for filename in filenames:
+ # Load python source code file
if filename.endswith(".py"):
strategy_module_name = ".".join(
[module_name, filename.replace(".py", "")])
+ self.load_strategy_class_from_module(strategy_module_name)
+ # Load compiled pyd binary file
elif filename.endswith(".pyd"):
strategy_module_name = ".".join(
[module_name, filename.split(".")[0]])
-
- self.load_strategy_class_from_module(strategy_module_name)
+ self.load_strategy_class_from_module(strategy_module_name)
def load_strategy_class_from_module(self, module_name: str):
"""
@@ -354,18 +356,24 @@ class BacktesterEngine(BaseEngine):
contract = self.main_engine.get_contract(vt_symbol)
- # If history data provided in gateway, then query
- if contract and contract.history_data:
- data = self.main_engine.query_history(req, contract.gateway_name)
- # Otherwise use RQData to query data
- else:
- data = rqdata_client.query_history(req)
+ try:
+ # If history data provided in gateway, then query
+ if contract and contract.history_data:
+ data = self.main_engine.query_history(
+ req, contract.gateway_name
+ )
+ # Otherwise use RQData to query data
+ else:
+ data = rqdata_client.query_history(req)
- if data:
- database_manager.save_bar_data(data)
- self.write_log(f"{vt_symbol}-{interval}历史数据下载完成")
- else:
- self.write_log(f"数据下载失败,无法获取{vt_symbol}的历史数据")
+ if data:
+ database_manager.save_bar_data(data)
+ self.write_log(f"{vt_symbol}-{interval}历史数据下载完成")
+ else:
+ self.write_log(f"数据下载失败,无法获取{vt_symbol}的历史数据")
+ except Exception:
+ msg = f"数据下载失败,触发异常:\n{traceback.format_exc()}"
+ self.write_log(msg)
# Clear thread object handler.
self.thread = None
diff --git a/vnpy/app/cta_backtester/ui/widget.py b/vnpy/app/cta_backtester/ui/widget.py
index 70752e05..6412a99a 100644
--- a/vnpy/app/cta_backtester/ui/widget.py
+++ b/vnpy/app/cta_backtester/ui/widget.py
@@ -332,7 +332,7 @@ class BacktesterManager(QtWidgets.QWidget):
end_date = self.end_date_edit.date()
start = datetime(start_date.year(), start_date.month(), start_date.day())
- end = datetime(end_date.year(), end_date.month(), end_date.day())
+ end = datetime(end_date.year(), end_date.month(), end_date.day(), 23, 59, 59)
self.backtester_engine.start_downloading(
vt_symbol,
diff --git a/vnpy/app/cta_strategy/engine.py b/vnpy/app/cta_strategy/engine.py
index 5245faca..c4622fc2 100644
--- a/vnpy/app/cta_strategy/engine.py
+++ b/vnpy/app/cta_strategy/engine.py
@@ -38,6 +38,7 @@ from vnpy.trader.constant import (
from vnpy.trader.utility import load_json, save_json, extract_vt_symbol, round_to
from vnpy.trader.database import database_manager
from vnpy.trader.rqdata import rqdata_client
+from vnpy.trader.converter import OffsetConverter
from .base import (
APP_NAME,
@@ -50,7 +51,6 @@ from .base import (
STOPORDER_PREFIX
)
from .template import CtaTemplate
-from .converter import OffsetConverter
STOP_STATUS_MAP = {
diff --git a/vnpy/app/cta_strategy/strategies/turtle_signal_strategy.py b/vnpy/app/cta_strategy/strategies/turtle_signal_strategy.py
index aa608b95..bcbe766d 100644
--- a/vnpy/app/cta_strategy/strategies/turtle_signal_strategy.py
+++ b/vnpy/app/cta_strategy/strategies/turtle_signal_strategy.py
@@ -78,7 +78,12 @@ class TurtleSignalStrategy(CtaTemplate):
if not self.am.inited:
return
- self.entry_up, self.entry_down = self.am.donchian(self.entry_window)
+ # Only calculates new entry channel when no position holding
+ if not self.pos:
+ self.entry_up, self.entry_down = self.am.donchian(
+ self.entry_window
+ )
+
self.exit_up, self.exit_down = self.am.donchian(self.exit_window)
if not self.pos:
@@ -92,13 +97,13 @@ class TurtleSignalStrategy(CtaTemplate):
self.send_buy_orders(self.entry_up)
self.send_short_orders(self.entry_down)
elif self.pos > 0:
- self.send_buy_orders(self.long_entry)
+ self.send_buy_orders(self.entry_up)
sell_price = max(self.long_stop, self.exit_down)
self.sell(sell_price, abs(self.pos), True)
elif self.pos < 0:
- self.send_short_orders(self.short_entry)
+ self.send_short_orders(self.entry_down)
cover_price = min(self.short_stop, self.exit_up)
self.cover(cover_price, abs(self.pos), True)
diff --git a/vnpy/app/spread_trading/__init__.py b/vnpy/app/spread_trading/__init__.py
new file mode 100644
index 00000000..360a4390
--- /dev/null
+++ b/vnpy/app/spread_trading/__init__.py
@@ -0,0 +1,28 @@
+from pathlib import Path
+
+from vnpy.trader.app import BaseApp
+from vnpy.trader.object import (
+ OrderData,
+ TradeData
+)
+
+from .engine import (
+ SpreadEngine,
+ APP_NAME,
+ SpreadData,
+ LegData,
+ SpreadStrategyTemplate,
+ SpreadAlgoTemplate
+)
+
+
+class SpreadTradingApp(BaseApp):
+ """"""
+
+ app_name = APP_NAME
+ app_module = __module__
+ app_path = Path(__file__).parent
+ display_name = "价差交易"
+ engine_class = SpreadEngine
+ widget_name = "SpreadManager"
+ icon_name = "spread.ico"
diff --git a/vnpy/app/spread_trading/algo.py b/vnpy/app/spread_trading/algo.py
new file mode 100644
index 00000000..bda5d317
--- /dev/null
+++ b/vnpy/app/spread_trading/algo.py
@@ -0,0 +1,141 @@
+from typing import Any
+
+from vnpy.trader.constant import Direction
+from vnpy.trader.object import (TickData, OrderData, TradeData)
+
+from .template import SpreadAlgoTemplate
+from .base import SpreadData
+
+
+class SpreadTakerAlgo(SpreadAlgoTemplate):
+ """"""
+ algo_name = "SpreadTaker"
+
+ def __init__(
+ self,
+ algo_engine: Any,
+ algoid: str,
+ spread: SpreadData,
+ direction: Direction,
+ price: float,
+ volume: float,
+ payup: int,
+ interval: int,
+ lock: bool
+ ):
+ """"""
+ super().__init__(
+ algo_engine, algoid, spread, direction,
+ price, volume, payup, interval, lock
+ )
+
+ self.cancel_interval: int = 2
+ self.timer_count: int = 0
+
+ def on_tick(self, tick: TickData):
+ """"""
+ # Return if tick not inited
+ if not self.spread.bid_volume or not self.spread.ask_volume:
+ return
+
+ # Return if there are any existing orders
+ if not self.check_order_finished():
+ return
+
+ # Hedge if active leg is not fully hedged
+ if not self.check_hedge_finished():
+ self.hedge_passive_legs()
+ return
+
+ # Otherwise check if should take active leg
+ if self.direction == Direction.LONG:
+ if self.spread.ask_price <= self.price:
+ self.take_active_leg()
+ else:
+ if self.spread.bid_price >= self.price:
+ self.take_active_leg()
+
+ def on_order(self, order: OrderData):
+ """"""
+ # Only care active leg order update
+ if order.vt_symbol != self.spread.active_leg.vt_symbol:
+ return
+
+ # Do nothing if still any existing orders
+ if not self.check_order_finished():
+ return
+
+ # Hedge passive legs if necessary
+ if not self.check_hedge_finished():
+ self.hedge_passive_legs()
+
+ def on_trade(self, trade: TradeData):
+ """"""
+ pass
+
+ def on_interval(self):
+ """"""
+ if not self.check_order_finished():
+ self.cancel_all_order()
+
+ def take_active_leg(self):
+ """"""
+ # Calculate spread order volume of new round trade
+ spread_volume_left = self.target - self.traded
+
+ if self.direction == Direction.LONG:
+ spread_order_volume = self.spread.ask_volume
+ spread_order_volume = min(spread_order_volume, spread_volume_left)
+ else:
+ spread_order_volume = -self.spread.bid_volume
+ spread_order_volume = max(spread_order_volume, spread_volume_left)
+
+ # Calculate active leg order volume
+ leg_order_volume = self.spread.calculate_leg_volume(
+ self.spread.active_leg.vt_symbol,
+ spread_order_volume
+ )
+
+ # Send active leg order
+ self.send_leg_order(
+ self.spread.active_leg.vt_symbol,
+ leg_order_volume
+ )
+
+ def hedge_passive_legs(self):
+ """
+ Send orders to hedge all passive legs.
+ """
+ # Calcualte spread volume to hedge
+ active_leg = self.spread.active_leg
+ active_traded = self.leg_traded[active_leg.vt_symbol]
+
+ hedge_volume = self.spread.calculate_spread_volume(
+ active_leg.vt_symbol,
+ active_traded
+ )
+
+ # Calculate passive leg target volume and do hedge
+ for leg in self.spread.passive_legs:
+ passive_traded = self.leg_traded[leg.vt_symbol]
+ passive_target = self.spread.calculate_leg_volume(
+ leg.vt_symbol,
+ hedge_volume
+ )
+
+ leg_order_volume = passive_target - passive_traded
+ if leg_order_volume:
+ self.send_leg_order(leg.vt_symbol, leg_order_volume)
+
+ def send_leg_order(self, vt_symbol: str, leg_volume: float):
+ """"""
+ leg = self.spread.legs[vt_symbol]
+ leg_tick = self.get_tick(vt_symbol)
+ leg_contract = self.get_contract(vt_symbol)
+
+ if leg_volume > 0:
+ price = leg_tick.ask_price_1 + leg_contract.pricetick * self.payup
+ self.send_long_order(leg.vt_symbol, price, abs(leg_volume))
+ elif leg_volume < 0:
+ price = leg_tick.bid_price_1 - leg_contract.pricetick * self.payup
+ self.send_short_order(leg.vt_symbol, price, abs(leg_volume))
diff --git a/vnpy/app/spread_trading/base.py b/vnpy/app/spread_trading/base.py
new file mode 100644
index 00000000..5eff4b48
--- /dev/null
+++ b/vnpy/app/spread_trading/base.py
@@ -0,0 +1,235 @@
+from typing import Dict, List
+from math import floor, ceil
+from datetime import datetime
+
+from vnpy.trader.object import TickData, PositionData, TradeData
+from vnpy.trader.constant import Direction, Offset, Exchange
+
+
+EVENT_SPREAD_DATA = "eSpreadData"
+EVENT_SPREAD_POS = "eSpreadPos"
+EVENT_SPREAD_LOG = "eSpreadLog"
+EVENT_SPREAD_ALGO = "eSpreadAlgo"
+EVENT_SPREAD_STRATEGY = "eSpreadStrategy"
+
+
+class LegData:
+ """"""
+
+ def __init__(self, vt_symbol: str):
+ """"""
+ self.vt_symbol: str = vt_symbol
+
+ # Price and position data
+ self.bid_price: float = 0
+ self.ask_price: float = 0
+ self.bid_volume: float = 0
+ self.ask_volume: float = 0
+
+ self.long_pos: float = 0
+ self.short_pos: float = 0
+ self.net_pos: float = 0
+
+ # Tick data buf
+ self.tick: TickData = None
+
+ def update_tick(self, tick: TickData):
+ """"""
+ self.bid_price = tick.bid_price_1
+ self.ask_price = tick.ask_price_1
+ self.bid_volume = tick.bid_volume_1
+ self.ask_volume = tick.ask_volume_1
+
+ self.tick = tick
+
+ def update_position(self, position: PositionData):
+ """"""
+ if position.direction == Direction.NET:
+ self.net_pos = position.volume
+ else:
+ if position.direction == Direction.LONG:
+ self.long_pos = position.volume
+ else:
+ self.short_pos = position.volume
+ self.net_pos = self.long_pos - self.short_pos
+
+ def update_trade(self, trade: TradeData):
+ """"""
+ if trade.direction == Direction.LONG:
+ if trade.offset == Offset.OPEN:
+ self.long_pos += trade.volume
+ else:
+ self.short_pos -= trade.volume
+ else:
+ if trade.offset == Offset.OPEN:
+ self.short_pos += trade.volume
+ else:
+ self.long_pos -= trade.volume
+
+ self.net_pos = self.long_pos - self.net_pos
+
+
+class SpreadData:
+ """"""
+
+ def __init__(
+ self,
+ name: str,
+ legs: List[LegData],
+ price_multipliers: Dict[str, int],
+ trading_multipliers: Dict[str, int],
+ active_symbol: str
+ ):
+ """"""
+ self.name: str = name
+
+ self.legs: Dict[str, LegData] = {}
+ self.active_leg: LegData = None
+ self.passive_legs: List[LegData] = []
+
+ # For calculating spread price
+ self.price_multipliers: Dict[str: int] = price_multipliers
+
+ # For calculating spread pos and sending orders
+ self.trading_multipliers: Dict[str: int] = trading_multipliers
+
+ self.price_formula: str = ""
+ self.trading_formula: str = ""
+
+ for leg in legs:
+ self.legs[leg.vt_symbol] = leg
+ if leg.vt_symbol == active_symbol:
+ self.active_leg = leg
+ else:
+ self.passive_legs.append(leg)
+
+ price_multiplier = self.price_multipliers[leg.vt_symbol]
+ if price_multiplier > 0:
+ self.price_formula += f"+{price_multiplier}*{leg.vt_symbol}"
+ else:
+ self.price_formula += f"{price_multiplier}*{leg.vt_symbol}"
+
+ trading_multiplier = self.trading_multipliers[leg.vt_symbol]
+ if trading_multiplier > 0:
+ self.trading_formula += f"+{trading_multiplier}*{leg.vt_symbol}"
+ else:
+ self.trading_formula += f"{trading_multiplier}*{leg.vt_symbol}"
+
+ # Spread data
+ self.bid_price: float = 0
+ self.ask_price: float = 0
+ self.bid_volume: float = 0
+ self.ask_volume: float = 0
+
+ self.net_pos: float = 0
+ self.datetime: datetime = None
+
+ def calculate_price(self):
+ """"""
+ self.clear_price()
+
+ # Go through all legs to calculate price
+ for n, leg in enumerate(self.legs.values()):
+ # Filter not all leg price data has been received
+ if not leg.bid_volume or not leg.ask_volume:
+ self.clear_price()
+ return
+
+ # Calculate price
+ price_multiplier = self.price_multipliers[leg.vt_symbol]
+ if price_multiplier > 0:
+ self.bid_price += leg.bid_price * price_multiplier
+ self.ask_price += leg.ask_price * price_multiplier
+ else:
+ self.bid_price += leg.ask_price * price_multiplier
+ self.ask_price += leg.bid_price * price_multiplier
+
+ # Calculate volume
+ trading_multiplier = self.trading_multipliers[leg.vt_symbol]
+
+ if trading_multiplier > 0:
+ adjusted_bid_volume = floor(
+ leg.bid_volume / trading_multiplier)
+ adjusted_ask_volume = floor(
+ leg.ask_volume / trading_multiplier)
+ else:
+ adjusted_bid_volume = floor(
+ leg.ask_volume / abs(trading_multiplier))
+ adjusted_ask_volume = floor(
+ leg.bid_volume / abs(trading_multiplier))
+
+ # For the first leg, just initialize
+ if not n:
+ self.bid_volume = adjusted_bid_volume
+ self.ask_volume = adjusted_ask_volume
+ # For following legs, use min value of each leg quoting volume
+ else:
+ self.bid_volume = min(self.bid_volume, adjusted_bid_volume)
+ self.ask_volume = min(self.ask_volume, adjusted_ask_volume)
+
+ # Update calculate time
+ self.datetime = datetime.now()
+
+ def calculate_pos(self):
+ """"""
+ self.net_pos = 0
+
+ for n, leg in enumerate(self.legs.values()):
+ trading_multiplier = self.trading_multipliers[leg.vt_symbol]
+ adjusted_net_pos = leg.net_pos / trading_multiplier
+
+ if adjusted_net_pos > 0:
+ adjusted_net_pos = floor(adjusted_net_pos)
+ else:
+ adjusted_net_pos = ceil(adjusted_net_pos)
+
+ if not n:
+ self.net_pos = adjusted_net_pos
+ else:
+ if adjusted_net_pos > 0:
+ self.net_pos = min(self.net_pos, adjusted_net_pos)
+ else:
+ self.net_pos = max(self.net_pos, adjusted_net_pos)
+
+ def clear_price(self):
+ """"""
+ self.bid_price = 0
+ self.ask_price = 0
+ self.bid_volume = 0
+ self.ask_volume = 0
+
+ def calculate_leg_volume(self, vt_symbol: str, spread_volume: float) -> float:
+ """"""
+ leg = self.legs[vt_symbol]
+ trading_multiplier = self.trading_multipliers[leg.vt_symbol]
+ leg_volume = spread_volume * trading_multiplier
+ return leg_volume
+
+ def calculate_spread_volume(self, vt_symbol: str, leg_volume: float) -> float:
+ """"""
+ leg = self.legs[vt_symbol]
+ trading_multiplier = self.trading_multipliers[leg.vt_symbol]
+ spread_volume = leg_volume / trading_multiplier
+
+ if spread_volume > 0:
+ spread_volume = floor(spread_volume)
+ else:
+ spread_volume = ceil(spread_volume)
+
+ return spread_volume
+
+ def to_tick(self):
+ """"""
+ tick = TickData(
+ symbol=self.name,
+ exchange=Exchange.LOCAL,
+ datetime=self.datetime,
+ name=self.name,
+ last_price=(self.bid_price + self.ask_price) / 2,
+ bid_price_1=self.bid_price,
+ ask_price_1=self.ask_price,
+ bid_volume_1=self.bid_volume,
+ ask_volume_1=self.ask_volume,
+ gateway_name="SPREAD"
+ )
+ return tick
diff --git a/vnpy/app/spread_trading/engine.py b/vnpy/app/spread_trading/engine.py
new file mode 100644
index 00000000..837d299e
--- /dev/null
+++ b/vnpy/app/spread_trading/engine.py
@@ -0,0 +1,997 @@
+import traceback
+import importlib
+import os
+from typing import List, Dict, Set, Callable, Any, Type
+from collections import defaultdict
+from copy import copy
+from pathlib import Path
+
+from vnpy.event import EventEngine, Event
+from vnpy.trader.engine import BaseEngine, MainEngine
+from vnpy.trader.event import (
+ EVENT_TICK, EVENT_POSITION, EVENT_CONTRACT,
+ EVENT_ORDER, EVENT_TRADE, EVENT_TIMER
+)
+from vnpy.trader.utility import load_json, save_json
+from vnpy.trader.object import (
+ TickData, ContractData, LogData,
+ SubscribeRequest, OrderRequest
+)
+from vnpy.trader.constant import Direction, Offset, OrderType
+from vnpy.trader.converter import OffsetConverter
+
+from .base import (
+ LegData, SpreadData,
+ EVENT_SPREAD_DATA, EVENT_SPREAD_POS,
+ EVENT_SPREAD_ALGO, EVENT_SPREAD_LOG,
+ EVENT_SPREAD_STRATEGY
+)
+from .template import SpreadAlgoTemplate, SpreadStrategyTemplate
+from .algo import SpreadTakerAlgo
+
+
+APP_NAME = "SpreadTrading"
+
+
+class SpreadEngine(BaseEngine):
+ """"""
+
+ def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
+ """Constructor"""
+ super().__init__(main_engine, event_engine, APP_NAME)
+
+ self.active = False
+
+ self.data_engine: SpreadDataEngine = SpreadDataEngine(self)
+ self.algo_engine: SpreadAlgoEngine = SpreadAlgoEngine(self)
+ self.strategy_engine: SpreadStrategyEngine = SpreadStrategyEngine(self)
+
+ self.add_spread = self.data_engine.add_spread
+ self.remove_spread = self.data_engine.remove_spread
+ self.get_spread = self.data_engine.get_spread
+ self.get_all_spreads = self.data_engine.get_all_spreads
+
+ self.start_algo = self.algo_engine.start_algo
+ self.stop_algo = self.algo_engine.stop_algo
+
+ def start(self):
+ """"""
+ if self.active:
+ return
+ self.active = True
+
+ self.data_engine.start()
+ self.algo_engine.start()
+ self.strategy_engine.start()
+
+ def stop(self):
+ """"""
+ self.data_engine.stop()
+ self.algo_engine.stop()
+ self.strategy_engine.stop()
+
+ def write_log(self, msg: str):
+ """"""
+ log = LogData(
+ msg=msg,
+ gateway_name=APP_NAME
+ )
+ event = Event(EVENT_SPREAD_LOG, log)
+ self.event_engine.put(event)
+
+
+class SpreadDataEngine:
+ """"""
+ setting_filename = "spread_trading_setting.json"
+
+ def __init__(self, spread_engine: SpreadEngine):
+ """"""
+ self.spread_engine: SpreadEngine = spread_engine
+ self.main_engine: MainEngine = spread_engine.main_engine
+ self.event_engine: EventEngine = spread_engine.event_engine
+
+ self.write_log = spread_engine.write_log
+
+ self.legs: Dict[str, LegData] = {} # vt_symbol: leg
+ self.spreads: Dict[str, SpreadData] = {} # name: spread
+ self.symbol_spread_map: Dict[str, List[SpreadData]] = defaultdict(list)
+
+ def start(self):
+ """"""
+ self.load_setting()
+ self.register_event()
+
+ self.write_log("价差数据引擎启动成功")
+
+ def stop(self):
+ """"""
+ pass
+
+ def load_setting(self) -> None:
+ """"""
+ setting = load_json(self.setting_filename)
+
+ for spread_setting in setting:
+ self.add_spread(
+ spread_setting["name"],
+ spread_setting["leg_settings"],
+ spread_setting["active_symbol"],
+ save=False
+ )
+
+ def save_setting(self) -> None:
+ """"""
+ setting = []
+
+ for spread in self.spreads.values():
+ leg_settings = []
+ for leg in spread.legs.values():
+ price_multiplier = spread.price_multipliers[leg.vt_symbol]
+ trading_multiplier = spread.trading_multipliers[leg.vt_symbol]
+
+ leg_setting = {
+ "vt_symbol": leg.vt_symbol,
+ "price_multiplier": price_multiplier,
+ "trading_multiplier": trading_multiplier
+ }
+ leg_settings.append(leg_setting)
+
+ spread_setting = {
+ "name": spread.name,
+ "leg_settings": leg_settings,
+ "active_symbol": spread.active_leg.vt_symbol
+ }
+ setting.append(spread_setting)
+
+ save_json(self.setting_filename, setting)
+
+ def register_event(self) -> None:
+ """"""
+ self.event_engine.register(EVENT_TICK, self.process_tick_event)
+ self.event_engine.register(EVENT_TRADE, self.process_trade_event)
+ self.event_engine.register(EVENT_POSITION, self.process_position_event)
+ self.event_engine.register(EVENT_CONTRACT, self.process_contract_event)
+
+ def process_tick_event(self, event: Event) -> None:
+ """"""
+ tick = event.data
+
+ leg = self.legs.get(tick.vt_symbol, None)
+ if not leg:
+ return
+ leg.update_tick(tick)
+
+ for spread in self.symbol_spread_map[tick.vt_symbol]:
+ spread.calculate_price()
+ self.put_data_event(spread)
+
+ def process_position_event(self, event: Event) -> None:
+ """"""
+ position = event.data
+
+ leg = self.legs.get(position.vt_symbol, None)
+ if not leg:
+ return
+ leg.update_position(position)
+
+ for spread in self.symbol_spread_map[position.vt_symbol]:
+ spread.calculate_pos()
+ self.put_pos_event(spread)
+
+ def process_trade_event(self, event: Event) -> None:
+ """"""
+ trade = event.data
+
+ leg = self.legs.get(trade.vt_symbol, None)
+ if not leg:
+ return
+ leg.update_trade(trade)
+
+ for spread in self.symbol_spread_map[trade.vt_symbol]:
+ spread.calculate_pos()
+ self.put_pos_event(spread)
+
+ def process_contract_event(self, event: Event) -> None:
+ """"""
+ contract = event.data
+
+ if contract.vt_symbol in self.legs:
+ req = SubscribeRequest(
+ contract.symbol, contract.exchange
+ )
+ self.main_engine.subscribe(req, contract.gateway_name)
+
+ def put_data_event(self, spread: SpreadData) -> None:
+ """"""
+ event = Event(EVENT_SPREAD_DATA, spread)
+ self.event_engine.put(event)
+
+ def put_pos_event(self, spread: SpreadData) -> None:
+ """"""
+ event = Event(EVENT_SPREAD_POS, spread)
+ self.event_engine.put(event)
+
+ def get_leg(self, vt_symbol: str) -> LegData:
+ """"""
+ leg = self.legs.get(vt_symbol, None)
+
+ if not leg:
+ leg = LegData(vt_symbol)
+ self.legs[vt_symbol] = leg
+
+ # Subscribe market data
+ contract = self.main_engine.get_contract(vt_symbol)
+ if contract:
+ req = SubscribeRequest(
+ contract.symbol,
+ contract.exchange
+ )
+ self.main_engine.subscribe(req, contract.gateway_name)
+
+ return leg
+
+ def add_spread(
+ self,
+ name: str,
+ leg_settings: List[Dict],
+ active_symbol: str,
+ save: bool = True
+ ) -> None:
+ """"""
+ if name in self.spreads:
+ self.write_log("价差创建失败,名称重复:{}".format(name))
+ return
+
+ legs: List[LegData] = []
+ price_multipliers: Dict[str, int] = {}
+ trading_multipliers: Dict[str, int] = {}
+
+ for leg_setting in leg_settings:
+ vt_symbol = leg_setting["vt_symbol"]
+ leg = self.get_leg(vt_symbol)
+
+ legs.append(leg)
+ price_multipliers[vt_symbol] = leg_setting["price_multiplier"]
+ trading_multipliers[vt_symbol] = leg_setting["trading_multiplier"]
+
+ spread = SpreadData(
+ name,
+ legs,
+ price_multipliers,
+ trading_multipliers,
+ active_symbol
+ )
+ self.spreads[name] = spread
+
+ for leg in spread.legs.values():
+ self.symbol_spread_map[leg.vt_symbol].append(spread)
+
+ if save:
+ self.save_setting()
+
+ self.write_log("价差创建成功:{}".format(name))
+ self.put_data_event(spread)
+
+ def remove_spread(self, name: str) -> None:
+ """"""
+ if name not in self.spreads:
+ return
+
+ spread = self.spreads.pop(name)
+
+ for leg in spread.legs.values():
+ self.symbol_spread_map[leg.vt_symbol].remove(spread)
+
+ self.save_setting()
+ self.write_log("价差移除成功:{},重启后生效".format(name))
+
+ def get_spread(self, name: str) -> SpreadData:
+ """"""
+ spread = self.spreads.get(name, None)
+ return spread
+
+ def get_all_spreads(self) -> List[SpreadData]:
+ """"""
+ return list(self.spreads.values())
+
+
+class SpreadAlgoEngine:
+ """"""
+ algo_class = SpreadTakerAlgo
+
+ def __init__(self, spread_engine: SpreadEngine):
+ """"""
+ self.spread_engine: SpreadEngine = spread_engine
+ self.main_engine: MainEngine = spread_engine.main_engine
+ self.event_engine: EventEngine = spread_engine.event_engine
+
+ self.write_log = spread_engine.write_log
+
+ self.spreads: Dict[str: SpreadData] = {}
+ self.algos: Dict[str: SpreadAlgoTemplate] = {}
+
+ self.order_algo_map: dict[str: SpreadAlgoTemplate] = {}
+ self.symbol_algo_map: dict[str: SpreadAlgoTemplate] = defaultdict(list)
+
+ self.algo_count: int = 0
+ self.vt_tradeids: Set = set()
+
+ self.offset_converter: OffsetConverter = OffsetConverter(
+ self.main_engine
+ )
+
+ def start(self):
+ """"""
+ self.register_event()
+
+ self.write_log("价差算法引擎启动成功")
+
+ def stop(self):
+ """"""
+ for algo in self.algos.values():
+ self.stop_algo(algo)
+
+ def register_event(self):
+ """"""
+ self.event_engine.register(EVENT_TICK, self.process_tick_event)
+ self.event_engine.register(EVENT_ORDER, self.process_order_event)
+ self.event_engine.register(EVENT_TRADE, self.process_trade_event)
+ self.event_engine.register(EVENT_POSITION, self.process_position_event)
+ self.event_engine.register(EVENT_TIMER, self.process_timer_event)
+ self.event_engine.register(
+ EVENT_SPREAD_DATA, self.process_spread_event
+ )
+
+ def process_spread_event(self, event: Event):
+ """"""
+ spread: SpreadData = event.data
+ self.spreads[spread.name] = spread
+
+ def process_tick_event(self, event: Event):
+ """"""
+ tick = event.data
+ algos = self.symbol_algo_map[tick.vt_symbol]
+ if not algos:
+ return
+
+ buf = copy(algos)
+ for algo in buf:
+ if not algo.is_active():
+ algos.remove(algo)
+ else:
+ algo.update_tick(tick)
+
+ def process_order_event(self, event: Event):
+ """"""
+ order = event.data
+
+ self.offset_converter.update_order(order)
+
+ algo = self.order_algo_map.get(order.vt_orderid, None)
+ if algo and algo.is_active():
+ algo.update_order(order)
+
+ def process_trade_event(self, event: Event):
+ """"""
+ trade = event.data
+
+ # Filter duplicate trade push
+ if trade.vt_tradeid in self.vt_tradeids:
+ return
+ self.vt_tradeids.add(trade.vt_tradeid)
+
+ self.offset_converter.update_trade(trade)
+
+ algo = self.order_algo_map.get(trade.vt_orderid, None)
+ if algo and algo.is_active():
+ algo.update_trade(trade)
+
+ def process_position_event(self, event: Event):
+ """"""
+ position = event.data
+
+ self.offset_converter.update_position(position)
+
+ def process_timer_event(self, event: Event):
+ """"""
+ buf = list(self.algos.values())
+
+ for algo in buf:
+ if not algo.is_active():
+ self.algos.pop(algo.algoid)
+ else:
+ algo.update_timer()
+
+ def start_algo(
+ self,
+ spread_name: str,
+ direction: Direction,
+ price: float,
+ volume: float,
+ payup: int,
+ interval: int,
+ lock: bool
+ ) -> str:
+ # Find spread object
+ spread = self.spreads.get(spread_name, None)
+ if not spread:
+ self.write_log("创建价差算法失败,找不到价差:{}".format(spread_name))
+ return ""
+
+ # Generate algoid str
+ self.algo_count += 1
+ algo_count_str = str(self.algo_count).rjust(6, "0")
+ algoid = f"{self.algo_class.algo_name}_{algo_count_str}"
+
+ # Create algo object
+ algo = self.algo_class(
+ self,
+ algoid,
+ spread,
+ direction,
+ price,
+ volume,
+ payup,
+ interval,
+ lock
+ )
+ self.algos[algoid] = algo
+
+ # Generate map between vt_symbol and algo
+ for leg in spread.legs.values():
+ self.symbol_algo_map[leg.vt_symbol].append(algo)
+
+ # Put event to update GUI
+ self.put_algo_event(algo)
+
+ return algoid
+
+ def stop_algo(
+ self,
+ algoid: str
+ ):
+ """"""
+ algo = self.algos.get(algoid, None)
+ if not algo:
+ self.write_log("停止价差算法失败,找不到算法:{}".format(algoid))
+ return
+
+ algo.stop()
+
+ def put_algo_event(self, algo: SpreadAlgoTemplate) -> None:
+ """"""
+ event = Event(EVENT_SPREAD_ALGO, algo)
+ self.event_engine.put(event)
+
+ def write_algo_log(self, algo: SpreadAlgoTemplate, msg: str) -> None:
+ """"""
+ msg = f"{algo.algoid}:{msg}"
+ self.write_log(msg)
+
+ def send_order(
+ self,
+ algo: SpreadAlgoTemplate,
+ vt_symbol: str,
+ price: float,
+ volume: float,
+ direction: Direction,
+ lock: bool
+ ) -> List[str]:
+ """"""
+ holding = self.offset_converter.get_position_holding(vt_symbol)
+ contract = self.main_engine.get_contract(vt_symbol)
+
+ if direction == Direction.LONG:
+ available = holding.short_pos - holding.short_pos_frozen
+ else:
+ available = holding.long_pos - holding.long_pos_frozen
+
+ # If no position to close, just open new
+ if not available:
+ offset = Offset.OPEN
+ # If enougth position to close, just close old
+ elif volume < available:
+ offset = Offset.CLOSE
+ # Otherwise, just close existing position
+ else:
+ volume = available
+ offset = Offset.CLOSE
+
+ original_req = OrderRequest(
+ symbol=contract.symbol,
+ exchange=contract.exchange,
+ direction=direction,
+ offset=offset,
+ type=OrderType.LIMIT,
+ price=price,
+ volume=volume
+ )
+
+ # Convert with offset converter
+ req_list = self.offset_converter.convert_order_request(
+ original_req, lock)
+
+ # Send Orders
+ vt_orderids = []
+
+ for req in req_list:
+ vt_orderid = self.main_engine.send_order(
+ req, contract.gateway_name)
+
+ # Check if sending order successful
+ if not vt_orderid:
+ continue
+
+ vt_orderids.append(vt_orderid)
+
+ self.offset_converter.update_order_request(req, vt_orderid)
+
+ # Save relationship between orderid and algo.
+ self.order_algo_map[vt_orderid] = algo
+
+ return vt_orderids
+
+ def cancel_order(self, algo: SpreadAlgoTemplate, vt_orderid: str) -> None:
+ """"""
+ order = self.main_engine.get_order(vt_orderid)
+ if not order:
+ self.write_algo_log(algo, "撤单失败,找不到委托{}".format(vt_orderid))
+ return
+
+ req = order.create_cancel_request()
+ self.main_engine.cancel_order(req, order.gateway_name)
+
+ def get_tick(self, vt_symbol: str) -> TickData:
+ """"""
+ return self.main_engine.get_tick(vt_symbol)
+
+ def get_contract(self, vt_symbol: str) -> ContractData:
+ """"""
+ return self.main_engine.get_contract(vt_symbol)
+
+
+class SpreadStrategyEngine:
+ """"""
+
+ setting_filename = "spraed_trading_strategy.json"
+
+ def __init__(self, spread_engine: SpreadEngine):
+ """"""
+ self.spread_engine: SpreadEngine = spread_engine
+ self.main_engine: MainEngine = spread_engine.main_engine
+ self.event_engine: EventEngine = spread_engine.event_engine
+
+ self.write_log = spread_engine.write_log
+
+ self.strategy_setting: Dict[str: Dict] = {}
+
+ self.classes: Dict[str: Type[SpreadStrategyTemplate]] = {}
+ self.strategies: Dict[str: SpreadStrategyTemplate] = {}
+
+ self.order_strategy_map: dict[str: SpreadStrategyTemplate] = {}
+ self.algo_strategy_map: dict[str: SpreadStrategyTemplate] = {}
+ self.spread_strategy_map: dict[str: SpreadStrategyTemplate] = defaultdict(
+ list)
+
+ self.vt_tradeids: Set = set()
+
+ self.load_strategy_class()
+
+ def start(self):
+ """"""
+ self.load_strategy_setting()
+ self.register_event()
+
+ self.write_log("价差策略引擎启动成功")
+
+ def close(self):
+ """"""
+ self.stop_all_strategies()
+
+ def load_strategy_class(self):
+ """
+ Load strategy class from source code.
+ """
+ path1 = Path(__file__).parent.joinpath("strategies")
+ self.load_strategy_class_from_folder(
+ path1, "vnpy.app.spread_trading.strategies")
+
+ path2 = Path.cwd().joinpath("strategies")
+ self.load_strategy_class_from_folder(path2, "strategies")
+
+ def load_strategy_class_from_folder(self, path: Path, module_name: str = ""):
+ """
+ Load strategy class from certain folder.
+ """
+ for dirpath, dirnames, filenames in os.walk(str(path)):
+ for filename in filenames:
+ if filename.endswith(".py"):
+ strategy_module_name = ".".join(
+ [module_name, filename.replace(".py", "")])
+ elif filename.endswith(".pyd"):
+ strategy_module_name = ".".join(
+ [module_name, filename.split(".")[0]])
+
+ self.load_strategy_class_from_module(strategy_module_name)
+
+ def load_strategy_class_from_module(self, module_name: str):
+ """
+ Load strategy class from module file.
+ """
+ try:
+ module = importlib.import_module(module_name)
+
+ for name in dir(module):
+ value = getattr(module, name)
+ if (isinstance(value, type) and issubclass(value, SpreadStrategyTemplate) and value is not SpreadStrategyTemplate):
+ self.classes[value.__name__] = value
+ except: # noqa
+ msg = f"策略文件{module_name}加载失败,触发异常:\n{traceback.format_exc()}"
+ self.write_log(msg)
+
+ def get_all_strategy_class_names(self):
+ """"""
+ return list(self.classes.keys())
+
+ def load_strategy_setting(self):
+ """
+ Load setting file.
+ """
+ self.strategy_setting = load_json(self.setting_filename)
+
+ for strategy_name, strategy_config in self.strategy_setting.items():
+ self.add_strategy(
+ strategy_config["class_name"],
+ strategy_name,
+ strategy_config["spread_name"],
+ strategy_config["setting"]
+ )
+
+ def update_strategy_setting(self, strategy_name: str, setting: dict):
+ """
+ Update setting file.
+ """
+ strategy = self.strategies[strategy_name]
+
+ self.strategy_setting[strategy_name] = {
+ "class_name": strategy.__class__.__name__,
+ "spread_name": strategy.spread_name,
+ "setting": setting,
+ }
+ save_json(self.setting_filename, self.strategy_setting)
+
+ def remove_strategy_setting(self, strategy_name: str):
+ """
+ Update setting file.
+ """
+ if strategy_name not in self.strategy_setting:
+ return
+
+ self.strategy_setting.pop(strategy_name)
+ save_json(self.setting_filename, self.strategy_setting)
+
+ def register_event(self):
+ """"""
+ ee = self.event_engine
+ ee.register(EVENT_ORDER, self.process_order_event)
+ ee.register(EVENT_TRADE, self.process_trade_event)
+ ee.register(EVENT_SPREAD_DATA, self.process_spread_data_event)
+ ee.register(EVENT_SPREAD_POS, self.process_spread_pos_event)
+ ee.register(EVENT_SPREAD_ALGO, self.process_spread_algo_event)
+
+ def process_spread_data_event(self, event: Event):
+ """"""
+ spread = event.data
+ strategies = self.spread_strategy_map[spread.name]
+
+ for strategy in strategies:
+ if strategy.inited:
+ self.call_strategy_func(strategy, strategy.on_spread_data)
+
+ def process_spread_pos_event(self, event: Event):
+ """"""
+ spread = event.data
+ strategies = self.spread_strategy_map[spread.name]
+
+ for strategy in strategies:
+ if strategy.inited:
+ self.call_strategy_func(strategy, strategy.on_spread_pos)
+
+ def process_spread_algo_event(self, event: Event):
+ """"""
+ algo = event.data
+ strategy = self.algo_strategy_map.get(algo.algoid, None)
+
+ if strategy:
+ self.call_strategy_func(strategy, strategy.update_spread_algo, algo)
+
+ def process_order_event(self, event: Event):
+ """"""
+ order = event.data
+ strategy = self.order_strategy_map.get(order.vt_orderid, None)
+
+ if strategy:
+ self.call_strategy_func(strategy, strategy.update_order, order)
+
+ def process_trade_event(self, event: Event):
+ """"""
+ trade = event.data
+ strategy = self.order_strategy_map.get(trade.vt_orderid, None)
+
+ if strategy:
+ self.call_strategy_func(strategy, strategy.on_trade, trade)
+
+ def call_strategy_func(
+ self, strategy: SpreadStrategyTemplate, func: Callable, params: Any = None
+ ):
+ """
+ Call function of a strategy and catch any exception raised.
+ """
+ try:
+ if params:
+ func(params)
+ else:
+ func()
+ except Exception:
+ strategy.trading = False
+ strategy.inited = False
+
+ msg = f"触发异常已停止\n{traceback.format_exc()}"
+ self.write_strategy_log(strategy, msg)
+
+ def add_strategy(
+ self, class_name: str, strategy_name: str, spread_name: str, setting: dict
+ ):
+ """
+ Add a new strategy.
+ """
+ if strategy_name in self.strategies:
+ self.write_log(f"创建策略失败,存在重名{strategy_name}")
+ return
+
+ strategy_class = self.classes.get(class_name, None)
+ if not strategy_class:
+ self.write_log(f"创建策略失败,找不到策略类{class_name}")
+ return
+
+ spread = self.spread_engine.get_spread(spread_name)
+ if not spread:
+ self.write_log(f"创建策略失败,找不到价差{spread_name}")
+ return
+
+ strategy = strategy_class(self, strategy_name, spread, setting)
+ self.strategies[strategy_name] = strategy
+
+ # Add vt_symbol to strategy map.
+ strategies = self.spread_strategy_map[spread_name]
+ strategies.append(strategy)
+
+ # Update to setting file.
+ self.update_strategy_setting(strategy_name, setting)
+
+ self.put_strategy_event(strategy)
+
+ def edit_strategy(self, strategy_name: str, setting: dict):
+ """
+ Edit parameters of a strategy.
+ """
+ strategy = self.strategies[strategy_name]
+ strategy.update_setting(setting)
+
+ self.update_strategy_setting(strategy_name, setting)
+ self.put_strategy_event(strategy)
+
+ def remove_strategy(self, strategy_name: str):
+ """
+ Remove a strategy.
+ """
+ strategy = self.strategies[strategy_name]
+ if strategy.trading:
+ self.write_log(f"策略{strategy.strategy_name}移除失败,请先停止")
+ return
+
+ # Remove setting
+ self.remove_strategy_setting(strategy_name)
+
+ # Remove from symbol strategy map
+ strategies = self.spread_strategy_map[strategy.spread_name]
+ strategies.remove(strategy)
+
+ # Remove from strategies
+ self.strategies.pop(strategy_name)
+
+ return True
+
+ def init_strategy(self, strategy_name: str):
+ """"""
+ strategy = self.strategies[strategy_name]
+
+ if strategy.inited:
+ self.write_log(f"{strategy_name}已经完成初始化,禁止重复操作")
+ return
+
+ self.call_strategy_func(strategy, strategy.on_init)
+ strategy.inited = True
+
+ self.put_strategy_event(strategy)
+ self.write_log(f"{strategy_name}初始化完成")
+
+ def start_strategy(self, strategy_name: str):
+ """"""
+ strategy = self.strategies[strategy_name]
+ if not strategy.inited:
+ self.write_log(f"策略{strategy.strategy_name}启动失败,请先初始化")
+ return
+
+ if strategy.trading:
+ self.write_log(f"{strategy_name}已经启动,请勿重复操作")
+ return
+
+ self.call_strategy_func(strategy, strategy.on_start)
+ strategy.trading = True
+
+ self.put_strategy_event(strategy)
+
+ def stop_strategy(self, strategy_name: str):
+ """"""
+ strategy = self.strategies[strategy_name]
+ if not strategy.trading:
+ return
+
+ self.call_strategy_func(strategy, strategy.on_stop)
+
+ strategy.stop_all_algos()
+ strategy.cancel_all_orders()
+
+ strategy.trading = False
+
+ self.put_strategy_event(strategy)
+
+ def init_all_strategies(self):
+ """"""
+ for strategy in self.strategies.keys():
+ self.init_strategy(strategy)
+
+ def start_all_strategies(self):
+ """"""
+ for strategy in self.strategies.keys():
+ self.start_strategy(strategy)
+
+ def stop_all_strategies(self):
+ """"""
+ for strategy in self.strategies.keys():
+ self.stop_strategy(strategy)
+
+ def get_strategy_class_parameters(self, class_name: str):
+ """
+ Get default parameters of a strategy class.
+ """
+ strategy_class = self.classes[class_name]
+
+ parameters = {}
+ for name in strategy_class.parameters:
+ parameters[name] = getattr(strategy_class, name)
+
+ return parameters
+
+ def get_strategy_parameters(self, strategy_name):
+ """
+ Get parameters of a strategy.
+ """
+ strategy = self.strategies[strategy_name]
+ return strategy.get_parameters()
+
+ def start_algo(
+ self,
+ strategy: SpreadStrategyTemplate,
+ spread_name: str,
+ direction: Direction,
+ price: float,
+ volume: float,
+ payup: int,
+ interval: int,
+ lock: bool
+ ) -> str:
+ """"""
+ algoid = self.spread_engine.start_algo(
+ spread_name,
+ direction,
+ price,
+ volume,
+ payup,
+ interval,
+ lock
+ )
+
+ self.algo_strategy_map[algoid] = strategy
+
+ return algoid
+
+ def stop_algo(self, strategy: SpreadStrategyTemplate, algoid: str):
+ """"""
+ self.spread_engine.stop_algo(algoid)
+
+ def stop_all_algos(self, strategy: SpreadStrategyTemplate):
+ """"""
+ pass
+
+ def send_order(
+ self,
+ strategy: SpreadStrategyTemplate,
+ vt_symbol: str,
+ price: float,
+ volume: float,
+ direction: Direction,
+ offset: Offset,
+ lock: bool
+ ) -> List[str]:
+ contract = self.main_engine.get_contract(vt_symbol)
+
+ original_req = OrderRequest(
+ symbol=contract.symbol,
+ exchange=contract.exchange,
+ direction=direction,
+ offset=offset,
+ type=OrderType.LIMIT,
+ price=price,
+ volume=volume
+ )
+
+ # Convert with offset converter
+ req_list = self.offset_converter.convert_order_request(
+ original_req, lock)
+
+ # Send Orders
+ vt_orderids = []
+
+ for req in req_list:
+ vt_orderid = self.main_engine.send_order(
+ req, contract.gateway_name)
+
+ # Check if sending order successful
+ if not vt_orderid:
+ continue
+
+ vt_orderids.append(vt_orderid)
+
+ self.offset_converter.update_order_request(req, vt_orderid)
+
+ # Save relationship between orderid and strategy.
+ self.order_strategy_map[vt_orderid] = strategy
+
+ return vt_orderids
+
+ def cancel_order(self, strategy: SpreadStrategyTemplate, vt_orderid: str):
+ """"""
+ order = self.main_engine.get_order(vt_orderid)
+ if not order:
+ self.write_strategy_log(
+ strategy, "撤单失败,找不到委托{}".format(vt_orderid))
+ return
+
+ req = order.create_cancel_request()
+ self.main_engine.cancel_order(req, order.gateway_name)
+
+ def cancel_all_orders(self, strategy: SpreadStrategyTemplate):
+ """"""
+ pass
+
+ def put_strategy_event(self, strategy: SpreadStrategyTemplate):
+ """"""
+ data = strategy.get_data()
+ event = Event(EVENT_SPREAD_STRATEGY, data)
+ self.event_engine.put(event)
+
+ def write_strategy_log(self, strategy: SpreadStrategyTemplate, msg: str):
+ """"""
+ msg = f"{strategy.strategy_name}:{msg}"
+ self.write_log(msg)
+
+ def send_strategy_email(self, strategy: SpreadStrategyTemplate, msg: str):
+ """"""
+ if strategy:
+ subject = f"{strategy.strategy_name}"
+ else:
+ subject = "价差策略引擎"
+
+ self.main_engine.send_email(subject, msg)
diff --git a/vnpy/app/spread_trading/strategies/__init__.py b/vnpy/app/spread_trading/strategies/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/vnpy/app/spread_trading/strategies/basic_spread_strategy.py b/vnpy/app/spread_trading/strategies/basic_spread_strategy.py
new file mode 100644
index 00000000..b1213531
--- /dev/null
+++ b/vnpy/app/spread_trading/strategies/basic_spread_strategy.py
@@ -0,0 +1,168 @@
+from vnpy.app.spread_trading import (
+ SpreadStrategyTemplate,
+ SpreadAlgoTemplate,
+ SpreadData,
+ OrderData,
+ TradeData
+)
+
+
+class BasicSpreadStrategy(SpreadStrategyTemplate):
+ """"""
+
+ author = "用Python的交易员"
+
+ buy_price = 0.0
+ sell_price = 0.0
+ cover_price = 0.0
+ short_price = 0.0
+ max_pos = 0.0
+ payup = 10
+ interval = 5
+
+ spread_pos = 0.0
+ buy_algoid = ""
+ sell_algoid = ""
+ short_algoid = ""
+ cover_algoid = ""
+
+ parameters = [
+ "buy_price",
+ "sell_price",
+ "cover_price",
+ "short_price",
+ "max_pos",
+ "payup",
+ "interval"
+ ]
+ variables = [
+ "spread_pos",
+ "buy_algoid",
+ "sell_algoid",
+ "short_algoid",
+ "cover_algoid",
+ ]
+
+ def __init__(
+ self,
+ strategy_engine,
+ strategy_name: str,
+ spread: SpreadData,
+ setting: dict
+ ):
+ """"""
+ super().__init__(
+ strategy_engine, strategy_name, spread, setting
+ )
+
+ def on_init(self):
+ """
+ Callback when strategy is inited.
+ """
+ self.write_log("策略初始化")
+
+ def on_start(self):
+ """
+ Callback when strategy is started.
+ """
+ self.write_log("策略启动")
+
+ def on_stop(self):
+ """
+ Callback when strategy is stopped.
+ """
+ self.write_log("策略停止")
+
+ self.buy_algoid = ""
+ self.sell_algoid = ""
+ self.short_algoid = ""
+ self.cover_algoid = ""
+ self.put_event()
+
+ def on_spread_data(self):
+ """
+ Callback when spread price is updated.
+ """
+ self.spread_pos = self.get_spread_pos()
+
+ # No position
+ if not self.spread_pos:
+ # Start open algos
+ if not self.buy_algoid:
+ self.buy_algoid = self.start_long_algo(
+ self.buy_price, self.max_pos, self.payup, self.interval
+ )
+
+ if not self.short_algoid:
+ self.short_algoid = self.start_short_algo(
+ self.short_price, self.max_pos, self.payup, self.interval
+ )
+
+ # Stop close algos
+ if self.sell_algoid:
+ self.stop_algo(self.sell_algoid)
+
+ if self.cover_algoid:
+ self.stop_algo(self.cover_algoid)
+
+ # Long position
+ elif self.spread_pos > 0:
+ # Start sell close algo
+ if not self.sell_algoid:
+ self.sell_algoid = self.start_short_algo(
+ self.sell_price, self.spread_pos, self.payup, self.interval
+ )
+
+ # Stop buy open algo
+ if self.buy_algoid:
+ self.stop_algo(self.buy_algoid)
+
+ # Short position
+ elif self.spread_pos < 0:
+ # Start cover close algo
+ if not self.cover_algoid:
+ self.cover_algoid = self.start_long_algo(
+ self.cover_price, abs(
+ self.spread_pos), self.payup, self.interval
+ )
+
+ # Stop short open algo
+ if self.short_algoid:
+ self.stop_algo(self.short_algoid)
+
+ self.put_event()
+
+ def on_spread_pos(self):
+ """
+ Callback when spread position is updated.
+ """
+ self.spread_pos = self.get_spread_pos()
+ self.put_event()
+
+ def on_spread_algo(self, algo: SpreadAlgoTemplate):
+ """
+ Callback when algo status is updated.
+ """
+ if not algo.is_active():
+ if self.buy_algoid == algo.algoid:
+ self.buy_algoid = ""
+ elif self.sell_algoid == algo.algoid:
+ self.sell_algoid = ""
+ elif self.short_algoid == algo.algoid:
+ self.short_algoid = ""
+ else:
+ self.cover_algoid = ""
+
+ self.put_event()
+
+ def on_order(self, order: OrderData):
+ """
+ Callback when order status is updated.
+ """
+ pass
+
+ def on_trade(self, trade: TradeData):
+ """
+ Callback when new trade data is received.
+ """
+ pass
diff --git a/vnpy/app/spread_trading/template.py b/vnpy/app/spread_trading/template.py
new file mode 100644
index 00000000..de4ca343
--- /dev/null
+++ b/vnpy/app/spread_trading/template.py
@@ -0,0 +1,592 @@
+
+from collections import defaultdict
+from typing import Dict, List, Set
+from math import floor, ceil
+from copy import copy
+
+from vnpy.trader.object import TickData, TradeData, OrderData, ContractData
+from vnpy.trader.constant import Direction, Status, Offset
+from vnpy.trader.utility import virtual
+
+from .base import SpreadData
+
+
+class SpreadAlgoTemplate:
+ """
+ Template for implementing spread trading algos.
+ """
+ algo_name = "AlgoTemplate"
+
+ def __init__(
+ self,
+ algo_engine,
+ algoid: str,
+ spread: SpreadData,
+ direction: Direction,
+ price: float,
+ volume: float,
+ payup: int,
+ interval: int,
+ lock: bool
+ ):
+ """"""
+ self.algo_engine = algo_engine
+ self.algoid: str = algoid
+
+ self.spread: SpreadData = spread
+ self.spread_name: str = spread.name
+
+ self.direction: Direction = direction
+ self.price: float = price
+ self.volume: float = volume
+ self.payup: int = payup
+ self.interval = interval
+ self.lock = lock
+
+ if direction == Direction.LONG:
+ self.target = volume
+ else:
+ self.target = -volume
+
+ self.status: Status = Status.NOTTRADED # Algo status
+ self.count: int = 0 # Timer count
+ self.traded: float = 0 # Volume traded
+ self.traded_volume: float = 0 # Volume traded (Abs value)
+
+ self.leg_traded: Dict[str, float] = defaultdict(int)
+ self.leg_orders: Dict[str, List[str]] = defaultdict(list)
+
+ self.write_log("算法已启动")
+
+ def is_active(self):
+ """"""
+ if self.status not in [Status.CANCELLED, Status.ALLTRADED]:
+ return True
+ else:
+ return False
+
+ def check_order_finished(self):
+ """"""
+ finished = True
+
+ for leg in self.spread.legs.values():
+ vt_orderids = self.leg_orders[leg.vt_symbol]
+
+ if vt_orderids:
+ finished = False
+ break
+
+ return finished
+
+ def check_hedge_finished(self):
+ """"""
+ active_symbol = self.spread.active_leg.vt_symbol
+ active_traded = self.leg_traded[active_symbol]
+
+ spread_volume = self.spread.calculate_spread_volume(
+ active_symbol, active_traded
+ )
+
+ finished = True
+
+ for leg in self.spread.passive_legs:
+ passive_symbol = leg.vt_symbol
+
+ leg_target = self.spread.calculate_leg_volume(
+ passive_symbol, spread_volume
+ )
+ leg_traded = self.leg_traded[passive_symbol]
+
+ if leg_traded != leg_target:
+ finished = False
+ break
+
+ return finished
+
+ def stop(self):
+ """"""
+ if self.is_active():
+ self.cancel_all_order()
+ self.status = Status.CANCELLED
+ self.write_log("算法已停止")
+ self.put_event()
+
+ def update_tick(self, tick: TickData):
+ """"""
+ self.on_tick(tick)
+
+ def update_trade(self, trade: TradeData):
+ """"""
+ if trade.direction == Direction.LONG:
+ self.leg_traded[trade.vt_symbol] += trade.volume
+ else:
+ self.leg_traded[trade.vt_symbol] -= trade.volume
+
+ msg = "委托成交,{},{},{}@{}".format(
+ trade.vt_symbol,
+ trade.direction,
+ trade.volume,
+ trade.price
+ )
+ self.write_log(msg)
+
+ self.calculate_traded()
+ self.put_event()
+
+ self.on_trade(trade)
+
+ def update_order(self, order: OrderData):
+ """"""
+ if not order.is_active():
+ vt_orderids = self.leg_orders[order.vt_symbol]
+ if order.vt_orderid in vt_orderids:
+ vt_orderids.remove(order.vt_orderid)
+
+ self.on_order(order)
+
+ def update_timer(self):
+ """"""
+ self.count += 1
+ if self.count > self.interval:
+ self.count = 0
+ self.on_interval()
+
+ self.put_event()
+
+ def put_event(self):
+ """"""
+ self.algo_engine.put_algo_event(self)
+
+ def write_log(self, msg: str):
+ """"""
+ self.algo_engine.write_algo_log(self, msg)
+
+ def send_long_order(self, vt_symbol: str, price: float, volume: float):
+ """"""
+ self.send_order(vt_symbol, price, volume, Direction.LONG)
+
+ def send_short_order(self, vt_symbol: str, price: float, volume: float):
+ """"""
+ self.send_order(vt_symbol, price, volume, Direction.SHORT)
+
+ def send_order(
+ self,
+ vt_symbol: str,
+ price: float,
+ volume: float,
+ direction: Direction,
+ ):
+ """"""
+ vt_orderids = self.algo_engine.send_order(
+ self,
+ vt_symbol,
+ price,
+ volume,
+ direction,
+ self.lock
+ )
+
+ self.leg_orders[vt_symbol].extend(vt_orderids)
+
+ msg = "发出委托,{},{},{}@{}".format(
+ vt_symbol,
+ direction,
+ volume,
+ price
+ )
+ self.write_log(msg)
+
+ def cancel_leg_order(self, vt_symbol: str):
+ """"""
+ for vt_orderid in self.leg_orders[vt_symbol]:
+ self.algo_engine.cancel_order(self, vt_orderid)
+
+ def cancel_all_order(self):
+ """"""
+ for vt_symbol in self.leg_orders.keys():
+ self.cancel_leg_order(vt_symbol)
+
+ def calculate_traded(self):
+ """"""
+ self.traded = 0
+
+ for n, leg in enumerate(self.spread.legs.values()):
+ leg_traded = self.leg_traded[leg.vt_symbol]
+ trading_multiplier = self.spread.trading_multipliers[
+ leg.vt_symbol]
+ adjusted_leg_traded = leg_traded / trading_multiplier
+
+ if adjusted_leg_traded > 0:
+ adjusted_leg_traded = floor(adjusted_leg_traded)
+ else:
+ adjusted_leg_traded = ceil(adjusted_leg_traded)
+
+ if not n:
+ self.traded = adjusted_leg_traded
+ else:
+ if adjusted_leg_traded > 0:
+ self.traded = min(self.traded, adjusted_leg_traded)
+ elif adjusted_leg_traded < 0:
+ self.traded = max(self.traded, adjusted_leg_traded)
+ else:
+ self.traded = 0
+
+ self.traded_volume = abs(self.traded)
+
+ if self.traded == self.target:
+ self.status = Status.ALLTRADED
+ elif not self.traded:
+ self.status = Status.NOTTRADED
+ else:
+ self.status = Status.PARTTRADED
+
+ def get_tick(self, vt_symbol: str) -> TickData:
+ """"""
+ return self.algo_engine.get_tick(vt_symbol)
+
+ def get_contract(self, vt_symbol: str) -> ContractData:
+ """"""
+ return self.algo_engine.get_contract(vt_symbol)
+
+ @virtual
+ def on_tick(self, tick: TickData):
+ """"""
+ pass
+
+ @virtual
+ def on_order(self, order: OrderData):
+ """"""
+ pass
+
+ @virtual
+ def on_trade(self, trade: TradeData):
+ """"""
+ pass
+
+ @virtual
+ def on_interval(self):
+ """"""
+ pass
+
+
+class SpreadStrategyTemplate:
+ """
+ Template for implementing spread trading strategies.
+ """
+
+ author: str = ""
+ parameters: List[str] = []
+ variables: List[str] = []
+
+ def __init__(
+ self,
+ strategy_engine,
+ strategy_name: str,
+ spread: SpreadData,
+ setting: dict
+ ):
+ """"""
+ self.strategy_engine = strategy_engine
+ self.strategy_name = strategy_name
+ self.spread = spread
+ self.spread_name = spread.name
+
+ self.inited = False
+ self.trading = False
+
+ self.variables = copy(self.variables)
+ self.variables.insert(0, "inited")
+ self.variables.insert(1, "trading")
+
+ self.vt_orderids: Set[str] = set()
+ self.algoids: Set[str] = set()
+
+ self.update_setting(setting)
+
+ def update_setting(self, setting: dict):
+ """
+ Update strategy parameter wtih value in setting dict.
+ """
+ for name in self.parameters:
+ if name in setting:
+ setattr(self, name, setting[name])
+
+ @classmethod
+ def get_class_parameters(cls):
+ """
+ Get default parameters dict of strategy class.
+ """
+ class_parameters = {}
+ for name in cls.parameters:
+ class_parameters[name] = getattr(cls, name)
+ return class_parameters
+
+ def get_parameters(self):
+ """
+ Get strategy parameters dict.
+ """
+ strategy_parameters = {}
+ for name in self.parameters:
+ strategy_parameters[name] = getattr(self, name)
+ return strategy_parameters
+
+ def get_variables(self):
+ """
+ Get strategy variables dict.
+ """
+ strategy_variables = {}
+ for name in self.variables:
+ strategy_variables[name] = getattr(self, name)
+ return strategy_variables
+
+ def get_data(self):
+ """
+ Get strategy data.
+ """
+ strategy_data = {
+ "strategy_name": self.strategy_name,
+ "spread_name": self.spread_name,
+ "class_name": self.__class__.__name__,
+ "author": self.author,
+ "parameters": self.get_parameters(),
+ "variables": self.get_variables(),
+ }
+ return strategy_data
+
+ def update_spread_algo(self, algo: SpreadAlgoTemplate):
+ """
+ Callback when algo status is updated.
+ """
+ if not algo.is_active() and algo.algoid in self.algoids:
+ self.algoids.remove(algo.algoid)
+
+ self.on_spread_algo(algo)
+
+ def update_order(self, order: OrderData):
+ """
+ Callback when order status is updated.
+ """
+ if not order.is_active() and order.vt_orderid in self.vt_orderids:
+ self.vt_orderids.remove(order.vt_orderid)
+
+ self.on_order(order)
+
+ @virtual
+ def on_init(self):
+ """
+ Callback when strategy is inited.
+ """
+ pass
+
+ @virtual
+ def on_start(self):
+ """
+ Callback when strategy is started.
+ """
+ pass
+
+ @virtual
+ def on_stop(self):
+ """
+ Callback when strategy is stopped.
+ """
+ pass
+
+ @virtual
+ def on_spread_data(self):
+ """
+ Callback when spread price is updated.
+ """
+ pass
+
+ @virtual
+ def on_spread_pos(self):
+ """
+ Callback when spread position is updated.
+ """
+ pass
+
+ @virtual
+ def on_spread_algo(self, algo: SpreadAlgoTemplate):
+ """
+ Callback when algo status is updated.
+ """
+ pass
+
+ @virtual
+ def on_order(self, order: OrderData):
+ """
+ Callback when order status is updated.
+ """
+ pass
+
+ @virtual
+ def on_trade(self, trade: TradeData):
+ """
+ Callback when new trade data is received.
+ """
+ pass
+
+ def start_algo(
+ self,
+ direction: Direction,
+ price: float,
+ volume: float,
+ payup: int,
+ interval: int,
+ lock: bool
+ ) -> str:
+ """"""
+ if not self.trading:
+ return ""
+
+ algoid: str = self.strategy_engine.start_algo(
+ self,
+ self.spread_name,
+ direction,
+ price,
+ volume,
+ payup,
+ interval,
+ lock
+ )
+
+ self.algoids.add(algoid)
+
+ return algoid
+
+ def start_long_algo(
+ self,
+ price: float,
+ volume: float,
+ payup: int,
+ interval: int,
+ lock: bool = False
+ ) -> str:
+ """"""
+ return self.start_algo(Direction.LONG, price, volume, payup, interval, lock)
+
+ def start_short_algo(
+ self,
+ price: float,
+ volume: float,
+ payup: int,
+ interval: int,
+ lock: bool = False
+ ) -> str:
+ """"""
+ return self.start_algo(Direction.SHORT, price, volume, payup, interval, lock)
+
+ def stop_algo(self, algoid: str):
+ """"""
+ if not self.trading:
+ return
+
+ self.strategy_engine.stop_algo(self, algoid)
+
+ def stop_all_algos(self):
+ """"""
+ for algoid in self.algoids:
+ self.stop_algo(algoid)
+
+ def buy(self, vt_symbol: str, price: float, volume: float, lock: bool = False) -> List[str]:
+ """"""
+ return self.send_order(vt_symbol, price, volume, Direction.LONG, Offset.OPEN, lock)
+
+ def sell(self, vt_symbol: str, price: float, volume: float, lock: bool = False) -> List[str]:
+ """"""
+ return self.send_order(vt_symbol, price, volume, Direction.SHORT, Offset.CLOSE, lock)
+
+ def short(self, vt_symbol: str, price: float, volume: float, lock: bool = False) -> List[str]:
+ """"""
+ return self.send_order(vt_symbol, price, volume, Direction.SHORT, Offset.OPEN, lock)
+
+ def cover(self, vt_symbol: str, price: float, volume: float, lock: bool = False) -> List[str]:
+ """"""
+ return self.send_order(vt_symbol, price, volume, Direction.LONG, Offset.CLOSE, lock)
+
+ def send_order(
+ self,
+ vt_symbol: str,
+ price: float,
+ volume: float,
+ direction: Direction,
+ offset: Offset,
+ lock: bool
+ ) -> List[str]:
+ """"""
+ if not self.trading:
+ return []
+
+ vt_orderids: List[str] = self.strategy_engine.send_order(
+ self,
+ vt_symbol,
+ price,
+ volume,
+ direction,
+ offset,
+ lock
+ )
+
+ for vt_orderid in vt_orderids:
+ self.vt_orderids.add(vt_orderid)
+
+ return vt_orderids
+
+ def cancel_order(self, vt_orderid: str):
+ """"""
+ if not self.trading:
+ return
+
+ self.strategy_engine.cancel_order(self, vt_orderid)
+
+ def cancel_all_orders(self):
+ """"""
+ for vt_orderid in self.vt_orderids:
+ self.cancel_order(vt_orderid)
+
+ def put_event(self):
+ """"""
+ self.strategy_engine.put_strategy_event(self)
+
+ def write_log(self, msg: str):
+ """"""
+ self.strategy_engine.write_strategy_log(self, msg)
+
+ def get_spread_tick(self) -> TickData:
+ """"""
+ return self.spread.to_tick()
+
+ def get_spread_pos(self) -> float:
+ """"""
+ return self.spread.net_pos
+
+ def get_leg_tick(self, vt_symbol: str) -> TickData:
+ """"""
+ leg = self.spread.legs.get(vt_symbol, None)
+
+ if not leg:
+ return None
+
+ return leg.tick
+
+ def get_leg_pos(self, vt_symbol: str, direction: Direction = Direction.NET) -> float:
+ """"""
+ leg = self.spread.legs.get(vt_symbol, None)
+
+ if not leg:
+ return None
+
+ if direction == Direction.NET:
+ return leg.net_pos
+ elif direction == Direction.LONG:
+ return leg.long_pos
+ else:
+ return leg.short_pos
+
+ def send_email(self, msg: str):
+ """
+ Send email to default receiver.
+ """
+ if self.inited:
+ self.strategy_engine.send_email(msg, self)
diff --git a/vnpy/app/spread_trading/ui/__init__.py b/vnpy/app/spread_trading/ui/__init__.py
new file mode 100644
index 00000000..c7639754
--- /dev/null
+++ b/vnpy/app/spread_trading/ui/__init__.py
@@ -0,0 +1 @@
+from .widget import SpreadManager
diff --git a/vnpy/app/spread_trading/ui/spread.ico b/vnpy/app/spread_trading/ui/spread.ico
new file mode 100644
index 00000000..05b3d571
Binary files /dev/null and b/vnpy/app/spread_trading/ui/spread.ico differ
diff --git a/vnpy/app/spread_trading/ui/widget.py b/vnpy/app/spread_trading/ui/widget.py
new file mode 100644
index 00000000..00ef63a5
--- /dev/null
+++ b/vnpy/app/spread_trading/ui/widget.py
@@ -0,0 +1,810 @@
+"""
+Widget for spread trading.
+"""
+
+from vnpy.event import EventEngine, Event
+from vnpy.trader.engine import MainEngine
+from vnpy.trader.constant import Direction
+from vnpy.trader.ui import QtWidgets, QtCore, QtGui
+from vnpy.trader.ui.widget import (
+ BaseMonitor, BaseCell,
+ BidCell, AskCell,
+ TimeCell, PnlCell,
+ DirectionCell, EnumCell,
+)
+
+from ..engine import (
+ SpreadEngine,
+ SpreadStrategyEngine,
+ APP_NAME,
+ EVENT_SPREAD_DATA,
+ EVENT_SPREAD_POS,
+ EVENT_SPREAD_LOG,
+ EVENT_SPREAD_ALGO,
+ EVENT_SPREAD_STRATEGY
+)
+
+
+class SpreadManager(QtWidgets.QWidget):
+ """"""
+
+ def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
+ """"""
+ super().__init__()
+
+ self.main_engine = main_engine
+ self.event_engine = event_engine
+
+ self.spread_engine = main_engine.get_engine(APP_NAME)
+
+ self.init_ui()
+
+ def init_ui(self):
+ """"""
+ self.setWindowTitle("价差交易")
+
+ self.algo_dialog = SpreadAlgoWidget(self.spread_engine)
+ algo_group = self.create_group("交易", self.algo_dialog)
+ algo_group.setMaximumWidth(300)
+
+ self.data_monitor = SpreadDataMonitor(
+ self.main_engine,
+ self.event_engine
+ )
+ self.log_monitor = SpreadLogMonitor(
+ self.main_engine,
+ self.event_engine
+ )
+ self.algo_monitor = SpreadAlgoMonitor(
+ self.spread_engine
+ )
+
+ self.strategy_monitor = SpreadStrategyMonitor(
+ self.spread_engine
+ )
+
+ grid = QtWidgets.QGridLayout()
+ grid.addWidget(self.create_group("价差", self.data_monitor), 0, 0)
+ grid.addWidget(self.create_group("日志", self.log_monitor), 1, 0)
+ grid.addWidget(self.create_group("算法", self.algo_monitor), 0, 1)
+ grid.addWidget(self.create_group("策略", self.strategy_monitor), 1, 1)
+
+ hbox = QtWidgets.QHBoxLayout()
+ hbox.addWidget(algo_group)
+ hbox.addLayout(grid)
+
+ self.setLayout(hbox)
+
+ def show(self):
+ """"""
+ self.spread_engine.start()
+ self.algo_dialog.update_class_combo()
+ self.showMaximized()
+
+ def create_group(self, title: str, widget: QtWidgets.QWidget):
+ """"""
+ group = QtWidgets.QGroupBox()
+
+ vbox = QtWidgets.QVBoxLayout()
+ vbox.addWidget(widget)
+
+ group.setLayout(vbox)
+ group.setTitle(title)
+
+ return group
+
+
+class SpreadDataMonitor(BaseMonitor):
+ """
+ Monitor for spread data.
+ """
+
+ event_type = EVENT_SPREAD_DATA
+ data_key = "name"
+ sorting = False
+
+ headers = {
+ "name": {"display": "名称", "cell": BaseCell, "update": False},
+ "bid_volume": {"display": "买量", "cell": BidCell, "update": True},
+ "bid_price": {"display": "买价", "cell": BidCell, "update": True},
+ "ask_price": {"display": "卖价", "cell": AskCell, "update": True},
+ "ask_volume": {"display": "卖量", "cell": AskCell, "update": True},
+ "net_pos": {"display": "净仓", "cell": PnlCell, "update": True},
+ "datetime": {"display": "时间", "cell": TimeCell, "update": True},
+ "price_formula": {"display": "定价", "cell": BaseCell, "update": False},
+ "trading_formula": {"display": "交易", "cell": BaseCell, "update": False},
+ }
+
+ def register_event(self):
+ """
+ Register event handler into event engine.
+ """
+ super().register_event()
+ self.event_engine.register(EVENT_SPREAD_POS, self.signal.emit)
+
+
+class SpreadLogMonitor(QtWidgets.QTextEdit):
+ """
+ Monitor for log data.
+ """
+ signal = QtCore.pyqtSignal(Event)
+
+ def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
+ """"""
+ super().__init__()
+
+ self.main_engine = main_engine
+ self.event_engine = event_engine
+
+ self.init_ui()
+ self.register_event()
+
+ def init_ui(self):
+ """"""
+ self.setReadOnly(True)
+
+ def register_event(self):
+ """"""
+ self.signal.connect(self.process_log_event)
+
+ self.event_engine.register(EVENT_SPREAD_LOG, self.signal.emit)
+
+ def process_log_event(self, event: Event):
+ """"""
+ log = event.data
+ msg = f"{log.time.strftime('%H:%M:%S')}\t{log.msg}"
+ self.append(msg)
+
+
+class SpreadAlgoMonitor(BaseMonitor):
+ """
+ Monitor for algo status.
+ """
+
+ event_type = EVENT_SPREAD_ALGO
+ data_key = "algoid"
+ sorting = False
+
+ headers = {
+ "algoid": {"display": "算法", "cell": BaseCell, "update": False},
+ "spread_name": {"display": "价差", "cell": BaseCell, "update": False},
+ "direction": {"display": "方向", "cell": DirectionCell, "update": False},
+ "price": {"display": "价格", "cell": BaseCell, "update": False},
+ "payup": {"display": "超价", "cell": BaseCell, "update": False},
+ "volume": {"display": "数量", "cell": BaseCell, "update": False},
+ "traded_volume": {"display": "成交", "cell": BaseCell, "update": True},
+ "interval": {"display": "间隔", "cell": BaseCell, "update": False},
+ "count": {"display": "计数", "cell": BaseCell, "update": True},
+ "status": {"display": "状态", "cell": EnumCell, "update": True},
+ }
+
+ def __init__(self, spread_engine: SpreadEngine):
+ """"""
+ super().__init__(spread_engine.main_engine, spread_engine.event_engine)
+
+ self.spread_engine = spread_engine
+
+ def init_ui(self):
+ """
+ Connect signal.
+ """
+ super().init_ui()
+
+ self.setToolTip("双击单元格停止算法")
+ self.itemDoubleClicked.connect(self.stop_algo)
+
+ def stop_algo(self, cell):
+ """
+ Stop algo if cell double clicked.
+ """
+ algo = cell.get_data()
+ self.spread_engine.stop_algo(algo.algoid)
+
+
+class SpreadAlgoWidget(QtWidgets.QFrame):
+ """"""
+
+ def __init__(self, spread_engine: SpreadEngine):
+ """"""
+ super().__init__()
+
+ self.spread_engine: SpreadEngine = spread_engine
+ self.strategy_engine: SpreadStrategyEngine = spread_engine.strategy_engine
+
+ self.init_ui()
+
+ def init_ui(self):
+ """"""
+ self.setWindowTitle("启动算法")
+ self.setFrameShape(self.Box)
+ self.setLineWidth(1)
+
+ self.name_line = QtWidgets.QLineEdit()
+
+ self.direction_combo = QtWidgets.QComboBox()
+ self.direction_combo.addItems(
+ [Direction.LONG.value, Direction.SHORT.value]
+ )
+
+ float_validator = QtGui.QDoubleValidator()
+
+ self.price_line = QtWidgets.QLineEdit()
+ self.price_line.setValidator(float_validator)
+
+ self.volume_line = QtWidgets.QLineEdit()
+ self.volume_line.setValidator(float_validator)
+
+ int_validator = QtGui.QIntValidator()
+
+ self.payup_line = QtWidgets.QLineEdit()
+ self.payup_line.setValidator(int_validator)
+
+ self.interval_line = QtWidgets.QLineEdit()
+ self.interval_line.setValidator(int_validator)
+
+ button_start = QtWidgets.QPushButton("启动")
+ button_start.clicked.connect(self.start_algo)
+
+ self.lock_combo = QtWidgets.QComboBox()
+ self.lock_combo.addItems(
+ ["否", "是"]
+ )
+
+ self.class_combo = QtWidgets.QComboBox()
+
+ add_button = QtWidgets.QPushButton("添加策略")
+ add_button.clicked.connect(self.add_strategy)
+
+ init_button = QtWidgets.QPushButton("全部初始化")
+ init_button.clicked.connect(self.strategy_engine.init_all_strategies)
+
+ start_button = QtWidgets.QPushButton("全部启动")
+ start_button.clicked.connect(self.strategy_engine.start_all_strategies)
+
+ stop_button = QtWidgets.QPushButton("全部停止")
+ stop_button.clicked.connect(self.strategy_engine.stop_all_strategies)
+
+ add_spread_button = QtWidgets.QPushButton("创建价差")
+ add_spread_button.clicked.connect(self.add_spread)
+
+ remove_spread_button = QtWidgets.QPushButton("移除价差")
+ remove_spread_button.clicked.connect(self.remove_spread)
+
+ form = QtWidgets.QFormLayout()
+ form.addRow("价差", self.name_line)
+ form.addRow("方向", self.direction_combo)
+ form.addRow("价格", self.price_line)
+ form.addRow("数量", self.volume_line)
+ form.addRow("超价", self.payup_line)
+ form.addRow("间隔", self.interval_line)
+ form.addRow("锁仓", self.lock_combo)
+ form.addRow(button_start)
+
+ vbox = QtWidgets.QVBoxLayout()
+ vbox.addLayout(form)
+ vbox.addStretch()
+ vbox.addWidget(self.class_combo)
+ vbox.addWidget(add_button)
+ vbox.addWidget(init_button)
+ vbox.addWidget(start_button)
+ vbox.addWidget(stop_button)
+ vbox.addStretch()
+ vbox.addWidget(add_spread_button)
+ vbox.addWidget(remove_spread_button)
+
+ self.setLayout(vbox)
+
+ def start_algo(self):
+ """"""
+ name = self.name_line.text()
+ direction = Direction(self.direction_combo.currentText())
+ price = float(self.price_line.text())
+ volume = float(self.volume_line.text())
+ payup = int(self.payup_line.text())
+ interval = int(self.interval_line.text())
+
+ lock_str = self.lock_combo.currentText()
+ if lock_str == "是":
+ lock = True
+ else:
+ lock = False
+
+ self.spread_engine.start_algo(
+ name, direction, price, volume, payup, interval, lock
+ )
+
+ def add_spread(self):
+ """"""
+ dialog = SpreadDataDialog(self.spread_engine)
+ dialog.exec_()
+
+ def remove_spread(self):
+ """"""
+ dialog = SpreadRemoveDialog(self.spread_engine)
+ dialog.exec_()
+
+ def update_class_combo(self):
+ """"""
+ self.class_combo.addItems(
+ self.strategy_engine.get_all_strategy_class_names()
+ )
+
+ def remove_strategy(self, strategy_name):
+ """"""
+ manager = self.managers.pop(strategy_name)
+ manager.deleteLater()
+
+ def add_strategy(self):
+ """"""
+ class_name = str(self.class_combo.currentText())
+ if not class_name:
+ return
+
+ parameters = self.strategy_engine.get_strategy_class_parameters(
+ class_name)
+ editor = SettingEditor(parameters, class_name=class_name)
+ n = editor.exec_()
+
+ if n == editor.Accepted:
+ setting = editor.get_setting()
+ spread_name = setting.pop("spread_name")
+ strategy_name = setting.pop("strategy_name")
+
+ self.strategy_engine.add_strategy(
+ class_name, strategy_name, spread_name, setting
+ )
+
+
+class SpreadDataDialog(QtWidgets.QDialog):
+ """"""
+
+ def __init__(self, spread_engine: SpreadEngine):
+ """"""
+ super().__init__()
+
+ self.spread_engine: SpreadEngine = spread_engine
+
+ self.leg_widgets = []
+
+ self.init_ui()
+
+ def init_ui(self):
+ """"""
+ self.setWindowTitle("创建价差")
+
+ self.name_line = QtWidgets.QLineEdit()
+ self.active_line = QtWidgets.QLineEdit()
+
+ self.grid = QtWidgets.QGridLayout()
+
+ button_add = QtWidgets.QPushButton("创建价差")
+ button_add.clicked.connect(self.add_spread)
+
+ Label = QtWidgets.QLabel
+
+ grid = QtWidgets.QGridLayout()
+ grid.addWidget(Label("价差名称"), 0, 0)
+ grid.addWidget(self.name_line, 0, 1, 1, 3)
+ grid.addWidget(Label("主动腿代码"), 1, 0)
+ grid.addWidget(self.active_line, 1, 1, 1, 3)
+
+ grid.addWidget(Label(""), 2, 0)
+ grid.addWidget(Label("本地代码"), 3, 1)
+ grid.addWidget(Label("价格乘数"), 3, 2)
+ grid.addWidget(Label("交易乘数"), 3, 3)
+
+ int_validator = QtGui.QIntValidator()
+
+ leg_count = 5
+ for i in range(leg_count):
+ symbol_line = QtWidgets.QLineEdit()
+
+ price_line = QtWidgets.QLineEdit()
+ price_line.setValidator(int_validator)
+
+ trading_line = QtWidgets.QLineEdit()
+ trading_line.setValidator(int_validator)
+
+ grid.addWidget(Label("腿{}".format(i + 1)), 4 + i, 0)
+ grid.addWidget(symbol_line, 4 + i, 1)
+ grid.addWidget(price_line, 4 + i, 2)
+ grid.addWidget(trading_line, 4 + i, 3)
+
+ d = {
+ "symbol": symbol_line,
+ "price": price_line,
+ "trading": trading_line
+ }
+ self.leg_widgets.append(d)
+
+ grid.addWidget(Label(""), 4 + leg_count, 0,)
+ grid.addWidget(button_add, 5 + leg_count, 0, 1, 4)
+
+ self.setLayout(grid)
+
+ def add_spread(self):
+ """"""
+ spread_name = self.name_line.text()
+ if not spread_name:
+ QtWidgets.QMessageBox.warning(
+ self,
+ "创建失败",
+ "请输入价差名称",
+ QtWidgets.QMessageBox.Ok
+ )
+ return
+
+ active_symbol = self.active_line.text()
+
+ leg_settings = {}
+ for d in self.leg_widgets:
+ try:
+ vt_symbol = d["symbol"].text()
+ price_multiplier = int(d["price"].text())
+ trading_multiplier = int(d["trading"].text())
+
+ leg_settings[vt_symbol] = {
+ "vt_symbol": vt_symbol,
+ "price_multiplier": price_multiplier,
+ "trading_multiplier": trading_multiplier
+ }
+ except ValueError:
+ pass
+
+ if len(leg_settings) < 2:
+ QtWidgets.QMessageBox.warning(
+ self,
+ "创建失败",
+ "价差最少需要2条腿",
+ QtWidgets.QMessageBox.Ok
+ )
+ return
+
+ if active_symbol not in leg_settings:
+ QtWidgets.QMessageBox.warning(
+ self,
+ "创建失败",
+ "各条腿中找不到主动腿代码",
+ QtWidgets.QMessageBox.Ok
+ )
+ return
+
+ self.spread_engine.add_spread(
+ spread_name,
+ list(leg_settings.values()),
+ active_symbol
+ )
+ self.accept()
+
+
+class SpreadRemoveDialog(QtWidgets.QDialog):
+ """"""
+
+ def __init__(self, spread_engine: SpreadEngine):
+ """"""
+ super().__init__()
+
+ self.spread_engine: SpreadEngine = spread_engine
+
+ self.init_ui()
+
+ def init_ui(self):
+ """"""
+ self.setWindowTitle("移除价差")
+ self.setMinimumWidth(300)
+
+ self.name_combo = QtWidgets.QComboBox()
+ spreads = self.spread_engine.get_all_spreads()
+ for spread in spreads:
+ self.name_combo.addItem(spread.name)
+
+ button_remove = QtWidgets.QPushButton("移除")
+ button_remove.clicked.connect(self.remove_spread)
+
+ hbox = QtWidgets.QHBoxLayout()
+ hbox.addWidget(self.name_combo)
+ hbox.addWidget(button_remove)
+
+ self.setLayout(hbox)
+
+ def remove_spread(self):
+ """"""
+ spread_name = self.name_combo.currentText()
+ self.spread_engine.remove_spread(spread_name)
+ self.accept()
+
+
+class SpreadStrategyMonitor(QtWidgets.QWidget):
+ """"""
+
+ signal_strategy = QtCore.pyqtSignal(Event)
+
+ def __init__(self, spread_engine: SpreadEngine):
+ super().__init__()
+
+ self.strategy_engine = spread_engine.strategy_engine
+ self.main_engine = spread_engine.main_engine
+ self.event_engine = spread_engine.event_engine
+
+ self.managers = {}
+
+ self.init_ui()
+ self.register_event()
+
+ def init_ui(self):
+ """"""
+ self.scroll_layout = QtWidgets.QVBoxLayout()
+ self.scroll_layout.addStretch()
+
+ scroll_widget = QtWidgets.QWidget()
+ scroll_widget.setLayout(self.scroll_layout)
+
+ scroll_area = QtWidgets.QScrollArea()
+ scroll_area.setWidgetResizable(True)
+ scroll_area.setWidget(scroll_widget)
+
+ vbox = QtWidgets.QVBoxLayout()
+ vbox.addWidget(scroll_area)
+ self.setLayout(vbox)
+
+ def register_event(self):
+ """"""
+ self.signal_strategy.connect(self.process_strategy_event)
+
+ self.event_engine.register(
+ EVENT_SPREAD_STRATEGY, self.signal_strategy.emit
+ )
+
+ def process_strategy_event(self, event):
+ """
+ Update strategy status onto its monitor.
+ """
+ data = event.data
+ strategy_name = data["strategy_name"]
+
+ if strategy_name in self.managers:
+ manager = self.managers[strategy_name]
+ manager.update_data(data)
+ else:
+ manager = SpreadStrategyWidget(self, self.strategy_engine, data)
+ self.scroll_layout.insertWidget(0, manager)
+ self.managers[strategy_name] = manager
+
+ def remove_strategy(self, strategy_name):
+ """"""
+ manager = self.managers.pop(strategy_name)
+ manager.deleteLater()
+
+
+class SpreadStrategyWidget(QtWidgets.QFrame):
+ """
+ Manager for a strategy
+ """
+
+ def __init__(
+ self,
+ strategy_monitor: SpreadStrategyMonitor,
+ strategy_engine: SpreadStrategyEngine,
+ data: dict
+ ):
+ """"""
+ super().__init__()
+
+ self.strategy_monitor = strategy_monitor
+ self.strategy_engine = strategy_engine
+
+ self.strategy_name = data["strategy_name"]
+ self._data = data
+
+ self.init_ui()
+
+ def init_ui(self):
+ """"""
+ self.setFixedHeight(300)
+ self.setFrameShape(self.Box)
+ self.setLineWidth(1)
+
+ init_button = QtWidgets.QPushButton("初始化")
+ init_button.clicked.connect(self.init_strategy)
+
+ start_button = QtWidgets.QPushButton("启动")
+ start_button.clicked.connect(self.start_strategy)
+
+ stop_button = QtWidgets.QPushButton("停止")
+ stop_button.clicked.connect(self.stop_strategy)
+
+ edit_button = QtWidgets.QPushButton("编辑")
+ edit_button.clicked.connect(self.edit_strategy)
+
+ remove_button = QtWidgets.QPushButton("移除")
+ remove_button.clicked.connect(self.remove_strategy)
+
+ strategy_name = self._data["strategy_name"]
+ spread_name = self._data["spread_name"]
+ class_name = self._data["class_name"]
+ author = self._data["author"]
+
+ label_text = (
+ f"{strategy_name} - {spread_name} ({class_name} by {author})"
+ )
+ label = QtWidgets.QLabel(label_text)
+ label.setAlignment(QtCore.Qt.AlignCenter)
+
+ self.parameters_monitor = StrategyDataMonitor(self._data["parameters"])
+ self.variables_monitor = StrategyDataMonitor(self._data["variables"])
+
+ hbox = QtWidgets.QHBoxLayout()
+ hbox.addWidget(init_button)
+ hbox.addWidget(start_button)
+ hbox.addWidget(stop_button)
+ hbox.addWidget(edit_button)
+ hbox.addWidget(remove_button)
+
+ vbox = QtWidgets.QVBoxLayout()
+ vbox.addWidget(label)
+ vbox.addLayout(hbox)
+ vbox.addWidget(self.parameters_monitor)
+ vbox.addWidget(self.variables_monitor)
+ self.setLayout(vbox)
+
+ def update_data(self, data: dict):
+ """"""
+ self._data = data
+
+ self.parameters_monitor.update_data(data["parameters"])
+ self.variables_monitor.update_data(data["variables"])
+
+ def init_strategy(self):
+ """"""
+ self.strategy_engine.init_strategy(self.strategy_name)
+
+ def start_strategy(self):
+ """"""
+ self.strategy_engine.start_strategy(self.strategy_name)
+
+ def stop_strategy(self):
+ """"""
+ self.strategy_engine.stop_strategy(self.strategy_name)
+
+ def edit_strategy(self):
+ """"""
+ strategy_name = self._data["strategy_name"]
+
+ parameters = self.strategy_engine.get_strategy_parameters(
+ strategy_name)
+ editor = SettingEditor(parameters, strategy_name=strategy_name)
+ n = editor.exec_()
+
+ if n == editor.Accepted:
+ setting = editor.get_setting()
+ self.strategy_engine.edit_strategy(strategy_name, setting)
+
+ def remove_strategy(self):
+ """"""
+ result = self.strategy_engine.remove_strategy(self.strategy_name)
+
+ # Only remove strategy gui manager if it has been removed from engine
+ if result:
+ self.strategy_monitor.remove_strategy(self.strategy_name)
+
+
+class StrategyDataMonitor(QtWidgets.QTableWidget):
+ """
+ Table monitor for parameters and variables.
+ """
+
+ def __init__(self, data: dict):
+ """"""
+ super().__init__()
+
+ self._data = data
+ self.cells = {}
+
+ self.init_ui()
+
+ def init_ui(self):
+ """"""
+ labels = list(self._data.keys())
+ self.setColumnCount(len(labels))
+ self.setHorizontalHeaderLabels(labels)
+
+ self.setRowCount(1)
+ self.verticalHeader().setSectionResizeMode(
+ QtWidgets.QHeaderView.Stretch
+ )
+ self.verticalHeader().setVisible(False)
+ self.setEditTriggers(self.NoEditTriggers)
+
+ for column, name in enumerate(self._data.keys()):
+ value = self._data[name]
+
+ cell = QtWidgets.QTableWidgetItem(str(value))
+ cell.setTextAlignment(QtCore.Qt.AlignCenter)
+
+ self.setItem(0, column, cell)
+ self.cells[name] = cell
+
+ def update_data(self, data: dict):
+ """"""
+ for name, value in data.items():
+ cell = self.cells[name]
+ cell.setText(str(value))
+
+
+class SettingEditor(QtWidgets.QDialog):
+ """
+ For creating new strategy and editing strategy parameters.
+ """
+
+ def __init__(
+ self, parameters: dict, strategy_name: str = "", class_name: str = ""
+ ):
+ """"""
+ super(SettingEditor, self).__init__()
+
+ self.parameters = parameters
+ self.strategy_name = strategy_name
+ self.class_name = class_name
+
+ self.edits = {}
+
+ self.init_ui()
+
+ def init_ui(self):
+ """"""
+ form = QtWidgets.QFormLayout()
+
+ # Add spread_name and name edit if add new strategy
+ if self.class_name:
+ self.setWindowTitle(f"添加策略:{self.class_name}")
+ button_text = "添加"
+ parameters = {"strategy_name": "", "spread_name": ""}
+ parameters.update(self.parameters)
+ else:
+ self.setWindowTitle(f"参数编辑:{self.strategy_name}")
+ button_text = "确定"
+ parameters = self.parameters
+
+ for name, value in parameters.items():
+ type_ = type(value)
+
+ edit = QtWidgets.QLineEdit(str(value))
+ if type_ is int:
+ validator = QtGui.QIntValidator()
+ edit.setValidator(validator)
+ elif type_ is float:
+ validator = QtGui.QDoubleValidator()
+ edit.setValidator(validator)
+
+ form.addRow(f"{name} {type_}", edit)
+
+ self.edits[name] = (edit, type_)
+
+ button = QtWidgets.QPushButton(button_text)
+ button.clicked.connect(self.accept)
+ form.addRow(button)
+
+ self.setLayout(form)
+
+ def get_setting(self):
+ """"""
+ setting = {}
+
+ if self.class_name:
+ setting["class_name"] = self.class_name
+
+ for name, tp in self.edits.items():
+ edit, type_ = tp
+ value_text = edit.text()
+
+ if type_ == bool:
+ if value_text == "True":
+ value = True
+ else:
+ value = False
+ else:
+ value = type_(value_text)
+
+ setting[name] = value
+
+ return setting
diff --git a/vnpy/gateway/bitfinex/bitfinex_gateway.py b/vnpy/gateway/bitfinex/bitfinex_gateway.py
index 47d58fdd..a37dd995 100644
--- a/vnpy/gateway/bitfinex/bitfinex_gateway.py
+++ b/vnpy/gateway/bitfinex/bitfinex_gateway.py
@@ -11,6 +11,8 @@ from datetime import datetime, timedelta
from urllib.parse import urlencode
from vnpy.api.rest import Request, RestClient
from vnpy.api.websocket import WebsocketClient
+from vnpy.event import Event
+from vnpy.trader.event import EVENT_TIMER
from vnpy.trader.constant import (
Direction,
@@ -92,6 +94,9 @@ class BitfinexGateway(BaseGateway):
"""Constructor"""
super(BitfinexGateway, self).__init__(event_engine, "BITFINEX")
+ self.timer_count = 0
+ self.resubscribe_interval = 60
+
self.rest_api = BitfinexRestApi(self)
self.ws_api = BitfinexWebsocketApi(self)
@@ -104,9 +109,10 @@ class BitfinexGateway(BaseGateway):
proxy_port = setting["proxy_port"]
self.rest_api.connect(key, secret, session, proxy_host, proxy_port)
-
self.ws_api.connect(key, secret, proxy_host, proxy_port)
+ self.event_engine.register(EVENT_TIMER, self.process_timer_event)
+
def subscribe(self, req: SubscribeRequest):
""""""
self.ws_api.subscribe(req)
@@ -136,6 +142,16 @@ class BitfinexGateway(BaseGateway):
self.rest_api.stop()
self.ws_api.stop()
+ def process_timer_event(self, event: Event):
+ """"""
+ self.timer_count += 1
+
+ if self.timer_count < self.resubscribe_interval:
+ return
+
+ self.timer_count = 0
+ self.ws_api.resubscribe()
+
class BitfinexRestApi(RestClient):
"""
@@ -359,11 +375,12 @@ class BitfinexWebsocketApi(WebsocketClient):
self.accounts = {}
self.orders = {}
self.trades = set()
- self.tickDict = {}
- self.bidDict = {}
- self.askDict = {}
- self.orderLocalDict = {}
- self.channelDict = {} # channel_id : (Channel, Symbol)
+ self.ticks = {}
+ self.bids = {}
+ self.asks = {}
+ self.channels = {} # channel_id : (Channel, Symbol)
+
+ self.subscribed = {}
def connect(
self, key: str, secret: str, proxy_host: str, proxy_port: int
@@ -378,12 +395,16 @@ class BitfinexWebsocketApi(WebsocketClient):
"""
Subscribe to tick data upate.
"""
+ if req.symbol not in self.subscribed:
+ self.subscribed[req.symbol] = req
+
d = {
"event": "subscribe",
"channel": "book",
"symbol": req.symbol,
}
self.send_packet(d)
+
d = {
"event": "subscribe",
"channel": "ticker",
@@ -393,6 +414,11 @@ class BitfinexWebsocketApi(WebsocketClient):
return int(round(time.time() * 1000))
+ def resubscribe(self):
+ """"""
+ for req in self.subscribed.values():
+ self.subscribe(req)
+
def _gen_unqiue_cid(self):
self.order_id += 1
local_oid = time.strftime("%y%m%d") + str(self.order_id)
@@ -463,7 +489,7 @@ class BitfinexWebsocketApi(WebsocketClient):
if data["event"] == "subscribed":
symbol = str(data["symbol"].replace("t", ""))
- self.channelDict[data["chanId"]] = (data["channel"], symbol)
+ self.channels[data["chanId"]] = (data["channel"], symbol)
def on_update(self, data):
""""""
@@ -480,12 +506,12 @@ class BitfinexWebsocketApi(WebsocketClient):
def on_data_update(self, data):
""""""
channel_id = data[0]
- channel, symbol = self.channelDict[channel_id]
+ channel, symbol = self.channels[channel_id]
symbol = str(symbol.replace("t", ""))
# Get the Tick object
- if symbol in self.tickDict:
- tick = self.tickDict[symbol]
+ if symbol in self.ticks:
+ tick = self.ticks[symbol]
else:
tick = TickData(
symbol=symbol,
@@ -495,7 +521,7 @@ class BitfinexWebsocketApi(WebsocketClient):
gateway_name=self.gateway_name,
)
- self.tickDict[symbol] = tick
+ self.ticks[symbol] = tick
l_data1 = data[1]
@@ -509,8 +535,8 @@ class BitfinexWebsocketApi(WebsocketClient):
# Update deep quote
elif channel == "book":
- bid = self.bidDict.setdefault(symbol, {})
- ask = self.askDict.setdefault(symbol, {})
+ bid = self.bids.setdefault(symbol, {})
+ ask = self.asks.setdefault(symbol, {})
if len(l_data1) > 3:
for price, count, amount in l_data1:
@@ -558,7 +584,7 @@ class BitfinexWebsocketApi(WebsocketClient):
# ASK
ask_keys = ask.keys()
- askPriceList = sorted(ask_keys, reverse=True)
+ askPriceList = sorted(ask_keys)
tick.ask_price_1 = askPriceList[0]
tick.ask_price_2 = askPriceList[1]
diff --git a/vnpy/gateway/bitmex/bitmex_gateway.py b/vnpy/gateway/bitmex/bitmex_gateway.py
index ca62365f..18de3404 100644
--- a/vnpy/gateway/bitmex/bitmex_gateway.py
+++ b/vnpy/gateway/bitmex/bitmex_gateway.py
@@ -471,7 +471,10 @@ class BitmexRestApi(RestClient):
headers = request.response.headers
self.rate_limit_remaining = int(headers["x-ratelimit-remaining"])
- self.rate_limit_sleep = int(headers.get("Retry-After", 0)) + 1 # 1 extra second sleep
+
+ self.rate_limit_sleep = int(headers.get("Retry-After", 0))
+ if self.rate_limit_sleep:
+ self.rate_limit_sleep += 1 # 1 extra second sleep
def reset_rate_limit(self):
"""
diff --git a/vnpy/gateway/ctp/ctp_gateway.py b/vnpy/gateway/ctp/ctp_gateway.py
index d08f3864..ab7919bf 100644
--- a/vnpy/gateway/ctp/ctp_gateway.py
+++ b/vnpy/gateway/ctp/ctp_gateway.py
@@ -479,46 +479,48 @@ class CtpTdApi(TdApi):
if not data:
return
- # Get buffered position object
- key = f"{data['InstrumentID'], data['PosiDirection']}"
- position = self.positions.get(key, None)
- if not position:
- position = PositionData(
- symbol=data["InstrumentID"],
- exchange=symbol_exchange_map[data["InstrumentID"]],
- direction=DIRECTION_CTP2VT[data["PosiDirection"]],
- gateway_name=self.gateway_name
- )
- self.positions[key] = position
+ # Check if contract data received
+ if data["InstrumentID"] in symbol_exchange_map:
+ # Get buffered position object
+ key = f"{data['InstrumentID'], data['PosiDirection']}"
+ position = self.positions.get(key, None)
+ if not position:
+ position = PositionData(
+ symbol=data["InstrumentID"],
+ exchange=symbol_exchange_map[data["InstrumentID"]],
+ direction=DIRECTION_CTP2VT[data["PosiDirection"]],
+ gateway_name=self.gateway_name
+ )
+ self.positions[key] = position
- # For SHFE position data update
- if position.exchange == Exchange.SHFE:
- if data["YdPosition"] and not data["TodayPosition"]:
- position.yd_volume = data["Position"]
- # For other exchange position data update
- else:
- position.yd_volume = data["Position"] - data["TodayPosition"]
+ # For SHFE position data update
+ if position.exchange == Exchange.SHFE:
+ if data["YdPosition"] and not data["TodayPosition"]:
+ position.yd_volume = data["Position"]
+ # For other exchange position data update
+ else:
+ position.yd_volume = data["Position"] - data["TodayPosition"]
- # Get contract size (spread contract has no size value)
- size = symbol_size_map.get(position.symbol, 0)
+ # Get contract size (spread contract has no size value)
+ size = symbol_size_map.get(position.symbol, 0)
- # Calculate previous position cost
- cost = position.price * position.volume * size
+ # Calculate previous position cost
+ cost = position.price * position.volume * size
- # Update new position volume
- position.volume += data["Position"]
- position.pnl += data["PositionProfit"]
+ # Update new position volume
+ position.volume += data["Position"]
+ position.pnl += data["PositionProfit"]
- # Calculate average position price
- if position.volume and size:
- cost += data["PositionCost"]
- position.price = cost / (position.volume * size)
+ # Calculate average position price
+ if position.volume and size:
+ cost += data["PositionCost"]
+ position.price = cost / (position.volume * size)
- # Get frozen volume
- if position.direction == Direction.LONG:
- position.frozen += data["ShortFrozen"]
- else:
- position.frozen += data["LongFrozen"]
+ # Get frozen volume
+ if position.direction == Direction.LONG:
+ position.frozen += data["ShortFrozen"]
+ else:
+ position.frozen += data["LongFrozen"]
if last:
for position in self.positions.values():
@@ -717,6 +719,10 @@ class CtpTdApi(TdApi):
"""
self.order_ref += 1
+ if req.offset not in OFFSET_VT2CTP:
+ self.gateway.write_log("请选择开平方向")
+ return ""
+
ctp_req = {
"InstrumentID": req.symbol,
"ExchangeID": req.exchange.value,
diff --git a/vnpy/gateway/ctptest/ctptest_gateway.py b/vnpy/gateway/ctptest/ctptest_gateway.py
index a25f9dec..e22ef0a6 100644
--- a/vnpy/gateway/ctptest/ctptest_gateway.py
+++ b/vnpy/gateway/ctptest/ctptest_gateway.py
@@ -478,46 +478,48 @@ class CtpTdApi(TdApi):
if not data:
return
- # Get buffered position object
- key = f"{data['InstrumentID'], data['PosiDirection']}"
- position = self.positions.get(key, None)
- if not position:
- position = PositionData(
- symbol=data["InstrumentID"],
- exchange=symbol_exchange_map[data["InstrumentID"]],
- direction=DIRECTION_CTP2VT[data["PosiDirection"]],
- gateway_name=self.gateway_name
- )
- self.positions[key] = position
+ # Check if contract data received
+ if data["InstrumentID"] in symbol_exchange_map:
+ # Get buffered position object
+ key = f"{data['InstrumentID'], data['PosiDirection']}"
+ position = self.positions.get(key, None)
+ if not position:
+ position = PositionData(
+ symbol=data["InstrumentID"],
+ exchange=symbol_exchange_map[data["InstrumentID"]],
+ direction=DIRECTION_CTP2VT[data["PosiDirection"]],
+ gateway_name=self.gateway_name
+ )
+ self.positions[key] = position
- # For SHFE position data update
- if position.exchange == Exchange.SHFE:
- if data["YdPosition"] and not data["TodayPosition"]:
- position.yd_volume = data["Position"]
- # For other exchange position data update
- else:
- position.yd_volume = data["Position"] - data["TodayPosition"]
+ # For SHFE position data update
+ if position.exchange == Exchange.SHFE:
+ if data["YdPosition"] and not data["TodayPosition"]:
+ position.yd_volume = data["Position"]
+ # For other exchange position data update
+ else:
+ position.yd_volume = data["Position"] - data["TodayPosition"]
- # Get contract size (spread contract has no size value)
- size = symbol_size_map.get(position.symbol, 0)
+ # Get contract size (spread contract has no size value)
+ size = symbol_size_map.get(position.symbol, 0)
- # Calculate previous position cost
- cost = position.price * position.volume * size
+ # Calculate previous position cost
+ cost = position.price * position.volume * size
- # Update new position volume
- position.volume += data["Position"]
- position.pnl += data["PositionProfit"]
+ # Update new position volume
+ position.volume += data["Position"]
+ position.pnl += data["PositionProfit"]
- # Calculate average position price
- if position.volume and size:
- cost += data["PositionCost"]
- position.price = cost / (position.volume * size)
+ # Calculate average position price
+ if position.volume and size:
+ cost += data["PositionCost"]
+ position.price = cost / (position.volume * size)
- # Get frozen volume
- if position.direction == Direction.LONG:
- position.frozen += data["ShortFrozen"]
- else:
- position.frozen += data["LongFrozen"]
+ # Get frozen volume
+ if position.direction == Direction.LONG:
+ position.frozen += data["ShortFrozen"]
+ else:
+ position.frozen += data["LongFrozen"]
if last:
for position in self.positions.values():
diff --git a/vnpy/gateway/hbdm/hbdm_gateway.py b/vnpy/gateway/hbdm/hbdm_gateway.py
index 663feaf0..b8de5b76 100644
--- a/vnpy/gateway/hbdm/hbdm_gateway.py
+++ b/vnpy/gateway/hbdm/hbdm_gateway.py
@@ -63,6 +63,9 @@ ORDERTYPE_VT2HBDM = {
ORDERTYPE_HBDM2VT = {v: k for k, v in ORDERTYPE_VT2HBDM.items()}
ORDERTYPE_HBDM2VT[1] = OrderType.LIMIT
ORDERTYPE_HBDM2VT[3] = OrderType.MARKET
+ORDERTYPE_HBDM2VT[4] = OrderType.MARKET
+ORDERTYPE_HBDM2VT[5] = OrderType.STOP
+ORDERTYPE_HBDM2VT[6] = OrderType.LIMIT
DIRECTION_VT2HBDM = {
Direction.LONG: "buy",
diff --git a/vnpy/gateway/ib/ib_gateway.py b/vnpy/gateway/ib/ib_gateway.py
index 1f0387d7..2e4b0d87 100644
--- a/vnpy/gateway/ib/ib_gateway.py
+++ b/vnpy/gateway/ib/ib_gateway.py
@@ -64,6 +64,7 @@ EXCHANGE_VT2IB = {
Exchange.ICE: "ICE",
Exchange.SEHK: "SEHK",
Exchange.HKFE: "HKFE",
+ Exchange.CFE: "CFE"
}
EXCHANGE_IB2VT = {v: k for k, v in EXCHANGE_VT2IB.items()}
diff --git a/vnpy/gateway/mini/mini_gateway.py b/vnpy/gateway/mini/mini_gateway.py
index 1811396f..ff41b911 100644
--- a/vnpy/gateway/mini/mini_gateway.py
+++ b/vnpy/gateway/mini/mini_gateway.py
@@ -491,7 +491,7 @@ class MiniTdApi(TdApi):
def onRspQryInvestorPosition(self, data: dict, error: dict, reqid: int, last: bool):
""""""
- if data:
+ if data and data["InstrumentID"] in symbol_exchange_map:
# Get buffered position object
key = f"{data['InstrumentID'], data['PosiDirection']}"
position = self.positions.get(key, None)
diff --git a/vnpy/gateway/minitest/minitest_gateway.py b/vnpy/gateway/minitest/minitest_gateway.py
index cd315540..0a96ef87 100644
--- a/vnpy/gateway/minitest/minitest_gateway.py
+++ b/vnpy/gateway/minitest/minitest_gateway.py
@@ -491,7 +491,7 @@ class MiniTdApi(TdApi):
def onRspQryInvestorPosition(self, data: dict, error: dict, reqid: int, last: bool):
""""""
- if data:
+ if data and data["InstrumentID"] in symbol_exchange_map:
# Get buffered position object
key = f"{data['InstrumentID'], data['PosiDirection']}"
position = self.positions.get(key, None)
diff --git a/vnpy/gateway/okexf/okexf_gateway.py b/vnpy/gateway/okexf/okexf_gateway.py
index e1b9d3cd..abcb660b 100644
--- a/vnpy/gateway/okexf/okexf_gateway.py
+++ b/vnpy/gateway/okexf/okexf_gateway.py
@@ -561,7 +561,7 @@ class OkexfRestApi(RestClient):
for l in data:
ts, o, h, l, c, v, _ = l
- dt = datetime.strptime(ts, "%Y-%m-%dT%H:%M:%S.%fZ")
+ dt = utc_to_local(ts)
bar = BarData(
symbol=req.symbol,
exchange=req.exchange,
diff --git a/vnpy/rpc/__init__.py b/vnpy/rpc/__init__.py
index d1b4877a..f29bc00b 100644
--- a/vnpy/rpc/__init__.py
+++ b/vnpy/rpc/__init__.py
@@ -1,3 +1,4 @@
+from zmq.backend.cython.constants import NOBLOCK
import signal
import threading
import traceback
@@ -7,10 +8,11 @@ from typing import Any, Callable
import zmq
-_ = lambda x: x
+
+def _(x): return x
# Achieve Ctrl-c interrupt recv
-from zmq.backend.cython.constants import NOBLOCK
+
signal.signal(signal.SIGINT, signal.SIG_DFL)
@@ -100,7 +102,7 @@ class RpcServer:
def join(self):
# Wait for RpcServer thread to exit
- if self.__thread.isAlive():
+ if self.__thread and self.__thread.is_alive():
self.__thread.join()
self.__thread = None
@@ -237,7 +239,7 @@ class RpcClient:
def join(self):
# Wait for RpcClient thread to exit
- if self.__thread.isAlive():
+ if self.__thread and self.__thread.is_alive():
self.__thread.join()
self.__thread = None
diff --git a/vnpy/trader/constant.py b/vnpy/trader/constant.py
index 714e86e4..5ed93469 100644
--- a/vnpy/trader/constant.py
+++ b/vnpy/trader/constant.py
@@ -100,6 +100,8 @@ class Exchange(Enum):
HKFE = "HKFE" # Hong Kong Futures Exchange
SGX = "SGX" # Singapore Global Exchange
CBOT = "CBT" # Chicago Board of Trade
+ CBOE = "CBOE" # Chicago Board Options Exchange
+ CFE = "CFE" # CBOE Futures Exchange
DME = "DME" # Dubai Mercantile Exchange
EUREX = "EUX" # Eurex Exchange
APEX = "APEX" # Asia Pacific Exchange
@@ -118,6 +120,9 @@ class Exchange(Enum):
BYBIT = "BYBIT" # bybit.com
COINBASE = "COINBASE"
+ # Special Function
+ LOCAL = "LOCAL" # For local generated data
+
class Currency(Enum):
"""
diff --git a/vnpy/app/cta_strategy/converter.py b/vnpy/trader/converter.py
similarity index 100%
rename from vnpy/app/cta_strategy/converter.py
rename to vnpy/trader/converter.py
diff --git a/vnpy/trader/ui/widget.py b/vnpy/trader/ui/widget.py
index ad8c0297..591b0b00 100644
--- a/vnpy/trader/ui/widget.py
+++ b/vnpy/trader/ui/widget.py
@@ -156,6 +156,9 @@ class TimeCell(BaseCell):
"""
Time format is 12:12:12.5
"""
+ if content is None:
+ return
+
timestamp = content.strftime("%H:%M:%S")
millisecond = int(content.microsecond / 1000)
diff --git a/vnpy/trader/utility.py b/vnpy/trader/utility.py
index 2c7f4b50..51471b70 100644
--- a/vnpy/trader/utility.py
+++ b/vnpy/trader/utility.py
@@ -3,8 +3,10 @@ General utility functions.
"""
import json
+import logging
from pathlib import Path
-from typing import Callable
+from typing import Callable, Dict
+from decimal import Decimal
import numpy as np
import talib
@@ -13,6 +15,9 @@ from .object import BarData, TickData
from .constant import Exchange, Interval
+log_formatter = logging.Formatter('[%(asctime)s] %(message)s')
+
+
def extract_vt_symbol(vt_symbol: str):
"""
:return: (symbol, exchange)
@@ -109,11 +114,13 @@ def save_json(filename: str, data: dict):
)
-def round_to(value: float, target: float):
+def round_to(value: float, target: float) -> float:
"""
Round price to price tick value.
"""
- rounded = int(round(value / target)) * target
+ value = Decimal(str(value))
+ target = Decimal(str(target))
+ rounded = float(int(round(value / target)) * target)
return rounded
@@ -458,3 +465,25 @@ def virtual(func: "callable"):
that can be (re)implemented by subclasses.
"""
return func
+
+
+file_handlers: Dict[str, logging.FileHandler] = {}
+
+
+def _get_file_logger_handler(filename: str):
+ handler = file_handlers.get(filename, None)
+ if handler is None:
+ handler = logging.FileHandler(filename)
+ file_handlers[filename] = handler # Am i need a lock?
+ return handler
+
+
+def get_file_logger(filename: str):
+ """
+ return a logger that writes records into a file.
+ """
+ logger = logging.getLogger(filename)
+ handler = _get_file_logger_handler(filename) # get singleton handler.
+ handler.setFormatter(log_formatter)
+ logger.addHandler(handler) # each handler will be added only once.
+ return logger