From b85188298e1372f3ffa6216249baed1f93a6cf17 Mon Sep 17 00:00:00 2001 From: msincenselee Date: Fri, 24 Apr 2020 11:41:20 +0800 Subject: [PATCH] =?UTF-8?q?[=E5=A2=9E=E5=BC=BA=E5=8A=9F=E8=83=BD]=20?= =?UTF-8?q?=E6=95=B0=E5=AD=97=E7=AD=96=E7=95=A5=E9=97=B4=E4=BA=92=E8=AE=BF?= =?UTF-8?q?=E9=97=AE=EF=BC=8C=E5=A4=9A=E5=90=88=E7=BA=A6=E6=A8=A1=E6=9D=BF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- vnpy/app/cta_crypto/__init__.py | 1 + vnpy/app/cta_crypto/back_testing.py | 9 +++ vnpy/app/cta_crypto/engine.py | 37 +++++------- vnpy/app/cta_crypto/template.py | 92 +++++++++++++++++++++++++++++ 4 files changed, 117 insertions(+), 22 deletions(-) diff --git a/vnpy/app/cta_crypto/__init__.py b/vnpy/app/cta_crypto/__init__.py index ddb80071..326ff12a 100644 --- a/vnpy/app/cta_crypto/__init__.py +++ b/vnpy/app/cta_crypto/__init__.py @@ -9,6 +9,7 @@ from .template import ( Direction, Offset, Status, + OrderType, Interval, TickData, BarData, diff --git a/vnpy/app/cta_crypto/back_testing.py b/vnpy/app/cta_crypto/back_testing.py index ae0333b0..6d79acae 100644 --- a/vnpy/app/cta_crypto/back_testing.py +++ b/vnpy/app/cta_crypto/back_testing.py @@ -419,6 +419,15 @@ class BackTestingEngine(object): self.positions[k] = pos return pos + def get_strategy_value(self, strategy_name: str, parameter:str): + """获取策略的某个参数值""" + strategy = self.strategies.get(strategy_name) + if not strategy: + return None + + value = getattr(strategy, parameter, None) + return value + def set_name(self, test_name): """ 设置组合的运行实例名称 diff --git a/vnpy/app/cta_crypto/engine.py b/vnpy/app/cta_crypto/engine.py index de0363dc..e61acd91 100644 --- a/vnpy/app/cta_crypto/engine.py +++ b/vnpy/app/cta_crypto/engine.py @@ -62,7 +62,6 @@ from vnpy.trader.utility import ( from vnpy.trader.util_logger import setup_logger, logging from vnpy.trader.util_wechat import send_wx_msg -from vnpy.trader.converter import PositionHolding from .base import ( APP_NAME, @@ -223,7 +222,6 @@ class CtaEngine(BaseEngine): # 推送到事件 self.put_all_strategy_pos_event(all_strategy_pos) - def process_tick_event(self, event: Event): """处理tick到达事件""" tick = event.data @@ -358,8 +356,6 @@ class CtaEngine(BaseEngine): contract = self.main_engine.get_contract(vt_symbol) is_bar = True if vt_symbol in self.bar_strategy_map else False if contract: - dt = datetime.now() - self.write_log(f'重新提交合约{vt_symbol}订阅请求') for strategy_name, is_bar in list(self.pending_subcribe_symbol_map[vt_symbol]): self.subscribe_symbol(strategy_name=strategy_name, @@ -689,7 +685,7 @@ class CtaEngine(BaseEngine): volume=volume, type=order_type, gateway_name=gateway_name - ) + ) def cancel_order(self, strategy: CtaTemplate, vt_orderid: str): """ @@ -825,7 +821,7 @@ class CtaEngine(BaseEngine): else: return 0, 0, 0, 0 - def get_position(self, vt_symbol: str, direction: Direction, gateway_name: str = ''): + def get_position(self, vt_symbol: str, direction: Direction = Direction.NET, gateway_name: str = ''): """ 查询合约在账号的持仓,需要指定方向""" contract = self.main_engine.get_contract(vt_symbol) if contract: @@ -888,19 +884,7 @@ class CtaEngine(BaseEngine): callback: Callable[[TickData], None] ): """""" - symbol, exchange = extract_vt_symbol(vt_symbol) - end = datetime.now() - start = end - timedelta(days) - - ticks = database_manager.load_tick_data( - symbol=symbol, - exchange=exchange, - start=start, - end=end, - ) - - for tick in ticks: - callback(tick) + pass def call_strategy_func( self, strategy: CtaTemplate, func: Callable, params: Any = None @@ -1255,7 +1239,7 @@ class CtaEngine(BaseEngine): # 通过事件方式,传导到account_recorder snapshot.update({ 'account_id': self.engine_config.get('accountid', '-'), - 'strategy_group': self.engine_config.get('strategy_group', self.engine_name), + 'strategy_group': self.engine_config.get('strategy_group', self.engine_name), 'guid': str(uuid1()) }) event = Event(EVENT_STRATEGY_SNAPSHOT, snapshot) @@ -1460,7 +1444,7 @@ class CtaEngine(BaseEngine): return parameters - def get_strategy_parameters(self, strategy_name): + def get_strategy_parameters(self, strategy_name: str): """ Get parameters of a strategy. """ @@ -1472,6 +1456,15 @@ class CtaEngine(BaseEngine): d.update(strategy.get_parameters()) return d + def get_strategy_value(self, strategy_name: str, parameter:str): + """获取策略的某个参数值""" + strategy = self.strategies.get(strategy_name) + if not strategy: + return None + + value = getattr(strategy, parameter, None) + return value + def compare_pos(self, strategy_pos_list=[]): """ 对比账号&策略的持仓,不同的话则发出微信提醒 @@ -1535,7 +1528,7 @@ class CtaEngine(BaseEngine): u'{}({})'.format(strategy_pos['strategy_name'], abs(pos.get('volume', 0)))) self.write_log(u'更新{}策略持空仓=>{}'.format(vt_symbol, symbol_pos.get('策略空单', 0))) if pos.get('direction') == 'long': - symbol_pos.update({'策略多单': round(symbol_pos.get('策略多单', 0) + abs(pos.get('volume', 0)),7)}) + symbol_pos.update({'策略多单': round(symbol_pos.get('策略多单', 0) + abs(pos.get('volume', 0)), 7)}) symbol_pos['多单策略'].append( u'{}({})'.format(strategy_pos['strategy_name'], abs(pos.get('volume', 0)))) self.write_log(u'更新{}策略持多仓=>{}'.format(vt_symbol, symbol_pos.get('策略多单', 0))) diff --git a/vnpy/app/cta_crypto/template.py b/vnpy/app/cta_crypto/template.py index e8d25597..99d5dcc9 100644 --- a/vnpy/app/cta_crypto/template.py +++ b/vnpy/app/cta_crypto/template.py @@ -19,6 +19,7 @@ from vnpy.trader.utility import virtual, append_data, extract_vt_symbol, get_und from .base import StopOrder from vnpy.component.cta_grid_trade import CtaGrid, CtaGridTrade from vnpy.component.cta_position import CtaPosition +from vnpy.component.cta_policy import CtaPolicy class CtaTemplate(ABC): @@ -1376,6 +1377,8 @@ class CtaFutureTemplate(CtaTemplate): dist_data.update({'margin': dist_data.get('price', 0) * dist_data.get('volume', 0) * self.cta_engine.get_margin_rate( dist_data.get('symbol', self.vt_symbol))}) + if 'datetime' not in dist_data: + dist_data.update({'datetime': self.cur_datetime}) if self.position and 'long_pos' not in dist_data: dist_data.update({'long_pos': self.position.long_pos}) if self.position and 'short_pos' not in dist_data: @@ -1408,3 +1411,92 @@ class CtaFutureTemplate(CtaTemplate): if self.backtesting: return self.cta_engine.send_wechat(msg=msg, strategy=self) + + +class MultiContractPolicy(CtaPolicy): + """多合约Policy,记录持仓""" + + def __init__(self, strategy=None, **kwargs): + super().__init__(strategy, **kwargs) + self.debug = kwargs.get('debug', False) + self.positions = {} # vt_symbol: net_pos + + def from_json(self, json_data): + """将数据从json_data中恢复""" + super().from_json(json_data) + + self.positions = json_data.get('positions') + + def to_json(self): + """转换至json文件""" + j = super().to_json() + j['positions'] = self.positions + return j + + def on_trade(self, trade: TradeData): + """更新交易""" + pos = self.positions.get(trade.vt_symbol) + + if pos is None: + pos = 0 + pre_pos = pos + if trade.direction == Direction.LONG: + pos = round(pos + trade.volume, 7) + + elif trade.direction == Direction.SHORT: + pos = round(pos - trade.volume, 7) + + self.positions.update({trade.vt_symbol: pos}) + + if self.debug and self.strategy: + self.strategy.write_log(f'{trade.vt_symbol} pos:{pre_pos}=>{pos}') + + self.save() + + +class MultiContractTemplate(CtaTemplate): + """多合约交易模板""" + + def __init__(self, cta_engine, strategy_name, vt_symbol, setting): + + self.policy = None + self.cur_datetime = None + super().__init__(cta_engine, strategy_name, vt_symbol, setting) + + self.policy = MultiContractPolicy(strategy=self, debug=True) + + def sync_data(self): + """同步更新数据""" + + if self.inited and self.trading: + self.write_log(u'保存policy数据') + self.policy.save() + + def on_trade(self, trade: TradeData): + """成交回报事件处理""" + self.policy.on_trade(trade) + + def get_positions(self): + """ 获取策略所有持仓详细""" + pos_list = [] + + for vt_symbol, pos in self.policy.positions.items(): + pos_list.append({'vt_symbol': vt_symbol, + 'direction': 'long' if pos >= 0 else 'short', + 'volume': pos}) + + if self.cur_datetime and (datetime.now() - self.cur_datetime).total_seconds() < 10: + self.write_log(u'{}当前持仓:{}'.format(self.strategy_name, pos_list)) + return pos_list + + def on_order(self, order: OrderData): + pass + + def on_init(self): + self.inited = True + + def on_start(self): + self.trading = True + + def on_stop(self): + self.trading = False