[新功能] 增加回测记录/结果=》数据库
This commit is contained in:
parent
65a1410146
commit
48f78cce8f
@ -16,6 +16,10 @@ import pandas as pd
|
||||
import traceback
|
||||
import numpy as np
|
||||
import logging
|
||||
import socket
|
||||
import zlib
|
||||
import pickle
|
||||
from bson import binary
|
||||
|
||||
from collections import OrderedDict, defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
@ -58,7 +62,8 @@ from vnpy.trader.utility import (
|
||||
)
|
||||
|
||||
from vnpy.trader.util_logger import setup_logger
|
||||
|
||||
from vnpy.data.mongo.mongo_data import MongoData
|
||||
from uuid import uuid1
|
||||
|
||||
class BackTestingEngine(object):
|
||||
"""
|
||||
@ -199,6 +204,12 @@ class BackTestingEngine(object):
|
||||
self.fund_kline_dict = {}
|
||||
self.active_fund_kline = False
|
||||
|
||||
# 回测任务/回测结果,保存在数据库中
|
||||
self.mongo_api = None
|
||||
self.task_id = None
|
||||
self.test_setting = None # 回测设置
|
||||
self.strategy_setting = None # 所有回测策略得设置
|
||||
|
||||
def create_fund_kline(self, name, use_renko=False):
|
||||
"""
|
||||
创建资金曲线
|
||||
@ -421,39 +432,41 @@ class BackTestingEngine(object):
|
||||
"""
|
||||
self.daily_report_name = report_file
|
||||
|
||||
def prepare_env(self, test_settings):
|
||||
def prepare_env(self, test_setting):
|
||||
"""
|
||||
根据配置参数,准备环境
|
||||
包括:
|
||||
回测名称 ,是否debug,数据目录/日志目录,
|
||||
资金/保证金类型/仓位控制
|
||||
回测开始/结束日期
|
||||
:param test_settings:
|
||||
:param test_setting:
|
||||
:return:
|
||||
"""
|
||||
self.output('back_testing prepare_env')
|
||||
if 'name' in test_settings:
|
||||
self.set_name(test_settings.get('name'))
|
||||
self.test_setting = copy.copy(test_setting)
|
||||
|
||||
self.mode = test_settings.get('mode', 'bar')
|
||||
self.output('back_testing prepare_env')
|
||||
if 'name' in test_setting:
|
||||
self.set_name(test_setting.get('name'))
|
||||
|
||||
self.mode = test_setting.get('mode', 'bar')
|
||||
self.output(f'采用{self.mode}方式回测')
|
||||
|
||||
self.contract_type = test_settings.get('contract_type', 'future')
|
||||
self.contract_type = test_setting.get('contract_type', 'future')
|
||||
self.output(f'测试合约主要为{self.contract_type}')
|
||||
|
||||
self.debug = test_settings.get('debug', False)
|
||||
self.debug = test_setting.get('debug', False)
|
||||
|
||||
# 更新数据目录
|
||||
if 'data_path' in test_settings:
|
||||
self.data_path = test_settings.get('data_path')
|
||||
if 'data_path' in test_setting:
|
||||
self.data_path = test_setting.get('data_path')
|
||||
else:
|
||||
self.data_path = os.path.abspath(os.path.join(os.getcwd(), 'data'))
|
||||
|
||||
print(f'数据输出目录:{self.data_path}')
|
||||
|
||||
# 更新日志目录
|
||||
if 'logs_path' in test_settings:
|
||||
self.logs_path = os.path.abspath(os.path.join(test_settings.get('logs_path'), self.test_name))
|
||||
if 'logs_path' in test_setting:
|
||||
self.logs_path = os.path.abspath(os.path.join(test_setting.get('logs_path'), self.test_name))
|
||||
else:
|
||||
self.logs_path = os.path.abspath(os.path.join(os.getcwd(), 'log', self.test_name))
|
||||
print(f'日志输出目录:{self.logs_path}')
|
||||
@ -462,55 +475,55 @@ class BackTestingEngine(object):
|
||||
self.create_logger(debug=self.debug)
|
||||
|
||||
# 设置资金
|
||||
if 'init_capital' in test_settings:
|
||||
self.write_log(u'设置期初资金:{}'.format(test_settings.get('init_capital')))
|
||||
self.set_init_capital(test_settings.get('init_capital'))
|
||||
if 'init_capital' in test_setting:
|
||||
self.write_log(u'设置期初资金:{}'.format(test_setting.get('init_capital')))
|
||||
self.set_init_capital(test_setting.get('init_capital'))
|
||||
|
||||
# 缺省使用保证金方式。(期货使用保证金/股票不使用保证金)
|
||||
self.use_margin = test_settings.get('use_margin', True)
|
||||
self.use_margin = test_setting.get('use_margin', True)
|
||||
|
||||
# 设置最大资金使用比例
|
||||
if 'percent_limit' in test_settings:
|
||||
self.write_log(u'设置最大资金使用比例:{}%'.format(test_settings.get('percent_limit')))
|
||||
self.percent_limit = test_settings.get('percent_limit')
|
||||
if 'percent_limit' in test_setting:
|
||||
self.write_log(u'设置最大资金使用比例:{}%'.format(test_setting.get('percent_limit')))
|
||||
self.percent_limit = test_setting.get('percent_limit')
|
||||
|
||||
if 'start_date' in test_settings:
|
||||
if 'strategy_start_date' not in test_settings:
|
||||
init_days = test_settings.get('init_days', 10)
|
||||
self.write_log(u'设置回测开始日期:{},数据加载日数:{}'.format(test_settings.get('start_date'), init_days))
|
||||
self.set_test_start_date(test_settings.get('start_date'), init_days)
|
||||
if 'start_date' in test_setting:
|
||||
if 'strategy_start_date' not in test_setting:
|
||||
init_days = test_setting.get('init_days', 10)
|
||||
self.write_log(u'设置回测开始日期:{},数据加载日数:{}'.format(test_setting.get('start_date'), init_days))
|
||||
self.set_test_start_date(test_setting.get('start_date'), init_days)
|
||||
else:
|
||||
start_date = test_settings.get('start_date')
|
||||
strategy_start_date = test_settings.get('strategy_start_date')
|
||||
start_date = test_setting.get('start_date')
|
||||
strategy_start_date = test_setting.get('strategy_start_date')
|
||||
self.write_log(u'使用指定的数据开始日期:{}和策略启动日期:{}'.format(start_date, strategy_start_date))
|
||||
self.test_start_date = start_date
|
||||
self.data_start_date = datetime.strptime(start_date.replace('-', ''), '%Y%m%d')
|
||||
self.strategy_start_date = datetime.strptime(strategy_start_date.replace('-', ''), '%Y%m%d')
|
||||
|
||||
if 'end_date' in test_settings:
|
||||
self.write_log(u'设置回测结束日期:{}'.format(test_settings.get('end_date')))
|
||||
self.set_test_end_date(test_settings.get('end_date'))
|
||||
if 'end_date' in test_setting:
|
||||
self.write_log(u'设置回测结束日期:{}'.format(test_setting.get('end_date')))
|
||||
self.set_test_end_date(test_setting.get('end_date'))
|
||||
|
||||
# 准备数据
|
||||
if 'symbol_datas' in test_settings:
|
||||
if 'symbol_datas' in test_setting:
|
||||
self.write_log(u'准备数据')
|
||||
self.prepare_data(test_settings.get('symbol_datas'))
|
||||
self.prepare_data(test_setting.get('symbol_datas'))
|
||||
|
||||
if self.mode == 'tick':
|
||||
self.tick_path = test_settings.get('tick_path', None)
|
||||
self.tick_path = test_setting.get('tick_path', None)
|
||||
|
||||
# 设置bar文件的时间间隔秒数
|
||||
if 'bar_interval_seconds' in test_settings:
|
||||
self.write_log(u'设置bar文件的时间间隔秒数:{}'.format(test_settings.get('bar_interval_seconds')))
|
||||
self.bar_interval_seconds = test_settings.get('bar_interval_seconds')
|
||||
if 'bar_interval_seconds' in test_setting:
|
||||
self.write_log(u'设置bar文件的时间间隔秒数:{}'.format(test_setting.get('bar_interval_seconds')))
|
||||
self.bar_interval_seconds = test_setting.get('bar_interval_seconds')
|
||||
|
||||
# 资金曲线
|
||||
self.active_fund_kline = test_settings.get('active_fund_kline', False)
|
||||
self.active_fund_kline = test_setting.get('active_fund_kline', False)
|
||||
if self.active_fund_kline:
|
||||
# 创建资金K线
|
||||
self.create_fund_kline(self.test_name, use_renko=test_settings.get('use_renko', False))
|
||||
self.create_fund_kline(self.test_name, use_renko=test_setting.get('use_renko', False))
|
||||
|
||||
self.is_plot_daily = test_settings.get('is_plot_daily', False)
|
||||
self.is_plot_daily = test_setting.get('is_plot_daily', False)
|
||||
|
||||
# 加载所有本地策略class
|
||||
self.load_strategy_class()
|
||||
@ -2039,8 +2052,104 @@ class BackTestingEngine(object):
|
||||
result_info.update({u'Sharpe Ratio': d['sharpe']})
|
||||
self.output(u'Sharpe Ratio:\t%s' % format_number(d['sharpe']))
|
||||
|
||||
# 保存回测结果/交易记录/日线统计 至数据库
|
||||
self.save_result_to_mongo(result_info)
|
||||
|
||||
return result_info
|
||||
|
||||
def save_setting_to_mongo(self):
|
||||
""" 保存测试设置到mongo中"""
|
||||
self.task_id = self.test_setting.get('task_id', str(uuid1()))
|
||||
|
||||
# 保存到mongo得配置
|
||||
save_mongo = self.test_setting.get('save_mongo', {})
|
||||
if len(save_mongo) == 0:
|
||||
return
|
||||
|
||||
if not self.mongo_api:
|
||||
self.mongo_api = MongoData(host=save_mongo.get('host', 'localhost'), port=save_mongo.get('port', 27017))
|
||||
|
||||
d = {
|
||||
'task_id': self.task_id, # 单实例回测任务id
|
||||
'name': self.test_name, # 回测实例名称, 策略名+参数+时间
|
||||
'group_id': self.test_setting.get('group_id', datetime.now().strftime('%y-%m-%d')), # 回测组合id
|
||||
'status': 'start',
|
||||
'task_start_time': datetime.now(), # 任务开始执行时间
|
||||
'run_host': socket.gethostname(), # 任务运行得host主机
|
||||
'test_setting': self.test_setting, # 回测参数
|
||||
'strategy_setting': self.strategy_setting, # 策略参数
|
||||
}
|
||||
|
||||
# 保存入数据库
|
||||
self.mongo_api.db_insert(
|
||||
db_name=self.gateway_name,
|
||||
col_name='tasks',
|
||||
d=d)
|
||||
|
||||
def save_fail_to_mongo(self, fail_msg):
|
||||
# 保存到mongo得配置
|
||||
save_mongo = self.test_setting.get('save_mongo', {})
|
||||
if len(save_mongo) == 0:
|
||||
return
|
||||
|
||||
if not self.mongo_api:
|
||||
self.mongo_api = MongoData(host=save_mongo.get('host', 'localhost'), port=save_mongo.get('port', 27017))
|
||||
|
||||
# 更新数据到数据库回测记录中
|
||||
flt = {'task_id': self.task_id}
|
||||
|
||||
d = self.mongo_api.db_query_one(
|
||||
db_name=self.gateway_name,
|
||||
col_name='tasks',
|
||||
flt=flt)
|
||||
|
||||
if d:
|
||||
d.update({'status': 'fail'}) # 更新状态未完成
|
||||
d.update({'fail_msg': fail_msg})
|
||||
|
||||
self.write_log(u'更新回测结果至数据库')
|
||||
|
||||
self.mongo_api.db_update(
|
||||
db_name=self.gateway_name,
|
||||
col_name='tasks',
|
||||
filter_dict=flt,
|
||||
data_dict=d,
|
||||
replace=False)
|
||||
|
||||
def save_result_to_mongo(self, result_info):
|
||||
|
||||
# 保存到mongo得配置
|
||||
save_mongo = self.test_setting.get('save_mongo', {})
|
||||
if len(save_mongo) == 0:
|
||||
return
|
||||
|
||||
if not self.mongo_api:
|
||||
self.mongo_api = MongoData(host=save_mongo.get('host', 'localhost'), port=save_mongo.get('port', 27017))
|
||||
|
||||
# 更新数据到数据库回测记录中
|
||||
flt = {'task_id': self.task_id}
|
||||
|
||||
d = self.mongo_api.db_query_one(
|
||||
db_name=self.gateway_name,
|
||||
col_name='tasks',
|
||||
flt=flt)
|
||||
|
||||
if d:
|
||||
d.update({'status': 'finish'}) # 更新状态未完成
|
||||
d.update(result_info) # 补充回测结果
|
||||
d.update({'task_finish_time': datetime.now()}) # 更新回测完成时间
|
||||
d.update({'trade_list': binary.Binary(zlib.compress(pickle.dumps(self.trade_pnl_list)))}) # 更新交易记录
|
||||
d.update({'daily_list': binary.Binary(zlib.compress(pickle.dumps(self.daily_list)))}) # 更新每日净值记录
|
||||
|
||||
self.write_log(u'更新回测结果至数据库')
|
||||
|
||||
self.mongo_api.db_update(
|
||||
db_name=self.gateway_name,
|
||||
col_name='tasks',
|
||||
filter_dict=flt,
|
||||
data_dict=d,
|
||||
replace=False)
|
||||
|
||||
def put_strategy_event(self, strategy: CtaTemplate):
|
||||
"""发送策略更新事件,回测中忽略"""
|
||||
pass
|
||||
|
@ -112,9 +112,9 @@ class PortfolioTestingEngine(BackTestingEngine):
|
||||
self.bar_df = pd.concat(self.bar_df_dict, axis=0).swaplevel(0, 1).sort_index()
|
||||
self.bar_df_dict.clear()
|
||||
|
||||
def prepare_env(self, test_settings):
|
||||
def prepare_env(self, test_setting):
|
||||
self.output('portfolio prepare_env')
|
||||
super().prepare_env(test_settings)
|
||||
super().prepare_env(test_setting)
|
||||
|
||||
def prepare_data(self, data_dict):
|
||||
"""
|
||||
@ -148,7 +148,7 @@ class PortfolioTestingEngine(BackTestingEngine):
|
||||
|
||||
self.bar_csv_file.update({symbol: bar_file})
|
||||
|
||||
def run_portfolio_test(self, strategy_settings: dict = {}):
|
||||
def run_portfolio_test(self, strategy_setting: dict = {}):
|
||||
"""
|
||||
运行组合回测
|
||||
"""
|
||||
@ -156,7 +156,7 @@ class PortfolioTestingEngine(BackTestingEngine):
|
||||
self.write_error(u'回测开始日期未设置。')
|
||||
return
|
||||
|
||||
if len(strategy_settings) == 0:
|
||||
if len(strategy_setting) == 0:
|
||||
self.write_error('未提供有效配置策略实例')
|
||||
return
|
||||
|
||||
@ -164,9 +164,12 @@ class PortfolioTestingEngine(BackTestingEngine):
|
||||
if not self.data_end_date:
|
||||
self.data_end_date = datetime.today()
|
||||
|
||||
# 保存回测设置/策略设置/任务ID至数据库
|
||||
self.save_setting_to_mongo()
|
||||
|
||||
self.write_log(u'开始组合回测')
|
||||
|
||||
for strategy_name, strategy_setting in strategy_settings.items():
|
||||
for strategy_name, strategy_setting in strategy_setting.items():
|
||||
self.load_strategy(strategy_name, strategy_setting)
|
||||
|
||||
self.write_log(u'策略初始化完成')
|
||||
@ -324,6 +327,7 @@ def single_test(test_setting: dict, strategy_setting: dict):
|
||||
except Exception as ex:
|
||||
print('组合回测异常{}'.format(str(ex)))
|
||||
traceback.print_exc()
|
||||
engine.save_fail_to_mongo(f'回测异常{str(ex)}')
|
||||
return False
|
||||
|
||||
print('测试结束')
|
||||
|
@ -16,6 +16,10 @@ import pandas as pd
|
||||
import traceback
|
||||
import numpy as np
|
||||
import logging
|
||||
import socket
|
||||
import zlib
|
||||
import pickle
|
||||
from bson import binary
|
||||
|
||||
from collections import OrderedDict, defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
@ -58,7 +62,8 @@ from vnpy.trader.utility import (
|
||||
)
|
||||
|
||||
from vnpy.trader.util_logger import setup_logger
|
||||
|
||||
from vnpy.data.mongo.mongo_data import MongoData
|
||||
from uuid import uuid1
|
||||
|
||||
class BackTestingEngine(object):
|
||||
"""
|
||||
@ -198,6 +203,12 @@ class BackTestingEngine(object):
|
||||
self.fund_kline_dict = {}
|
||||
self.active_fund_kline = False
|
||||
|
||||
# 回测任务/回测结果,保存在数据库中
|
||||
self.mongo_api = None
|
||||
self.task_id = None
|
||||
self.test_setting = None # 回测设置
|
||||
self.strategy_setting = None # 所有回测策略得设置
|
||||
|
||||
def create_fund_kline(self, name, use_renko=False):
|
||||
"""
|
||||
创建资金曲线
|
||||
@ -406,39 +417,41 @@ class BackTestingEngine(object):
|
||||
"""
|
||||
self.daily_report_name = report_file
|
||||
|
||||
def prepare_env(self, test_settings):
|
||||
def prepare_env(self, test_setting):
|
||||
"""
|
||||
根据配置参数,准备环境
|
||||
包括:
|
||||
回测名称 ,是否debug,数据目录/日志目录,
|
||||
资金/保证金类型/仓位控制
|
||||
回测开始/结束日期
|
||||
:param test_settings:
|
||||
:param test_setting:
|
||||
:return:
|
||||
"""
|
||||
self.output('back_testing prepare_env')
|
||||
if 'name' in test_settings:
|
||||
self.set_name(test_settings.get('name'))
|
||||
self.test_setting = copy.copy(test_setting)
|
||||
|
||||
self.mode = test_settings.get('mode', 'bar')
|
||||
self.output('back_testing prepare_env')
|
||||
if 'name' in test_setting:
|
||||
self.set_name(test_setting.get('name'))
|
||||
|
||||
self.mode = test_setting.get('mode', 'bar')
|
||||
self.output(f'采用{self.mode}方式回测')
|
||||
|
||||
self.contract_type = test_settings.get('contract_type', 'future')
|
||||
self.contract_type = test_setting.get('contract_type', 'future')
|
||||
self.output(f'测试合约主要为{self.contract_type}')
|
||||
|
||||
self.debug = test_settings.get('debug', False)
|
||||
self.debug = test_setting.get('debug', False)
|
||||
|
||||
# 更新数据目录
|
||||
if 'data_path' in test_settings:
|
||||
self.data_path = test_settings.get('data_path')
|
||||
if 'data_path' in test_setting:
|
||||
self.data_path = test_setting.get('data_path')
|
||||
else:
|
||||
self.data_path = os.path.abspath(os.path.join(os.getcwd(), 'data'))
|
||||
|
||||
print(f'数据输出目录:{self.data_path}')
|
||||
|
||||
# 更新日志目录
|
||||
if 'logs_path' in test_settings:
|
||||
self.logs_path = os.path.abspath(os.path.join(test_settings.get('logs_path'), self.test_name))
|
||||
if 'logs_path' in test_setting:
|
||||
self.logs_path = os.path.abspath(os.path.join(test_setting.get('logs_path'), self.test_name))
|
||||
else:
|
||||
self.logs_path = os.path.abspath(os.path.join(os.getcwd(), 'log', self.test_name))
|
||||
print(f'日志输出目录:{self.logs_path}')
|
||||
@ -447,55 +460,55 @@ class BackTestingEngine(object):
|
||||
self.create_logger(debug=self.debug)
|
||||
|
||||
# 设置资金
|
||||
if 'init_capital' in test_settings:
|
||||
self.write_log(u'设置期初资金:{}'.format(test_settings.get('init_capital')))
|
||||
self.set_init_capital(test_settings.get('init_capital'))
|
||||
if 'init_capital' in test_setting:
|
||||
self.write_log(u'设置期初资金:{}'.format(test_setting.get('init_capital')))
|
||||
self.set_init_capital(test_setting.get('init_capital'))
|
||||
|
||||
# 缺省使用保证金方式。(期货使用保证金/股票不使用保证金)
|
||||
self.use_margin = test_settings.get('use_margin', True)
|
||||
self.use_margin = test_setting.get('use_margin', True)
|
||||
|
||||
# 设置最大资金使用比例
|
||||
if 'percent_limit' in test_settings:
|
||||
self.write_log(u'设置最大资金使用比例:{}%'.format(test_settings.get('percent_limit')))
|
||||
self.percent_limit = test_settings.get('percent_limit')
|
||||
if 'percent_limit' in test_setting:
|
||||
self.write_log(u'设置最大资金使用比例:{}%'.format(test_setting.get('percent_limit')))
|
||||
self.percent_limit = test_setting.get('percent_limit')
|
||||
|
||||
if 'start_date' in test_settings:
|
||||
if 'strategy_start_date' not in test_settings:
|
||||
init_days = test_settings.get('init_days', 10)
|
||||
self.write_log(u'设置回测开始日期:{},数据加载日数:{}'.format(test_settings.get('start_date'), init_days))
|
||||
self.set_test_start_date(test_settings.get('start_date'), init_days)
|
||||
if 'start_date' in test_setting:
|
||||
if 'strategy_start_date' not in test_setting:
|
||||
init_days = test_setting.get('init_days', 10)
|
||||
self.write_log(u'设置回测开始日期:{},数据加载日数:{}'.format(test_setting.get('start_date'), init_days))
|
||||
self.set_test_start_date(test_setting.get('start_date'), init_days)
|
||||
else:
|
||||
start_date = test_settings.get('start_date')
|
||||
strategy_start_date = test_settings.get('strategy_start_date')
|
||||
start_date = test_setting.get('start_date')
|
||||
strategy_start_date = test_setting.get('strategy_start_date')
|
||||
self.write_log(u'使用指定的数据开始日期:{}和策略启动日期:{}'.format(start_date, strategy_start_date))
|
||||
self.test_start_date = start_date
|
||||
self.data_start_date = datetime.strptime(start_date.replace('-', ''), '%Y%m%d')
|
||||
self.strategy_start_date = datetime.strptime(strategy_start_date.replace('-', ''), '%Y%m%d')
|
||||
|
||||
if 'end_date' in test_settings:
|
||||
self.write_log(u'设置回测结束日期:{}'.format(test_settings.get('end_date')))
|
||||
self.set_test_end_date(test_settings.get('end_date'))
|
||||
if 'end_date' in test_setting:
|
||||
self.write_log(u'设置回测结束日期:{}'.format(test_setting.get('end_date')))
|
||||
self.set_test_end_date(test_setting.get('end_date'))
|
||||
|
||||
# 准备数据
|
||||
if 'symbol_datas' in test_settings:
|
||||
if 'symbol_datas' in test_setting:
|
||||
self.write_log(u'准备数据')
|
||||
self.prepare_data(test_settings.get('symbol_datas'))
|
||||
self.prepare_data(test_setting.get('symbol_datas'))
|
||||
|
||||
if self.mode == 'tick':
|
||||
self.tick_path = test_settings.get('tick_path', None)
|
||||
self.tick_path = test_setting.get('tick_path', None)
|
||||
|
||||
# 设置bar文件的时间间隔秒数
|
||||
if 'bar_interval_seconds' in test_settings:
|
||||
self.write_log(u'设置bar文件的时间间隔秒数:{}'.format(test_settings.get('bar_interval_seconds')))
|
||||
self.bar_interval_seconds = test_settings.get('bar_interval_seconds')
|
||||
if 'bar_interval_seconds' in test_setting:
|
||||
self.write_log(u'设置bar文件的时间间隔秒数:{}'.format(test_setting.get('bar_interval_seconds')))
|
||||
self.bar_interval_seconds = test_setting.get('bar_interval_seconds')
|
||||
|
||||
# 资金曲线
|
||||
self.active_fund_kline = test_settings.get('active_fund_kline', False)
|
||||
self.active_fund_kline = test_setting.get('active_fund_kline', False)
|
||||
if self.active_fund_kline:
|
||||
# 创建资金K线
|
||||
self.create_fund_kline(self.test_name, use_renko=test_settings.get('use_renko', False))
|
||||
self.create_fund_kline(self.test_name, use_renko=test_setting.get('use_renko', False))
|
||||
|
||||
self.is_plot_daily = test_settings.get('is_plot_daily', False)
|
||||
self.is_plot_daily = test_setting.get('is_plot_daily', False)
|
||||
|
||||
# 加载所有本地策略class
|
||||
self.load_strategy_class()
|
||||
@ -2125,8 +2138,104 @@ class BackTestingEngine(object):
|
||||
result_info.update({u'Sharpe Ratio': d['sharpe']})
|
||||
self.output(u'Sharpe Ratio:\t%s' % format_number(d['sharpe']))
|
||||
|
||||
# 保存回测结果/交易记录/日线统计 至数据库
|
||||
self.save_result_to_mongo(result_info)
|
||||
|
||||
return result_info
|
||||
|
||||
def save_setting_to_mongo(self):
|
||||
""" 保存测试设置到mongo中"""
|
||||
self.task_id = self.test_setting.get('task_id', str(uuid1()))
|
||||
|
||||
# 保存到mongo得配置
|
||||
save_mongo = self.test_setting.get('save_mongo', {})
|
||||
if len(save_mongo) == 0:
|
||||
return
|
||||
|
||||
if not self.mongo_api:
|
||||
self.mongo_api = MongoData(host=save_mongo.get('host', 'localhost'), port=save_mongo.get('port', 27017))
|
||||
|
||||
d = {
|
||||
'task_id': self.task_id, # 单实例回测任务id
|
||||
'name': self.test_name, # 回测实例名称, 策略名+参数+时间
|
||||
'group_id': self.test_setting.get('group_id', datetime.now().strftime('%y-%m-%d')), # 回测组合id
|
||||
'status': 'start',
|
||||
'task_start_time': datetime.now(), # 任务开始执行时间
|
||||
'run_host': socket.gethostname(), # 任务运行得host主机
|
||||
'test_setting': self.test_setting, # 回测参数
|
||||
'strategy_setting': self.strategy_setting, # 策略参数
|
||||
}
|
||||
|
||||
# 保存入数据库
|
||||
self.mongo_api.db_insert(
|
||||
db_name=self.gateway_name,
|
||||
col_name='tasks',
|
||||
d=d)
|
||||
|
||||
def save_fail_to_mongo(self, fail_msg):
|
||||
# 保存到mongo得配置
|
||||
save_mongo = self.test_setting.get('save_mongo', {})
|
||||
if len(save_mongo) == 0:
|
||||
return
|
||||
|
||||
if not self.mongo_api:
|
||||
self.mongo_api = MongoData(host=save_mongo.get('host', 'localhost'), port=save_mongo.get('port', 27017))
|
||||
|
||||
# 更新数据到数据库回测记录中
|
||||
flt = {'task_id': self.task_id}
|
||||
|
||||
d = self.mongo_api.db_query_one(
|
||||
db_name=self.gateway_name,
|
||||
col_name='tasks',
|
||||
flt=flt)
|
||||
|
||||
if d:
|
||||
d.update({'status': 'fail'}) # 更新状态未完成
|
||||
d.update({'fail_msg': fail_msg})
|
||||
|
||||
self.write_log(u'更新回测结果至数据库')
|
||||
|
||||
self.mongo_api.db_update(
|
||||
db_name=self.gateway_name,
|
||||
col_name='tasks',
|
||||
filter_dict=flt,
|
||||
data_dict=d,
|
||||
replace=False)
|
||||
|
||||
def save_result_to_mongo(self, result_info):
|
||||
|
||||
# 保存到mongo得配置
|
||||
save_mongo = self.test_setting.get('save_mongo', {})
|
||||
if len(save_mongo) == 0:
|
||||
return
|
||||
|
||||
if not self.mongo_api:
|
||||
self.mongo_api = MongoData(host=save_mongo.get('host', 'localhost'), port=save_mongo.get('port', 27017))
|
||||
|
||||
# 更新数据到数据库回测记录中
|
||||
flt = {'task_id': self.task_id}
|
||||
|
||||
d = self.mongo_api.db_query_one(
|
||||
db_name=self.gateway_name,
|
||||
col_name='tasks',
|
||||
flt=flt)
|
||||
|
||||
if d:
|
||||
d.update({'status': 'finish'}) # 更新状态未完成
|
||||
d.update(result_info) # 补充回测结果
|
||||
d.update({'task_finish_time': datetime.now()}) # 更新回测完成时间
|
||||
d.update({'trade_list': binary.Binary(zlib.compress(pickle.dumps(self.trade_pnl_list)))}) # 更新交易记录
|
||||
d.update({'daily_list': binary.Binary(zlib.compress(pickle.dumps(self.daily_list)))}) # 更新每日净值记录
|
||||
|
||||
self.write_log(u'更新回测结果至数据库')
|
||||
|
||||
self.mongo_api.db_update(
|
||||
db_name=self.gateway_name,
|
||||
col_name='tasks',
|
||||
filter_dict=flt,
|
||||
data_dict=d,
|
||||
replace=False)
|
||||
|
||||
def put_strategy_event(self, strategy: CtaTemplate):
|
||||
"""发送策略更新事件,回测中忽略"""
|
||||
pass
|
||||
|
@ -113,9 +113,9 @@ class PortfolioTestingEngine(BackTestingEngine):
|
||||
self.bar_df = pd.concat(self.bar_df_dict, axis=0).swaplevel(0, 1).sort_index()
|
||||
self.bar_df_dict.clear()
|
||||
|
||||
def prepare_env(self, test_settings):
|
||||
def prepare_env(self, test_setting):
|
||||
self.output('portfolio prepare_env')
|
||||
super().prepare_env(test_settings)
|
||||
super().prepare_env(test_setting)
|
||||
|
||||
def prepare_data(self, data_dict):
|
||||
"""
|
||||
@ -149,7 +149,7 @@ class PortfolioTestingEngine(BackTestingEngine):
|
||||
|
||||
self.bar_csv_file.update({symbol: bar_file})
|
||||
|
||||
def run_portfolio_test(self, strategy_settings: dict = {}):
|
||||
def run_portfolio_test(self, strategy_setting: dict = {}):
|
||||
"""
|
||||
运行组合回测
|
||||
"""
|
||||
@ -157,7 +157,7 @@ class PortfolioTestingEngine(BackTestingEngine):
|
||||
self.write_error(u'回测开始日期未设置。')
|
||||
return
|
||||
|
||||
if len(strategy_settings) == 0:
|
||||
if len(strategy_setting) == 0:
|
||||
self.write_error('未提供有效配置策略实例')
|
||||
return
|
||||
|
||||
@ -165,9 +165,12 @@ class PortfolioTestingEngine(BackTestingEngine):
|
||||
if not self.data_end_date:
|
||||
self.data_end_date = datetime.today()
|
||||
|
||||
# 保存回测脚本到数据库
|
||||
self.save_setting_to_mongo()
|
||||
|
||||
self.write_log(u'开始组合回测')
|
||||
|
||||
for strategy_name, strategy_setting in strategy_settings.items():
|
||||
for strategy_name, strategy_setting in strategy_setting.items():
|
||||
self.load_strategy(strategy_name, strategy_setting)
|
||||
|
||||
self.write_log(u'策略初始化完成')
|
||||
@ -447,6 +450,7 @@ def single_test(test_setting: dict, strategy_setting: dict):
|
||||
except Exception as ex:
|
||||
print('组合回测异常{}'.format(str(ex)))
|
||||
traceback.print_exc()
|
||||
engine.save_fail_to_mongo(f'回测异常{str(ex)}')
|
||||
return False
|
||||
|
||||
print('测试结束')
|
||||
|
@ -62,9 +62,9 @@ class SpreadTestingEngine(BackTestingEngine):
|
||||
self.strategy_start_date_dict = {}
|
||||
self.strategy_end_date_dict = {}
|
||||
|
||||
def prepare_env(self, test_settings):
|
||||
def prepare_env(self, test_setting):
|
||||
self.output('portfolio prepare_env')
|
||||
super().prepare_env(test_settings)
|
||||
super().prepare_env(test_setting)
|
||||
|
||||
def load_strategy(self, strategy_name: str, strategy_setting: dict = None):
|
||||
"""
|
||||
@ -413,6 +413,7 @@ def single_test(test_setting: dict, strategy_setting: dict):
|
||||
|
||||
except Exception as ex:
|
||||
print('组合回测异常{}'.format(str(ex)))
|
||||
engine.save_fail_to_mongo(f'回测异常{str(ex)}')
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user