diff --git a/vnpy/app/cta_option/__init__.py b/vnpy/app/cta_option/__init__.py new file mode 100644 index 00000000..7dfc0301 --- /dev/null +++ b/vnpy/app/cta_option/__init__.py @@ -0,0 +1,37 @@ +from pathlib import Path + +from vnpy.trader.app import BaseApp +from .base import APP_NAME + +# 期权CTA策略引擎 +from .engine import CtaOptionEngine + +from .template import ( + Direction, + Offset, + Exchange, + Status, + Color, + ContractData, + HistoryRequest, + TickData, + BarData, + TradeData, + OrderData, + CtaTemplate, + CtaOptionTemplate, + CtaOptionPolicy + ) # noqa +from vnpy.trader.utility import BarGenerator, ArrayManager # noqa + + +class CtaOptionApp(BaseApp): + """期权引擎App""" + + app_name = APP_NAME + app_module = __module__ + app_path = Path(__file__).parent + display_name = "CTA期权策略" + engine_class = CtaOptionEngine + widget_name = "CtaOption" + icon_name = "cta.ico" diff --git a/vnpy/app/cta_option/base.py b/vnpy/app/cta_option/base.py new file mode 100644 index 00000000..32e770e8 --- /dev/null +++ b/vnpy/app/cta_option/base.py @@ -0,0 +1,53 @@ +""" +Defines constants and objects used in CtaStrategyPro App. +""" + +from dataclasses import dataclass, field +from enum import Enum +from datetime import timedelta +from vnpy.trader.constant import Direction, Offset, Interval + +APP_NAME = "CtaOption" +STOPORDER_PREFIX = "STOP" + + +class StopOrderStatus(Enum): + WAITING = "等待中" + CANCELLED = "已撤销" + TRIGGERED = "已触发" + + +class EngineType(Enum): + LIVE = "实盘" + BACKTESTING = "回测" + + +class BacktestingMode(Enum): + BAR = 1 + TICK = 2 + + +@dataclass +class StopOrder: + vt_symbol: str + direction: Direction + offset: Offset + price: float + volume: float + stop_orderid: str + strategy_name: str + lock: bool = False + vt_orderids: list = field(default_factory=list) + status: StopOrderStatus = StopOrderStatus.WAITING + gateway_name: str = None + + +EVENT_CTA_LOG = "eCtaLog" +EVENT_CTA_OPTION = "eCtaOption" +EVENT_CTA_STOPORDER = "eCtaStopOrder" + +INTERVAL_DELTA_MAP = { + Interval.MINUTE: timedelta(minutes=1), + Interval.HOUR: timedelta(hours=1), + Interval.DAILY: timedelta(days=1), +} diff --git a/vnpy/app/cta_option/engine.py b/vnpy/app/cta_option/engine.py new file mode 100644 index 00000000..c8adfc0a --- /dev/null +++ b/vnpy/app/cta_option/engine.py @@ -0,0 +1,2379 @@ +""" +CTA期权策略运行引擎 +华富资产 +""" + +import importlib +import os +import sys +import traceback +import json +import pickle +import bz2 +import pandas as pd +import numpy as np + +from collections import defaultdict +from pathlib import Path +from typing import Any, Callable, List, Dict +from datetime import datetime, timedelta +from collections import OrderedDict +from concurrent.futures import ThreadPoolExecutor +from copy import copy +from functools import lru_cache +from uuid import uuid1 + +from vnpy.event import Event, EventEngine +from vnpy.trader.engine import BaseEngine, MainEngine +from vnpy.trader.object import ( + OrderRequest, + SubscribeRequest, + LogData, + TickData, + BarData, + PositionData, + ContractData, + HistoryRequest, + Interval + +) +from vnpy.trader.event import ( + EVENT_TIMER, + EVENT_TICK, + EVENT_BAR, + EVENT_ORDER, + EVENT_TRADE, + EVENT_POSITION, + EVENT_STRATEGY_POS, + EVENT_STRATEGY_SNAPSHOT +) +from vnpy.trader.constant import ( + Direction, + Exchange, + Product, + OrderType, + Offset, + Status +) +from vnpy.trader.utility import ( + load_json, + save_json, + extract_vt_symbol, + round_to, + TRADER_DIR, + get_folder_path, + get_underlying_symbol, + append_data, + import_module_by_str, +get_csv_last_dt) + +from vnpy.trader.util_logger import setup_logger, logging +from vnpy.trader.util_wechat import send_wx_msg +from vnpy.data.mongo.mongo_data import MongoData +from vnpy.trader.setting import SETTINGS +from vnpy.data.stock.adjust_factor import get_all_adjust_factor +from vnpy.data.stock.stock_base import get_stock_base +from vnpy.data.common import stock_to_adj + +from .base import ( + APP_NAME, + EVENT_CTA_LOG, + EVENT_CTA_OPTION, + EVENT_CTA_STOPORDER, + EngineType, + StopOrder, + StopOrderStatus, + STOPORDER_PREFIX, +) +from .template import CtaTemplate +from vnpy.component.base import MARKET_DAY_ONLY, MyEncoder +from vnpy.component.cta_position import CtaPosition + +STOP_STATUS_MAP = { + Status.SUBMITTING: StopOrderStatus.WAITING, + Status.NOTTRADED: StopOrderStatus.WAITING, + Status.PARTTRADED: StopOrderStatus.TRIGGERED, + Status.ALLTRADED: StopOrderStatus.TRIGGERED, + Status.CANCELLED: StopOrderStatus.CANCELLED, + Status.REJECTED: StopOrderStatus.CANCELLED +} + +# 假期,后续可以从cta_option_config.json文件中获取更新 +holiday_dict = { + # 放假第一天:放假最后一天 + "2000124": "20200130", + "20200501": "20200505", + "20201001": "20201008", + "20210211": "20210217", + "20210501": "20210505", + "20211001": "20211007", +} + + +class CtaOptionEngine(BaseEngine): + """ + 期权策略引擎 + + """ + + engine_type = EngineType.LIVE # live trading engine + + # 策略配置文件 + setting_filename = "cta_option_setting.json" + # 引擎配置文件 + config_filename = "cta_option_config.json" + + # 期权策略引擎得特殊参数配置 + # "accountid" : "xxxx", 资金账号,一般用于推送消息时附带,后续入数据库时,可根据accountid归结到统一个账号中 + # "strategy_group": "cta_option", # 当前实例名。多个实例时,区分开 + # "trade_2_wx": true # 是否交易记录转发至微信通知 + # "event_log: false # 是否转发日志到event bus,显示在图形界面 + # "snapshot2file": false # 是否保存切片到文件 + # "compare_pos": false # False,强制不进行 账号 <=> 引擎实例 得仓位比对。(一般分布式RPC运行时,其他得实例都不进行比对) + # "get_pos_from_db": false # True,使用数据库得 策略<=>pos 数据作为比较(一般分布式RPC运行时,其中一个使用即可); False,使用当前引擎实例得 策略.pos进行比对 + # "holiday_dict": { "开始日期":"结束日期"} + + def __init__(self, main_engine: MainEngine, event_engine: EventEngine): + """ + 构造函数 + :param main_engine: 主引擎 + :param event_engine: 事件引擎 + """ + super().__init__(main_engine, event_engine, APP_NAME) + + self.engine_config = {} + # 是否激活 write_log写入event bus(比较耗资源) + self.event_log = False + + self.strategy_setting = {} # strategy_name: dict + self.strategy_data = {} # strategy_name: dict + + self.classes = {} # class_name: stategy_class + self.class_module_map = {} # class_name: mudule_name + self.strategies = {} # strategy_name: strategy + + # Strategy pos dict,key:strategy instance name, value: pos dict + self.strategy_pos_dict = {} + + self.strategy_loggers = {} # strategy_name: logger + + # 未能订阅的symbols,支持策略启动时,并未接入gateway,如果没有收到tick,再定时重新订阅 + # gateway_name.vt_symbol: set() of (strategy_name, is_bar) + self.pending_subcribe_symbol_map = defaultdict(set) + + self.symbol_strategy_map = defaultdict(list) # vt_symbol: strategy list + self.bar_strategy_map = defaultdict(list) # vt_symbol: strategy list + self.strategy_symbol_map = defaultdict(set) # strategy_name: vt_symbol set + + self.orderid_strategy_map = {} # vt_orderid: strategy + self.strategy_orderid_map = defaultdict( + set) # strategy_name: orderid list + + self.stop_order_count = 0 # for generating stop_orderid + self.stop_orders = {} # stop_orderid: stop_order + + # 异步线程执行,一般用于策略得初始化数据等加载,不影响交易 + self.thread_executor = ThreadPoolExecutor(max_workers=1) + self.thread_tasks = [] + + self.vt_tradeids = set() # for filtering duplicate trade + + self.last_minute = None + self.symbol_bar_dict = {} # vt_symbol: bar(一分钟bar) + + self.stock_adjust_factors = get_all_adjust_factor() + # 获取全量股票信息 + self.write_log(f'获取全量股票信息') + self.symbol_dict = get_stock_base() + self.write_log(f'共{len(self.symbol_dict)}个股票') + # 除权因子 + self.write_log(f'获取所有除权因子') + self.adjust_factor_dict = get_all_adjust_factor() + self.write_log(f'共{len(self.adjust_factor_dict)}条除权信息') + + # 寻找数据文件所在目录 + vnpy_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..')) + self.write_log(f'项目所在目录:{vnpy_root}') + self.bar_data_folder = os.path.abspath(os.path.join(vnpy_root, 'bar_data')) + if os.path.exists(self.bar_data_folder): + SSE_folder = os.path.abspath(os.path.join(vnpy_root, 'bar_data', 'SSE')) + if os.path.exists(SSE_folder): + self.write_log(f'上交所bar数据目录:{SSE_folder}') + else: + self.write_error(f'不存在上交所数据目录:{SSE_folder}') + + SZSE_folder = os.path.abspath(os.path.join(vnpy_root, 'bar_data', 'SZSE')) + if os.path.exists(SZSE_folder): + self.write_log(f'深交所bar数据目录:{SZSE_folder}') + else: + self.write_error(f'不存在深交所数据目录:{SZSE_folder}') + else: + self.write_error(f'不存在bar数据目录:{self.bar_data_folder}') + self.bar_data_folder = None + + def init_engine(self): + """ + """ + self.register_event() + self.register_funcs() + + self.load_strategy_class() + self.load_strategy_setting() + + self.write_log("CTA策略引擎初始化成功") + + if self.engine_config.get('get_pos_from_db',False): + self.write_log(f'激活数据库策略仓位比对模式') + self.init_mongo_data() + + def init_mongo_data(self): + """初始化hams数据库""" + host = SETTINGS.get('hams.host', 'localhost') + port = SETTINGS.get('hams.port', 27017) + self.write_log(f'初始化hams数据库连接:{host}:{port}') + try: + # Mongo数据连接客户端 + self.mongo_data = MongoData(host=host, port=port) + + if self.mongo_data and self.mongo_data.db_has_connected: + self.write_log(f'连接成功') + else: + self.write_error(f'HAMS数据库{host}:{port}连接异常.') + except Exception as ex: + self.write_error(f'HAMS数据库{host}:{port}连接异常.{str(ex)}') + + def close(self): + """停止所属有的策略""" + self.stop_all_strategies() + + def register_event(self): + """注册事件""" + self.event_engine.register(EVENT_TIMER, self.process_timer_event) + self.event_engine.register(EVENT_TICK, self.process_tick_event) + self.event_engine.register(EVENT_BAR, self.process_bar_event) + self.event_engine.register(EVENT_ORDER, self.process_order_event) + self.event_engine.register(EVENT_TRADE, self.process_trade_event) + + def register_funcs(self): + """ + register the funcs to main_engine + :return: + """ + self.main_engine.get_name = self.get_name + self.main_engine.get_strategy_status = self.get_strategy_status + self.main_engine.get_strategy_pos = self.get_strategy_pos + self.main_engine.compare_pos = self.compare_pos + self.main_engine.add_strategy = self.add_strategy + self.main_engine.init_strategy = self.init_strategy + self.main_engine.start_strategy = self.start_strategy + self.main_engine.stop_strategy = self.stop_strategy + self.main_engine.remove_strategy = self.remove_strategy + self.main_engine.reload_strategy = self.reload_strategy + self.main_engine.save_strategy_data = self.save_strategy_data + self.main_engine.save_strategy_snapshot = self.save_strategy_snapshot + self.main_engine.clean_strategy_cache = self.clean_strategy_cache + + # 注册到远程服务调用 + if self.main_engine.rpc_service: + self.main_engine.rpc_service.register(self.main_engine.get_strategy_status) + self.main_engine.rpc_service.register(self.main_engine.get_strategy_pos) + self.main_engine.rpc_service.register(self.main_engine.compare_pos) + self.main_engine.rpc_service.register(self.main_engine.add_strategy) + self.main_engine.rpc_service.register(self.main_engine.init_strategy) + self.main_engine.rpc_service.register(self.main_engine.start_strategy) + self.main_engine.rpc_service.register(self.main_engine.stop_strategy) + self.main_engine.rpc_service.register(self.main_engine.remove_strategy) + self.main_engine.rpc_service.register(self.main_engine.reload_strategy) + self.main_engine.rpc_service.register(self.main_engine.save_strategy_data) + self.main_engine.rpc_service.register(self.main_engine.save_strategy_snapshot) + self.main_engine.rpc_service.register(self.main_engine.clean_strategy_cache) + + def process_timer_event(self, event: Event): + """ 处理定时器事件""" + + all_trading = True + dt = datetime.now() + + # 触发每个策略的定时接口 + for strategy in list(self.strategies.values()): + strategy.on_timer() + if not strategy.trading: + all_trading = False + + # 临近夜晚收盘前,强制发出撤单 + if dt.hour == 2 and dt.minute == 59 and dt.second >= 55: + self.cancel_all(strategy) + + # 每分钟执行的逻辑 + if self.last_minute != dt.minute: + self.last_minute = dt.minute + + if all_trading: + # 主动获取所有策略得持仓信息 + all_strategy_pos = self.get_all_strategy_pos() + + # 每5分钟检查一次 + if dt.minute % 5 == 0 and self.engine_config.get('compare_pos', True): + # 比对仓位,使用上述获取得持仓信息,不用重复获取 + self.compare_pos(strategy_pos_list=copy(all_strategy_pos)) + + # 推送到事件 + self.put_all_strategy_pos_event(all_strategy_pos) + + for strategy_name in list(self.strategies.keys()): + strategy = self.strategies.get(strategy_name, None) + if strategy and strategy.inited: + self.call_strategy_func(strategy, strategy.on_timer) + + + def process_tick_event(self, event: Event): + """处理tick到达事件""" + tick = event.data + + key = f'{tick.gateway_name}.{tick.vt_symbol}' + v = self.pending_subcribe_symbol_map.pop(key, None) + if v: + # 这里不做tick/bar的判断了,因为基本有tick就有bar + self.write_log(f'{key} tick已经到达,移除未订阅记录:{v}') + + strategies = self.symbol_strategy_map[tick.vt_symbol] + if not strategies: + return + + self.check_stop_order(tick) + + for strategy in strategies: + if strategy.inited: + self.call_strategy_func(strategy, strategy.on_tick, {tick.vt_symbol:tick}) + + def process_bar_event(self, event: Event): + """处理bar到达事件""" + bar = event.data + # 更新bar + self.symbol_bar_dict[bar.vt_symbol] = bar + # 寻找订阅了该bar的策略 + strategies = self.symbol_strategy_map[bar.vt_symbol] + if not strategies: + return + for strategy in strategies: + if strategy.inited: + self.call_strategy_func(strategy, strategy.on_bar, {bar.vt_symbol: bar}) + + def process_order_event(self, event: Event): + """""" + order = event.data + + strategy = self.orderid_strategy_map.get(order.vt_orderid, None) + if not strategy: + self.write_log(f'委托单没有对应的策略设置:order:{order.__dict__}') + self.write_log(f'当前策略侦听委托单:{list(self.orderid_strategy_map.keys())}') + return + self.write_log(f'委托更新:{order.vt_orderid} => 策略:{strategy.strategy_name}') + # Remove vt_orderid if order is no longer active. + vt_orderids = self.strategy_orderid_map[strategy.strategy_name] + if order.vt_orderid in vt_orderids and not order.is_active(): + vt_orderids.remove(order.vt_orderid) + + # For server stop order, call strategy on_stop_order function + if order.type == OrderType.STOP: + so = StopOrder( + vt_symbol=order.vt_symbol, + direction=order.direction, + offset=order.offset, + price=order.price, + volume=order.volume, + stop_orderid=order.vt_orderid, + strategy_name=strategy.strategy_name, + status=STOP_STATUS_MAP[order.status], + vt_orderids=[order.vt_orderid], + ) + self.call_strategy_func(strategy, strategy.on_stop_order, so) + + # Call strategy on_order function + self.call_strategy_func(strategy, strategy.on_order, order) + + def process_trade_event(self, event: Event): + """""" + trade = event.data + + # Filter duplicate trade push + if trade.vt_tradeid in self.vt_tradeids: + self.write_log(f'成交单的交易编号{trade.vt_tradeid}已处理完毕,不再处理') + return + self.vt_tradeids.add(trade.vt_tradeid) + + strategy = self.orderid_strategy_map.get(trade.vt_orderid, None) + if not strategy: + self.write_log(f'成交单没有对应的策略设置:trade:{trade.__dict__}') + self.write_log(f'当前策略侦听委托单:{list(self.orderid_strategy_map.keys())}') + return + + self.write_log(f'成交更新:{trade.vt_orderid} => 策略:{strategy.strategy_name}') + + # Update strategy pos before calling on_trade method + # 取消外部干预策略pos,由策略自行完成更新 + # if trade.direction == Direction.LONG: + # strategy.pos += trade.volume + # else: + # strategy.pos -= trade.volume + # 根据策略名称,写入 data\straetgy_name_trade.csv文件 + strategy_name = getattr(strategy, 'strategy_name') + trade_fields = ['datetime', 'symbol', 'exchange', 'vt_symbol', 'name','tradeid', 'vt_tradeid', 'orderid', 'vt_orderid', + 'direction', 'offset', 'price', 'volume'] + trade_dict = OrderedDict() + try: + for k in trade_fields: + if k == 'datetime': + dt = getattr(trade, 'datetime') + if isinstance(dt, datetime): + trade_dict[k] = dt.strftime('%Y-%m-%d %H:%M:%S') + else: + trade_dict[k] = datetime.now().strftime('%Y-%m-%d') + ' ' + getattr(trade, 'time', '') + if k in ['exchange', 'direction', 'offset']: + trade_dict[k] = getattr(trade, k).value + else: + trade_dict[k] = getattr(trade, k, '') + + if strategy_name is not None: + trade_file = str(get_folder_path('data').joinpath('{}_trade.csv'.format(strategy_name))) + append_data(file_name=trade_file, dict_data=trade_dict) + except Exception as ex: + self.write_error(u'写入交易记录csv出错:{},{}'.format(str(ex), traceback.format_exc())) + + self.call_strategy_func(strategy, strategy.on_trade, trade) + + # Sync strategy variables to data file + # 取消此功能,由策略自身完成数据持久化 + # self.sync_strategy_data(strategy) + + # Update GUI + self.put_strategy_event(strategy) + + # 如果配置文件 cta_stock_config.json中,有trade_2_wx的设置项,则发送微信通知 + if self.engine_config.get('trade_2_wx', False): + accountid = self.engine_config.get('accountid', 'XXX') + d = { + 'account': accountid, + 'strategy': strategy_name, + 'symbol': trade.symbol, + 'action': f'{trade.direction.value} {trade.offset.value}', + 'price': str(trade.price), + 'volume': trade.volume, + 'remark': f'{accountid}:{strategy_name}', + 'timestamp': trade.time + } + send_wx_msg(content=d, target=accountid, msg_type='TRADE') + + def check_unsubscribed_symbols(self): + """检查未订阅合约""" + + for key in self.pending_subcribe_symbol_map.keys(): + # gateway_name.symbol.exchange = > gateway_name, vt_symbol + keys = key.split('.') + gateway_name = keys[0] + vt_symbol = '.'.join(keys[1:]) + + contract = self.main_engine.get_contract(vt_symbol) + is_bar = True if vt_symbol in self.bar_strategy_map else False + if contract: + dt = datetime.now() + + self.write_log(f'重新提交合约{vt_symbol}订阅请求') + for strategy_name, is_bar in list(self.pending_subcribe_symbol_map[vt_symbol]): + self.subscribe_symbol(strategy_name=strategy_name, + vt_symbol=vt_symbol, + gateway_name=gateway_name, + is_bar=is_bar) + else: + try: + self.write_log(f'找不到合约{vt_symbol}信息,尝试请求所有接口') + symbol, exchange = extract_vt_symbol(vt_symbol) + req = SubscribeRequest(symbol=symbol, exchange=exchange) + req.is_bar = is_bar + self.main_engine.subscribe(req, gateway_name) + + except Exception as ex: + self.write_error( + u'重新订阅{}.{}异常:{},{}'.format(gateway_name, vt_symbol, str(ex), traceback.format_exc())) + return + + def check_stop_order(self, tick: TickData): + """""" + for stop_order in list(self.stop_orders.values()): + if stop_order.vt_symbol != tick.vt_symbol: + continue + + long_triggered = stop_order.direction == Direction.LONG and tick.last_price >= stop_order.price + short_triggered = stop_order.direction == Direction.SHORT and tick.last_price <= stop_order.price + + if long_triggered or short_triggered: + strategy = self.strategies[stop_order.strategy_name] + + # To get excuted immediately after stop order is + # triggered, use limit price if available, otherwise + # use ask_price_5 or bid_price_5 + if stop_order.direction == Direction.LONG: + if tick.limit_up: + price = tick.limit_up + else: + price = tick.ask_price_5 + else: + if tick.limit_down: + price = tick.limit_down + else: + price = tick.bid_price_5 + + contract = self.main_engine.get_contract(stop_order.vt_symbol) + + vt_orderids = self.send_limit_order( + strategy, + contract, + stop_order.direction, + stop_order.offset, + price, + stop_order.volume + ) + + # Update stop order status if placed successfully + if vt_orderids: + # Remove from relation map. + self.stop_orders.pop(stop_order.stop_orderid) + + strategy_vt_orderids = self.strategy_orderid_map[strategy.strategy_name] + if stop_order.stop_orderid in strategy_vt_orderids: + strategy_vt_orderids.remove(stop_order.stop_orderid) + + # Change stop order status to cancelled and update to strategy. + stop_order.status = StopOrderStatus.TRIGGERED + stop_order.vt_orderids = vt_orderids + + self.call_strategy_func( + strategy, strategy.on_stop_order, stop_order + ) + self.put_stop_order_event(stop_order) + + def send_server_order( + self, + strategy: CtaTemplate, + contract: ContractData, + direction: Direction, + offset: Offset, + price: float, + volume: float, + type: OrderType, + gateway_name: str = None + ): + """ + Send a new order to server. + """ + # Create request and send order. + original_req = OrderRequest( + symbol=contract.symbol, + exchange=contract.exchange, + direction=direction, + offset=offset, + type=type, + price=price, + volume=volume, + strategy_name=strategy.strategy_name + ) + + # 如果没有指定网关,则使用合约信息内的网关 + if contract.gateway_name and not gateway_name: + gateway_name = contract.gateway_name + + # Convert with offset converter + req_list = [original_req] + + # Send Orders + vt_orderids = [] + + for req in req_list: + vt_orderid = self.main_engine.send_order( + req, gateway_name) + + # Check if sending order successful + if not vt_orderid: + continue + + vt_orderids.append(vt_orderid) + + # Save relationship between orderid and strategy. + self.orderid_strategy_map[vt_orderid] = strategy + self.strategy_orderid_map[strategy.strategy_name].add(vt_orderid) + + return vt_orderids + + def send_limit_order( + self, + strategy: CtaTemplate, + contract: ContractData, + direction: Direction, + offset: Offset, + price: float, + volume: float, + gateway_name: str = None + ): + """ + Send a limit order to server. + """ + return self.send_server_order( + strategy, + contract, + direction, + offset, + price, + volume, + OrderType.LIMIT, + gateway_name + ) + + def send_fak_order( + self, + strategy: CtaTemplate, + contract: ContractData, + direction: Direction, + offset: Offset, + price: float, + volume: float, + gateway_name: str = None + ): + """ + Send a limit order to server. + """ + return self.send_server_order( + strategy, + contract, + direction, + offset, + price, + volume, + OrderType.FAK, + gateway_name + ) + + def send_server_stop_order( + self, + strategy: CtaTemplate, + contract: ContractData, + direction: Direction, + offset: Offset, + price: float, + volume: float, + gateway_name: str = None + ): + """ + Send a stop order to server. + + Should only be used if stop order supported + on the trading server. + """ + return self.send_server_order( + strategy, + contract, + direction, + offset, + price, + volume, + OrderType.STOP, + gateway_name + ) + + def send_local_stop_order( + self, + strategy: CtaTemplate, + vt_symbol: str, + direction: Direction, + offset: Offset, + price: float, + volume: float, + gateway_name: str = None + ): + """ + Create a new local stop order. + """ + self.stop_order_count += 1 + stop_orderid = f"{STOPORDER_PREFIX}.{self.stop_order_count}" + + stop_order = StopOrder( + vt_symbol=vt_symbol, + direction=direction, + offset=offset, + price=price, + volume=volume, + stop_orderid=stop_orderid, + strategy_name=strategy.strategy_name, + gateway_name=gateway_name + ) + + self.stop_orders[stop_orderid] = stop_order + + vt_orderids = self.strategy_orderid_map[strategy.strategy_name] + vt_orderids.add(stop_orderid) + + self.call_strategy_func(strategy, strategy.on_stop_order, stop_order) + self.put_stop_order_event(stop_order) + + return [stop_orderid] + + def cancel_server_order(self, strategy: CtaTemplate, vt_orderid: str): + """ + Cancel existing order by vt_orderid. + """ + order = self.main_engine.get_order(vt_orderid) + if not order: + self.write_log(msg=f"撤单失败,找不到委托{vt_orderid}", + strategy_name=strategy.strategy_name, + level=logging.ERROR) + return False + + req = order.create_cancel_request() + return self.main_engine.cancel_order(req, order.gateway_name) + + def cancel_local_stop_order(self, strategy: CtaTemplate, stop_orderid: str): + """ + Cancel a local stop order. + """ + stop_order = self.stop_orders.get(stop_orderid, None) + if not stop_order: + return False + strategy = self.strategies[stop_order.strategy_name] + + # Remove from relation map. + self.stop_orders.pop(stop_orderid) + + vt_orderids = self.strategy_orderid_map[strategy.strategy_name] + if stop_orderid in vt_orderids: + vt_orderids.remove(stop_orderid) + + # Change stop order status to cancelled and update to strategy. + stop_order.status = StopOrderStatus.CANCELLED + + self.call_strategy_func(strategy, strategy.on_stop_order, stop_order) + self.put_stop_order_event(stop_order) + return True + + def send_order( + self, + strategy: CtaTemplate, + vt_symbol: str, + direction: Direction, + offset: Offset, + price: float, + volume: float, + stop: bool, + order_type: OrderType = OrderType.LIMIT, + gateway_name: str = None + ): + """ + 该方法供策略使用,发送委托。 + """ + contract = self.main_engine.get_contract(vt_symbol) + if not contract: + self.write_log(msg=f"委托失败,找不到合约:{vt_symbol}", + strategy_name=strategy.strategy_name, + level=logging.ERROR) + return "" + if contract.gateway_name and not gateway_name: + gateway_name = contract.gateway_name + # Round order price and volume to nearest incremental value + price = round_to(price, contract.pricetick) + volume = round_to(volume, contract.min_volume) + + if stop: + if contract.stop_supported: + return self.send_server_stop_order(strategy, contract, direction, offset, price, volume, + gateway_name) + else: + return self.send_local_stop_order(strategy, vt_symbol, direction, offset, price, volume, + gateway_name) + if order_type == OrderType.FAK: + return self.send_fak_order(strategy, contract, direction, offset, price, volume, gateway_name) + else: + return self.send_limit_order(strategy, contract, direction, offset, price, volume, gateway_name) + + def cancel_order(self, strategy: CtaTemplate, vt_orderid: str): + """ + """ + if vt_orderid.startswith(STOPORDER_PREFIX): + return self.cancel_local_stop_order(strategy, vt_orderid) + else: + return self.cancel_server_order(strategy, vt_orderid) + + def cancel_all(self, strategy: CtaTemplate): + """ + Cancel all active orders of a strategy. + """ + vt_orderids = self.strategy_orderid_map[strategy.strategy_name] + if not vt_orderids: + return + + for vt_orderid in copy(vt_orderids): + self.cancel_order(strategy, vt_orderid) + + def subscribe_symbol(self, strategy_name: str, vt_symbol: str, gateway_name: str = '', is_bar: bool = False): + """订阅合约""" + strategy = self.strategies.get(strategy_name, None) + if not strategy: + return False + if len(vt_symbol) == 0: + self.write_error(f'不能为{strategy_name}订阅空白合约') + return False + contract = self.main_engine.get_contract(vt_symbol) + if contract: + if contract.gateway_name and not gateway_name: + gateway_name = contract.gateway_name + req = SubscribeRequest( + symbol=contract.symbol, exchange=contract.exchange) + self.main_engine.subscribe(req, gateway_name) + else: + self.write_log(msg=f"找不到合约{vt_symbol},添加到待订阅列表", + strategy_name=strategy.strategy_name) + self.pending_subcribe_symbol_map[f'{gateway_name}.{vt_symbol}'].add((strategy_name, is_bar)) + try: + self.write_log(f'找不到合约{vt_symbol}信息,尝试请求所有接口') + symbol, exchange = extract_vt_symbol(vt_symbol) + req = SubscribeRequest(symbol=symbol, exchange=exchange) + req.is_bar = is_bar + self.main_engine.subscribe(req, gateway_name) + + except Exception as ex: + self.write_error(u'重新订阅{}异常:{},{}'.format(vt_symbol, str(ex), traceback.format_exc())) + + # 如果是订阅bar + if is_bar: + strategies = self.bar_strategy_map[vt_symbol] + if strategy not in strategies: + strategies.append(strategy) + self.bar_strategy_map.update({vt_symbol: strategies}) + else: + # 添加 合约订阅 vt_symbol <=> 策略实例 strategy 映射. + strategies = self.symbol_strategy_map[vt_symbol] + if strategy not in strategies: + strategies.append(strategy) + + # 添加 策略名 strategy_name <=> 合约订阅 vt_symbol 的映射 + subscribe_symbol_set = self.strategy_symbol_map[strategy.strategy_name] + subscribe_symbol_set.add(vt_symbol) + + return True + + @lru_cache() + def get_exchange(self, symbol): + return self.main_engine.get_exchange(symbol) + + @lru_cache() + def get_name(self, vt_symbol: str): + """查询合约的name""" + contract = self.main_engine.get_contract(vt_symbol) + if contract is None: + self.write_error(f'查询不到{vt_symbol}合约信息') + return vt_symbol + return contract.name + + @lru_cache() + def get_size(self, vt_symbol: str): + """查询合约的size""" + contract = self.main_engine.get_contract(vt_symbol) + if contract is None: + self.write_error(f'查询不到{vt_symbol}合约信息') + return 10 + return contract.size + + @lru_cache() + def get_margin_rate(self, vt_symbol: str): + """查询保证金比率""" + contract = self.main_engine.get_contract(vt_symbol) + if contract is None: + self.write_error(f'查询不到{vt_symbol}合约信息') + return 0.1 + if contract.margin_rate == 0: + return 0.1 + return contract.margin_rate + + @lru_cache() + def get_price_tick(self, vt_symbol: str): + """查询价格最小跳动""" + contract = self.main_engine.get_contract(vt_symbol) + if contract is None: + self.write_error(f'查询不到{vt_symbol}合约信息,缺省使用1作为价格跳动') + return 0.0001 + + return contract.pricetick + + @lru_cache() + def get_volume_tick(self, vt_symbol: str): + """查询合约的最小成交数量""" + contract = self.main_engine.get_contract(vt_symbol) + if contract is None: + self.write_error(f'查询不到{vt_symbol}合约信息,缺省使用1作为最小成交数量') + return 1 + + return contract.min_volume + + def get_margin(self, vt_symbol: str): + """ + 按照当前价格,计算1手合约需要得保证金 + :param vt_symbol: + :return: 普通合约/期权 => 当前价格 * size * margin_rate + + """ + + cur_price = self.get_price(vt_symbol) + cur_size = self.get_size(vt_symbol) + cur_margin_rate = self.get_margin_rate(vt_symbol) + if cur_price and cur_size and cur_margin_rate: + return abs(cur_price * cur_size * cur_margin_rate) + else: + # 取不到价格,取不到size,或者取不到保证金比例 + self.write_error(f'无法计算{vt_symbol}的保证金,价格:{cur_price}或size:{cur_size}或margin_rate:{cur_margin_rate}') + return None + + def get_tick(self, vt_symbol: str): + """获取合约得最新tick""" + return self.main_engine.get_tick(vt_symbol) + + def get_price(self, vt_symbol: str): + """查询合约的最新价格""" + price = self.main_engine.get_price(vt_symbol) + if price: + return price + + tick = self.main_engine.get_tick(vt_symbol) + if tick: + if '&' in tick.symbol: + return (tick.ask_price_1 + tick.bid_price_1) / 2 + else: + return tick.last_price + + return None + + def get_contract(self, vt_symbol): + return self.main_engine.get_contract(vt_symbol) + + def get_all_contracts(self): + return self.main_engine.get_all_contracts() + + def get_option_list(self, underlying_symbol, year_month): + """ + 获取ETF期权的交易合约 + :param underlying_symbol: 标的物合约,例如 510050.SSE + :param year_month 2112, 表示2021年12月 + :return: + """ + symbol = underlying_symbol.split('.')[0] + + all_contracts = self.get_all_contracts() + + # 510050C2112M03100 + cur_month_contracts = [c for c in all_contracts + if c.product == Product.OPTION \ + and len(c.option_index) >= 17 \ + and c.option_index.startswith(symbol) \ + and c.option_index[7:11] == str(year_month)] + d = {} + for c in cur_month_contracts: + if c.vt_symbol in d: + continue + d[c.vt_symbol] = c + return list(d.values()) + + def get_holiday(self): + """获取假日""" + return self.engine_config.get('holiday_dict', holiday_dict) + + def get_option_rest_days(self, cur_date: str, expire_date: str): + """ + 获取期权从当前日到行权价得结束天数 + :param cur_date: 当前日期 + :param expire_date: 行权日期 + :return: + """ + holidays = self.get_holiday() + if cur_date > expire_date: + return 0 + rest_days = 0 # 剩余天数 + # 开始日期 > 结束日期 + for d in range(int(cur_date), int(expire_date)): + _s = str(d) + + # 判断是否周六日 + try: + _c = datetime.strptime(_s, '%Y%m%d') + + if _c.isoweekday() in [6, 7]: + continue + + except Exception as ex: + continue + # 剩余天数先增加一天 + rest_days += 1 + + # 如果存在假期内,就减除1天 + for s, e in holidays.items(): + if s <= _s <= e: + rest_days -= 1 + break + + return max(0, rest_days) + + def get_account(self, vt_accountid: str = ""): + """ 查询账号的资金""" + # 如果启动风控,则使用风控中的最大仓位 + if self.main_engine.rm_engine: + return self.main_engine.rm_engine.get_account(vt_accountid) + + if len(vt_accountid) > 0: + account = self.main_engine.get_account(vt_accountid) + return account.balance, account.available, round(account.frozen * 100 / (account.balance + 0.01), 2), 100 + else: + accounts = self.main_engine.get_all_accounts() + if len(accounts) > 0: + account = accounts[0] + return account.balance, account.available, round(account.frozen * 100 / (account.balance + 0.01), + 2), 100 + else: + return 0, 0, 0, 0 + + def get_position(self, vt_symbol: str, direction: Direction, gateway_name: str = ''): + """ 查询合约在账号的持仓,需要指定方向""" + if len(gateway_name) == 0: + contract = self.main_engine.get_contract(vt_symbol) + if contract and contract.gateway_name: + gateway_name = contract.gateway_name + vt_position_id = f"{gateway_name}.{vt_symbol}.{direction.value}" + return self.main_engine.get_position(vt_position_id) + + def get_engine_type(self): + """""" + return self.engine_type + + @lru_cache() + def get_data_path(self): + data_path = os.path.abspath(os.path.join(TRADER_DIR, 'data')) + return data_path + + @lru_cache() + def get_logs_path(self): + log_path = os.path.abspath(os.path.join(TRADER_DIR, 'log')) + return log_path + + def load_bar( + self, + vt_symbol: str, + days: int, + interval: Interval, + callback: Callable[[BarData], None], + interval_num: int = 1 + ): + """获取历史记录""" + symbol, exchange = extract_vt_symbol(vt_symbol) + end = datetime.now() + start = end - timedelta(days) + bars = [] + + # Query bars from gateway if available + contract = self.main_engine.get_contract(vt_symbol) + + if contract and contract.history_data: + req = HistoryRequest( + symbol=symbol, + exchange=exchange, + interval=interval, + interval_num=interval_num, + start=start, + end=end + ) + bars = self.main_engine.query_history(req, contract.gateway_name) + + if bars is None: + self.write_error(f'获取不到历史K线:{req.__dict__}') + return + + for bar in bars: + if bar.trading_day: + bar.trading_day = bar.datetime.strftime('%Y-%m-%d') + + callback(bar) + + def get_bars( + self, + vt_symbol: str, + days: int, + interval: Interval, + interval_num: int = 1 + ): + """获取历史记录""" + symbol, exchange = extract_vt_symbol(vt_symbol) + end = datetime.now() + start = end - timedelta(days) + bars = [] + + # 检查股票代码 + if vt_symbol not in self.symbol_dict: + self.write_error(f'{vt_symbol}不在基础配置股票信息中') + return bars + + # 检查数据文件目录 + if not self.bar_data_folder: + self.write_error(f'没有bar数据目录') + return bars + # 按照交易所的存放目录 + bar_file_folder = os.path.abspath(os.path.join(self.bar_data_folder, f'{exchange.value}')) + + resample_min = False + resample_hour = False + resample_day = False + file_interval_num = 1 + # 只有1,5,15,30分钟,日线数据 + if interval == Interval.MINUTE: + # 如果存在相应的分钟文件,直接读取 + bar_file_path = os.path.abspath(os.path.join(bar_file_folder, f'{symbol}_{interval_num}m.csv')) + if interval_num in [1, 5, 15, 30] and os.path.exists(bar_file_path): + file_interval_num = interval + # 需要resample + else: + resample_min = True + if interval_num > 5: + file_interval_num = 5 + + elif interval == Interval.HOUR: + file_interval_num = 5 + resample_hour = True + bar_file_path = os.path.abspath(os.path.join(bar_file_folder, f'{symbol}_{file_interval_num}m.csv')) + elif interval == Interval.DAILY: + bar_file_path = os.path.abspath(os.path.join(bar_file_folder, f'{symbol}_{interval_num}d.csv')) + if not os.path.exists(bar_file_path): + file_interval_num = 5 + resample_day = True + bar_file_path = os.path.abspath(os.path.join(bar_file_folder, f'{symbol}_{file_interval_num}m.csv')) + else: + self.write_error(f'目前仅支持分钟,小时,日线数据') + return bars + + bar_interval_seconds = interval_num * 60 + + if not os.path.exists(bar_file_path): + self.write_error(f'没有bar数据文件:{bar_file_path}') + return bars + + try: + data_types = { + "datetime": str, + "open": float, + "high": float, + "low": float, + "close": float, + "volume": float, + "amount": float, + "symbol": str, + "trading_day": str, + "date": str, + "time": str + } + + symbol_df = None + qfq_bar_file_path = bar_file_path.replace('.csv', '_qfq.csv') + use_qfq_file = False + last_qfq_dt = get_csv_last_dt(qfq_bar_file_path) + if last_qfq_dt is not None: + last_dt = get_csv_last_dt(bar_file_path) + + if last_qfq_dt == last_dt: + use_qfq_file = True + + if use_qfq_file: + self.write_log(f'使用前复权文件:{qfq_bar_file_path}') + symbol_df = pd.read_csv(qfq_bar_file_path, dtype=data_types) + else: + # 加载csv文件 =》 dateframe + self.write_log(f'使用未复权文件:{bar_file_path}') + symbol_df = pd.read_csv(bar_file_path, dtype=data_types) + + # 转换时间,str =》 datetime + symbol_df["datetime"] = pd.to_datetime(symbol_df["datetime"], format="%Y-%m-%d %H:%M:%S") + # 设置时间为索引 + symbol_df = symbol_df.set_index("datetime") + + # 裁剪数据 + symbol_df = symbol_df.loc[start:end] + + if resample_day: + self.write_log(f'{vt_symbol} resample:{file_interval_num}m => {interval}day') + symbol_df = self.resample_bars(df=symbol_df, to_day=True) + elif resample_hour: + self.write_log(f'{vt_symbol} resample:{file_interval_num}m => {interval}hour') + symbol_df = self.resample_bars(df=symbol_df, x_hour=interval_num) + elif resample_min: + self.write_log(f'{vt_symbol} resample:{file_interval_num}m => {interval}m') + symbol_df = self.resample_bars(df=symbol_df, x_min=interval_num) + + if len(symbol_df) == 0: + return bars + + if not use_qfq_file: + # 复权转换 + adj_list = self.adjust_factor_dict.get(vt_symbol, []) + # 按照结束日期,裁剪复权记录 + adj_list = [row for row in adj_list if row['dividOperateDate'].replace('-', '') <= end.strftime('%Y%m%d')] + + if len(adj_list) > 0: + self.write_log(f'需要对{vt_symbol}进行前复权处理') + for row in adj_list: + row.update({'dividOperateDate': row.get('dividOperateDate')[:10] + ' 09:30:00'}) + # list -> dataframe, 转换复权日期格式 + adj_data = pd.DataFrame(adj_list) + adj_data["dividOperateDate"] = pd.to_datetime(adj_data["dividOperateDate"], format="%Y-%m-%d %H:%M:%S") + adj_data = adj_data.set_index("dividOperateDate") + # 调用转换方法,对open,high,low,close, volume进行复权, fore, 前复权, 其他,后复权 + symbol_df = stock_to_adj(symbol_df, adj_data, adj_type='fore') + + for dt, bar_data in symbol_df.iterrows(): + bar_datetime = dt #- timedelta(seconds=bar_interval_seconds) + + bar = BarData( + gateway_name='backtesting', + symbol=symbol, + exchange=exchange, + datetime=bar_datetime + ) + if 'open' in bar_data: + bar.open_price = float(bar_data['open']) + bar.close_price = float(bar_data['close']) + bar.high_price = float(bar_data['high']) + bar.low_price = float(bar_data['low']) + else: + bar.open_price = float(bar_data['open_price']) + bar.close_price = float(bar_data['close_price']) + bar.high_price = float(bar_data['high_price']) + bar.low_price = float(bar_data['low_price']) + + bar.volume = int(bar_data['volume']) if not np.isnan(bar_data['volume']) else 0 + bar.date = dt.strftime('%Y-%m-%d') + bar.time = dt.strftime('%H:%M:%S') + str_td = str(bar_data.get('trading_day', '')) + if len(str_td) == 8: + bar.trading_day = str_td[0:4] + '-' + str_td[4:6] + '-' + str_td[6:8] + else: + bar.trading_day = bar.date + + bars.append(bar) + + except Exception as ex: + self.write_error(u'回测时读取{} csv文件{}失败:{}'.format(vt_symbol, bar_file_path, ex)) + self.write_error(traceback.format_exc()) + return bars + + return bars + + + def resample_bars(self, df, x_min=None, x_hour=None, to_day=False): + """ + 重建x分钟K线(或日线) + :param df: 输入分钟数 + :param x_min: 5, 15, 30, 60 + :param x_hour: 1, 2, 3, 4 + :param include_day: 重建日线, True得时候,不会重建分钟数 + :return: + """ + # 设置df数据中每列的规则 + ohlc_rule = { + 'open': 'first', # open列:序列中第一个的值 + 'high': 'max', # high列:序列中最大的值 + 'low': 'min', # low列:序列中最小的值 + 'close': 'last', # close列:序列中最后一个的值 + 'volume': 'sum', # volume列:将所有序列里的volume值作和 + 'amount': 'sum', # amount列:将所有序列里的amount值作和 + "symbol": 'first', + "trading_date": 'first', + "date": 'first', + "time": 'first' + } + + if isinstance(x_min, int) and not to_day: + # 合成x分钟K线并删除为空的行 参数 closed:left类似向上取值既 09:30的k线数据是包含09:30-09:35之间的数据 + df_target = df.resample(f'{x_min}min', closed='left', label='left').agg(ohlc_rule).dropna(axis=0, + how='any') + return df_target + if isinstance(x_hour, int) and not to_day: + # 合成x小时K线并删除为空的行 参数 closed:left类似向上取值既 09:30的k线数据是包含09:30-09:35之间的数据 + df_target = df.resample(f'{x_hour}hour', closed='left', label='left').agg(ohlc_rule).dropna(axis=0, + how='any') + return df_target + + if to_day: + # 合成x分钟K线并删除为空的行 参数 closed:left类似向上取值既 09:30的k线数据是包含09:30-09:35之间的数据 + df_target = df.resample(f'D', closed='left', label='left').agg(ohlc_rule).dropna(axis=0, how='any') + return df_target + + return df + + def call_strategy_func( + self, strategy: CtaTemplate, func: Callable, params: Any = None + ): + """ + Call function of a strategy and catch any exception raised. + """ + try: + if params: + func(params) + else: + func() + except Exception: + strategy.trading = False + strategy.inited = False + accountid = self.engine_config.get('accountid', 'XXX') + + msg = f"{accountid}/{strategy.strategy_name}触发异常已停止\n{traceback.format_exc()}" + self.write_log(msg=msg, + strategy_name=strategy.strategy_name, + level=logging.CRITICAL) + self.send_wechat(msg) + + def add_strategy( + self, class_name: str, + strategy_name: str, + vt_symbols: List[str], + setting: dict, + auto_init: bool = False, + auto_start: bool = False + ): + """ + Add a new strategy. + """ + try: + if strategy_name in self.strategies: + msg = f"创建策略失败,存在重名{strategy_name}" + self.write_log(msg=msg, + level=logging.CRITICAL) + return False, msg + + strategy_class = self.classes.get(class_name, None) + if not strategy_class: + msg = f"创建策略失败,找不到策略类{class_name}" + self.write_log(msg=msg, + level=logging.CRITICAL) + return False, msg + + self.write_log(f'开始添加策略类{class_name},实例名:{strategy_name}') + strategy = strategy_class(self, strategy_name, vt_symbols, setting) + self.strategies[strategy_name] = strategy + + # Add vt_symbol to strategy map. + subscribe_symbol_set = self.strategy_symbol_map[strategy_name] + for vt_symbol in vt_symbols: + strategies = self.symbol_strategy_map[vt_symbol] + strategies.append(strategy) + subscribe_symbol_set.add(vt_symbol) + + # Update to setting file. + self.update_strategy_setting(strategy_name, setting, auto_init, auto_start) + + self.put_strategy_event(strategy) + + # 判断设置中是否由自动初始化和自动启动项目 + if auto_init: + self.init_strategy(strategy_name, auto_start=auto_start) + + except Exception as ex: + msg = f'添加策略实例{strategy_name}失败,{str(ex)}' + self.write_error(msg) + self.write_error(traceback.format_exc()) + self.send_wechat(msg) + + return False, f'添加策略实例{strategy_name}失败' + + return True, f'成功添加{strategy_name}' + + def init_strategy(self, strategy_name: str, auto_start: bool = False): + """ + Init a strategy. + """ + task = self.thread_executor.submit(self._init_strategy, strategy_name, auto_start) + self.thread_tasks.append(task) + + def _init_strategy(self, strategy_name: str, auto_start: bool = False): + """ + Init strategies in queue. + """ + strategy = self.strategies[strategy_name] + + if strategy.inited: + self.write_error(f"{strategy_name}已经完成初始化,禁止重复操作") + return + + self.write_log(f"{strategy_name}开始执行初始化") + + # Call on_init function of strategy + self.call_strategy_func(strategy, strategy.on_init) + + # Restore strategy data(variables) + # Pro 版本不使用自动恢复除了内部数据功能,由策略自身初始化时完成 + # data = self.strategy_data.get(strategy_name, None) + # if data: + # for name in strategy.variables: + # value = data.get(name, None) + # if value: + # setattr(strategy, name, value) + + # Subscribe market data 订阅缺省的vt_symbol, 如果有其他合约需要订阅,由策略内部初始化时提交订阅即可。 + for vt_symbol in strategy.vt_symbols: + self.subscribe_symbol(strategy_name=strategy_name, vt_symbol=vt_symbol) + + # Put event to update init completed status. + strategy.inited = True + self.put_strategy_event(strategy) + self.write_log(f"{strategy_name}初始化完成") + + # 初始化后,自动启动策略交易 + if auto_start: + self.start_strategy(strategy_name) + + def start_strategy(self, strategy_name: str): + """ + Start a strategy. + """ + strategy = self.strategies[strategy_name] + if not strategy.inited: + msg = f"策略{strategy.strategy_name}启动失败,请先初始化" + self.write_error(msg) + return False, msg + + if strategy.trading: + msg = f"{strategy_name}已经启动,请勿重复操作" + self.write_log(msg) + return False, msg + + self.call_strategy_func(strategy, strategy.on_start) + strategy.trading = True + + self.put_strategy_event(strategy) + + return True, f'成功启动策略{strategy_name}' + + def stop_strategy(self, strategy_name: str): + """ + Stop a strategy. + """ + strategy = self.strategies[strategy_name] + if not strategy.trading: + msg = f'{strategy_name}策略实例已处于停止交易状态' + self.write_log(msg) + return False, msg + + # Call on_stop function of the strategy + self.write_log(f'调用{strategy_name}的on_stop,停止交易') + self.call_strategy_func(strategy, strategy.on_stop) + + # Change trading status of strategy to False + strategy.trading = False + + # Cancel all orders of the strategy + self.write_log(f'撤销{strategy_name}所有委托') + self.cancel_all(strategy) + + # Sync strategy variables to data file + # 取消此功能,由策略自身完成数据的持久化 + # self.sync_strategy_data(strategy) + + # Update GUI + self.put_strategy_event(strategy) + return True, f'成功停止策略{strategy_name}' + + def edit_strategy(self, strategy_name: str, setting: dict): + """ + Edit parameters of a strategy. + 风险警示: 该方法强行干预策略的配置 + """ + strategy = self.strategies[strategy_name] + auto_init = setting.pop('auto_init', False) + auto_start = setting.pop('auto_start', False) + + strategy.update_setting(setting) + + self.update_strategy_setting(strategy_name, setting, auto_init, auto_start) + self.put_strategy_event(strategy) + + def remove_strategy(self, strategy_name: str): + """ + Remove a strategy. + """ + strategy = self.strategies[strategy_name] + if strategy.trading: + # err_msg = f"策略{strategy.strategy_name}正在运行,先停止" + # self.write_error(err_msg) + # return False, err_msg + ret, msg = self.stop_strategy(strategy_name) + if not ret: + return False, msg + else: + self.write_log(msg) + + # Remove setting + self.remove_strategy_setting(strategy_name) + + # 移除订阅合约与策略的关联关系 + for vt_symbol in self.strategy_symbol_map[strategy_name]: + # Remove from symbol strategy map + self.write_log(f'移除{vt_symbol}《=》{strategy_name}的订阅关系') + strategies = self.symbol_strategy_map[vt_symbol] + if strategy in strategies: + strategies.remove(strategy) + + # Remove from active orderid map + if strategy_name in self.strategy_orderid_map: + vt_orderids = self.strategy_orderid_map.pop(strategy_name) + self.write_log(f'移除{strategy_name}的所有委托订单映射关系') + # Remove vt_orderid strategy map + for vt_orderid in vt_orderids: + if vt_orderid in self.orderid_strategy_map: + self.orderid_strategy_map.pop(vt_orderid) + + # Remove from strategies + self.write_log(f'移除{strategy_name}策略实例') + self.strategies.pop(strategy_name) + + return True, f'成功移除{strategy_name}策略实例' + + def reload_strategy(self, strategy_name: str, vt_symbols: List[str] = [], setting: dict = {}): + """ + 重新加载策略 + 一般使用于在线更新策略代码,或者更新策略参数,需要重新启动策略 + :param strategy_name: + :param setting: + :return: + """ + self.write_log(f'开始重新加载策略{strategy_name}') + + # 优先判断重启的策略,是否已经加载 + if strategy_name not in self.strategies or strategy_name not in self.strategy_setting: + err_msg = f"{strategy_name}不在运行策略中,不能重启" + self.write_error(err_msg) + return False, err_msg + + # 从本地配置文件中读取 + if len(setting) == 0: + strategies_setting = load_json(self.setting_filename) + old_strategy_config = strategies_setting.get(strategy_name, {}) + self.write_log(f'使用配置文件的配置:{old_strategy_config}') + else: + old_strategy_config = copy(self.strategy_setting[strategy_name]) + self.write_log(f'使用已经运行的配置:{old_strategy_config}') + + class_name = old_strategy_config.get('class_name') + self.write_log(f'使用策略类名:{class_name}') + + # 没有配置vt_symbol时,使用配置文件/旧配置中的vt_symbol + if len(vt_symbols) == 0: + vt_symbols = old_strategy_config.get('vt_symbols') + self.write_log(f'使用配置文件/已运行配置的vt_symbols:{vt_symbols}') + + # 没有新配置时,使用配置文件/旧配置中的setting + if len(setting) == 0: + setting = old_strategy_config.get('setting') + self.write_log(f'没有新策略参数,使用配置文件/旧配置中的setting:{setting}') + + module_name = self.class_module_map[class_name] + # 重新load class module + # if not self.load_strategy_class_from_module(module_name): + # err_msg = f'不能加载模块:{module_name}' + # self.write_error(err_msg) + # return False, err_msg + if module_name: + new_class_name = module_name + '.' + class_name + self.write_log(u'转换策略为全路径:{}'.format(new_class_name)) + old_strategy_class = self.classes[class_name] + self.write_log(f'旧策略ID:{id(old_strategy_class)}') + strategy_class = import_module_by_str(new_class_name) + if strategy_class is None: + err_msg = u'加载策略模块失败:{}'.format(new_class_name) + self.write_error(err_msg) + return False, err_msg + + self.write_log(f'重新加载模块成功,使用新模块:{new_class_name}') + self.write_log(f'新策略ID:{id(strategy_class)}') + self.classes[class_name] = strategy_class + else: + self.write_log(f'没有{class_name}的module_name,无法重新加载模块') + + # 停止当前策略实例的运行,撤单 + self.stop_strategy(strategy_name) + + # 移除运行中的策略实例 + self.remove_strategy(strategy_name) + + # 重新添加策略 + self.add_strategy(class_name=class_name, + strategy_name=strategy_name, + vt_symbols=vt_symbols, + setting=setting, + auto_init=old_strategy_config.get('auto_init', False), + auto_start=old_strategy_config.get('auto_start', False)) + + msg = f'成功重载策略{strategy_name}' + self.write_log(msg) + return True, msg + + def save_strategy_data(self, select_name: str = 'ALL'): + """ save strategy data""" + has_executed = False + msg = "" + # 1.判断策略名称是否存在字典中 + for strategy_name in list(self.strategies.keys()): + if select_name != 'ALL': + if strategy_name != select_name: + continue + # 2.提取策略 + strategy = self.strategies.get(strategy_name, None) + if not strategy: + continue + + # 3.判断策略是否运行 + if strategy.inited and strategy.trading: + task = self.thread_executor.submit(self.thread_save_strategy_data, strategy_name) + self.thread_tasks.append(task) + msg += f'{strategy_name}执行保存数据\n' + has_executed = True + else: + self.write_log(f'{strategy_name}未初始化/未启动交易,不进行保存数据') + return has_executed, msg + + def thread_save_strategy_data(self, strategy_name): + """异步线程保存策略数据""" + strategy = self.strategies.get(strategy_name, None) + if strategy is None: + return + try: + # 保存策略数据 + strategy.sync_data() + except Exception as ex: + self.write_error(u'保存策略{}数据异常:'.format(strategy_name, str(ex))) + self.write_error(traceback.format_exc()) + + def clean_strategy_cache(self, strategy_name): + """清除策略K线缓存文件""" + cache_file = os.path.abspath(os.path.join(self.get_data_path(), f'{strategy_name}_klines.pkb2')) + if os.path.exists(cache_file): + self.write_log(f'移除策略缓存文件:{cache_file}') + os.remove(cache_file) + else: + self.write_log(f'策略缓存文件不存在:{cache_file}') + + def get_strategy_kline_names(self, strategy_name): + """ + 获取策略实例内的K线名称 + :param strategy_name:策略实例名称 + :return: + """ + info = {} + strategy = self.strategies.get(strategy_name, None) + if strategy is None: + return info + if hasattr(strategy, 'get_klines_info'): + info = strategy.get_klines_info() + return info + + def get_strategy_snapshot(self, strategy_name, include_kline_names=[]): + """ + 实时获取策略的K线切片(比较耗性能) + :param strategy_name: 策略实例 + :param include_kline_names: 指定若干kline名称 + :return: + """ + strategy = self.strategies.get(strategy_name, None) + if strategy is None: + return None + + try: + # 5.获取策略切片 + snapshot = strategy.get_klines_snapshot(include_kline_names) + if not snapshot: + self.write_log(f'{strategy_name}返回得K线切片数据为空') + return None + return snapshot + + except Exception as ex: + self.write_error(u'获取策略{}切片数据异常:'.format(strategy_name, str(ex))) + self.write_error(traceback.format_exc()) + return None + + def save_strategy_snapshot(self, select_name: str = 'ALL'): + """ + 保存策略K线切片数据 + :param select_name: + :return: + """ + has_executed = False + msg = "" + # 1.判断策略名称是否存在字典中 + for strategy_name in list(self.strategies.keys()): + if select_name != 'ALL': + if strategy_name != select_name: + continue + # 2.提取策略 + strategy = self.strategies.get(strategy_name, None) + if not strategy: + continue + + if not hasattr(strategy, 'get_klines_snapshot'): + continue + + # 3.判断策略是否运行 + if strategy.inited and strategy.trading: + task = self.thread_executor.submit(self.thread_save_strategy_snapshot, strategy_name) + self.thread_tasks.append(task) + msg += f'{strategy_name}执行保存K线切片\n' + has_executed = True + + return has_executed, msg + + def thread_save_strategy_snapshot(self, strategy_name): + """异步线程保存策略切片""" + strategy = self.strategies.get(strategy_name, None) + if strategy is None: + return + + try: + # 5.保存策略切片 + snapshot = strategy.get_klines_snapshot() + if not snapshot: + self.write_log(f'{strategy_name}返回得K线切片数据为空') + return + + if self.engine_config.get('snapshot2file', False): + # 剩下工作:保存本地文件/数据库 + snapshot_folder = get_folder_path(f'data/snapshots/{strategy_name}') + snapshot_file = snapshot_folder.joinpath('{}.pkb2'.format(datetime.now().strftime('%Y%m%d_%H%M%S'))) + with bz2.BZ2File(str(snapshot_file), 'wb') as f: + pickle.dump(snapshot, f) + self.write_log(u'切片保存成功:{}'.format(str(snapshot_file))) + + # 通过事件方式,传导到account_recorder + snapshot.update({ + 'account_id': self.engine_config.get('accountid', '-'), + 'strategy_group': self.engine_config.get('strategy_group', self.engine_name), + 'guid': str(uuid1()) + }) + event = Event(EVENT_STRATEGY_SNAPSHOT, snapshot) + self.event_engine.put(event) + + except Exception as ex: + self.write_error(u'获取策略{}切片数据异常:'.format(strategy_name, str(ex))) + self.write_error(traceback.format_exc()) + + def load_strategy_class(self): + """ + Load strategy class from source code. + """ + # 加载 vnpy/app/cta_strategy_pro/strategies的所有策略 + path1 = Path(__file__).parent.joinpath("strategies") + self.load_strategy_class_from_folder( + path1, "vnpy.app.cta_option.strategies") + + # 加载 当前运行目录下strategies子目录的所有策略 + path2 = Path.cwd().joinpath("strategies") + self.load_strategy_class_from_folder(path2, "strategies") + + def load_strategy_class_from_folder(self, path: Path, module_name: str = ""): + """ + Load strategy class from certain folder. + """ + for dirpath, dirnames, filenames in os.walk(str(path)): + for filename in filenames: + if filename.endswith(".py"): + strategy_module_name = ".".join( + [module_name, filename.replace(".py", "")]) + elif filename.endswith(".pyd"): + strategy_module_name = ".".join( + [module_name, filename.split(".")[0]]) + elif filename.endswith(".so"): + strategy_module_name = ".".join( + [module_name, filename.split(".")[0]]) + else: + continue + self.load_strategy_class_from_module(strategy_module_name) + + def load_strategy_class_from_module(self, module_name: str): + """ + Load/Reload strategy class from module file. + """ + try: + module = importlib.import_module(module_name) + + for name in dir(module): + value = getattr(module, name) + if (isinstance(value, type) and issubclass(value, CtaTemplate) and value is not CtaTemplate): + class_name = value.__name__ + if class_name not in self.classes: + self.write_log(f"加载策略类{module_name}.{class_name}") + else: + self.write_log(f"更新策略类{module_name}.{class_name}") + self.classes[class_name] = value + self.class_module_map[class_name] = module_name + return True + except: # noqa + account = self.engine_config.get('accountid', '') + msg = f"cta_stock:{account}策略文件{module_name}加载失败,触发异常:\n{traceback.format_exc()}" + self.write_log(msg=msg, level=logging.CRITICAL) + return False + + def load_strategy_data(self): + """ + Load strategy data from json file. + """ + print(f'load_strategy_data 此功能已取消,由策略自身完成数据的持久化加载', file=sys.stderr) + return + # self.strategy_data = load_json(self.data_filename) + + def sync_strategy_data(self, strategy: CtaTemplate): + """ + Sync strategy data into json file. + """ + # data = strategy.get_variables() + # data.pop("inited") # Strategy status (inited, trading) should not be synced. + # data.pop("trading") + # self.strategy_data[strategy.strategy_name] = data + # save_json(self.data_filename, self.strategy_data) + print(f'sync_strategy_data此功能已取消,由策略自身完成数据的持久化保存', file=sys.stderr) + + def get_all_strategy_class_names(self): + """ + Return names of strategy classes loaded. + """ + return list(self.classes.keys()) + + def get_strategy_status(self): + """ + return strategy inited/trading status + :param strategy_name: + :return: + """ + return {k: {'inited': v.inited, 'trading': v.trading} for k, v in self.strategies.items()} + + def get_strategy_pos(self, name, strategy=None): + """ + 获取策略的持仓字典 + :param name:策略名 + :return: [ {},{}] + """ + # 兼容处理,如果strategy是None,通过name获取 + if strategy is None: + if name not in self.strategies: + self.write_log(u'get_strategy_pos 策略实例不存在:' + name) + return [] + # 获取策略实例 + strategy = self.strategies[name] + + pos_list = [] + + if strategy.inited: + # 如果策略具有getPositions得方法,则调用该方法 + if hasattr(strategy, 'get_positions'): + pos_list = strategy.get_positions() + for pos in pos_list: + vt_symbol = pos.get('vt_symbol', None) + if vt_symbol: + symbol, exchange = extract_vt_symbol(vt_symbol) + pos.update({'symbol': symbol}) + + # update local pos dict + self.strategy_pos_dict.update({name: pos_list}) + + return pos_list + + def get_all_strategy_pos(self): + """ + 获取所有得策略仓位明细 + """ + strategy_pos_list = [] + for strategy_name in list(self.strategies.keys()): + d = OrderedDict() + d['accountid'] = self.engine_config.get('accountid', '-') + d['strategy_group'] = self.engine_config.get('strategy_group', '-') + d['strategy_name'] = strategy_name + dt = datetime.now() + d['date'] = dt.strftime('%Y%m%d') + d['hour'] = dt.hour + d['datetime'] = datetime.now() + strategy = self.strategies.get(strategy_name) + d['inited'] = strategy.inited + d['trading'] = strategy.trading + try: + d['pos'] = self.get_strategy_pos(name=strategy_name) + except Exception as ex: + self.write_error( + u'get_strategy_pos exception:{},{}'.format(str(ex), traceback.format_exc())) + d['pos'] = [] + strategy_pos_list.append(d) + + return strategy_pos_list + + def get_all_strategy_pos_from_hams(self): + """ + 获取hams中该账号下所有策略仓位明细 + """ + strategy_pos_list = [] + if not self.mongo_data: + self.init_mongo_data() + + if self.mongo_data and self.mongo_data.db_has_connected: + filter = {'account_id':self.engine_config.get('accountid','-')} + + pos_list = self.mongo_data.db_query( + db_name='Account', + col_name='today_strategy_pos', + filter_dict=filter + ) + for pos in pos_list: + strategy_pos_list.append(pos) + + return strategy_pos_list + + def get_strategy_class_parameters(self, class_name: str): + """ + Get default parameters of a strategy class. + """ + strategy_class = self.classes[class_name] + + parameters = {} + for name in strategy_class.parameters: + parameters[name] = getattr(strategy_class, name) + + return parameters + + def get_strategy_parameters(self, strategy_name): + """ + Get parameters of a strategy. + """ + strategy = self.strategies[strategy_name] + strategy_config = self.strategy_setting.get(strategy_name, {}) + d = {} + d.update({'auto_init': strategy_config.get('auto_init', False)}) + d.update({'auto_start': strategy_config.get('auto_start', False)}) + d.update(strategy.get_parameters()) + return d + + def get_strategy_value(self, strategy_name: str, parameter: str): + """获取策略的某个参数值""" + strategy = self.strategies.get(strategy_name) + if not strategy: + return None + + value = getattr(strategy, parameter, None) + return value + + def get_none_strategy_pos_list(self): + """获取非策略持有的仓位""" + # 格式 [ 'strategy_name':'account', 'pos': [{'vt_symbol': '', 'direction': 'xxx', 'volume':xxx }] } ] + none_strategy_pos_file = os.path.abspath(os.path.join(os.getcwd(), 'data', 'none_strategy_pos.json')) + if not os.path.exists(none_strategy_pos_file): + return [] + try: + with open(none_strategy_pos_file, encoding='utf8') as f: + pos_list = json.load(f) + if isinstance(pos_list, list): + return pos_list + + return [] + except Exception as ex: + self.write_error(u'未能读取或解释{}'.format(none_strategy_pos_file)) + return [] + + def compare_pos(self, strategy_pos_list=[], auto_balance=False): + """ + 对比账号&策略的持仓,不同的话则发出微信提醒 + :return: + """ + # 当前没有接入网关 + if len(self.main_engine.gateways) == 0: + return False, u'当前没有接入网关' + + self.write_log(u'开始对比账号&策略的持仓') + + # 获取hams数据库中所有运行实例得策略 + if self.engine_config.get("get_pos_from_db", False): + strategy_pos_list = self.get_all_strategy_pos_from_hams() + else: + # 获取当前实例运行策略得持仓 + if len(strategy_pos_list) == 0: + strategy_pos_list = self.get_all_strategy_pos() + self.write_log(u'策略持仓清单:{}'.format(strategy_pos_list)) + + none_strategy_pos = self.get_none_strategy_pos_list() + if len(none_strategy_pos) > 0: + strategy_pos_list.extend(none_strategy_pos) + + # 需要进行对比得合约集合(来自策略持仓/账号持仓) + vt_symbols = set() + + # 账号的持仓处理 => account_pos + + compare_pos = dict() # vt_symbol: {'账号多单': xx, '账号空单':xxx, '策略空单':[], '策略多单':[]} + + for pos in self.main_engine.get_all_positions(): + vt_symbols.add(pos.vt_symbol) + vt_symbol_pos = compare_pos.get(pos.vt_symbol, { + "账号空单": 0, + '账号多单': 0, + '策略空单': 0, + '策略多单': 0, + '空单策略': [], + '多单策略': [] + }) + if pos.direction == Direction.LONG: + vt_symbol_pos['账号多单'] = vt_symbol_pos['账号多单'] + pos.volume + else: + vt_symbol_pos['账号空单'] = vt_symbol_pos['账号空单'] + pos.volume + + compare_pos.update({pos.vt_symbol:vt_symbol_pos}) + + # 逐一根据策略仓位,与Account_pos进行处理比对 + for strategy_pos in strategy_pos_list: + for pos in strategy_pos.get('pos', []): + vt_symbol = pos.get('vt_symbol') + if not vt_symbol: + continue + vt_symbols.add(vt_symbol) + symbol_pos = compare_pos.get(vt_symbol, None) + if symbol_pos is None: + # self.write_log(u'账号持仓信息获取不到{},创建一个'.format(vt_symbol)) + symbol_pos = OrderedDict( + { + "账号空单": 0, + '账号多单': 0, + '策略空单': 0, + '策略多单': 0, + '空单策略': [], + '多单策略': [] + } + ) + + if pos.get('direction') == 'short': + symbol_pos.update({'策略空单': symbol_pos.get('策略空单', 0) + abs(pos.get('volume', 0))}) + symbol_pos['空单策略'].append( + u'{}({})'.format(strategy_pos['strategy_name'], abs(pos.get('volume', 0)))) + self.write_log(u'更新{}策略持空仓=>{}'.format(vt_symbol, symbol_pos.get('策略空单', 0))) + if pos.get('direction') == 'long': + symbol_pos.update({'策略多单': symbol_pos.get('策略多单', 0) + abs(pos.get('volume', 0))}) + symbol_pos['多单策略'].append( + u'{}({})'.format(strategy_pos['strategy_name'], abs(pos.get('volume', 0)))) + self.write_log(u'更新{}策略持多仓=>{}'.format(vt_symbol, symbol_pos.get('策略多单', 0))) + + compare_pos.update({vt_symbol: symbol_pos}) + + pos_compare_result = '' + # 精简输出 + compare_info = '' + diff_pos_dict = {} + for vt_symbol in sorted(vt_symbols): + # 发送不一致得结果 + symbol_pos = compare_pos.pop(vt_symbol, {}) + + d_long = { + 'account_id': self.engine_config.get('accountid', '-'), + 'vt_symbol': vt_symbol, + 'direction': Direction.LONG.value, + 'strategy_list': symbol_pos.get('多单策略', [])} + + d_short = { + 'account_id': self.engine_config.get('accountid', '-'), + 'vt_symbol': vt_symbol, + 'direction': Direction.SHORT.value, + 'strategy_list': symbol_pos.get('空单策略', [])} + + # 股指期货: 帐号多/空轧差, vs 策略多空轧差 是否一致; + # 其他期货:帐号多单 vs 除了多单, 空单 vs 空单 + if vt_symbol.endswith(".CFFEX"): + diff_match = (symbol_pos.get('账号多单', 0) - symbol_pos.get('账号空单', 0)) == ( + symbol_pos.get('策略多单', 0) - symbol_pos.get('策略空单', 0)) + pos_match = symbol_pos.get('账号空单', 0) == symbol_pos.get('策略空单', 0) and \ + symbol_pos.get('账号多单', 0) == symbol_pos.get('策略多单', 0) + match = diff_match + # 轧差一致,帐号/策略持仓不一致 + if diff_match and not pos_match: + if symbol_pos.get('账号多单', 0) > symbol_pos.get('策略多单', 0): + self.write_log('{}轧差持仓:多:{},空:{} 大于 策略持仓 多:{},空:{}'.format( + vt_symbol, + symbol_pos.get('账号多单', 0), + symbol_pos.get('账号空单', 0), + symbol_pos.get('策略多单', 0), + symbol_pos.get('策略空单', 0) + )) + diff_pos_dict.update({vt_symbol: {"long": symbol_pos.get('账号多单', 0) - symbol_pos.get('策略多单', 0), + "short": symbol_pos.get('账号空单', 0) - symbol_pos.get('策略空单', + 0)}}) + else: + match = round(symbol_pos.get('账号空单', 0), 7) == round(symbol_pos.get('策略空单', 0), 7) and \ + round(symbol_pos.get('账号多单', 0), 7) == round(symbol_pos.get('策略多单', 0), 7) + # 多空都一致 + if match: + msg = u'{}多空都一致.{}\n'.format(vt_symbol, json.dumps(symbol_pos, indent=2, ensure_ascii=False)) + self.write_log(msg) + compare_info += msg + else: + pos_compare_result += '\n{}: '.format(vt_symbol) + # 判断是多单不一致? + diff_long_volume = round(symbol_pos.get('账号多单', 0), 7) - round(symbol_pos.get('策略多单', 0), 7) + if diff_long_volume != 0: + msg = '{}多单[账号({}), 策略{},共({})], ' \ + .format(vt_symbol, + symbol_pos.get('账号多单'), + symbol_pos.get('多单策略'), + symbol_pos.get('策略多单')) + + pos_compare_result += msg + self.write_error(u'{}不一致:{}'.format(vt_symbol, msg)) + compare_info += u'{}不一致:{}\n'.format(vt_symbol, msg) + if auto_balance: + self.balance_pos(vt_symbol, Direction.LONG, diff_long_volume) + + # 判断是空单不一致: + diff_short_volume = round(symbol_pos.get('账号空单', 0), 7) - round(symbol_pos.get('策略空单', 0), 7) + + if diff_short_volume != 0: + msg = '{}空单[账号({}), 策略{},共({})], ' \ + .format(vt_symbol, + symbol_pos.get('账号空单'), + symbol_pos.get('空单策略'), + symbol_pos.get('策略空单')) + pos_compare_result += msg + self.write_error(u'{}不一致:{}'.format(vt_symbol, msg)) + compare_info += u'{}不一致:{}\n'.format(vt_symbol, msg) + if auto_balance: + self.balance_pos(vt_symbol, Direction.SHORT, diff_short_volume) + + # 不匹配,输入到stdErr通道 + if pos_compare_result != '': + msg = u'账户{}持仓不匹配: {}' \ + .format(self.engine_config.get('accountid', '-'), + pos_compare_result) + try: + from vnpy.trader.util_wechat import send_wx_msg + send_wx_msg(content=msg) + except Exception as ex: # noqa + pass + ret_msg = u'持仓不匹配: {}' \ + .format(pos_compare_result) + self.write_error(ret_msg) + return True, compare_info + ret_msg + else: + self.write_log(u'账户持仓与策略一致') + if len(diff_pos_dict) > 0: + for k, v in diff_pos_dict.items(): + self.write_log(f'{k} 存在大于策略的轧差持仓:{v}') + return True, compare_info + + def balance_pos(self, vt_symbol, direction, volume): + """ + 平衡仓位 + :param vt_symbol: 需要平衡得合约 + :param direction: 合约原始方向 + :param volume: 合约需要调整得数量(正数,需要平仓, 负数,需要开仓) + :return: + """ + tick = self.get_tick(vt_symbol) + if tick is None: + gateway_names = self.main_engine.get_all_gateway_names() + gateway_name = gateway_names[0] if len(gateway_names) > 0 else "" + symbol, exchange = extract_vt_symbol(vt_symbol) + self.main_engine.subscribe(req=SubscribeRequest(symbol=symbol, exchange=exchange), + gateway_name=gateway_name) + self.write_log(f'{vt_symbol}无最新tick,订阅行情') + + if volume > 0 and tick: + contract = self.main_engine.get_contract(vt_symbol) + req = OrderRequest( + symbol=contract.symbol, + exchange=contract.exchange, + direction=Direction.SHORT if direction == Direction.LONG else Direction.LONG, + offset=Offset.CLOSE, + type=OrderType.LIMIT, + price=tick.ask_price_1 if direction == Direction.SHORT else tick.bid_price_1, + volume=round(volume, 7) + ) + reqs = [req] + self.write_log(f'平衡仓位,减少 {vt_symbol},方向:{direction},数量:{req.volume} ') + for req in reqs: + self.main_engine.send_order(req, contract.gateway_name) + elif volume < 0 and tick: + contract = self.main_engine.get_contract(vt_symbol) + req = OrderRequest( + symbol=contract.symbol, + exchange=contract.exchange, + direction=direction, + offset=Offset.OPEN, + type=OrderType.FAK, + price=tick.ask_price_1 if direction == Direction.LONG else tick.bid_price_1, + volume=round(abs(volume), 7) + ) + reqs = [req] + self.write_log(f'平衡仓位, 增加{vt_symbol}, 方向:{direction}, 数量: {req.volume}') + for req in reqs: + self.main_engine.send_order(req, contract.gateway_name) + + def init_all_strategies(self): + """ + """ + for strategy_name in self.strategies.keys(): + self.init_strategy(strategy_name) + + def start_all_strategies(self): + """ + """ + for strategy_name in self.strategies.keys(): + self.start_strategy(strategy_name) + + def stop_all_strategies(self): + """ + """ + for strategy_name in self.strategies.keys(): + self.stop_strategy(strategy_name) + + def load_strategy_setting(self): + """ + Load setting file. + """ + # 读取引擎得配置 + self.engine_config = load_json(self.config_filename) + # 是否产生event log 日志(一般GUI界面才产生,而且比好消耗资源) + self.event_log = self.engine_config.get('event_log', False) + + # 读取策略得配置 + self.strategy_setting = load_json(self.setting_filename) + + for strategy_name, strategy_config in self.strategy_setting.items(): + self.add_strategy( + class_name=strategy_config["class_name"], + strategy_name=strategy_name, + vt_symbols=strategy_config["vt_symbols"], + setting=strategy_config["setting"], + auto_init=strategy_config.get('auto_init', False), + auto_start=strategy_config.get('auto_start', False) + ) + + def update_strategy_setting(self, strategy_name: str, setting: dict, auto_init: bool = False, + auto_start: bool = False): + """ + Update setting file. + """ + strategy = self.strategies[strategy_name] + # 原配置 + old_config = self.strategy_setting.get('strategy_name', {}) + new_config = { + "class_name": strategy.__class__.__name__, + "vt_symbols": strategy.vt_symbols, + "auto_init": auto_init, + "auto_start": auto_start, + "setting": setting + } + + if old_config: + self.write_log(f'{strategy_name} 配置变更:\n{old_config} \n=> \n{new_config}') + + self.strategy_setting[strategy_name] = new_config + + sorted_setting = OrderedDict() + for k in sorted(self.strategy_setting.keys()): + sorted_setting.update({k: self.strategy_setting.get(k)}) + + save_json(self.setting_filename, sorted_setting) + + def remove_strategy_setting(self, strategy_name: str): + """ + Update setting file. + """ + if strategy_name not in self.strategy_setting: + return + self.write_log(f'移除CTA期权引擎{strategy_name}的配置') + self.strategy_setting.pop(strategy_name) + sorted_setting = OrderedDict() + for k in sorted(self.strategy_setting.keys()): + sorted_setting.update({k: self.strategy_setting.get(k)}) + + save_json(self.setting_filename, sorted_setting) + + def put_stop_order_event(self, stop_order: StopOrder): + """ + Put an event to update stop order status. + """ + event = Event(EVENT_CTA_STOPORDER, stop_order) + self.event_engine.put(event) + + def put_strategy_event(self, strategy: CtaTemplate): + """ + Put an event to update strategy status. + """ + data = strategy.get_data() + event = Event(EVENT_CTA_OPTION, data) + self.event_engine.put(event) + + def put_all_strategy_pos_event(self, strategy_pos_list: list = []): + """推送所有策略得持仓事件""" + for strategy_pos in strategy_pos_list: + event = Event(EVENT_STRATEGY_POS, copy(strategy_pos)) + self.event_engine.put(event) + + def write_log(self, msg: str, strategy_name: str = '', level: int = logging.INFO): + """ + Create cta engine log event. + """ + if self.event_log: + # 推送至全局CTA_LOG Event + log = LogData(msg=f"{strategy_name}: {msg}" if strategy_name else msg, + gateway_name="CtaStrategy", + level=level) + event = Event(type=EVENT_CTA_LOG, data=log) + self.event_engine.put(event) + + # 保存单独的策略日志 + if strategy_name: + strategy_logger = self.strategy_loggers.get(strategy_name, None) + if not strategy_logger: + log_path = get_folder_path('log') + log_filename = str(log_path.joinpath(str(strategy_name))) + print(u'create logger:{}'.format(log_filename)) + self.strategy_loggers[strategy_name] = setup_logger(file_name=log_filename, + name=str(strategy_name)) + strategy_logger = self.strategy_loggers.get(strategy_name) + if strategy_logger: + strategy_logger.log(level, msg) + else: + if self.logger: + self.logger.log(level, msg) + + # 如果日志数据异常,错误和告警,输出至sys.stderr + if level in [logging.CRITICAL, logging.ERROR, logging.WARNING]: + print(f"{strategy_name}: {msg}" if strategy_name else msg, file=sys.stderr) + + if level in [logging.CRITICAL, logging.WARN, logging.WARNING]: + send_wx_msg(content=f"{strategy_name}: {msg}" if strategy_name else msg, + target=self.engine_config.get('accountid', 'XXX')) + + def write_error(self, msg: str, strategy_name: str = '', level: int = logging.ERROR): + """写入错误日志""" + self.write_log(msg=msg, strategy_name=strategy_name, level=level) + + def send_email(self, msg: str, strategy: CtaTemplate = None): + """ + Send email to default receiver. + """ + if strategy: + subject = f"{strategy.strategy_name}" + else: + subject = "CTA期权策略引擎" + + self.main_engine.send_email(subject, msg) + + def send_wechat(self, msg: str, strategy: CtaTemplate = None): + """ + send wechat message to default receiver + :param msg: + :param strategy: + :return: + """ + if strategy: + subject = f"{strategy.strategy_name}" + else: + subject = "CTAOPtion引擎" + + send_wx_msg(content=f'{subject}:{msg}') diff --git a/vnpy/app/cta_option/option_utility.py b/vnpy/app/cta_option/option_utility.py new file mode 100644 index 00000000..c8a7d6b7 --- /dev/null +++ b/vnpy/app/cta_option/option_utility.py @@ -0,0 +1,120 @@ + +import numpy as np +from scipy import stats + +################# +# BSM模型相关 +def get_option_d(s, k, t, r, sigma, q): + d1 = (np.log(s/k) + (r - q + 0.5*sigma**2)*t)/(sigma*np.sqrt(t)) + d2 = (np.log(s/k) + (r - q - 0.5*sigma**2)*t)/(sigma*np.sqrt(t)) + return d1, d2 + +def get_option_greeks(cp, s, k, t, r, sigma, q): + """ + 计算期权希腊值 + :param cp: + :param s: + :param k: + :param t: + :param r: + :param sigma: + :param q: + :return: + """ + d1, d2 = get_option_d(s, k, t, r, sigma, q) + delta = cp * stats.norm.cdf(cp * d1) + gamma = stats.norm.pdf(d1) / (s * sigma * np.sqrt(t)) + vega = (s * stats.norm.pdf(d1) * np.sqrt(t)) + theta = (-1 * (s * stats.norm.pdf(d1) * sigma) / (2 * np.sqrt(t)) - cp * r * k * np.exp(-r * t) * stats.norm.cdf(cp * d2)) + return delta, gamma, vega, theta + +def bsm_value(cp, s, k, t, r, sigma, q): + d1, d2 = get_option_d(s, k, t, r, sigma, q) + if cp > 0: + value = ( + s*np.exp(-q*t)*stats.norm.cdf(d1) - + k*np.exp(-r*t)*stats.norm.cdf(d2) + ) + else: + value = ( + k * np.exp(-r * t) * stats.norm.cdf(-d2) - + s*np.exp(-q*t) * stats.norm.cdf(-d1) + ) + return value +############## + +# 二分法迭代计算隐波 +def calculate_single_option_iv_by_bsm( + cp, s, k, c, t, r, q, + initial_iv=0.5, # 迭代起始值,如果上一个分钟有计算过隐波,这里把上一分钟的结果输入进来,有助于加快收敛 +): + + c_est = 0 # 期权价格估计值 + top = 1 # 波动率上限 + floor = 0 # 波动率下限 + sigma = initial_iv # 波动率初始值 + count = 0 # 计数器 + best_result = 0 + error = abs(c - c_est) + last_error = error + while error > 0.0001: + c_est = bsm_value(cp, s, k, t, r, sigma, q) + error = abs(c - c_est) + if error < last_error: + best_result = sigma + + # 根据价格判断波动率是被低估还是高估,并对波动率做修正 + count += 1 + if count > 100: # 时间价值为0的期权是算不出隐含波动率的,因此迭代到一定次数就不再迭代了 + sigma = 0 + break + + if c - c_est > 0: # f(x)>0 + floor = sigma + sigma = (sigma + top)/2 + else: + top = sigma + sigma = (sigma + floor)/2 + return best_result + +# 计算隐含分红率 +# 我们目前不计算这个 +def calculate_dividend_rate( + underlying_price, # 当前标的价格 + call_price, + put_price, + rest_days, # 剩余时间 + exercise_price, # 行权价 + free_rate, + ): + c = call_price + c_p = put_price + r = free_rate + t = rest_days / 360 + k = exercise_price + s = underlying_price + q = -np.log((c+k*np.exp(-r*t)-c_p)/(s))/t + return q + +# 计算隐波和Greeks +def calculate_single_option_greeks( + underlying_price, # 当前标的价格 + option_price, # 期权价格 + call_put, # 期权方向, CALL=1 PUT=-1 + rest_days, # 剩余时间,按自然日计算,也可以用小数来表示不完整的日子 + exercise_price, # 行权价 + free_rate = 0.03, # 无风险利率,如果没有数据,指定为3% + dividend_rate = 0, # 分红率,目前指定为0 + initial_iv = 0.5, # 初始迭代的隐波 +): + cp = call_put + s = underlying_price + r = free_rate + k = exercise_price + t = rest_days / 360 + c = option_price + q = dividend_rate + sigma = calculate_single_option_iv_by_bsm(cp, s, k, c, t, r, q, initial_iv) + delta, gamma, vega, theta = get_option_greeks(cp, s, k, t, r, sigma, q) + # sigma就是iv + return sigma, delta, gamma, vega, theta diff --git a/vnpy/app/cta_option/strategies/__init__.py b/vnpy/app/cta_option/strategies/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vnpy/app/cta_option/template.py b/vnpy/app/cta_option/template.py new file mode 100644 index 00000000..98ff2a24 --- /dev/null +++ b/vnpy/app/cta_option/template.py @@ -0,0 +1,1068 @@ +# 期权模板 +# 华富资产 @ 李来佳 + +import os +import traceback + +import bz2 +import pickle +import zlib +from vnpy.trader.utility import append_data, extract_vt_symbol + +from abc import ABC +from copy import copy, deepcopy +from typing import Any, Callable, List, Dict +from logging import INFO, ERROR +from datetime import datetime +from vnpy.trader.constant import Interval, Direction, Offset, Status, OrderType, Color, Exchange +from vnpy.trader.object import BarData, TickData, OrderData, TradeData, PositionData, ContractData, HistoryRequest +from vnpy.trader.utility import virtual, append_data, extract_vt_symbol, get_underlying_symbol +# from vnpy.app.cta_option import CtaOptionEngine +from .base import StopOrder, EngineType +from vnpy.component.cta_grid_trade import CtaGrid, CtaGridTrade, LOCK_GRID + +from vnpy.component.cta_policy import CtaPolicy # noqa +from vnpy.trader.utility import print_dict + +DIRECTION_MAP = { + Direction.LONG.value: 'long', + Direction.SHORT.value: 'short', + Direction.NET.value: 'long' +} +class CtaOptionPolicy(CtaPolicy): + """ + 期权策略逻辑&持仓持久化组件 + 满足使用target_pos方式得策略 + """ + def __init__(self, strategy): + super().__init__(strategy) + self.cur_trading_date = None # 已执行pre_trading方法后更新的当前交易日 + self.signals = {} # kline_name: { 'last_signal': '', 'last_signal_time': datetime } + self.sub_tns = {} # 子事务, 事务名称: 事务内容dict + self.datas = {} # 数据名称: 数据内容 + self.holding_pos = {} # 当前策略得持仓, 合约_方向: 数量 + self.target_pos = {} # 当前策略得目标持仓,合约_方向: 数量 + + def from_json(self, json_data): + """将数据从json_data中恢复""" + super().from_json(json_data) + + self.cur_trading_date = json_data.get('cur_trading_date', None) + self.sub_tns = json_data.get('sub_tns',{}) + signals = json_data.get('signals', {}) + for k, signal in signals.items(): + last_signal = signal.get('last_signal', "") + str_ast_signal_time = signal.get('last_signal_time', "") + try: + if len(str_ast_signal_time) > 0: + last_signal_time = datetime.strptime(str_ast_signal_time, '%Y-%m-%d %H:%M:%S') + else: + last_signal_time = None + except Exception as ex: + last_signal_time = None + self.signals.update({k: {'last_signal': last_signal, 'last_signal_time': last_signal_time}}) + + self.datas = json_data.get('datas', {}) + self.holding_pos = json_data.get('holding_pos', {}) + self.target_pos = json_data.get('target_pos', {}) + + def to_json(self): + """转换至json文件""" + j = super().to_json() + j['cur_trading_date'] = self.cur_trading_date + j['sub_tns'] = self.sub_tns + d = {} + for kline_name, signal in self.signals.items(): + last_signal_time = signal.get('last_signal_time', None) + c_signal = {} + c_signal.update(signal) + c_signal.update({'last_signal': signal.get('last_signal', ''), + 'last_signal_time': last_signal_time.strftime( + '%Y-%m-%d %H:%M:%S') if last_signal_time is not None else "" + }) + d.update({kline_name: c_signal}) + j['signals'] = d + j['datas'] = self.datas + j['holding_pos'] = self.holding_pos + j['target_pos'] = self.target_pos + + return j + +class CtaTemplate(ABC): + """CTA策略模板""" + + author = "" + parameters = [] + variables = [] + + # 保存委托单编号和相关委托单的字典 + # key为委托单编号 + # value为该合约相关的委托单 + active_orders = {} + # 是否回测状态 + backtesting = False + + def __init__( + self, + cta_engine: Any, + strategy_name: str, + vt_symbols: List[str], + setting: dict, + ): + """""" + self.cta_engine = cta_engine + self.strategy_name = strategy_name + self.vt_symbols = vt_symbols + + self.backtesting = False # True, 回测状态; False,实盘状态 + self.inited = False # 是否初始化完毕 + self.trading = False # 是否开始交易 + self.positions = {} # 持仓,vt_symbol_direction: position data + self.entrust = 0 # 是否正在委托, 0, 无委托 , 1, 委托方向是LONG, -1, 委托方向是SHORT + + self.cur_datetime = datetime.now() # 当前时间 + + self.tick_dict = {} # 记录所有on_tick传入最新tick + self.active_orders = {} + # Copy a new variables list here to avoid duplicate insert when multiple + # strategy instances are created with the same strategy class. + self.variables = copy(self.variables) + self.variables.insert(0, "inited") + self.variables.insert(1, "trading") + self.variables.insert(2, "entrust") + + def update_setting(self, setting: dict): + """ + Update strategy parameter wtih value in setting dict. + """ + for name in self.parameters: + if name in setting: + setattr(self, name, setting[name]) + + @classmethod + def get_class_parameters(cls): + """ + Get default parameters dict of strategy class. + """ + class_parameters = {} + for name in cls.parameters: + class_parameters[name] = getattr(cls, name) + return class_parameters + + def get_parameters(self): + """ + Get strategy parameters dict. + """ + strategy_parameters = {} + for name in self.parameters: + strategy_parameters[name] = getattr(self, name) + return strategy_parameters + + def get_variables(self): + """ + Get strategy variables dict. + """ + strategy_variables = {} + for name in self.variables: + strategy_variables[name] = getattr(self, name) + return strategy_variables + + def get_data(self): + """ + Get strategy data. + """ + strategy_data = { + "strategy_name": self.strategy_name, + "vt_symbols": self.vt_symbols, + "class_name": self.__class__.__name__, + "author": self.author, + "parameters": self.get_parameters(), + "variables": self.get_variables(), + } + return strategy_data + + def get_position(self, vt_symbol, direction) -> PositionData: + """ + 获取策略内某vt_symbol+方向得持仓 + :return: + """ + k = f'{vt_symbol}_{direction.value}' + pos = self.positions.get(k, None) + if pos is None: + symbol, exchange = extract_vt_symbol(vt_symbol) + contract = self.cta_engine.get_contract(vt_symbol) + pos = PositionData( + gateway_name=contract.gateway_name if contract else '', + symbol=symbol, + name=contract.name, + exchange=exchange, + direction=direction + ) + self.positions.update({k: pos}) + + return pos + + def get_positions(self): + """ 返回持仓数量""" + pos_list = [] + for k, v in self.positions.items(): + # 分解出vt_symbol和方向 + vt_symbol, direction = k.split('_') + pos_list.append({ + "vt_symbol": vt_symbol, + "direction": DIRECTION_MAP.get(direction,'long'), + "name": v.name, + "volume": v.volume, + "price": v.price, + 'pnl': v.pnl + }) + + if len(pos_list) > 0: + self.write_log(f'策略返回持仓信息:{pos_list}') + return pos_list + + @virtual + def on_timer(self): + pass + + @virtual + def on_init(self): + """ + Callback when strategy is inited. + """ + pass + + @virtual + def on_start(self): + """ + Callback when strategy is started. + """ + pass + + @virtual + def on_stop(self): + """ + Callback when strategy is stopped. + """ + pass + + @virtual + def on_tick(self, tick_dict: Dict[str, TickData]): + """ + Callback of new tick data update. + """ + pass + + @virtual + def on_bar(self, bar_dict: Dict[str, BarData]): + """ + Callback of new bar data update. + """ + pass + + @virtual + def on_trade(self, trade: TradeData): + """ + Callback of new trade data update. + """ + pass + + @virtual + def on_order(self, order: OrderData): + """ + Callback of new order data update. + """ + pass + + @virtual + def on_stop_order(self, stop_order: StopOrder): + """ + Callback of stop order update. + """ + pass + + def before_trading(self): + """开盘前/初始化后调用一次""" + self.write_log('开盘前调用') + + def after_trading(self): + """收盘后调用一次""" + self.write_log('收盘后调用') + + def buy(self, price: float, volume: float, stop: bool = False, + vt_symbol: str = '', order_type: OrderType = OrderType.LIMIT, + order_time: datetime = None, grid: CtaGrid = None): + """ + Send buy order to open a long position. + """ + if order_type in [OrderType.FAK, OrderType.FOK]: + if self.is_upper_limit(vt_symbol): + self.write_error(u'涨停价不做FAK/FOK委托') + return [] + if volume == 0: + self.write_error(f'委托数量有误,必须大于0,{vt_symbol}, price:{price}') + return [] + return self.send_order(vt_symbol=vt_symbol, + direction=Direction.LONG, + offset=Offset.OPEN, + price=price, + volume=volume, + stop=stop, + order_type=order_type, + order_time=order_time, + grid=grid) + + def sell(self, price: float, volume: float, stop: bool = False, + vt_symbol: str = '', order_type: OrderType = OrderType.LIMIT, + order_time: datetime = None, grid: CtaGrid = None): + """ + Send sell order to close a long position. + """ + if order_type in [OrderType.FAK, OrderType.FOK]: + if self.is_lower_limit(vt_symbol): + self.write_error(u'跌停价不做FAK/FOK sell委托') + return [] + if volume == 0: + self.write_error(f'委托数量有误,必须大于0,{vt_symbol}, price:{price}') + return [] + return self.send_order(vt_symbol=vt_symbol, + direction=Direction.SHORT, + offset=Offset.CLOSE, + price=price, + volume=volume, + stop=stop, + order_type=order_type, + order_time=order_time, + grid=grid) + + def short(self, price: float, volume: float, stop: bool = False, + vt_symbol: str = '', order_type: OrderType = OrderType.LIMIT, + order_time: datetime = None, grid: CtaGrid = None): + """ + Send short order to open as short position. + """ + if order_type in [OrderType.FAK, OrderType.FOK]: + if self.is_lower_limit(vt_symbol): + self.write_error(u'跌停价不做FAK/FOK short委托') + return [] + if volume == 0: + self.write_error(f'委托数量有误,必须大于0,{vt_symbol}, price:{price}') + return [] + return self.send_order(vt_symbol=vt_symbol, + direction=Direction.SHORT, + offset=Offset.OPEN, + price=price, + volume=volume, + stop=stop, + order_type=order_type, + order_time=order_time, + grid=grid) + + def cover(self, price: float, volume: float, stop: bool = False, + vt_symbol: str = '', order_type: OrderType = OrderType.LIMIT, + order_time: datetime = None, grid: CtaGrid = None): + """ + Send cover order to close a short position. + """ + if order_type in [OrderType.FAK, OrderType.FOK]: + if self.is_upper_limit(vt_symbol): + self.write_error(u'涨停价不做FAK/FOK cover委托') + return [] + if volume == 0: + self.write_error(f'委托数量有误,必须大于0,{vt_symbol}, price:{price}') + return [] + return self.send_order(vt_symbol=vt_symbol, + direction=Direction.LONG, + offset=Offset.CLOSE, + price=price, + volume=volume, + stop=stop, + order_type=order_type, + order_time=order_time, + grid=grid) + + def send_order( + self, + vt_symbol: str, + direction: Direction, + offset: Offset, + price: float, + volume: float, + stop: bool = False, + order_type: OrderType = OrderType.LIMIT, + order_time: datetime = None, + grid: CtaGrid = None + ): + """ + Send a new order. + """ + # 兼容cta_strategy的模板,缺省不指定vt_symbol时,使用策略配置的vt_symbol + if vt_symbol == '': + return [] + + if not self.trading: + self.write_log(f'非交易状态') + return [] + + vt_orderids = self.cta_engine.send_order( + strategy=self, + vt_symbol=vt_symbol, + direction=direction, + offset=offset, + price=price, + volume=volume, + stop=stop, + order_type=order_type + ) + if len(vt_orderids) == 0: + self.write_error(f'{self.strategy_name}调用cta_engine.send_order委托返回失败,vt_symbol:{vt_symbol}') + # f',direction:{direction.value},offset:{offset.value},' + # f'price:{price},volume:{volume},stop:{stop},lock:{lock},' + # f'order_type:{order_type}') + + if order_time is None: + order_time = datetime.now() + + for vt_orderid in vt_orderids: + d = { + 'direction': direction, + 'offset': offset, + 'vt_symbol': vt_symbol, + 'price': price, + 'volume': volume, + 'order_type': order_type, + 'traded': 0, + 'order_time': order_time, + 'status': Status.SUBMITTING + } + if grid: + d.update({'grid': grid}) + if len(vt_orderid) > 0: + grid.order_ids.append(vt_orderid) + grid.order_time = order_time + self.active_orders.update({vt_orderid: d}) + if direction == Direction.LONG: + self.entrust = 1 + elif direction == Direction.SHORT: + self.entrust = -1 + return vt_orderids + + def cancel_order(self, vt_orderid: str): + """ + Cancel an existing order. + """ + if self.trading: + return self.cta_engine.cancel_order(self, vt_orderid) + + return False + + def cancel_all(self): + """ + Cancel all orders sent by strategy. + """ + if self.trading: + self.cta_engine.cancel_all(self) + + def is_upper_limit(self, symbol): + """是否涨停""" + tick = self.tick_dict.get(symbol, None) + if tick is None or tick.limit_up is None or tick.limit_up == 0: + return False + if tick.bid_price_1 == tick.limit_up: + return True + + def is_lower_limit(self, symbol): + """是否跌停""" + tick = self.tick_dict.get(symbol, None) + if tick is None or tick.limit_down is None or tick.limit_down == 0: + return False + if tick.ask_price_1 == tick.limit_down: + return True + + def write_log(self, msg: str, level: int = INFO): + """ + Write a log message. + """ + self.cta_engine.write_log(msg=msg, strategy_name=self.strategy_name, level=level) + + def write_error(self, msg: str): + """write error log message""" + self.write_log(msg=msg, level=ERROR) + + def get_engine_type(self): + """ + Return whether the cta_engine is backtesting or live trading. + """ + return self.cta_engine.get_engine_type() + + def load_bar( + self, + vt_symbol:str, + days: int, + interval: Interval = Interval.MINUTE, + callback: Callable = None, + interval_num: int = 1 + ): + """ + Load historical bar data for initializing strategy. + """ + if not callback: + callback = self.on_bar + + self.cta_engine.load_bar(vt_symbol, days, interval, callback, interval_num) + + def load_tick(self, vt_symbol: str, days: int): + """ + Load historical tick data for initializing strategy. + """ + self.cta_engine.load_tick(vt_symbol, days, self.on_tick) + + def put_event(self): + """ + Put an strategy data event for ui update. + """ + if self.inited: + self.cta_engine.put_strategy_event(self) + + def send_email(self, msg): + """ + Send email to default receiver. + """ + if self.inited: + self.cta_engine.send_email(msg, self) + + def sync_data(self): + """ + Sync strategy variables value into disk storage. + """ + if self.trading: + self.cta_engine.sync_strategy_data(self) + +class CtaOptionTemplate(CtaTemplate): + """期权交易增强版模板""" + + + # 逻辑过程日志 + dist_fieldnames = ['datetime', 'vt_symbol', 'name', 'volume', 'price', + 'operation', 'signal', 'stop_price', 'target_price', + 'long_pos','short_pos'] + + def __init__(self, cta_engine, strategy_name, vt_symbol, setting): + """""" + super().__init__(cta_engine, strategy_name, vt_symbol, setting) + + self.cancel_seconds = 60 # 撤单时间 + + self.klines = {} # 所有K线 + + # 策略事务逻辑与持仓组件 + self.policy = CtaOptionPolicy(strategy=self) + + + def update_setting(self, setting: dict): + """更新配置参数""" + super().update_setting(setting) + + def init_policy(self): + """加载policy""" + self.write_log(f'{self.strategy_name} => 初始化Policy') + self.policy.load() + self.write_log(u'Policy:{}'.format(print_dict(self.policy.to_json()))) + + # self.policy持仓 => self.positions + for k in list(self.policy.holding_pos.keys()): + v = self.policy.holding_pos.get(k) + if v == 0: + self.policy.holding_pos.pop(k,None) + continue + vt_symbol, direction = k.split('_') + cur_pos = self.get_position(vt_symbol, Direction(direction)) + cur_pos.volume = v + self.positions.update({k:cur_pos}) + # 订阅行情 + self.cta_engine.subscribe_symbol( + strategy_name=self.strategy_name, + vt_symbol=vt_symbol) + self.write_log(f'{self.strategy_name} => 恢复持仓 {cur_pos.vt_symbol}[{cur_pos.name}] {cur_pos.direction}:{cur_pos.volume}') + + def display_tns(self): + """ + 打印日志 + :return: + """ + if self.backtesting: + return + + self.write_log('当前policy:\n{}'.format(print_dict(self.policy.to_json()))) + + + def sync_data(self): + """同步更新数据""" + if not self.backtesting: + self.write_log(u'保存k线缓存数据') + self.save_klines_to_cache() + + if self.inited and self.trading: + self.write_log(u'保存policy数据') + self.policy.save() + + def save_klines_to_cache(self, kline_names: list = [], vt_symbol: str = ""): + """ + 保存K线数据到缓存 + :param kline_names: 一般为self.klines的keys + :param vt_symbol: 指定股票代码, + 如果使用该选项,加载 data/klines/strategyname_vtsymbol_klines.pkb2 + 如果空白,加载 data/strategyname_klines.pkb2 + :return: + """ + if len(kline_names) == 0: + kline_names = list(self.klines.keys()) + + try: + # 如果是指定合约的话,使用klines子目录 + if len(vt_symbol) > 0: + kline_names = [n for n in kline_names if vt_symbol in n] + save_path = os.path.abspath(os.path.join(self.cta_engine.get_data_path(), 'klines')) + if not os.path.exists(save_path): + os.makedirs(save_path) + file_name = os.path.abspath(os.path.join(save_path, f'{self.strategy_name}_{vt_symbol}_klines.pkb2')) + else: + # 获取保存路径 + save_path = self.cta_engine.get_data_path() + # 保存缓存的文件名 + file_name = os.path.abspath(os.path.join(save_path, f'{self.strategy_name}_klines.pkb2')) + + with bz2.BZ2File(file_name, 'wb') as f: + klines = {} + for kline_name in kline_names: + kline = self.klines.get(kline_name, None) + # if kline: + # kline.strategy = None + # kline.cb_on_bar = None + klines.update({kline_name: kline}) + pickle.dump(klines, f) + self.write_log(f'保存{vt_symbol} K线数据成功=>{file_name}') + except Exception as ex: + self.write_error(f'保存k线数据异常:{str(ex)}') + self.write_error(traceback.format_exc()) + + def load_klines_from_cache(self, kline_names: list = [], vt_symbol: str = ""): + """ + 从缓存加载K线数据 + :param kline_names: 指定需要加载的k线名称列表 + :param vt_symbol: 指定股票代码, + 如果使用该选项,加载 data/klines/strategyname_vtsymbol_klines.pkb2 + 如果空白,加载 data/strategyname_klines.pkb2 + :return: + """ + if len(kline_names) == 0: + kline_names = list(self.klines.keys()) + + # 如果是指定合约的话,使用klines子目录 + if len(vt_symbol) > 0: + kline_names = [n for n in kline_names if vt_symbol in n] + save_path = os.path.abspath(os.path.join(self.cta_engine.get_data_path(), 'klines')) + if not os.path.exists(save_path): + os.makedirs(save_path) + file_name = os.path.abspath(os.path.join(save_path, f'{self.strategy_name}_{vt_symbol}_klines.pkb2')) + else: + save_path = self.cta_engine.get_data_path() + file_name = os.path.abspath(os.path.join(save_path, f'{self.strategy_name}_klines.pkb2')) + try: + last_bar_dt = None + with bz2.BZ2File(file_name, 'rb') as f: + klines = pickle.load(f) + # 逐一恢复K线 + for kline_name in kline_names: + # 缓存的k线实例 + cache_kline = klines.get(kline_name, None) + # 当前策略实例的K线实例 + strategy_kline = self.klines.get(kline_name, None) + + if cache_kline and strategy_kline: + # 临时保存当前的回调函数 + cb_on_bar = strategy_kline.cb_on_bar + # 缓存实例数据 =》 当前实例数据 + strategy_kline.__dict__.update(cache_kline.__dict__) + + # 所有K线的最后时间 + if last_bar_dt and strategy_kline.cur_datetime: + last_bar_dt = max(last_bar_dt, strategy_kline.cur_datetime) + else: + last_bar_dt = strategy_kline.cur_datetime + + # 重新绑定k线策略与on_bar回调函数 + strategy_kline.strategy = self + strategy_kline.cb_on_bar = cb_on_bar + + self.write_log(f'恢复{kline_name}缓存数据,最新bar结束时间:{last_bar_dt}') + + self.write_log(u'加载缓存k线数据完毕') + return last_bar_dt + except Exception as ex: + self.write_error(f'加载缓存K线数据失败:{str(ex)}') + return None + + def get_klines_info(self): + """ + 返回当前所有kline的信息 + :return: {"股票中文":[kline_name1, kline_name2]} + """ + info = {} + for kline_name in list(self.klines.keys()): + # 策略中如果kline不是按照 vtsymbol_xxxx 的命名方式,需要策略内部自行实现方法 + vt_symbol = kline_name.split('_')[0] + # vt_symbol => 中文名 + cn_name = self.cta_engine.get_name(vt_symbol) + + # 添加到列表 => 排序 + kline_names = info.get(cn_name, []) + kline_names.append(kline_name) + kline_names = sorted(kline_names) + + # 更新 + info[cn_name] = kline_names + + return info + + def get_klines_snapshot(self, include_kline_names=[]): + """ + 返回当前klines的切片数据 + :param include_kline_names: 如果存在,则只保留这些指定得K线 + :return: + """ + try: + self.write_log(f'获取{self.strategy_name}的切片数据') + d = { + 'strategy': self.strategy_name, + 'datetime': datetime.now()} + klines = {} + for kline_name in sorted(self.klines.keys()): + if len(include_kline_names) > 0: + if kline_name not in include_kline_names: + continue + klines.update({kline_name: self.klines.get(kline_name).get_data()}) + kline_names = list(klines.keys()) + binary_data = zlib.compress(pickle.dumps(klines)) + d.update({'kline_names': kline_names, 'klines': binary_data, 'zlib': True}) + return d + except Exception as ex: + self.write_error(f'获取klines切片数据失败:{str(ex)}') + return {} + + def on_start(self): + """启动策略(必须由用户继承实现)""" + self.write_log(f'{self.strategy_name} => 策略启动') + self.trading = True + self.put_event() + + def on_stop(self): + """停止策略(必须由用户继承实现)""" + self.active_orders.clear() + self.entrust = 0 + + self.write_log(f'{self.strategy_name} => 策略停止') + self.put_event() + + def on_trade(self, trade: TradeData): + """ + 交易更新 + :param trade: + :return: + """ + + if (trade.direction == Direction.LONG and trade.offset == Offset.OPEN) \ + or (trade.direction == Direction.SHORT and trade.offset != Offset.OPEN): + cur_pos = self.get_position(trade.vt_symbol, Direction.LONG) + else: + cur_pos = self.get_position(trade.vt_symbol, Direction.SHORT) + + self.write_log(u'{},交易更新 =>{}\n,\n 当前持仓:\n{} ' + .format(self.cur_datetime, + print_dict(trade.__dict__), + print_dict(cur_pos.__dict__))) + + dist_record = dict() + if self.backtesting: + dist_record['datetime'] = trade.time + else: + dist_record['datetime'] = ' '.join([self.cur_datetime.strftime('%Y-%m-%d'), trade.time]) + dist_record['volume'] = trade.volume + dist_record['price'] = trade.price + dist_record['symbol'] = trade.vt_symbol + + if trade.direction == Direction.LONG and trade.offset == Offset.OPEN: + dist_record['operation'] = 'buy' + cur_pos.volume += trade.volume + dist_record['long_pos'] = cur_pos.volume + dist_record['short_pos'] = 0 + + if trade.direction == Direction.SHORT and trade.offset == Offset.OPEN: + dist_record['operation'] = 'short' + cur_pos.volume += trade.volume + dist_record['long_pos'] = 0 + dist_record['short_pos'] = cur_pos.volume + + if trade.direction == Direction.LONG and trade.offset != Offset.OPEN: + dist_record['operation'] = 'cover' + cur_pos.volume = max(0, cur_pos.volume - trade.volume) + dist_record['long_pos'] = 0 + dist_record['short_pos'] = cur_pos.volume + + if trade.direction == Direction.SHORT and trade.offset != Offset.OPEN: + dist_record['operation'] = 'sell' + cur_pos.volume = max(0, cur_pos.volume - trade.volume) + dist_record['long_pos'] = cur_pos.volume + dist_record['short_pos'] = 0 + + k = f'{cur_pos.vt_symbol}_{cur_pos.direction.value}' + # 更新 self.positions + self.positions.update({k: cur_pos}) + self.write_log(f'{self.strategy_name} ,positions[{k}]持仓更新 =>\n{print_dict(cur_pos.__dict__)}') + + # 更新 policy.holding_pos + self.write_log(f'{self.strategy_name} ,policy.holding_pos[{k}]持仓更新 => {cur_pos.volume}') + if cur_pos.volume == 0: + self.policy.holding_pos.pop(k, None) + else: + self.policy.holding_pos[k] = int(cur_pos.volume) + self.policy.save() + + # 这里要判断订单是否全部完成,如果完成,就移除活动订单 + if trade.vt_orderid in self.active_orders: + if self.active_orders[trade.vt_orderid].get('volume',-1) == self.active_orders[trade.vt_orderid].get('traded',0): + self.write_log(f'{trade.vt_orderid}全部执行完毕,移除活动订单') + self.active_orders.pop(trade.vt_orderid, None) + + self.save_dist(dist_record) + + def on_order(self, order: OrderData): + """报单更新""" + # 未执行的订单中,存在是异常,删除 + self.write_log(u'{}报单更新 =>\n {}'.format(self.cur_datetime, print_dict(order.__dict__))) + + if order.vt_orderid in self.active_orders: + d = self.active_orders[order.vt_orderid] + if d['traded'] != order.traded: + self.write_log(f'委托单交易 已交易{d["traded"]} => {order.traded}, 总委托:{order.volume}') + d['traded'] = order.traded + + if order.status in [Status.ALLTRADED]: + # 全部成交 + self.write_log(f'报单更新 => 委托开仓 => {order.status}') + # 这里不去掉active_orders,由on_trade进行去除 + + elif order.status in [Status.CANCELLED, Status.REJECTED]: + # 撤单、拒单 + self.write_log(f'报单更新 => 委托开仓 => {order.status}') + self.active_orders.pop(order.vt_orderid, None) + else: + # 未完成、部分成交.. + self.write_log(u'委托单未完成,total:{},traded:{},tradeStatus:{}' + .format(order.volume, order.traded, order.status)) + else: + self.write_error(u'委托单{}不在策略的未完成订单列表中:{}'.format(order.vt_orderid, self.active_orders)) + + def on_stop_order(self, stop_order: StopOrder): + """ + 停止单更新 + 需要自己重载,处理各类触发、撤单等情况 + """ + self.write_log(f'停止单触发:{stop_order.__dict__}') + + def cancel_all_orders(self): + """ + 重载撤销所有正在进行得委托 + :return: + """ + self.write_log(u'撤销所有正在进行得委托') + self.tns_cancel_logic(dt=datetime.now(), force=True) + + def tns_cancel_logic(self, dt, force=False): + "撤单逻辑""" + if len(self.active_orders) < 1: + self.entrust = 0 + return + + canceled_ids = [] + + for vt_orderid in list(self.active_orders.keys()): + order_info = self.active_orders[vt_orderid] + order_vt_symbol = order_info.get('vt_symbol') + order_time = order_info['order_time'] + + order_status = order_info.get('status', Status.NOTTRADED) + order_type = order_info.get('order_type', OrderType.LIMIT) + over_seconds = (dt - order_time).total_seconds() + + # 只处理未成交的限价委托单 + if order_status in [Status.NOTTRADED, Status.SUBMITTING] and order_type == OrderType.LIMIT: + if over_seconds > self.cancel_seconds or force: # 超过设置的时间还未成交 + self.write_log(u'撤单逻辑 => 超时{}秒未成交,取消委托单:vt_orderid:{},order:{}' + .format(over_seconds, vt_orderid, order_info)) + order_info.update({'status': Status.CANCELLING}) + self.active_orders.update({vt_orderid: order_info}) + ret = self.cancel_order(str(vt_orderid)) + if not ret: + self.write_error(f'{self.strategy_name}撤单逻辑 => {order_vt_symbol}撤单失败') + + continue + + # 处理状态为‘撤销’的委托单 + elif order_status == Status.CANCELLED: + self.write_log(u'撤单逻辑 => 委托单{}已成功撤单,将删除未完成订单{}'.format(vt_orderid, order_info)) + canceled_ids.append(vt_orderid) + + # 删除撤单的订单 + for vt_orderid in canceled_ids: + self.write_log(u'撤单逻辑 => 删除未完成订单:{}'.format(vt_orderid)) + self.active_orders.pop(vt_orderid, None) + + if len(self.active_orders) == 0: + self.entrust = 0 + + def tns_balance_pos(self, vt_symbol): + """ + 事务自动平衡 policy得holding_pos & target_pos 仓位 + 这里委托单时,还需要看看当前tick得ask1&bid1差距情况 + :param vt_symbol: + :return: + """ + option_name = self.cta_engine.get_name(vt_symbol) + c = self.cta_engine.get_contract(vt_symbol) + + for direction in [Direction.LONG, Direction.SHORT]: + k = f'{vt_symbol}_{direction.value}' + + target_pos = self.policy.target_pos.get(k, 0) + holding_pos = self.policy.holding_pos.get(k, 0) + diff_pos = target_pos - holding_pos + + if diff_pos == 0: + continue + + # 获取最新价 + cur_price = self.cta_engine.get_price(vt_symbol) + if not cur_price: + continue + # 获取最新tick + cur_tick = self.cta_engine.get_tick(vt_symbol) + price_tick = self.cta_engine.get_price_tick(vt_symbol) + if diff_pos > 0: # 需要增加仓位,增加多单或空单 + self.write_log(f'平衡仓位,{vt_symbol} [{c.name}]{direction.value}单,{holding_pos} =>{target_pos} => 增加 {diff_pos}手') + # 检查是否存在相同得开仓委托 + if self.exist_order(vt_symbol, direction, Offset.OPEN): + self.write_log(f'存在相同得开仓委托,暂不处理') + continue + + if direction == Direction.LONG: # 买入多单 + # 发出委托 + vt_orderid = self.buy(vt_symbol=vt_symbol, + price=cur_price, + volume=diff_pos, + order_type=OrderType.LIMIT, + order_time=self.cur_datetime) + if vt_orderid: + self.write_log(f'{self.strategy_name} 调整目标:{vt_symbol}[{option_name}]' + + f' {holding_pos} =>{target_pos} 开多:{diff_pos} ' + + f'价格:{cur_price} 委托编号:{vt_orderid}') + else: # 卖出空单 + # 发出委托 + vt_orderid = self.short(vt_symbol=vt_symbol, + price=cur_price, + volume=diff_pos, + order_type=OrderType.LIMIT, + order_time=self.cur_datetime) + if vt_orderid: + self.write_log(f'{self.strategy_name} 调整目标:{vt_symbol}[{option_name}]' + + f' {holding_pos} =>{target_pos} 开空:{diff_pos} ' + + f'价格:{cur_price} 委托编号:{vt_orderid}') + + else: # 需要减少仓位,平多单或平空单 + self.write_log( + f'平衡仓位,{vt_symbol} [{c.name}]{direction.value}单,{holding_pos} =>{target_pos} => 减少 {abs(diff_pos)}手') + close_direction = Direction.LONG if direction == Direction.SHORT else Direction.SHORT + # 检查是否存在相同得平仓委托 + if self.exist_order(vt_symbol, close_direction, Offset.CLOSE): + self.write_log(f'存在相同得平仓委托,暂不处理') + continue + + if direction == Direction.LONG: # 平仓多单 + sell_price = cur_price + # 价格 tick检查,叫卖价降低1个跳卖出 + if cur_tick and cur_tick.ask_price_1 and cur_tick.bid_price_1: + sell_price = max(cur_tick.ask_price_1 - price_tick,sell_price) + + # 发出委托 + vt_orderid = self.sell(vt_symbol=vt_symbol, + price=sell_price, + volume=abs(diff_pos), + order_type=OrderType.LIMIT, + order_time=self.cur_datetime) + if vt_orderid: + self.write_log(f'{self.strategy_name} 调整目标:{vt_symbol}[{option_name}]' + + f' {holding_pos} =>{target_pos} 多单平仓:{abs(diff_pos)} ' + + f'价格:{sell_price} 委托编号:{vt_orderid}') + else: # 平仓空单 + cover_price = cur_price + # 价格 tick检查,叫买价提高1个跳卖出 + if cur_tick and cur_tick.ask_price_1 and cur_tick.bid_price_1: + cover_price = min(cur_tick.bid_price_1 + price_tick, cover_price) + + # 发出委托 + vt_orderid = self.cover(vt_symbol=vt_symbol, + price=cover_price, + volume=abs(diff_pos), + order_type=OrderType.LIMIT, + order_time=self.cur_datetime) + if vt_orderid: + self.write_log(f'{self.strategy_name} 调整目标:{vt_symbol}[{option_name}]' + + f' {holding_pos} =>{target_pos} 空单平仓:{abs(diff_pos)} ' + + f'价格:{cover_price} 委托编号:{vt_orderid}') + + def exist_order(self, vt_symbol, direction, offset): + """ + 是否存在相同得委托 + :param vt_symbol: + :param direction: + :param offset: + :return: + """ + if len(self.active_orders) == 0: + self.write_log(f'当前活动订单中,数量为零. 查询{vt_symbol},方向:{direction.value}, 开平:{offset.value}') + return False + + for orderid, order in self.active_orders.items(): + self.write_log(f'当前活动订单:\n{print_dict(order)}') + if order['vt_symbol'] == vt_symbol and order['direction'] == direction and order['offset'] == offset: + self.write_log(f'存在相同得活动订单') + return True + + return False + + def save_dist(self, dist_data): + """ + 保存策略逻辑过程记录=》 csv文件按 + :param dist_data: + :return: + """ + if self.backtesting: + save_path = self.cta_engine.get_logs_path() + else: + save_path = self.cta_engine.get_data_path() + try: + + if 'datetime' not in dist_data: + dist_data.update({'datetime': self.cur_datetime}) + if 'long_pos' not in dist_data: + vt_symbol = dist_data.get('vt_symbol') + if vt_symbol: + # pos = self.get_position(vt_symbol) + # dist_data.update({'long_pos': pos.volume}) + if 'name' not in dist_data: + dist_data['name'] = self.cta_engine.get_name(vt_symbol) + + file_name = os.path.abspath(os.path.join(save_path, f'{self.strategy_name}_dist.csv')) + append_data(file_name=file_name, dict_data=dist_data, field_names=self.dist_fieldnames) + except Exception as ex: + self.write_error(u'save_dist 异常:{} {}'.format(str(ex), traceback.format_exc())) diff --git a/vnpy/app/cta_option/ui/__init__.py b/vnpy/app/cta_option/ui/__init__.py new file mode 100644 index 00000000..96215ecb --- /dev/null +++ b/vnpy/app/cta_option/ui/__init__.py @@ -0,0 +1 @@ +from .widget import CtaOption diff --git a/vnpy/app/cta_option/ui/cta.ico b/vnpy/app/cta_option/ui/cta.ico new file mode 100644 index 00000000..25cbaa73 Binary files /dev/null and b/vnpy/app/cta_option/ui/cta.ico differ diff --git a/vnpy/app/cta_option/ui/widget.py b/vnpy/app/cta_option/ui/widget.py new file mode 100644 index 00000000..2e283b15 --- /dev/null +++ b/vnpy/app/cta_option/ui/widget.py @@ -0,0 +1,552 @@ +from vnpy.event import Event, EventEngine +from vnpy.trader.engine import MainEngine +from vnpy.trader.ui import QtCore, QtGui, QtWidgets +from vnpy.trader.ui.widget import ( + BaseCell, + EnumCell, + MsgCell, + TimeCell, + BaseMonitor +) +from vnpy.trader.ui.kline.ui_snapshot import UiSnapshot +from ..base import ( + APP_NAME, + EVENT_CTA_LOG, + EVENT_CTA_STOPORDER, + EVENT_CTA_OPTION +) + +from ..engine import CtaOptionEngine + + +class CtaOption(QtWidgets.QWidget): + """""" + + signal_log = QtCore.pyqtSignal(Event) + signal_strategy = QtCore.pyqtSignal(Event) + + def __init__(self, main_engine: MainEngine, event_engine: EventEngine): + super(CtaOption, self).__init__() + + self.main_engine = main_engine + self.event_engine = event_engine + self.cta_engine = main_engine.get_engine(APP_NAME) + + self.managers = {} + + self.init_ui() + self.register_event() + self.cta_engine.init_engine() + self.update_class_combo() + + def init_ui(self): + """""" + self.setWindowTitle("CTA策略") + + # Create widgets + self.class_combo = QtWidgets.QComboBox() + + add_button = QtWidgets.QPushButton("添加策略") + add_button.clicked.connect(self.add_strategy) + + init_button = QtWidgets.QPushButton("全部初始化") + init_button.clicked.connect(self.cta_engine.init_all_strategies) + + start_button = QtWidgets.QPushButton("全部启动") + start_button.clicked.connect(self.cta_engine.start_all_strategies) + + stop_button = QtWidgets.QPushButton("全部停止") + stop_button.clicked.connect(self.cta_engine.stop_all_strategies) + + clear_button = QtWidgets.QPushButton("清空日志") + clear_button.clicked.connect(self.clear_log) + + self.scroll_layout = QtWidgets.QVBoxLayout() + self.scroll_layout.addStretch() + + scroll_widget = QtWidgets.QWidget() + scroll_widget.setLayout(self.scroll_layout) + + scroll_area = QtWidgets.QScrollArea() + scroll_area.setWidgetResizable(True) + scroll_area.setWidget(scroll_widget) + + self.log_monitor = LogMonitor(self.main_engine, self.event_engine) + + self.stop_order_monitor = StopOrderMonitor( + self.main_engine, self.event_engine + ) + + # Set layout + hbox1 = QtWidgets.QHBoxLayout() + hbox1.addWidget(self.class_combo) + hbox1.addWidget(add_button) + hbox1.addStretch() + hbox1.addWidget(init_button) + hbox1.addWidget(start_button) + hbox1.addWidget(stop_button) + hbox1.addWidget(clear_button) + + grid = QtWidgets.QGridLayout() + grid.addWidget(scroll_area, 0, 0, 2, 1) + grid.addWidget(self.stop_order_monitor, 0, 1) + grid.addWidget(self.log_monitor, 1, 1) + + vbox = QtWidgets.QVBoxLayout() + vbox.addLayout(hbox1) + vbox.addLayout(grid) + + self.setLayout(vbox) + + def update_class_combo(self): + """""" + self.class_combo.addItems( + self.cta_engine.get_all_strategy_class_names() + ) + + def register_event(self): + """""" + self.signal_strategy.connect(self.process_strategy_event) + + self.event_engine.register( + EVENT_CTA_OPTION, self.signal_strategy.emit + ) + + def process_strategy_event(self, event): + """ + Update strategy status onto its monitor. + """ + data = event.data + strategy_name = data["strategy_name"] + + if strategy_name in self.managers: + manager = self.managers[strategy_name] + manager.update_data(data) + else: + manager = StrategyManager(self, self.cta_engine, data) + self.scroll_layout.insertWidget(0, manager) + self.managers[strategy_name] = manager + + def remove_strategy(self, strategy_name): + """""" + manager = self.managers.pop(strategy_name) + manager.deleteLater() + + def add_strategy(self): + """""" + class_name = str(self.class_combo.currentText()) + if not class_name: + return + + parameters = self.cta_engine.get_strategy_class_parameters(class_name) + editor = SettingEditor(parameters, class_name=class_name) + n = editor.exec_() + + if n == editor.Accepted: + setting = editor.get_setting() + vt_symbols = setting.pop("vt_symbols").split(",") + strategy_name = setting.pop("strategy_name") + auto_init = setting.pop("auto_init", False) + auto_start = setting.pop("auto_start", False) + self.cta_engine.add_strategy( + class_name, strategy_name, vt_symbols, setting, auto_init, auto_start + ) + + def clear_log(self): + """""" + self.log_monitor.setRowCount(0) + + def show(self): + """""" + self.showMaximized() + + +class StrategyManager(QtWidgets.QFrame): + """ + Manager for a strategy + """ + + def __init__( + self, cta_manager: CtaOption, cta_engine: CtaOptionEngine, data: dict + ): + """""" + super(StrategyManager, self).__init__() + + self.cta_manager = cta_manager + self.cta_engine = cta_engine + + self.strategy_name = data["strategy_name"] + self._data = data + + self.init_ui() + + def init_ui(self): + """""" + self.setFixedHeight(300) + self.setFrameShape(self.Box) + self.setLineWidth(1) + + init_button = QtWidgets.QPushButton("初始化") + init_button.clicked.connect(self.init_strategy) + + start_button = QtWidgets.QPushButton("启动") + start_button.clicked.connect(self.start_strategy) + + stop_button = QtWidgets.QPushButton("停止") + stop_button.clicked.connect(self.stop_strategy) + + edit_button = QtWidgets.QPushButton("编辑") + edit_button.clicked.connect(self.edit_strategy) + + remove_button = QtWidgets.QPushButton("移除") + remove_button.clicked.connect(self.remove_strategy) + + reload_button = QtWidgets.QPushButton("重载") + reload_button.clicked.connect(self.reload_strategy) + + save_button = QtWidgets.QPushButton("保存") + save_button.clicked.connect(self.save_strategy) + + view_button = QtWidgets.QPushButton("K线") + view_button.clicked.connect(self.view_strategy_snapshot) + + strategy_name = self._data["strategy_name"] + #vt_symbol = self._data["vt_symbol"] + class_name = self._data["class_name"] + author = self._data["author"] + + label_text = ( + f"{strategy_name} - ({class_name} by {author})" + ) + label = QtWidgets.QLabel(label_text) + label.setAlignment(QtCore.Qt.AlignCenter) + + self.parameters_monitor = DataMonitor(self._data["parameters"]) + self.variables_monitor = DataMonitor(self._data["variables"]) + + hbox = QtWidgets.QHBoxLayout() + hbox.addWidget(init_button) + hbox.addWidget(start_button) + hbox.addWidget(stop_button) + hbox.addWidget(edit_button) + hbox.addWidget(remove_button) + hbox.addWidget(reload_button) + hbox.addWidget(save_button) + hbox.addWidget(view_button) + + vbox = QtWidgets.QVBoxLayout() + vbox.addWidget(label) + vbox.addLayout(hbox) + vbox.addWidget(self.parameters_monitor) + vbox.addWidget(self.variables_monitor) + self.setLayout(vbox) + + def update_data(self, data: dict): + """""" + self._data = data + + self.parameters_monitor.update_data(data["parameters"]) + self.variables_monitor.update_data(data["variables"]) + + def init_strategy(self): + """""" + self.cta_engine.init_strategy(self.strategy_name) + + def start_strategy(self): + """""" + self.cta_engine.start_strategy(self.strategy_name) + + def stop_strategy(self): + """""" + self.cta_engine.stop_strategy(self.strategy_name) + + def edit_strategy(self): + """""" + strategy_name = self._data["strategy_name"] + + parameters = self.cta_engine.get_strategy_parameters(strategy_name) + editor = SettingEditor(parameters, strategy_name=strategy_name) + n = editor.exec_() + + if n == editor.Accepted: + setting = editor.get_setting() + self.cta_engine.edit_strategy(strategy_name, setting) + + def remove_strategy(self): + """""" + result = self.cta_engine.remove_strategy(self.strategy_name) + + # Only remove strategy gui manager if it has been removed from engine + if result: + self.cta_manager.remove_strategy(self.strategy_name) + + def reload_strategy(self): + """重新加载策略""" + self.cta_engine.reload_strategy(self.strategy_name) + + def save_strategy(self): + """保存策略缓存数据""" + self.cta_engine.save_strategy_data(self.strategy_name) + + def view_strategy_snapshot(self): + """实时查看策略切片""" + kline_info = self.cta_engine.get_strategy_kline_names(self.strategy_name) + + selector = KlineSelectDialog(kline_info,self.strategy_name) + n = selector.exec_() + + if n == selector.Accepted: + klines = selector.get_klines() + if len(klines) > 0: + snapshot = self.cta_engine.get_strategy_snapshot(self.strategy_name,klines) + if snapshot is None: + return + ui_snapshot = UiSnapshot() + ui_snapshot.show(snapshot_file="", d=snapshot) + +class DataMonitor(QtWidgets.QTableWidget): + """ + Table monitor for parameters and variables. + """ + + def __init__(self, data: dict): + """""" + super(DataMonitor, self).__init__() + + self._data = data + self.cells = {} + + self.init_ui() + + def init_ui(self): + """""" + labels = list(self._data.keys()) + self.setColumnCount(len(labels)) + self.setHorizontalHeaderLabels(labels) + + self.setRowCount(1) + self.verticalHeader().setSectionResizeMode( + QtWidgets.QHeaderView.Stretch + ) + self.verticalHeader().setVisible(False) + self.setEditTriggers(self.NoEditTriggers) + + for column, name in enumerate(self._data.keys()): + value = self._data[name] + + cell = QtWidgets.QTableWidgetItem(str(value)) + cell.setTextAlignment(QtCore.Qt.AlignCenter) + + self.setItem(0, column, cell) + self.cells[name] = cell + + def update_data(self, data: dict): + """""" + for name, value in data.items(): + cell = self.cells[name] + cell.setText(str(value)) + + +class StopOrderMonitor(BaseMonitor): + """ + Monitor for local stop order. + """ + + event_type = EVENT_CTA_STOPORDER + data_key = "stop_orderid" + sorting = True + + headers = { + "stop_orderid": { + "display": "停止委托号", + "cell": BaseCell, + "update": False, + }, + "vt_orderids": {"display": "限价委托号", "cell": BaseCell, "update": True}, + "vt_symbol": {"display": "本地代码", "cell": BaseCell, "update": False}, + "direction": {"display": "方向", "cell": EnumCell, "update": False}, + "offset": {"display": "开平", "cell": EnumCell, "update": False}, + "price": {"display": "价格", "cell": BaseCell, "update": False}, + "volume": {"display": "数量", "cell": BaseCell, "update": False}, + "status": {"display": "状态", "cell": EnumCell, "update": True}, + "lock": {"display": "锁仓", "cell": BaseCell, "update": False}, + "strategy_name": {"display": "策略名", "cell": BaseCell, "update": False}, + } + + +class LogMonitor(BaseMonitor): + """ + Monitor for log data. + """ + + event_type = EVENT_CTA_LOG + data_key = "" + sorting = False + + headers = { + "time": {"display": "时间", "cell": TimeCell, "update": False}, + "msg": {"display": "信息", "cell": MsgCell, "update": False}, + } + + def init_ui(self): + """ + Stretch last column. + """ + super(LogMonitor, self).init_ui() + + self.horizontalHeader().setSectionResizeMode( + 1, QtWidgets.QHeaderView.Stretch + ) + + def insert_new_row(self, data): + """ + Insert a new row at the top of table. + """ + super(LogMonitor, self).insert_new_row(data) + self.resizeRowToContents(0) + +class KlineSelectDialog(QtWidgets.QDialog): + """ + 多K线选择窗口 + """ + def __init__( + self, info: dict, strategy_name:str + ): + """ + 构造函数 + :param info: 所有k线的配置 + :param strategy_name: + """ + super(KlineSelectDialog, self).__init__() + + self.info = info + self.strategy_name = strategy_name + self.t = None + self.select_names = [] + + self.init_ui() + + def init_ui(self): + """""" + form = QtWidgets.QFormLayout() + self.t = QtWidgets.QTableWidget(len(self.info), 2) + + self.t.setHorizontalHeaderLabels(['股票', 'K线']) + row = 0 + for k, v in self.info.items(): + + item = QtWidgets.QTableWidgetItem() + item.setText(k) + self.t.setItem(row, 0, item) + + klines = QtWidgets.QTableWidgetItem() + klines.setText(','.join(v)) + self.t.setItem(row,1, klines) + row +=1 + + # 单选 + self.t.setSelectionMode(QtWidgets.QAbstractItemView.SingleSelection) + # self.t.cellPressed.conect(self.cell_select) + form.addWidget(self.t) + button = QtWidgets.QPushButton('确定') + button.clicked.connect(self.accept) + form.addRow(button) + self.setLayout(form) + + def cell_select(self,row,col): + try: + content = self.t.item(row,0).text() + self.select_names = self.info.get(content,[]) + except Exception as ex: + pass + + def get_klines(self): + """""" + selectedItems = self.t.selectedItems() + for item in selectedItems: + cur_row = item.row() + content = item.text() + self.select_names = self.info.get(content, []) + if len(self.select_names) > 0: + return self.select_names + return self.select_names + + +class SettingEditor(QtWidgets.QDialog): + """ + For creating new strategy and editing strategy parameters. + """ + + def __init__( + self, parameters: dict, strategy_name: str = "", class_name: str = "" + ): + """""" + super(SettingEditor, self).__init__() + + self.parameters = parameters + self.strategy_name = strategy_name + self.class_name = class_name + + self.edits = {} + + self.init_ui() + + def init_ui(self): + """""" + form = QtWidgets.QFormLayout() + + # Add vt_symbol and name edit if add new strategy + if self.class_name: + self.setWindowTitle(f"添加策略:{self.class_name}") + button_text = "添加" + parameters = {"strategy_name": "", "vt_symbols": "", "auto_init": True, "auto_start": True} + parameters.update(self.parameters) + + else: + self.setWindowTitle(f"参数编辑:{self.strategy_name}") + button_text = "确定" + parameters = self.parameters + + for name, value in parameters.items(): + type_ = type(value) + + edit = QtWidgets.QLineEdit(str(value)) + if type_ is int: + validator = QtGui.QIntValidator() + edit.setValidator(validator) + elif type_ is float: + validator = QtGui.QDoubleValidator() + edit.setValidator(validator) + + form.addRow(f"{name} {type_}", edit) + + self.edits[name] = (edit, type_) + + button = QtWidgets.QPushButton(button_text) + button.clicked.connect(self.accept) + form.addRow(button) + + self.setLayout(form) + + def get_setting(self): + """""" + setting = {} + + if self.class_name: + setting["class_name"] = self.class_name + + for name, tp in self.edits.items(): + edit, type_ = tp + value_text = edit.text() + + if type_ == bool: + if value_text == "True": + value = True + else: + value = False + else: + value = type_(value_text) + + setting[name] = value + + return setting