[新功能] 增加回测记录/结果=》数据库

This commit is contained in:
msincenselee 2020-03-25 16:36:01 +08:00
parent 65a1410146
commit 48f78cce8f
5 changed files with 317 additions and 90 deletions

View File

@ -16,6 +16,10 @@ import pandas as pd
import traceback import traceback
import numpy as np import numpy as np
import logging import logging
import socket
import zlib
import pickle
from bson import binary
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
from datetime import datetime, timedelta from datetime import datetime, timedelta
@ -58,7 +62,8 @@ from vnpy.trader.utility import (
) )
from vnpy.trader.util_logger import setup_logger from vnpy.trader.util_logger import setup_logger
from vnpy.data.mongo.mongo_data import MongoData
from uuid import uuid1
class BackTestingEngine(object): class BackTestingEngine(object):
""" """
@ -199,6 +204,12 @@ class BackTestingEngine(object):
self.fund_kline_dict = {} self.fund_kline_dict = {}
self.active_fund_kline = False self.active_fund_kline = False
# 回测任务/回测结果,保存在数据库中
self.mongo_api = None
self.task_id = None
self.test_setting = None # 回测设置
self.strategy_setting = None # 所有回测策略得设置
def create_fund_kline(self, name, use_renko=False): def create_fund_kline(self, name, use_renko=False):
""" """
创建资金曲线 创建资金曲线
@ -421,39 +432,41 @@ class BackTestingEngine(object):
""" """
self.daily_report_name = report_file self.daily_report_name = report_file
def prepare_env(self, test_settings): def prepare_env(self, test_setting):
""" """
根据配置参数准备环境 根据配置参数准备环境
包括 包括
回测名称 是否debug数据目录/日志目录 回测名称 是否debug数据目录/日志目录
资金/保证金类型/仓位控制 资金/保证金类型/仓位控制
回测开始/结束日期 回测开始/结束日期
:param test_settings: :param test_setting:
:return: :return:
""" """
self.output('back_testing prepare_env') self.test_setting = copy.copy(test_setting)
if 'name' in test_settings:
self.set_name(test_settings.get('name'))
self.mode = test_settings.get('mode', 'bar') self.output('back_testing prepare_env')
if 'name' in test_setting:
self.set_name(test_setting.get('name'))
self.mode = test_setting.get('mode', 'bar')
self.output(f'采用{self.mode}方式回测') self.output(f'采用{self.mode}方式回测')
self.contract_type = test_settings.get('contract_type', 'future') self.contract_type = test_setting.get('contract_type', 'future')
self.output(f'测试合约主要为{self.contract_type}') self.output(f'测试合约主要为{self.contract_type}')
self.debug = test_settings.get('debug', False) self.debug = test_setting.get('debug', False)
# 更新数据目录 # 更新数据目录
if 'data_path' in test_settings: if 'data_path' in test_setting:
self.data_path = test_settings.get('data_path') self.data_path = test_setting.get('data_path')
else: else:
self.data_path = os.path.abspath(os.path.join(os.getcwd(), 'data')) self.data_path = os.path.abspath(os.path.join(os.getcwd(), 'data'))
print(f'数据输出目录:{self.data_path}') print(f'数据输出目录:{self.data_path}')
# 更新日志目录 # 更新日志目录
if 'logs_path' in test_settings: if 'logs_path' in test_setting:
self.logs_path = os.path.abspath(os.path.join(test_settings.get('logs_path'), self.test_name)) self.logs_path = os.path.abspath(os.path.join(test_setting.get('logs_path'), self.test_name))
else: else:
self.logs_path = os.path.abspath(os.path.join(os.getcwd(), 'log', self.test_name)) self.logs_path = os.path.abspath(os.path.join(os.getcwd(), 'log', self.test_name))
print(f'日志输出目录:{self.logs_path}') print(f'日志输出目录:{self.logs_path}')
@ -462,55 +475,55 @@ class BackTestingEngine(object):
self.create_logger(debug=self.debug) self.create_logger(debug=self.debug)
# 设置资金 # 设置资金
if 'init_capital' in test_settings: if 'init_capital' in test_setting:
self.write_log(u'设置期初资金:{}'.format(test_settings.get('init_capital'))) self.write_log(u'设置期初资金:{}'.format(test_setting.get('init_capital')))
self.set_init_capital(test_settings.get('init_capital')) self.set_init_capital(test_setting.get('init_capital'))
# 缺省使用保证金方式。(期货使用保证金/股票不使用保证金) # 缺省使用保证金方式。(期货使用保证金/股票不使用保证金)
self.use_margin = test_settings.get('use_margin', True) self.use_margin = test_setting.get('use_margin', True)
# 设置最大资金使用比例 # 设置最大资金使用比例
if 'percent_limit' in test_settings: if 'percent_limit' in test_setting:
self.write_log(u'设置最大资金使用比例:{}%'.format(test_settings.get('percent_limit'))) self.write_log(u'设置最大资金使用比例:{}%'.format(test_setting.get('percent_limit')))
self.percent_limit = test_settings.get('percent_limit') self.percent_limit = test_setting.get('percent_limit')
if 'start_date' in test_settings: if 'start_date' in test_setting:
if 'strategy_start_date' not in test_settings: if 'strategy_start_date' not in test_setting:
init_days = test_settings.get('init_days', 10) init_days = test_setting.get('init_days', 10)
self.write_log(u'设置回测开始日期:{},数据加载日数:{}'.format(test_settings.get('start_date'), init_days)) self.write_log(u'设置回测开始日期:{},数据加载日数:{}'.format(test_setting.get('start_date'), init_days))
self.set_test_start_date(test_settings.get('start_date'), init_days) self.set_test_start_date(test_setting.get('start_date'), init_days)
else: else:
start_date = test_settings.get('start_date') start_date = test_setting.get('start_date')
strategy_start_date = test_settings.get('strategy_start_date') strategy_start_date = test_setting.get('strategy_start_date')
self.write_log(u'使用指定的数据开始日期:{}和策略启动日期:{}'.format(start_date, strategy_start_date)) self.write_log(u'使用指定的数据开始日期:{}和策略启动日期:{}'.format(start_date, strategy_start_date))
self.test_start_date = start_date self.test_start_date = start_date
self.data_start_date = datetime.strptime(start_date.replace('-', ''), '%Y%m%d') self.data_start_date = datetime.strptime(start_date.replace('-', ''), '%Y%m%d')
self.strategy_start_date = datetime.strptime(strategy_start_date.replace('-', ''), '%Y%m%d') self.strategy_start_date = datetime.strptime(strategy_start_date.replace('-', ''), '%Y%m%d')
if 'end_date' in test_settings: if 'end_date' in test_setting:
self.write_log(u'设置回测结束日期:{}'.format(test_settings.get('end_date'))) self.write_log(u'设置回测结束日期:{}'.format(test_setting.get('end_date')))
self.set_test_end_date(test_settings.get('end_date')) self.set_test_end_date(test_setting.get('end_date'))
# 准备数据 # 准备数据
if 'symbol_datas' in test_settings: if 'symbol_datas' in test_setting:
self.write_log(u'准备数据') self.write_log(u'准备数据')
self.prepare_data(test_settings.get('symbol_datas')) self.prepare_data(test_setting.get('symbol_datas'))
if self.mode == 'tick': if self.mode == 'tick':
self.tick_path = test_settings.get('tick_path', None) self.tick_path = test_setting.get('tick_path', None)
# 设置bar文件的时间间隔秒数 # 设置bar文件的时间间隔秒数
if 'bar_interval_seconds' in test_settings: if 'bar_interval_seconds' in test_setting:
self.write_log(u'设置bar文件的时间间隔秒数{}'.format(test_settings.get('bar_interval_seconds'))) self.write_log(u'设置bar文件的时间间隔秒数{}'.format(test_setting.get('bar_interval_seconds')))
self.bar_interval_seconds = test_settings.get('bar_interval_seconds') self.bar_interval_seconds = test_setting.get('bar_interval_seconds')
# 资金曲线 # 资金曲线
self.active_fund_kline = test_settings.get('active_fund_kline', False) self.active_fund_kline = test_setting.get('active_fund_kline', False)
if self.active_fund_kline: if self.active_fund_kline:
# 创建资金K线 # 创建资金K线
self.create_fund_kline(self.test_name, use_renko=test_settings.get('use_renko', False)) self.create_fund_kline(self.test_name, use_renko=test_setting.get('use_renko', False))
self.is_plot_daily = test_settings.get('is_plot_daily', False) self.is_plot_daily = test_setting.get('is_plot_daily', False)
# 加载所有本地策略class # 加载所有本地策略class
self.load_strategy_class() self.load_strategy_class()
@ -2039,8 +2052,104 @@ class BackTestingEngine(object):
result_info.update({u'Sharpe Ratio': d['sharpe']}) result_info.update({u'Sharpe Ratio': d['sharpe']})
self.output(u'Sharpe Ratio\t%s' % format_number(d['sharpe'])) self.output(u'Sharpe Ratio\t%s' % format_number(d['sharpe']))
# 保存回测结果/交易记录/日线统计 至数据库
self.save_result_to_mongo(result_info)
return result_info return result_info
def save_setting_to_mongo(self):
""" 保存测试设置到mongo中"""
self.task_id = self.test_setting.get('task_id', str(uuid1()))
# 保存到mongo得配置
save_mongo = self.test_setting.get('save_mongo', {})
if len(save_mongo) == 0:
return
if not self.mongo_api:
self.mongo_api = MongoData(host=save_mongo.get('host', 'localhost'), port=save_mongo.get('port', 27017))
d = {
'task_id': self.task_id, # 单实例回测任务id
'name': self.test_name, # 回测实例名称, 策略名+参数+时间
'group_id': self.test_setting.get('group_id', datetime.now().strftime('%y-%m-%d')), # 回测组合id
'status': 'start',
'task_start_time': datetime.now(), # 任务开始执行时间
'run_host': socket.gethostname(), # 任务运行得host主机
'test_setting': self.test_setting, # 回测参数
'strategy_setting': self.strategy_setting, # 策略参数
}
# 保存入数据库
self.mongo_api.db_insert(
db_name=self.gateway_name,
col_name='tasks',
d=d)
def save_fail_to_mongo(self, fail_msg):
# 保存到mongo得配置
save_mongo = self.test_setting.get('save_mongo', {})
if len(save_mongo) == 0:
return
if not self.mongo_api:
self.mongo_api = MongoData(host=save_mongo.get('host', 'localhost'), port=save_mongo.get('port', 27017))
# 更新数据到数据库回测记录中
flt = {'task_id': self.task_id}
d = self.mongo_api.db_query_one(
db_name=self.gateway_name,
col_name='tasks',
flt=flt)
if d:
d.update({'status': 'fail'}) # 更新状态未完成
d.update({'fail_msg': fail_msg})
self.write_log(u'更新回测结果至数据库')
self.mongo_api.db_update(
db_name=self.gateway_name,
col_name='tasks',
filter_dict=flt,
data_dict=d,
replace=False)
def save_result_to_mongo(self, result_info):
# 保存到mongo得配置
save_mongo = self.test_setting.get('save_mongo', {})
if len(save_mongo) == 0:
return
if not self.mongo_api:
self.mongo_api = MongoData(host=save_mongo.get('host', 'localhost'), port=save_mongo.get('port', 27017))
# 更新数据到数据库回测记录中
flt = {'task_id': self.task_id}
d = self.mongo_api.db_query_one(
db_name=self.gateway_name,
col_name='tasks',
flt=flt)
if d:
d.update({'status': 'finish'}) # 更新状态未完成
d.update(result_info) # 补充回测结果
d.update({'task_finish_time': datetime.now()}) # 更新回测完成时间
d.update({'trade_list': binary.Binary(zlib.compress(pickle.dumps(self.trade_pnl_list)))}) # 更新交易记录
d.update({'daily_list': binary.Binary(zlib.compress(pickle.dumps(self.daily_list)))}) # 更新每日净值记录
self.write_log(u'更新回测结果至数据库')
self.mongo_api.db_update(
db_name=self.gateway_name,
col_name='tasks',
filter_dict=flt,
data_dict=d,
replace=False)
def put_strategy_event(self, strategy: CtaTemplate): def put_strategy_event(self, strategy: CtaTemplate):
"""发送策略更新事件,回测中忽略""" """发送策略更新事件,回测中忽略"""
pass pass

View File

@ -112,9 +112,9 @@ class PortfolioTestingEngine(BackTestingEngine):
self.bar_df = pd.concat(self.bar_df_dict, axis=0).swaplevel(0, 1).sort_index() self.bar_df = pd.concat(self.bar_df_dict, axis=0).swaplevel(0, 1).sort_index()
self.bar_df_dict.clear() self.bar_df_dict.clear()
def prepare_env(self, test_settings): def prepare_env(self, test_setting):
self.output('portfolio prepare_env') self.output('portfolio prepare_env')
super().prepare_env(test_settings) super().prepare_env(test_setting)
def prepare_data(self, data_dict): def prepare_data(self, data_dict):
""" """
@ -148,7 +148,7 @@ class PortfolioTestingEngine(BackTestingEngine):
self.bar_csv_file.update({symbol: bar_file}) self.bar_csv_file.update({symbol: bar_file})
def run_portfolio_test(self, strategy_settings: dict = {}): def run_portfolio_test(self, strategy_setting: dict = {}):
""" """
运行组合回测 运行组合回测
""" """
@ -156,7 +156,7 @@ class PortfolioTestingEngine(BackTestingEngine):
self.write_error(u'回测开始日期未设置。') self.write_error(u'回测开始日期未设置。')
return return
if len(strategy_settings) == 0: if len(strategy_setting) == 0:
self.write_error('未提供有效配置策略实例') self.write_error('未提供有效配置策略实例')
return return
@ -164,9 +164,12 @@ class PortfolioTestingEngine(BackTestingEngine):
if not self.data_end_date: if not self.data_end_date:
self.data_end_date = datetime.today() self.data_end_date = datetime.today()
# 保存回测设置/策略设置/任务ID至数据库
self.save_setting_to_mongo()
self.write_log(u'开始组合回测') self.write_log(u'开始组合回测')
for strategy_name, strategy_setting in strategy_settings.items(): for strategy_name, strategy_setting in strategy_setting.items():
self.load_strategy(strategy_name, strategy_setting) self.load_strategy(strategy_name, strategy_setting)
self.write_log(u'策略初始化完成') self.write_log(u'策略初始化完成')
@ -324,6 +327,7 @@ def single_test(test_setting: dict, strategy_setting: dict):
except Exception as ex: except Exception as ex:
print('组合回测异常{}'.format(str(ex))) print('组合回测异常{}'.format(str(ex)))
traceback.print_exc() traceback.print_exc()
engine.save_fail_to_mongo(f'回测异常{str(ex)}')
return False return False
print('测试结束') print('测试结束')

View File

@ -16,6 +16,10 @@ import pandas as pd
import traceback import traceback
import numpy as np import numpy as np
import logging import logging
import socket
import zlib
import pickle
from bson import binary
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
from datetime import datetime, timedelta from datetime import datetime, timedelta
@ -58,7 +62,8 @@ from vnpy.trader.utility import (
) )
from vnpy.trader.util_logger import setup_logger from vnpy.trader.util_logger import setup_logger
from vnpy.data.mongo.mongo_data import MongoData
from uuid import uuid1
class BackTestingEngine(object): class BackTestingEngine(object):
""" """
@ -198,6 +203,12 @@ class BackTestingEngine(object):
self.fund_kline_dict = {} self.fund_kline_dict = {}
self.active_fund_kline = False self.active_fund_kline = False
# 回测任务/回测结果,保存在数据库中
self.mongo_api = None
self.task_id = None
self.test_setting = None # 回测设置
self.strategy_setting = None # 所有回测策略得设置
def create_fund_kline(self, name, use_renko=False): def create_fund_kline(self, name, use_renko=False):
""" """
创建资金曲线 创建资金曲线
@ -406,39 +417,41 @@ class BackTestingEngine(object):
""" """
self.daily_report_name = report_file self.daily_report_name = report_file
def prepare_env(self, test_settings): def prepare_env(self, test_setting):
""" """
根据配置参数准备环境 根据配置参数准备环境
包括 包括
回测名称 是否debug数据目录/日志目录 回测名称 是否debug数据目录/日志目录
资金/保证金类型/仓位控制 资金/保证金类型/仓位控制
回测开始/结束日期 回测开始/结束日期
:param test_settings: :param test_setting:
:return: :return:
""" """
self.output('back_testing prepare_env') self.test_setting = copy.copy(test_setting)
if 'name' in test_settings:
self.set_name(test_settings.get('name'))
self.mode = test_settings.get('mode', 'bar') self.output('back_testing prepare_env')
if 'name' in test_setting:
self.set_name(test_setting.get('name'))
self.mode = test_setting.get('mode', 'bar')
self.output(f'采用{self.mode}方式回测') self.output(f'采用{self.mode}方式回测')
self.contract_type = test_settings.get('contract_type', 'future') self.contract_type = test_setting.get('contract_type', 'future')
self.output(f'测试合约主要为{self.contract_type}') self.output(f'测试合约主要为{self.contract_type}')
self.debug = test_settings.get('debug', False) self.debug = test_setting.get('debug', False)
# 更新数据目录 # 更新数据目录
if 'data_path' in test_settings: if 'data_path' in test_setting:
self.data_path = test_settings.get('data_path') self.data_path = test_setting.get('data_path')
else: else:
self.data_path = os.path.abspath(os.path.join(os.getcwd(), 'data')) self.data_path = os.path.abspath(os.path.join(os.getcwd(), 'data'))
print(f'数据输出目录:{self.data_path}') print(f'数据输出目录:{self.data_path}')
# 更新日志目录 # 更新日志目录
if 'logs_path' in test_settings: if 'logs_path' in test_setting:
self.logs_path = os.path.abspath(os.path.join(test_settings.get('logs_path'), self.test_name)) self.logs_path = os.path.abspath(os.path.join(test_setting.get('logs_path'), self.test_name))
else: else:
self.logs_path = os.path.abspath(os.path.join(os.getcwd(), 'log', self.test_name)) self.logs_path = os.path.abspath(os.path.join(os.getcwd(), 'log', self.test_name))
print(f'日志输出目录:{self.logs_path}') print(f'日志输出目录:{self.logs_path}')
@ -447,55 +460,55 @@ class BackTestingEngine(object):
self.create_logger(debug=self.debug) self.create_logger(debug=self.debug)
# 设置资金 # 设置资金
if 'init_capital' in test_settings: if 'init_capital' in test_setting:
self.write_log(u'设置期初资金:{}'.format(test_settings.get('init_capital'))) self.write_log(u'设置期初资金:{}'.format(test_setting.get('init_capital')))
self.set_init_capital(test_settings.get('init_capital')) self.set_init_capital(test_setting.get('init_capital'))
# 缺省使用保证金方式。(期货使用保证金/股票不使用保证金) # 缺省使用保证金方式。(期货使用保证金/股票不使用保证金)
self.use_margin = test_settings.get('use_margin', True) self.use_margin = test_setting.get('use_margin', True)
# 设置最大资金使用比例 # 设置最大资金使用比例
if 'percent_limit' in test_settings: if 'percent_limit' in test_setting:
self.write_log(u'设置最大资金使用比例:{}%'.format(test_settings.get('percent_limit'))) self.write_log(u'设置最大资金使用比例:{}%'.format(test_setting.get('percent_limit')))
self.percent_limit = test_settings.get('percent_limit') self.percent_limit = test_setting.get('percent_limit')
if 'start_date' in test_settings: if 'start_date' in test_setting:
if 'strategy_start_date' not in test_settings: if 'strategy_start_date' not in test_setting:
init_days = test_settings.get('init_days', 10) init_days = test_setting.get('init_days', 10)
self.write_log(u'设置回测开始日期:{},数据加载日数:{}'.format(test_settings.get('start_date'), init_days)) self.write_log(u'设置回测开始日期:{},数据加载日数:{}'.format(test_setting.get('start_date'), init_days))
self.set_test_start_date(test_settings.get('start_date'), init_days) self.set_test_start_date(test_setting.get('start_date'), init_days)
else: else:
start_date = test_settings.get('start_date') start_date = test_setting.get('start_date')
strategy_start_date = test_settings.get('strategy_start_date') strategy_start_date = test_setting.get('strategy_start_date')
self.write_log(u'使用指定的数据开始日期:{}和策略启动日期:{}'.format(start_date, strategy_start_date)) self.write_log(u'使用指定的数据开始日期:{}和策略启动日期:{}'.format(start_date, strategy_start_date))
self.test_start_date = start_date self.test_start_date = start_date
self.data_start_date = datetime.strptime(start_date.replace('-', ''), '%Y%m%d') self.data_start_date = datetime.strptime(start_date.replace('-', ''), '%Y%m%d')
self.strategy_start_date = datetime.strptime(strategy_start_date.replace('-', ''), '%Y%m%d') self.strategy_start_date = datetime.strptime(strategy_start_date.replace('-', ''), '%Y%m%d')
if 'end_date' in test_settings: if 'end_date' in test_setting:
self.write_log(u'设置回测结束日期:{}'.format(test_settings.get('end_date'))) self.write_log(u'设置回测结束日期:{}'.format(test_setting.get('end_date')))
self.set_test_end_date(test_settings.get('end_date')) self.set_test_end_date(test_setting.get('end_date'))
# 准备数据 # 准备数据
if 'symbol_datas' in test_settings: if 'symbol_datas' in test_setting:
self.write_log(u'准备数据') self.write_log(u'准备数据')
self.prepare_data(test_settings.get('symbol_datas')) self.prepare_data(test_setting.get('symbol_datas'))
if self.mode == 'tick': if self.mode == 'tick':
self.tick_path = test_settings.get('tick_path', None) self.tick_path = test_setting.get('tick_path', None)
# 设置bar文件的时间间隔秒数 # 设置bar文件的时间间隔秒数
if 'bar_interval_seconds' in test_settings: if 'bar_interval_seconds' in test_setting:
self.write_log(u'设置bar文件的时间间隔秒数{}'.format(test_settings.get('bar_interval_seconds'))) self.write_log(u'设置bar文件的时间间隔秒数{}'.format(test_setting.get('bar_interval_seconds')))
self.bar_interval_seconds = test_settings.get('bar_interval_seconds') self.bar_interval_seconds = test_setting.get('bar_interval_seconds')
# 资金曲线 # 资金曲线
self.active_fund_kline = test_settings.get('active_fund_kline', False) self.active_fund_kline = test_setting.get('active_fund_kline', False)
if self.active_fund_kline: if self.active_fund_kline:
# 创建资金K线 # 创建资金K线
self.create_fund_kline(self.test_name, use_renko=test_settings.get('use_renko', False)) self.create_fund_kline(self.test_name, use_renko=test_setting.get('use_renko', False))
self.is_plot_daily = test_settings.get('is_plot_daily', False) self.is_plot_daily = test_setting.get('is_plot_daily', False)
# 加载所有本地策略class # 加载所有本地策略class
self.load_strategy_class() self.load_strategy_class()
@ -2125,8 +2138,104 @@ class BackTestingEngine(object):
result_info.update({u'Sharpe Ratio': d['sharpe']}) result_info.update({u'Sharpe Ratio': d['sharpe']})
self.output(u'Sharpe Ratio\t%s' % format_number(d['sharpe'])) self.output(u'Sharpe Ratio\t%s' % format_number(d['sharpe']))
# 保存回测结果/交易记录/日线统计 至数据库
self.save_result_to_mongo(result_info)
return result_info return result_info
def save_setting_to_mongo(self):
""" 保存测试设置到mongo中"""
self.task_id = self.test_setting.get('task_id', str(uuid1()))
# 保存到mongo得配置
save_mongo = self.test_setting.get('save_mongo', {})
if len(save_mongo) == 0:
return
if not self.mongo_api:
self.mongo_api = MongoData(host=save_mongo.get('host', 'localhost'), port=save_mongo.get('port', 27017))
d = {
'task_id': self.task_id, # 单实例回测任务id
'name': self.test_name, # 回测实例名称, 策略名+参数+时间
'group_id': self.test_setting.get('group_id', datetime.now().strftime('%y-%m-%d')), # 回测组合id
'status': 'start',
'task_start_time': datetime.now(), # 任务开始执行时间
'run_host': socket.gethostname(), # 任务运行得host主机
'test_setting': self.test_setting, # 回测参数
'strategy_setting': self.strategy_setting, # 策略参数
}
# 保存入数据库
self.mongo_api.db_insert(
db_name=self.gateway_name,
col_name='tasks',
d=d)
def save_fail_to_mongo(self, fail_msg):
# 保存到mongo得配置
save_mongo = self.test_setting.get('save_mongo', {})
if len(save_mongo) == 0:
return
if not self.mongo_api:
self.mongo_api = MongoData(host=save_mongo.get('host', 'localhost'), port=save_mongo.get('port', 27017))
# 更新数据到数据库回测记录中
flt = {'task_id': self.task_id}
d = self.mongo_api.db_query_one(
db_name=self.gateway_name,
col_name='tasks',
flt=flt)
if d:
d.update({'status': 'fail'}) # 更新状态未完成
d.update({'fail_msg': fail_msg})
self.write_log(u'更新回测结果至数据库')
self.mongo_api.db_update(
db_name=self.gateway_name,
col_name='tasks',
filter_dict=flt,
data_dict=d,
replace=False)
def save_result_to_mongo(self, result_info):
# 保存到mongo得配置
save_mongo = self.test_setting.get('save_mongo', {})
if len(save_mongo) == 0:
return
if not self.mongo_api:
self.mongo_api = MongoData(host=save_mongo.get('host', 'localhost'), port=save_mongo.get('port', 27017))
# 更新数据到数据库回测记录中
flt = {'task_id': self.task_id}
d = self.mongo_api.db_query_one(
db_name=self.gateway_name,
col_name='tasks',
flt=flt)
if d:
d.update({'status': 'finish'}) # 更新状态未完成
d.update(result_info) # 补充回测结果
d.update({'task_finish_time': datetime.now()}) # 更新回测完成时间
d.update({'trade_list': binary.Binary(zlib.compress(pickle.dumps(self.trade_pnl_list)))}) # 更新交易记录
d.update({'daily_list': binary.Binary(zlib.compress(pickle.dumps(self.daily_list)))}) # 更新每日净值记录
self.write_log(u'更新回测结果至数据库')
self.mongo_api.db_update(
db_name=self.gateway_name,
col_name='tasks',
filter_dict=flt,
data_dict=d,
replace=False)
def put_strategy_event(self, strategy: CtaTemplate): def put_strategy_event(self, strategy: CtaTemplate):
"""发送策略更新事件,回测中忽略""" """发送策略更新事件,回测中忽略"""
pass pass

View File

@ -113,9 +113,9 @@ class PortfolioTestingEngine(BackTestingEngine):
self.bar_df = pd.concat(self.bar_df_dict, axis=0).swaplevel(0, 1).sort_index() self.bar_df = pd.concat(self.bar_df_dict, axis=0).swaplevel(0, 1).sort_index()
self.bar_df_dict.clear() self.bar_df_dict.clear()
def prepare_env(self, test_settings): def prepare_env(self, test_setting):
self.output('portfolio prepare_env') self.output('portfolio prepare_env')
super().prepare_env(test_settings) super().prepare_env(test_setting)
def prepare_data(self, data_dict): def prepare_data(self, data_dict):
""" """
@ -149,7 +149,7 @@ class PortfolioTestingEngine(BackTestingEngine):
self.bar_csv_file.update({symbol: bar_file}) self.bar_csv_file.update({symbol: bar_file})
def run_portfolio_test(self, strategy_settings: dict = {}): def run_portfolio_test(self, strategy_setting: dict = {}):
""" """
运行组合回测 运行组合回测
""" """
@ -157,7 +157,7 @@ class PortfolioTestingEngine(BackTestingEngine):
self.write_error(u'回测开始日期未设置。') self.write_error(u'回测开始日期未设置。')
return return
if len(strategy_settings) == 0: if len(strategy_setting) == 0:
self.write_error('未提供有效配置策略实例') self.write_error('未提供有效配置策略实例')
return return
@ -165,9 +165,12 @@ class PortfolioTestingEngine(BackTestingEngine):
if not self.data_end_date: if not self.data_end_date:
self.data_end_date = datetime.today() self.data_end_date = datetime.today()
# 保存回测脚本到数据库
self.save_setting_to_mongo()
self.write_log(u'开始组合回测') self.write_log(u'开始组合回测')
for strategy_name, strategy_setting in strategy_settings.items(): for strategy_name, strategy_setting in strategy_setting.items():
self.load_strategy(strategy_name, strategy_setting) self.load_strategy(strategy_name, strategy_setting)
self.write_log(u'策略初始化完成') self.write_log(u'策略初始化完成')
@ -447,6 +450,7 @@ def single_test(test_setting: dict, strategy_setting: dict):
except Exception as ex: except Exception as ex:
print('组合回测异常{}'.format(str(ex))) print('组合回测异常{}'.format(str(ex)))
traceback.print_exc() traceback.print_exc()
engine.save_fail_to_mongo(f'回测异常{str(ex)}')
return False return False
print('测试结束') print('测试结束')

View File

@ -62,9 +62,9 @@ class SpreadTestingEngine(BackTestingEngine):
self.strategy_start_date_dict = {} self.strategy_start_date_dict = {}
self.strategy_end_date_dict = {} self.strategy_end_date_dict = {}
def prepare_env(self, test_settings): def prepare_env(self, test_setting):
self.output('portfolio prepare_env') self.output('portfolio prepare_env')
super().prepare_env(test_settings) super().prepare_env(test_setting)
def load_strategy(self, strategy_name: str, strategy_setting: dict = None): def load_strategy(self, strategy_name: str, strategy_setting: dict = None):
""" """
@ -413,6 +413,7 @@ def single_test(test_setting: dict, strategy_setting: dict):
except Exception as ex: except Exception as ex:
print('组合回测异常{}'.format(str(ex))) print('组合回测异常{}'.format(str(ex)))
engine.save_fail_to_mongo(f'回测异常{str(ex)}')
traceback.print_exc() traceback.print_exc()
return False return False