[update] 股票引擎。增加复权缓存,K线缓存

This commit is contained in:
msincenselee 2021-08-30 08:59:45 +08:00
parent aa865aa38a
commit 344f877bda
4 changed files with 519 additions and 79 deletions

View File

@ -589,7 +589,7 @@ class BackTestingEngine(object):
股票数据复权转换 股票数据复权转换
:param raw_data: 不复权数据 :param raw_data: 不复权数据
:param adj_data: 复权记录 ( 从barstock下载的复权记录列表=df :param adj_data: 复权记录 ( 从barstock下载的复权记录列表=df
:param adj_type: 复权类型 :param adj_type: 复权类型: fore 前复权, 其他后复权
:return: :return:
""" """
@ -1735,7 +1735,8 @@ class BackTestingEngine(object):
# 返回回测结果 # 返回回测结果
d = {} d = {}
d['init_capital'] = self.init_capital d['init_capital'] = self.init_capital
d['profit'] = self.cur_capital - self.init_capital d['net_capital'] = self.net_capital
d['profit'] = self.net_capital - self.init_capital
d['max_capital'] = self.max_net_capital # 取消原 maxCapital d['max_capital'] = self.max_net_capital # 取消原 maxCapital
if len(self.pnl_list) == 0: if len(self.pnl_list) == 0:
@ -1816,6 +1817,9 @@ class BackTestingEngine(object):
result_info.update({u'期初资金': d['init_capital']}) result_info.update({u'期初资金': d['init_capital']})
self.output(u'期初资金:\t%s' % format_number(d['init_capital'])) self.output(u'期初资金:\t%s' % format_number(d['init_capital']))
result_info.update({u'期末资金': d['net_capital']})
self.output(u'期末资金:\t%s' % format_number(d['net_capital']))
result_info.update({u'总盈亏': d['profit']}) result_info.update({u'总盈亏': d['profit']})
self.output(u'总盈亏:\t%s' % format_number(d['profit'])) self.output(u'总盈亏:\t%s' % format_number(d['profit']))

View File

@ -10,6 +10,8 @@ import traceback
import json import json
import pickle import pickle
import bz2 import bz2
import pandas as pd
import numpy as np
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
@ -60,12 +62,16 @@ from vnpy.trader.utility import (
get_folder_path, get_folder_path,
get_underlying_symbol, get_underlying_symbol,
append_data, append_data,
import_module_by_str) import_module_by_str,
get_csv_last_dt)
from vnpy.trader.util_logger import setup_logger, logging from vnpy.trader.util_logger import setup_logger, logging
from vnpy.trader.util_wechat import send_wx_msg from vnpy.trader.util_wechat import send_wx_msg
from vnpy.trader.converter import PositionHolding from vnpy.trader.converter import PositionHolding
from vnpy.data.mongo.mongo_data import MongoData
from vnpy.trader.setting import SETTINGS
from vnpy.data.stock.adjust_factor import get_all_adjust_factor
from vnpy.data.stock.stock_base import get_stock_base
from .base import ( from .base import (
APP_NAME, APP_NAME,
EVENT_CTA_LOG, EVENT_CTA_LOG,
@ -115,9 +121,14 @@ class CtaEngine(BaseEngine):
# "trade_2_wx": true # 是否交易记录转发至微信通知 # "trade_2_wx": true # 是否交易记录转发至微信通知
# "event_log: false # 是否转发日志到event bus显示在图形界面 # "event_log: false # 是否转发日志到event bus显示在图形界面
# "snapshot2file": false # 是否保存切片到文件 # "snapshot2file": false # 是否保存切片到文件
self.engine_config = {} # "get_pos_from_db": false # 是否从数据库中获取所有策略得持仓信息。因为其他RPC进程也运行策略推送到Mongodb)
# "compare_pos": True 是否比对仓位
# "compare_pos_bypass_names": ["name1","name2"] # 比对仓位时,过滤账号中含有字符得股票,例如自动申购得转债等
self.engine_config = load_json(self.engine_filename)
# 是否激活 write_log写入event bus(比较耗资源) # 是否激活 write_log写入event bus(比较耗资源)
self.event_log = False self.event_log = self.engine_config.get('event_log', False)
self.strategy_setting = {} # strategy_name: dict self.strategy_setting = {} # strategy_name: dict
self.strategy_data = {} # strategy_name: dict self.strategy_data = {} # strategy_name: dict
@ -157,6 +168,40 @@ class CtaEngine(BaseEngine):
self.symbol_bar_dict = {} # vt_symbol: bar(一分钟bar) self.symbol_bar_dict = {} # vt_symbol: bar(一分钟bar)
self.stock_adjust_factors = get_all_adjust_factor()
self.mongo_data = None
# 获取全量股票信息
self.write_log(f'获取全量股票信息')
self.symbol_dict = get_stock_base()
self.write_log(f'{len(self.symbol_dict)}个股票')
# 除权因子
self.write_log(f'获取所有除权因子')
self.adjust_factor_dict = get_all_adjust_factor()
self.write_log(f'{len(self.adjust_factor_dict)}条除权信息')
# 寻找数据文件所在目录
vnpy_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
self.write_log(f'项目所在目录:{vnpy_root}')
self.bar_data_folder = os.path.abspath(os.path.join(vnpy_root, 'bar_data'))
if os.path.exists(self.bar_data_folder):
SSE_folder = os.path.abspath(os.path.join(vnpy_root, 'bar_data', 'SSE'))
if os.path.exists(SSE_folder):
self.write_log(f'上交所bar数据目录:{SSE_folder}')
else:
self.write_error(f'不存在上交所数据目录:{SSE_folder}')
SZSE_folder = os.path.abspath(os.path.join(vnpy_root, 'bar_data', 'SZSE'))
if os.path.exists(SZSE_folder):
self.write_log(f'深交所bar数据目录:{SZSE_folder}')
else:
self.write_error(f'不存在深交所数据目录:{SZSE_folder}')
else:
self.write_error(f'不存在bar数据目录:{self.bar_data_folder}')
self.bar_data_folder = None
def init_engine(self): def init_engine(self):
""" """
""" """
@ -168,6 +213,25 @@ class CtaEngine(BaseEngine):
self.write_log("CTA策略股票引擎初始化成功") self.write_log("CTA策略股票引擎初始化成功")
if self.engine_config.get('get_pos_from_db',False):
self.init_mongo_data()
def init_mongo_data(self):
"""初始化hams数据库"""
host = SETTINGS.get('hams.host', 'localhost')
port = SETTINGS.get('hams.port', 27017)
self.write_log(f'初始化hams数据库连接:{host}:{port}')
try:
# Mongo数据连接客户端
self.mongo_data = MongoData(host=host, port=port)
if self.mongo_data and self.mongo_data.db_has_connected:
self.write_log(f'连接成功')
else:
self.write_error(f'HAMS数据库{host}:{port}连接异常.')
except Exception as ex:
self.write_error(f'HAMS数据库{host}:{port}连接异常.{str(ex)}')
def close(self): def close(self):
"""停止所属有的策略""" """停止所属有的策略"""
self.stop_all_strategies() self.stop_all_strategies()
@ -227,12 +291,12 @@ class CtaEngine(BaseEngine):
self.health_check() self.health_check()
# 在国内股市开盘期间做检查 # 在国内股市开盘期间做检查
if '0930' < self.last_minute < '1530': if '0930' < self.last_minute < '1830':
# 主动获取所有策略得持仓信息 # 主动获取所有策略得持仓信息
all_strategy_pos = self.get_all_strategy_pos() all_strategy_pos = self.get_all_strategy_pos()
# 每5分钟检查一次 # 每5分钟检查一次
if dt.minute % 10 == 0: if dt.minute % 10 == 0 and self.engine_config.get('compare_pos', True):
# 比对仓位,使用上述获取得持仓信息,不用重复获取 # 比对仓位,使用上述获取得持仓信息,不用重复获取
self.compare_pos(strategy_pos_list=copy(all_strategy_pos)) self.compare_pos(strategy_pos_list=copy(all_strategy_pos))
@ -881,6 +945,25 @@ class CtaEngine(BaseEngine):
def get_contract(self, vt_symbol): def get_contract(self, vt_symbol):
return self.main_engine.get_contract(vt_symbol) return self.main_engine.get_contract(vt_symbol)
def get_adjust_factor(self, vt_symbol, check_date=None):
"""
获取[check_date前]除权因子
:param vt_symbol:
:param check_date: 某一指定日期
:return:
"""
stock_adjust_factor_list = self.stock_adjust_factors.get(vt_symbol, [])
if len(stock_adjust_factor_list) == 0:
return None
stock_adjust_factor_list.reverse()
if check_date is None:
check_date = datetime.now().strftime('%Y-%m-%d')
for d in stock_adjust_factor_list:
if d.get("dividOperateDate","") < check_date:
return d
return None
def get_account(self, vt_accountid: str = ""): def get_account(self, vt_accountid: str = ""):
""" 查询账号的资金""" """ 查询账号的资金"""
# 如果启动风控,则使用风控中的最大仓位 # 如果启动风控,则使用风控中的最大仓位
@ -955,6 +1038,253 @@ class CtaEngine(BaseEngine):
callback(bar) callback(bar)
def get_bars(
self,
vt_symbol: str,
days: int,
interval: Interval,
interval_num: int = 1
):
"""获取历史记录"""
symbol, exchange = extract_vt_symbol(vt_symbol)
end = datetime.now()
start = end - timedelta(days)
bars = []
# 检查股票代码
if vt_symbol not in self.symbol_dict:
self.write_error(f'{vt_symbol}不在基础配置股票信息中')
return bars
# 检查数据文件目录
if not self.bar_data_folder:
self.write_error(f'没有bar数据目录')
return bars
# 按照交易所的存放目录
bar_file_folder = os.path.abspath(os.path.join(self.bar_data_folder, f'{exchange.value}'))
resample_min = False
resample_hour = False
resample_day = False
file_interval_num = 1
# 只有1,5,15,30分钟日线数据
if interval == Interval.MINUTE:
# 如果存在相应的分钟文件,直接读取
bar_file_path = os.path.abspath(os.path.join(bar_file_folder, f'{symbol}_{interval_num}m.csv'))
if interval_num in [1, 5, 15, 30] and os.path.exists(bar_file_path):
file_interval_num = interval
# 需要resample
else:
resample_min = True
if interval_num > 5:
file_interval_num = 5
elif interval == Interval.HOUR:
file_interval_num = 5
resample_hour = True
bar_file_path = os.path.abspath(os.path.join(bar_file_folder, f'{symbol}_{file_interval_num}m.csv'))
elif interval == Interval.DAILY:
bar_file_path = os.path.abspath(os.path.join(bar_file_folder, f'{symbol}_{interval_num}d.csv'))
if not os.path.exists(bar_file_path):
file_interval_num = 5
resample_day = True
bar_file_path = os.path.abspath(os.path.join(bar_file_folder, f'{symbol}_{file_interval_num}m.csv'))
else:
self.write_error(f'目前仅支持分钟,小时,日线数据')
return bars
bar_interval_seconds = interval_num * 60
if not os.path.exists(bar_file_path):
self.write_error(f'没有bar数据文件{bar_file_path}')
return bars
try:
data_types = {
"datetime": str,
"open": float,
"high": float,
"low": float,
"close": float,
"volume": float,
"amount": float,
"symbol": str,
"trading_day": str,
"date": str,
"time": str
}
symbol_df = None
qfq_bar_file_path = bar_file_path.replace('.csv', '_qfq.csv')
use_qfq_file = False
last_qfq_dt = get_csv_last_dt(qfq_bar_file_path)
if last_qfq_dt is not None:
last_dt = get_csv_last_dt(bar_file_path)
if last_qfq_dt == last_dt:
use_qfq_file = True
if use_qfq_file:
self.write_log(f'使用前复权文件:{qfq_bar_file_path}')
symbol_df = pd.read_csv(qfq_bar_file_path, dtype=data_types)
else:
# 加载csv文件 =》 dateframe
self.write_log(f'使用未复权文件:{bar_file_path}')
symbol_df = pd.read_csv(bar_file_path, dtype=data_types)
# 转换时间str =》 datetime
symbol_df["datetime"] = pd.to_datetime(symbol_df["datetime"], format="%Y-%m-%d %H:%M:%S")
# 设置时间为索引
symbol_df = symbol_df.set_index("datetime")
# 裁剪数据
symbol_df = symbol_df.loc[start:end]
if resample_day:
self.write_log(f'{vt_symbol} resample:{file_interval_num}m => {interval}day')
symbol_df = self.resample_bars(df=symbol_df, to_day=True)
elif resample_hour:
self.write_log(f'{vt_symbol} resample:{file_interval_num}m => {interval}hour')
symbol_df = self.resample_bars(df=symbol_df, x_hour=interval_num)
elif resample_min:
self.write_log(f'{vt_symbol} resample:{file_interval_num}m => {interval}m')
symbol_df = self.resample_bars(df=symbol_df, x_min=interval_num)
if len(symbol_df) == 0:
return bars
if not use_qfq_file:
# 复权转换
adj_list = self.adjust_factor_dict.get(vt_symbol, [])
# 按照结束日期,裁剪复权记录
adj_list = [row for row in adj_list if row['dividOperateDate'].replace('-', '') <= end.strftime('%Y%m%d')]
if len(adj_list) > 0:
self.write_log(f'需要对{vt_symbol}进行前复权处理')
for row in adj_list:
row.update({'dividOperateDate': row.get('dividOperateDate')[:10] + ' 09:31:00'})
# list -> dataframe, 转换复权日期格式
adj_data = pd.DataFrame(adj_list)
adj_data["dividOperateDate"] = pd.to_datetime(adj_data["dividOperateDate"], format="%Y-%m-%d %H:%M:%S")
adj_data = adj_data.set_index("dividOperateDate")
# 调用转换方法对open,high,low,close, volume进行复权, fore, 前复权, 其他,后复权
symbol_df = self.stock_to_adj(symbol_df, adj_data, adj_type='fore')
for dt, bar_data in symbol_df.iterrows():
bar_datetime = dt #- timedelta(seconds=bar_interval_seconds)
bar = BarData(
gateway_name='backtesting',
symbol=symbol,
exchange=exchange,
datetime=bar_datetime
)
if 'open' in bar_data:
bar.open_price = float(bar_data['open'])
bar.close_price = float(bar_data['close'])
bar.high_price = float(bar_data['high'])
bar.low_price = float(bar_data['low'])
else:
bar.open_price = float(bar_data['open_price'])
bar.close_price = float(bar_data['close_price'])
bar.high_price = float(bar_data['high_price'])
bar.low_price = float(bar_data['low_price'])
bar.volume = int(bar_data['volume']) if not np.isnan(bar_data['volume']) else 0
bar.date = dt.strftime('%Y-%m-%d')
bar.time = dt.strftime('%H:%M:%S')
str_td = str(bar_data.get('trading_day', ''))
if len(str_td) == 8:
bar.trading_day = str_td[0:4] + '-' + str_td[4:6] + '-' + str_td[6:8]
else:
bar.trading_day = bar.date
bars.append(bar)
except Exception as ex:
self.write_error(u'回测时读取{} csv文件{}失败:{}'.format(vt_symbol, bar_file_path, ex))
self.write_error(traceback.format_exc())
return bars
return bars
def stock_to_adj(self, raw_data, adj_data, adj_type):
"""
股票数据复权转换
:param raw_data: 不复权数据
:param adj_data: 复权记录 ( 从barstock下载的复权记录列表=df
:param adj_type: 复权类型
:return:
"""
if adj_type == 'fore':
adj_factor = adj_data["foreAdjustFactor"]
adj_factor = adj_factor / adj_factor.iloc[-1] # 保证最后一个复权因子是1
else:
adj_factor = adj_data["backAdjustFactor"]
adj_factor = adj_factor / adj_factor.iloc[0] # 保证第一个复权因子是1
# 把raw_data的第一个日期插入复权因子df使用后填充
if adj_factor.index[0] != raw_data.index[0]:
adj_factor.loc[raw_data.index[0]] = np.nan
adj_factor.sort_index(inplace=True)
adj_factor = adj_factor.ffill()
adj_factor = adj_factor.reindex(index=raw_data.index) # 按价格dataframe的日期索引来扩展索引
adj_factor = adj_factor.ffill() # 向前(向未来)填充扩展后的空单元格
# 把复权因子作为adj字段补充到raw_data中
raw_data['adj'] = adj_factor
# 逐一复权高低开平和成交量
for col in ['open', 'high', 'low', 'close']:
raw_data[col] = raw_data[col] * raw_data['adj'] # 价格乘上复权系数
raw_data['volume'] = raw_data['volume'] / raw_data['adj'] # 成交量除以复权系数
return raw_data
def resample_bars(self, df, x_min=None, x_hour=None, to_day=False):
"""
重建x分钟K线或日线
:param df: 输入分钟数
:param x_min: 5, 15, 30, 60
:param x_hour: 1, 2, 3, 4
:param include_day: 重建日线, True得时候不会重建分钟数
:return:
"""
# 设置df数据中每列的规则
ohlc_rule = {
'open': 'first', # open列序列中第一个的值
'high': 'max', # high列序列中最大的值
'low': 'min', # low列序列中最小的值
'close': 'last', # close列序列中最后一个的值
'volume': 'sum', # volume列将所有序列里的volume值作和
'amount': 'sum', # amount列将所有序列里的amount值作和
"symbol": 'first',
"trading_date": 'first',
"date": 'first',
"time": 'first'
}
if isinstance(x_min, int) and not to_day:
# 合成x分钟K线并删除为空的行 参数 closedleft类似向上取值既 0930的k线数据是包含0930-0935之间的数据
df_target = df.resample(f'{x_min}min', closed='left', label='left').agg(ohlc_rule).dropna(axis=0,
how='any')
return df_target
if isinstance(x_hour, int) and not to_day:
# 合成x小时K线并删除为空的行 参数 closedleft类似向上取值既 0930的k线数据是包含0930-0935之间的数据
df_target = df.resample(f'{x_hour}hour', closed='left', label='left').agg(ohlc_rule).dropna(axis=0,
how='any')
return df_target
if to_day:
# 合成x分钟K线并删除为空的行 参数 closedleft类似向上取值既 0930的k线数据是包含0930-0935之间的数据
df_target = df.resample(f'D', closed='left', label='left').agg(ohlc_rule).dropna(axis=0, how='any')
return df_target
return df
def call_strategy_func( def call_strategy_func(
self, strategy: CtaTemplate, func: Callable, params: Any = None self, strategy: CtaTemplate, func: Callable, params: Any = None
): ):
@ -1153,7 +1483,8 @@ class CtaEngine(BaseEngine):
# Remove from symbol strategy map # Remove from symbol strategy map
self.write_log(f'移除{vt_symbol}《=》{strategy_name}的订阅关系') self.write_log(f'移除{vt_symbol}《=》{strategy_name}的订阅关系')
strategies = self.symbol_strategy_map[vt_symbol] strategies = self.symbol_strategy_map[vt_symbol]
strategies.remove(strategy) if strategy in strategies:
strategies.remove(strategy)
# Remove from active orderid map # Remove from active orderid map
if strategy_name in self.strategy_orderid_map: if strategy_name in self.strategy_orderid_map:
@ -1528,6 +1859,27 @@ class CtaEngine(BaseEngine):
return strategy_pos_list return strategy_pos_list
def get_all_strategy_pos_from_hams(self):
"""
获取hams中该账号下所有策略仓位明细
"""
strategy_pos_list = []
if not self.mongo_data:
self.init_mongo_data()
if self.mongo_data and self.mongo_data.db_has_connected:
filter = {'account_id':self.engine_config.get('accountid','-')}
pos_list = self.mongo_data.db_query(
db_name='Account',
col_name='today_strategy_pos',
filter_dict=filter
)
for pos in pos_list:
strategy_pos_list.append(pos)
return strategy_pos_list
def get_strategy_class_parameters(self, class_name: str): def get_strategy_class_parameters(self, class_name: str):
""" """
Get default parameters of a strategy class. Get default parameters of a strategy class.
@ -1580,9 +1932,17 @@ class CtaEngine(BaseEngine):
self.write_log(u'开始对比账号&策略的持仓') self.write_log(u'开始对比账号&策略的持仓')
# 获取当前策略得持仓 # 获取hams数据库中所有运行实例得策略
if len(strategy_pos_list) == 0: if self.engine_config.get("get_pos_from_db", False):
strategy_pos_list = self.get_all_strategy_pos() strategy_pos_list = self.get_all_strategy_pos_from_hams()
else:
# 获取当前实例运行策略得持仓
if len(strategy_pos_list) == 0:
strategy_pos_list = self.get_all_strategy_pos()
# 过滤账号中名字列表
bypass_names = self.engine_config.get('compare_pos_bypass_names',[])
self.write_log(u'策略持仓清单:{}'.format(strategy_pos_list)) self.write_log(u'策略持仓清单:{}'.format(strategy_pos_list))
none_strategy_pos = self.get_none_strategy_pos_list() none_strategy_pos = self.get_none_strategy_pos_list()
@ -1597,6 +1957,12 @@ class CtaEngine(BaseEngine):
for position in list(self.positions.values()): for position in list(self.positions.values()):
# gateway_name.symbol.exchange => symbol.exchange # gateway_name.symbol.exchange => symbol.exchange
vt_symbol = position.vt_symbol vt_symbol = position.vt_symbol
cn_name = self.get_name(vt_symbol)
# 中文名字,在过滤清单中
if any([name in cn_name for name in bypass_names]):
continue
vt_symbols.add(vt_symbol) vt_symbols.add(vt_symbol)
compare_pos[vt_symbol] = OrderedDict( compare_pos[vt_symbol] = OrderedDict(
@ -1631,6 +1997,8 @@ class CtaEngine(BaseEngine):
u'{}({})'.format(strategy_pos['strategy_name'], abs(pos.get('volume', 0)))) u'{}({})'.format(strategy_pos['strategy_name'], abs(pos.get('volume', 0))))
self.write_log(u'更新{}策略持多仓=>{}'.format(vt_symbol, symbol_pos.get('策略多单', 0))) self.write_log(u'更新{}策略持多仓=>{}'.format(vt_symbol, symbol_pos.get('策略多单', 0)))
compare_pos.update({vt_symbol:symbol_pos})
pos_compare_result = '' pos_compare_result = ''
# 精简输出 # 精简输出
compare_info = '' compare_info = ''
@ -1724,14 +2092,7 @@ class CtaEngine(BaseEngine):
""" """
Load setting file. Load setting file.
""" """
# 读取引擎得配置
# "accountid" : "xxxx", 资金账号,一般用于推送消息时附带
# "strategy_group": "cta_strategy_pro", # 当前实例名。多个实例时,区分开
# "trade_2_wx": true # 是否交易记录转发至微信通知
# "event_log: false # 是否转发日志到event bus显示在图形界面
self.engine_config = load_json(self.engine_filename)
# 是否产生event log 日志一般GUI界面才产生而且比好消耗资源)
self.event_log = self.engine_config.get('event_log', False)
# 读取策略得配置 # 读取策略得配置
self.strategy_setting = load_json(self.setting_filename) self.strategy_setting = load_json(self.setting_filename)

View File

@ -15,6 +15,7 @@ import traceback
import random import random
import bz2 import bz2
import pickle import pickle
import numpy as np
from datetime import datetime, timedelta from datetime import datetime, timedelta
from time import sleep from time import sleep
@ -31,6 +32,7 @@ from vnpy.trader.constant import (
from vnpy.trader.utility import ( from vnpy.trader.utility import (
get_trading_date, get_trading_date,
extract_vt_symbol, extract_vt_symbol,
get_csv_last_dt
) )
from .back_testing import BackTestingEngine from .back_testing import BackTestingEngine
@ -57,7 +59,7 @@ class PortfolioTestingEngine(BackTestingEngine):
self.tick_path = None # tick级别回测 路径 self.tick_path = None # tick级别回测 路径
def load_bar_csv_to_df(self, vt_symbol, bar_file, data_start_date=None, data_end_date=None): def load_bar_csv_to_df(self, vt_symbol, bar_file, data_start_date=None, data_end_date=None, qfq=True):
""" """
加载回测bar数据到DataFrame 加载回测bar数据到DataFrame
1. 增加前复权/后复权 1. 增加前复权/后复权
@ -65,14 +67,16 @@ class PortfolioTestingEngine(BackTestingEngine):
:param bar_file: :param bar_file:
:param data_start_date: :param data_start_date:
:param data_end_date: :param data_end_date:
:param qfq:True 前复权False 后复权
:return: :return:
""" """
self.output(u'loading {} from {}'.format(vt_symbol, bar_file)) fq_name = '前复权' if qfq else '后复权'
self.output(u'加载数据[{}] :未复权文件{},复权转换:{}'.format(vt_symbol, bar_file, fq_name))
if vt_symbol in self.bar_df_dict: if vt_symbol in self.bar_df_dict:
return True return True
if bar_file is None or not os.path.exists(bar_file): if bar_file is None or not os.path.exists(bar_file):
self.write_error(u'回测时,{}对应的csv bar文件{}不存在'.format(vt_symbol, bar_file)) self.write_error(u'加载数据[{}]:对应的csv bar文件{}不存在'.format(vt_symbol, bar_file))
return False return False
try: try:
@ -93,34 +97,64 @@ class PortfolioTestingEngine(BackTestingEngine):
"date": str, "date": str,
"time": str "time": str
} }
# 加载csv文件 =》 dateframe
symbol_df = pd.read_csv(bar_file, dtype=data_types)
# 转换时间str =》 datetime
symbol_df["datetime"] = pd.to_datetime(symbol_df["datetime"], format="%Y-%m-%d %H:%M:%S")
# 设置时间为索引
symbol_df = symbol_df.set_index("datetime")
# 裁剪数据 symbol_df = None
symbol_df = symbol_df.loc[self.test_start_date:self.test_end_date] auto_generate_fq = True
# 复权文件
fq_bar_file = bar_file.replace('.csv', '_qfq.csv' if qfq else '_hfq.csv')
if os.path.exists(fq_bar_file):
# 存在复权文件
last_dt = get_csv_last_dt(fq_bar_file)
if isinstance(last_dt, datetime):
if last_dt.strftime('%Y-%m-%d') < self.test_end_date:
self.write_log(f'加载数据[{vt_symbol}], 使用{fq_name}文件:{fq_bar_file}')
symbol_df = pd.read_csv(bar_file, dtype=data_types)
# 转换时间str =》 datetime
symbol_df["datetime"] = pd.to_datetime(symbol_df["datetime"], format="%Y-%m-%d %H:%M:%S")
# 设置时间为索引
symbol_df = symbol_df.set_index("datetime")
# 裁剪数据
symbol_df = symbol_df.loc[self.test_start_date:self.test_end_date]
# 不再产生复权文件
auto_generate_fq = False
# 复权转换 if not isinstance(symbol_df, pd.DataFrame):
adj_list = self.adjust_factors.get(vt_symbol, []) # 加载csv文件 =》 dateframe
# 按照结束日期,裁剪复权记录 symbol_df = pd.read_csv(bar_file, dtype=data_types)
adj_list = [row for row in adj_list if row['dividOperateDate'].replace('-', '') <= self.test_end_date] # 转换时间str =》 datetime
symbol_df["datetime"] = pd.to_datetime(symbol_df["datetime"], format="%Y-%m-%d %H:%M:%S")
# 设置时间为索引
symbol_df = symbol_df.set_index("datetime")
if adj_list: # 裁剪数据
self.write_log(f'需要对{vt_symbol}进行前复权处理') symbol_df = symbol_df.loc[self.test_start_date:self.test_end_date]
for row in adj_list:
row.update({'dividOperateDate': row.get('dividOperateDate') + ' 09:31:00'})
# list -> dataframe, 转换复权日期格式
adj_data = pd.DataFrame(adj_list)
adj_data["dividOperateDate"] = pd.to_datetime(adj_data["dividOperateDate"], format="%Y-%m-%d %H:%M:%S")
adj_data = adj_data.set_index("dividOperateDate")
# 调用转换方法对open,high,low,close, volume进行复权, fore, 前复权, 其他,后复权
symbol_df = self.stock_to_adj(symbol_df, adj_data, adj_type='fore')
# 添加到待合并dataframe dict中 # 复权转换
self.bar_df_dict.update({vt_symbol: symbol_df}) adj_list = self.adjust_factors.get(vt_symbol, [])
# 按照结束日期,裁剪复权记录
adj_list = [row for row in adj_list if row['dividOperateDate'].replace('-', '') <= self.test_end_date]
if adj_list:
self.write_log(f'加载数据[{vt_symbol}], 对{vt_symbol}进行{fq_name}处理')
for row in adj_list:
d = row.get('dividOperateDate', "")[0:10]
if len(d) == 10:
row.update({'dividOperateDate': d + ' 09:31:00'})
# list -> dataframe, 转换复权日期格式
adj_data = pd.DataFrame(adj_list)
adj_data["dividOperateDate"] = pd.to_datetime(adj_data["dividOperateDate"],
format="%Y-%m-%d %H:%M:%S")
adj_data = adj_data.set_index("dividOperateDate")
# 调用转换方法对open,high,low,close, volume进行复权, fore, 前复权, 其他,后复权
symbol_df = self.stock_to_adj(symbol_df, adj_data, adj_type='fore' if qfq else "")
if auto_generate_fq:
self.write_log(f'加载数据[{vt_symbol}] ,缓存{fq_name}文件=>{fq_bar_file}')
symbol_df.to_csv(fq_bar_file)
if isinstance(symbol_df, pd.DataFrame):
# 添加到待合并dataframe dict中
self.bar_df_dict.update({vt_symbol: symbol_df})
except Exception as ex: except Exception as ex:
self.write_error(u'回测时读取{} csv文件{}失败:{}'.format(vt_symbol, bar_file, ex)) self.write_error(u'回测时读取{} csv文件{}失败:{}'.format(vt_symbol, bar_file, ex))
@ -277,7 +311,7 @@ class PortfolioTestingEngine(BackTestingEngine):
bar.high_price = float(bar_data['high_price']) bar.high_price = float(bar_data['high_price'])
bar.low_price = float(bar_data['low_price']) bar.low_price = float(bar_data['low_price'])
bar.volume = int(bar_data['volume']) bar.volume = float(bar_data['volume'])
bar.date = dt.strftime('%Y-%m-%d') bar.date = dt.strftime('%Y-%m-%d')
bar.time = dt.strftime('%H:%M:%S') bar.time = dt.strftime('%H:%M:%S')
str_td = str(bar_data.get('trading_day', '')) str_td = str(bar_data.get('trading_day', ''))
@ -299,6 +333,8 @@ class PortfolioTestingEngine(BackTestingEngine):
# bar时间与队列时间不一致先推送队列的bars # bar时间与队列时间不一致先推送队列的bars
random.shuffle(bars_same_dt) random.shuffle(bars_same_dt)
for _bar_ in bars_same_dt: for _bar_ in bars_same_dt:
if np.isnan(_bar_.close_price):
continue
self.new_bar(_bar_) self.new_bar(_bar_)
# 创建新的队列 # 创建新的队列

View File

@ -476,40 +476,68 @@ class CtaStockTemplate(CtaTemplate):
self.write_log(u'保存policy数据') self.write_log(u'保存policy数据')
self.policy.save() self.policy.save()
def save_klines_to_cache(self, kline_names: list = []): def save_klines_to_cache(self, kline_names: list = [], vt_symbol: str = ""):
""" """
保存K线数据到缓存 保存K线数据到缓存
:param kline_names: 一般为self.klines的keys :param kline_names: 一般为self.klines的keys
:param vt_symbol: 指定股票代码,
如果使用该选项加载 data/klines/strategyname_vtsymbol_klines.pkb2
如果空白加载 data/strategyname_klines.pkb2
:return: :return:
""" """
if len(kline_names) == 0: if len(kline_names) == 0:
kline_names = list(self.klines.keys()) kline_names = list(self.klines.keys())
# 获取保存路径 try:
save_path = self.cta_engine.get_data_path() # 如果是指定合约的话使用klines子目录
# 保存缓存的文件名 if len(vt_symbol) > 0:
file_name = os.path.abspath(os.path.join(save_path, f'{self.strategy_name}_klines.pkb2')) kline_names = [n for n in kline_names if vt_symbol in n]
with bz2.BZ2File(file_name, 'wb') as f: save_path = os.path.abspath(os.path.join(self.cta_engine.get_data_path(), 'klines'))
klines = {} if not os.path.exists(save_path):
for kline_name in kline_names: os.makedirs(save_path)
kline = self.klines.get(kline_name, None) file_name = os.path.abspath(os.path.join(save_path, f'{self.strategy_name}_{vt_symbol}_klines.pkb2'))
# if kline: else:
# kline.strategy = None # 获取保存路径
# kline.cb_on_bar = None save_path = self.cta_engine.get_data_path()
klines.update({kline_name: kline}) # 保存缓存的文件名
pickle.dump(klines, f) file_name = os.path.abspath(os.path.join(save_path, f'{self.strategy_name}_klines.pkb2'))
def load_klines_from_cache(self, kline_names: list = []): with bz2.BZ2File(file_name, 'wb') as f:
klines = {}
for kline_name in kline_names:
kline = self.klines.get(kline_name, None)
# if kline:
# kline.strategy = None
# kline.cb_on_bar = None
klines.update({kline_name: kline})
pickle.dump(klines, f)
self.write_log(f'保存{vt_symbol} K线数据成功=>{file_name}')
except Exception as ex:
self.write_error(f'保存k线数据异常:{str(ex)}')
self.write_error(traceback.format_exc())
def load_klines_from_cache(self, kline_names: list = [], vt_symbol: str = ""):
""" """
从缓存加载K线数据 从缓存加载K线数据
:param kline_names: 指定需要加载的k线名称列表 :param kline_names: 指定需要加载的k线名称列表
:param vt_symbol: 指定股票代码,
如果使用该选项加载 data/klines/strategyname_vtsymbol_klines.pkb2
如果空白加载 data/strategyname_klines.pkb2
:return: :return:
""" """
if len(kline_names) == 0: if len(kline_names) == 0:
kline_names = list(self.klines.keys()) kline_names = list(self.klines.keys())
save_path = self.cta_engine.get_data_path() # 如果是指定合约的话使用klines子目录
file_name = os.path.abspath(os.path.join(save_path, f'{self.strategy_name}_klines.pkb2')) if len(vt_symbol) > 0:
kline_names = [n for n in kline_names if vt_symbol in n]
save_path = os.path.abspath(os.path.join(self.cta_engine.get_data_path(), 'klines'))
if not os.path.exists(save_path):
os.makedirs(save_path)
file_name = os.path.abspath(os.path.join(save_path, f'{self.strategy_name}_{vt_symbol}_klines.pkb2'))
else:
save_path = self.cta_engine.get_data_path()
file_name = os.path.abspath(os.path.join(save_path, f'{self.strategy_name}_klines.pkb2'))
try: try:
last_bar_dt = None last_bar_dt = None
with bz2.BZ2File(file_name, 'rb') as f: with bz2.BZ2File(file_name, 'rb') as f:
@ -976,7 +1004,7 @@ class CtaStockTemplate(CtaTemplate):
self.gt.remove_grids_by_ids(direction=Direction.LONG, ids=remove_gids) self.gt.remove_grids_by_ids(direction=Direction.LONG, ids=remove_gids)
self.gt.save() self.gt.save()
def tns_excute_sell_grids(self, vt_symbol=None): def tns_excute_sell_grids(self, vt_symbol=None, force=False):
""" """
事务执行卖出网格 事务执行卖出网格
1找出所有order_status=True,open_status=Talse, close_status=True的网格 1找出所有order_status=True,open_status=Talse, close_status=True的网格
@ -1027,17 +1055,27 @@ class CtaStockTemplate(CtaTemplate):
sell_volume = ordering_grid.volume - ordering_grid.traded_volume sell_volume = ordering_grid.volume - ordering_grid.traded_volume
if sell_volume > acc_symbol_pos.volume: if sell_volume > acc_symbol_pos.volume:
self.write_error(u'账号{}持仓{},不满足减仓目标:{}' if not force:
.format(vt_symbol, acc_symbol_pos.volume, sell_volume)) self.write_error(u'账号{}持仓{},不满足减仓目标:{}'
.format(vt_symbol, acc_symbol_pos.volume, sell_volume))
continue
else:
self.write_log(u'账号{}持仓{},不满足减仓目标:{}, 修正卖出数量:{}=>{}'
.format(vt_symbol, acc_symbol_pos.volume, sell_volume, sell_volume,
acc_symbol_pos.volume))
sell_volume = acc_symbol_pos.volume
if sell_volume == 0:
self.write_log(f'账号{vt_symbol}持仓{acc_symbol_pos.volume},卖出目标:{sell_volume}=0 不执行')
continue
cur_price = self.cta_engine.get_price(vt_symbol)
if not cur_price:
self.cta_engine.subscribe_symbol(strategy_name=self.strategy_name, vt_symbol=vt_symbol)
continue continue
# 实盘运行时,要加入市场买卖量的判断 # 实盘运行时,要加入市场买卖量的判断
if not self.backtesting: if not force and not self.backtesting:
cur_price = self.cta_engine.get_price(vt_symbol)
if not cur_price:
self.cta_engine.subscribe_symbol(strategy_name=self.strategy_name, vt_symbol=vt_symbol)
continue
symbol_tick = self.cta_engine.get_tick(vt_symbol) symbol_tick = self.cta_engine.get_tick(vt_symbol)
if symbol_tick: if symbol_tick:
symbol_volume_tick = self.cta_engine.get_volume_tick(vt_symbol) symbol_volume_tick = self.cta_engine.get_volume_tick(vt_symbol)
@ -1072,7 +1110,7 @@ class CtaStockTemplate(CtaTemplate):
self.write_log(f'{vt_symbol} 已委托卖出,{sell_volume},委托价:{sell_price}, 数量:{sell_volume}') self.write_log(f'{vt_symbol} 已委托卖出,{sell_volume},委托价:{sell_price}, 数量:{sell_volume}')
def tns_finish_sell_grid(self, grid): def tns_finish_sell_grid(self, grid:CtaGrid):
""" """
事务完成卖出网格 事务完成卖出网格
:param grid: :param grid:
@ -1106,7 +1144,7 @@ class CtaStockTemplate(CtaTemplate):
self.gt.save() self.gt.save()
self.policy.save() self.policy.save()
def tns_execute_buy_grids(self, vt_symbol=None): def tns_execute_buy_grids(self, vt_symbol=None, force=False):
""" """
事务执行买入网格 事务执行买入网格
:return: :return:
@ -1149,7 +1187,8 @@ class CtaStockTemplate(CtaTemplate):
balance, availiable, _, _ = self.cta_engine.get_account() balance, availiable, _, _ = self.cta_engine.get_account()
if availiable <= 0: if availiable <= 0:
self.write_error(u'当前可用资金不足'.format(availiable)) if not self.backtesting:
self.write_log(u'当前可用资金{}不足,总资金:{}'.format(availiable, balance))
continue continue
vt_symbol = ordering_grid.vt_symbol vt_symbol = ordering_grid.vt_symbol
cur_price = self.cta_engine.get_price(vt_symbol) cur_price = self.cta_engine.get_price(vt_symbol)
@ -1178,7 +1217,7 @@ class CtaStockTemplate(CtaTemplate):
buy_volume = max_buy_volume buy_volume = max_buy_volume
# 实盘运行时,要加入市场买卖量的判断 # 实盘运行时,要加入市场买卖量的判断
if not self.backtesting and 'market' in ordering_grid.snapshot: if not force and not self.backtesting and 'market' in ordering_grid.snapshot:
symbol_tick = self.cta_engine.get_tick(vt_symbol) symbol_tick = self.cta_engine.get_tick(vt_symbol)
if symbol_tick: if symbol_tick:
# 根据市场计算前5档买单数量 # 根据市场计算前5档买单数量
@ -1207,7 +1246,7 @@ class CtaStockTemplate(CtaTemplate):
else: else:
self.write_log(f'{self.strategy_name}, {vt_orderids},已委托买入,{vt_symbol} 委托价:{buy_price} 数量:{buy_volume}') self.write_log(f'{self.strategy_name}, {vt_orderids},已委托买入,{vt_symbol} 委托价:{buy_price} 数量:{buy_volume}')
def tns_finish_buy_grid(self, grid): def tns_finish_buy_grid(self, grid:CtaGrid):
""" """
事务完成买入网格 事务完成买入网格
:return: :return: