[增强功能] 数字策略间互访问,多合约模板

This commit is contained in:
msincenselee 2020-04-24 11:41:20 +08:00
parent 40d4170711
commit b85188298e
4 changed files with 117 additions and 22 deletions

View File

@ -9,6 +9,7 @@ from .template import (
Direction, Direction,
Offset, Offset,
Status, Status,
OrderType,
Interval, Interval,
TickData, TickData,
BarData, BarData,

View File

@ -419,6 +419,15 @@ class BackTestingEngine(object):
self.positions[k] = pos self.positions[k] = pos
return pos return pos
def get_strategy_value(self, strategy_name: str, parameter:str):
"""获取策略的某个参数值"""
strategy = self.strategies.get(strategy_name)
if not strategy:
return None
value = getattr(strategy, parameter, None)
return value
def set_name(self, test_name): def set_name(self, test_name):
""" """
设置组合的运行实例名称 设置组合的运行实例名称

View File

@ -62,7 +62,6 @@ from vnpy.trader.utility import (
from vnpy.trader.util_logger import setup_logger, logging from vnpy.trader.util_logger import setup_logger, logging
from vnpy.trader.util_wechat import send_wx_msg from vnpy.trader.util_wechat import send_wx_msg
from vnpy.trader.converter import PositionHolding
from .base import ( from .base import (
APP_NAME, APP_NAME,
@ -223,7 +222,6 @@ class CtaEngine(BaseEngine):
# 推送到事件 # 推送到事件
self.put_all_strategy_pos_event(all_strategy_pos) self.put_all_strategy_pos_event(all_strategy_pos)
def process_tick_event(self, event: Event): def process_tick_event(self, event: Event):
"""处理tick到达事件""" """处理tick到达事件"""
tick = event.data tick = event.data
@ -358,8 +356,6 @@ class CtaEngine(BaseEngine):
contract = self.main_engine.get_contract(vt_symbol) contract = self.main_engine.get_contract(vt_symbol)
is_bar = True if vt_symbol in self.bar_strategy_map else False is_bar = True if vt_symbol in self.bar_strategy_map else False
if contract: if contract:
dt = datetime.now()
self.write_log(f'重新提交合约{vt_symbol}订阅请求') self.write_log(f'重新提交合约{vt_symbol}订阅请求')
for strategy_name, is_bar in list(self.pending_subcribe_symbol_map[vt_symbol]): for strategy_name, is_bar in list(self.pending_subcribe_symbol_map[vt_symbol]):
self.subscribe_symbol(strategy_name=strategy_name, self.subscribe_symbol(strategy_name=strategy_name,
@ -689,7 +685,7 @@ class CtaEngine(BaseEngine):
volume=volume, volume=volume,
type=order_type, type=order_type,
gateway_name=gateway_name gateway_name=gateway_name
) )
def cancel_order(self, strategy: CtaTemplate, vt_orderid: str): def cancel_order(self, strategy: CtaTemplate, vt_orderid: str):
""" """
@ -825,7 +821,7 @@ class CtaEngine(BaseEngine):
else: else:
return 0, 0, 0, 0 return 0, 0, 0, 0
def get_position(self, vt_symbol: str, direction: Direction, gateway_name: str = ''): def get_position(self, vt_symbol: str, direction: Direction = Direction.NET, gateway_name: str = ''):
""" 查询合约在账号的持仓,需要指定方向""" """ 查询合约在账号的持仓,需要指定方向"""
contract = self.main_engine.get_contract(vt_symbol) contract = self.main_engine.get_contract(vt_symbol)
if contract: if contract:
@ -888,19 +884,7 @@ class CtaEngine(BaseEngine):
callback: Callable[[TickData], None] callback: Callable[[TickData], None]
): ):
"""""" """"""
symbol, exchange = extract_vt_symbol(vt_symbol) pass
end = datetime.now()
start = end - timedelta(days)
ticks = database_manager.load_tick_data(
symbol=symbol,
exchange=exchange,
start=start,
end=end,
)
for tick in ticks:
callback(tick)
def call_strategy_func( def call_strategy_func(
self, strategy: CtaTemplate, func: Callable, params: Any = None self, strategy: CtaTemplate, func: Callable, params: Any = None
@ -1255,7 +1239,7 @@ class CtaEngine(BaseEngine):
# 通过事件方式传导到account_recorder # 通过事件方式传导到account_recorder
snapshot.update({ snapshot.update({
'account_id': self.engine_config.get('accountid', '-'), 'account_id': self.engine_config.get('accountid', '-'),
'strategy_group': self.engine_config.get('strategy_group', self.engine_name), 'strategy_group': self.engine_config.get('strategy_group', self.engine_name),
'guid': str(uuid1()) 'guid': str(uuid1())
}) })
event = Event(EVENT_STRATEGY_SNAPSHOT, snapshot) event = Event(EVENT_STRATEGY_SNAPSHOT, snapshot)
@ -1460,7 +1444,7 @@ class CtaEngine(BaseEngine):
return parameters return parameters
def get_strategy_parameters(self, strategy_name): def get_strategy_parameters(self, strategy_name: str):
""" """
Get parameters of a strategy. Get parameters of a strategy.
""" """
@ -1472,6 +1456,15 @@ class CtaEngine(BaseEngine):
d.update(strategy.get_parameters()) d.update(strategy.get_parameters())
return d return d
def get_strategy_value(self, strategy_name: str, parameter:str):
"""获取策略的某个参数值"""
strategy = self.strategies.get(strategy_name)
if not strategy:
return None
value = getattr(strategy, parameter, None)
return value
def compare_pos(self, strategy_pos_list=[]): def compare_pos(self, strategy_pos_list=[]):
""" """
对比账号&策略的持仓,不同的话则发出微信提醒 对比账号&策略的持仓,不同的话则发出微信提醒
@ -1535,7 +1528,7 @@ class CtaEngine(BaseEngine):
u'{}({})'.format(strategy_pos['strategy_name'], abs(pos.get('volume', 0)))) u'{}({})'.format(strategy_pos['strategy_name'], abs(pos.get('volume', 0))))
self.write_log(u'更新{}策略持空仓=>{}'.format(vt_symbol, symbol_pos.get('策略空单', 0))) self.write_log(u'更新{}策略持空仓=>{}'.format(vt_symbol, symbol_pos.get('策略空单', 0)))
if pos.get('direction') == 'long': if pos.get('direction') == 'long':
symbol_pos.update({'策略多单': round(symbol_pos.get('策略多单', 0) + abs(pos.get('volume', 0)),7)}) symbol_pos.update({'策略多单': round(symbol_pos.get('策略多单', 0) + abs(pos.get('volume', 0)), 7)})
symbol_pos['多单策略'].append( symbol_pos['多单策略'].append(
u'{}({})'.format(strategy_pos['strategy_name'], abs(pos.get('volume', 0)))) u'{}({})'.format(strategy_pos['strategy_name'], abs(pos.get('volume', 0))))
self.write_log(u'更新{}策略持多仓=>{}'.format(vt_symbol, symbol_pos.get('策略多单', 0))) self.write_log(u'更新{}策略持多仓=>{}'.format(vt_symbol, symbol_pos.get('策略多单', 0)))

View File

@ -19,6 +19,7 @@ from vnpy.trader.utility import virtual, append_data, extract_vt_symbol, get_und
from .base import StopOrder from .base import StopOrder
from vnpy.component.cta_grid_trade import CtaGrid, CtaGridTrade from vnpy.component.cta_grid_trade import CtaGrid, CtaGridTrade
from vnpy.component.cta_position import CtaPosition from vnpy.component.cta_position import CtaPosition
from vnpy.component.cta_policy import CtaPolicy
class CtaTemplate(ABC): class CtaTemplate(ABC):
@ -1376,6 +1377,8 @@ class CtaFutureTemplate(CtaTemplate):
dist_data.update({'margin': dist_data.get('price', 0) * dist_data.get('volume', dist_data.update({'margin': dist_data.get('price', 0) * dist_data.get('volume',
0) * self.cta_engine.get_margin_rate( 0) * self.cta_engine.get_margin_rate(
dist_data.get('symbol', self.vt_symbol))}) dist_data.get('symbol', self.vt_symbol))})
if 'datetime' not in dist_data:
dist_data.update({'datetime': self.cur_datetime})
if self.position and 'long_pos' not in dist_data: if self.position and 'long_pos' not in dist_data:
dist_data.update({'long_pos': self.position.long_pos}) dist_data.update({'long_pos': self.position.long_pos})
if self.position and 'short_pos' not in dist_data: if self.position and 'short_pos' not in dist_data:
@ -1408,3 +1411,92 @@ class CtaFutureTemplate(CtaTemplate):
if self.backtesting: if self.backtesting:
return return
self.cta_engine.send_wechat(msg=msg, strategy=self) self.cta_engine.send_wechat(msg=msg, strategy=self)
class MultiContractPolicy(CtaPolicy):
"""多合约Policy记录持仓"""
def __init__(self, strategy=None, **kwargs):
super().__init__(strategy, **kwargs)
self.debug = kwargs.get('debug', False)
self.positions = {} # vt_symbol: net_pos
def from_json(self, json_data):
"""将数据从json_data中恢复"""
super().from_json(json_data)
self.positions = json_data.get('positions')
def to_json(self):
"""转换至json文件"""
j = super().to_json()
j['positions'] = self.positions
return j
def on_trade(self, trade: TradeData):
"""更新交易"""
pos = self.positions.get(trade.vt_symbol)
if pos is None:
pos = 0
pre_pos = pos
if trade.direction == Direction.LONG:
pos = round(pos + trade.volume, 7)
elif trade.direction == Direction.SHORT:
pos = round(pos - trade.volume, 7)
self.positions.update({trade.vt_symbol: pos})
if self.debug and self.strategy:
self.strategy.write_log(f'{trade.vt_symbol} pos:{pre_pos}=>{pos}')
self.save()
class MultiContractTemplate(CtaTemplate):
"""多合约交易模板"""
def __init__(self, cta_engine, strategy_name, vt_symbol, setting):
self.policy = None
self.cur_datetime = None
super().__init__(cta_engine, strategy_name, vt_symbol, setting)
self.policy = MultiContractPolicy(strategy=self, debug=True)
def sync_data(self):
"""同步更新数据"""
if self.inited and self.trading:
self.write_log(u'保存policy数据')
self.policy.save()
def on_trade(self, trade: TradeData):
"""成交回报事件处理"""
self.policy.on_trade(trade)
def get_positions(self):
""" 获取策略所有持仓详细"""
pos_list = []
for vt_symbol, pos in self.policy.positions.items():
pos_list.append({'vt_symbol': vt_symbol,
'direction': 'long' if pos >= 0 else 'short',
'volume': pos})
if self.cur_datetime and (datetime.now() - self.cur_datetime).total_seconds() < 10:
self.write_log(u'{}当前持仓:{}'.format(self.strategy_name, pos_list))
return pos_list
def on_order(self, order: OrderData):
pass
def on_init(self):
self.inited = True
def on_start(self):
self.trading = True
def on_stop(self):
self.trading = False