From 48f78cce8face840f9c2d28ccd3c3111c0cb8ee5 Mon Sep 17 00:00:00 2001 From: msincenselee Date: Wed, 25 Mar 2020 16:36:01 +0800 Subject: [PATCH] =?UTF-8?q?[=E6=96=B0=E5=8A=9F=E8=83=BD]=20=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0=E5=9B=9E=E6=B5=8B=E8=AE=B0=E5=BD=95/=E7=BB=93?= =?UTF-8?q?=E6=9E=9C=3D=E3=80=8B=E6=95=B0=E6=8D=AE=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- vnpy/app/cta_crypto/back_testing.py | 187 ++++++++++++++---- vnpy/app/cta_crypto/portfolio_testing.py | 14 +- vnpy/app/cta_strategy_pro/back_testing.py | 187 ++++++++++++++---- .../app/cta_strategy_pro/portfolio_testing.py | 14 +- vnpy/app/cta_strategy_pro/spread_testing.py | 5 +- 5 files changed, 317 insertions(+), 90 deletions(-) diff --git a/vnpy/app/cta_crypto/back_testing.py b/vnpy/app/cta_crypto/back_testing.py index 9903ff3c..1054feb1 100644 --- a/vnpy/app/cta_crypto/back_testing.py +++ b/vnpy/app/cta_crypto/back_testing.py @@ -16,6 +16,10 @@ import pandas as pd import traceback import numpy as np import logging +import socket +import zlib +import pickle +from bson import binary from collections import OrderedDict, defaultdict from datetime import datetime, timedelta @@ -58,7 +62,8 @@ from vnpy.trader.utility import ( ) from vnpy.trader.util_logger import setup_logger - +from vnpy.data.mongo.mongo_data import MongoData +from uuid import uuid1 class BackTestingEngine(object): """ @@ -199,6 +204,12 @@ class BackTestingEngine(object): self.fund_kline_dict = {} self.active_fund_kline = False + # 回测任务/回测结果,保存在数据库中 + self.mongo_api = None + self.task_id = None + self.test_setting = None # 回测设置 + self.strategy_setting = None # 所有回测策略得设置 + def create_fund_kline(self, name, use_renko=False): """ 创建资金曲线 @@ -421,39 +432,41 @@ class BackTestingEngine(object): """ self.daily_report_name = report_file - def prepare_env(self, test_settings): + def prepare_env(self, test_setting): """ 根据配置参数,准备环境 包括: 回测名称 ,是否debug,数据目录/日志目录, 资金/保证金类型/仓位控制 回测开始/结束日期 - :param test_settings: + :param test_setting: :return: """ - self.output('back_testing prepare_env') - if 'name' in test_settings: - self.set_name(test_settings.get('name')) + self.test_setting = copy.copy(test_setting) - self.mode = test_settings.get('mode', 'bar') + self.output('back_testing prepare_env') + if 'name' in test_setting: + self.set_name(test_setting.get('name')) + + self.mode = test_setting.get('mode', 'bar') self.output(f'采用{self.mode}方式回测') - self.contract_type = test_settings.get('contract_type', 'future') + self.contract_type = test_setting.get('contract_type', 'future') self.output(f'测试合约主要为{self.contract_type}') - self.debug = test_settings.get('debug', False) + self.debug = test_setting.get('debug', False) # 更新数据目录 - if 'data_path' in test_settings: - self.data_path = test_settings.get('data_path') + if 'data_path' in test_setting: + self.data_path = test_setting.get('data_path') else: self.data_path = os.path.abspath(os.path.join(os.getcwd(), 'data')) print(f'数据输出目录:{self.data_path}') # 更新日志目录 - if 'logs_path' in test_settings: - self.logs_path = os.path.abspath(os.path.join(test_settings.get('logs_path'), self.test_name)) + if 'logs_path' in test_setting: + self.logs_path = os.path.abspath(os.path.join(test_setting.get('logs_path'), self.test_name)) else: self.logs_path = os.path.abspath(os.path.join(os.getcwd(), 'log', self.test_name)) print(f'日志输出目录:{self.logs_path}') @@ -462,55 +475,55 @@ class BackTestingEngine(object): self.create_logger(debug=self.debug) # 设置资金 - if 'init_capital' in test_settings: - self.write_log(u'设置期初资金:{}'.format(test_settings.get('init_capital'))) - self.set_init_capital(test_settings.get('init_capital')) + if 'init_capital' in test_setting: + self.write_log(u'设置期初资金:{}'.format(test_setting.get('init_capital'))) + self.set_init_capital(test_setting.get('init_capital')) # 缺省使用保证金方式。(期货使用保证金/股票不使用保证金) - self.use_margin = test_settings.get('use_margin', True) + self.use_margin = test_setting.get('use_margin', True) # 设置最大资金使用比例 - if 'percent_limit' in test_settings: - self.write_log(u'设置最大资金使用比例:{}%'.format(test_settings.get('percent_limit'))) - self.percent_limit = test_settings.get('percent_limit') + if 'percent_limit' in test_setting: + self.write_log(u'设置最大资金使用比例:{}%'.format(test_setting.get('percent_limit'))) + self.percent_limit = test_setting.get('percent_limit') - if 'start_date' in test_settings: - if 'strategy_start_date' not in test_settings: - init_days = test_settings.get('init_days', 10) - self.write_log(u'设置回测开始日期:{},数据加载日数:{}'.format(test_settings.get('start_date'), init_days)) - self.set_test_start_date(test_settings.get('start_date'), init_days) + if 'start_date' in test_setting: + if 'strategy_start_date' not in test_setting: + init_days = test_setting.get('init_days', 10) + self.write_log(u'设置回测开始日期:{},数据加载日数:{}'.format(test_setting.get('start_date'), init_days)) + self.set_test_start_date(test_setting.get('start_date'), init_days) else: - start_date = test_settings.get('start_date') - strategy_start_date = test_settings.get('strategy_start_date') + start_date = test_setting.get('start_date') + strategy_start_date = test_setting.get('strategy_start_date') self.write_log(u'使用指定的数据开始日期:{}和策略启动日期:{}'.format(start_date, strategy_start_date)) self.test_start_date = start_date self.data_start_date = datetime.strptime(start_date.replace('-', ''), '%Y%m%d') self.strategy_start_date = datetime.strptime(strategy_start_date.replace('-', ''), '%Y%m%d') - if 'end_date' in test_settings: - self.write_log(u'设置回测结束日期:{}'.format(test_settings.get('end_date'))) - self.set_test_end_date(test_settings.get('end_date')) + if 'end_date' in test_setting: + self.write_log(u'设置回测结束日期:{}'.format(test_setting.get('end_date'))) + self.set_test_end_date(test_setting.get('end_date')) # 准备数据 - if 'symbol_datas' in test_settings: + if 'symbol_datas' in test_setting: self.write_log(u'准备数据') - self.prepare_data(test_settings.get('symbol_datas')) + self.prepare_data(test_setting.get('symbol_datas')) if self.mode == 'tick': - self.tick_path = test_settings.get('tick_path', None) + self.tick_path = test_setting.get('tick_path', None) # 设置bar文件的时间间隔秒数 - if 'bar_interval_seconds' in test_settings: - self.write_log(u'设置bar文件的时间间隔秒数:{}'.format(test_settings.get('bar_interval_seconds'))) - self.bar_interval_seconds = test_settings.get('bar_interval_seconds') + if 'bar_interval_seconds' in test_setting: + self.write_log(u'设置bar文件的时间间隔秒数:{}'.format(test_setting.get('bar_interval_seconds'))) + self.bar_interval_seconds = test_setting.get('bar_interval_seconds') # 资金曲线 - self.active_fund_kline = test_settings.get('active_fund_kline', False) + self.active_fund_kline = test_setting.get('active_fund_kline', False) if self.active_fund_kline: # 创建资金K线 - self.create_fund_kline(self.test_name, use_renko=test_settings.get('use_renko', False)) + self.create_fund_kline(self.test_name, use_renko=test_setting.get('use_renko', False)) - self.is_plot_daily = test_settings.get('is_plot_daily', False) + self.is_plot_daily = test_setting.get('is_plot_daily', False) # 加载所有本地策略class self.load_strategy_class() @@ -2039,8 +2052,104 @@ class BackTestingEngine(object): result_info.update({u'Sharpe Ratio': d['sharpe']}) self.output(u'Sharpe Ratio:\t%s' % format_number(d['sharpe'])) + # 保存回测结果/交易记录/日线统计 至数据库 + self.save_result_to_mongo(result_info) + return result_info + def save_setting_to_mongo(self): + """ 保存测试设置到mongo中""" + self.task_id = self.test_setting.get('task_id', str(uuid1())) + + # 保存到mongo得配置 + save_mongo = self.test_setting.get('save_mongo', {}) + if len(save_mongo) == 0: + return + + if not self.mongo_api: + self.mongo_api = MongoData(host=save_mongo.get('host', 'localhost'), port=save_mongo.get('port', 27017)) + + d = { + 'task_id': self.task_id, # 单实例回测任务id + 'name': self.test_name, # 回测实例名称, 策略名+参数+时间 + 'group_id': self.test_setting.get('group_id', datetime.now().strftime('%y-%m-%d')), # 回测组合id + 'status': 'start', + 'task_start_time': datetime.now(), # 任务开始执行时间 + 'run_host': socket.gethostname(), # 任务运行得host主机 + 'test_setting': self.test_setting, # 回测参数 + 'strategy_setting': self.strategy_setting, # 策略参数 + } + + # 保存入数据库 + self.mongo_api.db_insert( + db_name=self.gateway_name, + col_name='tasks', + d=d) + + def save_fail_to_mongo(self, fail_msg): + # 保存到mongo得配置 + save_mongo = self.test_setting.get('save_mongo', {}) + if len(save_mongo) == 0: + return + + if not self.mongo_api: + self.mongo_api = MongoData(host=save_mongo.get('host', 'localhost'), port=save_mongo.get('port', 27017)) + + # 更新数据到数据库回测记录中 + flt = {'task_id': self.task_id} + + d = self.mongo_api.db_query_one( + db_name=self.gateway_name, + col_name='tasks', + flt=flt) + + if d: + d.update({'status': 'fail'}) # 更新状态未完成 + d.update({'fail_msg': fail_msg}) + + self.write_log(u'更新回测结果至数据库') + + self.mongo_api.db_update( + db_name=self.gateway_name, + col_name='tasks', + filter_dict=flt, + data_dict=d, + replace=False) + + def save_result_to_mongo(self, result_info): + + # 保存到mongo得配置 + save_mongo = self.test_setting.get('save_mongo', {}) + if len(save_mongo) == 0: + return + + if not self.mongo_api: + self.mongo_api = MongoData(host=save_mongo.get('host', 'localhost'), port=save_mongo.get('port', 27017)) + + # 更新数据到数据库回测记录中 + flt = {'task_id': self.task_id} + + d = self.mongo_api.db_query_one( + db_name=self.gateway_name, + col_name='tasks', + flt=flt) + + if d: + d.update({'status': 'finish'}) # 更新状态未完成 + d.update(result_info) # 补充回测结果 + d.update({'task_finish_time': datetime.now()}) # 更新回测完成时间 + d.update({'trade_list': binary.Binary(zlib.compress(pickle.dumps(self.trade_pnl_list)))}) # 更新交易记录 + d.update({'daily_list': binary.Binary(zlib.compress(pickle.dumps(self.daily_list)))}) # 更新每日净值记录 + + self.write_log(u'更新回测结果至数据库') + + self.mongo_api.db_update( + db_name=self.gateway_name, + col_name='tasks', + filter_dict=flt, + data_dict=d, + replace=False) + def put_strategy_event(self, strategy: CtaTemplate): """发送策略更新事件,回测中忽略""" pass diff --git a/vnpy/app/cta_crypto/portfolio_testing.py b/vnpy/app/cta_crypto/portfolio_testing.py index d3c0f94a..3b05604f 100644 --- a/vnpy/app/cta_crypto/portfolio_testing.py +++ b/vnpy/app/cta_crypto/portfolio_testing.py @@ -112,9 +112,9 @@ class PortfolioTestingEngine(BackTestingEngine): self.bar_df = pd.concat(self.bar_df_dict, axis=0).swaplevel(0, 1).sort_index() self.bar_df_dict.clear() - def prepare_env(self, test_settings): + def prepare_env(self, test_setting): self.output('portfolio prepare_env') - super().prepare_env(test_settings) + super().prepare_env(test_setting) def prepare_data(self, data_dict): """ @@ -148,7 +148,7 @@ class PortfolioTestingEngine(BackTestingEngine): self.bar_csv_file.update({symbol: bar_file}) - def run_portfolio_test(self, strategy_settings: dict = {}): + def run_portfolio_test(self, strategy_setting: dict = {}): """ 运行组合回测 """ @@ -156,7 +156,7 @@ class PortfolioTestingEngine(BackTestingEngine): self.write_error(u'回测开始日期未设置。') return - if len(strategy_settings) == 0: + if len(strategy_setting) == 0: self.write_error('未提供有效配置策略实例') return @@ -164,9 +164,12 @@ class PortfolioTestingEngine(BackTestingEngine): if not self.data_end_date: self.data_end_date = datetime.today() + # 保存回测设置/策略设置/任务ID至数据库 + self.save_setting_to_mongo() + self.write_log(u'开始组合回测') - for strategy_name, strategy_setting in strategy_settings.items(): + for strategy_name, strategy_setting in strategy_setting.items(): self.load_strategy(strategy_name, strategy_setting) self.write_log(u'策略初始化完成') @@ -324,6 +327,7 @@ def single_test(test_setting: dict, strategy_setting: dict): except Exception as ex: print('组合回测异常{}'.format(str(ex))) traceback.print_exc() + engine.save_fail_to_mongo(f'回测异常{str(ex)}') return False print('测试结束') diff --git a/vnpy/app/cta_strategy_pro/back_testing.py b/vnpy/app/cta_strategy_pro/back_testing.py index f647646f..04be4627 100644 --- a/vnpy/app/cta_strategy_pro/back_testing.py +++ b/vnpy/app/cta_strategy_pro/back_testing.py @@ -16,6 +16,10 @@ import pandas as pd import traceback import numpy as np import logging +import socket +import zlib +import pickle +from bson import binary from collections import OrderedDict, defaultdict from datetime import datetime, timedelta @@ -58,7 +62,8 @@ from vnpy.trader.utility import ( ) from vnpy.trader.util_logger import setup_logger - +from vnpy.data.mongo.mongo_data import MongoData +from uuid import uuid1 class BackTestingEngine(object): """ @@ -198,6 +203,12 @@ class BackTestingEngine(object): self.fund_kline_dict = {} self.active_fund_kline = False + # 回测任务/回测结果,保存在数据库中 + self.mongo_api = None + self.task_id = None + self.test_setting = None # 回测设置 + self.strategy_setting = None # 所有回测策略得设置 + def create_fund_kline(self, name, use_renko=False): """ 创建资金曲线 @@ -406,39 +417,41 @@ class BackTestingEngine(object): """ self.daily_report_name = report_file - def prepare_env(self, test_settings): + def prepare_env(self, test_setting): """ 根据配置参数,准备环境 包括: 回测名称 ,是否debug,数据目录/日志目录, 资金/保证金类型/仓位控制 回测开始/结束日期 - :param test_settings: + :param test_setting: :return: """ - self.output('back_testing prepare_env') - if 'name' in test_settings: - self.set_name(test_settings.get('name')) + self.test_setting = copy.copy(test_setting) - self.mode = test_settings.get('mode', 'bar') + self.output('back_testing prepare_env') + if 'name' in test_setting: + self.set_name(test_setting.get('name')) + + self.mode = test_setting.get('mode', 'bar') self.output(f'采用{self.mode}方式回测') - self.contract_type = test_settings.get('contract_type', 'future') + self.contract_type = test_setting.get('contract_type', 'future') self.output(f'测试合约主要为{self.contract_type}') - self.debug = test_settings.get('debug', False) + self.debug = test_setting.get('debug', False) # 更新数据目录 - if 'data_path' in test_settings: - self.data_path = test_settings.get('data_path') + if 'data_path' in test_setting: + self.data_path = test_setting.get('data_path') else: self.data_path = os.path.abspath(os.path.join(os.getcwd(), 'data')) print(f'数据输出目录:{self.data_path}') # 更新日志目录 - if 'logs_path' in test_settings: - self.logs_path = os.path.abspath(os.path.join(test_settings.get('logs_path'), self.test_name)) + if 'logs_path' in test_setting: + self.logs_path = os.path.abspath(os.path.join(test_setting.get('logs_path'), self.test_name)) else: self.logs_path = os.path.abspath(os.path.join(os.getcwd(), 'log', self.test_name)) print(f'日志输出目录:{self.logs_path}') @@ -447,55 +460,55 @@ class BackTestingEngine(object): self.create_logger(debug=self.debug) # 设置资金 - if 'init_capital' in test_settings: - self.write_log(u'设置期初资金:{}'.format(test_settings.get('init_capital'))) - self.set_init_capital(test_settings.get('init_capital')) + if 'init_capital' in test_setting: + self.write_log(u'设置期初资金:{}'.format(test_setting.get('init_capital'))) + self.set_init_capital(test_setting.get('init_capital')) # 缺省使用保证金方式。(期货使用保证金/股票不使用保证金) - self.use_margin = test_settings.get('use_margin', True) + self.use_margin = test_setting.get('use_margin', True) # 设置最大资金使用比例 - if 'percent_limit' in test_settings: - self.write_log(u'设置最大资金使用比例:{}%'.format(test_settings.get('percent_limit'))) - self.percent_limit = test_settings.get('percent_limit') + if 'percent_limit' in test_setting: + self.write_log(u'设置最大资金使用比例:{}%'.format(test_setting.get('percent_limit'))) + self.percent_limit = test_setting.get('percent_limit') - if 'start_date' in test_settings: - if 'strategy_start_date' not in test_settings: - init_days = test_settings.get('init_days', 10) - self.write_log(u'设置回测开始日期:{},数据加载日数:{}'.format(test_settings.get('start_date'), init_days)) - self.set_test_start_date(test_settings.get('start_date'), init_days) + if 'start_date' in test_setting: + if 'strategy_start_date' not in test_setting: + init_days = test_setting.get('init_days', 10) + self.write_log(u'设置回测开始日期:{},数据加载日数:{}'.format(test_setting.get('start_date'), init_days)) + self.set_test_start_date(test_setting.get('start_date'), init_days) else: - start_date = test_settings.get('start_date') - strategy_start_date = test_settings.get('strategy_start_date') + start_date = test_setting.get('start_date') + strategy_start_date = test_setting.get('strategy_start_date') self.write_log(u'使用指定的数据开始日期:{}和策略启动日期:{}'.format(start_date, strategy_start_date)) self.test_start_date = start_date self.data_start_date = datetime.strptime(start_date.replace('-', ''), '%Y%m%d') self.strategy_start_date = datetime.strptime(strategy_start_date.replace('-', ''), '%Y%m%d') - if 'end_date' in test_settings: - self.write_log(u'设置回测结束日期:{}'.format(test_settings.get('end_date'))) - self.set_test_end_date(test_settings.get('end_date')) + if 'end_date' in test_setting: + self.write_log(u'设置回测结束日期:{}'.format(test_setting.get('end_date'))) + self.set_test_end_date(test_setting.get('end_date')) # 准备数据 - if 'symbol_datas' in test_settings: + if 'symbol_datas' in test_setting: self.write_log(u'准备数据') - self.prepare_data(test_settings.get('symbol_datas')) + self.prepare_data(test_setting.get('symbol_datas')) if self.mode == 'tick': - self.tick_path = test_settings.get('tick_path', None) + self.tick_path = test_setting.get('tick_path', None) # 设置bar文件的时间间隔秒数 - if 'bar_interval_seconds' in test_settings: - self.write_log(u'设置bar文件的时间间隔秒数:{}'.format(test_settings.get('bar_interval_seconds'))) - self.bar_interval_seconds = test_settings.get('bar_interval_seconds') + if 'bar_interval_seconds' in test_setting: + self.write_log(u'设置bar文件的时间间隔秒数:{}'.format(test_setting.get('bar_interval_seconds'))) + self.bar_interval_seconds = test_setting.get('bar_interval_seconds') # 资金曲线 - self.active_fund_kline = test_settings.get('active_fund_kline', False) + self.active_fund_kline = test_setting.get('active_fund_kline', False) if self.active_fund_kline: # 创建资金K线 - self.create_fund_kline(self.test_name, use_renko=test_settings.get('use_renko', False)) + self.create_fund_kline(self.test_name, use_renko=test_setting.get('use_renko', False)) - self.is_plot_daily = test_settings.get('is_plot_daily', False) + self.is_plot_daily = test_setting.get('is_plot_daily', False) # 加载所有本地策略class self.load_strategy_class() @@ -2125,8 +2138,104 @@ class BackTestingEngine(object): result_info.update({u'Sharpe Ratio': d['sharpe']}) self.output(u'Sharpe Ratio:\t%s' % format_number(d['sharpe'])) + # 保存回测结果/交易记录/日线统计 至数据库 + self.save_result_to_mongo(result_info) + return result_info + def save_setting_to_mongo(self): + """ 保存测试设置到mongo中""" + self.task_id = self.test_setting.get('task_id', str(uuid1())) + + # 保存到mongo得配置 + save_mongo = self.test_setting.get('save_mongo', {}) + if len(save_mongo) == 0: + return + + if not self.mongo_api: + self.mongo_api = MongoData(host=save_mongo.get('host', 'localhost'), port=save_mongo.get('port', 27017)) + + d = { + 'task_id': self.task_id, # 单实例回测任务id + 'name': self.test_name, # 回测实例名称, 策略名+参数+时间 + 'group_id': self.test_setting.get('group_id', datetime.now().strftime('%y-%m-%d')), # 回测组合id + 'status': 'start', + 'task_start_time': datetime.now(), # 任务开始执行时间 + 'run_host': socket.gethostname(), # 任务运行得host主机 + 'test_setting': self.test_setting, # 回测参数 + 'strategy_setting': self.strategy_setting, # 策略参数 + } + + # 保存入数据库 + self.mongo_api.db_insert( + db_name=self.gateway_name, + col_name='tasks', + d=d) + + def save_fail_to_mongo(self, fail_msg): + # 保存到mongo得配置 + save_mongo = self.test_setting.get('save_mongo', {}) + if len(save_mongo) == 0: + return + + if not self.mongo_api: + self.mongo_api = MongoData(host=save_mongo.get('host', 'localhost'), port=save_mongo.get('port', 27017)) + + # 更新数据到数据库回测记录中 + flt = {'task_id': self.task_id} + + d = self.mongo_api.db_query_one( + db_name=self.gateway_name, + col_name='tasks', + flt=flt) + + if d: + d.update({'status': 'fail'}) # 更新状态未完成 + d.update({'fail_msg': fail_msg}) + + self.write_log(u'更新回测结果至数据库') + + self.mongo_api.db_update( + db_name=self.gateway_name, + col_name='tasks', + filter_dict=flt, + data_dict=d, + replace=False) + + def save_result_to_mongo(self, result_info): + + # 保存到mongo得配置 + save_mongo = self.test_setting.get('save_mongo', {}) + if len(save_mongo) == 0: + return + + if not self.mongo_api: + self.mongo_api = MongoData(host=save_mongo.get('host', 'localhost'), port=save_mongo.get('port', 27017)) + + # 更新数据到数据库回测记录中 + flt = {'task_id': self.task_id} + + d = self.mongo_api.db_query_one( + db_name=self.gateway_name, + col_name='tasks', + flt=flt) + + if d: + d.update({'status': 'finish'}) # 更新状态未完成 + d.update(result_info) # 补充回测结果 + d.update({'task_finish_time': datetime.now()}) # 更新回测完成时间 + d.update({'trade_list': binary.Binary(zlib.compress(pickle.dumps(self.trade_pnl_list)))}) # 更新交易记录 + d.update({'daily_list': binary.Binary(zlib.compress(pickle.dumps(self.daily_list)))}) # 更新每日净值记录 + + self.write_log(u'更新回测结果至数据库') + + self.mongo_api.db_update( + db_name=self.gateway_name, + col_name='tasks', + filter_dict=flt, + data_dict=d, + replace=False) + def put_strategy_event(self, strategy: CtaTemplate): """发送策略更新事件,回测中忽略""" pass diff --git a/vnpy/app/cta_strategy_pro/portfolio_testing.py b/vnpy/app/cta_strategy_pro/portfolio_testing.py index ce031e75..60e4d334 100644 --- a/vnpy/app/cta_strategy_pro/portfolio_testing.py +++ b/vnpy/app/cta_strategy_pro/portfolio_testing.py @@ -113,9 +113,9 @@ class PortfolioTestingEngine(BackTestingEngine): self.bar_df = pd.concat(self.bar_df_dict, axis=0).swaplevel(0, 1).sort_index() self.bar_df_dict.clear() - def prepare_env(self, test_settings): + def prepare_env(self, test_setting): self.output('portfolio prepare_env') - super().prepare_env(test_settings) + super().prepare_env(test_setting) def prepare_data(self, data_dict): """ @@ -149,7 +149,7 @@ class PortfolioTestingEngine(BackTestingEngine): self.bar_csv_file.update({symbol: bar_file}) - def run_portfolio_test(self, strategy_settings: dict = {}): + def run_portfolio_test(self, strategy_setting: dict = {}): """ 运行组合回测 """ @@ -157,7 +157,7 @@ class PortfolioTestingEngine(BackTestingEngine): self.write_error(u'回测开始日期未设置。') return - if len(strategy_settings) == 0: + if len(strategy_setting) == 0: self.write_error('未提供有效配置策略实例') return @@ -165,9 +165,12 @@ class PortfolioTestingEngine(BackTestingEngine): if not self.data_end_date: self.data_end_date = datetime.today() + # 保存回测脚本到数据库 + self.save_setting_to_mongo() + self.write_log(u'开始组合回测') - for strategy_name, strategy_setting in strategy_settings.items(): + for strategy_name, strategy_setting in strategy_setting.items(): self.load_strategy(strategy_name, strategy_setting) self.write_log(u'策略初始化完成') @@ -447,6 +450,7 @@ def single_test(test_setting: dict, strategy_setting: dict): except Exception as ex: print('组合回测异常{}'.format(str(ex))) traceback.print_exc() + engine.save_fail_to_mongo(f'回测异常{str(ex)}') return False print('测试结束') diff --git a/vnpy/app/cta_strategy_pro/spread_testing.py b/vnpy/app/cta_strategy_pro/spread_testing.py index 14a54ded..2a75a5d5 100644 --- a/vnpy/app/cta_strategy_pro/spread_testing.py +++ b/vnpy/app/cta_strategy_pro/spread_testing.py @@ -62,9 +62,9 @@ class SpreadTestingEngine(BackTestingEngine): self.strategy_start_date_dict = {} self.strategy_end_date_dict = {} - def prepare_env(self, test_settings): + def prepare_env(self, test_setting): self.output('portfolio prepare_env') - super().prepare_env(test_settings) + super().prepare_env(test_setting) def load_strategy(self, strategy_name: str, strategy_setting: dict = None): """ @@ -413,6 +413,7 @@ def single_test(test_setting: dict, strategy_setting: dict): except Exception as ex: print('组合回测异常{}'.format(str(ex))) + engine.save_fail_to_mongo(f'回测异常{str(ex)}') traceback.print_exc() return False