[增强] 期货套利回测引擎/套利模板/定义套利合约转算法引擎
This commit is contained in:
parent
2f3f65c694
commit
4ffbd50496
@ -1,11 +1,16 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
from vnpy.event import EventEngine, Event
|
||||
from vnpy.trader.engine import BaseEngine, MainEngine
|
||||
from vnpy.trader.event import (
|
||||
EVENT_TICK, EVENT_TIMER, EVENT_ORDER, EVENT_TRADE)
|
||||
from vnpy.trader.constant import (Direction, Offset, OrderType)
|
||||
from vnpy.trader.object import (SubscribeRequest, OrderRequest, LogData)
|
||||
from vnpy.trader.utility import load_json, save_json, round_to
|
||||
from vnpy.trader.constant import (Direction, Offset, OrderType, Status)
|
||||
from vnpy.trader.object import (SubscribeRequest, OrderRequest, LogData, CancelRequest)
|
||||
from vnpy.trader.utility import load_json, save_json, round_to, get_folder_path
|
||||
from vnpy.trader.util_logger import setup_logger, logging
|
||||
|
||||
from .template import AlgoTemplate
|
||||
|
||||
@ -30,9 +35,13 @@ class AlgoEngine(BaseEngine):
|
||||
self.symbol_algo_map = {}
|
||||
self.orderid_algo_map = {}
|
||||
|
||||
self.algo_vtorderid_order_map = {} # 记录外部发起的算法交易委托编号,便于通过算法引擎撤单
|
||||
|
||||
self.algo_templates = {}
|
||||
self.algo_settings = {}
|
||||
|
||||
self.algo_loggers = {} # algo_name: logger
|
||||
|
||||
self.load_algo_template()
|
||||
self.register_event()
|
||||
|
||||
@ -172,7 +181,7 @@ class AlgoEngine(BaseEngine):
|
||||
""""""
|
||||
contract = self.main_engine.get_contract(vt_symbol)
|
||||
if not contract:
|
||||
self.write_log(f'委托下单失败,找不到合约:{vt_symbol}', algo)
|
||||
self.write_log(f'委托下单失败,找不到合约:{vt_symbol}', algo_name=algo.algo_name)
|
||||
return
|
||||
|
||||
volume = round_to(volume, contract.min_volume)
|
||||
@ -204,33 +213,192 @@ class AlgoEngine(BaseEngine):
|
||||
req = order.create_cancel_request()
|
||||
self.main_engine.cancel_order(req, order.gateway_name)
|
||||
|
||||
def send_algo_order(self, req: OrderRequest, gateway_name: str):
|
||||
"""发送算法交易指令"""
|
||||
self.write_log(u'创建算法交易,gateway_name:{},strategy_name:{},vt_symbol:{},price:{},volume:{}'
|
||||
.format(gateway_name, req.strategy_name, req.vt_symbol, req.price, req.volume))
|
||||
|
||||
# 创建算法实例,由算法引擎启动
|
||||
trade_command = ''
|
||||
if req.direction == Direction.LONG and req.offset == Offset.OPEN:
|
||||
trade_command = 'Buy'
|
||||
elif req.direction == Direction.SHORT and req.offset == Offset.OPEN:
|
||||
trade_command = 'Short'
|
||||
elif req.direction == Direction.SHORT and req.offset != Offset.OPEN:
|
||||
trade_command = 'Sell'
|
||||
elif req.direction == Direction.LONG and req.offset != Offset.OPEN:
|
||||
trade_command = 'Cover'
|
||||
|
||||
all_custom_contracts = self.main_engine.get_all_custom_contracts()
|
||||
contract_setting = all_custom_contracts.get(req.vt_symbol, {})
|
||||
algo_setting = {
|
||||
'templateName': u'SpreadTrading套利',
|
||||
'order_vt_symbol': req.vt_symbol,
|
||||
'order_command': trade_command,
|
||||
'order_price': req.price,
|
||||
'order_volume': req.volume,
|
||||
'timer_interval': 60 * 60 * 24,
|
||||
'strategy_name': req.strategy_name,
|
||||
'gateway_name': gateway_name
|
||||
}
|
||||
algo_setting.update(contract_setting)
|
||||
|
||||
# 算法引擎
|
||||
algo_name = self.start_algo(algo_setting)
|
||||
self.write_log(u'send_algo_order(): start_algo {}={}'.format(algo_name, str(algo_setting)))
|
||||
|
||||
# 创建一个Order事件
|
||||
order = req.create_order_data(orderid=algo_name, gateway_name=gateway_name)
|
||||
order.orderTime = datetime.now().strftime('%H:%M:%S.%f')
|
||||
order.status = Status.SUBMITTING
|
||||
|
||||
event1 = Event(type=EVENT_ORDER, data=order)
|
||||
self.event_engine.put(event1)
|
||||
|
||||
# 登记在本地的算法委托字典中
|
||||
self.algo_vtorderid_order_map.update({order.vt_orderid: order})
|
||||
|
||||
return order.vt_orderid
|
||||
|
||||
def is_algo_order(self, req: CancelRequest, gateway_name: str):
|
||||
"""是否为外部算法委托单"""
|
||||
vt_orderid = '.'.join([req.orderid, gateway_name])
|
||||
if vt_orderid in self.algo_vtorderid_order_map:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def cancel_algo_order(self, req: CancelRequest, gateway_name: str):
|
||||
"""外部算法单撤单"""
|
||||
vt_orderid = '.'.join([req.orderid, gateway_name])
|
||||
order = self.algo_vtorderid_order_map.get(vt_orderid, None)
|
||||
if not order:
|
||||
self.write_error(f'{vt_orderid}不在算法引擎中,撤单失败')
|
||||
return False
|
||||
|
||||
algo = self.algos.get(req.orderid, None)
|
||||
if not algo:
|
||||
self.write_error(f'{req.orderid}算法实例不在算法引擎中,撤单失败')
|
||||
return False
|
||||
|
||||
ret = self.stop_algo(req.orderid)
|
||||
if ret:
|
||||
order.cancelTime = datetime.now().strftime('%H:%M:%S.%f')
|
||||
order.status = Status.CANCELLED
|
||||
event1 = Event(type=EVENT_ORDER, data=order)
|
||||
self.event_engine.put(event1)
|
||||
self.write_log(f'算法实例撤单成功:{req.orderid}')
|
||||
return True
|
||||
else:
|
||||
self.write_error(f'算法实例撤单失败:{req.orderid}')
|
||||
return False
|
||||
|
||||
def get_tick(self, algo: AlgoTemplate, vt_symbol: str):
|
||||
""""""
|
||||
tick = self.main_engine.get_tick(vt_symbol)
|
||||
|
||||
if not tick:
|
||||
self.write_log(f"查询行情失败,找不到行情:{vt_symbol}", algo)
|
||||
self.write_log(f"查询行情失败,找不到行情:{vt_symbol}", algo_name=algo.algo_name)
|
||||
|
||||
return tick
|
||||
|
||||
def get_price(self, algo: AlgoTemplate, vt_symbol: str):
|
||||
tick = self.main_engine.get_tick(vt_symbol)
|
||||
|
||||
if not tick:
|
||||
self.write_log(f"查询行情失败,找不到行情:{vt_symbol}", algo_name=algo.algo_name)
|
||||
return None
|
||||
|
||||
return tick.last_price
|
||||
|
||||
@lru_cache()
|
||||
def get_size(self, vt_symbol: str):
|
||||
"""查询合约的size"""
|
||||
contract = self.main_engine.get_contract(vt_symbol)
|
||||
if contract is None:
|
||||
self.write_error(f'查询不到{vt_symbol}合约信息')
|
||||
return 10
|
||||
return contract.size
|
||||
|
||||
@lru_cache()
|
||||
def get_margin_rate(self, vt_symbol: str):
|
||||
"""查询保证金比率"""
|
||||
contract = self.main_engine.get_contract(vt_symbol)
|
||||
if contract is None:
|
||||
self.write_error(f'查询不到{vt_symbol}合约信息')
|
||||
return 0.1
|
||||
if contract.margin_rate == 0:
|
||||
return 0.1
|
||||
return contract.margin_rate
|
||||
|
||||
@lru_cache()
|
||||
def get_price_tick(self, vt_symbol: str):
|
||||
"""查询价格最小跳动"""
|
||||
contract = self.main_engine.get_contract(vt_symbol)
|
||||
if contract is None:
|
||||
self.write_error(f'查询不到{vt_symbol}合约信息')
|
||||
return 0.1
|
||||
|
||||
return contract.pricetick
|
||||
|
||||
def get_account(self, vt_accountid: str = ""):
|
||||
""" 查询账号的资金"""
|
||||
# 如果启动风控,则使用风控中的最大仓位
|
||||
if self.main_engine.rm_engine:
|
||||
return self.main_engine.rm_engine.get_account(vt_accountid)
|
||||
|
||||
if len(vt_accountid) > 0:
|
||||
account = self.main_engine.get_account(vt_accountid)
|
||||
return account.balance, account.avaliable, round(account.frozen * 100 / (account.balance + 0.01), 2), 100
|
||||
else:
|
||||
accounts = self.main_engine.get_all_accounts()
|
||||
if len(accounts) > 0:
|
||||
account = accounts[0]
|
||||
return account.balance, account.avaliable, round(account.frozen * 100 / (account.balance + 0.01),
|
||||
2), 100
|
||||
else:
|
||||
return 0, 0, 0, 0
|
||||
|
||||
def get_contract(self, algo: AlgoTemplate, vt_symbol: str):
|
||||
""""""
|
||||
contract = self.main_engine.get_contract(vt_symbol)
|
||||
|
||||
if not contract:
|
||||
self.write_log(f"查询合约失败,找不到合约:{vt_symbol}", algo)
|
||||
self.write_log(msg=f"查询合约失败,找不到合约:{vt_symbol}", algo_name=algo.algo_name)
|
||||
|
||||
return contract
|
||||
|
||||
def write_log(self, msg: str, algo: AlgoTemplate = None):
|
||||
""""""
|
||||
if algo:
|
||||
msg = f"{algo.algo_name}:{msg}"
|
||||
def write_log(self, msg: str, algo_name: str = None, level: int = logging.INFO):
|
||||
"""增强版写日志"""
|
||||
if algo_name:
|
||||
msg = f"{algo_name}:{msg}"
|
||||
|
||||
log = LogData(msg=msg, gateway_name=APP_NAME)
|
||||
log = LogData(msg=msg, gateway_name=APP_NAME, level=level)
|
||||
event = Event(EVENT_ALGO_LOG, data=log)
|
||||
self.event_engine.put(event)
|
||||
|
||||
# 保存单独的策略日志
|
||||
if algo_name:
|
||||
algo_logger = self.algo_loggers.get(algo_name, None)
|
||||
if not algo_logger:
|
||||
log_path = get_folder_path('log')
|
||||
log_filename = os.path.abspath(os.path.join(log_path, str(algo_name)))
|
||||
print(u'create logger:{}'.format(log_filename))
|
||||
self.algo_loggers[algo_name] = setup_logger(
|
||||
file_name=log_filename,
|
||||
name=str(algo_name))
|
||||
algo_logger = self.algo_loggers.get(algo_name)
|
||||
if algo_logger:
|
||||
algo_logger.log(level, msg)
|
||||
|
||||
# 如果日志数据异常,错误和告警,输出至sys.stderr
|
||||
if level in [logging.CRITICAL, logging.ERROR, logging.WARNING]:
|
||||
print(msg, file=sys.stderr)
|
||||
|
||||
def write_error(self, msg: str, algo_name: str = ''):
|
||||
"""写入错误日志"""
|
||||
self.write_log(msg=msg, algo_name=algo_name, level=logging.ERROR)
|
||||
|
||||
def put_setting_event(self, setting_name: str, setting: dict):
|
||||
""""""
|
||||
event = Event(EVENT_ALGO_SETTING)
|
||||
|
@ -174,7 +174,10 @@ class AlgoTemplate:
|
||||
|
||||
def write_log(self, msg: str):
|
||||
""""""
|
||||
self.algo_engine.write_log(msg, self)
|
||||
self.algo_engine.write_log(msg, self.algo_name)
|
||||
|
||||
def write_error(self, msg: str):
|
||||
self.algo_engine.write_error(msg, self.algo_name)
|
||||
|
||||
def put_parameters_event(self):
|
||||
""""""
|
||||
@ -182,7 +185,7 @@ class AlgoTemplate:
|
||||
for name in self.default_setting.keys():
|
||||
parameters[name] = getattr(self, name)
|
||||
|
||||
self.algo_engine.put_parameters_event(self, parameters)
|
||||
self.algo_engine.put_parameters_event(algo=self, parameters=parameters)
|
||||
|
||||
def put_variables_event(self):
|
||||
""""""
|
||||
|
@ -196,7 +196,7 @@ class BackTestingEngine(object):
|
||||
self.data_path = None
|
||||
|
||||
self.fund_kline_dict = {}
|
||||
self.acivte_fund_kline = False
|
||||
self.active_fund_kline = False
|
||||
|
||||
def create_fund_kline(self, name, use_renko=False):
|
||||
"""
|
||||
@ -334,7 +334,8 @@ class BackTestingEngine(object):
|
||||
def get_price_tick(self, vt_symbol: str):
|
||||
return self.price_tick.get(vt_symbol, 1)
|
||||
|
||||
def set_contract(self, symbol: str, exchange: Exchange, product: Product, name: str, size: int, price_tick: float):
|
||||
def set_contract(self, symbol: str, exchange: Exchange, product: Product, name: str, size: int,
|
||||
price_tick: float, margin_rate: float = 0.1):
|
||||
"""设置合约信息"""
|
||||
vt_symbol = '.'.join([symbol, exchange.value])
|
||||
if vt_symbol not in self.contract_dict:
|
||||
@ -345,11 +346,12 @@ class BackTestingEngine(object):
|
||||
name=name,
|
||||
product=product,
|
||||
size=size,
|
||||
pricetick=price_tick
|
||||
pricetick=price_tick,
|
||||
margin_rate=margin_rate
|
||||
)
|
||||
self.contract_dict.update({vt_symbol: c})
|
||||
self.set_size(vt_symbol, size)
|
||||
# self.set_margin_rate(vt_symbol, )
|
||||
self.set_margin_rate(vt_symbol, margin_rate)
|
||||
self.set_price_tick(vt_symbol, price_tick)
|
||||
self.symbol_exchange_dict.update({symbol: exchange})
|
||||
|
||||
@ -486,8 +488,8 @@ class BackTestingEngine(object):
|
||||
self.bar_interval_seconds = test_settings.get('bar_interval_seconds')
|
||||
|
||||
# 资金曲线
|
||||
self.acivte_fund_kline = test_settings.get('acivte_fund_kline', False)
|
||||
if self.acivte_fund_kline:
|
||||
self.active_fund_kline = test_settings.get('active_fund_kline', False)
|
||||
if self.active_fund_kline:
|
||||
# 创建资金K线
|
||||
self.create_fund_kline(self.test_name, use_renko=test_settings.get('use_renko', False))
|
||||
|
||||
@ -515,8 +517,8 @@ class BackTestingEngine(object):
|
||||
self.set_slippage(symbol, symbol_data.get('slippage', 0))
|
||||
|
||||
self.set_size(symbol, symbol_data.get('symbol_size', 10))
|
||||
|
||||
self.set_margin_rate(symbol, symbol_data.get('margin_rate', 0.1))
|
||||
margin_rate = symbol_data.get('margin_rate', 0.1)
|
||||
self.set_margin_rate(symbol, margin_rate)
|
||||
|
||||
self.set_commission_rate(symbol, symbol_data.get('commission_rate', float(0.0001)))
|
||||
|
||||
@ -526,7 +528,8 @@ class BackTestingEngine(object):
|
||||
exchange=Exchange(symbol_data.get('exchange', 'LOCAL')),
|
||||
product=Product(symbol_data.get('product', "期货")),
|
||||
size=symbol_data.get('symbol_size', 10),
|
||||
price_tick=symbol_data.get('price_tick', 1)
|
||||
price_tick=symbol_data.get('price_tick', 1),
|
||||
margin_rate=margin_rate
|
||||
)
|
||||
|
||||
def new_tick(self, tick):
|
||||
@ -735,7 +738,7 @@ class BackTestingEngine(object):
|
||||
self.write_log(u'自动启动策略')
|
||||
strategy.on_start()
|
||||
|
||||
if self.acivte_fund_kline:
|
||||
if self.active_fund_kline:
|
||||
# 创建策略实例的资金K线
|
||||
self.create_fund_kline(name=strategy_name, use_renko=False)
|
||||
|
||||
|
@ -202,7 +202,7 @@ class CtaLineBar(object):
|
||||
self.write_log(u'导入卡尔曼过滤器失败,需先安装 pip install pykalman')
|
||||
self.para_active_kf = False
|
||||
|
||||
def registerEvent(self, event_type, cb_func):
|
||||
def register_event(self, event_type, cb_func):
|
||||
"""注册事件回调函数"""
|
||||
self.cb_dict.update({event_type: cb_func})
|
||||
if event_type == self.CB_ON_PERIOD:
|
||||
@ -643,7 +643,7 @@ class CtaLineBar(object):
|
||||
|
||||
# 更新curPeriod的High,low
|
||||
if self.cur_period is not None:
|
||||
self.cur_period.onPrice(self.cur_tick.last_price)
|
||||
self.cur_period.update_price(self.cur_tick.last_price)
|
||||
|
||||
def add_bar(self, bar: BarData, bar_is_completed: bool = False, bar_freq: int = 1):
|
||||
"""
|
||||
|
@ -96,8 +96,7 @@ class CtaEngine(BaseEngine):
|
||||
|
||||
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
|
||||
""""""
|
||||
super(CtaEngine, self).__init__(
|
||||
main_engine, event_engine, APP_NAME)
|
||||
super().__init__(main_engine, event_engine, APP_NAME)
|
||||
|
||||
self.engine_config = {}
|
||||
|
||||
@ -428,6 +427,7 @@ class CtaEngine(BaseEngine):
|
||||
type=type,
|
||||
price=price,
|
||||
volume=volume,
|
||||
strategy_name=strategy.strategy_name
|
||||
)
|
||||
|
||||
# 如果没有指定网关,则使用合约信息内的网关
|
||||
@ -749,6 +749,10 @@ class CtaEngine(BaseEngine):
|
||||
|
||||
return contract.pricetick
|
||||
|
||||
def get_tick(self, vt_symbol: str):
|
||||
"""获取合约得最新tick"""
|
||||
return self.main_engine.get_tick(vt_symbol)
|
||||
|
||||
def get_price(self, vt_symbol: str):
|
||||
"""查询合约的最新价格"""
|
||||
tick = self.main_engine.get_tick(vt_symbol)
|
||||
|
423
vnpy/app/cta_strategy_pro/spread_testing.py
Normal file
423
vnpy/app/cta_strategy_pro/spread_testing.py
Normal file
@ -0,0 +1,423 @@
|
||||
# encoding: UTF-8
|
||||
|
||||
'''
|
||||
本文件中包含的是CTA模块的组合套利回测引擎,回测引擎的API和CTA引擎一致,
|
||||
可以使用和实盘相同的代码进行回测。
|
||||
华富资产 李来佳
|
||||
'''
|
||||
from __future__ import division
|
||||
|
||||
import sys
|
||||
import os
|
||||
import gc
|
||||
import pandas as pd
|
||||
import traceback
|
||||
import random
|
||||
import bz2
|
||||
import pickle
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from time import sleep
|
||||
|
||||
from vnpy.trader.object import (
|
||||
TickData,
|
||||
BarData,
|
||||
RenkoBarData,
|
||||
)
|
||||
from vnpy.trader.constant import (
|
||||
Exchange,
|
||||
)
|
||||
|
||||
from vnpy.trader.utility import (
|
||||
get_trading_date,
|
||||
extract_vt_symbol,
|
||||
get_underlying_symbol,
|
||||
import_module_by_str
|
||||
)
|
||||
|
||||
from .back_testing import BackTestingEngine
|
||||
|
||||
# vnpy交易所,与淘宝数据tick目录得对应关系
|
||||
VN_EXCHANGE_TICKFOLDER_MAP = {
|
||||
Exchange.SHFE.value: 'SQ',
|
||||
Exchange.DCE.value: 'DL',
|
||||
Exchange.CZCE.value: 'ZZ',
|
||||
Exchange.CFFEX.value: 'ZJ',
|
||||
Exchange.INE.value: 'SQ'
|
||||
}
|
||||
|
||||
class SpreadTestingEngine(BackTestingEngine):
|
||||
"""
|
||||
CTA套利组合回测引擎, 使用回测引擎作为父类
|
||||
函数接口和策略引擎保持一样,
|
||||
从而实现同一套代码从回测到实盘。
|
||||
针对tick回测
|
||||
导入CTA_Settings
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, event_engine=None):
|
||||
"""Constructor"""
|
||||
super().__init__(event_engine)
|
||||
self.tick_path = None # tick级别回测, 路径
|
||||
|
||||
self.strategy_start_date_dict = {}
|
||||
self.strategy_end_date_dict = {}
|
||||
|
||||
def prepare_env(self, test_settings):
|
||||
self.output('portfolio prepare_env')
|
||||
super().prepare_env(test_settings)
|
||||
|
||||
def load_strategy(self, strategy_name: str, strategy_setting: dict = None):
|
||||
"""
|
||||
装载回测的策略
|
||||
setting是参数设置,包括
|
||||
class_name: str, 策略类名字
|
||||
vt_symbol: str, 缺省合约
|
||||
setting: {}, 策略的参数
|
||||
auto_init: True/False, 策略是否自动初始化
|
||||
auto_start: True/False, 策略是否自动启动
|
||||
"""
|
||||
|
||||
# 获取策略的类名
|
||||
class_name = strategy_setting.get('class_name', None)
|
||||
if class_name is None or strategy_name is None:
|
||||
self.write_error(u'setting中没有class_name')
|
||||
return
|
||||
|
||||
# strategy_class => module.strategy_class
|
||||
if '.' not in class_name:
|
||||
module_name = self.class_module_map.get(class_name, None)
|
||||
if module_name:
|
||||
class_name = module_name + '.' + class_name
|
||||
self.write_log(u'转换策略为全路径:{}'.format(class_name))
|
||||
|
||||
# 获取策略类的定义
|
||||
strategy_class = import_module_by_str(class_name)
|
||||
if strategy_class is None:
|
||||
self.write_error(u'加载策略模块失败:{}'.format(class_name))
|
||||
return
|
||||
|
||||
# 处理 vt_symbol
|
||||
vt_symbol = strategy_setting.get('vt_symbol')
|
||||
symbol, exchange = extract_vt_symbol(vt_symbol)
|
||||
|
||||
subscribe_symobls = [vt_symbol]
|
||||
# 属于自定义套利合约
|
||||
if exchange == Exchange.SPD:
|
||||
act_symbol, act_ratio, pas_symbol, pas_ratio, spread_type = symbol.split('-')
|
||||
act_underly = get_underlying_symbol(act_symbol).upper()
|
||||
pas_underly = get_underlying_symbol(pas_symbol).upper()
|
||||
act_exchange = self.get_exchange(f'{act_underly}99')
|
||||
pas_exchange = self.get_exchange(f'{pas_underly}99')
|
||||
idx_contract = self.get_contract(f'{act_underly}99.{act_exchange.value}')
|
||||
self.set_contract(symbol=act_symbol,
|
||||
exchange=act_exchange,
|
||||
product=idx_contract.product,
|
||||
name=act_symbol,
|
||||
size=idx_contract.size,
|
||||
price_tick=idx_contract.pricetick,
|
||||
margin_rate=idx_contract.margin_rate)
|
||||
if pas_underly != act_underly:
|
||||
idx_contract = self.get_contract(f'{pas_underly}99.{pas_exchange.value}')
|
||||
|
||||
self.set_contract(symbol=pas_symbol,
|
||||
exchange=pas_exchange,
|
||||
product=idx_contract.product,
|
||||
name=act_symbol,
|
||||
size=idx_contract.size,
|
||||
price_tick=idx_contract.pricetick,
|
||||
margin_rate=idx_contract.margin_rate)
|
||||
|
||||
subscribe_symobls.remove(vt_symbol)
|
||||
subscribe_symobls.append(f'{act_symbol}.{act_exchange.value}')
|
||||
subscribe_symobls.append(f'{pas_symbol}.{pas_exchange.value}')
|
||||
|
||||
# 取消自动启动
|
||||
if 'auto_start' in strategy_setting:
|
||||
strategy_setting.update({'auto_start': False})
|
||||
|
||||
# 策略参数设置
|
||||
setting = strategy_setting.get('setting', {})
|
||||
|
||||
# 强制更新回测为True
|
||||
setting.update({'backtesting': True})
|
||||
|
||||
# 创建实例
|
||||
strategy = strategy_class(self, strategy_name, vt_symbol, setting)
|
||||
|
||||
# 保存到策略实例映射表中
|
||||
self.strategies.update({strategy_name: strategy})
|
||||
|
||||
# 更新vt_symbol合约与策略的订阅关系
|
||||
for sub_vt_symbol in subscribe_symobls:
|
||||
self.subscribe_symbol(strategy_name=strategy_name, vt_symbol=sub_vt_symbol)
|
||||
|
||||
if strategy_setting.get('auto_init', False):
|
||||
self.write_log(u'自动初始化策略')
|
||||
strategy.on_init()
|
||||
|
||||
if strategy_setting.get('auto_start', False):
|
||||
self.write_log(u'自动启动策略')
|
||||
strategy.on_start()
|
||||
|
||||
if self.active_fund_kline:
|
||||
# 创建策略实例的资金K线
|
||||
self.create_fund_kline(name=strategy_name, use_renko=False)
|
||||
|
||||
def run_portfolio_test(self, strategy_settings: dict = {}):
|
||||
"""
|
||||
运行组合回测
|
||||
"""
|
||||
if not self.strategy_start_date:
|
||||
self.write_error(u'回测开始日期未设置。')
|
||||
return
|
||||
|
||||
if len(strategy_settings) == 0:
|
||||
self.write_error('未提供有效配置策略实例')
|
||||
return
|
||||
|
||||
self.cur_capital = self.init_capital # 更新设置期初资金
|
||||
if not self.data_end_date:
|
||||
self.data_end_date = datetime.today()
|
||||
|
||||
self.write_log(u'开始套利组合回测')
|
||||
|
||||
for strategy_name, strategy_setting in strategy_settings.items():
|
||||
# 策略得启动日期
|
||||
if 'start_date' in strategy_setting:
|
||||
start_date = strategy_setting.get('start_date')
|
||||
start_date = datetime.strptime(start_date, '%Y-%m-%d')
|
||||
self.strategy_start_date_dict.update({strategy_name, start_date})
|
||||
# 策略得结束日期
|
||||
if 'end_date' in strategy_setting:
|
||||
end_date = strategy_setting.get('end_date')
|
||||
end_date = datetime.strptime(end_date, '%Y-%m-%d')
|
||||
self.strategy_end_date_dict.update({strategy_name, end_date})
|
||||
|
||||
self.load_strategy(strategy_name, strategy_setting)
|
||||
|
||||
self.write_log(u'策略初始化完成')
|
||||
|
||||
self.write_log(u'开始回放数据')
|
||||
|
||||
self.write_log(u'开始回测:{} ~ {}'.format(self.data_start_date, self.data_end_date))
|
||||
|
||||
self.run_tick_test()
|
||||
|
||||
def load_csv_file(self, tick_folder, vt_symbol, tick_date):
|
||||
"""从文件中读取tick,返回list[{dict}]"""
|
||||
|
||||
symbol, exchange = extract_vt_symbol(vt_symbol)
|
||||
underly_symbol = get_underlying_symbol(symbol)
|
||||
exchange_folder = VN_EXCHANGE_TICKFOLDER_MAP.get(exchange.value)
|
||||
|
||||
if exchange == Exchange.INE:
|
||||
file_path = os.path.abspath(
|
||||
os.path.join(
|
||||
tick_folder,
|
||||
exchange_folder,
|
||||
tick_date.strftime('%Y'),
|
||||
tick_date.strftime('%Y%m'),
|
||||
tick_date.strftime('%Y%m%d'),
|
||||
'{}_{}.csv'.format(symbol.upper(), tick_date.strftime('%Y%m%d'))))
|
||||
else:
|
||||
file_path = os.path.abspath(
|
||||
os.path.join(
|
||||
tick_folder,
|
||||
exchange_folder,
|
||||
tick_date.strftime('%Y'),
|
||||
tick_date.strftime('%Y%m'),
|
||||
tick_date.strftime('%Y%m%d'),
|
||||
'{}{}_{}.csv'.format(underly_symbol.upper(), symbol[-2:], tick_date.strftime('%Y%m%d'))))
|
||||
|
||||
|
||||
ticks = []
|
||||
if not os.path.isfile(file_path):
|
||||
self.write_log(u'{0}文件不存在'.format(file_path))
|
||||
return None
|
||||
|
||||
df = pd.read_csv(file_path, encoding='gbk', parse_dates=False)
|
||||
df.columns = ['date', 'time', 'last_price', 'volume', 'last_volume', 'open_interest',
|
||||
'bid_price_1', 'bid_volume_1', 'bid_price_2', 'bid_volume_2', 'bid_price_3', 'bid_volume_3',
|
||||
'ask_price_1', 'ask_volume_1', 'ask_price_2', 'ask_volume_2', 'ask_price_3', 'ask_volume_3', 'BS']
|
||||
|
||||
self.write_log(u'加载csv文件{}'.format(file_path))
|
||||
last_time = None
|
||||
for index, row in df.iterrows():
|
||||
# 日期, 时间, 成交价, 成交量, 总量, 属性(持仓增减), B1价, B1量, B2价, B2量, B3价, B3量, S1价, S1量, S2价, S2量, S3价, S3量, BS
|
||||
# 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
|
||||
|
||||
tick = row.to_dict()
|
||||
tick.update({'symbol': symbol, 'exchange': exchange.value, 'trading_day': tick_date.strftime('%Y-%m-%d')})
|
||||
tick_datetime = datetime.strptime(tick['date'] + ' ' + tick['time'], '%Y-%m-%d %H:%M:%S')
|
||||
|
||||
# 修正毫秒
|
||||
if tick['time'] == last_time:
|
||||
# 与上一个tick的时间(去除毫秒后)相同,修改为500毫秒
|
||||
tick_datetime = tick_datetime.replace(microsecond=500)
|
||||
tick['time'] = tick_datetime.strftime('%H:%M:%S.%f')
|
||||
else:
|
||||
last_time = tick['time']
|
||||
tick_datetime = tick_datetime.replace(microsecond=0)
|
||||
tick['time'] = tick_datetime.strftime('%H:%M:%S.%f')
|
||||
tick['datetime'] = tick_datetime
|
||||
|
||||
# 排除涨停/跌停的数据
|
||||
if (float(tick['bid_price_1']) == float('1.79769E308') and int(tick['bid_volume_1']) == 0) \
|
||||
or (float(tick['ask_price_1']) == float('1.79769E308') and int(tick['ask_volume_1']) == 0):
|
||||
continue
|
||||
|
||||
ticks.append(tick)
|
||||
|
||||
del df
|
||||
|
||||
return ticks
|
||||
|
||||
def load_bz2_cache(self, cache_folder, cache_symbol, cache_date):
|
||||
"""
|
||||
加载缓存数据
|
||||
list[{dict}]
|
||||
"""
|
||||
if not os.path.exists(cache_folder):
|
||||
self.write_error('缓存目录:{}不存在,不能读取'.format(cache_folder))
|
||||
return None
|
||||
cache_folder_year_month = os.path.join(cache_folder, cache_date[:6])
|
||||
if not os.path.exists(cache_folder_year_month):
|
||||
self.write_error('缓存目录:{}不存在,不能读取'.format(cache_folder_year_month))
|
||||
return None
|
||||
|
||||
cache_file = os.path.join(cache_folder_year_month, '{}_{}.pkb2'.format(cache_symbol, cache_date))
|
||||
if not os.path.isfile(cache_file):
|
||||
cache_file = os.path.join(cache_folder_year_month, '{}_{}.pkz2'.format(cache_symbol, cache_date))
|
||||
if not os.path.isfile(cache_file):
|
||||
self.write_error('缓存文件:{}不存在,不能读取'.format(cache_file))
|
||||
return None
|
||||
|
||||
with bz2.BZ2File(cache_file, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
return data
|
||||
|
||||
return None
|
||||
|
||||
def get_day_tick_df(self, test_day):
|
||||
"""获取某一天得所有合约tick"""
|
||||
tick_data_dict = {}
|
||||
|
||||
for vt_symbol in list(self.symbol_strategy_map.keys()):
|
||||
symbol, exchange = extract_vt_symbol(vt_symbol)
|
||||
tick_list = self.load_csv_file(tick_folder=self.tick_path,
|
||||
vt_symbol=vt_symbol,
|
||||
tick_date=test_day)
|
||||
if not tick_list or len(tick_list) == 0:
|
||||
continue
|
||||
|
||||
symbol_tick_df = pd.DataFrame(tick_list)
|
||||
# 缓存文件中,datetime字段,已经是datetime格式
|
||||
# 暂时根据时间去重,没有汇总volume
|
||||
symbol_tick_df.drop_duplicates(subset=['datetime'], keep='first', inplace=True)
|
||||
symbol_tick_df.set_index('datetime', inplace=True)
|
||||
|
||||
tick_data_dict.update({vt_symbol: symbol_tick_df})
|
||||
|
||||
if len(tick_data_dict) == 0:
|
||||
return None
|
||||
|
||||
tick_df = pd.concat(tick_data_dict, axis=0).swaplevel(0, 1).sort_index()
|
||||
|
||||
return tick_df
|
||||
|
||||
def run_tick_test(self):
|
||||
"""运行tick级别组合回测"""
|
||||
testdays = (self.data_end_date - self.data_start_date).days
|
||||
|
||||
if testdays < 1:
|
||||
self.write_log(u'回测时间不足')
|
||||
return
|
||||
|
||||
gc_collect_days = 0
|
||||
|
||||
# 循环每一天
|
||||
for i in range(0, testdays):
|
||||
test_day = self.data_start_date + timedelta(days=i)
|
||||
|
||||
combined_df = self.get_day_tick_df(test_day)
|
||||
|
||||
if combined_df is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
for (dt, vt_symbol), tick_data in combined_df.iterrows():
|
||||
symbol, exchange = extract_vt_symbol(vt_symbol)
|
||||
tick = TickData(
|
||||
gateway_name='backtesting',
|
||||
symbol=symbol,
|
||||
exchange=exchange,
|
||||
datetime=dt,
|
||||
date=dt.strftime('%Y-%m-%d'),
|
||||
time=dt.strftime('%H:%M:%S.%f'),
|
||||
trading_day=tick_data['trading_day'],
|
||||
last_price=float(tick_data['last_price']),
|
||||
volume=int(tick_data['volume']),
|
||||
ask_price_1=float(tick_data['ask_price_1']),
|
||||
ask_volume_1=int(tick_data['ask_volume_1']),
|
||||
bid_price_1=float(tick_data['bid_price_1']),
|
||||
bid_volume_1=int(tick_data['bid_volume_1'])
|
||||
)
|
||||
|
||||
self.new_tick(tick)
|
||||
|
||||
# 结束一个交易日后,更新每日净值
|
||||
self.saving_daily_data(test_day,
|
||||
self.cur_capital,
|
||||
self.max_net_capital,
|
||||
self.total_commission)
|
||||
|
||||
self.cancel_orders()
|
||||
# 更新持仓缓存
|
||||
self.update_pos_buffer()
|
||||
|
||||
gc_collect_days += 1
|
||||
if gc_collect_days >= 10:
|
||||
# 执行内存回收
|
||||
gc.collect()
|
||||
sleep(1)
|
||||
gc_collect_days = 0
|
||||
|
||||
if self.net_capital < 0:
|
||||
self.write_error(u'净值低于0,回测停止')
|
||||
self.output(u'净值低于0,回测停止')
|
||||
return
|
||||
|
||||
except Exception as ex:
|
||||
self.write_error(u'回测异常导致停止:{}'.format(str(ex)))
|
||||
self.write_error(u'{},{}'.format(str(ex), traceback.format_exc()))
|
||||
print(str(ex), file=sys.stderr)
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
self.write_log(u'tick数据回放完成')
|
||||
|
||||
|
||||
def single_test(test_setting: dict, strategy_setting: dict):
|
||||
"""
|
||||
单一回测
|
||||
: test_setting, 组合回测所需的配置,包括合约信息,数据tick信息,回测时间,资金等。
|
||||
:strategy_setting, dict, 一个或多个策略配置
|
||||
"""
|
||||
# 创建组合回测引擎
|
||||
engine = SpreadTestingEngine()
|
||||
|
||||
engine.prepare_env(test_setting)
|
||||
try:
|
||||
engine.run_portfolio_test(strategy_setting)
|
||||
# 回测结果,保存
|
||||
engine.show_backtesting_result()
|
||||
|
||||
except Exception as ex:
|
||||
print('组合回测异常{}'.format(str(ex)))
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
print('测试结束')
|
||||
return True
|
@ -1150,11 +1150,23 @@ class CtaProFutureTemplate(CtaProTemplate):
|
||||
self.save_dist(dist_record)
|
||||
self.pos = self.position.pos
|
||||
|
||||
def fix_order(self, order: OrderData):
|
||||
"""修正order被拆单得情况"""
|
||||
order_info = self.active_orders.get(order.vt_orderid, None)
|
||||
if order_info:
|
||||
volume = order_info.get('volume')
|
||||
if volume != order.volume:
|
||||
self.write_log(f'调整{order.vt_orderid} volume:{volume}=>{order.volume}')
|
||||
order_info.update({'volume': order.volume})
|
||||
|
||||
def on_order(self, order: OrderData):
|
||||
"""报单更新"""
|
||||
# 未执行的订单中,存在是异常,删除
|
||||
self.write_log(u'{}报单更新,{}'.format(self.cur_datetime, order.__dict__))
|
||||
|
||||
# 修正order被拆单得情况"
|
||||
self.fix_order(order)
|
||||
|
||||
if order.vt_orderid in self.active_orders:
|
||||
|
||||
if order.volume == order.traded and order.status in [Status.ALLTRADED]:
|
||||
@ -1543,7 +1555,7 @@ class CtaProFutureTemplate(CtaProTemplate):
|
||||
|
||||
for vt_orderid in list(self.active_orders.keys()):
|
||||
order_info = self.active_orders[vt_orderid]
|
||||
order_symbol = order_info.get('symbol', self.vt_symbol)
|
||||
order_vt_symbol = order_info.get('vt_symbol', self.vt_symbol)
|
||||
order_time = order_info['order_time']
|
||||
order_volume = order_info['volume'] - order_info['traded']
|
||||
# order_price = order_info['price']
|
||||
@ -1555,7 +1567,7 @@ class CtaProFutureTemplate(CtaProTemplate):
|
||||
over_seconds = (dt - order_time).total_seconds()
|
||||
|
||||
# 只处理未成交的限价委托单
|
||||
if order_status in [Status.NOTTRADED] and (order_type == OrderType.LIMIT or '.SPD' in order_symbol):
|
||||
if order_status in [Status.NOTTRADED] and (order_type == OrderType.LIMIT or '.SPD' in order_vt_symbol):
|
||||
if over_seconds > self.cancel_seconds or force: # 超过设置的时间还未成交
|
||||
self.write_log(u'超时{}秒未成交,取消委托单:vt_orderid:{},order:{}'
|
||||
.format(over_seconds, vt_orderid, order_info))
|
||||
@ -1586,10 +1598,10 @@ class CtaProFutureTemplate(CtaProTemplate):
|
||||
u'网格volume:{},order_volume:{}不一致,修正'.format(order_grid.volume, order_volume))
|
||||
order_grid.volume = order_volume
|
||||
|
||||
self.write_log(u'重新提交{}开空委托,开空价{},v:{}'.format(order_symbol, short_price, order_volume))
|
||||
self.write_log(u'重新提交{}开空委托,开空价{},v:{}'.format(order_vt_symbol, short_price, order_volume))
|
||||
vt_orderids = self.short(price=short_price,
|
||||
volume=order_volume,
|
||||
vt_symbol=order_symbol,
|
||||
vt_symbol=order_vt_symbol,
|
||||
order_type=order_type,
|
||||
order_time=self.cur_datetime,
|
||||
grid=order_grid)
|
||||
@ -1606,10 +1618,10 @@ class CtaProFutureTemplate(CtaProTemplate):
|
||||
u'网格volume:{},order_volume:{}不一致,修正'.format(order_grid.volume, order_volume))
|
||||
order_grid.volume = order_volume
|
||||
|
||||
self.write_log(u'重新提交{}开多委托,开多价{},v:{}'.format(order_symbol, buy_price, order_volume))
|
||||
self.write_log(u'重新提交{}开多委托,开多价{},v:{}'.format(order_vt_symbol, buy_price, order_volume))
|
||||
vt_orderids = self.buy(price=buy_price,
|
||||
volume=order_volume,
|
||||
vt_symbol=order_symbol,
|
||||
vt_symbol=order_vt_symbol,
|
||||
order_type=order_type,
|
||||
order_time=self.cur_datetime,
|
||||
grid=order_grid)
|
||||
@ -1623,10 +1635,10 @@ class CtaProFutureTemplate(CtaProTemplate):
|
||||
# 属于平多委托单
|
||||
if order_info['direction'] == Direction.SHORT:
|
||||
sell_price = self.cur_mi_price - self.price_tick
|
||||
self.write_log(u'重新提交{}平多委托,{},v:{}'.format(order_symbol, sell_price, order_volume))
|
||||
self.write_log(u'重新提交{}平多委托,{},v:{}'.format(order_vt_symbol, sell_price, order_volume))
|
||||
vt_orderids = self.sell(price=sell_price,
|
||||
volume=order_volume,
|
||||
vt_symbol=order_symbol,
|
||||
vt_symbol=order_vt_symbol,
|
||||
order_type=order_type,
|
||||
order_time=self.cur_datetime,
|
||||
grid=order_grid)
|
||||
@ -1637,10 +1649,10 @@ class CtaProFutureTemplate(CtaProTemplate):
|
||||
# 属于平空委托单
|
||||
else:
|
||||
cover_price = self.cur_mi_price + self.price_tick
|
||||
self.write_log(u'重新提交{}平空委托,委托价{},v:{}'.format(order_symbol, cover_price, order_volume))
|
||||
self.write_log(u'重新提交{}平空委托,委托价{},v:{}'.format(order_vt_symbol, cover_price, order_volume))
|
||||
vt_orderids = self.cover(price=cover_price,
|
||||
volume=order_volume,
|
||||
vt_symbol=order_symbol,
|
||||
vt_symbol=order_vt_symbol,
|
||||
order_type=order_type,
|
||||
order_time=self.cur_datetime,
|
||||
grid=order_grid)
|
||||
|
1198
vnpy/app/cta_strategy_pro/template_spread.py
Normal file
1198
vnpy/app/cta_strategy_pro/template_spread.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -65,14 +65,19 @@ class OffsetConverter:
|
||||
return holding
|
||||
|
||||
def convert_order_request(self, req: OrderRequest, lock: bool, gateway_name: str = ''):
|
||||
""""""
|
||||
"""转换委托单"""
|
||||
# 合约是净仓,不具有多空,不需要转换
|
||||
if not self.is_convert_required(req.vt_symbol):
|
||||
return [req]
|
||||
|
||||
# 获取当前持仓信息
|
||||
holding = self.get_position_holding(req.vt_symbol, gateway_name)
|
||||
|
||||
if lock:
|
||||
# 锁仓转换
|
||||
return holding.convert_order_request_lock(req)
|
||||
|
||||
# 平今/平昨拆分
|
||||
elif req.exchange in [Exchange.SHFE, Exchange.INE]:
|
||||
return holding.convert_order_request_shfe(req)
|
||||
else:
|
||||
@ -231,7 +236,7 @@ class PositionHolding:
|
||||
self.short_pos_frozen = self.short_td_frozen + self.short_yd_frozen
|
||||
|
||||
def convert_order_request_shfe(self, req: OrderRequest):
|
||||
""""""
|
||||
"""上期所,委托单拆分"""
|
||||
if req.offset == Offset.OPEN:
|
||||
return [req]
|
||||
|
||||
|
@ -24,6 +24,7 @@ from .event import (
|
||||
)
|
||||
from .gateway import BaseGateway
|
||||
from .object import (
|
||||
Exchange,
|
||||
CancelRequest,
|
||||
LogData,
|
||||
OrderRequest,
|
||||
@ -56,6 +57,7 @@ class MainEngine:
|
||||
self.exchanges = []
|
||||
|
||||
self.rm_engine = None
|
||||
self.algo_engine = None
|
||||
|
||||
os.chdir(TRADER_DIR) # Change working directory
|
||||
self.init_engines() # Initialize function engines
|
||||
@ -99,6 +101,8 @@ class MainEngine:
|
||||
engine = self.add_engine(app.engine_class)
|
||||
if app.app_name == "RiskManager":
|
||||
self.rm_engine = engine
|
||||
elif app.app_name == "AlgoTrading":
|
||||
self.algo_engine == engine
|
||||
|
||||
return engine
|
||||
|
||||
@ -188,7 +192,14 @@ class MainEngine:
|
||||
def send_order(self, req: OrderRequest, gateway_name: str):
|
||||
"""
|
||||
Send new order request to a specific gateway.
|
||||
扩展支持自定义套利合约。 由cta_strategy_pro发出算法单委托,由算法引擎进行处理
|
||||
"""
|
||||
# 自定义套利合约,交给算法引擎处理
|
||||
if self.algo_engine and req.exchange == Exchange.SPD:
|
||||
return self.algo_engine.send_algo_order(
|
||||
req=req,
|
||||
gateway_name=gateway_name)
|
||||
|
||||
gateway = self.get_gateway(gateway_name)
|
||||
if gateway:
|
||||
return gateway.send_order(req)
|
||||
@ -206,6 +217,7 @@ class MainEngine:
|
||||
|
||||
def send_orders(self, reqs: Sequence[OrderRequest], gateway_name: str):
|
||||
"""
|
||||
批量发单
|
||||
"""
|
||||
gateway = self.get_gateway(gateway_name)
|
||||
if gateway:
|
||||
@ -461,6 +473,7 @@ class OmsEngine(BaseEngine):
|
||||
""""""
|
||||
contract = event.data
|
||||
self.contracts[contract.vt_symbol] = contract
|
||||
self.contracts[contract.symbol] = contract
|
||||
|
||||
def get_tick(self, vt_symbol):
|
||||
"""
|
||||
@ -595,7 +608,7 @@ class CustomContract(object):
|
||||
gateway_name = setting.get('gateway_name', None)
|
||||
if gateway_name is None:
|
||||
gateway_name = SETTINGS.get('gateway_name', '')
|
||||
vn_exchange = Exchange(setting.get('exchange', 'LOCAL'))
|
||||
vn_exchange = Exchange(setting.get('exchange', 'SPD'))
|
||||
contract = ContractData(
|
||||
gateway_name=gateway_name,
|
||||
symbol=symbol,
|
||||
|
@ -249,7 +249,7 @@ class AccountData(BaseData):
|
||||
|
||||
|
||||
@dataclass
|
||||
class VtFundsFlowData(BaseData):
|
||||
class FundsFlowData(BaseData):
|
||||
"""历史资金流水数据类(股票专用)"""
|
||||
|
||||
# 账号代码相关
|
||||
@ -353,6 +353,7 @@ class OrderRequest:
|
||||
volume: float
|
||||
price: float = 0
|
||||
offset: Offset = Offset.NONE
|
||||
strategy_name: str = ""
|
||||
|
||||
def __post_init__(self):
|
||||
""""""
|
||||
|
Loading…
Reference in New Issue
Block a user