[增强功能] 添加vnpy原版组合引擎,修改gateway,取消多余event,提高性能;更新K线组件/网格组件
This commit is contained in:
parent
57448173a6
commit
fa22e74a8a
@ -330,10 +330,13 @@ class AccountRecorder(BaseEngine):
|
||||
# self.write_log(u'记录委托日志:{}'.format(order.__dict__))
|
||||
if len(order.sys_orderid) == 0:
|
||||
# 未有系统的委托编号,不做持久化
|
||||
return
|
||||
order.sys_orderid = order.orderid
|
||||
dt = getattr(order, 'datetime')
|
||||
if not dt:
|
||||
order_date = datetime.now().strftime('%Y-%m-%d')
|
||||
if len(order.time) > 0 and '.' not in order.time:
|
||||
dt = datetime.strptime(f'{order_date} {order.time}', '%Y-%m-%d %H:%M:%S')
|
||||
order.datetime = dt
|
||||
else:
|
||||
order_date = dt.strftime('%Y-%m-%d')
|
||||
|
||||
|
@ -6,11 +6,12 @@ from functools import lru_cache
|
||||
from vnpy.event import EventEngine, Event
|
||||
from vnpy.trader.engine import BaseEngine, MainEngine
|
||||
from vnpy.trader.event import (
|
||||
EVENT_TICK, EVENT_TIMER, EVENT_ORDER, EVENT_TRADE)
|
||||
EVENT_TICK, EVENT_TIMER, EVENT_ORDER, EVENT_TRADE, EVENT_POSITION)
|
||||
from vnpy.trader.constant import (Direction, Offset, OrderType, Status)
|
||||
from vnpy.trader.object import (SubscribeRequest, OrderRequest, LogData, CancelRequest)
|
||||
from vnpy.trader.utility import load_json, save_json, round_to, get_folder_path
|
||||
from vnpy.trader.utility import load_json, save_json, round_to, get_folder_path,print_dict
|
||||
from vnpy.trader.util_logger import setup_logger, logging
|
||||
from vnpy.trader.converter import OffsetConverter
|
||||
|
||||
from .template import AlgoTemplate
|
||||
|
||||
@ -35,13 +36,13 @@ class AlgoEngine(BaseEngine):
|
||||
self.symbol_algo_map = {}
|
||||
self.orderid_algo_map = {}
|
||||
|
||||
self.algo_vtorderid_order_map = {} # 记录外部发起的算法交易委托编号,便于通过算法引擎撤单
|
||||
self.spd_orders = {} # 记录外部发起的算法交易委托编号,便于通过算法引擎撤单
|
||||
|
||||
self.algo_templates = {}
|
||||
self.algo_settings = {}
|
||||
|
||||
self.algo_loggers = {} # algo_name: logger
|
||||
|
||||
self.offset_converter = OffsetConverter(self.main_engine)
|
||||
self.load_algo_template()
|
||||
self.register_event()
|
||||
|
||||
@ -60,6 +61,7 @@ class AlgoEngine(BaseEngine):
|
||||
from .algos.grid_algo import GridAlgo
|
||||
from .algos.dma_algo import DmaAlgo
|
||||
from .algos.arbitrage_algo import ArbitrageAlgo
|
||||
from .algos.spread_algo_v2 import SpreadAlgoV2
|
||||
|
||||
self.add_algo_template(TwapAlgo)
|
||||
self.add_algo_template(IcebergAlgo)
|
||||
@ -69,6 +71,7 @@ class AlgoEngine(BaseEngine):
|
||||
self.add_algo_template(GridAlgo)
|
||||
self.add_algo_template(DmaAlgo)
|
||||
self.add_algo_template(ArbitrageAlgo)
|
||||
self.add_algo_template(SpreadAlgoV2)
|
||||
|
||||
def add_algo_template(self, template: AlgoTemplate):
|
||||
""""""
|
||||
@ -93,6 +96,7 @@ class AlgoEngine(BaseEngine):
|
||||
self.event_engine.register(EVENT_TIMER, self.process_timer_event)
|
||||
self.event_engine.register(EVENT_ORDER, self.process_order_event)
|
||||
self.event_engine.register(EVENT_TRADE, self.process_trade_event)
|
||||
self.event_engine.register(EVENT_POSITION, self.process_position_event)
|
||||
|
||||
def process_tick_event(self, event: Event):
|
||||
""""""
|
||||
@ -114,7 +118,7 @@ class AlgoEngine(BaseEngine):
|
||||
def process_trade_event(self, event: Event):
|
||||
""""""
|
||||
trade = event.data
|
||||
|
||||
self.offset_converter.update_trade(trade)
|
||||
algo = self.orderid_algo_map.get(trade.vt_orderid, None)
|
||||
if algo:
|
||||
algo.update_trade(trade)
|
||||
@ -122,11 +126,17 @@ class AlgoEngine(BaseEngine):
|
||||
def process_order_event(self, event: Event):
|
||||
""""""
|
||||
order = event.data
|
||||
|
||||
self.offset_converter.update_order(order)
|
||||
algo = self.orderid_algo_map.get(order.vt_orderid, None)
|
||||
if algo:
|
||||
algo.update_order(order)
|
||||
|
||||
def process_position_event(self, event: Event):
|
||||
""""""
|
||||
position = event.data
|
||||
|
||||
self.offset_converter.update_position(position)
|
||||
|
||||
def start_algo(self, setting: dict):
|
||||
""""""
|
||||
template_name = setting["template_name"]
|
||||
@ -144,6 +154,7 @@ class AlgoEngine(BaseEngine):
|
||||
if algo:
|
||||
algo.stop()
|
||||
self.algos.pop(algo_name)
|
||||
return True
|
||||
|
||||
def stop_all(self):
|
||||
""""""
|
||||
@ -154,7 +165,7 @@ class AlgoEngine(BaseEngine):
|
||||
""""""
|
||||
contract = self.main_engine.get_contract(vt_symbol)
|
||||
if not contract:
|
||||
self.write_log(f'订阅行情失败,找不到合约:{vt_symbol}', algo)
|
||||
self.write_log(msg=f'订阅行情失败,找不到合约:{vt_symbol}', algo_name=algo.algo_name)
|
||||
return
|
||||
|
||||
algos = self.symbol_algo_map.setdefault(vt_symbol, set())
|
||||
@ -186,9 +197,9 @@ class AlgoEngine(BaseEngine):
|
||||
|
||||
volume = round_to(volume, contract.min_volume)
|
||||
if not volume:
|
||||
return ""
|
||||
return []
|
||||
|
||||
req = OrderRequest(
|
||||
original_req = OrderRequest(
|
||||
symbol=contract.symbol,
|
||||
exchange=contract.exchange,
|
||||
direction=direction,
|
||||
@ -197,83 +208,88 @@ class AlgoEngine(BaseEngine):
|
||||
price=price,
|
||||
offset=offset
|
||||
)
|
||||
req_list = self.offset_converter.convert_order_request(req=original_req, lock=False, gateway_name=contract.gateway_name)
|
||||
vt_orderids = []
|
||||
for req in req_list:
|
||||
vt_orderid = self.main_engine.send_order(req, contract.gateway_name)
|
||||
if not vt_orderid:
|
||||
continue
|
||||
|
||||
vt_orderids.append(vt_orderid)
|
||||
|
||||
self.offset_converter.update_order_request(req, vt_orderid, contract.gateway_name)
|
||||
|
||||
self.orderid_algo_map[vt_orderid] = algo
|
||||
return vt_orderid
|
||||
|
||||
return vt_orderids
|
||||
|
||||
def cancel_order(self, algo: AlgoTemplate, vt_orderid: str):
|
||||
""""""
|
||||
order = self.main_engine.get_order(vt_orderid)
|
||||
|
||||
if not order:
|
||||
self.write_log(f"委托撤单失败,找不到委托:{vt_orderid}", algo)
|
||||
return
|
||||
self.write_log(msg=f"委托撤单失败,找不到委托:{vt_orderid}", algo_name=algo.algo_name)
|
||||
return False
|
||||
|
||||
req = order.create_cancel_request()
|
||||
self.main_engine.cancel_order(req, order.gateway_name)
|
||||
return self.main_engine.cancel_order(req, order.gateway_name)
|
||||
|
||||
def send_algo_order(self, req: OrderRequest, gateway_name: str):
|
||||
"""发送算法交易指令"""
|
||||
self.write_log(u'创建算法交易,gateway_name:{},strategy_name:{},vt_symbol:{},price:{},volume:{}'
|
||||
def send_spd_order(self, req: OrderRequest, gateway_name: str):
|
||||
"""发送SPD算法交易指令"""
|
||||
self.write_log(u'[SPD算法交易],gateway_name:{},strategy_name:{},vt_symbol:{},price:{},volume:{}'
|
||||
.format(gateway_name, req.strategy_name, req.vt_symbol, req.price, req.volume))
|
||||
|
||||
# 创建算法实例,由算法引擎启动
|
||||
trade_command = ''
|
||||
if req.direction == Direction.LONG and req.offset == Offset.OPEN:
|
||||
trade_command = 'Buy'
|
||||
elif req.direction == Direction.SHORT and req.offset == Offset.OPEN:
|
||||
trade_command = 'Short'
|
||||
elif req.direction == Direction.SHORT and req.offset != Offset.OPEN:
|
||||
trade_command = 'Sell'
|
||||
elif req.direction == Direction.LONG and req.offset != Offset.OPEN:
|
||||
trade_command = 'Cover'
|
||||
|
||||
all_custom_contracts = self.main_engine.get_all_custom_contracts()
|
||||
contract_setting = all_custom_contracts.get(req.vt_symbol, {})
|
||||
algo_setting = {
|
||||
'templateName': u'SpreadTrading套利',
|
||||
custom_settings = self.main_engine.get_all_custom_contracts(rtn_setting=True)
|
||||
contract = custom_settings.get(req.symbol, {})
|
||||
setting = {
|
||||
'template_name': u'SpreadAlgoV2',
|
||||
'order_vt_symbol': req.vt_symbol,
|
||||
'order_command': trade_command,
|
||||
'order_direction': req.direction,
|
||||
'order_offset': req.offset,
|
||||
'order_price': req.price,
|
||||
'order_volume': req.volume,
|
||||
'timer_interval': 60 * 60 * 24,
|
||||
'strategy_name': req.strategy_name,
|
||||
'gateway_name': gateway_name
|
||||
}
|
||||
algo_setting.update(contract_setting)
|
||||
# 更新算法配置
|
||||
setting.update(contract)
|
||||
|
||||
# 算法引擎
|
||||
algo_name = self.start_algo(algo_setting)
|
||||
self.write_log(u'send_algo_order(): start_algo {}={}'.format(algo_name, str(algo_setting)))
|
||||
algo_name = self.start_algo(setting)
|
||||
self.write_log(f'[SPD算法交易]: 实例id: {algo_name}, 配置:{print_dict(setting)}')
|
||||
|
||||
# 创建一个Order事件
|
||||
# 创建一个Order事件, 正在提交
|
||||
order = req.create_order_data(orderid=algo_name, gateway_name=gateway_name)
|
||||
order.orderTime = datetime.now().strftime('%H:%M:%S.%f')
|
||||
order.datetime = datetime.now()
|
||||
order.time = order.datetime.strftime('%H:%M:%S.%f')
|
||||
order.status = Status.SUBMITTING
|
||||
|
||||
event1 = Event(type=EVENT_ORDER, data=order)
|
||||
self.event_engine.put(event1)
|
||||
|
||||
# 登记在本地的算法委托字典中
|
||||
self.algo_vtorderid_order_map.update({order.vt_orderid: order})
|
||||
self.spd_orders.update({order.orderid: order})
|
||||
|
||||
return order.vt_orderid
|
||||
|
||||
def is_algo_order(self, req: CancelRequest, gateway_name: str):
|
||||
def get_spd_order(self, orderid):
|
||||
"""返回spd委托单"""
|
||||
return self.spd_orders.get(orderid, None)
|
||||
|
||||
def is_spd_order(self, req: CancelRequest):
|
||||
"""是否为外部算法委托单"""
|
||||
vt_orderid = '.'.join([req.orderid, gateway_name])
|
||||
if vt_orderid in self.algo_vtorderid_order_map:
|
||||
if req.orderid in self.spd_orders:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def cancel_algo_order(self, req: CancelRequest, gateway_name: str):
|
||||
def cancel_spd_order(self, req: CancelRequest):
|
||||
"""外部算法单撤单"""
|
||||
vt_orderid = '.'.join([req.orderid, gateway_name])
|
||||
order = self.algo_vtorderid_order_map.get(vt_orderid, None)
|
||||
|
||||
order = self.spd_orders.get(req.orderid, None)
|
||||
if not order:
|
||||
self.write_error(f'{vt_orderid}不在算法引擎中,撤单失败')
|
||||
self.write_error(f'{req.orderid}不在算法引擎中,撤单失败')
|
||||
return False
|
||||
|
||||
algo = self.algos.get(req.orderid, None)
|
||||
@ -368,6 +384,19 @@ class AlgoEngine(BaseEngine):
|
||||
|
||||
return contract
|
||||
|
||||
def get_position(self, vt_symbol: str, direction: Direction, gateway_name: str = ''):
|
||||
""" 查询合约在账号的持仓,需要指定方向"""
|
||||
if len(gateway_name) == 0:
|
||||
contract = self.main_engine.get_contract(vt_symbol)
|
||||
if contract and contract.gateway_name:
|
||||
gateway_name = contract.gateway_name
|
||||
vt_position_id = f"{gateway_name}.{vt_symbol}.{direction.value}"
|
||||
return self.main_engine.get_position(vt_position_id)
|
||||
|
||||
def get_position_holding(self, vt_symbol: str, gateway_name: str = ''):
|
||||
""" 查询合约在账号的持仓(包含多空)"""
|
||||
return self.offset_converter.get_position_holding(vt_symbol, gateway_name)
|
||||
|
||||
def write_log(self, msg: str, algo_name: str = None, level: int = logging.INFO):
|
||||
"""增强版写日志"""
|
||||
if algo_name:
|
||||
|
@ -154,7 +154,7 @@ class AlgoTemplate:
|
||||
|
||||
def cancel_order(self, vt_orderid: str):
|
||||
""""""
|
||||
self.algo_engine.cancel_order(self, vt_orderid)
|
||||
return self.algo_engine.cancel_order(self, vt_orderid)
|
||||
|
||||
def cancel_all(self):
|
||||
""""""
|
||||
|
@ -333,8 +333,8 @@ class BackTestingEngine(object):
|
||||
self.fix_commission.update({vt_symbol: rate})
|
||||
|
||||
def get_commission_rate(self, vt_symbol: str):
|
||||
""" 获取保证金比例,缺省万分之一"""
|
||||
return self.commission_rate.get(vt_symbol, float(0.00001))
|
||||
""" 获取保证金比例,缺省千分之2"""
|
||||
return self.commission_rate.get(vt_symbol, float(0.0004))
|
||||
|
||||
def get_fix_commission(self, vt_symbol: str):
|
||||
return self.fix_commission.get(vt_symbol, 0)
|
||||
@ -561,7 +561,7 @@ class BackTestingEngine(object):
|
||||
margin_rate = symbol_data.get('margin_rate', 0.1)
|
||||
self.set_margin_rate(symbol, margin_rate)
|
||||
|
||||
self.set_commission_rate(symbol, symbol_data.get('commission_rate', float(0.0001)))
|
||||
self.set_commission_rate(symbol, symbol_data.get('commission_rate', float(0.0004)))
|
||||
|
||||
self.set_contract(
|
||||
symbol=symbol,
|
||||
@ -2232,7 +2232,7 @@ class TradingResult(object):
|
||||
self.volume = volume # 交易数量(+/-代表方向)
|
||||
self.group_id = group_id # 主交易ID(针对多手平仓)
|
||||
|
||||
self.turnover = (self.open_price + self.exit_price) * abs(volume) * margin_rate # 成交金额(实际保证金金额)
|
||||
self.turnover = (self.open_price + self.exit_price) * abs(volume) # 成交金额(实际保证金金额)
|
||||
if fix_commission > 0:
|
||||
self.commission = fix_commission * abs(self.volume)
|
||||
else:
|
||||
|
@ -42,6 +42,7 @@ from vnpy.trader.event import (
|
||||
)
|
||||
from vnpy.trader.constant import (
|
||||
Direction,
|
||||
Exchange,
|
||||
OrderType,
|
||||
Offset,
|
||||
Status
|
||||
@ -349,12 +350,27 @@ class CtaEngine(BaseEngine):
|
||||
# Update GUI
|
||||
self.put_strategy_event(strategy)
|
||||
|
||||
if self.engine_config.get('trade_2_wx', False):
|
||||
accountid = self.engine_config.get('accountid', '-')
|
||||
d = {
|
||||
'account': accountid,
|
||||
'strategy': strategy_name,
|
||||
'symbol': trade.symbol,
|
||||
'action': f'{trade.direction.value} {trade.offset.value}',
|
||||
'price': str(trade.price),
|
||||
'volume': trade.volume,
|
||||
'remark': f'{accountid}:{strategy_name}',
|
||||
'timestamp': trade.time
|
||||
}
|
||||
send_wx_msg(content=d, target=accountid)
|
||||
|
||||
def process_position_event(self, event: Event):
|
||||
""""""
|
||||
position = event.data
|
||||
|
||||
self.offset_converter.update_position(position)
|
||||
|
||||
|
||||
def check_unsubscribed_symbols(self):
|
||||
"""检查未订阅合约"""
|
||||
|
||||
@ -1666,7 +1682,8 @@ class CtaEngine(BaseEngine):
|
||||
gateway_names = self.main_engine.get_all_gateway_names()
|
||||
gateway_name = gateway_names[0] if len(gateway_names) > 0 else ""
|
||||
symbol, exchange = extract_vt_symbol(vt_symbol)
|
||||
self.main_engine.subscribe(req=SubscribeRequest(symbol=symbol, exchange=exchange), gateway_name=gateway_name)
|
||||
self.main_engine.subscribe(req=SubscribeRequest(symbol=symbol, exchange=exchange),
|
||||
gateway_name=gateway_name)
|
||||
if volume > 0 and tick:
|
||||
contract = self.main_engine.get_contract(vt_symbol)
|
||||
req = OrderRequest(
|
||||
@ -1821,9 +1838,13 @@ class CtaEngine(BaseEngine):
|
||||
if level in [logging.CRITICAL, logging.ERROR, logging.WARNING]:
|
||||
print(f"{strategy_name}: {msg}" if strategy_name else msg, file=sys.stderr)
|
||||
|
||||
def write_error(self, msg: str, strategy_name: str = ''):
|
||||
if level in [logging.CRITICAL, logging.WARN, logging.WARNING]:
|
||||
send_wx_msg(content=f"{strategy_name}: {msg}" if strategy_name else msg)
|
||||
|
||||
def write_error(self, msg: str, strategy_name: str = '', level: int = logging.ERROR):
|
||||
"""写入错误日志"""
|
||||
self.write_log(msg=msg, strategy_name=strategy_name, level=logging.ERROR)
|
||||
self.write_log(msg=msg, strategy_name=strategy_name, level=level)
|
||||
|
||||
|
||||
def send_email(self, msg: str, strategy: CtaTemplate = None):
|
||||
"""
|
||||
|
23
vnpy/app/portfolio_strategy/__init__.py
Normal file
23
vnpy/app/portfolio_strategy/__init__.py
Normal file
@ -0,0 +1,23 @@
|
||||
from pathlib import Path
|
||||
|
||||
from vnpy.trader.app import BaseApp
|
||||
from vnpy.trader.constant import Direction
|
||||
from vnpy.trader.object import TickData, BarData, TradeData, OrderData
|
||||
from vnpy.trader.utility import BarGenerator, ArrayManager
|
||||
|
||||
from .base import APP_NAME
|
||||
from .engine import StrategyEngine
|
||||
from .template import StrategyTemplate
|
||||
from .backtesting import BacktestingEngine
|
||||
|
||||
|
||||
class PortfolioStrategyApp(BaseApp):
|
||||
""""""
|
||||
|
||||
app_name = APP_NAME
|
||||
app_module = __module__
|
||||
app_path = Path(__file__).parent
|
||||
display_name = "组合策略"
|
||||
engine_class = StrategyEngine
|
||||
widget_name = "PortfolioStrategyManager"
|
||||
icon_name = "strategy.ico"
|
821
vnpy/app/portfolio_strategy/backtesting.py
Normal file
821
vnpy/app/portfolio_strategy/backtesting.py
Normal file
@ -0,0 +1,821 @@
|
||||
from collections import defaultdict
|
||||
from datetime import date, datetime, timedelta
|
||||
from typing import Dict, List, Set, Tuple
|
||||
from functools import lru_cache
|
||||
import traceback
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
from pandas import DataFrame
|
||||
|
||||
from vnpy.trader.constant import Direction, Offset, Interval, Status
|
||||
from vnpy.trader.database import database_manager
|
||||
from vnpy.trader.object import OrderData, TradeData, BarData
|
||||
from vnpy.trader.utility import round_to, extract_vt_symbol
|
||||
|
||||
from .template import StrategyTemplate
|
||||
|
||||
# Set seaborn style
|
||||
sns.set_style("whitegrid")
|
||||
|
||||
|
||||
INTERVAL_DELTA_MAP = {
|
||||
Interval.MINUTE: timedelta(minutes=1),
|
||||
Interval.HOUR: timedelta(hours=1),
|
||||
Interval.DAILY: timedelta(days=1),
|
||||
}
|
||||
|
||||
|
||||
class BacktestingEngine:
|
||||
""""""
|
||||
|
||||
gateway_name = "BACKTESTING"
|
||||
|
||||
def __init__(self):
|
||||
""""""
|
||||
self.vt_symbols: List[str] = []
|
||||
self.start: datetime = None
|
||||
self.end: datetime = None
|
||||
|
||||
self.rates: Dict[str, float] = 0
|
||||
self.slippages: Dict[str, float] = 0
|
||||
self.sizes: Dict[str, float] = 1
|
||||
self.priceticks: Dict[str, float] = 0
|
||||
|
||||
self.capital: float = 1_000_000
|
||||
|
||||
self.strategy: StrategyTemplate = None
|
||||
self.bars: Dict[str, BarData] = {}
|
||||
self.datetime: datetime = None
|
||||
|
||||
self.interval: Interval = None
|
||||
self.days: int = 0
|
||||
self.history_data: Dict[Tuple, BarData] = {}
|
||||
self.dts: Set[datetime] = set()
|
||||
|
||||
self.limit_order_count = 0
|
||||
self.limit_orders = {}
|
||||
self.active_limit_orders = {}
|
||||
|
||||
self.trade_count = 0
|
||||
self.trades = {}
|
||||
|
||||
self.logs = []
|
||||
|
||||
self.daily_results = {}
|
||||
self.daily_df = None
|
||||
|
||||
def clear_data(self) -> None:
|
||||
"""
|
||||
Clear all data of last backtesting.
|
||||
"""
|
||||
self.strategy = None
|
||||
self.bars = {}
|
||||
self.datetime = None
|
||||
|
||||
self.limit_order_count = 0
|
||||
self.limit_orders.clear()
|
||||
self.active_limit_orders.clear()
|
||||
|
||||
self.trade_count = 0
|
||||
self.trades.clear()
|
||||
|
||||
self.logs.clear()
|
||||
self.daily_results.clear()
|
||||
self.daily_df = None
|
||||
|
||||
def set_parameters(
|
||||
self,
|
||||
vt_symbols: List[str],
|
||||
interval: Interval,
|
||||
start: datetime,
|
||||
rates: Dict[str, float],
|
||||
slippages: Dict[str, float],
|
||||
sizes: Dict[str, float],
|
||||
priceticks: Dict[str, float],
|
||||
capital: int = 0,
|
||||
end: datetime = None
|
||||
) -> None:
|
||||
""""""
|
||||
self.vt_symbols = vt_symbols
|
||||
self.interval = interval
|
||||
|
||||
self.rates = rates
|
||||
self.slippages = slippages
|
||||
self.sizes = sizes
|
||||
self.priceticks = priceticks
|
||||
|
||||
self.start = start
|
||||
self.end = end
|
||||
self.capital = capital
|
||||
|
||||
def add_strategy(self, strategy_class: type, setting: dict) -> None:
|
||||
""""""
|
||||
self.strategy = strategy_class(
|
||||
self, strategy_class.__name__, self.vt_symbols, setting
|
||||
)
|
||||
|
||||
def load_data(self) -> None:
|
||||
""""""
|
||||
self.output("开始加载历史数据")
|
||||
|
||||
if not self.end:
|
||||
self.end = datetime.now()
|
||||
|
||||
if self.start >= self.end:
|
||||
self.output("起始日期必须小于结束日期")
|
||||
return
|
||||
|
||||
# Clear previously loaded history data
|
||||
self.history_data.clear()
|
||||
self.dts.clear()
|
||||
|
||||
# Load 30 days of data each time and allow for progress update
|
||||
progress_delta = timedelta(days=30)
|
||||
total_delta = self.end - self.start
|
||||
interval_delta = INTERVAL_DELTA_MAP[self.interval]
|
||||
|
||||
for vt_symbol in self.vt_symbols:
|
||||
start = self.start
|
||||
end = self.start + progress_delta
|
||||
progress = 0
|
||||
|
||||
data_count = 0
|
||||
while start < self.end:
|
||||
end = min(end, self.end) # Make sure end time stays within set range
|
||||
|
||||
data = load_bar_data(
|
||||
vt_symbol,
|
||||
self.interval,
|
||||
start,
|
||||
end
|
||||
)
|
||||
|
||||
for bar in data:
|
||||
self.dts.add(bar.datetime)
|
||||
self.history_data[(bar.datetime, vt_symbol)] = bar
|
||||
data_count += 1
|
||||
|
||||
progress += progress_delta / total_delta
|
||||
progress = min(progress, 1)
|
||||
progress_bar = "#" * int(progress * 10)
|
||||
self.output(f"{vt_symbol}加载进度:{progress_bar} [{progress:.0%}]")
|
||||
|
||||
start = end + interval_delta
|
||||
end += (progress_delta + interval_delta)
|
||||
|
||||
self.output(f"{vt_symbol}历史数据加载完成,数据量:{data_count}")
|
||||
|
||||
self.output("所有历史数据加载完成")
|
||||
|
||||
def run_backtesting(self) -> None:
|
||||
""""""
|
||||
self.strategy.on_init()
|
||||
|
||||
# Generate sorted datetime list
|
||||
dts = list(self.dts)
|
||||
dts.sort()
|
||||
|
||||
# Use the first [days] of history data for initializing strategy
|
||||
day_count = 0
|
||||
ix = 0
|
||||
|
||||
for ix, dt in enumerate(dts):
|
||||
if self.datetime and dt.day != self.datetime.day:
|
||||
day_count += 1
|
||||
if day_count >= self.days:
|
||||
break
|
||||
|
||||
try:
|
||||
self.new_bars(dt)
|
||||
except Exception:
|
||||
self.output("触发异常,回测终止")
|
||||
self.output(traceback.format_exc())
|
||||
return
|
||||
|
||||
self.strategy.inited = True
|
||||
self.output("策略初始化完成")
|
||||
|
||||
self.strategy.on_start()
|
||||
self.strategy.trading = True
|
||||
self.output("开始回放历史数据")
|
||||
|
||||
# Use the rest of history data for running backtesting
|
||||
for dt in dts[ix:]:
|
||||
try:
|
||||
self.new_bars(dt)
|
||||
except Exception:
|
||||
self.output("触发异常,回测终止")
|
||||
self.output(traceback.format_exc())
|
||||
return
|
||||
|
||||
self.output("历史数据回放结束")
|
||||
|
||||
def calculate_result(self) -> None:
|
||||
""""""
|
||||
self.output("开始计算逐日盯市盈亏")
|
||||
|
||||
if not self.trades:
|
||||
self.output("成交记录为空,无法计算")
|
||||
return
|
||||
|
||||
# Add trade data into daily reuslt.
|
||||
for trade in self.trades.values():
|
||||
d = trade.datetime.date()
|
||||
daily_result = self.daily_results[d]
|
||||
daily_result.add_trade(trade)
|
||||
|
||||
# Calculate daily result by iteration.
|
||||
pre_closes = {}
|
||||
start_poses = {}
|
||||
|
||||
for daily_result in self.daily_results.values():
|
||||
daily_result.calculate_pnl(
|
||||
pre_closes,
|
||||
start_poses,
|
||||
self.sizes,
|
||||
self.rates,
|
||||
self.slippages,
|
||||
)
|
||||
|
||||
pre_closes = daily_result.close_prices
|
||||
start_poses = daily_result.end_poses
|
||||
|
||||
# Generate dataframe
|
||||
results = defaultdict(list)
|
||||
|
||||
for daily_result in self.daily_results.values():
|
||||
fields = [
|
||||
"date", "trade_count", "turnover",
|
||||
"commission", "slippage", "trading_pnl",
|
||||
"holding_pnl", "total_pnl", "net_pnl"
|
||||
]
|
||||
for key in fields:
|
||||
value = getattr(daily_result, key)
|
||||
results[key].append(value)
|
||||
|
||||
self.daily_df = DataFrame.from_dict(results).set_index("date")
|
||||
|
||||
self.output("逐日盯市盈亏计算完成")
|
||||
return self.daily_df
|
||||
|
||||
def calculate_statistics(self, df: DataFrame = None, output=True) -> None:
|
||||
""""""
|
||||
self.output("开始计算策略统计指标")
|
||||
|
||||
# Check DataFrame input exterior
|
||||
if df is None:
|
||||
df = self.daily_df
|
||||
|
||||
# Check for init DataFrame
|
||||
if df is None:
|
||||
# Set all statistics to 0 if no trade.
|
||||
start_date = ""
|
||||
end_date = ""
|
||||
total_days = 0
|
||||
profit_days = 0
|
||||
loss_days = 0
|
||||
end_balance = 0
|
||||
max_drawdown = 0
|
||||
max_ddpercent = 0
|
||||
max_drawdown_duration = 0
|
||||
total_net_pnl = 0
|
||||
daily_net_pnl = 0
|
||||
total_commission = 0
|
||||
daily_commission = 0
|
||||
total_slippage = 0
|
||||
daily_slippage = 0
|
||||
total_turnover = 0
|
||||
daily_turnover = 0
|
||||
total_trade_count = 0
|
||||
daily_trade_count = 0
|
||||
total_return = 0
|
||||
annual_return = 0
|
||||
daily_return = 0
|
||||
return_std = 0
|
||||
sharpe_ratio = 0
|
||||
return_drawdown_ratio = 0
|
||||
else:
|
||||
# Calculate balance related time series data
|
||||
df["balance"] = df["net_pnl"].cumsum() + self.capital
|
||||
df["return"] = np.log(df["balance"] / df["balance"].shift(1)).fillna(0)
|
||||
df["highlevel"] = (
|
||||
df["balance"].rolling(
|
||||
min_periods=1, window=len(df), center=False).max()
|
||||
)
|
||||
df["drawdown"] = df["balance"] - df["highlevel"]
|
||||
df["ddpercent"] = df["drawdown"] / df["highlevel"] * 100
|
||||
|
||||
# Calculate statistics value
|
||||
start_date = df.index[0]
|
||||
end_date = df.index[-1]
|
||||
|
||||
total_days = len(df)
|
||||
profit_days = len(df[df["net_pnl"] > 0])
|
||||
loss_days = len(df[df["net_pnl"] < 0])
|
||||
|
||||
end_balance = df["balance"].iloc[-1]
|
||||
max_drawdown = df["drawdown"].min()
|
||||
max_ddpercent = df["ddpercent"].min()
|
||||
max_drawdown_end = df["drawdown"].idxmin()
|
||||
|
||||
if isinstance(max_drawdown_end, date):
|
||||
max_drawdown_start = df["balance"][:max_drawdown_end].idxmax()
|
||||
max_drawdown_duration = (max_drawdown_end - max_drawdown_start).days
|
||||
else:
|
||||
max_drawdown_duration = 0
|
||||
|
||||
total_net_pnl = df["net_pnl"].sum()
|
||||
daily_net_pnl = total_net_pnl / total_days
|
||||
|
||||
total_commission = df["commission"].sum()
|
||||
daily_commission = total_commission / total_days
|
||||
|
||||
total_slippage = df["slippage"].sum()
|
||||
daily_slippage = total_slippage / total_days
|
||||
|
||||
total_turnover = df["turnover"].sum()
|
||||
daily_turnover = total_turnover / total_days
|
||||
|
||||
total_trade_count = df["trade_count"].sum()
|
||||
daily_trade_count = total_trade_count / total_days
|
||||
|
||||
total_return = (end_balance / self.capital - 1) * 100
|
||||
annual_return = total_return / total_days * 240
|
||||
daily_return = df["return"].mean() * 100
|
||||
return_std = df["return"].std() * 100
|
||||
|
||||
if return_std:
|
||||
sharpe_ratio = daily_return / return_std * np.sqrt(240)
|
||||
else:
|
||||
sharpe_ratio = 0
|
||||
|
||||
return_drawdown_ratio = -total_return / max_ddpercent
|
||||
|
||||
# Output
|
||||
if output:
|
||||
self.output("-" * 30)
|
||||
self.output(f"首个交易日:\t{start_date}")
|
||||
self.output(f"最后交易日:\t{end_date}")
|
||||
|
||||
self.output(f"总交易日:\t{total_days}")
|
||||
self.output(f"盈利交易日:\t{profit_days}")
|
||||
self.output(f"亏损交易日:\t{loss_days}")
|
||||
|
||||
self.output(f"起始资金:\t{self.capital:,.2f}")
|
||||
self.output(f"结束资金:\t{end_balance:,.2f}")
|
||||
|
||||
self.output(f"总收益率:\t{total_return:,.2f}%")
|
||||
self.output(f"年化收益:\t{annual_return:,.2f}%")
|
||||
self.output(f"最大回撤: \t{max_drawdown:,.2f}")
|
||||
self.output(f"百分比最大回撤: {max_ddpercent:,.2f}%")
|
||||
self.output(f"最长回撤天数: \t{max_drawdown_duration}")
|
||||
|
||||
self.output(f"总盈亏:\t{total_net_pnl:,.2f}")
|
||||
self.output(f"总手续费:\t{total_commission:,.2f}")
|
||||
self.output(f"总滑点:\t{total_slippage:,.2f}")
|
||||
self.output(f"总成交金额:\t{total_turnover:,.2f}")
|
||||
self.output(f"总成交笔数:\t{total_trade_count}")
|
||||
|
||||
self.output(f"日均盈亏:\t{daily_net_pnl:,.2f}")
|
||||
self.output(f"日均手续费:\t{daily_commission:,.2f}")
|
||||
self.output(f"日均滑点:\t{daily_slippage:,.2f}")
|
||||
self.output(f"日均成交金额:\t{daily_turnover:,.2f}")
|
||||
self.output(f"日均成交笔数:\t{daily_trade_count}")
|
||||
|
||||
self.output(f"日均收益率:\t{daily_return:,.2f}%")
|
||||
self.output(f"收益标准差:\t{return_std:,.2f}%")
|
||||
self.output(f"Sharpe Ratio:\t{sharpe_ratio:,.2f}")
|
||||
self.output(f"收益回撤比:\t{return_drawdown_ratio:,.2f}")
|
||||
|
||||
statistics = {
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"total_days": total_days,
|
||||
"profit_days": profit_days,
|
||||
"loss_days": loss_days,
|
||||
"capital": self.capital,
|
||||
"end_balance": end_balance,
|
||||
"max_drawdown": max_drawdown,
|
||||
"max_ddpercent": max_ddpercent,
|
||||
"max_drawdown_duration": max_drawdown_duration,
|
||||
"total_net_pnl": total_net_pnl,
|
||||
"daily_net_pnl": daily_net_pnl,
|
||||
"total_commission": total_commission,
|
||||
"daily_commission": daily_commission,
|
||||
"total_slippage": total_slippage,
|
||||
"daily_slippage": daily_slippage,
|
||||
"total_turnover": total_turnover,
|
||||
"daily_turnover": daily_turnover,
|
||||
"total_trade_count": total_trade_count,
|
||||
"daily_trade_count": daily_trade_count,
|
||||
"total_return": total_return,
|
||||
"annual_return": annual_return,
|
||||
"daily_return": daily_return,
|
||||
"return_std": return_std,
|
||||
"sharpe_ratio": sharpe_ratio,
|
||||
"return_drawdown_ratio": return_drawdown_ratio,
|
||||
}
|
||||
|
||||
# Filter potential error infinite value
|
||||
for key, value in statistics.items():
|
||||
if value in (np.inf, -np.inf):
|
||||
value = 0
|
||||
statistics[key] = np.nan_to_num(value)
|
||||
|
||||
self.output("策略统计指标计算完成")
|
||||
return statistics
|
||||
|
||||
def show_chart(self, df: DataFrame = None) -> None:
|
||||
""""""
|
||||
# Check DataFrame input exterior
|
||||
if df is None:
|
||||
df = self.daily_df
|
||||
|
||||
# Check for init DataFrame
|
||||
if df is None:
|
||||
return
|
||||
|
||||
plt.figure(figsize=(10, 16))
|
||||
|
||||
balance_plot = plt.subplot(4, 1, 1)
|
||||
balance_plot.set_title("Balance")
|
||||
df["balance"].plot(legend=True)
|
||||
|
||||
drawdown_plot = plt.subplot(4, 1, 2)
|
||||
drawdown_plot.set_title("Drawdown")
|
||||
drawdown_plot.fill_between(range(len(df)), df["drawdown"].values)
|
||||
|
||||
pnl_plot = plt.subplot(4, 1, 3)
|
||||
pnl_plot.set_title("Daily Pnl")
|
||||
df["net_pnl"].plot(kind="bar", legend=False, grid=False, xticks=[])
|
||||
|
||||
distribution_plot = plt.subplot(4, 1, 4)
|
||||
distribution_plot.set_title("Daily Pnl Distribution")
|
||||
df["net_pnl"].hist(bins=50)
|
||||
|
||||
plt.show()
|
||||
|
||||
def update_daily_close(self, bars: Dict[str, BarData], dt: datetime) -> None:
|
||||
""""""
|
||||
d = dt.date()
|
||||
|
||||
close_prices = {}
|
||||
for bar in bars.values():
|
||||
close_prices[bar.vt_symbol] = bar.close_price
|
||||
|
||||
daily_result = self.daily_results.get(d, None)
|
||||
|
||||
if daily_result:
|
||||
daily_result.update_close_prices(close_prices)
|
||||
else:
|
||||
self.daily_results[d] = PortfolioDailyResult(d, close_prices)
|
||||
|
||||
def new_bars(self, dt: datetime) -> None:
|
||||
""""""
|
||||
self.datetime = dt
|
||||
|
||||
self.bars.clear()
|
||||
for vt_symbol in self.vt_symbols:
|
||||
bar = self.history_data.get((dt, vt_symbol), None)
|
||||
if bar:
|
||||
self.bars[vt_symbol] = bar
|
||||
else:
|
||||
dt_str = dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||
self.output(f"数据缺失:{dt_str} {vt_symbol}")
|
||||
|
||||
self.cross_limit_order()
|
||||
self.strategy.on_bars(self.bars)
|
||||
|
||||
self.update_daily_close(self.bars, dt)
|
||||
|
||||
def cross_limit_order(self) -> None:
|
||||
"""
|
||||
Cross limit order with last bar/tick data.
|
||||
"""
|
||||
for order in list(self.active_limit_orders.values()):
|
||||
bar = self.bars[order.vt_symbol]
|
||||
|
||||
long_cross_price = bar.low_price
|
||||
short_cross_price = bar.high_price
|
||||
long_best_price = bar.open_price
|
||||
short_best_price = bar.open_price
|
||||
|
||||
# Push order update with status "not traded" (pending).
|
||||
if order.status == Status.SUBMITTING:
|
||||
order.status = Status.NOTTRADED
|
||||
self.strategy.update_order(order)
|
||||
|
||||
# Check whether limit orders can be filled.
|
||||
long_cross = (
|
||||
order.direction == Direction.LONG
|
||||
and order.price >= long_cross_price
|
||||
and long_cross_price > 0
|
||||
)
|
||||
|
||||
short_cross = (
|
||||
order.direction == Direction.SHORT
|
||||
and order.price <= short_cross_price
|
||||
and short_cross_price > 0
|
||||
)
|
||||
|
||||
if not long_cross and not short_cross:
|
||||
continue
|
||||
|
||||
# Push order update with status "all traded" (filled).
|
||||
order.traded = order.volume
|
||||
order.status = Status.ALLTRADED
|
||||
self.strategy.update_order(order)
|
||||
|
||||
self.active_limit_orders.pop(order.vt_orderid)
|
||||
|
||||
# Push trade update
|
||||
self.trade_count += 1
|
||||
|
||||
if long_cross:
|
||||
trade_price = min(order.price, long_best_price)
|
||||
else:
|
||||
trade_price = max(order.price, short_best_price)
|
||||
|
||||
trade = TradeData(
|
||||
symbol=order.symbol,
|
||||
exchange=order.exchange,
|
||||
orderid=order.orderid,
|
||||
tradeid=str(self.trade_count),
|
||||
direction=order.direction,
|
||||
offset=order.offset,
|
||||
price=trade_price,
|
||||
volume=order.volume,
|
||||
time=self.datetime.strftime("%H:%M:%S"),
|
||||
gateway_name=self.gateway_name,
|
||||
)
|
||||
trade.datetime = self.datetime
|
||||
|
||||
self.strategy.update_trade(trade)
|
||||
self.trades[trade.vt_tradeid] = trade
|
||||
|
||||
def load_bars(
|
||||
self,
|
||||
strategy: StrategyTemplate,
|
||||
days: int,
|
||||
interval: Interval
|
||||
) -> None:
|
||||
""""""
|
||||
self.days = days
|
||||
|
||||
def send_order(
|
||||
self,
|
||||
strategy: StrategyTemplate,
|
||||
vt_symbol: str,
|
||||
direction: Direction,
|
||||
offset: Offset,
|
||||
price: float,
|
||||
volume: float,
|
||||
lock: bool
|
||||
) -> List[str]:
|
||||
""""""
|
||||
price = round_to(price, self.priceticks[vt_symbol])
|
||||
symbol, exchange = extract_vt_symbol(vt_symbol)
|
||||
|
||||
self.limit_order_count += 1
|
||||
|
||||
order = OrderData(
|
||||
symbol=symbol,
|
||||
exchange=exchange,
|
||||
orderid=str(self.limit_order_count),
|
||||
direction=direction,
|
||||
offset=offset,
|
||||
price=price,
|
||||
volume=volume,
|
||||
status=Status.SUBMITTING,
|
||||
gateway_name=self.gateway_name,
|
||||
)
|
||||
order.datetime = self.datetime
|
||||
|
||||
self.active_limit_orders[order.vt_orderid] = order
|
||||
self.limit_orders[order.vt_orderid] = order
|
||||
|
||||
return [order.vt_orderid]
|
||||
|
||||
def cancel_order(self, strategy: StrategyTemplate, vt_orderid: str) -> None:
|
||||
"""
|
||||
Cancel order by vt_orderid.
|
||||
"""
|
||||
if vt_orderid not in self.active_limit_orders:
|
||||
return
|
||||
order = self.active_limit_orders.pop(vt_orderid)
|
||||
|
||||
order.status = Status.CANCELLED
|
||||
self.strategy.update_order(order)
|
||||
|
||||
def write_log(self, msg: str, strategy: StrategyTemplate = None) -> None:
|
||||
"""
|
||||
Write log message.
|
||||
"""
|
||||
msg = f"{self.datetime}\t{msg}"
|
||||
self.logs.append(msg)
|
||||
|
||||
def send_email(self, msg: str, strategy: StrategyTemplate = None) -> None:
|
||||
"""
|
||||
Send email to default receiver.
|
||||
"""
|
||||
pass
|
||||
|
||||
def sync_strategy_data(self, strategy: StrategyTemplate) -> None:
|
||||
"""
|
||||
Sync strategy data into json file.
|
||||
"""
|
||||
pass
|
||||
|
||||
def put_strategy_event(self, strategy: StrategyTemplate) -> None:
|
||||
"""
|
||||
Put an event to update strategy status.
|
||||
"""
|
||||
pass
|
||||
|
||||
def output(self, msg) -> None:
|
||||
"""
|
||||
Output message of backtesting engine.
|
||||
"""
|
||||
print(f"{datetime.now()}\t{msg}")
|
||||
|
||||
def get_all_trades(self) -> List[TradeData]:
|
||||
"""
|
||||
Return all trade data of current backtesting result.
|
||||
"""
|
||||
return list(self.trades.values())
|
||||
|
||||
def get_all_orders(self) -> List[OrderData]:
|
||||
"""
|
||||
Return all limit order data of current backtesting result.
|
||||
"""
|
||||
return list(self.limit_orders.values())
|
||||
|
||||
def get_all_daily_results(self) -> List["PortfolioDailyResult"]:
|
||||
"""
|
||||
Return all daily result data.
|
||||
"""
|
||||
return list(self.daily_results.values())
|
||||
|
||||
|
||||
class ContractDailyResult:
|
||||
""""""
|
||||
|
||||
def __init__(self, result_date: date, close_price: float):
|
||||
""""""
|
||||
self.date: date = result_date
|
||||
self.close_price: float = close_price
|
||||
self.pre_close: float = 0
|
||||
|
||||
self.trades: List[TradeData] = []
|
||||
self.trade_count: int = 0
|
||||
|
||||
self.start_pos: float = 0
|
||||
self.end_pos: float = 0
|
||||
|
||||
self.turnover: float = 0
|
||||
self.commission: float = 0
|
||||
self.slippage: float = 0
|
||||
|
||||
self.trading_pnl: float = 0
|
||||
self.holding_pnl: float = 0
|
||||
self.total_pnl: float = 0
|
||||
self.net_pnl: float = 0
|
||||
|
||||
def add_trade(self, trade: TradeData) -> None:
|
||||
""""""
|
||||
self.trades.append(trade)
|
||||
|
||||
def calculate_pnl(
|
||||
self,
|
||||
pre_close: float,
|
||||
start_pos: float,
|
||||
size: int,
|
||||
rate: float,
|
||||
slippage: float
|
||||
) -> None:
|
||||
""""""
|
||||
# If no pre_close provided on the first day,
|
||||
# use value 1 to avoid zero division error
|
||||
if pre_close:
|
||||
self.pre_close = pre_close
|
||||
else:
|
||||
self.pre_close = 1
|
||||
|
||||
# Holding pnl is the pnl from holding position at day start
|
||||
self.start_pos = start_pos
|
||||
self.end_pos = start_pos
|
||||
|
||||
self.holding_pnl = self.start_pos * (self.close_price - self.pre_close) * size
|
||||
|
||||
# Trading pnl is the pnl from new trade during the day
|
||||
self.trade_count = len(self.trades)
|
||||
|
||||
for trade in self.trades:
|
||||
if trade.direction == Direction.LONG:
|
||||
pos_change = trade.volume
|
||||
else:
|
||||
pos_change = -trade.volume
|
||||
|
||||
self.end_pos += pos_change
|
||||
|
||||
turnover = trade.volume * size * trade.price
|
||||
|
||||
self.trading_pnl += pos_change * (self.close_price - trade.price) * size
|
||||
self.slippage += trade.volume * size * slippage
|
||||
self.turnover += turnover
|
||||
self.commission += turnover * rate
|
||||
|
||||
# Net pnl takes account of commission and slippage cost
|
||||
self.total_pnl = self.trading_pnl + self.holding_pnl
|
||||
self.net_pnl = self.total_pnl - self.commission - self.slippage
|
||||
|
||||
def update_close_price(self, close_price: float) -> None:
|
||||
""""""
|
||||
self.close_price = close_price
|
||||
|
||||
|
||||
class PortfolioDailyResult:
|
||||
""""""
|
||||
|
||||
def __init__(self, result_date: date, close_prices: Dict[str, float]):
|
||||
""""""
|
||||
self.date: date = result_date
|
||||
self.close_prices: Dict[str, float] = close_prices
|
||||
self.pre_closes: Dict[str, float] = {}
|
||||
self.start_poses: Dict[str, float] = {}
|
||||
self.end_poses: Dict[str, float] = {}
|
||||
|
||||
self.contract_results: Dict[str, ContractDailyResult] = {}
|
||||
|
||||
for vt_symbol, close_price in close_prices.items():
|
||||
self.contract_results[vt_symbol] = ContractDailyResult(result_date, close_price)
|
||||
|
||||
self.trade_count: int = 0
|
||||
self.turnover: float = 0
|
||||
self.commission: float = 0
|
||||
self.slippage: float = 0
|
||||
self.trading_pnl: float = 0
|
||||
self.holding_pnl: float = 0
|
||||
self.total_pnl: float = 0
|
||||
self.net_pnl: float = 0
|
||||
|
||||
def add_trade(self, trade: TradeData) -> None:
|
||||
""""""
|
||||
contract_result = self.contract_results[trade.vt_symbol]
|
||||
contract_result.add_trade(trade)
|
||||
|
||||
def calculate_pnl(
|
||||
self,
|
||||
pre_closes: Dict[str, float],
|
||||
start_poses: Dict[str, float],
|
||||
sizes: Dict[str, float],
|
||||
rates: Dict[str, float],
|
||||
slippages: Dict[str, float],
|
||||
) -> None:
|
||||
""""""
|
||||
self.pre_closes = pre_closes
|
||||
|
||||
for vt_symbol, contract_result in self.contract_results.items():
|
||||
contract_result.calculate_pnl(
|
||||
pre_closes.get(vt_symbol, 0),
|
||||
start_poses.get(vt_symbol, 0),
|
||||
sizes[vt_symbol],
|
||||
rates[vt_symbol],
|
||||
slippages[vt_symbol]
|
||||
)
|
||||
|
||||
self.trade_count += contract_result.trade_count
|
||||
self.turnover += contract_result.turnover
|
||||
self.commission += contract_result.commission
|
||||
self.slippage += contract_result.slippage
|
||||
self.trading_pnl += contract_result.trading_pnl
|
||||
self.holding_pnl += contract_result.holding_pnl
|
||||
self.total_pnl += contract_result.total_pnl
|
||||
self.net_pnl += contract_result.net_pnl
|
||||
|
||||
self.end_poses[vt_symbol] = contract_result.end_pos
|
||||
|
||||
def update_close_prices(self, close_prices: Dict[str, float]) -> None:
|
||||
""""""
|
||||
self.close_prices = close_prices
|
||||
|
||||
for vt_symbol, close_price in close_prices.items():
|
||||
contract_result = self.contract_results[vt_symbol]
|
||||
contract_result.update_close_price(close_price)
|
||||
|
||||
|
||||
@lru_cache(maxsize=999)
|
||||
def load_bar_data(
|
||||
vt_symbol: str,
|
||||
interval: Interval,
|
||||
start: datetime,
|
||||
end: datetime
|
||||
):
|
||||
""""""
|
||||
symbol, exchange = extract_vt_symbol(vt_symbol)
|
||||
|
||||
return database_manager.load_bar_data(
|
||||
symbol, exchange, interval, start, end
|
||||
)
|
17
vnpy/app/portfolio_strategy/base.py
Normal file
17
vnpy/app/portfolio_strategy/base.py
Normal file
@ -0,0 +1,17 @@
|
||||
"""
|
||||
Defines constants and objects used in PortfolioStrategy App.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
APP_NAME = "PortfolioStrategy"
|
||||
|
||||
|
||||
class EngineType(Enum):
|
||||
LIVE = "实盘"
|
||||
BACKTESTING = "回测"
|
||||
|
||||
|
||||
EVENT_PORTFOLIO_LOG = "ePortfolioLog"
|
||||
EVENT_PORTFOLIO_STRATEGY = "ePortfolioStrategy"
|
624
vnpy/app/portfolio_strategy/engine.py
Normal file
624
vnpy/app/portfolio_strategy/engine.py
Normal file
@ -0,0 +1,624 @@
|
||||
""""""
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Set, Tuple, Type, Any, Callable
|
||||
from datetime import datetime, timedelta
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from vnpy.event import Event, EventEngine
|
||||
from vnpy.trader.engine import BaseEngine, MainEngine
|
||||
from vnpy.trader.object import (
|
||||
OrderRequest,
|
||||
SubscribeRequest,
|
||||
HistoryRequest,
|
||||
LogData,
|
||||
TickData,
|
||||
OrderData,
|
||||
TradeData,
|
||||
PositionData,
|
||||
BarData,
|
||||
ContractData
|
||||
)
|
||||
from vnpy.trader.event import (
|
||||
EVENT_TICK,
|
||||
EVENT_ORDER,
|
||||
EVENT_TRADE,
|
||||
EVENT_POSITION
|
||||
)
|
||||
from vnpy.trader.constant import (
|
||||
Direction,
|
||||
OrderType,
|
||||
Interval,
|
||||
Exchange,
|
||||
Offset
|
||||
)
|
||||
from vnpy.trader.utility import load_json, save_json, extract_vt_symbol, round_to
|
||||
from vnpy.trader.database import database_manager
|
||||
from vnpy.trader.rqdata import rqdata_client
|
||||
from vnpy.trader.converter import OffsetConverter
|
||||
|
||||
from .base import (
|
||||
APP_NAME,
|
||||
EVENT_PORTFOLIO_LOG,
|
||||
EVENT_PORTFOLIO_STRATEGY
|
||||
)
|
||||
from .template import StrategyTemplate
|
||||
|
||||
|
||||
class StrategyEngine(BaseEngine):
|
||||
""""""
|
||||
|
||||
setting_filename = "portfolio_strategy_setting.json"
|
||||
data_filename = "portfolio_strategy_data.json"
|
||||
|
||||
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
|
||||
""""""
|
||||
super().__init__(main_engine, event_engine, APP_NAME)
|
||||
|
||||
self.strategy_data: Dict[str, Dict] = {}
|
||||
|
||||
self.classes: Dict[str, Type[StrategyTemplate]] = {}
|
||||
self.strategies: Dict[str, StrategyTemplate] = {}
|
||||
|
||||
self.symbol_strategy_map: Dict[str, List[StrategyTemplate]] = defaultdict(list)
|
||||
self.orderid_strategy_map: Dict[str, StrategyTemplate] = {}
|
||||
|
||||
self.init_executor: ThreadPoolExecutor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
self.vt_tradeids: Set[str] = set()
|
||||
|
||||
self.offset_converter: OffsetConverter = OffsetConverter(self.main_engine)
|
||||
|
||||
def init_engine(self):
|
||||
"""
|
||||
"""
|
||||
self.init_rqdata()
|
||||
self.load_strategy_class()
|
||||
self.load_strategy_setting()
|
||||
self.load_strategy_data()
|
||||
self.register_event()
|
||||
self.write_log("组合策略引擎初始化成功")
|
||||
|
||||
def close(self):
|
||||
""""""
|
||||
self.stop_all_strategies()
|
||||
|
||||
def register_event(self):
|
||||
""""""
|
||||
self.event_engine.register(EVENT_TICK, self.process_tick_event)
|
||||
self.event_engine.register(EVENT_ORDER, self.process_order_event)
|
||||
self.event_engine.register(EVENT_TRADE, self.process_trade_event)
|
||||
self.event_engine.register(EVENT_POSITION, self.process_position_event)
|
||||
|
||||
def init_rqdata(self):
|
||||
"""
|
||||
Init RQData client.
|
||||
"""
|
||||
result = rqdata_client.init()
|
||||
if result:
|
||||
self.write_log("RQData数据接口初始化成功")
|
||||
|
||||
def query_bar_from_rq(
|
||||
self, symbol: str, exchange: Exchange, interval: Interval, start: datetime, end: datetime
|
||||
):
|
||||
"""
|
||||
Query bar data from RQData.
|
||||
"""
|
||||
req = HistoryRequest(
|
||||
symbol=symbol,
|
||||
exchange=exchange,
|
||||
interval=interval,
|
||||
start=start,
|
||||
end=end
|
||||
)
|
||||
data = rqdata_client.query_history(req)
|
||||
return data
|
||||
|
||||
def process_tick_event(self, event: Event):
|
||||
""""""
|
||||
tick: TickData = event.data
|
||||
|
||||
strategies = self.symbol_strategy_map[tick.vt_symbol]
|
||||
if not strategies:
|
||||
return
|
||||
|
||||
for strategy in strategies:
|
||||
if strategy.inited:
|
||||
self.call_strategy_func(strategy, strategy.on_tick, tick)
|
||||
|
||||
def process_order_event(self, event: Event):
|
||||
""""""
|
||||
order: OrderData = event.data
|
||||
|
||||
self.offset_converter.update_order(order)
|
||||
|
||||
strategy = self.orderid_strategy_map.get(order.vt_orderid, None)
|
||||
if not strategy:
|
||||
return
|
||||
|
||||
self.call_strategy_func(strategy, strategy.update_order, order)
|
||||
|
||||
def process_trade_event(self, event: Event):
|
||||
""""""
|
||||
trade: TradeData = event.data
|
||||
|
||||
# Filter duplicate trade push
|
||||
if trade.vt_tradeid in self.vt_tradeids:
|
||||
return
|
||||
self.vt_tradeids.add(trade.vt_tradeid)
|
||||
|
||||
self.offset_converter.update_trade(trade)
|
||||
|
||||
strategy = self.orderid_strategy_map.get(trade.vt_orderid, None)
|
||||
if not strategy:
|
||||
return
|
||||
|
||||
self.call_strategy_func(strategy, strategy.update_trade, trade)
|
||||
|
||||
def process_position_event(self, event: Event):
|
||||
""""""
|
||||
position: PositionData = event.data
|
||||
|
||||
self.offset_converter.update_position(position)
|
||||
|
||||
def send_order(
|
||||
self,
|
||||
strategy: StrategyTemplate,
|
||||
vt_symbol: str,
|
||||
direction: Direction,
|
||||
offset: Offset,
|
||||
price: float,
|
||||
volume: float,
|
||||
lock: bool
|
||||
):
|
||||
"""
|
||||
Send a new order to server.
|
||||
"""
|
||||
contract: ContractData = self.main_engine.get_contract(vt_symbol)
|
||||
if not contract:
|
||||
self.write_log(f"委托失败,找不到合约:{vt_symbol}", strategy)
|
||||
return ""
|
||||
|
||||
# Round order price and volume to nearest incremental value
|
||||
price = round_to(price, contract.pricetick)
|
||||
volume = round_to(volume, contract.min_volume)
|
||||
|
||||
# Create request and send order.
|
||||
original_req = OrderRequest(
|
||||
symbol=contract.symbol,
|
||||
exchange=contract.exchange,
|
||||
direction=direction,
|
||||
offset=offset,
|
||||
type=OrderType.LIMIT,
|
||||
price=price,
|
||||
volume=volume,
|
||||
)
|
||||
|
||||
# Convert with offset converter
|
||||
req_list = self.offset_converter.convert_order_request(original_req, lock)
|
||||
|
||||
# Send Orders
|
||||
vt_orderids = []
|
||||
|
||||
for req in req_list:
|
||||
req.reference = strategy.strategy_name # Add strategy name as order reference
|
||||
|
||||
vt_orderid = self.main_engine.send_order(
|
||||
req, contract.gateway_name)
|
||||
|
||||
# Check if sending order successful
|
||||
if not vt_orderid:
|
||||
continue
|
||||
|
||||
vt_orderids.append(vt_orderid)
|
||||
|
||||
self.offset_converter.update_order_request(req, vt_orderid)
|
||||
|
||||
# Save relationship between orderid and strategy.
|
||||
self.orderid_strategy_map[vt_orderid] = strategy
|
||||
|
||||
return vt_orderids
|
||||
|
||||
def cancel_order(self, strategy: StrategyTemplate, vt_orderid: str):
|
||||
"""
|
||||
"""
|
||||
order = self.main_engine.get_order(vt_orderid)
|
||||
if not order:
|
||||
self.write_log(f"撤单失败,找不到委托{vt_orderid}", strategy)
|
||||
return
|
||||
|
||||
req = order.create_cancel_request()
|
||||
self.main_engine.cancel_order(req, order.gateway_name)
|
||||
|
||||
def load_bars(self, strategy: StrategyTemplate, days: int, interval: Interval):
|
||||
""""""
|
||||
vt_symbols = strategy.vt_symbols
|
||||
dts: Set[datetime] = set()
|
||||
history_data: Dict[Tuple, BarData] = {}
|
||||
|
||||
# Load data from rqdata/gateway/database
|
||||
for vt_symbol in vt_symbols:
|
||||
data = self.load_bar(vt_symbol, days, interval)
|
||||
|
||||
for bar in data:
|
||||
dts.add(bar.datetime)
|
||||
history_data[(bar.datetime, vt_symbol)] = bar
|
||||
|
||||
# Convert data structure and push to strategy
|
||||
dts = list(dts)
|
||||
dts.sort()
|
||||
|
||||
for dt in dts:
|
||||
bars = {}
|
||||
|
||||
for vt_symbol in vt_symbols:
|
||||
bar = history_data.get((dt, vt_symbol), None)
|
||||
if bar:
|
||||
bars[vt_symbol] = bar
|
||||
else:
|
||||
dt_str = dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||
self.write_log(f"数据缺失:{dt_str} {vt_symbol}", strategy)
|
||||
|
||||
self.call_strategy_func(strategy, strategy.on_bars, bars)
|
||||
|
||||
def load_bar(self, vt_symbol: str, days: int, interval: Interval) -> List[BarData]:
|
||||
""""""
|
||||
symbol, exchange = extract_vt_symbol(vt_symbol)
|
||||
end = datetime.now()
|
||||
start = end - timedelta(days)
|
||||
contract: ContractData = self.main_engine.get_contract(vt_symbol)
|
||||
data = []
|
||||
|
||||
# Query bars from gateway if available
|
||||
if contract and contract.history_data:
|
||||
req = HistoryRequest(
|
||||
symbol=symbol,
|
||||
exchange=exchange,
|
||||
interval=interval,
|
||||
start=start,
|
||||
end=end
|
||||
)
|
||||
data = self.main_engine.query_history(req, contract.gateway_name)
|
||||
# Try to query bars from RQData, if not found, load from database.
|
||||
else:
|
||||
data = self.query_bar_from_rq(symbol, exchange, interval, start, end)
|
||||
|
||||
if not data:
|
||||
data = database_manager.load_bar_data(
|
||||
symbol=symbol,
|
||||
exchange=exchange,
|
||||
interval=interval,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
def call_strategy_func(
|
||||
self, strategy: StrategyTemplate, func: Callable, params: Any = None
|
||||
):
|
||||
"""
|
||||
Call function of a strategy and catch any exception raised.
|
||||
"""
|
||||
try:
|
||||
if params:
|
||||
func(params)
|
||||
else:
|
||||
func()
|
||||
except Exception:
|
||||
strategy.trading = False
|
||||
strategy.inited = False
|
||||
|
||||
msg = f"触发异常已停止\n{traceback.format_exc()}"
|
||||
self.write_log(msg, strategy)
|
||||
|
||||
def add_strategy(
|
||||
self, class_name: str, strategy_name: str, vt_symbols: str, setting: dict
|
||||
):
|
||||
"""
|
||||
Add a new strategy.
|
||||
"""
|
||||
if strategy_name in self.strategies:
|
||||
self.write_log(f"创建策略失败,存在重名{strategy_name}")
|
||||
return
|
||||
|
||||
strategy_class = self.classes.get(class_name, None)
|
||||
if not strategy_class:
|
||||
self.write_log(f"创建策略失败,找不到策略类{class_name}")
|
||||
return
|
||||
|
||||
strategy = strategy_class(self, strategy_name, vt_symbols, setting)
|
||||
self.strategies[strategy_name] = strategy
|
||||
|
||||
# Add vt_symbol to strategy map.
|
||||
for vt_symbol in vt_symbols:
|
||||
strategies = self.symbol_strategy_map[vt_symbol]
|
||||
strategies.append(strategy)
|
||||
|
||||
self.save_strategy_setting()
|
||||
self.put_strategy_event(strategy)
|
||||
|
||||
def init_strategy(self, strategy_name: str):
|
||||
"""
|
||||
Init a strategy.
|
||||
"""
|
||||
self.init_executor.submit(self._init_strategy, strategy_name)
|
||||
|
||||
def _init_strategy(self, strategy_name: str):
|
||||
"""
|
||||
Init strategies in queue.
|
||||
"""
|
||||
strategy = self.strategies[strategy_name]
|
||||
|
||||
if strategy.inited:
|
||||
self.write_log(f"{strategy_name}已经完成初始化,禁止重复操作")
|
||||
return
|
||||
|
||||
self.write_log(f"{strategy_name}开始执行初始化")
|
||||
|
||||
# Call on_init function of strategy
|
||||
self.call_strategy_func(strategy, strategy.on_init)
|
||||
|
||||
# Restore strategy data(variables)
|
||||
data = self.strategy_data.get(strategy_name, None)
|
||||
if data:
|
||||
for name in strategy.variables:
|
||||
value = data.get(name, None)
|
||||
if value:
|
||||
setattr(strategy, name, value)
|
||||
|
||||
# Subscribe market data
|
||||
for vt_symbol in strategy.vt_symbols:
|
||||
contract: ContractData = self.main_engine.get_contract(vt_symbol)
|
||||
if contract:
|
||||
req = SubscribeRequest(
|
||||
symbol=contract.symbol, exchange=contract.exchange)
|
||||
self.main_engine.subscribe(req, contract.gateway_name)
|
||||
else:
|
||||
self.write_log(f"行情订阅失败,找不到合约{vt_symbol}", strategy)
|
||||
|
||||
# Put event to update init completed status.
|
||||
strategy.inited = True
|
||||
self.put_strategy_event(strategy)
|
||||
self.write_log(f"{strategy_name}初始化完成")
|
||||
|
||||
def start_strategy(self, strategy_name: str):
|
||||
"""
|
||||
Start a strategy.
|
||||
"""
|
||||
strategy = self.strategies[strategy_name]
|
||||
if not strategy.inited:
|
||||
self.write_log(f"策略{strategy.strategy_name}启动失败,请先初始化")
|
||||
return
|
||||
|
||||
if strategy.trading:
|
||||
self.write_log(f"{strategy_name}已经启动,请勿重复操作")
|
||||
return
|
||||
|
||||
self.call_strategy_func(strategy, strategy.on_start)
|
||||
strategy.trading = True
|
||||
|
||||
self.put_strategy_event(strategy)
|
||||
|
||||
def stop_strategy(self, strategy_name: str):
|
||||
"""
|
||||
Stop a strategy.
|
||||
"""
|
||||
strategy = self.strategies[strategy_name]
|
||||
if not strategy.trading:
|
||||
return
|
||||
|
||||
# Call on_stop function of the strategy
|
||||
self.call_strategy_func(strategy, strategy.on_stop)
|
||||
|
||||
# Change trading status of strategy to False
|
||||
strategy.trading = False
|
||||
|
||||
# Cancel all orders of the strategy
|
||||
strategy.cancel_all()
|
||||
|
||||
# Sync strategy variables to data file
|
||||
self.sync_strategy_data(strategy)
|
||||
|
||||
# Update GUI
|
||||
self.put_strategy_event(strategy)
|
||||
|
||||
def edit_strategy(self, strategy_name: str, setting: dict):
|
||||
"""
|
||||
Edit parameters of a strategy.
|
||||
"""
|
||||
strategy = self.strategies[strategy_name]
|
||||
strategy.update_setting(setting)
|
||||
|
||||
self.save_strategy_setting()
|
||||
self.put_strategy_event(strategy)
|
||||
|
||||
def remove_strategy(self, strategy_name: str):
|
||||
"""
|
||||
Remove a strategy.
|
||||
"""
|
||||
strategy = self.strategies[strategy_name]
|
||||
if strategy.trading:
|
||||
self.write_log(f"策略{strategy.strategy_name}移除失败,请先停止")
|
||||
return
|
||||
|
||||
# Remove from symbol strategy map
|
||||
for vt_symbol in strategy.vt_symbols:
|
||||
strategies = self.symbol_strategy_map[vt_symbol]
|
||||
strategies.remove(strategy)
|
||||
|
||||
# Remove from vt_orderid strategy map
|
||||
for vt_orderid in strategy.active_orderids:
|
||||
if vt_orderid in self.orderid_strategy_map:
|
||||
self.orderid_strategy_map.pop(vt_orderid)
|
||||
|
||||
# Remove from strategies
|
||||
self.strategies.pop(strategy_name)
|
||||
self.save_strategy_setting()
|
||||
|
||||
return True
|
||||
|
||||
def load_strategy_class(self):
|
||||
"""
|
||||
Load strategy class from source code.
|
||||
"""
|
||||
path1 = Path(__file__).parent.joinpath("strategies")
|
||||
self.load_strategy_class_from_folder(path1, "vnpy.app.portfolio_strategy.strategies")
|
||||
|
||||
path2 = Path.cwd().joinpath("strategies")
|
||||
self.load_strategy_class_from_folder(path2, "strategies")
|
||||
|
||||
def load_strategy_class_from_folder(self, path: Path, module_name: str = ""):
|
||||
"""
|
||||
Load strategy class from certain folder.
|
||||
"""
|
||||
for dirpath, dirnames, filenames in os.walk(str(path)):
|
||||
for filename in filenames:
|
||||
if filename.endswith(".py"):
|
||||
strategy_module_name = ".".join(
|
||||
[module_name, filename.replace(".py", "")])
|
||||
elif filename.endswith(".pyd"):
|
||||
strategy_module_name = ".".join(
|
||||
[module_name, filename.split(".")[0]])
|
||||
else:
|
||||
continue
|
||||
|
||||
self.load_strategy_class_from_module(strategy_module_name)
|
||||
|
||||
def load_strategy_class_from_module(self, module_name: str):
|
||||
"""
|
||||
Load strategy class from module file.
|
||||
"""
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
for name in dir(module):
|
||||
value = getattr(module, name)
|
||||
if (isinstance(value, type) and issubclass(value, StrategyTemplate) and value is not StrategyTemplate):
|
||||
self.classes[value.__name__] = value
|
||||
except: # noqa
|
||||
msg = f"策略文件{module_name}加载失败,触发异常:\n{traceback.format_exc()}"
|
||||
self.write_log(msg)
|
||||
|
||||
def load_strategy_data(self):
|
||||
"""
|
||||
Load strategy data from json file.
|
||||
"""
|
||||
self.strategy_data = load_json(self.data_filename)
|
||||
|
||||
def sync_strategy_data(self, strategy: StrategyTemplate):
|
||||
"""
|
||||
Sync strategy data into json file.
|
||||
"""
|
||||
data = strategy.get_variables()
|
||||
data.pop("inited") # Strategy status (inited, trading) should not be synced.
|
||||
data.pop("trading")
|
||||
|
||||
self.strategy_data[strategy.strategy_name] = data
|
||||
save_json(self.data_filename, self.strategy_data)
|
||||
|
||||
def get_all_strategy_class_names(self):
|
||||
"""
|
||||
Return names of strategy classes loaded.
|
||||
"""
|
||||
return list(self.classes.keys())
|
||||
|
||||
def get_strategy_class_parameters(self, class_name: str):
|
||||
"""
|
||||
Get default parameters of a strategy class.
|
||||
"""
|
||||
strategy_class = self.classes[class_name]
|
||||
|
||||
parameters = {}
|
||||
for name in strategy_class.parameters:
|
||||
parameters[name] = getattr(strategy_class, name)
|
||||
|
||||
return parameters
|
||||
|
||||
def get_strategy_parameters(self, strategy_name):
|
||||
"""
|
||||
Get parameters of a strategy.
|
||||
"""
|
||||
strategy = self.strategies[strategy_name]
|
||||
return strategy.get_parameters()
|
||||
|
||||
def init_all_strategies(self):
|
||||
"""
|
||||
"""
|
||||
for strategy_name in self.strategies.keys():
|
||||
self.init_strategy(strategy_name)
|
||||
|
||||
def start_all_strategies(self):
|
||||
"""
|
||||
"""
|
||||
for strategy_name in self.strategies.keys():
|
||||
self.start_strategy(strategy_name)
|
||||
|
||||
def stop_all_strategies(self):
|
||||
"""
|
||||
"""
|
||||
for strategy_name in self.strategies.keys():
|
||||
self.stop_strategy(strategy_name)
|
||||
|
||||
def load_strategy_setting(self):
|
||||
"""
|
||||
Load setting file.
|
||||
"""
|
||||
strategy_setting = load_json(self.setting_filename)
|
||||
|
||||
for strategy_name, strategy_config in strategy_setting.items():
|
||||
self.add_strategy(
|
||||
strategy_config["class_name"],
|
||||
strategy_name,
|
||||
strategy_config["vt_symbols"],
|
||||
strategy_config["setting"]
|
||||
)
|
||||
|
||||
def save_strategy_setting(self):
|
||||
"""
|
||||
Save setting file.
|
||||
"""
|
||||
strategy_setting = {}
|
||||
|
||||
for name, strategy in self.strategies.items():
|
||||
strategy_setting[name] = {
|
||||
"class_name": strategy.__class__.__name__,
|
||||
"vt_symbols": strategy.vt_symbols,
|
||||
"setting": strategy.get_parameters()
|
||||
}
|
||||
|
||||
save_json(self.setting_filename, strategy_setting)
|
||||
|
||||
def put_strategy_event(self, strategy: StrategyTemplate):
|
||||
"""
|
||||
Put an event to update strategy status.
|
||||
"""
|
||||
data = strategy.get_data()
|
||||
event = Event(EVENT_PORTFOLIO_STRATEGY, data)
|
||||
self.event_engine.put(event)
|
||||
|
||||
def write_log(self, msg: str, strategy: StrategyTemplate = None):
|
||||
"""
|
||||
Create cta engine log event.
|
||||
"""
|
||||
if strategy:
|
||||
msg = f"{strategy.strategy_name}: {msg}"
|
||||
|
||||
log = LogData(msg=msg, gateway_name="CtaStrategy")
|
||||
event = Event(type=EVENT_PORTFOLIO_LOG, data=log)
|
||||
self.event_engine.put(event)
|
||||
|
||||
def send_email(self, msg: str, strategy: StrategyTemplate = None):
|
||||
"""
|
||||
Send email to default receiver.
|
||||
"""
|
||||
if strategy:
|
||||
subject = f"{strategy.strategy_name}"
|
||||
else:
|
||||
subject = "组合策略引擎"
|
||||
|
||||
self.main_engine.send_email(subject, msg)
|
0
vnpy/app/portfolio_strategy/strategies/__init__.py
Normal file
0
vnpy/app/portfolio_strategy/strategies/__init__.py
Normal file
@ -0,0 +1,139 @@
|
||||
from typing import List, Dict
|
||||
|
||||
from vnpy.app.portfolio_strategy import StrategyTemplate, StrategyEngine
|
||||
from vnpy.trader.utility import BarGenerator, ArrayManager
|
||||
from vnpy.trader.object import TickData, BarData
|
||||
|
||||
|
||||
class TrendFollowingStrategy(StrategyTemplate):
|
||||
""""""
|
||||
|
||||
author = "用Python的交易员"
|
||||
|
||||
atr_window = 22
|
||||
atr_ma_window = 10
|
||||
rsi_window = 5
|
||||
rsi_entry = 16
|
||||
trailing_percent = 0.8
|
||||
fixed_size = 1
|
||||
|
||||
atr_value = 0
|
||||
atr_ma = 0
|
||||
rsi_value = 0
|
||||
rsi_buy = 0
|
||||
rsi_sell = 0
|
||||
intra_trade_high = 0
|
||||
intra_trade_low = 0
|
||||
|
||||
parameters = [
|
||||
"atr_window",
|
||||
"atr_ma_window",
|
||||
"rsi_window",
|
||||
"rsi_entry",
|
||||
"trailing_percent",
|
||||
"fixed_size"
|
||||
]
|
||||
variables = [
|
||||
"atr_value",
|
||||
"atr_ma",
|
||||
"rsi_value",
|
||||
"rsi_buy",
|
||||
"rsi_sell"
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
strategy_engine: StrategyEngine,
|
||||
strategy_name: str,
|
||||
vt_symbols: List[str],
|
||||
setting: dict
|
||||
):
|
||||
""""""
|
||||
super().__init__(strategy_engine, strategy_name, vt_symbols, setting)
|
||||
|
||||
self.vt_symbol = vt_symbols[0]
|
||||
self.bg = BarGenerator(self.on_bar)
|
||||
self.am = ArrayManager()
|
||||
|
||||
def on_init(self):
|
||||
"""
|
||||
Callback when strategy is inited.
|
||||
"""
|
||||
self.write_log("策略初始化")
|
||||
|
||||
self.rsi_buy = 50 + self.rsi_entry
|
||||
self.rsi_sell = 50 - self.rsi_entry
|
||||
|
||||
self.load_bars(10)
|
||||
|
||||
def on_start(self):
|
||||
"""
|
||||
Callback when strategy is started.
|
||||
"""
|
||||
self.write_log("策略启动")
|
||||
|
||||
def on_stop(self):
|
||||
"""
|
||||
Callback when strategy is stopped.
|
||||
"""
|
||||
self.write_log("策略停止")
|
||||
|
||||
def on_tick(self, tick: TickData):
|
||||
"""
|
||||
Callback of new tick data update.
|
||||
"""
|
||||
self.bg.update_tick(tick)
|
||||
|
||||
def on_bar(self, bar: BarData):
|
||||
"""
|
||||
Callback of new bar data update.
|
||||
"""
|
||||
bars = {bar.vt_symbol: bar}
|
||||
self.on_bars(bars)
|
||||
|
||||
def on_bars(self, bars: Dict[str, BarData]):
|
||||
""""""
|
||||
self.cancel_all()
|
||||
|
||||
bar = bars[self.vt_symbol]
|
||||
am = self.am
|
||||
am.update_bar(bar)
|
||||
if not am.inited:
|
||||
return
|
||||
|
||||
atr_array = am.atr(self.atr_window, array=True)
|
||||
self.atr_value = atr_array[-1]
|
||||
self.atr_ma = atr_array[-self.atr_ma_window:].mean()
|
||||
self.rsi_value = am.rsi(self.rsi_window)
|
||||
|
||||
pos = self.get_pos(self.vt_symbol)
|
||||
|
||||
if pos == 0:
|
||||
self.intra_trade_high = bar.high_price
|
||||
self.intra_trade_low = bar.low_price
|
||||
|
||||
if self.atr_value > self.atr_ma:
|
||||
if self.rsi_value > self.rsi_buy:
|
||||
self.buy(self.vt_symbol, bar.close_price + 5, self.fixed_size)
|
||||
elif self.rsi_value < self.rsi_sell:
|
||||
self.short(self.vt_symbol, bar.close_price - 5, self.fixed_size)
|
||||
|
||||
elif pos > 0:
|
||||
self.intra_trade_high = max(self.intra_trade_high, bar.high_price)
|
||||
self.intra_trade_low = bar.low_price
|
||||
|
||||
long_stop = self.intra_trade_high * (1 - self.trailing_percent / 100)
|
||||
|
||||
if bar.close_price <= long_stop:
|
||||
self.sell(self.vt_symbol, bar.close_price - 5, abs(pos))
|
||||
|
||||
elif pos < 0:
|
||||
self.intra_trade_low = min(self.intra_trade_low, bar.low_price)
|
||||
self.intra_trade_high = bar.high_price
|
||||
|
||||
short_stop = self.intra_trade_low * (1 + self.trailing_percent / 100)
|
||||
|
||||
if bar.close_price >= short_stop:
|
||||
self.cover(self.vt_symbol, bar.close_price + 5, abs(pos))
|
||||
|
||||
self.put_event()
|
247
vnpy/app/portfolio_strategy/template.py
Normal file
247
vnpy/app/portfolio_strategy/template.py
Normal file
@ -0,0 +1,247 @@
|
||||
""""""
|
||||
from abc import ABC
|
||||
from copy import copy
|
||||
from typing import Dict, Set, List, TYPE_CHECKING
|
||||
from collections import defaultdict
|
||||
|
||||
from vnpy.trader.constant import Interval, Direction, Offset
|
||||
from vnpy.trader.object import BarData, TickData, OrderData, TradeData
|
||||
from vnpy.trader.utility import virtual
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .engine import StrategyEngine
|
||||
|
||||
|
||||
class StrategyTemplate(ABC):
|
||||
""""""
|
||||
|
||||
author = ""
|
||||
parameters = []
|
||||
variables = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
strategy_engine: "StrategyEngine",
|
||||
strategy_name: str,
|
||||
vt_symbols: List[str],
|
||||
setting: dict,
|
||||
):
|
||||
""""""
|
||||
self.strategy_engine: "StrategyEngine" = strategy_engine
|
||||
self.strategy_name: str = strategy_name
|
||||
self.vt_symbols: List[str] = vt_symbols
|
||||
|
||||
self.inited: bool = False
|
||||
self.trading: bool = False
|
||||
self.pos: Dict[str, int] = defaultdict(int)
|
||||
self.orders: Dict[str, OrderData] = {}
|
||||
self.active_orderids: Set[str] = set()
|
||||
|
||||
# Copy a new variables list here to avoid duplicate insert when multiple
|
||||
# strategy instances are created with the same strategy class.
|
||||
self.variables: Dict = copy(self.variables)
|
||||
self.variables.insert(0, "inited")
|
||||
self.variables.insert(1, "trading")
|
||||
|
||||
self.update_setting(setting)
|
||||
|
||||
def update_setting(self, setting: dict) -> None:
|
||||
"""
|
||||
Update strategy parameter wtih value in setting dict.
|
||||
"""
|
||||
for name in self.parameters:
|
||||
if name in setting:
|
||||
setattr(self, name, setting[name])
|
||||
|
||||
@classmethod
|
||||
def get_class_parameters(cls) -> Dict:
|
||||
"""
|
||||
Get default parameters dict of strategy class.
|
||||
"""
|
||||
class_parameters = {}
|
||||
for name in cls.parameters:
|
||||
class_parameters[name] = getattr(cls, name)
|
||||
return class_parameters
|
||||
|
||||
def get_parameters(self) -> Dict:
|
||||
"""
|
||||
Get strategy parameters dict.
|
||||
"""
|
||||
strategy_parameters = {}
|
||||
for name in self.parameters:
|
||||
strategy_parameters[name] = getattr(self, name)
|
||||
return strategy_parameters
|
||||
|
||||
def get_variables(self) -> Dict:
|
||||
"""
|
||||
Get strategy variables dict.
|
||||
"""
|
||||
strategy_variables = {}
|
||||
for name in self.variables:
|
||||
strategy_variables[name] = getattr(self, name)
|
||||
return strategy_variables
|
||||
|
||||
def get_data(self) -> Dict:
|
||||
"""
|
||||
Get strategy data.
|
||||
"""
|
||||
strategy_data = {
|
||||
"strategy_name": self.strategy_name,
|
||||
"vt_symbols": self.vt_symbols,
|
||||
"class_name": self.__class__.__name__,
|
||||
"author": self.author,
|
||||
"parameters": self.get_parameters(),
|
||||
"variables": self.get_variables(),
|
||||
}
|
||||
return strategy_data
|
||||
|
||||
@virtual
|
||||
def on_init(self) -> None:
|
||||
"""
|
||||
Callback when strategy is inited.
|
||||
"""
|
||||
pass
|
||||
|
||||
@virtual
|
||||
def on_start(self) -> None:
|
||||
"""
|
||||
Callback when strategy is started.
|
||||
"""
|
||||
pass
|
||||
|
||||
@virtual
|
||||
def on_stop(self) -> None:
|
||||
"""
|
||||
Callback when strategy is stopped.
|
||||
"""
|
||||
pass
|
||||
|
||||
@virtual
|
||||
def on_tick(self, tick: TickData) -> None:
|
||||
"""
|
||||
Callback of new tick data update.
|
||||
"""
|
||||
pass
|
||||
|
||||
@virtual
|
||||
def on_bars(self, bars: Dict[str, BarData]) -> None:
|
||||
"""
|
||||
Callback of new bar data update.
|
||||
"""
|
||||
pass
|
||||
|
||||
def update_trade(self, trade: TradeData) -> None:
|
||||
"""
|
||||
Callback of new trade data update.
|
||||
"""
|
||||
if trade.direction == Direction.LONG:
|
||||
self.pos[trade.vt_symbol] += trade.volume
|
||||
else:
|
||||
self.pos[trade.vt_symbol] -= trade.volume
|
||||
|
||||
def update_order(self, order: OrderData) -> None:
|
||||
"""
|
||||
Callback of new order data update.
|
||||
"""
|
||||
self.orders[order.vt_orderid] = order
|
||||
|
||||
if order.is_active():
|
||||
self.active_orderids.add(order.vt_orderid)
|
||||
elif order.vt_orderid in self.active_orderids:
|
||||
self.active_orderids.remove(order.vt_orderid)
|
||||
|
||||
def buy(self, vt_symbol: str, price: float, volume: float, lock: bool = False) -> List[str]:
|
||||
"""
|
||||
Send buy order to open a long position.
|
||||
"""
|
||||
return self.send_order(vt_symbol, Direction.LONG, Offset.OPEN, price, volume, lock)
|
||||
|
||||
def sell(self, vt_symbol: str, price: float, volume: float, lock: bool = False) -> List[str]:
|
||||
"""
|
||||
Send sell order to close a long position.
|
||||
"""
|
||||
return self.send_order(vt_symbol, Direction.SHORT, Offset.CLOSE, price, volume, lock)
|
||||
|
||||
def short(self, vt_symbol: str, price: float, volume: float, lock: bool = False) -> List[str]:
|
||||
"""
|
||||
Send short order to open as short position.
|
||||
"""
|
||||
return self.send_order(vt_symbol, Direction.SHORT, Offset.OPEN, price, volume, lock)
|
||||
|
||||
def cover(self, vt_symbol: str, price: float, volume: float, lock: bool = False) -> List[str]:
|
||||
"""
|
||||
Send cover order to close a short position.
|
||||
"""
|
||||
return self.send_order(vt_symbol, Direction.LONG, Offset.CLOSE, price, volume, lock)
|
||||
|
||||
def send_order(
|
||||
self,
|
||||
vt_symbol: str,
|
||||
direction: Direction,
|
||||
offset: Offset,
|
||||
price: float,
|
||||
volume: float,
|
||||
lock: bool = False
|
||||
) -> List[str]:
|
||||
"""
|
||||
Send a new order.
|
||||
"""
|
||||
if self.trading:
|
||||
vt_orderids = self.strategy_engine.send_order(
|
||||
self, vt_symbol, direction, offset, price, volume, lock
|
||||
)
|
||||
return vt_orderids
|
||||
else:
|
||||
return []
|
||||
|
||||
def cancel_order(self, vt_orderid: str) -> None:
|
||||
"""
|
||||
Cancel an existing order.
|
||||
"""
|
||||
if self.trading:
|
||||
self.strategy_engine.cancel_order(self, vt_orderid)
|
||||
|
||||
def cancel_all(self) -> None:
|
||||
"""
|
||||
Cancel all orders sent by strategy.
|
||||
"""
|
||||
for vt_orderid in list(self.active_orderids):
|
||||
self.cancel_order(vt_orderid)
|
||||
|
||||
def get_pos(self, vt_symbol: str) -> int:
|
||||
""""""
|
||||
return self.pos.get(vt_symbol, 0)
|
||||
|
||||
def get_order(self, vt_orderid: str) -> OrderData:
|
||||
""""""
|
||||
return self.orders.get(vt_orderid, None)
|
||||
|
||||
def get_all_active_orderids(self) -> List[OrderData]:
|
||||
""""""
|
||||
return list(self.active_orderids.values())
|
||||
|
||||
def write_log(self, msg: str) -> None:
|
||||
"""
|
||||
Write a log message.
|
||||
"""
|
||||
self.strategy_engine.write_log(msg, self)
|
||||
|
||||
def load_bars(self, days: int, interval: Interval = Interval.MINUTE) -> None:
|
||||
"""
|
||||
Load historical bar data for initializing strategy.
|
||||
"""
|
||||
self.strategy_engine.load_bars(self, days, interval)
|
||||
|
||||
def put_event(self) -> None:
|
||||
"""
|
||||
Put an strategy data event for ui update.
|
||||
"""
|
||||
if self.inited:
|
||||
self.strategy_engine.put_strategy_event(self)
|
||||
|
||||
def send_email(self, msg) -> None:
|
||||
"""
|
||||
Send email to default receiver.
|
||||
"""
|
||||
if self.inited:
|
||||
self.strategy_engine.send_email(msg, self)
|
1
vnpy/app/portfolio_strategy/ui/__init__.py
Normal file
1
vnpy/app/portfolio_strategy/ui/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .widget import PortfolioStrategyManager
|
BIN
vnpy/app/portfolio_strategy/ui/strategy.ico
Normal file
BIN
vnpy/app/portfolio_strategy/ui/strategy.ico
Normal file
Binary file not shown.
After Width: | Height: | Size: 66 KiB |
439
vnpy/app/portfolio_strategy/ui/widget.py
Normal file
439
vnpy/app/portfolio_strategy/ui/widget.py
Normal file
@ -0,0 +1,439 @@
|
||||
from typing import Dict
|
||||
|
||||
from vnpy.event import Event, EventEngine
|
||||
from vnpy.trader.engine import MainEngine
|
||||
from vnpy.trader.ui import QtCore, QtGui, QtWidgets
|
||||
from vnpy.trader.ui.widget import (
|
||||
MsgCell,
|
||||
TimeCell,
|
||||
BaseMonitor
|
||||
)
|
||||
from ..base import (
|
||||
APP_NAME,
|
||||
EVENT_PORTFOLIO_LOG,
|
||||
EVENT_PORTFOLIO_STRATEGY
|
||||
)
|
||||
from ..engine import StrategyEngine
|
||||
|
||||
|
||||
class PortfolioStrategyManager(QtWidgets.QWidget):
|
||||
""""""
|
||||
|
||||
signal_log = QtCore.pyqtSignal(Event)
|
||||
signal_strategy = QtCore.pyqtSignal(Event)
|
||||
|
||||
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
|
||||
""""""
|
||||
super().__init__()
|
||||
|
||||
self.main_engine: MainEngine = main_engine
|
||||
self.event_engine: EventEngine = event_engine
|
||||
self.strategy_engine: StrategyEngine = main_engine.get_engine(APP_NAME)
|
||||
|
||||
self.managers: Dict[str, StrategyManager] = {}
|
||||
|
||||
self.init_ui()
|
||||
self.register_event()
|
||||
self.strategy_engine.init_engine()
|
||||
self.update_class_combo()
|
||||
|
||||
def init_ui(self) -> None:
|
||||
""""""
|
||||
self.setWindowTitle("组合策略")
|
||||
|
||||
# Create widgets
|
||||
self.class_combo = QtWidgets.QComboBox()
|
||||
|
||||
add_button = QtWidgets.QPushButton("添加策略")
|
||||
add_button.clicked.connect(self.add_strategy)
|
||||
|
||||
init_button = QtWidgets.QPushButton("全部初始化")
|
||||
init_button.clicked.connect(self.strategy_engine.init_all_strategies)
|
||||
|
||||
start_button = QtWidgets.QPushButton("全部启动")
|
||||
start_button.clicked.connect(self.strategy_engine.start_all_strategies)
|
||||
|
||||
stop_button = QtWidgets.QPushButton("全部停止")
|
||||
stop_button.clicked.connect(self.strategy_engine.stop_all_strategies)
|
||||
|
||||
clear_button = QtWidgets.QPushButton("清空日志")
|
||||
clear_button.clicked.connect(self.clear_log)
|
||||
|
||||
self.scroll_layout = QtWidgets.QVBoxLayout()
|
||||
self.scroll_layout.addStretch()
|
||||
|
||||
scroll_widget = QtWidgets.QWidget()
|
||||
scroll_widget.setLayout(self.scroll_layout)
|
||||
|
||||
scroll_area = QtWidgets.QScrollArea()
|
||||
scroll_area.setWidgetResizable(True)
|
||||
scroll_area.setWidget(scroll_widget)
|
||||
|
||||
self.log_monitor = LogMonitor(self.main_engine, self.event_engine)
|
||||
|
||||
# Set layout
|
||||
hbox1 = QtWidgets.QHBoxLayout()
|
||||
hbox1.addWidget(self.class_combo)
|
||||
hbox1.addWidget(add_button)
|
||||
hbox1.addStretch()
|
||||
hbox1.addWidget(init_button)
|
||||
hbox1.addWidget(start_button)
|
||||
hbox1.addWidget(stop_button)
|
||||
hbox1.addWidget(clear_button)
|
||||
|
||||
hbox2 = QtWidgets.QHBoxLayout()
|
||||
hbox2.addWidget(scroll_area)
|
||||
hbox2.addWidget(self.log_monitor)
|
||||
|
||||
vbox = QtWidgets.QVBoxLayout()
|
||||
vbox.addLayout(hbox1)
|
||||
vbox.addLayout(hbox2)
|
||||
|
||||
self.setLayout(vbox)
|
||||
|
||||
def update_class_combo(self):
|
||||
""""""
|
||||
self.class_combo.addItems(
|
||||
self.strategy_engine.get_all_strategy_class_names()
|
||||
)
|
||||
|
||||
def register_event(self):
|
||||
""""""
|
||||
self.signal_strategy.connect(self.process_strategy_event)
|
||||
|
||||
self.event_engine.register(
|
||||
EVENT_PORTFOLIO_STRATEGY, self.signal_strategy.emit
|
||||
)
|
||||
|
||||
def process_strategy_event(self, event):
|
||||
"""
|
||||
Update strategy status onto its monitor.
|
||||
"""
|
||||
data = event.data
|
||||
strategy_name = data["strategy_name"]
|
||||
|
||||
if strategy_name in self.managers:
|
||||
manager = self.managers[strategy_name]
|
||||
manager.update_data(data)
|
||||
else:
|
||||
manager = StrategyManager(self, self.strategy_engine, data)
|
||||
self.scroll_layout.insertWidget(0, manager)
|
||||
self.managers[strategy_name] = manager
|
||||
|
||||
def remove_strategy(self, strategy_name):
|
||||
""""""
|
||||
manager = self.managers.pop(strategy_name)
|
||||
manager.deleteLater()
|
||||
|
||||
def add_strategy(self):
|
||||
""""""
|
||||
class_name = str(self.class_combo.currentText())
|
||||
if not class_name:
|
||||
return
|
||||
|
||||
parameters = self.strategy_engine.get_strategy_class_parameters(class_name)
|
||||
editor = SettingEditor(parameters, class_name=class_name)
|
||||
n = editor.exec_()
|
||||
|
||||
if n == editor.Accepted:
|
||||
setting = editor.get_setting()
|
||||
vt_symbols = setting.pop("vt_symbols").split(",")
|
||||
strategy_name = setting.pop("strategy_name")
|
||||
|
||||
self.strategy_engine.add_strategy(
|
||||
class_name, strategy_name, vt_symbols, setting
|
||||
)
|
||||
|
||||
def clear_log(self):
|
||||
""""""
|
||||
self.log_monitor.setRowCount(0)
|
||||
|
||||
def show(self):
|
||||
""""""
|
||||
self.showMaximized()
|
||||
|
||||
|
||||
class StrategyManager(QtWidgets.QFrame):
|
||||
"""
|
||||
Manager for a strategy
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
strategy_manager: PortfolioStrategyManager,
|
||||
strategy_engine: StrategyEngine,
|
||||
data: dict
|
||||
):
|
||||
""""""
|
||||
super().__init__()
|
||||
|
||||
self.strategy_manager = strategy_manager
|
||||
self.strategy_engine = strategy_engine
|
||||
|
||||
self.strategy_name = data["strategy_name"]
|
||||
self._data = data
|
||||
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
""""""
|
||||
self.setFixedHeight(300)
|
||||
self.setFrameShape(self.Box)
|
||||
self.setLineWidth(1)
|
||||
|
||||
self.init_button = QtWidgets.QPushButton("初始化")
|
||||
self.init_button.clicked.connect(self.init_strategy)
|
||||
|
||||
self.start_button = QtWidgets.QPushButton("启动")
|
||||
self.start_button.clicked.connect(self.start_strategy)
|
||||
self.start_button.setEnabled(False)
|
||||
|
||||
self.stop_button = QtWidgets.QPushButton("停止")
|
||||
self.stop_button.clicked.connect(self.stop_strategy)
|
||||
self.stop_button.setEnabled(False)
|
||||
|
||||
self.edit_button = QtWidgets.QPushButton("编辑")
|
||||
self.edit_button.clicked.connect(self.edit_strategy)
|
||||
|
||||
self.remove_button = QtWidgets.QPushButton("移除")
|
||||
self.remove_button.clicked.connect(self.remove_strategy)
|
||||
|
||||
strategy_name = self._data["strategy_name"]
|
||||
class_name = self._data["class_name"]
|
||||
author = self._data["author"]
|
||||
|
||||
label_text = (
|
||||
f"{strategy_name} - ({class_name} by {author})"
|
||||
)
|
||||
label = QtWidgets.QLabel(label_text)
|
||||
label.setAlignment(QtCore.Qt.AlignCenter)
|
||||
|
||||
self.parameters_monitor = DataMonitor(self._data["parameters"])
|
||||
self.variables_monitor = DataMonitor(self._data["variables"])
|
||||
|
||||
hbox = QtWidgets.QHBoxLayout()
|
||||
hbox.addWidget(self.init_button)
|
||||
hbox.addWidget(self.start_button)
|
||||
hbox.addWidget(self.stop_button)
|
||||
hbox.addWidget(self.edit_button)
|
||||
hbox.addWidget(self.remove_button)
|
||||
|
||||
vbox = QtWidgets.QVBoxLayout()
|
||||
vbox.addWidget(label)
|
||||
vbox.addLayout(hbox)
|
||||
vbox.addWidget(self.parameters_monitor)
|
||||
vbox.addWidget(self.variables_monitor)
|
||||
self.setLayout(vbox)
|
||||
|
||||
def update_data(self, data: dict):
|
||||
""""""
|
||||
self._data = data
|
||||
|
||||
self.parameters_monitor.update_data(data["parameters"])
|
||||
self.variables_monitor.update_data(data["variables"])
|
||||
|
||||
# Update button status
|
||||
variables = data["variables"]
|
||||
inited = variables["inited"]
|
||||
trading = variables["trading"]
|
||||
|
||||
if not inited:
|
||||
return
|
||||
self.init_button.setEnabled(False)
|
||||
|
||||
if trading:
|
||||
self.start_button.setEnabled(False)
|
||||
self.stop_button.setEnabled(True)
|
||||
self.edit_button.setEnabled(False)
|
||||
self.remove_button.setEnabled(False)
|
||||
else:
|
||||
self.start_button.setEnabled(True)
|
||||
self.stop_button.setEnabled(False)
|
||||
self.edit_button.setEnabled(True)
|
||||
self.remove_button.setEnabled(True)
|
||||
|
||||
def init_strategy(self):
|
||||
""""""
|
||||
self.strategy_engine.init_strategy(self.strategy_name)
|
||||
|
||||
def start_strategy(self):
|
||||
""""""
|
||||
self.strategy_engine.start_strategy(self.strategy_name)
|
||||
|
||||
def stop_strategy(self):
|
||||
""""""
|
||||
self.strategy_engine.stop_strategy(self.strategy_name)
|
||||
|
||||
def edit_strategy(self):
|
||||
""""""
|
||||
strategy_name = self._data["strategy_name"]
|
||||
|
||||
parameters = self.strategy_engine.get_strategy_parameters(strategy_name)
|
||||
editor = SettingEditor(parameters, strategy_name=strategy_name)
|
||||
n = editor.exec_()
|
||||
|
||||
if n == editor.Accepted:
|
||||
setting = editor.get_setting()
|
||||
self.strategy_engine.edit_strategy(strategy_name, setting)
|
||||
|
||||
def remove_strategy(self):
|
||||
""""""
|
||||
result = self.strategy_engine.remove_strategy(self.strategy_name)
|
||||
|
||||
# Only remove strategy gui manager if it has been removed from engine
|
||||
if result:
|
||||
self.strategy_manager.remove_strategy(self.strategy_name)
|
||||
|
||||
|
||||
class DataMonitor(QtWidgets.QTableWidget):
|
||||
"""
|
||||
Table monitor for parameters and variables.
|
||||
"""
|
||||
|
||||
def __init__(self, data: dict):
|
||||
""""""
|
||||
super(DataMonitor, self).__init__()
|
||||
|
||||
self._data = data
|
||||
self.cells = {}
|
||||
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
""""""
|
||||
labels = list(self._data.keys())
|
||||
self.setColumnCount(len(labels))
|
||||
self.setHorizontalHeaderLabels(labels)
|
||||
|
||||
self.setRowCount(1)
|
||||
self.verticalHeader().setSectionResizeMode(
|
||||
QtWidgets.QHeaderView.Stretch
|
||||
)
|
||||
self.verticalHeader().setVisible(False)
|
||||
self.setEditTriggers(self.NoEditTriggers)
|
||||
|
||||
for column, name in enumerate(self._data.keys()):
|
||||
value = self._data[name]
|
||||
|
||||
cell = QtWidgets.QTableWidgetItem(str(value))
|
||||
cell.setTextAlignment(QtCore.Qt.AlignCenter)
|
||||
|
||||
self.setItem(0, column, cell)
|
||||
self.cells[name] = cell
|
||||
|
||||
def update_data(self, data: dict):
|
||||
""""""
|
||||
for name, value in data.items():
|
||||
cell = self.cells[name]
|
||||
cell.setText(str(value))
|
||||
|
||||
|
||||
class LogMonitor(BaseMonitor):
|
||||
"""
|
||||
Monitor for log data.
|
||||
"""
|
||||
|
||||
event_type = EVENT_PORTFOLIO_LOG
|
||||
data_key = ""
|
||||
sorting = False
|
||||
|
||||
headers = {
|
||||
"time": {"display": "时间", "cell": TimeCell, "update": False},
|
||||
"msg": {"display": "信息", "cell": MsgCell, "update": False},
|
||||
}
|
||||
|
||||
def init_ui(self):
|
||||
"""
|
||||
Stretch last column.
|
||||
"""
|
||||
super(LogMonitor, self).init_ui()
|
||||
|
||||
self.horizontalHeader().setSectionResizeMode(
|
||||
1, QtWidgets.QHeaderView.Stretch
|
||||
)
|
||||
|
||||
def insert_new_row(self, data):
|
||||
"""
|
||||
Insert a new row at the top of table.
|
||||
"""
|
||||
super().insert_new_row(data)
|
||||
self.resizeRowToContents(0)
|
||||
|
||||
|
||||
class SettingEditor(QtWidgets.QDialog):
|
||||
"""
|
||||
For creating new strategy and editing strategy parameters.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, parameters: dict, strategy_name: str = "", class_name: str = ""
|
||||
):
|
||||
""""""
|
||||
super(SettingEditor, self).__init__()
|
||||
|
||||
self.parameters = parameters
|
||||
self.strategy_name = strategy_name
|
||||
self.class_name = class_name
|
||||
|
||||
self.edits = {}
|
||||
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
""""""
|
||||
form = QtWidgets.QFormLayout()
|
||||
|
||||
# Add vt_symbols and name edit if add new strategy
|
||||
if self.class_name:
|
||||
self.setWindowTitle(f"添加策略:{self.class_name}")
|
||||
button_text = "添加"
|
||||
parameters = {"strategy_name": "", "vt_symbols": ""}
|
||||
parameters.update(self.parameters)
|
||||
else:
|
||||
self.setWindowTitle(f"参数编辑:{self.strategy_name}")
|
||||
button_text = "确定"
|
||||
parameters = self.parameters
|
||||
|
||||
for name, value in parameters.items():
|
||||
type_ = type(value)
|
||||
|
||||
edit = QtWidgets.QLineEdit(str(value))
|
||||
if type_ is int:
|
||||
validator = QtGui.QIntValidator()
|
||||
edit.setValidator(validator)
|
||||
elif type_ is float:
|
||||
validator = QtGui.QDoubleValidator()
|
||||
edit.setValidator(validator)
|
||||
|
||||
form.addRow(f"{name} {type_}", edit)
|
||||
|
||||
self.edits[name] = (edit, type_)
|
||||
|
||||
button = QtWidgets.QPushButton(button_text)
|
||||
button.clicked.connect(self.accept)
|
||||
form.addRow(button)
|
||||
|
||||
self.setLayout(form)
|
||||
|
||||
def get_setting(self):
|
||||
""""""
|
||||
setting = {}
|
||||
|
||||
if self.class_name:
|
||||
setting["class_name"] = self.class_name
|
||||
|
||||
for name, tp in self.edits.items():
|
||||
edit, type_ = tp
|
||||
value_text = edit.text()
|
||||
|
||||
if type_ == bool:
|
||||
if value_text == "True":
|
||||
value = True
|
||||
else:
|
||||
value = False
|
||||
else:
|
||||
value = type_(value_text)
|
||||
|
||||
setting[name] = value
|
||||
|
||||
return setting
|
@ -595,7 +595,7 @@ class CtaGridTrade(CtaComponent):
|
||||
x.type = type # 网格类型标签
|
||||
# self.open_prices = {} # 套利使用,开仓价格,symbol:price
|
||||
|
||||
def rebuild_grids(self, direction: Direction,
|
||||
def rebuild_grids(self, directions: list = [],
|
||||
upper_line: float = 0.0,
|
||||
down_line: float = 0.0,
|
||||
middle_line: float = 0.0,
|
||||
@ -608,7 +608,7 @@ class CtaGridTrade(CtaComponent):
|
||||
upRate , 上轨网格高度比率
|
||||
dnRate, 下轨网格高度比率
|
||||
"""
|
||||
self.write_log(u'重新拉网:direction:{},upline:{},dnline:{}'.format(direction, upper_line, down_line))
|
||||
self.write_log(u'重新拉网:direction:{},upline:{},dnline:{}'.format(directions, upper_line, down_line))
|
||||
|
||||
# 检查上下网格的高度比率,不能低于0.5
|
||||
if upper_rate < 0.5 or down_rate < 0.5:
|
||||
@ -616,7 +616,7 @@ class CtaGridTrade(CtaComponent):
|
||||
down_rate = max(0.5, down_rate)
|
||||
|
||||
# 重建下网格(移除未挂单、保留开仓得网格、在最低价之下才增加网格
|
||||
if direction == Direction.LONG:
|
||||
if len(directions) == 0 or Direction.LONG in directions:
|
||||
min_long_price = middle_line
|
||||
remove_grids = []
|
||||
opened_grids = []
|
||||
@ -636,8 +636,8 @@ class CtaGridTrade(CtaComponent):
|
||||
self.write_log(u'保留下网格[{}]'.format(opened_grids))
|
||||
|
||||
# 需要重建的剩余网格数量
|
||||
remainLots = len(self.dn_grids)
|
||||
lots = self.max_lots - remainLots
|
||||
remain_lots = len(self.dn_grids)
|
||||
lots = self.max_lots - remain_lots
|
||||
|
||||
down_line = min(down_line, min_long_price - self.grid_height * down_rate)
|
||||
self.write_log(u'需要重建的网格数量:{0},起点:{1}'.format(lots, down_line))
|
||||
@ -651,13 +651,13 @@ class CtaGridTrade(CtaComponent):
|
||||
grid = CtaGrid(direction=Direction.LONG,
|
||||
open_price=open_price,
|
||||
close_price=close_price,
|
||||
volume=self.volume * self.get_volume_rate(remainLots + i))
|
||||
volume=self.volume * self.get_volume_rate(remain_lots + i))
|
||||
grid.reuse_count = reuse_count
|
||||
self.dn_grids.append(grid)
|
||||
self.write_log(u'重新拉下网格:[{0}==>{1}]'.format(down_line, down_line - lots * self.grid_height * down_rate))
|
||||
|
||||
# 重建上网格(移除未挂单、保留开仓得网格、在最高价之上才增加网格
|
||||
if direction == Direction.SHORT:
|
||||
if len(directions) == 0 or Direction.SHORT in directions:
|
||||
max_short_price = middle_line # 最高开空价
|
||||
remove_grids = [] # 移除的网格列表
|
||||
opened_grids = [] # 已开仓的网格列表
|
||||
@ -678,8 +678,8 @@ class CtaGridTrade(CtaComponent):
|
||||
self.write_log(u'保留上网格[{}]'.format(opened_grids))
|
||||
|
||||
# 需要重建的剩余网格数量
|
||||
remainLots = len(self.up_grids)
|
||||
lots = self.max_lots - remainLots
|
||||
remain_lots = len(self.up_grids)
|
||||
lots = self.max_lots - remain_lots
|
||||
upper_line = max(upper_line, max_short_price + self.grid_height * upper_rate)
|
||||
self.write_log(u'需要重建的网格数量:{0},起点:{1}'.format(lots, upper_line))
|
||||
|
||||
@ -692,7 +692,7 @@ class CtaGridTrade(CtaComponent):
|
||||
grid = CtaGrid(direction=Direction.SHORT,
|
||||
open_price=open_price,
|
||||
close_price=close_price,
|
||||
volume=self.volume * self.get_volume_rate(remainLots + i))
|
||||
volume=self.volume * self.get_volume_rate(remain_lots + i))
|
||||
grid.reuse_count = reuse_count
|
||||
self.up_grids.append(grid)
|
||||
|
||||
|
@ -178,13 +178,20 @@ class CtaLineBar(object):
|
||||
|
||||
# (实时运行时,或者addbar小于bar得周期时,不包含最后一根Bar)
|
||||
self.open_array = np.zeros(self.max_hold_bars) # 与lineBar一致得开仓价清单
|
||||
self.open_array[:] = np.nan
|
||||
self.high_array = np.zeros(self.max_hold_bars) # 与lineBar一致得最高价清单
|
||||
self.high_array[:] = np.nan
|
||||
self.low_array = np.zeros(self.max_hold_bars) # 与lineBar一致得最低价清单
|
||||
self.low_array[:] = np.nan
|
||||
self.close_array = np.zeros(self.max_hold_bars) # 与lineBar一致得收盘价清单
|
||||
self.close_array[:] = np.nan
|
||||
|
||||
self.mid3_array = np.zeros(self.max_hold_bars) # 收盘价/最高/最低价 的平均价
|
||||
self.mid3_array[:] = np.nan
|
||||
self.mid4_array = np.zeros(self.max_hold_bars) # 收盘价*2/最高/最低价 的平均价
|
||||
self.mid4_array[:] = np.nan
|
||||
self.mid5_array = np.zeros(self.max_hold_bars) # 收盘价*2/开仓价/最高/最低价 的平均价
|
||||
self.mid5_array[:] = np.nan
|
||||
|
||||
self.export_filename = None
|
||||
self.export_fields = []
|
||||
@ -1263,7 +1270,8 @@ class CtaLineBar(object):
|
||||
# 2.计算前inputPreLen周期内(不包含当前周期)的Bar高点和低点
|
||||
preHigh = max(self.high_array[-count_len:])
|
||||
preLow = min(self.low_array[-count_len:])
|
||||
|
||||
if np.isnan(preHigh) or np.isnan(preLow):
|
||||
return
|
||||
# 保存
|
||||
if len(self.line_pre_high) > self.max_hold_bars:
|
||||
del self.line_pre_high[0]
|
||||
@ -1452,6 +1460,8 @@ class CtaLineBar(object):
|
||||
count_len = min(self.para_ma1_len, self.bar_len)
|
||||
|
||||
barMa1 = ta.MA(self.close_array[-count_len:], count_len)[-1]
|
||||
if np.isnan(barMa1):
|
||||
return
|
||||
barMa1 = round(barMa1, self.round_n)
|
||||
|
||||
if len(self.line_ma1) > self.max_hold_bars:
|
||||
@ -1470,6 +1480,8 @@ class CtaLineBar(object):
|
||||
if self.para_ma2_len > 0:
|
||||
count_len = min(self.para_ma2_len, self.bar_len)
|
||||
barMa2 = ta.MA(self.close_array[-count_len:], count_len)[-1]
|
||||
if np.isnan(barMa2):
|
||||
return
|
||||
barMa2 = round(barMa2, self.round_n)
|
||||
|
||||
if len(self.line_ma2) > self.max_hold_bars:
|
||||
@ -1488,6 +1500,8 @@ class CtaLineBar(object):
|
||||
if self.para_ma3_len > 0:
|
||||
count_len = min(self.para_ma3_len, self.bar_len)
|
||||
barMa3 = ta.MA(self.close_array[-count_len:], count_len)[-1]
|
||||
if np.isnan(barMa3):
|
||||
return
|
||||
barMa3 = round(barMa3, self.round_n)
|
||||
|
||||
if len(self.line_ma3) > self.max_hold_bars:
|
||||
@ -1685,6 +1699,8 @@ class CtaLineBar(object):
|
||||
|
||||
# 3、获取前InputN周期(不包含当前周期)的K线
|
||||
barEma1 = ta.EMA(self.close_array[-ema1_data_len:], count_len)[-1]
|
||||
if np.isnan(barEma1):
|
||||
return
|
||||
barEma1 = round(float(barEma1), self.round_n)
|
||||
|
||||
if len(self.line_ema1) > self.max_hold_bars:
|
||||
@ -1698,7 +1714,8 @@ class CtaLineBar(object):
|
||||
# 3、获取前InputN周期(不包含当前周期)的自适应均线
|
||||
|
||||
barEma2 = ta.EMA(self.close_array[-ema2_data_len:], count_len)[-1]
|
||||
|
||||
if np.isnan(barEma2):
|
||||
return
|
||||
barEma2 = round(float(barEma2), self.round_n)
|
||||
|
||||
if len(self.line_ema2) > self.max_hold_bars:
|
||||
@ -1711,6 +1728,8 @@ class CtaLineBar(object):
|
||||
|
||||
# 3、获取前InputN周期(不包含当前周期)的自适应均线
|
||||
barEma3 = ta.EMA(self.close_array[-ema3_data_len:], count_len)[-1]
|
||||
if np.isnan(barEma3):
|
||||
return
|
||||
barEma3 = round(float(barEma3), self.round_n)
|
||||
|
||||
if len(self.line_ema3) > self.max_hold_bars:
|
||||
@ -1737,6 +1756,8 @@ class CtaLineBar(object):
|
||||
|
||||
# 3、获取前InputN周期(不包含当前周期)的K线
|
||||
barEma1 = ta.EMA(np.append(self.close_array[-ema1_data_len:], [self.cur_price]), count_len)[-1]
|
||||
if np.isnan(barEma1):
|
||||
return
|
||||
self._rt_ema1 = round(float(barEma1), self.round_n)
|
||||
|
||||
# 计算第二条EMA均线
|
||||
@ -1746,7 +1767,8 @@ class CtaLineBar(object):
|
||||
# 3、获取前InputN周期(不包含当前周期)的自适应均线
|
||||
|
||||
barEma2 = ta.EMA(np.append(self.close_array[-ema2_data_len:], [self.cur_price]), count_len)[-1]
|
||||
|
||||
if np.isnan(barEma2):
|
||||
return
|
||||
self._rt_ema2 = round(float(barEma2), self.round_n)
|
||||
|
||||
# 计算第三条EMA均线
|
||||
@ -1755,6 +1777,8 @@ class CtaLineBar(object):
|
||||
|
||||
# 3、获取前InputN周期(不包含当前周期)的自适应均线
|
||||
barEma3 = ta.EMA(np.append(self.close_array[-ema3_data_len:], [self.cur_price]), count_len)[-1]
|
||||
if np.isnan(barEma3):
|
||||
return
|
||||
self._rt_ema3 = round(float(barEma3), self.round_n)
|
||||
|
||||
@property
|
||||
@ -2076,7 +2100,7 @@ class CtaLineBar(object):
|
||||
return
|
||||
|
||||
if self.para_boll_len > 0:
|
||||
if self.bar_len < min(7, self.para_boll_len):
|
||||
if self.bar_len < min(20, self.para_boll_len):
|
||||
self.write_log(u'数据未充分,当前Bar数据数量:{0},计算Boll需要:{1}'.
|
||||
format(len(self.line_bar), min(14, self.para_boll_len) + 1))
|
||||
else:
|
||||
@ -2086,6 +2110,9 @@ class CtaLineBar(object):
|
||||
upper_list, middle_list, lower_list = ta.BBANDS(self.close_array,
|
||||
timeperiod=bollLen, nbdevup=self.para_boll_std_rate,
|
||||
nbdevdn=self.para_boll_std_rate, matype=0)
|
||||
if np.isnan(upper_list[-1]):
|
||||
return
|
||||
|
||||
if len(self.line_boll_upper) > self.max_hold_bars:
|
||||
del self.line_boll_upper[0]
|
||||
if len(self.line_boll_middle) > self.max_hold_bars:
|
||||
@ -2144,6 +2171,8 @@ class CtaLineBar(object):
|
||||
upper_list, middle_list, lower_list = ta.BBANDS(self.close_array,
|
||||
timeperiod=boll2Len, nbdevup=self.para_boll2_std_rate,
|
||||
nbdevdn=self.para_boll2_std_rate, matype=0)
|
||||
if np.isnan(upper_list[-1]):
|
||||
return
|
||||
if len(self.line_boll2_upper) > self.max_hold_bars:
|
||||
del self.line_boll2_upper[0]
|
||||
if len(self.line_boll2_middle) > self.max_hold_bars:
|
||||
@ -2510,14 +2539,19 @@ class CtaLineBar(object):
|
||||
|
||||
hhv = max(self.high_array[-inputKdjLen:])
|
||||
llv = min(self.low_array[-inputKdjLen:])
|
||||
|
||||
if np.isnan(hhv) or np.isnan(llv):
|
||||
return
|
||||
if len(self.line_k) > 0:
|
||||
lastK = self.line_k[-1]
|
||||
if np.isnan(lastK):
|
||||
lastK = 0
|
||||
else:
|
||||
lastK = 0
|
||||
|
||||
if len(self.line_d) > 0:
|
||||
lastD = self.line_d[-1]
|
||||
if np.isnan(lastD):
|
||||
lastD = 0
|
||||
else:
|
||||
lastD = 0
|
||||
|
||||
@ -2620,14 +2654,20 @@ class CtaLineBar(object):
|
||||
|
||||
hhv = max(self.high_array[-data_len:])
|
||||
llv = min(self.low_array[-data_len:])
|
||||
if np.isnan(hhv) or np.isnan(llv):
|
||||
return
|
||||
|
||||
if len(self.line_k) > 0:
|
||||
lastK = self.line_k[-1]
|
||||
if np.isnan(lastK):
|
||||
lastK = 0
|
||||
else:
|
||||
lastK = 0
|
||||
|
||||
if len(self.line_d) > 0:
|
||||
lastD = self.line_d[-1]
|
||||
if np.isnan(lastD):
|
||||
lastD = 0
|
||||
else:
|
||||
lastD = 0
|
||||
|
||||
@ -2747,7 +2787,8 @@ class CtaLineBar(object):
|
||||
dif_list, dea_list, macd_list = ta.MACD(self.close_array[-2 * maxLen:], fastperiod=self.para_macd_fast_len,
|
||||
slowperiod=self.para_macd_slow_len,
|
||||
signalperiod=self.para_macd_signal_len)
|
||||
|
||||
if np.isnan(dif_list[-1]) or np.isnan(dea_list[-1]) or np.isnan(macd_list[-1]):
|
||||
return
|
||||
# dif, dea, macd = ta.MACDEXT(np.array(listClose, dtype=float),
|
||||
# fastperiod=self.inputMacdFastPeriodLen, fastmatype=1,
|
||||
# slowperiod=self.inputMacdSlowPeriodLen, slowmatype=1,
|
||||
@ -2868,6 +2909,9 @@ class CtaLineBar(object):
|
||||
fastperiod=self.para_macd_fast_len,
|
||||
slowperiod=self.para_macd_slow_len, signalperiod=self.para_macd_signal_len)
|
||||
|
||||
if np.isnan(dif[-1]) or np.isnan(dea[-1]) or np.isnan(macd[-1]):
|
||||
return
|
||||
|
||||
self._rt_dif = round(dif[-1], self.round_n) if len(dif) > 0 else None
|
||||
self._rt_dea = round(dea[-1], self.round_n) if len(dea) > 0 else None
|
||||
self._rt_macd = round(macd[-1] * 2, self.round_n) if len(macd) > 0 else None
|
||||
@ -3958,6 +4002,9 @@ class CtaLineBar(object):
|
||||
hhv = max(self.high_array[-bar_len:])
|
||||
llv = min(self.low_array[-bar_len:])
|
||||
|
||||
if np.isnan(hhv) or np.isnan(llv):
|
||||
return
|
||||
|
||||
self.cur_p192 = hhv - (hhv - llv) * 0.192
|
||||
self.cur_p382 = hhv - (hhv - llv) * 0.382
|
||||
self.cur_p500 = (hhv + llv) / 2
|
||||
|
@ -25,6 +25,7 @@ from .event import (
|
||||
)
|
||||
from .gateway import BaseGateway
|
||||
from .object import (
|
||||
Direction,
|
||||
Exchange,
|
||||
CancelRequest,
|
||||
LogData,
|
||||
@ -214,7 +215,7 @@ class MainEngine:
|
||||
"""
|
||||
# 自定义套利合约,交给算法引擎处理
|
||||
if self.algo_engine and req.exchange == Exchange.SPD:
|
||||
return self.algo_engine.send_algo_order(
|
||||
return self.algo_engine.send_spd_order(
|
||||
req=req,
|
||||
gateway_name=gateway_name)
|
||||
|
||||
@ -228,6 +229,11 @@ class MainEngine:
|
||||
"""
|
||||
Send cancel order request to a specific gateway.
|
||||
"""
|
||||
# 自定义套利合约,交给算法引擎处理
|
||||
if self.algo_engine and req.exchange == Exchange.SPD:
|
||||
return self.algo_engine.cancel_spd_order(
|
||||
req=req)
|
||||
|
||||
gateway = self.get_gateway(gateway_name)
|
||||
if gateway:
|
||||
return gateway.cancel_order(req)
|
||||
@ -422,13 +428,19 @@ class OmsEngine(BaseEngine):
|
||||
self.accounts: Dict[str, AccountData] = {}
|
||||
self.contracts: Dict[str, ContractData] = {}
|
||||
self.today_contracts: Dict[str, ContractData] = {}
|
||||
self.custom_contracts = {}
|
||||
|
||||
# 自定义合约
|
||||
self.custom_contracts = {} # vt_symbol: ContractData
|
||||
self.custom_settings = {} # symbol: dict
|
||||
self.symbol_spd_maping = {} # symbol: [spd_symbol]
|
||||
|
||||
self.prices = {}
|
||||
|
||||
self.active_orders: Dict[str, OrderData] = {}
|
||||
|
||||
self.add_function()
|
||||
self.register_event()
|
||||
self.load_contracts()
|
||||
|
||||
def __del__(self):
|
||||
"""保存缓存"""
|
||||
@ -445,6 +457,30 @@ class OmsEngine(BaseEngine):
|
||||
self.contracts = pickle.load(f)
|
||||
self.write_log(f'加载缓存合约字典:{contract_file_name}')
|
||||
|
||||
# 更新自定义合约
|
||||
custom_contracts = self.get_all_custom_contracts()
|
||||
for contract in custom_contracts.values():
|
||||
|
||||
# 更新合约缓存
|
||||
self.contracts.update({contract.symbol: contract})
|
||||
self.contracts.update({contract.vt_symbol: contract})
|
||||
self.today_contracts[contract.vt_symbol] = contract
|
||||
self.today_contracts[contract.symbol] = contract
|
||||
|
||||
# 获取自定义合约的主动腿/被动腿
|
||||
setting = self.custom_settings.get(contract.symbol, {})
|
||||
leg1_symbol = setting.get('leg1_symbol')
|
||||
leg2_symbol = setting.get('leg2_symbol')
|
||||
|
||||
# 构建映射关系
|
||||
for symbol in [leg1_symbol, leg2_symbol]:
|
||||
spd_mapping_list = self.symbol_spd_maping.get(symbol, [])
|
||||
|
||||
# 更新映射 symbol => spd_symbol
|
||||
if contract.symbol not in spd_mapping_list:
|
||||
spd_mapping_list.append(contract.symbol)
|
||||
self.symbol_spd_maping.update({symbol: spd_mapping_list})
|
||||
|
||||
def save_contracts(self) -> None:
|
||||
"""持久化合约对象到缓存文件"""
|
||||
import bz2
|
||||
@ -474,6 +510,7 @@ class OmsEngine(BaseEngine):
|
||||
self.main_engine.get_all_contracts = self.get_all_contracts
|
||||
self.main_engine.get_all_active_orders = self.get_all_active_orders
|
||||
self.main_engine.get_all_custom_contracts = self.get_all_custom_contracts
|
||||
self.main_engine.get_mapping_spd = self.get_mapping_spd
|
||||
self.main_engine.save_contracts = self.save_contracts
|
||||
|
||||
def register_event(self) -> None:
|
||||
@ -515,6 +552,76 @@ class OmsEngine(BaseEngine):
|
||||
position = event.data
|
||||
self.positions[position.vt_positionid] = position
|
||||
|
||||
def reverse_direction(self, direction):
|
||||
"""返回反向持仓"""
|
||||
if direction == Direction.LONG:
|
||||
return Direction.SHORT
|
||||
elif direction == Direction.SHORT:
|
||||
return Direction.LONG
|
||||
return direction
|
||||
|
||||
def create_spd_position_event(self, symbol, direction ):
|
||||
"""创建自定义品种对持仓信息"""
|
||||
spd_symbols = self.symbol_spd_maping.get(symbol, [])
|
||||
if not spd_symbols:
|
||||
return
|
||||
for spd_symbol in spd_symbols:
|
||||
spd_setting = self.custom_settings.get(spd_symbol, None)
|
||||
if not spd_setting:
|
||||
continue
|
||||
|
||||
leg1_symbol = spd_setting.get('leg1_symbol')
|
||||
leg2_symbol = spd_setting.get('leg2_symbol')
|
||||
leg1_contract = self.contracts.get(leg1_symbol)
|
||||
leg2_contract = self.contracts.get(leg2_symbol)
|
||||
spd_contract = self.contracts.get(spd_symbol)
|
||||
|
||||
if leg1_contract is None or leg2_contract is None:
|
||||
continue
|
||||
leg1_ratio = spd_setting.get('leg1_ratio', 1)
|
||||
leg2_ratio = spd_setting.get('leg2_ratio', 1)
|
||||
|
||||
# 找出leg1,leg2的持仓,并判断出spd的方向
|
||||
if leg1_symbol == symbol:
|
||||
k1 = f"{leg1_contract.gateway_name}.{leg1_contract.vt_symbol}.{direction.value}"
|
||||
leg1_pos = self.positions.get(k1)
|
||||
k2 = f"{leg2_contract.gateway_name}.{leg2_contract.vt_symbol}.{self.reverse_direction(direction).value}"
|
||||
leg2_pos = self.positions.get(k2)
|
||||
spd_direction = direction
|
||||
elif leg2_symbol == symbol:
|
||||
k1 = f"{leg1_contract.gateway_name}.{leg1_contract.vt_symbol}.{self.reverse_direction(direction).value}"
|
||||
leg1_pos = self.positions.get(k1)
|
||||
k2 = f"{leg2_contract.gateway_name}.{leg2_contract.vt_symbol}.{direction.value}"
|
||||
leg2_pos = self.positions.get(k2)
|
||||
spd_direction = self.reverse_direction(direction)
|
||||
else:
|
||||
continue
|
||||
|
||||
if leg1_pos is None or leg2_pos is None or leg1_pos.volume ==0 or leg2_pos.volume == 0:
|
||||
continue
|
||||
|
||||
# 根据leg1/leg2的volume ratio,计算出最小spd_volume
|
||||
spd_volume = min(int(leg1_pos.volume/leg1_ratio), int(leg2_pos.volume/leg2_ratio))
|
||||
if spd_volume <= 0:
|
||||
continue
|
||||
if spd_setting.get('is_ratio', False) and leg2_pos.price > 0:
|
||||
spd_price = 100 * (leg2_pos.price * leg1_ratio) / (leg2_pos.price * leg2_ratio)
|
||||
elif spd_setting.get('is_spread', False):
|
||||
spd_price = leg1_pos.price * leg1_ratio - leg2_pos.price * leg2_ratio
|
||||
else:
|
||||
spd_price = 0
|
||||
|
||||
spd_pos = PositionData(
|
||||
gateway_name=spd_contract.gateway_name,
|
||||
symbol=spd_symbol,
|
||||
exchange=Exchange.SPD,
|
||||
direction=spd_direction,
|
||||
volume=spd_volume,
|
||||
price=spd_price
|
||||
)
|
||||
event = Event(EVENT_POSITION, data=spd_pos)
|
||||
self.event_engine.put(event)
|
||||
|
||||
def process_account_event(self, event: Event) -> None:
|
||||
""""""
|
||||
account = event.data
|
||||
@ -624,16 +731,25 @@ class OmsEngine(BaseEngine):
|
||||
]
|
||||
return active_orders
|
||||
|
||||
def get_all_custom_contracts(self):
|
||||
def get_all_custom_contracts(self, rtn_setting=False):
|
||||
"""
|
||||
获取所有自定义合约
|
||||
:return:
|
||||
"""
|
||||
if rtn_setting:
|
||||
if len(self.custom_settings) == 0:
|
||||
c = CustomContract()
|
||||
self.custom_settings = c.get_config()
|
||||
return self.custom_settings
|
||||
|
||||
if len(self.custom_contracts) == 0:
|
||||
c = CustomContract()
|
||||
self.custom_contracts = c.get_contracts()
|
||||
return self.custom_contracts
|
||||
|
||||
def get_mapping_spd(self, symbol):
|
||||
"""根据主动腿/被动腿symbol,获取自定义套利对的symbol list"""
|
||||
return self.symbol_spd_maping.get(symbol, [])
|
||||
|
||||
class CustomContract(object):
|
||||
"""
|
||||
@ -668,6 +784,7 @@ class CustomContract(object):
|
||||
exchange=vn_exchange,
|
||||
name=setting.get('name', symbol),
|
||||
size=setting.get('size', 100),
|
||||
product=None,
|
||||
pricetick=setting.get('price_tick', 0.01),
|
||||
margin_rate=setting.get('margin_rate', 0.1)
|
||||
)
|
||||
|
@ -123,7 +123,7 @@ class BaseGateway(ABC):
|
||||
Tick event of a specific vt_symbol is also pushed.
|
||||
"""
|
||||
self.on_event(EVENT_TICK, tick)
|
||||
self.on_event(EVENT_TICK + tick.vt_symbol, tick)
|
||||
# self.on_event(EVENT_TICK + tick.vt_symbol, tick)
|
||||
|
||||
# 推送Bar
|
||||
kline = self.klines.get(tick.vt_symbol, None)
|
||||
@ -142,7 +142,7 @@ class BaseGateway(ABC):
|
||||
Trade event of a specific vt_symbol is also pushed.
|
||||
"""
|
||||
self.on_event(EVENT_TRADE, trade)
|
||||
self.on_event(EVENT_TRADE + trade.vt_symbol, trade)
|
||||
# self.on_event(EVENT_TRADE + trade.vt_symbol, trade)
|
||||
|
||||
def on_order(self, order: OrderData) -> None:
|
||||
"""
|
||||
@ -150,7 +150,7 @@ class BaseGateway(ABC):
|
||||
Order event of a specific vt_orderid is also pushed.
|
||||
"""
|
||||
self.on_event(EVENT_ORDER, order)
|
||||
self.on_event(EVENT_ORDER + order.vt_orderid, order)
|
||||
# self.on_event(EVENT_ORDER + order.vt_orderid, order)
|
||||
|
||||
def on_position(self, position: PositionData) -> None:
|
||||
"""
|
||||
@ -158,7 +158,7 @@ class BaseGateway(ABC):
|
||||
Position event of a specific vt_symbol is also pushed.
|
||||
"""
|
||||
self.on_event(EVENT_POSITION, position)
|
||||
self.on_event(EVENT_POSITION + position.vt_symbol, position)
|
||||
# self.on_event(EVENT_POSITION + position.vt_symbol, position)
|
||||
|
||||
def on_account(self, account: AccountData) -> None:
|
||||
"""
|
||||
@ -166,7 +166,7 @@ class BaseGateway(ABC):
|
||||
Account event of a specific vt_accountid is also pushed.
|
||||
"""
|
||||
self.on_event(EVENT_ACCOUNT, account)
|
||||
self.on_event(EVENT_ACCOUNT + account.vt_accountid, account)
|
||||
# self.on_event(EVENT_ACCOUNT + account.vt_accountid, account)
|
||||
|
||||
def on_log(self, log: LogData) -> None:
|
||||
"""
|
||||
|
@ -804,11 +804,14 @@ class BarGenerator:
|
||||
"""
|
||||
Generate the bar data and call callback immediately.
|
||||
"""
|
||||
self.bar.datetime = self.bar.datetime.replace(
|
||||
second=0, microsecond=0
|
||||
)
|
||||
self.on_bar(self.bar)
|
||||
bar = self.bar
|
||||
|
||||
if self.bar:
|
||||
bar.datetime = bar.datetime.replace(second=0, microsecond=0)
|
||||
self.on_bar(bar)
|
||||
|
||||
self.bar = None
|
||||
return bar
|
||||
|
||||
|
||||
class ArrayManager(object):
|
||||
@ -1217,7 +1220,6 @@ class ArrayManager(object):
|
||||
def aroon(
|
||||
self,
|
||||
n: int,
|
||||
dev: float,
|
||||
array: bool = False
|
||||
) -> Union[
|
||||
Tuple[np.ndarray, np.ndarray],
|
||||
|
Loading…
Reference in New Issue
Block a user