[功能增强]股票复利修复,指数行情推送兼容1.x,CTA引擎支持账号仓位

This commit is contained in:
msincenselee 2020-05-13 11:28:35 +08:00
parent e6452844a9
commit 66de40e5f7
9 changed files with 184 additions and 20 deletions

View File

@ -598,7 +598,8 @@ class BackTestingEngine(object):
adj_factor = adj_factor / adj_factor.iloc[0] # 保证第一个复权因子是1 adj_factor = adj_factor / adj_factor.iloc[0] # 保证第一个复权因子是1
# 把raw_data的第一个日期插入复权因子df使用后填充 # 把raw_data的第一个日期插入复权因子df使用后填充
adj_factor.loc[raw_data.index[0]] = np.nan if adj_factor.index[0] != raw_data.index[0]:
adj_factor.loc[raw_data.index[0]] = np.nan
adj_factor.sort_index(inplace=True) adj_factor.sort_index(inplace=True)
adj_factor = adj_factor.ffill() adj_factor = adj_factor.ffill()

View File

@ -1457,7 +1457,24 @@ class CtaEngine(BaseEngine):
d.update(strategy.get_parameters()) d.update(strategy.get_parameters())
return d return d
def compare_pos(self,strategy_pos_list=[]): def get_none_strategy_pos_list(self):
"""获取非策略持有的仓位"""
# 格式 [ 'strategy_name':'account', 'pos': [{'vt_symbol': '', 'direction': 'xxx', 'volume':xxx }] } ]
none_strategy_pos_file = os.path.abspath(os.path.join(os.getcwd(), 'data', 'none_strategy_pos.json'))
if not os.path.exists(none_strategy_pos_file):
return []
try:
with open(none_strategy_pos_file, encoding='utf8') as f:
pos_list = json.load(f)
if isinstance(pos_list, list):
return pos_list
return []
except Exception as ex:
self.write_error(u'未能读取或解释{}'.format(none_strategy_pos_file))
return []
def compare_pos(self, strategy_pos_list=[]):
""" """
对比账号&策略的持仓,不同的话则发出微信提醒 对比账号&策略的持仓,不同的话则发出微信提醒
:return: :return:
@ -1473,13 +1490,15 @@ class CtaEngine(BaseEngine):
strategy_pos_list = self.get_all_strategy_pos() strategy_pos_list = self.get_all_strategy_pos()
self.write_log(u'策略持仓清单:{}'.format(strategy_pos_list)) self.write_log(u'策略持仓清单:{}'.format(strategy_pos_list))
none_strategy_pos = self.get_none_strategy_pos_list()
if len(none_strategy_pos) > 0:
strategy_pos_list.extend(none_strategy_pos)
# 需要进行对比得合约集合(来自策略持仓/账号持仓) # 需要进行对比得合约集合(来自策略持仓/账号持仓)
vt_symbols = set() vt_symbols = set()
# 账号的持仓处理 => account_pos # 账号的持仓处理 => account_pos
compare_pos = dict() # vt_symbol: {'账号多单': xx,'策略多单':[]} compare_pos = dict() # vt_symbol: {'账号多单': xx,'策略多单':[]}
for position in list(self.positions.values()): for position in list(self.positions.values()):
# gateway_name.symbol.exchange => symbol.exchange # gateway_name.symbol.exchange => symbol.exchange
vt_symbol = position.vt_symbol vt_symbol = position.vt_symbol

View File

@ -1211,7 +1211,7 @@ class CtaEngine(BaseEngine):
# 通过事件方式传导到account_recorder # 通过事件方式传导到account_recorder
snapshot.update({ snapshot.update({
'account_id': self.engine_config.get('accountid', '-'), 'account_id': self.engine_config.get('accountid', '-'),
'strategy_group': self.engine_config.get('strategy_group', self.engine_name), 'strategy_group': self.engine_config.get('strategy_group', self.engine_name),
'guid': str(uuid1()) 'guid': str(uuid1())
}) })
event = Event(EVENT_STRATEGY_SNAPSHOT, snapshot) event = Event(EVENT_STRATEGY_SNAPSHOT, snapshot)
@ -1474,7 +1474,7 @@ class CtaEngine(BaseEngine):
d.update(strategy.get_parameters()) d.update(strategy.get_parameters())
return d return d
def get_strategy_value(self, strategy_name: str, parameter:str): def get_strategy_value(self, strategy_name: str, parameter: str):
"""获取策略的某个参数值""" """获取策略的某个参数值"""
strategy = self.strategies.get(strategy_name) strategy = self.strategies.get(strategy_name)
if not strategy: if not strategy:
@ -1483,7 +1483,24 @@ class CtaEngine(BaseEngine):
value = getattr(strategy, parameter, None) value = getattr(strategy, parameter, None)
return value return value
def compare_pos(self, strategy_pos_list=[]): def get_none_strategy_pos_list(self):
"""获取非策略持有的仓位"""
# 格式 [ 'strategy_name':'account', 'pos': [{'vt_symbol': '', 'direction': 'xxx', 'volume':xxx }] } ]
none_strategy_pos_file = os.path.abspath(os.path.join(os.getcwd(), 'data', 'none_strategy_pos.json'))
if not os.path.exists(none_strategy_pos_file):
return []
try:
with open(none_strategy_pos_file, encoding='utf8') as f:
pos_list = json.load(f)
if isinstance(pos_list, list):
return pos_list
return []
except Exception as ex:
self.write_error(u'未能读取或解释{}'.format(none_strategy_pos_file))
return []
def compare_pos(self, strategy_pos_list=[], auto_balance=False):
""" """
对比账号&策略的持仓,不同的话则发出微信提醒 对比账号&策略的持仓,不同的话则发出微信提醒
:return: :return:
@ -1499,6 +1516,10 @@ class CtaEngine(BaseEngine):
strategy_pos_list = self.get_all_strategy_pos() strategy_pos_list = self.get_all_strategy_pos()
self.write_log(u'策略持仓清单:{}'.format(strategy_pos_list)) self.write_log(u'策略持仓清单:{}'.format(strategy_pos_list))
none_strategy_pos = self.get_none_strategy_pos_list()
if len(none_strategy_pos) > 0:
strategy_pos_list.extend(none_strategy_pos)
# 需要进行对比得合约集合(来自策略持仓/账号持仓) # 需要进行对比得合约集合(来自策略持仓/账号持仓)
vt_symbols = set() vt_symbols = set()
@ -1584,8 +1605,9 @@ class CtaEngine(BaseEngine):
compare_info += msg compare_info += msg
else: else:
pos_compare_result += '\n{}: '.format(vt_symbol) pos_compare_result += '\n{}: '.format(vt_symbol)
# 多单不一致 # 判断是多单不一致?
if round(symbol_pos['策略多单'], 7) != round(symbol_pos['账号多单'], 7): diff_long_volume = round(symbol_pos['账号多单'], 7) - round(symbol_pos['策略多单'], 7)
if diff_long_volume != 0:
msg = '{}多单[账号({}), 策略{},共({})], ' \ msg = '{}多单[账号({}), 策略{},共({})], ' \
.format(vt_symbol, .format(vt_symbol,
symbol_pos['账号多单'], symbol_pos['账号多单'],
@ -1595,8 +1617,13 @@ class CtaEngine(BaseEngine):
pos_compare_result += msg pos_compare_result += msg
self.write_error(u'{}不一致:{}'.format(vt_symbol, msg)) self.write_error(u'{}不一致:{}'.format(vt_symbol, msg))
compare_info += u'{}不一致:{}\n'.format(vt_symbol, msg) compare_info += u'{}不一致:{}\n'.format(vt_symbol, msg)
# 空单不一致 if auto_balance:
if round(symbol_pos['策略空单'], 7) != round(symbol_pos['账号空单'], 7): self.balance_pos(vt_symbol, Direction.LONG, diff_long_volume)
# 判断是空单不一致:
diff_short_volume = round(symbol_pos['账号空单'], 7) - round(symbol_pos['策略空单'], 7)
if diff_short_volume != 0:
msg = '{}空单[账号({}), 策略{},共({})], ' \ msg = '{}空单[账号({}), 策略{},共({})], ' \
.format(vt_symbol, .format(vt_symbol,
symbol_pos['账号空单'], symbol_pos['账号空单'],
@ -1605,6 +1632,8 @@ class CtaEngine(BaseEngine):
pos_compare_result += msg pos_compare_result += msg
self.write_error(u'{}不一致:{}'.format(vt_symbol, msg)) self.write_error(u'{}不一致:{}'.format(vt_symbol, msg))
compare_info += u'{}不一致:{}\n'.format(vt_symbol, msg) compare_info += u'{}不一致:{}\n'.format(vt_symbol, msg)
if auto_balance:
self.balance_pos(vt_symbol, Direction.SHORT, diff_short_volume)
# 不匹配输入到stdErr通道 # 不匹配输入到stdErr通道
if pos_compare_result != '': if pos_compare_result != '':
@ -1614,7 +1643,7 @@ class CtaEngine(BaseEngine):
try: try:
from vnpy.trader.util_wechat import send_wx_msg from vnpy.trader.util_wechat import send_wx_msg
send_wx_msg(content=msg) send_wx_msg(content=msg)
except Exception as ex: # noqa except Exception as ex: # noqa
pass pass
ret_msg = u'持仓不匹配: {}' \ ret_msg = u'持仓不匹配: {}' \
.format(pos_compare_result) .format(pos_compare_result)
@ -1624,6 +1653,51 @@ class CtaEngine(BaseEngine):
self.write_log(u'账户持仓与策略一致') self.write_log(u'账户持仓与策略一致')
return True, compare_info return True, compare_info
def balance_pos(self, vt_symbol, direction, volume):
"""
平衡仓位
:param vt_symbol: 需要平衡得合约
:param direction: 合约原始方向
:param volume: 合约需要调整得数量正数需要平仓 负数需要开仓
:return:
"""
tick = self.get_tick(vt_symbol)
if tick is None:
gateway_names = self.main_engine.get_all_gateway_names()
gateway_name = gateway_names[0] if len(gateway_names) > 0 else ""
symbol, exchange = extract_vt_symbol(vt_symbol)
self.main_engine.subscribe(req=SubscribeRequest(symbol=symbol, exchange=exchange), gateway_name=gateway_name)
if volume > 0 and tick:
contract = self.main_engine.get_contract(vt_symbol)
req = OrderRequest(
symbol=contract.symbol,
exchange=contract.exchange,
direction=Direction.SHORT if direction == Direction.LONG else Direction.LONG,
offset=Offset.CLOSE,
type=OrderType.FAK,
price=tick.ask_price_1 if direction == Direction.SHORT else tick.bid_price_1,
volume=round(volume, 7)
)
reqs = self.offset_converter.convert_order_request(req=req, lock=False)
self.write_log(f'平衡仓位,减少 {vt_symbol},方向:{direction},数量:{req.volume} ')
for req in reqs:
self.main_engine.send_order(req, contract.gateway_name)
elif volume < 0 and tick:
contract = self.main_engine.get_contract(vt_symbol)
req = OrderRequest(
symbol=contract.symbol,
exchange=contract.exchange,
direction=direction,
offset=Offset.OPEN,
type=OrderType.FAK,
price=tick.ask_price_1 if direction == Direction.LONG else tick.bid_price_1,
volume=round(abs(volume), 7)
)
reqs = self.offset_converter.convert_order_request(req=req, lock=False)
self.write_log(f'平衡仓位, 增加{vt_symbol} 方向:{direction}, 数量: {req.volume}')
for req in reqs:
self.main_engine.send_order(req, contract.gateway_name)
def init_all_strategies(self): def init_all_strategies(self):
""" """
""" """

View File

@ -628,6 +628,8 @@ class CtaProTemplate(CtaTemplate):
if self.idx_symbol is None: if self.idx_symbol is None:
symbol, exchange = extract_vt_symbol(self.vt_symbol) symbol, exchange = extract_vt_symbol(self.vt_symbol)
self.idx_symbol = get_underlying_symbol(symbol).upper() + '99.' + exchange.value self.idx_symbol = get_underlying_symbol(symbol).upper() + '99.' + exchange.value
self.cta_engine.subscribe_symbol(strategy_name=self.strategy_name, vt_symbol=self.idx_symbol)
if self.vt_symbol != self.idx_symbol: if self.vt_symbol != self.idx_symbol:
self.write_log(f'指数合约:{self.idx_symbol}, 主力合约:{self.vt_symbol}') self.write_log(f'指数合约:{self.idx_symbol}, 主力合约:{self.vt_symbol}')
self.price_tick = self.cta_engine.get_price_tick(self.vt_symbol) self.price_tick = self.cta_engine.get_price_tick(self.vt_symbol)

View File

@ -26,7 +26,7 @@ from vnpy.component.base import (
from vnpy.amqp.producer import publisher from vnpy.amqp.producer import publisher
APP_NAME = 'INDEXDATAPUBLISHER' APP_NAME = 'Idx_Publisher'
class IndexTickPublisher(BaseEngine): class IndexTickPublisher(BaseEngine):
@ -472,6 +472,6 @@ class IndexTickPublisher(BaseEngine):
d = copy.copy(tick.__dict__) d = copy.copy(tick.__dict__)
if isinstance(tick.datetime, datetime): if isinstance(tick.datetime, datetime):
d.update({'datetime': tick.datetime.strftime('%Y-%m-%d %H:%M:%S.%f')}) d.update({'datetime': tick.datetime.strftime('%Y-%m-%d %H:%M:%S.%f')})
d.update({'exchange': tick.exchange.value()}) d.update({'exchange': tick.exchange.value})
d = json.dumps(d) d = json.dumps(d)
self.pub.pub(d) self.pub.pub(d)

View File

@ -3774,7 +3774,7 @@ class CtaLineBar(object):
if runtime: if runtime:
# 兼容写法,如果老策略没有配置实时运行,又用到实时数据,就添加 # 兼容写法,如果老策略没有配置实时运行,又用到实时数据,就添加
if self.rt_count_skd not in self.rt_funcs: if self.rt_count_skd not in self.rt_funcs:
self.write_log(u'skd_is_high_dead_cross(),添加rt_countSkd到实时函数中') self.write_log(u'rt_count_skd(),添加rt_countSkd到实时函数中')
self.rt_funcs.add(self.rt_count_skd) self.rt_funcs.add(self.rt_count_skd)
self.rt_count_sk_sd() self.rt_count_sk_sd()
if self._rt_sk is None or self._rt_sd is None: if self._rt_sk is None or self._rt_sd is None:

View File

@ -366,16 +366,16 @@ class TdxStockData(object):
cache_date: str): cache_date: str):
"""加载缓存数据""" """加载缓存数据"""
if not os.path.exists(cache_folder): if not os.path.exists(cache_folder):
self.write_error('缓存目录:{}不存在,不能读取'.format(cache_folder)) #self.write_error('缓存目录:{}不存在,不能读取'.format(cache_folder))
return None return None
cache_folder_year_month = os.path.join(cache_folder, cache_date[:6]) cache_folder_year_month = os.path.join(cache_folder, cache_date[:6])
if not os.path.exists(cache_folder_year_month): if not os.path.exists(cache_folder_year_month):
self.write_error('缓存目录:{}不存在,不能读取'.format(cache_folder_year_month)) #self.write_error('缓存目录:{}不存在,不能读取'.format(cache_folder_year_month))
return None return None
cache_file = os.path.join(cache_folder_year_month, '{}_{}.pkb2'.format(cache_symbol, cache_date)) cache_file = os.path.join(cache_folder_year_month, '{}_{}.pkb2'.format(cache_symbol, cache_date))
if not os.path.isfile(cache_file): if not os.path.isfile(cache_file):
self.write_error('缓存文件:{}不存在,不能读取'.format(cache_file)) #self.write_error('缓存文件:{}不存在,不能读取'.format(cache_file))
return None return None
with bz2.BZ2File(cache_file, 'rb') as f: with bz2.BZ2File(cache_file, 'rb') as f:
data = pickle.load(f) data = pickle.load(f)

View File

@ -1583,6 +1583,10 @@ class SubMdApi():
try: try:
str_tick = body.decode('utf-8') str_tick = body.decode('utf-8')
d = json.loads(str_tick) d = json.loads(str_tick)
d.pop('rawData', None)
d = self.conver_update(d)
symbol = d.pop('symbol', None) symbol = d.pop('symbol', None)
str_datetime = d.pop('datetime', None) str_datetime = d.pop('datetime', None)
if symbol not in self.registed_symbol_set or str_datetime is None: if symbol not in self.registed_symbol_set or str_datetime is None:
@ -1592,14 +1596,12 @@ class SubMdApi():
else: else:
dt = datetime.strptime(str_datetime, '%Y-%m-%d %H:%M:%S') dt = datetime.strptime(str_datetime, '%Y-%m-%d %H:%M:%S')
d.pop('rawData', None)
tick = TickData(gateway_name=self.gateway_name, tick = TickData(gateway_name=self.gateway_name,
exchange=Exchange(d.get('exchange')), exchange=Exchange(d.get('exchange')),
symbol=d.get('symbol'), symbol=d.get('symbol'),
datetime=dt) datetime=dt)
d.pop('exchange', None) d.pop('exchange', None)
d.pop('symbol', None) d.pop('symbol', None)
d.pop()
tick.__dict__.update(d) tick.__dict__.update(d)
self.symbol_tick_dict[symbol] = tick self.symbol_tick_dict[symbol] = tick
@ -1610,6 +1612,62 @@ class SubMdApi():
self.gateway.write_error(u'RabbitMQ on_message 异常:{}'.format(str(ex))) self.gateway.write_error(u'RabbitMQ on_message 异常:{}'.format(str(ex)))
self.gateway.write_error(traceback.format_exc()) self.gateway.write_error(traceback.format_exc())
def conver_update(self, d):
"""转换dict vnpy1 tick dict => vnpy2 tick dict"""
if 'vtSymbol' not in d:
return d
symbol= d.get('symbol')
exchange = d.get('exchange')
vtSymbol = d.pop('vtSymbol', symbol)
if '.' not in symbol:
d.update({'vt_symbol': f'{symbol}.{exchange}'})
else:
d.update({'vt_symbol': f'{symbol}.{Exchange.LOCAL.value}'})
# 成交数据
d.update({'last_price': d.pop('lastPrice',0.0)}) # 最新成交价
d.update({'last_volume': d.pop('lastVolume', 0)}) # 最新成交量
d.update({'open_interest': d.pop('openInterest', 0)}) # 昨持仓量
d.update({'open_interest': d.pop('tradingDay', get_trading_date())})
# 常规行情
d.update({'open_price': d.pop('openPrice', 0)}) # 今日开盘价
d.update({'high_price': d.pop('highPrice', 0)}) # 今日最高价
d.update({'low_price': d.pop('lowPrice', 0)}) # 今日最低价
d.update({'pre_close': d.pop('preClosePrice', 0)}) # 昨收盘价
d.update({'limit_up': d.pop('upperLimit', 0)}) # 涨停价
d.update({'limit_down': d.pop('lowerLimit', 0)}) # 跌停价
# 五档行情
d.update({'bid_price_1': d.pop('bidPrice1', 0.0)})
d.update({'bid_price_2': d.pop('bidPrice2', 0.0)})
d.update({'bid_price_3': d.pop('bidPrice3', 0.0)})
d.update({'bid_price_4': d.pop('bidPrice4', 0.0)})
d.update({'bid_price_5': d.pop('bidPrice5', 0.0)})
d.update({'ask_price_1': d.pop('askPrice1', 0.0)})
d.update({'ask_price_2': d.pop('askPrice2', 0.0)})
d.update({'ask_price_3': d.pop('askPrice3', 0.0)})
d.update({'ask_price_4': d.pop('askPrice4', 0.0)})
d.update({'ask_price_5': d.pop('askPrice5', 0.0)})
d.update({'bid_volume_1': d.pop('bidVolume1', 0.0)})
d.update({'bid_volume_2': d.pop('bidVolume2', 0.0)})
d.update({'bid_volume_3': d.pop('bidVolume3', 0.0)})
d.update({'bid_volume_4': d.pop('bidVolume4', 0.0)})
d.update({'bid_volume_5': d.pop('bidVolume5', 0.0)})
d.update({'ask_volume_1': d.pop('askVolume1', 0.0)})
d.update({'ask_volume_2': d.pop('askVolume2', 0.0)})
d.update({'ask_volume_3': d.pop('askVolume3', 0.0)})
d.update({'ask_volume_4': d.pop('askVolume4', 0.0)})
d.update({'ask_volume_5': d.pop('askVolume5', 0.0)})
return d
def close(self): def close(self):
"""退出API""" """退出API"""
self.gateway.write_log(u'退出rabbit行情订阅API') self.gateway.write_log(u'退出rabbit行情订阅API')

View File

@ -158,3 +158,13 @@ class Interval(Enum):
DAILY = "d" DAILY = "d"
WEEKLY = "w" WEEKLY = "w"
RENKO = 'renko' RENKO = 'renko'
class StockType(Enum):
"""股票类型tdx"""
STOCK = 'stock_cn' # 股票
STOCKB = 'stockB_cn' # 深圳B股票特别
INDEX = 'index_cn' # 指数
BOND = 'bond_cn' # 企业债券
ETF = 'etf_cn' # ETF
CB = 'cb_cn' # 可转债
UNDEFINED = 'undefined' # 未定义