[增强功能] 回测引擎增加跨品种套利bar模式; 价差行情模块移动=> gateway.py; kline增加缠论支持

This commit is contained in:
msincenselee 2020-10-22 11:52:10 +08:00
parent ce1a85b656
commit bd7280bfaf
13 changed files with 2192 additions and 427 deletions

View File

@ -8,6 +8,7 @@ from .engine import CtaEngine
from .template import ( from .template import (
Direction, Direction,
Offset, Offset,
Exchange,
Status, Status,
Color, Color,
TickData, TickData,

View File

@ -1240,7 +1240,7 @@ class BackTestingEngine(object):
active_exchange = self.get_exchange(active_symbol) active_exchange = self.get_exchange(active_symbol)
active_vt_symbol = active_symbol + '.' + active_exchange.value active_vt_symbol = active_symbol + '.' + active_exchange.value
passive_exchange = self.get_exchange(passive_symbol) passive_exchange = self.get_exchange(passive_symbol)
# passive_vt_symbol = active_symbol + '.' + passive_exchange.value passive_vt_symbol = passive_symbol + '.' + passive_exchange.value
# 主动腿成交记录 # 主动腿成交记录
act_trade = TradeData(gateway_name=self.gateway_name, act_trade = TradeData(gateway_name=self.gateway_name,
symbol=active_symbol, symbol=active_symbol,
@ -1438,10 +1438,10 @@ class BackTestingEngine(object):
# 如果当前没有空单,属于异常行为 # 如果当前没有空单,属于异常行为
if len(self.short_position_list) == 0: if len(self.short_position_list) == 0:
self.write_error(u'异常!没有空单持仓不能cover') self.write_error(u'异常!没有空单持仓不能cover')
raise Exception(u'异常!没有空单持仓不能cover') # raise Exception(u'异常!没有空单持仓不能cover')
return return
cur_short_pos_list = [s_pos.volume for s_pos in self.short_position_list] cur_short_pos_list = [s_pos.volume for s_pos in self.short_position_list if s_pos.vt_symbol == trade.vt_symbol]
self.write_log(u'{}当前空单:{}'.format(trade.vt_symbol, cur_short_pos_list)) self.write_log(u'{}当前空单:{}'.format(trade.vt_symbol, cur_short_pos_list))
@ -1450,8 +1450,12 @@ class BackTestingEngine(object):
val.vt_symbol == trade.vt_symbol and val.strategy_name == trade.strategy_name] val.vt_symbol == trade.vt_symbol and val.strategy_name == trade.strategy_name]
if len(pop_indexs) < 1: if len(pop_indexs) < 1:
self.write_error(u'异常,{}没有对应symbol:{}的空单持仓'.format(trade.strategy_name, trade.vt_symbol)) if 'spd' in vt_tradeid:
raise Exception(u'realtimeCalculate2() Exception,没有对应symbol:{0}的空单持仓'.format(trade.vt_symbol)) self.write_error(f'没有{trade.strategy_name}对应的symbol:{trade.vt_symbol}的空单持仓, 继续')
break
else:
self.write_error(u'异常,{}没有对应symbol:{}的空单持仓, 终止'.format(trade.strategy_name, trade.vt_symbol))
# raise Exception(u'realtimeCalculate2() Exception,没有对应symbol:{0}的空单持仓'.format(trade.vt_symbol))
return return
pop_index = pop_indexs[0] pop_index = pop_indexs[0]
@ -1494,7 +1498,7 @@ class BackTestingEngine(object):
self.trade_pnl_list.append(t) self.trade_pnl_list.append(t)
# 非自定义套利对,才更新到策略盈亏 # 非自定义套利对,才更新到策略盈亏
if not open_trade.vt_symbol.endswith('SPD'): if not (open_trade.vt_symbol.endswith('SPD') or open_trade.vt_symbol.endswith('SPD99')):
# 更新策略实例的累加盈亏 # 更新策略实例的累加盈亏
self.pnl_strategy_dict.update( self.pnl_strategy_dict.update(
{open_trade.strategy_name: self.pnl_strategy_dict.get(open_trade.strategy_name, {open_trade.strategy_name: self.pnl_strategy_dict.get(open_trade.strategy_name,
@ -1506,6 +1510,8 @@ class BackTestingEngine(object):
open_trade.volume, result.pnl, result.commission) open_trade.volume, result.pnl, result.commission)
self.write_log(msg) self.write_log(msg)
# 添加到交易结果汇总
result_list.append(result) result_list.append(result)
if g_result is None: if g_result is None:
@ -1569,6 +1575,9 @@ class BackTestingEngine(object):
self.write_log(msg) self.write_log(msg)
# 添加到交易结果汇总
result_list.append(result)
# 更新减少开仓单的volume,重新推进开仓单列表中 # 更新减少开仓单的volume,重新推进开仓单列表中
open_trade.volume = remain_volume open_trade.volume = remain_volume
self.write_log(u'更新减少开仓单的volume,重新推进开仓单列表中:{}'.format(open_trade.volume)) self.write_log(u'更新减少开仓单的volume,重新推进开仓单列表中:{}'.format(open_trade.volume))
@ -1577,7 +1586,7 @@ class BackTestingEngine(object):
self.write_log(u'当前空单:{}'.format(cur_short_pos_list)) self.write_log(u'当前空单:{}'.format(cur_short_pos_list))
cover_volume = 0 cover_volume = 0
result_list.append(result)
if g_result is not None: if g_result is not None:
# 更新组合的数据 # 更新组合的数据
@ -1606,18 +1615,21 @@ class BackTestingEngine(object):
while sell_volume > 0: while sell_volume > 0:
if len(self.long_position_list) == 0: if len(self.long_position_list) == 0:
self.write_error(f'异常,没有{trade.vt_symbol}的多仓') self.write_error(f'异常,没有{trade.vt_symbol}的多仓')
raise RuntimeError(u'realtimeCalculate2() Exception,没有开多单') # raise RuntimeError(u'realtimeCalculate2() Exception,没有开多单')
return return
pop_indexs = [i for i, val in enumerate(self.long_position_list) if pop_indexs = [i for i, val in enumerate(self.long_position_list) if
val.vt_symbol == trade.vt_symbol and val.strategy_name == trade.strategy_name] val.vt_symbol == trade.vt_symbol and val.strategy_name == trade.strategy_name]
if len(pop_indexs) < 1: if len(pop_indexs) < 1:
self.write_error(f'没有{trade.strategy_name}对应的symbol{trade.vt_symbol}多单数据,') if 'spd' in vt_tradeid:
raise RuntimeError( self.write_error(f'没有{trade.strategy_name}对应的symbol:{trade.vt_symbol}多单数据, 继续')
f'realtimeCalculate2() Exception,没有对应的symbol{trade.vt_symbol}多单数据,') break
else:
self.write_error(f'没有{trade.strategy_name}对应的symbol:{trade.vt_symbol}多单数据, 终止')
# raise RuntimeError(f'realtimeCalculate2() Exception,没有对应的symbol:{trade.vt_symbol}多单数据,')
return return
cur_long_pos_list = [s_pos.volume for s_pos in self.long_position_list] cur_long_pos_list = [s_pos.volume for s_pos in self.long_position_list if s_pos.vt_symbol == trade.vt_symbol]
self.write_log(u'{}当前多单:{}'.format(trade.vt_symbol, cur_long_pos_list)) self.write_log(u'{}当前多单:{}'.format(trade.vt_symbol, cur_long_pos_list))
@ -1669,6 +1681,8 @@ class BackTestingEngine(object):
open_trade.volume, result.pnl, result.commission) open_trade.volume, result.pnl, result.commission)
self.write_log(msg) self.write_log(msg)
# 添加到交易结果汇总
result_list.append(result) result_list.append(result)
if g_result is None: if g_result is None:
@ -1728,13 +1742,14 @@ class BackTestingEngine(object):
result.commission) result.commission)
self.write_log(msg) self.write_log(msg)
# 添加到交易结果汇总
result_list.append(result)
# 减少开多volume,重新推进多单持仓列表中 # 减少开多volume,重新推进多单持仓列表中
open_trade.volume = remain_volume open_trade.volume = remain_volume
self.long_position_list.append(open_trade) self.long_position_list.append(open_trade)
sell_volume = 0 sell_volume = 0
result_list.append(result)
if g_result is not None: if g_result is not None:
# 更新组合的数据 # 更新组合的数据
@ -1786,8 +1801,11 @@ class BackTestingEngine(object):
continue continue
# 当前空单保证金 # 当前空单保证金
if self.use_margin: if self.use_margin:
try:
cur_occupy_money = max(self.get_price(t.vt_symbol), t.price) * abs(t.volume) * self.get_size( cur_occupy_money = max(self.get_price(t.vt_symbol), t.price) * abs(t.volume) * self.get_size(
t.vt_symbol) * self.get_margin_rate(t.vt_symbol) t.vt_symbol) * self.get_margin_rate(t.vt_symbol)
except Exception as ex:
self.write_error(ex)
else: else:
cur_occupy_money = self.get_price(t.vt_symbol) * abs(t.volume) * self.get_size( cur_occupy_money = self.get_price(t.vt_symbol) * abs(t.volume) * self.get_size(
t.vt_symbol) * self.get_margin_rate(t.vt_symbol) t.vt_symbol) * self.get_margin_rate(t.vt_symbol)

View File

@ -797,6 +797,10 @@ class CtaEngine(BaseEngine):
return True return True
@lru_cache()
def get_exchange(self, symbol):
return self.main_engine.get_exchange(symbol)
@lru_cache() @lru_cache()
def get_name(self, vt_symbol: str): def get_name(self, vt_symbol: str):
"""查询合约的name""" """查询合约的name"""
@ -868,6 +872,9 @@ class CtaEngine(BaseEngine):
def get_contract(self, vt_symbol): def get_contract(self, vt_symbol):
return self.main_engine.get_contract(vt_symbol) return self.main_engine.get_contract(vt_symbol)
def get_custom_contract(self, vt_symbol):
return self.main_engine.get_custom_contract(vt_symbol.split('.')[0])
def get_all_contracts(self): def get_all_contracts(self):
return self.main_engine.get_all_contracts() return self.main_engine.get_all_contracts()
@ -986,6 +993,7 @@ class CtaEngine(BaseEngine):
""" """
Add a new strategy. Add a new strategy.
""" """
try:
if strategy_name in self.strategies: if strategy_name in self.strategies:
msg = f"创建策略失败,存在重名{strategy_name}" msg = f"创建策略失败,存在重名{strategy_name}"
self.write_log(msg=msg, self.write_log(msg=msg,
@ -1019,6 +1027,14 @@ class CtaEngine(BaseEngine):
if auto_init: if auto_init:
self.init_strategy(strategy_name, auto_start=auto_start) self.init_strategy(strategy_name, auto_start=auto_start)
except Exception as ex:
msg = f'添加策略实例{strategy_name}失败,{str(ex)}'
self.write_error(msg)
self.write_error(traceback.format_exc())
self.send_wechat(msg)
return False, f'添加策略实例{strategy_name}失败'
return True, f'成功添加{strategy_name}' return True, f'成功添加{strategy_name}'
def init_strategy(self, strategy_name: str, auto_start: bool = False): def init_strategy(self, strategy_name: str, auto_start: bool = False):
@ -1804,7 +1820,8 @@ class CtaEngine(BaseEngine):
symbol_pos.get('策略空单', 0) symbol_pos.get('策略空单', 0)
)) ))
diff_pos_dict.update({vt_symbol: {"long": symbol_pos.get('账号多单', 0) - symbol_pos.get('策略多单', 0), diff_pos_dict.update({vt_symbol: {"long": symbol_pos.get('账号多单', 0) - symbol_pos.get('策略多单', 0),
"short":symbol_pos.get('账号空单', 0) - symbol_pos.get('策略空单', 0)}}) "short": symbol_pos.get('账号空单', 0) - symbol_pos.get('策略空单',
0)}})
else: else:
match = round(symbol_pos.get('账号空单', 0), 7) == round(symbol_pos.get('策略空单', 0), 7) and \ match = round(symbol_pos.get('账号空单', 0), 7) == round(symbol_pos.get('策略空单', 0), 7) and \
round(symbol_pos.get('账号多单', 0), 7) == round(symbol_pos.get('策略多单', 0), 7) round(symbol_pos.get('账号多单', 0), 7) == round(symbol_pos.get('策略多单', 0), 7)

View File

@ -277,7 +277,7 @@ class PortfolioTestingEngine(BackTestingEngine):
bar.high_price = float(bar_data['high']) bar.high_price = float(bar_data['high'])
bar.low_price = float(bar_data['low']) bar.low_price = float(bar_data['low'])
bar.volume = int(bar_data['volume']) bar.volume = int(bar_data['volume'])
bar.open_interest = int(bar_data.get('open_interest', 0)) bar.open_interest = float(bar_data.get('open_interest', 0))
bar.date = bar_datetime.strftime('%Y-%m-%d') bar.date = bar_datetime.strftime('%Y-%m-%d')
bar.time = bar_datetime.strftime('%H:%M:%S') bar.time = bar_datetime.strftime('%H:%M:%S')
str_td = str(bar_data.get('trading_day', '')) str_td = str(bar_data.get('trading_day', ''))

View File

@ -13,7 +13,7 @@ import gc
import pandas as pd import pandas as pd
import numpy as np import numpy as np
import traceback import traceback
import random
import bz2 import bz2
import pickle import pickle
@ -21,7 +21,7 @@ from datetime import datetime, timedelta
from time import sleep from time import sleep
from vnpy.trader.object import ( from vnpy.trader.object import (
TickData, TickData, BarData
) )
from vnpy.trader.constant import ( from vnpy.trader.constant import (
Exchange, Exchange,
@ -33,6 +33,7 @@ from vnpy.trader.utility import (
get_trading_date, get_trading_date,
import_module_by_str import_module_by_str
) )
from vnpy.trader.gateway import TickCombiner
from .back_testing import BackTestingEngine from .back_testing import BackTestingEngine
@ -51,8 +52,9 @@ class SpreadTestingEngine(BackTestingEngine):
CTA套利组合回测引擎, 使用回测引擎作为父类 CTA套利组合回测引擎, 使用回测引擎作为父类
函数接口和策略引擎保持一样 函数接口和策略引擎保持一样
从而实现同一套代码从回测到实盘 从而实现同一套代码从回测到实盘
针对tick回测
导入CTA_Settings tick回测:
1,设置tick_path,
""" """
@ -60,9 +62,17 @@ class SpreadTestingEngine(BackTestingEngine):
"""Constructor""" """Constructor"""
super().__init__(event_engine) super().__init__(event_engine)
self.tick_path = None # tick级别回测 路径 self.tick_path = None # tick级别回测 路径
self.use_tq = False self.use_tq = False # True:使用tq数据; False:使用淘宝购买的数据(19年之前)
self.strategy_start_date_dict = {} self.strategy_start_date_dict = {}
self.strategy_end_date_dict = {} self.strategy_end_date_dict = {}
self.tick_combiner_dict = {} # tick合成器
self.symbol_combiner_dict = {} # symbol : [combiner]
self.bar_csv_file = {}
self.bar_df_dict = {} # 历史数据的df回测用
self.bar_df = None # 历史数据的df时间+symbol作为组合索引
self.bar_interval_seconds = 60 # bar csv文件属于K线类型K线的周期秒数,缺省是1分钟
self.on_tick = self.new_tick # 仿造 gateway的on_tick接口 => new_tick
def prepare_env(self, test_setting): def prepare_env(self, test_setting):
self.output('portfolio prepare_env') self.output('portfolio prepare_env')
@ -70,6 +80,38 @@ class SpreadTestingEngine(BackTestingEngine):
self.use_tq = test_setting.get('use_tq', False) self.use_tq = test_setting.get('use_tq', False)
def prepare_data(self, data_dict):
"""
准备组合数据
:param data_dict: 合约得配置参数
:return:
"""
# 调用回测引擎,跟新合约得数据
super().prepare_data(data_dict)
if len(data_dict) == 0:
self.write_log(u'请指定回测数据和文件')
return
if self.mode == 'tick':
return
# 检查/更新bar文件
for symbol, symbol_data in data_dict.items():
self.write_log(u'配置{}数据:{}'.format(symbol, symbol_data))
bar_file = symbol_data.get('bar_file', None)
if bar_file is None:
self.write_error(u'{}没有配置数据文件')
continue
if not os.path.isfile(bar_file):
self.write_log(u'{0}文件不存在'.format(bar_file))
continue
self.bar_csv_file.update({symbol: bar_file})
def load_strategy(self, strategy_name: str, strategy_setting: dict = None): def load_strategy(self, strategy_name: str, strategy_setting: dict = None):
""" """
装载回测的策略 装载回测的策略
@ -112,29 +154,90 @@ class SpreadTestingEngine(BackTestingEngine):
pas_underly = get_underlying_symbol(pas_symbol).upper() pas_underly = get_underlying_symbol(pas_symbol).upper()
act_exchange = self.get_exchange(f'{act_underly}99') act_exchange = self.get_exchange(f'{act_underly}99')
pas_exchange = self.get_exchange(f'{pas_underly}99') pas_exchange = self.get_exchange(f'{pas_underly}99')
idx_contract = self.get_contract(f'{act_underly}99.{act_exchange.value}') act_contract = self.get_contract(f'{act_underly}99.{act_exchange.value}')
if self.get_contract(f'{act_symbol}.{act_exchange.value}') is None:
self.set_contract(symbol=act_symbol, self.set_contract(symbol=act_symbol,
exchange=act_exchange, exchange=act_exchange,
product=idx_contract.product, product=act_contract.product,
name=act_symbol, name=act_symbol,
size=idx_contract.size, size=act_contract.size,
price_tick=idx_contract.pricetick, price_tick=act_contract.pricetick,
margin_rate=idx_contract.margin_rate) margin_rate=act_contract.margin_rate)
self.write_log(f'设置主动腿指数合约信息{act_symbol}.{act_exchange.value}')
if pas_underly != act_underly: if pas_underly != act_underly:
idx_contract = self.get_contract(f'{pas_underly}99.{pas_exchange.value}') pas_contract = self.get_contract(f'{pas_underly}99.{pas_exchange.value}')
else:
pas_contract = act_contract
if self.get_contract(f'{pas_symbol}.{pas_exchange.value}') is None:
self.set_contract(symbol=pas_symbol, self.set_contract(symbol=pas_symbol,
exchange=pas_exchange, exchange=pas_exchange,
product=idx_contract.product, product=pas_contract.product,
name=act_symbol, name=act_symbol,
size=idx_contract.size, size=pas_contract.size,
price_tick=idx_contract.pricetick, price_tick=pas_contract.pricetick,
margin_rate=idx_contract.margin_rate) margin_rate=pas_contract.margin_rate)
self.write_log(f'设置被动腿指数合约信息{pas_symbol}.{pas_exchange.value}')
idx_spd_symbol=f'{act_underly}99-{act_ratio}-{pas_underly}99-{pas_ratio}-{spread_type}'
subscribe_symobls.remove(vt_symbol) if f'{idx_spd_symbol}.SPD' not in self.contract_dict:
if spread_type == 'CJ':
if act_underly == pas_underly:
spd_price_tick = act_contract.pricetick
spd_size = act_contract.size
spd_margin_rate = act_contract.margin_rate
else:
spd_price_tick = min(act_contract.pricetick, pas_contract.pricetick)
spd_size = min(act_contract.size, pas_contract.size)
spd_margin_rate = max(act_contract.margin_rate, pas_contract.margin_rate)
else:
spd_price_tick = 0.01
spd_size = 100
spd_margin_rate = 0.1
self.set_contract(
symbol=idx_spd_symbol,
exchange=Exchange.SPD,
product=act_contract.product,
name=idx_spd_symbol,
size=spd_size,
price_tick=spd_price_tick,
margin_rate=spd_margin_rate
)
self.write_log(f'设置套利合约信息{idx_spd_symbol}.SPD')
spd_contract =self.contract_dict.get(f'{idx_spd_symbol}.SPD')
# subscribe_symobls.remove(vt_symbol)
subscribe_symobls.append(f'{act_symbol}.{act_exchange.value}') subscribe_symobls.append(f'{act_symbol}.{act_exchange.value}')
subscribe_symobls.append(f'{pas_symbol}.{pas_exchange.value}') subscribe_symobls.append(f'{pas_symbol}.{pas_exchange.value}')
# 价差生成器
combiner = self.tick_combiner_dict.get(vt_symbol, None)
act_combiners = self.symbol_combiner_dict.get(act_symbol, [])
pas_combiners = self.symbol_combiner_dict.get(pas_symbol, [])
if combiner is None:
combiner = TickCombiner(
gateway=self,
setting={
"symbol": symbol,
"leg1_symbol": act_symbol,
"leg1_ratio": int(act_ratio),
"leg2_symbol": pas_symbol,
"leg2_ratio": int(pas_ratio),
"price_tick": spd_contract.pricetick,
"is_spread": True if spread_type == "CJ" else False,
"is_ratio": True if spread_type == "BJ" else False}
)
self.tick_combiner_dict[vt_symbol] = combiner
self.write_log(f'添加{vt_symbol} tick合成器')
if combiner not in act_combiners:
act_combiners.append(combiner)
self.symbol_combiner_dict.update({act_symbol: act_combiners})
self.write_log(f'添加{act_symbol} => {vt_symbol} 合成器映射关系')
if combiner not in pas_combiners:
pas_combiners.append(combiner)
self.symbol_combiner_dict.update({pas_symbol: pas_combiners})
self.write_log(f'添加{pas_symbol} => {vt_symbol} 合成器映射关系')
# 取消自动启动 # 取消自动启动
if 'auto_start' in strategy_setting: if 'auto_start' in strategy_setting:
strategy_setting.update({'auto_start': False}) strategy_setting.update({'auto_start': False})
@ -205,7 +308,10 @@ class SpreadTestingEngine(BackTestingEngine):
self.write_log(u'开始回测:{} ~ {}'.format(self.data_start_date, self.data_end_date)) self.write_log(u'开始回测:{} ~ {}'.format(self.data_start_date, self.data_end_date))
if self.mode == 'tick':
self.run_tick_test() self.run_tick_test()
else:
self.run_bar_test()
def load_csv_file(self, tick_folder, vt_symbol, tick_date): def load_csv_file(self, tick_folder, vt_symbol, tick_date):
"""从文件中读取tick返回list[{dict}]""" """从文件中读取tick返回list[{dict}]"""
@ -237,7 +343,7 @@ class SpreadTestingEngine(BackTestingEngine):
ticks = [] ticks = []
if not os.path.isfile(file_path): if not os.path.isfile(file_path):
self.write_log(u'{0}文件不存在'.format(file_path)) self.write_log(f'{file_path}文件不存在')
return None return None
df = pd.read_csv(file_path, encoding='gbk', parse_dates=False) df = pd.read_csv(file_path, encoding='gbk', parse_dates=False)
@ -292,7 +398,7 @@ class SpreadTestingEngine(BackTestingEngine):
ticks = [] ticks = []
if not os.path.isfile(file_path): if not os.path.isfile(file_path):
self.write_log(u'{0}文件不存在'.format(file_path)) self.write_log(u'{}文件不存在'.format(file_path))
return None return None
try: try:
df = pd.read_csv(file_path, parse_dates=False) df = pd.read_csv(file_path, parse_dates=False)
@ -332,7 +438,7 @@ class SpreadTestingEngine(BackTestingEngine):
del df del df
except Exception as ex: except Exception as ex:
self.write_log(u'{0}文件读取不成功'.format(file_path)) self.write_log(f'{file_path}文件读取不成功: {str(ex)}')
return None return None
return ticks return ticks
@ -389,6 +495,186 @@ class SpreadTestingEngine(BackTestingEngine):
return tick_df return tick_df
def load_bar_csv_to_df(self, vt_symbol, bar_file, data_start_date=None, data_end_date=None):
"""加载回测bar数据到DataFrame"""
self.output(u'loading {} from {}'.format(vt_symbol, bar_file))
if vt_symbol in self.bar_df_dict:
return True
if bar_file is None or not os.path.exists(bar_file):
self.write_error(u'回测时,{}对应的csv bar文件{}不存在'.format(vt_symbol, bar_file))
return False
try:
data_types = {
"datetime": str,
"open": float,
"high": float,
"low": float,
"close": float,
"open_interest": float,
"volume": float,
"instrument_id": str,
"symbol": str,
"total_turnover": float,
"limit_down": float,
"limit_up": float,
"trading_day": str,
"date": str,
"time": str
}
# 加载csv文件 =》 dateframe
symbol_df = pd.read_csv(bar_file, dtype=data_types)
if len(symbol_df) == 0:
self.write_error(f'回测时加载{vt_symbol} csv文件{bar_file}失败。')
return False
first_dt = symbol_df.iloc[0]['datetime']
if '.' in first_dt:
datetime_format = "%Y-%m-%d %H:%M:%S.%f"
else:
datetime_format = "%Y-%m-%d %H:%M:%S"
# 转换时间str =》 datetime
symbol_df["datetime"] = pd.to_datetime(symbol_df["datetime"], format=datetime_format)
# 设置时间为索引
symbol_df = symbol_df.set_index("datetime")
# 裁剪数据
symbol_df = symbol_df.loc[self.test_start_date:self.test_end_date]
self.bar_df_dict.update({vt_symbol: symbol_df})
except Exception as ex:
self.write_error(u'回测时读取{} csv文件{}失败:{}'.format(vt_symbol, bar_file, ex))
self.output(u'回测时读取{} csv文件{}失败:{}'.format(vt_symbol, bar_file, ex))
return False
return True
def comine_bar_df(self):
"""
合并所有回测合约的bar DataFrame =集中的DataFrame
把bar_df_dict =bar_df
:return:
"""
self.output('comine_df')
if len(self.bar_df_dict) == 0:
self.output(f'无加载任何数据,请检查bar文件路径配置')
self.bar_df = pd.concat(self.bar_df_dict, axis=0).swaplevel(0, 1).sort_index()
self.bar_df_dict.clear()
def run_bar_test(self):
"""使用bar进行组合回测"""
testdays = (self.data_end_date - self.data_start_date).days
if testdays < 1:
self.write_log(u'回测时间不足')
return
# 加载数据
for vt_symbol in self.symbol_strategy_map.keys():
symbol, exchange = extract_vt_symbol(vt_symbol)
# 不读取SPD的bar文件
if exchange == Exchange.SPD:
continue
self.load_bar_csv_to_df(vt_symbol, self.bar_csv_file.get(symbol))
# 合并数据
self.comine_bar_df()
last_trading_day = None
bars_dt = None
bars_same_dt = []
gc_collect_days = 0
try:
for (dt, vt_symbol), bar_data in self.bar_df.iterrows():
symbol, exchange = extract_vt_symbol(vt_symbol)
bar_datetime = dt - timedelta(seconds=self.bar_interval_seconds)
bar = BarData(
gateway_name='backtesting',
symbol=symbol,
exchange=exchange,
datetime=bar_datetime
)
bar.open_price = float(bar_data['open'])
bar.close_price = float(bar_data['close'])
bar.high_price = float(bar_data['high'])
bar.low_price = float(bar_data['low'])
bar.volume = int(bar_data['volume'])
bar.open_interest = int(bar_data.get('open_interest', 0))
bar.date = bar_datetime.strftime('%Y-%m-%d')
bar.time = bar_datetime.strftime('%H:%M:%S')
str_td = str(bar_data.get('trading_day', ''))
if len(str_td) == 8:
bar.trading_day = str_td[0:4] + '-' + str_td[4:6] + '-' + str_td[6:8]
elif len(str_td) == 10:
bar.trading_day = str_td
else:
bar.trading_day = get_trading_date(bar_datetime)
if last_trading_day != bar.trading_day:
self.output(u'回测数据日期:{},资金:{}'.format(bar.trading_day, self.net_capital))
if self.strategy_start_date > bar.datetime:
last_trading_day = bar.trading_day
# bar时间与队列时间一致添加到队列中
if dt == bars_dt:
bars_same_dt.append(bar)
continue
else:
# bar时间与队列时间不一致先推送队列的bars
random.shuffle(bars_same_dt)
for _bar_ in bars_same_dt:
self.new_bar(_bar_)
# 创建新的队列
bars_same_dt = [bar]
bars_dt = dt
# 更新每日净值
if self.strategy_start_date <= dt <= self.data_end_date:
if last_trading_day != bar.trading_day:
if last_trading_day is not None:
self.saving_daily_data(datetime.strptime(last_trading_day, '%Y-%m-%d'), self.cur_capital,
self.max_net_capital, self.total_commission)
last_trading_day = bar.trading_day
# 第二个交易日,撤单
self.cancel_orders()
# 更新持仓缓存
self.update_pos_buffer()
gc_collect_days += 1
if gc_collect_days >= 10:
# 执行内存回收
gc.collect()
sleep(1)
gc_collect_days = 0
if self.net_capital < 0:
self.write_error(u'净值低于0回测停止')
self.output(u'净值低于0回测停止')
return
self.write_log(u'bar数据回放完成')
if last_trading_day is not None:
self.saving_daily_data(datetime.strptime(last_trading_day, '%Y-%m-%d'), self.cur_capital,
self.max_net_capital, self.total_commission)
except Exception as ex:
self.write_error(u'回测异常导致停止:{}'.format(str(ex)))
self.write_error(u'{},{}'.format(str(ex), traceback.format_exc()))
print(str(ex), file=sys.stderr)
traceback.print_exc()
return
def run_tick_test(self): def run_tick_test(self):
"""运行tick级别组合回测""" """运行tick级别组合回测"""
testdays = (self.data_end_date - self.data_start_date).days testdays = (self.data_end_date - self.data_start_date).days
@ -432,6 +718,9 @@ class SpreadTestingEngine(BackTestingEngine):
self.new_tick(tick) self.new_tick(tick)
# 推送至所有tick combiner
[c.on_tick(tick) for c in self.symbol_combiner_dict.get(tick.symbol, [])]
# 结束一个交易日后,更新每日净值 # 结束一个交易日后,更新每日净值
self.saving_daily_data(test_day, self.saving_daily_data(test_day,
self.cur_capital, self.cur_capital,
@ -463,6 +752,47 @@ class SpreadTestingEngine(BackTestingEngine):
self.write_log(u'tick数据回放完成') self.write_log(u'tick数据回放完成')
def new_bar(self, bar):
"""
重载new_bar方法
bar => tick => 合成器 => new_tick
:param bar:
:return:
"""
tick = self.bar_to_tick(bar)
self.new_tick(tick)
# 推送至所有tick combiner
[c.on_tick(tick) for c in self.symbol_combiner_dict.get(tick.symbol, [])]
def bar_to_tick(self, bar):
""" 通过bar分时bar转换为tick数据 """
# tick =》 增加一分钟
tick = TickData(
gateway_name='backtesting',
symbol=bar.symbol,
exchange=bar.exchange,
datetime=bar.datetime + timedelta(minutes=1)
)
tick.date = tick.datetime.strftime('%Y-%m-%d')
tick.time = tick.datetime.strftime('%H:%M:%S.000')
tick.trading_day = bar.trading_day if bar.trading_day else get_trading_date(tick.datetime)
tick.volume = bar.volume
tick.open_interest = bar.open_interest
tick.last_price = bar.close_price
tick.last_volume = bar.volume
tick.limit_up = 0
tick.limit_down = 0
tick.open_price = 0
tick.high_price = 0
tick.low_price = 0
tick.pre_close = 0
tick.bid_price_1 = bar.close_price
tick.ask_price_1 = bar.close_price
tick.bid_volume_1 = bar.volume
tick.ask_volume_1 = bar.volume
return tick
def single_test(test_setting: dict, strategy_setting: dict): def single_test(test_setting: dict, strategy_setting: dict):
""" """

View File

@ -1160,7 +1160,7 @@ class CtaProFutureTemplate(CtaProTemplate):
:param trade: :param trade:
:return: :return:
""" """
self.write_log(u'{},交易更新事件:{},当前持仓:{} ' self.write_log(u'{},交易更新 =>{},\n 当前持仓:{} '
.format(self.cur_datetime, .format(self.cur_datetime,
trade.__dict__, trade.__dict__,
self.position.pos)) self.position.pos))
@ -1173,6 +1173,8 @@ class CtaProFutureTemplate(CtaProTemplate):
dist_record['volume'] = trade.volume dist_record['volume'] = trade.volume
dist_record['price'] = trade.price dist_record['price'] = trade.price
dist_record['symbol'] = trade.vt_symbol dist_record['symbol'] = trade.vt_symbol
# 处理股指锁单
if trade.exchange == Exchange.CFFEX: if trade.exchange == Exchange.CFFEX:
if trade.direction == Direction.LONG: if trade.direction == Direction.LONG:
if abs(self.position.short_pos) >= trade.volume: if abs(self.position.short_pos) >= trade.volume:
@ -1228,7 +1230,7 @@ class CtaProFutureTemplate(CtaProTemplate):
def on_order(self, order: OrderData): def on_order(self, order: OrderData):
"""报单更新""" """报单更新"""
# 未执行的订单中,存在是异常,删除 # 未执行的订单中,存在是异常,删除
self.write_log(u'{}报单更新{}'.format(self.cur_datetime, order.__dict__)) self.write_log(u'{}报单更新 => {}'.format(self.cur_datetime, order.__dict__))
# 修正order被拆单得情况" # 修正order被拆单得情况"
self.fix_order(order) self.fix_order(order)
@ -1274,7 +1276,7 @@ class CtaProFutureTemplate(CtaProTemplate):
:param order: :param order:
:return: :return:
""" """
self.write_log(u'委托单全部完成:{}'.format(order.__dict__)) self.write_log(u'报单更新 => 委托单全部完成:{}'.format(order.__dict__))
active_order = self.active_orders[order.vt_orderid] active_order = self.active_orders[order.vt_orderid]
# 通过vt_orderid找到对应的网格 # 通过vt_orderid找到对应的网格
@ -1330,7 +1332,7 @@ class CtaProFutureTemplate(CtaProTemplate):
:param order: :param order:
:return: :return:
""" """
self.write_log(u'委托开仓单撤销:{}'.format(order.__dict__)) self.write_log(u'报单更新 => 委托开仓 => 撤销:{}'.format(order.__dict__))
if not self.trading: if not self.trading:
if not self.backtesting: if not self.backtesting:
@ -1343,7 +1345,7 @@ class CtaProFutureTemplate(CtaProTemplate):
# 直接更新“未完成委托单”更新volume,retry次数 # 直接更新“未完成委托单”更新volume,retry次数
old_order = self.active_orders[order.vt_orderid] old_order = self.active_orders[order.vt_orderid]
self.write_log(u'{} 委托信息:{}'.format(order.vt_orderid, old_order)) self.write_log(u'报单更新 => {} 未完成订单信息:{}'.format(order.vt_orderid, old_order))
old_order['traded'] = order.traded old_order['traded'] = order.traded
order_vt_symbol = copy(old_order['vt_symbol']) order_vt_symbol = copy(old_order['vt_symbol'])
order_volume = old_order['volume'] - old_order['traded'] order_volume = old_order['volume'] - old_order['traded']
@ -1477,7 +1479,7 @@ class CtaProFutureTemplate(CtaProTemplate):
else: else:
pre_status = old_order.get('status', Status.NOTTRADED) pre_status = old_order.get('status', Status.NOTTRADED)
old_order.update({'status': Status.CANCELLED}) old_order.update({'status': Status.CANCELLED})
self.write_log(u'委托单状态:{}=>{}'.format(pre_status, old_order.get('status'))) self.write_log(u'委托单方式{},状态:{}=>{}'.format(order_type, pre_status, old_order.get('status')))
if grid: if grid:
if order.vt_orderid in grid.order_ids: if order.vt_orderid in grid.order_ids:
grid.order_ids.remove(order.vt_orderid) grid.order_ids.remove(order.vt_orderid)
@ -1492,7 +1494,7 @@ class CtaProFutureTemplate(CtaProTemplate):
def on_order_close_canceled(self, order: OrderData): def on_order_close_canceled(self, order: OrderData):
"""委托平仓单撤销""" """委托平仓单撤销"""
self.write_log(u'委托平仓单撤销:{}'.format(order.__dict__)) self.write_log(u'报单更新 => 委托平仓 => 撤销:{}'.format(order.__dict__))
if order.vt_orderid not in self.active_orders: if order.vt_orderid not in self.active_orders:
self.write_error(u'{}不在未完成的委托单中:{}'.format(order.vt_orderid, self.active_orders)) self.write_error(u'{}不在未完成的委托单中:{}'.format(order.vt_orderid, self.active_orders))
@ -1504,7 +1506,7 @@ class CtaProFutureTemplate(CtaProTemplate):
# 直接更新“未完成委托单”更新volume,Retry次数 # 直接更新“未完成委托单”更新volume,Retry次数
old_order = self.active_orders[order.vt_orderid] old_order = self.active_orders[order.vt_orderid]
self.write_log(u'{} 订单信息:{}'.format(order.vt_orderid, old_order)) self.write_log(u'报单更新 => {} 未完成订单信息:{}'.format(order.vt_orderid, old_order))
old_order['traded'] = order.traded old_order['traded'] = order.traded
# order_time = old_order['order_time'] # order_time = old_order['order_time']
order_vt_symbol = copy(old_order['vt_symbol']) order_vt_symbol = copy(old_order['vt_symbol'])
@ -1692,13 +1694,13 @@ class CtaProFutureTemplate(CtaProTemplate):
if order_status in [Status.NOTTRADED, Status.SUBMITTING] and ( if order_status in [Status.NOTTRADED, Status.SUBMITTING] and (
order_type == OrderType.LIMIT or '.SPD' in order_vt_symbol): order_type == OrderType.LIMIT or '.SPD' in order_vt_symbol):
if over_seconds > self.cancel_seconds or force: # 超过设置的时间还未成交 if over_seconds > self.cancel_seconds or force: # 超过设置的时间还未成交
self.write_log(u'超时{}秒未成交取消委托单vt_orderid:{},order:{}' self.write_log(u'撤单逻辑 => 超时{}秒未成交取消委托单vt_orderid:{},order:{}'
.format(over_seconds, vt_orderid, order_info)) .format(over_seconds, vt_orderid, order_info))
order_info.update({'status': Status.CANCELLING}) order_info.update({'status': Status.CANCELLING})
self.active_orders.update({vt_orderid: order_info}) self.active_orders.update({vt_orderid: order_info})
ret = self.cancel_order(str(vt_orderid)) ret = self.cancel_order(str(vt_orderid))
if not ret: if not ret:
self.write_log(u'撤单失败,更新状态为撤单成功') self.write_log(u'撤单逻辑 => 撤单失败,更新状态为撤单成功')
order_info.update({'status': Status.CANCELLED}) order_info.update({'status': Status.CANCELLED})
self.active_orders.update({vt_orderid: order_info}) self.active_orders.update({vt_orderid: order_info})
if order_grid: if order_grid:
@ -1710,13 +1712,13 @@ class CtaProFutureTemplate(CtaProTemplate):
# 处理状态为‘撤销’的委托单 # 处理状态为‘撤销’的委托单
elif order_status == Status.CANCELLED: elif order_status == Status.CANCELLED:
self.write_log(u'委托单{}已成功撤单,删除{}'.format(vt_orderid, order_info)) self.write_log(u'撤单逻辑 => 委托单{}已成功撤单,删除未完成订单{}'.format(vt_orderid, order_info))
canceled_ids.append(vt_orderid) canceled_ids.append(vt_orderid)
if reopen: if reopen:
# 撤销的委托单,属于开仓类,需要重新委托 # 撤销的委托单,属于开仓类,需要重新委托
if order_info['offset'] == Offset.OPEN: if order_info['offset'] == Offset.OPEN:
self.write_log(u'超时撤单后,重新开仓') self.write_log(u'撤单逻辑 => 重新开仓')
# 开空委托单 # 开空委托单
if order_info['direction'] == Direction.SHORT: if order_info['direction'] == Direction.SHORT:
short_price = self.cur_mi_price - self.price_tick short_price = self.cur_mi_price - self.price_tick
@ -1788,17 +1790,29 @@ class CtaProFutureTemplate(CtaProTemplate):
else: else:
self.write_error(u'撤单后,重新委托平空仓失败') self.write_error(u'撤单后,重新委托平空仓失败')
else: else:
self.write_log(u'撤单逻辑 => 无须重新开仓')
if order_info['offset'] == Offset.OPEN \ if order_info['offset'] == Offset.OPEN \
and order_grid \ and order_grid \
and len(order_grid.order_ids) == 0 \ and len(order_grid.order_ids) == 0:
and order_grid.traded_volume == 0:
self.write_log(u'移除委托网格{}'.format(order_grid.__dict__)) if order_info['traded'] == 0 and order_grid.traded_volume == 0:
self.write_log(u'撤单逻辑 => 无任何成交 => 移除委托网格{}'.format(order_grid.__dict__))
order_info['grid'] = None order_info['grid'] = None
self.gt.remove_grids_by_ids(direction=order_grid.direction, ids=[order_grid.id]) self.gt.remove_grids_by_ids(direction=order_grid.direction, ids=[order_grid.id])
elif order_info['traded'] > 0:
self.write_log('撤单逻辑 = > 部分开仓')
if order_grid.traded_volume < order_info['traded']:
self.write_log('撤单逻辑 = > 调整网格开仓数 {} => {}'.format(order_grid.traded_volume, order_grid['traded'] ))
order_grid.traded_volume = order_info['traded']
self.write_log(f'撤单逻辑 => 调整网格委托状态=> False, 开仓状态:True, 开仓数量:{order_grid.volume}=>{order_grid.traded_volume}')
order_grid.order_status = False
order_grid.open_status = True
order_grid.volume = order_grid.traded_volume
order_grid.traded_volume = 0
# 删除撤单的订单 # 删除撤单的订单
for vt_orderid in canceled_ids: for vt_orderid in canceled_ids:
self.write_log(u'删除orderID:{0}'.format(vt_orderid)) self.write_log(u'撤单逻辑 => 删除未完成订单:{}'.format(vt_orderid))
self.active_orders.pop(vt_orderid, None) self.active_orders.pop(vt_orderid, None)
if len(self.active_orders) == 0: if len(self.active_orders) == 0:

View File

@ -1163,19 +1163,19 @@ class CtaSpreadTemplate(CtaTemplate):
return True return True
# 检查流动性缺失 # 检查流动性缺失
if not self.cur_act_tick.bid_price_1 <= self.cur_act_tick.last_price <= self.cur_act_tick.ask_price_1 \ # if not self.cur_act_tick.bid_price_1 <= self.cur_act_tick.last_price <= self.cur_act_tick.ask_price_1 \
and self.cur_act_tick.volume > 0: # and self.cur_act_tick.volume > 0:
self.write_log(u'流动性缺失导致leg1最新价{0} /V:{1}超出买1 {2}卖1 {3}范围,' # self.write_log(u'流动性缺失导致leg1最新价{0} /V:{1}超出买1 {2}卖1 {3}范围,'
.format(self.cur_act_tick.last_price, self.cur_act_tick.volume, # .format(self.cur_act_tick.last_price, self.cur_act_tick.volume,
self.cur_act_tick.bid_price_1, self.cur_act_tick.ask_price_1)) # self.cur_act_tick.bid_price_1, self.cur_act_tick.ask_price_1))
return False # return False
#
if not self.cur_pas_tick.bid_price_1 <= self.cur_pas_tick.last_price <= self.cur_pas_tick.ask_price_1 \ # if not self.cur_pas_tick.bid_price_1 <= self.cur_pas_tick.last_price <= self.cur_pas_tick.ask_price_1 \
and self.cur_pas_tick.volume > 0: # and self.cur_pas_tick.volume > 0:
self.write_log(u'流动性缺失导致leg2最新价{0} /V:{1}超出买1 {2}卖1 {3}范围,' # self.write_log(u'流动性缺失导致leg2最新价{0} /V:{1}超出买1 {2}卖1 {3}范围,'
.format(self.cur_pas_tick.last_price, self.cur_pas_tick.volume, # .format(self.cur_pas_tick.last_price, self.cur_pas_tick.volume,
self.cur_pas_tick.bid_price_1, self.cur_pas_tick.ask_price_1)) # self.cur_pas_tick.bid_price_1, self.cur_pas_tick.ask_price_1))
return False # return False
# 如果设置了方向和volume检查是否满足 # 如果设置了方向和volume检查是否满足
if direction==Direction.LONG: if direction==Direction.LONG:

File diff suppressed because it is too large Load Diff

View File

@ -9,6 +9,8 @@ import sys
import traceback import traceback
import talib as ta import talib as ta
import numpy as np import numpy as np
import pandas as pd
import csv
from collections import OrderedDict from collections import OrderedDict
from datetime import datetime, timedelta from datetime import datetime, timedelta
@ -20,6 +22,12 @@ from vnpy.trader.constant import Direction, Color
from vnpy.component.cta_period import CtaPeriod, Period from vnpy.component.cta_period import CtaPeriod, Period
try:
from vnpy.component.chanlun import ChanGraph, ChanLibrary
except Exception as ex:
print('can not import pyChanlun from vnpy.component.chanlun')
class CtaRenkoBar(object): class CtaRenkoBar(object):
"""CTA 砖型K线""" """CTA 砖型K线"""
@ -75,6 +83,7 @@ class CtaRenkoBar(object):
self.param_list.append('para_kdj_smooth_len') self.param_list.append('para_kdj_smooth_len')
self.param_list.append('para_active_kf') # 卡尔曼均线 self.param_list.append('para_active_kf') # 卡尔曼均线
self.param_list.append('para_kf_obscov_len') # 卡尔曼均线观测方差的长度
self.param_list.append('para_active_skd') # 摆动指标 self.param_list.append('para_active_skd') # 摆动指标
self.param_list.append('para_skd_fast_len') self.param_list.append('para_skd_fast_len')
@ -91,6 +100,7 @@ class CtaRenkoBar(object):
self.param_list.append('para_yb_ref') self.param_list.append('para_yb_ref')
self.param_list.append('para_golden_n') # 黄金分割 self.param_list.append('para_golden_n') # 黄金分割
self.param_list.append('para_active_chanlun') # 激活缠论
# 输入参数 # 输入参数
@ -322,6 +332,7 @@ class CtaRenkoBar(object):
# 卡尔曼过滤器 # 卡尔曼过滤器
self.para_active_kf = False self.para_active_kf = False
self.para_kf_obscov_len = 1 # t+1时刻的观测协方差
self.kf = None self.kf = None
self.line_state_mean = [] self.line_state_mean = []
self.line_state_covar = [] self.line_state_covar = []
@ -389,6 +400,7 @@ class CtaRenkoBar(object):
self.is_7x24 = False self.is_7x24 = False
# (实时运行时或者addbar小于bar得周期时不包含最后一根Bar # (实时运行时或者addbar小于bar得周期时不包含最后一根Bar
self.index_list = []
self.open_array = np.zeros(self.max_hold_bars) # 与lineBar一致得开仓价清单 self.open_array = np.zeros(self.max_hold_bars) # 与lineBar一致得开仓价清单
self.open_array[:] = np.nan self.open_array[:] = np.nan
self.high_array = np.zeros(self.max_hold_bars) # 与lineBar一致得最高价清单 self.high_array = np.zeros(self.max_hold_bars) # 与lineBar一致得最高价清单
@ -405,6 +417,16 @@ class CtaRenkoBar(object):
self.mid5_array = np.zeros(self.max_hold_bars) # 收盘价*2/开仓价/最高/最低价 的平均价 self.mid5_array = np.zeros(self.max_hold_bars) # 收盘价*2/开仓价/最高/最低价 的平均价
self.mid5_array[:] = np.nan self.mid5_array[:] = np.nan
self.para_active_chanlun = False # 是否激活缠论
self.chan_lib = None
self.chan_graph = None
self.chanlun_calculated = False # 当前bar是否计算过
self._fenxing_list = [] # 分型列表
self._bi_list = [] # 笔列表
self._bi_zs_list = [] # 笔中枢列表
self._duan_list = [] # 段列表
self._duan_zs_list = [] # 段中枢列表
# 导出到csv文件 # 导出到csv文件
self.export_filename = None self.export_filename = None
self.export_fields = [] self.export_fields = []
@ -417,6 +439,7 @@ class CtaRenkoBar(object):
# 事件回调函数 # 事件回调函数
self.cb_dict = {} self.cb_dict = {}
if setting: if setting:
self.setParam(setting) self.setParam(setting)
@ -429,6 +452,13 @@ class CtaRenkoBar(object):
if self.kilo_height > 0: if self.kilo_height > 0:
self.height = self.price_tick * self.kilo_height self.height = self.price_tick * self.kilo_height
if self.para_active_chanlun:
try:
self.chan_lib = ChanLibrary(bi_style=2, duan_style=1)
except:
self.write_log(u'导入缠论组件失败')
self.chan_lib = None
def __getstate__(self): def __getstate__(self):
"""移除Pickle dump()时不支持的Attribute""" """移除Pickle dump()时不支持的Attribute"""
state = self.__dict__.copy() state = self.__dict__.copy()
@ -508,7 +538,8 @@ class CtaRenkoBar(object):
observation_matrices=[1], observation_matrices=[1],
initial_state_mean=self.last_price_list[-1], initial_state_mean=self.last_price_list[-1],
initial_state_covariance=1, initial_state_covariance=1,
transition_covariance=0.01) transition_covariance=0.01,
observation_covariance=self.para_kf_obscov_len)
state_means, state_covariances = self.tick_kf.filter(np.array(self.last_price_list, dtype=float)) state_means, state_covariances = self.tick_kf.filter(np.array(self.last_price_list, dtype=float))
m = state_means[-1].item() m = state_means[-1].item()
c = state_covariances[-1].item() c = state_covariances[-1].item()
@ -600,6 +631,9 @@ class CtaRenkoBar(object):
# 新添加得bar比现有得bar时间晚不添加 # 新添加得bar比现有得bar时间晚不添加
if bar.datetime < self.line_bar[-1].datetime: if bar.datetime < self.line_bar[-1].datetime:
return return
if bar.datetime == self.line_bar[-1].datetime:
if bar.close_price != self.line_bar[-1].close_price:
bar.datetime += timedelta(microseconds=1)
# 更新最后价格 # 更新最后价格
self.cur_price = bar.close_price self.cur_price = bar.close_price
@ -615,7 +649,9 @@ class CtaRenkoBar(object):
bar_mid4 = round((2 * bar.close_price + bar.high_price + bar.low_price) / 4, self.round_n) bar_mid4 = round((2 * bar.close_price + bar.high_price + bar.low_price) / 4, self.round_n)
bar_mid5 = round((2 * bar.close_price + bar.open_price + bar.high_price + bar.low_price) / 5, self.round_n) bar_mid5 = round((2 * bar.close_price + bar.open_price + bar.high_price + bar.low_price) / 5, self.round_n)
# 扩展open,close,high,low numpy array列表 # 扩展时间索引,open,close,high,low numpy array列表 平移更新序列最新值
self.index_list.append(bar.datetime.strftime('%Y-%m-%d %H:%M:%S.%f'))
self.open_array[:-1] = self.open_array[1:] self.open_array[:-1] = self.open_array[1:]
self.open_array[-1] = bar.open_price self.open_array[-1] = bar.open_price
@ -654,7 +690,9 @@ class CtaRenkoBar(object):
elif bar.close_price < bar.open_price: elif bar.close_price < bar.open_price:
bar.color = Color.BLUE bar.color = Color.BLUE
# 扩展open,close,high,low 列表 # 扩展时间索引,open,close,high,low numpy array列表 平移更新序列最新值
self.index_list.append(bar.datetime.strftime('%Y-%m-%d %H:%M:%S.%f'))
self.open_array[:-1] = self.open_array[1:] self.open_array[:-1] = self.open_array[1:]
self.open_array[-1] = bar.open_price self.open_array[-1] = bar.open_price
@ -709,6 +747,8 @@ class CtaRenkoBar(object):
self.runtime_recount() self.runtime_recount()
self.chanlun_calculated = False
# 回调上层调用者 # 回调上层调用者
self.cb_on_bar(bar, self.name) self.cb_on_bar(bar, self.name)
@ -2848,13 +2888,14 @@ class CtaRenkoBar(object):
observation_matrices=[1], observation_matrices=[1],
initial_state_mean=self.close_array[-1], initial_state_mean=self.close_array[-1],
initial_state_covariance=1, initial_state_covariance=1,
transition_covariance=0.01,
transition_covariance=0.01) observation_covariance = self.para_kf_obscov_len
)
except Exception: except Exception:
self.write_log(u'导入卡尔曼过滤器失败,需先安装 pip install pykalman') self.write_log(u'导入卡尔曼过滤器失败,需先安装 pip install pykalman')
self.para_active_kf = False self.para_active_kf = False
state_means, state_covariances = self.kf.filter(np.array(self.close_array, dtype=float)) state_means, state_covariances = self.kf.filter(np.array(self.close_array[-1], dtype=float))
m = state_means[-1].item() m = state_means[-1].item()
c = state_covariances[-1].item() c = state_covariances[-1].item()
else: else:
@ -3521,6 +3562,432 @@ class CtaRenkoBar(object):
self.write_log(u'call back event{} exception:{}'.format(self.CB_ON_PERIOD, str(ex))) self.write_log(u'call back event{} exception:{}'.format(self.CB_ON_PERIOD, str(ex)))
self.write_log(u'traceback:{}'.format(traceback.format_exc())) self.write_log(u'traceback:{}'.format(traceback.format_exc()))
def __count_chanlun(self):
"""重新计算缠论"""
if self.chanlun_calculated:
return
if not self.chan_lib:
return
if self.bar_len <= 3:
return
if self.chan_graph is not None:
del self.chan_graph
self.chan_graph = None
self.chan_graph = ChanGraph(chan_lib=self.chan_lib,
index=self.index_list[-self.bar_len+1:],
high=self.high_array[-self.bar_len+1:],
low=self.low_array[-self.bar_len+1:])
self._fenxing_list = self.chan_graph.fenxing_list
self._bi_list = self.chan_graph.bi_list
self._bi_zs_list = self.chan_graph.bi_zhongshu_list
self._duan_list = self.chan_graph.duan_list
self._duan_zs_list = self.chan_graph.duan_zhongshu_list
self.chanlun_calculated = True
@property
def fenxing_list(self):
if not self.chanlun_calculated:
self.__count_chanlun()
return self._fenxing_list
@property
def bi_list(self):
if not self.chanlun_calculated:
self.__count_chanlun()
return self._bi_list
@property
def bi_zs_list(self):
if not self.chanlun_calculated:
self.__count_chanlun()
return self._bi_zs_list
@property
def duan_list(self):
if not self.chanlun_calculated:
self.__count_chanlun()
return self._duan_list
@property
def duan_zs_list(self):
if not self.chanlun_calculated:
self.__count_chanlun()
return self._duan_zs_list
def is_bi_beichi_inside_duan(self, direction):
"""当前段内的笔,是否形成背驰"""
if len(self._duan_list) == 0:
return False
# Direction => int
if isinstance(direction, Direction):
direction = 1 if direction == Direction.LONG else -1
# 分型需要确认
if self.fenxing_list[-1].is_rt:
return False
# 当前段
duan = self._duan_list[-1]
if duan.direction != direction:
return False
# 当前段包含的分笔必须大于等于5(缠论里面,如果只有三个分笔,背驰的力度比较弱)
if len(duan.bi_list) < 5:
return False
# 获取最近2个匹配direction的分型
fx_list = [fx for fx in self._fenxing_list[-4:] if fx.direction == direction]
if len(fx_list) != 2:
return False
# 这里是排除段的信号出错,获取了很久之前的一段,而不是最新的一段
if duan.end < fx_list[0].index:
return False
# 分笔与段同向
if duan.bi_list[-1].direction != direction \
or duan.bi_list[-3].direction != direction \
or duan.bi_list[-5].direction != direction:
return False
# 背驰: 同向分笔,逐笔提升,最后一笔,比上一同向笔,短,斜率也比上一同向笔小
if direction == 1:
if duan.bi_list[-1].low > duan.bi_list[-3].low > duan.bi_list[-5].low \
and duan.bi_list[-1].low > duan.bi_list[-5].high \
and duan.bi_list[-1].height < duan.bi_list[-3].height \
and duan.bi_list[-1].atan < duan.bi_list[-3].atan:
return True
if direction == -1:
if duan.bi_list[-1].high < duan.bi_list[-3].high < duan.bi_list[-5].high \
and duan.bi_list[-1].high < duan.bi_list[-5].low \
and duan.bi_list[-1].height < duan.bi_list[-3].height\
and duan.bi_list[-1].atan < duan.bi_list[-3].atan:
return True
return False
def is_fx_macd_divergence(self, direction):
"""
分型的macd背离
:param direction: 1-1 或者 Direction.LONG判断是否顶背离, Direction.SHORT判断是否底背离
:return:
"""
if isinstance(direction, Direction):
direction = 1 if direction == Direction.LONG else -1
# 当前段
duan = self._duan_list[-1]
if duan.direction != direction:
return False
# 当前段包含的分笔必须大于3
if len(duan.bi_list) <= 3:
return False
# 获取最近2个匹配direction的分型
fx_list = [fx for fx in self._fenxing_list[-4:] if fx.direction == direction]
if len(fx_list) != 2:
return False
# 这里是排除段的信号出错,获取了很久之前的一段,而不是最新的一段
if duan.end < fx_list[0].index:
return False
pre_dif = self.get_dif_by_dt(fx_list[0].index)
cur_dif = self.get_dif_by_dt(fx_list[1].index)
if pre_dif is None or cur_dif is None:
return False
if direction == 1:
# 前顶分型顶部价格
pre_price = fx_list[0].high
# 当前顶分型顶部价格
cur_price = fx_list[1].high
if pre_price < cur_price and pre_dif >= cur_dif and 0 < self.line_dif[-1] < self.line_dif[-2]:
return True
else:
pre_price = fx_list[0].low
cur_price = fx_list[1].low
if pre_price > cur_price and pre_dif <= cur_dif and self.line_dif[-2] < self.line_dif[-1] < 0:
return True
return False
def is_2nd_opportunity(self, direction):
"""
是二买二卖机会
二买当前线段下行最后2笔不在线段中最后一笔与下行线段同向该笔底部不破线段底部底分型出现且确认
二卖当前线段上行最后2笔不在线段中最后一笔与上行线段同向该笔顶部不破线段顶部顶分型出现且确认
:param direction: 1Direction.LONG, 当前线段的方向, 判断是否二卖机会 -1 Direction.SHORT 判断是否二买
:return:
"""
# Direction => int
if isinstance(direction, Direction):
direction = 1 if direction == Direction.LONG else -1
# 具备段
if len(self.duan_list) < 1:
return False
cur_duan = self.duan_list[-1]
if cur_duan.direction != direction:
return False
# 当前段到最新bar之间的笔列表此时未出现中枢
extra_bi_list = [bi for bi in self.bi_list[-3:] if bi.end > cur_duan.end]
if len(extra_bi_list) < 2:
return False
# 最后一笔是同向
if extra_bi_list[-1].direction != direction:
return False
# 线段外一笔的高度,不能超过线段最后一笔高度
if extra_bi_list[0].height > cur_duan.bi_list[-1].height:
return False
# 最后一笔的高度不能超过最后一段的高度的黄金分割38%
if extra_bi_list[-1].height > cur_duan.height * 0.38:
return False
# 最后的分型,不是实时。
if not self.fenxing_list[-1].is_rt:
return True
return False
def is_contain_zs_inside_duan(self, direction, zs_num):
"""最近段符合方向并且至少包含zs_num个中枢"""
# Direction => int
if isinstance(direction, Direction):
direction = 1 if direction == Direction.LONG else -1
# 具备中枢
if len(self.bi_zs_list) < zs_num:
return False
# 具备段
if len(self.duan_list) < 1:
return False
cur_duan = self.duan_list[-1]
if cur_duan.direction != direction:
return False
# 段的开始时间至少大于前zs_num个中枢的结束时间
if cur_duan.start > self.bi_zs_list[-zs_num].end:
return False
return True
def is_contain_zs_with_direction(self, start, direction, zs_num):
"""从start开始计算至少包含zs_num(>1)个中枢,且最后两个中枢符合方向"""
if zs_num < 2:
return False
# Direction => int
if isinstance(direction, Direction):
direction = 1 if direction == Direction.LONG else -1
# 具备中枢
if len(self.bi_zs_list) < zs_num:
return False
bi_zs_list = [zs for zs in self.bi_zs_list[-zs_num:] if zs.end > start]
if len(bi_zs_list) != zs_num:
return False
if direction == 1 and bi_zs_list[-2].high < bi_zs_list[-1].high:
return True
if direction == -1 and bi_zs_list[-2].high > bi_zs_list[-1].high:
return True
return False
def is_zs_beichi_inside_duan(self, direction):
"""是否中枢盘整背驰,进入笔、离去笔,高度,能量背驰"""
# Direction => int
if isinstance(direction, Direction):
direction = 1 if direction == Direction.LONG else -1
# 具备中枢
if len(self.bi_zs_list) < 1:
return False
# 具备段
if len(self.duan_list) < 1:
return False
# 最后线段
cur_duan = self.duan_list[-1]
if cur_duan.direction != direction:
return False
# 线段内的笔中枢(取前三个就可以了)
zs_list_inside_duan = [zs for zs in self.bi_zs_list[-3:] if zs.start >= cur_duan.start]
# 无中枢或者超过1个中枢都不符合中枢背驰
if len(zs_list_inside_duan) != 1:
return False
# 当前中枢
cur_zs = zs_list_inside_duan[0]
# 当前中枢最后一笔,与段最后一笔不一致
if cur_duan.bi_list[-1].end != cur_zs.bi_list[-1].end:
return False
# 分型需要确认
if self.fenxing_list[-1].is_rt:
return False
# 找出中枢得进入笔
entry_bi = cur_zs.bi_list[0]
if entry_bi.direction != direction:
# 找出中枢之前,与段同向得笔
before_bi_list = [bi for bi in cur_duan.bi_list if bi.start < entry_bi.start and bi.direction==direction]
# 中枢之前得同向笔,不存在(一般不可能,因为中枢得第一笔不同向,该中枢存在与段中间)
if len(before_bi_list) == 0:
return False
entry_bi = before_bi_list[-1]
# 中枢第一笔,与最后一笔,比较力度和能量
if entry_bi.height > cur_zs.bi_list[-1].height\
and entry_bi.atan > cur_zs.bi_list[-1].atan:
return True
return False
def is_zs_fangda(self, cur_bi_zs = None, start=False, last_bi=False):
"""
判断中枢是否为放大型中枢
中枢放大一般是反向力量的强烈试探导致
cur_bi_zs: 指定的笔中枢若不指定则默认为最后一个中枢
start: True从中枢开始的笔进行计算前三 False 从最后三笔计算
last_bi: 采用缺省最后一笔时是否要求最后一笔必须等于中枢得最后一笔
"""
if cur_bi_zs is None:
# 具备中枢
if len(self.bi_zs_list) < 1:
return False
cur_bi_zs = self.bi_zs_list[-1]
if last_bi:
cur_bi = self.bi_list[-1]
# 要求最后一笔,必须等于中枢得最后一笔
if cur_bi.start != cur_bi_zs.bi_list[-1].start:
return False
if len(cur_bi_zs.bi_list) < 3:
return False
# 从开始前三笔计算
if start and cur_bi_zs.bi_list[2].height > cur_bi_zs.bi_list[1].height > cur_bi_zs.bi_list[0].height:
return True
# 从最后的三笔计算
if not start and cur_bi_zs.bi_list[-1].height > cur_bi_zs.bi_list[-2].height > cur_bi_zs.bi_list[-3].height:
return True
return False
def is_zs_shoulian(self, cur_bi_zs=None, start=False, last_bi=False):
"""
判断中枢是否为收殓型中枢
中枢收敛一般是多空力量的趋于平衡如果是段中的第二个或以上中枢可能存在变盘
cur_bi_zs: 指定的中枢或者最后一个中枢
start: True从中枢开始的笔进行计算前三 False 从最后三笔计算
"""
if cur_bi_zs is None:
# 具备中枢
if len(self.bi_zs_list) < 1:
return False
cur_bi_zs = self.bi_zs_list[-1]
if last_bi:
cur_bi = self.bi_list[-1]
# 要求最后一笔,必须等于中枢得最后一笔
if cur_bi.start != cur_bi_zs.bi_list[-1].start:
return False
if len(cur_bi_zs.bi_list) < 3:
return False
if start and cur_bi_zs.bi_list[2].height < cur_bi_zs.bi_list[1].height < cur_bi_zs.bi_list[0].height:
return True
if not start and cur_bi_zs.bi_list[-1].height < cur_bi_zs.bi_list[-2].height < cur_bi_zs.bi_list[-3].height:
return True
return False
def is_zoushi_beichi(self, direction):
"""
判断是否走势背驰
:param direction:
:return:
"""
# Direction => int
if isinstance(direction, Direction):
direction = 1 if direction == Direction.LONG else -1
# 具备中枢
if len(self.bi_zs_list) < 1:
return False
# 具备段
if len(self.duan_list) < 1:
return False
# 最后线段
cur_duan = self.duan_list[-1]
if cur_duan.direction != direction:
return False
# 线段内的笔中枢(取前三个就可以了)
zs_list_inside_duan = [zs for zs in self.bi_zs_list[-3:] if zs.start >= cur_duan.start]
# 少于2个中枢都不符合走势背驰
if len(zs_list_inside_duan) < 2:
return False
# 当前中枢
cur_zs = zs_list_inside_duan[-1]
# 上一个中枢
pre_zs = zs_list_inside_duan[-2]
bi_list_between_zs = [bi for bi in cur_duan.bi_list if bi.direction == direction and bi.end > pre_zs.end and bi.start < cur_zs.start]
if len(bi_list_between_zs) ==0:
return False
# 最后一笔作为2个中枢间的笔
bi_between_zs = bi_list_between_zs[-1]
bi_list_after_cur_zs = [bi for bi in cur_duan.bi_list if bi.direction==direction and bi.end > cur_zs.end]
if len(bi_list_after_cur_zs) == 0:
return False
# 离开中枢的一笔
bi_leave_cur_zs = bi_list_after_cur_zs[0]
# 离开中枢的一笔,不是段的最后一笔
if bi_leave_cur_zs.start != cur_duan.bi_list[-1].start:
return False
# 离开中枢的一笔,不是最后一笔
if bi_leave_cur_zs.start != self.bi_list[-1].start:
return False
fx = [fx for fx in self.fenxing_list[-2:] if fx.direction==direction][-1]
if fx.is_rt:
return False
# 中枢间的分笔,能量大于最后分笔,形成走势背驰
if bi_between_zs.height > bi_leave_cur_zs.height and bi_between_zs.atan > bi_leave_cur_zs.atan:
return True
return False
# ---------------------------------------------------------------------- # ----------------------------------------------------------------------
def write_log(self, content): def write_log(self, content):
"""记录CTA日志""" """记录CTA日志"""

View File

@ -54,7 +54,7 @@ from vnpy.trader.constant import (
OptionType, OptionType,
Interval Interval
) )
from vnpy.trader.gateway import BaseGateway from vnpy.trader.gateway import BaseGateway, TickCombiner
from vnpy.trader.object import ( from vnpy.trader.object import (
TickData, TickData,
BarData, BarData,
@ -2091,226 +2091,3 @@ class TqMdApi():
except Exception as e: except Exception as e:
self.gateway.write_log('退出天勤行情api异常:{}'.format(str(e))) self.gateway.write_log('退出天勤行情api异常:{}'.format(str(e)))
class TickCombiner(object):
"""
Tick合成类
"""
def __init__(self, gateway, setting):
self.gateway = gateway
self.gateway_name = self.gateway.gateway_name
self.gateway.write_log(u'创建tick合成类:{}'.format(setting))
self.symbol = setting.get('symbol', None)
self.leg1_symbol = setting.get('leg1_symbol', None)
self.leg2_symbol = setting.get('leg2_symbol', None)
self.leg1_ratio = setting.get('leg1_ratio', 1) # 腿1的数量配比
self.leg2_ratio = setting.get('leg2_ratio', 1) # 腿2的数量配比
self.price_tick = setting.get('price_tick', 1) # 合成价差加比后的最小跳动
# 价差
self.is_spread = setting.get('is_spread', False)
# 价比
self.is_ratio = setting.get('is_ratio', False)
self.last_leg1_tick = None
self.last_leg2_tick = None
# 价差日内最高/最低价
self.spread_high = None
self.spread_low = None
# 价比日内最高/最低价
self.ratio_high = None
self.ratio_low = None
# 当前交易日
self.trading_day = None
if self.is_ratio and self.is_spread:
self.gateway.write_error(u'{}参数有误,不能同时做价差/加比.setting:{}'.format(self.symbol, setting))
return
self.gateway.write_log(u'初始化{}合成器成功'.format(self.symbol))
if self.is_spread:
self.gateway.write_log(
u'leg1:{} * {} - leg2:{} * {}'.format(self.leg1_symbol, self.leg1_ratio, self.leg2_symbol,
self.leg2_ratio))
if self.is_ratio:
self.gateway.write_log(
u'leg1:{} * {} / leg2:{} * {}'.format(self.leg1_symbol, self.leg1_ratio, self.leg2_symbol,
self.leg2_ratio))
def on_tick(self, tick):
"""OnTick处理"""
combinable = False
if tick.symbol == self.leg1_symbol:
# leg1合约
self.last_leg1_tick = tick
if self.last_leg2_tick is not None:
if self.last_leg1_tick.datetime.replace(microsecond=0) == self.last_leg2_tick.datetime.replace(
microsecond=0):
combinable = True
elif tick.symbol == self.leg2_symbol:
# leg2合约
self.last_leg2_tick = tick
if self.last_leg1_tick is not None:
if self.last_leg2_tick.datetime.replace(microsecond=0) == self.last_leg1_tick.datetime.replace(
microsecond=0):
combinable = True
# 不能合并
if not combinable:
return
if not self.is_ratio and not self.is_spread:
return
# 以下情况,基本为单腿涨跌停,不合成价差/价格比 Tick
if (self.last_leg1_tick.ask_price_1 == 0 or self.last_leg1_tick.bid_price_1 == self.last_leg1_tick.limit_up) \
and self.last_leg1_tick.ask_volume_1 == 0:
self.gateway.write_log(
u'leg1:{0}涨停{1}不合成价差Tick'.format(self.last_leg1_tick.vt_symbol, self.last_leg1_tick.bid_price_1))
return
if (self.last_leg1_tick.bid_price_1 == 0 or self.last_leg1_tick.ask_price_1 == self.last_leg1_tick.limit_down) \
and self.last_leg1_tick.bid_volume_1 == 0:
self.gateway.write_log(
u'leg1:{0}跌停{1}不合成价差Tick'.format(self.last_leg1_tick.vt_symbol, self.last_leg1_tick.ask_price_1))
return
if (self.last_leg2_tick.ask_price_1 == 0 or self.last_leg2_tick.bid_price_1 == self.last_leg2_tick.limit_up) \
and self.last_leg2_tick.ask_volume_1 == 0:
self.gateway.write_log(
u'leg2:{0}涨停{1}不合成价差Tick'.format(self.last_leg2_tick.vt_symbol, self.last_leg2_tick.bid_price_1))
return
if (self.last_leg2_tick.bid_price_1 == 0 or self.last_leg2_tick.ask_price_1 == self.last_leg2_tick.limit_down) \
and self.last_leg2_tick.bid_volume_1 == 0:
self.gateway.write_log(
u'leg2:{0}跌停{1}不合成价差Tick'.format(self.last_leg2_tick.vt_symbol, self.last_leg2_tick.ask_price_1))
return
if self.trading_day != tick.trading_day:
self.trading_day = tick.trading_day
self.spread_high = None
self.spread_low = None
self.ratio_high = None
self.ratio_low = None
if self.is_spread:
spread_tick = TickData(gateway_name=self.gateway_name,
symbol=self.symbol,
exchange=Exchange.SPD,
datetime=tick.datetime)
spread_tick.trading_day = tick.trading_day
spread_tick.date = tick.date
spread_tick.time = tick.time
# 叫卖价差=leg1.ask_price_1 * 配比 - leg2.bid_price_1 * 配比volume为两者最小
spread_tick.ask_price_1 = round_to(target=self.price_tick,
value=self.last_leg1_tick.ask_price_1 * self.leg1_ratio - self.last_leg2_tick.bid_price_1 * self.leg2_ratio)
spread_tick.ask_volume_1 = min(self.last_leg1_tick.ask_volume_1, self.last_leg2_tick.bid_volume_1)
# 叫买价差=leg1.bid_price_1 * 配比 - leg2.ask_price_1 * 配比volume为两者最小
spread_tick.bid_price_1 = round_to(target=self.price_tick,
value=self.last_leg1_tick.bid_price_1 * self.leg1_ratio - self.last_leg2_tick.ask_price_1 * self.leg2_ratio)
spread_tick.bid_volume_1 = min(self.last_leg1_tick.bid_volume_1, self.last_leg2_tick.ask_volume_1)
# 最新价
spread_tick.last_price = round_to(target=self.price_tick,
value=(spread_tick.ask_price_1 + spread_tick.bid_price_1) / 2)
# 昨收盘价
if self.last_leg2_tick.pre_close > 0 and self.last_leg1_tick.pre_close > 0:
spread_tick.pre_close = round_to(target=self.price_tick,
value=self.last_leg1_tick.pre_close * self.leg1_ratio - self.last_leg2_tick.pre_close * self.leg2_ratio)
# 开盘价
if self.last_leg2_tick.open_price > 0 and self.last_leg1_tick.open_price > 0:
spread_tick.open_price = round_to(target=self.price_tick,
value=self.last_leg1_tick.open_price * self.leg1_ratio - self.last_leg2_tick.open_price * self.leg2_ratio)
# 最高价
if self.spread_high:
self.spread_high = max(self.spread_high, spread_tick.ask_price_1)
else:
self.spread_high = spread_tick.ask_price_1
spread_tick.high_price = self.spread_high
# 最低价
if self.spread_low:
self.spread_low = min(self.spread_low, spread_tick.bid_price_1)
else:
self.spread_low = spread_tick.bid_price_1
spread_tick.low_price = self.spread_low
self.gateway.on_tick(spread_tick)
if self.is_ratio:
ratio_tick = TickData(
gateway_name=self.gateway_name,
symbol=self.symbol,
exchange=Exchange.SPD,
datetime=tick.datetime
)
ratio_tick.trading_day = tick.trading_day
ratio_tick.date = tick.date
ratio_tick.time = tick.time
# 比率tick = (腿1 * 腿1 手数 / 腿2价格 * 腿2手数) 百分比
ratio_tick.ask_price_1 = 100 * self.last_leg1_tick.ask_price_1 * self.leg1_ratio \
/ (self.last_leg2_tick.bid_price_1 * self.leg2_ratio) # noqa
ratio_tick.ask_price_1 = round_to(
target=self.price_tick,
value=ratio_tick.ask_price_1
)
ratio_tick.ask_volume_1 = min(self.last_leg1_tick.ask_volume_1, self.last_leg2_tick.bid_volume_1)
ratio_tick.bid_price_1 = 100 * self.last_leg1_tick.bid_price_1 * self.leg1_ratio \
/ (self.last_leg2_tick.ask_price_1 * self.leg2_ratio) # noqa
ratio_tick.bid_price_1 = round_to(
target=self.price_tick,
value=ratio_tick.bid_price_1
)
ratio_tick.bid_volume_1 = min(self.last_leg1_tick.bid_volume_1, self.last_leg2_tick.ask_volume_1)
ratio_tick.last_price = (ratio_tick.ask_price_1 + ratio_tick.bid_price_1) / 2
ratio_tick.last_price = round_to(
target=self.price_tick,
value=ratio_tick.last_price
)
# 昨收盘价
if self.last_leg2_tick.pre_close > 0 and self.last_leg1_tick.pre_close > 0:
ratio_tick.pre_close = 100 * self.last_leg1_tick.pre_close * self.leg1_ratio / (
self.last_leg2_tick.pre_close * self.leg2_ratio) # noqa
ratio_tick.pre_close = round_to(
target=self.price_tick,
value=ratio_tick.pre_close
)
# 开盘价
if self.last_leg2_tick.open_price > 0 and self.last_leg1_tick.open_price > 0:
ratio_tick.open_price = 100 * self.last_leg1_tick.open_price * self.leg1_ratio / (
self.last_leg2_tick.open_price * self.leg2_ratio) # noqa
ratio_tick.open_price = round_to(
target=self.price_tick,
value=ratio_tick.open_price
)
# 最高价
if self.ratio_high:
self.ratio_high = max(self.ratio_high, ratio_tick.ask_price_1)
else:
self.ratio_high = ratio_tick.ask_price_1
ratio_tick.high_price = self.spread_high
# 最低价
if self.ratio_low:
self.ratio_low = min(self.ratio_low, ratio_tick.bid_price_1)
else:
self.ratio_low = ratio_tick.bid_price_1
ratio_tick.low_price = self.spread_low
self.gateway.on_tick(ratio_tick)

View File

@ -374,8 +374,8 @@ class PbGateway(BaseGateway):
product_id=product_id, product_id=product_id,
unit_id=unit_id, unit_id=unit_id,
holder_ids=holder_ids) holder_ids=holder_ids)
self.tq_api = TqMdApi(self) #self.tq_api = TqMdApi(self)
self.tq_api.connect() #self.tq_api.connect()
self.init_query() self.init_query()
def close(self) -> None: def close(self) -> None:

View File

@ -507,6 +507,8 @@ class OmsEngine(BaseEngine):
self.main_engine.get_position = self.get_position self.main_engine.get_position = self.get_position
self.main_engine.get_account = self.get_account self.main_engine.get_account = self.get_account
self.main_engine.get_contract = self.get_contract self.main_engine.get_contract = self.get_contract
self.main_engine.get_exchange = self.get_exchange
self.main_engine.get_custom_contract = self.get_custom_contract
self.main_engine.get_all_ticks = self.get_all_ticks self.main_engine.get_all_ticks = self.get_all_ticks
self.main_engine.get_all_orders = self.get_all_orders self.main_engine.get_all_orders = self.get_all_orders
self.main_engine.get_all_trades = self.get_all_trades self.main_engine.get_all_trades = self.get_all_trades
@ -650,6 +652,13 @@ class OmsEngine(BaseEngine):
self.today_contracts[contract.vt_symbol] = contract self.today_contracts[contract.vt_symbol] = contract
self.today_contracts[contract.symbol] = contract self.today_contracts[contract.symbol] = contract
def get_exchange(self, symbol: str) -> Exchange:
"""获取合约对应的交易所"""
contract = self.contracts.get(symbol, None)
if contract is None:
return Exchange.LOCAL
return contract.exchange
def get_tick(self, vt_symbol: str) -> Optional[TickData]: def get_tick(self, vt_symbol: str) -> Optional[TickData]:
""" """
Get latest market tick data by vt_symbol. Get latest market tick data by vt_symbol.
@ -746,6 +755,27 @@ class OmsEngine(BaseEngine):
] ]
return active_orders return active_orders
def get_custom_contract(self, symbol):
"""
获取自定义合约的设置
:param symbol: "pb2012-1-pb2101-1-CJ"
:return: {
"name": "pb跨期价差",
"exchange": "SPD",
"leg1_symbol": "pb2012",
"leg1_exchange": "SHFE",
"leg1_ratio": 1,
"leg2_symbol": "pb2101",
"leg2_exchange": "SHFE",
"leg2_ratio": 1,
"is_spread": true,
"size": 1,
"margin_rate": 0.1,
"price_tick": 5
}
"""
return self.custom_settings.get(symbol, None)
def get_all_custom_contracts(self, rtn_setting=False): def get_all_custom_contracts(self, rtn_setting=False):
""" """
获取所有自定义合约 获取所有自定义合约
@ -759,6 +789,7 @@ class OmsEngine(BaseEngine):
if len(self.custom_contracts) == 0: if len(self.custom_contracts) == 0:
c = CustomContract() c = CustomContract()
self.custom_settings = c.get_config()
self.custom_contracts = c.get_contracts() self.custom_contracts = c.get_contracts()
return self.custom_contracts return self.custom_contracts

View File

@ -34,7 +34,7 @@ from .object import (
Exchange Exchange
) )
from vnpy.trader.utility import get_folder_path from vnpy.trader.utility import get_folder_path, round_to
from vnpy.trader.util_logger import setup_logger from vnpy.trader.util_logger import setup_logger
@ -329,6 +329,229 @@ class BaseGateway(ABC):
return self.status return self.status
class TickCombiner(object):
"""
Tick合成类
"""
def __init__(self, gateway, setting):
self.gateway = gateway
self.gateway_name = self.gateway.gateway_name
self.gateway.write_log(u'创建tick合成类:{}'.format(setting))
self.symbol = setting.get('symbol', None)
self.leg1_symbol = setting.get('leg1_symbol', None)
self.leg2_symbol = setting.get('leg2_symbol', None)
self.leg1_ratio = setting.get('leg1_ratio', 1) # 腿1的数量配比
self.leg2_ratio = setting.get('leg2_ratio', 1) # 腿2的数量配比
self.price_tick = setting.get('price_tick', 1) # 合成价差加比后的最小跳动
# 价差
self.is_spread = setting.get('is_spread', False)
# 价比
self.is_ratio = setting.get('is_ratio', False)
self.last_leg1_tick = None
self.last_leg2_tick = None
# 价差日内最高/最低价
self.spread_high = None
self.spread_low = None
# 价比日内最高/最低价
self.ratio_high = None
self.ratio_low = None
# 当前交易日
self.trading_day = None
if self.is_ratio and self.is_spread:
self.gateway.write_error(u'{}参数有误,不能同时做价差/加比.setting:{}'.format(self.symbol, setting))
return
self.gateway.write_log(u'初始化{}合成器成功'.format(self.symbol))
if self.is_spread:
self.gateway.write_log(
u'leg1:{} * {} - leg2:{} * {}'.format(self.leg1_symbol, self.leg1_ratio, self.leg2_symbol,
self.leg2_ratio))
if self.is_ratio:
self.gateway.write_log(
u'leg1:{} * {} / leg2:{} * {}'.format(self.leg1_symbol, self.leg1_ratio, self.leg2_symbol,
self.leg2_ratio))
def on_tick(self, tick):
"""OnTick处理"""
combinable = False
if tick.symbol == self.leg1_symbol:
# leg1合约
self.last_leg1_tick = tick
if self.last_leg2_tick is not None:
if self.last_leg1_tick.datetime.replace(microsecond=0) == self.last_leg2_tick.datetime.replace(
microsecond=0):
combinable = True
elif tick.symbol == self.leg2_symbol:
# leg2合约
self.last_leg2_tick = tick
if self.last_leg1_tick is not None:
if self.last_leg2_tick.datetime.replace(microsecond=0) == self.last_leg1_tick.datetime.replace(
microsecond=0):
combinable = True
# 不能合并
if not combinable:
return
if not self.is_ratio and not self.is_spread:
return
# 以下情况,基本为单腿涨跌停,不合成价差/价格比 Tick
if (self.last_leg1_tick.ask_price_1 == 0 or self.last_leg1_tick.bid_price_1 == self.last_leg1_tick.limit_up) \
and self.last_leg1_tick.ask_volume_1 == 0:
self.gateway.write_log(
u'leg1:{0}涨停{1}不合成价差Tick'.format(self.last_leg1_tick.vt_symbol, self.last_leg1_tick.bid_price_1))
return
if (self.last_leg1_tick.bid_price_1 == 0 or self.last_leg1_tick.ask_price_1 == self.last_leg1_tick.limit_down) \
and self.last_leg1_tick.bid_volume_1 == 0:
self.gateway.write_log(
u'leg1:{0}跌停{1}不合成价差Tick'.format(self.last_leg1_tick.vt_symbol, self.last_leg1_tick.ask_price_1))
return
if (self.last_leg2_tick.ask_price_1 == 0 or self.last_leg2_tick.bid_price_1 == self.last_leg2_tick.limit_up) \
and self.last_leg2_tick.ask_volume_1 == 0:
self.gateway.write_log(
u'leg2:{0}涨停{1}不合成价差Tick'.format(self.last_leg2_tick.vt_symbol, self.last_leg2_tick.bid_price_1))
return
if (self.last_leg2_tick.bid_price_1 == 0 or self.last_leg2_tick.ask_price_1 == self.last_leg2_tick.limit_down) \
and self.last_leg2_tick.bid_volume_1 == 0:
self.gateway.write_log(
u'leg2:{0}跌停{1}不合成价差Tick'.format(self.last_leg2_tick.vt_symbol, self.last_leg2_tick.ask_price_1))
return
if self.trading_day != tick.trading_day:
self.trading_day = tick.trading_day
self.spread_high = None
self.spread_low = None
self.ratio_high = None
self.ratio_low = None
if self.is_spread:
spread_tick = TickData(gateway_name=self.gateway_name,
symbol=self.symbol,
exchange=Exchange.SPD,
datetime=tick.datetime)
spread_tick.trading_day = tick.trading_day
spread_tick.date = tick.date
spread_tick.time = tick.time
# 叫卖价差=leg1.ask_price_1 * 配比 - leg2.bid_price_1 * 配比volume为两者最小
spread_tick.ask_price_1 = round_to(target=self.price_tick,
value=self.last_leg1_tick.ask_price_1 * self.leg1_ratio - self.last_leg2_tick.bid_price_1 * self.leg2_ratio)
spread_tick.ask_volume_1 = min(self.last_leg1_tick.ask_volume_1, self.last_leg2_tick.bid_volume_1)
# 叫买价差=leg1.bid_price_1 * 配比 - leg2.ask_price_1 * 配比volume为两者最小
spread_tick.bid_price_1 = round_to(target=self.price_tick,
value=self.last_leg1_tick.bid_price_1 * self.leg1_ratio - self.last_leg2_tick.ask_price_1 * self.leg2_ratio)
spread_tick.bid_volume_1 = min(self.last_leg1_tick.bid_volume_1, self.last_leg2_tick.ask_volume_1)
# 最新价
spread_tick.last_price = round_to(target=self.price_tick,
value=(spread_tick.ask_price_1 + spread_tick.bid_price_1) / 2)
# 昨收盘价
if self.last_leg2_tick.pre_close > 0 and self.last_leg1_tick.pre_close > 0:
spread_tick.pre_close = round_to(target=self.price_tick,
value=self.last_leg1_tick.pre_close * self.leg1_ratio - self.last_leg2_tick.pre_close * self.leg2_ratio)
# 开盘价
if self.last_leg2_tick.open_price > 0 and self.last_leg1_tick.open_price > 0:
spread_tick.open_price = round_to(target=self.price_tick,
value=self.last_leg1_tick.open_price * self.leg1_ratio - self.last_leg2_tick.open_price * self.leg2_ratio)
# 最高价
if self.spread_high:
self.spread_high = max(self.spread_high, spread_tick.ask_price_1)
else:
self.spread_high = spread_tick.ask_price_1
spread_tick.high_price = self.spread_high
# 最低价
if self.spread_low:
self.spread_low = min(self.spread_low, spread_tick.bid_price_1)
else:
self.spread_low = spread_tick.bid_price_1
spread_tick.low_price = self.spread_low
self.gateway.on_tick(spread_tick)
if self.is_ratio:
ratio_tick = TickData(
gateway_name=self.gateway_name,
symbol=self.symbol,
exchange=Exchange.SPD,
datetime=tick.datetime
)
ratio_tick.trading_day = tick.trading_day
ratio_tick.date = tick.date
ratio_tick.time = tick.time
# 比率tick = (腿1 * 腿1 手数 / 腿2价格 * 腿2手数) 百分比
ratio_tick.ask_price_1 = 100 * self.last_leg1_tick.ask_price_1 * self.leg1_ratio \
/ (self.last_leg2_tick.bid_price_1 * self.leg2_ratio) # noqa
ratio_tick.ask_price_1 = round_to(
target=self.price_tick,
value=ratio_tick.ask_price_1
)
ratio_tick.ask_volume_1 = min(self.last_leg1_tick.ask_volume_1, self.last_leg2_tick.bid_volume_1)
ratio_tick.bid_price_1 = 100 * self.last_leg1_tick.bid_price_1 * self.leg1_ratio \
/ (self.last_leg2_tick.ask_price_1 * self.leg2_ratio) # noqa
ratio_tick.bid_price_1 = round_to(
target=self.price_tick,
value=ratio_tick.bid_price_1
)
ratio_tick.bid_volume_1 = min(self.last_leg1_tick.bid_volume_1, self.last_leg2_tick.ask_volume_1)
ratio_tick.last_price = (ratio_tick.ask_price_1 + ratio_tick.bid_price_1) / 2
ratio_tick.last_price = round_to(
target=self.price_tick,
value=ratio_tick.last_price
)
# 昨收盘价
if self.last_leg2_tick.pre_close > 0 and self.last_leg1_tick.pre_close > 0:
ratio_tick.pre_close = 100 * self.last_leg1_tick.pre_close * self.leg1_ratio / (
self.last_leg2_tick.pre_close * self.leg2_ratio) # noqa
ratio_tick.pre_close = round_to(
target=self.price_tick,
value=ratio_tick.pre_close
)
# 开盘价
if self.last_leg2_tick.open_price > 0 and self.last_leg1_tick.open_price > 0:
ratio_tick.open_price = 100 * self.last_leg1_tick.open_price * self.leg1_ratio / (
self.last_leg2_tick.open_price * self.leg2_ratio) # noqa
ratio_tick.open_price = round_to(
target=self.price_tick,
value=ratio_tick.open_price
)
# 最高价
if self.ratio_high:
self.ratio_high = max(self.ratio_high, ratio_tick.ask_price_1)
else:
self.ratio_high = ratio_tick.ask_price_1
ratio_tick.high_price = self.spread_high
# 最低价
if self.ratio_low:
self.ratio_low = min(self.ratio_low, ratio_tick.bid_price_1)
else:
self.ratio_low = ratio_tick.bid_price_1
ratio_tick.low_price = self.spread_low
self.gateway.on_tick(ratio_tick)
class LocalOrderManager: class LocalOrderManager:
""" """
Management tool to support use local order id for trading. Management tool to support use local order id for trading.