From f9aad85db96e921e04d756f6bd205ce7bf58e540 Mon Sep 17 00:00:00 2001 From: msincenselee Date: Sun, 2 Feb 2020 19:09:05 +0800 Subject: [PATCH] [bug fix] --- Q_n_A.md | 55 +- examples/vn_trader/ar_setting.json | 16 + .../vn_trader/cta_strategy_pro_config.json | 4 + examples/vn_trader/run.py | 36 +- vnpy/app/cta_strategy_pro/__init__.py | 9 +- vnpy/app/cta_strategy_pro/base.py | 31 +- vnpy/app/cta_strategy_pro/cta_grid_trade.py | 90 ++- vnpy/app/cta_strategy_pro/cta_line_bar.py | 452 +----------- vnpy/app/cta_strategy_pro/cta_policy.py | 16 +- vnpy/app/cta_strategy_pro/cta_position.py | 13 +- vnpy/app/cta_strategy_pro/engine.py | 62 +- .../app/cta_strategy_pro/portfolio_testing.py | 685 +++++++++++++----- .../strategies/turtle_signal_strategy.py | 4 +- .../strategies/turtle_signal_strategy_v2.py | 94 ++- vnpy/app/cta_strategy_pro/template.py | 669 +++++++++++++++-- vnpy/app/cta_strategy_pro/test_line_bar_01.py | 65 ++ vnpy/app/cta_strategy_pro/test_line_bar_02.py | 442 +++++++++++ vnpy/app/index_tick_publisher/engine.py | 2 - vnpy/app/risk_manager/engine.py | 27 +- vnpy/data/tdx/future_contracts.json | 82 +-- vnpy/data/tdx/refill_tdx_future_bars.py | 2 +- vnpy/data/tdx/tdx_future_data.py | 4 +- vnpy/task/celery_app.py | 16 +- vnpy/trader/constant.py | 69 +- vnpy/trader/engine.py | 25 +- vnpy/trader/event.py | 1 - vnpy/trader/gateway.py | 2 +- vnpy/trader/object.py | 35 +- vnpy/trader/utility.py | 251 ++++++- 29 files changed, 2309 insertions(+), 950 deletions(-) create mode 100644 examples/vn_trader/ar_setting.json create mode 100644 examples/vn_trader/cta_strategy_pro_config.json create mode 100644 vnpy/app/cta_strategy_pro/test_line_bar_01.py create mode 100644 vnpy/app/cta_strategy_pro/test_line_bar_02.py diff --git a/Q_n_A.md b/Q_n_A.md index 7c17cc86..9b4b0162 100644 --- a/Q_n_A.md +++ b/Q_n_A.md @@ -2,59 +2,67 @@ -------------------------------------------------------------------------------------------- ###FAQ: -#3、碰到的问题:找不到vnpy.xx.xx(原2.7环境) +3、碰到的问题:找不到vnpy.xx.xx(原2.7环境) + 可能你使用了vnpy的原版,安装到conda环境中了。需要先卸载 pip uninstall vnpy -#4、碰到的问题:importError: libGL.so.1: cannot open shared object file: No such file or directory +4、碰到的问题:importError: libGL.so.1: cannot open shared object file: No such file or directory + ubuntu下:sudo apt install libgl1-mesa-glx centOS下:sudo yum install mesa-libGL.x86_64 -#5、碰到的问题:version `GLIBCXX_3.4.21' not found +5、碰到的问题:version `GLIBCXX_3.4.21' not found + conda install libgcc 若出现更高版本需求,参见第10点 -#6、碰到的问题:在3.7 env下安装RqPlus时,报错:talib/common.c:242:28: fatal error: ta-lib/ta_defs.h: No such file or directory +6、碰到的问题:在3.7 env下安装RqPlus时,报错:talib/common.c:242:28: fatal error: ta-lib/ta_defs.h: No such file or directory + locate ta_defs.h 找到地址:/home/user/anaconda3/pkgs/ta-lib-0.4.9-np111py27_0/include/ta-lib # 复制一份到/usr/include目录下 sudo cp /home/user/anaconda3/pkgs/ta-lib-0.4.9-np111py27_0/include/ta-lib /usr/include -R -#7、碰到的问题:安装某些组件,提示 cl.exe Not found +7、碰到的问题:安装某些组件,提示 cl.exe Not found + 如果你安装了VC基础组件,需要增加一个用户环境变量,把"C:\Program Files (x86)\Microsoft Visual Studio\Shared\14.0\VC\bin" 添加到path变量中 -#8、Install Ta-Lib -如果你用py37虚拟环境 -source activate py37 +8、Install Ta-Lib + + 如果你用py37虚拟环境 + source activate py37 + + conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ + conda config --set show_channel_urls yes + conda install -c quantopian ta-lib=0.4.9 -conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ -conda config --set show_channel_urls yes -conda install -c quantopian ta-lib=0.4.9 +9、数字货币的增量安装 -# 9、数字货币的增量安装 -conda install scipy + conda install scipy + + pip install autobahn + pip install twisted + 若出现找不到rc.exe, 请先使用vs x86&x64界面,激活py37后,再运行 + pip install pyOpenSSL -pip install autobahn -pip install twisted -若出现找不到rc.exe, 请先使用vs x86&x64界面,激活py37后,再运行 -pip install pyOpenSSL +10、升级gcc -# 10、升级gcc 使用奇正MOM的CTP API时,提示`GLIBCXX_3.4.22' not found,当前centos最高版本是 3.4.21,通过yum不能升级,需要手工下载升级。 wget http://ftp.de.debian.org/debian/pool/main/g/gcc-9/libstdc++6_9.2.1-8_amd64.deb -解压 - + 解压 ar -x libstdc++6_9.2.1-8_amd64.deb (就是 ar 命令,不是tar) tar -xvf data.tar.xz -安装 + 安装 删除: rm /usr/lib64/libstdc++.so.6 拷贝: cp usr/lib/x86_64-linux-gnu/libstdc++.so.6.0.28 /usr/lib64/ 连接: ln -s /usr/lib64/libstdc++.so.6.0.28 /usr/lib64/libstdc++.so.6 -结果 + + 结果 strings /usr/lib64/libstdc++.so.6 | grep GLIBCXX @@ -86,7 +94,8 @@ pip install pyOpenSSL GLIBCXX_3.4.25 GLIBCXX_DEBUG_MESSAGE_LENGTH -# 11、升级glibc +11、升级glibc + 使用奇正MOM的CTP API时,提示`GLIBC_2.18' not found,当前centos最高版本是 3.4.21,通过yum不能升级,需要手工下载升级。 root 用户登录 diff --git a/examples/vn_trader/ar_setting.json b/examples/vn_trader/ar_setting.json new file mode 100644 index 00000000..fc8fc3f9 --- /dev/null +++ b/examples/vn_trader/ar_setting.json @@ -0,0 +1,16 @@ +{ + "mongo_db": + { + "host": "192.168.0.207", + "port": 27017 + }, + "accounts": + { + + "ctp": + { + "copy_history_trades": true, + "copy_history_orders": true + } + } +} diff --git a/examples/vn_trader/cta_strategy_pro_config.json b/examples/vn_trader/cta_strategy_pro_config.json new file mode 100644 index 00000000..f049e2a2 --- /dev/null +++ b/examples/vn_trader/cta_strategy_pro_config.json @@ -0,0 +1,4 @@ +{ + "accountid" : "112022", + "strategy_group": "win01" +} diff --git a/examples/vn_trader/run.py b/examples/vn_trader/run.py index 1c15d1bc..8dca1cac 100644 --- a/examples/vn_trader/run.py +++ b/examples/vn_trader/run.py @@ -5,10 +5,10 @@ from vnpy.trader.engine import MainEngine from vnpy.trader.ui import MainWindow, create_qapp # from vnpy.gateway.binance import BinanceGateway -from vnpy.gateway.bitmex import BitmexGateway +#from vnpy.gateway.bitmex import BitmexGateway # from vnpy.gateway.futu import FutuGateway # from vnpy.gateway.ib import IbGateway -# from vnpy.gateway.ctp import CtpGateway +from vnpy.gateway.ctp import CtpGateway # from vnpy.gateway.ctptest import CtptestGateway # from vnpy.gateway.mini import MiniGateway # from vnpy.gateway.sopt import SoptGateway @@ -31,30 +31,33 @@ from vnpy.gateway.bitmex import BitmexGateway # from vnpy.gateway.coinbase import CoinbaseGateway # from vnpy.gateway.bitstamp import BitstampGateway # from vnpy.gateway.gateios import GateiosGateway -from vnpy.gateway.bybit import BybitGateway +# from vnpy.gateway.bybit import BybitGateway -# from vnpy.app.cta_strategy import CtaStrategyApp +from vnpy.app.cta_strategy_pro import CtaStrategyProApp # from vnpy.app.csv_loader import CsvLoaderApp # from vnpy.app.algo_trading import AlgoTradingApp from vnpy.app.cta_backtester import CtaBacktesterApp # from vnpy.app.data_recorder import DataRecorderApp -# from vnpy.app.risk_manager import RiskManagerApp +from vnpy.app.tick_recorder import TickRecorderApp +from vnpy.app.risk_manager import RiskManagerApp # from vnpy.app.script_trader import ScriptTraderApp # from vnpy.app.rpc_service import RpcServiceApp # from vnpy.app.spread_trading import SpreadTradingApp -from vnpy.app.portfolio_manager import PortfolioManagerApp - +# from vnpy.app.portfolio_manager import PortfolioManagerApp +from vnpy.app.account_recorder import AccountRecorderApp def main(): """""" qapp = create_qapp() - event_engine = EventEngine() + event_engine = EventEngine(debug=True, over_ms=200) main_engine = MainEngine(event_engine) # main_engine.add_gateway(BinanceGateway) - # main_engine.add_gateway(CtpGateway) + main_engine.add_gateway(CtpGateway, gateway_name='ctp') + main_engine.add_gateway(CtpGateway, gateway_name='ctp_yh01') + # main_engine.add_gateway(CtptestGateway) # main_engine.add_gateway(MiniGateway) # main_engine.add_gateway(SoptGateway) @@ -62,7 +65,7 @@ def main(): # main_engine.add_gateway(FemasGateway) # main_engine.add_gateway(IbGateway) # main_engine.add_gateway(FutuGateway) - main_engine.add_gateway(BitmexGateway) + #main_engine.add_gateway(BitmexGateway) # main_engine.add_gateway(TigerGateway) # main_engine.add_gateway(OesGateway) # main_engine.add_gateway(OkexGateway) @@ -80,18 +83,21 @@ def main(): # main_engine.add_gateway(CoinbaseGateway) # main_engine.add_gateway(BitstampGateway) # main_engine.add_gateway(GateiosGateway) - main_engine.add_gateway(BybitGateway) + #main_engine.add_gateway(BybitGateway) - # main_engine.add_app(CtaStrategyApp) - main_engine.add_app(CtaBacktesterApp) + #main_engine.add_app(CtaStrategyApp) + main_engine.add_app(CtaStrategyProApp) + #main_engine.add_app(CtaBacktesterApp) # main_engine.add_app(CsvLoaderApp) # main_engine.add_app(AlgoTradingApp) # main_engine.add_app(DataRecorderApp) - # main_engine.add_app(RiskManagerApp) + # main_engine.add_app(TickRecorderApp) + main_engine.add_app(RiskManagerApp) # main_engine.add_app(ScriptTraderApp) # main_engine.add_app(RpcServiceApp) # main_engine.add_app(SpreadTradingApp) - main_engine.add_app(PortfolioManagerApp) + # main_engine.add_app(PortfolioManagerApp) + # main_engine.add_app(AccountRecorderApp) main_window = MainWindow(main_engine, event_engine) main_window.showMaximized() diff --git a/vnpy/app/cta_strategy_pro/__init__.py b/vnpy/app/cta_strategy_pro/__init__.py index 29adc5cf..6cd18448 100644 --- a/vnpy/app/cta_strategy_pro/__init__.py +++ b/vnpy/app/cta_strategy_pro/__init__.py @@ -1,14 +1,17 @@ from pathlib import Path from vnpy.trader.app import BaseApp -from vnpy.trader.constant import Direction,Offset +from vnpy.trader.constant import Direction,Offset,Status from vnpy.trader.object import TickData, BarData, TradeData, OrderData from vnpy.trader.utility import BarGenerator, ArrayManager - +from .cta_position import CtaPosition +from .cta_line_bar import CtaLineBar, CtaMinuteBar, CtaHourBar, CtaDayBar, CtaWeekBar +from .cta_policy import CtaPolicy +from .cta_grid_trade import CtaGrid, CtaGridTrade from .base import APP_NAME, StopOrder from .engine import CtaEngine -from .template import CtaTemplate, CtaSignal, TargetPosTemplate +from .template import CtaTemplate, CtaSignal, TargetPosTemplate, CtaProTemplate class CtaStrategyProApp(BaseApp): """""" diff --git a/vnpy/app/cta_strategy_pro/base.py b/vnpy/app/cta_strategy_pro/base.py index b22eb7e1..24cd2b33 100644 --- a/vnpy/app/cta_strategy_pro/base.py +++ b/vnpy/app/cta_strategy_pro/base.py @@ -1,11 +1,11 @@ """ Defines constants and objects used in CtaStrategyPro App. """ - +from abc import ABC from dataclasses import dataclass, field from enum import Enum from datetime import timedelta - +from logging import INFO, ERROR from vnpy.trader.constant import Direction, Offset, Interval APP_NAME = "CtaStrategyPro" @@ -92,3 +92,30 @@ INTERVAL_DELTA_MAP = { Interval.HOUR: timedelta(hours=1), Interval.DAILY: timedelta(days=1), } + +class CtaComponent(ABC): + """ CTA策略基础组件""" + + def __init__(self, strategy=None, **kwargs): + """ + 构造 + :param strategy: + """ + self.strategy = strategy + + # ---------------------------------------------------------------------- + def write_log(self, content: str): + """记录日志""" + if self.strategy: + self.strategy.write_log(msg=content, level=INFO) + else: + print(content) + + # ---------------------------------------------------------------------- + def write_error(self, content: str, level: int = ERROR): + """记录错误日志""" + if self.strategy: + self.strategy.write_log(msg=content, level=level) + else: + print(content, file=sys.stderr) + diff --git a/vnpy/app/cta_strategy_pro/cta_grid_trade.py b/vnpy/app/cta_strategy_pro/cta_grid_trade.py index 1845fe2f..71bc9a7e 100644 --- a/vnpy/app/cta_strategy_pro/cta_grid_trade.py +++ b/vnpy/app/cta_strategy_pro/cta_grid_trade.py @@ -10,10 +10,9 @@ import traceback from collections import OrderedDict from datetime import datetime from dataclasses import dataclass, field - +from typing import List from vnpy.trader.utility import get_folder_path -from vnpy.app.cta_strategy_pro.base import Direction -from vnpy.app.cta_strategy_pro.template import CtaComponent +from vnpy.app.cta_strategy_pro.base import Direction, CtaComponent """ 网格交易,用于套利单 @@ -35,37 +34,47 @@ TREND_GRID = 'trend' # 趋势网格 LOCK_GRID = 'lock' # 对锁网格 -@dataclass class CtaGrid(object): """网格类 它是网格交易的最小单元 包括交易方向,开仓价格,平仓价格,止损价格,开仓状态,平仓状态 """ - id: str = str(uuid.uuid1()) # gid - direction: Direction = Direction.NET # 交易方向(LONG:多,正套;SHORT:空,反套) - open_price: float = 0 # 开仓价格 - close_price: float = 0 # 止盈价格 - stop_price: float = 0 # 止损价格 + def __init__(self, + direction: Direction = None, + open_price: float = 0, + close_price: float = 0, + stop_price: float = 0, + vt_symbol: str = '', + volume: float = 0, + traded_volume: float = 0, + order_status: bool = False, + open_status: bool = False, + close_status: bool = False, + open_time: datetime = None, + order_time: datetime = None, + reuse_count: int = 0, + type: str = '' + ): - vt_symbol: str = '' # 品种合约 - volume: float = 0 # 开仓数量( 兼容数字货币 ) - - traded_volume: float = 0 # 已成交数量 开仓时,为开仓数量,平仓时,为平仓数量 - - order_status: bool = False # 挂单状态: True,已挂单,False,未挂单 - order_ids: list[str] = field(default_factory=list) # order_id list - open_status: bool = False # 开仓状态 - close_status: bool = False # 平仓状态 - - open_time: datetime = None # 开仓时间 - order_time: datetime = None # 委托时间 - - lock_grid_ids: list[str] = field(default_factory=list) # 锁单的网格,[gid,gid] - reuse_count: int = 0 # 重用次数(0, 平仓后是否删除) - type: str = '' # 网格类型标签 - - snapshot: dict = field(default_factory=dict) # 切片数据,如记录开仓点时的某些状态数据 + self.id: str = str(uuid.uuid1()) # gid + self.direction = direction # 交易方向(LONG:多,正套;SHORT:空,反套) + self.open_price = open_price # 开仓价格 + self.close_price = close_price # 止盈价格 + self.stop_price = stop_price # 止损价格 + self.vt_symbol = vt_symbol # 品种合约 + self.volume = volume # 开仓数量( 兼容数字货币 ) + self.traded_volume = traded_volume # 已成交数量 开仓时,为开仓数量,平仓时,为平仓数量 + self.order_status = order_status # 挂单状态: True,已挂单,False,未挂单 + self.order_ids = [] # order_id list + self.open_status = open_status # 开仓状态 + self.close_status = close_status # 平仓状态 + self.open_time = open_time # 开仓时间 + self.order_time = order_time # 委托时间 + self.lock_grid_ids = [] # 锁单的网格,[gid,gid] + self.reuse_count = reuse_count # 重用次数(0, 平仓后是否删除) + self.type = type # 网格类型标签 + self.snapshot = {} # 切片数据,如记录开仓点时的某些状态数据 def to_json(self): """输出JSON""" @@ -156,10 +165,10 @@ class CtaGridTrade(CtaComponent): vol,网格开仓数 minDiff, 最小价格跳动 """ - super(CtaGridTrade).__init__(strategy=strategy) + super(CtaGridTrade, self).__init__(strategy=strategy) self.price_tick = kwargs.get('price_tick', 1) - self.jsonName = self.strategy.name # 策略名称 + self.json_name = self.strategy.strategy_name # 策略名称 self.max_lots = kwargs.get('max_lots', 10) # 缺省网格数量 self.grid_height = kwargs.get('grid_height', 10 * self.price_tick) # 最小网格高度 self.grid_win = kwargs.get('grid_win', 10 * self.price_tick) # 最小止盈高度 @@ -176,7 +185,8 @@ class CtaGridTrade(CtaComponent): self.max_up_open_price = 0.0 # 上网格最高开仓价 self.min_dn_open_price = 0.0 # 下网格最小开仓价 - self.json_file_path = os.path.join(get_folder_path('data'), f'{self.jsonName}_Grids.json') # 网格的路径 + # 网格json文件的路径 + self.json_file_path = os.path.join(get_folder_path('data'), f'{self.json_name}_Grids.json') def get_volume_rate(self, idx: int = 0): """获取网格索引对应的开仓数量比例""" @@ -554,7 +564,7 @@ class CtaGridTrade(CtaComponent): if direction == Direction.SHORT: for x in self.up_grids[:]: - if x.id in id: + if x.id in ids: self.write_log(u'清除上网格[open={},close={},stop={},volume={}]' .format(x.open_price, x.close_price, x.stop_price, x.volume)) self.up_grids.remove(x) @@ -873,12 +883,12 @@ class CtaGridTrade(CtaComponent): grids_save_path = get_folder_path('data') # 确保json名字与策略一致 - if self.jsonName != self.strategy.name: - self.write_log(u'JsonName {} 与 上层策略名{} 不一致.'.format(self.jsonName, self.strategy.name)) - self.jsonName = self.strategy.name + if self.json_name != self.strategy.strategy_name: + self.write_log(u'JsonName {} 与 上层策略名{} 不一致.'.format(self.json_name, self.strategy.strategy_name)) + self.json_name = self.strategy.strategy_name # 新版网格持久化文件 - grid_json_file = os.path.join(grids_save_path, u'{}_Grids.json'.format(self.jsonName)) + grid_json_file = os.path.join(grids_save_path, u'{}_Grids.json'.format(self.json_name)) self.json_file_path = grid_json_file data = {} @@ -907,12 +917,12 @@ class CtaGridTrade(CtaComponent): data = {} grids_save_path = get_folder_path('data') - if self.jsonName != self.strategy.name: - self.write_log(u'JsonName {} 与 上层策略名{} 不一致.'.format(self.jsonName, self.strategy.name)) - self.jsonName = self.strategy.name + if self.json_name != self.strategy.strategy_name: + self.write_log(u'JsonName {} 与 上层策略名{} 不一致.'.format(self.json_name, self.strategy.strategy_name)) + self.json_name = self.strategy.strategy_name # 若json文件不存在,就保存一个;若存在,就优先使用数据文件 - grid_json_file = os.path.join(grids_save_path, u'{}_Grids.json'.format(self.jsonName)) + grid_json_file = os.path.join(grids_save_path, u'{}_Grids.json'.format(self.json_name)) if not os.path.exists(grid_json_file): data['up_grids'] = [] data['dn_grids'] = [] @@ -970,7 +980,7 @@ class CtaGridTrade(CtaComponent): data_folder = get_folder_path('data') - self.jsonName = new_name + self.json_name = new_name # 旧文件 old_json_file = os.path.join(data_folder, u'{0}_Grids.json'.format(old_name)) diff --git a/vnpy/app/cta_strategy_pro/cta_line_bar.py b/vnpy/app/cta_strategy_pro/cta_line_bar.py index 134d84d6..c8f4170b 100644 --- a/vnpy/app/cta_strategy_pro/cta_line_bar.py +++ b/vnpy/app/cta_strategy_pro/cta_line_bar.py @@ -1016,7 +1016,12 @@ class CtaLineBar(object): def first_tick(self, tick: TickData): """ K线的第一个Tick数据""" - self.cur_bar = BarData() # 创建新的K线 + self.cur_bar = BarData( + gateway_name=tick.gateway_name, + symbol=tick.symbol, + exchange=tick.exchange, + datetime=tick.datetime + ) # 创建新的K线 # 计算K线的整点分钟周期,这里周期最小是1分钟。如果你是采用非整点分钟,例如1.5分钟,请把这段注解掉 if self.minute_interval and self.interval == Interval.SECOND: self.minute_interval = int(self.bar_interval / 60) @@ -1506,7 +1511,7 @@ class CtaLineBar(object): if self.para_ma1_len > 0: count_len = min(self.bar_len, self.para_ma1_len) if count_len > 0: - close_ma_array = ta.MA(np.append(self.close_array[-count_len:], [self.line_bar[-1].close]), count_len) + close_ma_array = ta.MA(np.append(self.close_array[-count_len:], [self.line_bar[-1].close_price]), count_len) self._rt_ma1 = round(float(close_ma_array[-1]), self.round_n) # 计算斜率 @@ -1517,7 +1522,7 @@ class CtaLineBar(object): if self.para_ma2_len > 0: count_len = min(self.bar_len, self.para_ma2_len) if count_len > 0: - close_ma_array = ta.MA(np.append(self.close_array[-count_len:], [self.line_bar[-1].close]), count_len) + close_ma_array = ta.MA(np.append(self.close_array[-count_len:], [self.line_bar[-1].close_price]), count_len) self._rt_ma2 = round(float(close_ma_array[-1]), self.round_n) # 计算斜率 @@ -1528,7 +1533,7 @@ class CtaLineBar(object): if self.para_ma3_len > 0: count_len = min(self.bar_len, self.para_ma3_len) if count_len > 0: - close_ma_array = ta.MA(np.append(self.close_array[-count_len:], [self.line_bar[-1].close]), count_len) + close_ma_array = ta.MA(np.append(self.close_array[-count_len:], [self.line_bar[-1].close_price]), count_len) self._rt_ma3 = round(float(close_ma_array[-1]), self.round_n) # 计算斜率 @@ -1846,27 +1851,27 @@ class CtaLineBar(object): # 计算 ATR if self.para_atr1_len > 0: count_len = min(self.bar_len, self.para_atr1_len) - self.cur_atr1 = ta.ATR(self.high_array[-count_len:], self.low_array[-count_len:], - self.close_array[-count_len:], count_len) - self.cur_atr1 = round(self.cur_atr1, self.round_n) + cur_atr1 = ta.ATR(self.high_array[-count_len * 2:], self.low_array[-count_len * 2:], + self.close_array[-count_len * 2:], count_len) + self.cur_atr1 = round(cur_atr1[-1], self.round_n) if len(self.line_atr1) > self.max_hold_bars: del self.line_atr1[0] self.line_atr1.append(self.cur_atr1) if self.para_atr2_len > 0: count_len = min(self.bar_len, self.para_atr2_len) - self.cur_atr2 = ta.ATR(self.high_array[-count_len:], self.low_array[-count_len:], - self.close_array[-count_len:], count_len) - self.cur_atr2 = round(self.cur_atr2, self.round_n) + cur_atr2 = ta.ATR(self.high_array[-count_len * 2:], self.low_array[-count_len * 2:], + self.close_array[-count_len * 2:], count_len) + self.cur_atr2 = round(cur_atr2[-1], self.round_n) if len(self.line_atr2) > self.max_hold_bars: del self.line_atr2[0] self.line_atr2.append(self.cur_atr2) if self.para_atr3_len > 0: count_len = min(self.bar_len, self.para_atr3_len) - self.cur_atr3 = ta.ATR(self.high_array[-count_len:], self.low_array[-count_len:], - self.close_array[-count_len:], count_len) - self.cur_atr3 = round(self.cur_atr3, self.round_n) + cur_atr3 = ta.ATR(self.high_array[-count_len * 2 :], self.low_array[-count_len * 2:], + self.close_array[-count_len * 2:], count_len) + self.cur_atr3 = round(cur_atr3[-1], self.round_n) if len(self.line_atr3) > self.max_hold_bars: del self.line_atr3[0] @@ -5398,424 +5403,5 @@ class CtaWeekBar(CtaLineBar): # 实时计算 self.rt_executed = False - self.lastTick = tick + self.last_tick = tick - -class test_strategy(object): - - def __init__(self): - - self.price_tick = 1 - self.underlying_symbol = 'I' - self.vt_symbol = 'I99' - - self.lineM5 = None - self.lineM30 = None - self.lineH1 = None - self.lineH2 = None - self.lineD = None - self.lineW = None - - self.TMinuteInterval = 1 - - self.save_m30_bars = [] - self.save_h1_bars = [] - self.save_h2_bars = [] - self.save_d_bars = [] - - self.save_w_bars = [] - - def createM5(self): - """使用ctalinbar,创建5分钟K线""" - lineM5Setting = {} - lineM5Setting['name'] = u'M5' - lineM5Setting['interval'] = Interval.MINUTE - lineM5Setting['bar_interval'] = 5 - lineM5Setting['mode'] = CtaLineBar.TICK_MODE - lineM5Setting['price_tick'] = self.price_tick - lineM5Setting['underlying_symbol'] = self.underlying_symbol - self.lineM5 = CtaLineBar(self, self.onBarM5, lineM5Setting) - - def onBarM5(self, bar): - self.write_log(self.lineM5.get_last_bar_str()) - - def createlineM30_with_macd(self): - """使用CtaLineBar,创建30分钟时间""" - # 创建M30 K线 - lineM30Setting = {} - lineM30Setting['name'] = u'M30' - lineM30Setting['interval'] = Interval.MINUTE - lineM30Setting['bar_interval'] = 30 - lineM30Setting['para_macd_fast_len'] = 26 - lineM30Setting['para_macd_slow_len'] = 12 - lineM30Setting['para_macd_signal_len'] = 9 - lineM30Setting['mode'] = CtaLineBar.TICK_MODE - lineM30Setting['price_tick'] = self.price_tick - lineM30Setting['underlying_symbol'] = self.underlying_symbol - self.lineM30 = CtaLineBar(self, self.onBarM30MACD, lineM30Setting) - - def onBarM30MACD(self, bar): - self.write_log(self.lineM30.get_last_bar_str()) - - def createLineM30(self): - """使用ctaMinuteBar, 测试内部自动写入csv文件""" - # 创建M30 K线 - lineM30Setting = {} - lineM30Setting['name'] = u'M30' - lineM30Setting['interval'] = Interval.MINUTE - lineM30Setting['bar_interval'] = 30 - lineM30Setting['para_pre_len'] = 10 - lineM30Setting['para_ma1_len'] = 5 - lineM30Setting['para_ma2_len'] = 10 - lineM30Setting['para_ma3_len'] = 60 - lineM30Setting['para_active_yb'] = True - lineM30Setting['para_active_skd'] = True - lineM30Setting['mode'] = CtaLineBar.TICK_MODE - lineM30Setting['price_tick'] = self.price_tick - lineM30Setting['underlying_symbol'] = self.underlying_symbol - self.lineM30 = CtaMinuteBar(self, self.onBarM30, lineM30Setting) - - # 写入文件 - self.lineM30.export_filename = os.path.abspath( - os.path.join(os.getcwd(), - u'export_{}_{}.csv'.format(self.vt_symbol, self.lineM30.name))) - - self.lineM30.export_fields = [ - {'name': 'datetime', 'source': 'bar', 'attr': 'datetime', 'type_': 'datetime'}, - {'name': 'open', 'source': 'bar', 'attr': 'open_price', 'type_': 'float'}, - {'name': 'high', 'source': 'bar', 'attr': 'high_price', 'type_': 'float'}, - {'name': 'low', 'source': 'bar', 'attr': 'low_price', 'type_': 'float'}, - {'name': 'close', 'source': 'bar', 'attr': 'close_price', 'type_': 'float'}, - {'name': 'turnover', 'source': 'bar', 'attr': 'turnover', 'type_': 'float'}, - {'name': 'volume', 'source': 'bar', 'attr': 'volume', 'type_': 'float'}, - {'name': 'open_interest', 'source': 'bar', 'attr': 'open_interest', 'type_': 'float'}, - {'name': 'kf', 'source': 'line_bar', 'attr': 'line_statemean', 'type_': 'list'} - ] - - def createLineH1(self): - # 创建2小时K线 - lineH1Setting = {} - lineH1Setting['name'] = u'H1' - lineH1Setting['interval'] = Interval.HOUR - lineH1Setting['bar_interval'] = 1 - lineH1Setting['para_pre_len'] = 10 - lineH1Setting['para_ema1_len'] = 5 - lineH1Setting['para_ema2_len'] = 10 - lineH1Setting['para_ema3_len'] = 60 - lineH1Setting['para_active_yb'] = True - lineH1Setting['para_active_skd'] = True - lineH1Setting['mode'] = CtaLineBar.TICK_MODE - lineH1Setting['price_tick'] = self.price_tick - lineH1Setting['underlying_symbol'] = self.underlying_symbol - self.lineH1 = CtaLineBar(self, self.onBarH1, lineH1Setting) - - def createLineH2(self): - # 创建2小时K线 - lineH2Setting = {} - lineH2Setting['name'] = u'H2' - lineH2Setting['interval'] = Interval.HOUR - lineH2Setting['bar_interval'] = 2 - lineH2Setting['para_pre_len'] = 5 - lineH2Setting['para_ma1_len'] = 5 - lineH2Setting['para_ma2_len'] = 10 - lineH2Setting['para_ma3_len'] = 18 - lineH2Setting['para_active_yb'] = True - lineH2Setting['para_active_skd'] = True - lineH2Setting['mode'] = CtaLineBar.TICK_MODE - lineH2Setting['price_tick'] = self.price_tick - lineH2Setting['underlying_symbol'] = self.underlying_symbol - self.lineH2 = CtaHourBar(self, self.onBarH2, lineH2Setting) - - def createLineD(self): - # 创建的日K线 - lineDaySetting = {} - lineDaySetting['name'] = u'D1' - lineDaySetting['bar_interval'] = 1 - lineDaySetting['para_pre_len'] = 5 - lineDaySetting['para_art1_len'] = 26 - lineDaySetting['para_ma1_len'] = 5 - lineDaySetting['para_ma2_len'] = 10 - lineDaySetting['para_ma3_len'] = 18 - lineDaySetting['para_active_yb'] = True - lineDaySetting['para_active_skd'] = True - lineDaySetting['mode'] = CtaDayBar.TICK_MODE - lineDaySetting['price_tick'] = self.price_tick - lineDaySetting['underlying_symbol'] = self.underlying_symbol - self.lineD = CtaDayBar(self, self.onBarD, lineDaySetting) - - def createLineW(self): - """创建周线""" - lineWeekSetting = {} - lineWeekSetting['name'] = u'W1' - lineWeekSetting['para_pre_len'] = 5 - lineWeekSetting['para_art1_len'] = 26 - lineWeekSetting['para_ma1_len'] = 5 - lineWeekSetting['para_ma2_len'] = 10 - lineWeekSetting['para_ma3_len'] = 18 - lineWeekSetting['para_active_yb'] = True - lineWeekSetting['para_active_skd'] = True - lineWeekSetting['mode'] = CtaDayBar.TICK_MODE - lineWeekSetting['price_tick'] = self.price_tick - lineWeekSetting['underlying_symbol'] = self.underlying_symbol - self.lineW = CtaWeekBar(self, self.onBarW, lineWeekSetting) - - def onBar(self, bar): - # print(u'tradingDay:{},dt:{},o:{},h:{},l:{},c:{},v:{}'.format(bar.trading_day,bar.datetime, bar.open, bar.high, bar.low_price, bar.close_price, bar.volume)) - if self.lineW: - self.lineW.add_bar(bar, bar_freq=self.TMinuteInterval) - if self.lineD: - self.lineD.add_bar(bar, bar_freq=self.TMinuteInterval) - if self.lineH2: - self.lineH2.add_bar(bar, bar_freq=self.TMinuteInterval) - - if self.lineH1: - self.lineH1.add_bar(bar, bar_freq=self.TMinuteInterval) - - if self.lineM30: - self.lineM30.add_bar(bar, bar_freq=self.TMinuteInterval) - - if self.lineM5: - self.lineM5.add_bar(bar, bar_freq=self.TMinuteInterval) - - # if self.lineH2: - # self.lineH2.skd_is_high_dead_cross(runtime=True, high_skd=30) - # self.lineH2.skd_is_low_golden_cross(runtime=True, low_skd=70) - - def onBarM30(self, bar): - self.write_log(self.lineM30.get_last_bar_str()) - - self.save_m30_bars.append({ - 'datetime': bar.datetime, - 'open': bar.open_price, - 'high': bar.high_price, - 'low': bar.low_price, - 'close': bar.close_price, - 'turnover': 0, - 'volume': bar.volume, - 'open_interest': 0, - 'ma5': self.lineM30.line_ma1[-1] if len(self.lineM30.line_ma1) > 0 else bar.close_price, - 'ma10': self.lineM30.line_ma2[-1] if len(self.lineM30.line_ma2) > 0 else bar.close_price, - 'ma60': self.lineM30.line_ma3[-1] if len(self.lineM30.line_ma3) > 0 else bar.close_price, - 'sk': self.lineM30.line_sk[-1] if len(self.lineM30.line_sk) > 0 else 0, - 'sd': self.lineM30.line_sd[-1] if len(self.lineM30.line_sd) > 0 else 0 - }) - - def onBarH1(self, bar): - self.write_log(self.lineH1.get_last_bar_str()) - - self.save_h1_bars.append({ - 'datetime': bar.datetime, - 'open': bar.open_price, - 'high': bar.high_price, - 'low': bar.low_price, - 'close': bar.close_price, - 'turnover': 0, - 'volume': bar.volume, - 'open_interest': 0, - 'ema5': self.lineH1.line_ema1[-1] if len(self.lineH1.line_ema1) > 0 else bar.close_price, - 'ema10': self.lineH1.line_ema2[-1] if len(self.lineH1.line_ema2) > 0 else bar.close_price, - 'ema60': self.lineH1.line_ema3[-1] if len(self.lineH1.line_ema3) > 0 else bar.close_price, - 'sk': self.lineH1.line_sk[-1] if len(self.lineH1.line_sk) > 0 else 0, - 'sd': self.lineH1.line_sd[-1] if len(self.lineH1.line_sd) > 0 else 0 - }) - - def onBarH2(self, bar): - self.write_log(self.lineH2.get_last_bar_str()) - - self.save_h2_bars.append({ - 'datetime': bar.datetime, - 'open': bar.open_price, - 'high': bar.high_price, - 'low': bar.low_price, - 'close': bar.close_price, - 'turnover': 0, - 'volume': bar.volume, - 'open_interest': 0, - 'ma5': self.lineH2.line_ma1[-1] if len(self.lineH2.line_ma1) > 0 else bar.close_price, - 'ma10': self.lineH2.line_ma2[-1] if len(self.lineH2.line_ma2) > 0 else bar.close_price, - 'ma18': self.lineH2.line_ma3[-1] if len(self.lineH2.line_ma3) > 0 else bar.close_price, - 'sk': self.lineH2.line_sk[-1] if len(self.lineH2.line_sk) > 0 else 0, - 'sd': self.lineH2.line_sd[-1] if len(self.lineH2.line_sd) > 0 else 0 - }) - - def onBarD(self, bar): - self.write_log(self.lineD.get_last_bar_str()) - self.save_d_bars.append({ - 'datetime': bar.datetime, - 'open': bar.open_price, - 'high': bar.high_price, - 'low': bar.low_price, - 'close': bar.close_price, - 'turnover': 0, - 'volume': bar.volume, - 'open_interest': 0, - 'ma5': self.lineD.line_ma1[-1] if len(self.lineD.line_ma1) > 0 else bar.close_price, - 'ma10': self.lineD.line_ma2[-1] if len(self.lineD.line_ma2) > 0 else bar.close_price, - 'ma18': self.lineD.line_ma3[-1] if len(self.lineD.line_ma3) > 0 else bar.close_price, - 'sk': self.lineD.line_sk[-1] if len(self.lineD.line_sk) > 0 else 0, - 'sd': self.lineD.line_sd[-1] if len(self.lineD.line_sd) > 0 else 0 - }) - - def onBarW(self, bar): - self.write_log(self.lineW.get_last_bar_str()) - self.save_w_bars.append({ - 'datetime': bar.datetime, - 'open': bar.open_price, - 'high': bar.high_price, - 'low': bar.low_price, - 'close': bar.close_price, - 'turnover': 0, - 'volume': bar.volume, - 'open_interest': 0, - 'ma5': self.lineW.line_ma1[-1] if len(self.lineW.line_ma1) > 0 else bar.close_price, - 'ma10': self.lineW.line_ma2[-1] if len(self.lineW.line_ma2) > 0 else bar.close_price, - 'ma18': self.lineW.line_ma3[-1] if len(self.lineW.line_ma3) > 0 else bar.close_price, - 'sk': self.lineW.line_sk[-1] if len(self.lineW.line_sk) > 0 else 0, - 'sd': self.lineW.line_sd[-1] if len(self.lineW.line_sd) > 0 else 0 - }) - - def on_tick(self, tick): - print(u'{0},{1},ap:{2},av:{3},bp:{4},bv:{5}'.format(tick.datetime, tick.last_price, tick.ask_price_1, - tick.ask_volume_1, tick.bid_price_1, tick.bid_volume_1)) - - def write_log(self, content): - print(content) - - def saveData(self): - - if len(self.save_m30_bars) > 0: - outputFile = '{}_m30.csv'.format(self.vt_symbol) - with open(outputFile, 'w', encoding='utf8', newline='') as f: - fieldnames = ['datetime', 'open', 'high', 'low', 'close', 'turnover', 'volume', 'open_interest', - 'ma5', 'ma10', 'ma60', 'sk', 'sd'] - writer = csv.DictWriter(f=f, fieldnames=fieldnames, dialect='excel') - writer.writeheader() - for row in self.save_m30_bars: - writer.writerow(row) - - if len(self.save_h1_bars) > 0: - outputFile = '{}_h1.csv'.format(self.vt_symbol) - with open(outputFile, 'w', encoding='utf8', newline='') as f: - fieldnames = ['datetime', 'open', 'high', 'low', 'close', 'turnover', 'volume', 'open_interest', - 'ema5', 'ema10', 'ema60', 'sk', 'sd'] - writer = csv.DictWriter(f=f, fieldnames=fieldnames, dialect='excel') - writer.writeheader() - for row in self.save_h1_bars: - writer.writerow(row) - - if len(self.save_h2_bars) > 0: - outputFile = '{}_h2.csv'.format(self.vt_symbol) - with open(outputFile, 'w', encoding='utf8', newline='') as f: - fieldnames = ['datetime', 'open', 'high', 'low', 'close', 'turnover', 'volume', 'open_interest', - 'ma5', 'ma10', 'ma18', 'sk', 'sd'] - writer = csv.DictWriter(f=f, fieldnames=fieldnames, dialect='excel') - writer.writeheader() - for row in self.save_h2_bars: - writer.writerow(row) - - if len(self.save_d_bars) > 0: - outputFile = '{}_d.csv'.format(self.vt_symbol) - with open(outputFile, 'w', encoding='utf8', newline='') as f: - fieldnames = ['datetime', 'open', 'high', 'low', 'close', 'turnover', 'volume', 'open_interest', - 'ma5', 'ma10', 'ma18', 'sk', 'sd'] - writer = csv.DictWriter(f=f, fieldnames=fieldnames, dialect='excel') - writer.writeheader() - for row in self.save_d_bars: - writer.writerow(row) - - if len(self.save_w_bars) > 0: - outputFile = '{}_w.csv'.format(self.vt_symbol) - with open(outputFile, 'w', encoding='utf8', newline='') as f: - fieldnames = ['datetime', 'open', 'high', 'low', 'close', 'turnover', 'volume', 'open_interest', - 'ma5', 'ma10', 'ma18', 'sk', 'sd'] - writer = csv.DictWriter(f=f, fieldnames=fieldnames, dialect='excel') - writer.writeheader() - for row in self.save_w_bars: - writer.writerow(row) - - -if __name__ == '__main__': - t = test_strategy() - t.price_tick = 0.5 - t.underlying_symbol = 'J' - t.vt_symbol = 'J99' - - # t.createM5() - # t.createLineW() - - # t.createlineM30_with_macd() - - # 创建M30线 - # t.createLineM30() - - # 回测1小时线 - # t.createLineH1() - - # 回测2小时线 - # t.createLineH2() - - # 回测日线 - # t.createLineD() - - # 测试周线 - t.createLineW() - - # vnpy/app/cta_strategy_pro/ - vnpy_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..')) - - filename = os.path.abspath(os.path.join(vnpy_root, 'bar_data/{}_20160101_20190517_1m.csv'.format(t.vt_symbol))) - csv_bar_seconds = 60 # csv 文件内,bar的时间间隔60秒 - - import csv - - csvfile = open(filename, 'r', encoding='utf8') - reader = csv.DictReader((line.replace('\0', '') for line in csvfile), delimiter=",") - last_tradingDay = None - for row in reader: - try: - dt = datetime.strptime(row['index'], '%Y-%m-%d %H:%M:%S') - timedelta(seconds=csv_bar_seconds) - - bar = BarData( - gateway_name='', - symbol=t.vt_symbol, - exchange=Exchange.LOCAL, - datetime=dt, - interval=Interval.MINUTE, - open_price=round_to(float(row['open']), t.price_tick), - high_price=round_to(float(row['high']), t.price_tick), - low_price=round_to(float(row['low']), t.price_tick), - close_price=round_to(float(row['close']), t.price_tick), - volume=float(row['volume']) - ) - - if 'trading_date' in row: - bar.trading_day = row['trading_date'] - if len(bar.trading_day) == 8 and '-' not in bar.trading_day: - bar.trading_day = bar.trading_day[0:4] + '-' + bar.trading_day[4:6] + '-' + bar.trading_day[6:8] - else: - if bar.datetime.hour >= 21: - if bar.datetime.isoweekday() == 5: - # 星期五=》星期一 - bar.trading_day = (dt + timedelta(days=3)).strftime('%Y-%m-%d') - else: - # 第二天 - bar.trading_day = (dt + timedelta(days=1)).strftime('%Y-%m-%d') - elif bar.datetime.hour < 8 and bar.datetime.isoweekday() == 6: - # 星期六=>星期一 - bar.trading_day = (dt + timedelta(days=2)).strftime('%Y-%m-%d') - else: - bar.trading_day = bar.datetime.strftime('%Y-%m-%d') - - t.onBar(bar) - # 测试 实时计算值 - # sk, sd = t.lineM30.getRuntimeSKD() - - # 测试实时计算值 - # if bar.datetime.minute==1: - # print('rt_Dif:{}'.format(t.lineM30.rt_Dif)) - except Exception as ex: - t.write_log(u'{0}:{1}'.format(Exception, ex)) - traceback.print_exc() - break - - t.saveData() diff --git a/vnpy/app/cta_strategy_pro/cta_policy.py b/vnpy/app/cta_strategy_pro/cta_policy.py index d6066cff..a9e0749d 100644 --- a/vnpy/app/cta_strategy_pro/cta_policy.py +++ b/vnpy/app/cta_strategy_pro/cta_policy.py @@ -4,7 +4,7 @@ import os import json from datetime import datetime from collections import OrderedDict -from vnpy.app.cta_strategy_pro.template import CtaComponent +from vnpy.app.cta_strategy_pro.base import CtaComponent from vnpy.trader.utility import get_folder_path @@ -18,12 +18,12 @@ class CtaPolicy(CtaComponent): 构造 :param strategy: """ - super(CtaPolicy).__init__(strategy=strategy) + super(CtaPolicy,self).__init__(strategy=strategy, kwargs=kwargs) self.create_time = None self.save_time = None - def toJson(self): + def to_json(self): """ 将数据转换成dict datetime =》 string @@ -36,7 +36,7 @@ class CtaPolicy(CtaComponent): return j - def fromJson(self, json_data): + def from_json(self, json_data): """ 将数据从json_data中恢复 :param json_data: @@ -67,7 +67,7 @@ class CtaPolicy(CtaComponent): 从持久化文件中获取 :return: """ - json_file = os.path.abspath(os.path.join(get_folder_path('data'), u'{}_Policy.json'.format(self.strategy.name))) + json_file = os.path.abspath(os.path.join(get_folder_path('data'), u'{}_Policy.json'.format(self.strategy.strategy_name))) json_data = {} if os.path.exists(json_file): @@ -80,7 +80,7 @@ class CtaPolicy(CtaComponent): json_data = {} # 从持久化文件恢复数据 - self.fromJson(json_data) + self.from_json(json_data) def save(self): """ @@ -88,14 +88,14 @@ class CtaPolicy(CtaComponent): :return: """ json_file = os.path.abspath( - os.path.join(get_folder_path('data'), u'{}_Policy.json'.format(self.strategy.name))) + os.path.join(get_folder_path('data'), u'{}_Policy.json'.format(self.strategy.strategy_name))) try: # 修改为:回测时不保存 if self.strategy and self.strategy.backtesting: return - json_data = self.toJson() + json_data = self.to_json() json_data['save_time'] = datetime.now().strftime('%Y-%m-%d %H:%M:%S') with open(json_file, 'w') as f: data = json.dumps(json_data, indent=4) diff --git a/vnpy/app/cta_strategy_pro/cta_position.py b/vnpy/app/cta_strategy_pro/cta_position.py index a1a47b38..b38cc2ca 100644 --- a/vnpy/app/cta_strategy_pro/cta_position.py +++ b/vnpy/app/cta_strategy_pro/cta_position.py @@ -2,8 +2,7 @@ import sys -from vnpy.app.cta_strategy_pro.base import Direction -from vnpy.app.cta_strategy_pro.template import CtaComponent +from vnpy.app.cta_strategy_pro.base import Direction, CtaComponent class CtaPosition(CtaComponent): @@ -15,13 +14,13 @@ class CtaPosition(CtaComponent): """ def __init__(self, strategy, **kwargs): - super(CtaComponent).__init__(strategy=strategy) + super(CtaPosition, self).__init__(strategy=strategy, kwargs=kwargs) self.long_pos = 0 # 多仓持仓(正数) self.short_pos = 0 # 空仓持仓(负数) self.pos = 0 # 持仓状态 0:空仓/对空平等; >=1 净多仓 ;<=-1 净空仓 self.maxPos = sys.maxsize # 最大持仓量(多仓+空仓总量) - def open_pos(self, direction: Direction, volume: int): + def open_pos(self, direction: Direction, volume: float): """开、加仓""" # volume: 正整数 @@ -31,7 +30,7 @@ class CtaPosition(CtaComponent): # 更新 self.write_log(f'多仓:{self.long_pos}->{self.long_pos + volume}') - self.write_log(u'净:{self.pos}->{self.pos + volume}') + self.write_log(f'净:{self.pos}->{self.pos + volume}') self.long_pos += volume self.pos += volume @@ -40,13 +39,13 @@ class CtaPosition(CtaComponent): self.write_error(content=f'开仓异常,净:{self.pos},空:{self.short_pos},加空:{volume},超过:{self.maxPos}') self.write_log(f'空仓:{self.short_pos}->{self.short_pos - volume}') - self.write_log(u'净:{self.pos}->{self.pos - volume}') + self.write_log(f'净:{self.pos}->{self.pos - volume}') self.short_pos -= volume self.pos -= volume return True - def close_pos(self, direction: Direction, volume): + def close_pos(self, direction: Direction, volume:float): """平、减仓""" # vol: 正整数 diff --git a/vnpy/app/cta_strategy_pro/engine.py b/vnpy/app/cta_strategy_pro/engine.py index c1d902b8..f503aa1e 100644 --- a/vnpy/app/cta_strategy_pro/engine.py +++ b/vnpy/app/cta_strategy_pro/engine.py @@ -43,13 +43,17 @@ from vnpy.trader.constant import ( Status ) from vnpy.trader.utility import ( - load_json, save_json, + load_json, + save_json, extract_vt_symbol, - round_to, get_folder_path, + round_to, + TRADER_DIR, + get_folder_path, get_underlying_symbol, append_data) from vnpy.trader.util_logger import setup_logger, logging +from vnpy.trader.util_wechat import send_wx_msg from vnpy.trader.converter import OffsetConverter from .base import ( @@ -88,7 +92,6 @@ class CtaEngine(BaseEngine): 6、支持指定gateway的交易。主引擎可接入多个gateway """ - engine_type = EngineType.LIVE # live trading engine engine_type = EngineType.LIVE # live trading engine # 策略配置文件 @@ -264,13 +267,17 @@ class CtaEngine(BaseEngine): # strategy.pos -= trade.volume # 根据策略名称,写入 data\straetgy_name_trade.csv文件 strategy_name = getattr(strategy, 'name') - trade_fields = ['time', 'symbol', 'exchange', 'vt_symbol', 'tradeid', 'vt_tradeid', 'orderid', 'vt_orderid', + trade_fields = ['datetime', 'symbol', 'exchange', 'vt_symbol', 'tradeid', 'vt_tradeid', 'orderid', 'vt_orderid', 'direction', 'offset', 'price', 'volume', 'idx_price'] trade_dict = OrderedDict() try: for k in trade_fields: - if k == 'time': - trade_dict[k] = datetime.now().strftime('%Y-%m-%d') + ' ' + getattr(trade, k, '') + if k == 'datetime': + dt = getattr(trade, 'datetime') + if isinstance(dt, datetime): + trade_dict[k] = dt.strftime('%Y-%m-%d %H:%M:%S') + else: + trade_dict[k] = datetime.now().strftime('%Y-%m-%d') + ' ' + getattr(trade, 'time', '') if k in ['exchange', 'direction', 'offset']: trade_dict[k] = getattr(trade, k).value else: @@ -586,12 +593,12 @@ class CtaEngine(BaseEngine): order = self.main_engine.get_order(vt_orderid) if not order: self.write_log(msg=f"撤单失败,找不到委托{vt_orderid}", - strategy_Name=strategy.name, + strategy_Name=strategy.strategy_name, level=logging.ERROR) - return + return False req = order.create_cancel_request() - self.main_engine.cancel_order(req, order.gateway_name) + return self.main_engine.cancel_order(req, order.gateway_name) def cancel_local_stop_order(self, strategy: CtaTemplate, stop_orderid: str): """ @@ -599,7 +606,7 @@ class CtaEngine(BaseEngine): """ stop_order = self.stop_orders.get(stop_orderid, None) if not stop_order: - return + return False strategy = self.strategies[stop_order.strategy_name] # Remove from relation map. @@ -614,6 +621,7 @@ class CtaEngine(BaseEngine): self.call_strategy_func(strategy, strategy.on_stop_order, stop_order) self.put_stop_order_event(stop_order) + return True def send_order( self, @@ -634,7 +642,7 @@ class CtaEngine(BaseEngine): contract = self.main_engine.get_contract(vt_symbol) if not contract: self.write_log(msg=f"委托失败,找不到合约:{vt_symbol}", - strategy_name=strategy.name, + strategy_name=strategy.strategy_name, level=logging.ERROR) return "" if contract.gateway_name and not gateway_name: @@ -659,9 +667,9 @@ class CtaEngine(BaseEngine): """ """ if vt_orderid.startswith(STOPORDER_PREFIX): - self.cancel_local_stop_order(strategy, vt_orderid) + return self.cancel_local_stop_order(strategy, vt_orderid) else: - self.cancel_server_order(strategy, vt_orderid) + return self.cancel_server_order(strategy, vt_orderid) def cancel_all(self, strategy: CtaTemplate): """ @@ -689,7 +697,7 @@ class CtaEngine(BaseEngine): self.main_engine.subscribe(req, gateway_name) else: self.write_log(msg=f"找不到合约{vt_symbol},添加到待订阅列表", - strategy_name=strategy.name) + strategy_name=strategy.strategy_name) self.pending_subcribe_symbol_map[f'{gateway_name}.{vt_symbol}'].add((strategy_name, is_bar)) try: self.write_log(f'找不到合约{vt_symbol}信息,尝试请求所有接口') @@ -713,7 +721,7 @@ class CtaEngine(BaseEngine): strategies.append(strategy) # 添加 策略名 strategy_name <=> 合约订阅 vt_symbol 的映射 - subscribe_symbol_set = self.strategy_symbol_map[strategy.name] + subscribe_symbol_set = self.strategy_symbol_map[strategy.strategy_name] subscribe_symbol_set.add(contract.vt_symbol) return True @@ -786,6 +794,16 @@ class CtaEngine(BaseEngine): """""" return self.engine_type + @lru_cache() + def get_data_path(self): + data_path = os.path.abspath(os.path.join(TRADER_DIR, 'data')) + return data_path + + @lru_cache() + def get_logs_path(self): + log_path = os.path.abspath(os.path.join(TRADER_DIR, 'log')) + return log_path + def call_strategy_func( self, strategy: CtaTemplate, func: Callable, params: Any = None ): @@ -1389,3 +1407,17 @@ class CtaEngine(BaseEngine): subject = "CTA策略引擎" self.main_engine.send_email(subject, msg) + + def send_wechat(self, msg: str, strategy: CtaTemplate = None): + """ + send wechat message to default receiver + :param msg: + :param strategy: + :return: + """ + if strategy: + subject = f"{strategy.strategy_name}" + else: + subject = "CTAPRO引擎" + + send_wx_msg(content=f'{subject}:{msg}') diff --git a/vnpy/app/cta_strategy_pro/portfolio_testing.py b/vnpy/app/cta_strategy_pro/portfolio_testing.py index 39ca3458..6c03dbbe 100644 --- a/vnpy/app/cta_strategy_pro/portfolio_testing.py +++ b/vnpy/app/cta_strategy_pro/portfolio_testing.py @@ -14,29 +14,22 @@ import importlib import csv import copy import pandas as pd -import re import traceback -import decimal import numpy as np import random import logging -from collections import OrderedDict,defaultdict +from collections import OrderedDict, defaultdict from datetime import datetime, timedelta from functools import lru_cache from pathlib import Path from time import sleep -cta_engine_path = os.path.abspath(os.path.dirname(__file__)) -vnpy_root = os.path.abspath(os.path.join(cta_engine_path, '..', '..', '..', '..')) - from .base import ( - BacktestingMode, EngineType, STOPORDER_PREFIX, StopOrder, - StopOrderStatus, - INTERVAL_DELTA_MAP + StopOrderStatus ) from .template import CtaTemplate @@ -44,6 +37,7 @@ from .cta_fund_kline import FundKline from vnpy.trader.object import ( BarData, + RenkoBarData, OrderData, TradeData, ContractData @@ -63,15 +57,13 @@ from vnpy.trader.utility import ( get_underlying_symbol, round_to, extract_vt_symbol, - format_number + format_number, + import_module_by_str ) from vnpy.trader.util_logger import setup_logger -from vnpy.data.tdx.tdx_common import get_future_contracts - -######################################################################## class PortfolioTestingEngine(object): """ CTA组合回测引擎 @@ -114,7 +106,7 @@ class PortfolioTestingEngine(object): self.margin_rate = {} # 回测合约的保证金比率 self.price_dict = {} # 登记vt_symbol对应的最新价 self.contract_dict = {} # 登记vt_symbol得对应合约信息 - self.symbol_exchange_dict = {} # 登记symbol: exchange的对应关系 + self.symbol_exchange_dict = {} # 登记symbol: exchange的对应关系 self.bar_csv_file = {} self.bar_df_dict = {} # 历史数据的df,回测用 @@ -124,6 +116,10 @@ class PortfolioTestingEngine(object): self.data_end_date = None # 回测数据结束日期,datetime对象 (用于截取数据) self.strategy_start_date = None # 策略启动日期(即前面的数据用于初始化),datetime对象 + self.stop_order_count = 0 # 本地停止单编号 + self.stop_orders = {} # 本地停止单 + self.active_stop_orders = {} # 活动本地停止单 + self.limit_order_count = 0 # 限价单编号 self.limit_orders = OrderedDict() # 限价单字典 self.active_limit_orders = OrderedDict() # 活动限价单字典,用于进行撮合用 @@ -146,7 +142,7 @@ class PortfolioTestingEngine(object): self.gateway_name = u'BackTest' self.last_bar = {} # 最新的bar - self.last_dt = None + self.last_dt = None # 最新时间 # csvFile相关 self.bar_interval_seconds = 60 # csv文件,属于K线类型,K线的周期(秒数),缺省是1分钟 @@ -246,7 +242,7 @@ class PortfolioTestingEngine(object): else: return None - def get_account(self): + def get_account(self, vt_accountid: str = ""): """返回账号的实时权益,可用资金,仓位比例,投资仓位比例上限""" if self.net_capital == 0.0: self.percent = 0.0 @@ -342,7 +338,7 @@ class PortfolioTestingEngine(object): def set_contract(self, symbol: str, exchange: Exchange, product: Product, name: str, size: int, price_tick: float): """设置合约信息""" - vt_symbol = '.'.join(symbol, exchange.value) + vt_symbol = '.'.join([symbol, exchange.value]) if vt_symbol not in self.contract_dict: c = ContractData( gateway_name=self.gateway_name, @@ -358,6 +354,7 @@ class PortfolioTestingEngine(object): # self.set_margin_rate(vt_symbol, ) self.set_price_tick(vt_symbol, price_tick) self.symbol_exchange_dict.update({symbol: exchange}) + @lru_cache() def get_contract(self, vt_symbol): """获取合约配置信息""" @@ -383,25 +380,48 @@ class PortfolioTestingEngine(object): """ self.daily_report_name = report_file - def load_csv_to_df(self, symbol, bar_file, data_start_date=None, data_end_date=None): + def load_csv_to_df(self, vt_symbol, bar_file, data_start_date=None, data_end_date=None): """回测数据初始化""" - self.output(u'loading {} from {}'.format(symbol, bar_file)) - if symbol in self.bar_df_dict: + self.output(u'loading {} from {}'.format(vt_symbol, bar_file)) + if vt_symbol in self.bar_df_dict: return True if not os.path.isfile(bar_file): - self.write_error(u'回测时,{}对应的csv bar文件{}不存在'.format(symbol, bar_file)) + self.write_error(u'回测时,{}对应的csv bar文件{}不存在'.format(vt_symbol, bar_file)) return False try: - symbol_df = pd.read_csv(bar_file).set_index("index").rename(index=pd.to_datetime) + data_types = { + "datetime": str, + "open": float, + "high": float, + "low": float, + "close": float, + "open_interest": float, + "volume": float, + "instrument_id": str, + "symbol": str, + "total_turnover": float, + "limit_down": float, + "limit_up": float, + "trading_day": str, + "date": str, + "time": str + } + # 加载csv文件 =》 dateframe + symbol_df = pd.read_csv(bar_file, dtype=data_types) + # 转换时间,str =》 datetime + symbol_df["datetime"] = pd.to_datetime(symbol_df["datetime"], format="%Y-%m-%d %H:%M:%S") + # 设置时间为索引 + symbol_df = symbol_df.set_index("datetime") # 裁剪数据 symbol_df = symbol_df.loc[self.test_start_date:self.test_end_date] - self.bar_df_dict.update({symbol: symbol_df}) + self.bar_df_dict.update({vt_symbol: symbol_df}) except Exception as ex: - self.write_error(u'回测时读取{} csv文件{}失败:{}'.format(symbol, bar_file, ex)) + self.write_error(u'回测时读取{} csv文件{}失败:{}'.format(vt_symbol, bar_file, ex)) + self.output(u'回测时读取{} csv文件{}失败:{}'.format(vt_symbol, bar_file, ex)) return False return True @@ -421,12 +441,25 @@ class PortfolioTestingEngine(object): self.set_name(test_settings.get('name')) self.debug = test_settings.get('debug', False) - # 创建日志 - self.create_logger(debug=test_settings.get('debug', False)) + # 更新数据目录 - self.data_path = os.path.abspath(os.path.join(vnpy_root, test_settings.get('data_path', 'data'))) + if 'data_path' in test_settings: + self.data_path = test_settings.get('data_path') + else: + self.data_path = os.path.abspath(os.path.join(os.getcwd(), 'data')) + + print(f'数据输出目录:{self.data_path}') + # 更新日志目录 - self.logs_path = os.path.abspath(os.path.join(vnpy_root, test_settings.get('logs_path', 'logs'))) + if 'logs_path' in test_settings: + self.logs_path = os.path.abspath(os.path.join(test_settings.get('logs_path'), self.test_name)) + else: + self.logs_path = os.path.abspath(os.path.join(os.getcwd(), 'log', self.test_name)) + print(f'日志输出目录:{self.logs_path}') + + # 创建日志 + self.create_logger(debug=self.debug) + # 设置资金 if 'init_capital' in test_settings: self.write_log(u'设置期初资金:{}'.format(test_settings.get('init_capital'))) @@ -472,6 +505,8 @@ class PortfolioTestingEngine(object): # 创建资金K线 self.create_fund_kline(self.test_name, use_renko=test_settings.get('use_renko', False)) + self.load_strategy_class() + def prepare_data(self, data_dict): """ 准备组合数据 @@ -491,12 +526,21 @@ class PortfolioTestingEngine(object): self.set_slippage(symbol, symbol_data.get('slippage', 0)) - self.set_size(symbol, symbol_data.get('size', 10)) + self.set_size(symbol, symbol_data.get('symbol_size', 10)) self.set_margin_rate(symbol, symbol_data.get('margin_rate', 0.1)) self.set_commission_rate(symbol, symbol_data.get('commission_rate', float(0.0001))) + self.set_contract( + symbol=symbol, + name=symbol, + exchange=Exchange(symbol_data.get('exchange', 'LOCAL')), + product=Product(symbol_data.get('product', "期货")), + size=symbol_data.get('symbol_size', 10), + price_tick=symbol_data.get('price_tick', 1) + ) + bar_file = symbol_data.get('bar_file', None) if bar_file is None: @@ -545,15 +589,18 @@ class PortfolioTestingEngine(object): self.write_log(u'开始回测:{} ~ {}'.format(self.data_start_date, self.data_end_date)) # 加载数据 - for symbol in self.symbol_strategy_map.keys(): - self.load_csv_to_df(symbol, self.bar_csv_file.get(symbol)) + for vt_symbol in self.symbol_strategy_map.keys(): + symbol, exchange = extract_vt_symbol(vt_symbol) + self.load_csv_to_df(vt_symbol, self.bar_csv_file.get(symbol)) # 为套利合约提取主动 / 被动合约 - if symbol.endswith('SPD') or symbol.endswith('SPD99'): + if exchange == Exchange.SPD: try: active_symbol, active_rate, passive_symbol, passive_rate, spd_type = symbol.split('-') - self.load_csv_to_df(active_symbol, self.bar_csv_file.get(active_symbol)) - self.load_csv_to_df(passive_symbol, self.bar_csv_file.get(passive_symbol)) + active_vt_symbol = '.'.join([active_symbol, self.get_exchange(symbol=active_symbol)]) + passive_vt_symbol = '.'.join([passive_symbol, self.get_exchange(symbol=passive_symbol)]) + self.load_csv_to_df(active_vt_symbol, self.bar_csv_file.get(active_symbol)) + self.load_csv_to_df(passive_vt_symbol, self.bar_csv_file.get(passive_symbol)) except Exception as ex: self.write_error(u'为套利合约提取主动/被动合约出现异常:{}'.format(str(ex))) @@ -567,19 +614,33 @@ class PortfolioTestingEngine(object): gc_collect_days = 0 try: - for (dt, symbol), bar_data in self.bar_df.iterrows(): - - if symbol.startwith('future_renko'): + for (dt, vt_symbol), bar_data in self.bar_df.iterrows(): + symbol, exchange = extract_vt_symbol(vt_symbol) + if symbol.startswith('future_renko'): bar_datetime = dt + bar = RenkoBarData( + gateway_name='backtesting', + symbol=symbol, + exchange=exchange, + datetime=bar_datetime + ) + bar.seconds = float(bar_data.get('seconds', 0)) + bar.high_seconds = float(bar_data.get('high_seconds', 0)) # 当前Bar的上限秒数 + bar.low_seconds = float(bar_data.get('low_seconds', 0)) # 当前bar的下限秒数 + bar.height = float(bar_data.get('height', 0)) # 当前Bar的高度限制 + bar.up_band = float(bar_data.get('up_band', 0)) # 高位区域的基线 + bar.down_band = float(bar_data.get('down_band', 0)) # 低位区域的基线 + bar.low_time = bar_data.get('low_time', None) # 最后一次进入低位区域的时间 + bar.high_time = bar_data.get('high_time', None) # 最后一次进入高位区域的时间 else: bar_datetime = dt - timedelta(seconds=self.bar_interval_seconds) - bar = BarData( - gateway_name='backtesting', - symbol=symbol, - exchange=Exchange.LOCAL, - datetime=bar_datetime - ) + bar = BarData( + gateway_name='backtesting', + symbol=symbol, + exchange=exchange, + datetime=bar_datetime + ) bar.open_price = float(bar_data['open']) bar.close_price = float(bar_data['close']) @@ -635,6 +696,7 @@ class PortfolioTestingEngine(object): if self.net_capital < 0: self.write_error(u'净值低于0,回测停止') + self.output(u'净值低于0,回测停止') return self.write_log(u'数据回放完成') @@ -652,7 +714,8 @@ class PortfolioTestingEngine(object): """新的K线""" self.last_bar.update({bar.vt_symbol: bar}) self.last_dt = bar.datetime - self.set_price(bar.vt_symbol, bar.close) + self.set_price(bar.vt_symbol, bar.close_price) + self.cross_stop_order(bar) # 撮合停止单 self.cross_limit_order(bar) # 先撮合限价单 # 更新资金曲线(只有持仓时,才更新) @@ -660,11 +723,9 @@ class PortfolioTestingEngine(object): if fund_kline is not None and (len(self.long_position_list) > 0 or len(self.short_position_list) > 0): fund_kline.update_account(self.last_dt, self.net_capital) - self.set_price({bar.vt_symbol: bar.close}) - for strategy in self.symbol_strategy_map.get(bar.vt_symbol, []): # 更新策略的资金K线 - fund_kline = self.fund_kline_dict.get(strategy.name, None) + fund_kline = self.fund_kline_dict.get(strategy.strategy_name, None) if fund_kline: hold_pnl = fund_kline.get_hold_pnl() if hold_pnl != 0: @@ -677,7 +738,7 @@ class PortfolioTestingEngine(object): if not strategy.trading and self.strategy_start_date < bar.datetime: strategy.trading = True strategy.on_start() - self.output(u'{}策略启动交易'.format(strategy.name)) + self.output(u'{}策略启动交易'.format(strategy.strategy_name)) def load_strategy_class(self): """ @@ -724,7 +785,8 @@ class PortfolioTestingEngine(object): return True except: # noqa msg = f"策略文件{module_name}加载失败,触发异常:\n{traceback.format_exc()}" - self.write_log(msg=msg, level=logging.CRITICAL) + self.write_error(msg) + self.output(msg) return False def load_strategy(self, strategy_name: str, strategy_setting: dict = None): @@ -752,19 +814,25 @@ class PortfolioTestingEngine(object): self.write_log(u'转换策略为全路径:{}'.format(class_name)) # 获取策略类的定义 - strategy_class = self.load_strategy_class_from_module(class_name) + strategy_class = import_module_by_str(class_name) if strategy_class is None: self.write_error(u'加载策略模块失败:{}'.format(class_name)) return # 处理 vt_symbol vt_symbol = strategy_setting.get('vt_symbol') - symbol, exchange = extract_vt_symbol(vt_symbol) + if '.' in vt_symbol: + symbol, exchange = extract_vt_symbol(vt_symbol) + else: + symbol = vt_symbol + underly_symbol = get_underlying_symbol(symbol).upper() + exchange = self.get_exchange(f'{underly_symbol}99') + vt_symbol = '.'.join([symbol, exchange.value]) # 在期货组合回测,中需要把一般配置的主力合约,更换为指数合约 if '99' not in symbol and exchange != Exchange.SPD: - underly_symbol = get_underlying_symbol(symbol) - self.write_log(u'更新vt_symbol为指数合约:{}=>{}'.format(vt_symbol, underly_symbol + '99.'+ exchange.value)) + underly_symbol = get_underlying_symbol(symbol).upper() + self.write_log(u'更新vt_symbol为指数合约:{}=>{}'.format(vt_symbol, underly_symbol + '99.' + exchange.value)) vt_symbol = underly_symbol.upper() + '99.' + exchange.value strategy_setting.update({'vt_symbol': vt_symbol}) @@ -790,7 +858,7 @@ class PortfolioTestingEngine(object): strategy_setting.update({'backtesting': True}) # 策略参数设置 - setting = strategy_setting.get('setting',{}) + setting = strategy_setting.get('setting', {}) # 创建实例 strategy = strategy_class(self, strategy_name, vt_symbol, setting) @@ -799,7 +867,7 @@ class PortfolioTestingEngine(object): self.strategies.update({strategy_name: strategy}) # 更新vt_symbol合约与策略的订阅关系 - self.subscribe_symbol(strategy_name=strategy_name,vt_symbol=vt_symbol) + self.subscribe_symbol(strategy_name=strategy_name, vt_symbol=vt_symbol) if strategy_setting.get('auto_init', False): self.write_log(u'自动初始化策略') @@ -813,7 +881,6 @@ class PortfolioTestingEngine(object): # 创建策略实例的资金K线 self.create_fund_kline(name=strategy_name, use_renko=False) - def subscribe_symbol(self, strategy_name: str, vt_symbol: str, gateway_name: str = '', is_bar: bool = False): """订阅合约""" strategy = self.strategies.get(strategy_name, None) @@ -844,7 +911,43 @@ class PortfolioTestingEngine(object): order_type: OrderType = OrderType.LIMIT, gateway_name: str = None): """发单""" + price_tick = self.get_price_tick(vt_symbol) + price = round_to(price, price_tick) + if stop: + return self.send_local_stop_order( + strategy=strategy, + vt_symbol=vt_symbol, + direction=direction, + offset=offset, + price=price, + volume=volume, + lock=lock, + gateway_name=gateway_name + ) + else: + return self.send_limit_order( + strategy=strategy, + vt_symbol=vt_symbol, + direction=direction, + offset=offset, + price=price, + volume=volume, + lock=lock, + gateway_name=gateway_name + ) + + def send_limit_order(self, + strategy: CtaTemplate, + vt_symbol: str, + direction: Direction, + offset: Offset, + price: float, + volume: float, + lock: bool, + order_type: OrderType = OrderType.LIMIT, + gateway_name: str = None + ): self.limit_order_count += 1 order_id = str(self.limit_order_count) symbol, exchange = extract_vt_symbol(vt_symbol) @@ -869,62 +972,244 @@ class PortfolioTestingEngine(object): self.limit_orders[order.vt_orderid] = order self.order_strategy_dict.update({order.vt_orderid: strategy}) - self.write_log( - u'{},{},{},p:{},v:{},ref:[{}]'.format(vt_symbol, direction, offset, price, volume, order.vt_orderid)) + self.write_log(f'创建限价单:{order.__dict__}') return [order.vt_orderid] - def cancel_order(self, vt_orderid): + def send_local_stop_order( + self, + strategy: CtaTemplate, + vt_symbol: str, + direction: Direction, + offset: Offset, + price: float, + volume: float, + lock: bool, + gateway_name: str = None): + + """""" + self.stop_order_count += 1 + + stop_order = StopOrder( + vt_symbol=vt_symbol, + direction=direction, + offset=offset, + price=price, + volume=volume, + stop_orderid=f"{STOPORDER_PREFIX}.{self.stop_order_count}", + strategy_name=strategy.strategy_name, + ) + self.write_log(f'创建本地停止单:{stop_order.__dict__}') + self.order_strategy_dict.update({stop_order.stop_orderid: strategy}) + + self.active_stop_orders[stop_order.stop_orderid] = stop_order + self.stop_orders[stop_order.stop_orderid] = stop_order + + return [stop_order.stop_orderid] + + def cancel_order(self, strategy: CtaTemplate, vt_orderid: str): """撤单""" + if vt_orderid.startswith(STOPORDER_PREFIX): + return self.cancel_stop_order(strategy, vt_orderid) + else: + return self.cancel_limit_order(strategy, vt_orderid) + + def cancel_limit_order(self, strategy: CtaTemplate, vt_orderid: str): + """限价单撤单""" if vt_orderid in self.active_limit_orders: order = self.active_limit_orders[vt_orderid] - strategy = self.order_strategy_dict.get(vt_orderid, None) + register_strategy = self.order_strategy_dict.get(vt_orderid, None) + if register_strategy.strategy_name != strategy.strategy_name: + return False order.status = Status.CANCELLED order.cancelTime = str(self.last_dt) self.active_limit_orders.pop(vt_orderid, None) - if strategy: - strategy.on_order(order) + strategy.on_order(order) + return True + return False - def cancel_orders(self, vt_symbol: str = None, offset: Offset = None): + def cancel_stop_order(self, strategy: CtaTemplate, vt_orderid: str): + """本地停止单撤单""" + if vt_orderid not in self.active_stop_orders: + return False + stop_order = self.active_stop_orders.pop(vt_orderid) + + stop_order.status = StopOrderStatus.CANCELLED + strategy.on_stop_order(stop_order) + return True + + def cancel_all(self, strategy): + """撤销某个策略的所有委托单""" + self.cancel_orders(strategy=strategy) + + def cancel_orders(self, vt_symbol: str = None, offset: Offset = None, strategy: CtaTemplate = None): """撤销所有单""" # Symbol参数:指定合约的撤单; # OFFSET参数:指定Offset的撤单,缺省不填写时,为所有 + # strategy参数: 指定某个策略的单子 + if len(self.active_limit_orders) > 0: - self.write_log(u'从所有订单中撤销{0}\{1}'.format(offset, vt_symbol if vt_symbol is not None else u'所有')) + self.write_log(u'从所有订单中,撤销:开平:{},合约:{},策略:{}' + .format(offset, + vt_symbol if vt_symbol is not None else u'所有', + strategy.strategy_name if strategy else None)) for vt_orderid in list(self.active_limit_orders.keys()): order = self.active_limit_orders.get(vt_orderid, None) - strategy = self.order_strategy_dict.get(vt_orderid, None) - if order is None or strategy is None: + order_strategy = self.order_strategy_dict.get(vt_orderid, None) + if order is None or order_strategy is None: continue if offset is None: - offsetCond = True + offset_cond = True else: - offsetCond = order.offset == offset + offset_cond = order.offset == offset if vt_symbol is None: symbol_cond = True else: symbol_cond = order.vt_symbol == vt_symbol - if symbol_cond and offsetCond: + + if strategy is None: + strategy_cond = True + else: + strategy_cond = strategy.strategy_name == order_strategy.strategy_name + + if offset_cond and symbol_cond and strategy_cond: self.write_log( - u'撤销订单:{0},{1} {2}@{3}'.format(vt_orderid, order.direction, order.price, order.volume)) + u'撤销订单:{},{} {}@{}'.format(vt_orderid, order.direction, order.price, order.volume)) order.status = Status.CANCELLED - order.cancelTime = str(self.last_dt) + order.cancel_time = str(self.last_dt) del self.active_limit_orders[vt_orderid] if strategy: - strategy.onOrder(order) + strategy.on_order(order) - def send_stop_order(self, vt_symbol, orderType, price, volume, strategy): - """发停止单(本地实现)""" + for stop_orderid in list(self.active_stop_orders.keys()): + order = self.active_stop_orders.get(stop_orderid, None) + order_strategy = self.order_strategy_dict.get(stop_orderid, None) + if order is None or order_strategy is None: + continue - self.write_error(u'暂不支持本地停止单功能') - return '' + if offset is None: + offset_cond = True + else: + offset_cond = order.offset == offset - def cancel_stop_order(self, stopOrderID): - """撤销停止单""" - pass + if vt_symbol is None: + symbol_cond = True + else: + symbol_cond = order.vt_symbol == vt_symbol + + if strategy is None: + strategy_cond = True + else: + strategy_cond = strategy.strategy_name == order_strategy.strategy_name + + if offset_cond and symbol_cond and strategy_cond: + self.write_log( + u'撤销本地停止单:{},{} {}@{}'.format(stop_orderid, order.direction, order.price, order.volume)) + order.status = Status.CANCELLED + order.cancel_time = str(self.last_dt) + self.active_stop_orders.pop(stop_orderid, None) + if strategy: + strategy.on_stop_order(order) + + def cross_stop_order(self, bar): + """ + Cross stop order with last bar/tick data. + """ + vt_symbol = bar.vt_symbol + for stop_orderid in list(self.active_stop_orders.keys()): + stop_order = self.active_stop_orders[stop_orderid] + strategy = self.order_strategy_dict.get(stop_orderid, None) + if stop_order.vt_symbol != vt_symbol or stop_order is None or strategy is None: + continue + # 若买入方向停止单价格高于等于该价格,则会触发 + long_cross_price = round_to(value=bar.low_price, target=self.get_price_tick(vt_symbol)) + long_cross_price -= self.get_price_tick(vt_symbol) + + # 若卖出方向停止单价格低于等于该价格,则会触发 + short_cross_price = round_to(value=bar.high_price, target=self.get_price_tick(vt_symbol)) + short_cross_price += self.get_price_tick(vt_symbol) + + # 在当前时间点前发出的买入委托可能的最优成交价 + long_best_price = round_to(value=bar.open_price, + target=self.get_price_tick(vt_symbol)) + self.get_price_tick(vt_symbol) + + # 在当前时间点前发出的卖出委托可能的最优成交价 + short_best_price = round_to(value=bar.open_price, + target=self.get_price_tick(vt_symbol)) - self.get_price_tick(vt_symbol) + + # Check whether stop order can be triggered. + long_cross = ( + stop_order.direction == Direction.LONG + and stop_order.price <= long_cross_price + ) + + short_cross = ( + stop_order.direction == Direction.SHORT + and stop_order.price >= short_cross_price + ) + + if not long_cross and not short_cross: + continue + + # Create order data. + self.limit_order_count += 1 + symbol, exchange = extract_vt_symbol(vt_symbol) + order = OrderData( + symbol=symbol, + exchange=exchange, + orderid=str(self.limit_order_count), + direction=stop_order.direction, + offset=stop_order.offset, + price=stop_order.price, + volume=stop_order.volume, + status=Status.ALLTRADED, + gateway_name=self.gateway_name, + ) + order.datetime = self.last_dt + self.write_log(f'停止单被触发:\n{stop_order.__dict__}\n=>委托单{order.__dict__}') + self.limit_orders[order.vt_orderid] = order + + # Create trade data. + if long_cross: + trade_price = max(stop_order.price, long_best_price) + else: + trade_price = min(stop_order.price, short_best_price) + + self.trade_count += 1 + + trade = TradeData( + symbol=order.symbol, + exchange=order.exchange, + orderid=order.orderid, + tradeid=str(self.trade_count), + direction=order.direction, + offset=order.offset, + price=trade_price, + volume=order.volume, + time=self.last_dt.strftime("%Y-%m-%d %H:%M:%S"), + datetime=self.last_dt, + gateway_name=self.gateway_name, + ) + trade.strategy_name = strategy.strategy_name + trade.datetime = self.last_dt + self.write_log(f'停止单触发成交:{trade.__dict__}') + self.trade_dict[trade.vt_tradeid] = trade + self.trades[trade.vt_tradeid] = copy.copy(trade) + + # Update stop order. + stop_order.vt_orderids.append(order.vt_orderid) + stop_order.status = StopOrderStatus.TRIGGERED + + self.active_stop_orders.pop(stop_order.stop_orderid) + + # Push update to strategy. + strategy.on_stop_order(stop_order) + strategy.on_order(order) + self.append_trade(trade) + strategy.on_trade(trade) def cross_limit_order(self, bar): """基于最新数据撮合限价单""" @@ -932,9 +1217,8 @@ class PortfolioTestingEngine(object): vt_symbol = bar.vt_symbol # 遍历限价单字典中的所有限价单 - workingLimitOrderDictClone = copy.deepcopy(self.active_limit_orders) - for orderID, order in list(workingLimitOrderDictClone.items()): - + for vt_orderid in list(self.active_limit_orders.keys()): + order = self.active_limit_orders.get(vt_orderid, None) if order.vt_symbol != vt_symbol: continue @@ -943,25 +1227,25 @@ class PortfolioTestingEngine(object): self.write_error(u'找不到vt_orderid:{}对应的策略'.format(order.vt_orderid)) continue - buyCrossPrice = round_to(value=bar.low, - target=self.get_price_tick(vt_symbol)) + self.get_price_tick( + buy_cross_price = round_to(value=bar.low_price, + target=self.get_price_tick(vt_symbol)) + self.get_price_tick( vt_symbol) # 若买入方向限价单价格高于该价格,则会成交 - sellCrossPrice = round_to(value=bar.high, - target=self.get_price_tick(vt_symbol)) - self.get_price_tick( + sell_cross_price = round_to(value=bar.high_price, + target=self.get_price_tick(vt_symbol)) - self.get_price_tick( vt_symbol) # 若卖出方向限价单价格低于该价格,则会成交 - buyBestCrossPrice = round_to(value=bar.open, - target=self.get_price_tick(vt_symbol)) + self.get_price_tick( + buy_best_cross_price = round_to(value=bar.open_price, + target=self.get_price_tick(vt_symbol)) + self.get_price_tick( vt_symbol) # 在当前时间点前发出的买入委托可能的最优成交价 - sellBestCrossPrice = round_to(value=bar.open, - target=self.get_price_tick(vt_symbol)) - self.get_price_tick( + sell_best_cross_price = round_to(value=bar.open_price, + target=self.get_price_tick(vt_symbol)) - self.get_price_tick( vt_symbol) # 在当前时间点前发出的卖出委托可能的最优成交价 # 判断是否会成交 - buyCross = order.direction == Direction.LONG and order.price >= buyCrossPrice - sellCross = order.direction == Direction.SHORT and order.price <= sellCrossPrice + buy_cross = order.direction == Direction.LONG and order.price >= buy_cross_price + sell_cross = order.direction == Direction.SHORT and order.price <= sell_cross_price # 如果发生了成交 - if buyCross or sellCross: + if buy_cross or sell_cross: # 推送成交数据 self.trade_count += 1 # 成交编号自增1 @@ -976,27 +1260,29 @@ class PortfolioTestingEngine(object): direction=order.direction, offset=order.offset, volume=order.volume, - time=str(self.last_dt) + time=self.last_dt.strftime("%Y-%m-%d %H:%M:%S"), + datetime=self.last_dt ) # 以买入为例: # 1. 假设当根K线的OHLC分别为:100, 125, 90, 110 # 2. 假设在上一根K线结束(也是当前K线开始)的时刻,策略发出的委托为限价105 # 3. 则在实际中的成交价会是100而不是105,因为委托发出时市场的最优价格是100 - if buyCross: - trade_rice = min(order.price, buyBestCrossPrice) + if buy_cross: + trade_price = min(order.price, buy_best_cross_price) else: - trade_price = max(order.price, sellBestCrossPrice) + trade_price = max(order.price, sell_best_cross_price) + trade.price = trade_price # 记录该合约来自哪个策略实例 - trade.strategy = strategy.name + trade.strategy_name = strategy.strategy_name - strategy.onTrade(trade) + strategy.on_trade(trade) for cov_trade in self.convert_spd_trade(trade): self.trade_dict[cov_trade.vt_tradeid] = cov_trade - self.trades[cov_trade.vt_tradeid] = cov_trade + self.trades[cov_trade.vt_tradeid] = copy.copy(cov_trade) self.write_log(u'vt_trade_id:{0}'.format(cov_trade.vt_tradeid)) # 更新持仓缓存数据 @@ -1005,16 +1291,16 @@ class PortfolioTestingEngine(object): pos_buffer = PositionHolding(self.get_contract(vt_symbol)) self.pos_holding_dict[cov_trade.vt_symbol] = pos_buffer pos_buffer.update_trade(cov_trade) - self.write_log(u'{} : crossLimitOrder: TradeId:{}, posBuffer = {}'.format(cov_trade.strategy, - cov_trade.tradeID, - pos_buffer.toStr())) + self.write_log(u'{} : crossLimitOrder: TradeId:{}, posBuffer = {}'.format(cov_trade.strategy_name, + cov_trade.tradeid, + pos_buffer.to_str())) # 写入交易记录 self.append_trade(cov_trade) # 更新资金曲线 if 'SPD' not in cov_trade.vt_symbol: - fund_kline = self.get_fund_kline(cov_trade.strategy) + fund_kline = self.get_fund_kline(cov_trade.strategy_name) if fund_kline: fund_kline.update_trade(cov_trade) @@ -1025,10 +1311,7 @@ class PortfolioTestingEngine(object): strategy.on_order(order) # 从字典中删除该限价单 - try: - del self.active_limit_orders[orderID] - except Exception as ex: - self.write_error(u'crossLimitOrder exception:{},{}'.format(str(ex), traceback.format_exc())) + self.active_limit_orders.pop(vt_orderid, None) # 实时计算模式 self.realtime_calculate() @@ -1057,7 +1340,8 @@ class PortfolioTestingEngine(object): strategy_name=trade.strategy_name, price=self.get_price(active_vt_symbol), volume=int(trade.volume * active_rate), - time=trade.time + time=trade.time, + datetime=trade.datetime ) # 被动腿成交记录 @@ -1069,7 +1353,9 @@ class PortfolioTestingEngine(object): tradeid='spd_pas_' + str(trade.tradeid), direction=Direction.LONG if trade.direction == Direction.SHORT else Direction.SHORT, offset=trade.offset, - strategy_name=trade.strategy_name + strategy_name=trade.strategy_name, + time=trade.time, + datetime=trade.datetime ) # 根据套利合约的类型+主合约的价格,反向推导出被动合约的价格 @@ -1126,7 +1412,7 @@ class PortfolioTestingEngine(object): if self.logs_path is not None: logs_folder = self.logs_path else: - logs_folder = os.path.abspath(os.path.join(os.getcwd(), 'logs')) + logs_folder = os.path.abspath(os.path.join(os.getcwd(), 'log')) self.logs_path = logs_folder if not os.path.exists(logs_folder): @@ -1144,9 +1430,10 @@ class PortfolioTestingEngine(object): if strategy_name is None: filename = os.path.abspath(os.path.join(self.get_logs_path(), '{}'.format( self.test_name if len(self.test_name) > 0 else 'portfolio_test'))) + print(u'create logger:{}'.format(filename)) self.logger = setup_logger(file_name=filename, name=self.test_name, - level=logging.DEBUG if debug else logging.ERROR, + log_level=logging.DEBUG if debug else logging.ERROR, backtesing=True) else: filename = os.path.abspath( @@ -1154,10 +1441,10 @@ class PortfolioTestingEngine(object): print(u'create logger:{}'.format(filename)) self.strategy_loggers[strategy_name] = setup_logger(file_name=filename, name=str(strategy_name), - level=logging.DEBUG if debug else logging.ERROR, + log_level=logging.DEBUG if debug else logging.ERROR, backtesing=True) - def write_log(self, content, strategy_name=None): + def write_log(self, msg: str, strategy_name: str = None, level: int = logging.DEBUG): """记录日志""" # log = str(self.datetime) + ' ' + content # self.logList.append(log) @@ -1165,39 +1452,39 @@ class PortfolioTestingEngine(object): if strategy_name is None: # 写入本地log日志 if self.logger: - self.logger.info(content) + self.logger.log(msg=msg, level=level) else: - self.create_logger() + self.create_logger(debug=self.debug) else: if strategy_name in self.strategy_loggers: - self.strategy_loggers[strategy_name].info(content) + self.strategy_loggers[strategy_name].log(msg=msg, level=level) else: - self.create_logger(strategy_name=strategy_name) + self.create_logger(strategy_name=strategy_name, debug=self.debug) - def write_error(self, content, strategy_name=None): + def write_error(self, msg, strategy_name=None): """记录异常""" if strategy_name is None: if self.logger: - self.logger.error(content) + self.logger.error(msg) else: - self.create_logger() + self.create_logger(debug=self.debug) else: if strategy_name in self.strategy_loggers: - self.strategy_loggers[strategy_name].error(content) + self.strategy_loggers[strategy_name].error(msg) else: - self.create_logger(strategy_name=strategy_name) + self.create_logger(strategy_name=strategy_name, debug=self.debug) try: - self.strategy_loggers[strategy_name].error(content) + self.strategy_loggers[strategy_name].error(msg) except Exception as ex: print('{}'.format(datetime.now()), file=sys.stderr) print('could not create cta logger for {},excption:{},trace:{}'.format(strategy_name, str(ex), traceback.format_exc())) - print(content, file=sys.stderr) + print(msg, file=sys.stderr) def output(self, content): """输出内容""" - print(str(datetime.now()) + "\t" + content) + print(self.test_name + "\t" + content) def realtime_calculate(self): """实时计算交易结果 @@ -1244,14 +1531,15 @@ class PortfolioTestingEngine(object): cur_short_pos_list = [s_pos.volume for s_pos in self.short_position_list] - self.write_log(u'当前空单:{}'.format(cur_short_pos_list)) - + self.write_log(u'{}当前空单:{}'.format(trade.vt_symbol, cur_short_pos_list)) + if len(cur_short_pos_list) > 3: + a = 1 # 来自同一策略,同一合约才能撮合 pop_indexs = [i for i, val in enumerate(self.short_position_list) if - val.vt_symbol == trade.vt_symbol and val.strategy == trade.strategy] + val.vt_symbol == trade.vt_symbol and val.strategy_name == trade.strategy_name] if len(pop_indexs) < 1: - self.write_error(u'异常,{}没有对应symbol:{}的空单持仓'.format(trade.strategy, trade.vt_symbol)) + self.write_error(u'异常,{}没有对应symbol:{}的空单持仓'.format(trade.strategy_name, trade.vt_symbol)) raise Exception(u'realtimeCalculate2() Exception,没有对应symbol:{0}的空单持仓'.format(trade.vt_symbol)) return @@ -1282,12 +1570,12 @@ class PortfolioTestingEngine(object): t = OrderedDict() t['gid'] = g_id - t['strategy'] = open_trade.strategy + t['strategy'] = open_trade.strategy_name t['vt_symbol'] = open_trade.vt_symbol - t['open_time'] = open_trade.tradeTime + t['open_time'] = open_trade.time t['open_price'] = open_trade.price t['direction'] = u'Short' - t['close_time'] = trade.tradeTime + t['close_time'] = trade.time t['close_price'] = trade.price t['volume'] = open_trade.volume t['profit'] = result.pnl @@ -1298,7 +1586,8 @@ class PortfolioTestingEngine(object): if not open_trade.vt_symbol.endswith('SPD'): # 更新策略实例的累加盈亏 self.pnl_strategy_dict.update( - {open_trade.strategy: self.pnl_strategy_dict.get(open_trade.strategy, 0) + result.pnl}) + {open_trade.strategy_name: self.pnl_strategy_dict.get(open_trade.strategy_name, + 0) + result.pnl}) msg = u'gid:{} {}[{}:开空tid={}:{}]-[{}.平空tid={},{},vol:{}],净盈亏pnl={},手续费:{}' \ .format(g_id, open_trade.vt_symbol, open_trade.time, shortid, open_trade.price, @@ -1343,12 +1632,12 @@ class PortfolioTestingEngine(object): t = OrderedDict() t['gid'] = g_id - t['strategy'] = open_trade.strategy + t['strategy'] = open_trade.strategy_name t['vt_symbol'] = open_trade.vt_symbol - t['open_time'] = open_trade.tradeTime + t['open_time'] = open_trade.time t['open_price'] = open_trade.price t['direction'] = u'Short' - t['close_time'] = trade.tradeTime + t['close_time'] = trade.time t['close_price'] = trade.price t['volume'] = cover_volume t['profit'] = result.pnl @@ -1359,11 +1648,12 @@ class PortfolioTestingEngine(object): if not (open_trade.vt_symbol.endswith('SPD') or open_trade.vt_symbol.endswith('SPD99')): # 更新策略实例的累加盈亏 self.pnl_strategy_dict.update( - {open_trade.strategy: self.pnl_strategy_dict.get(open_trade.strategy, 0) + result.pnl}) + {open_trade.strategy_name: self.pnl_strategy_dict.get(open_trade.strategy_name, + 0) + result.pnl}) msg = u'gid:{} {}[{}:开空tid={}:{}]-[{}.平空tid={},{},vol:{}],净盈亏pnl={},手续费:{}' \ - .format(g_id, open_trade.vt_symbol, open_trade.tradeTime, shortid, open_trade.price, - trade.tradeTime, vt_tradeid, trade.price, + .format(g_id, open_trade.vt_symbol, open_trade.time, shortid, open_trade.price, + trade.time, vt_tradeid, trade.price, cover_volume, result.pnl, result.commission) self.write_log(msg) @@ -1409,13 +1699,19 @@ class PortfolioTestingEngine(object): return pop_indexs = [i for i, val in enumerate(self.long_position_list) if - val.vt_symbol == trade.vt_symbol and val.strategy == trade.strategy] + val.vt_symbol == trade.vt_symbol and val.strategy_name == trade.strategy_name] if len(pop_indexs) < 1: - self.write_error(f'没有{trade.strategy}对应的symbol{trade.vt_symbol}多单数据,') + self.write_error(f'没有{trade.strategy_name}对应的symbol{trade.vt_symbol}多单数据,') raise RuntimeError( f'realtimeCalculate2() Exception,没有对应的symbol{trade.vt_symbol}多单数据,') return + cur_long_pos_list = [s_pos.volume for s_pos in self.long_position_list] + + self.write_log(u'{}当前多单:{}'.format(trade.vt_symbol, cur_long_pos_list)) + if len(cur_long_pos_list) > 3: + a = 1 + pop_index = pop_indexs[0] open_trade = self.long_position_list.pop(pop_index) # 开多volume,不大于平仓volume @@ -1438,12 +1734,12 @@ class PortfolioTestingEngine(object): t = OrderedDict() t['gid'] = g_id - t['strategy'] = open_trade.strategy + t['strategy'] = open_trade.strategy_name t['vt_symbol'] = open_trade.vt_symbol - t['open_time'] = open_trade.tradeTime + t['open_time'] = open_trade.time t['open_price'] = open_trade.price t['direction'] = u'Long' - t['close_time'] = trade.tradeTime + t['close_time'] = trade.time t['close_price'] = trade.price t['volume'] = open_trade.volume t['profit'] = result.pnl @@ -1454,12 +1750,13 @@ class PortfolioTestingEngine(object): if not (open_trade.vt_symbol.endswith('SPD') or open_trade.vt_symbol.endswith('SPD99')): # 更新策略实例的累加盈亏 self.pnl_strategy_dict.update( - {open_trade.strategy: self.pnl_strategy_dict.get(open_trade.strategy, 0) + result.pnl}) + {open_trade.strategy_name: self.pnl_strategy_dict.get(open_trade.strategy_name, + 0) + result.pnl}) msg = u'gid:{} {}[{}:开多tid={}:{}]-[{}.平多tid={},{},vol:{}],净盈亏pnl={},手续费:{}' \ .format(g_id, open_trade.vt_symbol, - open_trade.tradeTime, longid, open_trade.price, - trade.tradeTime, vt_tradeid, trade.price, + open_trade.time, longid, open_trade.price, + trade.time, vt_tradeid, trade.price, open_trade.volume, result.pnl, result.commission) self.write_log(msg) @@ -1497,12 +1794,12 @@ class PortfolioTestingEngine(object): t = OrderedDict() t['gid'] = g_id - t['strategy'] = open_trade.strategy + t['strategy'] = open_trade.strategy_name t['vt_symbol'] = open_trade.vt_symbol - t['open_time'] = open_trade.tradeTime + t['open_time'] = open_trade.time t['open_price'] = open_trade.price t['direction'] = u'Long' - t['close_time'] = trade.tradeTime + t['close_time'] = trade.time t['close_price'] = trade.price t['volume'] = sell_volume t['profit'] = result.pnl @@ -1513,11 +1810,12 @@ class PortfolioTestingEngine(object): if not (open_trade.vt_symbol.endswith('SPD') or open_trade.vt_symbol.endswith('SPD99')): # 更新策略实例的累加盈亏 self.pnl_strategy_dict.update( - {open_trade.strategy: self.pnl_strategy_dict.get(open_trade.strategy, 0) + result.pnl}) + {open_trade.strategy_name: self.pnl_strategy_dict.get(open_trade.strategy_name, + 0) + result.pnl}) msg = u'Gid:{} {}[{}:开多tid={}:{}]-[{}.平多tid={},{},vol:{}],净盈亏pnl={},手续费:{}' \ - .format(g_id, open_trade.vt_symbol, open_trade.tradeTime, longid, open_trade.price, - trade.tradeTime, vt_tradeid, trade.price, sell_volume, result.pnl, + .format(g_id, open_trade.vt_symbol, open_trade.time, longid, open_trade.price, + trade.time, vt_tradeid, trade.price, sell_volume, result.pnl, result.commission) self.write_log(msg) @@ -1698,7 +1996,7 @@ class PortfolioTestingEngine(object): today_holding_profit += holding_profit # 计算每个策略实例的持仓盈亏 - strategy_pnl.update({longpos.strategy: strategy_pnl.get(longpos.strategy, 0) + holding_profit}) + strategy_pnl.update({longpos.strategy_name: strategy_pnl.get(longpos.strategy_name, 0) + holding_profit}) positionMsg += "{},long,p={},v={},m={};".format(symbol, longpos.price, longpos.volume, holding_profit) @@ -1709,7 +2007,7 @@ class PortfolioTestingEngine(object): symbol = shortpos.vt_symbol # 计算持仓浮盈浮亏/占用保证金 holding_profit = 0 - last_price = self.get_price(symbol, None) + last_price = self.get_price(symbol) if last_price is not None: holding_profit = (shortpos.price - last_price) * shortpos.volume * self.get_size(symbol) short_pos_occupy_money += last_price * abs(shortpos.volume) * self.get_size( @@ -1718,7 +2016,7 @@ class PortfolioTestingEngine(object): # 账号的持仓盈亏 today_holding_profit += holding_profit # 计算每个策略实例的持仓盈亏 - strategy_pnl.update({shortpos.strategy: strategy_pnl.get(shortpos.strategy, 0) + holding_profit}) + strategy_pnl.update({shortpos.strategy_name: strategy_pnl.get(shortpos.strategy_name, 0) + holding_profit}) positionMsg += "{},short,p={},v={},m={};".format(symbol, shortpos.price, shortpos.volume, holding_profit) @@ -1752,7 +2050,7 @@ class PortfolioTestingEngine(object): positionMsg)) # --------------------------------------------------------------------- - def export_trade_result(self): + def export_trade_result(self, is_plot_daily=False): """ 导出交易结果(开仓-》平仓, 平仓收益) 导出每日净值结果表 @@ -1762,7 +2060,6 @@ class PortfolioTestingEngine(object): self.write_log('no traded records') return - s = '' s = self.test_name.replace('&', '') s = s.replace(' ', '') trade_list_csv_file = os.path.abspath(os.path.join(self.get_logs_path(), '{}_trade_list.csv'.format(s))) @@ -1800,6 +2097,16 @@ class PortfolioTestingEngine(object): for row in self.daily_list: writer2.writerow(row) + if is_plot_daily: + # 生成净值曲线图片 + df = pd.DataFrame(self.daily_list) + df = df.set_index('date') + from vnpy.trader.utility import display_dual_axis + plot_file = os.path.abspath(os.path.join(self.get_logs_path(), '{}_plot.png'.format(s))) + + # 双坐标输出,左侧坐标是净值(比率),右侧是各策略的实际资金收益曲线 + display_dual_axis(df=df, columns1=['rate'], columns2=list(self.strategies.keys()), image_name=plot_file) + return def get_result(self): @@ -1847,17 +2154,17 @@ class PortfolioTestingEngine(object): return {}, [], [] capital_net_list = [] - capitalList = [] + capital_list = [] for row in self.daily_list: capital_net_list.append(row['net']) - capitalList.append(row['capital']) + capital_list.append(row['capital']) capital = pd.Series(capital_net_list) log_returns = np.log(capital).diff().fillna(0) sharpe = (log_returns.mean() * 252) / (log_returns.std() * np.sqrt(252)) d['sharpe'] = sharpe - return d, capital_net_list, capitalList + return d, capital_net_list, capital_list def show_backtesting_result(self, is_plot_daily=False): """显示回测结果""" @@ -1869,17 +2176,17 @@ class PortfolioTestingEngine(object): return {}, '' # 导出交易清单 - self.export_trade_result() + self.export_trade_result(is_plot_daily) result_info = OrderedDict() # 输出 self.output('-' * 30) - result_info.update({u'第一笔交易': str(d['timeList'][0])}) - self.output(u'第一笔交易:\t%s' % d['timeList'][0]) + result_info.update({u'第一笔交易': str(d['time_list'][0])}) + self.output(u'第一笔交易:\t%s' % d['time_list'][0]) - result_info.update({u'最后一笔交易': str(d['timeList'][-1])}) - self.output(u'最后一笔交易:\t%s' % d['timeList'][-1]) + result_info.update({u'最后一笔交易': str(d['time_list'][-1])}) + self.output(u'最后一笔交易:\t%s' % d['time_list'][-1]) result_info.update({u'总交易次数': d['total_trade_count']}) self.output(u'总交易次数:\t%s' % format_number(d['total_trade_count'])) @@ -1902,11 +2209,10 @@ class PortfolioTestingEngine(object): result_info.update({u'每笔最大亏损': d['min_pnl']}) self.output(u'每笔最大亏损:\t%s' % format_number(d['min_pnl'])) - result_info.update({u'净值最大回撤': min(d['drawdown_dist'])}) - self.output(u'净值最大回撤: \t%s' % format_number(min(d['drawdown_dist']))) + result_info.update({u'净值最大回撤': min(d['drawdown_list'])}) + self.output(u'净值最大回撤: \t%s' % format_number(min(d['drawdown_list']))) result_info.update({u'净值最大回撤率': self.daily_max_drawdown_rate}) - # self.writeCtaNotification(u'净值最大回撤率: \t%s' % formatNumber(max(d['drawdownRateList']))) self.output(u'净值最大回撤率: \t%s' % format_number(self.daily_max_drawdown_rate)) result_info.update({u'净值最大回撤时间': str(self.max_drawdown_rate_time)}) @@ -1927,8 +2233,8 @@ class PortfolioTestingEngine(object): result_info.update({u'最大资金占比': d['max_occupy_rate']}) self.output(u'最大资金占比:\t%s' % format_number(d['max_occupy_rate'])) - result_info.update({u'平均每笔盈利': d['capital'] / d['total_trade_count']}) - self.output(u'平均每笔盈利:\t%s' % format_number(d['capital'] / d['total_trade_count'])) + result_info.update({u'平均每笔盈利': d['profit'] / d['total_trade_count']}) + self.output(u'平均每笔盈利:\t%s' % format_number(d['profit'] / d['total_trade_count'])) result_info.update({u'平均每笔滑点成本': d['total_slippage'] / d['total_trade_count']}) self.output(u'平均每笔滑点成本:\t%s' % format_number(d['total_slippage'] / d['total_trade_count'])) @@ -1970,10 +2276,7 @@ class PortfolioTestingEngine(object): for k in trade_fields: d[k] = getattr(trade, k, '') - trade_folder = os.path.abspath(os.path.join(self.get_logs_path(), self.test_name)) - if not os.path.exists(trade_folder): - os.makedirs(trade_folder) - trade_file = os.path.abspath(os.path.join(trade_folder, '{}_trade.csv'.format(strategy_name))) + trade_file = os.path.abspath(os.path.join(self.get_logs_path(), '{}_trade.csv'.format(strategy_name))) self.append_data(file_name=trade_file, dict_data=d) except Exception as ex: self.write_error(u'写入交易记录csv出错:{},{}'.format(str(ex), traceback.format_exc())) @@ -2034,3 +2337,33 @@ class TradingResult(object): - self.commission - self.slippage) # 净盈亏 +def single_test(test_setting: dict, strategy_setting: dict): + """ + 单一回测 + : test_setting, 组合回测所需的配置,包括合约信息,数据bar信息,回测时间,资金等。 + :strategy_setting, dict, 一个或多个策略配置 + """ + # 创建事件引擎 + from vnpy.event.engine import EventEngine + event_engine = EventEngine() + event_engine.start() + + # 创建组合回测引擎 + engine = PortfolioTestingEngine(event_engine) + + engine.prepare_env(test_setting) + try: + engine.run_portfolio_test(strategy_setting) + # 回测结果,保存 + result_info = engine.show_backtesting_result(is_plot_daily=test_setting.get('is_plot_daily', False)) + + except Exception as ex: + print('组合回测异常{}'.format(str(ex))) + traceback.print_exc() + return False + + if event_engine: + event_engine.stop() + + print('测试结束') + return True diff --git a/vnpy/app/cta_strategy_pro/strategies/turtle_signal_strategy.py b/vnpy/app/cta_strategy_pro/strategies/turtle_signal_strategy.py index 3fcfc92c..0e79cd6a 100644 --- a/vnpy/app/cta_strategy_pro/strategies/turtle_signal_strategy.py +++ b/vnpy/app/cta_strategy_pro/strategies/turtle_signal_strategy.py @@ -33,7 +33,7 @@ class TurtleSignalStrategy(CtaTemplate): long_stop = 0 short_stop = 0 - parameters = ["x_minuite", "entry_window", "exit_window", "atr_window", "fixed_size"] + parameters = ["x_minute", "entry_window", "exit_window", "atr_window", "fixed_size"] variables = ["entry_up", "entry_down", "exit_up", "exit_down", "atr_value"] def __init__(self, cta_engine, strategy_name, vt_symbol, setting): @@ -52,7 +52,7 @@ class TurtleSignalStrategy(CtaTemplate): Callback when strategy is inited. """ self.write_log("策略初始化") - #self.load_bar(20) + # self.load_bar(20) def on_start(self): """ diff --git a/vnpy/app/cta_strategy_pro/strategies/turtle_signal_strategy_v2.py b/vnpy/app/cta_strategy_pro/strategies/turtle_signal_strategy_v2.py index 03680104..56830643 100644 --- a/vnpy/app/cta_strategy_pro/strategies/turtle_signal_strategy_v2.py +++ b/vnpy/app/cta_strategy_pro/strategies/turtle_signal_strategy_v2.py @@ -11,6 +11,8 @@ from vnpy.app.cta_strategy_pro import ( ArrayManager, ) +from vnpy.trader.utility import round_to + class TurtleSignalStrategy_v2(CtaTemplate): """""" @@ -22,7 +24,7 @@ class TurtleSignalStrategy_v2(CtaTemplate): atr_window = 20 fixed_size = 1 invest_pos = 1 - invest_percent = 10 # 投资比例 + invest_percent = 10 # 投资比例 entry_up = 0 entry_down = 0 @@ -35,7 +37,7 @@ class TurtleSignalStrategy_v2(CtaTemplate): long_stop = 0 short_stop = 0 - parameters = ["x_minuite", "entry_window", "exit_window", "atr_window", "fixed_size"] + parameters = ["x_minute", "entry_window", "exit_window", "atr_window", "fixed_size"] variables = ["entry_up", "entry_down", "exit_up", "exit_down", "atr_value"] def __init__(self, cta_engine, strategy_name, vt_symbol, setting): @@ -47,6 +49,7 @@ class TurtleSignalStrategy_v2(CtaTemplate): # 获取合约乘数,保证金比例 self.symbol_size = self.cta_engine.get_size(self.vt_symbol) self.symbol_margin_rate = self.cta_engine.get_margin_rate(self.vt_symbol) + self.symbol_price_tick = self.cta_engine.get_price_tick(self.vt_symbol) self.bg = BarGenerator(self.on_bar, window=self.x_minute) self.am = ArrayManager() @@ -58,7 +61,7 @@ class TurtleSignalStrategy_v2(CtaTemplate): Callback when strategy is inited. """ self.write_log("策略初始化") - #self.load_bar(20) + # self.load_bar(20) def on_start(self): """ @@ -98,8 +101,12 @@ class TurtleSignalStrategy_v2(CtaTemplate): self.exit_up, self.exit_down = self.am.donchian(self.exit_window) + if bar.datetime.strftime('%Y-%m-%d %H') == '2016-03-07 09': + a = 1 # noqa + if not self.pos: self.atr_value = self.am.atr(self.atr_window) + self.atr_value = max(4 * self.symbol_price_tick, self.atr_value) self.long_entry = 0 self.short_entry = 0 @@ -112,13 +119,17 @@ class TurtleSignalStrategy_v2(CtaTemplate): self.send_buy_orders(self.entry_up) sell_price = max(self.long_stop, self.exit_down) - self.sell(sell_price, abs(self.pos), True) + refs = self.sell(sell_price, abs(self.pos), True) + if len(refs) > 0: + self.write_log(f'平多委托编号:{refs}') elif self.pos < 0: self.send_short_orders(self.entry_down) cover_price = min(self.short_stop, self.exit_up) - ret = self.cover(cover_price, abs(self.pos), True) + refs = self.cover(cover_price, abs(self.pos), True) + if len(refs) > 0: + self.write_log(f'平空委托编号:{refs}') self.put_event() @@ -161,48 +172,77 @@ class TurtleSignalStrategy_v2(CtaTemplate): def send_buy_orders(self, price): """""" - if self.pos >= 4: - return - if self.cur_mi_price <= price - self.atr_value/2: + if self.cur_mi_price <= price - self.atr_value / 2: return self.update_invest_pos() - t = self.pos / self.invest_pos + t = int(self.pos / self.invest_pos) + + if t >= 4: + return if t < 1: - self.buy(price, self.invest_pos, True) + refs = self.buy(price, self.invest_pos, True) + if len(refs) > 0: + self.write_log(f'买入委托编号:{refs}') - if t < 2: - self.buy(price + self.atr_value * 0.5, self.invest_pos, True) + if t == 1 and self.cur_mi_price > price: + buy_price = round_to(price + self.atr_value * 0.5 , self.symbol_price_tick) + self.write_log(u'发出做多停止单,触发价格为: {}'.format(buy_price)) + refs = self.buy(buy_price, self.invest_pos, True) + if len(refs) > 0: + self.write_log(f'买入委托编号:{refs}') - if t < 3: - self.buy(price + self.atr_value, self.invest_pos, True) + if t == 2 and self.cur_mi_price > price + self.atr_value * 0.5: + buy_price = round_to(price + self.atr_value, self.symbol_price_tick) + self.write_log(u'发出做多停止单,触发价格为: {}'.format(buy_price)) + refs = self.buy(buy_price, self.invest_pos, True) + if len(refs) > 0: + self.write_log(f'买入委托编号:{refs}') - if t < 4: - self.buy(price + self.atr_value * 1.5, self.invest_pos, True) + if t == 3 and self.cur_mi_price > price + self.atr_value: + buy_price = round_to(price + self.atr_value * 1.5, self.symbol_price_tick) + self.write_log(u'发出做多停止单,触发价格为: {}'.format(buy_price)) + refs = self.buy(buy_price, self.invest_pos, True) + if len(refs) > 0: + self.write_log(f'买入委托编号:{refs}') def send_short_orders(self, price): """""" - if self.pos <= -4: - return - if self.cur_mi_price >= price + self.atr_value / 2: return self.update_invest_pos() - t = self.pos / self.invest_pos + t = int(self.pos / self.invest_pos) + + if t <= -4: + return if t > -1: - self.short(price, self.invest_pos, True) + refs = self.short(price, self.invest_pos, True) + if len(refs) > 0: + self.write_log(f'卖出委托编号:{refs}') - if t > -2: - self.short(price - self.atr_value * 0.5, self.invest_pos, True) + if t == -1 and self.cur_mi_price < price: + short_price = round_to(price - self.atr_value * 0.5, self.symbol_price_tick) + self.write_log(u'发出做空停止单,触发价格为: {}'.format(short_price)) + refs = self.short(short_price, self.invest_pos, True) + if len(refs) > 0: + self.write_log(f'卖出委托编号:{refs}') - if t > -3: - self.short(price - self.atr_value, self.invest_pos, True) + if t == -2 and self.cur_mi_price < price + self.atr_value * 0.5: + short_price = round_to(price - self.atr_value, self.symbol_price_tick) + self.write_log(u'发出做空停止单,触发价格为: {}'.format(short_price)) + refs = self.short(short_price, self.invest_pos, True) + if len(refs) > 0: + self.write_log(f'卖出委托编号:{refs}') - if t > -4: - self.short(price - self.atr_value * 1.5, self.invest_pos, True) + if t == -3 and self.cur_mi_price < price + self.atr_value: + short_price = round_to(price - self.atr_value * 1.5, self.symbol_price_tick) + self.write_log(u'发出做空停止单,触发价格为: {}'.format(short_price)) + refs = self.short(short_price, self.invest_pos, True) + if len(refs) > 0: + self.write_log(f'卖出委托编号:{refs}') diff --git a/vnpy/app/cta_strategy_pro/template.py b/vnpy/app/cta_strategy_pro/template.py index b2a0b3b7..049866ee 100644 --- a/vnpy/app/cta_strategy_pro/template.py +++ b/vnpy/app/cta_strategy_pro/template.py @@ -1,40 +1,23 @@ """""" -import sys +import os +import uuid +import bz2 +import pickle +import copy +import traceback + from abc import ABC from copy import copy from typing import Any, Callable from logging import INFO, ERROR -from vnpy.trader.constant import Interval, Direction, Offset +from datetime import datetime +from vnpy.trader.constant import Interval, Direction, Offset, Status from vnpy.trader.object import BarData, TickData, OrderData, TradeData -from vnpy.trader.utility import virtual +from vnpy.trader.utility import virtual, append_data, extract_vt_symbol,get_underlying_symbol from .base import StopOrder, EngineType - - -class CtaComponent(ABC): - """ CTA策略基础组件""" - def __init__(self, strategy=None, **kwargs): - """ - 构造 - :param strategy: - """ - self.strategy = strategy - - # ---------------------------------------------------------------------- - def write_log(self, content: str): - """记录日志""" - if self.strategy: - self.strategy.write_log(msg=content, level=INFO) - else: - print(content) - - # ---------------------------------------------------------------------- - def write_error(self, content: str, level: int = ERROR): - """记录错误日志""" - if self.strategy: - self.strategy.write_log(msg=content, level=level) - else: - print(content, file=sys.stderr) +from .cta_grid_trade import CtaGrid, CtaGridTrade +from .cta_position import CtaPosition class CtaTemplate(ABC): @@ -45,20 +28,28 @@ class CtaTemplate(ABC): variables = [] def __init__( - self, - cta_engine: Any, - strategy_name: str, - vt_symbol: str, - setting: dict, + self, + cta_engine: Any, + strategy_name: str, + vt_symbol: str, + setting: dict, ): """""" self.cta_engine = cta_engine self.strategy_name = strategy_name self.vt_symbol = vt_symbol - self.inited = False - self.trading = False - self.pos = 0 + self.inited = False # 是否初始化完毕 + self.trading = False # 是否开始交易 + self.pos = 0 # 持仓/仓差 + self.entrust = 0 # 是否正在委托, 0, 无委托 , 1, 委托方向是LONG, -1, 委托方向是SHORT + + self.tick_dict = {} # 记录所有on_tick传入最新tick + + # 保存委托单编号和相关委托单的字典 + # key为委托单编号 + # value为该合约相关的委托单 + self.active_orders = {} # Copy a new variables list here to avoid duplicate insert when multiple # strategy instances are created with the same strategy class. @@ -66,8 +57,7 @@ class CtaTemplate(ABC): self.variables.insert(0, "inited") self.variables.insert(1, "trading") self.variables.insert(2, "pos") - - self.update_setting(setting) + self.variables.insert(3, "entrust") def update_setting(self, setting: dict): """ @@ -119,6 +109,23 @@ class CtaTemplate(ABC): } return strategy_data + def get_positions(self): + """ 返回持仓数量""" + pos_list = [] + if self.pos > 0: + pos_list.append({ + "vt_symbol": self.vt_symbol, + "direction": "long", + "volume": self.pos + }) + elif self.pos < 0: + pos_list.append({ + "vt_symbol": self.vt_symbol, + "direction": "short", + "volume": abs(self.pos) + }) + + @virtual def on_timer(self): pass @@ -178,19 +185,24 @@ class CtaTemplate(ABC): """ pass - def buy(self, price: float, volume: float, stop: bool = False, lock: bool = False, vt_symbol: str = ''): + def buy(self, price: float, volume: float, stop: bool = False, lock: bool = False, + vt_symbol: str = '', order_time: datetime = None, grid: CtaGrid = None): """ Send buy order to open a long position. """ + return self.send_order(vt_symbol=vt_symbol, direction=Direction.LONG, offset=Offset.OPEN, price=price, volume=volume, stop=stop, - lock=lock) + lock=lock, + order_time=order_time, + grid=grid) - def sell(self, price: float, volume: float, stop: bool = False, lock: bool = False, vt_symbol: str = ''): + def sell(self, price: float, volume: float, stop: bool = False, lock: bool = False, + vt_symbol: str = '', order_time: datetime = None, grid: CtaGrid = None): """ Send sell order to close a long position. """ @@ -200,9 +212,12 @@ class CtaTemplate(ABC): price=price, volume=volume, stop=stop, - lock=lock) + lock=lock, + order_time=order_time, + grid=grid) - def short(self, price: float, volume: float, stop: bool = False, lock: bool = False, vt_symbol: str = ''): + def short(self, price: float, volume: float, stop: bool = False, lock: bool = False, + vt_symbol: str = '', order_time: datetime = None, grid: CtaGrid = None): """ Send short order to open as short position. """ @@ -212,9 +227,12 @@ class CtaTemplate(ABC): price=price, volume=volume, stop=stop, - lock=lock) + lock=lock, + order_time=order_time, + grid=grid) - def cover(self, price: float, volume: float, stop: bool = False, lock: bool = False, vt_symbol: str = ''): + def cover(self, price: float, volume: float, stop: bool = False, lock: bool = False, + vt_symbol: str = '', order_time: datetime = None, grid: CtaGrid = None): """ Send cover order to close a short position. """ @@ -224,17 +242,21 @@ class CtaTemplate(ABC): price=price, volume=volume, stop=stop, - lock=lock) + lock=lock, + order_time=order_time, + grid=grid) def send_order( - self, - vt_symbol: str, - direction: Direction, - offset: Offset, - price: float, - volume: float, - stop: bool = False, - lock: bool = False + self, + vt_symbol: str, + direction: Direction, + offset: Offset, + price: float, + volume: float, + stop: bool = False, + lock: bool = False, + order_time: datetime = None, + grid: CtaGrid = None ): """ Send a new order. @@ -243,20 +265,44 @@ class CtaTemplate(ABC): if vt_symbol == '': vt_symbol = self.vt_symbol - if self.trading: - vt_orderids = self.cta_engine.send_order( - self, vt_symbol, direction, offset, price, volume, stop, lock - ) - return vt_orderids - else: + if not self.trading: return [] + vt_orderids = self.cta_engine.send_order( + self, vt_symbol, direction, offset, price, volume, stop, lock + ) + + if order_time is None: + order_time = datetime.now() + + for vt_orderid in vt_orderids: + d = { + 'direction': direction.value, + 'offset': offset.value, + 'vt_symbol': vt_symbol, + 'price': price, + 'volume': volume, + 'traded': 0, + 'order_time': order_time, + 'status': Status.SUBMITTING + } + if grid: + d.update({'grid': grid}) + grid.order_ids.append(vt_orderid) + self.active_orders.update({vt_orderid: d}) + if direction == Direction.LONG: + self.entrust = 1 + elif direction == Direction.SHORT: + self.entrust = -1 + return vt_orderids def cancel_order(self, vt_orderid: str): """ Cancel an existing order. """ if self.trading: - self.cta_engine.cancel_order(self, vt_orderid) + return self.cta_engine.cancel_order(self, vt_orderid) + + return False def cancel_all(self): """ @@ -265,12 +311,32 @@ class CtaTemplate(ABC): if self.trading: self.cta_engine.cancel_all(self) + def is_upper_limit(self, symbol): + """是否涨停""" + tick = self.tick_dict.get(symbol, None) + if tick is None or tick.limit_up is None or tick.limit_up == 0: + return False + if tick.bid_price_1 == tick.limit_up: + return True + + def is_lower_limit(self, symbol): + """是否跌停""" + tick = self.tick_dict.get(symbol, None) + if tick is None or tick.limit_down is None or tick.limit_down == 0: + return False + if tick.ask_price_1 == tick.limit_down: + return True + def write_log(self, msg: str, level: int = INFO): """ Write a log message. """ self.cta_engine.write_log(msg=msg, strategy_name=self.strategy_name, level=level) + def write_error(self, msg: str): + """write error log message""" + self.write_log(msg=msg, level=ERROR) + def get_engine_type(self): """ Return whether the cta_engine is backtesting or live trading. @@ -278,10 +344,10 @@ class CtaTemplate(ABC): return self.cta_engine.get_engine_type() def load_bar( - self, - days: int, - interval: Interval = Interval.MINUTE, - callback: Callable = None, + self, + days: int, + interval: Interval = Interval.MINUTE, + callback: Callable = None, ): """ Load historical bar data for initializing strategy. @@ -452,3 +518,474 @@ class TargetPosTemplate(CtaTemplate): else: vt_orderids = self.short(short_price, abs(pos_change)) self.vt_orderids.extend(vt_orderids) + + +class CtaProTemplate(CtaTemplate): + """ + 增强模板 + """ + + # 逻辑过程日志 + dist_fieldnames = ['datetime', 'symbol', 'volume', 'price', + 'operation', 'signal', 'stop_price', 'target_price', + 'long_pos', 'short_pos'] + + def __init__(self, cta_engine, strategy_name, vt_symbol, setting): + """""" + super(CtaProTemplate, self).__init__( + cta_engine, strategy_name, vt_symbol, setting + ) + + self.idx_symbol = None # 指数合约 + + self.price_tick = 1 # 商品的最小价格跳动 + self.symbol_size = 10 # 商品得合约乘数 + + self.cur_datetime = None # 当前Tick时间 + + self.cur_mi_tick = None # 最新的主力合约tick( vt_symbol) + self.cur_99_tick = None # 最新得指数合约tick( idx_symbol) + + self.cur_mi_price = None # 当前价(主力合约 vt_symbol) + self.cur_99_price = None # 当前价(tick时,根据tick更新,onBar回测时,根据bar.close更新) + + self.cancel_seconds = 120 # 撤单时间(秒) + + self.backtesting = False + + self.klines = {} # K线字典: kline_name: kline + + # 增加仓位管理模块 + self.position = CtaPosition(strategy=self) + + # 增加网格持久化模块 + self.gt = CtaGridTrade(strategy=self) + + # 增加指数合约 + if 'idx_symbol' not in self.parameters: + self.parameters.append('idx_symbol') + + if 'backtesting' not in self.parameters: + self.parameters.append('backtesting') + + def update_setting(self, setting: dict): + """ + Update strategy parameter wtih value in setting dict. + """ + for name in self.parameters: + if name in setting: + setattr(self, name, setting[name]) + + if self.idx_symbol is None: + symbol, exchange = extract_vt_symbol(self.vt_symbol) + self.idx_symbol = get_underlying_symbol(symbol).upper() + '99.' + exchange.value + if self.vt_symbol != self.idx_symbol: + self.write_log(f'指数合约:{self.idx_symbol}, 主力合约:{self.vt_symbol}') + self.price_tick = self.cta_engine.get_price_tick(self.vt_symbol) + self.symbol_size = self.cta_engine.get_size(self.vt_symbol) + + def save_klines_to_cache(self, kline_names: list = []): + """ + 保存K线数据到缓存 + :param kline_names: 一般为self.klines的keys + :return: + """ + if len(kline_names) == 0: + kline_names = list(self.klines.keys()) + + # 获取保存路径 + save_path = self.cta_engine.get_data_path() + # 保存缓存的文件名 + file_name = os.path.abspath(os.path.join(save_path, f'{self.strategy_name}_klines.pkb2')) + with bz2.BZ2File(file_name, 'wb') as f: + klines = {} + for kline_name in kline_names: + klines.update({kline_name: self.klines.get(kline_name, None)}) + pickle.dump(klines, f) + + def load_klines_from_cache(self, kline_names: list = []): + """ + 从缓存加载K线数据 + :param kline_names: + :return: + """ + if len(kline_names) == 0: + kline_names = list(self.klines.keys()) + + save_path = self.cta_engine.get_data_path() + file_name = os.path.abspath(os.path.join(save_path, f'{self.strategy_name}_klines.pkb2')) + try: + last_bar_dt = None + with bz2.BZ2File(file_name, 'rb') as f: + klines = pickle.load(f) + # 逐一恢复K线 + for kline_name in kline_names: + # 缓存的k线实例 + cache_kline = klines.get(kline_name, None) + # 当前策略实例的K线实例 + strategy_kline = self.klines.get(kline_name, None) + + if cache_kline and strategy_kline: + # 临时保存当前的回调函数 + cb_on_bar = strategy_kline.cb_on_bar + # 缓存实例数据 =》 当前实例数据 + strategy_kline.__dict__.update(cache_kline.__dict__) + + # 所有K线的最后时间 + if last_bar_dt and strategy_kline.cur_datetime: + last_bar_dt = max(last_bar_dt, strategy_kline.cur_datetime) + else: + last_bar_dt = strategy_kline.cur_datetime + + # 重新绑定k线策略与on_bar回调函数 + strategy_kline.strategy = self + strategy_kline.cb_on_bar = cb_on_bar + + self.write_log(f'恢复{kline_name}缓存数据,最新bar结束时间:{last_bar_dt}') + + self.write_log(u'加载缓存k线数据完毕') + return last_bar_dt + except Exception as ex: + self.write_error(f'加载缓存K线数据失败:{str(ex)}') + return None + + def get_klines_snapshot(self): + """返回当前klines的切片数据""" + try: + d = { + 'strategy': self.strategy_name, + 'datetime': datetime.now()} + + for kline_name in self.klines.keys(): + d.update({kline_name: self.klines.get(kline_name).get_data()}) + return d + except Exception as ex: + self.write_error(u'获取klines切片数据失败') + return {} + + def init_position(self): + """ + 初始化Positin + 使用网格的持久化,获取开仓状态的多空单,更新 + :return: + """ + self.write_log(u'init_position(),初始化持仓') + pos_symbols = set() + if len(self.gt.up_grids) <= 0: + self.position.short_pos = 0 + # 加载已开仓的空单数据,网格JSON + short_grids = self.gt.load(direction=Direction.SHORT, open_status_filter=[True]) + if len(short_grids) == 0: + self.write_log(u'没有持久化的空单数据') + self.gt.up_grids = [] + + else: + self.gt.up_grids = short_grids + for sg in short_grids: + if len(sg.order_ids) > 0 or sg.order_status: + self.write_log(f'重置委托状态:{sg.order_status},清除委托单:{sg.order_ids}') + sg.order_status = False + sg.order_ids = [] + + short_symbol = sg.snapshot.get('mi_symbol', self.vt_symbol) + pos_symbols.add(short_symbol) + self.write_log(u'加载持仓空单[{},价格:{}],[指数:{},价格:{}],数量:{}手' + .format(short_symbol, sg.snapshot.get('open_price'), + self.idx_symbol, sg.open_price, sg.volume)) + self.position.short_pos -= sg.volume + + self.write_log(u'持久化空单,共持仓:{}手'.format(abs(self.position.short_pos))) + + if len(self.gt.dn_grids) <= 0: + # 加载已开仓的多数据,网格JSON + self.position.long_pos = 0 + long_grids = self.gt.load(direction=Direction.LONG, open_status_filter=[True]) + if len(long_grids) == 0: + self.write_log(u'没有持久化的多单数据') + self.gt.dn_grids = [] + else: + self.gt.dn_grids = long_grids + for lg in long_grids: + + if len(lg.order_ids) > 0 or lg.order_status: + self.write_log(f'重置委托状态:{lg.order_status},清除委托单:{lg.order_ids}') + lg.order_status = False + lg.order_ids = [] + # lg.type = self.line.name + long_symbol = lg.snapshot.get('mi_symbol', self.vt_symbol) + pos_symbols.add(long_symbol) + + self.write_log(u'加载持仓多单[{},价格:{}],[指数{},价格:{}],数量:{}手' + .format(lg.snapshot.get('miSymbol'), lg.snapshot.get('open_price'), + self.idx_symbol, lg.open_price, lg.volume)) + self.position.long_pos += lg.volume + + self.write_log(f'持久化多单,共持仓:{self.position.long_pos}手') + + self.position.pos = self.position.long_pos + self.position.short_pos + + self.write_log( + u'{}加载持久化数据完成,多单:{},空单:{},共:{}手' + .format(self.strategy_name, + self.position.long_pos, + abs(self.position.short_pos), + self.position.pos)) + self.pos = self.position.pos + self.gt.save() + self.display_grids() + + if not self.backtesting: + pos_symbols.add(self.vt_symbol) + pos_symbols.add(self.idx_symbol) + # 如果持仓的合约,不在self.vt_symbol中,需要订阅 + for symbol in list(pos_symbols): + self.write_log(f'新增订阅合约:{symbol}') + self.cta_engine.subscribe_symbol(strategy_name=self.strategy_name, vt_symbol=symbol) + + def get_positions(self): + """ + 获取策略当前持仓(重构,使用主力合约) + :return: [{'vt_symbol':symbol,'direction':direction,'volume':volume] + """ + if not self.position: + return [] + pos_list = [] + + if self.position.long_pos > 0: + for g in self.gt.get_opened_grids(direction=Direction.LONG): + vt_symbol = g.snapshot.get('mi_symbol', self.vt_symbol) + open_price = g.snapshot.get('open_price', g.openPrice) + pos_list.append({'vt_symbol': vt_symbol, + 'direction': 'long', + 'volume': g.volume - g.traded_volume, + 'price': open_price}) + + if abs(self.position.short_pos) > 0: + for g in self.gt.get_opened_grids(direction=Direction.SHORT): + vt_symbol = g.snapshot.get('mi_symbol', self.vt_symbol) + open_price = g.snapshot.get('open_price', g.open_price) + pos_list.append({'vt_symbol': vt_symbol, + 'direction': 'short', + 'volume': abs(g.volume - g.traded_volume), + 'price': open_price}) + + if self.cur_datetime and (datetime.now() - self.cur_datetime).total_seconds() < 10: + self.write_log(u'当前持仓:{}'.format(pos_list)) + return pos_list + + def tns_cancel_logic(self, dt, force=False): + "撤单逻辑""" + if len(self.active_orders) < 1: + self.entrust = 0 + return + + for vt_orderid in list(self.active_orders.keys()): + order_info = self.active_orders.get(vt_orderid) + if order_info.get('status', None) in [Status.CANCELLED, Status.REJECTED]: + self.active_orders.pop(vt_orderid, None) + continue + + order_time = order_info.get('order_time') + over_ms = (dt - order_time).total_seconds() + if (over_ms > self.cancel_seconds) \ + or force: # 超过设置的时间还未成交 + self.write_log(f'{dt}, 超时{over_ms}秒未成交,取消委托单:{order_info}') + + if self.cancel_order(vt_orderid): + order_info.update({'status': Status.CANCELLING}) + else: + order_info.update({'status': Status.CANCELLED}) + + if len(self.active_orders) < 1: + self.entrust = 0 + + def tns_switch_long_pos(self): + """切换合约,从持仓的非主力合约,切换至主力合约""" + + if self.entrust != 0 and self.position.long_pos == 0: + return + + if self.cur_mi_price == 0: + return + + none_mi_grid = None + none_mi_symbol = None + + # 找出非主力合约的持仓网格 + for g in self.gt.get_opened_grids(direction=Direction.LONG): + none_mi_symbol = g.snapshot.get('mi_symbol') + if none_mi_symbol is None or none_mi_symbol == self.vt_symbol: + # 如果持仓的合约,跟策略配置的vt_symbol一致,则不处理 + continue + if not g.open_status or g.order_status or g.volume - g.traded_volume <= 0: + continue + none_mi_grid = g + if g.traded_volume > 0 and g.volume - g.traded_volume > 0: + g.volume -= g.traded_volume + g.traded_volume = 0 + break + if none_mi_grid is None: + return + + # 找到行情中非主力合约/主力合约的最新价 + none_mi_tick = self.tick_dict.get(none_mi_symbol) + mi_tick = self.tick_dict.get(self.vt_symbol, None) + if none_mi_tick is None or mi_tick is None: + return + + # 如果涨停价,不做卖出 + if self.is_upper_limit(none_mi_symbol) or self.is_upper_limit(self.vt_symbol): + return + none_mi_price = max(none_mi_tick.last_price, none_mi_tick.bid_price_1) + + grid = copy.copy(none_mi_grid) + + # 委托卖出非主力合约 + order_ids = self.sell(price=none_mi_price, volume=none_mi_grid.volume, vt_symbol=none_mi_symbol, + grid=none_mi_grid) + if len(order_ids) > 0: + self.write_log(f'切换合约,委托卖出非主力合约{none_mi_symbol}持仓:{none_mi_grid.volume}') + + # 添加买入主力合约 + grid.id = str(uuid.uuid1()) + grid.snapshot.update({'mi_symbol': self.vt_symbol, 'open_price': self.cur_mi_price}) + self.gt.dn_grids.append(grid) + + order_ids = self.buy(price=self.cur_mi_price, volume=grid.volume, vt_symbol=self.vt_symbol, grid=grid) + if len(order_ids) > 0: + self.write_log(u'切换合约,委托买入主力合约:{},价格:{},数量:{}' + .format(self.vt_symbol, self.cur_mi_price, grid.volume)) + self.gt.save() + else: + self.write_error(f'委托买入主力合约:{self.vt_symbol}失败') + else: + self.write_error(f'委托卖出非主力合约:{none_mi_symbol}失败') + + def tns_switch_short_pos(self): + """切换合约,从持仓的非主力合约,切换至主力合约""" + if self.entrust != 0 and self.position.short_pos == 0: + return + + if self.cur_mi_price == 0: + return + + none_mi_grid = None + none_mi_symbol = None + + # 找出非主力合约的持仓网格 + for g in self.gt.get_opened_grids(direction=Direction.SHORT): + none_mi_symbol = g.snapshot.get('miSymbol') + if none_mi_symbol is None or none_mi_symbol == self.vt_symbol: + continue + if not g.open_status or g.order_status or g.volume - g.traded_volume <= 0: + continue + none_mi_grid = g + if g.traded_volume > 0 and g.volume - g.traded_volume > 0: + g.volume -= g.traded_volume + g.traded_volume = 0 + break + + # 找不到与主力合约不一致的持仓网格 + if none_mi_grid is None: + return + + # 找到行情中非主力合约的最新价 + none_mi_tick = self.tick_dict.get(none_mi_symbol) + mi_tick = self.tick_dict.get(self.vt_symbol, None) + if none_mi_tick is None or mi_tick is None: + return + + # 如果跌停价,不做cover + if self.is_lower_limit(none_mi_symbol) or self.is_lower_limit(self.vt_symbol): + return + none_mi_price = max(none_mi_tick.last_price, none_mi_tick.bid_price_1) + + grid = copy.copy(none_mi_grid) + # 委托平空非主力合约 + order_ids = self.cover(price=none_mi_price, volume=none_mi_grid.volume, vt_symbol=self.vt_symbol, + grid=none_mi_grid) + if len(order_ids) > 0: + self.write_log(f'委托平空非主力合约{none_mi_symbol}持仓:{none_mi_grid.volume}') + + # 添加卖出主力合约 + grid.id = str(uuid.uuid1()) + grid.snapshot.update({'mi_symbol': self.vt_symbol, 'open_price': self.cur_mi_price}) + self.gt.up_grids.append(grid) + order_ids = self.short(price=self.cur_mi_price, volume=grid.volume, vt_symbol=self.vt_symbol, grid=grid) + if len(order_ids) > 0: + self.write_log(f'委托做空主力合约:{self.vt_symbol},价格:{self.cur_mi_price},数量:{grid.volume}') + self.gt.save() + else: + self.write_error(f'委托做空主力合约:{self.vt_symbol}失败') + else: + self.write_error(f'委托平空非主力合约:{none_mi_symbol}失败') + + def display_grids(self): + """更新网格显示信息""" + if not self.inited: + return + + up_grids_info = self.gt.to_str(direction=Direction.SHORT) + if len(self.gt.up_grids) > 0: + self.write_log(up_grids_info) + + dn_grids_info = self.gt.to_str(direction=Direction.LONG) + if len(self.gt.dn_grids) > 0: + self.write_log(dn_grids_info) + + def display_tns(self): + """显示事务的过程记录=》 log""" + if not self.inited: + return + self.write_log(u'{} 当前指数{}价格:{},当前主力{}价格:{}' + .format(self.cur_datetime, + self.idx_symbol, self.cur_99_price, + self.vt_symbol, self.cur_mi_price)) + if hasattr(self, 'policy'): + policy = getattr(self, 'policy') + op = getattr(policy, 'to_json', None) + if callable(op): + self.write_log(u'当前Policy:{}'.format(policy.to_json())) + + def save_dist(self, dist_data): + """ + 保存策略逻辑过程记录=》 csv文件按 + :param dist_data: + :return: + """ + if self.backtesting: + save_path = self.cta_engine.get_logs_path() + else: + save_path = self.cta_engine.get_data_path() + try: + if self.position: + dist_data.update({'long_pos': self.position.long_pos}) + dist_data.update({'short_pos': self.position.short_pos}) + + file_name = os.path.abspath(os.path.join(save_path, f'{self.strategy_name}_dist.csv')) + append_data(file_name=file_name, dict_data=dist_data, field_names=self.dist_fieldnames) + except Exception as ex: + self.write_error(u'save_dist 异常:{} {}'.format(str(ex), traceback.format_exc())) + + def save_tns(self, tns_data): + """ + 保存多空事务记录=》csv文件,便于后续分析 + :param tns_data: + :return: + """ + if self.backtesting: + save_path = self.cta_engine.get_logs_path() + else: + save_path = self.cta_engine.get_data_path() + + try: + file_name = os.path.abspath(os.path.join(save_path, f'{self.strategy_name}_tns.csv')) + append_data(file_name=file_name, dict_data=tns_data) + except Exception as ex: + self.write_error(u'save_tns 异常:{} {}'.format(str(ex), traceback.format_exc())) + + def send_wechat(self, msg: str): + """实盘时才发送微信""" + if self.backtesting: + return + self.cta_engine.send_wechat(msg=msg, strategy=self) diff --git a/vnpy/app/cta_strategy_pro/test_line_bar_01.py b/vnpy/app/cta_strategy_pro/test_line_bar_01.py new file mode 100644 index 00000000..07a045ca --- /dev/null +++ b/vnpy/app/cta_strategy_pro/test_line_bar_01.py @@ -0,0 +1,65 @@ +# flake8: noqa + +# 测试 app.cta_strategy_pro.CtaLineBar组件 +# 从通达信获取历史交易记录,模拟tick。推送至line_bar + +import os +import sys +import json + +vnpy_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..')) +if vnpy_root not in sys.path: + print(f'sys.path apppend:{vnpy_root}') + sys.path.append(vnpy_root) + +os.environ["VNPY_TESTING"] = "1" + +from vnpy.trader.constant import Interval, Exchange +from vnpy.data.tdx.tdx_common import FakeStrategy +from vnpy.data.tdx.tdx_future_data import TdxFutureData +from vnpy.app.cta_strategy_pro.cta_line_bar import CtaLineBar +from vnpy.trader.object import TickData +from vnpy.trader.utility import get_trading_date + +t1 = FakeStrategy() + +tdx_api = TdxFutureData(strategy=t1) + + +def on_bar(bar): + print(f'{bar.__dict__}') + + +# 创建10秒周期的k线 +kline_setting = {} +kline_setting["name"] = "S10" +kline_setting['interval'] = Interval.SECOND +kline_setting['bar_interval'] = 10 +kline_setting['price_tick'] = 0.5 +kline_setting['underlying_symbol'] = 'J' +kline_s10 = CtaLineBar(strategy=t1, cb_on_bar=on_bar, setting=kline_setting) + +ret, result = tdx_api.get_history_transaction_data('J99', '20200106') +# for data in result[0:10] + result[-10:]: +# print(data) + +for data in result: + dt = data['datetime'] + price = float(data['price']) + volume = float(data['volume']) + + tick = TickData( + gateway_name='tdx', + datetime=dt, + last_price=price, + volume=volume, + symbol='J99', + exchange=Exchange('DCE'), + date=dt.strftime('%Y-%m-%d'), + time=dt.strftime('%H:%M:%S'), + trading_day=get_trading_date(dt) + ) + + kline_s10.on_tick(tick) + +os._exit(0) diff --git a/vnpy/app/cta_strategy_pro/test_line_bar_02.py b/vnpy/app/cta_strategy_pro/test_line_bar_02.py new file mode 100644 index 00000000..ef3568da --- /dev/null +++ b/vnpy/app/cta_strategy_pro/test_line_bar_02.py @@ -0,0 +1,442 @@ +# flake8: noqa + +# 测试 app.cta_strategy_pro.CtaLineBar组件 +# 从通达信获取历史交易记录,模拟tick。推送至line_bar + +import os +import sys +import json +import traceback +from datetime import datetime, timedelta + +vnpy_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..')) +if vnpy_root not in sys.path: + print(f'sys.path apppend:{vnpy_root}') + sys.path.append(vnpy_root) + +os.environ["VNPY_TESTING"] = "1" + +from vnpy.trader.constant import Interval, Exchange +from vnpy.trader.object import BarData +from vnpy.app.cta_strategy_pro.cta_line_bar import ( + CtaLineBar, + CtaMinuteBar, + CtaHourBar, + CtaDayBar, + CtaWeekBar) +from vnpy.trader.utility import round_to + +class test_strategy(object): + + def __init__(self): + + self.price_tick = 1 + self.underlying_symbol = 'I' + self.vt_symbol = 'I99' + + self.lineM5 = None + self.lineM30 = None + self.lineH1 = None + self.lineH2 = None + self.lineD = None + self.lineW = None + + self.TMinuteInterval = 1 + + self.save_m30_bars = [] + self.save_h1_bars = [] + self.save_h2_bars = [] + self.save_d_bars = [] + + self.save_w_bars = [] + + def createM5(self): + """使用ctalinbar,创建5分钟K线""" + lineM5Setting = {} + lineM5Setting['name'] = u'M5' + lineM5Setting['interval'] = Interval.MINUTE + lineM5Setting['bar_interval'] = 5 + lineM5Setting['mode'] = CtaLineBar.TICK_MODE + lineM5Setting['price_tick'] = self.price_tick + lineM5Setting['underlying_symbol'] = self.underlying_symbol + self.lineM5 = CtaLineBar(self, self.onBarM5, lineM5Setting) + + def onBarM5(self, bar): + self.write_log(self.lineM5.get_last_bar_str()) + + def createlineM30_with_macd(self): + """使用CtaLineBar,创建30分钟时间""" + # 创建M30 K线 + lineM30Setting = {} + lineM30Setting['name'] = u'M30' + lineM30Setting['interval'] = Interval.MINUTE + lineM30Setting['bar_interval'] = 30 + lineM30Setting['para_macd_fast_len'] = 26 + lineM30Setting['para_macd_slow_len'] = 12 + lineM30Setting['para_macd_signal_len'] = 9 + lineM30Setting['price_tick'] = self.price_tick + lineM30Setting['underlying_symbol'] = self.underlying_symbol + self.lineM30 = CtaLineBar(self, self.onBarM30MACD, lineM30Setting) + + def onBarM30MACD(self, bar): + self.write_log(self.lineM30.get_last_bar_str()) + + def createLineM30(self): + """使用ctaMinuteBar, 测试内部自动写入csv文件""" + # 创建M30 K线 + lineM30Setting = {} + lineM30Setting['name'] = u'M30' + lineM30Setting['interval'] = Interval.MINUTE + lineM30Setting['bar_interval'] = 30 + lineM30Setting['para_pre_len'] = 10 + lineM30Setting['para_ma1_len'] = 5 + lineM30Setting['para_ma2_len'] = 10 + lineM30Setting['para_ma3_len'] = 60 + lineM30Setting['para_active_yb'] = True + lineM30Setting['para_active_skd'] = True + lineM30Setting['price_tick'] = self.price_tick + lineM30Setting['underlying_symbol'] = self.underlying_symbol + self.lineM30 = CtaMinuteBar(self, self.onBarM30, lineM30Setting) + + # 写入文件 + self.lineM30.export_filename = os.path.abspath( + os.path.join(os.getcwd(), + u'export_{}_{}.csv'.format(self.vt_symbol, self.lineM30.name))) + + self.lineM30.export_fields = [ + {'name': 'datetime', 'source': 'bar', 'attr': 'datetime', 'type_': 'datetime'}, + {'name': 'open', 'source': 'bar', 'attr': 'open_price', 'type_': 'float'}, + {'name': 'high', 'source': 'bar', 'attr': 'high_price', 'type_': 'float'}, + {'name': 'low', 'source': 'bar', 'attr': 'low_price', 'type_': 'float'}, + {'name': 'close', 'source': 'bar', 'attr': 'close_price', 'type_': 'float'}, + {'name': 'turnover', 'source': 'bar', 'attr': 'turnover', 'type_': 'float'}, + {'name': 'volume', 'source': 'bar', 'attr': 'volume', 'type_': 'float'}, + {'name': 'open_interest', 'source': 'bar', 'attr': 'open_interest', 'type_': 'float'}, + {'name': 'kf', 'source': 'line_bar', 'attr': 'line_statemean', 'type_': 'list'} + ] + + def createLineH1(self): + # 创建2小时K线 + lineH1Setting = {} + lineH1Setting['name'] = u'H1' + lineH1Setting['interval'] = Interval.HOUR + lineH1Setting['bar_interval'] = 1 + lineH1Setting['para_pre_len'] = 10 + lineH1Setting['para_ema1_len'] = 5 + lineH1Setting['para_ema2_len'] = 10 + lineH1Setting['para_ema3_len'] = 60 + lineH1Setting['para_active_yb'] = True + lineH1Setting['para_active_skd'] = True + lineH1Setting['price_tick'] = self.price_tick + lineH1Setting['underlying_symbol'] = self.underlying_symbol + self.lineH1 = CtaLineBar(self, self.onBarH1, lineH1Setting) + + def createLineH2(self): + # 创建2小时K线 + lineH2Setting = {} + lineH2Setting['name'] = u'H2' + lineH2Setting['interval'] = Interval.HOUR + lineH2Setting['bar_interval'] = 2 + lineH2Setting['para_pre_len'] = 5 + lineH2Setting['para_ma1_len'] = 5 + lineH2Setting['para_ma2_len'] = 10 + lineH2Setting['para_ma3_len'] = 18 + lineH2Setting['para_active_yb'] = True + lineH2Setting['para_active_skd'] = True + lineH2Setting['mode'] = CtaLineBar.TICK_MODE + lineH2Setting['price_tick'] = self.price_tick + lineH2Setting['underlying_symbol'] = self.underlying_symbol + self.lineH2 = CtaHourBar(self, self.onBarH2, lineH2Setting) + + def createLineD(self): + # 创建的日K线 + lineDaySetting = {} + lineDaySetting['name'] = u'D1' + lineDaySetting['bar_interval'] = 1 + lineDaySetting['para_pre_len'] = 5 + lineDaySetting['para_art1_len'] = 26 + lineDaySetting['para_ma1_len'] = 5 + lineDaySetting['para_ma2_len'] = 10 + lineDaySetting['para_ma3_len'] = 18 + lineDaySetting['para_active_yb'] = True + lineDaySetting['para_active_skd'] = True + lineDaySetting['price_tick'] = self.price_tick + lineDaySetting['underlying_symbol'] = self.underlying_symbol + self.lineD = CtaDayBar(self, self.onBarD, lineDaySetting) + + def createLineW(self): + """创建周线""" + lineWeekSetting = {} + lineWeekSetting['name'] = u'W1' + lineWeekSetting['para_pre_len'] = 5 + lineWeekSetting['para_art1_len'] = 26 + lineWeekSetting['para_ma1_len'] = 5 + lineWeekSetting['para_ma2_len'] = 10 + lineWeekSetting['para_ma3_len'] = 18 + lineWeekSetting['para_active_yb'] = True + lineWeekSetting['para_active_skd'] = True + lineWeekSetting['mode'] = CtaDayBar.TICK_MODE + lineWeekSetting['price_tick'] = self.price_tick + lineWeekSetting['underlying_symbol'] = self.underlying_symbol + self.lineW = CtaWeekBar(self, self.onBarW, lineWeekSetting) + + def onBar(self, bar): + # print(u'tradingDay:{},dt:{},o:{},h:{},l:{},c:{},v:{}'.format(bar.trading_day,bar.datetime, bar.open, bar.high, bar.low_price, bar.close_price, bar.volume)) + if self.lineW: + self.lineW.add_bar(bar, bar_freq=self.TMinuteInterval) + if self.lineD: + self.lineD.add_bar(bar, bar_freq=self.TMinuteInterval) + if self.lineH2: + self.lineH2.add_bar(bar, bar_freq=self.TMinuteInterval) + + if self.lineH1: + self.lineH1.add_bar(bar, bar_freq=self.TMinuteInterval) + + if self.lineM30: + self.lineM30.add_bar(bar, bar_freq=self.TMinuteInterval) + + if self.lineM5: + self.lineM5.add_bar(bar, bar_freq=self.TMinuteInterval) + + # if self.lineH2: + # self.lineH2.skd_is_high_dead_cross(runtime=True, high_skd=30) + # self.lineH2.skd_is_low_golden_cross(runtime=True, low_skd=70) + + def onBarM30(self, bar): + self.write_log(self.lineM30.get_last_bar_str()) + + self.save_m30_bars.append({ + 'datetime': bar.datetime, + 'open': bar.open_price, + 'high': bar.high_price, + 'low': bar.low_price, + 'close': bar.close_price, + 'turnover': 0, + 'volume': bar.volume, + 'open_interest': 0, + 'ma5': self.lineM30.line_ma1[-1] if len(self.lineM30.line_ma1) > 0 else bar.close_price, + 'ma10': self.lineM30.line_ma2[-1] if len(self.lineM30.line_ma2) > 0 else bar.close_price, + 'ma60': self.lineM30.line_ma3[-1] if len(self.lineM30.line_ma3) > 0 else bar.close_price, + 'sk': self.lineM30.line_sk[-1] if len(self.lineM30.line_sk) > 0 else 0, + 'sd': self.lineM30.line_sd[-1] if len(self.lineM30.line_sd) > 0 else 0 + }) + + def onBarH1(self, bar): + self.write_log(self.lineH1.get_last_bar_str()) + + self.save_h1_bars.append({ + 'datetime': bar.datetime, + 'open': bar.open_price, + 'high': bar.high_price, + 'low': bar.low_price, + 'close': bar.close_price, + 'turnover': 0, + 'volume': bar.volume, + 'open_interest': 0, + 'ema5': self.lineH1.line_ema1[-1] if len(self.lineH1.line_ema1) > 0 else bar.close_price, + 'ema10': self.lineH1.line_ema2[-1] if len(self.lineH1.line_ema2) > 0 else bar.close_price, + 'ema60': self.lineH1.line_ema3[-1] if len(self.lineH1.line_ema3) > 0 else bar.close_price, + 'sk': self.lineH1.line_sk[-1] if len(self.lineH1.line_sk) > 0 else 0, + 'sd': self.lineH1.line_sd[-1] if len(self.lineH1.line_sd) > 0 else 0 + }) + + def onBarH2(self, bar): + self.write_log(self.lineH2.get_last_bar_str()) + + self.save_h2_bars.append({ + 'datetime': bar.datetime, + 'open': bar.open_price, + 'high': bar.high_price, + 'low': bar.low_price, + 'close': bar.close_price, + 'turnover': 0, + 'volume': bar.volume, + 'open_interest': 0, + 'ma5': self.lineH2.line_ma1[-1] if len(self.lineH2.line_ma1) > 0 else bar.close_price, + 'ma10': self.lineH2.line_ma2[-1] if len(self.lineH2.line_ma2) > 0 else bar.close_price, + 'ma18': self.lineH2.line_ma3[-1] if len(self.lineH2.line_ma3) > 0 else bar.close_price, + 'sk': self.lineH2.line_sk[-1] if len(self.lineH2.line_sk) > 0 else 0, + 'sd': self.lineH2.line_sd[-1] if len(self.lineH2.line_sd) > 0 else 0 + }) + + def onBarD(self, bar): + self.write_log(self.lineD.get_last_bar_str()) + self.save_d_bars.append({ + 'datetime': bar.datetime, + 'open': bar.open_price, + 'high': bar.high_price, + 'low': bar.low_price, + 'close': bar.close_price, + 'turnover': 0, + 'volume': bar.volume, + 'open_interest': 0, + 'ma5': self.lineD.line_ma1[-1] if len(self.lineD.line_ma1) > 0 else bar.close_price, + 'ma10': self.lineD.line_ma2[-1] if len(self.lineD.line_ma2) > 0 else bar.close_price, + 'ma18': self.lineD.line_ma3[-1] if len(self.lineD.line_ma3) > 0 else bar.close_price, + 'sk': self.lineD.line_sk[-1] if len(self.lineD.line_sk) > 0 else 0, + 'sd': self.lineD.line_sd[-1] if len(self.lineD.line_sd) > 0 else 0 + }) + + def onBarW(self, bar): + self.write_log(self.lineW.get_last_bar_str()) + self.save_w_bars.append({ + 'datetime': bar.datetime, + 'open': bar.open_price, + 'high': bar.high_price, + 'low': bar.low_price, + 'close': bar.close_price, + 'turnover': 0, + 'volume': bar.volume, + 'open_interest': 0, + 'ma5': self.lineW.line_ma1[-1] if len(self.lineW.line_ma1) > 0 else bar.close_price, + 'ma10': self.lineW.line_ma2[-1] if len(self.lineW.line_ma2) > 0 else bar.close_price, + 'ma18': self.lineW.line_ma3[-1] if len(self.lineW.line_ma3) > 0 else bar.close_price, + 'sk': self.lineW.line_sk[-1] if len(self.lineW.line_sk) > 0 else 0, + 'sd': self.lineW.line_sd[-1] if len(self.lineW.line_sd) > 0 else 0 + }) + + def on_tick(self, tick): + print(u'{0},{1},ap:{2},av:{3},bp:{4},bv:{5}'.format(tick.datetime, tick.last_price, tick.ask_price_1, + tick.ask_volume_1, tick.bid_price_1, tick.bid_volume_1)) + + def write_log(self, content): + print(content) + + def saveData(self): + + if len(self.save_m30_bars) > 0: + outputFile = '{}_m30.csv'.format(self.vt_symbol) + with open(outputFile, 'w', encoding='utf8', newline='') as f: + fieldnames = ['datetime', 'open', 'high', 'low', 'close', 'turnover', 'volume', 'open_interest', + 'ma5', 'ma10', 'ma60', 'sk', 'sd'] + writer = csv.DictWriter(f=f, fieldnames=fieldnames, dialect='excel') + writer.writeheader() + for row in self.save_m30_bars: + writer.writerow(row) + + if len(self.save_h1_bars) > 0: + outputFile = '{}_h1.csv'.format(self.vt_symbol) + with open(outputFile, 'w', encoding='utf8', newline='') as f: + fieldnames = ['datetime', 'open', 'high', 'low', 'close', 'turnover', 'volume', 'open_interest', + 'ema5', 'ema10', 'ema60', 'sk', 'sd'] + writer = csv.DictWriter(f=f, fieldnames=fieldnames, dialect='excel') + writer.writeheader() + for row in self.save_h1_bars: + writer.writerow(row) + + if len(self.save_h2_bars) > 0: + outputFile = '{}_h2.csv'.format(self.vt_symbol) + with open(outputFile, 'w', encoding='utf8', newline='') as f: + fieldnames = ['datetime', 'open', 'high', 'low', 'close', 'turnover', 'volume', 'open_interest', + 'ma5', 'ma10', 'ma18', 'sk', 'sd'] + writer = csv.DictWriter(f=f, fieldnames=fieldnames, dialect='excel') + writer.writeheader() + for row in self.save_h2_bars: + writer.writerow(row) + + if len(self.save_d_bars) > 0: + outputFile = '{}_d.csv'.format(self.vt_symbol) + with open(outputFile, 'w', encoding='utf8', newline='') as f: + fieldnames = ['datetime', 'open', 'high', 'low', 'close', 'turnover', 'volume', 'open_interest', + 'ma5', 'ma10', 'ma18', 'sk', 'sd'] + writer = csv.DictWriter(f=f, fieldnames=fieldnames, dialect='excel') + writer.writeheader() + for row in self.save_d_bars: + writer.writerow(row) + + if len(self.save_w_bars) > 0: + outputFile = '{}_w.csv'.format(self.vt_symbol) + with open(outputFile, 'w', encoding='utf8', newline='') as f: + fieldnames = ['datetime', 'open', 'high', 'low', 'close', 'turnover', 'volume', 'open_interest', + 'ma5', 'ma10', 'ma18', 'sk', 'sd'] + writer = csv.DictWriter(f=f, fieldnames=fieldnames, dialect='excel') + writer.writeheader() + for row in self.save_w_bars: + writer.writerow(row) + + +if __name__ == '__main__': + t = test_strategy() + t.price_tick = 0.5 + t.underlying_symbol = 'J' + t.vt_symbol = 'J99' + + # t.createM5() + # t.createLineW() + + # t.createlineM30_with_macd() + + # 创建M30线 + # t.createLineM30() + + # 回测1小时线 + # t.createLineH1() + + # 回测2小时线 + # t.createLineH2() + + # 回测日线 + # t.createLineD() + + # 测试周线 + t.createLineW() + + # vnpy/app/cta_strategy_pro/ + vnpy_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..')) + + filename = os.path.abspath(os.path.join(vnpy_root, 'bar_data/{}_20160101_1m.csv'.format(t.vt_symbol))) + csv_bar_seconds = 60 # csv 文件内,bar的时间间隔60秒 + + import csv + + csvfile = open(filename, 'r', encoding='utf8') + reader = csv.DictReader((line.replace('\0', '') for line in csvfile), delimiter=",") + last_tradingDay = None + for row in reader: + try: + dt = datetime.strptime(row['datetime'], '%Y-%m-%d %H:%M:%S') - timedelta(seconds=csv_bar_seconds) + + bar = BarData( + gateway_name='', + symbol=t.vt_symbol, + exchange=Exchange.LOCAL, + datetime=dt, + interval=Interval.MINUTE, + open_price=round_to(float(row['open']), t.price_tick), + high_price=round_to(float(row['high']), t.price_tick), + low_price=round_to(float(row['low']), t.price_tick), + close_price=round_to(float(row['close']), t.price_tick), + volume=float(row['volume']) + ) + + if 'trading_date' in row: + bar.trading_day = row['trading_date'] + if len(bar.trading_day) == 8 and '-' not in bar.trading_day: + bar.trading_day = bar.trading_day[0:4] + '-' + bar.trading_day[4:6] + '-' + bar.trading_day[6:8] + else: + if bar.datetime.hour >= 21: + if bar.datetime.isoweekday() == 5: + # 星期五=》星期一 + bar.trading_day = (dt + timedelta(days=3)).strftime('%Y-%m-%d') + else: + # 第二天 + bar.trading_day = (dt + timedelta(days=1)).strftime('%Y-%m-%d') + elif bar.datetime.hour < 8 and bar.datetime.isoweekday() == 6: + # 星期六=>星期一 + bar.trading_day = (dt + timedelta(days=2)).strftime('%Y-%m-%d') + else: + bar.trading_day = bar.datetime.strftime('%Y-%m-%d') + + t.onBar(bar) + # 测试 实时计算值 + # sk, sd = t.lineM30.getRuntimeSKD() + + # 测试实时计算值 + # if bar.datetime.minute==1: + # print('rt_Dif:{}'.format(t.lineM30.rt_Dif)) + except Exception as ex: + t.write_log(u'{0}:{1}'.format(Exception, ex)) + traceback.print_exc() + break + + t.saveData() diff --git a/vnpy/app/index_tick_publisher/engine.py b/vnpy/app/index_tick_publisher/engine.py index 2a9447a1..d9ee536e 100644 --- a/vnpy/app/index_tick_publisher/engine.py +++ b/vnpy/app/index_tick_publisher/engine.py @@ -3,8 +3,6 @@ # 通达信指数行情发布器 # 华富资产 -import os -import sys import copy import json import traceback diff --git a/vnpy/app/risk_manager/engine.py b/vnpy/app/risk_manager/engine.py index 37493165..3a54fdde 100644 --- a/vnpy/app/risk_manager/engine.py +++ b/vnpy/app/risk_manager/engine.py @@ -11,7 +11,6 @@ from vnpy.trader.event import EVENT_TRADE, EVENT_ORDER, EVENT_LOG, EVENT_ACCOUNT from vnpy.trader.constant import Status from vnpy.trader.utility import load_json, save_json - APP_NAME = "RiskManager" @@ -42,12 +41,12 @@ class RiskManagerEngine(BaseEngine): self.active_order_limit = 500 # 总仓位相关(0~100+) - self.percent_limit = 100 # 仓位比例限制 + self.percent_limit = 100 # 仓位比例限制 self.last_over_time = None # 启动风控后,最后一次超过仓位限制的时间 - self.account_dict = {} # 资金账号信息 - self.gateway_dict = {} # 记录gateway对应的仓位比例 - self.currency_list = [] # 资金账号风控管理得币种 + self.account_dict = {} # 资金账号信息 + self.gateway_dict = {} # 记录gateway对应的仓位比例 + self.currency_list = [] # 资金账号风控管理得币种 self.load_setting() self.register_event() @@ -178,6 +177,24 @@ class RiskManagerEngine(BaseEngine): account.balance, account_percent, self.percent_limit) self.write_log(msg) + def get_account(self, vt_accountid: str = ""): + """获取账号的当前净值,可用资金,账号当前仓位百分比,允许的最大仓位百分比""" + if vt_accountid: + account = self.account_dict.get(vt_accountid, None) + if account: + return account.balance, \ + account.available, \ + round(account.frozen * 100 / (account.balance + 0.01), 2), \ + self.percent_limit + if len(self.account_dict.values()) > 0: + account = self.account_dict.values()[0] + return account.balance, \ + account.available, \ + round(account.frozen * 100 / (account.balance + 0.01), 2), \ + self.percent_limit + else: + return 0, 0, 0, 0 + def write_log(self, msg: str): """""" log = LogData(msg=msg, gateway_name="RiskManager") diff --git a/vnpy/data/tdx/future_contracts.json b/vnpy/data/tdx/future_contracts.json index 518ff4fa..fc235b2a 100644 --- a/vnpy/data/tdx/future_contracts.json +++ b/vnpy/data/tdx/future_contracts.json @@ -13,7 +13,7 @@ "mi_symbol": "ag2007", "full_symbol": "AG2007", "exchange": "SHFE", - "margin_rate": 0.08, + "margin_rate": 0.07, "symbol_size": 15, "price_tick": 1.0 }, @@ -22,7 +22,7 @@ "mi_symbol": "al2003", "full_symbol": "AL2003", "exchange": "SHFE", - "margin_rate": 0.1, + "margin_rate": 0.07, "symbol_size": 5, "price_tick": 5.0 }, @@ -31,7 +31,7 @@ "mi_symbol": "AP005", "full_symbol": "AP2005", "exchange": "CZCE", - "margin_rate": 0.07, + "margin_rate": 0.08, "symbol_size": 10, "price_tick": 1.0 }, @@ -40,7 +40,7 @@ "mi_symbol": "au2006", "full_symbol": "AU2006", "exchange": "SHFE", - "margin_rate": 0.08, + "margin_rate": 0.06, "symbol_size": 1000, "price_tick": 0.02 }, @@ -55,8 +55,8 @@ }, "BB": { "underlying_symbol": "BB", - "mi_symbol": "bb2012", - "full_symbol": "BB2012", + "mi_symbol": "bb2101", + "full_symbol": "BB2101", "exchange": "DCE", "margin_rate": 0.2, "symbol_size": 500, @@ -67,7 +67,7 @@ "mi_symbol": "bu2006", "full_symbol": "BU2006", "exchange": "SHFE", - "margin_rate": 0.1, + "margin_rate": 0.09, "symbol_size": 10, "price_tick": 2.0 }, @@ -82,8 +82,8 @@ }, "CF": { "underlying_symbol": "CF", - "mi_symbol": "CF005", - "full_symbol": "CF2005", + "mi_symbol": "CF009", + "full_symbol": "CF2009", "exchange": "CZCE", "margin_rate": 0.05, "symbol_size": 5, @@ -112,7 +112,7 @@ "mi_symbol": "cu2003", "full_symbol": "CU2003", "exchange": "SHFE", - "margin_rate": 0.09, + "margin_rate": 0.07, "symbol_size": 5, "price_tick": 10.0 }, @@ -139,7 +139,7 @@ "mi_symbol": "eg2005", "full_symbol": "EG2005", "exchange": "DCE", - "margin_rate": 0.05, + "margin_rate": 0.06, "symbol_size": 10, "price_tick": 1.0 }, @@ -148,9 +148,9 @@ "mi_symbol": "fb2005", "full_symbol": "FB2005", "exchange": "DCE", - "margin_rate": 0.2, - "symbol_size": 500, - "price_tick": 0.05 + "margin_rate": 0.1, + "symbol_size": 10, + "price_tick": 0.5 }, "FG": { "underlying_symbol": "FG", @@ -166,7 +166,7 @@ "mi_symbol": "fu2005", "full_symbol": "FU2005", "exchange": "SHFE", - "margin_rate": 0.2, + "margin_rate": 0.1, "symbol_size": 10, "price_tick": 1.0 }, @@ -175,7 +175,7 @@ "mi_symbol": "hc2005", "full_symbol": "HC2005", "exchange": "SHFE", - "margin_rate": 0.1, + "margin_rate": 0.08, "symbol_size": 10, "price_tick": 1.0 }, @@ -184,7 +184,7 @@ "mi_symbol": "i2005", "full_symbol": "I2005", "exchange": "DCE", - "margin_rate": 0.05, + "margin_rate": 0.08, "symbol_size": 100, "price_tick": 0.5 }, @@ -193,7 +193,7 @@ "mi_symbol": "IC2003", "full_symbol": "IC2003", "exchange": "CFFEX", - "margin_rate": 0.1, + "margin_rate": 0.12, "symbol_size": 200, "price_tick": 0.2 }, @@ -220,7 +220,7 @@ "mi_symbol": "j2005", "full_symbol": "J2005", "exchange": "DCE", - "margin_rate": 0.05, + "margin_rate": 0.08, "symbol_size": 100, "price_tick": 0.5 }, @@ -229,7 +229,7 @@ "mi_symbol": "jd2005", "full_symbol": "JD2005", "exchange": "DCE", - "margin_rate": 0.08, + "margin_rate": 0.07, "symbol_size": 10, "price_tick": 1.0 }, @@ -238,14 +238,14 @@ "mi_symbol": "jm2005", "full_symbol": "JM2005", "exchange": "DCE", - "margin_rate": 0.05, + "margin_rate": 0.08, "symbol_size": 60, "price_tick": 0.5 }, "JR": { "underlying_symbol": "JR", - "mi_symbol": "JR011", - "full_symbol": "JR2011", + "mi_symbol": "JR101", + "full_symbol": "JR2101", "exchange": "CZCE", "margin_rate": 0.05, "symbol_size": 20, @@ -328,14 +328,14 @@ "mi_symbol": "pb2003", "full_symbol": "PB2003", "exchange": "SHFE", - "margin_rate": 0.1, + "margin_rate": 0.07, "symbol_size": 5, "price_tick": 5.0 }, "PM": { "underlying_symbol": "PM", - "mi_symbol": "PM011", - "full_symbol": "PM2011", + "mi_symbol": "PM101", + "full_symbol": "PM2101", "exchange": "CZCE", "margin_rate": 0.05, "symbol_size": 50, @@ -355,14 +355,14 @@ "mi_symbol": "rb2005", "full_symbol": "RB2005", "exchange": "SHFE", - "margin_rate": 0.1, + "margin_rate": 0.08, "symbol_size": 10, "price_tick": 1.0 }, "RI": { "underlying_symbol": "RI", - "mi_symbol": "RI011", - "full_symbol": "RI2011", + "mi_symbol": "RI101", + "full_symbol": "RI2101", "exchange": "CZCE", "margin_rate": 0.05, "symbol_size": 20, @@ -373,7 +373,7 @@ "mi_symbol": "RM005", "full_symbol": "RM2005", "exchange": "CZCE", - "margin_rate": 0.05, + "margin_rate": 0.06, "symbol_size": 10, "price_tick": 1.0 }, @@ -391,7 +391,7 @@ "mi_symbol": "RS011", "full_symbol": "RS2011", "exchange": "CZCE", - "margin_rate": 0.05, + "margin_rate": 0.2, "symbol_size": 10, "price_tick": 1.0 }, @@ -400,7 +400,7 @@ "mi_symbol": "ru2005", "full_symbol": "RU2005", "exchange": "SHFE", - "margin_rate": 0.1, + "margin_rate": 0.09, "symbol_size": 10, "price_tick": 5.0 }, @@ -418,7 +418,7 @@ "mi_symbol": "sc2003", "full_symbol": "SC2003", "exchange": "INE", - "margin_rate": 0.05, + "margin_rate": 0.07, "symbol_size": 1000, "price_tick": 0.1 }, @@ -427,7 +427,7 @@ "mi_symbol": "SF005", "full_symbol": "SF2005", "exchange": "CZCE", - "margin_rate": 0.05, + "margin_rate": 0.07, "symbol_size": 5, "price_tick": 2.0 }, @@ -436,7 +436,7 @@ "mi_symbol": "SM005", "full_symbol": "SM2005", "exchange": "CZCE", - "margin_rate": 0.05, + "margin_rate": 0.07, "symbol_size": 5, "price_tick": 2.0 }, @@ -445,7 +445,7 @@ "mi_symbol": "sn2006", "full_symbol": "SN2006", "exchange": "SHFE", - "margin_rate": 0.09, + "margin_rate": 0.08, "symbol_size": 1, "price_tick": 10.0 }, @@ -490,7 +490,7 @@ "mi_symbol": "TA005", "full_symbol": "TA2005", "exchange": "CZCE", - "margin_rate": 0.05, + "margin_rate": 0.06, "symbol_size": 5, "price_tick": 2.0 }, @@ -535,7 +535,7 @@ "mi_symbol": "WH011", "full_symbol": "WH2011", "exchange": "CZCE", - "margin_rate": 0.05, + "margin_rate": 0.07, "symbol_size": 20, "price_tick": 1.0 }, @@ -544,7 +544,7 @@ "mi_symbol": "wr2012", "full_symbol": "WR2012", "exchange": "SHFE", - "margin_rate": 0.2, + "margin_rate": 0.08, "symbol_size": 10, "price_tick": 1.0 }, @@ -562,7 +562,7 @@ "mi_symbol": "ZC005", "full_symbol": "ZC2005", "exchange": "CZCE", - "margin_rate": 0.05, + "margin_rate": 0.06, "symbol_size": 100, "price_tick": 0.2 }, @@ -571,7 +571,7 @@ "mi_symbol": "zn2003", "full_symbol": "ZN2003", "exchange": "SHFE", - "margin_rate": 0.1, + "margin_rate": 0.07, "symbol_size": 5, "price_tick": 5.0 } diff --git a/vnpy/data/tdx/refill_tdx_future_bars.py b/vnpy/data/tdx/refill_tdx_future_bars.py index 0d6a9594..74eb6f54 100644 --- a/vnpy/data/tdx/refill_tdx_future_bars.py +++ b/vnpy/data/tdx/refill_tdx_future_bars.py @@ -29,7 +29,7 @@ api_01 = TdxFutureData() api_01.update_mi_contracts() # 逐一指数合约下载并更新 -for underlying_symbol in ['RB', 'J']: #api_01.future_contracts.keys(): +for underlying_symbol in api_01.future_contracts.keys(): index_symbol = underlying_symbol + '99' print(f'开始更新:{index_symbol}') # csv数据文件名 diff --git a/vnpy/data/tdx/tdx_future_data.py b/vnpy/data/tdx/tdx_future_data.py index 1b817fea..2b3d4078 100644 --- a/vnpy/data/tdx/tdx_future_data.py +++ b/vnpy/data/tdx/tdx_future_data.py @@ -137,7 +137,7 @@ class TdxFutureData(object): last_datetime = datetime.strptime(last_datetime_str, '%Y-%m-%d %H:%M:%S') if (datetime.now() - last_datetime).total_seconds() > 60 * 60 * 2: self.best_ip = {} - except Exception as ex: + except Exception as ex: # noqa self.best_ip = {} else: self.best_ip = {} @@ -262,6 +262,8 @@ class TdxFutureData(object): """ ret_bars = [] + if '.' in symbol: + symbol = symbol.split('.')[0] tdx_symbol = symbol.upper().replace('_', '') tdx_symbol = tdx_symbol.replace('99', 'L9') underlying_symbol = get_underlying_symbol(symbol).upper() diff --git a/vnpy/task/celery_app.py b/vnpy/task/celery_app.py index e22b30c3..b45503d4 100644 --- a/vnpy/task/celery_app.py +++ b/vnpy/task/celery_app.py @@ -3,7 +3,6 @@ # Celery app # 该py脚本,为启动celery worker app # 在项目根目录下,运行 celery -A vnpy.task.celery worker -import time from celery import Celery import sys @@ -23,7 +22,7 @@ if vnpy_root not in sys.path: sys.path.append(vnpy_root) # 使用本地配置的 -from vnpy.trader.utility import load_json +from vnpy.trader.utility import load_json # noqa file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'celery_config.json')) celery_config = load_json(file_path) @@ -40,17 +39,8 @@ print(u'Celery 使用redis配置:\nbroker:{}\nbackend:{}'.format(broker, backend app = Celery('vnpy_task', broker=broker) - # 动态导入task目录下子任务 -# app.conf.CELERY_IMPORTS = ['vnpy.task.celery_app.worker_started'] - -# app.conf.update( -# CELERY_TASK_SERIALIZER='json', -# CELERY_RESULT_SERIALIZER='json', -# CELERY_ACCEPT_CONTENT=['json'], -# CELERY_TIMEZONE='Asia/Shanghai', -# CELERY_ENABLE_UTC=True -# ) +app.conf.CELERY_IMPORTS = ['vnpy.task.celery_app.worker_started'] def worker_started(): @@ -59,7 +49,7 @@ def worker_started(): import socket from vnpy.trader.util_wechat import send_wx_msg send_wx_msg(u'{} Celery Worker 启动'.format(socket.gethostname())) - except: + except: # noqa pass diff --git a/vnpy/trader/constant.py b/vnpy/trader/constant.py index 5aa2d84d..6c9d7ebc 100644 --- a/vnpy/trader/constant.py +++ b/vnpy/trader/constant.py @@ -86,40 +86,40 @@ class Exchange(Enum): Exchange. """ # Chinese - CFFEX = "CFFEX" # China Financial Futures Exchange - SHFE = "SHFE" # Shanghai Futures Exchange - CZCE = "CZCE" # Zhengzhou Commodity Exchange - DCE = "DCE" # Dalian Commodity Exchange - INE = "INE" # Shanghai International Energy Exchange - SSE = "SSE" # Shanghai Stock Exchange - SZSE = "SZSE" # Shenzhen Stock Exchange - SGE = "SGE" # Shanghai Gold Exchange - WXE = "WXE" # Wuxi Steel Exchange + CFFEX = "CFFEX" # China Financial Futures Exchange + SHFE = "SHFE" # Shanghai Futures Exchange + CZCE = "CZCE" # Zhengzhou Commodity Exchange + DCE = "DCE" # Dalian Commodity Exchange + INE = "INE" # Shanghai International Energy Exchange + SSE = "SSE" # Shanghai Stock Exchange + SZSE = "SZSE" # Shenzhen Stock Exchange + SGE = "SGE" # Shanghai Gold Exchange + WXE = "WXE" # Wuxi Steel Exchange # Global - SMART = "SMART" # Smart Router for US stocks - NYMEX = "NYMEX" # New York Mercantile Exchange - COMEX = "COMEX" # a division of theNew York Mercantile Exchange - GLOBEX = "GLOBEX" # Globex of CME - IDEALPRO = "IDEALPRO" # Forex ECN of Interactive Brokers - CME = "CME" # Chicago Mercantile Exchange - ICE = "ICE" # Intercontinental Exchange - SEHK = "SEHK" # Stock Exchange of Hong Kong - HKFE = "HKFE" # Hong Kong Futures Exchange - SGX = "SGX" # Singapore Global Exchange - CBOT = "CBT" # Chicago Board of Trade - CBOE = "CBOE" # Chicago Board Options Exchange - CFE = "CFE" # CBOE Futures Exchange - DME = "DME" # Dubai Mercantile Exchange - EUREX = "EUX" # Eurex Exchange - APEX = "APEX" # Asia Pacific Exchange - LME = "LME" # London Metal Exchange - BMD = "BMD" # Bursa Malaysia Derivatives - TOCOM = "TOCOM" # Tokyo Commodity Exchange - EUNX = "EUNX" # Euronext Exchange - KRX = "KRX" # Korean Exchange + SMART = "SMART" # Smart Router for US stocks + NYMEX = "NYMEX" # New York Mercantile Exchange + COMEX = "COMEX" # a division of theNew York Mercantile Exchange + GLOBEX = "GLOBEX" # Globex of CME + IDEALPRO = "IDEALPRO" # Forex ECN of Interactive Brokers + CME = "CME" # Chicago Mercantile Exchange + ICE = "ICE" # Intercontinental Exchange + SEHK = "SEHK" # Stock Exchange of Hong Kong + HKFE = "HKFE" # Hong Kong Futures Exchange + SGX = "SGX" # Singapore Global Exchange + CBOT = "CBT" # Chicago Board of Trade + CBOE = "CBOE" # Chicago Board Options Exchange + CFE = "CFE" # CBOE Futures Exchange + DME = "DME" # Dubai Mercantile Exchange + EUREX = "EUX" # Eurex Exchange + APEX = "APEX" # Asia Pacific Exchange + LME = "LME" # London Metal Exchange + BMD = "BMD" # Bursa Malaysia Derivatives + TOCOM = "TOCOM" # Tokyo Commodity Exchange + EUNX = "EUNX" # Euronext Exchange + KRX = "KRX" # Korean Exchange - OANDA = "OANDA" # oanda.com + OANDA = "OANDA" # oanda.com # CryptoCurrency BITMEX = "BITMEX" @@ -127,14 +127,15 @@ class Exchange(Enum): HUOBI = "HUOBI" BITFINEX = "BITFINEX" BINANCE = "BINANCE" - BYBIT = "BYBIT" # bybit.com + BYBIT = "BYBIT" # bybit.com COINBASE = "COINBASE" GATEIO = "GATEIO" BITSTAMP = "BITSTAMP" # Special Function - LOCAL = "LOCAL" # For local generated data - SPD = "SPD" # Customer Spread data + LOCAL = "LOCAL" # For local generated data + SPD = "SPD" # Customer Spread data + class Currency(Enum): """ diff --git a/vnpy/trader/engine.py b/vnpy/trader/engine.py index 2dde15e8..7e70ce93 100644 --- a/vnpy/trader/engine.py +++ b/vnpy/trader/engine.py @@ -57,8 +57,8 @@ class MainEngine: self.rm_engine = None - os.chdir(TRADER_DIR) # Change working directory - self.init_engines() # Initialize function engines + os.chdir(TRADER_DIR) # Change working directory + self.init_engines() # Initialize function engines def add_engine(self, engine_class: Any): """ @@ -176,7 +176,7 @@ class MainEngine: if gateway: gateway.subscribe(req) else: - for gateway in self.gateways.items(): + for gateway in self.gateways.values(): if gateway: gateway.subscribe(req) @@ -196,7 +196,8 @@ class MainEngine: """ gateway = self.get_gateway(gateway_name) if gateway: - gateway.cancel_order(req) + return gateway.cancel_order(req) + return False def send_orders(self, reqs: Sequence[OrderRequest], gateway_name: str): """ @@ -245,10 +246,10 @@ class BaseEngine(ABC): """ def __init__( - self, - main_engine: MainEngine, - event_engine: EventEngine, - engine_name: str, + self, + main_engine: MainEngine, + event_engine: EventEngine, + engine_name: str, ): """""" self.main_engine = main_engine @@ -588,12 +589,13 @@ class CustomContract(object): for symbol, setting in self.setting.items(): gateway_name = setting.get('gateway_name', None) if gateway_name is None: - gateway_name= SETTINGS.get('gateway_name','') + gateway_name = SETTINGS.get('gateway_name', '') vn_exchange = Exchange(setting.get('exchange', 'LOCAL')) contract = ContractData( gateway_name=gateway_name, symbol=symbol, - name=contract.symbol, + exchange=vn_exchange, + name=setting.get('name', symbol), size=setting.get('size', 100), pricetick=setting.get('price_tick', 0.01), margin_rate=setting.get('margin_rate', 0.1) @@ -602,6 +604,7 @@ class CustomContract(object): return d + class EmailEngine(BaseEngine): """ Provides email sending function for VN Trader. @@ -642,7 +645,7 @@ class EmailEngine(BaseEngine): msg = self.queue.get(block=True, timeout=1) with smtplib.SMTP_SSL( - SETTINGS["email.server"], SETTINGS["email.port"] + SETTINGS["email.server"], SETTINGS["email.port"] ) as smtp: smtp.login( SETTINGS["email.username"], SETTINGS["email.password"] diff --git a/vnpy/trader/event.py b/vnpy/trader/event.py index f698ffa1..5d9db8f2 100644 --- a/vnpy/trader/event.py +++ b/vnpy/trader/event.py @@ -26,4 +26,3 @@ EVENT_FUNDS_FLOW = 'eFundsFlow.' EVENT_ERROR = 'eError' EVENT_WARNING = 'eWarning' EVENT_CRITICAL = 'eCritical' - diff --git a/vnpy/trader/gateway.py b/vnpy/trader/gateway.py index b389ac39..df0a9ccb 100644 --- a/vnpy/trader/gateway.py +++ b/vnpy/trader/gateway.py @@ -262,7 +262,7 @@ class BaseGateway(ABC): implementation should finish the tasks blow: * send request to server """ - pass + return False def send_orders(self, reqs: Sequence[OrderRequest]): """ diff --git a/vnpy/trader/object.py b/vnpy/trader/object.py index d96eede2..800f54f4 100644 --- a/vnpy/trader/object.py +++ b/vnpy/trader/object.py @@ -8,7 +8,7 @@ from logging import INFO from .constant import Direction, Exchange, Interval, Offset, Status, Product, OptionType, OrderType -ACTIVE_STATUSES = set([Status.SUBMITTING, Status.NOTTRADED, Status.PARTTRADED]) +ACTIVE_STATUSES = set([Status.SUBMITTING, Status.NOTTRADED, Status.PARTTRADED, Status.CANCELLING]) @dataclass @@ -232,15 +232,15 @@ class AccountData(BaseData): """ accountid: str - pre_balance: float = 0 # 昨净值 - balance: float = 0 # 当前净值 - frozen: float = 0 # 冻结资金 - currency: str = "" # 币种 - commission: float = 0 # 手续费 - margin: float = 0 # 使用保证金 + pre_balance: float = 0 # 昨净值 + balance: float = 0 # 当前净值 + frozen: float = 0 # 冻结资金 + currency: str = "" # 币种 + commission: float = 0 # 手续费 + margin: float = 0 # 使用保证金 close_profit: float = 0 # 平仓盈亏 holding_profit: float = 0 # 持仓盈亏 - trading_day: str = "" # 当前交易日 + trading_day: str = "" # 当前交易日 def __post_init__(self): """""" @@ -256,18 +256,18 @@ class VtFundsFlowData(BaseData): accountid: str # 账户代码 exchange: Exchange = None - currency: str = "" # 币种 - trade_date: str = "" # 成交日期 - trade_price: float = 0 # 成交价格 + currency: str = "" # 币种 + trade_date: str = "" # 成交日期 + trade_price: float = 0 # 成交价格 trade_volume: float = 0 # 成交数量 trade_amount: float = 0 # 发生金额( 正数代表卖出,或者转入资金,获取分红等,负数代表买入股票或者出金) - fund_remain: float = 0 # 资金余额 - contract_id: str = "" # 合同编号 + fund_remain: float = 0 # 资金余额 + contract_id: str = "" # 合同编号 business_name: str = "" # 业务名称 - symbol: str = "" # 合约代码(证券代码) - holder_id: str = "" # 股东代码 - direction: str = "" # 买卖类别:转,买,卖.. - comment: str = "" # 备注 + symbol: str = "" # 合约代码(证券代码) + holder_id: str = "" # 股东代码 + direction: str = "" # 买卖类别:转,买,卖.. + comment: str = "" # 备注 def __post_init__(self): if self.exchange: @@ -339,6 +339,7 @@ class SubscribeRequest: def __eq__(self, other): return self.vt_symbol == other.vt_symbol + @dataclass class OrderRequest: """ diff --git a/vnpy/trader/utility.py b/vnpy/trader/utility.py index 404e1bc6..718deb6a 100644 --- a/vnpy/trader/utility.py +++ b/vnpy/trader/utility.py @@ -172,17 +172,20 @@ def extract_vt_symbol(vt_symbol: str): symbol, exchange_str = vt_symbol.split(".") return symbol, Exchange(exchange_str) + def generate_vt_symbol(symbol: str, exchange: Exchange): """ return vt_symbol """ return f"{symbol}.{exchange.value}" + def format_number(n): """格式化数字到字符串""" rn = round(n, 2) # 保留两位小数 return format(rn, ',') # 加上千分符 + def _get_trader_dir(temp_name: str): """ Get path where trader is running in. @@ -212,8 +215,9 @@ def _get_trader_dir(temp_name: str): TRADER_DIR, TEMP_DIR = _get_trader_dir(".vntrader") -sys.path.append(str(TRADER_DIR)) -print(f'sys.path append: {str(TRADER_DIR)}') +if TRADER_DIR not in sys.path: + sys.path.append(str(TRADER_DIR)) + print(f'sys.path append: {str(TRADER_DIR)}') def get_file_path(filename: str): @@ -306,7 +310,7 @@ def print_dict(d: dict): return '\n'.join([f'{key}:{d[key]}' for key in sorted(d.keys())]) -def append_data(self, file_name: str, dict_data: dict, field_names: list = []): +def append_data(file_name: str, dict_data: dict, field_names: list = []): """ 添加数据到csv文件中 :param file_name: csv的文件全路径 @@ -360,19 +364,254 @@ def import_module_by_str(import_module_name): comp = modules[-1] if not hasattr(mod, comp): - loaded_modules = '.'.join([loaded_modules,comp]) + loaded_modules = '.'.join([loaded_modules, comp]) print('realod {}'.format(loaded_modules)) mod = reload(loaded_modules) else: - print('from {} import {}'.format(loaded_modules,comp)) + print('from {} import {}'.format(loaded_modules, comp)) mod = getattr(mod, comp) return mod except Exception as ex: - print('import {} fail,{},{}'.format(import_module_name,str(ex),traceback.format_exc())) + print('import {} fail,{},{}'.format(import_module_name, str(ex), traceback.format_exc())) return None + +def save_df_to_excel(file_name, sheet_name, df): + """ + 保存dataframe到execl + :param file_name: 保存的excel文件名 + :param sheet_name: 保存的sheet + :param df: dataframe + :return: True/False + """ + if file_name is None or sheet_name is None or df is None: + return False + + # ----------------------------- 扩展的功能 --------- + try: + import openpyxl + from openpyxl.utils.dataframe import dataframe_to_rows + # from openpyxl.drawing.image import Image + except: # noqa + print(u'can not import openpyxl', file=sys.stderr) + + if 'openpyxl' not in sys.modules: + print(u'can not import openpyxl', file=sys.stderr) + return False + + try: + ws = None + + try: + # 读取文件 + wb = openpyxl.load_workbook(file_name) + except: # noqa + # 创建一个excel workbook + wb = openpyxl.Workbook() + ws = wb.active + ws.title = sheet_name + try: + # 定位WorkSheet + if ws is None: + ws = wb[sheet_name] + except: # noqa + # 创建一个WorkSheet + ws = wb.create_sheet() + ws.title = sheet_name + + rows = dataframe_to_rows(df) + for r_idx, row in enumerate(rows, 1): + for c_idx, value in enumerate(row, 1): + ws.cell(row=r_idx, column=c_idx, value=value) + + # Save the workbook + wb.save(file_name) + wb.close() + except Exception as ex: + import traceback + print(u'save_df_to_excel exception:{}'.format(str(ex)), traceback.format_exc(), file=sys.stderr) + + +def save_text_to_excel(file_name, sheet_name, text): + """ + 保存文本文件到excel + :param file_name: + :param sheet_name: + :param text: + :return: + """ + if file_name is None or len(sheet_name) == 0 or len(text) == 0: + return False + + # ----------------------------- 扩展的功能 --------- + try: + import openpyxl + # from openpyxl.utils.dataframe import dataframe_to_rows + # from openpyxl.drawing.image import Image + except: # noqa + print(u'can not import openpyxl', file=sys.stderr) + + if 'openpyxl' not in sys.modules: + return False + + try: + ws = None + try: + # 读取文件 + wb = openpyxl.load_workbook(file_name) + except: # noqa + # 创建一个excel workbook + wb = openpyxl.Workbook() + ws = wb.active + ws.title = sheet_name + try: + # 定位WorkSheet + if ws is None: + ws = wb[sheet_name] + except: # noqa + # 创建一个WorkSheet + ws = wb.create_sheet() + ws.title = sheet_name + + # 设置宽度,自动换行方式 + ws.column_dimensions["A"].width = 120 + ws['A2'].alignment = openpyxl.styles.Alignment(wrapText=True) + ws['A2'].value = text + + # Save the workbook + wb.save(file_name) + wb.close() + return True + except Exception as ex: + import traceback + print(u'save_text_to_excel exception:{}'.format(str(ex)), traceback.format_exc(), file=sys.stderr) + return False + + +def save_images_to_excel(file_name, sheet_name, image_names): + """ + # 保存图形文件到excel + :param file_name: excel文件名 + :param sheet_name: workSheet + :param image_names: 图像文件名列表 + :return: + """ + if file_name is None or len(sheet_name) == 0 or len(image_names) == 0: + return False + # ----------------------------- 扩展的功能 --------- + try: + import openpyxl + # from openpyxl.utils.dataframe import dataframe_to_rows + from openpyxl.drawing.image import Image + except Exception as ex: + print(f'can not import openpyxl:{str(ex)}', file=sys.stderr) + + if 'openpyxl' not in sys.modules: + return False + try: + ws = None + + try: + # 读取文件 + wb = openpyxl.load_workbook(file_name) + except: # noqa + # 创建一个excel workbook + wb = openpyxl.Workbook() + ws = wb.active + ws.title = sheet_name + try: + # 定位WorkSheet + if ws is None: + ws = wb[sheet_name] + except Exception as ex: # noqa + # 创建一个WorkSheet + ws = wb.create_sheet() + ws.title = sheet_name + + i = 1 + + for image_name in image_names: + try: + # 加载图形文件 + img1 = Image(image_name) + + cell_id = 'A{0}'.format(i) + ws[cell_id].value = image_name + cell_id = 'A{0}'.format(i + 1) + + i += 30 + + # 添加至对应的WorkSheet中 + ws.add_image(img1, cell_id) + except Exception as ex: + print('exception loading image {}, {}'.format(image_name, str(ex)), file=sys.stderr) + return False + + # Save the workbook + wb.save(file_name) + wb.close() + return True + except Exception as ex: + import traceback + print(u'save_images_to_excel exception:{}'.format(str(ex)), traceback.format_exc(), file=sys.stderr) + return False + + +def display_dual_axis(df, columns1, columns2=[], invert_yaxis1=False, invert_yaxis2=False, file_name=None, + sheet_name=None, + image_name=None): + """ + 显示(保存)双Y轴的走势图 + :param df: DataFrame + :param columns1: y1轴 + :param columns2: Y2轴 + :param invert_yaxis1: Y1 轴反转 + :param invert_yaxis2: Y2 轴翻转 + :param file_name: 保存的excel 文件名称 + :param sheet_name: excel 的sheet + :param image_name: 保存的image 文件名 + :return: + """ + + import matplotlib + import matplotlib.pyplot as plt + matplotlib.rcParams['figure.figsize'] = (20.0, 10.0) + + df1 = df[columns1] + df1.index = list(range(len(df))) + fig, ax1 = plt.subplots() + if invert_yaxis1: + ax1.invert_yaxis() + ax1.plot(df1) + + if len(columns2) > 0: + df2 = df[columns2] + df2.index = list(range(len(df))) + ax2 = ax1.twinx() + if invert_yaxis2: + ax2.invert_yaxis() + ax2.plot(df2) + + # 修改x轴得label为时间 + xt = ax1.get_xticks() + xt2 = [df.index[int(i)] for i in xt[1:-2]] + xt2.insert(0, '') + xt2.append('') + ax1.set_xticklabels(xt2) + + # 是否保存图片到文件 + if image_name is not None: + fig = plt.gcf() + fig.savefig(image_name, bbox_inches='tight') + + # 插入图片到指定的excel文件sheet中并保存excel + if file_name is not None and sheet_name is not None: + save_images_to_excel(file_name, sheet_name, [image_name]) + else: + plt.show() + class BarGenerator: """ For: