[update] 股票引擎。增加复权缓存,K线缓存

This commit is contained in:
msincenselee 2021-08-30 08:59:45 +08:00
parent aa865aa38a
commit 344f877bda
4 changed files with 519 additions and 79 deletions

View File

@ -589,7 +589,7 @@ class BackTestingEngine(object):
股票数据复权转换
:param raw_data: 不复权数据
:param adj_data: 复权记录 ( 从barstock下载的复权记录列表=df
:param adj_type: 复权类型
:param adj_type: 复权类型: fore 前复权, 其他后复权
:return:
"""
@ -1735,7 +1735,8 @@ class BackTestingEngine(object):
# 返回回测结果
d = {}
d['init_capital'] = self.init_capital
d['profit'] = self.cur_capital - self.init_capital
d['net_capital'] = self.net_capital
d['profit'] = self.net_capital - self.init_capital
d['max_capital'] = self.max_net_capital # 取消原 maxCapital
if len(self.pnl_list) == 0:
@ -1816,6 +1817,9 @@ class BackTestingEngine(object):
result_info.update({u'期初资金': d['init_capital']})
self.output(u'期初资金:\t%s' % format_number(d['init_capital']))
result_info.update({u'期末资金': d['net_capital']})
self.output(u'期末资金:\t%s' % format_number(d['net_capital']))
result_info.update({u'总盈亏': d['profit']})
self.output(u'总盈亏:\t%s' % format_number(d['profit']))

View File

@ -10,6 +10,8 @@ import traceback
import json
import pickle
import bz2
import pandas as pd
import numpy as np
from collections import defaultdict
from pathlib import Path
@ -60,12 +62,16 @@ from vnpy.trader.utility import (
get_folder_path,
get_underlying_symbol,
append_data,
import_module_by_str)
import_module_by_str,
get_csv_last_dt)
from vnpy.trader.util_logger import setup_logger, logging
from vnpy.trader.util_wechat import send_wx_msg
from vnpy.trader.converter import PositionHolding
from vnpy.data.mongo.mongo_data import MongoData
from vnpy.trader.setting import SETTINGS
from vnpy.data.stock.adjust_factor import get_all_adjust_factor
from vnpy.data.stock.stock_base import get_stock_base
from .base import (
APP_NAME,
EVENT_CTA_LOG,
@ -115,9 +121,14 @@ class CtaEngine(BaseEngine):
# "trade_2_wx": true # 是否交易记录转发至微信通知
# "event_log: false # 是否转发日志到event bus显示在图形界面
# "snapshot2file": false # 是否保存切片到文件
self.engine_config = {}
# "get_pos_from_db": false # 是否从数据库中获取所有策略得持仓信息。因为其他RPC进程也运行策略推送到Mongodb)
# "compare_pos": True 是否比对仓位
# "compare_pos_bypass_names": ["name1","name2"] # 比对仓位时,过滤账号中含有字符得股票,例如自动申购得转债等
self.engine_config = load_json(self.engine_filename)
# 是否激活 write_log写入event bus(比较耗资源)
self.event_log = False
self.event_log = self.engine_config.get('event_log', False)
self.strategy_setting = {} # strategy_name: dict
self.strategy_data = {} # strategy_name: dict
@ -157,6 +168,40 @@ class CtaEngine(BaseEngine):
self.symbol_bar_dict = {} # vt_symbol: bar(一分钟bar)
self.stock_adjust_factors = get_all_adjust_factor()
self.mongo_data = None
# 获取全量股票信息
self.write_log(f'获取全量股票信息')
self.symbol_dict = get_stock_base()
self.write_log(f'{len(self.symbol_dict)}个股票')
# 除权因子
self.write_log(f'获取所有除权因子')
self.adjust_factor_dict = get_all_adjust_factor()
self.write_log(f'{len(self.adjust_factor_dict)}条除权信息')
# 寻找数据文件所在目录
vnpy_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
self.write_log(f'项目所在目录:{vnpy_root}')
self.bar_data_folder = os.path.abspath(os.path.join(vnpy_root, 'bar_data'))
if os.path.exists(self.bar_data_folder):
SSE_folder = os.path.abspath(os.path.join(vnpy_root, 'bar_data', 'SSE'))
if os.path.exists(SSE_folder):
self.write_log(f'上交所bar数据目录:{SSE_folder}')
else:
self.write_error(f'不存在上交所数据目录:{SSE_folder}')
SZSE_folder = os.path.abspath(os.path.join(vnpy_root, 'bar_data', 'SZSE'))
if os.path.exists(SZSE_folder):
self.write_log(f'深交所bar数据目录:{SZSE_folder}')
else:
self.write_error(f'不存在深交所数据目录:{SZSE_folder}')
else:
self.write_error(f'不存在bar数据目录:{self.bar_data_folder}')
self.bar_data_folder = None
def init_engine(self):
"""
"""
@ -168,6 +213,25 @@ class CtaEngine(BaseEngine):
self.write_log("CTA策略股票引擎初始化成功")
if self.engine_config.get('get_pos_from_db',False):
self.init_mongo_data()
def init_mongo_data(self):
"""初始化hams数据库"""
host = SETTINGS.get('hams.host', 'localhost')
port = SETTINGS.get('hams.port', 27017)
self.write_log(f'初始化hams数据库连接:{host}:{port}')
try:
# Mongo数据连接客户端
self.mongo_data = MongoData(host=host, port=port)
if self.mongo_data and self.mongo_data.db_has_connected:
self.write_log(f'连接成功')
else:
self.write_error(f'HAMS数据库{host}:{port}连接异常.')
except Exception as ex:
self.write_error(f'HAMS数据库{host}:{port}连接异常.{str(ex)}')
def close(self):
"""停止所属有的策略"""
self.stop_all_strategies()
@ -227,12 +291,12 @@ class CtaEngine(BaseEngine):
self.health_check()
# 在国内股市开盘期间做检查
if '0930' < self.last_minute < '1530':
if '0930' < self.last_minute < '1830':
# 主动获取所有策略得持仓信息
all_strategy_pos = self.get_all_strategy_pos()
# 每5分钟检查一次
if dt.minute % 10 == 0:
if dt.minute % 10 == 0 and self.engine_config.get('compare_pos', True):
# 比对仓位,使用上述获取得持仓信息,不用重复获取
self.compare_pos(strategy_pos_list=copy(all_strategy_pos))
@ -881,6 +945,25 @@ class CtaEngine(BaseEngine):
def get_contract(self, vt_symbol):
return self.main_engine.get_contract(vt_symbol)
def get_adjust_factor(self, vt_symbol, check_date=None):
"""
获取[check_date前]除权因子
:param vt_symbol:
:param check_date: 某一指定日期
:return:
"""
stock_adjust_factor_list = self.stock_adjust_factors.get(vt_symbol, [])
if len(stock_adjust_factor_list) == 0:
return None
stock_adjust_factor_list.reverse()
if check_date is None:
check_date = datetime.now().strftime('%Y-%m-%d')
for d in stock_adjust_factor_list:
if d.get("dividOperateDate","") < check_date:
return d
return None
def get_account(self, vt_accountid: str = ""):
""" 查询账号的资金"""
# 如果启动风控,则使用风控中的最大仓位
@ -955,6 +1038,253 @@ class CtaEngine(BaseEngine):
callback(bar)
def get_bars(
self,
vt_symbol: str,
days: int,
interval: Interval,
interval_num: int = 1
):
"""获取历史记录"""
symbol, exchange = extract_vt_symbol(vt_symbol)
end = datetime.now()
start = end - timedelta(days)
bars = []
# 检查股票代码
if vt_symbol not in self.symbol_dict:
self.write_error(f'{vt_symbol}不在基础配置股票信息中')
return bars
# 检查数据文件目录
if not self.bar_data_folder:
self.write_error(f'没有bar数据目录')
return bars
# 按照交易所的存放目录
bar_file_folder = os.path.abspath(os.path.join(self.bar_data_folder, f'{exchange.value}'))
resample_min = False
resample_hour = False
resample_day = False
file_interval_num = 1
# 只有1,5,15,30分钟日线数据
if interval == Interval.MINUTE:
# 如果存在相应的分钟文件,直接读取
bar_file_path = os.path.abspath(os.path.join(bar_file_folder, f'{symbol}_{interval_num}m.csv'))
if interval_num in [1, 5, 15, 30] and os.path.exists(bar_file_path):
file_interval_num = interval
# 需要resample
else:
resample_min = True
if interval_num > 5:
file_interval_num = 5
elif interval == Interval.HOUR:
file_interval_num = 5
resample_hour = True
bar_file_path = os.path.abspath(os.path.join(bar_file_folder, f'{symbol}_{file_interval_num}m.csv'))
elif interval == Interval.DAILY:
bar_file_path = os.path.abspath(os.path.join(bar_file_folder, f'{symbol}_{interval_num}d.csv'))
if not os.path.exists(bar_file_path):
file_interval_num = 5
resample_day = True
bar_file_path = os.path.abspath(os.path.join(bar_file_folder, f'{symbol}_{file_interval_num}m.csv'))
else:
self.write_error(f'目前仅支持分钟,小时,日线数据')
return bars
bar_interval_seconds = interval_num * 60
if not os.path.exists(bar_file_path):
self.write_error(f'没有bar数据文件{bar_file_path}')
return bars
try:
data_types = {
"datetime": str,
"open": float,
"high": float,
"low": float,
"close": float,
"volume": float,
"amount": float,
"symbol": str,
"trading_day": str,
"date": str,
"time": str
}
symbol_df = None
qfq_bar_file_path = bar_file_path.replace('.csv', '_qfq.csv')
use_qfq_file = False
last_qfq_dt = get_csv_last_dt(qfq_bar_file_path)
if last_qfq_dt is not None:
last_dt = get_csv_last_dt(bar_file_path)
if last_qfq_dt == last_dt:
use_qfq_file = True
if use_qfq_file:
self.write_log(f'使用前复权文件:{qfq_bar_file_path}')
symbol_df = pd.read_csv(qfq_bar_file_path, dtype=data_types)
else:
# 加载csv文件 =》 dateframe
self.write_log(f'使用未复权文件:{bar_file_path}')
symbol_df = pd.read_csv(bar_file_path, dtype=data_types)
# 转换时间str =》 datetime
symbol_df["datetime"] = pd.to_datetime(symbol_df["datetime"], format="%Y-%m-%d %H:%M:%S")
# 设置时间为索引
symbol_df = symbol_df.set_index("datetime")
# 裁剪数据
symbol_df = symbol_df.loc[start:end]
if resample_day:
self.write_log(f'{vt_symbol} resample:{file_interval_num}m => {interval}day')
symbol_df = self.resample_bars(df=symbol_df, to_day=True)
elif resample_hour:
self.write_log(f'{vt_symbol} resample:{file_interval_num}m => {interval}hour')
symbol_df = self.resample_bars(df=symbol_df, x_hour=interval_num)
elif resample_min:
self.write_log(f'{vt_symbol} resample:{file_interval_num}m => {interval}m')
symbol_df = self.resample_bars(df=symbol_df, x_min=interval_num)
if len(symbol_df) == 0:
return bars
if not use_qfq_file:
# 复权转换
adj_list = self.adjust_factor_dict.get(vt_symbol, [])
# 按照结束日期,裁剪复权记录
adj_list = [row for row in adj_list if row['dividOperateDate'].replace('-', '') <= end.strftime('%Y%m%d')]
if len(adj_list) > 0:
self.write_log(f'需要对{vt_symbol}进行前复权处理')
for row in adj_list:
row.update({'dividOperateDate': row.get('dividOperateDate')[:10] + ' 09:31:00'})
# list -> dataframe, 转换复权日期格式
adj_data = pd.DataFrame(adj_list)
adj_data["dividOperateDate"] = pd.to_datetime(adj_data["dividOperateDate"], format="%Y-%m-%d %H:%M:%S")
adj_data = adj_data.set_index("dividOperateDate")
# 调用转换方法对open,high,low,close, volume进行复权, fore, 前复权, 其他,后复权
symbol_df = self.stock_to_adj(symbol_df, adj_data, adj_type='fore')
for dt, bar_data in symbol_df.iterrows():
bar_datetime = dt #- timedelta(seconds=bar_interval_seconds)
bar = BarData(
gateway_name='backtesting',
symbol=symbol,
exchange=exchange,
datetime=bar_datetime
)
if 'open' in bar_data:
bar.open_price = float(bar_data['open'])
bar.close_price = float(bar_data['close'])
bar.high_price = float(bar_data['high'])
bar.low_price = float(bar_data['low'])
else:
bar.open_price = float(bar_data['open_price'])
bar.close_price = float(bar_data['close_price'])
bar.high_price = float(bar_data['high_price'])
bar.low_price = float(bar_data['low_price'])
bar.volume = int(bar_data['volume']) if not np.isnan(bar_data['volume']) else 0
bar.date = dt.strftime('%Y-%m-%d')
bar.time = dt.strftime('%H:%M:%S')
str_td = str(bar_data.get('trading_day', ''))
if len(str_td) == 8:
bar.trading_day = str_td[0:4] + '-' + str_td[4:6] + '-' + str_td[6:8]
else:
bar.trading_day = bar.date
bars.append(bar)
except Exception as ex:
self.write_error(u'回测时读取{} csv文件{}失败:{}'.format(vt_symbol, bar_file_path, ex))
self.write_error(traceback.format_exc())
return bars
return bars
def stock_to_adj(self, raw_data, adj_data, adj_type):
"""
股票数据复权转换
:param raw_data: 不复权数据
:param adj_data: 复权记录 ( 从barstock下载的复权记录列表=df
:param adj_type: 复权类型
:return:
"""
if adj_type == 'fore':
adj_factor = adj_data["foreAdjustFactor"]
adj_factor = adj_factor / adj_factor.iloc[-1] # 保证最后一个复权因子是1
else:
adj_factor = adj_data["backAdjustFactor"]
adj_factor = adj_factor / adj_factor.iloc[0] # 保证第一个复权因子是1
# 把raw_data的第一个日期插入复权因子df使用后填充
if adj_factor.index[0] != raw_data.index[0]:
adj_factor.loc[raw_data.index[0]] = np.nan
adj_factor.sort_index(inplace=True)
adj_factor = adj_factor.ffill()
adj_factor = adj_factor.reindex(index=raw_data.index) # 按价格dataframe的日期索引来扩展索引
adj_factor = adj_factor.ffill() # 向前(向未来)填充扩展后的空单元格
# 把复权因子作为adj字段补充到raw_data中
raw_data['adj'] = adj_factor
# 逐一复权高低开平和成交量
for col in ['open', 'high', 'low', 'close']:
raw_data[col] = raw_data[col] * raw_data['adj'] # 价格乘上复权系数
raw_data['volume'] = raw_data['volume'] / raw_data['adj'] # 成交量除以复权系数
return raw_data
def resample_bars(self, df, x_min=None, x_hour=None, to_day=False):
"""
重建x分钟K线或日线
:param df: 输入分钟数
:param x_min: 5, 15, 30, 60
:param x_hour: 1, 2, 3, 4
:param include_day: 重建日线, True得时候不会重建分钟数
:return:
"""
# 设置df数据中每列的规则
ohlc_rule = {
'open': 'first', # open列序列中第一个的值
'high': 'max', # high列序列中最大的值
'low': 'min', # low列序列中最小的值
'close': 'last', # close列序列中最后一个的值
'volume': 'sum', # volume列将所有序列里的volume值作和
'amount': 'sum', # amount列将所有序列里的amount值作和
"symbol": 'first',
"trading_date": 'first',
"date": 'first',
"time": 'first'
}
if isinstance(x_min, int) and not to_day:
# 合成x分钟K线并删除为空的行 参数 closedleft类似向上取值既 0930的k线数据是包含0930-0935之间的数据
df_target = df.resample(f'{x_min}min', closed='left', label='left').agg(ohlc_rule).dropna(axis=0,
how='any')
return df_target
if isinstance(x_hour, int) and not to_day:
# 合成x小时K线并删除为空的行 参数 closedleft类似向上取值既 0930的k线数据是包含0930-0935之间的数据
df_target = df.resample(f'{x_hour}hour', closed='left', label='left').agg(ohlc_rule).dropna(axis=0,
how='any')
return df_target
if to_day:
# 合成x分钟K线并删除为空的行 参数 closedleft类似向上取值既 0930的k线数据是包含0930-0935之间的数据
df_target = df.resample(f'D', closed='left', label='left').agg(ohlc_rule).dropna(axis=0, how='any')
return df_target
return df
def call_strategy_func(
self, strategy: CtaTemplate, func: Callable, params: Any = None
):
@ -1153,6 +1483,7 @@ class CtaEngine(BaseEngine):
# Remove from symbol strategy map
self.write_log(f'移除{vt_symbol}《=》{strategy_name}的订阅关系')
strategies = self.symbol_strategy_map[vt_symbol]
if strategy in strategies:
strategies.remove(strategy)
# Remove from active orderid map
@ -1528,6 +1859,27 @@ class CtaEngine(BaseEngine):
return strategy_pos_list
def get_all_strategy_pos_from_hams(self):
"""
获取hams中该账号下所有策略仓位明细
"""
strategy_pos_list = []
if not self.mongo_data:
self.init_mongo_data()
if self.mongo_data and self.mongo_data.db_has_connected:
filter = {'account_id':self.engine_config.get('accountid','-')}
pos_list = self.mongo_data.db_query(
db_name='Account',
col_name='today_strategy_pos',
filter_dict=filter
)
for pos in pos_list:
strategy_pos_list.append(pos)
return strategy_pos_list
def get_strategy_class_parameters(self, class_name: str):
"""
Get default parameters of a strategy class.
@ -1580,9 +1932,17 @@ class CtaEngine(BaseEngine):
self.write_log(u'开始对比账号&策略的持仓')
# 获取当前策略得持仓
# 获取hams数据库中所有运行实例得策略
if self.engine_config.get("get_pos_from_db", False):
strategy_pos_list = self.get_all_strategy_pos_from_hams()
else:
# 获取当前实例运行策略得持仓
if len(strategy_pos_list) == 0:
strategy_pos_list = self.get_all_strategy_pos()
# 过滤账号中名字列表
bypass_names = self.engine_config.get('compare_pos_bypass_names',[])
self.write_log(u'策略持仓清单:{}'.format(strategy_pos_list))
none_strategy_pos = self.get_none_strategy_pos_list()
@ -1597,6 +1957,12 @@ class CtaEngine(BaseEngine):
for position in list(self.positions.values()):
# gateway_name.symbol.exchange => symbol.exchange
vt_symbol = position.vt_symbol
cn_name = self.get_name(vt_symbol)
# 中文名字,在过滤清单中
if any([name in cn_name for name in bypass_names]):
continue
vt_symbols.add(vt_symbol)
compare_pos[vt_symbol] = OrderedDict(
@ -1631,6 +1997,8 @@ class CtaEngine(BaseEngine):
u'{}({})'.format(strategy_pos['strategy_name'], abs(pos.get('volume', 0))))
self.write_log(u'更新{}策略持多仓=>{}'.format(vt_symbol, symbol_pos.get('策略多单', 0)))
compare_pos.update({vt_symbol:symbol_pos})
pos_compare_result = ''
# 精简输出
compare_info = ''
@ -1724,14 +2092,7 @@ class CtaEngine(BaseEngine):
"""
Load setting file.
"""
# 读取引擎得配置
# "accountid" : "xxxx", 资金账号,一般用于推送消息时附带
# "strategy_group": "cta_strategy_pro", # 当前实例名。多个实例时,区分开
# "trade_2_wx": true # 是否交易记录转发至微信通知
# "event_log: false # 是否转发日志到event bus显示在图形界面
self.engine_config = load_json(self.engine_filename)
# 是否产生event log 日志一般GUI界面才产生而且比好消耗资源)
self.event_log = self.engine_config.get('event_log', False)
# 读取策略得配置
self.strategy_setting = load_json(self.setting_filename)

View File

@ -15,6 +15,7 @@ import traceback
import random
import bz2
import pickle
import numpy as np
from datetime import datetime, timedelta
from time import sleep
@ -31,6 +32,7 @@ from vnpy.trader.constant import (
from vnpy.trader.utility import (
get_trading_date,
extract_vt_symbol,
get_csv_last_dt
)
from .back_testing import BackTestingEngine
@ -57,7 +59,7 @@ class PortfolioTestingEngine(BackTestingEngine):
self.tick_path = None # tick级别回测 路径
def load_bar_csv_to_df(self, vt_symbol, bar_file, data_start_date=None, data_end_date=None):
def load_bar_csv_to_df(self, vt_symbol, bar_file, data_start_date=None, data_end_date=None, qfq=True):
"""
加载回测bar数据到DataFrame
1. 增加前复权/后复权
@ -65,14 +67,16 @@ class PortfolioTestingEngine(BackTestingEngine):
:param bar_file:
:param data_start_date:
:param data_end_date:
:param qfq:True 前复权False 后复权
:return:
"""
self.output(u'loading {} from {}'.format(vt_symbol, bar_file))
fq_name = '前复权' if qfq else '后复权'
self.output(u'加载数据[{}] :未复权文件{},复权转换:{}'.format(vt_symbol, bar_file, fq_name))
if vt_symbol in self.bar_df_dict:
return True
if bar_file is None or not os.path.exists(bar_file):
self.write_error(u'回测时,{}对应的csv bar文件{}不存在'.format(vt_symbol, bar_file))
self.write_error(u'加载数据[{}]:对应的csv bar文件{}不存在'.format(vt_symbol, bar_file))
return False
try:
@ -93,6 +97,28 @@ class PortfolioTestingEngine(BackTestingEngine):
"date": str,
"time": str
}
symbol_df = None
auto_generate_fq = True
# 复权文件
fq_bar_file = bar_file.replace('.csv', '_qfq.csv' if qfq else '_hfq.csv')
if os.path.exists(fq_bar_file):
# 存在复权文件
last_dt = get_csv_last_dt(fq_bar_file)
if isinstance(last_dt, datetime):
if last_dt.strftime('%Y-%m-%d') < self.test_end_date:
self.write_log(f'加载数据[{vt_symbol}], 使用{fq_name}文件:{fq_bar_file}')
symbol_df = pd.read_csv(bar_file, dtype=data_types)
# 转换时间str =》 datetime
symbol_df["datetime"] = pd.to_datetime(symbol_df["datetime"], format="%Y-%m-%d %H:%M:%S")
# 设置时间为索引
symbol_df = symbol_df.set_index("datetime")
# 裁剪数据
symbol_df = symbol_df.loc[self.test_start_date:self.test_end_date]
# 不再产生复权文件
auto_generate_fq = False
if not isinstance(symbol_df, pd.DataFrame):
# 加载csv文件 =》 dateframe
symbol_df = pd.read_csv(bar_file, dtype=data_types)
# 转换时间str =》 datetime
@ -109,16 +135,24 @@ class PortfolioTestingEngine(BackTestingEngine):
adj_list = [row for row in adj_list if row['dividOperateDate'].replace('-', '') <= self.test_end_date]
if adj_list:
self.write_log(f'需要对{vt_symbol}进行前复权处理')
self.write_log(f'加载数据[{vt_symbol}], 对{vt_symbol}进行{fq_name}处理')
for row in adj_list:
row.update({'dividOperateDate': row.get('dividOperateDate') + ' 09:31:00'})
d = row.get('dividOperateDate', "")[0:10]
if len(d) == 10:
row.update({'dividOperateDate': d + ' 09:31:00'})
# list -> dataframe, 转换复权日期格式
adj_data = pd.DataFrame(adj_list)
adj_data["dividOperateDate"] = pd.to_datetime(adj_data["dividOperateDate"], format="%Y-%m-%d %H:%M:%S")
adj_data["dividOperateDate"] = pd.to_datetime(adj_data["dividOperateDate"],
format="%Y-%m-%d %H:%M:%S")
adj_data = adj_data.set_index("dividOperateDate")
# 调用转换方法对open,high,low,close, volume进行复权, fore, 前复权, 其他,后复权
symbol_df = self.stock_to_adj(symbol_df, adj_data, adj_type='fore')
symbol_df = self.stock_to_adj(symbol_df, adj_data, adj_type='fore' if qfq else "")
if auto_generate_fq:
self.write_log(f'加载数据[{vt_symbol}] ,缓存{fq_name}文件=>{fq_bar_file}')
symbol_df.to_csv(fq_bar_file)
if isinstance(symbol_df, pd.DataFrame):
# 添加到待合并dataframe dict中
self.bar_df_dict.update({vt_symbol: symbol_df})
@ -277,7 +311,7 @@ class PortfolioTestingEngine(BackTestingEngine):
bar.high_price = float(bar_data['high_price'])
bar.low_price = float(bar_data['low_price'])
bar.volume = int(bar_data['volume'])
bar.volume = float(bar_data['volume'])
bar.date = dt.strftime('%Y-%m-%d')
bar.time = dt.strftime('%H:%M:%S')
str_td = str(bar_data.get('trading_day', ''))
@ -299,6 +333,8 @@ class PortfolioTestingEngine(BackTestingEngine):
# bar时间与队列时间不一致先推送队列的bars
random.shuffle(bars_same_dt)
for _bar_ in bars_same_dt:
if np.isnan(_bar_.close_price):
continue
self.new_bar(_bar_)
# 创建新的队列

View File

@ -476,19 +476,32 @@ class CtaStockTemplate(CtaTemplate):
self.write_log(u'保存policy数据')
self.policy.save()
def save_klines_to_cache(self, kline_names: list = []):
def save_klines_to_cache(self, kline_names: list = [], vt_symbol: str = ""):
"""
保存K线数据到缓存
:param kline_names: 一般为self.klines的keys
:param vt_symbol: 指定股票代码,
如果使用该选项加载 data/klines/strategyname_vtsymbol_klines.pkb2
如果空白加载 data/strategyname_klines.pkb2
:return:
"""
if len(kline_names) == 0:
kline_names = list(self.klines.keys())
try:
# 如果是指定合约的话使用klines子目录
if len(vt_symbol) > 0:
kline_names = [n for n in kline_names if vt_symbol in n]
save_path = os.path.abspath(os.path.join(self.cta_engine.get_data_path(), 'klines'))
if not os.path.exists(save_path):
os.makedirs(save_path)
file_name = os.path.abspath(os.path.join(save_path, f'{self.strategy_name}_{vt_symbol}_klines.pkb2'))
else:
# 获取保存路径
save_path = self.cta_engine.get_data_path()
# 保存缓存的文件名
file_name = os.path.abspath(os.path.join(save_path, f'{self.strategy_name}_klines.pkb2'))
with bz2.BZ2File(file_name, 'wb') as f:
klines = {}
for kline_name in kline_names:
@ -498,16 +511,31 @@ class CtaStockTemplate(CtaTemplate):
# kline.cb_on_bar = None
klines.update({kline_name: kline})
pickle.dump(klines, f)
self.write_log(f'保存{vt_symbol} K线数据成功=>{file_name}')
except Exception as ex:
self.write_error(f'保存k线数据异常:{str(ex)}')
self.write_error(traceback.format_exc())
def load_klines_from_cache(self, kline_names: list = []):
def load_klines_from_cache(self, kline_names: list = [], vt_symbol: str = ""):
"""
从缓存加载K线数据
:param kline_names: 指定需要加载的k线名称列表
:param vt_symbol: 指定股票代码,
如果使用该选项加载 data/klines/strategyname_vtsymbol_klines.pkb2
如果空白加载 data/strategyname_klines.pkb2
:return:
"""
if len(kline_names) == 0:
kline_names = list(self.klines.keys())
# 如果是指定合约的话使用klines子目录
if len(vt_symbol) > 0:
kline_names = [n for n in kline_names if vt_symbol in n]
save_path = os.path.abspath(os.path.join(self.cta_engine.get_data_path(), 'klines'))
if not os.path.exists(save_path):
os.makedirs(save_path)
file_name = os.path.abspath(os.path.join(save_path, f'{self.strategy_name}_{vt_symbol}_klines.pkb2'))
else:
save_path = self.cta_engine.get_data_path()
file_name = os.path.abspath(os.path.join(save_path, f'{self.strategy_name}_klines.pkb2'))
try:
@ -976,7 +1004,7 @@ class CtaStockTemplate(CtaTemplate):
self.gt.remove_grids_by_ids(direction=Direction.LONG, ids=remove_gids)
self.gt.save()
def tns_excute_sell_grids(self, vt_symbol=None):
def tns_excute_sell_grids(self, vt_symbol=None, force=False):
"""
事务执行卖出网格
1找出所有order_status=True,open_status=Talse, close_status=True的网格
@ -1027,17 +1055,27 @@ class CtaStockTemplate(CtaTemplate):
sell_volume = ordering_grid.volume - ordering_grid.traded_volume
if sell_volume > acc_symbol_pos.volume:
if not force:
self.write_error(u'账号{}持仓{},不满足减仓目标:{}'
.format(vt_symbol, acc_symbol_pos.volume, sell_volume))
continue
else:
self.write_log(u'账号{}持仓{},不满足减仓目标:{}, 修正卖出数量:{}=>{}'
.format(vt_symbol, acc_symbol_pos.volume, sell_volume, sell_volume,
acc_symbol_pos.volume))
sell_volume = acc_symbol_pos.volume
if sell_volume == 0:
self.write_log(f'账号{vt_symbol}持仓{acc_symbol_pos.volume},卖出目标:{sell_volume}=0 不执行')
continue
# 实盘运行时,要加入市场买卖量的判断
if not self.backtesting:
cur_price = self.cta_engine.get_price(vt_symbol)
if not cur_price:
self.cta_engine.subscribe_symbol(strategy_name=self.strategy_name, vt_symbol=vt_symbol)
continue
# 实盘运行时,要加入市场买卖量的判断
if not force and not self.backtesting:
symbol_tick = self.cta_engine.get_tick(vt_symbol)
if symbol_tick:
symbol_volume_tick = self.cta_engine.get_volume_tick(vt_symbol)
@ -1072,7 +1110,7 @@ class CtaStockTemplate(CtaTemplate):
self.write_log(f'{vt_symbol} 已委托卖出,{sell_volume},委托价:{sell_price}, 数量:{sell_volume}')
def tns_finish_sell_grid(self, grid):
def tns_finish_sell_grid(self, grid:CtaGrid):
"""
事务完成卖出网格
:param grid:
@ -1106,7 +1144,7 @@ class CtaStockTemplate(CtaTemplate):
self.gt.save()
self.policy.save()
def tns_execute_buy_grids(self, vt_symbol=None):
def tns_execute_buy_grids(self, vt_symbol=None, force=False):
"""
事务执行买入网格
:return:
@ -1149,7 +1187,8 @@ class CtaStockTemplate(CtaTemplate):
balance, availiable, _, _ = self.cta_engine.get_account()
if availiable <= 0:
self.write_error(u'当前可用资金不足'.format(availiable))
if not self.backtesting:
self.write_log(u'当前可用资金{}不足,总资金:{}'.format(availiable, balance))
continue
vt_symbol = ordering_grid.vt_symbol
cur_price = self.cta_engine.get_price(vt_symbol)
@ -1178,7 +1217,7 @@ class CtaStockTemplate(CtaTemplate):
buy_volume = max_buy_volume
# 实盘运行时,要加入市场买卖量的判断
if not self.backtesting and 'market' in ordering_grid.snapshot:
if not force and not self.backtesting and 'market' in ordering_grid.snapshot:
symbol_tick = self.cta_engine.get_tick(vt_symbol)
if symbol_tick:
# 根据市场计算前5档买单数量
@ -1207,7 +1246,7 @@ class CtaStockTemplate(CtaTemplate):
else:
self.write_log(f'{self.strategy_name}, {vt_orderids},已委托买入,{vt_symbol} 委托价:{buy_price} 数量:{buy_volume}')
def tns_finish_buy_grid(self, grid):
def tns_finish_buy_grid(self, grid:CtaGrid):
"""
事务完成买入网格
:return: