diff --git a/vnpy/app/cta_stock/__init__.py b/vnpy/app/cta_stock/__init__.py new file mode 100644 index 00000000..91963604 --- /dev/null +++ b/vnpy/app/cta_stock/__init__.py @@ -0,0 +1,34 @@ +from pathlib import Path + +from vnpy.trader.app import BaseApp +from .base import APP_NAME, StopOrder + +from .engine import CtaEngine + +from .template import ( + Exchange, + Direction, + Offset, + Status, + Color, + Interval, + TickData, + BarData, + TradeData, + OrderData, + CtaPolicy, + StockPolicy, + CtaTemplate, CtaStockTemplate) # noqa + +from vnpy.trader.utility import BarGenerator, ArrayManager # noqa + + +class CtaStockApp(BaseApp): + """""" + app_name = APP_NAME + app_module = __module__ + app_path = Path(__file__).parent + display_name = "股票CTA策略" + engine_class = CtaEngine + widget_name = "CtaManager" + icon_name = "cta.ico" diff --git a/vnpy/app/cta_stock/back_testing.py b/vnpy/app/cta_stock/back_testing.py new file mode 100644 index 00000000..528373b1 --- /dev/null +++ b/vnpy/app/cta_stock/back_testing.py @@ -0,0 +1,2070 @@ +# encoding: UTF-8 + +''' +本文件中包含的是CTA模块的组合回测引擎,回测引擎的API和CTA引擎一致, +可以使用和实盘相同的代码进行回测。 +华富资产 李来佳 +''' +from __future__ import division + +import sys +import os +import importlib +import csv +import copy +import pandas as pd +import traceback +import numpy as np +import logging +import socket +import zlib +import pickle +from bson import binary + +from collections import OrderedDict, defaultdict +from datetime import datetime, timedelta +from functools import lru_cache +from pathlib import Path + +from .base import ( + EngineType, + STOPORDER_PREFIX, + StopOrder, + StopOrderStatus +) +from .template import CtaTemplate + +from vnpy.component.cta_fund_kline import FundKline + +from vnpy.trader.object import ( + BarData, + TickData, + OrderData, + TradeData, + ContractData, + PositionData +) +from vnpy.trader.constant import ( + Exchange, + Direction, + Offset, + Status, + OrderType, + Product +) +from vnpy.trader.converter import PositionHolding + +from vnpy.trader.utility import ( + get_underlying_symbol, + round_to, + extract_vt_symbol, + format_number, + import_module_by_str, + get_stock_exchange +) +from vnpy.data.stock.adjust_factor import get_all_adjust_factor + +from vnpy.trader.util_logger import setup_logger +from vnpy.data.mongo.mongo_data import MongoData +from uuid import uuid1 + + +class BackTestingEngine(object): + """ + CTA回测引擎 + 函数接口和策略引擎保持一样, + 从而实现同一套代码从回测到实盘。 + 针对1分钟bar的回测 + 或者tick级别得回测 + 提供对组合回测/批量回测得服务 + + """ + + def __init__(self, event_engine=None): + """Constructor""" + + # 绑定事件引擎 + self.event_engine = event_engine + + self.mode = 'bar' # 'bar': 根据1分钟k线进行回测, 'tick',根据分笔tick进行回测 + + # 引擎类型为回测 + self.engine_type = EngineType.BACKTESTING + self.contract_type = 'stock' # future, stock, digital + + # 回测策略相关 + self.classes = {} # 策略类,class_name: stategy_class + self.class_module_map = {} # 策略类名与模块名映射 class_name: mudule_name + self.strategies = {} # 回测策略实例, key = strategy_name, value= strategy + self.symbol_strategy_map = defaultdict(list) # vt_symbol: strategy list + + self.test_name = 'stock_test_{}'.format(datetime.now().strftime('%M%S')) # 回测策略组合的实例名字 + self.daily_report_name = '' # 策略的日净值报告文件名称 + + # 回测日期相关 + self.test_start_date = '' # 组合回测启动得日期 + self.init_days = 0 # 初始化天数 + self.test_end_date = '' # 组合回测结束日期 + + self.data_start_date = None # 回测数据开始日期,datetime对象 (用于截取数据) + self.data_end_date = None # 回测数据结束日期,datetime对象 (用于截取数据) + self.strategy_start_date = None # 策略启动日期(即前面的数据用于初始化),datetime对象 + + # 回测数据相关 + self.adjust_factors = {} # 复权因子 + self.slippage = {} # 回测时假设的滑点 + self.commission_rate = {} # 回测时假设的佣金比例(适用于百分比佣金) + self.fix_commission = {} # 每手固定手续费 + self.size = {} # 合约大小,默认为1 + self.price_tick = {} # 价格最小变动 + self.volume_tick = {} # 合约委托单最小单位 + self.margin_rate = {} # 回测合约的保证金比率 + self.price_dict = {} # 登记vt_symbol对应的最新价 + self.contract_dict = {} # 登记vt_symbol得对应合约信息 + self.symbol_exchange_dict = {} # 登记symbol: exchange的对应关系 + + self.stop_order_count = 0 # 本地停止单编号 + self.stop_orders = {} # 本地停止单 + self.active_stop_orders = {} # 活动本地停止单 + + self.limit_order_count = 0 # 限价单编号 + self.limit_orders = OrderedDict() # 限价单字典 + self.active_limit_orders = OrderedDict() # 活动限价单字典,用于进行撮合用 + + self.order_strategy_dict = {} # orderid 与 strategy的映射 + + self.trade_count = 0 # 成交编号 + self.trade_dict = OrderedDict() # 用于统计成交收益时,还没处理得交易 + self.trades = OrderedDict() # 记录所有得成交记录 + self.trade_pnl_list = [] # 交易记录列表 + + self.long_position_list = [] # 多单持仓 + + self.positions = {} # 账号持仓,对象为PositionData + + # 当前最新数据,用于模拟成交用 + self.gateway_name = u'BackTest' + + self.last_bar = {} # 最新的bar + self.last_tick = {} # 最新tick + self.last_dt = None # 最新时间 + + # csvFile相关 + self.bar_interval_seconds = 60 # csv文件,属于K线类型,K线的周期(秒数),缺省是1分钟 + + # 费用风控情况 + self.percent = 0.0 + self.percent_limit = 30 # 投资仓位比例上限 + + # 回测计算相关 + self.use_margin = True # 使用保证金模式(期货使用,计算保证金时,按照开仓价计算。股票是按照当前价计算) + + self.init_capital = 1000000 # 期初资金 + self.cur_capital = self.init_capital # 当前资金净值 + self.net_capital = self.init_capital # 实时资金净值(每日根据capital和持仓浮盈计算) + self.max_capital = self.init_capital # 资金最高净值 + self.max_net_capital = self.init_capital + self.avaliable = self.init_capital + + self.max_pnl = 0 # 最高盈利 + self.min_pnl = 0 # 最大亏损 + self.max_occupy_rate = 0 # 最大保证金占比 + self.winning_result = 0 # 盈利次数 + self.losing_result = 0 # 亏损次数 + + self.total_trade_count = 0 # 总成交数量 + self.total_winning = 0 # 总盈利 + self.total_losing = 0 # 总亏损 + self.total_turnover = 0 # 总成交金额(合约面值) + self.total_commission = 0 # 总手续费 + self.total_slippage = 0 # 总滑点 + + self.time_list = [] # 时间序列 + self.pnl_list = [] # 每笔盈亏序列 + self.capital_list = [] # 盈亏汇总的时间序列 + self.drawdown_list = [] # 回撤的时间序列 + self.drawdown_rate_list = [] # 最大回撤比例的时间序列(成交结算) + + self.max_net_capital_time = '' + self.max_drawdown_rate_time = '' + self.daily_max_drawdown_rate = 0 # 按照日结算价计算 + + self.pnl_strategy_dict = {} # 策略实例的平仓盈亏 + + self.is_plot_daily = False + self.daily_list = [] # 按日统计得序列 + self.daily_first_benchmark = None + + self.logger = None + self.strategy_loggers = {} + self.debug = False + + self.is_7x24 = False + self.logs_path = None + self.data_path = None + + self.fund_kline_dict = {} + self.active_fund_kline = False + + # 回测任务/回测结果,保存在数据库中 + self.mongo_api = None + self.task_id = None + self.test_setting = None # 回测设置 + self.strategy_setting = None # 所有回测策略得设置 + + def create_fund_kline(self, name, use_renko=False): + """ + 创建资金曲线 + :param name: 账号名,或者策略名 + :param use_renko: + :return: + """ + setting = {} + setting.update({'name': name}) + setting['para_ma1_len'] = 5 + setting['para_ma2_len'] = 10 + setting['para_ma3_len'] = 20 + setting['para_active_yb'] = True + setting['price_tick'] = 0.01 + setting['underlying_symbol'] = 'fund' + setting['is_7x24'] = self.is_7x24 + + if use_renko: + # 使用砖图,高度是资金的千分之一 + setting['height'] = self.init_capital * 0.001 + setting['use_renko'] = True + + fund_kline = FundKline(cta_engine=self, setting=setting) + self.fund_kline_dict.update({name: fund_kline}) + return fund_kline + + def get_fund_kline(self, name: str = None): + # 指定资金账号/策略名 + if name: + kline = self.fund_kline_dict.get(name, None) + return kline + + # 没有指定账号,并且存在一个或多个资金K线 + if len(self.fund_kline_dict) > 0: + # 优先找vt_setting中,配置了strategy_groud的资金K线 + kline = self.fund_kline_dict.get(self.test_name, None) + + # 找不到,返回第一个 + if kline is None: + kline = self.fund_kline_dict.values()[0] + return kline + else: + return None + + def get_account(self, vt_accountid: str = ""): + """返回账号的实时权益,可用资金,仓位比例,投资仓位比例上限""" + if self.net_capital == 0.0: + self.percent = 0.0 + + return self.net_capital, self.avaliable, self.percent, self.percent_limit + + def set_test_start_date(self, start_date: str = '20100416', init_days: int = 10): + """设置回测的启动日期""" + self.test_start_date = start_date + self.init_days = init_days + + self.data_start_date = datetime.strptime(start_date, '%Y%m%d') + + # 初始化天数 + init_time_delta = timedelta(init_days) + + self.strategy_start_date = self.data_start_date + init_time_delta + self.write_log(u'设置:回测数据开始日期:{},初始化数据为{}天,策略自动启动日期:{}' + .format(self.data_start_date, self.init_days, self.strategy_start_date)) + + def set_test_end_date(self, end_date: str = ''): + """设置回测的结束日期""" + self.test_end_date = end_date + if end_date: + self.data_end_date = datetime.strptime(end_date, '%Y%m%d') + # 若不修改时间则会导致不包含dataEndDate当天数据 + self.data_end_date.replace(hour=23, minute=59) + else: + self.data_end_date = datetime.now() + self.write_log(u'设置:回测数据结束日期:{}'.format(self.data_end_date)) + + def set_init_capital(self, capital: float): + """设置期初净值""" + self.cur_capital = capital # 资金 + self.net_capital = capital # 实时资金净值(每日根据capital和持仓浮盈计算) + self.max_capital = capital # 资金最高净值 + self.max_net_capital = capital + self.avaliable = capital + self.init_capital = capital + + def set_margin_rate(self, vt_symbol: str, margin_rate: float): + """设置某个合约得保证金比率""" + self.margin_rate.update({vt_symbol: margin_rate}) + + @lru_cache() + def get_margin_rate(self, vt_symbol: str): + return self.margin_rate.get(vt_symbol, 0.05) + + def set_slippage(self, vt_symbol: str, slippage: float): + """设置滑点点数""" + self.slippage.update({vt_symbol: slippage}) + + @lru_cache() + def get_slippage(self, vt_symbol: str): + """获取滑点""" + return self.slippage.get(vt_symbol, 0) + + def set_size(self, vt_symbol: str, size: int): + """设置合约大小""" + self.size.update({vt_symbol: size}) + + @lru_cache() + def get_size(self, vt_symbol: str): + """查询合约的size""" + return self.size.get(vt_symbol, 10) + + @lru_cache() + def get_name(self, vt_symbol: str): + """查询中文名称""" + contract = self.get_contract(vt_symbol) + if contract: + return contract.name + else: + return vt_symbol + + def set_price(self, vt_symbol: str, price: float): + self.price_dict.update({vt_symbol: price}) + + def get_price(self, vt_symbol: str): + return self.price_dict.get(vt_symbol, None) + + def set_commission_rate(self, vt_symbol: str, rate: float): + """设置佣金比例""" + self.commission_rate.update({vt_symbol: rate}) + + if rate >= 0.1: + self.fix_commission.update({vt_symbol: rate}) + + def get_commission_rate(self, vt_symbol: str): + """ 获取保证金比例,缺省万分之一""" + return self.commission_rate.get(vt_symbol, float(0.00001)) + + def get_fix_commission(self, vt_symbol: str): + return self.fix_commission.get(vt_symbol, 0) + + def set_price_tick(self, vt_symbol: str, price_tick: float): + """设置价格最小变动""" + self.price_tick.update({vt_symbol: price_tick}) + + def get_price_tick(self, vt_symbol: str): + return self.price_tick.get(vt_symbol, 0.01) + + def set_volume_tick(self, vt_symbol: str, volume_tick: float): + """设置委托单最小单位""" + self.volume_tick.update({vt_symbol: volume_tick}) + + def get_volume_tick(self, vt_symbol: str): + return self.volume_tick.get(vt_symbol, 1) + + def set_contract(self, symbol: str, exchange: Exchange, product: Product, name: str, size: int, + price_tick: float, volume_tick: float = 1, margin_rate: float = 0.1): + """设置合约信息""" + vt_symbol = '.'.join([symbol, exchange.value]) + if vt_symbol not in self.contract_dict: + c = ContractData( + gateway_name=self.gateway_name, + symbol=symbol, + exchange=exchange, + name=name, + product=product, + size=size, + pricetick=price_tick, + min_volume=volume_tick, + margin_rate=margin_rate + ) + self.contract_dict.update({vt_symbol: c}) + self.set_size(vt_symbol, size) + self.set_margin_rate(vt_symbol, margin_rate) + self.set_price_tick(vt_symbol, price_tick) + self.set_volume_tick(vt_symbol, volume_tick) + self.symbol_exchange_dict.update({symbol: exchange}) + + @lru_cache() + def get_contract(self, vt_symbol): + """获取合约配置信息""" + return self.contract_dict.get(vt_symbol) + + @lru_cache() + def get_exchange(self, symbol: str): + return self.symbol_exchange_dict.get(symbol, Exchange.LOCAL) + + def get_position(self, vt_symbol: str, direction: Direction = Direction.NET, gateway_name: str = ''): + """ 查询合约在账号的持仓""" + if not gateway_name: + gateway_name = self.gateway_name + k = f'{gateway_name}.{vt_symbol}.{direction.value}' + pos = self.positions.get(k, None) + if not pos: + contract = self.get_contract(vt_symbol) + if not contract: + self.write_log(f'{vt_symbol}合约信息不存在,构造一个') + symbol, exchange = extract_vt_symbol(vt_symbol) + if self.contract_type == 'future': + product = Product.FUTURES + elif self.contract_type == 'stock': + product = Product.EQUITY + else: + product = Product.SPOT + contract = ContractData(gateway_name=gateway_name, + name=vt_symbol, + product=product, + symbol=symbol, + exchange=exchange, + size=self.get_size(vt_symbol), + pricetick=self.get_price_tick(vt_symbol), + margin_rate=self.get_margin_rate(vt_symbol)) + pos = PositionData( + gateway_name=gateway_name, + symbol=contract.symbol, + exchange=contract.exchange, + direction=direction + ) + self.positions[k] = pos + return pos + + def set_name(self, test_name): + """ + 设置组合的运行实例名称 + :param test_name: + :return: + """ + self.test_name = test_name + + def set_daily_report_name(self, report_file): + """ + 设置策略的日净值记录csv保存文件名(含路径) + :param report_file: 保存文件名(含路径) + :return: + """ + self.daily_report_name = report_file + + def prepare_env(self, test_setting): + """ + 根据配置参数,准备环境 + 包括: + 回测名称 ,是否debug,数据目录/日志目录, + 资金/保证金类型/仓位控制 + 回测开始/结束日期 + :param test_setting: + :return: + """ + self.test_setting = copy.copy(test_setting) + + self.output(f'回测引擎准备环境') + if 'name' in test_setting: + self.set_name(test_setting.get('name')) + + self.mode = test_setting.get('mode', 'bar') + self.output(f'采用{self.mode}方式回测') + + self.contract_type = test_setting.get('contract_type', 'stock') + self.output(f'测试合约主要为{self.contract_type}') + + self.debug = test_setting.get('debug', False) + + # 更新数据目录 + if 'data_path' in test_setting: + self.data_path = test_setting.get('data_path') + else: + self.data_path = os.path.abspath(os.path.join(os.getcwd(), 'data')) + + print(f'数据输出目录:{self.data_path}') + + # 更新日志目录 + if 'logs_path' in test_setting: + self.logs_path = os.path.abspath(os.path.join(test_setting.get('logs_path'), self.test_name)) + else: + self.logs_path = os.path.abspath(os.path.join(os.getcwd(), 'log', self.test_name)) + print(f'日志输出目录:{self.logs_path}') + + # 创建日志 + self.create_logger(debug=self.debug) + + # 设置资金 + if 'init_capital' in test_setting: + self.write_log(u'设置期初资金:{}'.format(test_setting.get('init_capital'))) + self.set_init_capital(test_setting.get('init_capital')) + + # 缺省使用保证金方式。(期货使用保证金/股票不使用保证金) + self.use_margin = test_setting.get('use_margin', False) + + # 设置最大资金使用比例 + if 'percent_limit' in test_setting: + self.write_log(u'设置最大资金使用比例:{}%'.format(test_setting.get('percent_limit'))) + self.percent_limit = test_setting.get('percent_limit') + + if 'start_date' in test_setting: + if 'strategy_start_date' not in test_setting: + init_days = test_setting.get('init_days', 10) + self.write_log(u'设置回测开始日期:{},数据加载日数:{}'.format(test_setting.get('start_date'), init_days)) + self.set_test_start_date(test_setting.get('start_date'), init_days) + else: + start_date = test_setting.get('start_date') + strategy_start_date = test_setting.get('strategy_start_date') + self.write_log(u'使用指定的数据开始日期:{}和策略启动日期:{}'.format(start_date, strategy_start_date)) + self.test_start_date = start_date + self.data_start_date = datetime.strptime(start_date.replace('-', ''), '%Y%m%d') + self.strategy_start_date = datetime.strptime(strategy_start_date.replace('-', ''), '%Y%m%d') + + if 'end_date' in test_setting: + self.write_log(u'设置回测结束日期:{}'.format(test_setting.get('end_date'))) + self.set_test_end_date(test_setting.get('end_date')) + + # 准备数据 + if 'symbol_datas' in test_setting: + self.write_log(u'准备数据') + self.prepare_data(test_setting.get('symbol_datas')) + + if self.mode == 'tick': + self.tick_path = test_setting.get('tick_path', None) + + # 设置bar文件的时间间隔秒数 + if 'bar_interval_seconds' in test_setting: + self.write_log(u'设置bar文件的时间间隔秒数:{}'.format(test_setting.get('bar_interval_seconds'))) + self.bar_interval_seconds = test_setting.get('bar_interval_seconds') + + # 资金曲线 + self.active_fund_kline = test_setting.get('active_fund_kline', False) + if self.active_fund_kline: + # 创建资金K线 + self.create_fund_kline(self.test_name, use_renko=test_setting.get('use_renko', False)) + + self.is_plot_daily = test_setting.get('is_plot_daily', False) + + # 加载所有本地策略class + self.load_strategy_class() + + def prepare_data(self, data_dict): + """ + 准备组合数据 + :param data_dict: + :return: + """ + self.output('prepare_data') + + self.write_log(f'获取所有股票的复权因子') + self.adjust_factors = get_all_adjust_factor() + + if len(data_dict) == 0: + self.write_log(u'请指定回测数据和文件') + return + + for vt_symbol, symbol_data in data_dict.items(): + self.write_log(u'配置{}数据:{}'.format(vt_symbol, symbol_data)) + self.set_price_tick(vt_symbol, symbol_data.get('price_tick', 0.01)) + self.set_volume_tick(vt_symbol, symbol_data.get('min_volume', 10)) + self.set_slippage(vt_symbol, symbol_data.get('slippage', 0)) + self.set_size(vt_symbol, symbol_data.get('symbol_size', 1)) + margin_rate = symbol_data.get('margin_rate', 1) + self.set_margin_rate(vt_symbol, margin_rate) + + self.set_commission_rate(vt_symbol, symbol_data.get('commission_rate', float(0.0001))) + symbol, exchange = extract_vt_symbol(vt_symbol) + self.set_contract( + symbol=symbol, + name=symbol_data.get('name',vt_symbol), + exchange=exchange, + product=Product(symbol_data.get('product', "股票")), + size=symbol_data.get('symbol_size', 1), + price_tick=symbol_data.get('price_tick', 0.01), + volume_tick=symbol_data.get('min_volume', 10), + margin_rate=margin_rate + ) + + def stock_to_adj(self, raw_data, adj_data, adj_type): + """ + 股票数据复权转换 + :param raw_data: 不复权数据 + :param adj_data: 复权记录 ( 从barstock下载的复权记录列表=》df) + :param adj_type: 复权类型 + :return: + """ + + if adj_type == 'fore': + adj_factor = adj_data["foreAdjustFactor"] + adj_factor = adj_factor / adj_factor.iloc[-1] # 保证最后一个复权因子是1 + else: + adj_factor = adj_data["backAdjustFactor"] + adj_factor = adj_factor / adj_factor.iloc[0] # 保证第一个复权因子是1 + + # 把raw_data的第一个日期,插入复权因子df,使用后填充 + adj_factor.loc[raw_data.index[0]] = np.nan + adj_factor.sort_index(inplace=True) + adj_factor = adj_factor.ffill() + + adj_factor = adj_factor.reindex(index=raw_data.index) # 按价格dataframe的日期索引来扩展索引 + adj_factor = adj_factor.ffill() # 向前(向未来)填充扩展后的空单元格 + + # 把复权因子,作为adj字段,补充到raw_data中 + raw_data['adj'] = adj_factor + + # 逐一复权高低开平和成交量 + for col in ['open', 'high', 'low', 'close']: + raw_data[col] = raw_data[col] * raw_data['adj'] # 价格乘上复权系数 + raw_data['volume'] = raw_data['volume'] / raw_data['adj'] # 成交量除以复权系数 + + return raw_data + + def new_tick(self, tick): + """新得tick""" + self.last_tick.update({tick.vt_symbol: tick}) + if self.last_dt is None or (tick.datetime and tick.datetime > self.last_dt): + self.last_dt = tick.datetime + + self.set_price(tick.vt_symbol, tick.last_price) + + self.cross_stop_order(tick=tick) # 撮合停止单 + self.cross_limit_order(tick=tick) # 先撮合限价单 + + # 更新账号级别资金曲线(只有持仓时,才更新) + fund_kline = self.get_fund_kline(self.test_name) + if fund_kline is not None and len(self.long_position_list) > 0: + fund_kline.update_account(self.last_dt, self.net_capital) + + for strategy in self.symbol_strategy_map.get(tick.vt_symbol, []): + # 更新策略的资金K线 + fund_kline = self.fund_kline_dict.get(strategy.strategy_name, None) + if fund_kline: + hold_pnl, _ = fund_kline.get_hold_pnl() + if hold_pnl != 0: + fund_kline.update_strategy(dt=self.last_dt, hold_pnl=hold_pnl) + + # 推送tick到策略中 + strategy.on_tick({tick.vt_symbol: tick}) # 推送K线到策略中 + + # 到达策略启动日期,启动策略 + if not strategy.trading and self.strategy_start_date < tick.datetime: + strategy.trading = True + strategy.on_start() + self.output(u'{}策略启动交易'.format(strategy.strategy_name)) + + def new_bar(self, bar): + """新的K线""" + self.last_bar.update({bar.vt_symbol: bar}) + if self.last_dt is None or (bar.datetime and bar.datetime > self.last_dt - timedelta(seconds=self.bar_interval_seconds)): + self.last_dt = bar.datetime + timedelta(seconds=self.bar_interval_seconds) + self.set_price(bar.vt_symbol, bar.close_price) + self.cross_stop_order(bar=bar) # 撮合停止单 + self.cross_limit_order(bar=bar) # 先撮合限价单 + + # 更新账号的资金曲线(只有持仓时,才更新) + fund_kline = self.get_fund_kline(self.test_name) + if fund_kline is not None and len(self.long_position_list) > 0: + fund_kline.update_account(self.last_dt, self.net_capital) + + for strategy in self.symbol_strategy_map.get(bar.vt_symbol, []): + # 更新策略的资金K线 + fund_kline = self.fund_kline_dict.get(strategy.strategy_name, None) + if fund_kline: + hold_pnl, _ = fund_kline.get_hold_pnl() + if hold_pnl != 0: + fund_kline.update_strategy(dt=self.last_dt, hold_pnl=hold_pnl) + + # 推送K线到策略中 + strategy.on_bar({bar.vt_symbol: bar}) # 推送K线到策略中 + + # 到达策略启动日期,启动策略 + if not strategy.trading and self.strategy_start_date < bar.datetime: + strategy.trading = True + strategy.on_start() + self.output(u'{}策略启动交易'.format(strategy.strategy_name)) + + def load_strategy_class(self): + """ + Load strategy class from source code. + """ + self.write_log('加载所有策略class') + # 加载 vnpy/app/cta_strategy_pro/strategies的所有策略 + path1 = Path(__file__).parent.joinpath("strategies") + self.load_strategy_class_from_folder( + path1, "vnpy.app.cta_stock.strategies") + + def load_strategy_class_from_folder(self, path: Path, module_name: str = ""): + """ + Load strategy class from certain folder. + """ + for dirpath, dirnames, filenames in os.walk(str(path)): + for filename in filenames: + if filename.endswith(".py"): + strategy_module_name = ".".join( + [module_name, filename.replace(".py", "")]) + elif filename.endswith(".pyd"): + strategy_module_name = ".".join( + [module_name, filename.split(".")[0]]) + elif filename.endswith(".so"): + strategy_module_name = ".".join( + [module_name, filename.split(".")[0]]) + else: + continue + self.load_strategy_class_from_module(strategy_module_name) + + def load_strategy_class_from_module(self, module_name: str): + """ + Load/Reload strategy class from module file. + """ + try: + module = importlib.import_module(module_name) + + for name in dir(module): + value = getattr(module, name) + if (isinstance(value, type) and issubclass(value, CtaTemplate) and value is not CtaTemplate): + class_name = value.__name__ + if class_name not in self.classes: + self.write_log(f"加载策略类{module_name}.{class_name}") + else: + self.write_log(f"更新策略类{module_name}.{class_name}") + self.classes[class_name] = value + self.class_module_map[class_name] = module_name + return True + except: # noqa + msg = f"策略文件{module_name}加载失败,触发异常:\n{traceback.format_exc()}" + self.write_error(msg) + self.output(msg) + return False + + def load_strategy(self, strategy_name: str, strategy_setting: dict = None): + """ + 装载回测的策略 + setting是参数设置,包括 + class_name: str, 策略类名字 + vt_symbol: str, 缺省合约 + setting: {}, 策略的参数 + auto_init: True/False, 策略是否自动初始化 + auto_start: True/False, 策略是否自动启动 + """ + + # 获取策略的类名 + class_name = strategy_setting.get('class_name', None) + if class_name is None or strategy_name is None: + self.write_error(u'setting中没有class_name') + return + + # strategy_class => module.strategy_class + if '.' not in class_name: + module_name = self.class_module_map.get(class_name, None) + if module_name: + class_name = module_name + '.' + class_name + self.write_log(u'转换策略为全路径:{}'.format(class_name)) + + # 获取策略类的定义 + strategy_class = import_module_by_str(class_name) + if strategy_class is None: + self.write_error(u'加载策略模块失败:{}'.format(class_name)) + return + + vt_symbols = strategy_setting.get('vt_symbols', []) + + # 取消自动启动 + if 'auto_start' in strategy_setting: + strategy_setting.update({'auto_start': False}) + + # 策略参数设置 + setting = strategy_setting.get('setting', {}) + + # 强制更新回测为True + setting.update({'backtesting': True}) + + # 创建实例 + strategy = strategy_class(self, strategy_name, vt_symbols, setting) + + # 保存到策略实例映射表中 + self.strategies.update({strategy_name: strategy}) + + # 更新vt_symbol合约与策略的订阅关系 + for vt_symbol in vt_symbols: + if '.' in vt_symbol: + symbol, exchange = extract_vt_symbol(vt_symbol) + else: + symbol = vt_symbol + exchange_str = get_stock_exchange(code=symbol) + vt_symbol = '.'.join([symbol, exchange_str]) + + self.subscribe_symbol(strategy_name=strategy_name, vt_symbol=vt_symbol) + + if strategy_setting.get('auto_init', False): + self.write_log(u'自动初始化策略') + strategy.on_init() + + if strategy_setting.get('auto_start', False): + self.write_log(u'自动启动策略') + strategy.on_start() + + if self.active_fund_kline: + # 创建策略实例的资金K线 + self.create_fund_kline(name=strategy_name, use_renko=False) + + def subscribe_symbol(self, strategy_name: str, vt_symbol: str, gateway_name: str = '', is_bar: bool = False): + """订阅合约""" + strategy = self.strategies.get(strategy_name, None) + if not strategy: + return False + + # 添加 合约订阅 vt_symbol <=> 策略实例 strategy 映射. + strategies = self.symbol_strategy_map[vt_symbol] + strategies.append(strategy) + return True + + # --------------------------------------------------------------------- + def save_strategy_data(self): + """保存策略数据""" + for strategy in self.strategies.values(): + self.write_log(u'save strategy data') + strategy.save_data() + + def send_order(self, + strategy: CtaTemplate, + vt_symbol: str, + direction: Direction, + offset: Offset, + price: float, + volume: float, + stop: bool, + order_type: OrderType = OrderType.LIMIT, + gateway_name: str = None): + """发单""" + price_tick = self.get_price_tick(vt_symbol) + price = round_to(price, price_tick) + + if stop: + return self.send_local_stop_order( + strategy=strategy, + vt_symbol=vt_symbol, + direction=direction, + offset=offset, + price=price, + volume=volume, + gateway_name=gateway_name + ) + else: + return self.send_limit_order( + strategy=strategy, + vt_symbol=vt_symbol, + direction=direction, + offset=offset, + price=price, + volume=volume, + gateway_name=gateway_name + ) + + def send_limit_order(self, + strategy: CtaTemplate, + vt_symbol: str, + direction: Direction, + offset: Offset, + price: float, + volume: float, + order_type: OrderType = OrderType.LIMIT, + gateway_name: str = None + ): + + """ 发限价单""" + self.limit_order_count += 1 + order_id = str(self.limit_order_count) + symbol, exchange = extract_vt_symbol(vt_symbol) + if gateway_name is None: + gateway_name = self.gateway_name + order = OrderData( + gateway_name=gateway_name, + symbol=symbol, + exchange=exchange, + orderid=order_id, + direction=direction, + offset=offset, + type=order_type, + price=round_to(value=price, target=self.get_price_tick(vt_symbol)), + volume=volume, + status=Status.NOTTRADED, + time=str(self.last_dt) + ) + + # 保存到限价单字典中 + self.active_limit_orders[order.vt_orderid] = order + self.limit_orders[order.vt_orderid] = order + self.order_strategy_dict.update({order.vt_orderid: strategy}) + + self.write_log(f'创建限价单:{order.__dict__}') + + return [order.vt_orderid] + + def send_local_stop_order( + self, + strategy: CtaTemplate, + vt_symbol: str, + direction: Direction, + offset: Offset, + price: float, + volume: float, + gateway_name: str = None): + + """发出本地停止单""" + self.stop_order_count += 1 + + stop_order = StopOrder( + vt_symbol=vt_symbol, + direction=direction, + offset=offset, + price=price, + volume=volume, + stop_orderid=f"{STOPORDER_PREFIX}.{self.stop_order_count}", + strategy_name=strategy.strategy_name, + ) + self.write_log(f'创建本地停止单:{stop_order.__dict__}') + self.order_strategy_dict.update({stop_order.stop_orderid: strategy}) + + self.active_stop_orders[stop_order.stop_orderid] = stop_order + self.stop_orders[stop_order.stop_orderid] = stop_order + + return [stop_order.stop_orderid] + + def cancel_order(self, strategy: CtaTemplate, vt_orderid: str): + """撤单""" + if vt_orderid.startswith(STOPORDER_PREFIX): + return self.cancel_stop_order(strategy, vt_orderid) + else: + return self.cancel_limit_order(strategy, vt_orderid) + + def cancel_limit_order(self, strategy: CtaTemplate, vt_orderid: str): + """限价单撤单""" + if vt_orderid in self.active_limit_orders: + order = self.active_limit_orders[vt_orderid] + register_strategy = self.order_strategy_dict.get(vt_orderid, None) + if register_strategy.strategy_name != strategy.strategy_name: + return False + order.status = Status.CANCELLED + order.cancel_time = str(self.last_dt) + self.active_limit_orders.pop(vt_orderid, None) + strategy.on_order(order) + return True + return False + + def cancel_stop_order(self, strategy: CtaTemplate, vt_orderid: str): + """本地停止单撤单""" + if vt_orderid not in self.active_stop_orders: + return False + stop_order = self.active_stop_orders.pop(vt_orderid) + + stop_order.status = StopOrderStatus.CANCELLED + strategy.on_stop_order(stop_order) + return True + + def cancel_all(self, strategy): + """撤销某个策略的所有委托单""" + self.cancel_orders(strategy=strategy) + + def cancel_orders(self, vt_symbol: str = None, offset: Offset = None, strategy: CtaTemplate = None): + """撤销所有单""" + # Symbol参数:指定合约的撤单; + # OFFSET参数:指定Offset的撤单,缺省不填写时,为所有 + # strategy参数: 指定某个策略的单子 + + if len(self.active_limit_orders) > 0: + self.write_log(u'从所有订单中,撤销:开平:{},合约:{},策略:{}' + .format(offset, + vt_symbol if vt_symbol is not None else u'所有', + strategy.strategy_name if strategy else None)) + + for vt_orderid in list(self.active_limit_orders.keys()): + order = self.active_limit_orders.get(vt_orderid, None) + order_strategy = self.order_strategy_dict.get(vt_orderid, None) + if order is None or order_strategy is None: + continue + + if offset is None: + offset_cond = True + else: + offset_cond = order.offset == offset + + if vt_symbol is None: + symbol_cond = True + else: + symbol_cond = order.vt_symbol == vt_symbol + + if strategy is None: + strategy_cond = True + else: + strategy_cond = strategy.strategy_name == order_strategy.strategy_name + + if offset_cond and symbol_cond and strategy_cond: + self.write_log(u'撤销订单:{},{} {}@{}' + .format(vt_orderid, order.direction, order.price, order.volume)) + order.status = Status.CANCELLED + order.cancel_time = str(self.last_dt) + del self.active_limit_orders[vt_orderid] + if strategy: + strategy.on_order(order) + + # 撤销本地停止单 + for stop_orderid in list(self.active_stop_orders.keys()): + order = self.active_stop_orders.get(stop_orderid, None) + order_strategy = self.order_strategy_dict.get(stop_orderid, None) + if order is None or order_strategy is None: + continue + + if offset is None: + offset_cond = True + else: + offset_cond = order.offset == offset + + if vt_symbol is None: + symbol_cond = True + else: + symbol_cond = order.vt_symbol == vt_symbol + + if strategy is None: + strategy_cond = True + else: + strategy_cond = strategy.strategy_name == order_strategy.strategy_name + + if offset_cond and symbol_cond and strategy_cond: + self.write_log(u'撤销本地停止单:{},{} {}@{}' + .format(stop_orderid, order.direction, order.price, order.volume)) + order.status = Status.CANCELLED + order.cancel_time = str(self.last_dt) + self.active_stop_orders.pop(stop_orderid, None) + if strategy: + strategy.on_stop_order(order) + + def cross_stop_order(self, bar: BarData = None, tick: TickData = None): + """ + 本地停止单撮合 + Cross stop order with last bar/tick data. + """ + vt_symbol = bar.vt_symbol if bar else tick.vt_symbol + + for stop_orderid in list(self.active_stop_orders.keys()): + stop_order = self.active_stop_orders[stop_orderid] + strategy = self.order_strategy_dict.get(stop_orderid, None) + if stop_order.vt_symbol != vt_symbol or stop_order is None or strategy is None: + continue + + # 若买入方向停止单价格高于等于该价格,则会触发 + if bar: + long_cross_price = round_to(value=bar.low_price, target=self.get_price_tick(vt_symbol)) + long_cross_price -= self.get_price_tick(vt_symbol) + # 若卖出方向停止单价格低于等于该价格,则会触发 + sell_cross_price = round_to(value=bar.high_price, target=self.get_price_tick(vt_symbol)) + sell_cross_price += self.get_price_tick(vt_symbol) + # 在当前时间点前发出的买入委托可能的最优成交价 + long_best_price = round_to(value=bar.open_price, + target=self.get_price_tick(vt_symbol)) + self.get_price_tick(vt_symbol) + + # 在当前时间点前发出的卖出委托可能的最优成交价 + sell_best_price = round_to(value=bar.open_price, + target=self.get_price_tick(vt_symbol)) - self.get_price_tick(vt_symbol) + else: + long_cross_price = tick.last_price + sell_cross_price = tick.last_price + long_best_price = tick.last_price + sell_best_price = tick.last_price + + # Check whether stop order can be triggered. + long_cross = stop_order.direction == Direction.LONG and stop_order.price <= long_cross_price + + sell_cross = stop_order.direction == Direction.SHORT and stop_order.price >= sell_cross_price + + if not long_cross and not sell_cross: + continue + + # Create order data. + self.limit_order_count += 1 + symbol, exchange = extract_vt_symbol(vt_symbol) + order = OrderData( + symbol=symbol, + exchange=exchange, + orderid=str(self.limit_order_count), + direction=stop_order.direction, + offset=stop_order.offset, + price=stop_order.price, + volume=stop_order.volume, + status=Status.ALLTRADED, + gateway_name=self.gateway_name, + ) + order.datetime = self.last_dt + self.write_log(f'停止单被触发:\n{stop_order.__dict__}\n=>委托单{order.__dict__}') + self.limit_orders[order.vt_orderid] = order + + # Create trade data. + if long_cross: + trade_price = max(stop_order.price, long_best_price) + else: + trade_price = min(stop_order.price, sell_best_price) + + self.trade_count += 1 + + trade = TradeData( + symbol=order.symbol, + exchange=order.exchange, + orderid=order.orderid, + tradeid=str(self.trade_count), + direction=order.direction, + offset=order.offset, + price=trade_price, + volume=order.volume, + time=self.last_dt.strftime("%Y-%m-%d %H:%M:%S"), + datetime=self.last_dt, + gateway_name=self.gateway_name, + ) + trade.strategy_name = strategy.strategy_name + trade.datetime = self.last_dt + self.write_log(f'停止单触发成交:{trade.__dict__}') + self.trade_dict[trade.vt_tradeid] = trade + self.trades[trade.vt_tradeid] = copy.copy(trade) + + # Update stop order. + stop_order.vt_orderids.append(order.vt_orderid) + stop_order.status = StopOrderStatus.TRIGGERED + + self.active_stop_orders.pop(stop_order.stop_orderid) + + # Push update to strategy. + strategy.on_stop_order(stop_order) + strategy.on_order(order) + self.append_trade(trade) + + # 更新持仓缓存数据 + pos = self.get_position(vt_symbol=trade.vt_symbol, direction=Direction.NET) + pre_volume = pos.volume + if trade.direction == Direction.LONG: + pos.volume = round(pos.volume + trade.volume, 7) + else: + pos.volume = round(pos.volume - trade.volume, 7) + self.write_log(f'{trade.vt_symbol} volume:{pre_volume} => {pos.volume}') + + strategy.on_trade(trade) + + def cross_limit_order(self, bar: BarData = None, tick: TickData = None): + """基于最新数据撮合限价单""" + + vt_symbol = bar.vt_symbol if bar else tick.vt_symbol + + # 遍历限价单字典中的所有限价单 + for vt_orderid in list(self.active_limit_orders.keys()): + order = self.active_limit_orders.get(vt_orderid, None) + if order.vt_symbol != vt_symbol: + continue + + strategy = self.order_strategy_dict.get(order.vt_orderid, None) + if strategy is None: + self.write_error(u'找不到vt_orderid:{}对应的策略'.format(order.vt_orderid)) + continue + if bar: + price_tick = self.get_price_tick(vt_symbol) + + buy_cross_price = round_to(value=bar.low_price, target=price_tick) + price_tick # 若买入方向限价单价格高于该价格,则会成交 + sell_cross_price = round_to(value=bar.high_price, + target=price_tick) - price_tick # 若卖出方向限价单价格低于该价格,则会成交 + buy_best_cross_price = round_to(value=bar.open_price, + target=price_tick) + price_tick # 在当前时间点前发出的买入委托可能的最优成交价 + sell_best_cross_price = round_to(value=bar.open_price, + target=price_tick) - price_tick # 在当前时间点前发出的卖出委托可能的最优成交价 + else: + buy_cross_price = tick.last_price + sell_cross_price = tick.last_price + buy_best_cross_price = tick.last_price + sell_best_cross_price = tick.last_price + + # 判断是否会成交 + buy_cross = order.direction == Direction.LONG and order.price >= buy_cross_price + sell_cross = order.direction == Direction.SHORT and order.price <= sell_cross_price + + # 如果发生了成交 + if buy_cross or sell_cross: + # 推送成交数据 + self.trade_count += 1 # 成交编号自增1 + + trade_id = str(self.trade_count) + symbol, exchange = extract_vt_symbol(vt_symbol) + trade = TradeData( + gateway_name=self.gateway_name, + symbol=symbol, + exchange=exchange, + tradeid=trade_id, + orderid=order.orderid, + direction=order.direction, + offset=order.offset, + volume=order.volume, + time=self.last_dt.strftime("%Y-%m-%d %H:%M:%S"), + datetime=self.last_dt + ) + + # 以买入为例: + # 1. 假设当根K线的OHLC分别为:100, 125, 90, 110 + # 2. 假设在上一根K线结束(也是当前K线开始)的时刻,策略发出的委托为限价105 + # 3. 则在实际中的成交价会是100而不是105,因为委托发出时市场的最优价格是100 + if buy_cross: + trade_price = min(order.price, buy_best_cross_price) + + else: + trade_price = max(order.price, sell_best_cross_price) + trade.price = trade_price + + # 记录该合约来自哪个策略实例 + trade.strategy_name = strategy.strategy_name + + # 更新持仓缓存数据 + pos = self.get_position(vt_symbol=trade.vt_symbol, direction=Direction.NET) + pre_volume = pos.volume + if trade.direction == Direction.LONG: + pos.volume = round(pos.volume + trade.volume, 7) + else: + pos.volume = round(pos.volume - trade.volume, 7) + self.write_log(f'{trade.vt_symbol} volume:{pre_volume} => {pos.volume}') + + self.trade_dict[trade.vt_tradeid] = trade + self.trades[trade.vt_tradeid] = copy.copy(trade) + self.write_log(u'vt_trade_id:{0}'.format(trade.vt_tradeid)) + + self.write_log(u'{} : crossLimitOrder: TradeId:{}'.format(trade.strategy_name, + trade.tradeid, + )) + + # 写入交易记录 + self.append_trade(trade) + + strategy.on_trade(trade) + + # 更新资金曲线 + fund_kline = self.get_fund_kline(trade.strategy_name) + if fund_kline: + fund_kline.update_trade(trade) + + # 推送委托数据 + order.traded = order.volume + order.status = Status.ALLTRADED + + strategy.on_order(order) + + # 从字典中删除该限价单 + self.active_limit_orders.pop(vt_orderid, None) + + # 实时计算模式 + self.realtime_calculate() + + def update_position_yd(self): + """更新持仓信息,把今仓=>昨仓""" + + for k, v in self.positions.items(): + if v.volume > 0: + self.write_log(f'调整{v.vt_symbol}持仓: 昨仓{v.yd_volume} => {v.volume}') + v.yd_volume = v.volume + + def get_data_path(self): + """ + 获取数据保存目录 + :return: + """ + if self.data_path is not None: + data_folder = self.data_path + else: + data_folder = os.path.abspath(os.path.join(os.getcwd(), 'data')) + self.data_path = data_folder + if not os.path.exists(data_folder): + os.makedirs(data_folder) + return data_folder + + def get_logs_path(self): + """ + 获取日志保存目录 + :return: + """ + if self.logs_path is not None: + logs_folder = self.logs_path + else: + logs_folder = os.path.abspath(os.path.join(os.getcwd(), 'log')) + self.logs_path = logs_folder + + if not os.path.exists(logs_folder): + os.makedirs(logs_folder) + + return logs_folder + + def create_logger(self, strategy_name=None, debug=False): + """ + 创建日志 + :param strategy_name 策略实例名称 + :param debug:是否详细记录日志 + :return: + """ + if strategy_name is None: + filename = os.path.abspath(os.path.join(self.get_logs_path(), '{}'.format( + self.test_name if len(self.test_name) > 0 else 'portfolio_test'))) + print(u'create logger:{}'.format(filename)) + self.logger = setup_logger(file_name=filename, + name=self.test_name, + log_level=logging.DEBUG if debug else logging.ERROR, + backtesing=True) + else: + filename = os.path.abspath( + os.path.join(self.get_logs_path(), '{}_{}'.format(self.test_name, str(strategy_name)))) + print(u'create logger:{}'.format(filename)) + self.strategy_loggers[strategy_name] = setup_logger(file_name=filename, + name=str(strategy_name), + log_level=logging.DEBUG if debug else logging.ERROR, + backtesing=True) + + def write_log(self, msg: str, strategy_name: str = None, level: int = logging.DEBUG): + """记录日志""" + # log = str(self.datetime) + ' ' + content + # self.logList.append(log) + + if strategy_name is None: + # 写入本地log日志 + if self.logger: + self.logger.log(msg=msg, level=level) + else: + self.create_logger(debug=self.debug) + else: + if strategy_name in self.strategy_loggers: + self.strategy_loggers[strategy_name].log(msg=msg, level=level) + else: + self.create_logger(strategy_name=strategy_name, debug=self.debug) + + def write_error(self, msg, strategy_name=None): + """记录异常""" + + if strategy_name is None: + if self.logger: + self.logger.error(msg) + else: + self.create_logger(debug=self.debug) + else: + if strategy_name in self.strategy_loggers: + self.strategy_loggers[strategy_name].error(msg) + else: + self.create_logger(strategy_name=strategy_name, debug=self.debug) + try: + self.strategy_loggers[strategy_name].error(msg) + except Exception as ex: + print('{}'.format(datetime.now()), file=sys.stderr) + print('could not create cta logger for {},excption:{},trace:{}'.format(strategy_name, str(ex), + traceback.format_exc())) + print(msg, file=sys.stderr) + + def output(self, content): + """输出内容""" + print(self.test_name + "\t" + content) + + def realtime_calculate(self): + """实时计算交易结果 + 支持多空仓位并存""" + + if len(self.trade_dict) < 1: + return + + # 获取所有未处理得成交单 + vt_tradeids = list(self.trade_dict.keys()) + + result_list = [] # 保存交易记录 + longid = '' + + # 对交易记录逐一处理 + for vt_tradeid in vt_tradeids: + + trade = self.trade_dict.pop(vt_tradeid, None) + if trade is None: + continue + + if trade.volume == 0: + continue + # buy trade 开多买入 + if trade.direction == Direction.LONG: + self.write_log(f'{trade.vt_symbol} buy, price:{trade.price},volume:{trade.volume}') + # 放入多单仓位队列 + self.long_position_list.append(trade) + + # sell trade + elif trade.direction == Direction.SHORT: + g_id = trade.vt_tradeid # 交易组(多个平仓数为一组) + g_result = None # 组合的交易结果 + + sell_volume = trade.volume + + while sell_volume > 0: + if len(self.long_position_list) == 0: + self.write_error(f'异常,没有{trade.vt_symbol}的多仓') + raise RuntimeError(u'realtimeCalculate2() Exception,没有开多单') + return + + pop_indexs = [i for i, val in enumerate(self.long_position_list) if + val.vt_symbol == trade.vt_symbol and val.strategy_name == trade.strategy_name] + if len(pop_indexs) < 1: + self.write_error(f'没有{trade.strategy_name}对应的symbol{trade.vt_symbol}多单数据,') + raise RuntimeError( + f'realtimeCalculate2() Exception,没有对应的symbol{trade.vt_symbol}多单数据,') + return + + cur_long_pos_list = [s_pos.volume for s_pos in self.long_position_list] + + self.write_log(u'{}当前多单:{}'.format(trade.vt_symbol, cur_long_pos_list)) + + pop_index = pop_indexs[0] + open_trade = self.long_position_list.pop(pop_index) + # 开多volume,不大于平仓volume + if sell_volume >= open_trade.volume: + self.write_log(f'{open_trade.vt_symbol},Sell Volume:{sell_volume} 满足:{open_trade.volume}') + sell_volume = sell_volume - open_trade.volume + sell_volume = round(sell_volume, 7) + self.write_log(f'{open_trade.vt_symbol},sell, price:{trade.price},volume:{open_trade.volume}') + + result = TradingResult(open_price=open_trade.price, + open_datetime=open_trade.datetime, + exit_price=trade.price, + close_datetime=trade.datetime, + volume=open_trade.volume, + rate=self.get_commission_rate(trade.vt_symbol), + slippage=self.get_slippage(trade.vt_symbol), + size=self.get_size(trade.vt_symbol), + group_id=g_id, + margin_rate=self.get_margin_rate(trade.vt_symbol), + fix_commission=self.get_fix_commission(trade.vt_symbol)) + + t = OrderedDict() + t['gid'] = g_id + t['strategy'] = open_trade.strategy_name + t['vt_symbol'] = open_trade.vt_symbol + t['open_time'] = open_trade.time + t['open_price'] = open_trade.price + t['direction'] = u'Long' + t['close_time'] = trade.time + t['close_price'] = trade.price + t['volume'] = open_trade.volume + t['profit'] = result.pnl + t['commission'] = result.commission + self.trade_pnl_list.append(t) + + # 更新策略实例的累加盈亏 + self.pnl_strategy_dict.update( + {open_trade.strategy_name: self.pnl_strategy_dict.get(open_trade.strategy_name, + 0) + result.pnl}) + + msg = u'gid:{} {}[{}:开多tid={}:{}]-[{}.平多tid={},{},vol:{}],净盈亏pnl={},手续费:{}' \ + .format(g_id, open_trade.vt_symbol, + open_trade.time, longid, open_trade.price, + trade.time, vt_tradeid, trade.price, + open_trade.volume, result.pnl, result.commission) + + self.write_log(msg) + result_list.append(result) + + if g_result is None: + if sell_volume > 0: + # 属于组合 + g_result = copy.deepcopy(result) + else: + # 更新组合的数据 + g_result.turnover = g_result.turnover + result.turnover + g_result.commission = g_result.commission + result.commission + g_result.slippage = g_result.slippage + result.slippage + g_result.pnl = g_result.pnl + result.pnl + + if sell_volume == 0: + g_result.volume = abs(trade.volume) + + # 开多volume,大于平仓volume,需要更新减少tradeDict的数量。 + else: + remain_volume = open_trade.volume - sell_volume + remain_volume = round(remain_volume, 7) + self.write_log(f'{open_trade.vt_symbol} pos: {open_trade.volume} => {remain_volume}') + + result = TradingResult(open_price=open_trade.price, + open_datetime=open_trade.datetime, + exit_price=trade.price, + close_datetime=trade.datetime, + volume=sell_volume, + rate=self.get_commission_rate(trade.vt_symbol), + slippage=self.get_slippage(trade.vt_symbol), + size=self.get_size(trade.vt_symbol), + group_id=g_id, + margin_rate=self.get_margin_rate(trade.vt_symbol), + fix_commission=self.get_fix_commission(trade.vt_symbol)) + + t = OrderedDict() + t['gid'] = g_id + t['strategy'] = open_trade.strategy_name + t['vt_symbol'] = open_trade.vt_symbol + t['open_time'] = open_trade.time + t['open_price'] = open_trade.price + t['direction'] = u'Long' + t['close_time'] = trade.time + t['close_price'] = trade.price + t['volume'] = sell_volume + t['profit'] = result.pnl + t['commission'] = result.commission + self.trade_pnl_list.append(t) + + # 更新策略实例的累加盈亏 + self.pnl_strategy_dict.update( + {open_trade.strategy_name: self.pnl_strategy_dict.get(open_trade.strategy_name, + 0) + result.pnl}) + + msg = u'Gid:{} {}[{}:开多tid={}:{}]-[{}.平多tid={},{},vol:{}],净盈亏pnl={},手续费:{}' \ + .format(g_id, open_trade.vt_symbol, open_trade.time, longid, open_trade.price, + trade.time, vt_tradeid, trade.price, sell_volume, result.pnl, + result.commission) + + self.write_log(msg) + + # 减少开多volume,重新推进多单持仓列表中 + open_trade.volume = remain_volume + self.long_position_list.append(open_trade) + + sell_volume = 0 + result_list.append(result) + + if g_result is not None: + # 更新组合的数据 + g_result.turnover = g_result.turnover + result.turnover + g_result.commission = g_result.commission + result.commission + g_result.slippage = g_result.slippage + result.slippage + g_result.pnl = g_result.pnl + result.pnl + g_result.volume = abs(trade.volume) + + if g_result is not None: + self.write_log(u'组合净盈亏:{0}'.format(g_result.pnl)) + + # 计算仓位比例 + holding_cost = 0.0 # 持仓成本 + long_pos_dict = {} + + if len(self.long_position_list) > 0: + for t in self.long_position_list: + # 当前持仓的成本 + cur_holding_cost = t.price * t.volume + holding_cost = holding_cost + cur_holding_cost + + # 可用资金 = 当前净值 - 占用保证金 + self.avaliable = self.net_capital - holding_cost + # 当前成本占比 + self.percent = round(float(holding_cost * 100 / self.net_capital), 2) + # 更新最大成本占比 + self.max_occupy_rate = max(self.max_occupy_rate, self.percent) + + # 检查是否有平交易 + if len(result_list) == 0: + msg = u'' + if len(self.long_position_list) > 0: + msg += u'持多仓{0},'.format(str(long_pos_dict)) + + msg += u'资金占用:{0},仓位:{1}%%'.format(holding_cost, self.percent) + + self.write_log(msg) + return + + # 对交易结果汇总统计 + for result in result_list: + if result.pnl > 0: + self.winning_result += 1 + self.total_winning += result.pnl + else: + self.losing_result += 1 + self.total_losing += result.pnl + self.cur_capital += result.pnl + self.max_capital = max(self.cur_capital, self.max_capital) + self.net_capital = max(self.net_capital, self.cur_capital) + self.max_net_capital = max(self.net_capital, self.max_net_capital) + # self.maxVolume = max(self.maxVolume, result.volume) + drawdown = self.net_capital - self.max_net_capital + drawdown_rate = round(float(drawdown * 100 / self.max_net_capital), 4) + + self.pnl_list.append(result.pnl) + self.time_list.append(result.close_datetime) + self.capital_list.append(self.cur_capital) + self.drawdown_list.append(drawdown) + self.drawdown_rate_list.append(drawdown_rate) + + self.total_trade_count += 1 + self.total_turnover += result.turnover + self.total_commission += result.commission + self.total_slippage += result.slippage + + msg = u'[gid:{}] {} 交易盈亏:{},交易手续费:{}回撤:{}/{},账号平仓权益:{},持仓权益:{},累计手续费:{}' \ + .format(result.group_id, result.close_datetime, result.pnl, result.commission, drawdown, + drawdown_rate, self.cur_capital, self.net_capital, self.total_commission) + + self.write_log(msg) + + # 重新计算一次avaliable + self.avaliable = self.net_capital - holding_cost + self.percent = round(float(holding_cost * 100 / self.net_capital), 2) + + def saving_daily_data(self, d, c, m, commission, benchmark=0): + """保存每日数据""" + data = {} + data['date'] = d.strftime('%Y/%m/%d') # 日期 + data['capital'] = c # 当前平仓净值 + data['max_capital'] = m # 之前得最高净值 + today_holding_profit = 0 # 持仓浮盈 + holding_cost = 0 + + strategy_pnl = {} + for strategy in self.strategies.keys(): + strategy_pnl.update({strategy: self.pnl_strategy_dict.get(strategy, 0)}) + + positionMsg = "" + for longpos in self.long_position_list: + symbol = longpos.vt_symbol + # 计算持仓浮盈浮亏/占用保证金 + holding_profit = 0 + last_price = self.get_price(symbol) + if last_price is not None: + holding_profit = (last_price - longpos.price) * longpos.volume + holding_cost += last_price * abs(longpos.volume) + + # 账号的持仓盈亏 + today_holding_profit += holding_profit + + # 计算每个策略实例的持仓盈亏 + strategy_pnl.update({longpos.strategy_name: strategy_pnl.get(longpos.strategy_name, 0) + holding_profit}) + + positionMsg += "{},long,p={},v={},m={};".format(symbol, longpos.price, longpos.volume, holding_profit) + + data['net'] = c + today_holding_profit # 当日净值(含持仓盈亏) + data['rate'] = (c + today_holding_profit) / self.init_capital + data['holding_cost'] = holding_cost + data['occupy_rate'] = data['holding_cost'] / data['capital'] + data['commission'] = commission + + data.update(self.price_dict) + + data.update(strategy_pnl) + + self.daily_list.append(data) + + # 更新每日浮动净值 + self.net_capital = data['net'] + + # 更新最大初次持仓浮盈净值 + if data['net'] > self.max_net_capital: + self.max_net_capital = data['net'] + self.max_net_capital_time = data['date'] + drawdown_rate = round((float(self.max_net_capital - data['net']) * 100) / self.max_net_capital, 4) + if drawdown_rate > self.daily_max_drawdown_rate: + self.daily_max_drawdown_rate = drawdown_rate + self.max_drawdown_rate_time = data['date'] + + msg = u'{}: net={}, capital={} max={} margin={} commission={}, pos: {}' \ + .format(data['date'], + data['net'], c, m, + today_holding_profit, + commission, + positionMsg) + + if not self.debug: + self.output(msg) + else: + self.write_log(msg) + + # --------------------------------------------------------------------- + def export_trade_result(self): + """ + 导出交易结果(开仓-》平仓, 平仓收益) + 导出每日净值结果表 + :return: + """ + if len(self.trade_pnl_list) == 0: + self.write_log('no traded records') + return + + s = self.test_name.replace('&', '') + s = s.replace(' ', '') + trade_list_csv_file = os.path.abspath(os.path.join(self.get_logs_path(), '{}_trade_list.csv'.format(s))) + + self.write_log(u'save trade records to:{}'.format(trade_list_csv_file)) + import csv + csv_write_file = open(trade_list_csv_file, 'w', encoding='utf8', newline='') + + fieldnames = ['gid', 'strategy', 'vt_symbol', 'open_time', 'open_price', 'direction', 'close_time', + 'close_price', + 'volume', 'profit', 'commission'] + + writer = csv.DictWriter(f=csv_write_file, fieldnames=fieldnames, dialect='excel') + writer.writeheader() + + for row in self.trade_pnl_list: + writer.writerow(row) + + # 导出每日净值记录表 + if not self.daily_list: + return + + if self.daily_report_name == '': + daily_csv_file = os.path.abspath(os.path.join(self.get_logs_path(), '{}_daily_list.csv'.format(s))) + else: + daily_csv_file = self.daily_report_name + self.write_log(u'save daily records to:{}'.format(daily_csv_file)) + + csv_write_file2 = open(daily_csv_file, 'w', encoding='utf8', newline='') + fieldnames = ['date', 'capital', 'net', 'max_capital', 'rate', 'commission', 'holding_cost', 'occupy_rate', 'today_margin_long'] + # 添加合约的每日close价 + fieldnames.extend(sorted(self.price_dict.keys())) + # 添加策略列表 + fieldnames.extend(sorted(self.strategies.keys())) + writer2 = csv.DictWriter(f=csv_write_file2, fieldnames=fieldnames, dialect='excel') + writer2.writeheader() + + for row in self.daily_list: + writer2.writerow(row) + + if self.is_plot_daily: + # 生成净值曲线图片 + df = pd.DataFrame(self.daily_list) + df = df.set_index('date') + from vnpy.trader.utility import display_dual_axis + plot_file = os.path.abspath(os.path.join(self.get_logs_path(), '{}_plot.png'.format(s))) + + # 双坐标输出,左侧坐标是净值(比率),右侧是各策略的实际资金收益曲线 + display_dual_axis(df=df, columns1=['rate'], columns2=list(self.strategies.keys()), image_name=plot_file) + + return + + def get_result(self): + # 返回回测结果 + d = {} + d['init_capital'] = self.init_capital + d['profit'] = self.cur_capital - self.init_capital + d['max_capital'] = self.max_net_capital # 取消原 maxCapital + + if len(self.pnl_list) == 0: + return {}, [], [] + + d['max_pnl'] = max(self.pnl_list) + d['min_pnl'] = min(self.pnl_list) + + d['max_occupy_rate'] = self.max_occupy_rate + d['total_trade_count'] = self.total_trade_count + d['total_turnover'] = self.total_turnover + d['total_commission'] = self.total_commission + d['total_slippage'] = self.total_slippage + d['time_list'] = self.time_list + d['pnl_list'] = self.pnl_list + d['capital_list'] = self.capital_list + d['drawdown_list'] = self.drawdown_list + d['drawdown_rate_list'] = self.drawdown_rate_list # 净值最大回撤率列表 + d['winning_rate'] = round(100 * self.winning_result / len(self.pnl_list), 4) + + average_winning = 0 # 这里把数据都初始化为0 + average_losing = 0 + profit_loss_ratio = 0 + + if self.winning_result: + average_winning = self.total_winning / self.winning_result # 平均每笔盈利 + if self.losing_result: + average_losing = self.total_losing / self.losing_result # 平均每笔亏损 + if average_losing: + profit_loss_ratio = -average_winning / average_losing # 盈亏比 + + d['average_winning'] = average_winning + d['average_losing'] = average_losing + d['profit_loss_ratio'] = profit_loss_ratio + + # 计算Sharp + if not self.daily_list: + return {}, [], [] + + capital_net_list = [] + capital_list = [] + for row in self.daily_list: + capital_net_list.append(row['net']) + capital_list.append(row['capital']) + + capital = pd.Series(capital_net_list) + log_returns = np.log(capital).diff().fillna(0) + sharpe = (log_returns.mean() * 252) / (log_returns.std() * np.sqrt(252)) + d['sharpe'] = sharpe + + return d, capital_net_list, capital_list + + def show_backtesting_result(self): + """显示回测结果""" + + d, daily_net_capital, daily_capital = self.get_result() + + if len(d) == 0: + self.output(u'无交易结果') + return {}, '' + + # 导出交易清单 + self.export_trade_result() + + result_info = OrderedDict() + + # 输出 + self.output('-' * 30) + result_info.update({u'第一笔交易': str(d['time_list'][0])}) + self.output(u'第一笔交易:\t%s' % d['time_list'][0]) + + result_info.update({u'最后一笔交易': str(d['time_list'][-1])}) + self.output(u'最后一笔交易:\t%s' % d['time_list'][-1]) + + result_info.update({u'总交易次数': d['total_trade_count']}) + self.output(u'总交易次数:\t%s' % format_number(d['total_trade_count'])) + + result_info.update({u'期初资金': d['init_capital']}) + self.output(u'期初资金:\t%s' % format_number(d['init_capital'])) + + result_info.update({u'总盈亏': d['profit']}) + self.output(u'总盈亏:\t%s' % format_number(d['profit'])) + + result_info.update({u'资金最高净值': d['max_capital']}) + self.output(u'资金最高净值:\t%s' % format_number(d['max_capital'])) + + result_info.update({u'资金最高净值时间': str(self.max_net_capital_time)}) + self.output(u'资金最高净值时间:\t%s' % self.max_net_capital_time) + + result_info.update({u'每笔最大盈利': d['max_pnl']}) + self.output(u'每笔最大盈利:\t%s' % format_number(d['max_pnl'])) + + result_info.update({u'每笔最大亏损': d['min_pnl']}) + self.output(u'每笔最大亏损:\t%s' % format_number(d['min_pnl'])) + + result_info.update({u'净值最大回撤': min(d['drawdown_list'])}) + self.output(u'净值最大回撤: \t%s' % format_number(min(d['drawdown_list']))) + + result_info.update({u'净值最大回撤率': self.daily_max_drawdown_rate}) + self.output(u'净值最大回撤率: \t%s' % format_number(self.daily_max_drawdown_rate)) + + result_info.update({u'净值最大回撤时间': str(self.max_drawdown_rate_time)}) + self.output(u'净值最大回撤时间:\t%s' % self.max_drawdown_rate_time) + + result_info.update({u'胜率': d['winning_rate']}) + self.output(u'胜率:\t%s' % format_number(d['winning_rate'])) + + result_info.update({u'盈利交易平均值': d['average_winning']}) + self.output(u'盈利交易平均值\t%s' % format_number(d['average_winning'])) + + result_info.update({u'亏损交易平均值': d['average_losing']}) + self.output(u'亏损交易平均值\t%s' % format_number(d['average_losing'])) + + result_info.update({u'盈亏比': d['profit_loss_ratio']}) + self.output(u'盈亏比:\t%s' % format_number(d['profit_loss_ratio'])) + + result_info.update({u'最大资金占比': d['max_occupy_rate']}) + self.output(u'最大资金占比:\t%s' % format_number(d['max_occupy_rate'])) + + result_info.update({u'平均每笔盈利': d['profit'] / d['total_trade_count']}) + self.output(u'平均每笔盈利:\t%s' % format_number(d['profit'] / d['total_trade_count'])) + + result_info.update({u'平均每笔滑点成本': d['total_slippage'] / d['total_trade_count']}) + self.output(u'平均每笔滑点成本:\t%s' % format_number(d['total_slippage'] / d['total_trade_count'])) + + result_info.update({u'平均每笔佣金': d['total_commission'] / d['total_trade_count']}) + self.output(u'平均每笔佣金:\t%s' % format_number(d['total_commission'] / d['total_trade_count'])) + + result_info.update({u'Sharpe Ratio': d['sharpe']}) + self.output(u'Sharpe Ratio:\t%s' % format_number(d['sharpe'])) + + # 保存回测结果/交易记录/日线统计 至数据库 + self.save_result_to_mongo(result_info) + + return result_info + + def save_setting_to_mongo(self): + """ 保存测试设置到mongo中""" + self.task_id = self.test_setting.get('task_id', str(uuid1())) + + # 保存到mongo得配置 + save_mongo = self.test_setting.get('save_mongo', {}) + if len(save_mongo) == 0: + return + + if not self.mongo_api: + self.mongo_api = MongoData(host=save_mongo.get('host', 'localhost'), port=save_mongo.get('port', 27017)) + + # 移除含有"."的数据 + test_setting = {} + if self.test_setting is not None: + for k, v in self.test_setting.items(): + if k not in ['symbol_datas']: + test_setting[k] = v + strategy_setting = {} + if self.strategy_setting is not None: + for k, v in self.strategy_setting.items(): + if k not in ['symbol_datas']: + strategy_setting[k] = v + d = { + 'task_id': self.task_id, # 单实例回测任务id + 'name': self.test_name, # 回测实例名称, 策略名+参数+时间 + 'group_id': self.test_setting.get('group_id', datetime.now().strftime('%y-%m-%d')), # 回测组合id + 'status': 'start', + 'task_start_time': datetime.now(), # 任务开始执行时间 + 'run_host': socket.gethostname(), # 任务运行得host主机 + 'test_setting': test_setting, # 回测参数 + 'strategy_setting': strategy_setting, # 策略参数 + } + + # 保存入数据库 + self.mongo_api.db_insert( + db_name=self.gateway_name, + col_name='tasks', + d=d) + + def save_fail_to_mongo(self, fail_msg): + # 保存到mongo得配置 + save_mongo = self.test_setting.get('save_mongo', {}) + if len(save_mongo) == 0: + return + + if not self.mongo_api: + self.mongo_api = MongoData(host=save_mongo.get('host', 'localhost'), port=save_mongo.get('port', 27017)) + + # 更新数据到数据库回测记录中 + flt = {'task_id': self.task_id} + + d = self.mongo_api.db_query_one( + db_name=self.gateway_name, + col_name='tasks', + flt=flt) + + if d: + d.update({'status': 'fail'}) # 更新状态未完成 + d.update({'fail_msg': fail_msg}) + + self.write_log(u'更新回测结果至数据库') + + self.mongo_api.db_update( + db_name=self.gateway_name, + col_name='tasks', + filter_dict=flt, + data_dict=d, + replace=False) + + def save_result_to_mongo(self, result_info): + + # 保存到mongo得配置 + save_mongo = self.test_setting.get('save_mongo', {}) + if len(save_mongo) == 0: + return + + if not self.mongo_api: + self.mongo_api = MongoData(host=save_mongo.get('host', 'localhost'), port=save_mongo.get('port', 27017)) + + # 更新数据到数据库回测记录中 + flt = {'task_id': self.task_id} + + d = self.mongo_api.db_query_one( + db_name=self.gateway_name, + col_name='tasks', + flt=flt) + + if d: + d.update({'status': 'finish'}) # 更新状态未完成 + d.update(result_info) # 补充回测结果 + d.update({'task_finish_time': datetime.now()}) # 更新回测完成时间 + d.update({'trade_list': binary.Binary(zlib.compress(pickle.dumps(self.trade_pnl_list)))}) # 更新交易记录 + d.update({'daily_list': binary.Binary(zlib.compress(pickle.dumps(self.daily_list)))}) # 更新每日净值记录 + + self.write_log(u'更新回测结果至数据库') + + self.mongo_api.db_update( + db_name=self.gateway_name, + col_name='tasks', + filter_dict=flt, + data_dict=d, + replace=False) + + def put_strategy_event(self, strategy: CtaTemplate): + """发送策略更新事件,回测中忽略""" + pass + + def clear_backtesting_result(self): + """清空之前回测的结果""" + # 清空限价单相关 + self.limit_order_count = 0 + self.limit_orders.clear() + self.active_limit_orders.clear() + + # 清空成交相关 + self.trade_count = 0 + self.trade_dict.clear() + self.trades.clear() + self.trade_pnl_list = [] + + def append_trade(self, trade: TradeData): + """ + 根据策略名称,写入 logs\test_name_straetgy_name_trade.csv文件 + :param trade: + :return: + """ + strategy_name = getattr(trade, 'strategy_name', self.test_name) + trade_fields = ['symbol', 'exchange', 'vt_symbol', 'tradeid', + 'vt_tradeid', 'orderid', 'vt_orderid', + 'direction', + 'offset', 'price', 'volume', 'time'] + + d = OrderedDict() + try: + for k in trade_fields: + if k in ['exchange', 'direction', 'offset']: + d[k] = getattr(trade, k).value + else: + d[k] = getattr(trade, k, '') + + trade_file = os.path.abspath(os.path.join(self.get_logs_path(), '{}_trade.csv'.format(strategy_name))) + self.append_data(file_name=trade_file, dict_data=d) + except Exception as ex: + self.write_error(u'写入交易记录csv出错:{},{}'.format(str(ex), traceback.format_exc())) + + # 保存记录相关 + def append_data(self, file_name: str, dict_data: OrderedDict, field_names: list = None): + """ + 添加数据到csv文件中 + :param file_name: csv的文件全路径 + :param dict_data: OrderedDict + :return: + """ + if field_names is None or field_names == []: + dict_fieldnames = list(dict_data.keys()) + else: + dict_fieldnames = field_names + + try: + if not os.path.exists(file_name): + self.write_log(u'create csv file:{}'.format(file_name)) + with open(file_name, 'a', encoding='utf8', newline='') as csvWriteFile: + writer = csv.DictWriter(f=csvWriteFile, fieldnames=dict_fieldnames, dialect='excel') + self.write_log(u'write csv header:{}'.format(dict_fieldnames)) + writer.writeheader() + writer.writerow(dict_data) + else: + with open(file_name, 'a', encoding='utf8', newline='') as csvWriteFile: + writer = csv.DictWriter(f=csvWriteFile, fieldnames=dict_fieldnames, dialect='excel', + extrasaction='ignore') + writer.writerow(dict_data) + except Exception as ex: + self.write_error(u'append_data exception:{}'.format(str(ex))) + + +######################################################################## +class TradingResult(object): + """每笔交易的结果""" + + def __init__(self, open_price, open_datetime, exit_price, close_datetime, volume, rate, slippage, size, group_id, + margin_rate, fix_commission=0.0): + """Constructor""" + self.open_price = open_price # 开仓价格 + self.exit_price = exit_price # 平仓价格 + + self.open_datetime = open_datetime # 开仓时间datetime + self.close_datetime = close_datetime # 平仓时间 + + self.volume = volume # 交易数量(+/-代表方向) + self.group_id = group_id # 主交易ID(针对多手平仓) + + self.turnover = (self.open_price + self.exit_price) * abs(volume) * margin_rate # 成交金额(实际保证金金额) + if fix_commission > 0: + self.commission = fix_commission * abs(self.volume) + else: + self.commission = abs(self.turnover * rate) # 手续费成本 + self.slippage = slippage * 2 * abs(self.turnover) # 滑点成本 + self.pnl = ((self.exit_price - self.open_price) * volume + - self.commission - self.slippage) # 净盈亏 diff --git a/vnpy/app/cta_stock/base.py b/vnpy/app/cta_stock/base.py new file mode 100644 index 00000000..bb7793e5 --- /dev/null +++ b/vnpy/app/cta_stock/base.py @@ -0,0 +1,53 @@ +""" +Defines constants and objects used in CtaCrypto App. +""" + +from dataclasses import dataclass, field +from enum import Enum +from datetime import timedelta +from vnpy.trader.constant import Direction, Offset, Interval + +APP_NAME = "CtaStock" +STOPORDER_PREFIX = "STOP" + + +class StopOrderStatus(Enum): + WAITING = "等待中" + CANCELLED = "已撤销" + TRIGGERED = "已触发" + + +class EngineType(Enum): + LIVE = "实盘" + BACKTESTING = "回测" + + +class BacktestingMode(Enum): + BAR = 1 + TICK = 2 + + +@dataclass +class StopOrder: + vt_symbol: str + direction: Direction + offset: Offset + price: float + volume: float + stop_orderid: str + strategy_name: str + lock: bool = False + vt_orderids: list = field(default_factory=list) + status: StopOrderStatus = StopOrderStatus.WAITING + gateway_name: str = None + + +EVENT_CTA_LOG = "eCtaLog" +EVENT_CTA_STRATEGY = "eCtaStrategy" +EVENT_CTA_STOPORDER = "eCtaStopOrder" + +INTERVAL_DELTA_MAP = { + Interval.MINUTE: timedelta(minutes=1), + Interval.HOUR: timedelta(hours=1), + Interval.DAILY: timedelta(days=1), +} diff --git a/vnpy/app/cta_stock/engine.py b/vnpy/app/cta_stock/engine.py new file mode 100644 index 00000000..259e090e --- /dev/null +++ b/vnpy/app/cta_stock/engine.py @@ -0,0 +1,1717 @@ +""" +数字货币CTA策略运行引擎 +华富资产: +""" + +import importlib +import os +import sys +import traceback +import json +import pickle +import bz2 + +from collections import defaultdict +from pathlib import Path +from typing import Any, Callable, List, Dict +from datetime import datetime, timedelta +from collections import OrderedDict +from concurrent.futures import ThreadPoolExecutor +from copy import copy +from functools import lru_cache +from uuid import uuid1 + +from vnpy.event import Event, EventEngine +from vnpy.trader.engine import BaseEngine, MainEngine +from vnpy.trader.object import ( + OrderRequest, + SubscribeRequest, + LogData, + TickData, + BarData, + PositionData, + ContractData, + HistoryRequest +) + +from vnpy.trader.event import ( + EVENT_TIMER, + EVENT_TICK, + EVENT_BAR, + EVENT_ORDER, + EVENT_TRADE, + EVENT_POSITION, + EVENT_STRATEGY_POS, + EVENT_STRATEGY_SNAPSHOT +) +from vnpy.trader.constant import ( + Direction, + OrderType, + Offset, + Status, + Interval +) +from vnpy.trader.utility import ( + load_json, + save_json, + extract_vt_symbol, + round_to, + TRADER_DIR, + get_folder_path, + get_underlying_symbol, + append_data) + +from vnpy.trader.util_logger import setup_logger, logging +from vnpy.trader.util_wechat import send_wx_msg +from vnpy.trader.converter import PositionHolding + +from .base import ( + APP_NAME, + EVENT_CTA_LOG, + EVENT_CTA_STRATEGY, + EVENT_CTA_STOPORDER, + EngineType, + StopOrder, + StopOrderStatus, + STOPORDER_PREFIX, +) +from .template import CtaTemplate +from vnpy.component.cta_position import CtaPosition + +STOP_STATUS_MAP = { + Status.SUBMITTING: StopOrderStatus.WAITING, + Status.NOTTRADED: StopOrderStatus.WAITING, + Status.PARTTRADED: StopOrderStatus.TRIGGERED, + Status.ALLTRADED: StopOrderStatus.TRIGGERED, + Status.CANCELLED: StopOrderStatus.CANCELLED, + Status.REJECTED: StopOrderStatus.CANCELLED +} + + +class CtaEngine(BaseEngine): + """ + 策略引擎【股票版】 + """ + + engine_type = EngineType.LIVE # live trading engine + + # 策略配置文件 + setting_filename = "cta_stock_setting.json" + # 引擎配置文件 + engine_filename = "cta_stock_config.json" + + def __init__(self, main_engine: MainEngine, event_engine: EventEngine): + """ + 构造函数 + :param main_engine: 主引擎 + :param event_engine: 事件引擎 + """ + super().__init__(main_engine, event_engine, APP_NAME) + + self.engine_config = {} + + self.strategy_setting = {} # strategy_name: dict + self.strategy_data = {} # strategy_name: dict + + self.classes = {} # class_name: stategy_class + self.class_module_map = {} # class_name: mudule_name + self.strategies = {} # strategy_name: strategy + + # Strategy pos dict,key:strategy instance name, value: pos dict + self.strategy_pos_dict = {} + self.strategy_loggers = {} # strategy_name: logger + + # 未能订阅的symbols,支持策略启动时,并未接入gateway + # gateway_name.vt_symbol: set() of (strategy_name, is_bar) + self.pending_subcribe_symbol_map = defaultdict(set) + + self.symbol_strategy_map = defaultdict(list) # vt_symbol: strategy list + self.bar_strategy_map = defaultdict(list) # vt_symbol: strategy list + self.strategy_symbol_map = defaultdict(set) # strategy_name: vt_symbol set + + self.orderid_strategy_map = {} # vt_orderid: strategy + self.strategy_orderid_map = defaultdict( + set) # strategy_name: orderid list + + self.stop_order_count = 0 # for generating stop_orderid + self.stop_orders = {} # stop_orderid: stop_order + + self.thread_executor = ThreadPoolExecutor(max_workers=1) # 异步线程任务执行 + self.thread_tasks = [] + + self.vt_tradeids = set() # for filtering duplicate trade + + self.positions = {} + + self.last_minute = None + + def init_engine(self): + """ + """ + self.register_event() + self.register_funcs() + + self.load_strategy_class() + self.load_strategy_setting() + + self.write_log("CTA策略数字货币引擎初始化成功") + + def close(self): + """停止所属有的策略""" + self.stop_all_strategies() + + def register_event(self): + """注册事件""" + self.event_engine.register(EVENT_TIMER, self.process_timer_event) + self.event_engine.register(EVENT_TICK, self.process_tick_event) + self.event_engine.register(EVENT_BAR, self.process_bar_event) + self.event_engine.register(EVENT_ORDER, self.process_order_event) + self.event_engine.register(EVENT_TRADE, self.process_trade_event) + self.event_engine.register(EVENT_POSITION, self.process_position_event) + + def register_funcs(self): + """ + register the funcs to main_engine + :return: + """ + self.main_engine.get_strategy_status = self.get_strategy_status + self.main_engine.get_strategy_pos = self.get_strategy_pos + self.main_engine.compare_pos = self.compare_pos + self.main_engine.add_strategy = self.add_strategy + self.main_engine.init_strategy = self.init_strategy + self.main_engine.start_strategy = self.start_strategy + self.main_engine.stop_strategy = self.stop_strategy + self.main_engine.remove_strategy = self.remove_strategy + self.main_engine.reload_strategy = self.reload_strategy + self.main_engine.save_strategy_data = self.save_strategy_data + self.main_engine.save_strategy_snapshot = self.save_strategy_snapshot + + # 注册到远程服务调用 + if self.main_engine.rpc_service: + self.main_engine.rpc_service.register(self.main_engine.get_strategy_status) + self.main_engine.rpc_service.register(self.main_engine.get_strategy_pos) + self.main_engine.rpc_service.register(self.main_engine.compare_pos) + self.main_engine.rpc_service.register(self.main_engine.add_strategy) + self.main_engine.rpc_service.register(self.main_engine.init_strategy) + self.main_engine.rpc_service.register(self.main_engine.start_strategy) + self.main_engine.rpc_service.register(self.main_engine.stop_strategy) + self.main_engine.rpc_service.register(self.main_engine.remove_strategy) + self.main_engine.rpc_service.register(self.main_engine.reload_strategy) + self.main_engine.rpc_service.register(self.main_engine.save_strategy_data) + self.main_engine.rpc_service.register(self.main_engine.save_strategy_snapshot) + + def process_timer_event(self, event: Event): + """ 处理定时器事件""" + all_trading = True + # 触发每个策略的定时接口 + for strategy in list(self.strategies.values()): + strategy.on_timer() + if not strategy.trading: + all_trading = False + + dt = datetime.now() + + if self.last_minute != dt.minute: + self.last_minute = dt.minute + + if all_trading: + # 主动获取所有策略得持仓信息 + all_strategy_pos = self.get_all_strategy_pos() + + # 比对仓位,使用上述获取得持仓信息,不用重复获取 + self.compare_pos(strategy_pos_list=copy(all_strategy_pos)) + + # 推送到事件 + self.put_all_strategy_pos_event(all_strategy_pos) + + def process_tick_event(self, event: Event): + """处理tick到达事件""" + tick = event.data + + key = f'{tick.gateway_name}.{tick.vt_symbol}' + v = self.pending_subcribe_symbol_map.pop(key, None) + if v: + # 这里不做tick/bar的判断了,因为基本有tick就有bar + self.write_log(f'{key} tick已经到达,移除未订阅记录:{v}') + + strategies = self.symbol_strategy_map[tick.vt_symbol] + if not strategies: + return + + self.check_stop_order(tick) + + for strategy in strategies: + if strategy.inited: + self.call_strategy_func(strategy, strategy.on_tick, {tick.vt_symbol:tick}) + + def process_bar_event(self, event: Event): + """处理bar到达事件""" + bar = event.data + strategies = self.symbol_strategy_map[bar.vt_symbol] + if not strategies: + return + for strategy in strategies: + if strategy.inited: + self.call_strategy_func(strategy, strategy.on_bar, {bar.vt_symbol: bar}) + + def process_order_event(self, event: Event): + """""" + order = event.data + + strategy = self.orderid_strategy_map.get(order.vt_orderid, None) + if not strategy: + return + + # Remove vt_orderid if order is no longer active. + vt_orderids = self.strategy_orderid_map[strategy.strategy_name] + if order.vt_orderid in vt_orderids and not order.is_active(): + vt_orderids.remove(order.vt_orderid) + + # For server stop order, call strategy on_stop_order function + if order.type == OrderType.STOP: + so = StopOrder( + vt_symbol=order.vt_symbol, + direction=order.direction, + offset=order.offset, + price=order.price, + volume=order.volume, + stop_orderid=order.vt_orderid, + strategy_name=strategy.strategy_name, + status=STOP_STATUS_MAP[order.status], + vt_orderids=[order.vt_orderid], + ) + self.call_strategy_func(strategy, strategy.on_stop_order, so) + + # Call strategy on_order function + self.call_strategy_func(strategy, strategy.on_order, order) + + def process_trade_event(self, event: Event): + """""" + trade = event.data + + # Filter duplicate trade push + if trade.vt_tradeid in self.vt_tradeids: + return + self.vt_tradeids.add(trade.vt_tradeid) + + strategy = self.orderid_strategy_map.get(trade.vt_orderid, None) + if not strategy: + return + + # Update strategy pos before calling on_trade method + # 取消外部干预策略pos,由策略自行完成更新 + # if trade.direction == Direction.LONG: + # strategy.pos += trade.volume + # else: + # strategy.pos -= trade.volume + # 根据策略名称,写入 data\straetgy_name_trade.csv文件 + strategy_name = getattr(strategy, 'strategy_name') + trade_fields = ['datetime', 'symbol', 'exchange', 'vt_symbol', 'tradeid', 'vt_tradeid', 'orderid', 'vt_orderid', + 'direction', 'offset', 'price', 'volume', 'idx_price'] + trade_dict = OrderedDict() + try: + for k in trade_fields: + if k == 'datetime': + dt = getattr(trade, 'datetime') + if isinstance(dt, datetime): + trade_dict[k] = dt.strftime('%Y-%m-%d %H:%M:%S') + else: + trade_dict[k] = datetime.now().strftime('%Y-%m-%d') + ' ' + getattr(trade, 'time', '') + if k in ['exchange', 'direction', 'offset']: + trade_dict[k] = getattr(trade, k).value + else: + trade_dict[k] = getattr(trade, k, '') + + # 添加指数价格 + symbol = trade_dict.get('symbol') + idx_symbol = get_underlying_symbol(symbol).upper() + '99.' + trade_dict.get('exchange') + idx_price = self.get_price(idx_symbol) + if idx_price: + trade_dict.update({'idx_price': idx_price}) + else: + trade_dict.update({'idx_price': trade_dict.get('price')}) + + if strategy_name is not None: + trade_file = str(get_folder_path('data').joinpath('{}_trade.csv'.format(strategy_name))) + append_data(file_name=trade_file, dict_data=trade_dict) + except Exception as ex: + self.write_error(u'写入交易记录csv出错:{},{}'.format(str(ex), traceback.format_exc())) + + self.call_strategy_func(strategy, strategy.on_trade, trade) + + # Sync strategy variables to data file + # 取消此功能,由策略自身完成数据持久化 + # self.sync_strategy_data(strategy) + + # Update GUI + self.put_strategy_event(strategy) + + def process_position_event(self, event: Event): + """""" + position = event.data + + self.positions.update({position.vt_positionid: position}) + + def check_unsubscribed_symbols(self): + """检查未订阅合约""" + + for key in self.pending_subcribe_symbol_map.keys(): + # gateway_name.symbol.exchange = > gateway_name, vt_symbol + keys = key.split('.') + gateway_name = keys[0] + vt_symbol = '.'.join(keys[1:]) + + contract = self.main_engine.get_contract(vt_symbol) + is_bar = True if vt_symbol in self.bar_strategy_map else False + if contract: + dt = datetime.now() + + self.write_log(f'重新提交合约{vt_symbol}订阅请求') + for strategy_name, is_bar in list(self.pending_subcribe_symbol_map[vt_symbol]): + self.subscribe_symbol(strategy_name=strategy_name, + vt_symbol=vt_symbol, + gateway_name=gateway_name, + is_bar=is_bar) + else: + try: + self.write_log(f'找不到合约{vt_symbol}信息,尝试请求所有接口') + symbol, exchange = extract_vt_symbol(vt_symbol) + req = SubscribeRequest(symbol=symbol, exchange=exchange) + req.is_bar = is_bar + self.main_engine.subscribe(req, gateway_name) + + except Exception as ex: + self.write_error( + u'重新订阅{}.{}异常:{},{}'.format(gateway_name, vt_symbol, str(ex), traceback.format_exc())) + return + + def check_stop_order(self, tick: TickData): + """""" + for stop_order in list(self.stop_orders.values()): + if stop_order.vt_symbol != tick.vt_symbol: + continue + + long_triggered = stop_order.direction == Direction.LONG and tick.last_price >= stop_order.price + short_triggered = stop_order.direction == Direction.SHORT and tick.last_price <= stop_order.price + + if long_triggered or short_triggered: + strategy = self.strategies[stop_order.strategy_name] + + # To get excuted immediately after stop order is + # triggered, use limit price if available, otherwise + # use ask_price_5 or bid_price_5 + if stop_order.direction == Direction.LONG: + if tick.limit_up: + price = tick.limit_up + else: + price = tick.ask_price_5 + else: + if tick.limit_down: + price = tick.limit_down + else: + price = tick.bid_price_5 + + contract = self.main_engine.get_contract(stop_order.vt_symbol) + + vt_orderids = self.send_limit_order( + strategy=strategy, + contract=contract, + direction=stop_order.direction, + offset=stop_order.offset, + price=price, + volume=stop_order.volume + ) + + # Update stop order status if placed successfully + if vt_orderids: + # Remove from relation map. + self.stop_orders.pop(stop_order.stop_orderid) + + strategy_vt_orderids = self.strategy_orderid_map[strategy.strategy_name] + if stop_order.stop_orderid in strategy_vt_orderids: + strategy_vt_orderids.remove(stop_order.stop_orderid) + + # Change stop order status to cancelled and update to strategy. + stop_order.status = StopOrderStatus.TRIGGERED + stop_order.vt_orderids = vt_orderids + + self.call_strategy_func( + strategy, strategy.on_stop_order, stop_order + ) + self.put_stop_order_event(stop_order) + + def send_server_order( + self, + strategy: CtaTemplate, + contract: ContractData, + direction: Direction, + offset: Offset, + price: float, + volume: float, + type: OrderType, + gateway_name: str = None + ): + """ + Send a new order to server. + """ + # Create request and send order. + req = OrderRequest( + symbol=contract.symbol, + exchange=contract.exchange, + direction=direction, + offset=offset, + type=type, + price=price, + volume=volume, + strategy_name=strategy.strategy_name + ) + + # 如果没有指定网关,则使用合约信息内的网关 + if contract.gateway_name and not gateway_name: + gateway_name = contract.gateway_name + + # Send Orders + vt_orderids = [] + + vt_orderid = self.main_engine.send_order( + req, gateway_name) + + # Check if sending order successful + if not vt_orderid: + vt_orderids + + vt_orderids.append(vt_orderid) + + # Save relationship between orderid and strategy. + self.orderid_strategy_map[vt_orderid] = strategy + self.strategy_orderid_map[strategy.strategy_name].add(vt_orderid) + + return vt_orderids + + def send_limit_order( + self, + strategy: CtaTemplate, + contract: ContractData, + direction: Direction, + offset: Offset, + price: float, + volume: float, + gateway_name: str = None + ): + """ + Send a limit order to server. + """ + return self.send_server_order( + strategy=strategy, + contract=contract, + direction=direction, + offset=offset, + price=price, + volume=volume, + type=OrderType.LIMIT, + gateway_name=gateway_name + ) + + def send_fak_order( + self, + strategy: CtaTemplate, + contract: ContractData, + direction: Direction, + offset: Offset, + price: float, + volume: float, + gateway_name: str = None + ): + """ + Send a limit order to server. + """ + return self.send_server_order( + strategy=strategy, + contract=contract, + direction=direction, + offset=offset, + price=price, + volume=volume, + type=OrderType.FAK, + gateway_name=gateway_name + ) + + def send_server_stop_order( + self, + strategy: CtaTemplate, + contract: ContractData, + direction: Direction, + offset: Offset, + price: float, + volume: float, + gateway_name: str = None + ): + """ + Send a stop order to server. + + Should only be used if stop order supported + on the trading server. + """ + return self.send_server_order( + strategy=strategy, + contract=contract, + direction=direction, + offset=offset, + price=price, + volume=volume, + type=OrderType.STOP, + gateway_name=gateway_name + ) + + def send_local_stop_order( + self, + strategy: CtaTemplate, + vt_symbol: str, + direction: Direction, + offset: Offset, + price: float, + volume: float, + gateway_name: str = None + ): + """ + Create a new local stop order. + """ + self.stop_order_count += 1 + stop_orderid = f"{STOPORDER_PREFIX}.{self.stop_order_count}" + + stop_order = StopOrder( + vt_symbol=vt_symbol, + direction=direction, + offset=offset, + price=price, + volume=volume, + stop_orderid=stop_orderid, + strategy_name=strategy.strategy_name, + gateway_name=gateway_name + ) + + self.stop_orders[stop_orderid] = stop_order + + vt_orderids = self.strategy_orderid_map[strategy.strategy_name] + vt_orderids.add(stop_orderid) + + self.call_strategy_func(strategy, strategy.on_stop_order, stop_order) + self.put_stop_order_event(stop_order) + + return [stop_orderid] + + def cancel_server_order(self, strategy: CtaTemplate, vt_orderid: str): + """ + Cancel existing order by vt_orderid. + """ + order = self.main_engine.get_order(vt_orderid) + if not order: + self.write_log(msg=f"撤单失败,找不到委托{vt_orderid}", + strategy_name=strategy.strategy_name, + level=logging.ERROR) + return False + + req = order.create_cancel_request() + return self.main_engine.cancel_order(req, order.gateway_name) + + def cancel_local_stop_order(self, strategy: CtaTemplate, stop_orderid: str): + """ + Cancel a local stop order. + """ + stop_order = self.stop_orders.get(stop_orderid, None) + if not stop_order: + return False + strategy = self.strategies[stop_order.strategy_name] + + # Remove from relation map. + self.stop_orders.pop(stop_orderid) + + vt_orderids = self.strategy_orderid_map[strategy.strategy_name] + if stop_orderid in vt_orderids: + vt_orderids.remove(stop_orderid) + + # Change stop order status to cancelled and update to strategy. + stop_order.status = StopOrderStatus.CANCELLED + + self.call_strategy_func(strategy, strategy.on_stop_order, stop_order) + self.put_stop_order_event(stop_order) + return True + + def send_order( + self, + strategy: CtaTemplate, + vt_symbol: str, + direction: Direction, + offset: Offset, + price: float, + volume: float, + stop: bool, + order_type: OrderType = OrderType.LIMIT, + gateway_name: str = None + ): + """ + 该方法供策略使用,发送委托。 + """ + contract = self.main_engine.get_contract(vt_symbol) + if not contract: + self.write_log(msg=f"委托失败,找不到合约:{vt_symbol}", + strategy_name=strategy.strategy_name, + level=logging.ERROR) + return "" + if contract.gateway_name and not gateway_name: + gateway_name = contract.gateway_name + # Round order price and volume to nearest incremental value + price = round_to(price, contract.pricetick) + volume = round_to(volume, contract.min_volume) + + if stop: + if contract.stop_supported: + # 发送服务器停止单 + return self.send_server_stop_order( + strategy=strategy, + contract=contract, + direction=direction, + offset=offset, + price=price, + volume=volume, + gateway_name=gateway_name) + else: + # 创建本地停止单 + return self.send_local_stop_order( + strategy=strategy, + vt_symbol=vt_symbol, + direction=direction, + offset=offset, + price=price, + volume=volume, + gateway_name=gateway_name) + if order_type == OrderType.FAK: + return self.send_fak_order( + strategy=strategy, + contract=contract, + direction=direction, + offset=offset, + price=price, + volume=volume, + gateway_name=gateway_name) + else: + return self.send_limit_order( + strategy=strategy, + contract=contract, + direction=direction, + offset=offset, + price=price, + volume=volume, + gateway_name=gateway_name) + + def cancel_order(self, strategy: CtaTemplate, vt_orderid: str): + """ + """ + if vt_orderid.startswith(STOPORDER_PREFIX): + return self.cancel_local_stop_order(strategy, vt_orderid) + else: + return self.cancel_server_order(strategy, vt_orderid) + + def cancel_all(self, strategy: CtaTemplate): + """ + Cancel all active orders of a strategy. + """ + vt_orderids = self.strategy_orderid_map[strategy.strategy_name] + if not vt_orderids: + return + + for vt_orderid in copy(vt_orderids): + self.cancel_order(strategy, vt_orderid) + + def subscribe_symbol(self, strategy_name: str, vt_symbol: str, gateway_name: str = '', is_bar: bool = False): + """订阅合约""" + strategy = self.strategies.get(strategy_name, None) + if not strategy: + return False + + contract = self.main_engine.get_contract(vt_symbol) + if contract: + if contract.gateway_name and not gateway_name: + gateway_name = contract.gateway_name + req = SubscribeRequest( + symbol=contract.symbol, exchange=contract.exchange) + self.main_engine.subscribe(req, gateway_name) + else: + self.write_log(msg=f"找不到合约{vt_symbol},添加到待订阅列表", + strategy_name=strategy.strategy_name) + self.pending_subcribe_symbol_map[f'{gateway_name}.{vt_symbol}'].add((strategy_name, is_bar)) + try: + self.write_log(f'找不到合约{vt_symbol}信息,尝试请求所有接口') + symbol, exchange = extract_vt_symbol(vt_symbol) + req = SubscribeRequest(symbol=symbol, exchange=exchange) + req.is_bar = is_bar + self.main_engine.subscribe(req, gateway_name) + + except Exception as ex: + self.write_error(u'重新订阅{}异常:{},{}'.format(vt_symbol, str(ex), traceback.format_exc())) + + # 如果是订阅bar + if is_bar: + strategies = self.bar_strategy_map[vt_symbol] + if strategy not in strategies: + strategies.append(strategy) + self.bar_strategy_map.update({vt_symbol: strategies}) + else: + # 添加 合约订阅 vt_symbol <=> 策略实例 strategy 映射. + strategies = self.symbol_strategy_map[vt_symbol] + if strategy not in strategies: + strategies.append(strategy) + + # 添加 策略名 strategy_name <=> 合约订阅 vt_symbol 的映射 + subscribe_symbol_set = self.strategy_symbol_map[strategy.strategy_name] + subscribe_symbol_set.add(vt_symbol) + + return True + + @lru_cache() + def get_name(self, vt_symbol: str): + """查询合约的name""" + contract = self.main_engine.get_contract(vt_symbol) + if contract is None: + self.write_error(f'查询不到{vt_symbol}合约信息') + return vt_symbol + return contract.name + + @lru_cache() + def get_size(self, vt_symbol: str): + """查询合约的size""" + contract = self.main_engine.get_contract(vt_symbol) + if contract is None: + self.write_error(f'查询不到{vt_symbol}合约信息') + return 10 + return contract.size + + @lru_cache() + def get_margin_rate(self, vt_symbol: str): + """查询保证金比率""" + contract = self.main_engine.get_contract(vt_symbol) + if contract is None: + self.write_error(f'查询不到{vt_symbol}合约信息') + return 0.1 + if contract.margin_rate == 0: + return 0.1 + return contract.margin_rate + + @lru_cache() + def get_price_tick(self, vt_symbol: str): + """查询价格最小跳动""" + contract = self.main_engine.get_contract(vt_symbol) + if contract is None: + self.write_error(f'查询不到{vt_symbol}合约信息') + return 0.01 + + return contract.pricetick + + @lru_cache() + def get_volume_tick(self, vt_symbol: str): + """查询合约最小成交单位""" + contract = self.main_engine.get_contract(vt_symbol) + if contract is None: + self.write_error(f'查询不到{vt_symbol}合约信息') + return 100 + + return contract.min_volume + + def get_tick(self, vt_symbol: str): + """获取合约得最新tick""" + return self.main_engine.get_tick(vt_symbol) + + def get_price(self, vt_symbol: str): + """查询合约的最新价格""" + tick = self.main_engine.get_tick(vt_symbol) + if tick: + return tick.last_price + + return None + + def get_contract(self, vt_symbol): + return self.main_engine.get_contract(vt_symbol) + + def get_account(self, vt_accountid: str = ""): + """ 查询账号的资金""" + # 如果启动风控,则使用风控中的最大仓位 + if self.main_engine.rm_engine: + return self.main_engine.rm_engine.get_account(vt_accountid) + + if len(vt_accountid) > 0: + account = self.main_engine.get_account(vt_accountid) + return account.balance, account.available, round(account.frozen * 100 / (account.balance + 0.01), 2), 100 + else: + accounts = self.main_engine.get_all_accounts() + if len(accounts) > 0: + account = accounts[0] + return account.balance, account.available, round(account.frozen * 100 / (account.balance + 0.01), + 2), 100 + else: + return 0, 0, 0, 0 + + def get_position(self, vt_symbol: str, direction: Direction, gateway_name: str = '') -> PositionData: + """ 查询合约在账号的持仓,需要指定方向""" + contract = self.main_engine.get_contract(vt_symbol) + if contract: + if contract.gateway_name and not gateway_name: + gateway_name = contract.gateway_name + + vt_position_id = f"{gateway_name}.{vt_symbol}.{direction.value}" + return self.main_engine.get_position(vt_position_id) + + def get_engine_type(self): + """""" + return self.engine_type + + @lru_cache() + def get_data_path(self): + data_path = os.path.abspath(os.path.join(TRADER_DIR, 'data')) + return data_path + + @lru_cache() + def get_logs_path(self): + log_path = os.path.abspath(os.path.join(TRADER_DIR, 'log')) + return log_path + + def load_bar( + self, + vt_symbol: str, + days: int, + interval: Interval, + callback: Callable[[BarData], None] + ): + """""" + symbol, exchange = extract_vt_symbol(vt_symbol) + end = datetime.now() + start = end - timedelta(days) + bars = [] + + # Query bars from gateway if available + contract = self.main_engine.get_contract(vt_symbol) + + if contract and contract.history_data: + req = HistoryRequest( + symbol=symbol, + exchange=exchange, + interval=interval, + start=start, + end=end + ) + bars = self.main_engine.query_history(req, contract.gateway_name) + + for bar in bars: + if bar.trading_day: + bar.trading_day = bar.datetime.strftime('%Y-%m-%d') + + callback(bar) + + def call_strategy_func( + self, strategy: CtaTemplate, func: Callable, params: Any = None + ): + """ + Call function of a strategy and catch any exception raised. + """ + try: + if params: + func(params) + else: + func() + except Exception: + strategy.trading = False + strategy.inited = False + + msg = f"触发异常已停止\n{traceback.format_exc()}" + self.write_log(msg=msg, + strategy_name=strategy.strategy_name, + level=logging.CRITICAL) + + def add_strategy( + self, class_name: str, + strategy_name: str, + vt_symbols: List[str], + setting: dict, + auto_init: bool = False, + auto_start: bool = False + ): + """ + Add a new strategy. + """ + if strategy_name in self.strategies: + msg = f"创建策略失败,存在重名{strategy_name}" + self.write_log(msg=msg, + level=logging.CRITICAL) + return False, msg + + strategy_class = self.classes.get(class_name, None) + if not strategy_class: + msg = f"创建策略失败,找不到策略类{class_name}" + self.write_log(msg=msg, + level=logging.CRITICAL) + return False, msg + + self.write_log(f'开始添加策略类{class_name},实例名:{strategy_name}') + strategy = strategy_class(self, strategy_name, vt_symbols, setting) + self.strategies[strategy_name] = strategy + + # Add vt_symbol to strategy map. + subscribe_symbol_set = self.strategy_symbol_map[strategy_name] + for vt_symbol in vt_symbols: + strategies = self.symbol_strategy_map[vt_symbol] + strategies.append(strategy) + subscribe_symbol_set.add(vt_symbol) + + # Update to setting file. + self.update_strategy_setting(strategy_name, setting, auto_init, auto_start) + + self.put_strategy_event(strategy) + + # 判断设置中是否由自动初始化和自动启动项目 + if auto_init: + self.init_strategy(strategy_name, auto_start=auto_start) + + return True, f'成功添加{strategy_name}' + + def init_strategy(self, strategy_name: str, auto_start: bool = False): + """ + Init a strategy. + """ + task = self.thread_executor.submit(self._init_strategy, strategy_name, auto_start) + self.thread_tasks.append(task) + + def _init_strategy(self, strategy_name: str, auto_start: bool = False): + """ + Init strategies in queue. + """ + strategy = self.strategies[strategy_name] + + if strategy.inited: + self.write_error(f"{strategy_name}已经完成初始化,禁止重复操作") + return + + self.write_log(f"{strategy_name}开始执行初始化") + + # Call on_init function of strategy + self.call_strategy_func(strategy, strategy.on_init) + + # Restore strategy data(variables) + # Pro 版本不使用自动恢复除了内部数据功能,由策略自身初始化时完成 + # data = self.strategy_data.get(strategy_name, None) + # if data: + # for name in strategy.variables: + # value = data.get(name, None) + # if value: + # setattr(strategy, name, value) + + # Subscribe market data 订阅缺省的vt_symbol, 如果有其他合约需要订阅,由策略内部初始化时提交订阅即可。 + for vt_symbol in strategy.vt_symbols: + self.subscribe_symbol(strategy_name=strategy_name, vt_symbol=vt_symbol) + + # Put event to update init completed status. + strategy.inited = True + self.put_strategy_event(strategy) + self.write_log(f"{strategy_name}初始化完成") + + # 初始化后,自动启动策略交易 + if auto_start: + self.start_strategy(strategy_name) + + def start_strategy(self, strategy_name: str): + """ + Start a strategy. + """ + strategy = self.strategies[strategy_name] + if not strategy.inited: + msg = f"策略{strategy.strategy_name}启动失败,请先初始化" + self.write_error(msg) + return False, msg + + if strategy.trading: + msg = f"{strategy_name}已经启动,请勿重复操作" + self.write_error(msg) + return False, msg + + self.call_strategy_func(strategy, strategy.on_start) + strategy.trading = True + + self.put_strategy_event(strategy) + + return True, f'成功启动策略{strategy_name}' + + def stop_strategy(self, strategy_name: str): + """ + Stop a strategy. + """ + strategy = self.strategies[strategy_name] + if not strategy.trading: + msg = f'{strategy_name}策略实例已处于停止交易状态' + self.write_log(msg) + return False, msg + + # Call on_stop function of the strategy + self.write_log(f'调用{strategy_name}的on_stop,停止交易') + self.call_strategy_func(strategy, strategy.on_stop) + + # Change trading status of strategy to False + strategy.trading = False + + # Cancel all orders of the strategy + self.write_log(f'撤销{strategy_name}所有委托') + self.cancel_all(strategy) + + # Sync strategy variables to data file + # 取消此功能,由策略自身完成数据的持久化 + # self.sync_strategy_data(strategy) + + # Update GUI + self.put_strategy_event(strategy) + return True, f'成功停止策略{strategy_name}' + + def edit_strategy(self, strategy_name: str, setting: dict): + """ + Edit parameters of a strategy. + 风险警示: 该方法强行干预策略的配置 + """ + strategy = self.strategies[strategy_name] + auto_init = setting.pop('auto_init', False) + auto_start = setting.pop('auto_start', False) + + strategy.update_setting(setting) + + self.update_strategy_setting(strategy_name, setting, auto_init, auto_start) + self.put_strategy_event(strategy) + + def remove_strategy(self, strategy_name: str): + """ + Remove a strategy. + """ + strategy = self.strategies[strategy_name] + if strategy.trading: + err_msg = f"策略{strategy.strategy_name}移除失败,请先停止" + self.write_error(err_msg) + return False, err_msg + + # Remove setting + self.remove_strategy_setting(strategy_name) + + # 移除订阅合约与策略的关联关系 + for vt_symbol in self.strategy_symbol_map[strategy_name]: + # Remove from symbol strategy map + self.write_log(f'移除{vt_symbol}《=》{strategy_name}的订阅关系') + strategies = self.symbol_strategy_map[vt_symbol] + strategies.remove(strategy) + + # Remove from active orderid map + if strategy_name in self.strategy_orderid_map: + vt_orderids = self.strategy_orderid_map.pop(strategy_name) + self.write_log(f'移除{strategy_name}的所有委托订单映射关系') + # Remove vt_orderid strategy map + for vt_orderid in vt_orderids: + if vt_orderid in self.orderid_strategy_map: + self.orderid_strategy_map.pop(vt_orderid) + + # Remove from strategies + self.write_log(f'移除{strategy_name}策略实例') + self.strategies.pop(strategy_name) + + return True, f'成功移除{strategy_name}策略实例' + + def reload_strategy(self, strategy_name: str, vt_symbols: List[str] = [], setting: dict = {}): + """ + 重新加载策略 + 一般使用于在线更新策略代码,或者更新策略参数,需要重新启动策略 + :param strategy_name: + :param setting: + :return: + """ + self.write_log(f'开始重新加载策略{strategy_name}') + + # 优先判断重启的策略,是否已经加载 + if strategy_name not in self.strategies or strategy_name not in self.strategy_setting: + err_msg = f"{strategy_name}不在运行策略中,不能重启" + self.write_error(err_msg) + return False, err_msg + + # 从本地配置文件中读取 + if len(setting) == 0: + strategies_setting = load_json(self.setting_filename) + old_strategy_config = strategies_setting.get(strategy_name, {}) + else: + old_strategy_config = copy(self.strategy_setting[strategy_name]) + + class_name = old_strategy_config.get('class_name') + if len(vt_symbols) == 0: + vt_symbols = old_strategy_config.get('vt_symbols') + if len(setting) == 0: + setting = old_strategy_config.get('setting') + + module_name = self.class_module_map[class_name] + # 重新load class module + if not self.load_strategy_class_from_module(module_name): + err_msg = f'不能加载模块:{module_name}' + self.write_error(err_msg) + return False, err_msg + + # 停止当前策略实例的运行,撤单 + self.stop_strategy(strategy_name) + + # 移除运行中的策略实例 + self.remove_strategy(strategy_name) + + # 重新添加策略 + self.add_strategy(class_name=class_name, + strategy_name=strategy_name, + vt_symbols=vt_symbols, + setting=setting, + auto_init=old_strategy_config.get('auto_init', False), + auto_start=old_strategy_config.get('auto_start', False)) + + msg = f'成功重载策略{strategy_name}' + self.write_log(msg) + return True, msg + + def save_strategy_data(self, select_name: str = 'ALL'): + """ save strategy data""" + has_executed = False + msg = "" + # 1.判断策略名称是否存在字典中 + for strategy_name in list(self.strategies.keys()): + if select_name != 'ALL': + if strategy_name != select_name: + continue + # 2.提取策略 + strategy = self.strategies.get(strategy_name, None) + if not strategy: + continue + + # 3.判断策略是否运行 + if strategy.inited and strategy.trading: + task = self.thread_executor.submit(self.thread_save_strategy_data, strategy_name) + self.thread_tasks.append(task) + msg += f'{strategy_name}执行保存数据\n' + has_executed = True + else: + self.write_log(f'{strategy_name}未初始化/未启动交易,不进行保存数据') + return has_executed, msg + + def thread_save_strategy_data(self, strategy_name): + """异步线程保存策略数据""" + strategy = self.strategies.get(strategy_name, None) + if strategy is None: + return + try: + # 保存策略数据 + strategy.sync_data() + except Exception as ex: + self.write_error(u'保存策略{}数据异常:'.format(strategy_name, str(ex))) + self.write_error(traceback.format_exc()) + + def save_strategy_snapshot(self, select_name: str = 'ALL'): + """ + 保存策略K线切片数据 + :param select_name: + :return: + """ + has_executed = False + msg = "" + # 1.判断策略名称是否存在字典中 + for strategy_name in list(self.strategies.keys()): + if select_name != 'ALL': + if strategy_name != select_name: + continue + # 2.提取策略 + strategy = self.strategies.get(strategy_name, None) + if not strategy: + continue + + if not hasattr(strategy, 'get_klines_snapshot'): + continue + + # 3.判断策略是否运行 + if strategy.inited and strategy.trading: + task = self.thread_executor.submit(self.thread_save_strategy_snapshot, strategy_name) + self.thread_tasks.append(task) + msg += f'{strategy_name}执行保存K线切片\n' + has_executed = True + + return has_executed, msg + + def thread_save_strategy_snapshot(self, strategy_name): + """异步线程保存策略切片""" + strategy = self.strategies.get(strategy_name, None) + if strategy is None: + return + + try: + # 5.保存策略切片 + snapshot = strategy.get_klines_snapshot() + if not snapshot: + self.write_log(f'{strategy_name}返回得K线切片数据为空') + return + + # 剩下工作:保存本地文件/数据库 + snapshot_folder = get_folder_path(f'data/snapshots/{strategy_name}') + snapshot_file = snapshot_folder.joinpath('{}.pkb2'.format(datetime.now().strftime('%Y%m%d_%H%M%S'))) + with bz2.BZ2File(str(snapshot_file), 'wb') as f: + pickle.dump(snapshot, f) + self.write_log(u'切片保存成功:{}'.format(str(snapshot_file))) + + # 通过事件方式,传导到account_recorder + snapshot.update({ + 'account_id': self.engine_config.get('accountid', '-'), + 'strategy_group': self.engine_config.get('strategy_group', self.engine_name), + 'guid': str(uuid1()) + }) + event = Event(EVENT_STRATEGY_SNAPSHOT, snapshot) + self.event_engine.put(event) + + except Exception as ex: + self.write_error(u'获取策略{}切片数据异常:'.format(strategy_name, str(ex))) + self.write_error(traceback.format_exc()) + + def load_strategy_class(self): + """ + Load strategy class from source code. + """ + # 加载 vnpy/app/cta_strategy_pro/strategies的所有策略 + path1 = Path(__file__).parent.joinpath("strategies") + self.load_strategy_class_from_folder( + path1, "vnpy.app.cta_stock.strategies") + + # 加载 当前运行目录下strategies子目录的所有策略 + path2 = Path.cwd().joinpath("strategies") + self.load_strategy_class_from_folder(path2, "strategies") + + def load_strategy_class_from_folder(self, path: Path, module_name: str = ""): + """ + Load strategy class from certain folder. + """ + for dirpath, dirnames, filenames in os.walk(str(path)): + for filename in filenames: + if filename.endswith(".py"): + strategy_module_name = ".".join( + [module_name, filename.replace(".py", "")]) + elif filename.endswith(".pyd"): + strategy_module_name = ".".join( + [module_name, filename.split(".")[0]]) + elif filename.endswith(".so"): + strategy_module_name = ".".join( + [module_name, filename.split(".")[0]]) + else: + continue + self.load_strategy_class_from_module(strategy_module_name) + + def load_strategy_class_from_module(self, module_name: str): + """ + Load/Reload strategy class from module file. + """ + try: + module = importlib.import_module(module_name) + + for name in dir(module): + value = getattr(module, name) + if (isinstance(value, type) and issubclass(value, CtaTemplate) and value is not CtaTemplate): + class_name = value.__name__ + if class_name not in self.classes: + self.write_log(f"加载策略类{module_name}.{class_name}") + else: + self.write_log(f"更新策略类{module_name}.{class_name}") + self.classes[class_name] = value + self.class_module_map[class_name] = module_name + return True + except: # noqa + msg = f"策略文件{module_name}加载失败,触发异常:\n{traceback.format_exc()}" + self.write_log(msg=msg, level=logging.CRITICAL) + return False + + def load_strategy_data(self): + """ + Load strategy data from json file. + """ + print(f'load_strategy_data 此功能已取消,由策略自身完成数据的持久化加载', file=sys.stderr) + return + # self.strategy_data = load_json(self.data_filename) + + def sync_strategy_data(self, strategy: CtaTemplate): + """ + Sync strategy data into json file. + """ + # data = strategy.get_variables() + # data.pop("inited") # Strategy status (inited, trading) should not be synced. + # data.pop("trading") + # self.strategy_data[strategy.strategy_name] = data + # save_json(self.data_filename, self.strategy_data) + print(f'sync_strategy_data此功能已取消,由策略自身完成数据的持久化保存', file=sys.stderr) + + def get_all_strategy_class_names(self): + """ + Return names of strategy classes loaded. + """ + return list(self.classes.keys()) + + def get_strategy_status(self): + """ + return strategy inited/trading status + :param strategy_name: + :return: + """ + return {k: {'inited': v.inited, 'trading': v.trading} for k, v in self.strategies.items()} + + def get_strategy_pos(self, name, strategy=None): + """ + 获取策略的持仓字典 + :param name:策略名 + :return: [ {},{}] + """ + # 兼容处理,如果strategy是None,通过name获取 + if strategy is None: + if name not in self.strategies: + self.write_log(u'get_strategy_pos 策略实例不存在:' + name) + return [] + # 获取策略实例 + strategy = self.strategies[name] + + pos_list = [] + + if strategy.inited: + # 如果策略具有getPositions得方法,则调用该方法 + if hasattr(strategy, 'get_positions'): + pos_list = strategy.get_positions() + for pos in pos_list: + vt_symbol = pos.get('vt_symbol', None) + if vt_symbol: + symbol, exchange = extract_vt_symbol(vt_symbol) + pos.update({'symbol': symbol}) + pos.update({'name': self.get_name(vt_symbol)}) + + # 如果策略有 positions 属性 + elif hasattr(strategy, 'positions') and issubclass(strategy.positions, Dict): + symbol, exchange = extract_vt_symbol(strategy.vt_symbol) + # 多仓 + for vt_symbol, pos in strategy.get_positions.items(): + long_pos = {} + long_pos['vt_symbol'] = vt_symbol + long_pos['symbol'] = vt_symbol.split('.')[0] + long_pos['name'] = self.get_name(vt_symbol) + long_pos['direction'] = 'long' + long_pos['volume'] = pos.volume + long_pos['price'] = pos.price + long_pos['pnl'] = pos.pnl + if long_pos['volume'] > 0: + pos_list.append(long_pos) + + # update local pos dict + self.strategy_pos_dict.update({name: pos_list}) + + return pos_list + + def get_all_strategy_pos(self): + """ + 获取所有得策略仓位明细 + """ + strategy_pos_list = [] + for strategy_name in list(self.strategies.keys()): + d = OrderedDict() + d['accountid'] = self.engine_config.get('accountid', '-') + d['strategy_group'] = self.engine_config.get('strategy_group', '-') + d['strategy_name'] = strategy_name + dt = datetime.now() + d['date'] = dt.strftime('%Y%m%d') + d['hour'] = dt.hour + d['datetime'] = datetime.now() + strategy = self.strategies.get(strategy_name) + d['inited'] = strategy.inited + d['trading'] = strategy.trading + try: + d['pos'] = self.get_strategy_pos(name=strategy_name) + except Exception as ex: + self.write_error( + u'get_strategy_pos exception:{},{}'.format(str(ex), traceback.format_exc())) + d['pos'] = [] + strategy_pos_list.append(d) + + return strategy_pos_list + + def get_strategy_class_parameters(self, class_name: str): + """ + Get default parameters of a strategy class. + """ + strategy_class = self.classes[class_name] + + parameters = {} + for name in strategy_class.parameters: + parameters[name] = getattr(strategy_class, name) + + return parameters + + def get_strategy_parameters(self, strategy_name): + """ + Get parameters of a strategy. + """ + strategy = self.strategies[strategy_name] + strategy_config = self.strategy_setting.get(strategy_name, {}) + d = {} + d.update({'auto_init': strategy_config.get('auto_init', False)}) + d.update({'auto_start': strategy_config.get('auto_start', False)}) + d.update(strategy.get_parameters()) + return d + + def compare_pos(self): + """ + 对比账号&策略的持仓,不同的话则发出微信提醒 + :return: + """ + # 当前没有接入网关 + if len(self.main_engine.gateways) == 0: + return False, u'当前没有接入网关' + + self.write_log(u'开始对比账号&策略的持仓') + + # 获取当前策略得持仓 + strategy_pos_list = self.get_all_strategy_pos() + self.write_log(u'策略持仓清单:{}'.format(strategy_pos_list)) + + # 需要进行对比得合约集合(来自策略持仓/账号持仓) + vt_symbols = set() + + # 账号的持仓处理 => account_pos + + compare_pos = dict() # vt_symbol: {'账号多单': xx,'策略多单':[]} + + for position in list(self.positions.values()): + # gateway_name.symbol.exchange => symbol.exchange + vt_symbol = position.vt_symbol + vt_symbols.add(vt_symbol) + + compare_pos[vt_symbol] = OrderedDict( + { + '账号多单': position.volume, + '策略多单': 0, + '多单策略': [] + } + ) + + # 逐一根据策略仓位,与Account_pos进行处理比对 + for strategy_pos in strategy_pos_list: + for pos in strategy_pos.get('pos', []): + vt_symbol = pos.get('vt_symbol') + if not vt_symbol: + continue + vt_symbols.add(vt_symbol) + symbol_pos = compare_pos.get(vt_symbol, None) + if symbol_pos is None: + self.write_log(u'账号持仓信息获取不到{},创建一个'.format(vt_symbol)) + symbol_pos = OrderedDict( + { + '账号多单': 0, + '策略多单': 0, + '多单策略': [] + } + ) + + if pos.get('direction') == 'long': + symbol_pos.update({'策略多单': symbol_pos.get('策略多单', 0) + abs(pos.get('volume', 0))}) + symbol_pos['多单策略'].append( + u'{}({})'.format(strategy_pos['strategy_name'], abs(pos.get('volume', 0)))) + self.write_log(u'更新{}策略持多仓=>{}'.format(vt_symbol, symbol_pos.get('策略多单', 0))) + + pos_compare_result = '' + # 精简输出 + compare_info = '' + for vt_symbol in sorted(vt_symbols): + # 发送不一致得结果 + symbol_pos = compare_pos.pop(vt_symbol) + d_long = { + 'account_id': self.engine_config.get('account_id', '-'), + 'vt_symbol': vt_symbol, + 'direction': Direction.LONG.value, + 'strategy_list': symbol_pos.get('多单策略', [])} + + # 多单都一致 + if round(symbol_pos['账号多单'], 7) == round(symbol_pos['策略多单'], 7): + msg = u'{}多单都一致.{}\n'.format(vt_symbol, json.dumps(symbol_pos, indent=2, ensure_ascii=False)) + self.write_log(msg) + compare_info += msg + else: + pos_compare_result += '\n{}: '.format(vt_symbol) + # 多单不一致 + if round(symbol_pos['策略多单'], 7) != round(symbol_pos['账号多单'], 7): + msg = '{}多单[账号({}), 策略{},共({})], ' \ + .format(vt_symbol, + symbol_pos['账号多单'], + symbol_pos['多单策略'], + symbol_pos['策略多单']) + + pos_compare_result += msg + self.write_error(u'{}不一致:{}'.format(vt_symbol, msg)) + compare_info += u'{}不一致:{}\n'.format(vt_symbol, msg) + + # 不匹配,输入到stdErr通道 + if pos_compare_result != '': + msg = u'账户{}持仓不匹配: {}' \ + .format(self.engine_config.get('account_id', '-'), + pos_compare_result) + try: + from vnpy.trader.util_wechat import send_wx_msg + send_wx_msg(content=msg) + except Exception: # noqa + pass + ret_msg = u'持仓不匹配: {}' \ + .format(pos_compare_result) + self.write_error(ret_msg) + return True, compare_info + ret_msg + else: + self.write_log(u'账户持仓与策略一致') + return True, compare_info + + def init_all_strategies(self): + """ + """ + for strategy_name in self.strategies.keys(): + self.init_strategy(strategy_name) + + def start_all_strategies(self): + """ + """ + for strategy_name in self.strategies.keys(): + self.start_strategy(strategy_name) + + def stop_all_strategies(self): + """ + """ + for strategy_name in self.strategies.keys(): + self.stop_strategy(strategy_name) + + def load_strategy_setting(self): + """ + Load setting file. + """ + # 读取引擎得配置 + self.engine_config = load_json(self.engine_filename) + + # 读取策略得配置 + self.strategy_setting = load_json(self.setting_filename) + + for strategy_name, strategy_config in self.strategy_setting.items(): + self.add_strategy( + class_name=strategy_config["class_name"], + strategy_name=strategy_name, + vt_symbol=strategy_config["vt_symbol"], + setting=strategy_config["setting"], + auto_init=strategy_config.get('auto_init', False), + auto_start=strategy_config.get('auto_start', False) + ) + + def update_strategy_setting(self, strategy_name: str, setting: dict, auto_init: bool = False, + auto_start: bool = False): + """ + Update setting file. + """ + strategy = self.strategies[strategy_name] + # 原配置 + old_config = self.strategy_setting.get('strategy_name', {}) + new_config = { + "class_name": strategy.__class__.__name__, + "vt_symbol": strategy.vt_symbol, + "auto_init": auto_init, + "auto_start": auto_start, + "setting": setting + } + + if old_config: + self.write_log(f'{strategy_name} 配置变更:\n{old_config} \n=> \n{new_config}') + + self.strategy_setting[strategy_name] = new_config + + save_json(self.setting_filename, self.strategy_setting) + + def remove_strategy_setting(self, strategy_name: str): + """ + Update setting file. + """ + if strategy_name not in self.strategy_setting: + return + self.write_log(f'移除CTA数字货币引擎{strategy_name}的配置') + self.strategy_setting.pop(strategy_name) + save_json(self.setting_filename, self.strategy_setting) + + def put_stop_order_event(self, stop_order: StopOrder): + """ + Put an event to update stop order status. + """ + event = Event(EVENT_CTA_STOPORDER, stop_order) + self.event_engine.put(event) + + def put_strategy_event(self, strategy: CtaTemplate): + """ + Put an event to update strategy status. + """ + data = strategy.get_data() + event = Event(EVENT_CTA_STRATEGY, data) + self.event_engine.put(event) + + def put_all_strategy_pos_event(self, strategy_pos_list: list = []): + """推送所有策略得持仓事件""" + for strategy_pos in strategy_pos_list: + event = Event(EVENT_STRATEGY_POS, copy(strategy_pos)) + self.event_engine.put(event) + + def write_log(self, msg: str, strategy_name: str = '', level: int = logging.INFO): + """ + Create cta engine log event. + """ + # 推送至全局CTA_LOG Event + log = LogData(msg=f"{strategy_name}: {msg}" if strategy_name else msg, + gateway_name="CtaStrategy", + level=level) + event = Event(type=EVENT_CTA_LOG, data=log) + self.event_engine.put(event) + + # 保存单独的策略日志 + if strategy_name: + strategy_logger = self.strategy_loggers.get(strategy_name, None) + if not strategy_logger: + log_path = get_folder_path('log') + log_filename = str(log_path.joinpath(str(strategy_name))) + print(u'create logger:{}'.format(log_filename)) + self.strategy_loggers[strategy_name] = setup_logger(file_name=log_filename, + name=str(strategy_name)) + strategy_logger = self.strategy_loggers.get(strategy_name) + if strategy_logger: + strategy_logger.log(level, msg) + else: + if self.logger: + self.logger.log(level, msg) + + # 如果日志数据异常,错误和告警,输出至sys.stderr + if level in [logging.CRITICAL, logging.ERROR, logging.WARNING]: + print(f"{strategy_name}: {msg}" if strategy_name else msg, file=sys.stderr) + + def write_error(self, msg: str, strategy_name: str = ''): + """写入错误日志""" + self.write_log(msg=msg, strategy_name=strategy_name, level=logging.ERROR) + + def send_email(self, msg: str, strategy: CtaTemplate = None): + """ + Send email to default receiver. + """ + if strategy: + subject = f"{strategy.strategy_name}" + else: + subject = "CTA策略数字货币引擎" + + self.main_engine.send_email(subject, msg) + + def send_wechat(self, msg: str, strategy: CtaTemplate = None): + """ + send wechat message to default receiver + :param msg: + :param strategy: + :return: + """ + if strategy: + subject = f"{strategy.strategy_name}" + else: + subject = "CTACRYPTO引擎" + + send_wx_msg(content=f'{subject}:{msg}') diff --git a/vnpy/app/cta_stock/portfolio_testing.py b/vnpy/app/cta_stock/portfolio_testing.py new file mode 100644 index 00000000..58078afd --- /dev/null +++ b/vnpy/app/cta_stock/portfolio_testing.py @@ -0,0 +1,484 @@ +# encoding: UTF-8 + +''' +本文件中包含的是CTA模块的组合回测引擎,回测引擎的API和CTA引擎一致, +可以使用和实盘相同的代码进行回测。 +华富资产 李来佳 +''' +from __future__ import division + +import sys +import os +import gc +import pandas as pd +import traceback +import random +import bz2 +import pickle + +from datetime import datetime, timedelta +from time import sleep + +from vnpy.trader.object import ( + TickData, + BarData, + RenkoBarData, +) +from vnpy.trader.constant import ( + Exchange, +) + +from vnpy.trader.utility import ( + get_trading_date, + extract_vt_symbol, +) + +from .back_testing import BackTestingEngine + + +class PortfolioTestingEngine(BackTestingEngine): + """ + CTA组合回测引擎, 使用回测引擎作为父类 + 函数接口和策略引擎保持一样, + 从而实现同一套代码从回测到实盘。 + 针对1分钟bar的回测 或者tick回测 + 导入CTA_Settings + + """ + + def __init__(self, event_engine=None): + """Constructor""" + super().__init__(event_engine) + + self.bar_csv_file = {} + self.bar_df_dict = {} # 历史数据的df,回测用 + self.bar_df = None # 历史数据的df,时间+symbol作为组合索引 + self.bar_interval_seconds = 60 # bar csv文件,属于K线类型,K线的周期(秒数),缺省是1分钟 + + self.tick_path = None # tick级别回测, 路径 + + def load_bar_csv_to_df(self, vt_symbol, bar_file, data_start_date=None, data_end_date=None): + """ + 加载回测bar数据到DataFrame + 1. 增加前复权/后复权 + :param vt_symbol: + :param bar_file: + :param data_start_date: + :param data_end_date: + :return: + """ + self.output(u'loading {} from {}'.format(vt_symbol, bar_file)) + if vt_symbol in self.bar_df_dict: + return True + + if bar_file is None or not os.path.exists(bar_file): + self.write_error(u'回测时,{}对应的csv bar文件{}不存在'.format(vt_symbol, bar_file)) + return False + + try: + data_types = { + "datetime": str, + "open": float, + "high": float, + "low": float, + "close": float, + "open_interest": float, + "volume": float, + "instrument_id": str, + "symbol": str, + "total_turnover": float, + "limit_down": float, + "limit_up": float, + "trading_day": str, + "date": str, + "time": str + } + # 加载csv文件 =》 dateframe + symbol_df = pd.read_csv(bar_file, dtype=data_types) + # 转换时间,str =》 datetime + symbol_df["datetime"] = pd.to_datetime(symbol_df["datetime"], format="%Y-%m-%d %H:%M:%S") + # 设置时间为索引 + symbol_df = symbol_df.set_index("datetime") + + # 裁剪数据 + symbol_df = symbol_df.loc[self.test_start_date:self.test_end_date] + + # 复权转换 + adj_list = self.adjust_factors.get(vt_symbol, []) + # 按照结束日期,裁剪复权记录 + adj_list = [row for row in adj_list if row['dividOperateDate'].replace('-', '') <= self.test_end_date] + + if adj_list: + self.write_log(f'需要对{vt_symbol}进行前复权处理') + for row in adj_list: + row.update({'dividOperateDate': row.get('dividOperateDate') + ' 09:31:00'}) + # list -> dataframe, 转换复权日期格式 + adj_data = pd.DataFrame(adj_list) + adj_data["dividOperateDate"] = pd.to_datetime(adj_data["dividOperateDate"], format="%Y-%m-%d %H:%M:%S") + adj_data = adj_data.set_index("dividOperateDate") + # 调用转换方法,对open,high,low,close, volume进行复权, fore, 前复权, 其他,后复权 + symbol_df = self.stock_to_adj(symbol_df, adj_data, adj_type='fore') + + # 添加到待合并dataframe dict中 + self.bar_df_dict.update({vt_symbol: symbol_df}) + + except Exception as ex: + self.write_error(u'回测时读取{} csv文件{}失败:{}'.format(vt_symbol, bar_file, ex)) + self.output(u'回测时读取{} csv文件{}失败:{}'.format(vt_symbol, bar_file, ex)) + return False + + return True + + def comine_bar_df(self): + """ + 合并所有回测合约的bar DataFrame =》集中的DataFrame + 把bar_df_dict =》bar_df + :return: + """ + self.output('comine_df') + self.bar_df = pd.concat(self.bar_df_dict, axis=0).swaplevel(0, 1).sort_index() + self.bar_df_dict.clear() + + def prepare_env(self, test_setting): + """ + 回测环境准备 + :param test_setting: + :return: + """ + self.output(f'准备组合回测环境') + + # 调用父类回测环境 + super().prepare_env(test_setting) + + def prepare_data(self, data_dict): + """ + 准备组合数据 + :param data_dict: 合约得配置参数 + :return: + """ + # 调用回测引擎,跟新合约得数据 + super().prepare_data(data_dict) + + if len(data_dict) == 0: + self.write_log(u'请指定回测数据和文件') + return + + if self.mode == 'tick': + return + + # 检查/更新需要回测的bar文件 + for vt_symbol, symbol_info in data_dict.items(): + self.write_log(u'配置{}数据:{}'.format(vt_symbol, symbol_info)) + + bar_file = symbol_info.get('bar_file', None) + + if bar_file is None: + self.write_error(u'{}没有配置数据文件') + continue + + if not os.path.isfile(bar_file): + self.write_log(u'{0}文件不存在'.format(bar_file)) + continue + + self.bar_csv_file.update({vt_symbol: bar_file}) + + def run_portfolio_test(self, strategy_setting: dict = {}): + """ + 运行组合回测 + """ + if not self.strategy_start_date: + self.write_error(u'回测开始日期未设置。') + return + + if len(strategy_setting) == 0: + self.write_error('未提供有效配置策略实例') + return + + self.cur_capital = self.init_capital # 更新设置期初资金 + if not self.data_end_date: + self.data_end_date = datetime.today() + + # 保存回测设置/策略设置/任务ID至数据库 + self.save_setting_to_mongo() + + self.write_log(u'开始组合回测') + + for strategy_name, strategy_setting in strategy_setting.items(): + self.load_strategy(strategy_name, strategy_setting) + + self.write_log(u'策略初始化完成') + + self.write_log(u'开始回放数据') + + self.write_log(u'开始回测:{} ~ {}'.format(self.data_start_date, self.data_end_date)) + + if self.mode == 'bar': + self.run_bar_test() + else: + self.run_tick_test() + + def run_bar_test(self): + """使用bar进行组合回测""" + testdays = (self.data_end_date - self.data_start_date).days + + if testdays < 1: + self.write_log(u'回测时间不足') + return + + # 加载数据 + for vt_symbol in self.symbol_strategy_map.keys(): + self.load_bar_csv_to_df(vt_symbol, self.bar_csv_file.get(vt_symbol)) + + # 合并数据 + self.comine_bar_df() + + last_trading_day = None + bars_dt = None + bars_same_dt = [] + + gc_collect_days = 0 + + try: + for (dt, vt_symbol), bar_data in self.bar_df.iterrows(): + symbol, exchange = extract_vt_symbol(vt_symbol) + if symbol.startswith('future_renko'): + bar_datetime = dt + bar = RenkoBarData( + gateway_name='backtesting', + symbol=symbol, + exchange=exchange, + datetime=bar_datetime + ) + bar.seconds = float(bar_data.get('seconds', 0)) + bar.high_seconds = float(bar_data.get('high_seconds', 0)) # 当前Bar的上限秒数 + bar.low_seconds = float(bar_data.get('low_seconds', 0)) # 当前bar的下限秒数 + bar.height = float(bar_data.get('height', 0)) # 当前Bar的高度限制 + bar.up_band = float(bar_data.get('up_band', 0)) # 高位区域的基线 + bar.down_band = float(bar_data.get('down_band', 0)) # 低位区域的基线 + bar.low_time = bar_data.get('low_time', None) # 最后一次进入低位区域的时间 + bar.high_time = bar_data.get('high_time', None) # 最后一次进入高位区域的时间 + else: + bar_datetime = dt - timedelta(seconds=self.bar_interval_seconds) + + bar = BarData( + gateway_name='backtesting', + symbol=symbol, + exchange=exchange, + datetime=bar_datetime + ) + if 'open' in bar_data: + bar.open_price = float(bar_data['open']) + bar.close_price = float(bar_data['close']) + bar.high_price = float(bar_data['high']) + bar.low_price = float(bar_data['low']) + else: + bar.open_price = float(bar_data['open_price']) + bar.close_price = float(bar_data['close_price']) + bar.high_price = float(bar_data['high_price']) + bar.low_price = float(bar_data['low_price']) + + bar.volume = int(bar_data['volume']) + bar.date = dt.strftime('%Y-%m-%d') + bar.time = dt.strftime('%H:%M:%S') + str_td = str(bar_data.get('trading_day', '')) + if len(str_td) == 8: + bar.trading_day = str_td[0:4] + '-' + str_td[4:6] + '-' + str_td[6:8] + else: + bar.trading_day = bar.date + + if last_trading_day != bar.trading_day: + self.output(u'回测数据日期:{},资金:{}'.format(bar.trading_day, self.net_capital)) + if self.strategy_start_date > bar.datetime: + last_trading_day = bar.trading_day + + # bar时间与队列时间一致,添加到队列中 + if dt == bars_dt: + bars_same_dt.append(bar) + continue + else: + # bar时间与队列时间不一致,先推送队列的bars + random.shuffle(bars_same_dt) + for _bar_ in bars_same_dt: + self.new_bar(_bar_) + + # 创建新的队列 + bars_same_dt = [bar] + bars_dt = dt + + # 更新每日净值 + if self.strategy_start_date <= dt <= self.data_end_date: + if last_trading_day != bar.trading_day: + if last_trading_day is not None: + self.saving_daily_data(datetime.strptime(last_trading_day, '%Y-%m-%d'), self.cur_capital, + self.max_net_capital, self.total_commission) + last_trading_day = bar.trading_day + + # 第二个交易日,撤单 + self.cancel_orders() + # 更新持仓缓存 + self.update_position_yd() + + gc_collect_days += 1 + if gc_collect_days >= 10: + # 执行内存回收 + gc.collect() + sleep(1) + gc_collect_days = 0 + + if self.net_capital < 0: + self.write_error(u'净值低于0,回测停止') + self.output(u'净值低于0,回测停止') + return + + self.write_log(u'bar数据回放完成') + if last_trading_day is not None: + self.saving_daily_data(datetime.strptime(last_trading_day, '%Y-%m-%d'), self.cur_capital, + self.max_net_capital, self.total_commission) + except Exception as ex: + self.write_error(u'回测异常导致停止:{}'.format(str(ex))) + self.write_error(u'{},{}'.format(str(ex), traceback.format_exc())) + print(str(ex), file=sys.stderr) + traceback.print_exc() + return + + def load_bz2_cache(self, cache_folder, cache_symbol, cache_date): + """加载缓存bz2数据""" + if not os.path.exists(cache_folder): + self.write_error('缓存目录:{}不存在,不能读取'.format(cache_folder)) + return None + cache_folder_year_month = os.path.join(cache_folder, cache_date[:6]) + if not os.path.exists(cache_folder_year_month): + self.write_error('缓存目录:{}不存在,不能读取'.format(cache_folder_year_month)) + return None + + cache_file = os.path.join(cache_folder_year_month, '{}_{}.pkb2'.format(cache_symbol, cache_date)) + if not os.path.isfile(cache_file): + cache_file = os.path.join(cache_folder_year_month, '{}_{}.pkz2'.format(cache_symbol, cache_date)) + if not os.path.isfile(cache_file): + self.write_error('缓存文件:{}不存在,不能读取'.format(cache_file)) + return None + + with bz2.BZ2File(cache_file, 'rb') as f: + data = pickle.load(f) + return data + + return None + + def get_day_tick_df(self, test_day): + """获取某一天得所有合约tick""" + tick_data_dict = {} + + for vt_symbol in list(self.symbol_strategy_map.keys()): + symbol, exchange = extract_vt_symbol(vt_symbol) + tick_list = self.load_bz2_cache(cache_folder=self.tick_path, + cache_symbol=symbol, + cache_date=test_day.strftime('%Y%m%d')) + if not tick_list or len(tick_list) == 0: + continue + + symbol_tick_df = pd.DataFrame(tick_list) + # 缓存文件中,datetime字段,已经是datetime格式 + # 暂时根据时间去重,没有汇总volume + symbol_tick_df.drop_duplicates(subset=['datetime'], keep='first', inplace=True) + symbol_tick_df.set_index('datetime', inplace=True) + + tick_data_dict.update({vt_symbol: symbol_tick_df}) + + if len(tick_data_dict) == 0: + return None + + tick_df = pd.concat(tick_data_dict, axis=0).swaplevel(0, 1).sort_index() + + return tick_df + + def run_tick_test(self): + """运行tick级别组合回测""" + testdays = (self.data_end_date - self.data_start_date).days + + if testdays < 1: + self.write_log(u'回测时间不足') + return + + gc_collect_days = 0 + + # 循环每一天 + for i in range(0, testdays): + test_day = self.data_start_date + timedelta(days=i) + + combined_df = self.get_day_tick_df(test_day) + + if combined_df is None: + continue + + try: + for (dt, vt_symbol), tick_data in combined_df.iterrows(): + symbol, exchange = extract_vt_symbol(vt_symbol) + tick = TickData( + gateway_name='backtesting', + symbol=symbol, + exchange=exchange, + datetime=dt, + date=dt.strftime('%Y-%m-%d'), + time=dt.strftime('%H:%M:%S.%f'), + trading_day=test_day.strftime('%Y-%m-%d'), + last_price=tick_data['price'], + volume=tick_data['volume'] + ) + + self.new_tick(tick) + + # 结束一个交易日后,更新每日净值 + self.saving_daily_data(test_day, + self.cur_capital, + self.max_net_capital, + self.total_commission) + + self.cancel_orders() + # 更新持仓缓存 + self.update_position_yd() + + gc_collect_days += 1 + if gc_collect_days >= 10: + # 执行内存回收 + gc.collect() + sleep(1) + gc_collect_days = 0 + + if self.net_capital < 0: + self.write_error(u'净值低于0,回测停止') + self.output(u'净值低于0,回测停止') + return + + except Exception as ex: + self.write_error(u'回测异常导致停止:{}'.format(str(ex))) + self.write_error(u'{},{}'.format(str(ex), traceback.format_exc())) + print(str(ex), file=sys.stderr) + traceback.print_exc() + return + + self.write_log(u'tick数据回放完成') + + +def single_test(test_setting: dict, strategy_setting: dict): + """ + 单一回测 + : test_setting, 组合回测所需的配置,包括合约信息,数据bar信息,回测时间,资金等。 + :strategy_setting, dict, 一个或多个策略配置 + """ + # 创建组合回测引擎 + engine = PortfolioTestingEngine() + + engine.prepare_env(test_setting) + try: + engine.run_portfolio_test(strategy_setting) + # 回测结果,保存 + engine.show_backtesting_result() + + except Exception as ex: + print('组合回测异常{}'.format(str(ex))) + traceback.print_exc() + engine.save_fail_to_mongo(f'回测异常{str(ex)}') + return False + + print('测试结束') + return True diff --git a/vnpy/app/cta_stock/strategies/__init__.py b/vnpy/app/cta_stock/strategies/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vnpy/app/cta_stock/strategies/readme.md b/vnpy/app/cta_stock/strategies/readme.md new file mode 100644 index 00000000..06492d81 --- /dev/null +++ b/vnpy/app/cta_stock/strategies/readme.md @@ -0,0 +1,35 @@ +策略加密 + +#windows 下加密并运行 + +1.安装Visual StudioComunity 2017,下载地址: + + https://visualstudio.microsoft.com/zh-hans/vs/older-downloads/ + 安装时请勾选“使用C++的桌面开发”。 + +2. 在Python环境中安装Cython,打开cmd后输入运行pip install cython即可。 + +3. 在”管理员”模式的命令行窗口,在策略所在目录,运行: + + cythonize -i demo_strategy.py + + 编译完成后,Demo文件夹下会多出2个新的文件,其中就有已加密的策略文件demo_strategy.cp37-win_amd64.pyd + + 改名=> demo_strategy.pyd + + 放置 demo_strategy.pyd到windows 生产环境的 strateies目录下。 + +#centos/ubuntu 下加密并运行 + + +1. 在Python环境中安装Cython,运行pip install cython即可。 + +3. 在策略所在目录,运行: + + cythonize -i demo_strategy.py + + 编译完成后,Demo文件夹下会多出2个新的文件,其中就有已加密的策略文件demo_strategy.cp37-win_amd64.so + + 改名=> demo_strategy.so + + 放置 demo_strategy.so 到centos/ubuntu 生产环境的 strateies目录下。 diff --git a/vnpy/app/cta_stock/template.py b/vnpy/app/cta_stock/template.py new file mode 100644 index 00000000..753321d1 --- /dev/null +++ b/vnpy/app/cta_stock/template.py @@ -0,0 +1,1261 @@ +"""""" +import os +import sys +import uuid +import bz2 +import pickle +import traceback +import zlib +import json +from abc import ABC +from copy import copy +from typing import Any, Callable, List, Dict +from logging import INFO, ERROR +from datetime import datetime +from vnpy.trader.constant import Interval, Direction, Offset, Status, OrderType, Exchange, Color +from vnpy.trader.object import BarData, TickData, OrderData, TradeData, PositionData +from vnpy.trader.utility import virtual, append_data, extract_vt_symbol, get_underlying_symbol, round_to + +from .base import StopOrder +from vnpy.component.cta_grid_trade import CtaGrid, CtaGridTrade +from vnpy.component.cta_position import CtaPosition +from vnpy.component.cta_policy import CtaPolicy + +class CtaTemplate(ABC): + """CTA股票策略模板""" + + author = "华富资产" + parameters = [] + variables = [] + + # 保存委托单编号和相关委托单的字典 + # key为委托单编号 + # value为该合约相关的委托单 + active_orders = {} + + def __init__( + self, + cta_engine: Any, + strategy_name: str, + vt_symbols: List[str], + setting: dict, + ): + """""" + self.cta_engine = cta_engine + self.strategy_name = strategy_name + self.vt_symbols = vt_symbols + + self.inited = False # 是否初始化完毕 + self.trading = False # 是否开始交易 + self.positions = {} # 持仓,vt_symbol: position data + self.entrust = 0 # 是否正在委托, 0, 无委托 , 1, 委托方向是LONG, -1, 委托方向是SHORT + + self.tick_dict = {} # 记录所有on_tick传入最新tick + self.active_orders = {} + # Copy a new variables list here to avoid duplicate insert when multiple + # strategy instances are created with the same strategy class. + self.variables = copy(self.variables) + self.variables.insert(0, "inited") + self.variables.insert(1, "trading") + self.variables.insert(2, "entrust") + + def update_setting(self, setting: dict): + """ + Update strategy parameter wtih value in setting dict. + """ + for name in self.parameters: + if name in setting: + setattr(self, name, setting[name]) + + @classmethod + def get_class_parameters(cls): + """ + Get default parameters dict of strategy class. + """ + class_parameters = {} + for name in cls.parameters: + class_parameters[name] = getattr(cls, name) + return class_parameters + + def get_parameters(self): + """ + Get strategy parameters dict. + """ + strategy_parameters = {} + for name in self.parameters: + strategy_parameters[name] = getattr(self, name) + return strategy_parameters + + def get_variables(self): + """ + Get strategy variables dict. + """ + strategy_variables = {} + for name in self.variables: + strategy_variables[name] = getattr(self, name) + return strategy_variables + + def get_data(self): + """ + Get strategy data. + """ + strategy_data = { + "strategy_name": self.strategy_name, + "vt_symbols": self.vt_symbols, + "class_name": self.__class__.__name__, + "author": self.author, + "parameters": self.get_parameters(), + "variables": self.get_variables(), + } + return strategy_data + + def get_positions(self): + """ 返回持仓数量""" + pos_list = [] + for k, v in self.positions.items(): + pos_list.append({ + "vt_symbol": k, + "direction": "long", + "volume": v.volume, + "price": v.price, + 'pnl': v.pnl + }) + + return pos_list + + @virtual + def on_timer(self): + pass + + @virtual + def on_init(self): + """ + Callback when strategy is inited. + """ + pass + + @virtual + def on_start(self): + """ + Callback when strategy is started. + """ + pass + + @virtual + def on_stop(self): + """ + Callback when strategy is stopped. + """ + pass + + @virtual + def on_tick(self, tick_dict: Dict[str, TickData]): + """ + Callback of new tick data update. + """ + pass + + @virtual + def on_bar(self, bar_dict: Dict[str, BarData]): + """ + Callback of new bar data update. + """ + pass + + @virtual + def on_trade(self, trade: TradeData): + """ + Callback of new trade data update. + """ + pass + + @virtual + def on_order(self, order: OrderData): + """ + Callback of new order data update. + """ + pass + + @virtual + def on_stop_order(self, stop_order: StopOrder): + """ + Callback of stop order update. + """ + pass + + def before_trading(self): + """开盘前/初始化后调用一次""" + self.write_log('开盘前调用') + + def after_trading(self): + """收盘后调用一次""" + self.write_log('收盘后调用') + + def buy(self, price: float, volume: float, stop: bool = False, + vt_symbol: str = '', order_type: OrderType = OrderType.LIMIT, + order_time: datetime = None, grid: CtaGrid = None): + """ + Send buy order to open a long position. + """ + if order_type in [OrderType.FAK, OrderType.FOK]: + if self.is_upper_limit(vt_symbol): + self.write_error(u'涨停价不做FAK/FOK委托') + return [] + return self.send_order(vt_symbol=vt_symbol, + direction=Direction.LONG, + offset=Offset.OPEN, + price=price, + volume=volume, + stop=stop, + order_type=order_type, + order_time=order_time, + grid=grid) + + def sell(self, price: float, volume: float, stop: bool = False, + vt_symbol: str = '', order_type: OrderType = OrderType.LIMIT, + order_time: datetime = None, grid: CtaGrid = None): + """ + Send sell order to close a long position. + """ + if order_type in [OrderType.FAK, OrderType.FOK]: + if self.is_lower_limit(vt_symbol): + self.write_error(u'跌停价不做FAK/FOK sell委托') + return [] + return self.send_order(vt_symbol=vt_symbol, + direction=Direction.SHORT, + offset=Offset.CLOSE, + price=price, + volume=volume, + stop=stop, + order_type=order_type, + order_time=order_time, + grid=grid) + + def send_order( + self, + vt_symbol: str, + direction: Direction, + offset: Offset, + price: float, + volume: float, + stop: bool = False, + order_type: OrderType = OrderType.LIMIT, + order_time: datetime = None, + grid: CtaGrid = None + ): + """ + Send a new order. + """ + if vt_symbol == '': + return [] + + if not self.trading: + return [] + + vt_orderids = self.cta_engine.send_order( + strategy=self, + vt_symbol=vt_symbol, + direction=direction, + offset=offset, + price=price, + volume=volume, + stop=stop, + order_type=order_type + ) + + if order_time is None: + order_time = datetime.now() + + for vt_orderid in vt_orderids: + d = { + 'direction': direction, + 'offset': offset, + 'vt_symbol': vt_symbol, + 'price': price, + 'volume': volume, + 'order_type': order_type, + 'traded': 0, + 'order_time': order_time, + 'status': Status.SUBMITTING + } + if grid: + d.update({'grid': grid}) + grid.order_ids.append(vt_orderid) + grid.order_time = order_time + self.active_orders.update({vt_orderid: d}) + if direction == Direction.LONG: + self.entrust = 1 + elif direction == Direction.SHORT: + self.entrust = -1 + return vt_orderids + + def cancel_order(self, vt_orderid: str): + """ + Cancel an existing order. + """ + if self.trading: + return self.cta_engine.cancel_order(self, vt_orderid) + + return False + + def cancel_all(self): + """ + Cancel all orders sent by strategy. + """ + if self.trading: + self.cta_engine.cancel_all(self) + + def is_upper_limit(self, symbol): + """是否涨停""" + tick = self.tick_dict.get(symbol, None) + if tick is None or tick.limit_up is None or tick.limit_up == 0: + return False + if tick.bid_price_1 == tick.limit_up: + return True + + def is_lower_limit(self, symbol): + """是否跌停""" + tick = self.tick_dict.get(symbol, None) + if tick is None or tick.limit_down is None or tick.limit_down == 0: + return False + if tick.ask_price_1 == tick.limit_down: + return True + + def write_log(self, msg: str, level: int = INFO): + """ + Write a log message. + """ + self.cta_engine.write_log(msg=msg, strategy_name=self.strategy_name, level=level) + + def write_error(self, msg: str): + """write error log message""" + self.write_log(msg=msg, level=ERROR) + + def get_engine_type(self): + """ + Return whether the cta_engine is backtesting or live trading. + """ + return self.cta_engine.get_engine_type() + + def put_event(self): + """ + Put an strategy data event for ui update. + """ + if self.inited: + self.cta_engine.put_strategy_event(self) + + def send_email(self, msg): + """ + Send email to default receiver. + """ + if self.inited: + self.cta_engine.send_email(msg, self) + + def sync_data(self): + """ + Sync strategy variables value into disk storage. + """ + if self.trading: + self.cta_engine.sync_strategy_data(self) + +class StockPolicy(CtaPolicy): + + def __init__(self, strategy): + super().__init__(strategy) + self.cur_trading_date = None # 已执行pre_trading方法后更新的当前交易日 + self.signals = {} # kline_name: { 'last_signal': '', 'last_signal_time': datetime } + self.sub_tns = {} # 子事务 + + def from_json(self, json_data): + """将数据从json_data中恢复""" + super().from_json(json_data) + + self.cur_trading_date = json_data.get('cur_trading_date', None) + self.sub_tns = json_data.get('sub_tns') + signals = json_data.get('signals', {}) + for kline_name, signal in signals: + last_signal = signal.get('last_signal', "") + str_ast_signal_time = signal.get('last_signal_time', "") + try: + if len(str_ast_signal_time) > 0: + last_signal_time = datetime.strptime(str_ast_signal_time, '%Y-%m-%d %H:%M:%S') + else: + last_signal_time = None + except Exception as ex: + last_signal_time = None + self.signals.update({kline_name: {'last_signal': last_signal, 'last_signal_time': last_signal_time}}) + + + def to_json(self): + """转换至json文件""" + j = super().to_json() + j['cur_trading_date'] = self.cur_trading_date + j['sub_tns'] = self.sub_tns + d = {} + for kline_name, signal in self.signals.items(): + last_signal_time = signal.get('last_signal_time', None) + d.update({kline_name: + {'last_signal': signal.get('last_signal', ''), + 'last_signal_time': last_signal_time.strftime( + '%Y-%m-%d %H:%M:%S') if last_signal_time is not None else "" + } + }) + j['singlals'] = d + return j + +class CtaStockTemplate(CtaTemplate): + """ + 股票增强模板 + """ + + # 委托类型 + order_type = OrderType.LIMIT + cancel_seconds = 120 # 撤单时间(秒) + + # 资金相关 + max_invest_rate = 0.1 # 最大仓位(0~1) + max_invest_margin = 0 # 资金上限 0,不限制 + + # 是否回测状态 + backtesting = False + + # 逻辑过程日志 + dist_fieldnames = ['datetime', 'symbol', 'name', 'volume', 'price', + 'operation', 'signal', 'stop_price', 'target_price', + 'long_pos'] + + def __init__(self, cta_engine, strategy_name, vt_symbols, setting): + """""" + + self.policy = None # 事务执行组件 + self.gt = None # 网格交易组件(使用了dn_grids,作为买入/持仓/卖出任务) + self.klines = {} # K线组件字典: kline_name: kline + self.positions = {} # 策略内持仓记录, vt_symbol: PositionData + self.order_type = OrderType.LIMIT + self.cancel_seconds = 120 # 撤单时间(秒) + + # 资金相关 + self.max_invest_rate = 0.1 # 最大仓位(0~1) + self.max_invest_margin = 0 # 资金上限 0,不限制 + + # 是否回测状态 + backtesting = False + + self.cur_datetime: datetime = None # 当前Tick时间 + self.last_minute = None # 最后的分钟,用于on_tick内每分钟处理的逻辑 + + super().__init__( + cta_engine, strategy_name, vt_symbols, setting + ) + + self.policy = StockPolicy(self) # 事务执行组件 + self.gt = CtaGridTrade(strategy=self) # 网格持久化模块 + + if 'backtesting' not in self.parameters: + self.parameters.append('backtesting') + + def update_setting(self, setting: dict): + """ + Update strategy parameter wtih value in setting dict. + """ + for name in self.parameters: + if name in setting: + setattr(self, name, setting[name]) + + def save_klines_to_cache(self, kline_names: list = []): + """ + 保存K线数据到缓存 + :param kline_names: 一般为self.klines的keys + :return: + """ + if len(kline_names) == 0: + kline_names = list(self.klines.keys()) + + # 获取保存路径 + save_path = self.cta_engine.get_data_path() + # 保存缓存的文件名 + file_name = os.path.abspath(os.path.join(save_path, f'{self.strategy_name}_klines.pkb2')) + with bz2.BZ2File(file_name, 'wb') as f: + klines = {} + for kline_name in kline_names: + kline = self.klines.get(kline_name, None) + # if kline: + # kline.strategy = None + # kline.cb_on_bar = None + klines.update({kline_name: kline}) + pickle.dump(klines, f) + + def load_klines_from_cache(self, kline_names: list = []): + """ + 从缓存加载K线数据 + :param kline_names: + :return: + """ + if len(kline_names) == 0: + kline_names = list(self.klines.keys()) + + save_path = self.cta_engine.get_data_path() + file_name = os.path.abspath(os.path.join(save_path, f'{self.strategy_name}_klines.pkb2')) + try: + last_bar_dt = None + with bz2.BZ2File(file_name, 'rb') as f: + klines = pickle.load(f) + # 逐一恢复K线 + for kline_name in kline_names: + # 缓存的k线实例 + cache_kline = klines.get(kline_name, None) + # 当前策略实例的K线实例 + strategy_kline = self.klines.get(kline_name, None) + + if cache_kline and strategy_kline: + # 临时保存当前的回调函数 + cb_on_bar = strategy_kline.cb_on_bar + # 缓存实例数据 =》 当前实例数据 + strategy_kline.__dict__.update(cache_kline.__dict__) + + # 所有K线的最后时间 + if last_bar_dt and strategy_kline.cur_datetime: + last_bar_dt = max(last_bar_dt, strategy_kline.cur_datetime) + else: + last_bar_dt = strategy_kline.cur_datetime + + # 重新绑定k线策略与on_bar回调函数 + strategy_kline.strategy = self + strategy_kline.cb_on_bar = cb_on_bar + + self.write_log(f'恢复{kline_name}缓存数据,最新bar结束时间:{last_bar_dt}') + + self.write_log(u'加载缓存k线数据完毕') + return last_bar_dt + except Exception as ex: + self.write_error(f'加载缓存K线数据失败:{str(ex)}') + return None + + def get_klines_snapshot(self): + """返回当前klines的切片数据""" + try: + d = { + 'strategy': self.strategy_name, + 'datetime': datetime.now()} + klines = {} + for kline_name in sorted(self.klines.keys()): + klines.update({kline_name: self.klines.get(kline_name).get_data()}) + kline_names = list(klines.keys()) + binary_data = zlib.compress(pickle.dumps(klines)) + d.update({'kline_names': kline_names, 'klines': binary_data, 'zlib': True}) + return d + except Exception as ex: + self.write_error(f'获取klines切片数据失败:{str(ex)}') + return {} + + def init_policy(self): + """初始化Policy""" + self.write_log(u'init_policy(),初始化执行逻辑') + self.policy.load() + self.write_log('{}'.format(json.dumps(self.policy.to_json(),indent=2, ensure_ascii=True))) + + def init_position(self): + """ + 初始化Position + 使用网格的持久化,获取开仓状态的持仓,更新 + :return: + """ + self.write_log(u'init_position(),初始化持仓') + + if len(self.gt.dn_grids) <= 0: + # 加载已开仓的多数据,网格JSON + long_grids = self.gt.load(direction=Direction.LONG, open_status_filter=[True, False]) + if len(long_grids) == 0: + self.write_log(u'没有持久化的多单数据') + self.gt.dn_grids = [] + else: + self.gt.dn_grids = long_grids + for lg in long_grids: + if len(lg.order_ids) > 0: + self.write_log(f'清除委托单:{lg.order_ids}') + [self.cta_engine.cancel_order(vt_orderid) for vt_orderid in lg.order_ids] + lg.order_ids = [] + if lg.open_status: + pos = self.get_position(lg.vt_symbol) + pos.volume += lg.volume + self.write_log(u'加载持仓多单[{},价格:{},数量:{}手, 开仓时间:{}' + .format(lg.vt_symbol, lg.open_price, lg.volume, lg.open_time)) + + self.gt.save() + self.display_grids() + + def get_position(self, vt_symbol) -> PositionData: + """ + 获取策略某vt_symbol持仓() + :return: + """ + pos = self.positions.get(vt_symbol) + if pos is None: + symbol, exchange = extract_vt_symbol(vt_symbol) + contract = self.cta_engine.get_contract(vt_symbol) + pos = PositionData( + gateway_name=contract.gateway_name if contract else '', + symbol=symbol, + exchange=exchange, + direction=Direction.NET + ) + self.positions.update({vt_symbol: pos}) + + return pos + + def compare_pos(self): + """比较仓位""" + for vt_symbol, position in self.positions.items(): + name = self.cta_engine.get_name(vt_symbol) + acc_pos = self.cta_engine.get_position(vt_symbol=vt_symbol, Direction=Direction.NET) + if position.volume > 0: + if not acc_pos: + self.write_error(f'账号中,没有{name}[{vt_symbol}]的持仓') + continue + if acc_pos.volume < position.volume: + self.write_error(f'{name}[{vt_symbol}]的账号持仓{acc_pos} 小于策略持仓:{position.volume}') + + def before_trading(self, dt: datetime = None): + """开盘前/初始化后调用一次""" + self.write_log(f'{self.strategy_name}开盘前检查') + + self.compare_pos() + + if not self.backtesting: + self.policy.cur_trading_date = datetime.strftime('%Y-%m-%d') + else: + if dt: + self.policy.cur_trading_date = dt.strftime('%Y-%m-%d') + + def after_trading(self): + """收盘后调用一次""" + self.write_log(f'{self.strategy_name}收盘后调用') + self.compare_pos() + + def on_trade(self, trade: TradeData): + """交易更新""" + self.write_log(u'{},交易更新:{}' + .format(self.cur_datetime, + trade.__dict__)) + + dist_record = dict() + if self.backtesting: + dist_record['datetime'] = trade.time + else: + dist_record['datetime'] = ' '.join([self.cur_datetime.strftime('%Y-%m-%d'), trade.time]) + dist_record['volume'] = trade.volume + dist_record['price'] = trade.price + dist_record['symbol'] = trade.vt_symbol + pos = self.get_position(trade.vt_symbol) + if trade.direction == Direction.LONG: + dist_record['operation'] = 'buy' + pos.volume += trade.volume + + if trade.direction == Direction.SHORT: + dist_record['operation'] = 'sell' + pos.volume -= trade.volume + + self.save_dist(dist_record) + + def on_order(self, order: OrderData): + """报单更新""" + # 未执行的订单中,存在是异常,删除 + self.write_log(u'{}报单更新,{}'.format(self.cur_datetime, order.__dict__)) + + if order.vt_orderid in self.active_orders: + + if order.volume == order.traded and order.status in [Status.ALLTRADED]: + self.on_order_all_traded(order) + + elif order.offset == Offset.OPEN and order.status in [Status.CANCELLED]: + # 开仓委托单被撤销 + self.on_order_open_canceled(order) + + elif order.offset != Offset.OPEN and order.status in [Status.CANCELLED]: + # 平仓委托单被撤销 + self.on_order_close_canceled(order) + + elif order.status == Status.REJECTED: + if order.offset == Offset.OPEN: + self.write_error(u'{}委托单开{}被拒,price:{},total:{},traded:{},status:{}' + .format(order.vt_symbol, order.direction, order.price, order.volume, + order.traded, order.status)) + self.on_order_open_canceled(order) + else: + self.write_error(u'OnOrder({})委托单平{}被拒,price:{},total:{},traded:{},status:{}' + .format(order.vt_symbol, order.direction, order.price, order.volume, + order.traded, order.status)) + self.on_order_close_canceled(order) + else: + self.write_log(u'委托单未完成,total:{},traded:{},tradeStatus:{}' + .format(order.volume, order.traded, order.status)) + else: + self.write_error(u'委托单{}不在策略的未完成订单列表中:{}'.format(order.vt_orderid, self.active_orders)) + + def on_order_all_traded(self, order: OrderData): + """ + 订单全部成交 + :param order: + :return: + """ + self.write_log(u'{},委托单:{}全部完成'.format(order.time, order.vt_orderid)) + order_info = self.active_orders[order.vt_orderid] + + # 通过vt_orderid,找到对应的网格 + grid = order_info.get('grid', None) + if grid is not None: + # 移除当前委托单 + if order.vt_orderid in grid.order_ids: + grid.order_ids.remove(order.vt_orderid) + + # 网格的所有委托单已经执行完毕 + if len(grid.order_ids) == 0: + grid.order_status = False + grid.traded_volume = 0 + + # 平仓完毕(cover, sell) + if order.offset != Offset.OPEN: + grid.open_status = False + grid.close_status = True + + self.write_log(f'{grid.direction.value}单已平仓完毕,order_price:{order.price}' + + f',volume:{order.volume}') + + self.write_log(f'移除网格:{grid.to_json()}') + self.gt.remove_grids_by_ids(direction=grid.direction, ids=[grid.id]) + + # 开仓完毕( buy, sell) + else: + grid.open_status = True + grid.open_time = self.cur_datetime + self.write_log(f'{grid.direction.value}单已开仓完毕,order_price:{order.price}' + + f',volume:{order.volume}') + + # 网格的所有委托单部分执行完毕 + else: + old_traded_volume = grid.traded_volume + grid.traded_volume += order.volume + grid.traded_volume = round(grid.traded_volume, 7) + + self.write_log(f'{grid.direction.value}单部分{order.offset}仓,' + + f'网格volume:{grid.volume}, traded_volume:{old_traded_volume}=>{grid.traded_volume}') + + self.write_log(f'剩余委托单号:{grid.order_ids}') + + # 在策略得活动订单中,移除 + self.active_orders.pop(order.vt_orderid, None) + + def on_order_open_canceled(self, order: OrderData): + """ + 委托开仓单撤销 + :param order: + :return: + """ + self.write_log(u'委托开仓单撤销:{}'.format(order.__dict__)) + + if order.vt_orderid not in self.active_orders: + self.write_error(u'{}不在未完成的委托单中{}。'.format(order.vt_orderid, self.active_orders)) + return + + old_order = self.active_orders[order.vt_orderid] + self.write_log(u'{} 委托信息:{}'.format(order.vt_orderid, old_order)) + # 更新成交数量 + old_order['traded'] = order.traded + # 获取订单对应的网格 + grid = old_order.get('grid', None) + + # 状态 =》 撤单 + pre_status = old_order.get('status', Status.NOTTRADED) + old_order.update({'status': Status.CANCELLED}) + self.write_log(u'委托单状态:{}=>{}'.format(pre_status, old_order.get('status'))) + + if grid: + if order.vt_orderid in grid.order_ids: + self.write_log(f'移除网格的开仓委托单:{order.vt_orderid}') + grid.order_ids.remove(order.vt_orderid) + + if order.traded > 0: + self.write_log(f'撤单中有成交,网格累计成交:{grid.traded_volume} => {grid.traded_volume + order.traded}') + grid.traded_volume += order.traded + + self.gt.save() + + self.active_orders.update({order.vt_orderid: old_order}) + + self.display_grids() + + def on_order_close_canceled(self, order: OrderData): + """委托平仓单撤销""" + self.write_log(u'委托平仓单撤销:{}'.format(order.__dict__)) + + if order.vt_orderid not in self.active_orders: + self.write_error(u'{}不在未完成的委托单中:{}。'.format(order.vt_orderid, self.active_orders)) + return + + # 更新 + old_order = self.active_orders[order.vt_orderid] + self.write_log(u'{} 订单信息:{}'.format(order.vt_orderid, old_order)) + + old_order['traded'] = order.traded + grid = old_order.get('grid', None) + + pre_status = old_order.get('status', Status.NOTTRADED) + old_order.update({'status': Status.CANCELLED}) + self.write_log(u'委托单状态:{}=>{}'.format(pre_status, old_order.get('status'))) + + if grid: + if order.vt_orderid in grid.order_ids: + self.write_log(f'移除网格的平仓委托单:{order.vt_orderid}') + grid.order_ids.remove(order.vt_orderid) + + if order.traded > 0: + self.write_log(f'撤单中有成交,网格累计成交:{grid.traded_volume} => {grid.traded_volume + order.traded}') + grid.traded_volume += order.traded + + self.gt.save() + + self.active_orders.update({order.vt_orderid: old_order}) + self.display_grids() + + def on_stop_order(self, stop_order: StopOrder): + self.write_log(f'停止单触发:{stop_order.__dict__}') + + def grid_check_stop(self): + """ + 网格逐一止损/止盈检查 (根据指数价格进行止损止盈) + :return: + """ + if self.entrust != 0: + return + + if not self.trading and not self.inited: + self.write_error(u'当前不允许交易') + return + + remove_gids = [] + # 多单网格逐一止损/止盈检查: + long_grids = self.gt.get_opened_grids(direction=Direction.LONG) + for lg in long_grids: + + if lg.close_status or lg.order_status or not lg.open_status: + continue + + cur_price = self.cta_engine.get_price(lg.vt_symbol) + if lg.stop_price > 0 and lg.stop_price > cur_price > 0: + # 调用平仓模块 + self.write_log(u'{} {}当前价:{} 触发止损线{},开仓价:{},v:{}'. + format(self.cur_datetime, + lg.vt_symbol, + cur_price, + lg.stop_price, + lg.open_price, + lg.volume)) + + if lg.traded_volume > 0: + lg.volume -= lg.traded_volume + lg.traded_volume = 0 + if lg.volume <= 0: + remove_gids.append(lg.id) + lg.open_status = False + lg.order_status = False + lg.close_status = False + continue + + lg.order_status = True + lg.close_status = True + self.write_log(f'{lg.vt_symbol} 数量:{lg.volume},准备卖出') + + if len(remove_gids) > 0: + self.gt.remove_grids_by_ids(direction=Direction.LONG, ids=remove_gids) + self.gt.save() + + def tns_excute_sell_grids(self): + """ + 事务执行卖出网格 + 1、找出所有order_status=True,open_status=Talse, close_status=True的网格。 + 2、比对volume和traded volume, 如果两者得数量差,大于min_trade_volume,继续发单 + :return: + """ + if not self.trading: + return + + if self.cur_datetime and 9 <= self.cur_datetime.hour <= 14: + if self.cur_datetime.hour == 12: + return + if self.cur_datetime.hour == 9 and self.cur_datetime.minute < 30: + return + if self.cur_datetime.hour == 11 and self.cur_datetime.minute >= 30: + return + + ordering_grid = None + for grid in self.gt.dn_grids: + # 排除: 未开仓/非平仓/非委托的网格 + if not grid.open_status or not grid.close_status or not grid.open_status: + continue + + # 排除存在委托单号的网格 + if len(grid.order_ids) > 0: + continue + + if grid.volume == grid.traded_volume: + self.write_log(u'网格计划卖出:{},已成交:{}'.format(grid.volume, grid.traded_volume)) + self.tns_finish_sell_grid(grid) + continue + + # 定位到首个满足条件的网格,跳出循环 + ordering_grid = grid + break + + # 没有满足条件的网格 + if ordering_grid is None: + return + + acc_symbol_pos = self.cta_engine.get_position( + vt_symbol=ordering_grid.vt_symbol, + direction=Direction.NET) + if acc_symbol_pos is None: + self.write_error(u'当前{}持仓查询不到'.format(ordering_grid.vt_symbol)) + return + + vt_symbol = ordering_grid.vt_symbol + sell_volume = ordering_grid.volume - ordering_grid.traded_volume + + if sell_volume > acc_symbol_pos.volume: + self.write_error(u'账号{}持仓{},不满足减仓目标:{}' + .format(vt_symbol, acc_symbol_pos.volume, sell_volume)) + return + + # 实盘运行时,要加入市场买卖量的判断 + if not self.backtesting: + symbol_tick = self.cta_engine.get_tick(vt_symbol) + symbol_volume_tick = self.cta_engine.get_volume_tick(vt_symbol) + # 根据市场计算,前5档买单数量 + if all([symbol_tick.ask_volume_1, symbol_tick.ask_volume_2, symbol_tick.ask_volume_3, + symbol_tick.ask_volume_4, symbol_tick.ask_volume_5]) \ + and all( + [symbol_tick.bid_volume_1, symbol_tick.bid_volume_2, symbol_tick.bid_volume_3, symbol_tick.bid_volume_4, + symbol_tick.bid_volume_5]): + market_ask_volumes = symbol_tick.ask_volume_1 + symbol_tick.ask_volume_2 + symbol_tick.ask_volume_3 + symbol_tick.ask_volume_4 + symbol_tick.ask_volume_5 + market_bid_volumes = symbol_tick.bid_volume_1 + symbol_tick.bid_volume_2 + symbol_tick.bid_volume_3 + symbol_tick.bid_volume_4 + symbol_tick.bid_volume_5 + org_sell_volume = sell_volume + if market_bid_volumes > 0 and market_ask_volumes > 0 and org_sell_volume >= 2 * symbol_volume_tick: + sell_volume = min(market_bid_volumes / 4, market_ask_volumes / 4, sell_volume) + sell_volume = max(round_to(value=sell_volume, target=symbol_volume_tick), symbol_volume_tick) + if org_sell_volume != sell_volume: + self.write_log(u'修正批次卖出{}数量:{}=>{}'.format(vt_symbol, org_sell_volume, sell_volume)) + + # 获取当前价格 + sell_price = self.cta_engine.get_price(vt_symbol) - self.cta_engine.get_price_tick(vt_symbol) + # 发出委托卖出 + vt_orderids = self.sell( + vt_symbol=vt_symbol, + price=sell_price, + volume=sell_volume, + order_time=self.cur_datetime, + grid=ordering_grid) + if vt_orderids is None or len(vt_orderids) == 0: + self.write_error(f'委托卖出失败,{vt_symbol} 委托价:{sell_price} 数量:{sell_volume}') + return + else: + self.write_log(f'已委托卖出,{sell_volume},委托价:{sell_price}, 数量:{sell_volume}') + + + def tns_finish_sell_grid(self, grid): + """ + 事务完成卖出网格 + :param grid: + :return: + """ + self.write_log( + u'卖出网格执行完毕,price:{},v:{},traded:{},type:'.format(grid.open_price, grid.volume, grid.traded_volume, grid.type)) + grid.order_status = False + grid.open_status = False + volume = grid.volume + traded_volume = grid.traded_volume + if grid.traded_volume > 0: + grid.volume = grid.traded_volume + grid.traded_volume = 0 + self.write_log(u'{} {} {} 委托状态为: {},完成状态:{} v:{}=>{},traded:{}=>{}' + .format(grid.type, grid.direction, grid.vt_symbol, + grid.order_status, grid.open_status, + volume, grid.volume, + traded_volume, grid.traded_volume)) + + dist_record = dict() + dist_record['volume'] = grid.volume + dist_record['price'] = self.cta_engine.get_price(grid.vt_symbol) + dist_record['operation'] = 'execute finished' + dist_record['signal'] = grid.type + self.save_dist(dist_record) + + id = grid.id + self.write_log(u'移除卖出网格:{}'.format(id)) + self.gt.remove_grids_by_ids(direction=Direction.LONG, ids=[id]) + self.gt.save() + self.policy.save() + + def tns_execute_buy_grids(self): + """ + 事务执行买入网格 + :return: + """ + if not self.trading: + return + if self.cur_datetime and 9 <= self.cur_datetime.hour <= 14: + if self.cur_datetime.hour == 12: + return + if self.cur_datetime.hour == 9 and self.cur_datetime.minute < 30: + return + if self.cur_datetime.hour == 11 and self.cur_datetime.minute >= 30: + return + + ordering_grid = None + for grid in self.gt.dn_grids: + # 排除已经执行完毕(处于开仓状态)的网格, 或者处于平仓状态的网格 + if grid.open_status or grid.close_status: + continue + # 排除非委托状态的网格 + if not grid.order_status: + continue + + # 排除存在委托单号的网格 + if len(grid.order_ids) > 0: + continue + + if grid.volume == grid.traded_volume: + self.write_log(u'网格计划买入:{},已成交:{}'.format(grid.volume, grid.traded_volume)) + self.tns_finish_buy_grid(grid) + return + + # 定位到首个满足条件的网格,跳出循环 + ordering_grid = grid + break + + # 没有满足条件的网格 + if ordering_grid is None: + return + + balance, availiable, _, _ = self.cta_engine.get_account() + if availiable <= 0: + self.write_error(u'当前可用资金不足'.format(availiable)) + return + vt_symbol = ordering_grid.vt_symbol + cur_price = self.cta_engine.get_price(vt_symbol) + buy_volume = ordering_grid.volume - ordering_grid.traded_volume + min_trade_volume = self.cta_engine.get_volume_tick(vt_symbol) + if availiable < buy_volume * cur_price: + self.write_error(f'可用资金{availiable},不满足买入{vt_symbol},数量:{buy_volume} X价格{cur_price}') + max_buy_volume = int(availiable / cur_price) + max_buy_volume = max_buy_volume - max_buy_volume % min_trade_volume + if max_buy_volume <= min_trade_volume: + return + # 计划买入数量,与可用资金买入数量的差别 + diff_volume = buy_volume - max_buy_volume + # 降低计划买入数量 + self.write_log(f'总计划{vt_symbol}买入数量:{ordering_grid.volume}=>{ordering_grid.volume - diff_volume}') + ordering_grid.volume -= diff_volume + self.gt.save() + buy_volume = max_buy_volume + + # 实盘运行时,要加入市场买卖量的判断 + if not self.backtesting: + symbol_tick = self.cta_engine.get_tick(vt_symbol) + # 根据市场计算,前5档买单数量 + if all([symbol_tick.ask_volume_1, symbol_tick.ask_volume_2, symbol_tick.ask_volume_3, + symbol_tick.ask_volume_4, symbol_tick.ask_volume_5]) \ + and all( + [symbol_tick.bid_volume_1, symbol_tick.bid_volume_2, symbol_tick.bid_volume_3, symbol_tick.bid_volume_4, + symbol_tick.bid_volume_5]): + market_ask_volumes = symbol_tick.ask_volume_1 + symbol_tick.ask_volume_2 + symbol_tick.ask_volume_3 + symbol_tick.ask_volume_4 + symbol_tick.ask_volume_5 + market_bid_volumes = symbol_tick.bid_volume_1 + symbol_tick.bid_volume_2 + symbol_tick.bid_volume_3 + symbol_tick.bid_volume_4 + symbol_tick.bid_volume_5 + if market_bid_volumes > 0 and market_ask_volumes > 0: + buy_volume = min(market_bid_volumes / 4, market_ask_volumes / 4, buy_volume) + buy_volume = max(buy_volume - buy_volume % min_trade_volume, min_trade_volume) + + buy_price = cur_price + self.cta_engine.get_price_tick(vt_symbol) + + vt_orderids = self.buy( + vt_symbol=vt_symbol, + price=buy_price, + volume=buy_volume, + order_time=self.cur_datetime, + grid=ordering_grid) + if vt_orderids is None or len(vt_orderids) == 0: + self.write_error(f'委托买入失败,{vt_symbol} 委托价:{buy_price} 数量:{buy_volume}') + return + else: + self.write_error(f'已委托买入,{vt_symbol} 委托价:{buy_price} 数量:{buy_volume}') + + def tns_finish_buy_grid(self, grid): + """ + 事务完成买入网格 + :return: + """ + self.write_log(u'事务完成买入网格:{},计划数量:{},计划价格:{},实际数量:{}' + .format(grid.type, grid.volume, grid.openPrice, grid.traded_volume)) + if grid.volume != grid.traded_volume: + grid.volume = grid.traded_volume + grid.traded_volume = 0 + grid.open_status = True + grid.order_status = False + grid.open_time = self.cur_datetime + + dist_record = dict() + dist_record['symbol'] = grid.vt_symbol + dist_record['volume'] = grid.volume + dist_record['price'] = self.cta_engine.get_price(grid.vt_symbol) + dist_record['operation'] = '{} finished'.format(grid.type) + dist_record['signal'] = grid.type + self.save_dist(dist_record) + + self.gt.save() + + def cancel_all_orders(self): + """ + 重载撤销所有正在进行得委托 + :return: + """ + self.write_log(u'撤销所有正在进行得委托') + self.tns_cancel_logic(dt=datetime.now(), force=True) + + def tns_cancel_logic(self, dt, force=False): + "撤单逻辑""" + if len(self.active_orders) < 1: + self.entrust = 0 + return + + canceled_ids = [] + + for vt_orderid in list(self.active_orders.keys()): + order_info = self.active_orders[vt_orderid] + order_vt_symbol = order_info.get('vt_symbol') + order_time = order_info['order_time'] + order_volume = order_info['volume'] - order_info['traded'] + # order_price = order_info['price'] + # order_direction = order_info['direction'] + # order_offset = order_info['offset'] + order_grid = order_info['grid'] + order_status = order_info.get('status', Status.NOTTRADED) + order_type = order_info.get('order_type', OrderType.LIMIT) + over_seconds = (dt - order_time).total_seconds() + + # 只处理未成交的限价委托单 + if order_status in [Status.SUBMITTING, Status.NOTTRADED] and order_type == OrderType.LIMIT: + if over_seconds > self.cancel_seconds or force: # 超过设置的时间还未成交 + self.write_log(u'超时{}秒未成交,取消委托单:vt_orderid:{},order:{}' + .format(over_seconds, vt_orderid, order_info)) + order_info.update({'status': Status.CANCELLING}) + self.active_orders.update({vt_orderid: order_info}) + ret = self.cancel_order(str(vt_orderid)) + if not ret: + self.write_log(u'撤单失败,更新状态为撤单成功') + order_info.update({'status': Status.CANCELLED}) + self.active_orders.update({vt_orderid: order_info}) + if order_grid: + if vt_orderid in order_grid.order_ids: + order_grid.order_ids.remove(vt_orderid) + + continue + + # 处理状态为‘撤销’的委托单 + elif order_status == Status.CANCELLED: + self.write_log(u'委托单{}已成功撤单,删除{}'.format(vt_orderid, order_info)) + canceled_ids.append(vt_orderid) + + # 删除撤单的订单 + for vt_orderid in canceled_ids: + self.write_log(u'删除orderID:{0}'.format(vt_orderid)) + self.active_orders.pop(vt_orderid, None) + + if len(self.active_orders) == 0: + self.entrust = 0 + + def display_grids(self): + """更新网格显示信息""" + if not self.inited: + return + + opening_info = "" + closing_info = "" + holding_info = "" + + for grid in self.gt.dn_grids: + name = self.cta_engine.get_name(grid.vt_symbol) + + if not grid.open_status and grid.order_status: + opening_info += f'网格{grid.type},买入状态:{name}[{grid.vt_symbol}], [已买入:{grid.traded_volume} => 目标:{grid.volume}, 委托时间:{grid.order_time}\n' + continue + + if grid.open_status and not grid.close_status: + holding_info += f'网格{grid.type},持有状态:{name}[{grid.vt_symbol}],[数量:{grid.volume}, 开仓时间:{grid.open_time}]\n' + continue + + if grid.open_status and grid.close_status: + closing_info += f'网格{grid.type},卖出状态:{name}[{grid.vt_symbol}], [已卖出:{grid.traded_volume} => 目标:{grid.volume}, 委托时间:{grid.order_time}\n' + + if len(opening_info) > 0: + self.write_log(opening_info) + if len(holding_info) > 0: + self.write_log(holding_info) + if len(closing_info) > 0: + self.write_log(closing_info) + + def display_tns(self): + """显示事务的过程记录=》 log""" + if not self.inited: + return + if hasattr(self, 'policy'): + policy = getattr(self, 'policy') + op = getattr(policy, 'to_json', None) + if callable(op): + self.write_log(u'当前Policy:{}'.format(json.dumps(policy.to_json(), indent=2, ensure_ascii=False))) + + def save_dist(self, dist_data): + """ + 保存策略逻辑过程记录=》 csv文件按 + :param dist_data: + :return: + """ + if self.backtesting: + save_path = self.cta_engine.get_logs_path() + else: + save_path = self.cta_engine.get_data_path() + try: + + if 'datetime' not in dist_data: + dist_data.update({'datetime': self.cur_datetime}) + if 'long_pos' not in dist_data: + vt_symbol = dist_data.get('symbol') + if vt_symbol: + pos = self.get_position(vt_symbol) + dist_data.update({'long_pos': pos.volume}) + if 'name' not in dist_data: + dist_data['name'] = self.cta_engine.get_name(vt_symbol) + + file_name = os.path.abspath(os.path.join(save_path, f'{self.strategy_name}_dist.csv')) + append_data(file_name=file_name, dict_data=dist_data, field_names=self.dist_fieldnames) + except Exception as ex: + self.write_error(u'save_dist 异常:{} {}'.format(str(ex), traceback.format_exc())) + + def save_tns(self, tns_data): + """ + 保存多空事务记录=》csv文件,便于后续分析 + :param tns_data: + :return: + """ + if self.backtesting: + save_path = self.cta_engine.get_logs_path() + else: + save_path = self.cta_engine.get_data_path() + + try: + file_name = os.path.abspath(os.path.join(save_path, f'{self.strategy_name}_tns.csv')) + append_data(file_name=file_name, dict_data=tns_data) + except Exception as ex: + self.write_error(u'save_tns 异常:{} {}'.format(str(ex), traceback.format_exc())) + + def send_wechat(self, msg: str): + """实盘时才发送微信""" + if self.backtesting: + return + self.cta_engine.send_wechat(msg=msg, strategy=self) diff --git a/vnpy/app/cta_stock/ui/__init__.py b/vnpy/app/cta_stock/ui/__init__.py new file mode 100644 index 00000000..592d401a --- /dev/null +++ b/vnpy/app/cta_stock/ui/__init__.py @@ -0,0 +1 @@ +from .widget import CtaManager diff --git a/vnpy/app/cta_stock/ui/cta.ico b/vnpy/app/cta_stock/ui/cta.ico new file mode 100644 index 00000000..25cbaa73 Binary files /dev/null and b/vnpy/app/cta_stock/ui/cta.ico differ diff --git a/vnpy/app/cta_stock/ui/widget.py b/vnpy/app/cta_stock/ui/widget.py new file mode 100644 index 00000000..935f912d --- /dev/null +++ b/vnpy/app/cta_stock/ui/widget.py @@ -0,0 +1,464 @@ +from vnpy.event import Event, EventEngine +from vnpy.trader.engine import MainEngine +from vnpy.trader.ui import QtCore, QtGui, QtWidgets +from vnpy.trader.ui.widget import ( + BaseCell, + EnumCell, + MsgCell, + TimeCell, + BaseMonitor +) +from ..base import ( + APP_NAME, + EVENT_CTA_LOG, + EVENT_CTA_STOPORDER, + EVENT_CTA_STRATEGY +) +from ..engine import CtaEngine + + +class CtaManager(QtWidgets.QWidget): + """""" + + signal_log = QtCore.pyqtSignal(Event) + signal_strategy = QtCore.pyqtSignal(Event) + + def __init__(self, main_engine: MainEngine, event_engine: EventEngine): + super(CtaManager, self).__init__() + + self.main_engine = main_engine + self.event_engine = event_engine + self.cta_engine = main_engine.get_engine(APP_NAME) + + self.managers = {} + + self.init_ui() + self.register_event() + self.cta_engine.init_engine() + self.update_class_combo() + + def init_ui(self): + """""" + self.setWindowTitle("CTA策略") + + # Create widgets + self.class_combo = QtWidgets.QComboBox() + + add_button = QtWidgets.QPushButton("添加策略") + add_button.clicked.connect(self.add_strategy) + + init_button = QtWidgets.QPushButton("全部初始化") + init_button.clicked.connect(self.cta_engine.init_all_strategies) + + start_button = QtWidgets.QPushButton("全部启动") + start_button.clicked.connect(self.cta_engine.start_all_strategies) + + stop_button = QtWidgets.QPushButton("全部停止") + stop_button.clicked.connect(self.cta_engine.stop_all_strategies) + + clear_button = QtWidgets.QPushButton("清空日志") + clear_button.clicked.connect(self.clear_log) + + self.scroll_layout = QtWidgets.QVBoxLayout() + self.scroll_layout.addStretch() + + scroll_widget = QtWidgets.QWidget() + scroll_widget.setLayout(self.scroll_layout) + + scroll_area = QtWidgets.QScrollArea() + scroll_area.setWidgetResizable(True) + scroll_area.setWidget(scroll_widget) + + self.log_monitor = LogMonitor(self.main_engine, self.event_engine) + + self.stop_order_monitor = StopOrderMonitor( + self.main_engine, self.event_engine + ) + + # Set layout + hbox1 = QtWidgets.QHBoxLayout() + hbox1.addWidget(self.class_combo) + hbox1.addWidget(add_button) + hbox1.addStretch() + hbox1.addWidget(init_button) + hbox1.addWidget(start_button) + hbox1.addWidget(stop_button) + hbox1.addWidget(clear_button) + + grid = QtWidgets.QGridLayout() + grid.addWidget(scroll_area, 0, 0, 2, 1) + grid.addWidget(self.stop_order_monitor, 0, 1) + grid.addWidget(self.log_monitor, 1, 1) + + vbox = QtWidgets.QVBoxLayout() + vbox.addLayout(hbox1) + vbox.addLayout(grid) + + self.setLayout(vbox) + + def update_class_combo(self): + """""" + self.class_combo.addItems( + self.cta_engine.get_all_strategy_class_names() + ) + + def register_event(self): + """""" + self.signal_strategy.connect(self.process_strategy_event) + + self.event_engine.register( + EVENT_CTA_STRATEGY, self.signal_strategy.emit + ) + + def process_strategy_event(self, event): + """ + Update strategy status onto its monitor. + """ + data = event.data + strategy_name = data["strategy_name"] + + if strategy_name in self.managers: + manager = self.managers[strategy_name] + manager.update_data(data) + else: + manager = StrategyManager(self, self.cta_engine, data) + self.scroll_layout.insertWidget(0, manager) + self.managers[strategy_name] = manager + + def remove_strategy(self, strategy_name): + """""" + manager = self.managers.pop(strategy_name) + manager.deleteLater() + + def add_strategy(self): + """""" + class_name = str(self.class_combo.currentText()) + if not class_name: + return + + parameters = self.cta_engine.get_strategy_class_parameters(class_name) + editor = SettingEditor(parameters, class_name=class_name) + n = editor.exec_() + + if n == editor.Accepted: + setting = editor.get_setting() + vt_symbol = setting.pop("vt_symbol") + strategy_name = setting.pop("strategy_name") + auto_init = setting.pop("auto_init", False) + auto_start = setting.pop("auto_start", False) + self.cta_engine.add_strategy( + class_name, strategy_name, vt_symbol, setting, auto_init, auto_start + ) + + def clear_log(self): + """""" + self.log_monitor.setRowCount(0) + + def show(self): + """""" + self.showMaximized() + + +class StrategyManager(QtWidgets.QFrame): + """ + Manager for a strategy + """ + + def __init__( + self, cta_manager: CtaManager, cta_engine: CtaEngine, data: dict + ): + """""" + super(StrategyManager, self).__init__() + + self.cta_manager = cta_manager + self.cta_engine = cta_engine + + self.strategy_name = data["strategy_name"] + self._data = data + + self.init_ui() + + def init_ui(self): + """""" + self.setFixedHeight(300) + self.setFrameShape(self.Box) + self.setLineWidth(1) + + init_button = QtWidgets.QPushButton("初始化") + init_button.clicked.connect(self.init_strategy) + + start_button = QtWidgets.QPushButton("启动") + start_button.clicked.connect(self.start_strategy) + + stop_button = QtWidgets.QPushButton("停止") + stop_button.clicked.connect(self.stop_strategy) + + edit_button = QtWidgets.QPushButton("编辑") + edit_button.clicked.connect(self.edit_strategy) + + remove_button = QtWidgets.QPushButton("移除") + remove_button.clicked.connect(self.remove_strategy) + + reload_button = QtWidgets.QPushButton("重载") + reload_button.clicked.connect(self.reload_strategy) + + save_button = QtWidgets.QPushButton("保存") + save_button.clicked.connect(self.save_strategy) + + strategy_name = self._data["strategy_name"] + vt_symbol = self._data["vt_symbol"] + class_name = self._data["class_name"] + author = self._data["author"] + + label_text = ( + f"{strategy_name} - {vt_symbol} ({class_name} by {author})" + ) + label = QtWidgets.QLabel(label_text) + label.setAlignment(QtCore.Qt.AlignCenter) + + self.parameters_monitor = DataMonitor(self._data["parameters"]) + self.variables_monitor = DataMonitor(self._data["variables"]) + + hbox = QtWidgets.QHBoxLayout() + hbox.addWidget(init_button) + hbox.addWidget(start_button) + hbox.addWidget(stop_button) + hbox.addWidget(edit_button) + hbox.addWidget(remove_button) + hbox.addWidget(reload_button) + hbox.addWidget(save_button) + + vbox = QtWidgets.QVBoxLayout() + vbox.addWidget(label) + vbox.addLayout(hbox) + vbox.addWidget(self.parameters_monitor) + vbox.addWidget(self.variables_monitor) + self.setLayout(vbox) + + def update_data(self, data: dict): + """""" + self._data = data + + self.parameters_monitor.update_data(data["parameters"]) + self.variables_monitor.update_data(data["variables"]) + + def init_strategy(self): + """""" + self.cta_engine.init_strategy(self.strategy_name) + + def start_strategy(self): + """""" + self.cta_engine.start_strategy(self.strategy_name) + + def stop_strategy(self): + """""" + self.cta_engine.stop_strategy(self.strategy_name) + + def edit_strategy(self): + """""" + strategy_name = self._data["strategy_name"] + + parameters = self.cta_engine.get_strategy_parameters(strategy_name) + editor = SettingEditor(parameters, strategy_name=strategy_name) + n = editor.exec_() + + if n == editor.Accepted: + setting = editor.get_setting() + self.cta_engine.edit_strategy(strategy_name, setting) + + def remove_strategy(self): + """""" + result = self.cta_engine.remove_strategy(self.strategy_name) + + # Only remove strategy gui manager if it has been removed from engine + if result: + self.cta_manager.remove_strategy(self.strategy_name) + + def reload_strategy(self): + """重新加载策略""" + self.cta_engine.reload_strategy(self.strategy_name) + + def save_strategy(self): + self.cta_engine.save_strategy_data(self.strategy_name) + + +class DataMonitor(QtWidgets.QTableWidget): + """ + Table monitor for parameters and variables. + """ + + def __init__(self, data: dict): + """""" + super(DataMonitor, self).__init__() + + self._data = data + self.cells = {} + + self.init_ui() + + def init_ui(self): + """""" + labels = list(self._data.keys()) + self.setColumnCount(len(labels)) + self.setHorizontalHeaderLabels(labels) + + self.setRowCount(1) + self.verticalHeader().setSectionResizeMode( + QtWidgets.QHeaderView.Stretch + ) + self.verticalHeader().setVisible(False) + self.setEditTriggers(self.NoEditTriggers) + + for column, name in enumerate(self._data.keys()): + value = self._data[name] + + cell = QtWidgets.QTableWidgetItem(str(value)) + cell.setTextAlignment(QtCore.Qt.AlignCenter) + + self.setItem(0, column, cell) + self.cells[name] = cell + + def update_data(self, data: dict): + """""" + for name, value in data.items(): + cell = self.cells[name] + cell.setText(str(value)) + + +class StopOrderMonitor(BaseMonitor): + """ + Monitor for local stop order. + """ + + event_type = EVENT_CTA_STOPORDER + data_key = "stop_orderid" + sorting = True + + headers = { + "stop_orderid": { + "display": "停止委托号", + "cell": BaseCell, + "update": False, + }, + "vt_orderids": {"display": "限价委托号", "cell": BaseCell, "update": True}, + "vt_symbol": {"display": "本地代码", "cell": BaseCell, "update": False}, + "direction": {"display": "方向", "cell": EnumCell, "update": False}, + "offset": {"display": "开平", "cell": EnumCell, "update": False}, + "price": {"display": "价格", "cell": BaseCell, "update": False}, + "volume": {"display": "数量", "cell": BaseCell, "update": False}, + "status": {"display": "状态", "cell": EnumCell, "update": True}, + "lock": {"display": "锁仓", "cell": BaseCell, "update": False}, + "strategy_name": {"display": "策略名", "cell": BaseCell, "update": False}, + } + + +class LogMonitor(BaseMonitor): + """ + Monitor for log data. + """ + + event_type = EVENT_CTA_LOG + data_key = "" + sorting = False + + headers = { + "time": {"display": "时间", "cell": TimeCell, "update": False}, + "msg": {"display": "信息", "cell": MsgCell, "update": False}, + } + + def init_ui(self): + """ + Stretch last column. + """ + super(LogMonitor, self).init_ui() + + self.horizontalHeader().setSectionResizeMode( + 1, QtWidgets.QHeaderView.Stretch + ) + + def insert_new_row(self, data): + """ + Insert a new row at the top of table. + """ + super(LogMonitor, self).insert_new_row(data) + self.resizeRowToContents(0) + + +class SettingEditor(QtWidgets.QDialog): + """ + For creating new strategy and editing strategy parameters. + """ + + def __init__( + self, parameters: dict, strategy_name: str = "", class_name: str = "" + ): + """""" + super(SettingEditor, self).__init__() + + self.parameters = parameters + self.strategy_name = strategy_name + self.class_name = class_name + + self.edits = {} + + self.init_ui() + + def init_ui(self): + """""" + form = QtWidgets.QFormLayout() + + # Add vt_symbol and name edit if add new strategy + if self.class_name: + self.setWindowTitle(f"添加策略:{self.class_name}") + button_text = "添加" + parameters = {"strategy_name": "", "vt_symbol": "", "auto_init": True, "auto_start": True} + parameters.update(self.parameters) + + else: + self.setWindowTitle(f"参数编辑:{self.strategy_name}") + button_text = "确定" + parameters = self.parameters + + for name, value in parameters.items(): + type_ = type(value) + + edit = QtWidgets.QLineEdit(str(value)) + if type_ is int: + validator = QtGui.QIntValidator() + edit.setValidator(validator) + elif type_ is float: + validator = QtGui.QDoubleValidator() + edit.setValidator(validator) + + form.addRow(f"{name} {type_}", edit) + + self.edits[name] = (edit, type_) + + button = QtWidgets.QPushButton(button_text) + button.clicked.connect(self.accept) + form.addRow(button) + + self.setLayout(form) + + def get_setting(self): + """""" + setting = {} + + if self.class_name: + setting["class_name"] = self.class_name + + for name, tp in self.edits.items(): + edit, type_ = tp + value_text = edit.text() + + if type_ == bool: + if value_text == "True": + value = True + else: + value = False + else: + value = type_(value_text) + + setting[name] = value + + return setting