[update] 股票引擎。增加复权缓存,K线缓存
This commit is contained in:
parent
aa865aa38a
commit
344f877bda
@ -589,7 +589,7 @@ class BackTestingEngine(object):
|
||||
股票数据复权转换
|
||||
:param raw_data: 不复权数据
|
||||
:param adj_data: 复权记录 ( 从barstock下载的复权记录列表=》df)
|
||||
:param adj_type: 复权类型
|
||||
:param adj_type: 复权类型: fore 前复权, 其他:后复权
|
||||
:return:
|
||||
"""
|
||||
|
||||
@ -1735,7 +1735,8 @@ class BackTestingEngine(object):
|
||||
# 返回回测结果
|
||||
d = {}
|
||||
d['init_capital'] = self.init_capital
|
||||
d['profit'] = self.cur_capital - self.init_capital
|
||||
d['net_capital'] = self.net_capital
|
||||
d['profit'] = self.net_capital - self.init_capital
|
||||
d['max_capital'] = self.max_net_capital # 取消原 maxCapital
|
||||
|
||||
if len(self.pnl_list) == 0:
|
||||
@ -1816,6 +1817,9 @@ class BackTestingEngine(object):
|
||||
result_info.update({u'期初资金': d['init_capital']})
|
||||
self.output(u'期初资金:\t%s' % format_number(d['init_capital']))
|
||||
|
||||
result_info.update({u'期末资金': d['net_capital']})
|
||||
self.output(u'期末资金:\t%s' % format_number(d['net_capital']))
|
||||
|
||||
result_info.update({u'总盈亏': d['profit']})
|
||||
self.output(u'总盈亏:\t%s' % format_number(d['profit']))
|
||||
|
||||
|
@ -10,6 +10,8 @@ import traceback
|
||||
import json
|
||||
import pickle
|
||||
import bz2
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
@ -60,12 +62,16 @@ from vnpy.trader.utility import (
|
||||
get_folder_path,
|
||||
get_underlying_symbol,
|
||||
append_data,
|
||||
import_module_by_str)
|
||||
import_module_by_str,
|
||||
get_csv_last_dt)
|
||||
|
||||
from vnpy.trader.util_logger import setup_logger, logging
|
||||
from vnpy.trader.util_wechat import send_wx_msg
|
||||
from vnpy.trader.converter import PositionHolding
|
||||
|
||||
from vnpy.data.mongo.mongo_data import MongoData
|
||||
from vnpy.trader.setting import SETTINGS
|
||||
from vnpy.data.stock.adjust_factor import get_all_adjust_factor
|
||||
from vnpy.data.stock.stock_base import get_stock_base
|
||||
from .base import (
|
||||
APP_NAME,
|
||||
EVENT_CTA_LOG,
|
||||
@ -115,9 +121,14 @@ class CtaEngine(BaseEngine):
|
||||
# "trade_2_wx": true # 是否交易记录转发至微信通知
|
||||
# "event_log: false # 是否转发日志到event bus,显示在图形界面
|
||||
# "snapshot2file": false # 是否保存切片到文件
|
||||
self.engine_config = {}
|
||||
# "get_pos_from_db": false # 是否从数据库中获取所有策略得持仓信息。(因为其他RPC进程也运行策略,推送到Mongodb)
|
||||
# "compare_pos": True, 是否比对仓位
|
||||
# "compare_pos_bypass_names": ["name1","name2"] # 比对仓位时,过滤账号中含有字符得股票,例如自动申购得转债等
|
||||
self.engine_config = load_json(self.engine_filename)
|
||||
|
||||
# 是否激活 write_log写入event bus(比较耗资源)
|
||||
self.event_log = False
|
||||
self.event_log = self.engine_config.get('event_log', False)
|
||||
|
||||
self.strategy_setting = {} # strategy_name: dict
|
||||
self.strategy_data = {} # strategy_name: dict
|
||||
|
||||
@ -157,6 +168,40 @@ class CtaEngine(BaseEngine):
|
||||
|
||||
self.symbol_bar_dict = {} # vt_symbol: bar(一分钟bar)
|
||||
|
||||
self.stock_adjust_factors = get_all_adjust_factor()
|
||||
|
||||
self.mongo_data = None
|
||||
|
||||
# 获取全量股票信息
|
||||
self.write_log(f'获取全量股票信息')
|
||||
self.symbol_dict = get_stock_base()
|
||||
self.write_log(f'共{len(self.symbol_dict)}个股票')
|
||||
# 除权因子
|
||||
self.write_log(f'获取所有除权因子')
|
||||
self.adjust_factor_dict = get_all_adjust_factor()
|
||||
self.write_log(f'共{len(self.adjust_factor_dict)}条除权信息')
|
||||
|
||||
# 寻找数据文件所在目录
|
||||
vnpy_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))
|
||||
self.write_log(f'项目所在目录:{vnpy_root}')
|
||||
self.bar_data_folder = os.path.abspath(os.path.join(vnpy_root, 'bar_data'))
|
||||
if os.path.exists(self.bar_data_folder):
|
||||
SSE_folder = os.path.abspath(os.path.join(vnpy_root, 'bar_data', 'SSE'))
|
||||
if os.path.exists(SSE_folder):
|
||||
self.write_log(f'上交所bar数据目录:{SSE_folder}')
|
||||
else:
|
||||
self.write_error(f'不存在上交所数据目录:{SSE_folder}')
|
||||
|
||||
SZSE_folder = os.path.abspath(os.path.join(vnpy_root, 'bar_data', 'SZSE'))
|
||||
if os.path.exists(SZSE_folder):
|
||||
self.write_log(f'深交所bar数据目录:{SZSE_folder}')
|
||||
else:
|
||||
self.write_error(f'不存在深交所数据目录:{SZSE_folder}')
|
||||
else:
|
||||
self.write_error(f'不存在bar数据目录:{self.bar_data_folder}')
|
||||
self.bar_data_folder = None
|
||||
|
||||
|
||||
def init_engine(self):
|
||||
"""
|
||||
"""
|
||||
@ -168,6 +213,25 @@ class CtaEngine(BaseEngine):
|
||||
|
||||
self.write_log("CTA策略股票引擎初始化成功")
|
||||
|
||||
if self.engine_config.get('get_pos_from_db',False):
|
||||
self.init_mongo_data()
|
||||
|
||||
def init_mongo_data(self):
|
||||
"""初始化hams数据库"""
|
||||
host = SETTINGS.get('hams.host', 'localhost')
|
||||
port = SETTINGS.get('hams.port', 27017)
|
||||
self.write_log(f'初始化hams数据库连接:{host}:{port}')
|
||||
try:
|
||||
# Mongo数据连接客户端
|
||||
self.mongo_data = MongoData(host=host, port=port)
|
||||
|
||||
if self.mongo_data and self.mongo_data.db_has_connected:
|
||||
self.write_log(f'连接成功')
|
||||
else:
|
||||
self.write_error(f'HAMS数据库{host}:{port}连接异常.')
|
||||
except Exception as ex:
|
||||
self.write_error(f'HAMS数据库{host}:{port}连接异常.{str(ex)}')
|
||||
|
||||
def close(self):
|
||||
"""停止所属有的策略"""
|
||||
self.stop_all_strategies()
|
||||
@ -227,12 +291,12 @@ class CtaEngine(BaseEngine):
|
||||
self.health_check()
|
||||
|
||||
# 在国内股市开盘期间做检查
|
||||
if '0930' < self.last_minute < '1530':
|
||||
if '0930' < self.last_minute < '1830':
|
||||
# 主动获取所有策略得持仓信息
|
||||
all_strategy_pos = self.get_all_strategy_pos()
|
||||
|
||||
# 每5分钟检查一次
|
||||
if dt.minute % 10 == 0:
|
||||
if dt.minute % 10 == 0 and self.engine_config.get('compare_pos', True):
|
||||
# 比对仓位,使用上述获取得持仓信息,不用重复获取
|
||||
self.compare_pos(strategy_pos_list=copy(all_strategy_pos))
|
||||
|
||||
@ -881,6 +945,25 @@ class CtaEngine(BaseEngine):
|
||||
def get_contract(self, vt_symbol):
|
||||
return self.main_engine.get_contract(vt_symbol)
|
||||
|
||||
def get_adjust_factor(self, vt_symbol, check_date=None):
|
||||
"""
|
||||
获取[check_date前]除权因子
|
||||
:param vt_symbol:
|
||||
:param check_date: 某一指定日期
|
||||
:return:
|
||||
"""
|
||||
stock_adjust_factor_list = self.stock_adjust_factors.get(vt_symbol, [])
|
||||
if len(stock_adjust_factor_list) == 0:
|
||||
return None
|
||||
stock_adjust_factor_list.reverse()
|
||||
if check_date is None:
|
||||
check_date = datetime.now().strftime('%Y-%m-%d')
|
||||
|
||||
for d in stock_adjust_factor_list:
|
||||
if d.get("dividOperateDate","") < check_date:
|
||||
return d
|
||||
return None
|
||||
|
||||
def get_account(self, vt_accountid: str = ""):
|
||||
""" 查询账号的资金"""
|
||||
# 如果启动风控,则使用风控中的最大仓位
|
||||
@ -955,6 +1038,253 @@ class CtaEngine(BaseEngine):
|
||||
|
||||
callback(bar)
|
||||
|
||||
|
||||
def get_bars(
|
||||
self,
|
||||
vt_symbol: str,
|
||||
days: int,
|
||||
interval: Interval,
|
||||
interval_num: int = 1
|
||||
):
|
||||
"""获取历史记录"""
|
||||
symbol, exchange = extract_vt_symbol(vt_symbol)
|
||||
end = datetime.now()
|
||||
start = end - timedelta(days)
|
||||
bars = []
|
||||
|
||||
# 检查股票代码
|
||||
if vt_symbol not in self.symbol_dict:
|
||||
self.write_error(f'{vt_symbol}不在基础配置股票信息中')
|
||||
return bars
|
||||
|
||||
# 检查数据文件目录
|
||||
if not self.bar_data_folder:
|
||||
self.write_error(f'没有bar数据目录')
|
||||
return bars
|
||||
# 按照交易所的存放目录
|
||||
bar_file_folder = os.path.abspath(os.path.join(self.bar_data_folder, f'{exchange.value}'))
|
||||
|
||||
resample_min = False
|
||||
resample_hour = False
|
||||
resample_day = False
|
||||
file_interval_num = 1
|
||||
# 只有1,5,15,30分钟,日线数据
|
||||
if interval == Interval.MINUTE:
|
||||
# 如果存在相应的分钟文件,直接读取
|
||||
bar_file_path = os.path.abspath(os.path.join(bar_file_folder, f'{symbol}_{interval_num}m.csv'))
|
||||
if interval_num in [1, 5, 15, 30] and os.path.exists(bar_file_path):
|
||||
file_interval_num = interval
|
||||
# 需要resample
|
||||
else:
|
||||
resample_min = True
|
||||
if interval_num > 5:
|
||||
file_interval_num = 5
|
||||
|
||||
elif interval == Interval.HOUR:
|
||||
file_interval_num = 5
|
||||
resample_hour = True
|
||||
bar_file_path = os.path.abspath(os.path.join(bar_file_folder, f'{symbol}_{file_interval_num}m.csv'))
|
||||
elif interval == Interval.DAILY:
|
||||
bar_file_path = os.path.abspath(os.path.join(bar_file_folder, f'{symbol}_{interval_num}d.csv'))
|
||||
if not os.path.exists(bar_file_path):
|
||||
file_interval_num = 5
|
||||
resample_day = True
|
||||
bar_file_path = os.path.abspath(os.path.join(bar_file_folder, f'{symbol}_{file_interval_num}m.csv'))
|
||||
else:
|
||||
self.write_error(f'目前仅支持分钟,小时,日线数据')
|
||||
return bars
|
||||
|
||||
bar_interval_seconds = interval_num * 60
|
||||
|
||||
if not os.path.exists(bar_file_path):
|
||||
self.write_error(f'没有bar数据文件:{bar_file_path}')
|
||||
return bars
|
||||
|
||||
try:
|
||||
data_types = {
|
||||
"datetime": str,
|
||||
"open": float,
|
||||
"high": float,
|
||||
"low": float,
|
||||
"close": float,
|
||||
"volume": float,
|
||||
"amount": float,
|
||||
"symbol": str,
|
||||
"trading_day": str,
|
||||
"date": str,
|
||||
"time": str
|
||||
}
|
||||
|
||||
symbol_df = None
|
||||
qfq_bar_file_path = bar_file_path.replace('.csv', '_qfq.csv')
|
||||
use_qfq_file = False
|
||||
last_qfq_dt = get_csv_last_dt(qfq_bar_file_path)
|
||||
if last_qfq_dt is not None:
|
||||
last_dt = get_csv_last_dt(bar_file_path)
|
||||
|
||||
if last_qfq_dt == last_dt:
|
||||
use_qfq_file = True
|
||||
|
||||
if use_qfq_file:
|
||||
self.write_log(f'使用前复权文件:{qfq_bar_file_path}')
|
||||
symbol_df = pd.read_csv(qfq_bar_file_path, dtype=data_types)
|
||||
else:
|
||||
# 加载csv文件 =》 dateframe
|
||||
self.write_log(f'使用未复权文件:{bar_file_path}')
|
||||
symbol_df = pd.read_csv(bar_file_path, dtype=data_types)
|
||||
|
||||
# 转换时间,str =》 datetime
|
||||
symbol_df["datetime"] = pd.to_datetime(symbol_df["datetime"], format="%Y-%m-%d %H:%M:%S")
|
||||
# 设置时间为索引
|
||||
symbol_df = symbol_df.set_index("datetime")
|
||||
|
||||
# 裁剪数据
|
||||
symbol_df = symbol_df.loc[start:end]
|
||||
|
||||
if resample_day:
|
||||
self.write_log(f'{vt_symbol} resample:{file_interval_num}m => {interval}day')
|
||||
symbol_df = self.resample_bars(df=symbol_df, to_day=True)
|
||||
elif resample_hour:
|
||||
self.write_log(f'{vt_symbol} resample:{file_interval_num}m => {interval}hour')
|
||||
symbol_df = self.resample_bars(df=symbol_df, x_hour=interval_num)
|
||||
elif resample_min:
|
||||
self.write_log(f'{vt_symbol} resample:{file_interval_num}m => {interval}m')
|
||||
symbol_df = self.resample_bars(df=symbol_df, x_min=interval_num)
|
||||
|
||||
if len(symbol_df) == 0:
|
||||
return bars
|
||||
|
||||
if not use_qfq_file:
|
||||
# 复权转换
|
||||
adj_list = self.adjust_factor_dict.get(vt_symbol, [])
|
||||
# 按照结束日期,裁剪复权记录
|
||||
adj_list = [row for row in adj_list if row['dividOperateDate'].replace('-', '') <= end.strftime('%Y%m%d')]
|
||||
|
||||
if len(adj_list) > 0:
|
||||
self.write_log(f'需要对{vt_symbol}进行前复权处理')
|
||||
for row in adj_list:
|
||||
row.update({'dividOperateDate': row.get('dividOperateDate')[:10] + ' 09:31:00'})
|
||||
# list -> dataframe, 转换复权日期格式
|
||||
adj_data = pd.DataFrame(adj_list)
|
||||
adj_data["dividOperateDate"] = pd.to_datetime(adj_data["dividOperateDate"], format="%Y-%m-%d %H:%M:%S")
|
||||
adj_data = adj_data.set_index("dividOperateDate")
|
||||
# 调用转换方法,对open,high,low,close, volume进行复权, fore, 前复权, 其他,后复权
|
||||
symbol_df = self.stock_to_adj(symbol_df, adj_data, adj_type='fore')
|
||||
|
||||
for dt, bar_data in symbol_df.iterrows():
|
||||
bar_datetime = dt #- timedelta(seconds=bar_interval_seconds)
|
||||
|
||||
bar = BarData(
|
||||
gateway_name='backtesting',
|
||||
symbol=symbol,
|
||||
exchange=exchange,
|
||||
datetime=bar_datetime
|
||||
)
|
||||
if 'open' in bar_data:
|
||||
bar.open_price = float(bar_data['open'])
|
||||
bar.close_price = float(bar_data['close'])
|
||||
bar.high_price = float(bar_data['high'])
|
||||
bar.low_price = float(bar_data['low'])
|
||||
else:
|
||||
bar.open_price = float(bar_data['open_price'])
|
||||
bar.close_price = float(bar_data['close_price'])
|
||||
bar.high_price = float(bar_data['high_price'])
|
||||
bar.low_price = float(bar_data['low_price'])
|
||||
|
||||
bar.volume = int(bar_data['volume']) if not np.isnan(bar_data['volume']) else 0
|
||||
bar.date = dt.strftime('%Y-%m-%d')
|
||||
bar.time = dt.strftime('%H:%M:%S')
|
||||
str_td = str(bar_data.get('trading_day', ''))
|
||||
if len(str_td) == 8:
|
||||
bar.trading_day = str_td[0:4] + '-' + str_td[4:6] + '-' + str_td[6:8]
|
||||
else:
|
||||
bar.trading_day = bar.date
|
||||
|
||||
bars.append(bar)
|
||||
|
||||
except Exception as ex:
|
||||
self.write_error(u'回测时读取{} csv文件{}失败:{}'.format(vt_symbol, bar_file_path, ex))
|
||||
self.write_error(traceback.format_exc())
|
||||
return bars
|
||||
|
||||
return bars
|
||||
|
||||
def stock_to_adj(self, raw_data, adj_data, adj_type):
|
||||
"""
|
||||
股票数据复权转换
|
||||
:param raw_data: 不复权数据
|
||||
:param adj_data: 复权记录 ( 从barstock下载的复权记录列表=》df)
|
||||
:param adj_type: 复权类型
|
||||
:return:
|
||||
"""
|
||||
|
||||
if adj_type == 'fore':
|
||||
adj_factor = adj_data["foreAdjustFactor"]
|
||||
adj_factor = adj_factor / adj_factor.iloc[-1] # 保证最后一个复权因子是1
|
||||
else:
|
||||
adj_factor = adj_data["backAdjustFactor"]
|
||||
adj_factor = adj_factor / adj_factor.iloc[0] # 保证第一个复权因子是1
|
||||
|
||||
# 把raw_data的第一个日期,插入复权因子df,使用后填充
|
||||
if adj_factor.index[0] != raw_data.index[0]:
|
||||
adj_factor.loc[raw_data.index[0]] = np.nan
|
||||
adj_factor.sort_index(inplace=True)
|
||||
adj_factor = adj_factor.ffill()
|
||||
|
||||
adj_factor = adj_factor.reindex(index=raw_data.index) # 按价格dataframe的日期索引来扩展索引
|
||||
adj_factor = adj_factor.ffill() # 向前(向未来)填充扩展后的空单元格
|
||||
|
||||
# 把复权因子,作为adj字段,补充到raw_data中
|
||||
raw_data['adj'] = adj_factor
|
||||
|
||||
# 逐一复权高低开平和成交量
|
||||
for col in ['open', 'high', 'low', 'close']:
|
||||
raw_data[col] = raw_data[col] * raw_data['adj'] # 价格乘上复权系数
|
||||
raw_data['volume'] = raw_data['volume'] / raw_data['adj'] # 成交量除以复权系数
|
||||
|
||||
return raw_data
|
||||
|
||||
def resample_bars(self, df, x_min=None, x_hour=None, to_day=False):
|
||||
"""
|
||||
重建x分钟K线(或日线)
|
||||
:param df: 输入分钟数
|
||||
:param x_min: 5, 15, 30, 60
|
||||
:param x_hour: 1, 2, 3, 4
|
||||
:param include_day: 重建日线, True得时候,不会重建分钟数
|
||||
:return:
|
||||
"""
|
||||
# 设置df数据中每列的规则
|
||||
ohlc_rule = {
|
||||
'open': 'first', # open列:序列中第一个的值
|
||||
'high': 'max', # high列:序列中最大的值
|
||||
'low': 'min', # low列:序列中最小的值
|
||||
'close': 'last', # close列:序列中最后一个的值
|
||||
'volume': 'sum', # volume列:将所有序列里的volume值作和
|
||||
'amount': 'sum', # amount列:将所有序列里的amount值作和
|
||||
"symbol": 'first',
|
||||
"trading_date": 'first',
|
||||
"date": 'first',
|
||||
"time": 'first'
|
||||
}
|
||||
|
||||
if isinstance(x_min, int) and not to_day:
|
||||
# 合成x分钟K线并删除为空的行 参数 closed:left类似向上取值既 09:30的k线数据是包含09:30-09:35之间的数据
|
||||
df_target = df.resample(f'{x_min}min', closed='left', label='left').agg(ohlc_rule).dropna(axis=0,
|
||||
how='any')
|
||||
return df_target
|
||||
if isinstance(x_hour, int) and not to_day:
|
||||
# 合成x小时K线并删除为空的行 参数 closed:left类似向上取值既 09:30的k线数据是包含09:30-09:35之间的数据
|
||||
df_target = df.resample(f'{x_hour}hour', closed='left', label='left').agg(ohlc_rule).dropna(axis=0,
|
||||
how='any')
|
||||
return df_target
|
||||
|
||||
if to_day:
|
||||
# 合成x分钟K线并删除为空的行 参数 closed:left类似向上取值既 09:30的k线数据是包含09:30-09:35之间的数据
|
||||
df_target = df.resample(f'D', closed='left', label='left').agg(ohlc_rule).dropna(axis=0, how='any')
|
||||
return df_target
|
||||
|
||||
return df
|
||||
|
||||
def call_strategy_func(
|
||||
self, strategy: CtaTemplate, func: Callable, params: Any = None
|
||||
):
|
||||
@ -1153,7 +1483,8 @@ class CtaEngine(BaseEngine):
|
||||
# Remove from symbol strategy map
|
||||
self.write_log(f'移除{vt_symbol}《=》{strategy_name}的订阅关系')
|
||||
strategies = self.symbol_strategy_map[vt_symbol]
|
||||
strategies.remove(strategy)
|
||||
if strategy in strategies:
|
||||
strategies.remove(strategy)
|
||||
|
||||
# Remove from active orderid map
|
||||
if strategy_name in self.strategy_orderid_map:
|
||||
@ -1528,6 +1859,27 @@ class CtaEngine(BaseEngine):
|
||||
|
||||
return strategy_pos_list
|
||||
|
||||
def get_all_strategy_pos_from_hams(self):
|
||||
"""
|
||||
获取hams中该账号下所有策略仓位明细
|
||||
"""
|
||||
strategy_pos_list = []
|
||||
if not self.mongo_data:
|
||||
self.init_mongo_data()
|
||||
|
||||
if self.mongo_data and self.mongo_data.db_has_connected:
|
||||
filter = {'account_id':self.engine_config.get('accountid','-')}
|
||||
|
||||
pos_list = self.mongo_data.db_query(
|
||||
db_name='Account',
|
||||
col_name='today_strategy_pos',
|
||||
filter_dict=filter
|
||||
)
|
||||
for pos in pos_list:
|
||||
strategy_pos_list.append(pos)
|
||||
|
||||
return strategy_pos_list
|
||||
|
||||
def get_strategy_class_parameters(self, class_name: str):
|
||||
"""
|
||||
Get default parameters of a strategy class.
|
||||
@ -1580,9 +1932,17 @@ class CtaEngine(BaseEngine):
|
||||
|
||||
self.write_log(u'开始对比账号&策略的持仓')
|
||||
|
||||
# 获取当前策略得持仓
|
||||
if len(strategy_pos_list) == 0:
|
||||
strategy_pos_list = self.get_all_strategy_pos()
|
||||
# 获取hams数据库中所有运行实例得策略
|
||||
if self.engine_config.get("get_pos_from_db", False):
|
||||
strategy_pos_list = self.get_all_strategy_pos_from_hams()
|
||||
else:
|
||||
# 获取当前实例运行策略得持仓
|
||||
if len(strategy_pos_list) == 0:
|
||||
strategy_pos_list = self.get_all_strategy_pos()
|
||||
|
||||
# 过滤账号中名字列表
|
||||
bypass_names = self.engine_config.get('compare_pos_bypass_names',[])
|
||||
|
||||
self.write_log(u'策略持仓清单:{}'.format(strategy_pos_list))
|
||||
|
||||
none_strategy_pos = self.get_none_strategy_pos_list()
|
||||
@ -1597,6 +1957,12 @@ class CtaEngine(BaseEngine):
|
||||
for position in list(self.positions.values()):
|
||||
# gateway_name.symbol.exchange => symbol.exchange
|
||||
vt_symbol = position.vt_symbol
|
||||
cn_name = self.get_name(vt_symbol)
|
||||
|
||||
# 中文名字,在过滤清单中
|
||||
if any([name in cn_name for name in bypass_names]):
|
||||
continue
|
||||
|
||||
vt_symbols.add(vt_symbol)
|
||||
|
||||
compare_pos[vt_symbol] = OrderedDict(
|
||||
@ -1631,6 +1997,8 @@ class CtaEngine(BaseEngine):
|
||||
u'{}({})'.format(strategy_pos['strategy_name'], abs(pos.get('volume', 0))))
|
||||
self.write_log(u'更新{}策略持多仓=>{}'.format(vt_symbol, symbol_pos.get('策略多单', 0)))
|
||||
|
||||
compare_pos.update({vt_symbol:symbol_pos})
|
||||
|
||||
pos_compare_result = ''
|
||||
# 精简输出
|
||||
compare_info = ''
|
||||
@ -1724,14 +2092,7 @@ class CtaEngine(BaseEngine):
|
||||
"""
|
||||
Load setting file.
|
||||
"""
|
||||
# 读取引擎得配置
|
||||
# "accountid" : "xxxx", 资金账号,一般用于推送消息时附带
|
||||
# "strategy_group": "cta_strategy_pro", # 当前实例名。多个实例时,区分开
|
||||
# "trade_2_wx": true # 是否交易记录转发至微信通知
|
||||
# "event_log: false # 是否转发日志到event bus,显示在图形界面
|
||||
self.engine_config = load_json(self.engine_filename)
|
||||
# 是否产生event log 日志(一般GUI界面才产生,而且比好消耗资源)
|
||||
self.event_log = self.engine_config.get('event_log', False)
|
||||
|
||||
# 读取策略得配置
|
||||
self.strategy_setting = load_json(self.setting_filename)
|
||||
|
||||
|
@ -15,6 +15,7 @@ import traceback
|
||||
import random
|
||||
import bz2
|
||||
import pickle
|
||||
import numpy as np
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from time import sleep
|
||||
@ -31,6 +32,7 @@ from vnpy.trader.constant import (
|
||||
from vnpy.trader.utility import (
|
||||
get_trading_date,
|
||||
extract_vt_symbol,
|
||||
get_csv_last_dt
|
||||
)
|
||||
|
||||
from .back_testing import BackTestingEngine
|
||||
@ -57,7 +59,7 @@ class PortfolioTestingEngine(BackTestingEngine):
|
||||
|
||||
self.tick_path = None # tick级别回测, 路径
|
||||
|
||||
def load_bar_csv_to_df(self, vt_symbol, bar_file, data_start_date=None, data_end_date=None):
|
||||
def load_bar_csv_to_df(self, vt_symbol, bar_file, data_start_date=None, data_end_date=None, qfq=True):
|
||||
"""
|
||||
加载回测bar数据到DataFrame
|
||||
1. 增加前复权/后复权
|
||||
@ -65,14 +67,16 @@ class PortfolioTestingEngine(BackTestingEngine):
|
||||
:param bar_file:
|
||||
:param data_start_date:
|
||||
:param data_end_date:
|
||||
:param qfq:True 前复权,False 后复权
|
||||
:return:
|
||||
"""
|
||||
self.output(u'loading {} from {}'.format(vt_symbol, bar_file))
|
||||
fq_name = '前复权' if qfq else '后复权'
|
||||
self.output(u'加载数据[{}] :未复权文件{},复权转换:{}'.format(vt_symbol, bar_file, fq_name))
|
||||
if vt_symbol in self.bar_df_dict:
|
||||
return True
|
||||
|
||||
if bar_file is None or not os.path.exists(bar_file):
|
||||
self.write_error(u'回测时,{}对应的csv bar文件{}不存在'.format(vt_symbol, bar_file))
|
||||
self.write_error(u'加载数据[{}]:对应的csv bar文件{}不存在'.format(vt_symbol, bar_file))
|
||||
return False
|
||||
|
||||
try:
|
||||
@ -93,34 +97,64 @@ class PortfolioTestingEngine(BackTestingEngine):
|
||||
"date": str,
|
||||
"time": str
|
||||
}
|
||||
# 加载csv文件 =》 dateframe
|
||||
symbol_df = pd.read_csv(bar_file, dtype=data_types)
|
||||
# 转换时间,str =》 datetime
|
||||
symbol_df["datetime"] = pd.to_datetime(symbol_df["datetime"], format="%Y-%m-%d %H:%M:%S")
|
||||
# 设置时间为索引
|
||||
symbol_df = symbol_df.set_index("datetime")
|
||||
|
||||
# 裁剪数据
|
||||
symbol_df = symbol_df.loc[self.test_start_date:self.test_end_date]
|
||||
symbol_df = None
|
||||
auto_generate_fq = True
|
||||
# 复权文件
|
||||
fq_bar_file = bar_file.replace('.csv', '_qfq.csv' if qfq else '_hfq.csv')
|
||||
if os.path.exists(fq_bar_file):
|
||||
# 存在复权文件
|
||||
last_dt = get_csv_last_dt(fq_bar_file)
|
||||
if isinstance(last_dt, datetime):
|
||||
if last_dt.strftime('%Y-%m-%d') < self.test_end_date:
|
||||
self.write_log(f'加载数据[{vt_symbol}], 使用{fq_name}文件:{fq_bar_file}')
|
||||
symbol_df = pd.read_csv(bar_file, dtype=data_types)
|
||||
# 转换时间,str =》 datetime
|
||||
symbol_df["datetime"] = pd.to_datetime(symbol_df["datetime"], format="%Y-%m-%d %H:%M:%S")
|
||||
# 设置时间为索引
|
||||
symbol_df = symbol_df.set_index("datetime")
|
||||
# 裁剪数据
|
||||
symbol_df = symbol_df.loc[self.test_start_date:self.test_end_date]
|
||||
# 不再产生复权文件
|
||||
auto_generate_fq = False
|
||||
|
||||
# 复权转换
|
||||
adj_list = self.adjust_factors.get(vt_symbol, [])
|
||||
# 按照结束日期,裁剪复权记录
|
||||
adj_list = [row for row in adj_list if row['dividOperateDate'].replace('-', '') <= self.test_end_date]
|
||||
if not isinstance(symbol_df, pd.DataFrame):
|
||||
# 加载csv文件 =》 dateframe
|
||||
symbol_df = pd.read_csv(bar_file, dtype=data_types)
|
||||
# 转换时间,str =》 datetime
|
||||
symbol_df["datetime"] = pd.to_datetime(symbol_df["datetime"], format="%Y-%m-%d %H:%M:%S")
|
||||
# 设置时间为索引
|
||||
symbol_df = symbol_df.set_index("datetime")
|
||||
|
||||
if adj_list:
|
||||
self.write_log(f'需要对{vt_symbol}进行前复权处理')
|
||||
for row in adj_list:
|
||||
row.update({'dividOperateDate': row.get('dividOperateDate') + ' 09:31:00'})
|
||||
# list -> dataframe, 转换复权日期格式
|
||||
adj_data = pd.DataFrame(adj_list)
|
||||
adj_data["dividOperateDate"] = pd.to_datetime(adj_data["dividOperateDate"], format="%Y-%m-%d %H:%M:%S")
|
||||
adj_data = adj_data.set_index("dividOperateDate")
|
||||
# 调用转换方法,对open,high,low,close, volume进行复权, fore, 前复权, 其他,后复权
|
||||
symbol_df = self.stock_to_adj(symbol_df, adj_data, adj_type='fore')
|
||||
# 裁剪数据
|
||||
symbol_df = symbol_df.loc[self.test_start_date:self.test_end_date]
|
||||
|
||||
# 添加到待合并dataframe dict中
|
||||
self.bar_df_dict.update({vt_symbol: symbol_df})
|
||||
# 复权转换
|
||||
adj_list = self.adjust_factors.get(vt_symbol, [])
|
||||
# 按照结束日期,裁剪复权记录
|
||||
adj_list = [row for row in adj_list if row['dividOperateDate'].replace('-', '') <= self.test_end_date]
|
||||
|
||||
if adj_list:
|
||||
self.write_log(f'加载数据[{vt_symbol}], 对{vt_symbol}进行{fq_name}处理')
|
||||
for row in adj_list:
|
||||
d = row.get('dividOperateDate', "")[0:10]
|
||||
if len(d) == 10:
|
||||
row.update({'dividOperateDate': d + ' 09:31:00'})
|
||||
# list -> dataframe, 转换复权日期格式
|
||||
adj_data = pd.DataFrame(adj_list)
|
||||
adj_data["dividOperateDate"] = pd.to_datetime(adj_data["dividOperateDate"],
|
||||
format="%Y-%m-%d %H:%M:%S")
|
||||
adj_data = adj_data.set_index("dividOperateDate")
|
||||
# 调用转换方法,对open,high,low,close, volume进行复权, fore, 前复权, 其他,后复权
|
||||
symbol_df = self.stock_to_adj(symbol_df, adj_data, adj_type='fore' if qfq else "")
|
||||
|
||||
if auto_generate_fq:
|
||||
self.write_log(f'加载数据[{vt_symbol}] ,缓存{fq_name}文件=>{fq_bar_file}')
|
||||
symbol_df.to_csv(fq_bar_file)
|
||||
|
||||
if isinstance(symbol_df, pd.DataFrame):
|
||||
# 添加到待合并dataframe dict中
|
||||
self.bar_df_dict.update({vt_symbol: symbol_df})
|
||||
|
||||
except Exception as ex:
|
||||
self.write_error(u'回测时读取{} csv文件{}失败:{}'.format(vt_symbol, bar_file, ex))
|
||||
@ -277,7 +311,7 @@ class PortfolioTestingEngine(BackTestingEngine):
|
||||
bar.high_price = float(bar_data['high_price'])
|
||||
bar.low_price = float(bar_data['low_price'])
|
||||
|
||||
bar.volume = int(bar_data['volume'])
|
||||
bar.volume = float(bar_data['volume'])
|
||||
bar.date = dt.strftime('%Y-%m-%d')
|
||||
bar.time = dt.strftime('%H:%M:%S')
|
||||
str_td = str(bar_data.get('trading_day', ''))
|
||||
@ -299,6 +333,8 @@ class PortfolioTestingEngine(BackTestingEngine):
|
||||
# bar时间与队列时间不一致,先推送队列的bars
|
||||
random.shuffle(bars_same_dt)
|
||||
for _bar_ in bars_same_dt:
|
||||
if np.isnan(_bar_.close_price):
|
||||
continue
|
||||
self.new_bar(_bar_)
|
||||
|
||||
# 创建新的队列
|
||||
|
@ -476,40 +476,68 @@ class CtaStockTemplate(CtaTemplate):
|
||||
self.write_log(u'保存policy数据')
|
||||
self.policy.save()
|
||||
|
||||
def save_klines_to_cache(self, kline_names: list = []):
|
||||
def save_klines_to_cache(self, kline_names: list = [], vt_symbol: str = ""):
|
||||
"""
|
||||
保存K线数据到缓存
|
||||
:param kline_names: 一般为self.klines的keys
|
||||
:param vt_symbol: 指定股票代码,
|
||||
如果使用该选项,加载 data/klines/strategyname_vtsymbol_klines.pkb2
|
||||
如果空白,加载 data/strategyname_klines.pkb2
|
||||
:return:
|
||||
"""
|
||||
if len(kline_names) == 0:
|
||||
kline_names = list(self.klines.keys())
|
||||
|
||||
# 获取保存路径
|
||||
save_path = self.cta_engine.get_data_path()
|
||||
# 保存缓存的文件名
|
||||
file_name = os.path.abspath(os.path.join(save_path, f'{self.strategy_name}_klines.pkb2'))
|
||||
with bz2.BZ2File(file_name, 'wb') as f:
|
||||
klines = {}
|
||||
for kline_name in kline_names:
|
||||
kline = self.klines.get(kline_name, None)
|
||||
# if kline:
|
||||
# kline.strategy = None
|
||||
# kline.cb_on_bar = None
|
||||
klines.update({kline_name: kline})
|
||||
pickle.dump(klines, f)
|
||||
try:
|
||||
# 如果是指定合约的话,使用klines子目录
|
||||
if len(vt_symbol) > 0:
|
||||
kline_names = [n for n in kline_names if vt_symbol in n]
|
||||
save_path = os.path.abspath(os.path.join(self.cta_engine.get_data_path(), 'klines'))
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
file_name = os.path.abspath(os.path.join(save_path, f'{self.strategy_name}_{vt_symbol}_klines.pkb2'))
|
||||
else:
|
||||
# 获取保存路径
|
||||
save_path = self.cta_engine.get_data_path()
|
||||
# 保存缓存的文件名
|
||||
file_name = os.path.abspath(os.path.join(save_path, f'{self.strategy_name}_klines.pkb2'))
|
||||
|
||||
def load_klines_from_cache(self, kline_names: list = []):
|
||||
with bz2.BZ2File(file_name, 'wb') as f:
|
||||
klines = {}
|
||||
for kline_name in kline_names:
|
||||
kline = self.klines.get(kline_name, None)
|
||||
# if kline:
|
||||
# kline.strategy = None
|
||||
# kline.cb_on_bar = None
|
||||
klines.update({kline_name: kline})
|
||||
pickle.dump(klines, f)
|
||||
self.write_log(f'保存{vt_symbol} K线数据成功=>{file_name}')
|
||||
except Exception as ex:
|
||||
self.write_error(f'保存k线数据异常:{str(ex)}')
|
||||
self.write_error(traceback.format_exc())
|
||||
|
||||
def load_klines_from_cache(self, kline_names: list = [], vt_symbol: str = ""):
|
||||
"""
|
||||
从缓存加载K线数据
|
||||
:param kline_names: 指定需要加载的k线名称列表
|
||||
:param vt_symbol: 指定股票代码,
|
||||
如果使用该选项,加载 data/klines/strategyname_vtsymbol_klines.pkb2
|
||||
如果空白,加载 data/strategyname_klines.pkb2
|
||||
:return:
|
||||
"""
|
||||
if len(kline_names) == 0:
|
||||
kline_names = list(self.klines.keys())
|
||||
|
||||
save_path = self.cta_engine.get_data_path()
|
||||
file_name = os.path.abspath(os.path.join(save_path, f'{self.strategy_name}_klines.pkb2'))
|
||||
# 如果是指定合约的话,使用klines子目录
|
||||
if len(vt_symbol) > 0:
|
||||
kline_names = [n for n in kline_names if vt_symbol in n]
|
||||
save_path = os.path.abspath(os.path.join(self.cta_engine.get_data_path(), 'klines'))
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
file_name = os.path.abspath(os.path.join(save_path, f'{self.strategy_name}_{vt_symbol}_klines.pkb2'))
|
||||
else:
|
||||
save_path = self.cta_engine.get_data_path()
|
||||
file_name = os.path.abspath(os.path.join(save_path, f'{self.strategy_name}_klines.pkb2'))
|
||||
try:
|
||||
last_bar_dt = None
|
||||
with bz2.BZ2File(file_name, 'rb') as f:
|
||||
@ -976,7 +1004,7 @@ class CtaStockTemplate(CtaTemplate):
|
||||
self.gt.remove_grids_by_ids(direction=Direction.LONG, ids=remove_gids)
|
||||
self.gt.save()
|
||||
|
||||
def tns_excute_sell_grids(self, vt_symbol=None):
|
||||
def tns_excute_sell_grids(self, vt_symbol=None, force=False):
|
||||
"""
|
||||
事务执行卖出网格
|
||||
1、找出所有order_status=True,open_status=Talse, close_status=True的网格。
|
||||
@ -1027,17 +1055,27 @@ class CtaStockTemplate(CtaTemplate):
|
||||
sell_volume = ordering_grid.volume - ordering_grid.traded_volume
|
||||
|
||||
if sell_volume > acc_symbol_pos.volume:
|
||||
self.write_error(u'账号{}持仓{},不满足减仓目标:{}'
|
||||
.format(vt_symbol, acc_symbol_pos.volume, sell_volume))
|
||||
if not force:
|
||||
self.write_error(u'账号{}持仓{},不满足减仓目标:{}'
|
||||
.format(vt_symbol, acc_symbol_pos.volume, sell_volume))
|
||||
continue
|
||||
else:
|
||||
self.write_log(u'账号{}持仓{},不满足减仓目标:{}, 修正卖出数量:{}=>{}'
|
||||
.format(vt_symbol, acc_symbol_pos.volume, sell_volume, sell_volume,
|
||||
acc_symbol_pos.volume))
|
||||
sell_volume = acc_symbol_pos.volume
|
||||
|
||||
if sell_volume == 0:
|
||||
self.write_log(f'账号{vt_symbol}持仓{acc_symbol_pos.volume},卖出目标:{sell_volume}=0 不执行')
|
||||
continue
|
||||
|
||||
cur_price = self.cta_engine.get_price(vt_symbol)
|
||||
if not cur_price:
|
||||
self.cta_engine.subscribe_symbol(strategy_name=self.strategy_name, vt_symbol=vt_symbol)
|
||||
continue
|
||||
|
||||
# 实盘运行时,要加入市场买卖量的判断
|
||||
if not self.backtesting:
|
||||
cur_price = self.cta_engine.get_price(vt_symbol)
|
||||
if not cur_price:
|
||||
self.cta_engine.subscribe_symbol(strategy_name=self.strategy_name, vt_symbol=vt_symbol)
|
||||
continue
|
||||
|
||||
if not force and not self.backtesting:
|
||||
symbol_tick = self.cta_engine.get_tick(vt_symbol)
|
||||
if symbol_tick:
|
||||
symbol_volume_tick = self.cta_engine.get_volume_tick(vt_symbol)
|
||||
@ -1072,7 +1110,7 @@ class CtaStockTemplate(CtaTemplate):
|
||||
self.write_log(f'{vt_symbol} 已委托卖出,{sell_volume},委托价:{sell_price}, 数量:{sell_volume}')
|
||||
|
||||
|
||||
def tns_finish_sell_grid(self, grid):
|
||||
def tns_finish_sell_grid(self, grid:CtaGrid):
|
||||
"""
|
||||
事务完成卖出网格
|
||||
:param grid:
|
||||
@ -1106,7 +1144,7 @@ class CtaStockTemplate(CtaTemplate):
|
||||
self.gt.save()
|
||||
self.policy.save()
|
||||
|
||||
def tns_execute_buy_grids(self, vt_symbol=None):
|
||||
def tns_execute_buy_grids(self, vt_symbol=None, force=False):
|
||||
"""
|
||||
事务执行买入网格
|
||||
:return:
|
||||
@ -1149,7 +1187,8 @@ class CtaStockTemplate(CtaTemplate):
|
||||
|
||||
balance, availiable, _, _ = self.cta_engine.get_account()
|
||||
if availiable <= 0:
|
||||
self.write_error(u'当前可用资金不足'.format(availiable))
|
||||
if not self.backtesting:
|
||||
self.write_log(u'当前可用资金{}不足,总资金:{}'.format(availiable, balance))
|
||||
continue
|
||||
vt_symbol = ordering_grid.vt_symbol
|
||||
cur_price = self.cta_engine.get_price(vt_symbol)
|
||||
@ -1178,7 +1217,7 @@ class CtaStockTemplate(CtaTemplate):
|
||||
buy_volume = max_buy_volume
|
||||
|
||||
# 实盘运行时,要加入市场买卖量的判断
|
||||
if not self.backtesting and 'market' in ordering_grid.snapshot:
|
||||
if not force and not self.backtesting and 'market' in ordering_grid.snapshot:
|
||||
symbol_tick = self.cta_engine.get_tick(vt_symbol)
|
||||
if symbol_tick:
|
||||
# 根据市场计算,前5档买单数量
|
||||
@ -1207,7 +1246,7 @@ class CtaStockTemplate(CtaTemplate):
|
||||
else:
|
||||
self.write_log(f'{self.strategy_name}, {vt_orderids},已委托买入,{vt_symbol} 委托价:{buy_price} 数量:{buy_volume}')
|
||||
|
||||
def tns_finish_buy_grid(self, grid):
|
||||
def tns_finish_buy_grid(self, grid:CtaGrid):
|
||||
"""
|
||||
事务完成买入网格
|
||||
:return:
|
||||
|
Loading…
Reference in New Issue
Block a user