[增强] 回测引擎支持PositionHolding

This commit is contained in:
msincenselee 2020-02-29 17:26:30 +08:00
parent cfedf54cde
commit 72222b5c49
5 changed files with 220 additions and 134 deletions

View File

@ -0,0 +1,37 @@
# flake8: noqa
import sys, os, copy, csv, signal
vnpy_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
if vnpy_root not in sys.path:
print(f'append {vnpy_root} into sys.path')
sys.path.append(vnpy_root)
os.environ["VNPY_TESTING"] = "1"
from vnpy.data.renko.rebuild_future import *
if __name__ == "__main__":
if len(sys.argv) < 4:
print(u'请输入三个参数 host symbol pricetick')
exit()
print(sys.argv)
host = sys.argv[1]
setting = {
"host": host,
"db_name": FUTURE_RENKO_DB_NAME,
"cache_folder": os.path.join(vnpy_root, 'ticks', 'tdx', 'future')
}
builder = FutureRenkoRebuilder(setting)
symbol = sys.argv[2]
price_tick = float(sys.argv[3])
print(f'启动期货renko补全,数据库:{host}/{FUTURE_RENKO_DB_NAME} 合约:{symbol}')
builder.start(symbol=symbol, price_tick=price_tick, height=[3, 5, 10, 'K3', 'K5', 'K10'], refill=True)
print(f'exit refill {symbol} renkos')

View File

@ -0,0 +1,63 @@
# flake8: noqa
"""
多周期显示K线
时间点同步
华富资产/李来佳
"""
import sys
import os
import ctypes
import platform
system = platform.system()
os.environ["VNPY_TESTING"] = "1"
# 将repostory的目录作为根目录添加到系统环境中。
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), '..' , '..'))
sys.path.append(ROOT_PATH)
from vnpy.trader.ui.kline.crosshair import Crosshair
from vnpy.trader.ui.kline.kline import *
if __name__ == '__main__':
# K线界面
try:
kline_settings = {
"M15":
{
"data_file": "log/renko_reverse_v1_J99_M15.csv",
"main_indicators": [
"pre_high",
"pre_low",
"ma5",
"ma18",
"ma60"
],
"sub_indicators": [
"atr"
]
},
"Renko_K3":
{
"data_file": "log/renko_reverse_v1_J99_Renko_K3.csv",
"main_indicators": [
"pre_high",
"pre_low",
"boll_upper",
"boll_middle",
"boll_lower",
"ma5",
"yb"
],
"sub_indicators": [
"sk",
"sd",
]
}
}
display_multi_grid(kline_settings)
except Exception as ex:
print(u'exception:{},trace:{}'.format(str(ex), traceback.format_exc()))

View File

@ -77,8 +77,11 @@ class BackTestingEngine(object):
# 绑定事件引擎 # 绑定事件引擎
self.event_engine = event_engine self.event_engine = event_engine
self.mode = 'bar' # 'bar': 根据1分钟k线进行回测 'tick'根据分笔tick进行回测
# 引擎类型为回测 # 引擎类型为回测
self.engine_type = EngineType.BACKTESTING self.engine_type = EngineType.BACKTESTING
self.contract_type = 'future' # future, stock, digital
# 回测策略相关 # 回测策略相关
self.classes = {} # 策略类class_name: stategy_class self.classes = {} # 策略类class_name: stategy_class
@ -129,6 +132,8 @@ class BackTestingEngine(object):
self.long_position_list = [] # 多单持仓 self.long_position_list = [] # 多单持仓
self.short_position_list = [] # 空单持仓 self.short_position_list = [] # 空单持仓
self.holdings = {} # 多空持仓
# 当前最新数据,用于模拟成交用 # 当前最新数据,用于模拟成交用
self.gateway_name = u'BackTest' self.gateway_name = u'BackTest'
@ -357,6 +362,30 @@ class BackTestingEngine(object):
def get_exchange(self, symbol: str): def get_exchange(self, symbol: str):
return self.symbol_exchange_dict.get(symbol, Exchange.LOCAL) return self.symbol_exchange_dict.get(symbol, Exchange.LOCAL)
def get_position_holding(self, vt_symbol: str, gateway_name: str = ''):
""" 查询合约在账号的持仓(包含多空)"""
k = f'{gateway_name}.{vt_symbol}'
holding = self.holdings.get(k, None)
if not holding:
symbol, exchange = extract_vt_symbol(vt_symbol)
if self.contract_type== 'future':
product = Product.FUTURES
elif self.contract_type == 'stock':
product = Product.EQUITY
else:
product = Product.SPOT
contract = ContractData(gateway_name=gateway_name,
name=vt_symbol,
product=product,
symbol=symbol,
exchange=exchange,
size=self.get_size(vt_symbol),
pricetick=self.get_price_tick(vt_symbol),
margin_rate=self.get_margin_rate(vt_symbol))
holding = PositionHolding(contract)
self.holdings[k] = holding
return holding
def set_name(self, test_name): def set_name(self, test_name):
""" """
设置组合的运行实例名称 设置组合的运行实例名称
@ -387,6 +416,12 @@ class BackTestingEngine(object):
if 'name' in test_settings: if 'name' in test_settings:
self.set_name(test_settings.get('name')) self.set_name(test_settings.get('name'))
self.mode = test_settings.get('mode', 'bar')
self.output(f'采用{self.mode}方式回测')
self.contract_type = test_settings.get('contract_type', 'future')
self.output(f'测试合约主要为{self.contract_type}')
self.debug = test_settings.get('debug', False) self.debug = test_settings.get('debug', False)
# 更新数据目录 # 更新数据目录
@ -442,6 +477,9 @@ class BackTestingEngine(object):
self.write_log(u'准备数据') self.write_log(u'准备数据')
self.prepare_data(test_settings.get('symbol_datas')) self.prepare_data(test_settings.get('symbol_datas'))
if self.mode == 'tick':
self.tick_path = test_settings.get('tick_path', None)
# 设置bar文件的时间间隔秒数 # 设置bar文件的时间间隔秒数
if 'bar_interval_seconds' in test_settings: if 'bar_interval_seconds' in test_settings:
self.write_log(u'设置bar文件的时间间隔秒数{}'.format(test_settings.get('bar_interval_seconds'))) self.write_log(u'设置bar文件的时间间隔秒数{}'.format(test_settings.get('bar_interval_seconds')))
@ -639,14 +677,18 @@ class BackTestingEngine(object):
vt_symbol = strategy_setting.get('vt_symbol') vt_symbol = strategy_setting.get('vt_symbol')
if '.' in vt_symbol: if '.' in vt_symbol:
symbol, exchange = extract_vt_symbol(vt_symbol) symbol, exchange = extract_vt_symbol(vt_symbol)
else: elif self.contract_type == 'future':
symbol = vt_symbol symbol = vt_symbol
underly_symbol = get_underlying_symbol(symbol).upper() underly_symbol = get_underlying_symbol(symbol).upper()
exchange = self.get_exchange(f'{underly_symbol}99') exchange = self.get_exchange(f'{underly_symbol}99')
vt_symbol = '.'.join([symbol, exchange.value]) vt_symbol = '.'.join([symbol, exchange.value])
else:
symbol = vt_symbol
exchange = Exchange.LOCAL
vt_symbol = '.'.join([symbol, exchange.value])
# 在期货组合回测,中需要把一般配置的主力合约,更换为指数合约 # 在期货组合回测,中需要把一般配置的主力合约,更换为指数合约
if '99' not in symbol and exchange != Exchange.SPD: if '99' not in symbol and exchange != Exchange.SPD and self.contract_type == 'future':
underly_symbol = get_underlying_symbol(symbol).upper() underly_symbol = get_underlying_symbol(symbol).upper()
self.write_log(u'更新vt_symbol为指数合约:{}=>{}'.format(vt_symbol, underly_symbol + '99.' + exchange.value)) self.write_log(u'更新vt_symbol为指数合约:{}=>{}'.format(vt_symbol, underly_symbol + '99.' + exchange.value))
vt_symbol = underly_symbol.upper() + '99.' + exchange.value vt_symbol = underly_symbol.upper() + '99.' + exchange.value
@ -1025,6 +1067,8 @@ class BackTestingEngine(object):
strategy.on_stop_order(stop_order) strategy.on_stop_order(stop_order)
strategy.on_order(order) strategy.on_order(order)
self.append_trade(trade) self.append_trade(trade)
holding = self.get_position_holding(vt_symbol=trade.vt_symbol)
holding.update_trade(trade)
strategy.on_trade(trade) strategy.on_trade(trade)
def cross_limit_order(self, bar: BarData = None, tick: TickData = None): def cross_limit_order(self, bar: BarData = None, tick: TickData = None):
@ -1854,10 +1898,28 @@ class BackTestingEngine(object):
self.daily_max_drawdown_rate = drawdown_rate self.daily_max_drawdown_rate = drawdown_rate
self.max_drawdown_rate_time = data['date'] self.max_drawdown_rate_time = data['date']
self.write_log(u'{}: net={}, capital={} max={} margin={} commission={} pos: {}' msg = u'{}: net={}, capital={} max={} margin={} commission={} pos: {}'\
.format(data['date'], data['net'], c, m, .format(data['date'],
today_holding_profit, commission, data['net'], c, m,
positionMsg)) today_holding_profit,
commission,
positionMsg)
if not self.debug:
self.output(msg)
else:
self.write_log(msg)
# 今仓 =》 昨仓
for holding in self.holdings.values():
if holding.long_td > 0:
self.write_log(f'{holding.vt_symbol} 多单今仓{holding.long_td},昨仓:{holding.long_yd}=> 昨仓:{holding.long_pos}')
holding.long_td = 0
holding.long_yd = holding.long_pos
if holding.short_td > 0:
self.write_log(f'{holding.vt_symbol} 空单今仓{holding.short_td},昨仓:{holding.short_yd}=> 昨仓:{holding.short_pos}')
holding.short_td = 0
holding.short_yd = holding.short_pos
# --------------------------------------------------------------------- # ---------------------------------------------------------------------
def export_trade_result(self): def export_trade_result(self):

View File

@ -50,8 +50,6 @@ class PortfolioTestingEngine(BackTestingEngine):
"""Constructor""" """Constructor"""
super().__init__(event_engine) super().__init__(event_engine)
self.mode = 'bar' # 'bar': 根据1分钟k线进行回测 'tick'根据分笔tick进行回测
self.bar_csv_file = {} self.bar_csv_file = {}
self.bar_df_dict = {} # 历史数据的df回测用 self.bar_df_dict = {} # 历史数据的df回测用
self.bar_df = None # 历史数据的df时间+symbol作为组合索引 self.bar_df = None # 历史数据的df时间+symbol作为组合索引
@ -116,83 +114,8 @@ class PortfolioTestingEngine(BackTestingEngine):
self.bar_df_dict.clear() self.bar_df_dict.clear()
def prepare_env(self, test_settings): def prepare_env(self, test_settings):
self.output('prepare_env') self.output('portfolio prepare_env')
super().prepare_env(test_settings)
if 'name' in test_settings:
self.set_name(test_settings.get('name'))
self.mode = test_settings.get('mode', 'bar')
self.output(f'采用{self.mode}方式回测')
self.debug = test_settings.get('debug', False)
# 更新数据目录
if 'data_path' in test_settings:
self.data_path = test_settings.get('data_path')
else:
self.data_path = os.path.abspath(os.path.join(os.getcwd(), 'data'))
self.output(f'数据输出目录:{self.data_path}')
# 更新日志目录
if 'logs_path' in test_settings:
self.logs_path = os.path.abspath(os.path.join(test_settings.get('logs_path'), self.test_name))
else:
self.logs_path = os.path.abspath(os.path.join(os.getcwd(), 'log', self.test_name))
self.output(f'日志输出目录:{self.logs_path}')
# 创建日志
self.create_logger(debug=self.debug)
# 设置资金
if 'init_capital' in test_settings:
self.write_log(u'设置期初资金:{}'.format(test_settings.get('init_capital')))
self.set_init_capital(test_settings.get('init_capital'))
# 缺省使用保证金方式。
self.use_margin = test_settings.get('use_margin', True)
# 设置最大资金使用比例
if 'percent_limit' in test_settings:
self.write_log(u'设置最大资金使用比例:{}%'.format(test_settings.get('percent_limit')))
self.percent_limit = test_settings.get('percent_limit')
if 'start_date' in test_settings:
if 'strategy_start_date' not in test_settings:
init_days = test_settings.get('init_days', 10)
self.write_log(u'设置回测开始日期:{},数据加载日数:{}'.format(test_settings.get('start_date'), init_days))
self.set_test_start_date(test_settings.get('start_date'), init_days)
else:
start_date = test_settings.get('start_date')
strategy_start_date = test_settings.get('strategy_start_date')
self.write_log(u'使用指定的数据开始日期:{}和策略启动日期:{}'.format(start_date, strategy_start_date))
self.test_start_date = start_date
self.data_start_date = datetime.strptime(start_date.replace('-', ''), '%Y%m%d')
self.strategy_start_date = datetime.strptime(strategy_start_date.replace('-', ''), '%Y%m%d')
if 'end_date' in test_settings:
self.write_log(u'设置回测结束日期:{}'.format(test_settings.get('end_date')))
self.set_test_end_date(test_settings.get('end_date'))
# 设置bar文件的时间间隔秒数
if 'bar_interval_seconds' in test_settings:
self.write_log(u'设置bar文件的时间间隔秒数{}'.format(test_settings.get('bar_interval_seconds')))
self.bar_interval_seconds = test_settings.get('bar_interval_seconds')
# 准备数据
if 'symbol_datas' in test_settings:
self.write_log(u'准备数据')
self.prepare_data(test_settings.get('symbol_datas'))
if self.mode == 'tick':
self.tick_path = test_settings.get('tick_path', None)
self.acivte_fund_kline = test_settings.get('acivte_fund_kline', False)
if self.acivte_fund_kline:
# 创建资金K线
self.create_fund_kline(self.test_name, use_renko=test_settings.get('use_renko', False))
self.load_strategy_class()
def prepare_data(self, data_dict): def prepare_data(self, data_dict):
""" """
@ -458,7 +381,7 @@ class PortfolioTestingEngine(BackTestingEngine):
continue continue
try: try:
for (dt, vt_symbol), bar_data in combined_df.iterrows(): for (dt, vt_symbol), tick_data in combined_df.iterrows():
symbol, exchange = extract_vt_symbol(vt_symbol) symbol, exchange = extract_vt_symbol(vt_symbol)
tick = TickData( tick = TickData(
gateway_name='backtesting', gateway_name='backtesting',
@ -468,8 +391,8 @@ class PortfolioTestingEngine(BackTestingEngine):
date=dt.strftime('%Y-%m-%d'), date=dt.strftime('%Y-%m-%d'),
time=dt.strftime('%H:%M:%S.%f'), time=dt.strftime('%H:%M:%S.%f'),
trading_day=test_day.strftime('%Y-%m-%d'), trading_day=test_day.strftime('%Y-%m-%d'),
last_price=bar_data['price'], last_price=tick_data['price'],
volume=bar_data['volume'] volume=tick_data['volume']
) )
self.new_tick(tick) self.new_tick(tick)

View File

@ -4,6 +4,7 @@ import uuid
import bz2 import bz2
import pickle import pickle
import traceback import traceback
import zlib
from abc import ABC from abc import ABC
from copy import copy from copy import copy
@ -704,9 +705,12 @@ class CtaProTemplate(CtaTemplate):
d = { d = {
'strategy': self.strategy_name, 'strategy': self.strategy_name,
'datetime': datetime.now()} 'datetime': datetime.now()}
klines = {}
for kline_name in sorted(self.klines.keys()): for kline_name in sorted(self.klines.keys()):
d.update({kline_name: self.klines.get(kline_name).get_data()}) klines.update({kline_name: self.klines.get(kline_name).get_data()})
kline_names = list(klines.keys())
binary_data = zlib.compress(pickle.dumps(klines))
d.update({'kline_names': kline_names, 'klines': binary_data, 'zlib': True})
return d return d
except Exception as ex: except Exception as ex:
self.write_error(f'获取klines切片数据失败:{str(ex)}') self.write_error(f'获取klines切片数据失败:{str(ex)}')
@ -1070,7 +1074,7 @@ class CtaProFutureTemplate(CtaProTemplate):
self.write_log(u'load_policy(),初始化Policy') self.write_log(u'load_policy(),初始化Policy')
self.policy.load() self.policy.load()
self.write_log(u'Policy:{}'.format(self.policy.toJson())) self.write_log(u'Policy:{}'.format(self.policy.to_json()))
def on_start(self): def on_start(self):
"""启动策略(必须由用户继承实现)""" """启动策略(必须由用户继承实现)"""
@ -1175,9 +1179,6 @@ class CtaProFutureTemplate(CtaProTemplate):
self.write_log(u'{},委托单:{}全部完成'.format(order.time, order.vt_orderid)) self.write_log(u'{},委托单:{}全部完成'.format(order.time, order.vt_orderid))
order_info = self.active_orders[order.vt_orderid] order_info = self.active_orders[order.vt_orderid]
# 平空仓完成(cover)
if order_info['direction'] == Direction.LONG.value and order.offset != Offset.OPEN:
self.write_log(u'{}平空仓完成(cover),委托价格:{}'.format(order.vt_symbol, order.price))
# 通过vt_orderid找到对应的网格 # 通过vt_orderid找到对应的网格
grid = order_info.get('grid', None) grid = order_info.get('grid', None)
if grid is not None: if grid is not None:
@ -1657,7 +1658,7 @@ class CtaProFutureTemplate(CtaProTemplate):
:param 平仓网格 :param 平仓网格
:return: :return:
""" """
self.write_log(u'执行事务平多仓位:{}'.format(grid.toJson())) self.write_log(u'执行事务平多仓位:{}'.format(grid.to_json()))
# 平仓网格得合约 # 平仓网格得合约
sell_symbol = grid.snapshot.get('mi_symbol', self.vt_symbol) sell_symbol = grid.snapshot.get('mi_symbol', self.vt_symbol)
@ -1748,7 +1749,7 @@ class CtaProFutureTemplate(CtaProTemplate):
:param 平仓网格 :param 平仓网格
:return: :return:
""" """
self.write_log(u'执行事务平空仓位:{}'.format(grid.toJson())) self.write_log(u'执行事务平空仓位:{}'.format(grid.to_json()))
# 平仓网格得合约 # 平仓网格得合约
cover_symbol = grid.snapshot.get('mi_symbol', self.vt_symbol) cover_symbol = grid.snapshot.get('mi_symbol', self.vt_symbol)
# vt_symbol => holding position # vt_symbol => holding position
@ -1849,7 +1850,7 @@ class CtaProFutureTemplate(CtaProTemplate):
symbol = g.snapshot.get('mi_symbol', self.vt_symbol) symbol = g.snapshot.get('mi_symbol', self.vt_symbol)
if g.order_status or g.order_ids: if g.order_status or g.order_ids:
self.write_log(u'当前对锁格:{}存在委托,不纳入计算'.format(g.toJson())) self.write_log(u'当前对锁格:{}存在委托,不纳入计算'.format(g.to_json()))
continue continue
if symbol != open_symbol: if symbol != open_symbol:
@ -1874,7 +1875,7 @@ class CtaProFutureTemplate(CtaProTemplate):
for g in locked_short_grids: for g in locked_short_grids:
symbol = g.snapshot.get('mi_symbol', self.vt_symbol) symbol = g.snapshot.get('mi_symbol', self.vt_symbol)
if g.order_status or g.order_ids: if g.order_status or g.order_ids:
self.write_log(u'当前对锁格:{}存在委托,不进行解锁'.format(g.toJson())) self.write_log(u'当前对锁格:{}存在委托,不进行解锁'.format(g.to_json()))
continue continue
if symbol != open_symbol: if symbol != open_symbol:
self.write_log(u'不处理symbol不一致: 委托请求:{}, Grid mi Symbol:{}'.format(open_symbol, symbol)) self.write_log(u'不处理symbol不一致: 委托请求:{}, Grid mi Symbol:{}'.format(open_symbol, symbol))
@ -2037,7 +2038,7 @@ class CtaProFutureTemplate(CtaProTemplate):
volume = g.volume - g.traded_volume volume = g.volume - g.traded_volume
locked_long_dict.update({vt_symbol: locked_long_dict.get(vt_symbol, 0) + volume}) locked_long_dict.update({vt_symbol: locked_long_dict.get(vt_symbol, 0) + volume})
if g.orderStatus or g.order_ids: if g.orderStatus or g.order_ids:
self.write_log(u'当前对锁格:{}存在委托,不进行解锁'.format(g.toJson())) self.write_log(u'当前对锁格:{}存在委托,不进行解锁'.format(g.to_json()))
return return
locked_long_volume = sum(locked_long_dict.values(), 0) locked_long_volume = sum(locked_long_dict.values(), 0)
@ -2054,14 +2055,14 @@ class CtaProFutureTemplate(CtaProTemplate):
volume = g.volume - g.traded_volume volume = g.volume - g.traded_volume
locked_short_dict.update({vt_symbol: locked_short_dict.get(vt_symbol, 0) + volume}) locked_short_dict.update({vt_symbol: locked_short_dict.get(vt_symbol, 0) + volume})
if g.orderStatus or g.order_ids: if g.orderStatus or g.order_ids:
self.write_log(u'当前对锁格:{}存在委托,不进行解锁'.format(g.toJson())) self.write_log(u'当前对锁格:{}存在委托,不进行解锁'.format(g.to_json()))
return return
locked_short_volume = sum(locked_short_dict.values(), 0) locked_short_volume = sum(locked_short_dict.values(), 0)
# debug info # debug info
self.write_log(u'多单对锁格:{}'.format([g.toJson() for g in locked_long_grids])) self.write_log(u'多单对锁格:{}'.format([g.to_json() for g in locked_long_grids]))
self.write_log(u'空单对锁格:{}'.format([g.toJson() for g in locked_short_grids])) self.write_log(u'空单对锁格:{}'.format([g.to_json() for g in locked_short_grids]))
if locked_long_volume != locked_short_volume: if locked_long_volume != locked_short_volume:
self.write_error(u'对锁格多空数量不一致,不能解锁.\n多:{},\n空:{}' self.write_error(u'对锁格多空数量不一致,不能解锁.\n多:{},\n空:{}'