[增强] 期货套利回测引擎/套利模板/定义套利合约转算法引擎

This commit is contained in:
msincenselee 2020-03-08 21:59:11 +08:00
parent 2f3f65c694
commit 4ffbd50496
11 changed files with 1871 additions and 41 deletions

View File

@ -1,11 +1,16 @@
import os
import sys
from datetime import datetime
from functools import lru_cache
from vnpy.event import EventEngine, Event from vnpy.event import EventEngine, Event
from vnpy.trader.engine import BaseEngine, MainEngine from vnpy.trader.engine import BaseEngine, MainEngine
from vnpy.trader.event import ( from vnpy.trader.event import (
EVENT_TICK, EVENT_TIMER, EVENT_ORDER, EVENT_TRADE) EVENT_TICK, EVENT_TIMER, EVENT_ORDER, EVENT_TRADE)
from vnpy.trader.constant import (Direction, Offset, OrderType) from vnpy.trader.constant import (Direction, Offset, OrderType, Status)
from vnpy.trader.object import (SubscribeRequest, OrderRequest, LogData) from vnpy.trader.object import (SubscribeRequest, OrderRequest, LogData, CancelRequest)
from vnpy.trader.utility import load_json, save_json, round_to from vnpy.trader.utility import load_json, save_json, round_to, get_folder_path
from vnpy.trader.util_logger import setup_logger, logging
from .template import AlgoTemplate from .template import AlgoTemplate
@ -30,9 +35,13 @@ class AlgoEngine(BaseEngine):
self.symbol_algo_map = {} self.symbol_algo_map = {}
self.orderid_algo_map = {} self.orderid_algo_map = {}
self.algo_vtorderid_order_map = {} # 记录外部发起的算法交易委托编号,便于通过算法引擎撤单
self.algo_templates = {} self.algo_templates = {}
self.algo_settings = {} self.algo_settings = {}
self.algo_loggers = {} # algo_name: logger
self.load_algo_template() self.load_algo_template()
self.register_event() self.register_event()
@ -172,7 +181,7 @@ class AlgoEngine(BaseEngine):
"""""" """"""
contract = self.main_engine.get_contract(vt_symbol) contract = self.main_engine.get_contract(vt_symbol)
if not contract: if not contract:
self.write_log(f'委托下单失败,找不到合约:{vt_symbol}', algo) self.write_log(f'委托下单失败,找不到合约:{vt_symbol}', algo_name=algo.algo_name)
return return
volume = round_to(volume, contract.min_volume) volume = round_to(volume, contract.min_volume)
@ -204,33 +213,192 @@ class AlgoEngine(BaseEngine):
req = order.create_cancel_request() req = order.create_cancel_request()
self.main_engine.cancel_order(req, order.gateway_name) self.main_engine.cancel_order(req, order.gateway_name)
def send_algo_order(self, req: OrderRequest, gateway_name: str):
"""发送算法交易指令"""
self.write_log(u'创建算法交易,gateway_name:{},strategy_name:{},vt_symbol:{},price:{},volume:{}'
.format(gateway_name, req.strategy_name, req.vt_symbol, req.price, req.volume))
# 创建算法实例,由算法引擎启动
trade_command = ''
if req.direction == Direction.LONG and req.offset == Offset.OPEN:
trade_command = 'Buy'
elif req.direction == Direction.SHORT and req.offset == Offset.OPEN:
trade_command = 'Short'
elif req.direction == Direction.SHORT and req.offset != Offset.OPEN:
trade_command = 'Sell'
elif req.direction == Direction.LONG and req.offset != Offset.OPEN:
trade_command = 'Cover'
all_custom_contracts = self.main_engine.get_all_custom_contracts()
contract_setting = all_custom_contracts.get(req.vt_symbol, {})
algo_setting = {
'templateName': u'SpreadTrading套利',
'order_vt_symbol': req.vt_symbol,
'order_command': trade_command,
'order_price': req.price,
'order_volume': req.volume,
'timer_interval': 60 * 60 * 24,
'strategy_name': req.strategy_name,
'gateway_name': gateway_name
}
algo_setting.update(contract_setting)
# 算法引擎
algo_name = self.start_algo(algo_setting)
self.write_log(u'send_algo_order(): start_algo {}={}'.format(algo_name, str(algo_setting)))
# 创建一个Order事件
order = req.create_order_data(orderid=algo_name, gateway_name=gateway_name)
order.orderTime = datetime.now().strftime('%H:%M:%S.%f')
order.status = Status.SUBMITTING
event1 = Event(type=EVENT_ORDER, data=order)
self.event_engine.put(event1)
# 登记在本地的算法委托字典中
self.algo_vtorderid_order_map.update({order.vt_orderid: order})
return order.vt_orderid
def is_algo_order(self, req: CancelRequest, gateway_name: str):
"""是否为外部算法委托单"""
vt_orderid = '.'.join([req.orderid, gateway_name])
if vt_orderid in self.algo_vtorderid_order_map:
return True
else:
return False
def cancel_algo_order(self, req: CancelRequest, gateway_name: str):
"""外部算法单撤单"""
vt_orderid = '.'.join([req.orderid, gateway_name])
order = self.algo_vtorderid_order_map.get(vt_orderid, None)
if not order:
self.write_error(f'{vt_orderid}不在算法引擎中,撤单失败')
return False
algo = self.algos.get(req.orderid, None)
if not algo:
self.write_error(f'{req.orderid}算法实例不在算法引擎中,撤单失败')
return False
ret = self.stop_algo(req.orderid)
if ret:
order.cancelTime = datetime.now().strftime('%H:%M:%S.%f')
order.status = Status.CANCELLED
event1 = Event(type=EVENT_ORDER, data=order)
self.event_engine.put(event1)
self.write_log(f'算法实例撤单成功:{req.orderid}')
return True
else:
self.write_error(f'算法实例撤单失败:{req.orderid}')
return False
def get_tick(self, algo: AlgoTemplate, vt_symbol: str): def get_tick(self, algo: AlgoTemplate, vt_symbol: str):
"""""" """"""
tick = self.main_engine.get_tick(vt_symbol) tick = self.main_engine.get_tick(vt_symbol)
if not tick: if not tick:
self.write_log(f"查询行情失败,找不到行情:{vt_symbol}", algo) self.write_log(f"查询行情失败,找不到行情:{vt_symbol}", algo_name=algo.algo_name)
return tick return tick
def get_price(self, algo: AlgoTemplate, vt_symbol: str):
tick = self.main_engine.get_tick(vt_symbol)
if not tick:
self.write_log(f"查询行情失败,找不到行情:{vt_symbol}", algo_name=algo.algo_name)
return None
return tick.last_price
@lru_cache()
def get_size(self, vt_symbol: str):
"""查询合约的size"""
contract = self.main_engine.get_contract(vt_symbol)
if contract is None:
self.write_error(f'查询不到{vt_symbol}合约信息')
return 10
return contract.size
@lru_cache()
def get_margin_rate(self, vt_symbol: str):
"""查询保证金比率"""
contract = self.main_engine.get_contract(vt_symbol)
if contract is None:
self.write_error(f'查询不到{vt_symbol}合约信息')
return 0.1
if contract.margin_rate == 0:
return 0.1
return contract.margin_rate
@lru_cache()
def get_price_tick(self, vt_symbol: str):
"""查询价格最小跳动"""
contract = self.main_engine.get_contract(vt_symbol)
if contract is None:
self.write_error(f'查询不到{vt_symbol}合约信息')
return 0.1
return contract.pricetick
def get_account(self, vt_accountid: str = ""):
""" 查询账号的资金"""
# 如果启动风控,则使用风控中的最大仓位
if self.main_engine.rm_engine:
return self.main_engine.rm_engine.get_account(vt_accountid)
if len(vt_accountid) > 0:
account = self.main_engine.get_account(vt_accountid)
return account.balance, account.avaliable, round(account.frozen * 100 / (account.balance + 0.01), 2), 100
else:
accounts = self.main_engine.get_all_accounts()
if len(accounts) > 0:
account = accounts[0]
return account.balance, account.avaliable, round(account.frozen * 100 / (account.balance + 0.01),
2), 100
else:
return 0, 0, 0, 0
def get_contract(self, algo: AlgoTemplate, vt_symbol: str): def get_contract(self, algo: AlgoTemplate, vt_symbol: str):
"""""" """"""
contract = self.main_engine.get_contract(vt_symbol) contract = self.main_engine.get_contract(vt_symbol)
if not contract: if not contract:
self.write_log(f"查询合约失败,找不到合约:{vt_symbol}", algo) self.write_log(msg=f"查询合约失败,找不到合约:{vt_symbol}", algo_name=algo.algo_name)
return contract return contract
def write_log(self, msg: str, algo: AlgoTemplate = None): def write_log(self, msg: str, algo_name: str = None, level: int = logging.INFO):
"""""" """增强版写日志"""
if algo: if algo_name:
msg = f"{algo.algo_name}{msg}" msg = f"{algo_name}{msg}"
log = LogData(msg=msg, gateway_name=APP_NAME) log = LogData(msg=msg, gateway_name=APP_NAME, level=level)
event = Event(EVENT_ALGO_LOG, data=log) event = Event(EVENT_ALGO_LOG, data=log)
self.event_engine.put(event) self.event_engine.put(event)
# 保存单独的策略日志
if algo_name:
algo_logger = self.algo_loggers.get(algo_name, None)
if not algo_logger:
log_path = get_folder_path('log')
log_filename = os.path.abspath(os.path.join(log_path, str(algo_name)))
print(u'create logger:{}'.format(log_filename))
self.algo_loggers[algo_name] = setup_logger(
file_name=log_filename,
name=str(algo_name))
algo_logger = self.algo_loggers.get(algo_name)
if algo_logger:
algo_logger.log(level, msg)
# 如果日志数据异常错误和告警输出至sys.stderr
if level in [logging.CRITICAL, logging.ERROR, logging.WARNING]:
print(msg, file=sys.stderr)
def write_error(self, msg: str, algo_name: str = ''):
"""写入错误日志"""
self.write_log(msg=msg, algo_name=algo_name, level=logging.ERROR)
def put_setting_event(self, setting_name: str, setting: dict): def put_setting_event(self, setting_name: str, setting: dict):
"""""" """"""
event = Event(EVENT_ALGO_SETTING) event = Event(EVENT_ALGO_SETTING)

View File

@ -174,7 +174,10 @@ class AlgoTemplate:
def write_log(self, msg: str): def write_log(self, msg: str):
"""""" """"""
self.algo_engine.write_log(msg, self) self.algo_engine.write_log(msg, self.algo_name)
def write_error(self, msg: str):
self.algo_engine.write_error(msg, self.algo_name)
def put_parameters_event(self): def put_parameters_event(self):
"""""" """"""
@ -182,7 +185,7 @@ class AlgoTemplate:
for name in self.default_setting.keys(): for name in self.default_setting.keys():
parameters[name] = getattr(self, name) parameters[name] = getattr(self, name)
self.algo_engine.put_parameters_event(self, parameters) self.algo_engine.put_parameters_event(algo=self, parameters=parameters)
def put_variables_event(self): def put_variables_event(self):
"""""" """"""

View File

@ -196,7 +196,7 @@ class BackTestingEngine(object):
self.data_path = None self.data_path = None
self.fund_kline_dict = {} self.fund_kline_dict = {}
self.acivte_fund_kline = False self.active_fund_kline = False
def create_fund_kline(self, name, use_renko=False): def create_fund_kline(self, name, use_renko=False):
""" """
@ -334,7 +334,8 @@ class BackTestingEngine(object):
def get_price_tick(self, vt_symbol: str): def get_price_tick(self, vt_symbol: str):
return self.price_tick.get(vt_symbol, 1) return self.price_tick.get(vt_symbol, 1)
def set_contract(self, symbol: str, exchange: Exchange, product: Product, name: str, size: int, price_tick: float): def set_contract(self, symbol: str, exchange: Exchange, product: Product, name: str, size: int,
price_tick: float, margin_rate: float = 0.1):
"""设置合约信息""" """设置合约信息"""
vt_symbol = '.'.join([symbol, exchange.value]) vt_symbol = '.'.join([symbol, exchange.value])
if vt_symbol not in self.contract_dict: if vt_symbol not in self.contract_dict:
@ -345,11 +346,12 @@ class BackTestingEngine(object):
name=name, name=name,
product=product, product=product,
size=size, size=size,
pricetick=price_tick pricetick=price_tick,
margin_rate=margin_rate
) )
self.contract_dict.update({vt_symbol: c}) self.contract_dict.update({vt_symbol: c})
self.set_size(vt_symbol, size) self.set_size(vt_symbol, size)
# self.set_margin_rate(vt_symbol, ) self.set_margin_rate(vt_symbol, margin_rate)
self.set_price_tick(vt_symbol, price_tick) self.set_price_tick(vt_symbol, price_tick)
self.symbol_exchange_dict.update({symbol: exchange}) self.symbol_exchange_dict.update({symbol: exchange})
@ -486,8 +488,8 @@ class BackTestingEngine(object):
self.bar_interval_seconds = test_settings.get('bar_interval_seconds') self.bar_interval_seconds = test_settings.get('bar_interval_seconds')
# 资金曲线 # 资金曲线
self.acivte_fund_kline = test_settings.get('acivte_fund_kline', False) self.active_fund_kline = test_settings.get('active_fund_kline', False)
if self.acivte_fund_kline: if self.active_fund_kline:
# 创建资金K线 # 创建资金K线
self.create_fund_kline(self.test_name, use_renko=test_settings.get('use_renko', False)) self.create_fund_kline(self.test_name, use_renko=test_settings.get('use_renko', False))
@ -515,8 +517,8 @@ class BackTestingEngine(object):
self.set_slippage(symbol, symbol_data.get('slippage', 0)) self.set_slippage(symbol, symbol_data.get('slippage', 0))
self.set_size(symbol, symbol_data.get('symbol_size', 10)) self.set_size(symbol, symbol_data.get('symbol_size', 10))
margin_rate = symbol_data.get('margin_rate', 0.1)
self.set_margin_rate(symbol, symbol_data.get('margin_rate', 0.1)) self.set_margin_rate(symbol, margin_rate)
self.set_commission_rate(symbol, symbol_data.get('commission_rate', float(0.0001))) self.set_commission_rate(symbol, symbol_data.get('commission_rate', float(0.0001)))
@ -526,7 +528,8 @@ class BackTestingEngine(object):
exchange=Exchange(symbol_data.get('exchange', 'LOCAL')), exchange=Exchange(symbol_data.get('exchange', 'LOCAL')),
product=Product(symbol_data.get('product', "期货")), product=Product(symbol_data.get('product', "期货")),
size=symbol_data.get('symbol_size', 10), size=symbol_data.get('symbol_size', 10),
price_tick=symbol_data.get('price_tick', 1) price_tick=symbol_data.get('price_tick', 1),
margin_rate=margin_rate
) )
def new_tick(self, tick): def new_tick(self, tick):
@ -735,7 +738,7 @@ class BackTestingEngine(object):
self.write_log(u'自动启动策略') self.write_log(u'自动启动策略')
strategy.on_start() strategy.on_start()
if self.acivte_fund_kline: if self.active_fund_kline:
# 创建策略实例的资金K线 # 创建策略实例的资金K线
self.create_fund_kline(name=strategy_name, use_renko=False) self.create_fund_kline(name=strategy_name, use_renko=False)

View File

@ -202,7 +202,7 @@ class CtaLineBar(object):
self.write_log(u'导入卡尔曼过滤器失败,需先安装 pip install pykalman') self.write_log(u'导入卡尔曼过滤器失败,需先安装 pip install pykalman')
self.para_active_kf = False self.para_active_kf = False
def registerEvent(self, event_type, cb_func): def register_event(self, event_type, cb_func):
"""注册事件回调函数""" """注册事件回调函数"""
self.cb_dict.update({event_type: cb_func}) self.cb_dict.update({event_type: cb_func})
if event_type == self.CB_ON_PERIOD: if event_type == self.CB_ON_PERIOD:
@ -643,7 +643,7 @@ class CtaLineBar(object):
# 更新curPeriod的Highlow # 更新curPeriod的Highlow
if self.cur_period is not None: if self.cur_period is not None:
self.cur_period.onPrice(self.cur_tick.last_price) self.cur_period.update_price(self.cur_tick.last_price)
def add_bar(self, bar: BarData, bar_is_completed: bool = False, bar_freq: int = 1): def add_bar(self, bar: BarData, bar_is_completed: bool = False, bar_freq: int = 1):
""" """

View File

@ -96,8 +96,7 @@ class CtaEngine(BaseEngine):
def __init__(self, main_engine: MainEngine, event_engine: EventEngine): def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
"""""" """"""
super(CtaEngine, self).__init__( super().__init__(main_engine, event_engine, APP_NAME)
main_engine, event_engine, APP_NAME)
self.engine_config = {} self.engine_config = {}
@ -428,6 +427,7 @@ class CtaEngine(BaseEngine):
type=type, type=type,
price=price, price=price,
volume=volume, volume=volume,
strategy_name=strategy.strategy_name
) )
# 如果没有指定网关,则使用合约信息内的网关 # 如果没有指定网关,则使用合约信息内的网关
@ -749,6 +749,10 @@ class CtaEngine(BaseEngine):
return contract.pricetick return contract.pricetick
def get_tick(self, vt_symbol: str):
"""获取合约得最新tick"""
return self.main_engine.get_tick(vt_symbol)
def get_price(self, vt_symbol: str): def get_price(self, vt_symbol: str):
"""查询合约的最新价格""" """查询合约的最新价格"""
tick = self.main_engine.get_tick(vt_symbol) tick = self.main_engine.get_tick(vt_symbol)

View File

@ -0,0 +1,423 @@
# encoding: UTF-8
'''
本文件中包含的是CTA模块的组合套利回测引擎回测引擎的API和CTA引擎一致
可以使用和实盘相同的代码进行回测
华富资产 李来佳
'''
from __future__ import division
import sys
import os
import gc
import pandas as pd
import traceback
import random
import bz2
import pickle
from datetime import datetime, timedelta
from time import sleep
from vnpy.trader.object import (
TickData,
BarData,
RenkoBarData,
)
from vnpy.trader.constant import (
Exchange,
)
from vnpy.trader.utility import (
get_trading_date,
extract_vt_symbol,
get_underlying_symbol,
import_module_by_str
)
from .back_testing import BackTestingEngine
# vnpy交易所与淘宝数据tick目录得对应关系
VN_EXCHANGE_TICKFOLDER_MAP = {
Exchange.SHFE.value: 'SQ',
Exchange.DCE.value: 'DL',
Exchange.CZCE.value: 'ZZ',
Exchange.CFFEX.value: 'ZJ',
Exchange.INE.value: 'SQ'
}
class SpreadTestingEngine(BackTestingEngine):
"""
CTA套利组合回测引擎, 使用回测引擎作为父类
函数接口和策略引擎保持一样
从而实现同一套代码从回测到实盘
针对tick回测
导入CTA_Settings
"""
def __init__(self, event_engine=None):
"""Constructor"""
super().__init__(event_engine)
self.tick_path = None # tick级别回测 路径
self.strategy_start_date_dict = {}
self.strategy_end_date_dict = {}
def prepare_env(self, test_settings):
self.output('portfolio prepare_env')
super().prepare_env(test_settings)
def load_strategy(self, strategy_name: str, strategy_setting: dict = None):
"""
装载回测的策略
setting是参数设置包括
class_name: str, 策略类名字
vt_symbol: str, 缺省合约
setting: {}, 策略的参数
auto_init: True/False, 策略是否自动初始化
auto_start: True/False 策略是否自动启动
"""
# 获取策略的类名
class_name = strategy_setting.get('class_name', None)
if class_name is None or strategy_name is None:
self.write_error(u'setting中没有class_name')
return
# strategy_class => module.strategy_class
if '.' not in class_name:
module_name = self.class_module_map.get(class_name, None)
if module_name:
class_name = module_name + '.' + class_name
self.write_log(u'转换策略为全路径:{}'.format(class_name))
# 获取策略类的定义
strategy_class = import_module_by_str(class_name)
if strategy_class is None:
self.write_error(u'加载策略模块失败:{}'.format(class_name))
return
# 处理 vt_symbol
vt_symbol = strategy_setting.get('vt_symbol')
symbol, exchange = extract_vt_symbol(vt_symbol)
subscribe_symobls = [vt_symbol]
# 属于自定义套利合约
if exchange == Exchange.SPD:
act_symbol, act_ratio, pas_symbol, pas_ratio, spread_type = symbol.split('-')
act_underly = get_underlying_symbol(act_symbol).upper()
pas_underly = get_underlying_symbol(pas_symbol).upper()
act_exchange = self.get_exchange(f'{act_underly}99')
pas_exchange = self.get_exchange(f'{pas_underly}99')
idx_contract = self.get_contract(f'{act_underly}99.{act_exchange.value}')
self.set_contract(symbol=act_symbol,
exchange=act_exchange,
product=idx_contract.product,
name=act_symbol,
size=idx_contract.size,
price_tick=idx_contract.pricetick,
margin_rate=idx_contract.margin_rate)
if pas_underly != act_underly:
idx_contract = self.get_contract(f'{pas_underly}99.{pas_exchange.value}')
self.set_contract(symbol=pas_symbol,
exchange=pas_exchange,
product=idx_contract.product,
name=act_symbol,
size=idx_contract.size,
price_tick=idx_contract.pricetick,
margin_rate=idx_contract.margin_rate)
subscribe_symobls.remove(vt_symbol)
subscribe_symobls.append(f'{act_symbol}.{act_exchange.value}')
subscribe_symobls.append(f'{pas_symbol}.{pas_exchange.value}')
# 取消自动启动
if 'auto_start' in strategy_setting:
strategy_setting.update({'auto_start': False})
# 策略参数设置
setting = strategy_setting.get('setting', {})
# 强制更新回测为True
setting.update({'backtesting': True})
# 创建实例
strategy = strategy_class(self, strategy_name, vt_symbol, setting)
# 保存到策略实例映射表中
self.strategies.update({strategy_name: strategy})
# 更新vt_symbol合约与策略的订阅关系
for sub_vt_symbol in subscribe_symobls:
self.subscribe_symbol(strategy_name=strategy_name, vt_symbol=sub_vt_symbol)
if strategy_setting.get('auto_init', False):
self.write_log(u'自动初始化策略')
strategy.on_init()
if strategy_setting.get('auto_start', False):
self.write_log(u'自动启动策略')
strategy.on_start()
if self.active_fund_kline:
# 创建策略实例的资金K线
self.create_fund_kline(name=strategy_name, use_renko=False)
def run_portfolio_test(self, strategy_settings: dict = {}):
"""
运行组合回测
"""
if not self.strategy_start_date:
self.write_error(u'回测开始日期未设置。')
return
if len(strategy_settings) == 0:
self.write_error('未提供有效配置策略实例')
return
self.cur_capital = self.init_capital # 更新设置期初资金
if not self.data_end_date:
self.data_end_date = datetime.today()
self.write_log(u'开始套利组合回测')
for strategy_name, strategy_setting in strategy_settings.items():
# 策略得启动日期
if 'start_date' in strategy_setting:
start_date = strategy_setting.get('start_date')
start_date = datetime.strptime(start_date, '%Y-%m-%d')
self.strategy_start_date_dict.update({strategy_name, start_date})
# 策略得结束日期
if 'end_date' in strategy_setting:
end_date = strategy_setting.get('end_date')
end_date = datetime.strptime(end_date, '%Y-%m-%d')
self.strategy_end_date_dict.update({strategy_name, end_date})
self.load_strategy(strategy_name, strategy_setting)
self.write_log(u'策略初始化完成')
self.write_log(u'开始回放数据')
self.write_log(u'开始回测:{} ~ {}'.format(self.data_start_date, self.data_end_date))
self.run_tick_test()
def load_csv_file(self, tick_folder, vt_symbol, tick_date):
"""从文件中读取tick返回list[{dict}]"""
symbol, exchange = extract_vt_symbol(vt_symbol)
underly_symbol = get_underlying_symbol(symbol)
exchange_folder = VN_EXCHANGE_TICKFOLDER_MAP.get(exchange.value)
if exchange == Exchange.INE:
file_path = os.path.abspath(
os.path.join(
tick_folder,
exchange_folder,
tick_date.strftime('%Y'),
tick_date.strftime('%Y%m'),
tick_date.strftime('%Y%m%d'),
'{}_{}.csv'.format(symbol.upper(), tick_date.strftime('%Y%m%d'))))
else:
file_path = os.path.abspath(
os.path.join(
tick_folder,
exchange_folder,
tick_date.strftime('%Y'),
tick_date.strftime('%Y%m'),
tick_date.strftime('%Y%m%d'),
'{}{}_{}.csv'.format(underly_symbol.upper(), symbol[-2:], tick_date.strftime('%Y%m%d'))))
ticks = []
if not os.path.isfile(file_path):
self.write_log(u'{0}文件不存在'.format(file_path))
return None
df = pd.read_csv(file_path, encoding='gbk', parse_dates=False)
df.columns = ['date', 'time', 'last_price', 'volume', 'last_volume', 'open_interest',
'bid_price_1', 'bid_volume_1', 'bid_price_2', 'bid_volume_2', 'bid_price_3', 'bid_volume_3',
'ask_price_1', 'ask_volume_1', 'ask_price_2', 'ask_volume_2', 'ask_price_3', 'ask_volume_3', 'BS']
self.write_log(u'加载csv文件{}'.format(file_path))
last_time = None
for index, row in df.iterrows():
# 日期, 时间, 成交价, 成交量, 总量, 属性(持仓增减), B1价, B1量, B2价, B2量, B3价, B3量, S1价, S1量, S2价, S2量, S3价, S3量, BS
# 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
tick = row.to_dict()
tick.update({'symbol': symbol, 'exchange': exchange.value, 'trading_day': tick_date.strftime('%Y-%m-%d')})
tick_datetime = datetime.strptime(tick['date'] + ' ' + tick['time'], '%Y-%m-%d %H:%M:%S')
# 修正毫秒
if tick['time'] == last_time:
# 与上一个tick的时间去除毫秒后相同,修改为500毫秒
tick_datetime = tick_datetime.replace(microsecond=500)
tick['time'] = tick_datetime.strftime('%H:%M:%S.%f')
else:
last_time = tick['time']
tick_datetime = tick_datetime.replace(microsecond=0)
tick['time'] = tick_datetime.strftime('%H:%M:%S.%f')
tick['datetime'] = tick_datetime
# 排除涨停/跌停的数据
if (float(tick['bid_price_1']) == float('1.79769E308') and int(tick['bid_volume_1']) == 0) \
or (float(tick['ask_price_1']) == float('1.79769E308') and int(tick['ask_volume_1']) == 0):
continue
ticks.append(tick)
del df
return ticks
def load_bz2_cache(self, cache_folder, cache_symbol, cache_date):
"""
加载缓存数据
list[{dict}]
"""
if not os.path.exists(cache_folder):
self.write_error('缓存目录:{}不存在,不能读取'.format(cache_folder))
return None
cache_folder_year_month = os.path.join(cache_folder, cache_date[:6])
if not os.path.exists(cache_folder_year_month):
self.write_error('缓存目录:{}不存在,不能读取'.format(cache_folder_year_month))
return None
cache_file = os.path.join(cache_folder_year_month, '{}_{}.pkb2'.format(cache_symbol, cache_date))
if not os.path.isfile(cache_file):
cache_file = os.path.join(cache_folder_year_month, '{}_{}.pkz2'.format(cache_symbol, cache_date))
if not os.path.isfile(cache_file):
self.write_error('缓存文件:{}不存在,不能读取'.format(cache_file))
return None
with bz2.BZ2File(cache_file, 'rb') as f:
data = pickle.load(f)
return data
return None
def get_day_tick_df(self, test_day):
"""获取某一天得所有合约tick"""
tick_data_dict = {}
for vt_symbol in list(self.symbol_strategy_map.keys()):
symbol, exchange = extract_vt_symbol(vt_symbol)
tick_list = self.load_csv_file(tick_folder=self.tick_path,
vt_symbol=vt_symbol,
tick_date=test_day)
if not tick_list or len(tick_list) == 0:
continue
symbol_tick_df = pd.DataFrame(tick_list)
# 缓存文件中datetime字段已经是datetime格式
# 暂时根据时间去重没有汇总volume
symbol_tick_df.drop_duplicates(subset=['datetime'], keep='first', inplace=True)
symbol_tick_df.set_index('datetime', inplace=True)
tick_data_dict.update({vt_symbol: symbol_tick_df})
if len(tick_data_dict) == 0:
return None
tick_df = pd.concat(tick_data_dict, axis=0).swaplevel(0, 1).sort_index()
return tick_df
def run_tick_test(self):
"""运行tick级别组合回测"""
testdays = (self.data_end_date - self.data_start_date).days
if testdays < 1:
self.write_log(u'回测时间不足')
return
gc_collect_days = 0
# 循环每一天
for i in range(0, testdays):
test_day = self.data_start_date + timedelta(days=i)
combined_df = self.get_day_tick_df(test_day)
if combined_df is None:
continue
try:
for (dt, vt_symbol), tick_data in combined_df.iterrows():
symbol, exchange = extract_vt_symbol(vt_symbol)
tick = TickData(
gateway_name='backtesting',
symbol=symbol,
exchange=exchange,
datetime=dt,
date=dt.strftime('%Y-%m-%d'),
time=dt.strftime('%H:%M:%S.%f'),
trading_day=tick_data['trading_day'],
last_price=float(tick_data['last_price']),
volume=int(tick_data['volume']),
ask_price_1=float(tick_data['ask_price_1']),
ask_volume_1=int(tick_data['ask_volume_1']),
bid_price_1=float(tick_data['bid_price_1']),
bid_volume_1=int(tick_data['bid_volume_1'])
)
self.new_tick(tick)
# 结束一个交易日后,更新每日净值
self.saving_daily_data(test_day,
self.cur_capital,
self.max_net_capital,
self.total_commission)
self.cancel_orders()
# 更新持仓缓存
self.update_pos_buffer()
gc_collect_days += 1
if gc_collect_days >= 10:
# 执行内存回收
gc.collect()
sleep(1)
gc_collect_days = 0
if self.net_capital < 0:
self.write_error(u'净值低于0回测停止')
self.output(u'净值低于0回测停止')
return
except Exception as ex:
self.write_error(u'回测异常导致停止:{}'.format(str(ex)))
self.write_error(u'{},{}'.format(str(ex), traceback.format_exc()))
print(str(ex), file=sys.stderr)
traceback.print_exc()
return
self.write_log(u'tick数据回放完成')
def single_test(test_setting: dict, strategy_setting: dict):
"""
单一回测
: test_setting, 组合回测所需的配置包括合约信息数据tick信息回测时间资金等
strategy_setting, dict, 一个或多个策略配置
"""
# 创建组合回测引擎
engine = SpreadTestingEngine()
engine.prepare_env(test_setting)
try:
engine.run_portfolio_test(strategy_setting)
# 回测结果,保存
engine.show_backtesting_result()
except Exception as ex:
print('组合回测异常{}'.format(str(ex)))
traceback.print_exc()
return False
print('测试结束')
return True

View File

@ -1150,11 +1150,23 @@ class CtaProFutureTemplate(CtaProTemplate):
self.save_dist(dist_record) self.save_dist(dist_record)
self.pos = self.position.pos self.pos = self.position.pos
def fix_order(self, order: OrderData):
"""修正order被拆单得情况"""
order_info = self.active_orders.get(order.vt_orderid, None)
if order_info:
volume = order_info.get('volume')
if volume != order.volume:
self.write_log(f'调整{order.vt_orderid} volume:{volume}=>{order.volume}')
order_info.update({'volume': order.volume})
def on_order(self, order: OrderData): def on_order(self, order: OrderData):
"""报单更新""" """报单更新"""
# 未执行的订单中,存在是异常,删除 # 未执行的订单中,存在是异常,删除
self.write_log(u'{}报单更新,{}'.format(self.cur_datetime, order.__dict__)) self.write_log(u'{}报单更新,{}'.format(self.cur_datetime, order.__dict__))
# 修正order被拆单得情况"
self.fix_order(order)
if order.vt_orderid in self.active_orders: if order.vt_orderid in self.active_orders:
if order.volume == order.traded and order.status in [Status.ALLTRADED]: if order.volume == order.traded and order.status in [Status.ALLTRADED]:
@ -1543,7 +1555,7 @@ class CtaProFutureTemplate(CtaProTemplate):
for vt_orderid in list(self.active_orders.keys()): for vt_orderid in list(self.active_orders.keys()):
order_info = self.active_orders[vt_orderid] order_info = self.active_orders[vt_orderid]
order_symbol = order_info.get('symbol', self.vt_symbol) order_vt_symbol = order_info.get('vt_symbol', self.vt_symbol)
order_time = order_info['order_time'] order_time = order_info['order_time']
order_volume = order_info['volume'] - order_info['traded'] order_volume = order_info['volume'] - order_info['traded']
# order_price = order_info['price'] # order_price = order_info['price']
@ -1555,7 +1567,7 @@ class CtaProFutureTemplate(CtaProTemplate):
over_seconds = (dt - order_time).total_seconds() over_seconds = (dt - order_time).total_seconds()
# 只处理未成交的限价委托单 # 只处理未成交的限价委托单
if order_status in [Status.NOTTRADED] and (order_type == OrderType.LIMIT or '.SPD' in order_symbol): if order_status in [Status.NOTTRADED] and (order_type == OrderType.LIMIT or '.SPD' in order_vt_symbol):
if over_seconds > self.cancel_seconds or force: # 超过设置的时间还未成交 if over_seconds > self.cancel_seconds or force: # 超过设置的时间还未成交
self.write_log(u'超时{}秒未成交取消委托单vt_orderid:{},order:{}' self.write_log(u'超时{}秒未成交取消委托单vt_orderid:{},order:{}'
.format(over_seconds, vt_orderid, order_info)) .format(over_seconds, vt_orderid, order_info))
@ -1586,10 +1598,10 @@ class CtaProFutureTemplate(CtaProTemplate):
u'网格volume:{},order_volume:{}不一致,修正'.format(order_grid.volume, order_volume)) u'网格volume:{},order_volume:{}不一致,修正'.format(order_grid.volume, order_volume))
order_grid.volume = order_volume order_grid.volume = order_volume
self.write_log(u'重新提交{}开空委托,开空价{}v:{}'.format(order_symbol, short_price, order_volume)) self.write_log(u'重新提交{}开空委托,开空价{}v:{}'.format(order_vt_symbol, short_price, order_volume))
vt_orderids = self.short(price=short_price, vt_orderids = self.short(price=short_price,
volume=order_volume, volume=order_volume,
vt_symbol=order_symbol, vt_symbol=order_vt_symbol,
order_type=order_type, order_type=order_type,
order_time=self.cur_datetime, order_time=self.cur_datetime,
grid=order_grid) grid=order_grid)
@ -1606,10 +1618,10 @@ class CtaProFutureTemplate(CtaProTemplate):
u'网格volume:{},order_volume:{}不一致,修正'.format(order_grid.volume, order_volume)) u'网格volume:{},order_volume:{}不一致,修正'.format(order_grid.volume, order_volume))
order_grid.volume = order_volume order_grid.volume = order_volume
self.write_log(u'重新提交{}开多委托,开多价{}v:{}'.format(order_symbol, buy_price, order_volume)) self.write_log(u'重新提交{}开多委托,开多价{}v:{}'.format(order_vt_symbol, buy_price, order_volume))
vt_orderids = self.buy(price=buy_price, vt_orderids = self.buy(price=buy_price,
volume=order_volume, volume=order_volume,
vt_symbol=order_symbol, vt_symbol=order_vt_symbol,
order_type=order_type, order_type=order_type,
order_time=self.cur_datetime, order_time=self.cur_datetime,
grid=order_grid) grid=order_grid)
@ -1623,10 +1635,10 @@ class CtaProFutureTemplate(CtaProTemplate):
# 属于平多委托单 # 属于平多委托单
if order_info['direction'] == Direction.SHORT: if order_info['direction'] == Direction.SHORT:
sell_price = self.cur_mi_price - self.price_tick sell_price = self.cur_mi_price - self.price_tick
self.write_log(u'重新提交{}平多委托,{}v:{}'.format(order_symbol, sell_price, order_volume)) self.write_log(u'重新提交{}平多委托,{}v:{}'.format(order_vt_symbol, sell_price, order_volume))
vt_orderids = self.sell(price=sell_price, vt_orderids = self.sell(price=sell_price,
volume=order_volume, volume=order_volume,
vt_symbol=order_symbol, vt_symbol=order_vt_symbol,
order_type=order_type, order_type=order_type,
order_time=self.cur_datetime, order_time=self.cur_datetime,
grid=order_grid) grid=order_grid)
@ -1637,10 +1649,10 @@ class CtaProFutureTemplate(CtaProTemplate):
# 属于平空委托单 # 属于平空委托单
else: else:
cover_price = self.cur_mi_price + self.price_tick cover_price = self.cur_mi_price + self.price_tick
self.write_log(u'重新提交{}平空委托,委托价{}v:{}'.format(order_symbol, cover_price, order_volume)) self.write_log(u'重新提交{}平空委托,委托价{}v:{}'.format(order_vt_symbol, cover_price, order_volume))
vt_orderids = self.cover(price=cover_price, vt_orderids = self.cover(price=cover_price,
volume=order_volume, volume=order_volume,
vt_symbol=order_symbol, vt_symbol=order_vt_symbol,
order_type=order_type, order_type=order_type,
order_time=self.cur_datetime, order_time=self.cur_datetime,
grid=order_grid) grid=order_grid)

File diff suppressed because it is too large Load Diff

View File

@ -65,14 +65,19 @@ class OffsetConverter:
return holding return holding
def convert_order_request(self, req: OrderRequest, lock: bool, gateway_name: str = ''): def convert_order_request(self, req: OrderRequest, lock: bool, gateway_name: str = ''):
"""""" """转换委托单"""
# 合约是净仓,不具有多空,不需要转换
if not self.is_convert_required(req.vt_symbol): if not self.is_convert_required(req.vt_symbol):
return [req] return [req]
# 获取当前持仓信息
holding = self.get_position_holding(req.vt_symbol, gateway_name) holding = self.get_position_holding(req.vt_symbol, gateway_name)
if lock: if lock:
# 锁仓转换
return holding.convert_order_request_lock(req) return holding.convert_order_request_lock(req)
# 平今/平昨拆分
elif req.exchange in [Exchange.SHFE, Exchange.INE]: elif req.exchange in [Exchange.SHFE, Exchange.INE]:
return holding.convert_order_request_shfe(req) return holding.convert_order_request_shfe(req)
else: else:
@ -231,7 +236,7 @@ class PositionHolding:
self.short_pos_frozen = self.short_td_frozen + self.short_yd_frozen self.short_pos_frozen = self.short_td_frozen + self.short_yd_frozen
def convert_order_request_shfe(self, req: OrderRequest): def convert_order_request_shfe(self, req: OrderRequest):
"""""" """上期所,委托单拆分"""
if req.offset == Offset.OPEN: if req.offset == Offset.OPEN:
return [req] return [req]

View File

@ -24,6 +24,7 @@ from .event import (
) )
from .gateway import BaseGateway from .gateway import BaseGateway
from .object import ( from .object import (
Exchange,
CancelRequest, CancelRequest,
LogData, LogData,
OrderRequest, OrderRequest,
@ -56,6 +57,7 @@ class MainEngine:
self.exchanges = [] self.exchanges = []
self.rm_engine = None self.rm_engine = None
self.algo_engine = None
os.chdir(TRADER_DIR) # Change working directory os.chdir(TRADER_DIR) # Change working directory
self.init_engines() # Initialize function engines self.init_engines() # Initialize function engines
@ -99,6 +101,8 @@ class MainEngine:
engine = self.add_engine(app.engine_class) engine = self.add_engine(app.engine_class)
if app.app_name == "RiskManager": if app.app_name == "RiskManager":
self.rm_engine = engine self.rm_engine = engine
elif app.app_name == "AlgoTrading":
self.algo_engine == engine
return engine return engine
@ -188,7 +192,14 @@ class MainEngine:
def send_order(self, req: OrderRequest, gateway_name: str): def send_order(self, req: OrderRequest, gateway_name: str):
""" """
Send new order request to a specific gateway. Send new order request to a specific gateway.
扩展支持自定义套利合约 由cta_strategy_pro发出算法单委托由算法引擎进行处理
""" """
# 自定义套利合约,交给算法引擎处理
if self.algo_engine and req.exchange == Exchange.SPD:
return self.algo_engine.send_algo_order(
req=req,
gateway_name=gateway_name)
gateway = self.get_gateway(gateway_name) gateway = self.get_gateway(gateway_name)
if gateway: if gateway:
return gateway.send_order(req) return gateway.send_order(req)
@ -206,6 +217,7 @@ class MainEngine:
def send_orders(self, reqs: Sequence[OrderRequest], gateway_name: str): def send_orders(self, reqs: Sequence[OrderRequest], gateway_name: str):
""" """
批量发单
""" """
gateway = self.get_gateway(gateway_name) gateway = self.get_gateway(gateway_name)
if gateway: if gateway:
@ -461,6 +473,7 @@ class OmsEngine(BaseEngine):
"""""" """"""
contract = event.data contract = event.data
self.contracts[contract.vt_symbol] = contract self.contracts[contract.vt_symbol] = contract
self.contracts[contract.symbol] = contract
def get_tick(self, vt_symbol): def get_tick(self, vt_symbol):
""" """
@ -595,7 +608,7 @@ class CustomContract(object):
gateway_name = setting.get('gateway_name', None) gateway_name = setting.get('gateway_name', None)
if gateway_name is None: if gateway_name is None:
gateway_name = SETTINGS.get('gateway_name', '') gateway_name = SETTINGS.get('gateway_name', '')
vn_exchange = Exchange(setting.get('exchange', 'LOCAL')) vn_exchange = Exchange(setting.get('exchange', 'SPD'))
contract = ContractData( contract = ContractData(
gateway_name=gateway_name, gateway_name=gateway_name,
symbol=symbol, symbol=symbol,

View File

@ -249,7 +249,7 @@ class AccountData(BaseData):
@dataclass @dataclass
class VtFundsFlowData(BaseData): class FundsFlowData(BaseData):
"""历史资金流水数据类(股票专用)""" """历史资金流水数据类(股票专用)"""
# 账号代码相关 # 账号代码相关
@ -353,6 +353,7 @@ class OrderRequest:
volume: float volume: float
price: float = 0 price: float = 0
offset: Offset = Offset.NONE offset: Offset = Offset.NONE
strategy_name: str = ""
def __post_init__(self): def __post_init__(self):
"""""" """"""