diff --git a/vnpy/app/cta_stock/back_testing.py b/vnpy/app/cta_stock/back_testing.py index 3ec531c4..4522ad92 100644 --- a/vnpy/app/cta_stock/back_testing.py +++ b/vnpy/app/cta_stock/back_testing.py @@ -589,7 +589,7 @@ class BackTestingEngine(object): 股票数据复权转换 :param raw_data: 不复权数据 :param adj_data: 复权记录 ( 从barstock下载的复权记录列表=》df) - :param adj_type: 复权类型 + :param adj_type: 复权类型: fore 前复权, 其他:后复权 :return: """ @@ -1735,7 +1735,8 @@ class BackTestingEngine(object): # 返回回测结果 d = {} d['init_capital'] = self.init_capital - d['profit'] = self.cur_capital - self.init_capital + d['net_capital'] = self.net_capital + d['profit'] = self.net_capital - self.init_capital d['max_capital'] = self.max_net_capital # 取消原 maxCapital if len(self.pnl_list) == 0: @@ -1816,6 +1817,9 @@ class BackTestingEngine(object): result_info.update({u'期初资金': d['init_capital']}) self.output(u'期初资金:\t%s' % format_number(d['init_capital'])) + result_info.update({u'期末资金': d['net_capital']}) + self.output(u'期末资金:\t%s' % format_number(d['net_capital'])) + result_info.update({u'总盈亏': d['profit']}) self.output(u'总盈亏:\t%s' % format_number(d['profit'])) diff --git a/vnpy/app/cta_stock/engine.py b/vnpy/app/cta_stock/engine.py index bb3868f4..459a2868 100644 --- a/vnpy/app/cta_stock/engine.py +++ b/vnpy/app/cta_stock/engine.py @@ -10,6 +10,8 @@ import traceback import json import pickle import bz2 +import pandas as pd +import numpy as np from collections import defaultdict from pathlib import Path @@ -60,12 +62,16 @@ from vnpy.trader.utility import ( get_folder_path, get_underlying_symbol, append_data, - import_module_by_str) + import_module_by_str, +get_csv_last_dt) from vnpy.trader.util_logger import setup_logger, logging from vnpy.trader.util_wechat import send_wx_msg from vnpy.trader.converter import PositionHolding - +from vnpy.data.mongo.mongo_data import MongoData +from vnpy.trader.setting import SETTINGS +from vnpy.data.stock.adjust_factor import get_all_adjust_factor +from vnpy.data.stock.stock_base import get_stock_base from .base import ( APP_NAME, EVENT_CTA_LOG, @@ -115,9 +121,14 @@ class CtaEngine(BaseEngine): # "trade_2_wx": true # 是否交易记录转发至微信通知 # "event_log: false # 是否转发日志到event bus,显示在图形界面 # "snapshot2file": false # 是否保存切片到文件 - self.engine_config = {} + # "get_pos_from_db": false # 是否从数据库中获取所有策略得持仓信息。(因为其他RPC进程也运行策略,推送到Mongodb) + # "compare_pos": True, 是否比对仓位 + # "compare_pos_bypass_names": ["name1","name2"] # 比对仓位时,过滤账号中含有字符得股票,例如自动申购得转债等 + self.engine_config = load_json(self.engine_filename) + # 是否激活 write_log写入event bus(比较耗资源) - self.event_log = False + self.event_log = self.engine_config.get('event_log', False) + self.strategy_setting = {} # strategy_name: dict self.strategy_data = {} # strategy_name: dict @@ -157,6 +168,40 @@ class CtaEngine(BaseEngine): self.symbol_bar_dict = {} # vt_symbol: bar(一分钟bar) + self.stock_adjust_factors = get_all_adjust_factor() + + self.mongo_data = None + + # 获取全量股票信息 + self.write_log(f'获取全量股票信息') + self.symbol_dict = get_stock_base() + self.write_log(f'共{len(self.symbol_dict)}个股票') + # 除权因子 + self.write_log(f'获取所有除权因子') + self.adjust_factor_dict = get_all_adjust_factor() + self.write_log(f'共{len(self.adjust_factor_dict)}条除权信息') + + # 寻找数据文件所在目录 + vnpy_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..')) + self.write_log(f'项目所在目录:{vnpy_root}') + self.bar_data_folder = os.path.abspath(os.path.join(vnpy_root, 'bar_data')) + if os.path.exists(self.bar_data_folder): + SSE_folder = os.path.abspath(os.path.join(vnpy_root, 'bar_data', 'SSE')) + if os.path.exists(SSE_folder): + self.write_log(f'上交所bar数据目录:{SSE_folder}') + else: + self.write_error(f'不存在上交所数据目录:{SSE_folder}') + + SZSE_folder = os.path.abspath(os.path.join(vnpy_root, 'bar_data', 'SZSE')) + if os.path.exists(SZSE_folder): + self.write_log(f'深交所bar数据目录:{SZSE_folder}') + else: + self.write_error(f'不存在深交所数据目录:{SZSE_folder}') + else: + self.write_error(f'不存在bar数据目录:{self.bar_data_folder}') + self.bar_data_folder = None + + def init_engine(self): """ """ @@ -168,6 +213,25 @@ class CtaEngine(BaseEngine): self.write_log("CTA策略股票引擎初始化成功") + if self.engine_config.get('get_pos_from_db',False): + self.init_mongo_data() + + def init_mongo_data(self): + """初始化hams数据库""" + host = SETTINGS.get('hams.host', 'localhost') + port = SETTINGS.get('hams.port', 27017) + self.write_log(f'初始化hams数据库连接:{host}:{port}') + try: + # Mongo数据连接客户端 + self.mongo_data = MongoData(host=host, port=port) + + if self.mongo_data and self.mongo_data.db_has_connected: + self.write_log(f'连接成功') + else: + self.write_error(f'HAMS数据库{host}:{port}连接异常.') + except Exception as ex: + self.write_error(f'HAMS数据库{host}:{port}连接异常.{str(ex)}') + def close(self): """停止所属有的策略""" self.stop_all_strategies() @@ -227,12 +291,12 @@ class CtaEngine(BaseEngine): self.health_check() # 在国内股市开盘期间做检查 - if '0930' < self.last_minute < '1530': + if '0930' < self.last_minute < '1830': # 主动获取所有策略得持仓信息 all_strategy_pos = self.get_all_strategy_pos() # 每5分钟检查一次 - if dt.minute % 10 == 0: + if dt.minute % 10 == 0 and self.engine_config.get('compare_pos', True): # 比对仓位,使用上述获取得持仓信息,不用重复获取 self.compare_pos(strategy_pos_list=copy(all_strategy_pos)) @@ -881,6 +945,25 @@ class CtaEngine(BaseEngine): def get_contract(self, vt_symbol): return self.main_engine.get_contract(vt_symbol) + def get_adjust_factor(self, vt_symbol, check_date=None): + """ + 获取[check_date前]除权因子 + :param vt_symbol: + :param check_date: 某一指定日期 + :return: + """ + stock_adjust_factor_list = self.stock_adjust_factors.get(vt_symbol, []) + if len(stock_adjust_factor_list) == 0: + return None + stock_adjust_factor_list.reverse() + if check_date is None: + check_date = datetime.now().strftime('%Y-%m-%d') + + for d in stock_adjust_factor_list: + if d.get("dividOperateDate","") < check_date: + return d + return None + def get_account(self, vt_accountid: str = ""): """ 查询账号的资金""" # 如果启动风控,则使用风控中的最大仓位 @@ -955,6 +1038,253 @@ class CtaEngine(BaseEngine): callback(bar) + + def get_bars( + self, + vt_symbol: str, + days: int, + interval: Interval, + interval_num: int = 1 + ): + """获取历史记录""" + symbol, exchange = extract_vt_symbol(vt_symbol) + end = datetime.now() + start = end - timedelta(days) + bars = [] + + # 检查股票代码 + if vt_symbol not in self.symbol_dict: + self.write_error(f'{vt_symbol}不在基础配置股票信息中') + return bars + + # 检查数据文件目录 + if not self.bar_data_folder: + self.write_error(f'没有bar数据目录') + return bars + # 按照交易所的存放目录 + bar_file_folder = os.path.abspath(os.path.join(self.bar_data_folder, f'{exchange.value}')) + + resample_min = False + resample_hour = False + resample_day = False + file_interval_num = 1 + # 只有1,5,15,30分钟,日线数据 + if interval == Interval.MINUTE: + # 如果存在相应的分钟文件,直接读取 + bar_file_path = os.path.abspath(os.path.join(bar_file_folder, f'{symbol}_{interval_num}m.csv')) + if interval_num in [1, 5, 15, 30] and os.path.exists(bar_file_path): + file_interval_num = interval + # 需要resample + else: + resample_min = True + if interval_num > 5: + file_interval_num = 5 + + elif interval == Interval.HOUR: + file_interval_num = 5 + resample_hour = True + bar_file_path = os.path.abspath(os.path.join(bar_file_folder, f'{symbol}_{file_interval_num}m.csv')) + elif interval == Interval.DAILY: + bar_file_path = os.path.abspath(os.path.join(bar_file_folder, f'{symbol}_{interval_num}d.csv')) + if not os.path.exists(bar_file_path): + file_interval_num = 5 + resample_day = True + bar_file_path = os.path.abspath(os.path.join(bar_file_folder, f'{symbol}_{file_interval_num}m.csv')) + else: + self.write_error(f'目前仅支持分钟,小时,日线数据') + return bars + + bar_interval_seconds = interval_num * 60 + + if not os.path.exists(bar_file_path): + self.write_error(f'没有bar数据文件:{bar_file_path}') + return bars + + try: + data_types = { + "datetime": str, + "open": float, + "high": float, + "low": float, + "close": float, + "volume": float, + "amount": float, + "symbol": str, + "trading_day": str, + "date": str, + "time": str + } + + symbol_df = None + qfq_bar_file_path = bar_file_path.replace('.csv', '_qfq.csv') + use_qfq_file = False + last_qfq_dt = get_csv_last_dt(qfq_bar_file_path) + if last_qfq_dt is not None: + last_dt = get_csv_last_dt(bar_file_path) + + if last_qfq_dt == last_dt: + use_qfq_file = True + + if use_qfq_file: + self.write_log(f'使用前复权文件:{qfq_bar_file_path}') + symbol_df = pd.read_csv(qfq_bar_file_path, dtype=data_types) + else: + # 加载csv文件 =》 dateframe + self.write_log(f'使用未复权文件:{bar_file_path}') + symbol_df = pd.read_csv(bar_file_path, dtype=data_types) + + # 转换时间,str =》 datetime + symbol_df["datetime"] = pd.to_datetime(symbol_df["datetime"], format="%Y-%m-%d %H:%M:%S") + # 设置时间为索引 + symbol_df = symbol_df.set_index("datetime") + + # 裁剪数据 + symbol_df = symbol_df.loc[start:end] + + if resample_day: + self.write_log(f'{vt_symbol} resample:{file_interval_num}m => {interval}day') + symbol_df = self.resample_bars(df=symbol_df, to_day=True) + elif resample_hour: + self.write_log(f'{vt_symbol} resample:{file_interval_num}m => {interval}hour') + symbol_df = self.resample_bars(df=symbol_df, x_hour=interval_num) + elif resample_min: + self.write_log(f'{vt_symbol} resample:{file_interval_num}m => {interval}m') + symbol_df = self.resample_bars(df=symbol_df, x_min=interval_num) + + if len(symbol_df) == 0: + return bars + + if not use_qfq_file: + # 复权转换 + adj_list = self.adjust_factor_dict.get(vt_symbol, []) + # 按照结束日期,裁剪复权记录 + adj_list = [row for row in adj_list if row['dividOperateDate'].replace('-', '') <= end.strftime('%Y%m%d')] + + if len(adj_list) > 0: + self.write_log(f'需要对{vt_symbol}进行前复权处理') + for row in adj_list: + row.update({'dividOperateDate': row.get('dividOperateDate')[:10] + ' 09:31:00'}) + # list -> dataframe, 转换复权日期格式 + adj_data = pd.DataFrame(adj_list) + adj_data["dividOperateDate"] = pd.to_datetime(adj_data["dividOperateDate"], format="%Y-%m-%d %H:%M:%S") + adj_data = adj_data.set_index("dividOperateDate") + # 调用转换方法,对open,high,low,close, volume进行复权, fore, 前复权, 其他,后复权 + symbol_df = self.stock_to_adj(symbol_df, adj_data, adj_type='fore') + + for dt, bar_data in symbol_df.iterrows(): + bar_datetime = dt #- timedelta(seconds=bar_interval_seconds) + + bar = BarData( + gateway_name='backtesting', + symbol=symbol, + exchange=exchange, + datetime=bar_datetime + ) + if 'open' in bar_data: + bar.open_price = float(bar_data['open']) + bar.close_price = float(bar_data['close']) + bar.high_price = float(bar_data['high']) + bar.low_price = float(bar_data['low']) + else: + bar.open_price = float(bar_data['open_price']) + bar.close_price = float(bar_data['close_price']) + bar.high_price = float(bar_data['high_price']) + bar.low_price = float(bar_data['low_price']) + + bar.volume = int(bar_data['volume']) if not np.isnan(bar_data['volume']) else 0 + bar.date = dt.strftime('%Y-%m-%d') + bar.time = dt.strftime('%H:%M:%S') + str_td = str(bar_data.get('trading_day', '')) + if len(str_td) == 8: + bar.trading_day = str_td[0:4] + '-' + str_td[4:6] + '-' + str_td[6:8] + else: + bar.trading_day = bar.date + + bars.append(bar) + + except Exception as ex: + self.write_error(u'回测时读取{} csv文件{}失败:{}'.format(vt_symbol, bar_file_path, ex)) + self.write_error(traceback.format_exc()) + return bars + + return bars + + def stock_to_adj(self, raw_data, adj_data, adj_type): + """ + 股票数据复权转换 + :param raw_data: 不复权数据 + :param adj_data: 复权记录 ( 从barstock下载的复权记录列表=》df) + :param adj_type: 复权类型 + :return: + """ + + if adj_type == 'fore': + adj_factor = adj_data["foreAdjustFactor"] + adj_factor = adj_factor / adj_factor.iloc[-1] # 保证最后一个复权因子是1 + else: + adj_factor = adj_data["backAdjustFactor"] + adj_factor = adj_factor / adj_factor.iloc[0] # 保证第一个复权因子是1 + + # 把raw_data的第一个日期,插入复权因子df,使用后填充 + if adj_factor.index[0] != raw_data.index[0]: + adj_factor.loc[raw_data.index[0]] = np.nan + adj_factor.sort_index(inplace=True) + adj_factor = adj_factor.ffill() + + adj_factor = adj_factor.reindex(index=raw_data.index) # 按价格dataframe的日期索引来扩展索引 + adj_factor = adj_factor.ffill() # 向前(向未来)填充扩展后的空单元格 + + # 把复权因子,作为adj字段,补充到raw_data中 + raw_data['adj'] = adj_factor + + # 逐一复权高低开平和成交量 + for col in ['open', 'high', 'low', 'close']: + raw_data[col] = raw_data[col] * raw_data['adj'] # 价格乘上复权系数 + raw_data['volume'] = raw_data['volume'] / raw_data['adj'] # 成交量除以复权系数 + + return raw_data + + def resample_bars(self, df, x_min=None, x_hour=None, to_day=False): + """ + 重建x分钟K线(或日线) + :param df: 输入分钟数 + :param x_min: 5, 15, 30, 60 + :param x_hour: 1, 2, 3, 4 + :param include_day: 重建日线, True得时候,不会重建分钟数 + :return: + """ + # 设置df数据中每列的规则 + ohlc_rule = { + 'open': 'first', # open列:序列中第一个的值 + 'high': 'max', # high列:序列中最大的值 + 'low': 'min', # low列:序列中最小的值 + 'close': 'last', # close列:序列中最后一个的值 + 'volume': 'sum', # volume列:将所有序列里的volume值作和 + 'amount': 'sum', # amount列:将所有序列里的amount值作和 + "symbol": 'first', + "trading_date": 'first', + "date": 'first', + "time": 'first' + } + + if isinstance(x_min, int) and not to_day: + # 合成x分钟K线并删除为空的行 参数 closed:left类似向上取值既 09:30的k线数据是包含09:30-09:35之间的数据 + df_target = df.resample(f'{x_min}min', closed='left', label='left').agg(ohlc_rule).dropna(axis=0, + how='any') + return df_target + if isinstance(x_hour, int) and not to_day: + # 合成x小时K线并删除为空的行 参数 closed:left类似向上取值既 09:30的k线数据是包含09:30-09:35之间的数据 + df_target = df.resample(f'{x_hour}hour', closed='left', label='left').agg(ohlc_rule).dropna(axis=0, + how='any') + return df_target + + if to_day: + # 合成x分钟K线并删除为空的行 参数 closed:left类似向上取值既 09:30的k线数据是包含09:30-09:35之间的数据 + df_target = df.resample(f'D', closed='left', label='left').agg(ohlc_rule).dropna(axis=0, how='any') + return df_target + + return df + def call_strategy_func( self, strategy: CtaTemplate, func: Callable, params: Any = None ): @@ -1153,7 +1483,8 @@ class CtaEngine(BaseEngine): # Remove from symbol strategy map self.write_log(f'移除{vt_symbol}《=》{strategy_name}的订阅关系') strategies = self.symbol_strategy_map[vt_symbol] - strategies.remove(strategy) + if strategy in strategies: + strategies.remove(strategy) # Remove from active orderid map if strategy_name in self.strategy_orderid_map: @@ -1528,6 +1859,27 @@ class CtaEngine(BaseEngine): return strategy_pos_list + def get_all_strategy_pos_from_hams(self): + """ + 获取hams中该账号下所有策略仓位明细 + """ + strategy_pos_list = [] + if not self.mongo_data: + self.init_mongo_data() + + if self.mongo_data and self.mongo_data.db_has_connected: + filter = {'account_id':self.engine_config.get('accountid','-')} + + pos_list = self.mongo_data.db_query( + db_name='Account', + col_name='today_strategy_pos', + filter_dict=filter + ) + for pos in pos_list: + strategy_pos_list.append(pos) + + return strategy_pos_list + def get_strategy_class_parameters(self, class_name: str): """ Get default parameters of a strategy class. @@ -1580,9 +1932,17 @@ class CtaEngine(BaseEngine): self.write_log(u'开始对比账号&策略的持仓') - # 获取当前策略得持仓 - if len(strategy_pos_list) == 0: - strategy_pos_list = self.get_all_strategy_pos() + # 获取hams数据库中所有运行实例得策略 + if self.engine_config.get("get_pos_from_db", False): + strategy_pos_list = self.get_all_strategy_pos_from_hams() + else: + # 获取当前实例运行策略得持仓 + if len(strategy_pos_list) == 0: + strategy_pos_list = self.get_all_strategy_pos() + + # 过滤账号中名字列表 + bypass_names = self.engine_config.get('compare_pos_bypass_names',[]) + self.write_log(u'策略持仓清单:{}'.format(strategy_pos_list)) none_strategy_pos = self.get_none_strategy_pos_list() @@ -1597,6 +1957,12 @@ class CtaEngine(BaseEngine): for position in list(self.positions.values()): # gateway_name.symbol.exchange => symbol.exchange vt_symbol = position.vt_symbol + cn_name = self.get_name(vt_symbol) + + # 中文名字,在过滤清单中 + if any([name in cn_name for name in bypass_names]): + continue + vt_symbols.add(vt_symbol) compare_pos[vt_symbol] = OrderedDict( @@ -1631,6 +1997,8 @@ class CtaEngine(BaseEngine): u'{}({})'.format(strategy_pos['strategy_name'], abs(pos.get('volume', 0)))) self.write_log(u'更新{}策略持多仓=>{}'.format(vt_symbol, symbol_pos.get('策略多单', 0))) + compare_pos.update({vt_symbol:symbol_pos}) + pos_compare_result = '' # 精简输出 compare_info = '' @@ -1724,14 +2092,7 @@ class CtaEngine(BaseEngine): """ Load setting file. """ - # 读取引擎得配置 - # "accountid" : "xxxx", 资金账号,一般用于推送消息时附带 - # "strategy_group": "cta_strategy_pro", # 当前实例名。多个实例时,区分开 - # "trade_2_wx": true # 是否交易记录转发至微信通知 - # "event_log: false # 是否转发日志到event bus,显示在图形界面 - self.engine_config = load_json(self.engine_filename) - # 是否产生event log 日志(一般GUI界面才产生,而且比好消耗资源) - self.event_log = self.engine_config.get('event_log', False) + # 读取策略得配置 self.strategy_setting = load_json(self.setting_filename) diff --git a/vnpy/app/cta_stock/portfolio_testing.py b/vnpy/app/cta_stock/portfolio_testing.py index 58078afd..cf4a701d 100644 --- a/vnpy/app/cta_stock/portfolio_testing.py +++ b/vnpy/app/cta_stock/portfolio_testing.py @@ -15,6 +15,7 @@ import traceback import random import bz2 import pickle +import numpy as np from datetime import datetime, timedelta from time import sleep @@ -31,6 +32,7 @@ from vnpy.trader.constant import ( from vnpy.trader.utility import ( get_trading_date, extract_vt_symbol, + get_csv_last_dt ) from .back_testing import BackTestingEngine @@ -57,7 +59,7 @@ class PortfolioTestingEngine(BackTestingEngine): self.tick_path = None # tick级别回测, 路径 - def load_bar_csv_to_df(self, vt_symbol, bar_file, data_start_date=None, data_end_date=None): + def load_bar_csv_to_df(self, vt_symbol, bar_file, data_start_date=None, data_end_date=None, qfq=True): """ 加载回测bar数据到DataFrame 1. 增加前复权/后复权 @@ -65,14 +67,16 @@ class PortfolioTestingEngine(BackTestingEngine): :param bar_file: :param data_start_date: :param data_end_date: + :param qfq:True 前复权,False 后复权 :return: """ - self.output(u'loading {} from {}'.format(vt_symbol, bar_file)) + fq_name = '前复权' if qfq else '后复权' + self.output(u'加载数据[{}] :未复权文件{},复权转换:{}'.format(vt_symbol, bar_file, fq_name)) if vt_symbol in self.bar_df_dict: return True if bar_file is None or not os.path.exists(bar_file): - self.write_error(u'回测时,{}对应的csv bar文件{}不存在'.format(vt_symbol, bar_file)) + self.write_error(u'加载数据[{}]:对应的csv bar文件{}不存在'.format(vt_symbol, bar_file)) return False try: @@ -93,34 +97,64 @@ class PortfolioTestingEngine(BackTestingEngine): "date": str, "time": str } - # 加载csv文件 =》 dateframe - symbol_df = pd.read_csv(bar_file, dtype=data_types) - # 转换时间,str =》 datetime - symbol_df["datetime"] = pd.to_datetime(symbol_df["datetime"], format="%Y-%m-%d %H:%M:%S") - # 设置时间为索引 - symbol_df = symbol_df.set_index("datetime") - # 裁剪数据 - symbol_df = symbol_df.loc[self.test_start_date:self.test_end_date] + symbol_df = None + auto_generate_fq = True + # 复权文件 + fq_bar_file = bar_file.replace('.csv', '_qfq.csv' if qfq else '_hfq.csv') + if os.path.exists(fq_bar_file): + # 存在复权文件 + last_dt = get_csv_last_dt(fq_bar_file) + if isinstance(last_dt, datetime): + if last_dt.strftime('%Y-%m-%d') < self.test_end_date: + self.write_log(f'加载数据[{vt_symbol}], 使用{fq_name}文件:{fq_bar_file}') + symbol_df = pd.read_csv(bar_file, dtype=data_types) + # 转换时间,str =》 datetime + symbol_df["datetime"] = pd.to_datetime(symbol_df["datetime"], format="%Y-%m-%d %H:%M:%S") + # 设置时间为索引 + symbol_df = symbol_df.set_index("datetime") + # 裁剪数据 + symbol_df = symbol_df.loc[self.test_start_date:self.test_end_date] + # 不再产生复权文件 + auto_generate_fq = False - # 复权转换 - adj_list = self.adjust_factors.get(vt_symbol, []) - # 按照结束日期,裁剪复权记录 - adj_list = [row for row in adj_list if row['dividOperateDate'].replace('-', '') <= self.test_end_date] + if not isinstance(symbol_df, pd.DataFrame): + # 加载csv文件 =》 dateframe + symbol_df = pd.read_csv(bar_file, dtype=data_types) + # 转换时间,str =》 datetime + symbol_df["datetime"] = pd.to_datetime(symbol_df["datetime"], format="%Y-%m-%d %H:%M:%S") + # 设置时间为索引 + symbol_df = symbol_df.set_index("datetime") - if adj_list: - self.write_log(f'需要对{vt_symbol}进行前复权处理') - for row in adj_list: - row.update({'dividOperateDate': row.get('dividOperateDate') + ' 09:31:00'}) - # list -> dataframe, 转换复权日期格式 - adj_data = pd.DataFrame(adj_list) - adj_data["dividOperateDate"] = pd.to_datetime(adj_data["dividOperateDate"], format="%Y-%m-%d %H:%M:%S") - adj_data = adj_data.set_index("dividOperateDate") - # 调用转换方法,对open,high,low,close, volume进行复权, fore, 前复权, 其他,后复权 - symbol_df = self.stock_to_adj(symbol_df, adj_data, adj_type='fore') + # 裁剪数据 + symbol_df = symbol_df.loc[self.test_start_date:self.test_end_date] - # 添加到待合并dataframe dict中 - self.bar_df_dict.update({vt_symbol: symbol_df}) + # 复权转换 + adj_list = self.adjust_factors.get(vt_symbol, []) + # 按照结束日期,裁剪复权记录 + adj_list = [row for row in adj_list if row['dividOperateDate'].replace('-', '') <= self.test_end_date] + + if adj_list: + self.write_log(f'加载数据[{vt_symbol}], 对{vt_symbol}进行{fq_name}处理') + for row in adj_list: + d = row.get('dividOperateDate', "")[0:10] + if len(d) == 10: + row.update({'dividOperateDate': d + ' 09:31:00'}) + # list -> dataframe, 转换复权日期格式 + adj_data = pd.DataFrame(adj_list) + adj_data["dividOperateDate"] = pd.to_datetime(adj_data["dividOperateDate"], + format="%Y-%m-%d %H:%M:%S") + adj_data = adj_data.set_index("dividOperateDate") + # 调用转换方法,对open,high,low,close, volume进行复权, fore, 前复权, 其他,后复权 + symbol_df = self.stock_to_adj(symbol_df, adj_data, adj_type='fore' if qfq else "") + + if auto_generate_fq: + self.write_log(f'加载数据[{vt_symbol}] ,缓存{fq_name}文件=>{fq_bar_file}') + symbol_df.to_csv(fq_bar_file) + + if isinstance(symbol_df, pd.DataFrame): + # 添加到待合并dataframe dict中 + self.bar_df_dict.update({vt_symbol: symbol_df}) except Exception as ex: self.write_error(u'回测时读取{} csv文件{}失败:{}'.format(vt_symbol, bar_file, ex)) @@ -277,7 +311,7 @@ class PortfolioTestingEngine(BackTestingEngine): bar.high_price = float(bar_data['high_price']) bar.low_price = float(bar_data['low_price']) - bar.volume = int(bar_data['volume']) + bar.volume = float(bar_data['volume']) bar.date = dt.strftime('%Y-%m-%d') bar.time = dt.strftime('%H:%M:%S') str_td = str(bar_data.get('trading_day', '')) @@ -299,6 +333,8 @@ class PortfolioTestingEngine(BackTestingEngine): # bar时间与队列时间不一致,先推送队列的bars random.shuffle(bars_same_dt) for _bar_ in bars_same_dt: + if np.isnan(_bar_.close_price): + continue self.new_bar(_bar_) # 创建新的队列 diff --git a/vnpy/app/cta_stock/template.py b/vnpy/app/cta_stock/template.py index 5ac50cd7..9c88fe25 100644 --- a/vnpy/app/cta_stock/template.py +++ b/vnpy/app/cta_stock/template.py @@ -476,40 +476,68 @@ class CtaStockTemplate(CtaTemplate): self.write_log(u'保存policy数据') self.policy.save() - def save_klines_to_cache(self, kline_names: list = []): + def save_klines_to_cache(self, kline_names: list = [], vt_symbol: str = ""): """ 保存K线数据到缓存 :param kline_names: 一般为self.klines的keys + :param vt_symbol: 指定股票代码, + 如果使用该选项,加载 data/klines/strategyname_vtsymbol_klines.pkb2 + 如果空白,加载 data/strategyname_klines.pkb2 :return: """ if len(kline_names) == 0: kline_names = list(self.klines.keys()) - # 获取保存路径 - save_path = self.cta_engine.get_data_path() - # 保存缓存的文件名 - file_name = os.path.abspath(os.path.join(save_path, f'{self.strategy_name}_klines.pkb2')) - with bz2.BZ2File(file_name, 'wb') as f: - klines = {} - for kline_name in kline_names: - kline = self.klines.get(kline_name, None) - # if kline: - # kline.strategy = None - # kline.cb_on_bar = None - klines.update({kline_name: kline}) - pickle.dump(klines, f) + try: + # 如果是指定合约的话,使用klines子目录 + if len(vt_symbol) > 0: + kline_names = [n for n in kline_names if vt_symbol in n] + save_path = os.path.abspath(os.path.join(self.cta_engine.get_data_path(), 'klines')) + if not os.path.exists(save_path): + os.makedirs(save_path) + file_name = os.path.abspath(os.path.join(save_path, f'{self.strategy_name}_{vt_symbol}_klines.pkb2')) + else: + # 获取保存路径 + save_path = self.cta_engine.get_data_path() + # 保存缓存的文件名 + file_name = os.path.abspath(os.path.join(save_path, f'{self.strategy_name}_klines.pkb2')) - def load_klines_from_cache(self, kline_names: list = []): + with bz2.BZ2File(file_name, 'wb') as f: + klines = {} + for kline_name in kline_names: + kline = self.klines.get(kline_name, None) + # if kline: + # kline.strategy = None + # kline.cb_on_bar = None + klines.update({kline_name: kline}) + pickle.dump(klines, f) + self.write_log(f'保存{vt_symbol} K线数据成功=>{file_name}') + except Exception as ex: + self.write_error(f'保存k线数据异常:{str(ex)}') + self.write_error(traceback.format_exc()) + + def load_klines_from_cache(self, kline_names: list = [], vt_symbol: str = ""): """ 从缓存加载K线数据 :param kline_names: 指定需要加载的k线名称列表 + :param vt_symbol: 指定股票代码, + 如果使用该选项,加载 data/klines/strategyname_vtsymbol_klines.pkb2 + 如果空白,加载 data/strategyname_klines.pkb2 :return: """ if len(kline_names) == 0: kline_names = list(self.klines.keys()) - save_path = self.cta_engine.get_data_path() - file_name = os.path.abspath(os.path.join(save_path, f'{self.strategy_name}_klines.pkb2')) + # 如果是指定合约的话,使用klines子目录 + if len(vt_symbol) > 0: + kline_names = [n for n in kline_names if vt_symbol in n] + save_path = os.path.abspath(os.path.join(self.cta_engine.get_data_path(), 'klines')) + if not os.path.exists(save_path): + os.makedirs(save_path) + file_name = os.path.abspath(os.path.join(save_path, f'{self.strategy_name}_{vt_symbol}_klines.pkb2')) + else: + save_path = self.cta_engine.get_data_path() + file_name = os.path.abspath(os.path.join(save_path, f'{self.strategy_name}_klines.pkb2')) try: last_bar_dt = None with bz2.BZ2File(file_name, 'rb') as f: @@ -976,7 +1004,7 @@ class CtaStockTemplate(CtaTemplate): self.gt.remove_grids_by_ids(direction=Direction.LONG, ids=remove_gids) self.gt.save() - def tns_excute_sell_grids(self, vt_symbol=None): + def tns_excute_sell_grids(self, vt_symbol=None, force=False): """ 事务执行卖出网格 1、找出所有order_status=True,open_status=Talse, close_status=True的网格。 @@ -1027,17 +1055,27 @@ class CtaStockTemplate(CtaTemplate): sell_volume = ordering_grid.volume - ordering_grid.traded_volume if sell_volume > acc_symbol_pos.volume: - self.write_error(u'账号{}持仓{},不满足减仓目标:{}' - .format(vt_symbol, acc_symbol_pos.volume, sell_volume)) + if not force: + self.write_error(u'账号{}持仓{},不满足减仓目标:{}' + .format(vt_symbol, acc_symbol_pos.volume, sell_volume)) + continue + else: + self.write_log(u'账号{}持仓{},不满足减仓目标:{}, 修正卖出数量:{}=>{}' + .format(vt_symbol, acc_symbol_pos.volume, sell_volume, sell_volume, + acc_symbol_pos.volume)) + sell_volume = acc_symbol_pos.volume + + if sell_volume == 0: + self.write_log(f'账号{vt_symbol}持仓{acc_symbol_pos.volume},卖出目标:{sell_volume}=0 不执行') + continue + + cur_price = self.cta_engine.get_price(vt_symbol) + if not cur_price: + self.cta_engine.subscribe_symbol(strategy_name=self.strategy_name, vt_symbol=vt_symbol) continue # 实盘运行时,要加入市场买卖量的判断 - if not self.backtesting: - cur_price = self.cta_engine.get_price(vt_symbol) - if not cur_price: - self.cta_engine.subscribe_symbol(strategy_name=self.strategy_name, vt_symbol=vt_symbol) - continue - + if not force and not self.backtesting: symbol_tick = self.cta_engine.get_tick(vt_symbol) if symbol_tick: symbol_volume_tick = self.cta_engine.get_volume_tick(vt_symbol) @@ -1072,7 +1110,7 @@ class CtaStockTemplate(CtaTemplate): self.write_log(f'{vt_symbol} 已委托卖出,{sell_volume},委托价:{sell_price}, 数量:{sell_volume}') - def tns_finish_sell_grid(self, grid): + def tns_finish_sell_grid(self, grid:CtaGrid): """ 事务完成卖出网格 :param grid: @@ -1106,7 +1144,7 @@ class CtaStockTemplate(CtaTemplate): self.gt.save() self.policy.save() - def tns_execute_buy_grids(self, vt_symbol=None): + def tns_execute_buy_grids(self, vt_symbol=None, force=False): """ 事务执行买入网格 :return: @@ -1149,7 +1187,8 @@ class CtaStockTemplate(CtaTemplate): balance, availiable, _, _ = self.cta_engine.get_account() if availiable <= 0: - self.write_error(u'当前可用资金不足'.format(availiable)) + if not self.backtesting: + self.write_log(u'当前可用资金{}不足,总资金:{}'.format(availiable, balance)) continue vt_symbol = ordering_grid.vt_symbol cur_price = self.cta_engine.get_price(vt_symbol) @@ -1178,7 +1217,7 @@ class CtaStockTemplate(CtaTemplate): buy_volume = max_buy_volume # 实盘运行时,要加入市场买卖量的判断 - if not self.backtesting and 'market' in ordering_grid.snapshot: + if not force and not self.backtesting and 'market' in ordering_grid.snapshot: symbol_tick = self.cta_engine.get_tick(vt_symbol) if symbol_tick: # 根据市场计算,前5档买单数量 @@ -1207,7 +1246,7 @@ class CtaStockTemplate(CtaTemplate): else: self.write_log(f'{self.strategy_name}, {vt_orderids},已委托买入,{vt_symbol} 委托价:{buy_price} 数量:{buy_volume}') - def tns_finish_buy_grid(self, grid): + def tns_finish_buy_grid(self, grid:CtaGrid): """ 事务完成买入网格 :return: