From fa22e74a8a47e2137331c3997da847f8c6954bd3 Mon Sep 17 00:00:00 2001 From: msincenselee Date: Mon, 18 May 2020 16:57:58 +0800 Subject: [PATCH] =?UTF-8?q?[=E5=A2=9E=E5=BC=BA=E5=8A=9F=E8=83=BD]=20?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0vnpy=E5=8E=9F=E7=89=88=E7=BB=84=E5=90=88?= =?UTF-8?q?=E5=BC=95=E6=93=8E=EF=BC=8C=E4=BF=AE=E6=94=B9gateway=EF=BC=8C?= =?UTF-8?q?=E5=8F=96=E6=B6=88=E5=A4=9A=E4=BD=99event=EF=BC=8C=E6=8F=90?= =?UTF-8?q?=E9=AB=98=E6=80=A7=E8=83=BD=EF=BC=9B=E6=9B=B4=E6=96=B0K?= =?UTF-8?q?=E7=BA=BF=E7=BB=84=E4=BB=B6/=E7=BD=91=E6=A0=BC=E7=BB=84?= =?UTF-8?q?=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- vnpy/app/account_recorder/engine.py | 5 +- vnpy/app/algo_trading/engine.py | 123 ++- vnpy/app/algo_trading/template.py | 2 +- vnpy/app/cta_crypto/back_testing.py | 8 +- vnpy/app/cta_strategy_pro/engine.py | 27 +- vnpy/app/portfolio_strategy/__init__.py | 23 + vnpy/app/portfolio_strategy/backtesting.py | 821 ++++++++++++++++++ vnpy/app/portfolio_strategy/base.py | 17 + vnpy/app/portfolio_strategy/engine.py | 624 +++++++++++++ .../portfolio_strategy/strategies/__init__.py | 0 .../strategies/trend_following_strategy.py | 139 +++ vnpy/app/portfolio_strategy/template.py | 247 ++++++ vnpy/app/portfolio_strategy/ui/__init__.py | 1 + vnpy/app/portfolio_strategy/ui/strategy.ico | Bin 0 -> 67646 bytes vnpy/app/portfolio_strategy/ui/widget.py | 439 ++++++++++ vnpy/component/cta_grid_trade.py | 20 +- vnpy/component/cta_line_bar.py | 59 +- vnpy/trader/engine.py | 123 ++- vnpy/trader/gateway.py | 10 +- vnpy/trader/utility.py | 56 +- 20 files changed, 2637 insertions(+), 107 deletions(-) create mode 100644 vnpy/app/portfolio_strategy/__init__.py create mode 100644 vnpy/app/portfolio_strategy/backtesting.py create mode 100644 vnpy/app/portfolio_strategy/base.py create mode 100644 vnpy/app/portfolio_strategy/engine.py create mode 100644 vnpy/app/portfolio_strategy/strategies/__init__.py create mode 100644 vnpy/app/portfolio_strategy/strategies/trend_following_strategy.py create mode 100644 vnpy/app/portfolio_strategy/template.py create mode 100644 vnpy/app/portfolio_strategy/ui/__init__.py create mode 100644 vnpy/app/portfolio_strategy/ui/strategy.ico create mode 100644 vnpy/app/portfolio_strategy/ui/widget.py diff --git a/vnpy/app/account_recorder/engine.py b/vnpy/app/account_recorder/engine.py index 2ff12352..9f660049 100644 --- a/vnpy/app/account_recorder/engine.py +++ b/vnpy/app/account_recorder/engine.py @@ -330,10 +330,13 @@ class AccountRecorder(BaseEngine): # self.write_log(u'记录委托日志:{}'.format(order.__dict__)) if len(order.sys_orderid) == 0: # 未有系统的委托编号,不做持久化 - return + order.sys_orderid = order.orderid dt = getattr(order, 'datetime') if not dt: order_date = datetime.now().strftime('%Y-%m-%d') + if len(order.time) > 0 and '.' not in order.time: + dt = datetime.strptime(f'{order_date} {order.time}', '%Y-%m-%d %H:%M:%S') + order.datetime = dt else: order_date = dt.strftime('%Y-%m-%d') diff --git a/vnpy/app/algo_trading/engine.py b/vnpy/app/algo_trading/engine.py index ee7af87b..e50ceb06 100644 --- a/vnpy/app/algo_trading/engine.py +++ b/vnpy/app/algo_trading/engine.py @@ -6,11 +6,12 @@ from functools import lru_cache from vnpy.event import EventEngine, Event from vnpy.trader.engine import BaseEngine, MainEngine from vnpy.trader.event import ( - EVENT_TICK, EVENT_TIMER, EVENT_ORDER, EVENT_TRADE) + EVENT_TICK, EVENT_TIMER, EVENT_ORDER, EVENT_TRADE, EVENT_POSITION) from vnpy.trader.constant import (Direction, Offset, OrderType, Status) from vnpy.trader.object import (SubscribeRequest, OrderRequest, LogData, CancelRequest) -from vnpy.trader.utility import load_json, save_json, round_to, get_folder_path +from vnpy.trader.utility import load_json, save_json, round_to, get_folder_path,print_dict from vnpy.trader.util_logger import setup_logger, logging +from vnpy.trader.converter import OffsetConverter from .template import AlgoTemplate @@ -35,13 +36,13 @@ class AlgoEngine(BaseEngine): self.symbol_algo_map = {} self.orderid_algo_map = {} - self.algo_vtorderid_order_map = {} # 记录外部发起的算法交易委托编号,便于通过算法引擎撤单 + self.spd_orders = {} # 记录外部发起的算法交易委托编号,便于通过算法引擎撤单 self.algo_templates = {} self.algo_settings = {} self.algo_loggers = {} # algo_name: logger - + self.offset_converter = OffsetConverter(self.main_engine) self.load_algo_template() self.register_event() @@ -60,6 +61,7 @@ class AlgoEngine(BaseEngine): from .algos.grid_algo import GridAlgo from .algos.dma_algo import DmaAlgo from .algos.arbitrage_algo import ArbitrageAlgo + from .algos.spread_algo_v2 import SpreadAlgoV2 self.add_algo_template(TwapAlgo) self.add_algo_template(IcebergAlgo) @@ -69,6 +71,7 @@ class AlgoEngine(BaseEngine): self.add_algo_template(GridAlgo) self.add_algo_template(DmaAlgo) self.add_algo_template(ArbitrageAlgo) + self.add_algo_template(SpreadAlgoV2) def add_algo_template(self, template: AlgoTemplate): """""" @@ -93,6 +96,7 @@ class AlgoEngine(BaseEngine): self.event_engine.register(EVENT_TIMER, self.process_timer_event) self.event_engine.register(EVENT_ORDER, self.process_order_event) self.event_engine.register(EVENT_TRADE, self.process_trade_event) + self.event_engine.register(EVENT_POSITION, self.process_position_event) def process_tick_event(self, event: Event): """""" @@ -114,7 +118,7 @@ class AlgoEngine(BaseEngine): def process_trade_event(self, event: Event): """""" trade = event.data - + self.offset_converter.update_trade(trade) algo = self.orderid_algo_map.get(trade.vt_orderid, None) if algo: algo.update_trade(trade) @@ -122,11 +126,17 @@ class AlgoEngine(BaseEngine): def process_order_event(self, event: Event): """""" order = event.data - + self.offset_converter.update_order(order) algo = self.orderid_algo_map.get(order.vt_orderid, None) if algo: algo.update_order(order) + def process_position_event(self, event: Event): + """""" + position = event.data + + self.offset_converter.update_position(position) + def start_algo(self, setting: dict): """""" template_name = setting["template_name"] @@ -144,6 +154,7 @@ class AlgoEngine(BaseEngine): if algo: algo.stop() self.algos.pop(algo_name) + return True def stop_all(self): """""" @@ -154,7 +165,7 @@ class AlgoEngine(BaseEngine): """""" contract = self.main_engine.get_contract(vt_symbol) if not contract: - self.write_log(f'订阅行情失败,找不到合约:{vt_symbol}', algo) + self.write_log(msg=f'订阅行情失败,找不到合约:{vt_symbol}', algo_name=algo.algo_name) return algos = self.symbol_algo_map.setdefault(vt_symbol, set()) @@ -186,9 +197,9 @@ class AlgoEngine(BaseEngine): volume = round_to(volume, contract.min_volume) if not volume: - return "" + return [] - req = OrderRequest( + original_req = OrderRequest( symbol=contract.symbol, exchange=contract.exchange, direction=direction, @@ -197,83 +208,88 @@ class AlgoEngine(BaseEngine): price=price, offset=offset ) - vt_orderid = self.main_engine.send_order(req, contract.gateway_name) + req_list = self.offset_converter.convert_order_request(req=original_req, lock=False, gateway_name=contract.gateway_name) + vt_orderids = [] + for req in req_list: + vt_orderid = self.main_engine.send_order(req, contract.gateway_name) + if not vt_orderid: + continue - self.orderid_algo_map[vt_orderid] = algo - return vt_orderid + vt_orderids.append(vt_orderid) + + self.offset_converter.update_order_request(req, vt_orderid, contract.gateway_name) + + self.orderid_algo_map[vt_orderid] = algo + + return vt_orderids def cancel_order(self, algo: AlgoTemplate, vt_orderid: str): """""" order = self.main_engine.get_order(vt_orderid) if not order: - self.write_log(f"委托撤单失败,找不到委托:{vt_orderid}", algo) - return + self.write_log(msg=f"委托撤单失败,找不到委托:{vt_orderid}", algo_name=algo.algo_name) + return False req = order.create_cancel_request() - self.main_engine.cancel_order(req, order.gateway_name) + return self.main_engine.cancel_order(req, order.gateway_name) - def send_algo_order(self, req: OrderRequest, gateway_name: str): - """发送算法交易指令""" - self.write_log(u'创建算法交易,gateway_name:{},strategy_name:{},vt_symbol:{},price:{},volume:{}' + def send_spd_order(self, req: OrderRequest, gateway_name: str): + """发送SPD算法交易指令""" + self.write_log(u'[SPD算法交易],gateway_name:{},strategy_name:{},vt_symbol:{},price:{},volume:{}' .format(gateway_name, req.strategy_name, req.vt_symbol, req.price, req.volume)) # 创建算法实例,由算法引擎启动 - trade_command = '' - if req.direction == Direction.LONG and req.offset == Offset.OPEN: - trade_command = 'Buy' - elif req.direction == Direction.SHORT and req.offset == Offset.OPEN: - trade_command = 'Short' - elif req.direction == Direction.SHORT and req.offset != Offset.OPEN: - trade_command = 'Sell' - elif req.direction == Direction.LONG and req.offset != Offset.OPEN: - trade_command = 'Cover' - - all_custom_contracts = self.main_engine.get_all_custom_contracts() - contract_setting = all_custom_contracts.get(req.vt_symbol, {}) - algo_setting = { - 'templateName': u'SpreadTrading套利', + custom_settings = self.main_engine.get_all_custom_contracts(rtn_setting=True) + contract = custom_settings.get(req.symbol, {}) + setting = { + 'template_name': u'SpreadAlgoV2', 'order_vt_symbol': req.vt_symbol, - 'order_command': trade_command, + 'order_direction': req.direction, + 'order_offset': req.offset, 'order_price': req.price, 'order_volume': req.volume, 'timer_interval': 60 * 60 * 24, 'strategy_name': req.strategy_name, 'gateway_name': gateway_name } - algo_setting.update(contract_setting) + # 更新算法配置 + setting.update(contract) # 算法引擎 - algo_name = self.start_algo(algo_setting) - self.write_log(u'send_algo_order(): start_algo {}={}'.format(algo_name, str(algo_setting))) + algo_name = self.start_algo(setting) + self.write_log(f'[SPD算法交易]: 实例id: {algo_name}, 配置:{print_dict(setting)}') - # 创建一个Order事件 + # 创建一个Order事件, 正在提交 order = req.create_order_data(orderid=algo_name, gateway_name=gateway_name) - order.orderTime = datetime.now().strftime('%H:%M:%S.%f') + order.datetime = datetime.now() + order.time = order.datetime.strftime('%H:%M:%S.%f') order.status = Status.SUBMITTING - event1 = Event(type=EVENT_ORDER, data=order) self.event_engine.put(event1) # 登记在本地的算法委托字典中 - self.algo_vtorderid_order_map.update({order.vt_orderid: order}) + self.spd_orders.update({order.orderid: order}) return order.vt_orderid - def is_algo_order(self, req: CancelRequest, gateway_name: str): + def get_spd_order(self, orderid): + """返回spd委托单""" + return self.spd_orders.get(orderid, None) + + def is_spd_order(self, req: CancelRequest): """是否为外部算法委托单""" - vt_orderid = '.'.join([req.orderid, gateway_name]) - if vt_orderid in self.algo_vtorderid_order_map: + if req.orderid in self.spd_orders: return True else: return False - def cancel_algo_order(self, req: CancelRequest, gateway_name: str): + def cancel_spd_order(self, req: CancelRequest): """外部算法单撤单""" - vt_orderid = '.'.join([req.orderid, gateway_name]) - order = self.algo_vtorderid_order_map.get(vt_orderid, None) + + order = self.spd_orders.get(req.orderid, None) if not order: - self.write_error(f'{vt_orderid}不在算法引擎中,撤单失败') + self.write_error(f'{req.orderid}不在算法引擎中,撤单失败') return False algo = self.algos.get(req.orderid, None) @@ -368,6 +384,19 @@ class AlgoEngine(BaseEngine): return contract + def get_position(self, vt_symbol: str, direction: Direction, gateway_name: str = ''): + """ 查询合约在账号的持仓,需要指定方向""" + if len(gateway_name) == 0: + contract = self.main_engine.get_contract(vt_symbol) + if contract and contract.gateway_name: + gateway_name = contract.gateway_name + vt_position_id = f"{gateway_name}.{vt_symbol}.{direction.value}" + return self.main_engine.get_position(vt_position_id) + + def get_position_holding(self, vt_symbol: str, gateway_name: str = ''): + """ 查询合约在账号的持仓(包含多空)""" + return self.offset_converter.get_position_holding(vt_symbol, gateway_name) + def write_log(self, msg: str, algo_name: str = None, level: int = logging.INFO): """增强版写日志""" if algo_name: diff --git a/vnpy/app/algo_trading/template.py b/vnpy/app/algo_trading/template.py index 1016aabd..31f30a86 100644 --- a/vnpy/app/algo_trading/template.py +++ b/vnpy/app/algo_trading/template.py @@ -154,7 +154,7 @@ class AlgoTemplate: def cancel_order(self, vt_orderid: str): """""" - self.algo_engine.cancel_order(self, vt_orderid) + return self.algo_engine.cancel_order(self, vt_orderid) def cancel_all(self): """""" diff --git a/vnpy/app/cta_crypto/back_testing.py b/vnpy/app/cta_crypto/back_testing.py index 84db0ee3..f52690e8 100644 --- a/vnpy/app/cta_crypto/back_testing.py +++ b/vnpy/app/cta_crypto/back_testing.py @@ -333,8 +333,8 @@ class BackTestingEngine(object): self.fix_commission.update({vt_symbol: rate}) def get_commission_rate(self, vt_symbol: str): - """ 获取保证金比例,缺省万分之一""" - return self.commission_rate.get(vt_symbol, float(0.00001)) + """ 获取保证金比例,缺省千分之2""" + return self.commission_rate.get(vt_symbol, float(0.0004)) def get_fix_commission(self, vt_symbol: str): return self.fix_commission.get(vt_symbol, 0) @@ -561,7 +561,7 @@ class BackTestingEngine(object): margin_rate = symbol_data.get('margin_rate', 0.1) self.set_margin_rate(symbol, margin_rate) - self.set_commission_rate(symbol, symbol_data.get('commission_rate', float(0.0001))) + self.set_commission_rate(symbol, symbol_data.get('commission_rate', float(0.0004))) self.set_contract( symbol=symbol, @@ -2232,7 +2232,7 @@ class TradingResult(object): self.volume = volume # 交易数量(+/-代表方向) self.group_id = group_id # 主交易ID(针对多手平仓) - self.turnover = (self.open_price + self.exit_price) * abs(volume) * margin_rate # 成交金额(实际保证金金额) + self.turnover = (self.open_price + self.exit_price) * abs(volume) # 成交金额(实际保证金金额) if fix_commission > 0: self.commission = fix_commission * abs(self.volume) else: diff --git a/vnpy/app/cta_strategy_pro/engine.py b/vnpy/app/cta_strategy_pro/engine.py index 1e06cac2..54c33e71 100644 --- a/vnpy/app/cta_strategy_pro/engine.py +++ b/vnpy/app/cta_strategy_pro/engine.py @@ -42,6 +42,7 @@ from vnpy.trader.event import ( ) from vnpy.trader.constant import ( Direction, + Exchange, OrderType, Offset, Status @@ -349,12 +350,27 @@ class CtaEngine(BaseEngine): # Update GUI self.put_strategy_event(strategy) + if self.engine_config.get('trade_2_wx', False): + accountid = self.engine_config.get('accountid', '-') + d = { + 'account': accountid, + 'strategy': strategy_name, + 'symbol': trade.symbol, + 'action': f'{trade.direction.value} {trade.offset.value}', + 'price': str(trade.price), + 'volume': trade.volume, + 'remark': f'{accountid}:{strategy_name}', + 'timestamp': trade.time + } + send_wx_msg(content=d, target=accountid) + def process_position_event(self, event: Event): """""" position = event.data self.offset_converter.update_position(position) + def check_unsubscribed_symbols(self): """检查未订阅合约""" @@ -1666,7 +1682,8 @@ class CtaEngine(BaseEngine): gateway_names = self.main_engine.get_all_gateway_names() gateway_name = gateway_names[0] if len(gateway_names) > 0 else "" symbol, exchange = extract_vt_symbol(vt_symbol) - self.main_engine.subscribe(req=SubscribeRequest(symbol=symbol, exchange=exchange), gateway_name=gateway_name) + self.main_engine.subscribe(req=SubscribeRequest(symbol=symbol, exchange=exchange), + gateway_name=gateway_name) if volume > 0 and tick: contract = self.main_engine.get_contract(vt_symbol) req = OrderRequest( @@ -1821,9 +1838,13 @@ class CtaEngine(BaseEngine): if level in [logging.CRITICAL, logging.ERROR, logging.WARNING]: print(f"{strategy_name}: {msg}" if strategy_name else msg, file=sys.stderr) - def write_error(self, msg: str, strategy_name: str = ''): + if level in [logging.CRITICAL, logging.WARN, logging.WARNING]: + send_wx_msg(content=f"{strategy_name}: {msg}" if strategy_name else msg) + + def write_error(self, msg: str, strategy_name: str = '', level: int = logging.ERROR): """写入错误日志""" - self.write_log(msg=msg, strategy_name=strategy_name, level=logging.ERROR) + self.write_log(msg=msg, strategy_name=strategy_name, level=level) + def send_email(self, msg: str, strategy: CtaTemplate = None): """ diff --git a/vnpy/app/portfolio_strategy/__init__.py b/vnpy/app/portfolio_strategy/__init__.py new file mode 100644 index 00000000..9e17fd99 --- /dev/null +++ b/vnpy/app/portfolio_strategy/__init__.py @@ -0,0 +1,23 @@ +from pathlib import Path + +from vnpy.trader.app import BaseApp +from vnpy.trader.constant import Direction +from vnpy.trader.object import TickData, BarData, TradeData, OrderData +from vnpy.trader.utility import BarGenerator, ArrayManager + +from .base import APP_NAME +from .engine import StrategyEngine +from .template import StrategyTemplate +from .backtesting import BacktestingEngine + + +class PortfolioStrategyApp(BaseApp): + """""" + + app_name = APP_NAME + app_module = __module__ + app_path = Path(__file__).parent + display_name = "组合策略" + engine_class = StrategyEngine + widget_name = "PortfolioStrategyManager" + icon_name = "strategy.ico" diff --git a/vnpy/app/portfolio_strategy/backtesting.py b/vnpy/app/portfolio_strategy/backtesting.py new file mode 100644 index 00000000..91bf2157 --- /dev/null +++ b/vnpy/app/portfolio_strategy/backtesting.py @@ -0,0 +1,821 @@ +from collections import defaultdict +from datetime import date, datetime, timedelta +from typing import Dict, List, Set, Tuple +from functools import lru_cache +import traceback + +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns +from pandas import DataFrame + +from vnpy.trader.constant import Direction, Offset, Interval, Status +from vnpy.trader.database import database_manager +from vnpy.trader.object import OrderData, TradeData, BarData +from vnpy.trader.utility import round_to, extract_vt_symbol + +from .template import StrategyTemplate + +# Set seaborn style +sns.set_style("whitegrid") + + +INTERVAL_DELTA_MAP = { + Interval.MINUTE: timedelta(minutes=1), + Interval.HOUR: timedelta(hours=1), + Interval.DAILY: timedelta(days=1), +} + + +class BacktestingEngine: + """""" + + gateway_name = "BACKTESTING" + + def __init__(self): + """""" + self.vt_symbols: List[str] = [] + self.start: datetime = None + self.end: datetime = None + + self.rates: Dict[str, float] = 0 + self.slippages: Dict[str, float] = 0 + self.sizes: Dict[str, float] = 1 + self.priceticks: Dict[str, float] = 0 + + self.capital: float = 1_000_000 + + self.strategy: StrategyTemplate = None + self.bars: Dict[str, BarData] = {} + self.datetime: datetime = None + + self.interval: Interval = None + self.days: int = 0 + self.history_data: Dict[Tuple, BarData] = {} + self.dts: Set[datetime] = set() + + self.limit_order_count = 0 + self.limit_orders = {} + self.active_limit_orders = {} + + self.trade_count = 0 + self.trades = {} + + self.logs = [] + + self.daily_results = {} + self.daily_df = None + + def clear_data(self) -> None: + """ + Clear all data of last backtesting. + """ + self.strategy = None + self.bars = {} + self.datetime = None + + self.limit_order_count = 0 + self.limit_orders.clear() + self.active_limit_orders.clear() + + self.trade_count = 0 + self.trades.clear() + + self.logs.clear() + self.daily_results.clear() + self.daily_df = None + + def set_parameters( + self, + vt_symbols: List[str], + interval: Interval, + start: datetime, + rates: Dict[str, float], + slippages: Dict[str, float], + sizes: Dict[str, float], + priceticks: Dict[str, float], + capital: int = 0, + end: datetime = None + ) -> None: + """""" + self.vt_symbols = vt_symbols + self.interval = interval + + self.rates = rates + self.slippages = slippages + self.sizes = sizes + self.priceticks = priceticks + + self.start = start + self.end = end + self.capital = capital + + def add_strategy(self, strategy_class: type, setting: dict) -> None: + """""" + self.strategy = strategy_class( + self, strategy_class.__name__, self.vt_symbols, setting + ) + + def load_data(self) -> None: + """""" + self.output("开始加载历史数据") + + if not self.end: + self.end = datetime.now() + + if self.start >= self.end: + self.output("起始日期必须小于结束日期") + return + + # Clear previously loaded history data + self.history_data.clear() + self.dts.clear() + + # Load 30 days of data each time and allow for progress update + progress_delta = timedelta(days=30) + total_delta = self.end - self.start + interval_delta = INTERVAL_DELTA_MAP[self.interval] + + for vt_symbol in self.vt_symbols: + start = self.start + end = self.start + progress_delta + progress = 0 + + data_count = 0 + while start < self.end: + end = min(end, self.end) # Make sure end time stays within set range + + data = load_bar_data( + vt_symbol, + self.interval, + start, + end + ) + + for bar in data: + self.dts.add(bar.datetime) + self.history_data[(bar.datetime, vt_symbol)] = bar + data_count += 1 + + progress += progress_delta / total_delta + progress = min(progress, 1) + progress_bar = "#" * int(progress * 10) + self.output(f"{vt_symbol}加载进度:{progress_bar} [{progress:.0%}]") + + start = end + interval_delta + end += (progress_delta + interval_delta) + + self.output(f"{vt_symbol}历史数据加载完成,数据量:{data_count}") + + self.output("所有历史数据加载完成") + + def run_backtesting(self) -> None: + """""" + self.strategy.on_init() + + # Generate sorted datetime list + dts = list(self.dts) + dts.sort() + + # Use the first [days] of history data for initializing strategy + day_count = 0 + ix = 0 + + for ix, dt in enumerate(dts): + if self.datetime and dt.day != self.datetime.day: + day_count += 1 + if day_count >= self.days: + break + + try: + self.new_bars(dt) + except Exception: + self.output("触发异常,回测终止") + self.output(traceback.format_exc()) + return + + self.strategy.inited = True + self.output("策略初始化完成") + + self.strategy.on_start() + self.strategy.trading = True + self.output("开始回放历史数据") + + # Use the rest of history data for running backtesting + for dt in dts[ix:]: + try: + self.new_bars(dt) + except Exception: + self.output("触发异常,回测终止") + self.output(traceback.format_exc()) + return + + self.output("历史数据回放结束") + + def calculate_result(self) -> None: + """""" + self.output("开始计算逐日盯市盈亏") + + if not self.trades: + self.output("成交记录为空,无法计算") + return + + # Add trade data into daily reuslt. + for trade in self.trades.values(): + d = trade.datetime.date() + daily_result = self.daily_results[d] + daily_result.add_trade(trade) + + # Calculate daily result by iteration. + pre_closes = {} + start_poses = {} + + for daily_result in self.daily_results.values(): + daily_result.calculate_pnl( + pre_closes, + start_poses, + self.sizes, + self.rates, + self.slippages, + ) + + pre_closes = daily_result.close_prices + start_poses = daily_result.end_poses + + # Generate dataframe + results = defaultdict(list) + + for daily_result in self.daily_results.values(): + fields = [ + "date", "trade_count", "turnover", + "commission", "slippage", "trading_pnl", + "holding_pnl", "total_pnl", "net_pnl" + ] + for key in fields: + value = getattr(daily_result, key) + results[key].append(value) + + self.daily_df = DataFrame.from_dict(results).set_index("date") + + self.output("逐日盯市盈亏计算完成") + return self.daily_df + + def calculate_statistics(self, df: DataFrame = None, output=True) -> None: + """""" + self.output("开始计算策略统计指标") + + # Check DataFrame input exterior + if df is None: + df = self.daily_df + + # Check for init DataFrame + if df is None: + # Set all statistics to 0 if no trade. + start_date = "" + end_date = "" + total_days = 0 + profit_days = 0 + loss_days = 0 + end_balance = 0 + max_drawdown = 0 + max_ddpercent = 0 + max_drawdown_duration = 0 + total_net_pnl = 0 + daily_net_pnl = 0 + total_commission = 0 + daily_commission = 0 + total_slippage = 0 + daily_slippage = 0 + total_turnover = 0 + daily_turnover = 0 + total_trade_count = 0 + daily_trade_count = 0 + total_return = 0 + annual_return = 0 + daily_return = 0 + return_std = 0 + sharpe_ratio = 0 + return_drawdown_ratio = 0 + else: + # Calculate balance related time series data + df["balance"] = df["net_pnl"].cumsum() + self.capital + df["return"] = np.log(df["balance"] / df["balance"].shift(1)).fillna(0) + df["highlevel"] = ( + df["balance"].rolling( + min_periods=1, window=len(df), center=False).max() + ) + df["drawdown"] = df["balance"] - df["highlevel"] + df["ddpercent"] = df["drawdown"] / df["highlevel"] * 100 + + # Calculate statistics value + start_date = df.index[0] + end_date = df.index[-1] + + total_days = len(df) + profit_days = len(df[df["net_pnl"] > 0]) + loss_days = len(df[df["net_pnl"] < 0]) + + end_balance = df["balance"].iloc[-1] + max_drawdown = df["drawdown"].min() + max_ddpercent = df["ddpercent"].min() + max_drawdown_end = df["drawdown"].idxmin() + + if isinstance(max_drawdown_end, date): + max_drawdown_start = df["balance"][:max_drawdown_end].idxmax() + max_drawdown_duration = (max_drawdown_end - max_drawdown_start).days + else: + max_drawdown_duration = 0 + + total_net_pnl = df["net_pnl"].sum() + daily_net_pnl = total_net_pnl / total_days + + total_commission = df["commission"].sum() + daily_commission = total_commission / total_days + + total_slippage = df["slippage"].sum() + daily_slippage = total_slippage / total_days + + total_turnover = df["turnover"].sum() + daily_turnover = total_turnover / total_days + + total_trade_count = df["trade_count"].sum() + daily_trade_count = total_trade_count / total_days + + total_return = (end_balance / self.capital - 1) * 100 + annual_return = total_return / total_days * 240 + daily_return = df["return"].mean() * 100 + return_std = df["return"].std() * 100 + + if return_std: + sharpe_ratio = daily_return / return_std * np.sqrt(240) + else: + sharpe_ratio = 0 + + return_drawdown_ratio = -total_return / max_ddpercent + + # Output + if output: + self.output("-" * 30) + self.output(f"首个交易日:\t{start_date}") + self.output(f"最后交易日:\t{end_date}") + + self.output(f"总交易日:\t{total_days}") + self.output(f"盈利交易日:\t{profit_days}") + self.output(f"亏损交易日:\t{loss_days}") + + self.output(f"起始资金:\t{self.capital:,.2f}") + self.output(f"结束资金:\t{end_balance:,.2f}") + + self.output(f"总收益率:\t{total_return:,.2f}%") + self.output(f"年化收益:\t{annual_return:,.2f}%") + self.output(f"最大回撤: \t{max_drawdown:,.2f}") + self.output(f"百分比最大回撤: {max_ddpercent:,.2f}%") + self.output(f"最长回撤天数: \t{max_drawdown_duration}") + + self.output(f"总盈亏:\t{total_net_pnl:,.2f}") + self.output(f"总手续费:\t{total_commission:,.2f}") + self.output(f"总滑点:\t{total_slippage:,.2f}") + self.output(f"总成交金额:\t{total_turnover:,.2f}") + self.output(f"总成交笔数:\t{total_trade_count}") + + self.output(f"日均盈亏:\t{daily_net_pnl:,.2f}") + self.output(f"日均手续费:\t{daily_commission:,.2f}") + self.output(f"日均滑点:\t{daily_slippage:,.2f}") + self.output(f"日均成交金额:\t{daily_turnover:,.2f}") + self.output(f"日均成交笔数:\t{daily_trade_count}") + + self.output(f"日均收益率:\t{daily_return:,.2f}%") + self.output(f"收益标准差:\t{return_std:,.2f}%") + self.output(f"Sharpe Ratio:\t{sharpe_ratio:,.2f}") + self.output(f"收益回撤比:\t{return_drawdown_ratio:,.2f}") + + statistics = { + "start_date": start_date, + "end_date": end_date, + "total_days": total_days, + "profit_days": profit_days, + "loss_days": loss_days, + "capital": self.capital, + "end_balance": end_balance, + "max_drawdown": max_drawdown, + "max_ddpercent": max_ddpercent, + "max_drawdown_duration": max_drawdown_duration, + "total_net_pnl": total_net_pnl, + "daily_net_pnl": daily_net_pnl, + "total_commission": total_commission, + "daily_commission": daily_commission, + "total_slippage": total_slippage, + "daily_slippage": daily_slippage, + "total_turnover": total_turnover, + "daily_turnover": daily_turnover, + "total_trade_count": total_trade_count, + "daily_trade_count": daily_trade_count, + "total_return": total_return, + "annual_return": annual_return, + "daily_return": daily_return, + "return_std": return_std, + "sharpe_ratio": sharpe_ratio, + "return_drawdown_ratio": return_drawdown_ratio, + } + + # Filter potential error infinite value + for key, value in statistics.items(): + if value in (np.inf, -np.inf): + value = 0 + statistics[key] = np.nan_to_num(value) + + self.output("策略统计指标计算完成") + return statistics + + def show_chart(self, df: DataFrame = None) -> None: + """""" + # Check DataFrame input exterior + if df is None: + df = self.daily_df + + # Check for init DataFrame + if df is None: + return + + plt.figure(figsize=(10, 16)) + + balance_plot = plt.subplot(4, 1, 1) + balance_plot.set_title("Balance") + df["balance"].plot(legend=True) + + drawdown_plot = plt.subplot(4, 1, 2) + drawdown_plot.set_title("Drawdown") + drawdown_plot.fill_between(range(len(df)), df["drawdown"].values) + + pnl_plot = plt.subplot(4, 1, 3) + pnl_plot.set_title("Daily Pnl") + df["net_pnl"].plot(kind="bar", legend=False, grid=False, xticks=[]) + + distribution_plot = plt.subplot(4, 1, 4) + distribution_plot.set_title("Daily Pnl Distribution") + df["net_pnl"].hist(bins=50) + + plt.show() + + def update_daily_close(self, bars: Dict[str, BarData], dt: datetime) -> None: + """""" + d = dt.date() + + close_prices = {} + for bar in bars.values(): + close_prices[bar.vt_symbol] = bar.close_price + + daily_result = self.daily_results.get(d, None) + + if daily_result: + daily_result.update_close_prices(close_prices) + else: + self.daily_results[d] = PortfolioDailyResult(d, close_prices) + + def new_bars(self, dt: datetime) -> None: + """""" + self.datetime = dt + + self.bars.clear() + for vt_symbol in self.vt_symbols: + bar = self.history_data.get((dt, vt_symbol), None) + if bar: + self.bars[vt_symbol] = bar + else: + dt_str = dt.strftime("%Y-%m-%d %H:%M:%S") + self.output(f"数据缺失:{dt_str} {vt_symbol}") + + self.cross_limit_order() + self.strategy.on_bars(self.bars) + + self.update_daily_close(self.bars, dt) + + def cross_limit_order(self) -> None: + """ + Cross limit order with last bar/tick data. + """ + for order in list(self.active_limit_orders.values()): + bar = self.bars[order.vt_symbol] + + long_cross_price = bar.low_price + short_cross_price = bar.high_price + long_best_price = bar.open_price + short_best_price = bar.open_price + + # Push order update with status "not traded" (pending). + if order.status == Status.SUBMITTING: + order.status = Status.NOTTRADED + self.strategy.update_order(order) + + # Check whether limit orders can be filled. + long_cross = ( + order.direction == Direction.LONG + and order.price >= long_cross_price + and long_cross_price > 0 + ) + + short_cross = ( + order.direction == Direction.SHORT + and order.price <= short_cross_price + and short_cross_price > 0 + ) + + if not long_cross and not short_cross: + continue + + # Push order update with status "all traded" (filled). + order.traded = order.volume + order.status = Status.ALLTRADED + self.strategy.update_order(order) + + self.active_limit_orders.pop(order.vt_orderid) + + # Push trade update + self.trade_count += 1 + + if long_cross: + trade_price = min(order.price, long_best_price) + else: + trade_price = max(order.price, short_best_price) + + trade = TradeData( + symbol=order.symbol, + exchange=order.exchange, + orderid=order.orderid, + tradeid=str(self.trade_count), + direction=order.direction, + offset=order.offset, + price=trade_price, + volume=order.volume, + time=self.datetime.strftime("%H:%M:%S"), + gateway_name=self.gateway_name, + ) + trade.datetime = self.datetime + + self.strategy.update_trade(trade) + self.trades[trade.vt_tradeid] = trade + + def load_bars( + self, + strategy: StrategyTemplate, + days: int, + interval: Interval + ) -> None: + """""" + self.days = days + + def send_order( + self, + strategy: StrategyTemplate, + vt_symbol: str, + direction: Direction, + offset: Offset, + price: float, + volume: float, + lock: bool + ) -> List[str]: + """""" + price = round_to(price, self.priceticks[vt_symbol]) + symbol, exchange = extract_vt_symbol(vt_symbol) + + self.limit_order_count += 1 + + order = OrderData( + symbol=symbol, + exchange=exchange, + orderid=str(self.limit_order_count), + direction=direction, + offset=offset, + price=price, + volume=volume, + status=Status.SUBMITTING, + gateway_name=self.gateway_name, + ) + order.datetime = self.datetime + + self.active_limit_orders[order.vt_orderid] = order + self.limit_orders[order.vt_orderid] = order + + return [order.vt_orderid] + + def cancel_order(self, strategy: StrategyTemplate, vt_orderid: str) -> None: + """ + Cancel order by vt_orderid. + """ + if vt_orderid not in self.active_limit_orders: + return + order = self.active_limit_orders.pop(vt_orderid) + + order.status = Status.CANCELLED + self.strategy.update_order(order) + + def write_log(self, msg: str, strategy: StrategyTemplate = None) -> None: + """ + Write log message. + """ + msg = f"{self.datetime}\t{msg}" + self.logs.append(msg) + + def send_email(self, msg: str, strategy: StrategyTemplate = None) -> None: + """ + Send email to default receiver. + """ + pass + + def sync_strategy_data(self, strategy: StrategyTemplate) -> None: + """ + Sync strategy data into json file. + """ + pass + + def put_strategy_event(self, strategy: StrategyTemplate) -> None: + """ + Put an event to update strategy status. + """ + pass + + def output(self, msg) -> None: + """ + Output message of backtesting engine. + """ + print(f"{datetime.now()}\t{msg}") + + def get_all_trades(self) -> List[TradeData]: + """ + Return all trade data of current backtesting result. + """ + return list(self.trades.values()) + + def get_all_orders(self) -> List[OrderData]: + """ + Return all limit order data of current backtesting result. + """ + return list(self.limit_orders.values()) + + def get_all_daily_results(self) -> List["PortfolioDailyResult"]: + """ + Return all daily result data. + """ + return list(self.daily_results.values()) + + +class ContractDailyResult: + """""" + + def __init__(self, result_date: date, close_price: float): + """""" + self.date: date = result_date + self.close_price: float = close_price + self.pre_close: float = 0 + + self.trades: List[TradeData] = [] + self.trade_count: int = 0 + + self.start_pos: float = 0 + self.end_pos: float = 0 + + self.turnover: float = 0 + self.commission: float = 0 + self.slippage: float = 0 + + self.trading_pnl: float = 0 + self.holding_pnl: float = 0 + self.total_pnl: float = 0 + self.net_pnl: float = 0 + + def add_trade(self, trade: TradeData) -> None: + """""" + self.trades.append(trade) + + def calculate_pnl( + self, + pre_close: float, + start_pos: float, + size: int, + rate: float, + slippage: float + ) -> None: + """""" + # If no pre_close provided on the first day, + # use value 1 to avoid zero division error + if pre_close: + self.pre_close = pre_close + else: + self.pre_close = 1 + + # Holding pnl is the pnl from holding position at day start + self.start_pos = start_pos + self.end_pos = start_pos + + self.holding_pnl = self.start_pos * (self.close_price - self.pre_close) * size + + # Trading pnl is the pnl from new trade during the day + self.trade_count = len(self.trades) + + for trade in self.trades: + if trade.direction == Direction.LONG: + pos_change = trade.volume + else: + pos_change = -trade.volume + + self.end_pos += pos_change + + turnover = trade.volume * size * trade.price + + self.trading_pnl += pos_change * (self.close_price - trade.price) * size + self.slippage += trade.volume * size * slippage + self.turnover += turnover + self.commission += turnover * rate + + # Net pnl takes account of commission and slippage cost + self.total_pnl = self.trading_pnl + self.holding_pnl + self.net_pnl = self.total_pnl - self.commission - self.slippage + + def update_close_price(self, close_price: float) -> None: + """""" + self.close_price = close_price + + +class PortfolioDailyResult: + """""" + + def __init__(self, result_date: date, close_prices: Dict[str, float]): + """""" + self.date: date = result_date + self.close_prices: Dict[str, float] = close_prices + self.pre_closes: Dict[str, float] = {} + self.start_poses: Dict[str, float] = {} + self.end_poses: Dict[str, float] = {} + + self.contract_results: Dict[str, ContractDailyResult] = {} + + for vt_symbol, close_price in close_prices.items(): + self.contract_results[vt_symbol] = ContractDailyResult(result_date, close_price) + + self.trade_count: int = 0 + self.turnover: float = 0 + self.commission: float = 0 + self.slippage: float = 0 + self.trading_pnl: float = 0 + self.holding_pnl: float = 0 + self.total_pnl: float = 0 + self.net_pnl: float = 0 + + def add_trade(self, trade: TradeData) -> None: + """""" + contract_result = self.contract_results[trade.vt_symbol] + contract_result.add_trade(trade) + + def calculate_pnl( + self, + pre_closes: Dict[str, float], + start_poses: Dict[str, float], + sizes: Dict[str, float], + rates: Dict[str, float], + slippages: Dict[str, float], + ) -> None: + """""" + self.pre_closes = pre_closes + + for vt_symbol, contract_result in self.contract_results.items(): + contract_result.calculate_pnl( + pre_closes.get(vt_symbol, 0), + start_poses.get(vt_symbol, 0), + sizes[vt_symbol], + rates[vt_symbol], + slippages[vt_symbol] + ) + + self.trade_count += contract_result.trade_count + self.turnover += contract_result.turnover + self.commission += contract_result.commission + self.slippage += contract_result.slippage + self.trading_pnl += contract_result.trading_pnl + self.holding_pnl += contract_result.holding_pnl + self.total_pnl += contract_result.total_pnl + self.net_pnl += contract_result.net_pnl + + self.end_poses[vt_symbol] = contract_result.end_pos + + def update_close_prices(self, close_prices: Dict[str, float]) -> None: + """""" + self.close_prices = close_prices + + for vt_symbol, close_price in close_prices.items(): + contract_result = self.contract_results[vt_symbol] + contract_result.update_close_price(close_price) + + +@lru_cache(maxsize=999) +def load_bar_data( + vt_symbol: str, + interval: Interval, + start: datetime, + end: datetime +): + """""" + symbol, exchange = extract_vt_symbol(vt_symbol) + + return database_manager.load_bar_data( + symbol, exchange, interval, start, end + ) diff --git a/vnpy/app/portfolio_strategy/base.py b/vnpy/app/portfolio_strategy/base.py new file mode 100644 index 00000000..f89be90d --- /dev/null +++ b/vnpy/app/portfolio_strategy/base.py @@ -0,0 +1,17 @@ +""" +Defines constants and objects used in PortfolioStrategy App. +""" + +from enum import Enum + + +APP_NAME = "PortfolioStrategy" + + +class EngineType(Enum): + LIVE = "实盘" + BACKTESTING = "回测" + + +EVENT_PORTFOLIO_LOG = "ePortfolioLog" +EVENT_PORTFOLIO_STRATEGY = "ePortfolioStrategy" diff --git a/vnpy/app/portfolio_strategy/engine.py b/vnpy/app/portfolio_strategy/engine.py new file mode 100644 index 00000000..2d7127ce --- /dev/null +++ b/vnpy/app/portfolio_strategy/engine.py @@ -0,0 +1,624 @@ +"""""" + +import importlib +import os +import traceback +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Set, Tuple, Type, Any, Callable +from datetime import datetime, timedelta +from concurrent.futures import ThreadPoolExecutor + +from vnpy.event import Event, EventEngine +from vnpy.trader.engine import BaseEngine, MainEngine +from vnpy.trader.object import ( + OrderRequest, + SubscribeRequest, + HistoryRequest, + LogData, + TickData, + OrderData, + TradeData, + PositionData, + BarData, + ContractData +) +from vnpy.trader.event import ( + EVENT_TICK, + EVENT_ORDER, + EVENT_TRADE, + EVENT_POSITION +) +from vnpy.trader.constant import ( + Direction, + OrderType, + Interval, + Exchange, + Offset +) +from vnpy.trader.utility import load_json, save_json, extract_vt_symbol, round_to +from vnpy.trader.database import database_manager +from vnpy.trader.rqdata import rqdata_client +from vnpy.trader.converter import OffsetConverter + +from .base import ( + APP_NAME, + EVENT_PORTFOLIO_LOG, + EVENT_PORTFOLIO_STRATEGY +) +from .template import StrategyTemplate + + +class StrategyEngine(BaseEngine): + """""" + + setting_filename = "portfolio_strategy_setting.json" + data_filename = "portfolio_strategy_data.json" + + def __init__(self, main_engine: MainEngine, event_engine: EventEngine): + """""" + super().__init__(main_engine, event_engine, APP_NAME) + + self.strategy_data: Dict[str, Dict] = {} + + self.classes: Dict[str, Type[StrategyTemplate]] = {} + self.strategies: Dict[str, StrategyTemplate] = {} + + self.symbol_strategy_map: Dict[str, List[StrategyTemplate]] = defaultdict(list) + self.orderid_strategy_map: Dict[str, StrategyTemplate] = {} + + self.init_executor: ThreadPoolExecutor = ThreadPoolExecutor(max_workers=1) + + self.vt_tradeids: Set[str] = set() + + self.offset_converter: OffsetConverter = OffsetConverter(self.main_engine) + + def init_engine(self): + """ + """ + self.init_rqdata() + self.load_strategy_class() + self.load_strategy_setting() + self.load_strategy_data() + self.register_event() + self.write_log("组合策略引擎初始化成功") + + def close(self): + """""" + self.stop_all_strategies() + + def register_event(self): + """""" + self.event_engine.register(EVENT_TICK, self.process_tick_event) + self.event_engine.register(EVENT_ORDER, self.process_order_event) + self.event_engine.register(EVENT_TRADE, self.process_trade_event) + self.event_engine.register(EVENT_POSITION, self.process_position_event) + + def init_rqdata(self): + """ + Init RQData client. + """ + result = rqdata_client.init() + if result: + self.write_log("RQData数据接口初始化成功") + + def query_bar_from_rq( + self, symbol: str, exchange: Exchange, interval: Interval, start: datetime, end: datetime + ): + """ + Query bar data from RQData. + """ + req = HistoryRequest( + symbol=symbol, + exchange=exchange, + interval=interval, + start=start, + end=end + ) + data = rqdata_client.query_history(req) + return data + + def process_tick_event(self, event: Event): + """""" + tick: TickData = event.data + + strategies = self.symbol_strategy_map[tick.vt_symbol] + if not strategies: + return + + for strategy in strategies: + if strategy.inited: + self.call_strategy_func(strategy, strategy.on_tick, tick) + + def process_order_event(self, event: Event): + """""" + order: OrderData = event.data + + self.offset_converter.update_order(order) + + strategy = self.orderid_strategy_map.get(order.vt_orderid, None) + if not strategy: + return + + self.call_strategy_func(strategy, strategy.update_order, order) + + def process_trade_event(self, event: Event): + """""" + trade: TradeData = event.data + + # Filter duplicate trade push + if trade.vt_tradeid in self.vt_tradeids: + return + self.vt_tradeids.add(trade.vt_tradeid) + + self.offset_converter.update_trade(trade) + + strategy = self.orderid_strategy_map.get(trade.vt_orderid, None) + if not strategy: + return + + self.call_strategy_func(strategy, strategy.update_trade, trade) + + def process_position_event(self, event: Event): + """""" + position: PositionData = event.data + + self.offset_converter.update_position(position) + + def send_order( + self, + strategy: StrategyTemplate, + vt_symbol: str, + direction: Direction, + offset: Offset, + price: float, + volume: float, + lock: bool + ): + """ + Send a new order to server. + """ + contract: ContractData = self.main_engine.get_contract(vt_symbol) + if not contract: + self.write_log(f"委托失败,找不到合约:{vt_symbol}", strategy) + return "" + + # Round order price and volume to nearest incremental value + price = round_to(price, contract.pricetick) + volume = round_to(volume, contract.min_volume) + + # Create request and send order. + original_req = OrderRequest( + symbol=contract.symbol, + exchange=contract.exchange, + direction=direction, + offset=offset, + type=OrderType.LIMIT, + price=price, + volume=volume, + ) + + # Convert with offset converter + req_list = self.offset_converter.convert_order_request(original_req, lock) + + # Send Orders + vt_orderids = [] + + for req in req_list: + req.reference = strategy.strategy_name # Add strategy name as order reference + + vt_orderid = self.main_engine.send_order( + req, contract.gateway_name) + + # Check if sending order successful + if not vt_orderid: + continue + + vt_orderids.append(vt_orderid) + + self.offset_converter.update_order_request(req, vt_orderid) + + # Save relationship between orderid and strategy. + self.orderid_strategy_map[vt_orderid] = strategy + + return vt_orderids + + def cancel_order(self, strategy: StrategyTemplate, vt_orderid: str): + """ + """ + order = self.main_engine.get_order(vt_orderid) + if not order: + self.write_log(f"撤单失败,找不到委托{vt_orderid}", strategy) + return + + req = order.create_cancel_request() + self.main_engine.cancel_order(req, order.gateway_name) + + def load_bars(self, strategy: StrategyTemplate, days: int, interval: Interval): + """""" + vt_symbols = strategy.vt_symbols + dts: Set[datetime] = set() + history_data: Dict[Tuple, BarData] = {} + + # Load data from rqdata/gateway/database + for vt_symbol in vt_symbols: + data = self.load_bar(vt_symbol, days, interval) + + for bar in data: + dts.add(bar.datetime) + history_data[(bar.datetime, vt_symbol)] = bar + + # Convert data structure and push to strategy + dts = list(dts) + dts.sort() + + for dt in dts: + bars = {} + + for vt_symbol in vt_symbols: + bar = history_data.get((dt, vt_symbol), None) + if bar: + bars[vt_symbol] = bar + else: + dt_str = dt.strftime("%Y-%m-%d %H:%M:%S") + self.write_log(f"数据缺失:{dt_str} {vt_symbol}", strategy) + + self.call_strategy_func(strategy, strategy.on_bars, bars) + + def load_bar(self, vt_symbol: str, days: int, interval: Interval) -> List[BarData]: + """""" + symbol, exchange = extract_vt_symbol(vt_symbol) + end = datetime.now() + start = end - timedelta(days) + contract: ContractData = self.main_engine.get_contract(vt_symbol) + data = [] + + # Query bars from gateway if available + if contract and contract.history_data: + req = HistoryRequest( + symbol=symbol, + exchange=exchange, + interval=interval, + start=start, + end=end + ) + data = self.main_engine.query_history(req, contract.gateway_name) + # Try to query bars from RQData, if not found, load from database. + else: + data = self.query_bar_from_rq(symbol, exchange, interval, start, end) + + if not data: + data = database_manager.load_bar_data( + symbol=symbol, + exchange=exchange, + interval=interval, + start=start, + end=end, + ) + + return data + + def call_strategy_func( + self, strategy: StrategyTemplate, func: Callable, params: Any = None + ): + """ + Call function of a strategy and catch any exception raised. + """ + try: + if params: + func(params) + else: + func() + except Exception: + strategy.trading = False + strategy.inited = False + + msg = f"触发异常已停止\n{traceback.format_exc()}" + self.write_log(msg, strategy) + + def add_strategy( + self, class_name: str, strategy_name: str, vt_symbols: str, setting: dict + ): + """ + Add a new strategy. + """ + if strategy_name in self.strategies: + self.write_log(f"创建策略失败,存在重名{strategy_name}") + return + + strategy_class = self.classes.get(class_name, None) + if not strategy_class: + self.write_log(f"创建策略失败,找不到策略类{class_name}") + return + + strategy = strategy_class(self, strategy_name, vt_symbols, setting) + self.strategies[strategy_name] = strategy + + # Add vt_symbol to strategy map. + for vt_symbol in vt_symbols: + strategies = self.symbol_strategy_map[vt_symbol] + strategies.append(strategy) + + self.save_strategy_setting() + self.put_strategy_event(strategy) + + def init_strategy(self, strategy_name: str): + """ + Init a strategy. + """ + self.init_executor.submit(self._init_strategy, strategy_name) + + def _init_strategy(self, strategy_name: str): + """ + Init strategies in queue. + """ + strategy = self.strategies[strategy_name] + + if strategy.inited: + self.write_log(f"{strategy_name}已经完成初始化,禁止重复操作") + return + + self.write_log(f"{strategy_name}开始执行初始化") + + # Call on_init function of strategy + self.call_strategy_func(strategy, strategy.on_init) + + # Restore strategy data(variables) + data = self.strategy_data.get(strategy_name, None) + if data: + for name in strategy.variables: + value = data.get(name, None) + if value: + setattr(strategy, name, value) + + # Subscribe market data + for vt_symbol in strategy.vt_symbols: + contract: ContractData = self.main_engine.get_contract(vt_symbol) + if contract: + req = SubscribeRequest( + symbol=contract.symbol, exchange=contract.exchange) + self.main_engine.subscribe(req, contract.gateway_name) + else: + self.write_log(f"行情订阅失败,找不到合约{vt_symbol}", strategy) + + # Put event to update init completed status. + strategy.inited = True + self.put_strategy_event(strategy) + self.write_log(f"{strategy_name}初始化完成") + + def start_strategy(self, strategy_name: str): + """ + Start a strategy. + """ + strategy = self.strategies[strategy_name] + if not strategy.inited: + self.write_log(f"策略{strategy.strategy_name}启动失败,请先初始化") + return + + if strategy.trading: + self.write_log(f"{strategy_name}已经启动,请勿重复操作") + return + + self.call_strategy_func(strategy, strategy.on_start) + strategy.trading = True + + self.put_strategy_event(strategy) + + def stop_strategy(self, strategy_name: str): + """ + Stop a strategy. + """ + strategy = self.strategies[strategy_name] + if not strategy.trading: + return + + # Call on_stop function of the strategy + self.call_strategy_func(strategy, strategy.on_stop) + + # Change trading status of strategy to False + strategy.trading = False + + # Cancel all orders of the strategy + strategy.cancel_all() + + # Sync strategy variables to data file + self.sync_strategy_data(strategy) + + # Update GUI + self.put_strategy_event(strategy) + + def edit_strategy(self, strategy_name: str, setting: dict): + """ + Edit parameters of a strategy. + """ + strategy = self.strategies[strategy_name] + strategy.update_setting(setting) + + self.save_strategy_setting() + self.put_strategy_event(strategy) + + def remove_strategy(self, strategy_name: str): + """ + Remove a strategy. + """ + strategy = self.strategies[strategy_name] + if strategy.trading: + self.write_log(f"策略{strategy.strategy_name}移除失败,请先停止") + return + + # Remove from symbol strategy map + for vt_symbol in strategy.vt_symbols: + strategies = self.symbol_strategy_map[vt_symbol] + strategies.remove(strategy) + + # Remove from vt_orderid strategy map + for vt_orderid in strategy.active_orderids: + if vt_orderid in self.orderid_strategy_map: + self.orderid_strategy_map.pop(vt_orderid) + + # Remove from strategies + self.strategies.pop(strategy_name) + self.save_strategy_setting() + + return True + + def load_strategy_class(self): + """ + Load strategy class from source code. + """ + path1 = Path(__file__).parent.joinpath("strategies") + self.load_strategy_class_from_folder(path1, "vnpy.app.portfolio_strategy.strategies") + + path2 = Path.cwd().joinpath("strategies") + self.load_strategy_class_from_folder(path2, "strategies") + + def load_strategy_class_from_folder(self, path: Path, module_name: str = ""): + """ + Load strategy class from certain folder. + """ + for dirpath, dirnames, filenames in os.walk(str(path)): + for filename in filenames: + if filename.endswith(".py"): + strategy_module_name = ".".join( + [module_name, filename.replace(".py", "")]) + elif filename.endswith(".pyd"): + strategy_module_name = ".".join( + [module_name, filename.split(".")[0]]) + else: + continue + + self.load_strategy_class_from_module(strategy_module_name) + + def load_strategy_class_from_module(self, module_name: str): + """ + Load strategy class from module file. + """ + try: + module = importlib.import_module(module_name) + + for name in dir(module): + value = getattr(module, name) + if (isinstance(value, type) and issubclass(value, StrategyTemplate) and value is not StrategyTemplate): + self.classes[value.__name__] = value + except: # noqa + msg = f"策略文件{module_name}加载失败,触发异常:\n{traceback.format_exc()}" + self.write_log(msg) + + def load_strategy_data(self): + """ + Load strategy data from json file. + """ + self.strategy_data = load_json(self.data_filename) + + def sync_strategy_data(self, strategy: StrategyTemplate): + """ + Sync strategy data into json file. + """ + data = strategy.get_variables() + data.pop("inited") # Strategy status (inited, trading) should not be synced. + data.pop("trading") + + self.strategy_data[strategy.strategy_name] = data + save_json(self.data_filename, self.strategy_data) + + def get_all_strategy_class_names(self): + """ + Return names of strategy classes loaded. + """ + return list(self.classes.keys()) + + def get_strategy_class_parameters(self, class_name: str): + """ + Get default parameters of a strategy class. + """ + strategy_class = self.classes[class_name] + + parameters = {} + for name in strategy_class.parameters: + parameters[name] = getattr(strategy_class, name) + + return parameters + + def get_strategy_parameters(self, strategy_name): + """ + Get parameters of a strategy. + """ + strategy = self.strategies[strategy_name] + return strategy.get_parameters() + + def init_all_strategies(self): + """ + """ + for strategy_name in self.strategies.keys(): + self.init_strategy(strategy_name) + + def start_all_strategies(self): + """ + """ + for strategy_name in self.strategies.keys(): + self.start_strategy(strategy_name) + + def stop_all_strategies(self): + """ + """ + for strategy_name in self.strategies.keys(): + self.stop_strategy(strategy_name) + + def load_strategy_setting(self): + """ + Load setting file. + """ + strategy_setting = load_json(self.setting_filename) + + for strategy_name, strategy_config in strategy_setting.items(): + self.add_strategy( + strategy_config["class_name"], + strategy_name, + strategy_config["vt_symbols"], + strategy_config["setting"] + ) + + def save_strategy_setting(self): + """ + Save setting file. + """ + strategy_setting = {} + + for name, strategy in self.strategies.items(): + strategy_setting[name] = { + "class_name": strategy.__class__.__name__, + "vt_symbols": strategy.vt_symbols, + "setting": strategy.get_parameters() + } + + save_json(self.setting_filename, strategy_setting) + + def put_strategy_event(self, strategy: StrategyTemplate): + """ + Put an event to update strategy status. + """ + data = strategy.get_data() + event = Event(EVENT_PORTFOLIO_STRATEGY, data) + self.event_engine.put(event) + + def write_log(self, msg: str, strategy: StrategyTemplate = None): + """ + Create cta engine log event. + """ + if strategy: + msg = f"{strategy.strategy_name}: {msg}" + + log = LogData(msg=msg, gateway_name="CtaStrategy") + event = Event(type=EVENT_PORTFOLIO_LOG, data=log) + self.event_engine.put(event) + + def send_email(self, msg: str, strategy: StrategyTemplate = None): + """ + Send email to default receiver. + """ + if strategy: + subject = f"{strategy.strategy_name}" + else: + subject = "组合策略引擎" + + self.main_engine.send_email(subject, msg) diff --git a/vnpy/app/portfolio_strategy/strategies/__init__.py b/vnpy/app/portfolio_strategy/strategies/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vnpy/app/portfolio_strategy/strategies/trend_following_strategy.py b/vnpy/app/portfolio_strategy/strategies/trend_following_strategy.py new file mode 100644 index 00000000..d5eac4c9 --- /dev/null +++ b/vnpy/app/portfolio_strategy/strategies/trend_following_strategy.py @@ -0,0 +1,139 @@ +from typing import List, Dict + +from vnpy.app.portfolio_strategy import StrategyTemplate, StrategyEngine +from vnpy.trader.utility import BarGenerator, ArrayManager +from vnpy.trader.object import TickData, BarData + + +class TrendFollowingStrategy(StrategyTemplate): + """""" + + author = "用Python的交易员" + + atr_window = 22 + atr_ma_window = 10 + rsi_window = 5 + rsi_entry = 16 + trailing_percent = 0.8 + fixed_size = 1 + + atr_value = 0 + atr_ma = 0 + rsi_value = 0 + rsi_buy = 0 + rsi_sell = 0 + intra_trade_high = 0 + intra_trade_low = 0 + + parameters = [ + "atr_window", + "atr_ma_window", + "rsi_window", + "rsi_entry", + "trailing_percent", + "fixed_size" + ] + variables = [ + "atr_value", + "atr_ma", + "rsi_value", + "rsi_buy", + "rsi_sell" + ] + + def __init__( + self, + strategy_engine: StrategyEngine, + strategy_name: str, + vt_symbols: List[str], + setting: dict + ): + """""" + super().__init__(strategy_engine, strategy_name, vt_symbols, setting) + + self.vt_symbol = vt_symbols[0] + self.bg = BarGenerator(self.on_bar) + self.am = ArrayManager() + + def on_init(self): + """ + Callback when strategy is inited. + """ + self.write_log("策略初始化") + + self.rsi_buy = 50 + self.rsi_entry + self.rsi_sell = 50 - self.rsi_entry + + self.load_bars(10) + + def on_start(self): + """ + Callback when strategy is started. + """ + self.write_log("策略启动") + + def on_stop(self): + """ + Callback when strategy is stopped. + """ + self.write_log("策略停止") + + def on_tick(self, tick: TickData): + """ + Callback of new tick data update. + """ + self.bg.update_tick(tick) + + def on_bar(self, bar: BarData): + """ + Callback of new bar data update. + """ + bars = {bar.vt_symbol: bar} + self.on_bars(bars) + + def on_bars(self, bars: Dict[str, BarData]): + """""" + self.cancel_all() + + bar = bars[self.vt_symbol] + am = self.am + am.update_bar(bar) + if not am.inited: + return + + atr_array = am.atr(self.atr_window, array=True) + self.atr_value = atr_array[-1] + self.atr_ma = atr_array[-self.atr_ma_window:].mean() + self.rsi_value = am.rsi(self.rsi_window) + + pos = self.get_pos(self.vt_symbol) + + if pos == 0: + self.intra_trade_high = bar.high_price + self.intra_trade_low = bar.low_price + + if self.atr_value > self.atr_ma: + if self.rsi_value > self.rsi_buy: + self.buy(self.vt_symbol, bar.close_price + 5, self.fixed_size) + elif self.rsi_value < self.rsi_sell: + self.short(self.vt_symbol, bar.close_price - 5, self.fixed_size) + + elif pos > 0: + self.intra_trade_high = max(self.intra_trade_high, bar.high_price) + self.intra_trade_low = bar.low_price + + long_stop = self.intra_trade_high * (1 - self.trailing_percent / 100) + + if bar.close_price <= long_stop: + self.sell(self.vt_symbol, bar.close_price - 5, abs(pos)) + + elif pos < 0: + self.intra_trade_low = min(self.intra_trade_low, bar.low_price) + self.intra_trade_high = bar.high_price + + short_stop = self.intra_trade_low * (1 + self.trailing_percent / 100) + + if bar.close_price >= short_stop: + self.cover(self.vt_symbol, bar.close_price + 5, abs(pos)) + + self.put_event() diff --git a/vnpy/app/portfolio_strategy/template.py b/vnpy/app/portfolio_strategy/template.py new file mode 100644 index 00000000..597a4a93 --- /dev/null +++ b/vnpy/app/portfolio_strategy/template.py @@ -0,0 +1,247 @@ +"""""" +from abc import ABC +from copy import copy +from typing import Dict, Set, List, TYPE_CHECKING +from collections import defaultdict + +from vnpy.trader.constant import Interval, Direction, Offset +from vnpy.trader.object import BarData, TickData, OrderData, TradeData +from vnpy.trader.utility import virtual + +if TYPE_CHECKING: + from .engine import StrategyEngine + + +class StrategyTemplate(ABC): + """""" + + author = "" + parameters = [] + variables = [] + + def __init__( + self, + strategy_engine: "StrategyEngine", + strategy_name: str, + vt_symbols: List[str], + setting: dict, + ): + """""" + self.strategy_engine: "StrategyEngine" = strategy_engine + self.strategy_name: str = strategy_name + self.vt_symbols: List[str] = vt_symbols + + self.inited: bool = False + self.trading: bool = False + self.pos: Dict[str, int] = defaultdict(int) + self.orders: Dict[str, OrderData] = {} + self.active_orderids: Set[str] = set() + + # Copy a new variables list here to avoid duplicate insert when multiple + # strategy instances are created with the same strategy class. + self.variables: Dict = copy(self.variables) + self.variables.insert(0, "inited") + self.variables.insert(1, "trading") + + self.update_setting(setting) + + def update_setting(self, setting: dict) -> None: + """ + Update strategy parameter wtih value in setting dict. + """ + for name in self.parameters: + if name in setting: + setattr(self, name, setting[name]) + + @classmethod + def get_class_parameters(cls) -> Dict: + """ + Get default parameters dict of strategy class. + """ + class_parameters = {} + for name in cls.parameters: + class_parameters[name] = getattr(cls, name) + return class_parameters + + def get_parameters(self) -> Dict: + """ + Get strategy parameters dict. + """ + strategy_parameters = {} + for name in self.parameters: + strategy_parameters[name] = getattr(self, name) + return strategy_parameters + + def get_variables(self) -> Dict: + """ + Get strategy variables dict. + """ + strategy_variables = {} + for name in self.variables: + strategy_variables[name] = getattr(self, name) + return strategy_variables + + def get_data(self) -> Dict: + """ + Get strategy data. + """ + strategy_data = { + "strategy_name": self.strategy_name, + "vt_symbols": self.vt_symbols, + "class_name": self.__class__.__name__, + "author": self.author, + "parameters": self.get_parameters(), + "variables": self.get_variables(), + } + return strategy_data + + @virtual + def on_init(self) -> None: + """ + Callback when strategy is inited. + """ + pass + + @virtual + def on_start(self) -> None: + """ + Callback when strategy is started. + """ + pass + + @virtual + def on_stop(self) -> None: + """ + Callback when strategy is stopped. + """ + pass + + @virtual + def on_tick(self, tick: TickData) -> None: + """ + Callback of new tick data update. + """ + pass + + @virtual + def on_bars(self, bars: Dict[str, BarData]) -> None: + """ + Callback of new bar data update. + """ + pass + + def update_trade(self, trade: TradeData) -> None: + """ + Callback of new trade data update. + """ + if trade.direction == Direction.LONG: + self.pos[trade.vt_symbol] += trade.volume + else: + self.pos[trade.vt_symbol] -= trade.volume + + def update_order(self, order: OrderData) -> None: + """ + Callback of new order data update. + """ + self.orders[order.vt_orderid] = order + + if order.is_active(): + self.active_orderids.add(order.vt_orderid) + elif order.vt_orderid in self.active_orderids: + self.active_orderids.remove(order.vt_orderid) + + def buy(self, vt_symbol: str, price: float, volume: float, lock: bool = False) -> List[str]: + """ + Send buy order to open a long position. + """ + return self.send_order(vt_symbol, Direction.LONG, Offset.OPEN, price, volume, lock) + + def sell(self, vt_symbol: str, price: float, volume: float, lock: bool = False) -> List[str]: + """ + Send sell order to close a long position. + """ + return self.send_order(vt_symbol, Direction.SHORT, Offset.CLOSE, price, volume, lock) + + def short(self, vt_symbol: str, price: float, volume: float, lock: bool = False) -> List[str]: + """ + Send short order to open as short position. + """ + return self.send_order(vt_symbol, Direction.SHORT, Offset.OPEN, price, volume, lock) + + def cover(self, vt_symbol: str, price: float, volume: float, lock: bool = False) -> List[str]: + """ + Send cover order to close a short position. + """ + return self.send_order(vt_symbol, Direction.LONG, Offset.CLOSE, price, volume, lock) + + def send_order( + self, + vt_symbol: str, + direction: Direction, + offset: Offset, + price: float, + volume: float, + lock: bool = False + ) -> List[str]: + """ + Send a new order. + """ + if self.trading: + vt_orderids = self.strategy_engine.send_order( + self, vt_symbol, direction, offset, price, volume, lock + ) + return vt_orderids + else: + return [] + + def cancel_order(self, vt_orderid: str) -> None: + """ + Cancel an existing order. + """ + if self.trading: + self.strategy_engine.cancel_order(self, vt_orderid) + + def cancel_all(self) -> None: + """ + Cancel all orders sent by strategy. + """ + for vt_orderid in list(self.active_orderids): + self.cancel_order(vt_orderid) + + def get_pos(self, vt_symbol: str) -> int: + """""" + return self.pos.get(vt_symbol, 0) + + def get_order(self, vt_orderid: str) -> OrderData: + """""" + return self.orders.get(vt_orderid, None) + + def get_all_active_orderids(self) -> List[OrderData]: + """""" + return list(self.active_orderids.values()) + + def write_log(self, msg: str) -> None: + """ + Write a log message. + """ + self.strategy_engine.write_log(msg, self) + + def load_bars(self, days: int, interval: Interval = Interval.MINUTE) -> None: + """ + Load historical bar data for initializing strategy. + """ + self.strategy_engine.load_bars(self, days, interval) + + def put_event(self) -> None: + """ + Put an strategy data event for ui update. + """ + if self.inited: + self.strategy_engine.put_strategy_event(self) + + def send_email(self, msg) -> None: + """ + Send email to default receiver. + """ + if self.inited: + self.strategy_engine.send_email(msg, self) diff --git a/vnpy/app/portfolio_strategy/ui/__init__.py b/vnpy/app/portfolio_strategy/ui/__init__.py new file mode 100644 index 00000000..90241f5e --- /dev/null +++ b/vnpy/app/portfolio_strategy/ui/__init__.py @@ -0,0 +1 @@ +from .widget import PortfolioStrategyManager diff --git a/vnpy/app/portfolio_strategy/ui/strategy.ico b/vnpy/app/portfolio_strategy/ui/strategy.ico new file mode 100644 index 0000000000000000000000000000000000000000..f67f9d8d492a21bdd72d5181928b6a94cb9fa280 GIT binary patch literal 67646 zcmeHQ37izg`5mHBqyL1MsDMVnfCe;PDEFxd2#PSfJA2P^i;4%H$R(hNir{So0TH2&_?%c;ST?A_Nf?M3!s0xBl-{clA{F^vrb6%;KWE{2tZaQ(edRzWVCysz9I$ z{tX)z!2h)Zb&shMI4KYa)CKT}59M=S^6K+H1|*F8%Ks}RP$_{*2~s_Zc6uOp#P~qb;=KvqJ#;y7kI5?! z^p}o5M|?-)zAjK-1MiOkEFDQlyk5dc#4~z2Sw;Co;jP-24jvw=&1Du9eT9r<+&BMukji9ELedIO_@S-?BM>Y`%tQ_;fG zlDr3Ilw@UGRuXJ{s!D%Ju+hmSnH`#xc*z68QZY@2mw@0~3Bx^(eTaK4FcYB7ITMJ|1;4U)`d9LVdp@vQ zrIoTL2PzL;kSFyU+SYmve{zE0zt26>k~RMB6YG$OKLRocTyfB$Op^6fgxaxu`Dk zTGbKv-QGjYocNHKFlMm0HnXWM5AeGII1lhu9+iK%nE>@ja@?nLL$Pv2u_!Am6Z|`P z@Su2m$vn~hvUBx#55Ru`;44lF#zkHy;aN9eA+Q~YCHvUVm+TkW*Y<mbG<0 z7UCV*4)CgXW_)bN&0u`2XJ~6aTe-p+oxER#=eaHauG?3>cLc+U_ve%c?*9T@1F+pq zdzzZ;(S|FWwN_-dy9jBkCfhATuFe#D_w07f$I}lDg&)IC1e?|r`C}KQgMXEWb;!@% z0Nabc^nh{%ZUHPC_S&q;u6g)*^ADnLR&(73e1i8_kH@3J_8a{I?8`j^$T)La)DkP- zUg{d>r|&N}_-|;{udJtJU(gn?cH@u%`)*x;H-O~JyHj7t|BLHIX8VScf3}tS_qYP~ zeury5#@=_k#(yoe8zap68^0BO zx-=A8TafLszc381fAni&09?sG>y)wg_lM0nV2hV+fr$^^soNX+k6GvFc{Z=lf;$P= z2`Cw*$%XxYwM1lPU18e)>Voy=2g{e(^1z$#H>?xaXSX!U4ekHM{}hEYR{3oMdQCTV zoNXGmX{!Kn^}yt!GhEpgV|)M3YqP|TZCj)AvGBzyk#>y19mi&V&oI3HnekBvngUw^ zvks^&W=?oe{Pg2SX+znDdj9E$MSklubpF`~Yz;^q(&NtZ#FwxS=s!)7 z{QiWzj{r^tV(Ne?%fG?>oOvKJ8lNH_yZa`wVclx6XLqUCUb00joi|gqZyf!x>&^Fd zDiU6s@37DD4C5o%4q&`}vzm!PeS+e;P;-%qd&d~Sboh;GZ|M2MHOJ5<18gU|$PY4S z-&-LS4p;~L`%978p|LgYL)%E5y{%gZF`(C#=+mBqJeX}}&VaIaS zJ8;hc{tH;~vMw;jfNcE?{?59l9$>}i%oFb018m>9@NPT{FWf1@Fa2PlXyLAye&dB> z{Ra8xJ@^9J^#6$H0Awe}PV_M>a^L7=k#ozPqG-`>#F;|bN57MO4=W$@wuqu72jb;p z1Jc_EU|Y!21M;Kf3HNw_WAGLi)LS(khq(;vIs9-e_)5UyA{N(f&((XT%kk!U0Q>fO z9P0WYp0l69KB|>oJs)Piu7|q`h%E>7xAMl!6qy~Hip(@jxk{fHj3+$jJA~iF$j5~zh^*|+Xiw}&#J`fwT*!%aWGo$E{n`nb z3VaHXi#5O-z!0DgAbFsh%AX#NdCfKfb+9JDao(c=D=uE^>Bcq3a;5^_`tuh5Cm3ZE zU9lr++=YxULCs{?xln!2#P|n4=Y0nH?qaAD^F| z@^4EA+n3;5B87V17N(g`+W%KJiBQj8Mm{i>oPEP>u>Zc>F`qzowss+9>?V8XGx|&S zpCgVV^?SyJwCXq=_w3_(=U(}JeGls3?G}n&-|IgN6;>h_?X(1Sa&mE_y&-LknU5s6 zSMh!c8Mf4At9aGqPcF)J?S&Q>_3td;Dc}O@d81Qw?oUH{7bL~K#TUz@LJ~S;h2!?F zQQq@LKZmx!Xro=mK3d#&+`>upOr>ABiTua>ojPE%1D4zkn}09vuLjCx?S&Q>xvvYn z39!xi4!|`>fr5$46#vzbR<>WtL7wRdQ*l2-w)caLP7yh`k1%X=x#&PT(k#PZT_D>7 z+1K2l>VV@Q`-cF|3vt{n?K*%kRa{Q)Ri9%8{L~Be=t2Jf2t0Fma9_{pU$D>L_|(Hi z2iPX8oi}P)0OJ&jd%aK2b$Myi0T0}pI#B&B^wZBR`j5B=={uMn{*}!%%4_Oj>gC?# zUiF{LRR>f)(qspU#hmnA zuf?%wMm8O^n=nw0sg}V-fdiVJh<5zpZk>cFx zrlbSDxYyDUDk_xy>6CPU{l7ty6k}L|7{l8S6yx7 zBXosadzwNYz~CM_l+*tn%um#G^6hq?ln(gjzCks#JDt{7L@X!NU-u>*4a zXPhx^f_YA^+seCdatiz(6nUfB*3)c0c~+b&mt&2*PKOSpXZ$B(FN|_;aNh)L)KvOa znpGa6UEeyO@y~gjWez#*z&+di9Pekl>0nl7huv~bwUwsoBGkP*<~_bQbk?s9kpB_d z_%HI9^!Ojg{2!QZwDXm$;AVEd#2B09Go(ACbv-#gYvonxz-@OK{Q|%9Ko3=kFbxzcB;ia=BY9-Zzh7&hY0T<<9fYz&)=v6*8s9_sKyRrx$vjj`)zT(5m4dp z#x>_3ueF|Wp7ingWjO=xi=+)@HnV*LQ84YJ%MQk z6X|Dp2h4Sy1(V+sp@JM~TV&e>>%$5rzbjC$M&t{>^uswk59`1+zKnW0{PGVXD?eo9 z1M&#nx58{>I^6j?>5DT|I|C z#XsG40Oxwx)))kEjHxGJ#K|>sTq8v}vtB}9nrrRy$1e6SGtEwf7hvp<M)M1V*^ZSH?r$^Yzi@;BC7oF6Ql`Bic@#}>Dq4_xcN`Mri+WSVjY zK47%3AoGF=Z^q)EI?xgr1)Ky}WjZeL+K>7>j_YXtkv7Y&v901aZ{_V@rAfK5_?J3o zreFCj@3i>|U!{pIWS)S3a&FBDkVC(JI`8;45@SxM;G5}ccKgrrFZoZxKKsdgI{n7u zKb`TEXIO6hH>mz(_6@fOx}0^CRo)ecez|x0N0|-;FE~Ds(e}a{xkd`%%6@2mCI7#S zy>sC|qxD%+di*GUmpklnITz#PFxh)@Pay0+y5PO{mx(dI$pomjsJcF)wm14J5GxYFvB?l?j&;Oi+^y>F%Yh+tvFaN3dR@9HNR>N zIQ9Y(1<(9fjyDy)_$|hQDhA{_bXDG+ZhZdv?YXu8fGW#=U8O^C`R(!s-~20m;TUB_ z1I8+?dJOGZ^lQwqT8um6d(mGW4KQ~x!>~zC?s*@d|GnUgdr+ig1JwI)ZvteTzRExO z$NY;m{;tLiE21msg3RwhD#A79A9C*=mn{Emn{dtjk$~S~Z~5oD!2bb@ffV?+=8)uf z$^2fa645fS?c0Gl9?pZx_LBedPyV?V&6z+_HbA`xw-Lbo(9*zvb#YBb6S>a@zaOnc zL+_%Ksgb`!w)R{hmAZNpWERKKcDP_u;67Ap^=}{>-P- z$vMx?`QLcszgGQk*njf>%D*Q>?O9Oyz>omjK{amHZz9C0&qy{`Y}9?opX6|2cR!vL8=n`7c*3CIc-|8Dlbw6sio``tmY{x^i3yNo>)b%iP?{+EEmW2pmP|4e=1JFf%f&i{KK z?H4`!Z5Q3H2X6S4!_4EG|BdY1sC1y5_-8**?j_)p|C)Gj?V;fR?f3SJ%pTiBm#)7J z1iNh$OWr(W{O<;b?Ep`00Ob#E3&6f7n@jy)F#ZjJJ@bum3j8j|E0l}pT+6H4w6WU% zE5LtN&uxNo`7I!$JNSRg1ON9-h$~y$`cKaPM&^GaHrue@+)LFj|K0Hp#}Fc6Ir+!9 zdG^)UiQK_sux9dsXyAFMTenL79eVAR|Lndu;oGNIMgC*2L<7$$_h6IL4Ee_A|A{2{ zzsttGTlX=#PGYV}QuGpV9+q|FQmuI!WF03Ns(c^AB#i{v?7q{}I6}9}~fj!$k1% zks_FZwJ3#O8~WghJ5yKkKM}z$lMx2Ksp>df1Urut!Mt}7&Q`#yJ|Ruv^&*({k_cWg z%1Bcu*yGRw5iI&XZkjyv?^V9^2RoZ{(_pp!!(MKUc6==#{Bv*O4*+{OSkHlPyyf>; z`47*An*i9i59^UrwEvKE{zoF%_8Qc6*uRnAaWz1FLo~v-Z!Iqo!QfQr$Cd=#fmia< z=|K@}-Votn4YLY|`q}hc5$rGo`t(g)?kt;-HBSUvcS0KQEv^*~>bl0*2mO*f=qkR= zh@nef`Om)YX1VuaLOFv6?rkW)Z*g*F+;jf!d>{OC50}rQ_-9yt)^oj;YWttue<*^j zIv6rm-x=$6k+P@EJ3Rz>mjG^k2y*|U2wv77yrZl%jI>EC`zfD6x7rk%IuOea;CTjY zVAHc>q{-Az>R4!@t1fxsAM<~C_dXN5{jd61y36%y?);bFd3_*}`ImhJ_~Up1#|>hI z=j}P`3%ULS;VLi5e`fl~*dM_ncw_QKR5 zhQoI=@GXP7Ho^jbv;(p(a0!R{*u06AChmPqsC|&Fy^yJwt=dADzCu`bUGmKT$fsk^ z|NG;{FaPI&`^~>7|7=%nMtY@7|~)k!RxJ6*iF;6?GkfU87i6SQfhU4dS(E|Bd3=#9;u;GS&; z(|&O8-5wp!mwTe~92}m5(iP_;SvtV**k0+zHfuze4g>zV&(zjmlz+CzJKkmEZ2ag! zVn61|)gSk5zp8TsQwLCPo1SO174)_|^(ea-dQ=Cz8TNyFWp7;njS5DT?%W}Dz-9;7 zCY1dP#T)Gr+7GPrWqptJIXLs=jn`&d;qZcUxS?j|{v1*#^NhCZ;l)4e3fc^Bc0QKB zlT; z3Ws)pb%C@2k(~u;TjQa(M*(d55|+}r+xx9#V496WG9+<$vdjenH!oHtzff6=V_Ye>IOJuonef7S(t z{fEsaZ?yBy@?W?X{I@XquO_a~X(>v!Y_^5<1I`iY-RUC3CNZ2=?O``I9B%yo^7fyy zXPaKtwP>%68GMbcye(h+vU&a>%0B(E&0X-hquin%$^9?dhiAbL+7SrnnqXt`)u$g= zys<7|dmv)}(Y_2~Esl+U=p)+}m*g7yfx4Y%2J#%91JFbM?cGxeM*QGu2 z-p@#rsz7Uc+9*3=6u% zz9Gk?q%Of;Sm|N9cn@40L|^1%gcGfgUir`E{O{5O-egTacy@(ueFY7mNvSm5U%FB$!KW#Z}y4Uhg{s(&DpXaJ)v^ftr*9CZv zbCgy6m^%Nu+z;Ptul!SXab=4$O66QU%U68(Og7#l8|D1Ei zzK=X(nfnt)0nc;qe6*5(M}6Xre~d34*4|83w=WG??2MS zzvs5U(ha36aEAlsUj8Wq93%gmg3i0TE+_t928W*3{<`pgIo{o4=|DM^f0h5=x?4{C z4*`FkZNFy!`A$t>O(p+_Ml;KS|59+*7VyQtrUUFV&P*r&^JY#wG{QU#V_7(7s?}HG zm|7a@|5f0wdJ6ngZ>~-k|D1cDK6aG&apQWi@!M}Ip~-+UdH(5#)6GA}aI27izkhtk z#944R0@33?vj6YB|HpB_f{V|TYbLpFvJ!R~6kPI;bn?FsoP|sd{n9V}oA0vi$nnFp z@Xz`3O2n3d8Y@qO{pa@^rbJ@O81`FR8+BzNEr{JIXiztTW_#A*@N@x`c`VtdZfI z3eTvM@{QK1HsdsR(IQX`t-%@i+ZCZuu`;i1&^ocd7m{3=hG!T;o( zTk*a0uaSSur5yJB<4il8nf*S_={PA}{8N8qSW6BBx-Ju~vv26XRqpliYb*bUJ^v^1 z%*wLxye&5rKKHrgEY-5F-Y5S!$2`z;r#K2{=gzX)&WYOp715T)vj2xW|L9j|w5`7? z|6dDEL7O|R{8JBV1DPF~;QYVp4_5MD0on-8`d{{Et1szt*6)w*a`w@x=qBxgW5*eS zC#(C1V)Ku4DJx=`XPQ*o&P(}Mek)>LRNCO$)_)k|8B4&;G11)lJNwh|{_(i4u03<- zpJ)9n*lB))P|@@NX~Q|pMX&Fc{b|?sKkYx(EAhJ)uB)n$dW^oS%*RrE6Ox$!--EN` zG~PVf0^L8Z{|fN@=+}YT0Bt}t|5)c2p1V`5s@-)E-lWzD(A4Da>d{aiy% z1rwKvtYBNg^Qk!=K^eOC{}3+E+~7AWa$ieT*8Ie$`=vhV(0@zfn&$zb{U!1rpD)8@ zpVi7or2O-oZgM{um;uxPlb)W z<5}Lwi6WzU9qBi-Lu2D?`y_ahy1Hn$D46`Nk&p4n%hZqJeV-LyzftbVeMakgB8cxE zd0rXXdG>J(gC}zICtx73+vI6B+*1Il1Bnzs`GXq*_5g~1x(fh?XFNljljA|5UibzJ z*YY`ZfZqZXJo|p6&h;Zht{KI;#~{v@kYlvaC+;C6%aNaO=ojN?+(VNxlWAicxp{55 zPa4|FCjV?VUx0iJ7yOn%jtMe<7=LHJiWcqyH$L^4It-oV_ggX_*rPKG=M5%RCKwjZ zq|eI4_l8zJFlLGL=Z+X(2Hh>Q(%i7dGUW4aKyh#6yJ3wHc;=V?W6iMk!d2nW-KXDY zw8EUFwQmN}g>M$}?wg#vu7iK%%djQp_b++(Pf3n{$gU9Mx|F>oBkE~mem{6hBu~6= zjBf!o9j3esXRJ!j?x5ZX&;3qhcEoq&w22rmTXm!jEn0&1iQh7#$~O5&of547S1C7^ zL#a=!N0lExF5csKe^$Px19p9m7OL0b4(lgp9B6cYh>hu-{yrnrNlqO zzhi`=ElwsxvvXI?&*5@)BWQ*u@0b|vfh;E)yX*vpY~Gc6ZvF&ExX?U zxt}TbH09ZxR{u@+>-E~J4v>4;?1CvPljmOY51ak!R>A$5nUCxn?|^J@mK*yYUeoO6 z5B4noVeC^Rb)bB>SNh?i1HQ}PX6FnCWvpH#Ss zbb#DF$t~N3eO_OHo%$Jk*lkzjKB;oAxYu1x4`>J6bbvZB zETdJuqtHg{0sXe08IV$5(>%_mbU<$pD4TypOW9ses-1c0@>qs`(3rx1e{n3@D%U~o zI{~K-q>X!}2Z%$}A5I-$TaazbqeH#>1TcSo1o$5a>`Tecm8U$F4yg73%l#ECj5;P= z<-YvVKMiuNn*Zp_0kmO{fZT5d^tOK5xsRg*+Y$GDfE;hcKF`KHp|N*3zc)Qt@pT|1 z#|8=}y`5TpUw(Y+?FH`P&u=mVXalJ7cqHV{aUPb#w*cR5J;kpV7vZaT|A6auz#p|| zihsI1Gx%}9etw8oIXUZp-g|x5t3TAq+{?Uh8ru1+pS`~4$DgYAmFwqSyWfRKUroS} zG3MemxwHKB1gt(De=Dy)EnX`PDlRL1Mmqn9d+LBsJAk(Hye*>exi19jIf?N6pTST0 z=s>w~zj&`4dyx5f?sGX;BiF*&dG(@q3n0VOb=fL;yKzw-+{cQ??YW*#T+Ha?CEPqmhqRb>%)za(>KTJaZwt(*Vo1 zl6yC~yFF8Sl8WmHw+il01zG^@P4_gozXPmv>FLEa&k|vq&|AL%d~gm%j+=3831c03 z_dlCH{>eAS<_n)+E%(nTz}S4@^iM_jwQWWm>8NXwz;hTZ z?sXZP*C)bl46rV~7&rsq+>gaY8o56U;XVp{1+d@2amHQ1YTyar9KecKPdBdFFJRyB zdmz5eBmcJXzY|j%|5NdiSJ-5(sSb7TX6&27HQ<_Id&#)u-#Px5 z?Ror<-)xQqSbxUi{$$8{2=F1W4ImdB`~4P}0kj5KRwWi28# zdxq)XidWAAt}g{v0(leC3h*c3mD(LpK$g~e{26fxi{aU(nNXSS=+q;es>#-`}&YE+xoif zo!39%_nQFA7WmZxysOG`GhF`eh;((EdCjCM^M*yeK~v-Z9c@@X;9@sC@7|3hg)AgR1}=7YxtKpZOm9ck=<0`)cz0 z&5^fs6%!vEBJR7bhbU}Q7dl{zw;2ApX2=$eBVTM6-U1}94`OlZ^*QVY`}Y~GvEQee zPU`HnnN7vxBW@N?jT#_s@7@vTw_pxX=`!X9^X{Ege*Hpw+Sx9846we{!?jIzlqJA6 z)?hRv7xs3u0- z6c*q8Z;iD>!1j_YV$vhS4IQ9-Q5R%|t`On*TT{0I&h*&tKSbWMf!_m`Y;5G?#*`(; zb~XT#3!XhPXUe}s>8>3zA7$7;XZ5O;;9;KkLdteS%`o?owO&)eY-3)}=DfSbxMd7ipdk z+Ks%u2(X_29bm~JmYnz;=^<{G>5^r!nt1hvXH>Fn?v~9zh`w3PbQ|z7-aCO|c>SYa z6BFP{dGP$54_BZvb@^Mha)mq#N#lPi-iyn>N{cQ7T(h6p9N;{MeecO+bOQY_eT_8v zzoTa-*vy?Sxe&4~5aGX;$RD>PUHo$ot=EA};8?)oo!5!}OfLEHslu=VO>p6m z3HO|64NL|$0m<0|+JDsHSs6I{&9wi8*kfbm3Mhu-k8FX&IbsD$gx2gM#9K64^H{up8ep9fKkBb zK6)w#9k2kyiGZAI=f` z=E8i3dh;|O73%Q^t%*jn#v>OlC?dSYnbkUX>Jsw;VhpweHsSg)Hm`Z>R4pJ(hY25th_mSJ6H zan9>x|4c5(;a!Rgo~dp0{&_P$@0(l>HzjJ2b2%ile9zlO$X!S z_>}MX++^dFpLx%^fqPPMU$U71`wfvjDqZ!W4E_1Z#(HvX3_y*YSmlK}fZwyP?I*_z z6U;f%7ID1pHQ)}Q5x{W`8|S=F^-swY?n%I_fF*OLU2+jU7RPbnd>|<<)O)6TF5E?c zzSh`E1D^AG1;DXF86Vwr{<+B&&saay1h`M>aDZo)d;zfSsMQt5*gw}1bFU9M4rJ;9 z(kbhE)(`wPfNM4RJ&R-fm+KfOnzrAz6sfc?|Zi5`;rpa`VyS6ur(Ae*qry!?Da*9O1aO`2Y83& zn(fcKfbKvGfV$?&H=ZY%8*i`avc)yWWs87$z&c;fV6O-hhtu|Zs$BB z?fgFgrp1bj*N5_ta`yo0MOE;6V%K*?EnJ_CbF41Fw{-P!enNfB3tfQu`g6LzCu-x& zf@=7kV^YTe3Bx{=d3DMe_gwo!dvF2pPeAtb;aYs8D}8#Hv98kb1P3tR6PNE))0M9GW69t7JYDfshEpklN(od-pi%;r5~!5G zuS)_^1IT_K46BMaN&*4tR1qR@G!Am1jD`xaE`ZA;MWabW0y0Q?m5ceuj>t#~iJ z^Mzr7ntO1+hd=WLc^^P9^6zMMKg_tVV&0ptBLKty0psUDpp|it#2J58jr(=#{%G7I z{$h1s#kg0ufq+b3&YI_`b*=g=3}1KVGe{y6lpmE%52KOOpN}>-lnL@)%=!hce4!@qhsyd?#)xVb^{FVvb9#^t-c9lE{{g^}A58!N literal 0 HcmV?d00001 diff --git a/vnpy/app/portfolio_strategy/ui/widget.py b/vnpy/app/portfolio_strategy/ui/widget.py new file mode 100644 index 00000000..6e03099f --- /dev/null +++ b/vnpy/app/portfolio_strategy/ui/widget.py @@ -0,0 +1,439 @@ +from typing import Dict + +from vnpy.event import Event, EventEngine +from vnpy.trader.engine import MainEngine +from vnpy.trader.ui import QtCore, QtGui, QtWidgets +from vnpy.trader.ui.widget import ( + MsgCell, + TimeCell, + BaseMonitor +) +from ..base import ( + APP_NAME, + EVENT_PORTFOLIO_LOG, + EVENT_PORTFOLIO_STRATEGY +) +from ..engine import StrategyEngine + + +class PortfolioStrategyManager(QtWidgets.QWidget): + """""" + + signal_log = QtCore.pyqtSignal(Event) + signal_strategy = QtCore.pyqtSignal(Event) + + def __init__(self, main_engine: MainEngine, event_engine: EventEngine): + """""" + super().__init__() + + self.main_engine: MainEngine = main_engine + self.event_engine: EventEngine = event_engine + self.strategy_engine: StrategyEngine = main_engine.get_engine(APP_NAME) + + self.managers: Dict[str, StrategyManager] = {} + + self.init_ui() + self.register_event() + self.strategy_engine.init_engine() + self.update_class_combo() + + def init_ui(self) -> None: + """""" + self.setWindowTitle("组合策略") + + # Create widgets + self.class_combo = QtWidgets.QComboBox() + + add_button = QtWidgets.QPushButton("添加策略") + add_button.clicked.connect(self.add_strategy) + + init_button = QtWidgets.QPushButton("全部初始化") + init_button.clicked.connect(self.strategy_engine.init_all_strategies) + + start_button = QtWidgets.QPushButton("全部启动") + start_button.clicked.connect(self.strategy_engine.start_all_strategies) + + stop_button = QtWidgets.QPushButton("全部停止") + stop_button.clicked.connect(self.strategy_engine.stop_all_strategies) + + clear_button = QtWidgets.QPushButton("清空日志") + clear_button.clicked.connect(self.clear_log) + + self.scroll_layout = QtWidgets.QVBoxLayout() + self.scroll_layout.addStretch() + + scroll_widget = QtWidgets.QWidget() + scroll_widget.setLayout(self.scroll_layout) + + scroll_area = QtWidgets.QScrollArea() + scroll_area.setWidgetResizable(True) + scroll_area.setWidget(scroll_widget) + + self.log_monitor = LogMonitor(self.main_engine, self.event_engine) + + # Set layout + hbox1 = QtWidgets.QHBoxLayout() + hbox1.addWidget(self.class_combo) + hbox1.addWidget(add_button) + hbox1.addStretch() + hbox1.addWidget(init_button) + hbox1.addWidget(start_button) + hbox1.addWidget(stop_button) + hbox1.addWidget(clear_button) + + hbox2 = QtWidgets.QHBoxLayout() + hbox2.addWidget(scroll_area) + hbox2.addWidget(self.log_monitor) + + vbox = QtWidgets.QVBoxLayout() + vbox.addLayout(hbox1) + vbox.addLayout(hbox2) + + self.setLayout(vbox) + + def update_class_combo(self): + """""" + self.class_combo.addItems( + self.strategy_engine.get_all_strategy_class_names() + ) + + def register_event(self): + """""" + self.signal_strategy.connect(self.process_strategy_event) + + self.event_engine.register( + EVENT_PORTFOLIO_STRATEGY, self.signal_strategy.emit + ) + + def process_strategy_event(self, event): + """ + Update strategy status onto its monitor. + """ + data = event.data + strategy_name = data["strategy_name"] + + if strategy_name in self.managers: + manager = self.managers[strategy_name] + manager.update_data(data) + else: + manager = StrategyManager(self, self.strategy_engine, data) + self.scroll_layout.insertWidget(0, manager) + self.managers[strategy_name] = manager + + def remove_strategy(self, strategy_name): + """""" + manager = self.managers.pop(strategy_name) + manager.deleteLater() + + def add_strategy(self): + """""" + class_name = str(self.class_combo.currentText()) + if not class_name: + return + + parameters = self.strategy_engine.get_strategy_class_parameters(class_name) + editor = SettingEditor(parameters, class_name=class_name) + n = editor.exec_() + + if n == editor.Accepted: + setting = editor.get_setting() + vt_symbols = setting.pop("vt_symbols").split(",") + strategy_name = setting.pop("strategy_name") + + self.strategy_engine.add_strategy( + class_name, strategy_name, vt_symbols, setting + ) + + def clear_log(self): + """""" + self.log_monitor.setRowCount(0) + + def show(self): + """""" + self.showMaximized() + + +class StrategyManager(QtWidgets.QFrame): + """ + Manager for a strategy + """ + + def __init__( + self, + strategy_manager: PortfolioStrategyManager, + strategy_engine: StrategyEngine, + data: dict + ): + """""" + super().__init__() + + self.strategy_manager = strategy_manager + self.strategy_engine = strategy_engine + + self.strategy_name = data["strategy_name"] + self._data = data + + self.init_ui() + + def init_ui(self): + """""" + self.setFixedHeight(300) + self.setFrameShape(self.Box) + self.setLineWidth(1) + + self.init_button = QtWidgets.QPushButton("初始化") + self.init_button.clicked.connect(self.init_strategy) + + self.start_button = QtWidgets.QPushButton("启动") + self.start_button.clicked.connect(self.start_strategy) + self.start_button.setEnabled(False) + + self.stop_button = QtWidgets.QPushButton("停止") + self.stop_button.clicked.connect(self.stop_strategy) + self.stop_button.setEnabled(False) + + self.edit_button = QtWidgets.QPushButton("编辑") + self.edit_button.clicked.connect(self.edit_strategy) + + self.remove_button = QtWidgets.QPushButton("移除") + self.remove_button.clicked.connect(self.remove_strategy) + + strategy_name = self._data["strategy_name"] + class_name = self._data["class_name"] + author = self._data["author"] + + label_text = ( + f"{strategy_name} - ({class_name} by {author})" + ) + label = QtWidgets.QLabel(label_text) + label.setAlignment(QtCore.Qt.AlignCenter) + + self.parameters_monitor = DataMonitor(self._data["parameters"]) + self.variables_monitor = DataMonitor(self._data["variables"]) + + hbox = QtWidgets.QHBoxLayout() + hbox.addWidget(self.init_button) + hbox.addWidget(self.start_button) + hbox.addWidget(self.stop_button) + hbox.addWidget(self.edit_button) + hbox.addWidget(self.remove_button) + + vbox = QtWidgets.QVBoxLayout() + vbox.addWidget(label) + vbox.addLayout(hbox) + vbox.addWidget(self.parameters_monitor) + vbox.addWidget(self.variables_monitor) + self.setLayout(vbox) + + def update_data(self, data: dict): + """""" + self._data = data + + self.parameters_monitor.update_data(data["parameters"]) + self.variables_monitor.update_data(data["variables"]) + + # Update button status + variables = data["variables"] + inited = variables["inited"] + trading = variables["trading"] + + if not inited: + return + self.init_button.setEnabled(False) + + if trading: + self.start_button.setEnabled(False) + self.stop_button.setEnabled(True) + self.edit_button.setEnabled(False) + self.remove_button.setEnabled(False) + else: + self.start_button.setEnabled(True) + self.stop_button.setEnabled(False) + self.edit_button.setEnabled(True) + self.remove_button.setEnabled(True) + + def init_strategy(self): + """""" + self.strategy_engine.init_strategy(self.strategy_name) + + def start_strategy(self): + """""" + self.strategy_engine.start_strategy(self.strategy_name) + + def stop_strategy(self): + """""" + self.strategy_engine.stop_strategy(self.strategy_name) + + def edit_strategy(self): + """""" + strategy_name = self._data["strategy_name"] + + parameters = self.strategy_engine.get_strategy_parameters(strategy_name) + editor = SettingEditor(parameters, strategy_name=strategy_name) + n = editor.exec_() + + if n == editor.Accepted: + setting = editor.get_setting() + self.strategy_engine.edit_strategy(strategy_name, setting) + + def remove_strategy(self): + """""" + result = self.strategy_engine.remove_strategy(self.strategy_name) + + # Only remove strategy gui manager if it has been removed from engine + if result: + self.strategy_manager.remove_strategy(self.strategy_name) + + +class DataMonitor(QtWidgets.QTableWidget): + """ + Table monitor for parameters and variables. + """ + + def __init__(self, data: dict): + """""" + super(DataMonitor, self).__init__() + + self._data = data + self.cells = {} + + self.init_ui() + + def init_ui(self): + """""" + labels = list(self._data.keys()) + self.setColumnCount(len(labels)) + self.setHorizontalHeaderLabels(labels) + + self.setRowCount(1) + self.verticalHeader().setSectionResizeMode( + QtWidgets.QHeaderView.Stretch + ) + self.verticalHeader().setVisible(False) + self.setEditTriggers(self.NoEditTriggers) + + for column, name in enumerate(self._data.keys()): + value = self._data[name] + + cell = QtWidgets.QTableWidgetItem(str(value)) + cell.setTextAlignment(QtCore.Qt.AlignCenter) + + self.setItem(0, column, cell) + self.cells[name] = cell + + def update_data(self, data: dict): + """""" + for name, value in data.items(): + cell = self.cells[name] + cell.setText(str(value)) + + +class LogMonitor(BaseMonitor): + """ + Monitor for log data. + """ + + event_type = EVENT_PORTFOLIO_LOG + data_key = "" + sorting = False + + headers = { + "time": {"display": "时间", "cell": TimeCell, "update": False}, + "msg": {"display": "信息", "cell": MsgCell, "update": False}, + } + + def init_ui(self): + """ + Stretch last column. + """ + super(LogMonitor, self).init_ui() + + self.horizontalHeader().setSectionResizeMode( + 1, QtWidgets.QHeaderView.Stretch + ) + + def insert_new_row(self, data): + """ + Insert a new row at the top of table. + """ + super().insert_new_row(data) + self.resizeRowToContents(0) + + +class SettingEditor(QtWidgets.QDialog): + """ + For creating new strategy and editing strategy parameters. + """ + + def __init__( + self, parameters: dict, strategy_name: str = "", class_name: str = "" + ): + """""" + super(SettingEditor, self).__init__() + + self.parameters = parameters + self.strategy_name = strategy_name + self.class_name = class_name + + self.edits = {} + + self.init_ui() + + def init_ui(self): + """""" + form = QtWidgets.QFormLayout() + + # Add vt_symbols and name edit if add new strategy + if self.class_name: + self.setWindowTitle(f"添加策略:{self.class_name}") + button_text = "添加" + parameters = {"strategy_name": "", "vt_symbols": ""} + parameters.update(self.parameters) + else: + self.setWindowTitle(f"参数编辑:{self.strategy_name}") + button_text = "确定" + parameters = self.parameters + + for name, value in parameters.items(): + type_ = type(value) + + edit = QtWidgets.QLineEdit(str(value)) + if type_ is int: + validator = QtGui.QIntValidator() + edit.setValidator(validator) + elif type_ is float: + validator = QtGui.QDoubleValidator() + edit.setValidator(validator) + + form.addRow(f"{name} {type_}", edit) + + self.edits[name] = (edit, type_) + + button = QtWidgets.QPushButton(button_text) + button.clicked.connect(self.accept) + form.addRow(button) + + self.setLayout(form) + + def get_setting(self): + """""" + setting = {} + + if self.class_name: + setting["class_name"] = self.class_name + + for name, tp in self.edits.items(): + edit, type_ = tp + value_text = edit.text() + + if type_ == bool: + if value_text == "True": + value = True + else: + value = False + else: + value = type_(value_text) + + setting[name] = value + + return setting diff --git a/vnpy/component/cta_grid_trade.py b/vnpy/component/cta_grid_trade.py index eb060037..9b01270b 100644 --- a/vnpy/component/cta_grid_trade.py +++ b/vnpy/component/cta_grid_trade.py @@ -595,7 +595,7 @@ class CtaGridTrade(CtaComponent): x.type = type # 网格类型标签 # self.open_prices = {} # 套利使用,开仓价格,symbol:price - def rebuild_grids(self, direction: Direction, + def rebuild_grids(self, directions: list = [], upper_line: float = 0.0, down_line: float = 0.0, middle_line: float = 0.0, @@ -608,7 +608,7 @@ class CtaGridTrade(CtaComponent): upRate , 上轨网格高度比率 dnRate, 下轨网格高度比率 """ - self.write_log(u'重新拉网:direction:{},upline:{},dnline:{}'.format(direction, upper_line, down_line)) + self.write_log(u'重新拉网:direction:{},upline:{},dnline:{}'.format(directions, upper_line, down_line)) # 检查上下网格的高度比率,不能低于0.5 if upper_rate < 0.5 or down_rate < 0.5: @@ -616,7 +616,7 @@ class CtaGridTrade(CtaComponent): down_rate = max(0.5, down_rate) # 重建下网格(移除未挂单、保留开仓得网格、在最低价之下才增加网格 - if direction == Direction.LONG: + if len(directions) == 0 or Direction.LONG in directions: min_long_price = middle_line remove_grids = [] opened_grids = [] @@ -636,8 +636,8 @@ class CtaGridTrade(CtaComponent): self.write_log(u'保留下网格[{}]'.format(opened_grids)) # 需要重建的剩余网格数量 - remainLots = len(self.dn_grids) - lots = self.max_lots - remainLots + remain_lots = len(self.dn_grids) + lots = self.max_lots - remain_lots down_line = min(down_line, min_long_price - self.grid_height * down_rate) self.write_log(u'需要重建的网格数量:{0},起点:{1}'.format(lots, down_line)) @@ -651,13 +651,13 @@ class CtaGridTrade(CtaComponent): grid = CtaGrid(direction=Direction.LONG, open_price=open_price, close_price=close_price, - volume=self.volume * self.get_volume_rate(remainLots + i)) + volume=self.volume * self.get_volume_rate(remain_lots + i)) grid.reuse_count = reuse_count self.dn_grids.append(grid) self.write_log(u'重新拉下网格:[{0}==>{1}]'.format(down_line, down_line - lots * self.grid_height * down_rate)) # 重建上网格(移除未挂单、保留开仓得网格、在最高价之上才增加网格 - if direction == Direction.SHORT: + if len(directions) == 0 or Direction.SHORT in directions: max_short_price = middle_line # 最高开空价 remove_grids = [] # 移除的网格列表 opened_grids = [] # 已开仓的网格列表 @@ -678,8 +678,8 @@ class CtaGridTrade(CtaComponent): self.write_log(u'保留上网格[{}]'.format(opened_grids)) # 需要重建的剩余网格数量 - remainLots = len(self.up_grids) - lots = self.max_lots - remainLots + remain_lots = len(self.up_grids) + lots = self.max_lots - remain_lots upper_line = max(upper_line, max_short_price + self.grid_height * upper_rate) self.write_log(u'需要重建的网格数量:{0},起点:{1}'.format(lots, upper_line)) @@ -692,7 +692,7 @@ class CtaGridTrade(CtaComponent): grid = CtaGrid(direction=Direction.SHORT, open_price=open_price, close_price=close_price, - volume=self.volume * self.get_volume_rate(remainLots + i)) + volume=self.volume * self.get_volume_rate(remain_lots + i)) grid.reuse_count = reuse_count self.up_grids.append(grid) diff --git a/vnpy/component/cta_line_bar.py b/vnpy/component/cta_line_bar.py index 7b93f29c..7a156f09 100644 --- a/vnpy/component/cta_line_bar.py +++ b/vnpy/component/cta_line_bar.py @@ -178,13 +178,20 @@ class CtaLineBar(object): # (实时运行时,或者addbar小于bar得周期时,不包含最后一根Bar) self.open_array = np.zeros(self.max_hold_bars) # 与lineBar一致得开仓价清单 + self.open_array[:] = np.nan self.high_array = np.zeros(self.max_hold_bars) # 与lineBar一致得最高价清单 + self.high_array[:] = np.nan self.low_array = np.zeros(self.max_hold_bars) # 与lineBar一致得最低价清单 + self.low_array[:] = np.nan self.close_array = np.zeros(self.max_hold_bars) # 与lineBar一致得收盘价清单 + self.close_array[:] = np.nan self.mid3_array = np.zeros(self.max_hold_bars) # 收盘价/最高/最低价 的平均价 + self.mid3_array[:] = np.nan self.mid4_array = np.zeros(self.max_hold_bars) # 收盘价*2/最高/最低价 的平均价 + self.mid4_array[:] = np.nan self.mid5_array = np.zeros(self.max_hold_bars) # 收盘价*2/开仓价/最高/最低价 的平均价 + self.mid5_array[:] = np.nan self.export_filename = None self.export_fields = [] @@ -1263,7 +1270,8 @@ class CtaLineBar(object): # 2.计算前inputPreLen周期内(不包含当前周期)的Bar高点和低点 preHigh = max(self.high_array[-count_len:]) preLow = min(self.low_array[-count_len:]) - + if np.isnan(preHigh) or np.isnan(preLow): + return # 保存 if len(self.line_pre_high) > self.max_hold_bars: del self.line_pre_high[0] @@ -1452,6 +1460,8 @@ class CtaLineBar(object): count_len = min(self.para_ma1_len, self.bar_len) barMa1 = ta.MA(self.close_array[-count_len:], count_len)[-1] + if np.isnan(barMa1): + return barMa1 = round(barMa1, self.round_n) if len(self.line_ma1) > self.max_hold_bars: @@ -1470,6 +1480,8 @@ class CtaLineBar(object): if self.para_ma2_len > 0: count_len = min(self.para_ma2_len, self.bar_len) barMa2 = ta.MA(self.close_array[-count_len:], count_len)[-1] + if np.isnan(barMa2): + return barMa2 = round(barMa2, self.round_n) if len(self.line_ma2) > self.max_hold_bars: @@ -1488,6 +1500,8 @@ class CtaLineBar(object): if self.para_ma3_len > 0: count_len = min(self.para_ma3_len, self.bar_len) barMa3 = ta.MA(self.close_array[-count_len:], count_len)[-1] + if np.isnan(barMa3): + return barMa3 = round(barMa3, self.round_n) if len(self.line_ma3) > self.max_hold_bars: @@ -1685,6 +1699,8 @@ class CtaLineBar(object): # 3、获取前InputN周期(不包含当前周期)的K线 barEma1 = ta.EMA(self.close_array[-ema1_data_len:], count_len)[-1] + if np.isnan(barEma1): + return barEma1 = round(float(barEma1), self.round_n) if len(self.line_ema1) > self.max_hold_bars: @@ -1698,7 +1714,8 @@ class CtaLineBar(object): # 3、获取前InputN周期(不包含当前周期)的自适应均线 barEma2 = ta.EMA(self.close_array[-ema2_data_len:], count_len)[-1] - + if np.isnan(barEma2): + return barEma2 = round(float(barEma2), self.round_n) if len(self.line_ema2) > self.max_hold_bars: @@ -1711,6 +1728,8 @@ class CtaLineBar(object): # 3、获取前InputN周期(不包含当前周期)的自适应均线 barEma3 = ta.EMA(self.close_array[-ema3_data_len:], count_len)[-1] + if np.isnan(barEma3): + return barEma3 = round(float(barEma3), self.round_n) if len(self.line_ema3) > self.max_hold_bars: @@ -1737,6 +1756,8 @@ class CtaLineBar(object): # 3、获取前InputN周期(不包含当前周期)的K线 barEma1 = ta.EMA(np.append(self.close_array[-ema1_data_len:], [self.cur_price]), count_len)[-1] + if np.isnan(barEma1): + return self._rt_ema1 = round(float(barEma1), self.round_n) # 计算第二条EMA均线 @@ -1746,7 +1767,8 @@ class CtaLineBar(object): # 3、获取前InputN周期(不包含当前周期)的自适应均线 barEma2 = ta.EMA(np.append(self.close_array[-ema2_data_len:], [self.cur_price]), count_len)[-1] - + if np.isnan(barEma2): + return self._rt_ema2 = round(float(barEma2), self.round_n) # 计算第三条EMA均线 @@ -1755,6 +1777,8 @@ class CtaLineBar(object): # 3、获取前InputN周期(不包含当前周期)的自适应均线 barEma3 = ta.EMA(np.append(self.close_array[-ema3_data_len:], [self.cur_price]), count_len)[-1] + if np.isnan(barEma3): + return self._rt_ema3 = round(float(barEma3), self.round_n) @property @@ -2076,7 +2100,7 @@ class CtaLineBar(object): return if self.para_boll_len > 0: - if self.bar_len < min(7, self.para_boll_len): + if self.bar_len < min(20, self.para_boll_len): self.write_log(u'数据未充分,当前Bar数据数量:{0},计算Boll需要:{1}'. format(len(self.line_bar), min(14, self.para_boll_len) + 1)) else: @@ -2086,6 +2110,9 @@ class CtaLineBar(object): upper_list, middle_list, lower_list = ta.BBANDS(self.close_array, timeperiod=bollLen, nbdevup=self.para_boll_std_rate, nbdevdn=self.para_boll_std_rate, matype=0) + if np.isnan(upper_list[-1]): + return + if len(self.line_boll_upper) > self.max_hold_bars: del self.line_boll_upper[0] if len(self.line_boll_middle) > self.max_hold_bars: @@ -2144,6 +2171,8 @@ class CtaLineBar(object): upper_list, middle_list, lower_list = ta.BBANDS(self.close_array, timeperiod=boll2Len, nbdevup=self.para_boll2_std_rate, nbdevdn=self.para_boll2_std_rate, matype=0) + if np.isnan(upper_list[-1]): + return if len(self.line_boll2_upper) > self.max_hold_bars: del self.line_boll2_upper[0] if len(self.line_boll2_middle) > self.max_hold_bars: @@ -2510,14 +2539,19 @@ class CtaLineBar(object): hhv = max(self.high_array[-inputKdjLen:]) llv = min(self.low_array[-inputKdjLen:]) - + if np.isnan(hhv) or np.isnan(llv): + return if len(self.line_k) > 0: lastK = self.line_k[-1] + if np.isnan(lastK): + lastK = 0 else: lastK = 0 if len(self.line_d) > 0: lastD = self.line_d[-1] + if np.isnan(lastD): + lastD = 0 else: lastD = 0 @@ -2620,14 +2654,20 @@ class CtaLineBar(object): hhv = max(self.high_array[-data_len:]) llv = min(self.low_array[-data_len:]) + if np.isnan(hhv) or np.isnan(llv): + return if len(self.line_k) > 0: lastK = self.line_k[-1] + if np.isnan(lastK): + lastK = 0 else: lastK = 0 if len(self.line_d) > 0: lastD = self.line_d[-1] + if np.isnan(lastD): + lastD = 0 else: lastD = 0 @@ -2747,7 +2787,8 @@ class CtaLineBar(object): dif_list, dea_list, macd_list = ta.MACD(self.close_array[-2 * maxLen:], fastperiod=self.para_macd_fast_len, slowperiod=self.para_macd_slow_len, signalperiod=self.para_macd_signal_len) - + if np.isnan(dif_list[-1]) or np.isnan(dea_list[-1]) or np.isnan(macd_list[-1]): + return # dif, dea, macd = ta.MACDEXT(np.array(listClose, dtype=float), # fastperiod=self.inputMacdFastPeriodLen, fastmatype=1, # slowperiod=self.inputMacdSlowPeriodLen, slowmatype=1, @@ -2868,6 +2909,9 @@ class CtaLineBar(object): fastperiod=self.para_macd_fast_len, slowperiod=self.para_macd_slow_len, signalperiod=self.para_macd_signal_len) + if np.isnan(dif[-1]) or np.isnan(dea[-1]) or np.isnan(macd[-1]): + return + self._rt_dif = round(dif[-1], self.round_n) if len(dif) > 0 else None self._rt_dea = round(dea[-1], self.round_n) if len(dea) > 0 else None self._rt_macd = round(macd[-1] * 2, self.round_n) if len(macd) > 0 else None @@ -3958,6 +4002,9 @@ class CtaLineBar(object): hhv = max(self.high_array[-bar_len:]) llv = min(self.low_array[-bar_len:]) + if np.isnan(hhv) or np.isnan(llv): + return + self.cur_p192 = hhv - (hhv - llv) * 0.192 self.cur_p382 = hhv - (hhv - llv) * 0.382 self.cur_p500 = (hhv + llv) / 2 diff --git a/vnpy/trader/engine.py b/vnpy/trader/engine.py index 0836b159..519c4d00 100644 --- a/vnpy/trader/engine.py +++ b/vnpy/trader/engine.py @@ -25,6 +25,7 @@ from .event import ( ) from .gateway import BaseGateway from .object import ( + Direction, Exchange, CancelRequest, LogData, @@ -214,7 +215,7 @@ class MainEngine: """ # 自定义套利合约,交给算法引擎处理 if self.algo_engine and req.exchange == Exchange.SPD: - return self.algo_engine.send_algo_order( + return self.algo_engine.send_spd_order( req=req, gateway_name=gateway_name) @@ -228,6 +229,11 @@ class MainEngine: """ Send cancel order request to a specific gateway. """ + # 自定义套利合约,交给算法引擎处理 + if self.algo_engine and req.exchange == Exchange.SPD: + return self.algo_engine.cancel_spd_order( + req=req) + gateway = self.get_gateway(gateway_name) if gateway: return gateway.cancel_order(req) @@ -422,13 +428,19 @@ class OmsEngine(BaseEngine): self.accounts: Dict[str, AccountData] = {} self.contracts: Dict[str, ContractData] = {} self.today_contracts: Dict[str, ContractData] = {} - self.custom_contracts = {} + + # 自定义合约 + self.custom_contracts = {} # vt_symbol: ContractData + self.custom_settings = {} # symbol: dict + self.symbol_spd_maping = {} # symbol: [spd_symbol] + self.prices = {} self.active_orders: Dict[str, OrderData] = {} self.add_function() self.register_event() + self.load_contracts() def __del__(self): """保存缓存""" @@ -445,6 +457,30 @@ class OmsEngine(BaseEngine): self.contracts = pickle.load(f) self.write_log(f'加载缓存合约字典:{contract_file_name}') + # 更新自定义合约 + custom_contracts = self.get_all_custom_contracts() + for contract in custom_contracts.values(): + + # 更新合约缓存 + self.contracts.update({contract.symbol: contract}) + self.contracts.update({contract.vt_symbol: contract}) + self.today_contracts[contract.vt_symbol] = contract + self.today_contracts[contract.symbol] = contract + + # 获取自定义合约的主动腿/被动腿 + setting = self.custom_settings.get(contract.symbol, {}) + leg1_symbol = setting.get('leg1_symbol') + leg2_symbol = setting.get('leg2_symbol') + + # 构建映射关系 + for symbol in [leg1_symbol, leg2_symbol]: + spd_mapping_list = self.symbol_spd_maping.get(symbol, []) + + # 更新映射 symbol => spd_symbol + if contract.symbol not in spd_mapping_list: + spd_mapping_list.append(contract.symbol) + self.symbol_spd_maping.update({symbol: spd_mapping_list}) + def save_contracts(self) -> None: """持久化合约对象到缓存文件""" import bz2 @@ -474,6 +510,7 @@ class OmsEngine(BaseEngine): self.main_engine.get_all_contracts = self.get_all_contracts self.main_engine.get_all_active_orders = self.get_all_active_orders self.main_engine.get_all_custom_contracts = self.get_all_custom_contracts + self.main_engine.get_mapping_spd = self.get_mapping_spd self.main_engine.save_contracts = self.save_contracts def register_event(self) -> None: @@ -515,6 +552,76 @@ class OmsEngine(BaseEngine): position = event.data self.positions[position.vt_positionid] = position + def reverse_direction(self, direction): + """返回反向持仓""" + if direction == Direction.LONG: + return Direction.SHORT + elif direction == Direction.SHORT: + return Direction.LONG + return direction + + def create_spd_position_event(self, symbol, direction ): + """创建自定义品种对持仓信息""" + spd_symbols = self.symbol_spd_maping.get(symbol, []) + if not spd_symbols: + return + for spd_symbol in spd_symbols: + spd_setting = self.custom_settings.get(spd_symbol, None) + if not spd_setting: + continue + + leg1_symbol = spd_setting.get('leg1_symbol') + leg2_symbol = spd_setting.get('leg2_symbol') + leg1_contract = self.contracts.get(leg1_symbol) + leg2_contract = self.contracts.get(leg2_symbol) + spd_contract = self.contracts.get(spd_symbol) + + if leg1_contract is None or leg2_contract is None: + continue + leg1_ratio = spd_setting.get('leg1_ratio', 1) + leg2_ratio = spd_setting.get('leg2_ratio', 1) + + # 找出leg1,leg2的持仓,并判断出spd的方向 + if leg1_symbol == symbol: + k1 = f"{leg1_contract.gateway_name}.{leg1_contract.vt_symbol}.{direction.value}" + leg1_pos = self.positions.get(k1) + k2 = f"{leg2_contract.gateway_name}.{leg2_contract.vt_symbol}.{self.reverse_direction(direction).value}" + leg2_pos = self.positions.get(k2) + spd_direction = direction + elif leg2_symbol == symbol: + k1 = f"{leg1_contract.gateway_name}.{leg1_contract.vt_symbol}.{self.reverse_direction(direction).value}" + leg1_pos = self.positions.get(k1) + k2 = f"{leg2_contract.gateway_name}.{leg2_contract.vt_symbol}.{direction.value}" + leg2_pos = self.positions.get(k2) + spd_direction = self.reverse_direction(direction) + else: + continue + + if leg1_pos is None or leg2_pos is None or leg1_pos.volume ==0 or leg2_pos.volume == 0: + continue + + # 根据leg1/leg2的volume ratio,计算出最小spd_volume + spd_volume = min(int(leg1_pos.volume/leg1_ratio), int(leg2_pos.volume/leg2_ratio)) + if spd_volume <= 0: + continue + if spd_setting.get('is_ratio', False) and leg2_pos.price > 0: + spd_price = 100 * (leg2_pos.price * leg1_ratio) / (leg2_pos.price * leg2_ratio) + elif spd_setting.get('is_spread', False): + spd_price = leg1_pos.price * leg1_ratio - leg2_pos.price * leg2_ratio + else: + spd_price = 0 + + spd_pos = PositionData( + gateway_name=spd_contract.gateway_name, + symbol=spd_symbol, + exchange=Exchange.SPD, + direction=spd_direction, + volume=spd_volume, + price=spd_price + ) + event = Event(EVENT_POSITION, data=spd_pos) + self.event_engine.put(event) + def process_account_event(self, event: Event) -> None: """""" account = event.data @@ -624,16 +731,25 @@ class OmsEngine(BaseEngine): ] return active_orders - def get_all_custom_contracts(self): + def get_all_custom_contracts(self, rtn_setting=False): """ 获取所有自定义合约 :return: """ + if rtn_setting: + if len(self.custom_settings) == 0: + c = CustomContract() + self.custom_settings = c.get_config() + return self.custom_settings + if len(self.custom_contracts) == 0: c = CustomContract() self.custom_contracts = c.get_contracts() return self.custom_contracts + def get_mapping_spd(self, symbol): + """根据主动腿/被动腿symbol,获取自定义套利对的symbol list""" + return self.symbol_spd_maping.get(symbol, []) class CustomContract(object): """ @@ -668,6 +784,7 @@ class CustomContract(object): exchange=vn_exchange, name=setting.get('name', symbol), size=setting.get('size', 100), + product=None, pricetick=setting.get('price_tick', 0.01), margin_rate=setting.get('margin_rate', 0.1) ) diff --git a/vnpy/trader/gateway.py b/vnpy/trader/gateway.py index 94b3510f..34da6863 100644 --- a/vnpy/trader/gateway.py +++ b/vnpy/trader/gateway.py @@ -123,7 +123,7 @@ class BaseGateway(ABC): Tick event of a specific vt_symbol is also pushed. """ self.on_event(EVENT_TICK, tick) - self.on_event(EVENT_TICK + tick.vt_symbol, tick) + # self.on_event(EVENT_TICK + tick.vt_symbol, tick) # 推送Bar kline = self.klines.get(tick.vt_symbol, None) @@ -142,7 +142,7 @@ class BaseGateway(ABC): Trade event of a specific vt_symbol is also pushed. """ self.on_event(EVENT_TRADE, trade) - self.on_event(EVENT_TRADE + trade.vt_symbol, trade) + # self.on_event(EVENT_TRADE + trade.vt_symbol, trade) def on_order(self, order: OrderData) -> None: """ @@ -150,7 +150,7 @@ class BaseGateway(ABC): Order event of a specific vt_orderid is also pushed. """ self.on_event(EVENT_ORDER, order) - self.on_event(EVENT_ORDER + order.vt_orderid, order) + # self.on_event(EVENT_ORDER + order.vt_orderid, order) def on_position(self, position: PositionData) -> None: """ @@ -158,7 +158,7 @@ class BaseGateway(ABC): Position event of a specific vt_symbol is also pushed. """ self.on_event(EVENT_POSITION, position) - self.on_event(EVENT_POSITION + position.vt_symbol, position) + # self.on_event(EVENT_POSITION + position.vt_symbol, position) def on_account(self, account: AccountData) -> None: """ @@ -166,7 +166,7 @@ class BaseGateway(ABC): Account event of a specific vt_accountid is also pushed. """ self.on_event(EVENT_ACCOUNT, account) - self.on_event(EVENT_ACCOUNT + account.vt_accountid, account) + # self.on_event(EVENT_ACCOUNT + account.vt_accountid, account) def on_log(self, log: LogData) -> None: """ diff --git a/vnpy/trader/utility.py b/vnpy/trader/utility.py index 4691f5df..5140a9db 100644 --- a/vnpy/trader/utility.py +++ b/vnpy/trader/utility.py @@ -674,11 +674,11 @@ class BarGenerator: """ def __init__( - self, - on_bar: Callable, - window: int = 0, - on_window_bar: Callable = None, - interval: Interval = Interval.MINUTE + self, + on_bar: Callable, + window: int = 0, + on_window_bar: Callable = None, + interval: Interval = Interval.MINUTE ): """Constructor""" self.bar: BarData = None @@ -804,11 +804,14 @@ class BarGenerator: """ Generate the bar data and call callback immediately. """ - self.bar.datetime = self.bar.datetime.replace( - second=0, microsecond=0 - ) - self.on_bar(self.bar) + bar = self.bar + + if self.bar: + bar.datetime = bar.datetime.replace(second=0, microsecond=0) + self.on_bar(bar) + self.bar = None + return bar class ArrayManager(object): @@ -1067,11 +1070,11 @@ class ArrayManager(object): return result[-1] def macd( - self, - fast_period: int, - slow_period: int, - signal_period: int, - array: bool = False + self, + fast_period: int, + slow_period: int, + signal_period: int, + array: bool = False ) -> Union[ Tuple[np.ndarray, np.ndarray, np.ndarray], Tuple[float, float, float] @@ -1159,10 +1162,10 @@ class ArrayManager(object): return result[-1] def boll( - self, - n: int, - dev: float, - array: bool = False + self, + n: int, + dev: float, + array: bool = False ) -> Union[ Tuple[np.ndarray, np.ndarray], Tuple[float, float] @@ -1179,10 +1182,10 @@ class ArrayManager(object): return up, down def keltner( - self, - n: int, - dev: float, - array: bool = False + self, + n: int, + dev: float, + array: bool = False ) -> Union[ Tuple[np.ndarray, np.ndarray], Tuple[float, float] @@ -1199,7 +1202,7 @@ class ArrayManager(object): return up, down def donchian( - self, n: int, array: bool = False + self, n: int, array: bool = False ) -> Union[ Tuple[np.ndarray, np.ndarray], Tuple[float, float] @@ -1215,10 +1218,9 @@ class ArrayManager(object): return up[-1], down[-1] def aroon( - self, - n: int, - dev: float, - array: bool = False + self, + n: int, + array: bool = False ) -> Union[ Tuple[np.ndarray, np.ndarray], Tuple[float, float]