[增强功能] 数字策略间互访问,多合约模板
This commit is contained in:
parent
40d4170711
commit
b85188298e
@ -9,6 +9,7 @@ from .template import (
|
|||||||
Direction,
|
Direction,
|
||||||
Offset,
|
Offset,
|
||||||
Status,
|
Status,
|
||||||
|
OrderType,
|
||||||
Interval,
|
Interval,
|
||||||
TickData,
|
TickData,
|
||||||
BarData,
|
BarData,
|
||||||
|
@ -419,6 +419,15 @@ class BackTestingEngine(object):
|
|||||||
self.positions[k] = pos
|
self.positions[k] = pos
|
||||||
return pos
|
return pos
|
||||||
|
|
||||||
|
def get_strategy_value(self, strategy_name: str, parameter:str):
|
||||||
|
"""获取策略的某个参数值"""
|
||||||
|
strategy = self.strategies.get(strategy_name)
|
||||||
|
if not strategy:
|
||||||
|
return None
|
||||||
|
|
||||||
|
value = getattr(strategy, parameter, None)
|
||||||
|
return value
|
||||||
|
|
||||||
def set_name(self, test_name):
|
def set_name(self, test_name):
|
||||||
"""
|
"""
|
||||||
设置组合的运行实例名称
|
设置组合的运行实例名称
|
||||||
|
@ -62,7 +62,6 @@ from vnpy.trader.utility import (
|
|||||||
|
|
||||||
from vnpy.trader.util_logger import setup_logger, logging
|
from vnpy.trader.util_logger import setup_logger, logging
|
||||||
from vnpy.trader.util_wechat import send_wx_msg
|
from vnpy.trader.util_wechat import send_wx_msg
|
||||||
from vnpy.trader.converter import PositionHolding
|
|
||||||
|
|
||||||
from .base import (
|
from .base import (
|
||||||
APP_NAME,
|
APP_NAME,
|
||||||
@ -223,7 +222,6 @@ class CtaEngine(BaseEngine):
|
|||||||
# 推送到事件
|
# 推送到事件
|
||||||
self.put_all_strategy_pos_event(all_strategy_pos)
|
self.put_all_strategy_pos_event(all_strategy_pos)
|
||||||
|
|
||||||
|
|
||||||
def process_tick_event(self, event: Event):
|
def process_tick_event(self, event: Event):
|
||||||
"""处理tick到达事件"""
|
"""处理tick到达事件"""
|
||||||
tick = event.data
|
tick = event.data
|
||||||
@ -358,8 +356,6 @@ class CtaEngine(BaseEngine):
|
|||||||
contract = self.main_engine.get_contract(vt_symbol)
|
contract = self.main_engine.get_contract(vt_symbol)
|
||||||
is_bar = True if vt_symbol in self.bar_strategy_map else False
|
is_bar = True if vt_symbol in self.bar_strategy_map else False
|
||||||
if contract:
|
if contract:
|
||||||
dt = datetime.now()
|
|
||||||
|
|
||||||
self.write_log(f'重新提交合约{vt_symbol}订阅请求')
|
self.write_log(f'重新提交合约{vt_symbol}订阅请求')
|
||||||
for strategy_name, is_bar in list(self.pending_subcribe_symbol_map[vt_symbol]):
|
for strategy_name, is_bar in list(self.pending_subcribe_symbol_map[vt_symbol]):
|
||||||
self.subscribe_symbol(strategy_name=strategy_name,
|
self.subscribe_symbol(strategy_name=strategy_name,
|
||||||
@ -825,7 +821,7 @@ class CtaEngine(BaseEngine):
|
|||||||
else:
|
else:
|
||||||
return 0, 0, 0, 0
|
return 0, 0, 0, 0
|
||||||
|
|
||||||
def get_position(self, vt_symbol: str, direction: Direction, gateway_name: str = ''):
|
def get_position(self, vt_symbol: str, direction: Direction = Direction.NET, gateway_name: str = ''):
|
||||||
""" 查询合约在账号的持仓,需要指定方向"""
|
""" 查询合约在账号的持仓,需要指定方向"""
|
||||||
contract = self.main_engine.get_contract(vt_symbol)
|
contract = self.main_engine.get_contract(vt_symbol)
|
||||||
if contract:
|
if contract:
|
||||||
@ -888,19 +884,7 @@ class CtaEngine(BaseEngine):
|
|||||||
callback: Callable[[TickData], None]
|
callback: Callable[[TickData], None]
|
||||||
):
|
):
|
||||||
""""""
|
""""""
|
||||||
symbol, exchange = extract_vt_symbol(vt_symbol)
|
pass
|
||||||
end = datetime.now()
|
|
||||||
start = end - timedelta(days)
|
|
||||||
|
|
||||||
ticks = database_manager.load_tick_data(
|
|
||||||
symbol=symbol,
|
|
||||||
exchange=exchange,
|
|
||||||
start=start,
|
|
||||||
end=end,
|
|
||||||
)
|
|
||||||
|
|
||||||
for tick in ticks:
|
|
||||||
callback(tick)
|
|
||||||
|
|
||||||
def call_strategy_func(
|
def call_strategy_func(
|
||||||
self, strategy: CtaTemplate, func: Callable, params: Any = None
|
self, strategy: CtaTemplate, func: Callable, params: Any = None
|
||||||
@ -1460,7 +1444,7 @@ class CtaEngine(BaseEngine):
|
|||||||
|
|
||||||
return parameters
|
return parameters
|
||||||
|
|
||||||
def get_strategy_parameters(self, strategy_name):
|
def get_strategy_parameters(self, strategy_name: str):
|
||||||
"""
|
"""
|
||||||
Get parameters of a strategy.
|
Get parameters of a strategy.
|
||||||
"""
|
"""
|
||||||
@ -1472,6 +1456,15 @@ class CtaEngine(BaseEngine):
|
|||||||
d.update(strategy.get_parameters())
|
d.update(strategy.get_parameters())
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
def get_strategy_value(self, strategy_name: str, parameter:str):
|
||||||
|
"""获取策略的某个参数值"""
|
||||||
|
strategy = self.strategies.get(strategy_name)
|
||||||
|
if not strategy:
|
||||||
|
return None
|
||||||
|
|
||||||
|
value = getattr(strategy, parameter, None)
|
||||||
|
return value
|
||||||
|
|
||||||
def compare_pos(self, strategy_pos_list=[]):
|
def compare_pos(self, strategy_pos_list=[]):
|
||||||
"""
|
"""
|
||||||
对比账号&策略的持仓,不同的话则发出微信提醒
|
对比账号&策略的持仓,不同的话则发出微信提醒
|
||||||
|
@ -19,6 +19,7 @@ from vnpy.trader.utility import virtual, append_data, extract_vt_symbol, get_und
|
|||||||
from .base import StopOrder
|
from .base import StopOrder
|
||||||
from vnpy.component.cta_grid_trade import CtaGrid, CtaGridTrade
|
from vnpy.component.cta_grid_trade import CtaGrid, CtaGridTrade
|
||||||
from vnpy.component.cta_position import CtaPosition
|
from vnpy.component.cta_position import CtaPosition
|
||||||
|
from vnpy.component.cta_policy import CtaPolicy
|
||||||
|
|
||||||
|
|
||||||
class CtaTemplate(ABC):
|
class CtaTemplate(ABC):
|
||||||
@ -1376,6 +1377,8 @@ class CtaFutureTemplate(CtaTemplate):
|
|||||||
dist_data.update({'margin': dist_data.get('price', 0) * dist_data.get('volume',
|
dist_data.update({'margin': dist_data.get('price', 0) * dist_data.get('volume',
|
||||||
0) * self.cta_engine.get_margin_rate(
|
0) * self.cta_engine.get_margin_rate(
|
||||||
dist_data.get('symbol', self.vt_symbol))})
|
dist_data.get('symbol', self.vt_symbol))})
|
||||||
|
if 'datetime' not in dist_data:
|
||||||
|
dist_data.update({'datetime': self.cur_datetime})
|
||||||
if self.position and 'long_pos' not in dist_data:
|
if self.position and 'long_pos' not in dist_data:
|
||||||
dist_data.update({'long_pos': self.position.long_pos})
|
dist_data.update({'long_pos': self.position.long_pos})
|
||||||
if self.position and 'short_pos' not in dist_data:
|
if self.position and 'short_pos' not in dist_data:
|
||||||
@ -1408,3 +1411,92 @@ class CtaFutureTemplate(CtaTemplate):
|
|||||||
if self.backtesting:
|
if self.backtesting:
|
||||||
return
|
return
|
||||||
self.cta_engine.send_wechat(msg=msg, strategy=self)
|
self.cta_engine.send_wechat(msg=msg, strategy=self)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiContractPolicy(CtaPolicy):
|
||||||
|
"""多合约Policy,记录持仓"""
|
||||||
|
|
||||||
|
def __init__(self, strategy=None, **kwargs):
|
||||||
|
super().__init__(strategy, **kwargs)
|
||||||
|
self.debug = kwargs.get('debug', False)
|
||||||
|
self.positions = {} # vt_symbol: net_pos
|
||||||
|
|
||||||
|
def from_json(self, json_data):
|
||||||
|
"""将数据从json_data中恢复"""
|
||||||
|
super().from_json(json_data)
|
||||||
|
|
||||||
|
self.positions = json_data.get('positions')
|
||||||
|
|
||||||
|
def to_json(self):
|
||||||
|
"""转换至json文件"""
|
||||||
|
j = super().to_json()
|
||||||
|
j['positions'] = self.positions
|
||||||
|
return j
|
||||||
|
|
||||||
|
def on_trade(self, trade: TradeData):
|
||||||
|
"""更新交易"""
|
||||||
|
pos = self.positions.get(trade.vt_symbol)
|
||||||
|
|
||||||
|
if pos is None:
|
||||||
|
pos = 0
|
||||||
|
pre_pos = pos
|
||||||
|
if trade.direction == Direction.LONG:
|
||||||
|
pos = round(pos + trade.volume, 7)
|
||||||
|
|
||||||
|
elif trade.direction == Direction.SHORT:
|
||||||
|
pos = round(pos - trade.volume, 7)
|
||||||
|
|
||||||
|
self.positions.update({trade.vt_symbol: pos})
|
||||||
|
|
||||||
|
if self.debug and self.strategy:
|
||||||
|
self.strategy.write_log(f'{trade.vt_symbol} pos:{pre_pos}=>{pos}')
|
||||||
|
|
||||||
|
self.save()
|
||||||
|
|
||||||
|
|
||||||
|
class MultiContractTemplate(CtaTemplate):
|
||||||
|
"""多合约交易模板"""
|
||||||
|
|
||||||
|
def __init__(self, cta_engine, strategy_name, vt_symbol, setting):
|
||||||
|
|
||||||
|
self.policy = None
|
||||||
|
self.cur_datetime = None
|
||||||
|
super().__init__(cta_engine, strategy_name, vt_symbol, setting)
|
||||||
|
|
||||||
|
self.policy = MultiContractPolicy(strategy=self, debug=True)
|
||||||
|
|
||||||
|
def sync_data(self):
|
||||||
|
"""同步更新数据"""
|
||||||
|
|
||||||
|
if self.inited and self.trading:
|
||||||
|
self.write_log(u'保存policy数据')
|
||||||
|
self.policy.save()
|
||||||
|
|
||||||
|
def on_trade(self, trade: TradeData):
|
||||||
|
"""成交回报事件处理"""
|
||||||
|
self.policy.on_trade(trade)
|
||||||
|
|
||||||
|
def get_positions(self):
|
||||||
|
""" 获取策略所有持仓详细"""
|
||||||
|
pos_list = []
|
||||||
|
|
||||||
|
for vt_symbol, pos in self.policy.positions.items():
|
||||||
|
pos_list.append({'vt_symbol': vt_symbol,
|
||||||
|
'direction': 'long' if pos >= 0 else 'short',
|
||||||
|
'volume': pos})
|
||||||
|
|
||||||
|
if self.cur_datetime and (datetime.now() - self.cur_datetime).total_seconds() < 10:
|
||||||
|
self.write_log(u'{}当前持仓:{}'.format(self.strategy_name, pos_list))
|
||||||
|
return pos_list
|
||||||
|
|
||||||
|
def on_order(self, order: OrderData):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_init(self):
|
||||||
|
self.inited = True
|
||||||
|
|
||||||
|
def on_start(self):
|
||||||
|
self.trading = True
|
||||||
|
|
||||||
|
def on_stop(self):
|
||||||
|
self.trading = False
|
||||||
|
Loading…
Reference in New Issue
Block a user