[增强] 回测引擎支持PositionHolding
This commit is contained in:
parent
cfedf54cde
commit
72222b5c49
37
prod/jobs/refill_future_renko.py
Normal file
37
prod/jobs/refill_future_renko.py
Normal file
@ -0,0 +1,37 @@
|
||||
# flake8: noqa
|
||||
|
||||
import sys, os, copy, csv, signal
|
||||
|
||||
vnpy_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
|
||||
if vnpy_root not in sys.path:
|
||||
print(f'append {vnpy_root} into sys.path')
|
||||
sys.path.append(vnpy_root)
|
||||
|
||||
os.environ["VNPY_TESTING"] = "1"
|
||||
|
||||
from vnpy.data.renko.rebuild_future import *
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
if len(sys.argv) < 4:
|
||||
print(u'请输入三个参数 host symbol pricetick')
|
||||
exit()
|
||||
print(sys.argv)
|
||||
host = sys.argv[1]
|
||||
|
||||
setting = {
|
||||
"host": host,
|
||||
"db_name": FUTURE_RENKO_DB_NAME,
|
||||
"cache_folder": os.path.join(vnpy_root, 'ticks', 'tdx', 'future')
|
||||
}
|
||||
builder = FutureRenkoRebuilder(setting)
|
||||
|
||||
symbol = sys.argv[2]
|
||||
price_tick = float(sys.argv[3])
|
||||
|
||||
print(f'启动期货renko补全,数据库:{host}/{FUTURE_RENKO_DB_NAME} 合约:{symbol}')
|
||||
builder.start(symbol=symbol, price_tick=price_tick, height=[3, 5, 10, 'K3', 'K5', 'K10'], refill=True)
|
||||
|
||||
print(f'exit refill {symbol} renkos')
|
||||
|
63
tests/renko/ui_renko_reverse_klines.py
Normal file
63
tests/renko/ui_renko_reverse_klines.py
Normal file
@ -0,0 +1,63 @@
|
||||
# flake8: noqa
|
||||
"""
|
||||
多周期显示K线,
|
||||
时间点同步
|
||||
华富资产/李来佳
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import ctypes
|
||||
import platform
|
||||
system = platform.system()
|
||||
os.environ["VNPY_TESTING"] = "1"
|
||||
|
||||
# 将repostory的目录,作为根目录,添加到系统环境中。
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), '..' , '..'))
|
||||
sys.path.append(ROOT_PATH)
|
||||
|
||||
from vnpy.trader.ui.kline.crosshair import Crosshair
|
||||
from vnpy.trader.ui.kline.kline import *
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# K线界面
|
||||
try:
|
||||
kline_settings = {
|
||||
"M15":
|
||||
{
|
||||
"data_file": "log/renko_reverse_v1_J99_M15.csv",
|
||||
"main_indicators": [
|
||||
"pre_high",
|
||||
"pre_low",
|
||||
"ma5",
|
||||
"ma18",
|
||||
"ma60"
|
||||
],
|
||||
"sub_indicators": [
|
||||
"atr"
|
||||
]
|
||||
},
|
||||
"Renko_K3":
|
||||
{
|
||||
"data_file": "log/renko_reverse_v1_J99_Renko_K3.csv",
|
||||
"main_indicators": [
|
||||
"pre_high",
|
||||
"pre_low",
|
||||
"boll_upper",
|
||||
"boll_middle",
|
||||
"boll_lower",
|
||||
"ma5",
|
||||
"yb"
|
||||
],
|
||||
"sub_indicators": [
|
||||
"sk",
|
||||
"sd",
|
||||
]
|
||||
|
||||
}
|
||||
}
|
||||
display_multi_grid(kline_settings)
|
||||
|
||||
except Exception as ex:
|
||||
print(u'exception:{},trace:{}'.format(str(ex), traceback.format_exc()))
|
@ -77,8 +77,11 @@ class BackTestingEngine(object):
|
||||
# 绑定事件引擎
|
||||
self.event_engine = event_engine
|
||||
|
||||
self.mode = 'bar' # 'bar': 根据1分钟k线进行回测, 'tick',根据分笔tick进行回测
|
||||
|
||||
# 引擎类型为回测
|
||||
self.engine_type = EngineType.BACKTESTING
|
||||
self.contract_type = 'future' # future, stock, digital
|
||||
|
||||
# 回测策略相关
|
||||
self.classes = {} # 策略类,class_name: stategy_class
|
||||
@ -129,6 +132,8 @@ class BackTestingEngine(object):
|
||||
self.long_position_list = [] # 多单持仓
|
||||
self.short_position_list = [] # 空单持仓
|
||||
|
||||
self.holdings = {} # 多空持仓
|
||||
|
||||
# 当前最新数据,用于模拟成交用
|
||||
self.gateway_name = u'BackTest'
|
||||
|
||||
@ -357,6 +362,30 @@ class BackTestingEngine(object):
|
||||
def get_exchange(self, symbol: str):
|
||||
return self.symbol_exchange_dict.get(symbol, Exchange.LOCAL)
|
||||
|
||||
def get_position_holding(self, vt_symbol: str, gateway_name: str = ''):
|
||||
""" 查询合约在账号的持仓(包含多空)"""
|
||||
k = f'{gateway_name}.{vt_symbol}'
|
||||
holding = self.holdings.get(k, None)
|
||||
if not holding:
|
||||
symbol, exchange = extract_vt_symbol(vt_symbol)
|
||||
if self.contract_type== 'future':
|
||||
product = Product.FUTURES
|
||||
elif self.contract_type == 'stock':
|
||||
product = Product.EQUITY
|
||||
else:
|
||||
product = Product.SPOT
|
||||
contract = ContractData(gateway_name=gateway_name,
|
||||
name=vt_symbol,
|
||||
product=product,
|
||||
symbol=symbol,
|
||||
exchange=exchange,
|
||||
size=self.get_size(vt_symbol),
|
||||
pricetick=self.get_price_tick(vt_symbol),
|
||||
margin_rate=self.get_margin_rate(vt_symbol))
|
||||
holding = PositionHolding(contract)
|
||||
self.holdings[k] = holding
|
||||
return holding
|
||||
|
||||
def set_name(self, test_name):
|
||||
"""
|
||||
设置组合的运行实例名称
|
||||
@ -387,6 +416,12 @@ class BackTestingEngine(object):
|
||||
if 'name' in test_settings:
|
||||
self.set_name(test_settings.get('name'))
|
||||
|
||||
self.mode = test_settings.get('mode', 'bar')
|
||||
self.output(f'采用{self.mode}方式回测')
|
||||
|
||||
self.contract_type = test_settings.get('contract_type', 'future')
|
||||
self.output(f'测试合约主要为{self.contract_type}')
|
||||
|
||||
self.debug = test_settings.get('debug', False)
|
||||
|
||||
# 更新数据目录
|
||||
@ -442,6 +477,9 @@ class BackTestingEngine(object):
|
||||
self.write_log(u'准备数据')
|
||||
self.prepare_data(test_settings.get('symbol_datas'))
|
||||
|
||||
if self.mode == 'tick':
|
||||
self.tick_path = test_settings.get('tick_path', None)
|
||||
|
||||
# 设置bar文件的时间间隔秒数
|
||||
if 'bar_interval_seconds' in test_settings:
|
||||
self.write_log(u'设置bar文件的时间间隔秒数:{}'.format(test_settings.get('bar_interval_seconds')))
|
||||
@ -639,14 +677,18 @@ class BackTestingEngine(object):
|
||||
vt_symbol = strategy_setting.get('vt_symbol')
|
||||
if '.' in vt_symbol:
|
||||
symbol, exchange = extract_vt_symbol(vt_symbol)
|
||||
else:
|
||||
elif self.contract_type == 'future':
|
||||
symbol = vt_symbol
|
||||
underly_symbol = get_underlying_symbol(symbol).upper()
|
||||
exchange = self.get_exchange(f'{underly_symbol}99')
|
||||
vt_symbol = '.'.join([symbol, exchange.value])
|
||||
else:
|
||||
symbol = vt_symbol
|
||||
exchange = Exchange.LOCAL
|
||||
vt_symbol = '.'.join([symbol, exchange.value])
|
||||
|
||||
# 在期货组合回测,中需要把一般配置的主力合约,更换为指数合约
|
||||
if '99' not in symbol and exchange != Exchange.SPD:
|
||||
if '99' not in symbol and exchange != Exchange.SPD and self.contract_type == 'future':
|
||||
underly_symbol = get_underlying_symbol(symbol).upper()
|
||||
self.write_log(u'更新vt_symbol为指数合约:{}=>{}'.format(vt_symbol, underly_symbol + '99.' + exchange.value))
|
||||
vt_symbol = underly_symbol.upper() + '99.' + exchange.value
|
||||
@ -1025,6 +1067,8 @@ class BackTestingEngine(object):
|
||||
strategy.on_stop_order(stop_order)
|
||||
strategy.on_order(order)
|
||||
self.append_trade(trade)
|
||||
holding = self.get_position_holding(vt_symbol=trade.vt_symbol)
|
||||
holding.update_trade(trade)
|
||||
strategy.on_trade(trade)
|
||||
|
||||
def cross_limit_order(self, bar: BarData = None, tick: TickData = None):
|
||||
@ -1854,10 +1898,28 @@ class BackTestingEngine(object):
|
||||
self.daily_max_drawdown_rate = drawdown_rate
|
||||
self.max_drawdown_rate_time = data['date']
|
||||
|
||||
self.write_log(u'{}: net={}, capital={} max={} margin={} commission={}, pos: {}'
|
||||
.format(data['date'], data['net'], c, m,
|
||||
today_holding_profit, commission,
|
||||
positionMsg))
|
||||
msg = u'{}: net={}, capital={} max={} margin={} commission={}, pos: {}'\
|
||||
.format(data['date'],
|
||||
data['net'], c, m,
|
||||
today_holding_profit,
|
||||
commission,
|
||||
positionMsg)
|
||||
|
||||
if not self.debug:
|
||||
self.output(msg)
|
||||
else:
|
||||
self.write_log(msg)
|
||||
|
||||
# 今仓 =》 昨仓
|
||||
for holding in self.holdings.values():
|
||||
if holding.long_td > 0:
|
||||
self.write_log(f'{holding.vt_symbol} 多单今仓{holding.long_td},昨仓:{holding.long_yd}=> 昨仓:{holding.long_pos}')
|
||||
holding.long_td = 0
|
||||
holding.long_yd = holding.long_pos
|
||||
if holding.short_td > 0:
|
||||
self.write_log(f'{holding.vt_symbol} 空单今仓{holding.short_td},昨仓:{holding.short_yd}=> 昨仓:{holding.short_pos}')
|
||||
holding.short_td = 0
|
||||
holding.short_yd = holding.short_pos
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
def export_trade_result(self):
|
||||
|
@ -50,8 +50,6 @@ class PortfolioTestingEngine(BackTestingEngine):
|
||||
"""Constructor"""
|
||||
super().__init__(event_engine)
|
||||
|
||||
self.mode = 'bar' # 'bar': 根据1分钟k线进行回测, 'tick',根据分笔tick进行回测
|
||||
|
||||
self.bar_csv_file = {}
|
||||
self.bar_df_dict = {} # 历史数据的df,回测用
|
||||
self.bar_df = None # 历史数据的df,时间+symbol作为组合索引
|
||||
@ -116,83 +114,8 @@ class PortfolioTestingEngine(BackTestingEngine):
|
||||
self.bar_df_dict.clear()
|
||||
|
||||
def prepare_env(self, test_settings):
|
||||
self.output('prepare_env')
|
||||
|
||||
if 'name' in test_settings:
|
||||
self.set_name(test_settings.get('name'))
|
||||
|
||||
self.mode = test_settings.get('mode', 'bar')
|
||||
self.output(f'采用{self.mode}方式回测')
|
||||
|
||||
self.debug = test_settings.get('debug', False)
|
||||
|
||||
# 更新数据目录
|
||||
if 'data_path' in test_settings:
|
||||
self.data_path = test_settings.get('data_path')
|
||||
else:
|
||||
self.data_path = os.path.abspath(os.path.join(os.getcwd(), 'data'))
|
||||
|
||||
self.output(f'数据输出目录:{self.data_path}')
|
||||
|
||||
# 更新日志目录
|
||||
if 'logs_path' in test_settings:
|
||||
self.logs_path = os.path.abspath(os.path.join(test_settings.get('logs_path'), self.test_name))
|
||||
else:
|
||||
self.logs_path = os.path.abspath(os.path.join(os.getcwd(), 'log', self.test_name))
|
||||
self.output(f'日志输出目录:{self.logs_path}')
|
||||
|
||||
# 创建日志
|
||||
self.create_logger(debug=self.debug)
|
||||
|
||||
# 设置资金
|
||||
if 'init_capital' in test_settings:
|
||||
self.write_log(u'设置期初资金:{}'.format(test_settings.get('init_capital')))
|
||||
self.set_init_capital(test_settings.get('init_capital'))
|
||||
|
||||
# 缺省使用保证金方式。
|
||||
self.use_margin = test_settings.get('use_margin', True)
|
||||
|
||||
# 设置最大资金使用比例
|
||||
if 'percent_limit' in test_settings:
|
||||
self.write_log(u'设置最大资金使用比例:{}%'.format(test_settings.get('percent_limit')))
|
||||
self.percent_limit = test_settings.get('percent_limit')
|
||||
|
||||
if 'start_date' in test_settings:
|
||||
if 'strategy_start_date' not in test_settings:
|
||||
init_days = test_settings.get('init_days', 10)
|
||||
self.write_log(u'设置回测开始日期:{},数据加载日数:{}'.format(test_settings.get('start_date'), init_days))
|
||||
self.set_test_start_date(test_settings.get('start_date'), init_days)
|
||||
else:
|
||||
start_date = test_settings.get('start_date')
|
||||
strategy_start_date = test_settings.get('strategy_start_date')
|
||||
self.write_log(u'使用指定的数据开始日期:{}和策略启动日期:{}'.format(start_date, strategy_start_date))
|
||||
self.test_start_date = start_date
|
||||
self.data_start_date = datetime.strptime(start_date.replace('-', ''), '%Y%m%d')
|
||||
self.strategy_start_date = datetime.strptime(strategy_start_date.replace('-', ''), '%Y%m%d')
|
||||
|
||||
if 'end_date' in test_settings:
|
||||
self.write_log(u'设置回测结束日期:{}'.format(test_settings.get('end_date')))
|
||||
self.set_test_end_date(test_settings.get('end_date'))
|
||||
|
||||
# 设置bar文件的时间间隔秒数
|
||||
if 'bar_interval_seconds' in test_settings:
|
||||
self.write_log(u'设置bar文件的时间间隔秒数:{}'.format(test_settings.get('bar_interval_seconds')))
|
||||
self.bar_interval_seconds = test_settings.get('bar_interval_seconds')
|
||||
|
||||
# 准备数据
|
||||
if 'symbol_datas' in test_settings:
|
||||
self.write_log(u'准备数据')
|
||||
self.prepare_data(test_settings.get('symbol_datas'))
|
||||
|
||||
if self.mode == 'tick':
|
||||
self.tick_path = test_settings.get('tick_path', None)
|
||||
|
||||
self.acivte_fund_kline = test_settings.get('acivte_fund_kline', False)
|
||||
if self.acivte_fund_kline:
|
||||
# 创建资金K线
|
||||
self.create_fund_kline(self.test_name, use_renko=test_settings.get('use_renko', False))
|
||||
|
||||
self.load_strategy_class()
|
||||
self.output('portfolio prepare_env')
|
||||
super().prepare_env(test_settings)
|
||||
|
||||
def prepare_data(self, data_dict):
|
||||
"""
|
||||
@ -458,7 +381,7 @@ class PortfolioTestingEngine(BackTestingEngine):
|
||||
continue
|
||||
|
||||
try:
|
||||
for (dt, vt_symbol), bar_data in combined_df.iterrows():
|
||||
for (dt, vt_symbol), tick_data in combined_df.iterrows():
|
||||
symbol, exchange = extract_vt_symbol(vt_symbol)
|
||||
tick = TickData(
|
||||
gateway_name='backtesting',
|
||||
@ -468,8 +391,8 @@ class PortfolioTestingEngine(BackTestingEngine):
|
||||
date=dt.strftime('%Y-%m-%d'),
|
||||
time=dt.strftime('%H:%M:%S.%f'),
|
||||
trading_day=test_day.strftime('%Y-%m-%d'),
|
||||
last_price=bar_data['price'],
|
||||
volume=bar_data['volume']
|
||||
last_price=tick_data['price'],
|
||||
volume=tick_data['volume']
|
||||
)
|
||||
|
||||
self.new_tick(tick)
|
||||
|
@ -4,6 +4,7 @@ import uuid
|
||||
import bz2
|
||||
import pickle
|
||||
import traceback
|
||||
import zlib
|
||||
|
||||
from abc import ABC
|
||||
from copy import copy
|
||||
@ -704,9 +705,12 @@ class CtaProTemplate(CtaTemplate):
|
||||
d = {
|
||||
'strategy': self.strategy_name,
|
||||
'datetime': datetime.now()}
|
||||
|
||||
klines = {}
|
||||
for kline_name in sorted(self.klines.keys()):
|
||||
d.update({kline_name: self.klines.get(kline_name).get_data()})
|
||||
klines.update({kline_name: self.klines.get(kline_name).get_data()})
|
||||
kline_names = list(klines.keys())
|
||||
binary_data = zlib.compress(pickle.dumps(klines))
|
||||
d.update({'kline_names': kline_names, 'klines': binary_data, 'zlib': True})
|
||||
return d
|
||||
except Exception as ex:
|
||||
self.write_error(f'获取klines切片数据失败:{str(ex)}')
|
||||
@ -1070,7 +1074,7 @@ class CtaProFutureTemplate(CtaProTemplate):
|
||||
self.write_log(u'load_policy(),初始化Policy')
|
||||
|
||||
self.policy.load()
|
||||
self.write_log(u'Policy:{}'.format(self.policy.toJson()))
|
||||
self.write_log(u'Policy:{}'.format(self.policy.to_json()))
|
||||
|
||||
def on_start(self):
|
||||
"""启动策略(必须由用户继承实现)"""
|
||||
@ -1175,9 +1179,6 @@ class CtaProFutureTemplate(CtaProTemplate):
|
||||
self.write_log(u'{},委托单:{}全部完成'.format(order.time, order.vt_orderid))
|
||||
order_info = self.active_orders[order.vt_orderid]
|
||||
|
||||
# 平空仓完成(cover)
|
||||
if order_info['direction'] == Direction.LONG.value and order.offset != Offset.OPEN:
|
||||
self.write_log(u'{}平空仓完成(cover),委托价格:{}'.format(order.vt_symbol, order.price))
|
||||
# 通过vt_orderid,找到对应的网格
|
||||
grid = order_info.get('grid', None)
|
||||
if grid is not None:
|
||||
@ -1657,7 +1658,7 @@ class CtaProFutureTemplate(CtaProTemplate):
|
||||
:param 平仓网格
|
||||
:return:
|
||||
"""
|
||||
self.write_log(u'执行事务平多仓位:{}'.format(grid.toJson()))
|
||||
self.write_log(u'执行事务平多仓位:{}'.format(grid.to_json()))
|
||||
|
||||
# 平仓网格得合约
|
||||
sell_symbol = grid.snapshot.get('mi_symbol', self.vt_symbol)
|
||||
@ -1748,7 +1749,7 @@ class CtaProFutureTemplate(CtaProTemplate):
|
||||
:param 平仓网格
|
||||
:return:
|
||||
"""
|
||||
self.write_log(u'执行事务平空仓位:{}'.format(grid.toJson()))
|
||||
self.write_log(u'执行事务平空仓位:{}'.format(grid.to_json()))
|
||||
# 平仓网格得合约
|
||||
cover_symbol = grid.snapshot.get('mi_symbol', self.vt_symbol)
|
||||
# vt_symbol => holding position
|
||||
@ -1849,7 +1850,7 @@ class CtaProFutureTemplate(CtaProTemplate):
|
||||
symbol = g.snapshot.get('mi_symbol', self.vt_symbol)
|
||||
|
||||
if g.order_status or g.order_ids:
|
||||
self.write_log(u'当前对锁格:{}存在委托,不纳入计算'.format(g.toJson()))
|
||||
self.write_log(u'当前对锁格:{}存在委托,不纳入计算'.format(g.to_json()))
|
||||
continue
|
||||
|
||||
if symbol != open_symbol:
|
||||
@ -1874,7 +1875,7 @@ class CtaProFutureTemplate(CtaProTemplate):
|
||||
for g in locked_short_grids:
|
||||
symbol = g.snapshot.get('mi_symbol', self.vt_symbol)
|
||||
if g.order_status or g.order_ids:
|
||||
self.write_log(u'当前对锁格:{}存在委托,不进行解锁'.format(g.toJson()))
|
||||
self.write_log(u'当前对锁格:{}存在委托,不进行解锁'.format(g.to_json()))
|
||||
continue
|
||||
if symbol != open_symbol:
|
||||
self.write_log(u'不处理symbol不一致: 委托请求:{}, Grid mi Symbol:{}'.format(open_symbol, symbol))
|
||||
@ -2037,7 +2038,7 @@ class CtaProFutureTemplate(CtaProTemplate):
|
||||
volume = g.volume - g.traded_volume
|
||||
locked_long_dict.update({vt_symbol: locked_long_dict.get(vt_symbol, 0) + volume})
|
||||
if g.orderStatus or g.order_ids:
|
||||
self.write_log(u'当前对锁格:{}存在委托,不进行解锁'.format(g.toJson()))
|
||||
self.write_log(u'当前对锁格:{}存在委托,不进行解锁'.format(g.to_json()))
|
||||
return
|
||||
|
||||
locked_long_volume = sum(locked_long_dict.values(), 0)
|
||||
@ -2054,14 +2055,14 @@ class CtaProFutureTemplate(CtaProTemplate):
|
||||
volume = g.volume - g.traded_volume
|
||||
locked_short_dict.update({vt_symbol: locked_short_dict.get(vt_symbol, 0) + volume})
|
||||
if g.orderStatus or g.order_ids:
|
||||
self.write_log(u'当前对锁格:{}存在委托,不进行解锁'.format(g.toJson()))
|
||||
self.write_log(u'当前对锁格:{}存在委托,不进行解锁'.format(g.to_json()))
|
||||
return
|
||||
|
||||
locked_short_volume = sum(locked_short_dict.values(), 0)
|
||||
|
||||
# debug info
|
||||
self.write_log(u'多单对锁格:{}'.format([g.toJson() for g in locked_long_grids]))
|
||||
self.write_log(u'空单对锁格:{}'.format([g.toJson() for g in locked_short_grids]))
|
||||
self.write_log(u'多单对锁格:{}'.format([g.to_json() for g in locked_long_grids]))
|
||||
self.write_log(u'空单对锁格:{}'.format([g.to_json() for g in locked_short_grids]))
|
||||
|
||||
if locked_long_volume != locked_short_volume:
|
||||
self.write_error(u'对锁格多空数量不一致,不能解锁.\n多:{},\n空:{}'
|
||||
|
Loading…
Reference in New Issue
Block a user