[bug fix]

This commit is contained in:
msincenselee 2020-02-02 19:09:05 +08:00
parent d9a06c11cd
commit f9aad85db9
29 changed files with 2309 additions and 950 deletions

View File

@ -2,59 +2,67 @@
--------------------------------------------------------------------------------------------
###FAQ
#3、碰到的问题找不到vnpy.xx.xx(原2.7环境)
3、碰到的问题找不到vnpy.xx.xx(原2.7环境)
可能你使用了vnpy的原版安装到conda环境中了。需要先卸载 pip uninstall vnpy
#4、碰到的问题importError: libGL.so.1: cannot open shared object file: No such file or directory
4、碰到的问题importError: libGL.so.1: cannot open shared object file: No such file or directory
ubuntu下sudo apt install libgl1-mesa-glx
centOS下sudo yum install mesa-libGL.x86_64
#5、碰到的问题version `GLIBCXX_3.4.21' not found
5、碰到的问题version `GLIBCXX_3.4.21' not found
conda install libgcc
若出现更高版本需求参见第10点
#6、碰到的问题在3.7 env下安装RqPlus时报错:talib/common.c:242:28: fatal error: ta-lib/ta_defs.h: No such file or directory
6、碰到的问题在3.7 env下安装RqPlus时报错:talib/common.c:242:28: fatal error: ta-lib/ta_defs.h: No such file or directory
locate ta_defs.h
找到地址:/home/user/anaconda3/pkgs/ta-lib-0.4.9-np111py27_0/include/ta-lib
# 复制一份到/usr/include目录下
sudo cp /home/user/anaconda3/pkgs/ta-lib-0.4.9-np111py27_0/include/ta-lib /usr/include -R
#7、碰到的问题:安装某些组件,提示 cl.exe Not found
7、碰到的问题安装某些组件提示 cl.exe Not found
如果你安装了VC基础组件需要增加一个用户环境变量把"C:\Program Files (x86)\Microsoft Visual Studio\Shared\14.0\VC\bin" 添加到path变量中
#8、Install Ta-Lib
如果你用py37虚拟环境
source activate py37
8、Install Ta-Lib
conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
conda config --set show_channel_urls yes
conda install -c quantopian ta-lib=0.4.9
如果你用py37虚拟环境
source activate py37
# 9、数字货币的增量安装
conda install scipy
conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
conda config --set show_channel_urls yes
conda install -c quantopian ta-lib=0.4.9
pip install autobahn
pip install twisted
若出现找不到rc.exe 请先使用vs x86&x64界面激活py37后再运行
pip install pyOpenSSL
9、数字货币的增量安装
conda install scipy
pip install autobahn
pip install twisted
若出现找不到rc.exe 请先使用vs x86&x64界面激活py37后再运行
pip install pyOpenSSL
10、升级gcc
# 10、升级gcc
使用奇正MOM的CTP API时提示`GLIBCXX_3.4.22' not found当前centos最高版本是 3.4.21通过yum不能升级需要手工下载升级。
wget http://ftp.de.debian.org/debian/pool/main/g/gcc-9/libstdc++6_9.2.1-8_amd64.deb
解压
解压
ar -x libstdc++6_9.2.1-8_amd64.deb
(就是 ar 命令不是tar
tar -xvf data.tar.xz
安装
安装
删除: rm /usr/lib64/libstdc++.so.6
拷贝: cp usr/lib/x86_64-linux-gnu/libstdc++.so.6.0.28 /usr/lib64/
连接: ln -s /usr/lib64/libstdc++.so.6.0.28 /usr/lib64/libstdc++.so.6
结果
结果
strings /usr/lib64/libstdc++.so.6 | grep GLIBCXX
@ -86,7 +94,8 @@ pip install pyOpenSSL
GLIBCXX_3.4.25
GLIBCXX_DEBUG_MESSAGE_LENGTH
# 11、升级glibc
11、升级glibc
使用奇正MOM的CTP API时提示`GLIBC_2.18' not found当前centos最高版本是 3.4.21通过yum不能升级需要手工下载升级。
root 用户登录

View File

@ -0,0 +1,16 @@
{
"mongo_db":
{
"host": "192.168.0.207",
"port": 27017
},
"accounts":
{
"ctp":
{
"copy_history_trades": true,
"copy_history_orders": true
}
}
}

View File

@ -0,0 +1,4 @@
{
"accountid" : "112022",
"strategy_group": "win01"
}

View File

@ -5,10 +5,10 @@ from vnpy.trader.engine import MainEngine
from vnpy.trader.ui import MainWindow, create_qapp
# from vnpy.gateway.binance import BinanceGateway
from vnpy.gateway.bitmex import BitmexGateway
#from vnpy.gateway.bitmex import BitmexGateway
# from vnpy.gateway.futu import FutuGateway
# from vnpy.gateway.ib import IbGateway
# from vnpy.gateway.ctp import CtpGateway
from vnpy.gateway.ctp import CtpGateway
# from vnpy.gateway.ctptest import CtptestGateway
# from vnpy.gateway.mini import MiniGateway
# from vnpy.gateway.sopt import SoptGateway
@ -31,30 +31,33 @@ from vnpy.gateway.bitmex import BitmexGateway
# from vnpy.gateway.coinbase import CoinbaseGateway
# from vnpy.gateway.bitstamp import BitstampGateway
# from vnpy.gateway.gateios import GateiosGateway
from vnpy.gateway.bybit import BybitGateway
# from vnpy.gateway.bybit import BybitGateway
# from vnpy.app.cta_strategy import CtaStrategyApp
from vnpy.app.cta_strategy_pro import CtaStrategyProApp
# from vnpy.app.csv_loader import CsvLoaderApp
# from vnpy.app.algo_trading import AlgoTradingApp
from vnpy.app.cta_backtester import CtaBacktesterApp
# from vnpy.app.data_recorder import DataRecorderApp
# from vnpy.app.risk_manager import RiskManagerApp
from vnpy.app.tick_recorder import TickRecorderApp
from vnpy.app.risk_manager import RiskManagerApp
# from vnpy.app.script_trader import ScriptTraderApp
# from vnpy.app.rpc_service import RpcServiceApp
# from vnpy.app.spread_trading import SpreadTradingApp
from vnpy.app.portfolio_manager import PortfolioManagerApp
# from vnpy.app.portfolio_manager import PortfolioManagerApp
from vnpy.app.account_recorder import AccountRecorderApp
def main():
""""""
qapp = create_qapp()
event_engine = EventEngine()
event_engine = EventEngine(debug=True, over_ms=200)
main_engine = MainEngine(event_engine)
# main_engine.add_gateway(BinanceGateway)
# main_engine.add_gateway(CtpGateway)
main_engine.add_gateway(CtpGateway, gateway_name='ctp')
main_engine.add_gateway(CtpGateway, gateway_name='ctp_yh01')
# main_engine.add_gateway(CtptestGateway)
# main_engine.add_gateway(MiniGateway)
# main_engine.add_gateway(SoptGateway)
@ -62,7 +65,7 @@ def main():
# main_engine.add_gateway(FemasGateway)
# main_engine.add_gateway(IbGateway)
# main_engine.add_gateway(FutuGateway)
main_engine.add_gateway(BitmexGateway)
#main_engine.add_gateway(BitmexGateway)
# main_engine.add_gateway(TigerGateway)
# main_engine.add_gateway(OesGateway)
# main_engine.add_gateway(OkexGateway)
@ -80,18 +83,21 @@ def main():
# main_engine.add_gateway(CoinbaseGateway)
# main_engine.add_gateway(BitstampGateway)
# main_engine.add_gateway(GateiosGateway)
main_engine.add_gateway(BybitGateway)
#main_engine.add_gateway(BybitGateway)
# main_engine.add_app(CtaStrategyApp)
main_engine.add_app(CtaBacktesterApp)
#main_engine.add_app(CtaStrategyApp)
main_engine.add_app(CtaStrategyProApp)
#main_engine.add_app(CtaBacktesterApp)
# main_engine.add_app(CsvLoaderApp)
# main_engine.add_app(AlgoTradingApp)
# main_engine.add_app(DataRecorderApp)
# main_engine.add_app(RiskManagerApp)
# main_engine.add_app(TickRecorderApp)
main_engine.add_app(RiskManagerApp)
# main_engine.add_app(ScriptTraderApp)
# main_engine.add_app(RpcServiceApp)
# main_engine.add_app(SpreadTradingApp)
main_engine.add_app(PortfolioManagerApp)
# main_engine.add_app(PortfolioManagerApp)
# main_engine.add_app(AccountRecorderApp)
main_window = MainWindow(main_engine, event_engine)
main_window.showMaximized()

View File

@ -1,14 +1,17 @@
from pathlib import Path
from vnpy.trader.app import BaseApp
from vnpy.trader.constant import Direction,Offset
from vnpy.trader.constant import Direction,Offset,Status
from vnpy.trader.object import TickData, BarData, TradeData, OrderData
from vnpy.trader.utility import BarGenerator, ArrayManager
from .cta_position import CtaPosition
from .cta_line_bar import CtaLineBar, CtaMinuteBar, CtaHourBar, CtaDayBar, CtaWeekBar
from .cta_policy import CtaPolicy
from .cta_grid_trade import CtaGrid, CtaGridTrade
from .base import APP_NAME, StopOrder
from .engine import CtaEngine
from .template import CtaTemplate, CtaSignal, TargetPosTemplate
from .template import CtaTemplate, CtaSignal, TargetPosTemplate, CtaProTemplate
class CtaStrategyProApp(BaseApp):
""""""

View File

@ -1,11 +1,11 @@
"""
Defines constants and objects used in CtaStrategyPro App.
"""
from abc import ABC
from dataclasses import dataclass, field
from enum import Enum
from datetime import timedelta
from logging import INFO, ERROR
from vnpy.trader.constant import Direction, Offset, Interval
APP_NAME = "CtaStrategyPro"
@ -92,3 +92,30 @@ INTERVAL_DELTA_MAP = {
Interval.HOUR: timedelta(hours=1),
Interval.DAILY: timedelta(days=1),
}
class CtaComponent(ABC):
""" CTA策略基础组件"""
def __init__(self, strategy=None, **kwargs):
"""
构造
:param strategy:
"""
self.strategy = strategy
# ----------------------------------------------------------------------
def write_log(self, content: str):
"""记录日志"""
if self.strategy:
self.strategy.write_log(msg=content, level=INFO)
else:
print(content)
# ----------------------------------------------------------------------
def write_error(self, content: str, level: int = ERROR):
"""记录错误日志"""
if self.strategy:
self.strategy.write_log(msg=content, level=level)
else:
print(content, file=sys.stderr)

View File

@ -10,10 +10,9 @@ import traceback
from collections import OrderedDict
from datetime import datetime
from dataclasses import dataclass, field
from typing import List
from vnpy.trader.utility import get_folder_path
from vnpy.app.cta_strategy_pro.base import Direction
from vnpy.app.cta_strategy_pro.template import CtaComponent
from vnpy.app.cta_strategy_pro.base import Direction, CtaComponent
"""
网格交易用于套利单
@ -35,37 +34,47 @@ TREND_GRID = 'trend' # 趋势网格
LOCK_GRID = 'lock' # 对锁网格
@dataclass
class CtaGrid(object):
"""网格类
它是网格交易的最小单元
包括交易方向开仓价格平仓价格止损价格开仓状态平仓状态
"""
id: str = str(uuid.uuid1()) # gid
direction: Direction = Direction.NET # 交易方向LONG正套SHORT反套
open_price: float = 0 # 开仓价格
close_price: float = 0 # 止盈价格
stop_price: float = 0 # 止损价格
def __init__(self,
direction: Direction = None,
open_price: float = 0,
close_price: float = 0,
stop_price: float = 0,
vt_symbol: str = '',
volume: float = 0,
traded_volume: float = 0,
order_status: bool = False,
open_status: bool = False,
close_status: bool = False,
open_time: datetime = None,
order_time: datetime = None,
reuse_count: int = 0,
type: str = ''
):
vt_symbol: str = '' # 品种合约
volume: float = 0 # 开仓数量( 兼容数字货币 )
traded_volume: float = 0 # 已成交数量 开仓时,为开仓数量,平仓时,为平仓数量
order_status: bool = False # 挂单状态: True,已挂单False未挂单
order_ids: list[str] = field(default_factory=list) # order_id list
open_status: bool = False # 开仓状态
close_status: bool = False # 平仓状态
open_time: datetime = None # 开仓时间
order_time: datetime = None # 委托时间
lock_grid_ids: list[str] = field(default_factory=list) # 锁单的网格,[gid,gid]
reuse_count: int = 0 # 重用次数0 平仓后是否删除)
type: str = '' # 网格类型标签
snapshot: dict = field(default_factory=dict) # 切片数据,如记录开仓点时的某些状态数据
self.id: str = str(uuid.uuid1()) # gid
self.direction = direction # 交易方向LONG正套SHORT反套
self.open_price = open_price # 开仓价格
self.close_price = close_price # 止盈价格
self.stop_price = stop_price # 止损价格
self.vt_symbol = vt_symbol # 品种合约
self.volume = volume # 开仓数量( 兼容数字货币 )
self.traded_volume = traded_volume # 已成交数量 开仓时,为开仓数量,平仓时,为平仓数量
self.order_status = order_status # 挂单状态: True,已挂单False未挂单
self.order_ids = [] # order_id list
self.open_status = open_status # 开仓状态
self.close_status = close_status # 平仓状态
self.open_time = open_time # 开仓时间
self.order_time = order_time # 委托时间
self.lock_grid_ids = [] # 锁单的网格,[gid,gid]
self.reuse_count = reuse_count # 重用次数0 平仓后是否删除)
self.type = type # 网格类型标签
self.snapshot = {} # 切片数据,如记录开仓点时的某些状态数据
def to_json(self):
"""输出JSON"""
@ -156,10 +165,10 @@ class CtaGridTrade(CtaComponent):
vol网格开仓数
minDiff, 最小价格跳动
"""
super(CtaGridTrade).__init__(strategy=strategy)
super(CtaGridTrade, self).__init__(strategy=strategy)
self.price_tick = kwargs.get('price_tick', 1)
self.jsonName = self.strategy.name # 策略名称
self.json_name = self.strategy.strategy_name # 策略名称
self.max_lots = kwargs.get('max_lots', 10) # 缺省网格数量
self.grid_height = kwargs.get('grid_height', 10 * self.price_tick) # 最小网格高度
self.grid_win = kwargs.get('grid_win', 10 * self.price_tick) # 最小止盈高度
@ -176,7 +185,8 @@ class CtaGridTrade(CtaComponent):
self.max_up_open_price = 0.0 # 上网格最高开仓价
self.min_dn_open_price = 0.0 # 下网格最小开仓价
self.json_file_path = os.path.join(get_folder_path('data'), f'{self.jsonName}_Grids.json') # 网格的路径
# 网格json文件的路径
self.json_file_path = os.path.join(get_folder_path('data'), f'{self.json_name}_Grids.json')
def get_volume_rate(self, idx: int = 0):
"""获取网格索引对应的开仓数量比例"""
@ -554,7 +564,7 @@ class CtaGridTrade(CtaComponent):
if direction == Direction.SHORT:
for x in self.up_grids[:]:
if x.id in id:
if x.id in ids:
self.write_log(u'清除上网格[open={},close={},stop={},volume={}]'
.format(x.open_price, x.close_price, x.stop_price, x.volume))
self.up_grids.remove(x)
@ -873,12 +883,12 @@ class CtaGridTrade(CtaComponent):
grids_save_path = get_folder_path('data')
# 确保json名字与策略一致
if self.jsonName != self.strategy.name:
self.write_log(u'JsonName {} 与 上层策略名{} 不一致.'.format(self.jsonName, self.strategy.name))
self.jsonName = self.strategy.name
if self.json_name != self.strategy.strategy_name:
self.write_log(u'JsonName {} 与 上层策略名{} 不一致.'.format(self.json_name, self.strategy.strategy_name))
self.json_name = self.strategy.strategy_name
# 新版网格持久化文件
grid_json_file = os.path.join(grids_save_path, u'{}_Grids.json'.format(self.jsonName))
grid_json_file = os.path.join(grids_save_path, u'{}_Grids.json'.format(self.json_name))
self.json_file_path = grid_json_file
data = {}
@ -907,12 +917,12 @@ class CtaGridTrade(CtaComponent):
data = {}
grids_save_path = get_folder_path('data')
if self.jsonName != self.strategy.name:
self.write_log(u'JsonName {} 与 上层策略名{} 不一致.'.format(self.jsonName, self.strategy.name))
self.jsonName = self.strategy.name
if self.json_name != self.strategy.strategy_name:
self.write_log(u'JsonName {} 与 上层策略名{} 不一致.'.format(self.json_name, self.strategy.strategy_name))
self.json_name = self.strategy.strategy_name
# 若json文件不存在就保存一个若存在就优先使用数据文件
grid_json_file = os.path.join(grids_save_path, u'{}_Grids.json'.format(self.jsonName))
grid_json_file = os.path.join(grids_save_path, u'{}_Grids.json'.format(self.json_name))
if not os.path.exists(grid_json_file):
data['up_grids'] = []
data['dn_grids'] = []
@ -970,7 +980,7 @@ class CtaGridTrade(CtaComponent):
data_folder = get_folder_path('data')
self.jsonName = new_name
self.json_name = new_name
# 旧文件
old_json_file = os.path.join(data_folder, u'{0}_Grids.json'.format(old_name))

View File

@ -1016,7 +1016,12 @@ class CtaLineBar(object):
def first_tick(self, tick: TickData):
""" K线的第一个Tick数据"""
self.cur_bar = BarData() # 创建新的K线
self.cur_bar = BarData(
gateway_name=tick.gateway_name,
symbol=tick.symbol,
exchange=tick.exchange,
datetime=tick.datetime
) # 创建新的K线
# 计算K线的整点分钟周期这里周期最小是1分钟。如果你是采用非整点分钟例如1.5分钟,请把这段注解掉
if self.minute_interval and self.interval == Interval.SECOND:
self.minute_interval = int(self.bar_interval / 60)
@ -1506,7 +1511,7 @@ class CtaLineBar(object):
if self.para_ma1_len > 0:
count_len = min(self.bar_len, self.para_ma1_len)
if count_len > 0:
close_ma_array = ta.MA(np.append(self.close_array[-count_len:], [self.line_bar[-1].close]), count_len)
close_ma_array = ta.MA(np.append(self.close_array[-count_len:], [self.line_bar[-1].close_price]), count_len)
self._rt_ma1 = round(float(close_ma_array[-1]), self.round_n)
# 计算斜率
@ -1517,7 +1522,7 @@ class CtaLineBar(object):
if self.para_ma2_len > 0:
count_len = min(self.bar_len, self.para_ma2_len)
if count_len > 0:
close_ma_array = ta.MA(np.append(self.close_array[-count_len:], [self.line_bar[-1].close]), count_len)
close_ma_array = ta.MA(np.append(self.close_array[-count_len:], [self.line_bar[-1].close_price]), count_len)
self._rt_ma2 = round(float(close_ma_array[-1]), self.round_n)
# 计算斜率
@ -1528,7 +1533,7 @@ class CtaLineBar(object):
if self.para_ma3_len > 0:
count_len = min(self.bar_len, self.para_ma3_len)
if count_len > 0:
close_ma_array = ta.MA(np.append(self.close_array[-count_len:], [self.line_bar[-1].close]), count_len)
close_ma_array = ta.MA(np.append(self.close_array[-count_len:], [self.line_bar[-1].close_price]), count_len)
self._rt_ma3 = round(float(close_ma_array[-1]), self.round_n)
# 计算斜率
@ -1846,27 +1851,27 @@ class CtaLineBar(object):
# 计算 ATR
if self.para_atr1_len > 0:
count_len = min(self.bar_len, self.para_atr1_len)
self.cur_atr1 = ta.ATR(self.high_array[-count_len:], self.low_array[-count_len:],
self.close_array[-count_len:], count_len)
self.cur_atr1 = round(self.cur_atr1, self.round_n)
cur_atr1 = ta.ATR(self.high_array[-count_len * 2:], self.low_array[-count_len * 2:],
self.close_array[-count_len * 2:], count_len)
self.cur_atr1 = round(cur_atr1[-1], self.round_n)
if len(self.line_atr1) > self.max_hold_bars:
del self.line_atr1[0]
self.line_atr1.append(self.cur_atr1)
if self.para_atr2_len > 0:
count_len = min(self.bar_len, self.para_atr2_len)
self.cur_atr2 = ta.ATR(self.high_array[-count_len:], self.low_array[-count_len:],
self.close_array[-count_len:], count_len)
self.cur_atr2 = round(self.cur_atr2, self.round_n)
cur_atr2 = ta.ATR(self.high_array[-count_len * 2:], self.low_array[-count_len * 2:],
self.close_array[-count_len * 2:], count_len)
self.cur_atr2 = round(cur_atr2[-1], self.round_n)
if len(self.line_atr2) > self.max_hold_bars:
del self.line_atr2[0]
self.line_atr2.append(self.cur_atr2)
if self.para_atr3_len > 0:
count_len = min(self.bar_len, self.para_atr3_len)
self.cur_atr3 = ta.ATR(self.high_array[-count_len:], self.low_array[-count_len:],
self.close_array[-count_len:], count_len)
self.cur_atr3 = round(self.cur_atr3, self.round_n)
cur_atr3 = ta.ATR(self.high_array[-count_len * 2 :], self.low_array[-count_len * 2:],
self.close_array[-count_len * 2:], count_len)
self.cur_atr3 = round(cur_atr3[-1], self.round_n)
if len(self.line_atr3) > self.max_hold_bars:
del self.line_atr3[0]
@ -5398,424 +5403,5 @@ class CtaWeekBar(CtaLineBar):
# 实时计算
self.rt_executed = False
self.lastTick = tick
self.last_tick = tick
class test_strategy(object):
def __init__(self):
self.price_tick = 1
self.underlying_symbol = 'I'
self.vt_symbol = 'I99'
self.lineM5 = None
self.lineM30 = None
self.lineH1 = None
self.lineH2 = None
self.lineD = None
self.lineW = None
self.TMinuteInterval = 1
self.save_m30_bars = []
self.save_h1_bars = []
self.save_h2_bars = []
self.save_d_bars = []
self.save_w_bars = []
def createM5(self):
"""使用ctalinbar创建5分钟K线"""
lineM5Setting = {}
lineM5Setting['name'] = u'M5'
lineM5Setting['interval'] = Interval.MINUTE
lineM5Setting['bar_interval'] = 5
lineM5Setting['mode'] = CtaLineBar.TICK_MODE
lineM5Setting['price_tick'] = self.price_tick
lineM5Setting['underlying_symbol'] = self.underlying_symbol
self.lineM5 = CtaLineBar(self, self.onBarM5, lineM5Setting)
def onBarM5(self, bar):
self.write_log(self.lineM5.get_last_bar_str())
def createlineM30_with_macd(self):
"""使用CtaLineBar创建30分钟时间"""
# 创建M30 K线
lineM30Setting = {}
lineM30Setting['name'] = u'M30'
lineM30Setting['interval'] = Interval.MINUTE
lineM30Setting['bar_interval'] = 30
lineM30Setting['para_macd_fast_len'] = 26
lineM30Setting['para_macd_slow_len'] = 12
lineM30Setting['para_macd_signal_len'] = 9
lineM30Setting['mode'] = CtaLineBar.TICK_MODE
lineM30Setting['price_tick'] = self.price_tick
lineM30Setting['underlying_symbol'] = self.underlying_symbol
self.lineM30 = CtaLineBar(self, self.onBarM30MACD, lineM30Setting)
def onBarM30MACD(self, bar):
self.write_log(self.lineM30.get_last_bar_str())
def createLineM30(self):
"""使用ctaMinuteBar, 测试内部自动写入csv文件"""
# 创建M30 K线
lineM30Setting = {}
lineM30Setting['name'] = u'M30'
lineM30Setting['interval'] = Interval.MINUTE
lineM30Setting['bar_interval'] = 30
lineM30Setting['para_pre_len'] = 10
lineM30Setting['para_ma1_len'] = 5
lineM30Setting['para_ma2_len'] = 10
lineM30Setting['para_ma3_len'] = 60
lineM30Setting['para_active_yb'] = True
lineM30Setting['para_active_skd'] = True
lineM30Setting['mode'] = CtaLineBar.TICK_MODE
lineM30Setting['price_tick'] = self.price_tick
lineM30Setting['underlying_symbol'] = self.underlying_symbol
self.lineM30 = CtaMinuteBar(self, self.onBarM30, lineM30Setting)
# 写入文件
self.lineM30.export_filename = os.path.abspath(
os.path.join(os.getcwd(),
u'export_{}_{}.csv'.format(self.vt_symbol, self.lineM30.name)))
self.lineM30.export_fields = [
{'name': 'datetime', 'source': 'bar', 'attr': 'datetime', 'type_': 'datetime'},
{'name': 'open', 'source': 'bar', 'attr': 'open_price', 'type_': 'float'},
{'name': 'high', 'source': 'bar', 'attr': 'high_price', 'type_': 'float'},
{'name': 'low', 'source': 'bar', 'attr': 'low_price', 'type_': 'float'},
{'name': 'close', 'source': 'bar', 'attr': 'close_price', 'type_': 'float'},
{'name': 'turnover', 'source': 'bar', 'attr': 'turnover', 'type_': 'float'},
{'name': 'volume', 'source': 'bar', 'attr': 'volume', 'type_': 'float'},
{'name': 'open_interest', 'source': 'bar', 'attr': 'open_interest', 'type_': 'float'},
{'name': 'kf', 'source': 'line_bar', 'attr': 'line_statemean', 'type_': 'list'}
]
def createLineH1(self):
# 创建2小时K线
lineH1Setting = {}
lineH1Setting['name'] = u'H1'
lineH1Setting['interval'] = Interval.HOUR
lineH1Setting['bar_interval'] = 1
lineH1Setting['para_pre_len'] = 10
lineH1Setting['para_ema1_len'] = 5
lineH1Setting['para_ema2_len'] = 10
lineH1Setting['para_ema3_len'] = 60
lineH1Setting['para_active_yb'] = True
lineH1Setting['para_active_skd'] = True
lineH1Setting['mode'] = CtaLineBar.TICK_MODE
lineH1Setting['price_tick'] = self.price_tick
lineH1Setting['underlying_symbol'] = self.underlying_symbol
self.lineH1 = CtaLineBar(self, self.onBarH1, lineH1Setting)
def createLineH2(self):
# 创建2小时K线
lineH2Setting = {}
lineH2Setting['name'] = u'H2'
lineH2Setting['interval'] = Interval.HOUR
lineH2Setting['bar_interval'] = 2
lineH2Setting['para_pre_len'] = 5
lineH2Setting['para_ma1_len'] = 5
lineH2Setting['para_ma2_len'] = 10
lineH2Setting['para_ma3_len'] = 18
lineH2Setting['para_active_yb'] = True
lineH2Setting['para_active_skd'] = True
lineH2Setting['mode'] = CtaLineBar.TICK_MODE
lineH2Setting['price_tick'] = self.price_tick
lineH2Setting['underlying_symbol'] = self.underlying_symbol
self.lineH2 = CtaHourBar(self, self.onBarH2, lineH2Setting)
def createLineD(self):
# 创建的日K线
lineDaySetting = {}
lineDaySetting['name'] = u'D1'
lineDaySetting['bar_interval'] = 1
lineDaySetting['para_pre_len'] = 5
lineDaySetting['para_art1_len'] = 26
lineDaySetting['para_ma1_len'] = 5
lineDaySetting['para_ma2_len'] = 10
lineDaySetting['para_ma3_len'] = 18
lineDaySetting['para_active_yb'] = True
lineDaySetting['para_active_skd'] = True
lineDaySetting['mode'] = CtaDayBar.TICK_MODE
lineDaySetting['price_tick'] = self.price_tick
lineDaySetting['underlying_symbol'] = self.underlying_symbol
self.lineD = CtaDayBar(self, self.onBarD, lineDaySetting)
def createLineW(self):
"""创建周线"""
lineWeekSetting = {}
lineWeekSetting['name'] = u'W1'
lineWeekSetting['para_pre_len'] = 5
lineWeekSetting['para_art1_len'] = 26
lineWeekSetting['para_ma1_len'] = 5
lineWeekSetting['para_ma2_len'] = 10
lineWeekSetting['para_ma3_len'] = 18
lineWeekSetting['para_active_yb'] = True
lineWeekSetting['para_active_skd'] = True
lineWeekSetting['mode'] = CtaDayBar.TICK_MODE
lineWeekSetting['price_tick'] = self.price_tick
lineWeekSetting['underlying_symbol'] = self.underlying_symbol
self.lineW = CtaWeekBar(self, self.onBarW, lineWeekSetting)
def onBar(self, bar):
# print(u'tradingDay:{},dt:{},o:{},h:{},l:{},c:{},v:{}'.format(bar.trading_day,bar.datetime, bar.open, bar.high, bar.low_price, bar.close_price, bar.volume))
if self.lineW:
self.lineW.add_bar(bar, bar_freq=self.TMinuteInterval)
if self.lineD:
self.lineD.add_bar(bar, bar_freq=self.TMinuteInterval)
if self.lineH2:
self.lineH2.add_bar(bar, bar_freq=self.TMinuteInterval)
if self.lineH1:
self.lineH1.add_bar(bar, bar_freq=self.TMinuteInterval)
if self.lineM30:
self.lineM30.add_bar(bar, bar_freq=self.TMinuteInterval)
if self.lineM5:
self.lineM5.add_bar(bar, bar_freq=self.TMinuteInterval)
# if self.lineH2:
# self.lineH2.skd_is_high_dead_cross(runtime=True, high_skd=30)
# self.lineH2.skd_is_low_golden_cross(runtime=True, low_skd=70)
def onBarM30(self, bar):
self.write_log(self.lineM30.get_last_bar_str())
self.save_m30_bars.append({
'datetime': bar.datetime,
'open': bar.open_price,
'high': bar.high_price,
'low': bar.low_price,
'close': bar.close_price,
'turnover': 0,
'volume': bar.volume,
'open_interest': 0,
'ma5': self.lineM30.line_ma1[-1] if len(self.lineM30.line_ma1) > 0 else bar.close_price,
'ma10': self.lineM30.line_ma2[-1] if len(self.lineM30.line_ma2) > 0 else bar.close_price,
'ma60': self.lineM30.line_ma3[-1] if len(self.lineM30.line_ma3) > 0 else bar.close_price,
'sk': self.lineM30.line_sk[-1] if len(self.lineM30.line_sk) > 0 else 0,
'sd': self.lineM30.line_sd[-1] if len(self.lineM30.line_sd) > 0 else 0
})
def onBarH1(self, bar):
self.write_log(self.lineH1.get_last_bar_str())
self.save_h1_bars.append({
'datetime': bar.datetime,
'open': bar.open_price,
'high': bar.high_price,
'low': bar.low_price,
'close': bar.close_price,
'turnover': 0,
'volume': bar.volume,
'open_interest': 0,
'ema5': self.lineH1.line_ema1[-1] if len(self.lineH1.line_ema1) > 0 else bar.close_price,
'ema10': self.lineH1.line_ema2[-1] if len(self.lineH1.line_ema2) > 0 else bar.close_price,
'ema60': self.lineH1.line_ema3[-1] if len(self.lineH1.line_ema3) > 0 else bar.close_price,
'sk': self.lineH1.line_sk[-1] if len(self.lineH1.line_sk) > 0 else 0,
'sd': self.lineH1.line_sd[-1] if len(self.lineH1.line_sd) > 0 else 0
})
def onBarH2(self, bar):
self.write_log(self.lineH2.get_last_bar_str())
self.save_h2_bars.append({
'datetime': bar.datetime,
'open': bar.open_price,
'high': bar.high_price,
'low': bar.low_price,
'close': bar.close_price,
'turnover': 0,
'volume': bar.volume,
'open_interest': 0,
'ma5': self.lineH2.line_ma1[-1] if len(self.lineH2.line_ma1) > 0 else bar.close_price,
'ma10': self.lineH2.line_ma2[-1] if len(self.lineH2.line_ma2) > 0 else bar.close_price,
'ma18': self.lineH2.line_ma3[-1] if len(self.lineH2.line_ma3) > 0 else bar.close_price,
'sk': self.lineH2.line_sk[-1] if len(self.lineH2.line_sk) > 0 else 0,
'sd': self.lineH2.line_sd[-1] if len(self.lineH2.line_sd) > 0 else 0
})
def onBarD(self, bar):
self.write_log(self.lineD.get_last_bar_str())
self.save_d_bars.append({
'datetime': bar.datetime,
'open': bar.open_price,
'high': bar.high_price,
'low': bar.low_price,
'close': bar.close_price,
'turnover': 0,
'volume': bar.volume,
'open_interest': 0,
'ma5': self.lineD.line_ma1[-1] if len(self.lineD.line_ma1) > 0 else bar.close_price,
'ma10': self.lineD.line_ma2[-1] if len(self.lineD.line_ma2) > 0 else bar.close_price,
'ma18': self.lineD.line_ma3[-1] if len(self.lineD.line_ma3) > 0 else bar.close_price,
'sk': self.lineD.line_sk[-1] if len(self.lineD.line_sk) > 0 else 0,
'sd': self.lineD.line_sd[-1] if len(self.lineD.line_sd) > 0 else 0
})
def onBarW(self, bar):
self.write_log(self.lineW.get_last_bar_str())
self.save_w_bars.append({
'datetime': bar.datetime,
'open': bar.open_price,
'high': bar.high_price,
'low': bar.low_price,
'close': bar.close_price,
'turnover': 0,
'volume': bar.volume,
'open_interest': 0,
'ma5': self.lineW.line_ma1[-1] if len(self.lineW.line_ma1) > 0 else bar.close_price,
'ma10': self.lineW.line_ma2[-1] if len(self.lineW.line_ma2) > 0 else bar.close_price,
'ma18': self.lineW.line_ma3[-1] if len(self.lineW.line_ma3) > 0 else bar.close_price,
'sk': self.lineW.line_sk[-1] if len(self.lineW.line_sk) > 0 else 0,
'sd': self.lineW.line_sd[-1] if len(self.lineW.line_sd) > 0 else 0
})
def on_tick(self, tick):
print(u'{0},{1},ap:{2},av:{3},bp:{4},bv:{5}'.format(tick.datetime, tick.last_price, tick.ask_price_1,
tick.ask_volume_1, tick.bid_price_1, tick.bid_volume_1))
def write_log(self, content):
print(content)
def saveData(self):
if len(self.save_m30_bars) > 0:
outputFile = '{}_m30.csv'.format(self.vt_symbol)
with open(outputFile, 'w', encoding='utf8', newline='') as f:
fieldnames = ['datetime', 'open', 'high', 'low', 'close', 'turnover', 'volume', 'open_interest',
'ma5', 'ma10', 'ma60', 'sk', 'sd']
writer = csv.DictWriter(f=f, fieldnames=fieldnames, dialect='excel')
writer.writeheader()
for row in self.save_m30_bars:
writer.writerow(row)
if len(self.save_h1_bars) > 0:
outputFile = '{}_h1.csv'.format(self.vt_symbol)
with open(outputFile, 'w', encoding='utf8', newline='') as f:
fieldnames = ['datetime', 'open', 'high', 'low', 'close', 'turnover', 'volume', 'open_interest',
'ema5', 'ema10', 'ema60', 'sk', 'sd']
writer = csv.DictWriter(f=f, fieldnames=fieldnames, dialect='excel')
writer.writeheader()
for row in self.save_h1_bars:
writer.writerow(row)
if len(self.save_h2_bars) > 0:
outputFile = '{}_h2.csv'.format(self.vt_symbol)
with open(outputFile, 'w', encoding='utf8', newline='') as f:
fieldnames = ['datetime', 'open', 'high', 'low', 'close', 'turnover', 'volume', 'open_interest',
'ma5', 'ma10', 'ma18', 'sk', 'sd']
writer = csv.DictWriter(f=f, fieldnames=fieldnames, dialect='excel')
writer.writeheader()
for row in self.save_h2_bars:
writer.writerow(row)
if len(self.save_d_bars) > 0:
outputFile = '{}_d.csv'.format(self.vt_symbol)
with open(outputFile, 'w', encoding='utf8', newline='') as f:
fieldnames = ['datetime', 'open', 'high', 'low', 'close', 'turnover', 'volume', 'open_interest',
'ma5', 'ma10', 'ma18', 'sk', 'sd']
writer = csv.DictWriter(f=f, fieldnames=fieldnames, dialect='excel')
writer.writeheader()
for row in self.save_d_bars:
writer.writerow(row)
if len(self.save_w_bars) > 0:
outputFile = '{}_w.csv'.format(self.vt_symbol)
with open(outputFile, 'w', encoding='utf8', newline='') as f:
fieldnames = ['datetime', 'open', 'high', 'low', 'close', 'turnover', 'volume', 'open_interest',
'ma5', 'ma10', 'ma18', 'sk', 'sd']
writer = csv.DictWriter(f=f, fieldnames=fieldnames, dialect='excel')
writer.writeheader()
for row in self.save_w_bars:
writer.writerow(row)
if __name__ == '__main__':
t = test_strategy()
t.price_tick = 0.5
t.underlying_symbol = 'J'
t.vt_symbol = 'J99'
# t.createM5()
# t.createLineW()
# t.createlineM30_with_macd()
# 创建M30线
# t.createLineM30()
# 回测1小时线
# t.createLineH1()
# 回测2小时线
# t.createLineH2()
# 回测日线
# t.createLineD()
# 测试周线
t.createLineW()
# vnpy/app/cta_strategy_pro/
vnpy_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
filename = os.path.abspath(os.path.join(vnpy_root, 'bar_data/{}_20160101_20190517_1m.csv'.format(t.vt_symbol)))
csv_bar_seconds = 60 # csv 文件内bar的时间间隔60秒
import csv
csvfile = open(filename, 'r', encoding='utf8')
reader = csv.DictReader((line.replace('\0', '') for line in csvfile), delimiter=",")
last_tradingDay = None
for row in reader:
try:
dt = datetime.strptime(row['index'], '%Y-%m-%d %H:%M:%S') - timedelta(seconds=csv_bar_seconds)
bar = BarData(
gateway_name='',
symbol=t.vt_symbol,
exchange=Exchange.LOCAL,
datetime=dt,
interval=Interval.MINUTE,
open_price=round_to(float(row['open']), t.price_tick),
high_price=round_to(float(row['high']), t.price_tick),
low_price=round_to(float(row['low']), t.price_tick),
close_price=round_to(float(row['close']), t.price_tick),
volume=float(row['volume'])
)
if 'trading_date' in row:
bar.trading_day = row['trading_date']
if len(bar.trading_day) == 8 and '-' not in bar.trading_day:
bar.trading_day = bar.trading_day[0:4] + '-' + bar.trading_day[4:6] + '-' + bar.trading_day[6:8]
else:
if bar.datetime.hour >= 21:
if bar.datetime.isoweekday() == 5:
# 星期五=》星期一
bar.trading_day = (dt + timedelta(days=3)).strftime('%Y-%m-%d')
else:
# 第二天
bar.trading_day = (dt + timedelta(days=1)).strftime('%Y-%m-%d')
elif bar.datetime.hour < 8 and bar.datetime.isoweekday() == 6:
# 星期六=>星期一
bar.trading_day = (dt + timedelta(days=2)).strftime('%Y-%m-%d')
else:
bar.trading_day = bar.datetime.strftime('%Y-%m-%d')
t.onBar(bar)
# 测试 实时计算值
# sk, sd = t.lineM30.getRuntimeSKD()
# 测试实时计算值
# if bar.datetime.minute==1:
# print('rt_Dif:{}'.format(t.lineM30.rt_Dif))
except Exception as ex:
t.write_log(u'{0}:{1}'.format(Exception, ex))
traceback.print_exc()
break
t.saveData()

View File

@ -4,7 +4,7 @@ import os
import json
from datetime import datetime
from collections import OrderedDict
from vnpy.app.cta_strategy_pro.template import CtaComponent
from vnpy.app.cta_strategy_pro.base import CtaComponent
from vnpy.trader.utility import get_folder_path
@ -18,12 +18,12 @@ class CtaPolicy(CtaComponent):
构造
:param strategy:
"""
super(CtaPolicy).__init__(strategy=strategy)
super(CtaPolicy,self).__init__(strategy=strategy, kwargs=kwargs)
self.create_time = None
self.save_time = None
def toJson(self):
def to_json(self):
"""
将数据转换成dict
datetime = string
@ -36,7 +36,7 @@ class CtaPolicy(CtaComponent):
return j
def fromJson(self, json_data):
def from_json(self, json_data):
"""
将数据从json_data中恢复
:param json_data:
@ -67,7 +67,7 @@ class CtaPolicy(CtaComponent):
从持久化文件中获取
:return:
"""
json_file = os.path.abspath(os.path.join(get_folder_path('data'), u'{}_Policy.json'.format(self.strategy.name)))
json_file = os.path.abspath(os.path.join(get_folder_path('data'), u'{}_Policy.json'.format(self.strategy.strategy_name)))
json_data = {}
if os.path.exists(json_file):
@ -80,7 +80,7 @@ class CtaPolicy(CtaComponent):
json_data = {}
# 从持久化文件恢复数据
self.fromJson(json_data)
self.from_json(json_data)
def save(self):
"""
@ -88,14 +88,14 @@ class CtaPolicy(CtaComponent):
:return:
"""
json_file = os.path.abspath(
os.path.join(get_folder_path('data'), u'{}_Policy.json'.format(self.strategy.name)))
os.path.join(get_folder_path('data'), u'{}_Policy.json'.format(self.strategy.strategy_name)))
try:
# 修改为:回测时不保存
if self.strategy and self.strategy.backtesting:
return
json_data = self.toJson()
json_data = self.to_json()
json_data['save_time'] = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
with open(json_file, 'w') as f:
data = json.dumps(json_data, indent=4)

View File

@ -2,8 +2,7 @@
import sys
from vnpy.app.cta_strategy_pro.base import Direction
from vnpy.app.cta_strategy_pro.template import CtaComponent
from vnpy.app.cta_strategy_pro.base import Direction, CtaComponent
class CtaPosition(CtaComponent):
@ -15,13 +14,13 @@ class CtaPosition(CtaComponent):
"""
def __init__(self, strategy, **kwargs):
super(CtaComponent).__init__(strategy=strategy)
super(CtaPosition, self).__init__(strategy=strategy, kwargs=kwargs)
self.long_pos = 0 # 多仓持仓(正数)
self.short_pos = 0 # 空仓持仓(负数)
self.pos = 0 # 持仓状态 0:空仓/对空平等; >=1 净多仓 <=-1 净空仓
self.maxPos = sys.maxsize # 最大持仓量(多仓+空仓总量)
def open_pos(self, direction: Direction, volume: int):
def open_pos(self, direction: Direction, volume: float):
"""开、加仓"""
# volume: 正整数
@ -31,7 +30,7 @@ class CtaPosition(CtaComponent):
# 更新
self.write_log(f'多仓:{self.long_pos}->{self.long_pos + volume}')
self.write_log(u'净:{self.pos}->{self.pos + volume}')
self.write_log(f'净:{self.pos}->{self.pos + volume}')
self.long_pos += volume
self.pos += volume
@ -40,13 +39,13 @@ class CtaPosition(CtaComponent):
self.write_error(content=f'开仓异常,净:{self.pos},空:{self.short_pos},加空:{volume},超过:{self.maxPos}')
self.write_log(f'空仓:{self.short_pos}->{self.short_pos - volume}')
self.write_log(u'净:{self.pos}->{self.pos - volume}')
self.write_log(f'净:{self.pos}->{self.pos - volume}')
self.short_pos -= volume
self.pos -= volume
return True
def close_pos(self, direction: Direction, volume):
def close_pos(self, direction: Direction, volume:float):
"""平、减仓"""
# vol: 正整数

View File

@ -43,13 +43,17 @@ from vnpy.trader.constant import (
Status
)
from vnpy.trader.utility import (
load_json, save_json,
load_json,
save_json,
extract_vt_symbol,
round_to, get_folder_path,
round_to,
TRADER_DIR,
get_folder_path,
get_underlying_symbol,
append_data)
from vnpy.trader.util_logger import setup_logger, logging
from vnpy.trader.util_wechat import send_wx_msg
from vnpy.trader.converter import OffsetConverter
from .base import (
@ -88,7 +92,6 @@ class CtaEngine(BaseEngine):
6支持指定gateway的交易主引擎可接入多个gateway
"""
engine_type = EngineType.LIVE # live trading engine
engine_type = EngineType.LIVE # live trading engine
# 策略配置文件
@ -264,13 +267,17 @@ class CtaEngine(BaseEngine):
# strategy.pos -= trade.volume
# 根据策略名称,写入 data\straetgy_name_trade.csv文件
strategy_name = getattr(strategy, 'name')
trade_fields = ['time', 'symbol', 'exchange', 'vt_symbol', 'tradeid', 'vt_tradeid', 'orderid', 'vt_orderid',
trade_fields = ['datetime', 'symbol', 'exchange', 'vt_symbol', 'tradeid', 'vt_tradeid', 'orderid', 'vt_orderid',
'direction', 'offset', 'price', 'volume', 'idx_price']
trade_dict = OrderedDict()
try:
for k in trade_fields:
if k == 'time':
trade_dict[k] = datetime.now().strftime('%Y-%m-%d') + ' ' + getattr(trade, k, '')
if k == 'datetime':
dt = getattr(trade, 'datetime')
if isinstance(dt, datetime):
trade_dict[k] = dt.strftime('%Y-%m-%d %H:%M:%S')
else:
trade_dict[k] = datetime.now().strftime('%Y-%m-%d') + ' ' + getattr(trade, 'time', '')
if k in ['exchange', 'direction', 'offset']:
trade_dict[k] = getattr(trade, k).value
else:
@ -586,12 +593,12 @@ class CtaEngine(BaseEngine):
order = self.main_engine.get_order(vt_orderid)
if not order:
self.write_log(msg=f"撤单失败,找不到委托{vt_orderid}",
strategy_Name=strategy.name,
strategy_Name=strategy.strategy_name,
level=logging.ERROR)
return
return False
req = order.create_cancel_request()
self.main_engine.cancel_order(req, order.gateway_name)
return self.main_engine.cancel_order(req, order.gateway_name)
def cancel_local_stop_order(self, strategy: CtaTemplate, stop_orderid: str):
"""
@ -599,7 +606,7 @@ class CtaEngine(BaseEngine):
"""
stop_order = self.stop_orders.get(stop_orderid, None)
if not stop_order:
return
return False
strategy = self.strategies[stop_order.strategy_name]
# Remove from relation map.
@ -614,6 +621,7 @@ class CtaEngine(BaseEngine):
self.call_strategy_func(strategy, strategy.on_stop_order, stop_order)
self.put_stop_order_event(stop_order)
return True
def send_order(
self,
@ -634,7 +642,7 @@ class CtaEngine(BaseEngine):
contract = self.main_engine.get_contract(vt_symbol)
if not contract:
self.write_log(msg=f"委托失败,找不到合约:{vt_symbol}",
strategy_name=strategy.name,
strategy_name=strategy.strategy_name,
level=logging.ERROR)
return ""
if contract.gateway_name and not gateway_name:
@ -659,9 +667,9 @@ class CtaEngine(BaseEngine):
"""
"""
if vt_orderid.startswith(STOPORDER_PREFIX):
self.cancel_local_stop_order(strategy, vt_orderid)
return self.cancel_local_stop_order(strategy, vt_orderid)
else:
self.cancel_server_order(strategy, vt_orderid)
return self.cancel_server_order(strategy, vt_orderid)
def cancel_all(self, strategy: CtaTemplate):
"""
@ -689,7 +697,7 @@ class CtaEngine(BaseEngine):
self.main_engine.subscribe(req, gateway_name)
else:
self.write_log(msg=f"找不到合约{vt_symbol},添加到待订阅列表",
strategy_name=strategy.name)
strategy_name=strategy.strategy_name)
self.pending_subcribe_symbol_map[f'{gateway_name}.{vt_symbol}'].add((strategy_name, is_bar))
try:
self.write_log(f'找不到合约{vt_symbol}信息,尝试请求所有接口')
@ -713,7 +721,7 @@ class CtaEngine(BaseEngine):
strategies.append(strategy)
# 添加 策略名 strategy_name <=> 合约订阅 vt_symbol 的映射
subscribe_symbol_set = self.strategy_symbol_map[strategy.name]
subscribe_symbol_set = self.strategy_symbol_map[strategy.strategy_name]
subscribe_symbol_set.add(contract.vt_symbol)
return True
@ -786,6 +794,16 @@ class CtaEngine(BaseEngine):
""""""
return self.engine_type
@lru_cache()
def get_data_path(self):
data_path = os.path.abspath(os.path.join(TRADER_DIR, 'data'))
return data_path
@lru_cache()
def get_logs_path(self):
log_path = os.path.abspath(os.path.join(TRADER_DIR, 'log'))
return log_path
def call_strategy_func(
self, strategy: CtaTemplate, func: Callable, params: Any = None
):
@ -1389,3 +1407,17 @@ class CtaEngine(BaseEngine):
subject = "CTA策略引擎"
self.main_engine.send_email(subject, msg)
def send_wechat(self, msg: str, strategy: CtaTemplate = None):
"""
send wechat message to default receiver
:param msg:
:param strategy:
:return:
"""
if strategy:
subject = f"{strategy.strategy_name}"
else:
subject = "CTAPRO引擎"
send_wx_msg(content=f'{subject}:{msg}')

File diff suppressed because it is too large Load Diff

View File

@ -33,7 +33,7 @@ class TurtleSignalStrategy(CtaTemplate):
long_stop = 0
short_stop = 0
parameters = ["x_minuite", "entry_window", "exit_window", "atr_window", "fixed_size"]
parameters = ["x_minute", "entry_window", "exit_window", "atr_window", "fixed_size"]
variables = ["entry_up", "entry_down", "exit_up", "exit_down", "atr_value"]
def __init__(self, cta_engine, strategy_name, vt_symbol, setting):
@ -52,7 +52,7 @@ class TurtleSignalStrategy(CtaTemplate):
Callback when strategy is inited.
"""
self.write_log("策略初始化")
#self.load_bar(20)
# self.load_bar(20)
def on_start(self):
"""

View File

@ -11,6 +11,8 @@ from vnpy.app.cta_strategy_pro import (
ArrayManager,
)
from vnpy.trader.utility import round_to
class TurtleSignalStrategy_v2(CtaTemplate):
""""""
@ -35,7 +37,7 @@ class TurtleSignalStrategy_v2(CtaTemplate):
long_stop = 0
short_stop = 0
parameters = ["x_minuite", "entry_window", "exit_window", "atr_window", "fixed_size"]
parameters = ["x_minute", "entry_window", "exit_window", "atr_window", "fixed_size"]
variables = ["entry_up", "entry_down", "exit_up", "exit_down", "atr_value"]
def __init__(self, cta_engine, strategy_name, vt_symbol, setting):
@ -47,6 +49,7 @@ class TurtleSignalStrategy_v2(CtaTemplate):
# 获取合约乘数,保证金比例
self.symbol_size = self.cta_engine.get_size(self.vt_symbol)
self.symbol_margin_rate = self.cta_engine.get_margin_rate(self.vt_symbol)
self.symbol_price_tick = self.cta_engine.get_price_tick(self.vt_symbol)
self.bg = BarGenerator(self.on_bar, window=self.x_minute)
self.am = ArrayManager()
@ -58,7 +61,7 @@ class TurtleSignalStrategy_v2(CtaTemplate):
Callback when strategy is inited.
"""
self.write_log("策略初始化")
#self.load_bar(20)
# self.load_bar(20)
def on_start(self):
"""
@ -98,8 +101,12 @@ class TurtleSignalStrategy_v2(CtaTemplate):
self.exit_up, self.exit_down = self.am.donchian(self.exit_window)
if bar.datetime.strftime('%Y-%m-%d %H') == '2016-03-07 09':
a = 1 # noqa
if not self.pos:
self.atr_value = self.am.atr(self.atr_window)
self.atr_value = max(4 * self.symbol_price_tick, self.atr_value)
self.long_entry = 0
self.short_entry = 0
@ -112,13 +119,17 @@ class TurtleSignalStrategy_v2(CtaTemplate):
self.send_buy_orders(self.entry_up)
sell_price = max(self.long_stop, self.exit_down)
self.sell(sell_price, abs(self.pos), True)
refs = self.sell(sell_price, abs(self.pos), True)
if len(refs) > 0:
self.write_log(f'平多委托编号:{refs}')
elif self.pos < 0:
self.send_short_orders(self.entry_down)
cover_price = min(self.short_stop, self.exit_up)
ret = self.cover(cover_price, abs(self.pos), True)
refs = self.cover(cover_price, abs(self.pos), True)
if len(refs) > 0:
self.write_log(f'平空委托编号:{refs}')
self.put_event()
@ -161,48 +172,77 @@ class TurtleSignalStrategy_v2(CtaTemplate):
def send_buy_orders(self, price):
""""""
if self.pos >= 4:
return
if self.cur_mi_price <= price - self.atr_value/2:
if self.cur_mi_price <= price - self.atr_value / 2:
return
self.update_invest_pos()
t = self.pos / self.invest_pos
t = int(self.pos / self.invest_pos)
if t >= 4:
return
if t < 1:
self.buy(price, self.invest_pos, True)
refs = self.buy(price, self.invest_pos, True)
if len(refs) > 0:
self.write_log(f'买入委托编号:{refs}')
if t < 2:
self.buy(price + self.atr_value * 0.5, self.invest_pos, True)
if t == 1 and self.cur_mi_price > price:
buy_price = round_to(price + self.atr_value * 0.5 , self.symbol_price_tick)
self.write_log(u'发出做多停止单,触发价格为: {}'.format(buy_price))
refs = self.buy(buy_price, self.invest_pos, True)
if len(refs) > 0:
self.write_log(f'买入委托编号:{refs}')
if t < 3:
self.buy(price + self.atr_value, self.invest_pos, True)
if t == 2 and self.cur_mi_price > price + self.atr_value * 0.5:
buy_price = round_to(price + self.atr_value, self.symbol_price_tick)
self.write_log(u'发出做多停止单,触发价格为: {}'.format(buy_price))
refs = self.buy(buy_price, self.invest_pos, True)
if len(refs) > 0:
self.write_log(f'买入委托编号:{refs}')
if t < 4:
self.buy(price + self.atr_value * 1.5, self.invest_pos, True)
if t == 3 and self.cur_mi_price > price + self.atr_value:
buy_price = round_to(price + self.atr_value * 1.5, self.symbol_price_tick)
self.write_log(u'发出做多停止单,触发价格为: {}'.format(buy_price))
refs = self.buy(buy_price, self.invest_pos, True)
if len(refs) > 0:
self.write_log(f'买入委托编号:{refs}')
def send_short_orders(self, price):
""""""
if self.pos <= -4:
return
if self.cur_mi_price >= price + self.atr_value / 2:
return
self.update_invest_pos()
t = self.pos / self.invest_pos
t = int(self.pos / self.invest_pos)
if t <= -4:
return
if t > -1:
self.short(price, self.invest_pos, True)
refs = self.short(price, self.invest_pos, True)
if len(refs) > 0:
self.write_log(f'卖出委托编号:{refs}')
if t > -2:
self.short(price - self.atr_value * 0.5, self.invest_pos, True)
if t == -1 and self.cur_mi_price < price:
short_price = round_to(price - self.atr_value * 0.5, self.symbol_price_tick)
self.write_log(u'发出做空停止单,触发价格为: {}'.format(short_price))
refs = self.short(short_price, self.invest_pos, True)
if len(refs) > 0:
self.write_log(f'卖出委托编号:{refs}')
if t > -3:
self.short(price - self.atr_value, self.invest_pos, True)
if t == -2 and self.cur_mi_price < price + self.atr_value * 0.5:
short_price = round_to(price - self.atr_value, self.symbol_price_tick)
self.write_log(u'发出做空停止单,触发价格为: {}'.format(short_price))
refs = self.short(short_price, self.invest_pos, True)
if len(refs) > 0:
self.write_log(f'卖出委托编号:{refs}')
if t > -4:
self.short(price - self.atr_value * 1.5, self.invest_pos, True)
if t == -3 and self.cur_mi_price < price + self.atr_value:
short_price = round_to(price - self.atr_value * 1.5, self.symbol_price_tick)
self.write_log(u'发出做空停止单,触发价格为: {}'.format(short_price))
refs = self.short(short_price, self.invest_pos, True)
if len(refs) > 0:
self.write_log(f'卖出委托编号:{refs}')

View File

@ -1,40 +1,23 @@
""""""
import sys
import os
import uuid
import bz2
import pickle
import copy
import traceback
from abc import ABC
from copy import copy
from typing import Any, Callable
from logging import INFO, ERROR
from vnpy.trader.constant import Interval, Direction, Offset
from datetime import datetime
from vnpy.trader.constant import Interval, Direction, Offset, Status
from vnpy.trader.object import BarData, TickData, OrderData, TradeData
from vnpy.trader.utility import virtual
from vnpy.trader.utility import virtual, append_data, extract_vt_symbol,get_underlying_symbol
from .base import StopOrder, EngineType
class CtaComponent(ABC):
""" CTA策略基础组件"""
def __init__(self, strategy=None, **kwargs):
"""
构造
:param strategy:
"""
self.strategy = strategy
# ----------------------------------------------------------------------
def write_log(self, content: str):
"""记录日志"""
if self.strategy:
self.strategy.write_log(msg=content, level=INFO)
else:
print(content)
# ----------------------------------------------------------------------
def write_error(self, content: str, level: int = ERROR):
"""记录错误日志"""
if self.strategy:
self.strategy.write_log(msg=content, level=level)
else:
print(content, file=sys.stderr)
from .cta_grid_trade import CtaGrid, CtaGridTrade
from .cta_position import CtaPosition
class CtaTemplate(ABC):
@ -56,9 +39,17 @@ class CtaTemplate(ABC):
self.strategy_name = strategy_name
self.vt_symbol = vt_symbol
self.inited = False
self.trading = False
self.pos = 0
self.inited = False # 是否初始化完毕
self.trading = False # 是否开始交易
self.pos = 0 # 持仓/仓差
self.entrust = 0 # 是否正在委托, 0, 无委托 , 1, 委托方向是LONG -1, 委托方向是SHORT
self.tick_dict = {} # 记录所有on_tick传入最新tick
# 保存委托单编号和相关委托单的字典
# key为委托单编号
# value为该合约相关的委托单
self.active_orders = {}
# Copy a new variables list here to avoid duplicate insert when multiple
# strategy instances are created with the same strategy class.
@ -66,8 +57,7 @@ class CtaTemplate(ABC):
self.variables.insert(0, "inited")
self.variables.insert(1, "trading")
self.variables.insert(2, "pos")
self.update_setting(setting)
self.variables.insert(3, "entrust")
def update_setting(self, setting: dict):
"""
@ -119,6 +109,23 @@ class CtaTemplate(ABC):
}
return strategy_data
def get_positions(self):
""" 返回持仓数量"""
pos_list = []
if self.pos > 0:
pos_list.append({
"vt_symbol": self.vt_symbol,
"direction": "long",
"volume": self.pos
})
elif self.pos < 0:
pos_list.append({
"vt_symbol": self.vt_symbol,
"direction": "short",
"volume": abs(self.pos)
})
@virtual
def on_timer(self):
pass
@ -178,19 +185,24 @@ class CtaTemplate(ABC):
"""
pass
def buy(self, price: float, volume: float, stop: bool = False, lock: bool = False, vt_symbol: str = ''):
def buy(self, price: float, volume: float, stop: bool = False, lock: bool = False,
vt_symbol: str = '', order_time: datetime = None, grid: CtaGrid = None):
"""
Send buy order to open a long position.
"""
return self.send_order(vt_symbol=vt_symbol,
direction=Direction.LONG,
offset=Offset.OPEN,
price=price,
volume=volume,
stop=stop,
lock=lock)
lock=lock,
order_time=order_time,
grid=grid)
def sell(self, price: float, volume: float, stop: bool = False, lock: bool = False, vt_symbol: str = ''):
def sell(self, price: float, volume: float, stop: bool = False, lock: bool = False,
vt_symbol: str = '', order_time: datetime = None, grid: CtaGrid = None):
"""
Send sell order to close a long position.
"""
@ -200,9 +212,12 @@ class CtaTemplate(ABC):
price=price,
volume=volume,
stop=stop,
lock=lock)
lock=lock,
order_time=order_time,
grid=grid)
def short(self, price: float, volume: float, stop: bool = False, lock: bool = False, vt_symbol: str = ''):
def short(self, price: float, volume: float, stop: bool = False, lock: bool = False,
vt_symbol: str = '', order_time: datetime = None, grid: CtaGrid = None):
"""
Send short order to open as short position.
"""
@ -212,9 +227,12 @@ class CtaTemplate(ABC):
price=price,
volume=volume,
stop=stop,
lock=lock)
lock=lock,
order_time=order_time,
grid=grid)
def cover(self, price: float, volume: float, stop: bool = False, lock: bool = False, vt_symbol: str = ''):
def cover(self, price: float, volume: float, stop: bool = False, lock: bool = False,
vt_symbol: str = '', order_time: datetime = None, grid: CtaGrid = None):
"""
Send cover order to close a short position.
"""
@ -224,7 +242,9 @@ class CtaTemplate(ABC):
price=price,
volume=volume,
stop=stop,
lock=lock)
lock=lock,
order_time=order_time,
grid=grid)
def send_order(
self,
@ -234,7 +254,9 @@ class CtaTemplate(ABC):
price: float,
volume: float,
stop: bool = False,
lock: bool = False
lock: bool = False,
order_time: datetime = None,
grid: CtaGrid = None
):
"""
Send a new order.
@ -243,20 +265,44 @@ class CtaTemplate(ABC):
if vt_symbol == '':
vt_symbol = self.vt_symbol
if self.trading:
if not self.trading:
return []
vt_orderids = self.cta_engine.send_order(
self, vt_symbol, direction, offset, price, volume, stop, lock
)
if order_time is None:
order_time = datetime.now()
for vt_orderid in vt_orderids:
d = {
'direction': direction.value,
'offset': offset.value,
'vt_symbol': vt_symbol,
'price': price,
'volume': volume,
'traded': 0,
'order_time': order_time,
'status': Status.SUBMITTING
}
if grid:
d.update({'grid': grid})
grid.order_ids.append(vt_orderid)
self.active_orders.update({vt_orderid: d})
if direction == Direction.LONG:
self.entrust = 1
elif direction == Direction.SHORT:
self.entrust = -1
return vt_orderids
else:
return []
def cancel_order(self, vt_orderid: str):
"""
Cancel an existing order.
"""
if self.trading:
self.cta_engine.cancel_order(self, vt_orderid)
return self.cta_engine.cancel_order(self, vt_orderid)
return False
def cancel_all(self):
"""
@ -265,12 +311,32 @@ class CtaTemplate(ABC):
if self.trading:
self.cta_engine.cancel_all(self)
def is_upper_limit(self, symbol):
"""是否涨停"""
tick = self.tick_dict.get(symbol, None)
if tick is None or tick.limit_up is None or tick.limit_up == 0:
return False
if tick.bid_price_1 == tick.limit_up:
return True
def is_lower_limit(self, symbol):
"""是否跌停"""
tick = self.tick_dict.get(symbol, None)
if tick is None or tick.limit_down is None or tick.limit_down == 0:
return False
if tick.ask_price_1 == tick.limit_down:
return True
def write_log(self, msg: str, level: int = INFO):
"""
Write a log message.
"""
self.cta_engine.write_log(msg=msg, strategy_name=self.strategy_name, level=level)
def write_error(self, msg: str):
"""write error log message"""
self.write_log(msg=msg, level=ERROR)
def get_engine_type(self):
"""
Return whether the cta_engine is backtesting or live trading.
@ -452,3 +518,474 @@ class TargetPosTemplate(CtaTemplate):
else:
vt_orderids = self.short(short_price, abs(pos_change))
self.vt_orderids.extend(vt_orderids)
class CtaProTemplate(CtaTemplate):
"""
增强模板
"""
# 逻辑过程日志
dist_fieldnames = ['datetime', 'symbol', 'volume', 'price',
'operation', 'signal', 'stop_price', 'target_price',
'long_pos', 'short_pos']
def __init__(self, cta_engine, strategy_name, vt_symbol, setting):
""""""
super(CtaProTemplate, self).__init__(
cta_engine, strategy_name, vt_symbol, setting
)
self.idx_symbol = None # 指数合约
self.price_tick = 1 # 商品的最小价格跳动
self.symbol_size = 10 # 商品得合约乘数
self.cur_datetime = None # 当前Tick时间
self.cur_mi_tick = None # 最新的主力合约tick( vt_symbol)
self.cur_99_tick = None # 最新得指数合约tick( idx_symbol)
self.cur_mi_price = None # 当前价(主力合约 vt_symbol)
self.cur_99_price = None # 当前价tick时根据tick更新onBar回测时根据bar.close更新)
self.cancel_seconds = 120 # 撤单时间(秒)
self.backtesting = False
self.klines = {} # K线字典: kline_name: kline
# 增加仓位管理模块
self.position = CtaPosition(strategy=self)
# 增加网格持久化模块
self.gt = CtaGridTrade(strategy=self)
# 增加指数合约
if 'idx_symbol' not in self.parameters:
self.parameters.append('idx_symbol')
if 'backtesting' not in self.parameters:
self.parameters.append('backtesting')
def update_setting(self, setting: dict):
"""
Update strategy parameter wtih value in setting dict.
"""
for name in self.parameters:
if name in setting:
setattr(self, name, setting[name])
if self.idx_symbol is None:
symbol, exchange = extract_vt_symbol(self.vt_symbol)
self.idx_symbol = get_underlying_symbol(symbol).upper() + '99.' + exchange.value
if self.vt_symbol != self.idx_symbol:
self.write_log(f'指数合约:{self.idx_symbol}, 主力合约:{self.vt_symbol}')
self.price_tick = self.cta_engine.get_price_tick(self.vt_symbol)
self.symbol_size = self.cta_engine.get_size(self.vt_symbol)
def save_klines_to_cache(self, kline_names: list = []):
"""
保存K线数据到缓存
:param kline_names: 一般为self.klines的keys
:return:
"""
if len(kline_names) == 0:
kline_names = list(self.klines.keys())
# 获取保存路径
save_path = self.cta_engine.get_data_path()
# 保存缓存的文件名
file_name = os.path.abspath(os.path.join(save_path, f'{self.strategy_name}_klines.pkb2'))
with bz2.BZ2File(file_name, 'wb') as f:
klines = {}
for kline_name in kline_names:
klines.update({kline_name: self.klines.get(kline_name, None)})
pickle.dump(klines, f)
def load_klines_from_cache(self, kline_names: list = []):
"""
从缓存加载K线数据
:param kline_names:
:return:
"""
if len(kline_names) == 0:
kline_names = list(self.klines.keys())
save_path = self.cta_engine.get_data_path()
file_name = os.path.abspath(os.path.join(save_path, f'{self.strategy_name}_klines.pkb2'))
try:
last_bar_dt = None
with bz2.BZ2File(file_name, 'rb') as f:
klines = pickle.load(f)
# 逐一恢复K线
for kline_name in kline_names:
# 缓存的k线实例
cache_kline = klines.get(kline_name, None)
# 当前策略实例的K线实例
strategy_kline = self.klines.get(kline_name, None)
if cache_kline and strategy_kline:
# 临时保存当前的回调函数
cb_on_bar = strategy_kline.cb_on_bar
# 缓存实例数据 =》 当前实例数据
strategy_kline.__dict__.update(cache_kline.__dict__)
# 所有K线的最后时间
if last_bar_dt and strategy_kline.cur_datetime:
last_bar_dt = max(last_bar_dt, strategy_kline.cur_datetime)
else:
last_bar_dt = strategy_kline.cur_datetime
# 重新绑定k线策略与on_bar回调函数
strategy_kline.strategy = self
strategy_kline.cb_on_bar = cb_on_bar
self.write_log(f'恢复{kline_name}缓存数据,最新bar结束时间:{last_bar_dt}')
self.write_log(u'加载缓存k线数据完毕')
return last_bar_dt
except Exception as ex:
self.write_error(f'加载缓存K线数据失败:{str(ex)}')
return None
def get_klines_snapshot(self):
"""返回当前klines的切片数据"""
try:
d = {
'strategy': self.strategy_name,
'datetime': datetime.now()}
for kline_name in self.klines.keys():
d.update({kline_name: self.klines.get(kline_name).get_data()})
return d
except Exception as ex:
self.write_error(u'获取klines切片数据失败')
return {}
def init_position(self):
"""
初始化Positin
使用网格的持久化获取开仓状态的多空单更新
:return:
"""
self.write_log(u'init_position(),初始化持仓')
pos_symbols = set()
if len(self.gt.up_grids) <= 0:
self.position.short_pos = 0
# 加载已开仓的空单数据网格JSON
short_grids = self.gt.load(direction=Direction.SHORT, open_status_filter=[True])
if len(short_grids) == 0:
self.write_log(u'没有持久化的空单数据')
self.gt.up_grids = []
else:
self.gt.up_grids = short_grids
for sg in short_grids:
if len(sg.order_ids) > 0 or sg.order_status:
self.write_log(f'重置委托状态:{sg.order_status},清除委托单:{sg.order_ids}')
sg.order_status = False
sg.order_ids = []
short_symbol = sg.snapshot.get('mi_symbol', self.vt_symbol)
pos_symbols.add(short_symbol)
self.write_log(u'加载持仓空单[{},价格:{}],[指数:{},价格:{}],数量:{}'
.format(short_symbol, sg.snapshot.get('open_price'),
self.idx_symbol, sg.open_price, sg.volume))
self.position.short_pos -= sg.volume
self.write_log(u'持久化空单,共持仓:{}'.format(abs(self.position.short_pos)))
if len(self.gt.dn_grids) <= 0:
# 加载已开仓的多数据网格JSON
self.position.long_pos = 0
long_grids = self.gt.load(direction=Direction.LONG, open_status_filter=[True])
if len(long_grids) == 0:
self.write_log(u'没有持久化的多单数据')
self.gt.dn_grids = []
else:
self.gt.dn_grids = long_grids
for lg in long_grids:
if len(lg.order_ids) > 0 or lg.order_status:
self.write_log(f'重置委托状态:{lg.order_status},清除委托单:{lg.order_ids}')
lg.order_status = False
lg.order_ids = []
# lg.type = self.line.name
long_symbol = lg.snapshot.get('mi_symbol', self.vt_symbol)
pos_symbols.add(long_symbol)
self.write_log(u'加载持仓多单[{},价格:{}],[指数{},价格:{}],数量:{}'
.format(lg.snapshot.get('miSymbol'), lg.snapshot.get('open_price'),
self.idx_symbol, lg.open_price, lg.volume))
self.position.long_pos += lg.volume
self.write_log(f'持久化多单,共持仓:{self.position.long_pos}')
self.position.pos = self.position.long_pos + self.position.short_pos
self.write_log(
u'{}加载持久化数据完成,多单:{},空单:{},共:{}'
.format(self.strategy_name,
self.position.long_pos,
abs(self.position.short_pos),
self.position.pos))
self.pos = self.position.pos
self.gt.save()
self.display_grids()
if not self.backtesting:
pos_symbols.add(self.vt_symbol)
pos_symbols.add(self.idx_symbol)
# 如果持仓的合约不在self.vt_symbol中需要订阅
for symbol in list(pos_symbols):
self.write_log(f'新增订阅合约:{symbol}')
self.cta_engine.subscribe_symbol(strategy_name=self.strategy_name, vt_symbol=symbol)
def get_positions(self):
"""
获取策略当前持仓(重构使用主力合约
:return: [{'vt_symbol':symbol,'direction':direction,'volume':volume]
"""
if not self.position:
return []
pos_list = []
if self.position.long_pos > 0:
for g in self.gt.get_opened_grids(direction=Direction.LONG):
vt_symbol = g.snapshot.get('mi_symbol', self.vt_symbol)
open_price = g.snapshot.get('open_price', g.openPrice)
pos_list.append({'vt_symbol': vt_symbol,
'direction': 'long',
'volume': g.volume - g.traded_volume,
'price': open_price})
if abs(self.position.short_pos) > 0:
for g in self.gt.get_opened_grids(direction=Direction.SHORT):
vt_symbol = g.snapshot.get('mi_symbol', self.vt_symbol)
open_price = g.snapshot.get('open_price', g.open_price)
pos_list.append({'vt_symbol': vt_symbol,
'direction': 'short',
'volume': abs(g.volume - g.traded_volume),
'price': open_price})
if self.cur_datetime and (datetime.now() - self.cur_datetime).total_seconds() < 10:
self.write_log(u'当前持仓:{}'.format(pos_list))
return pos_list
def tns_cancel_logic(self, dt, force=False):
"撤单逻辑"""
if len(self.active_orders) < 1:
self.entrust = 0
return
for vt_orderid in list(self.active_orders.keys()):
order_info = self.active_orders.get(vt_orderid)
if order_info.get('status', None) in [Status.CANCELLED, Status.REJECTED]:
self.active_orders.pop(vt_orderid, None)
continue
order_time = order_info.get('order_time')
over_ms = (dt - order_time).total_seconds()
if (over_ms > self.cancel_seconds) \
or force: # 超过设置的时间还未成交
self.write_log(f'{dt}, 超时{over_ms}秒未成交,取消委托单:{order_info}')
if self.cancel_order(vt_orderid):
order_info.update({'status': Status.CANCELLING})
else:
order_info.update({'status': Status.CANCELLED})
if len(self.active_orders) < 1:
self.entrust = 0
def tns_switch_long_pos(self):
"""切换合约,从持仓的非主力合约,切换至主力合约"""
if self.entrust != 0 and self.position.long_pos == 0:
return
if self.cur_mi_price == 0:
return
none_mi_grid = None
none_mi_symbol = None
# 找出非主力合约的持仓网格
for g in self.gt.get_opened_grids(direction=Direction.LONG):
none_mi_symbol = g.snapshot.get('mi_symbol')
if none_mi_symbol is None or none_mi_symbol == self.vt_symbol:
# 如果持仓的合约跟策略配置的vt_symbol一致则不处理
continue
if not g.open_status or g.order_status or g.volume - g.traded_volume <= 0:
continue
none_mi_grid = g
if g.traded_volume > 0 and g.volume - g.traded_volume > 0:
g.volume -= g.traded_volume
g.traded_volume = 0
break
if none_mi_grid is None:
return
# 找到行情中非主力合约/主力合约的最新价
none_mi_tick = self.tick_dict.get(none_mi_symbol)
mi_tick = self.tick_dict.get(self.vt_symbol, None)
if none_mi_tick is None or mi_tick is None:
return
# 如果涨停价,不做卖出
if self.is_upper_limit(none_mi_symbol) or self.is_upper_limit(self.vt_symbol):
return
none_mi_price = max(none_mi_tick.last_price, none_mi_tick.bid_price_1)
grid = copy.copy(none_mi_grid)
# 委托卖出非主力合约
order_ids = self.sell(price=none_mi_price, volume=none_mi_grid.volume, vt_symbol=none_mi_symbol,
grid=none_mi_grid)
if len(order_ids) > 0:
self.write_log(f'切换合约,委托卖出非主力合约{none_mi_symbol}持仓:{none_mi_grid.volume}')
# 添加买入主力合约
grid.id = str(uuid.uuid1())
grid.snapshot.update({'mi_symbol': self.vt_symbol, 'open_price': self.cur_mi_price})
self.gt.dn_grids.append(grid)
order_ids = self.buy(price=self.cur_mi_price, volume=grid.volume, vt_symbol=self.vt_symbol, grid=grid)
if len(order_ids) > 0:
self.write_log(u'切换合约,委托买入主力合约:{},价格:{},数量:{}'
.format(self.vt_symbol, self.cur_mi_price, grid.volume))
self.gt.save()
else:
self.write_error(f'委托买入主力合约:{self.vt_symbol}失败')
else:
self.write_error(f'委托卖出非主力合约:{none_mi_symbol}失败')
def tns_switch_short_pos(self):
"""切换合约,从持仓的非主力合约,切换至主力合约"""
if self.entrust != 0 and self.position.short_pos == 0:
return
if self.cur_mi_price == 0:
return
none_mi_grid = None
none_mi_symbol = None
# 找出非主力合约的持仓网格
for g in self.gt.get_opened_grids(direction=Direction.SHORT):
none_mi_symbol = g.snapshot.get('miSymbol')
if none_mi_symbol is None or none_mi_symbol == self.vt_symbol:
continue
if not g.open_status or g.order_status or g.volume - g.traded_volume <= 0:
continue
none_mi_grid = g
if g.traded_volume > 0 and g.volume - g.traded_volume > 0:
g.volume -= g.traded_volume
g.traded_volume = 0
break
# 找不到与主力合约不一致的持仓网格
if none_mi_grid is None:
return
# 找到行情中非主力合约的最新价
none_mi_tick = self.tick_dict.get(none_mi_symbol)
mi_tick = self.tick_dict.get(self.vt_symbol, None)
if none_mi_tick is None or mi_tick is None:
return
# 如果跌停价不做cover
if self.is_lower_limit(none_mi_symbol) or self.is_lower_limit(self.vt_symbol):
return
none_mi_price = max(none_mi_tick.last_price, none_mi_tick.bid_price_1)
grid = copy.copy(none_mi_grid)
# 委托平空非主力合约
order_ids = self.cover(price=none_mi_price, volume=none_mi_grid.volume, vt_symbol=self.vt_symbol,
grid=none_mi_grid)
if len(order_ids) > 0:
self.write_log(f'委托平空非主力合约{none_mi_symbol}持仓:{none_mi_grid.volume}')
# 添加卖出主力合约
grid.id = str(uuid.uuid1())
grid.snapshot.update({'mi_symbol': self.vt_symbol, 'open_price': self.cur_mi_price})
self.gt.up_grids.append(grid)
order_ids = self.short(price=self.cur_mi_price, volume=grid.volume, vt_symbol=self.vt_symbol, grid=grid)
if len(order_ids) > 0:
self.write_log(f'委托做空主力合约:{self.vt_symbol},价格:{self.cur_mi_price},数量:{grid.volume}')
self.gt.save()
else:
self.write_error(f'委托做空主力合约:{self.vt_symbol}失败')
else:
self.write_error(f'委托平空非主力合约:{none_mi_symbol}失败')
def display_grids(self):
"""更新网格显示信息"""
if not self.inited:
return
up_grids_info = self.gt.to_str(direction=Direction.SHORT)
if len(self.gt.up_grids) > 0:
self.write_log(up_grids_info)
dn_grids_info = self.gt.to_str(direction=Direction.LONG)
if len(self.gt.dn_grids) > 0:
self.write_log(dn_grids_info)
def display_tns(self):
"""显示事务的过程记录=》 log"""
if not self.inited:
return
self.write_log(u'{} 当前指数{}价格:{},当前主力{}价格:{}'
.format(self.cur_datetime,
self.idx_symbol, self.cur_99_price,
self.vt_symbol, self.cur_mi_price))
if hasattr(self, 'policy'):
policy = getattr(self, 'policy')
op = getattr(policy, 'to_json', None)
if callable(op):
self.write_log(u'当前Policy:{}'.format(policy.to_json()))
def save_dist(self, dist_data):
"""
保存策略逻辑过程记录= csv文件按
:param dist_data:
:return:
"""
if self.backtesting:
save_path = self.cta_engine.get_logs_path()
else:
save_path = self.cta_engine.get_data_path()
try:
if self.position:
dist_data.update({'long_pos': self.position.long_pos})
dist_data.update({'short_pos': self.position.short_pos})
file_name = os.path.abspath(os.path.join(save_path, f'{self.strategy_name}_dist.csv'))
append_data(file_name=file_name, dict_data=dist_data, field_names=self.dist_fieldnames)
except Exception as ex:
self.write_error(u'save_dist 异常:{} {}'.format(str(ex), traceback.format_exc()))
def save_tns(self, tns_data):
"""
保存多空事务记录=csv文件,便于后续分析
:param tns_data:
:return:
"""
if self.backtesting:
save_path = self.cta_engine.get_logs_path()
else:
save_path = self.cta_engine.get_data_path()
try:
file_name = os.path.abspath(os.path.join(save_path, f'{self.strategy_name}_tns.csv'))
append_data(file_name=file_name, dict_data=tns_data)
except Exception as ex:
self.write_error(u'save_tns 异常:{} {}'.format(str(ex), traceback.format_exc()))
def send_wechat(self, msg: str):
"""实盘时才发送微信"""
if self.backtesting:
return
self.cta_engine.send_wechat(msg=msg, strategy=self)

View File

@ -0,0 +1,65 @@
# flake8: noqa
# 测试 app.cta_strategy_pro.CtaLineBar组件
# 从通达信获取历史交易记录模拟tick。推送至line_bar
import os
import sys
import json
vnpy_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
if vnpy_root not in sys.path:
print(f'sys.path apppend:{vnpy_root}')
sys.path.append(vnpy_root)
os.environ["VNPY_TESTING"] = "1"
from vnpy.trader.constant import Interval, Exchange
from vnpy.data.tdx.tdx_common import FakeStrategy
from vnpy.data.tdx.tdx_future_data import TdxFutureData
from vnpy.app.cta_strategy_pro.cta_line_bar import CtaLineBar
from vnpy.trader.object import TickData
from vnpy.trader.utility import get_trading_date
t1 = FakeStrategy()
tdx_api = TdxFutureData(strategy=t1)
def on_bar(bar):
print(f'{bar.__dict__}')
# 创建10秒周期的k线
kline_setting = {}
kline_setting["name"] = "S10"
kline_setting['interval'] = Interval.SECOND
kline_setting['bar_interval'] = 10
kline_setting['price_tick'] = 0.5
kline_setting['underlying_symbol'] = 'J'
kline_s10 = CtaLineBar(strategy=t1, cb_on_bar=on_bar, setting=kline_setting)
ret, result = tdx_api.get_history_transaction_data('J99', '20200106')
# for data in result[0:10] + result[-10:]:
# print(data)
for data in result:
dt = data['datetime']
price = float(data['price'])
volume = float(data['volume'])
tick = TickData(
gateway_name='tdx',
datetime=dt,
last_price=price,
volume=volume,
symbol='J99',
exchange=Exchange('DCE'),
date=dt.strftime('%Y-%m-%d'),
time=dt.strftime('%H:%M:%S'),
trading_day=get_trading_date(dt)
)
kline_s10.on_tick(tick)
os._exit(0)

View File

@ -0,0 +1,442 @@
# flake8: noqa
# 测试 app.cta_strategy_pro.CtaLineBar组件
# 从通达信获取历史交易记录模拟tick。推送至line_bar
import os
import sys
import json
import traceback
from datetime import datetime, timedelta
vnpy_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
if vnpy_root not in sys.path:
print(f'sys.path apppend:{vnpy_root}')
sys.path.append(vnpy_root)
os.environ["VNPY_TESTING"] = "1"
from vnpy.trader.constant import Interval, Exchange
from vnpy.trader.object import BarData
from vnpy.app.cta_strategy_pro.cta_line_bar import (
CtaLineBar,
CtaMinuteBar,
CtaHourBar,
CtaDayBar,
CtaWeekBar)
from vnpy.trader.utility import round_to
class test_strategy(object):
def __init__(self):
self.price_tick = 1
self.underlying_symbol = 'I'
self.vt_symbol = 'I99'
self.lineM5 = None
self.lineM30 = None
self.lineH1 = None
self.lineH2 = None
self.lineD = None
self.lineW = None
self.TMinuteInterval = 1
self.save_m30_bars = []
self.save_h1_bars = []
self.save_h2_bars = []
self.save_d_bars = []
self.save_w_bars = []
def createM5(self):
"""使用ctalinbar创建5分钟K线"""
lineM5Setting = {}
lineM5Setting['name'] = u'M5'
lineM5Setting['interval'] = Interval.MINUTE
lineM5Setting['bar_interval'] = 5
lineM5Setting['mode'] = CtaLineBar.TICK_MODE
lineM5Setting['price_tick'] = self.price_tick
lineM5Setting['underlying_symbol'] = self.underlying_symbol
self.lineM5 = CtaLineBar(self, self.onBarM5, lineM5Setting)
def onBarM5(self, bar):
self.write_log(self.lineM5.get_last_bar_str())
def createlineM30_with_macd(self):
"""使用CtaLineBar创建30分钟时间"""
# 创建M30 K线
lineM30Setting = {}
lineM30Setting['name'] = u'M30'
lineM30Setting['interval'] = Interval.MINUTE
lineM30Setting['bar_interval'] = 30
lineM30Setting['para_macd_fast_len'] = 26
lineM30Setting['para_macd_slow_len'] = 12
lineM30Setting['para_macd_signal_len'] = 9
lineM30Setting['price_tick'] = self.price_tick
lineM30Setting['underlying_symbol'] = self.underlying_symbol
self.lineM30 = CtaLineBar(self, self.onBarM30MACD, lineM30Setting)
def onBarM30MACD(self, bar):
self.write_log(self.lineM30.get_last_bar_str())
def createLineM30(self):
"""使用ctaMinuteBar, 测试内部自动写入csv文件"""
# 创建M30 K线
lineM30Setting = {}
lineM30Setting['name'] = u'M30'
lineM30Setting['interval'] = Interval.MINUTE
lineM30Setting['bar_interval'] = 30
lineM30Setting['para_pre_len'] = 10
lineM30Setting['para_ma1_len'] = 5
lineM30Setting['para_ma2_len'] = 10
lineM30Setting['para_ma3_len'] = 60
lineM30Setting['para_active_yb'] = True
lineM30Setting['para_active_skd'] = True
lineM30Setting['price_tick'] = self.price_tick
lineM30Setting['underlying_symbol'] = self.underlying_symbol
self.lineM30 = CtaMinuteBar(self, self.onBarM30, lineM30Setting)
# 写入文件
self.lineM30.export_filename = os.path.abspath(
os.path.join(os.getcwd(),
u'export_{}_{}.csv'.format(self.vt_symbol, self.lineM30.name)))
self.lineM30.export_fields = [
{'name': 'datetime', 'source': 'bar', 'attr': 'datetime', 'type_': 'datetime'},
{'name': 'open', 'source': 'bar', 'attr': 'open_price', 'type_': 'float'},
{'name': 'high', 'source': 'bar', 'attr': 'high_price', 'type_': 'float'},
{'name': 'low', 'source': 'bar', 'attr': 'low_price', 'type_': 'float'},
{'name': 'close', 'source': 'bar', 'attr': 'close_price', 'type_': 'float'},
{'name': 'turnover', 'source': 'bar', 'attr': 'turnover', 'type_': 'float'},
{'name': 'volume', 'source': 'bar', 'attr': 'volume', 'type_': 'float'},
{'name': 'open_interest', 'source': 'bar', 'attr': 'open_interest', 'type_': 'float'},
{'name': 'kf', 'source': 'line_bar', 'attr': 'line_statemean', 'type_': 'list'}
]
def createLineH1(self):
# 创建2小时K线
lineH1Setting = {}
lineH1Setting['name'] = u'H1'
lineH1Setting['interval'] = Interval.HOUR
lineH1Setting['bar_interval'] = 1
lineH1Setting['para_pre_len'] = 10
lineH1Setting['para_ema1_len'] = 5
lineH1Setting['para_ema2_len'] = 10
lineH1Setting['para_ema3_len'] = 60
lineH1Setting['para_active_yb'] = True
lineH1Setting['para_active_skd'] = True
lineH1Setting['price_tick'] = self.price_tick
lineH1Setting['underlying_symbol'] = self.underlying_symbol
self.lineH1 = CtaLineBar(self, self.onBarH1, lineH1Setting)
def createLineH2(self):
# 创建2小时K线
lineH2Setting = {}
lineH2Setting['name'] = u'H2'
lineH2Setting['interval'] = Interval.HOUR
lineH2Setting['bar_interval'] = 2
lineH2Setting['para_pre_len'] = 5
lineH2Setting['para_ma1_len'] = 5
lineH2Setting['para_ma2_len'] = 10
lineH2Setting['para_ma3_len'] = 18
lineH2Setting['para_active_yb'] = True
lineH2Setting['para_active_skd'] = True
lineH2Setting['mode'] = CtaLineBar.TICK_MODE
lineH2Setting['price_tick'] = self.price_tick
lineH2Setting['underlying_symbol'] = self.underlying_symbol
self.lineH2 = CtaHourBar(self, self.onBarH2, lineH2Setting)
def createLineD(self):
# 创建的日K线
lineDaySetting = {}
lineDaySetting['name'] = u'D1'
lineDaySetting['bar_interval'] = 1
lineDaySetting['para_pre_len'] = 5
lineDaySetting['para_art1_len'] = 26
lineDaySetting['para_ma1_len'] = 5
lineDaySetting['para_ma2_len'] = 10
lineDaySetting['para_ma3_len'] = 18
lineDaySetting['para_active_yb'] = True
lineDaySetting['para_active_skd'] = True
lineDaySetting['price_tick'] = self.price_tick
lineDaySetting['underlying_symbol'] = self.underlying_symbol
self.lineD = CtaDayBar(self, self.onBarD, lineDaySetting)
def createLineW(self):
"""创建周线"""
lineWeekSetting = {}
lineWeekSetting['name'] = u'W1'
lineWeekSetting['para_pre_len'] = 5
lineWeekSetting['para_art1_len'] = 26
lineWeekSetting['para_ma1_len'] = 5
lineWeekSetting['para_ma2_len'] = 10
lineWeekSetting['para_ma3_len'] = 18
lineWeekSetting['para_active_yb'] = True
lineWeekSetting['para_active_skd'] = True
lineWeekSetting['mode'] = CtaDayBar.TICK_MODE
lineWeekSetting['price_tick'] = self.price_tick
lineWeekSetting['underlying_symbol'] = self.underlying_symbol
self.lineW = CtaWeekBar(self, self.onBarW, lineWeekSetting)
def onBar(self, bar):
# print(u'tradingDay:{},dt:{},o:{},h:{},l:{},c:{},v:{}'.format(bar.trading_day,bar.datetime, bar.open, bar.high, bar.low_price, bar.close_price, bar.volume))
if self.lineW:
self.lineW.add_bar(bar, bar_freq=self.TMinuteInterval)
if self.lineD:
self.lineD.add_bar(bar, bar_freq=self.TMinuteInterval)
if self.lineH2:
self.lineH2.add_bar(bar, bar_freq=self.TMinuteInterval)
if self.lineH1:
self.lineH1.add_bar(bar, bar_freq=self.TMinuteInterval)
if self.lineM30:
self.lineM30.add_bar(bar, bar_freq=self.TMinuteInterval)
if self.lineM5:
self.lineM5.add_bar(bar, bar_freq=self.TMinuteInterval)
# if self.lineH2:
# self.lineH2.skd_is_high_dead_cross(runtime=True, high_skd=30)
# self.lineH2.skd_is_low_golden_cross(runtime=True, low_skd=70)
def onBarM30(self, bar):
self.write_log(self.lineM30.get_last_bar_str())
self.save_m30_bars.append({
'datetime': bar.datetime,
'open': bar.open_price,
'high': bar.high_price,
'low': bar.low_price,
'close': bar.close_price,
'turnover': 0,
'volume': bar.volume,
'open_interest': 0,
'ma5': self.lineM30.line_ma1[-1] if len(self.lineM30.line_ma1) > 0 else bar.close_price,
'ma10': self.lineM30.line_ma2[-1] if len(self.lineM30.line_ma2) > 0 else bar.close_price,
'ma60': self.lineM30.line_ma3[-1] if len(self.lineM30.line_ma3) > 0 else bar.close_price,
'sk': self.lineM30.line_sk[-1] if len(self.lineM30.line_sk) > 0 else 0,
'sd': self.lineM30.line_sd[-1] if len(self.lineM30.line_sd) > 0 else 0
})
def onBarH1(self, bar):
self.write_log(self.lineH1.get_last_bar_str())
self.save_h1_bars.append({
'datetime': bar.datetime,
'open': bar.open_price,
'high': bar.high_price,
'low': bar.low_price,
'close': bar.close_price,
'turnover': 0,
'volume': bar.volume,
'open_interest': 0,
'ema5': self.lineH1.line_ema1[-1] if len(self.lineH1.line_ema1) > 0 else bar.close_price,
'ema10': self.lineH1.line_ema2[-1] if len(self.lineH1.line_ema2) > 0 else bar.close_price,
'ema60': self.lineH1.line_ema3[-1] if len(self.lineH1.line_ema3) > 0 else bar.close_price,
'sk': self.lineH1.line_sk[-1] if len(self.lineH1.line_sk) > 0 else 0,
'sd': self.lineH1.line_sd[-1] if len(self.lineH1.line_sd) > 0 else 0
})
def onBarH2(self, bar):
self.write_log(self.lineH2.get_last_bar_str())
self.save_h2_bars.append({
'datetime': bar.datetime,
'open': bar.open_price,
'high': bar.high_price,
'low': bar.low_price,
'close': bar.close_price,
'turnover': 0,
'volume': bar.volume,
'open_interest': 0,
'ma5': self.lineH2.line_ma1[-1] if len(self.lineH2.line_ma1) > 0 else bar.close_price,
'ma10': self.lineH2.line_ma2[-1] if len(self.lineH2.line_ma2) > 0 else bar.close_price,
'ma18': self.lineH2.line_ma3[-1] if len(self.lineH2.line_ma3) > 0 else bar.close_price,
'sk': self.lineH2.line_sk[-1] if len(self.lineH2.line_sk) > 0 else 0,
'sd': self.lineH2.line_sd[-1] if len(self.lineH2.line_sd) > 0 else 0
})
def onBarD(self, bar):
self.write_log(self.lineD.get_last_bar_str())
self.save_d_bars.append({
'datetime': bar.datetime,
'open': bar.open_price,
'high': bar.high_price,
'low': bar.low_price,
'close': bar.close_price,
'turnover': 0,
'volume': bar.volume,
'open_interest': 0,
'ma5': self.lineD.line_ma1[-1] if len(self.lineD.line_ma1) > 0 else bar.close_price,
'ma10': self.lineD.line_ma2[-1] if len(self.lineD.line_ma2) > 0 else bar.close_price,
'ma18': self.lineD.line_ma3[-1] if len(self.lineD.line_ma3) > 0 else bar.close_price,
'sk': self.lineD.line_sk[-1] if len(self.lineD.line_sk) > 0 else 0,
'sd': self.lineD.line_sd[-1] if len(self.lineD.line_sd) > 0 else 0
})
def onBarW(self, bar):
self.write_log(self.lineW.get_last_bar_str())
self.save_w_bars.append({
'datetime': bar.datetime,
'open': bar.open_price,
'high': bar.high_price,
'low': bar.low_price,
'close': bar.close_price,
'turnover': 0,
'volume': bar.volume,
'open_interest': 0,
'ma5': self.lineW.line_ma1[-1] if len(self.lineW.line_ma1) > 0 else bar.close_price,
'ma10': self.lineW.line_ma2[-1] if len(self.lineW.line_ma2) > 0 else bar.close_price,
'ma18': self.lineW.line_ma3[-1] if len(self.lineW.line_ma3) > 0 else bar.close_price,
'sk': self.lineW.line_sk[-1] if len(self.lineW.line_sk) > 0 else 0,
'sd': self.lineW.line_sd[-1] if len(self.lineW.line_sd) > 0 else 0
})
def on_tick(self, tick):
print(u'{0},{1},ap:{2},av:{3},bp:{4},bv:{5}'.format(tick.datetime, tick.last_price, tick.ask_price_1,
tick.ask_volume_1, tick.bid_price_1, tick.bid_volume_1))
def write_log(self, content):
print(content)
def saveData(self):
if len(self.save_m30_bars) > 0:
outputFile = '{}_m30.csv'.format(self.vt_symbol)
with open(outputFile, 'w', encoding='utf8', newline='') as f:
fieldnames = ['datetime', 'open', 'high', 'low', 'close', 'turnover', 'volume', 'open_interest',
'ma5', 'ma10', 'ma60', 'sk', 'sd']
writer = csv.DictWriter(f=f, fieldnames=fieldnames, dialect='excel')
writer.writeheader()
for row in self.save_m30_bars:
writer.writerow(row)
if len(self.save_h1_bars) > 0:
outputFile = '{}_h1.csv'.format(self.vt_symbol)
with open(outputFile, 'w', encoding='utf8', newline='') as f:
fieldnames = ['datetime', 'open', 'high', 'low', 'close', 'turnover', 'volume', 'open_interest',
'ema5', 'ema10', 'ema60', 'sk', 'sd']
writer = csv.DictWriter(f=f, fieldnames=fieldnames, dialect='excel')
writer.writeheader()
for row in self.save_h1_bars:
writer.writerow(row)
if len(self.save_h2_bars) > 0:
outputFile = '{}_h2.csv'.format(self.vt_symbol)
with open(outputFile, 'w', encoding='utf8', newline='') as f:
fieldnames = ['datetime', 'open', 'high', 'low', 'close', 'turnover', 'volume', 'open_interest',
'ma5', 'ma10', 'ma18', 'sk', 'sd']
writer = csv.DictWriter(f=f, fieldnames=fieldnames, dialect='excel')
writer.writeheader()
for row in self.save_h2_bars:
writer.writerow(row)
if len(self.save_d_bars) > 0:
outputFile = '{}_d.csv'.format(self.vt_symbol)
with open(outputFile, 'w', encoding='utf8', newline='') as f:
fieldnames = ['datetime', 'open', 'high', 'low', 'close', 'turnover', 'volume', 'open_interest',
'ma5', 'ma10', 'ma18', 'sk', 'sd']
writer = csv.DictWriter(f=f, fieldnames=fieldnames, dialect='excel')
writer.writeheader()
for row in self.save_d_bars:
writer.writerow(row)
if len(self.save_w_bars) > 0:
outputFile = '{}_w.csv'.format(self.vt_symbol)
with open(outputFile, 'w', encoding='utf8', newline='') as f:
fieldnames = ['datetime', 'open', 'high', 'low', 'close', 'turnover', 'volume', 'open_interest',
'ma5', 'ma10', 'ma18', 'sk', 'sd']
writer = csv.DictWriter(f=f, fieldnames=fieldnames, dialect='excel')
writer.writeheader()
for row in self.save_w_bars:
writer.writerow(row)
if __name__ == '__main__':
t = test_strategy()
t.price_tick = 0.5
t.underlying_symbol = 'J'
t.vt_symbol = 'J99'
# t.createM5()
# t.createLineW()
# t.createlineM30_with_macd()
# 创建M30线
# t.createLineM30()
# 回测1小时线
# t.createLineH1()
# 回测2小时线
# t.createLineH2()
# 回测日线
# t.createLineD()
# 测试周线
t.createLineW()
# vnpy/app/cta_strategy_pro/
vnpy_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
filename = os.path.abspath(os.path.join(vnpy_root, 'bar_data/{}_20160101_1m.csv'.format(t.vt_symbol)))
csv_bar_seconds = 60 # csv 文件内bar的时间间隔60秒
import csv
csvfile = open(filename, 'r', encoding='utf8')
reader = csv.DictReader((line.replace('\0', '') for line in csvfile), delimiter=",")
last_tradingDay = None
for row in reader:
try:
dt = datetime.strptime(row['datetime'], '%Y-%m-%d %H:%M:%S') - timedelta(seconds=csv_bar_seconds)
bar = BarData(
gateway_name='',
symbol=t.vt_symbol,
exchange=Exchange.LOCAL,
datetime=dt,
interval=Interval.MINUTE,
open_price=round_to(float(row['open']), t.price_tick),
high_price=round_to(float(row['high']), t.price_tick),
low_price=round_to(float(row['low']), t.price_tick),
close_price=round_to(float(row['close']), t.price_tick),
volume=float(row['volume'])
)
if 'trading_date' in row:
bar.trading_day = row['trading_date']
if len(bar.trading_day) == 8 and '-' not in bar.trading_day:
bar.trading_day = bar.trading_day[0:4] + '-' + bar.trading_day[4:6] + '-' + bar.trading_day[6:8]
else:
if bar.datetime.hour >= 21:
if bar.datetime.isoweekday() == 5:
# 星期五=》星期一
bar.trading_day = (dt + timedelta(days=3)).strftime('%Y-%m-%d')
else:
# 第二天
bar.trading_day = (dt + timedelta(days=1)).strftime('%Y-%m-%d')
elif bar.datetime.hour < 8 and bar.datetime.isoweekday() == 6:
# 星期六=>星期一
bar.trading_day = (dt + timedelta(days=2)).strftime('%Y-%m-%d')
else:
bar.trading_day = bar.datetime.strftime('%Y-%m-%d')
t.onBar(bar)
# 测试 实时计算值
# sk, sd = t.lineM30.getRuntimeSKD()
# 测试实时计算值
# if bar.datetime.minute==1:
# print('rt_Dif:{}'.format(t.lineM30.rt_Dif))
except Exception as ex:
t.write_log(u'{0}:{1}'.format(Exception, ex))
traceback.print_exc()
break
t.saveData()

View File

@ -3,8 +3,6 @@
# 通达信指数行情发布器
# 华富资产
import os
import sys
import copy
import json
import traceback

View File

@ -11,7 +11,6 @@ from vnpy.trader.event import EVENT_TRADE, EVENT_ORDER, EVENT_LOG, EVENT_ACCOUNT
from vnpy.trader.constant import Status
from vnpy.trader.utility import load_json, save_json
APP_NAME = "RiskManager"
@ -178,6 +177,24 @@ class RiskManagerEngine(BaseEngine):
account.balance, account_percent, self.percent_limit)
self.write_log(msg)
def get_account(self, vt_accountid: str = ""):
"""获取账号的当前净值,可用资金,账号当前仓位百分比,允许的最大仓位百分比"""
if vt_accountid:
account = self.account_dict.get(vt_accountid, None)
if account:
return account.balance, \
account.available, \
round(account.frozen * 100 / (account.balance + 0.01), 2), \
self.percent_limit
if len(self.account_dict.values()) > 0:
account = self.account_dict.values()[0]
return account.balance, \
account.available, \
round(account.frozen * 100 / (account.balance + 0.01), 2), \
self.percent_limit
else:
return 0, 0, 0, 0
def write_log(self, msg: str):
""""""
log = LogData(msg=msg, gateway_name="RiskManager")

View File

@ -13,7 +13,7 @@
"mi_symbol": "ag2007",
"full_symbol": "AG2007",
"exchange": "SHFE",
"margin_rate": 0.08,
"margin_rate": 0.07,
"symbol_size": 15,
"price_tick": 1.0
},
@ -22,7 +22,7 @@
"mi_symbol": "al2003",
"full_symbol": "AL2003",
"exchange": "SHFE",
"margin_rate": 0.1,
"margin_rate": 0.07,
"symbol_size": 5,
"price_tick": 5.0
},
@ -31,7 +31,7 @@
"mi_symbol": "AP005",
"full_symbol": "AP2005",
"exchange": "CZCE",
"margin_rate": 0.07,
"margin_rate": 0.08,
"symbol_size": 10,
"price_tick": 1.0
},
@ -40,7 +40,7 @@
"mi_symbol": "au2006",
"full_symbol": "AU2006",
"exchange": "SHFE",
"margin_rate": 0.08,
"margin_rate": 0.06,
"symbol_size": 1000,
"price_tick": 0.02
},
@ -55,8 +55,8 @@
},
"BB": {
"underlying_symbol": "BB",
"mi_symbol": "bb2012",
"full_symbol": "BB2012",
"mi_symbol": "bb2101",
"full_symbol": "BB2101",
"exchange": "DCE",
"margin_rate": 0.2,
"symbol_size": 500,
@ -67,7 +67,7 @@
"mi_symbol": "bu2006",
"full_symbol": "BU2006",
"exchange": "SHFE",
"margin_rate": 0.1,
"margin_rate": 0.09,
"symbol_size": 10,
"price_tick": 2.0
},
@ -82,8 +82,8 @@
},
"CF": {
"underlying_symbol": "CF",
"mi_symbol": "CF005",
"full_symbol": "CF2005",
"mi_symbol": "CF009",
"full_symbol": "CF2009",
"exchange": "CZCE",
"margin_rate": 0.05,
"symbol_size": 5,
@ -112,7 +112,7 @@
"mi_symbol": "cu2003",
"full_symbol": "CU2003",
"exchange": "SHFE",
"margin_rate": 0.09,
"margin_rate": 0.07,
"symbol_size": 5,
"price_tick": 10.0
},
@ -139,7 +139,7 @@
"mi_symbol": "eg2005",
"full_symbol": "EG2005",
"exchange": "DCE",
"margin_rate": 0.05,
"margin_rate": 0.06,
"symbol_size": 10,
"price_tick": 1.0
},
@ -148,9 +148,9 @@
"mi_symbol": "fb2005",
"full_symbol": "FB2005",
"exchange": "DCE",
"margin_rate": 0.2,
"symbol_size": 500,
"price_tick": 0.05
"margin_rate": 0.1,
"symbol_size": 10,
"price_tick": 0.5
},
"FG": {
"underlying_symbol": "FG",
@ -166,7 +166,7 @@
"mi_symbol": "fu2005",
"full_symbol": "FU2005",
"exchange": "SHFE",
"margin_rate": 0.2,
"margin_rate": 0.1,
"symbol_size": 10,
"price_tick": 1.0
},
@ -175,7 +175,7 @@
"mi_symbol": "hc2005",
"full_symbol": "HC2005",
"exchange": "SHFE",
"margin_rate": 0.1,
"margin_rate": 0.08,
"symbol_size": 10,
"price_tick": 1.0
},
@ -184,7 +184,7 @@
"mi_symbol": "i2005",
"full_symbol": "I2005",
"exchange": "DCE",
"margin_rate": 0.05,
"margin_rate": 0.08,
"symbol_size": 100,
"price_tick": 0.5
},
@ -193,7 +193,7 @@
"mi_symbol": "IC2003",
"full_symbol": "IC2003",
"exchange": "CFFEX",
"margin_rate": 0.1,
"margin_rate": 0.12,
"symbol_size": 200,
"price_tick": 0.2
},
@ -220,7 +220,7 @@
"mi_symbol": "j2005",
"full_symbol": "J2005",
"exchange": "DCE",
"margin_rate": 0.05,
"margin_rate": 0.08,
"symbol_size": 100,
"price_tick": 0.5
},
@ -229,7 +229,7 @@
"mi_symbol": "jd2005",
"full_symbol": "JD2005",
"exchange": "DCE",
"margin_rate": 0.08,
"margin_rate": 0.07,
"symbol_size": 10,
"price_tick": 1.0
},
@ -238,14 +238,14 @@
"mi_symbol": "jm2005",
"full_symbol": "JM2005",
"exchange": "DCE",
"margin_rate": 0.05,
"margin_rate": 0.08,
"symbol_size": 60,
"price_tick": 0.5
},
"JR": {
"underlying_symbol": "JR",
"mi_symbol": "JR011",
"full_symbol": "JR2011",
"mi_symbol": "JR101",
"full_symbol": "JR2101",
"exchange": "CZCE",
"margin_rate": 0.05,
"symbol_size": 20,
@ -328,14 +328,14 @@
"mi_symbol": "pb2003",
"full_symbol": "PB2003",
"exchange": "SHFE",
"margin_rate": 0.1,
"margin_rate": 0.07,
"symbol_size": 5,
"price_tick": 5.0
},
"PM": {
"underlying_symbol": "PM",
"mi_symbol": "PM011",
"full_symbol": "PM2011",
"mi_symbol": "PM101",
"full_symbol": "PM2101",
"exchange": "CZCE",
"margin_rate": 0.05,
"symbol_size": 50,
@ -355,14 +355,14 @@
"mi_symbol": "rb2005",
"full_symbol": "RB2005",
"exchange": "SHFE",
"margin_rate": 0.1,
"margin_rate": 0.08,
"symbol_size": 10,
"price_tick": 1.0
},
"RI": {
"underlying_symbol": "RI",
"mi_symbol": "RI011",
"full_symbol": "RI2011",
"mi_symbol": "RI101",
"full_symbol": "RI2101",
"exchange": "CZCE",
"margin_rate": 0.05,
"symbol_size": 20,
@ -373,7 +373,7 @@
"mi_symbol": "RM005",
"full_symbol": "RM2005",
"exchange": "CZCE",
"margin_rate": 0.05,
"margin_rate": 0.06,
"symbol_size": 10,
"price_tick": 1.0
},
@ -391,7 +391,7 @@
"mi_symbol": "RS011",
"full_symbol": "RS2011",
"exchange": "CZCE",
"margin_rate": 0.05,
"margin_rate": 0.2,
"symbol_size": 10,
"price_tick": 1.0
},
@ -400,7 +400,7 @@
"mi_symbol": "ru2005",
"full_symbol": "RU2005",
"exchange": "SHFE",
"margin_rate": 0.1,
"margin_rate": 0.09,
"symbol_size": 10,
"price_tick": 5.0
},
@ -418,7 +418,7 @@
"mi_symbol": "sc2003",
"full_symbol": "SC2003",
"exchange": "INE",
"margin_rate": 0.05,
"margin_rate": 0.07,
"symbol_size": 1000,
"price_tick": 0.1
},
@ -427,7 +427,7 @@
"mi_symbol": "SF005",
"full_symbol": "SF2005",
"exchange": "CZCE",
"margin_rate": 0.05,
"margin_rate": 0.07,
"symbol_size": 5,
"price_tick": 2.0
},
@ -436,7 +436,7 @@
"mi_symbol": "SM005",
"full_symbol": "SM2005",
"exchange": "CZCE",
"margin_rate": 0.05,
"margin_rate": 0.07,
"symbol_size": 5,
"price_tick": 2.0
},
@ -445,7 +445,7 @@
"mi_symbol": "sn2006",
"full_symbol": "SN2006",
"exchange": "SHFE",
"margin_rate": 0.09,
"margin_rate": 0.08,
"symbol_size": 1,
"price_tick": 10.0
},
@ -490,7 +490,7 @@
"mi_symbol": "TA005",
"full_symbol": "TA2005",
"exchange": "CZCE",
"margin_rate": 0.05,
"margin_rate": 0.06,
"symbol_size": 5,
"price_tick": 2.0
},
@ -535,7 +535,7 @@
"mi_symbol": "WH011",
"full_symbol": "WH2011",
"exchange": "CZCE",
"margin_rate": 0.05,
"margin_rate": 0.07,
"symbol_size": 20,
"price_tick": 1.0
},
@ -544,7 +544,7 @@
"mi_symbol": "wr2012",
"full_symbol": "WR2012",
"exchange": "SHFE",
"margin_rate": 0.2,
"margin_rate": 0.08,
"symbol_size": 10,
"price_tick": 1.0
},
@ -562,7 +562,7 @@
"mi_symbol": "ZC005",
"full_symbol": "ZC2005",
"exchange": "CZCE",
"margin_rate": 0.05,
"margin_rate": 0.06,
"symbol_size": 100,
"price_tick": 0.2
},
@ -571,7 +571,7 @@
"mi_symbol": "zn2003",
"full_symbol": "ZN2003",
"exchange": "SHFE",
"margin_rate": 0.1,
"margin_rate": 0.07,
"symbol_size": 5,
"price_tick": 5.0
}

View File

@ -29,7 +29,7 @@ api_01 = TdxFutureData()
api_01.update_mi_contracts()
# 逐一指数合约下载并更新
for underlying_symbol in ['RB', 'J']: #api_01.future_contracts.keys():
for underlying_symbol in api_01.future_contracts.keys():
index_symbol = underlying_symbol + '99'
print(f'开始更新:{index_symbol}')
# csv数据文件名

View File

@ -137,7 +137,7 @@ class TdxFutureData(object):
last_datetime = datetime.strptime(last_datetime_str, '%Y-%m-%d %H:%M:%S')
if (datetime.now() - last_datetime).total_seconds() > 60 * 60 * 2:
self.best_ip = {}
except Exception as ex:
except Exception as ex: # noqa
self.best_ip = {}
else:
self.best_ip = {}
@ -262,6 +262,8 @@ class TdxFutureData(object):
"""
ret_bars = []
if '.' in symbol:
symbol = symbol.split('.')[0]
tdx_symbol = symbol.upper().replace('_', '')
tdx_symbol = tdx_symbol.replace('99', 'L9')
underlying_symbol = get_underlying_symbol(symbol).upper()

View File

@ -3,7 +3,6 @@
# Celery app
# 该py脚本为启动celery worker app
# 在项目根目录下,运行 celery -A vnpy.task.celery worker
import time
from celery import Celery
import sys
@ -23,7 +22,7 @@ if vnpy_root not in sys.path:
sys.path.append(vnpy_root)
# 使用本地配置的
from vnpy.trader.utility import load_json
from vnpy.trader.utility import load_json # noqa
file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'celery_config.json'))
celery_config = load_json(file_path)
@ -40,17 +39,8 @@ print(u'Celery 使用redis配置:\nbroker:{}\nbackend:{}'.format(broker, backend
app = Celery('vnpy_task', broker=broker)
# 动态导入task目录下子任务
# app.conf.CELERY_IMPORTS = ['vnpy.task.celery_app.worker_started']
# app.conf.update(
# CELERY_TASK_SERIALIZER='json',
# CELERY_RESULT_SERIALIZER='json',
# CELERY_ACCEPT_CONTENT=['json'],
# CELERY_TIMEZONE='Asia/Shanghai',
# CELERY_ENABLE_UTC=True
# )
app.conf.CELERY_IMPORTS = ['vnpy.task.celery_app.worker_started']
def worker_started():
@ -59,7 +49,7 @@ def worker_started():
import socket
from vnpy.trader.util_wechat import send_wx_msg
send_wx_msg(u'{} Celery Worker 启动'.format(socket.gethostname()))
except:
except: # noqa
pass

View File

@ -136,6 +136,7 @@ class Exchange(Enum):
LOCAL = "LOCAL" # For local generated data
SPD = "SPD" # Customer Spread data
class Currency(Enum):
"""
Currency.

View File

@ -176,7 +176,7 @@ class MainEngine:
if gateway:
gateway.subscribe(req)
else:
for gateway in self.gateways.items():
for gateway in self.gateways.values():
if gateway:
gateway.subscribe(req)
@ -196,7 +196,8 @@ class MainEngine:
"""
gateway = self.get_gateway(gateway_name)
if gateway:
gateway.cancel_order(req)
return gateway.cancel_order(req)
return False
def send_orders(self, reqs: Sequence[OrderRequest], gateway_name: str):
"""
@ -588,12 +589,13 @@ class CustomContract(object):
for symbol, setting in self.setting.items():
gateway_name = setting.get('gateway_name', None)
if gateway_name is None:
gateway_name= SETTINGS.get('gateway_name','')
gateway_name = SETTINGS.get('gateway_name', '')
vn_exchange = Exchange(setting.get('exchange', 'LOCAL'))
contract = ContractData(
gateway_name=gateway_name,
symbol=symbol,
name=contract.symbol,
exchange=vn_exchange,
name=setting.get('name', symbol),
size=setting.get('size', 100),
pricetick=setting.get('price_tick', 0.01),
margin_rate=setting.get('margin_rate', 0.1)
@ -602,6 +604,7 @@ class CustomContract(object):
return d
class EmailEngine(BaseEngine):
"""
Provides email sending function for VN Trader.

View File

@ -26,4 +26,3 @@ EVENT_FUNDS_FLOW = 'eFundsFlow.'
EVENT_ERROR = 'eError'
EVENT_WARNING = 'eWarning'
EVENT_CRITICAL = 'eCritical'

View File

@ -262,7 +262,7 @@ class BaseGateway(ABC):
implementation should finish the tasks blow:
* send request to server
"""
pass
return False
def send_orders(self, reqs: Sequence[OrderRequest]):
"""

View File

@ -8,7 +8,7 @@ from logging import INFO
from .constant import Direction, Exchange, Interval, Offset, Status, Product, OptionType, OrderType
ACTIVE_STATUSES = set([Status.SUBMITTING, Status.NOTTRADED, Status.PARTTRADED])
ACTIVE_STATUSES = set([Status.SUBMITTING, Status.NOTTRADED, Status.PARTTRADED, Status.CANCELLING])
@dataclass
@ -339,6 +339,7 @@ class SubscribeRequest:
def __eq__(self, other):
return self.vt_symbol == other.vt_symbol
@dataclass
class OrderRequest:
"""

View File

@ -172,17 +172,20 @@ def extract_vt_symbol(vt_symbol: str):
symbol, exchange_str = vt_symbol.split(".")
return symbol, Exchange(exchange_str)
def generate_vt_symbol(symbol: str, exchange: Exchange):
"""
return vt_symbol
"""
return f"{symbol}.{exchange.value}"
def format_number(n):
"""格式化数字到字符串"""
rn = round(n, 2) # 保留两位小数
return format(rn, ',') # 加上千分符
def _get_trader_dir(temp_name: str):
"""
Get path where trader is running in.
@ -212,8 +215,9 @@ def _get_trader_dir(temp_name: str):
TRADER_DIR, TEMP_DIR = _get_trader_dir(".vntrader")
sys.path.append(str(TRADER_DIR))
print(f'sys.path append: {str(TRADER_DIR)}')
if TRADER_DIR not in sys.path:
sys.path.append(str(TRADER_DIR))
print(f'sys.path append: {str(TRADER_DIR)}')
def get_file_path(filename: str):
@ -306,7 +310,7 @@ def print_dict(d: dict):
return '\n'.join([f'{key}:{d[key]}' for key in sorted(d.keys())])
def append_data(self, file_name: str, dict_data: dict, field_names: list = []):
def append_data(file_name: str, dict_data: dict, field_names: list = []):
"""
添加数据到csv文件中
:param file_name: csv的文件全路径
@ -360,19 +364,254 @@ def import_module_by_str(import_module_name):
comp = modules[-1]
if not hasattr(mod, comp):
loaded_modules = '.'.join([loaded_modules,comp])
loaded_modules = '.'.join([loaded_modules, comp])
print('realod {}'.format(loaded_modules))
mod = reload(loaded_modules)
else:
print('from {} import {}'.format(loaded_modules,comp))
print('from {} import {}'.format(loaded_modules, comp))
mod = getattr(mod, comp)
return mod
except Exception as ex:
print('import {} fail,{},{}'.format(import_module_name,str(ex),traceback.format_exc()))
print('import {} fail,{},{}'.format(import_module_name, str(ex), traceback.format_exc()))
return None
def save_df_to_excel(file_name, sheet_name, df):
"""
保存dataframe到execl
:param file_name: 保存的excel文件名
:param sheet_name: 保存的sheet
:param df: dataframe
:return: True/False
"""
if file_name is None or sheet_name is None or df is None:
return False
# ----------------------------- 扩展的功能 ---------
try:
import openpyxl
from openpyxl.utils.dataframe import dataframe_to_rows
# from openpyxl.drawing.image import Image
except: # noqa
print(u'can not import openpyxl', file=sys.stderr)
if 'openpyxl' not in sys.modules:
print(u'can not import openpyxl', file=sys.stderr)
return False
try:
ws = None
try:
# 读取文件
wb = openpyxl.load_workbook(file_name)
except: # noqa
# 创建一个excel workbook
wb = openpyxl.Workbook()
ws = wb.active
ws.title = sheet_name
try:
# 定位WorkSheet
if ws is None:
ws = wb[sheet_name]
except: # noqa
# 创建一个WorkSheet
ws = wb.create_sheet()
ws.title = sheet_name
rows = dataframe_to_rows(df)
for r_idx, row in enumerate(rows, 1):
for c_idx, value in enumerate(row, 1):
ws.cell(row=r_idx, column=c_idx, value=value)
# Save the workbook
wb.save(file_name)
wb.close()
except Exception as ex:
import traceback
print(u'save_df_to_excel exception:{}'.format(str(ex)), traceback.format_exc(), file=sys.stderr)
def save_text_to_excel(file_name, sheet_name, text):
"""
保存文本文件到excel
:param file_name:
:param sheet_name:
:param text:
:return:
"""
if file_name is None or len(sheet_name) == 0 or len(text) == 0:
return False
# ----------------------------- 扩展的功能 ---------
try:
import openpyxl
# from openpyxl.utils.dataframe import dataframe_to_rows
# from openpyxl.drawing.image import Image
except: # noqa
print(u'can not import openpyxl', file=sys.stderr)
if 'openpyxl' not in sys.modules:
return False
try:
ws = None
try:
# 读取文件
wb = openpyxl.load_workbook(file_name)
except: # noqa
# 创建一个excel workbook
wb = openpyxl.Workbook()
ws = wb.active
ws.title = sheet_name
try:
# 定位WorkSheet
if ws is None:
ws = wb[sheet_name]
except: # noqa
# 创建一个WorkSheet
ws = wb.create_sheet()
ws.title = sheet_name
# 设置宽度,自动换行方式
ws.column_dimensions["A"].width = 120
ws['A2'].alignment = openpyxl.styles.Alignment(wrapText=True)
ws['A2'].value = text
# Save the workbook
wb.save(file_name)
wb.close()
return True
except Exception as ex:
import traceback
print(u'save_text_to_excel exception:{}'.format(str(ex)), traceback.format_exc(), file=sys.stderr)
return False
def save_images_to_excel(file_name, sheet_name, image_names):
"""
# 保存图形文件到excel
:param file_name: excel文件名
:param sheet_name: workSheet
:param image_names: 图像文件名列表
:return:
"""
if file_name is None or len(sheet_name) == 0 or len(image_names) == 0:
return False
# ----------------------------- 扩展的功能 ---------
try:
import openpyxl
# from openpyxl.utils.dataframe import dataframe_to_rows
from openpyxl.drawing.image import Image
except Exception as ex:
print(f'can not import openpyxl:{str(ex)}', file=sys.stderr)
if 'openpyxl' not in sys.modules:
return False
try:
ws = None
try:
# 读取文件
wb = openpyxl.load_workbook(file_name)
except: # noqa
# 创建一个excel workbook
wb = openpyxl.Workbook()
ws = wb.active
ws.title = sheet_name
try:
# 定位WorkSheet
if ws is None:
ws = wb[sheet_name]
except Exception as ex: # noqa
# 创建一个WorkSheet
ws = wb.create_sheet()
ws.title = sheet_name
i = 1
for image_name in image_names:
try:
# 加载图形文件
img1 = Image(image_name)
cell_id = 'A{0}'.format(i)
ws[cell_id].value = image_name
cell_id = 'A{0}'.format(i + 1)
i += 30
# 添加至对应的WorkSheet中
ws.add_image(img1, cell_id)
except Exception as ex:
print('exception loading image {}, {}'.format(image_name, str(ex)), file=sys.stderr)
return False
# Save the workbook
wb.save(file_name)
wb.close()
return True
except Exception as ex:
import traceback
print(u'save_images_to_excel exception:{}'.format(str(ex)), traceback.format_exc(), file=sys.stderr)
return False
def display_dual_axis(df, columns1, columns2=[], invert_yaxis1=False, invert_yaxis2=False, file_name=None,
sheet_name=None,
image_name=None):
"""
显示(保存)双Y轴的走势图
:param df: DataFrame
:param columns1: y1轴
:param columns2: Y2轴
:param invert_yaxis1: Y1 轴反转
:param invert_yaxis2: Y2 轴翻转
:param file_name: 保存的excel 文件名称
:param sheet_name: excel 的sheet
:param image_name: 保存的image 文件名
:return:
"""
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams['figure.figsize'] = (20.0, 10.0)
df1 = df[columns1]
df1.index = list(range(len(df)))
fig, ax1 = plt.subplots()
if invert_yaxis1:
ax1.invert_yaxis()
ax1.plot(df1)
if len(columns2) > 0:
df2 = df[columns2]
df2.index = list(range(len(df)))
ax2 = ax1.twinx()
if invert_yaxis2:
ax2.invert_yaxis()
ax2.plot(df2)
# 修改x轴得label为时间
xt = ax1.get_xticks()
xt2 = [df.index[int(i)] for i in xt[1:-2]]
xt2.insert(0, '')
xt2.append('')
ax1.set_xticklabels(xt2)
# 是否保存图片到文件
if image_name is not None:
fig = plt.gcf()
fig.savefig(image_name, bbox_inches='tight')
# 插入图片到指定的excel文件sheet中并保存excel
if file_name is not None and sheet_name is not None:
save_images_to_excel(file_name, sheet_name, [image_name])
else:
plt.show()
class BarGenerator:
"""
For: