[增强功能] 数字策略间互访问,多合约模板
This commit is contained in:
parent
40d4170711
commit
b85188298e
@ -9,6 +9,7 @@ from .template import (
|
||||
Direction,
|
||||
Offset,
|
||||
Status,
|
||||
OrderType,
|
||||
Interval,
|
||||
TickData,
|
||||
BarData,
|
||||
|
@ -419,6 +419,15 @@ class BackTestingEngine(object):
|
||||
self.positions[k] = pos
|
||||
return pos
|
||||
|
||||
def get_strategy_value(self, strategy_name: str, parameter:str):
|
||||
"""获取策略的某个参数值"""
|
||||
strategy = self.strategies.get(strategy_name)
|
||||
if not strategy:
|
||||
return None
|
||||
|
||||
value = getattr(strategy, parameter, None)
|
||||
return value
|
||||
|
||||
def set_name(self, test_name):
|
||||
"""
|
||||
设置组合的运行实例名称
|
||||
|
@ -62,7 +62,6 @@ from vnpy.trader.utility import (
|
||||
|
||||
from vnpy.trader.util_logger import setup_logger, logging
|
||||
from vnpy.trader.util_wechat import send_wx_msg
|
||||
from vnpy.trader.converter import PositionHolding
|
||||
|
||||
from .base import (
|
||||
APP_NAME,
|
||||
@ -223,7 +222,6 @@ class CtaEngine(BaseEngine):
|
||||
# 推送到事件
|
||||
self.put_all_strategy_pos_event(all_strategy_pos)
|
||||
|
||||
|
||||
def process_tick_event(self, event: Event):
|
||||
"""处理tick到达事件"""
|
||||
tick = event.data
|
||||
@ -358,8 +356,6 @@ class CtaEngine(BaseEngine):
|
||||
contract = self.main_engine.get_contract(vt_symbol)
|
||||
is_bar = True if vt_symbol in self.bar_strategy_map else False
|
||||
if contract:
|
||||
dt = datetime.now()
|
||||
|
||||
self.write_log(f'重新提交合约{vt_symbol}订阅请求')
|
||||
for strategy_name, is_bar in list(self.pending_subcribe_symbol_map[vt_symbol]):
|
||||
self.subscribe_symbol(strategy_name=strategy_name,
|
||||
@ -825,7 +821,7 @@ class CtaEngine(BaseEngine):
|
||||
else:
|
||||
return 0, 0, 0, 0
|
||||
|
||||
def get_position(self, vt_symbol: str, direction: Direction, gateway_name: str = ''):
|
||||
def get_position(self, vt_symbol: str, direction: Direction = Direction.NET, gateway_name: str = ''):
|
||||
""" 查询合约在账号的持仓,需要指定方向"""
|
||||
contract = self.main_engine.get_contract(vt_symbol)
|
||||
if contract:
|
||||
@ -888,19 +884,7 @@ class CtaEngine(BaseEngine):
|
||||
callback: Callable[[TickData], None]
|
||||
):
|
||||
""""""
|
||||
symbol, exchange = extract_vt_symbol(vt_symbol)
|
||||
end = datetime.now()
|
||||
start = end - timedelta(days)
|
||||
|
||||
ticks = database_manager.load_tick_data(
|
||||
symbol=symbol,
|
||||
exchange=exchange,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
for tick in ticks:
|
||||
callback(tick)
|
||||
pass
|
||||
|
||||
def call_strategy_func(
|
||||
self, strategy: CtaTemplate, func: Callable, params: Any = None
|
||||
@ -1460,7 +1444,7 @@ class CtaEngine(BaseEngine):
|
||||
|
||||
return parameters
|
||||
|
||||
def get_strategy_parameters(self, strategy_name):
|
||||
def get_strategy_parameters(self, strategy_name: str):
|
||||
"""
|
||||
Get parameters of a strategy.
|
||||
"""
|
||||
@ -1472,6 +1456,15 @@ class CtaEngine(BaseEngine):
|
||||
d.update(strategy.get_parameters())
|
||||
return d
|
||||
|
||||
def get_strategy_value(self, strategy_name: str, parameter:str):
|
||||
"""获取策略的某个参数值"""
|
||||
strategy = self.strategies.get(strategy_name)
|
||||
if not strategy:
|
||||
return None
|
||||
|
||||
value = getattr(strategy, parameter, None)
|
||||
return value
|
||||
|
||||
def compare_pos(self, strategy_pos_list=[]):
|
||||
"""
|
||||
对比账号&策略的持仓,不同的话则发出微信提醒
|
||||
|
@ -19,6 +19,7 @@ from vnpy.trader.utility import virtual, append_data, extract_vt_symbol, get_und
|
||||
from .base import StopOrder
|
||||
from vnpy.component.cta_grid_trade import CtaGrid, CtaGridTrade
|
||||
from vnpy.component.cta_position import CtaPosition
|
||||
from vnpy.component.cta_policy import CtaPolicy
|
||||
|
||||
|
||||
class CtaTemplate(ABC):
|
||||
@ -1376,6 +1377,8 @@ class CtaFutureTemplate(CtaTemplate):
|
||||
dist_data.update({'margin': dist_data.get('price', 0) * dist_data.get('volume',
|
||||
0) * self.cta_engine.get_margin_rate(
|
||||
dist_data.get('symbol', self.vt_symbol))})
|
||||
if 'datetime' not in dist_data:
|
||||
dist_data.update({'datetime': self.cur_datetime})
|
||||
if self.position and 'long_pos' not in dist_data:
|
||||
dist_data.update({'long_pos': self.position.long_pos})
|
||||
if self.position and 'short_pos' not in dist_data:
|
||||
@ -1408,3 +1411,92 @@ class CtaFutureTemplate(CtaTemplate):
|
||||
if self.backtesting:
|
||||
return
|
||||
self.cta_engine.send_wechat(msg=msg, strategy=self)
|
||||
|
||||
|
||||
class MultiContractPolicy(CtaPolicy):
|
||||
"""多合约Policy,记录持仓"""
|
||||
|
||||
def __init__(self, strategy=None, **kwargs):
|
||||
super().__init__(strategy, **kwargs)
|
||||
self.debug = kwargs.get('debug', False)
|
||||
self.positions = {} # vt_symbol: net_pos
|
||||
|
||||
def from_json(self, json_data):
|
||||
"""将数据从json_data中恢复"""
|
||||
super().from_json(json_data)
|
||||
|
||||
self.positions = json_data.get('positions')
|
||||
|
||||
def to_json(self):
|
||||
"""转换至json文件"""
|
||||
j = super().to_json()
|
||||
j['positions'] = self.positions
|
||||
return j
|
||||
|
||||
def on_trade(self, trade: TradeData):
|
||||
"""更新交易"""
|
||||
pos = self.positions.get(trade.vt_symbol)
|
||||
|
||||
if pos is None:
|
||||
pos = 0
|
||||
pre_pos = pos
|
||||
if trade.direction == Direction.LONG:
|
||||
pos = round(pos + trade.volume, 7)
|
||||
|
||||
elif trade.direction == Direction.SHORT:
|
||||
pos = round(pos - trade.volume, 7)
|
||||
|
||||
self.positions.update({trade.vt_symbol: pos})
|
||||
|
||||
if self.debug and self.strategy:
|
||||
self.strategy.write_log(f'{trade.vt_symbol} pos:{pre_pos}=>{pos}')
|
||||
|
||||
self.save()
|
||||
|
||||
|
||||
class MultiContractTemplate(CtaTemplate):
|
||||
"""多合约交易模板"""
|
||||
|
||||
def __init__(self, cta_engine, strategy_name, vt_symbol, setting):
|
||||
|
||||
self.policy = None
|
||||
self.cur_datetime = None
|
||||
super().__init__(cta_engine, strategy_name, vt_symbol, setting)
|
||||
|
||||
self.policy = MultiContractPolicy(strategy=self, debug=True)
|
||||
|
||||
def sync_data(self):
|
||||
"""同步更新数据"""
|
||||
|
||||
if self.inited and self.trading:
|
||||
self.write_log(u'保存policy数据')
|
||||
self.policy.save()
|
||||
|
||||
def on_trade(self, trade: TradeData):
|
||||
"""成交回报事件处理"""
|
||||
self.policy.on_trade(trade)
|
||||
|
||||
def get_positions(self):
|
||||
""" 获取策略所有持仓详细"""
|
||||
pos_list = []
|
||||
|
||||
for vt_symbol, pos in self.policy.positions.items():
|
||||
pos_list.append({'vt_symbol': vt_symbol,
|
||||
'direction': 'long' if pos >= 0 else 'short',
|
||||
'volume': pos})
|
||||
|
||||
if self.cur_datetime and (datetime.now() - self.cur_datetime).total_seconds() < 10:
|
||||
self.write_log(u'{}当前持仓:{}'.format(self.strategy_name, pos_list))
|
||||
return pos_list
|
||||
|
||||
def on_order(self, order: OrderData):
|
||||
pass
|
||||
|
||||
def on_init(self):
|
||||
self.inited = True
|
||||
|
||||
def on_start(self):
|
||||
self.trading = True
|
||||
|
||||
def on_stop(self):
|
||||
self.trading = False
|
||||
|
Loading…
Reference in New Issue
Block a user