[增强] 回测引擎支持PositionHolding

This commit is contained in:
msincenselee 2020-02-29 17:26:30 +08:00
parent cfedf54cde
commit 72222b5c49
5 changed files with 220 additions and 134 deletions

View File

@ -0,0 +1,37 @@
# flake8: noqa
import sys, os, copy, csv, signal
vnpy_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
if vnpy_root not in sys.path:
print(f'append {vnpy_root} into sys.path')
sys.path.append(vnpy_root)
os.environ["VNPY_TESTING"] = "1"
from vnpy.data.renko.rebuild_future import *
if __name__ == "__main__":
if len(sys.argv) < 4:
print(u'请输入三个参数 host symbol pricetick')
exit()
print(sys.argv)
host = sys.argv[1]
setting = {
"host": host,
"db_name": FUTURE_RENKO_DB_NAME,
"cache_folder": os.path.join(vnpy_root, 'ticks', 'tdx', 'future')
}
builder = FutureRenkoRebuilder(setting)
symbol = sys.argv[2]
price_tick = float(sys.argv[3])
print(f'启动期货renko补全,数据库:{host}/{FUTURE_RENKO_DB_NAME} 合约:{symbol}')
builder.start(symbol=symbol, price_tick=price_tick, height=[3, 5, 10, 'K3', 'K5', 'K10'], refill=True)
print(f'exit refill {symbol} renkos')

View File

@ -0,0 +1,63 @@
# flake8: noqa
"""
多周期显示K线
时间点同步
华富资产/李来佳
"""
import sys
import os
import ctypes
import platform
system = platform.system()
os.environ["VNPY_TESTING"] = "1"
# 将repostory的目录作为根目录添加到系统环境中。
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), '..' , '..'))
sys.path.append(ROOT_PATH)
from vnpy.trader.ui.kline.crosshair import Crosshair
from vnpy.trader.ui.kline.kline import *
if __name__ == '__main__':
# K线界面
try:
kline_settings = {
"M15":
{
"data_file": "log/renko_reverse_v1_J99_M15.csv",
"main_indicators": [
"pre_high",
"pre_low",
"ma5",
"ma18",
"ma60"
],
"sub_indicators": [
"atr"
]
},
"Renko_K3":
{
"data_file": "log/renko_reverse_v1_J99_Renko_K3.csv",
"main_indicators": [
"pre_high",
"pre_low",
"boll_upper",
"boll_middle",
"boll_lower",
"ma5",
"yb"
],
"sub_indicators": [
"sk",
"sd",
]
}
}
display_multi_grid(kline_settings)
except Exception as ex:
print(u'exception:{},trace:{}'.format(str(ex), traceback.format_exc()))

View File

@ -77,8 +77,11 @@ class BackTestingEngine(object):
# 绑定事件引擎
self.event_engine = event_engine
self.mode = 'bar' # 'bar': 根据1分钟k线进行回测 'tick'根据分笔tick进行回测
# 引擎类型为回测
self.engine_type = EngineType.BACKTESTING
self.contract_type = 'future' # future, stock, digital
# 回测策略相关
self.classes = {} # 策略类class_name: stategy_class
@ -129,6 +132,8 @@ class BackTestingEngine(object):
self.long_position_list = [] # 多单持仓
self.short_position_list = [] # 空单持仓
self.holdings = {} # 多空持仓
# 当前最新数据,用于模拟成交用
self.gateway_name = u'BackTest'
@ -357,6 +362,30 @@ class BackTestingEngine(object):
def get_exchange(self, symbol: str):
return self.symbol_exchange_dict.get(symbol, Exchange.LOCAL)
def get_position_holding(self, vt_symbol: str, gateway_name: str = ''):
""" 查询合约在账号的持仓(包含多空)"""
k = f'{gateway_name}.{vt_symbol}'
holding = self.holdings.get(k, None)
if not holding:
symbol, exchange = extract_vt_symbol(vt_symbol)
if self.contract_type== 'future':
product = Product.FUTURES
elif self.contract_type == 'stock':
product = Product.EQUITY
else:
product = Product.SPOT
contract = ContractData(gateway_name=gateway_name,
name=vt_symbol,
product=product,
symbol=symbol,
exchange=exchange,
size=self.get_size(vt_symbol),
pricetick=self.get_price_tick(vt_symbol),
margin_rate=self.get_margin_rate(vt_symbol))
holding = PositionHolding(contract)
self.holdings[k] = holding
return holding
def set_name(self, test_name):
"""
设置组合的运行实例名称
@ -387,6 +416,12 @@ class BackTestingEngine(object):
if 'name' in test_settings:
self.set_name(test_settings.get('name'))
self.mode = test_settings.get('mode', 'bar')
self.output(f'采用{self.mode}方式回测')
self.contract_type = test_settings.get('contract_type', 'future')
self.output(f'测试合约主要为{self.contract_type}')
self.debug = test_settings.get('debug', False)
# 更新数据目录
@ -442,6 +477,9 @@ class BackTestingEngine(object):
self.write_log(u'准备数据')
self.prepare_data(test_settings.get('symbol_datas'))
if self.mode == 'tick':
self.tick_path = test_settings.get('tick_path', None)
# 设置bar文件的时间间隔秒数
if 'bar_interval_seconds' in test_settings:
self.write_log(u'设置bar文件的时间间隔秒数{}'.format(test_settings.get('bar_interval_seconds')))
@ -639,14 +677,18 @@ class BackTestingEngine(object):
vt_symbol = strategy_setting.get('vt_symbol')
if '.' in vt_symbol:
symbol, exchange = extract_vt_symbol(vt_symbol)
else:
elif self.contract_type == 'future':
symbol = vt_symbol
underly_symbol = get_underlying_symbol(symbol).upper()
exchange = self.get_exchange(f'{underly_symbol}99')
vt_symbol = '.'.join([symbol, exchange.value])
else:
symbol = vt_symbol
exchange = Exchange.LOCAL
vt_symbol = '.'.join([symbol, exchange.value])
# 在期货组合回测,中需要把一般配置的主力合约,更换为指数合约
if '99' not in symbol and exchange != Exchange.SPD:
if '99' not in symbol and exchange != Exchange.SPD and self.contract_type == 'future':
underly_symbol = get_underlying_symbol(symbol).upper()
self.write_log(u'更新vt_symbol为指数合约:{}=>{}'.format(vt_symbol, underly_symbol + '99.' + exchange.value))
vt_symbol = underly_symbol.upper() + '99.' + exchange.value
@ -1025,6 +1067,8 @@ class BackTestingEngine(object):
strategy.on_stop_order(stop_order)
strategy.on_order(order)
self.append_trade(trade)
holding = self.get_position_holding(vt_symbol=trade.vt_symbol)
holding.update_trade(trade)
strategy.on_trade(trade)
def cross_limit_order(self, bar: BarData = None, tick: TickData = None):
@ -1854,10 +1898,28 @@ class BackTestingEngine(object):
self.daily_max_drawdown_rate = drawdown_rate
self.max_drawdown_rate_time = data['date']
self.write_log(u'{}: net={}, capital={} max={} margin={} commission={} pos: {}'
.format(data['date'], data['net'], c, m,
today_holding_profit, commission,
positionMsg))
msg = u'{}: net={}, capital={} max={} margin={} commission={} pos: {}'\
.format(data['date'],
data['net'], c, m,
today_holding_profit,
commission,
positionMsg)
if not self.debug:
self.output(msg)
else:
self.write_log(msg)
# 今仓 =》 昨仓
for holding in self.holdings.values():
if holding.long_td > 0:
self.write_log(f'{holding.vt_symbol} 多单今仓{holding.long_td},昨仓:{holding.long_yd}=> 昨仓:{holding.long_pos}')
holding.long_td = 0
holding.long_yd = holding.long_pos
if holding.short_td > 0:
self.write_log(f'{holding.vt_symbol} 空单今仓{holding.short_td},昨仓:{holding.short_yd}=> 昨仓:{holding.short_pos}')
holding.short_td = 0
holding.short_yd = holding.short_pos
# ---------------------------------------------------------------------
def export_trade_result(self):

View File

@ -50,8 +50,6 @@ class PortfolioTestingEngine(BackTestingEngine):
"""Constructor"""
super().__init__(event_engine)
self.mode = 'bar' # 'bar': 根据1分钟k线进行回测 'tick'根据分笔tick进行回测
self.bar_csv_file = {}
self.bar_df_dict = {} # 历史数据的df回测用
self.bar_df = None # 历史数据的df时间+symbol作为组合索引
@ -116,83 +114,8 @@ class PortfolioTestingEngine(BackTestingEngine):
self.bar_df_dict.clear()
def prepare_env(self, test_settings):
self.output('prepare_env')
if 'name' in test_settings:
self.set_name(test_settings.get('name'))
self.mode = test_settings.get('mode', 'bar')
self.output(f'采用{self.mode}方式回测')
self.debug = test_settings.get('debug', False)
# 更新数据目录
if 'data_path' in test_settings:
self.data_path = test_settings.get('data_path')
else:
self.data_path = os.path.abspath(os.path.join(os.getcwd(), 'data'))
self.output(f'数据输出目录:{self.data_path}')
# 更新日志目录
if 'logs_path' in test_settings:
self.logs_path = os.path.abspath(os.path.join(test_settings.get('logs_path'), self.test_name))
else:
self.logs_path = os.path.abspath(os.path.join(os.getcwd(), 'log', self.test_name))
self.output(f'日志输出目录:{self.logs_path}')
# 创建日志
self.create_logger(debug=self.debug)
# 设置资金
if 'init_capital' in test_settings:
self.write_log(u'设置期初资金:{}'.format(test_settings.get('init_capital')))
self.set_init_capital(test_settings.get('init_capital'))
# 缺省使用保证金方式。
self.use_margin = test_settings.get('use_margin', True)
# 设置最大资金使用比例
if 'percent_limit' in test_settings:
self.write_log(u'设置最大资金使用比例:{}%'.format(test_settings.get('percent_limit')))
self.percent_limit = test_settings.get('percent_limit')
if 'start_date' in test_settings:
if 'strategy_start_date' not in test_settings:
init_days = test_settings.get('init_days', 10)
self.write_log(u'设置回测开始日期:{},数据加载日数:{}'.format(test_settings.get('start_date'), init_days))
self.set_test_start_date(test_settings.get('start_date'), init_days)
else:
start_date = test_settings.get('start_date')
strategy_start_date = test_settings.get('strategy_start_date')
self.write_log(u'使用指定的数据开始日期:{}和策略启动日期:{}'.format(start_date, strategy_start_date))
self.test_start_date = start_date
self.data_start_date = datetime.strptime(start_date.replace('-', ''), '%Y%m%d')
self.strategy_start_date = datetime.strptime(strategy_start_date.replace('-', ''), '%Y%m%d')
if 'end_date' in test_settings:
self.write_log(u'设置回测结束日期:{}'.format(test_settings.get('end_date')))
self.set_test_end_date(test_settings.get('end_date'))
# 设置bar文件的时间间隔秒数
if 'bar_interval_seconds' in test_settings:
self.write_log(u'设置bar文件的时间间隔秒数{}'.format(test_settings.get('bar_interval_seconds')))
self.bar_interval_seconds = test_settings.get('bar_interval_seconds')
# 准备数据
if 'symbol_datas' in test_settings:
self.write_log(u'准备数据')
self.prepare_data(test_settings.get('symbol_datas'))
if self.mode == 'tick':
self.tick_path = test_settings.get('tick_path', None)
self.acivte_fund_kline = test_settings.get('acivte_fund_kline', False)
if self.acivte_fund_kline:
# 创建资金K线
self.create_fund_kline(self.test_name, use_renko=test_settings.get('use_renko', False))
self.load_strategy_class()
self.output('portfolio prepare_env')
super().prepare_env(test_settings)
def prepare_data(self, data_dict):
"""
@ -458,7 +381,7 @@ class PortfolioTestingEngine(BackTestingEngine):
continue
try:
for (dt, vt_symbol), bar_data in combined_df.iterrows():
for (dt, vt_symbol), tick_data in combined_df.iterrows():
symbol, exchange = extract_vt_symbol(vt_symbol)
tick = TickData(
gateway_name='backtesting',
@ -468,8 +391,8 @@ class PortfolioTestingEngine(BackTestingEngine):
date=dt.strftime('%Y-%m-%d'),
time=dt.strftime('%H:%M:%S.%f'),
trading_day=test_day.strftime('%Y-%m-%d'),
last_price=bar_data['price'],
volume=bar_data['volume']
last_price=tick_data['price'],
volume=tick_data['volume']
)
self.new_tick(tick)

View File

@ -4,6 +4,7 @@ import uuid
import bz2
import pickle
import traceback
import zlib
from abc import ABC
from copy import copy
@ -704,9 +705,12 @@ class CtaProTemplate(CtaTemplate):
d = {
'strategy': self.strategy_name,
'datetime': datetime.now()}
klines = {}
for kline_name in sorted(self.klines.keys()):
d.update({kline_name: self.klines.get(kline_name).get_data()})
klines.update({kline_name: self.klines.get(kline_name).get_data()})
kline_names = list(klines.keys())
binary_data = zlib.compress(pickle.dumps(klines))
d.update({'kline_names': kline_names, 'klines': binary_data, 'zlib': True})
return d
except Exception as ex:
self.write_error(f'获取klines切片数据失败:{str(ex)}')
@ -1070,7 +1074,7 @@ class CtaProFutureTemplate(CtaProTemplate):
self.write_log(u'load_policy(),初始化Policy')
self.policy.load()
self.write_log(u'Policy:{}'.format(self.policy.toJson()))
self.write_log(u'Policy:{}'.format(self.policy.to_json()))
def on_start(self):
"""启动策略(必须由用户继承实现)"""
@ -1175,9 +1179,6 @@ class CtaProFutureTemplate(CtaProTemplate):
self.write_log(u'{},委托单:{}全部完成'.format(order.time, order.vt_orderid))
order_info = self.active_orders[order.vt_orderid]
# 平空仓完成(cover)
if order_info['direction'] == Direction.LONG.value and order.offset != Offset.OPEN:
self.write_log(u'{}平空仓完成(cover),委托价格:{}'.format(order.vt_symbol, order.price))
# 通过vt_orderid找到对应的网格
grid = order_info.get('grid', None)
if grid is not None:
@ -1657,7 +1658,7 @@ class CtaProFutureTemplate(CtaProTemplate):
:param 平仓网格
:return:
"""
self.write_log(u'执行事务平多仓位:{}'.format(grid.toJson()))
self.write_log(u'执行事务平多仓位:{}'.format(grid.to_json()))
# 平仓网格得合约
sell_symbol = grid.snapshot.get('mi_symbol', self.vt_symbol)
@ -1748,7 +1749,7 @@ class CtaProFutureTemplate(CtaProTemplate):
:param 平仓网格
:return:
"""
self.write_log(u'执行事务平空仓位:{}'.format(grid.toJson()))
self.write_log(u'执行事务平空仓位:{}'.format(grid.to_json()))
# 平仓网格得合约
cover_symbol = grid.snapshot.get('mi_symbol', self.vt_symbol)
# vt_symbol => holding position
@ -1849,7 +1850,7 @@ class CtaProFutureTemplate(CtaProTemplate):
symbol = g.snapshot.get('mi_symbol', self.vt_symbol)
if g.order_status or g.order_ids:
self.write_log(u'当前对锁格:{}存在委托,不纳入计算'.format(g.toJson()))
self.write_log(u'当前对锁格:{}存在委托,不纳入计算'.format(g.to_json()))
continue
if symbol != open_symbol:
@ -1874,7 +1875,7 @@ class CtaProFutureTemplate(CtaProTemplate):
for g in locked_short_grids:
symbol = g.snapshot.get('mi_symbol', self.vt_symbol)
if g.order_status or g.order_ids:
self.write_log(u'当前对锁格:{}存在委托,不进行解锁'.format(g.toJson()))
self.write_log(u'当前对锁格:{}存在委托,不进行解锁'.format(g.to_json()))
continue
if symbol != open_symbol:
self.write_log(u'不处理symbol不一致: 委托请求:{}, Grid mi Symbol:{}'.format(open_symbol, symbol))
@ -2037,7 +2038,7 @@ class CtaProFutureTemplate(CtaProTemplate):
volume = g.volume - g.traded_volume
locked_long_dict.update({vt_symbol: locked_long_dict.get(vt_symbol, 0) + volume})
if g.orderStatus or g.order_ids:
self.write_log(u'当前对锁格:{}存在委托,不进行解锁'.format(g.toJson()))
self.write_log(u'当前对锁格:{}存在委托,不进行解锁'.format(g.to_json()))
return
locked_long_volume = sum(locked_long_dict.values(), 0)
@ -2054,14 +2055,14 @@ class CtaProFutureTemplate(CtaProTemplate):
volume = g.volume - g.traded_volume
locked_short_dict.update({vt_symbol: locked_short_dict.get(vt_symbol, 0) + volume})
if g.orderStatus or g.order_ids:
self.write_log(u'当前对锁格:{}存在委托,不进行解锁'.format(g.toJson()))
self.write_log(u'当前对锁格:{}存在委托,不进行解锁'.format(g.to_json()))
return
locked_short_volume = sum(locked_short_dict.values(), 0)
# debug info
self.write_log(u'多单对锁格:{}'.format([g.toJson() for g in locked_long_grids]))
self.write_log(u'空单对锁格:{}'.format([g.toJson() for g in locked_short_grids]))
self.write_log(u'多单对锁格:{}'.format([g.to_json() for g in locked_long_grids]))
self.write_log(u'空单对锁格:{}'.format([g.to_json() for g in locked_short_grids]))
if locked_long_volume != locked_short_volume:
self.write_error(u'对锁格多空数量不一致,不能解锁.\n多:{},\n空:{}'