[增强功能] 添加vnpy原版组合引擎,修改gateway,取消多余event,提高性能;更新K线组件/网格组件

This commit is contained in:
msincenselee 2020-05-18 16:57:58 +08:00
parent 57448173a6
commit fa22e74a8a
20 changed files with 2637 additions and 107 deletions

View File

@ -330,10 +330,13 @@ class AccountRecorder(BaseEngine):
# self.write_log(u'记录委托日志:{}'.format(order.__dict__))
if len(order.sys_orderid) == 0:
# 未有系统的委托编号,不做持久化
return
order.sys_orderid = order.orderid
dt = getattr(order, 'datetime')
if not dt:
order_date = datetime.now().strftime('%Y-%m-%d')
if len(order.time) > 0 and '.' not in order.time:
dt = datetime.strptime(f'{order_date} {order.time}', '%Y-%m-%d %H:%M:%S')
order.datetime = dt
else:
order_date = dt.strftime('%Y-%m-%d')

View File

@ -6,11 +6,12 @@ from functools import lru_cache
from vnpy.event import EventEngine, Event
from vnpy.trader.engine import BaseEngine, MainEngine
from vnpy.trader.event import (
EVENT_TICK, EVENT_TIMER, EVENT_ORDER, EVENT_TRADE)
EVENT_TICK, EVENT_TIMER, EVENT_ORDER, EVENT_TRADE, EVENT_POSITION)
from vnpy.trader.constant import (Direction, Offset, OrderType, Status)
from vnpy.trader.object import (SubscribeRequest, OrderRequest, LogData, CancelRequest)
from vnpy.trader.utility import load_json, save_json, round_to, get_folder_path
from vnpy.trader.utility import load_json, save_json, round_to, get_folder_path,print_dict
from vnpy.trader.util_logger import setup_logger, logging
from vnpy.trader.converter import OffsetConverter
from .template import AlgoTemplate
@ -35,13 +36,13 @@ class AlgoEngine(BaseEngine):
self.symbol_algo_map = {}
self.orderid_algo_map = {}
self.algo_vtorderid_order_map = {} # 记录外部发起的算法交易委托编号,便于通过算法引擎撤单
self.spd_orders = {} # 记录外部发起的算法交易委托编号,便于通过算法引擎撤单
self.algo_templates = {}
self.algo_settings = {}
self.algo_loggers = {} # algo_name: logger
self.offset_converter = OffsetConverter(self.main_engine)
self.load_algo_template()
self.register_event()
@ -60,6 +61,7 @@ class AlgoEngine(BaseEngine):
from .algos.grid_algo import GridAlgo
from .algos.dma_algo import DmaAlgo
from .algos.arbitrage_algo import ArbitrageAlgo
from .algos.spread_algo_v2 import SpreadAlgoV2
self.add_algo_template(TwapAlgo)
self.add_algo_template(IcebergAlgo)
@ -69,6 +71,7 @@ class AlgoEngine(BaseEngine):
self.add_algo_template(GridAlgo)
self.add_algo_template(DmaAlgo)
self.add_algo_template(ArbitrageAlgo)
self.add_algo_template(SpreadAlgoV2)
def add_algo_template(self, template: AlgoTemplate):
""""""
@ -93,6 +96,7 @@ class AlgoEngine(BaseEngine):
self.event_engine.register(EVENT_TIMER, self.process_timer_event)
self.event_engine.register(EVENT_ORDER, self.process_order_event)
self.event_engine.register(EVENT_TRADE, self.process_trade_event)
self.event_engine.register(EVENT_POSITION, self.process_position_event)
def process_tick_event(self, event: Event):
""""""
@ -114,7 +118,7 @@ class AlgoEngine(BaseEngine):
def process_trade_event(self, event: Event):
""""""
trade = event.data
self.offset_converter.update_trade(trade)
algo = self.orderid_algo_map.get(trade.vt_orderid, None)
if algo:
algo.update_trade(trade)
@ -122,11 +126,17 @@ class AlgoEngine(BaseEngine):
def process_order_event(self, event: Event):
""""""
order = event.data
self.offset_converter.update_order(order)
algo = self.orderid_algo_map.get(order.vt_orderid, None)
if algo:
algo.update_order(order)
def process_position_event(self, event: Event):
""""""
position = event.data
self.offset_converter.update_position(position)
def start_algo(self, setting: dict):
""""""
template_name = setting["template_name"]
@ -144,6 +154,7 @@ class AlgoEngine(BaseEngine):
if algo:
algo.stop()
self.algos.pop(algo_name)
return True
def stop_all(self):
""""""
@ -154,7 +165,7 @@ class AlgoEngine(BaseEngine):
""""""
contract = self.main_engine.get_contract(vt_symbol)
if not contract:
self.write_log(f'订阅行情失败,找不到合约:{vt_symbol}', algo)
self.write_log(msg=f'订阅行情失败,找不到合约:{vt_symbol}', algo_name=algo.algo_name)
return
algos = self.symbol_algo_map.setdefault(vt_symbol, set())
@ -186,9 +197,9 @@ class AlgoEngine(BaseEngine):
volume = round_to(volume, contract.min_volume)
if not volume:
return ""
return []
req = OrderRequest(
original_req = OrderRequest(
symbol=contract.symbol,
exchange=contract.exchange,
direction=direction,
@ -197,83 +208,88 @@ class AlgoEngine(BaseEngine):
price=price,
offset=offset
)
vt_orderid = self.main_engine.send_order(req, contract.gateway_name)
req_list = self.offset_converter.convert_order_request(req=original_req, lock=False, gateway_name=contract.gateway_name)
vt_orderids = []
for req in req_list:
vt_orderid = self.main_engine.send_order(req, contract.gateway_name)
if not vt_orderid:
continue
self.orderid_algo_map[vt_orderid] = algo
return vt_orderid
vt_orderids.append(vt_orderid)
self.offset_converter.update_order_request(req, vt_orderid, contract.gateway_name)
self.orderid_algo_map[vt_orderid] = algo
return vt_orderids
def cancel_order(self, algo: AlgoTemplate, vt_orderid: str):
""""""
order = self.main_engine.get_order(vt_orderid)
if not order:
self.write_log(f"委托撤单失败,找不到委托:{vt_orderid}", algo)
return
self.write_log(msg=f"委托撤单失败,找不到委托:{vt_orderid}", algo_name=algo.algo_name)
return False
req = order.create_cancel_request()
self.main_engine.cancel_order(req, order.gateway_name)
return self.main_engine.cancel_order(req, order.gateway_name)
def send_algo_order(self, req: OrderRequest, gateway_name: str):
"""发送算法交易指令"""
self.write_log(u'创建算法交易,gateway_name:{},strategy_name:{},vt_symbol:{},price:{},volume:{}'
def send_spd_order(self, req: OrderRequest, gateway_name: str):
"""发送SPD算法交易指令"""
self.write_log(u'[SPD算法交易],gateway_name:{},strategy_name:{},vt_symbol:{},price:{},volume:{}'
.format(gateway_name, req.strategy_name, req.vt_symbol, req.price, req.volume))
# 创建算法实例,由算法引擎启动
trade_command = ''
if req.direction == Direction.LONG and req.offset == Offset.OPEN:
trade_command = 'Buy'
elif req.direction == Direction.SHORT and req.offset == Offset.OPEN:
trade_command = 'Short'
elif req.direction == Direction.SHORT and req.offset != Offset.OPEN:
trade_command = 'Sell'
elif req.direction == Direction.LONG and req.offset != Offset.OPEN:
trade_command = 'Cover'
all_custom_contracts = self.main_engine.get_all_custom_contracts()
contract_setting = all_custom_contracts.get(req.vt_symbol, {})
algo_setting = {
'templateName': u'SpreadTrading套利',
custom_settings = self.main_engine.get_all_custom_contracts(rtn_setting=True)
contract = custom_settings.get(req.symbol, {})
setting = {
'template_name': u'SpreadAlgoV2',
'order_vt_symbol': req.vt_symbol,
'order_command': trade_command,
'order_direction': req.direction,
'order_offset': req.offset,
'order_price': req.price,
'order_volume': req.volume,
'timer_interval': 60 * 60 * 24,
'strategy_name': req.strategy_name,
'gateway_name': gateway_name
}
algo_setting.update(contract_setting)
# 更新算法配置
setting.update(contract)
# 算法引擎
algo_name = self.start_algo(algo_setting)
self.write_log(u'send_algo_order(): start_algo {}={}'.format(algo_name, str(algo_setting)))
algo_name = self.start_algo(setting)
self.write_log(f'[SPD算法交易]: 实例id: {algo_name}, 配置:{print_dict(setting)}')
# 创建一个Order事件
# 创建一个Order事件, 正在提交
order = req.create_order_data(orderid=algo_name, gateway_name=gateway_name)
order.orderTime = datetime.now().strftime('%H:%M:%S.%f')
order.datetime = datetime.now()
order.time = order.datetime.strftime('%H:%M:%S.%f')
order.status = Status.SUBMITTING
event1 = Event(type=EVENT_ORDER, data=order)
self.event_engine.put(event1)
# 登记在本地的算法委托字典中
self.algo_vtorderid_order_map.update({order.vt_orderid: order})
self.spd_orders.update({order.orderid: order})
return order.vt_orderid
def is_algo_order(self, req: CancelRequest, gateway_name: str):
def get_spd_order(self, orderid):
"""返回spd委托单"""
return self.spd_orders.get(orderid, None)
def is_spd_order(self, req: CancelRequest):
"""是否为外部算法委托单"""
vt_orderid = '.'.join([req.orderid, gateway_name])
if vt_orderid in self.algo_vtorderid_order_map:
if req.orderid in self.spd_orders:
return True
else:
return False
def cancel_algo_order(self, req: CancelRequest, gateway_name: str):
def cancel_spd_order(self, req: CancelRequest):
"""外部算法单撤单"""
vt_orderid = '.'.join([req.orderid, gateway_name])
order = self.algo_vtorderid_order_map.get(vt_orderid, None)
order = self.spd_orders.get(req.orderid, None)
if not order:
self.write_error(f'{vt_orderid}不在算法引擎中,撤单失败')
self.write_error(f'{req.orderid}不在算法引擎中,撤单失败')
return False
algo = self.algos.get(req.orderid, None)
@ -368,6 +384,19 @@ class AlgoEngine(BaseEngine):
return contract
def get_position(self, vt_symbol: str, direction: Direction, gateway_name: str = ''):
""" 查询合约在账号的持仓,需要指定方向"""
if len(gateway_name) == 0:
contract = self.main_engine.get_contract(vt_symbol)
if contract and contract.gateway_name:
gateway_name = contract.gateway_name
vt_position_id = f"{gateway_name}.{vt_symbol}.{direction.value}"
return self.main_engine.get_position(vt_position_id)
def get_position_holding(self, vt_symbol: str, gateway_name: str = ''):
""" 查询合约在账号的持仓(包含多空)"""
return self.offset_converter.get_position_holding(vt_symbol, gateway_name)
def write_log(self, msg: str, algo_name: str = None, level: int = logging.INFO):
"""增强版写日志"""
if algo_name:

View File

@ -154,7 +154,7 @@ class AlgoTemplate:
def cancel_order(self, vt_orderid: str):
""""""
self.algo_engine.cancel_order(self, vt_orderid)
return self.algo_engine.cancel_order(self, vt_orderid)
def cancel_all(self):
""""""

View File

@ -333,8 +333,8 @@ class BackTestingEngine(object):
self.fix_commission.update({vt_symbol: rate})
def get_commission_rate(self, vt_symbol: str):
""" 获取保证金比例,缺省万分之一"""
return self.commission_rate.get(vt_symbol, float(0.00001))
""" 获取保证金比例,缺省千分之2"""
return self.commission_rate.get(vt_symbol, float(0.0004))
def get_fix_commission(self, vt_symbol: str):
return self.fix_commission.get(vt_symbol, 0)
@ -561,7 +561,7 @@ class BackTestingEngine(object):
margin_rate = symbol_data.get('margin_rate', 0.1)
self.set_margin_rate(symbol, margin_rate)
self.set_commission_rate(symbol, symbol_data.get('commission_rate', float(0.0001)))
self.set_commission_rate(symbol, symbol_data.get('commission_rate', float(0.0004)))
self.set_contract(
symbol=symbol,
@ -2232,7 +2232,7 @@ class TradingResult(object):
self.volume = volume # 交易数量(+/-代表方向)
self.group_id = group_id # 主交易ID针对多手平仓
self.turnover = (self.open_price + self.exit_price) * abs(volume) * margin_rate # 成交金额(实际保证金金额)
self.turnover = (self.open_price + self.exit_price) * abs(volume) # 成交金额(实际保证金金额)
if fix_commission > 0:
self.commission = fix_commission * abs(self.volume)
else:

View File

@ -42,6 +42,7 @@ from vnpy.trader.event import (
)
from vnpy.trader.constant import (
Direction,
Exchange,
OrderType,
Offset,
Status
@ -349,12 +350,27 @@ class CtaEngine(BaseEngine):
# Update GUI
self.put_strategy_event(strategy)
if self.engine_config.get('trade_2_wx', False):
accountid = self.engine_config.get('accountid', '-')
d = {
'account': accountid,
'strategy': strategy_name,
'symbol': trade.symbol,
'action': f'{trade.direction.value} {trade.offset.value}',
'price': str(trade.price),
'volume': trade.volume,
'remark': f'{accountid}:{strategy_name}',
'timestamp': trade.time
}
send_wx_msg(content=d, target=accountid)
def process_position_event(self, event: Event):
""""""
position = event.data
self.offset_converter.update_position(position)
def check_unsubscribed_symbols(self):
"""检查未订阅合约"""
@ -1666,7 +1682,8 @@ class CtaEngine(BaseEngine):
gateway_names = self.main_engine.get_all_gateway_names()
gateway_name = gateway_names[0] if len(gateway_names) > 0 else ""
symbol, exchange = extract_vt_symbol(vt_symbol)
self.main_engine.subscribe(req=SubscribeRequest(symbol=symbol, exchange=exchange), gateway_name=gateway_name)
self.main_engine.subscribe(req=SubscribeRequest(symbol=symbol, exchange=exchange),
gateway_name=gateway_name)
if volume > 0 and tick:
contract = self.main_engine.get_contract(vt_symbol)
req = OrderRequest(
@ -1821,9 +1838,13 @@ class CtaEngine(BaseEngine):
if level in [logging.CRITICAL, logging.ERROR, logging.WARNING]:
print(f"{strategy_name}: {msg}" if strategy_name else msg, file=sys.stderr)
def write_error(self, msg: str, strategy_name: str = ''):
if level in [logging.CRITICAL, logging.WARN, logging.WARNING]:
send_wx_msg(content=f"{strategy_name}: {msg}" if strategy_name else msg)
def write_error(self, msg: str, strategy_name: str = '', level: int = logging.ERROR):
"""写入错误日志"""
self.write_log(msg=msg, strategy_name=strategy_name, level=logging.ERROR)
self.write_log(msg=msg, strategy_name=strategy_name, level=level)
def send_email(self, msg: str, strategy: CtaTemplate = None):
"""

View File

@ -0,0 +1,23 @@
from pathlib import Path
from vnpy.trader.app import BaseApp
from vnpy.trader.constant import Direction
from vnpy.trader.object import TickData, BarData, TradeData, OrderData
from vnpy.trader.utility import BarGenerator, ArrayManager
from .base import APP_NAME
from .engine import StrategyEngine
from .template import StrategyTemplate
from .backtesting import BacktestingEngine
class PortfolioStrategyApp(BaseApp):
""""""
app_name = APP_NAME
app_module = __module__
app_path = Path(__file__).parent
display_name = "组合策略"
engine_class = StrategyEngine
widget_name = "PortfolioStrategyManager"
icon_name = "strategy.ico"

View File

@ -0,0 +1,821 @@
from collections import defaultdict
from datetime import date, datetime, timedelta
from typing import Dict, List, Set, Tuple
from functools import lru_cache
import traceback
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pandas import DataFrame
from vnpy.trader.constant import Direction, Offset, Interval, Status
from vnpy.trader.database import database_manager
from vnpy.trader.object import OrderData, TradeData, BarData
from vnpy.trader.utility import round_to, extract_vt_symbol
from .template import StrategyTemplate
# Set seaborn style
sns.set_style("whitegrid")
INTERVAL_DELTA_MAP = {
Interval.MINUTE: timedelta(minutes=1),
Interval.HOUR: timedelta(hours=1),
Interval.DAILY: timedelta(days=1),
}
class BacktestingEngine:
""""""
gateway_name = "BACKTESTING"
def __init__(self):
""""""
self.vt_symbols: List[str] = []
self.start: datetime = None
self.end: datetime = None
self.rates: Dict[str, float] = 0
self.slippages: Dict[str, float] = 0
self.sizes: Dict[str, float] = 1
self.priceticks: Dict[str, float] = 0
self.capital: float = 1_000_000
self.strategy: StrategyTemplate = None
self.bars: Dict[str, BarData] = {}
self.datetime: datetime = None
self.interval: Interval = None
self.days: int = 0
self.history_data: Dict[Tuple, BarData] = {}
self.dts: Set[datetime] = set()
self.limit_order_count = 0
self.limit_orders = {}
self.active_limit_orders = {}
self.trade_count = 0
self.trades = {}
self.logs = []
self.daily_results = {}
self.daily_df = None
def clear_data(self) -> None:
"""
Clear all data of last backtesting.
"""
self.strategy = None
self.bars = {}
self.datetime = None
self.limit_order_count = 0
self.limit_orders.clear()
self.active_limit_orders.clear()
self.trade_count = 0
self.trades.clear()
self.logs.clear()
self.daily_results.clear()
self.daily_df = None
def set_parameters(
self,
vt_symbols: List[str],
interval: Interval,
start: datetime,
rates: Dict[str, float],
slippages: Dict[str, float],
sizes: Dict[str, float],
priceticks: Dict[str, float],
capital: int = 0,
end: datetime = None
) -> None:
""""""
self.vt_symbols = vt_symbols
self.interval = interval
self.rates = rates
self.slippages = slippages
self.sizes = sizes
self.priceticks = priceticks
self.start = start
self.end = end
self.capital = capital
def add_strategy(self, strategy_class: type, setting: dict) -> None:
""""""
self.strategy = strategy_class(
self, strategy_class.__name__, self.vt_symbols, setting
)
def load_data(self) -> None:
""""""
self.output("开始加载历史数据")
if not self.end:
self.end = datetime.now()
if self.start >= self.end:
self.output("起始日期必须小于结束日期")
return
# Clear previously loaded history data
self.history_data.clear()
self.dts.clear()
# Load 30 days of data each time and allow for progress update
progress_delta = timedelta(days=30)
total_delta = self.end - self.start
interval_delta = INTERVAL_DELTA_MAP[self.interval]
for vt_symbol in self.vt_symbols:
start = self.start
end = self.start + progress_delta
progress = 0
data_count = 0
while start < self.end:
end = min(end, self.end) # Make sure end time stays within set range
data = load_bar_data(
vt_symbol,
self.interval,
start,
end
)
for bar in data:
self.dts.add(bar.datetime)
self.history_data[(bar.datetime, vt_symbol)] = bar
data_count += 1
progress += progress_delta / total_delta
progress = min(progress, 1)
progress_bar = "#" * int(progress * 10)
self.output(f"{vt_symbol}加载进度:{progress_bar} [{progress:.0%}]")
start = end + interval_delta
end += (progress_delta + interval_delta)
self.output(f"{vt_symbol}历史数据加载完成,数据量:{data_count}")
self.output("所有历史数据加载完成")
def run_backtesting(self) -> None:
""""""
self.strategy.on_init()
# Generate sorted datetime list
dts = list(self.dts)
dts.sort()
# Use the first [days] of history data for initializing strategy
day_count = 0
ix = 0
for ix, dt in enumerate(dts):
if self.datetime and dt.day != self.datetime.day:
day_count += 1
if day_count >= self.days:
break
try:
self.new_bars(dt)
except Exception:
self.output("触发异常,回测终止")
self.output(traceback.format_exc())
return
self.strategy.inited = True
self.output("策略初始化完成")
self.strategy.on_start()
self.strategy.trading = True
self.output("开始回放历史数据")
# Use the rest of history data for running backtesting
for dt in dts[ix:]:
try:
self.new_bars(dt)
except Exception:
self.output("触发异常,回测终止")
self.output(traceback.format_exc())
return
self.output("历史数据回放结束")
def calculate_result(self) -> None:
""""""
self.output("开始计算逐日盯市盈亏")
if not self.trades:
self.output("成交记录为空,无法计算")
return
# Add trade data into daily reuslt.
for trade in self.trades.values():
d = trade.datetime.date()
daily_result = self.daily_results[d]
daily_result.add_trade(trade)
# Calculate daily result by iteration.
pre_closes = {}
start_poses = {}
for daily_result in self.daily_results.values():
daily_result.calculate_pnl(
pre_closes,
start_poses,
self.sizes,
self.rates,
self.slippages,
)
pre_closes = daily_result.close_prices
start_poses = daily_result.end_poses
# Generate dataframe
results = defaultdict(list)
for daily_result in self.daily_results.values():
fields = [
"date", "trade_count", "turnover",
"commission", "slippage", "trading_pnl",
"holding_pnl", "total_pnl", "net_pnl"
]
for key in fields:
value = getattr(daily_result, key)
results[key].append(value)
self.daily_df = DataFrame.from_dict(results).set_index("date")
self.output("逐日盯市盈亏计算完成")
return self.daily_df
def calculate_statistics(self, df: DataFrame = None, output=True) -> None:
""""""
self.output("开始计算策略统计指标")
# Check DataFrame input exterior
if df is None:
df = self.daily_df
# Check for init DataFrame
if df is None:
# Set all statistics to 0 if no trade.
start_date = ""
end_date = ""
total_days = 0
profit_days = 0
loss_days = 0
end_balance = 0
max_drawdown = 0
max_ddpercent = 0
max_drawdown_duration = 0
total_net_pnl = 0
daily_net_pnl = 0
total_commission = 0
daily_commission = 0
total_slippage = 0
daily_slippage = 0
total_turnover = 0
daily_turnover = 0
total_trade_count = 0
daily_trade_count = 0
total_return = 0
annual_return = 0
daily_return = 0
return_std = 0
sharpe_ratio = 0
return_drawdown_ratio = 0
else:
# Calculate balance related time series data
df["balance"] = df["net_pnl"].cumsum() + self.capital
df["return"] = np.log(df["balance"] / df["balance"].shift(1)).fillna(0)
df["highlevel"] = (
df["balance"].rolling(
min_periods=1, window=len(df), center=False).max()
)
df["drawdown"] = df["balance"] - df["highlevel"]
df["ddpercent"] = df["drawdown"] / df["highlevel"] * 100
# Calculate statistics value
start_date = df.index[0]
end_date = df.index[-1]
total_days = len(df)
profit_days = len(df[df["net_pnl"] > 0])
loss_days = len(df[df["net_pnl"] < 0])
end_balance = df["balance"].iloc[-1]
max_drawdown = df["drawdown"].min()
max_ddpercent = df["ddpercent"].min()
max_drawdown_end = df["drawdown"].idxmin()
if isinstance(max_drawdown_end, date):
max_drawdown_start = df["balance"][:max_drawdown_end].idxmax()
max_drawdown_duration = (max_drawdown_end - max_drawdown_start).days
else:
max_drawdown_duration = 0
total_net_pnl = df["net_pnl"].sum()
daily_net_pnl = total_net_pnl / total_days
total_commission = df["commission"].sum()
daily_commission = total_commission / total_days
total_slippage = df["slippage"].sum()
daily_slippage = total_slippage / total_days
total_turnover = df["turnover"].sum()
daily_turnover = total_turnover / total_days
total_trade_count = df["trade_count"].sum()
daily_trade_count = total_trade_count / total_days
total_return = (end_balance / self.capital - 1) * 100
annual_return = total_return / total_days * 240
daily_return = df["return"].mean() * 100
return_std = df["return"].std() * 100
if return_std:
sharpe_ratio = daily_return / return_std * np.sqrt(240)
else:
sharpe_ratio = 0
return_drawdown_ratio = -total_return / max_ddpercent
# Output
if output:
self.output("-" * 30)
self.output(f"首个交易日:\t{start_date}")
self.output(f"最后交易日:\t{end_date}")
self.output(f"总交易日:\t{total_days}")
self.output(f"盈利交易日:\t{profit_days}")
self.output(f"亏损交易日:\t{loss_days}")
self.output(f"起始资金:\t{self.capital:,.2f}")
self.output(f"结束资金:\t{end_balance:,.2f}")
self.output(f"总收益率:\t{total_return:,.2f}%")
self.output(f"年化收益:\t{annual_return:,.2f}%")
self.output(f"最大回撤: \t{max_drawdown:,.2f}")
self.output(f"百分比最大回撤: {max_ddpercent:,.2f}%")
self.output(f"最长回撤天数: \t{max_drawdown_duration}")
self.output(f"总盈亏:\t{total_net_pnl:,.2f}")
self.output(f"总手续费:\t{total_commission:,.2f}")
self.output(f"总滑点:\t{total_slippage:,.2f}")
self.output(f"总成交金额:\t{total_turnover:,.2f}")
self.output(f"总成交笔数:\t{total_trade_count}")
self.output(f"日均盈亏:\t{daily_net_pnl:,.2f}")
self.output(f"日均手续费:\t{daily_commission:,.2f}")
self.output(f"日均滑点:\t{daily_slippage:,.2f}")
self.output(f"日均成交金额:\t{daily_turnover:,.2f}")
self.output(f"日均成交笔数:\t{daily_trade_count}")
self.output(f"日均收益率:\t{daily_return:,.2f}%")
self.output(f"收益标准差:\t{return_std:,.2f}%")
self.output(f"Sharpe Ratio\t{sharpe_ratio:,.2f}")
self.output(f"收益回撤比:\t{return_drawdown_ratio:,.2f}")
statistics = {
"start_date": start_date,
"end_date": end_date,
"total_days": total_days,
"profit_days": profit_days,
"loss_days": loss_days,
"capital": self.capital,
"end_balance": end_balance,
"max_drawdown": max_drawdown,
"max_ddpercent": max_ddpercent,
"max_drawdown_duration": max_drawdown_duration,
"total_net_pnl": total_net_pnl,
"daily_net_pnl": daily_net_pnl,
"total_commission": total_commission,
"daily_commission": daily_commission,
"total_slippage": total_slippage,
"daily_slippage": daily_slippage,
"total_turnover": total_turnover,
"daily_turnover": daily_turnover,
"total_trade_count": total_trade_count,
"daily_trade_count": daily_trade_count,
"total_return": total_return,
"annual_return": annual_return,
"daily_return": daily_return,
"return_std": return_std,
"sharpe_ratio": sharpe_ratio,
"return_drawdown_ratio": return_drawdown_ratio,
}
# Filter potential error infinite value
for key, value in statistics.items():
if value in (np.inf, -np.inf):
value = 0
statistics[key] = np.nan_to_num(value)
self.output("策略统计指标计算完成")
return statistics
def show_chart(self, df: DataFrame = None) -> None:
""""""
# Check DataFrame input exterior
if df is None:
df = self.daily_df
# Check for init DataFrame
if df is None:
return
plt.figure(figsize=(10, 16))
balance_plot = plt.subplot(4, 1, 1)
balance_plot.set_title("Balance")
df["balance"].plot(legend=True)
drawdown_plot = plt.subplot(4, 1, 2)
drawdown_plot.set_title("Drawdown")
drawdown_plot.fill_between(range(len(df)), df["drawdown"].values)
pnl_plot = plt.subplot(4, 1, 3)
pnl_plot.set_title("Daily Pnl")
df["net_pnl"].plot(kind="bar", legend=False, grid=False, xticks=[])
distribution_plot = plt.subplot(4, 1, 4)
distribution_plot.set_title("Daily Pnl Distribution")
df["net_pnl"].hist(bins=50)
plt.show()
def update_daily_close(self, bars: Dict[str, BarData], dt: datetime) -> None:
""""""
d = dt.date()
close_prices = {}
for bar in bars.values():
close_prices[bar.vt_symbol] = bar.close_price
daily_result = self.daily_results.get(d, None)
if daily_result:
daily_result.update_close_prices(close_prices)
else:
self.daily_results[d] = PortfolioDailyResult(d, close_prices)
def new_bars(self, dt: datetime) -> None:
""""""
self.datetime = dt
self.bars.clear()
for vt_symbol in self.vt_symbols:
bar = self.history_data.get((dt, vt_symbol), None)
if bar:
self.bars[vt_symbol] = bar
else:
dt_str = dt.strftime("%Y-%m-%d %H:%M:%S")
self.output(f"数据缺失:{dt_str} {vt_symbol}")
self.cross_limit_order()
self.strategy.on_bars(self.bars)
self.update_daily_close(self.bars, dt)
def cross_limit_order(self) -> None:
"""
Cross limit order with last bar/tick data.
"""
for order in list(self.active_limit_orders.values()):
bar = self.bars[order.vt_symbol]
long_cross_price = bar.low_price
short_cross_price = bar.high_price
long_best_price = bar.open_price
short_best_price = bar.open_price
# Push order update with status "not traded" (pending).
if order.status == Status.SUBMITTING:
order.status = Status.NOTTRADED
self.strategy.update_order(order)
# Check whether limit orders can be filled.
long_cross = (
order.direction == Direction.LONG
and order.price >= long_cross_price
and long_cross_price > 0
)
short_cross = (
order.direction == Direction.SHORT
and order.price <= short_cross_price
and short_cross_price > 0
)
if not long_cross and not short_cross:
continue
# Push order update with status "all traded" (filled).
order.traded = order.volume
order.status = Status.ALLTRADED
self.strategy.update_order(order)
self.active_limit_orders.pop(order.vt_orderid)
# Push trade update
self.trade_count += 1
if long_cross:
trade_price = min(order.price, long_best_price)
else:
trade_price = max(order.price, short_best_price)
trade = TradeData(
symbol=order.symbol,
exchange=order.exchange,
orderid=order.orderid,
tradeid=str(self.trade_count),
direction=order.direction,
offset=order.offset,
price=trade_price,
volume=order.volume,
time=self.datetime.strftime("%H:%M:%S"),
gateway_name=self.gateway_name,
)
trade.datetime = self.datetime
self.strategy.update_trade(trade)
self.trades[trade.vt_tradeid] = trade
def load_bars(
self,
strategy: StrategyTemplate,
days: int,
interval: Interval
) -> None:
""""""
self.days = days
def send_order(
self,
strategy: StrategyTemplate,
vt_symbol: str,
direction: Direction,
offset: Offset,
price: float,
volume: float,
lock: bool
) -> List[str]:
""""""
price = round_to(price, self.priceticks[vt_symbol])
symbol, exchange = extract_vt_symbol(vt_symbol)
self.limit_order_count += 1
order = OrderData(
symbol=symbol,
exchange=exchange,
orderid=str(self.limit_order_count),
direction=direction,
offset=offset,
price=price,
volume=volume,
status=Status.SUBMITTING,
gateway_name=self.gateway_name,
)
order.datetime = self.datetime
self.active_limit_orders[order.vt_orderid] = order
self.limit_orders[order.vt_orderid] = order
return [order.vt_orderid]
def cancel_order(self, strategy: StrategyTemplate, vt_orderid: str) -> None:
"""
Cancel order by vt_orderid.
"""
if vt_orderid not in self.active_limit_orders:
return
order = self.active_limit_orders.pop(vt_orderid)
order.status = Status.CANCELLED
self.strategy.update_order(order)
def write_log(self, msg: str, strategy: StrategyTemplate = None) -> None:
"""
Write log message.
"""
msg = f"{self.datetime}\t{msg}"
self.logs.append(msg)
def send_email(self, msg: str, strategy: StrategyTemplate = None) -> None:
"""
Send email to default receiver.
"""
pass
def sync_strategy_data(self, strategy: StrategyTemplate) -> None:
"""
Sync strategy data into json file.
"""
pass
def put_strategy_event(self, strategy: StrategyTemplate) -> None:
"""
Put an event to update strategy status.
"""
pass
def output(self, msg) -> None:
"""
Output message of backtesting engine.
"""
print(f"{datetime.now()}\t{msg}")
def get_all_trades(self) -> List[TradeData]:
"""
Return all trade data of current backtesting result.
"""
return list(self.trades.values())
def get_all_orders(self) -> List[OrderData]:
"""
Return all limit order data of current backtesting result.
"""
return list(self.limit_orders.values())
def get_all_daily_results(self) -> List["PortfolioDailyResult"]:
"""
Return all daily result data.
"""
return list(self.daily_results.values())
class ContractDailyResult:
""""""
def __init__(self, result_date: date, close_price: float):
""""""
self.date: date = result_date
self.close_price: float = close_price
self.pre_close: float = 0
self.trades: List[TradeData] = []
self.trade_count: int = 0
self.start_pos: float = 0
self.end_pos: float = 0
self.turnover: float = 0
self.commission: float = 0
self.slippage: float = 0
self.trading_pnl: float = 0
self.holding_pnl: float = 0
self.total_pnl: float = 0
self.net_pnl: float = 0
def add_trade(self, trade: TradeData) -> None:
""""""
self.trades.append(trade)
def calculate_pnl(
self,
pre_close: float,
start_pos: float,
size: int,
rate: float,
slippage: float
) -> None:
""""""
# If no pre_close provided on the first day,
# use value 1 to avoid zero division error
if pre_close:
self.pre_close = pre_close
else:
self.pre_close = 1
# Holding pnl is the pnl from holding position at day start
self.start_pos = start_pos
self.end_pos = start_pos
self.holding_pnl = self.start_pos * (self.close_price - self.pre_close) * size
# Trading pnl is the pnl from new trade during the day
self.trade_count = len(self.trades)
for trade in self.trades:
if trade.direction == Direction.LONG:
pos_change = trade.volume
else:
pos_change = -trade.volume
self.end_pos += pos_change
turnover = trade.volume * size * trade.price
self.trading_pnl += pos_change * (self.close_price - trade.price) * size
self.slippage += trade.volume * size * slippage
self.turnover += turnover
self.commission += turnover * rate
# Net pnl takes account of commission and slippage cost
self.total_pnl = self.trading_pnl + self.holding_pnl
self.net_pnl = self.total_pnl - self.commission - self.slippage
def update_close_price(self, close_price: float) -> None:
""""""
self.close_price = close_price
class PortfolioDailyResult:
""""""
def __init__(self, result_date: date, close_prices: Dict[str, float]):
""""""
self.date: date = result_date
self.close_prices: Dict[str, float] = close_prices
self.pre_closes: Dict[str, float] = {}
self.start_poses: Dict[str, float] = {}
self.end_poses: Dict[str, float] = {}
self.contract_results: Dict[str, ContractDailyResult] = {}
for vt_symbol, close_price in close_prices.items():
self.contract_results[vt_symbol] = ContractDailyResult(result_date, close_price)
self.trade_count: int = 0
self.turnover: float = 0
self.commission: float = 0
self.slippage: float = 0
self.trading_pnl: float = 0
self.holding_pnl: float = 0
self.total_pnl: float = 0
self.net_pnl: float = 0
def add_trade(self, trade: TradeData) -> None:
""""""
contract_result = self.contract_results[trade.vt_symbol]
contract_result.add_trade(trade)
def calculate_pnl(
self,
pre_closes: Dict[str, float],
start_poses: Dict[str, float],
sizes: Dict[str, float],
rates: Dict[str, float],
slippages: Dict[str, float],
) -> None:
""""""
self.pre_closes = pre_closes
for vt_symbol, contract_result in self.contract_results.items():
contract_result.calculate_pnl(
pre_closes.get(vt_symbol, 0),
start_poses.get(vt_symbol, 0),
sizes[vt_symbol],
rates[vt_symbol],
slippages[vt_symbol]
)
self.trade_count += contract_result.trade_count
self.turnover += contract_result.turnover
self.commission += contract_result.commission
self.slippage += contract_result.slippage
self.trading_pnl += contract_result.trading_pnl
self.holding_pnl += contract_result.holding_pnl
self.total_pnl += contract_result.total_pnl
self.net_pnl += contract_result.net_pnl
self.end_poses[vt_symbol] = contract_result.end_pos
def update_close_prices(self, close_prices: Dict[str, float]) -> None:
""""""
self.close_prices = close_prices
for vt_symbol, close_price in close_prices.items():
contract_result = self.contract_results[vt_symbol]
contract_result.update_close_price(close_price)
@lru_cache(maxsize=999)
def load_bar_data(
vt_symbol: str,
interval: Interval,
start: datetime,
end: datetime
):
""""""
symbol, exchange = extract_vt_symbol(vt_symbol)
return database_manager.load_bar_data(
symbol, exchange, interval, start, end
)

View File

@ -0,0 +1,17 @@
"""
Defines constants and objects used in PortfolioStrategy App.
"""
from enum import Enum
APP_NAME = "PortfolioStrategy"
class EngineType(Enum):
LIVE = "实盘"
BACKTESTING = "回测"
EVENT_PORTFOLIO_LOG = "ePortfolioLog"
EVENT_PORTFOLIO_STRATEGY = "ePortfolioStrategy"

View File

@ -0,0 +1,624 @@
""""""
import importlib
import os
import traceback
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Set, Tuple, Type, Any, Callable
from datetime import datetime, timedelta
from concurrent.futures import ThreadPoolExecutor
from vnpy.event import Event, EventEngine
from vnpy.trader.engine import BaseEngine, MainEngine
from vnpy.trader.object import (
OrderRequest,
SubscribeRequest,
HistoryRequest,
LogData,
TickData,
OrderData,
TradeData,
PositionData,
BarData,
ContractData
)
from vnpy.trader.event import (
EVENT_TICK,
EVENT_ORDER,
EVENT_TRADE,
EVENT_POSITION
)
from vnpy.trader.constant import (
Direction,
OrderType,
Interval,
Exchange,
Offset
)
from vnpy.trader.utility import load_json, save_json, extract_vt_symbol, round_to
from vnpy.trader.database import database_manager
from vnpy.trader.rqdata import rqdata_client
from vnpy.trader.converter import OffsetConverter
from .base import (
APP_NAME,
EVENT_PORTFOLIO_LOG,
EVENT_PORTFOLIO_STRATEGY
)
from .template import StrategyTemplate
class StrategyEngine(BaseEngine):
""""""
setting_filename = "portfolio_strategy_setting.json"
data_filename = "portfolio_strategy_data.json"
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
""""""
super().__init__(main_engine, event_engine, APP_NAME)
self.strategy_data: Dict[str, Dict] = {}
self.classes: Dict[str, Type[StrategyTemplate]] = {}
self.strategies: Dict[str, StrategyTemplate] = {}
self.symbol_strategy_map: Dict[str, List[StrategyTemplate]] = defaultdict(list)
self.orderid_strategy_map: Dict[str, StrategyTemplate] = {}
self.init_executor: ThreadPoolExecutor = ThreadPoolExecutor(max_workers=1)
self.vt_tradeids: Set[str] = set()
self.offset_converter: OffsetConverter = OffsetConverter(self.main_engine)
def init_engine(self):
"""
"""
self.init_rqdata()
self.load_strategy_class()
self.load_strategy_setting()
self.load_strategy_data()
self.register_event()
self.write_log("组合策略引擎初始化成功")
def close(self):
""""""
self.stop_all_strategies()
def register_event(self):
""""""
self.event_engine.register(EVENT_TICK, self.process_tick_event)
self.event_engine.register(EVENT_ORDER, self.process_order_event)
self.event_engine.register(EVENT_TRADE, self.process_trade_event)
self.event_engine.register(EVENT_POSITION, self.process_position_event)
def init_rqdata(self):
"""
Init RQData client.
"""
result = rqdata_client.init()
if result:
self.write_log("RQData数据接口初始化成功")
def query_bar_from_rq(
self, symbol: str, exchange: Exchange, interval: Interval, start: datetime, end: datetime
):
"""
Query bar data from RQData.
"""
req = HistoryRequest(
symbol=symbol,
exchange=exchange,
interval=interval,
start=start,
end=end
)
data = rqdata_client.query_history(req)
return data
def process_tick_event(self, event: Event):
""""""
tick: TickData = event.data
strategies = self.symbol_strategy_map[tick.vt_symbol]
if not strategies:
return
for strategy in strategies:
if strategy.inited:
self.call_strategy_func(strategy, strategy.on_tick, tick)
def process_order_event(self, event: Event):
""""""
order: OrderData = event.data
self.offset_converter.update_order(order)
strategy = self.orderid_strategy_map.get(order.vt_orderid, None)
if not strategy:
return
self.call_strategy_func(strategy, strategy.update_order, order)
def process_trade_event(self, event: Event):
""""""
trade: TradeData = event.data
# Filter duplicate trade push
if trade.vt_tradeid in self.vt_tradeids:
return
self.vt_tradeids.add(trade.vt_tradeid)
self.offset_converter.update_trade(trade)
strategy = self.orderid_strategy_map.get(trade.vt_orderid, None)
if not strategy:
return
self.call_strategy_func(strategy, strategy.update_trade, trade)
def process_position_event(self, event: Event):
""""""
position: PositionData = event.data
self.offset_converter.update_position(position)
def send_order(
self,
strategy: StrategyTemplate,
vt_symbol: str,
direction: Direction,
offset: Offset,
price: float,
volume: float,
lock: bool
):
"""
Send a new order to server.
"""
contract: ContractData = self.main_engine.get_contract(vt_symbol)
if not contract:
self.write_log(f"委托失败,找不到合约:{vt_symbol}", strategy)
return ""
# Round order price and volume to nearest incremental value
price = round_to(price, contract.pricetick)
volume = round_to(volume, contract.min_volume)
# Create request and send order.
original_req = OrderRequest(
symbol=contract.symbol,
exchange=contract.exchange,
direction=direction,
offset=offset,
type=OrderType.LIMIT,
price=price,
volume=volume,
)
# Convert with offset converter
req_list = self.offset_converter.convert_order_request(original_req, lock)
# Send Orders
vt_orderids = []
for req in req_list:
req.reference = strategy.strategy_name # Add strategy name as order reference
vt_orderid = self.main_engine.send_order(
req, contract.gateway_name)
# Check if sending order successful
if not vt_orderid:
continue
vt_orderids.append(vt_orderid)
self.offset_converter.update_order_request(req, vt_orderid)
# Save relationship between orderid and strategy.
self.orderid_strategy_map[vt_orderid] = strategy
return vt_orderids
def cancel_order(self, strategy: StrategyTemplate, vt_orderid: str):
"""
"""
order = self.main_engine.get_order(vt_orderid)
if not order:
self.write_log(f"撤单失败,找不到委托{vt_orderid}", strategy)
return
req = order.create_cancel_request()
self.main_engine.cancel_order(req, order.gateway_name)
def load_bars(self, strategy: StrategyTemplate, days: int, interval: Interval):
""""""
vt_symbols = strategy.vt_symbols
dts: Set[datetime] = set()
history_data: Dict[Tuple, BarData] = {}
# Load data from rqdata/gateway/database
for vt_symbol in vt_symbols:
data = self.load_bar(vt_symbol, days, interval)
for bar in data:
dts.add(bar.datetime)
history_data[(bar.datetime, vt_symbol)] = bar
# Convert data structure and push to strategy
dts = list(dts)
dts.sort()
for dt in dts:
bars = {}
for vt_symbol in vt_symbols:
bar = history_data.get((dt, vt_symbol), None)
if bar:
bars[vt_symbol] = bar
else:
dt_str = dt.strftime("%Y-%m-%d %H:%M:%S")
self.write_log(f"数据缺失:{dt_str} {vt_symbol}", strategy)
self.call_strategy_func(strategy, strategy.on_bars, bars)
def load_bar(self, vt_symbol: str, days: int, interval: Interval) -> List[BarData]:
""""""
symbol, exchange = extract_vt_symbol(vt_symbol)
end = datetime.now()
start = end - timedelta(days)
contract: ContractData = self.main_engine.get_contract(vt_symbol)
data = []
# Query bars from gateway if available
if contract and contract.history_data:
req = HistoryRequest(
symbol=symbol,
exchange=exchange,
interval=interval,
start=start,
end=end
)
data = self.main_engine.query_history(req, contract.gateway_name)
# Try to query bars from RQData, if not found, load from database.
else:
data = self.query_bar_from_rq(symbol, exchange, interval, start, end)
if not data:
data = database_manager.load_bar_data(
symbol=symbol,
exchange=exchange,
interval=interval,
start=start,
end=end,
)
return data
def call_strategy_func(
self, strategy: StrategyTemplate, func: Callable, params: Any = None
):
"""
Call function of a strategy and catch any exception raised.
"""
try:
if params:
func(params)
else:
func()
except Exception:
strategy.trading = False
strategy.inited = False
msg = f"触发异常已停止\n{traceback.format_exc()}"
self.write_log(msg, strategy)
def add_strategy(
self, class_name: str, strategy_name: str, vt_symbols: str, setting: dict
):
"""
Add a new strategy.
"""
if strategy_name in self.strategies:
self.write_log(f"创建策略失败,存在重名{strategy_name}")
return
strategy_class = self.classes.get(class_name, None)
if not strategy_class:
self.write_log(f"创建策略失败,找不到策略类{class_name}")
return
strategy = strategy_class(self, strategy_name, vt_symbols, setting)
self.strategies[strategy_name] = strategy
# Add vt_symbol to strategy map.
for vt_symbol in vt_symbols:
strategies = self.symbol_strategy_map[vt_symbol]
strategies.append(strategy)
self.save_strategy_setting()
self.put_strategy_event(strategy)
def init_strategy(self, strategy_name: str):
"""
Init a strategy.
"""
self.init_executor.submit(self._init_strategy, strategy_name)
def _init_strategy(self, strategy_name: str):
"""
Init strategies in queue.
"""
strategy = self.strategies[strategy_name]
if strategy.inited:
self.write_log(f"{strategy_name}已经完成初始化,禁止重复操作")
return
self.write_log(f"{strategy_name}开始执行初始化")
# Call on_init function of strategy
self.call_strategy_func(strategy, strategy.on_init)
# Restore strategy data(variables)
data = self.strategy_data.get(strategy_name, None)
if data:
for name in strategy.variables:
value = data.get(name, None)
if value:
setattr(strategy, name, value)
# Subscribe market data
for vt_symbol in strategy.vt_symbols:
contract: ContractData = self.main_engine.get_contract(vt_symbol)
if contract:
req = SubscribeRequest(
symbol=contract.symbol, exchange=contract.exchange)
self.main_engine.subscribe(req, contract.gateway_name)
else:
self.write_log(f"行情订阅失败,找不到合约{vt_symbol}", strategy)
# Put event to update init completed status.
strategy.inited = True
self.put_strategy_event(strategy)
self.write_log(f"{strategy_name}初始化完成")
def start_strategy(self, strategy_name: str):
"""
Start a strategy.
"""
strategy = self.strategies[strategy_name]
if not strategy.inited:
self.write_log(f"策略{strategy.strategy_name}启动失败,请先初始化")
return
if strategy.trading:
self.write_log(f"{strategy_name}已经启动,请勿重复操作")
return
self.call_strategy_func(strategy, strategy.on_start)
strategy.trading = True
self.put_strategy_event(strategy)
def stop_strategy(self, strategy_name: str):
"""
Stop a strategy.
"""
strategy = self.strategies[strategy_name]
if not strategy.trading:
return
# Call on_stop function of the strategy
self.call_strategy_func(strategy, strategy.on_stop)
# Change trading status of strategy to False
strategy.trading = False
# Cancel all orders of the strategy
strategy.cancel_all()
# Sync strategy variables to data file
self.sync_strategy_data(strategy)
# Update GUI
self.put_strategy_event(strategy)
def edit_strategy(self, strategy_name: str, setting: dict):
"""
Edit parameters of a strategy.
"""
strategy = self.strategies[strategy_name]
strategy.update_setting(setting)
self.save_strategy_setting()
self.put_strategy_event(strategy)
def remove_strategy(self, strategy_name: str):
"""
Remove a strategy.
"""
strategy = self.strategies[strategy_name]
if strategy.trading:
self.write_log(f"策略{strategy.strategy_name}移除失败,请先停止")
return
# Remove from symbol strategy map
for vt_symbol in strategy.vt_symbols:
strategies = self.symbol_strategy_map[vt_symbol]
strategies.remove(strategy)
# Remove from vt_orderid strategy map
for vt_orderid in strategy.active_orderids:
if vt_orderid in self.orderid_strategy_map:
self.orderid_strategy_map.pop(vt_orderid)
# Remove from strategies
self.strategies.pop(strategy_name)
self.save_strategy_setting()
return True
def load_strategy_class(self):
"""
Load strategy class from source code.
"""
path1 = Path(__file__).parent.joinpath("strategies")
self.load_strategy_class_from_folder(path1, "vnpy.app.portfolio_strategy.strategies")
path2 = Path.cwd().joinpath("strategies")
self.load_strategy_class_from_folder(path2, "strategies")
def load_strategy_class_from_folder(self, path: Path, module_name: str = ""):
"""
Load strategy class from certain folder.
"""
for dirpath, dirnames, filenames in os.walk(str(path)):
for filename in filenames:
if filename.endswith(".py"):
strategy_module_name = ".".join(
[module_name, filename.replace(".py", "")])
elif filename.endswith(".pyd"):
strategy_module_name = ".".join(
[module_name, filename.split(".")[0]])
else:
continue
self.load_strategy_class_from_module(strategy_module_name)
def load_strategy_class_from_module(self, module_name: str):
"""
Load strategy class from module file.
"""
try:
module = importlib.import_module(module_name)
for name in dir(module):
value = getattr(module, name)
if (isinstance(value, type) and issubclass(value, StrategyTemplate) and value is not StrategyTemplate):
self.classes[value.__name__] = value
except: # noqa
msg = f"策略文件{module_name}加载失败,触发异常:\n{traceback.format_exc()}"
self.write_log(msg)
def load_strategy_data(self):
"""
Load strategy data from json file.
"""
self.strategy_data = load_json(self.data_filename)
def sync_strategy_data(self, strategy: StrategyTemplate):
"""
Sync strategy data into json file.
"""
data = strategy.get_variables()
data.pop("inited") # Strategy status (inited, trading) should not be synced.
data.pop("trading")
self.strategy_data[strategy.strategy_name] = data
save_json(self.data_filename, self.strategy_data)
def get_all_strategy_class_names(self):
"""
Return names of strategy classes loaded.
"""
return list(self.classes.keys())
def get_strategy_class_parameters(self, class_name: str):
"""
Get default parameters of a strategy class.
"""
strategy_class = self.classes[class_name]
parameters = {}
for name in strategy_class.parameters:
parameters[name] = getattr(strategy_class, name)
return parameters
def get_strategy_parameters(self, strategy_name):
"""
Get parameters of a strategy.
"""
strategy = self.strategies[strategy_name]
return strategy.get_parameters()
def init_all_strategies(self):
"""
"""
for strategy_name in self.strategies.keys():
self.init_strategy(strategy_name)
def start_all_strategies(self):
"""
"""
for strategy_name in self.strategies.keys():
self.start_strategy(strategy_name)
def stop_all_strategies(self):
"""
"""
for strategy_name in self.strategies.keys():
self.stop_strategy(strategy_name)
def load_strategy_setting(self):
"""
Load setting file.
"""
strategy_setting = load_json(self.setting_filename)
for strategy_name, strategy_config in strategy_setting.items():
self.add_strategy(
strategy_config["class_name"],
strategy_name,
strategy_config["vt_symbols"],
strategy_config["setting"]
)
def save_strategy_setting(self):
"""
Save setting file.
"""
strategy_setting = {}
for name, strategy in self.strategies.items():
strategy_setting[name] = {
"class_name": strategy.__class__.__name__,
"vt_symbols": strategy.vt_symbols,
"setting": strategy.get_parameters()
}
save_json(self.setting_filename, strategy_setting)
def put_strategy_event(self, strategy: StrategyTemplate):
"""
Put an event to update strategy status.
"""
data = strategy.get_data()
event = Event(EVENT_PORTFOLIO_STRATEGY, data)
self.event_engine.put(event)
def write_log(self, msg: str, strategy: StrategyTemplate = None):
"""
Create cta engine log event.
"""
if strategy:
msg = f"{strategy.strategy_name}: {msg}"
log = LogData(msg=msg, gateway_name="CtaStrategy")
event = Event(type=EVENT_PORTFOLIO_LOG, data=log)
self.event_engine.put(event)
def send_email(self, msg: str, strategy: StrategyTemplate = None):
"""
Send email to default receiver.
"""
if strategy:
subject = f"{strategy.strategy_name}"
else:
subject = "组合策略引擎"
self.main_engine.send_email(subject, msg)

View File

@ -0,0 +1,139 @@
from typing import List, Dict
from vnpy.app.portfolio_strategy import StrategyTemplate, StrategyEngine
from vnpy.trader.utility import BarGenerator, ArrayManager
from vnpy.trader.object import TickData, BarData
class TrendFollowingStrategy(StrategyTemplate):
""""""
author = "用Python的交易员"
atr_window = 22
atr_ma_window = 10
rsi_window = 5
rsi_entry = 16
trailing_percent = 0.8
fixed_size = 1
atr_value = 0
atr_ma = 0
rsi_value = 0
rsi_buy = 0
rsi_sell = 0
intra_trade_high = 0
intra_trade_low = 0
parameters = [
"atr_window",
"atr_ma_window",
"rsi_window",
"rsi_entry",
"trailing_percent",
"fixed_size"
]
variables = [
"atr_value",
"atr_ma",
"rsi_value",
"rsi_buy",
"rsi_sell"
]
def __init__(
self,
strategy_engine: StrategyEngine,
strategy_name: str,
vt_symbols: List[str],
setting: dict
):
""""""
super().__init__(strategy_engine, strategy_name, vt_symbols, setting)
self.vt_symbol = vt_symbols[0]
self.bg = BarGenerator(self.on_bar)
self.am = ArrayManager()
def on_init(self):
"""
Callback when strategy is inited.
"""
self.write_log("策略初始化")
self.rsi_buy = 50 + self.rsi_entry
self.rsi_sell = 50 - self.rsi_entry
self.load_bars(10)
def on_start(self):
"""
Callback when strategy is started.
"""
self.write_log("策略启动")
def on_stop(self):
"""
Callback when strategy is stopped.
"""
self.write_log("策略停止")
def on_tick(self, tick: TickData):
"""
Callback of new tick data update.
"""
self.bg.update_tick(tick)
def on_bar(self, bar: BarData):
"""
Callback of new bar data update.
"""
bars = {bar.vt_symbol: bar}
self.on_bars(bars)
def on_bars(self, bars: Dict[str, BarData]):
""""""
self.cancel_all()
bar = bars[self.vt_symbol]
am = self.am
am.update_bar(bar)
if not am.inited:
return
atr_array = am.atr(self.atr_window, array=True)
self.atr_value = atr_array[-1]
self.atr_ma = atr_array[-self.atr_ma_window:].mean()
self.rsi_value = am.rsi(self.rsi_window)
pos = self.get_pos(self.vt_symbol)
if pos == 0:
self.intra_trade_high = bar.high_price
self.intra_trade_low = bar.low_price
if self.atr_value > self.atr_ma:
if self.rsi_value > self.rsi_buy:
self.buy(self.vt_symbol, bar.close_price + 5, self.fixed_size)
elif self.rsi_value < self.rsi_sell:
self.short(self.vt_symbol, bar.close_price - 5, self.fixed_size)
elif pos > 0:
self.intra_trade_high = max(self.intra_trade_high, bar.high_price)
self.intra_trade_low = bar.low_price
long_stop = self.intra_trade_high * (1 - self.trailing_percent / 100)
if bar.close_price <= long_stop:
self.sell(self.vt_symbol, bar.close_price - 5, abs(pos))
elif pos < 0:
self.intra_trade_low = min(self.intra_trade_low, bar.low_price)
self.intra_trade_high = bar.high_price
short_stop = self.intra_trade_low * (1 + self.trailing_percent / 100)
if bar.close_price >= short_stop:
self.cover(self.vt_symbol, bar.close_price + 5, abs(pos))
self.put_event()

View File

@ -0,0 +1,247 @@
""""""
from abc import ABC
from copy import copy
from typing import Dict, Set, List, TYPE_CHECKING
from collections import defaultdict
from vnpy.trader.constant import Interval, Direction, Offset
from vnpy.trader.object import BarData, TickData, OrderData, TradeData
from vnpy.trader.utility import virtual
if TYPE_CHECKING:
from .engine import StrategyEngine
class StrategyTemplate(ABC):
""""""
author = ""
parameters = []
variables = []
def __init__(
self,
strategy_engine: "StrategyEngine",
strategy_name: str,
vt_symbols: List[str],
setting: dict,
):
""""""
self.strategy_engine: "StrategyEngine" = strategy_engine
self.strategy_name: str = strategy_name
self.vt_symbols: List[str] = vt_symbols
self.inited: bool = False
self.trading: bool = False
self.pos: Dict[str, int] = defaultdict(int)
self.orders: Dict[str, OrderData] = {}
self.active_orderids: Set[str] = set()
# Copy a new variables list here to avoid duplicate insert when multiple
# strategy instances are created with the same strategy class.
self.variables: Dict = copy(self.variables)
self.variables.insert(0, "inited")
self.variables.insert(1, "trading")
self.update_setting(setting)
def update_setting(self, setting: dict) -> None:
"""
Update strategy parameter wtih value in setting dict.
"""
for name in self.parameters:
if name in setting:
setattr(self, name, setting[name])
@classmethod
def get_class_parameters(cls) -> Dict:
"""
Get default parameters dict of strategy class.
"""
class_parameters = {}
for name in cls.parameters:
class_parameters[name] = getattr(cls, name)
return class_parameters
def get_parameters(self) -> Dict:
"""
Get strategy parameters dict.
"""
strategy_parameters = {}
for name in self.parameters:
strategy_parameters[name] = getattr(self, name)
return strategy_parameters
def get_variables(self) -> Dict:
"""
Get strategy variables dict.
"""
strategy_variables = {}
for name in self.variables:
strategy_variables[name] = getattr(self, name)
return strategy_variables
def get_data(self) -> Dict:
"""
Get strategy data.
"""
strategy_data = {
"strategy_name": self.strategy_name,
"vt_symbols": self.vt_symbols,
"class_name": self.__class__.__name__,
"author": self.author,
"parameters": self.get_parameters(),
"variables": self.get_variables(),
}
return strategy_data
@virtual
def on_init(self) -> None:
"""
Callback when strategy is inited.
"""
pass
@virtual
def on_start(self) -> None:
"""
Callback when strategy is started.
"""
pass
@virtual
def on_stop(self) -> None:
"""
Callback when strategy is stopped.
"""
pass
@virtual
def on_tick(self, tick: TickData) -> None:
"""
Callback of new tick data update.
"""
pass
@virtual
def on_bars(self, bars: Dict[str, BarData]) -> None:
"""
Callback of new bar data update.
"""
pass
def update_trade(self, trade: TradeData) -> None:
"""
Callback of new trade data update.
"""
if trade.direction == Direction.LONG:
self.pos[trade.vt_symbol] += trade.volume
else:
self.pos[trade.vt_symbol] -= trade.volume
def update_order(self, order: OrderData) -> None:
"""
Callback of new order data update.
"""
self.orders[order.vt_orderid] = order
if order.is_active():
self.active_orderids.add(order.vt_orderid)
elif order.vt_orderid in self.active_orderids:
self.active_orderids.remove(order.vt_orderid)
def buy(self, vt_symbol: str, price: float, volume: float, lock: bool = False) -> List[str]:
"""
Send buy order to open a long position.
"""
return self.send_order(vt_symbol, Direction.LONG, Offset.OPEN, price, volume, lock)
def sell(self, vt_symbol: str, price: float, volume: float, lock: bool = False) -> List[str]:
"""
Send sell order to close a long position.
"""
return self.send_order(vt_symbol, Direction.SHORT, Offset.CLOSE, price, volume, lock)
def short(self, vt_symbol: str, price: float, volume: float, lock: bool = False) -> List[str]:
"""
Send short order to open as short position.
"""
return self.send_order(vt_symbol, Direction.SHORT, Offset.OPEN, price, volume, lock)
def cover(self, vt_symbol: str, price: float, volume: float, lock: bool = False) -> List[str]:
"""
Send cover order to close a short position.
"""
return self.send_order(vt_symbol, Direction.LONG, Offset.CLOSE, price, volume, lock)
def send_order(
self,
vt_symbol: str,
direction: Direction,
offset: Offset,
price: float,
volume: float,
lock: bool = False
) -> List[str]:
"""
Send a new order.
"""
if self.trading:
vt_orderids = self.strategy_engine.send_order(
self, vt_symbol, direction, offset, price, volume, lock
)
return vt_orderids
else:
return []
def cancel_order(self, vt_orderid: str) -> None:
"""
Cancel an existing order.
"""
if self.trading:
self.strategy_engine.cancel_order(self, vt_orderid)
def cancel_all(self) -> None:
"""
Cancel all orders sent by strategy.
"""
for vt_orderid in list(self.active_orderids):
self.cancel_order(vt_orderid)
def get_pos(self, vt_symbol: str) -> int:
""""""
return self.pos.get(vt_symbol, 0)
def get_order(self, vt_orderid: str) -> OrderData:
""""""
return self.orders.get(vt_orderid, None)
def get_all_active_orderids(self) -> List[OrderData]:
""""""
return list(self.active_orderids.values())
def write_log(self, msg: str) -> None:
"""
Write a log message.
"""
self.strategy_engine.write_log(msg, self)
def load_bars(self, days: int, interval: Interval = Interval.MINUTE) -> None:
"""
Load historical bar data for initializing strategy.
"""
self.strategy_engine.load_bars(self, days, interval)
def put_event(self) -> None:
"""
Put an strategy data event for ui update.
"""
if self.inited:
self.strategy_engine.put_strategy_event(self)
def send_email(self, msg) -> None:
"""
Send email to default receiver.
"""
if self.inited:
self.strategy_engine.send_email(msg, self)

View File

@ -0,0 +1 @@
from .widget import PortfolioStrategyManager

Binary file not shown.

After

Width:  |  Height:  |  Size: 66 KiB

View File

@ -0,0 +1,439 @@
from typing import Dict
from vnpy.event import Event, EventEngine
from vnpy.trader.engine import MainEngine
from vnpy.trader.ui import QtCore, QtGui, QtWidgets
from vnpy.trader.ui.widget import (
MsgCell,
TimeCell,
BaseMonitor
)
from ..base import (
APP_NAME,
EVENT_PORTFOLIO_LOG,
EVENT_PORTFOLIO_STRATEGY
)
from ..engine import StrategyEngine
class PortfolioStrategyManager(QtWidgets.QWidget):
""""""
signal_log = QtCore.pyqtSignal(Event)
signal_strategy = QtCore.pyqtSignal(Event)
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
""""""
super().__init__()
self.main_engine: MainEngine = main_engine
self.event_engine: EventEngine = event_engine
self.strategy_engine: StrategyEngine = main_engine.get_engine(APP_NAME)
self.managers: Dict[str, StrategyManager] = {}
self.init_ui()
self.register_event()
self.strategy_engine.init_engine()
self.update_class_combo()
def init_ui(self) -> None:
""""""
self.setWindowTitle("组合策略")
# Create widgets
self.class_combo = QtWidgets.QComboBox()
add_button = QtWidgets.QPushButton("添加策略")
add_button.clicked.connect(self.add_strategy)
init_button = QtWidgets.QPushButton("全部初始化")
init_button.clicked.connect(self.strategy_engine.init_all_strategies)
start_button = QtWidgets.QPushButton("全部启动")
start_button.clicked.connect(self.strategy_engine.start_all_strategies)
stop_button = QtWidgets.QPushButton("全部停止")
stop_button.clicked.connect(self.strategy_engine.stop_all_strategies)
clear_button = QtWidgets.QPushButton("清空日志")
clear_button.clicked.connect(self.clear_log)
self.scroll_layout = QtWidgets.QVBoxLayout()
self.scroll_layout.addStretch()
scroll_widget = QtWidgets.QWidget()
scroll_widget.setLayout(self.scroll_layout)
scroll_area = QtWidgets.QScrollArea()
scroll_area.setWidgetResizable(True)
scroll_area.setWidget(scroll_widget)
self.log_monitor = LogMonitor(self.main_engine, self.event_engine)
# Set layout
hbox1 = QtWidgets.QHBoxLayout()
hbox1.addWidget(self.class_combo)
hbox1.addWidget(add_button)
hbox1.addStretch()
hbox1.addWidget(init_button)
hbox1.addWidget(start_button)
hbox1.addWidget(stop_button)
hbox1.addWidget(clear_button)
hbox2 = QtWidgets.QHBoxLayout()
hbox2.addWidget(scroll_area)
hbox2.addWidget(self.log_monitor)
vbox = QtWidgets.QVBoxLayout()
vbox.addLayout(hbox1)
vbox.addLayout(hbox2)
self.setLayout(vbox)
def update_class_combo(self):
""""""
self.class_combo.addItems(
self.strategy_engine.get_all_strategy_class_names()
)
def register_event(self):
""""""
self.signal_strategy.connect(self.process_strategy_event)
self.event_engine.register(
EVENT_PORTFOLIO_STRATEGY, self.signal_strategy.emit
)
def process_strategy_event(self, event):
"""
Update strategy status onto its monitor.
"""
data = event.data
strategy_name = data["strategy_name"]
if strategy_name in self.managers:
manager = self.managers[strategy_name]
manager.update_data(data)
else:
manager = StrategyManager(self, self.strategy_engine, data)
self.scroll_layout.insertWidget(0, manager)
self.managers[strategy_name] = manager
def remove_strategy(self, strategy_name):
""""""
manager = self.managers.pop(strategy_name)
manager.deleteLater()
def add_strategy(self):
""""""
class_name = str(self.class_combo.currentText())
if not class_name:
return
parameters = self.strategy_engine.get_strategy_class_parameters(class_name)
editor = SettingEditor(parameters, class_name=class_name)
n = editor.exec_()
if n == editor.Accepted:
setting = editor.get_setting()
vt_symbols = setting.pop("vt_symbols").split(",")
strategy_name = setting.pop("strategy_name")
self.strategy_engine.add_strategy(
class_name, strategy_name, vt_symbols, setting
)
def clear_log(self):
""""""
self.log_monitor.setRowCount(0)
def show(self):
""""""
self.showMaximized()
class StrategyManager(QtWidgets.QFrame):
"""
Manager for a strategy
"""
def __init__(
self,
strategy_manager: PortfolioStrategyManager,
strategy_engine: StrategyEngine,
data: dict
):
""""""
super().__init__()
self.strategy_manager = strategy_manager
self.strategy_engine = strategy_engine
self.strategy_name = data["strategy_name"]
self._data = data
self.init_ui()
def init_ui(self):
""""""
self.setFixedHeight(300)
self.setFrameShape(self.Box)
self.setLineWidth(1)
self.init_button = QtWidgets.QPushButton("初始化")
self.init_button.clicked.connect(self.init_strategy)
self.start_button = QtWidgets.QPushButton("启动")
self.start_button.clicked.connect(self.start_strategy)
self.start_button.setEnabled(False)
self.stop_button = QtWidgets.QPushButton("停止")
self.stop_button.clicked.connect(self.stop_strategy)
self.stop_button.setEnabled(False)
self.edit_button = QtWidgets.QPushButton("编辑")
self.edit_button.clicked.connect(self.edit_strategy)
self.remove_button = QtWidgets.QPushButton("移除")
self.remove_button.clicked.connect(self.remove_strategy)
strategy_name = self._data["strategy_name"]
class_name = self._data["class_name"]
author = self._data["author"]
label_text = (
f"{strategy_name} - ({class_name} by {author})"
)
label = QtWidgets.QLabel(label_text)
label.setAlignment(QtCore.Qt.AlignCenter)
self.parameters_monitor = DataMonitor(self._data["parameters"])
self.variables_monitor = DataMonitor(self._data["variables"])
hbox = QtWidgets.QHBoxLayout()
hbox.addWidget(self.init_button)
hbox.addWidget(self.start_button)
hbox.addWidget(self.stop_button)
hbox.addWidget(self.edit_button)
hbox.addWidget(self.remove_button)
vbox = QtWidgets.QVBoxLayout()
vbox.addWidget(label)
vbox.addLayout(hbox)
vbox.addWidget(self.parameters_monitor)
vbox.addWidget(self.variables_monitor)
self.setLayout(vbox)
def update_data(self, data: dict):
""""""
self._data = data
self.parameters_monitor.update_data(data["parameters"])
self.variables_monitor.update_data(data["variables"])
# Update button status
variables = data["variables"]
inited = variables["inited"]
trading = variables["trading"]
if not inited:
return
self.init_button.setEnabled(False)
if trading:
self.start_button.setEnabled(False)
self.stop_button.setEnabled(True)
self.edit_button.setEnabled(False)
self.remove_button.setEnabled(False)
else:
self.start_button.setEnabled(True)
self.stop_button.setEnabled(False)
self.edit_button.setEnabled(True)
self.remove_button.setEnabled(True)
def init_strategy(self):
""""""
self.strategy_engine.init_strategy(self.strategy_name)
def start_strategy(self):
""""""
self.strategy_engine.start_strategy(self.strategy_name)
def stop_strategy(self):
""""""
self.strategy_engine.stop_strategy(self.strategy_name)
def edit_strategy(self):
""""""
strategy_name = self._data["strategy_name"]
parameters = self.strategy_engine.get_strategy_parameters(strategy_name)
editor = SettingEditor(parameters, strategy_name=strategy_name)
n = editor.exec_()
if n == editor.Accepted:
setting = editor.get_setting()
self.strategy_engine.edit_strategy(strategy_name, setting)
def remove_strategy(self):
""""""
result = self.strategy_engine.remove_strategy(self.strategy_name)
# Only remove strategy gui manager if it has been removed from engine
if result:
self.strategy_manager.remove_strategy(self.strategy_name)
class DataMonitor(QtWidgets.QTableWidget):
"""
Table monitor for parameters and variables.
"""
def __init__(self, data: dict):
""""""
super(DataMonitor, self).__init__()
self._data = data
self.cells = {}
self.init_ui()
def init_ui(self):
""""""
labels = list(self._data.keys())
self.setColumnCount(len(labels))
self.setHorizontalHeaderLabels(labels)
self.setRowCount(1)
self.verticalHeader().setSectionResizeMode(
QtWidgets.QHeaderView.Stretch
)
self.verticalHeader().setVisible(False)
self.setEditTriggers(self.NoEditTriggers)
for column, name in enumerate(self._data.keys()):
value = self._data[name]
cell = QtWidgets.QTableWidgetItem(str(value))
cell.setTextAlignment(QtCore.Qt.AlignCenter)
self.setItem(0, column, cell)
self.cells[name] = cell
def update_data(self, data: dict):
""""""
for name, value in data.items():
cell = self.cells[name]
cell.setText(str(value))
class LogMonitor(BaseMonitor):
"""
Monitor for log data.
"""
event_type = EVENT_PORTFOLIO_LOG
data_key = ""
sorting = False
headers = {
"time": {"display": "时间", "cell": TimeCell, "update": False},
"msg": {"display": "信息", "cell": MsgCell, "update": False},
}
def init_ui(self):
"""
Stretch last column.
"""
super(LogMonitor, self).init_ui()
self.horizontalHeader().setSectionResizeMode(
1, QtWidgets.QHeaderView.Stretch
)
def insert_new_row(self, data):
"""
Insert a new row at the top of table.
"""
super().insert_new_row(data)
self.resizeRowToContents(0)
class SettingEditor(QtWidgets.QDialog):
"""
For creating new strategy and editing strategy parameters.
"""
def __init__(
self, parameters: dict, strategy_name: str = "", class_name: str = ""
):
""""""
super(SettingEditor, self).__init__()
self.parameters = parameters
self.strategy_name = strategy_name
self.class_name = class_name
self.edits = {}
self.init_ui()
def init_ui(self):
""""""
form = QtWidgets.QFormLayout()
# Add vt_symbols and name edit if add new strategy
if self.class_name:
self.setWindowTitle(f"添加策略:{self.class_name}")
button_text = "添加"
parameters = {"strategy_name": "", "vt_symbols": ""}
parameters.update(self.parameters)
else:
self.setWindowTitle(f"参数编辑:{self.strategy_name}")
button_text = "确定"
parameters = self.parameters
for name, value in parameters.items():
type_ = type(value)
edit = QtWidgets.QLineEdit(str(value))
if type_ is int:
validator = QtGui.QIntValidator()
edit.setValidator(validator)
elif type_ is float:
validator = QtGui.QDoubleValidator()
edit.setValidator(validator)
form.addRow(f"{name} {type_}", edit)
self.edits[name] = (edit, type_)
button = QtWidgets.QPushButton(button_text)
button.clicked.connect(self.accept)
form.addRow(button)
self.setLayout(form)
def get_setting(self):
""""""
setting = {}
if self.class_name:
setting["class_name"] = self.class_name
for name, tp in self.edits.items():
edit, type_ = tp
value_text = edit.text()
if type_ == bool:
if value_text == "True":
value = True
else:
value = False
else:
value = type_(value_text)
setting[name] = value
return setting

View File

@ -595,7 +595,7 @@ class CtaGridTrade(CtaComponent):
x.type = type # 网格类型标签
# self.open_prices = {} # 套利使用开仓价格symbol:price
def rebuild_grids(self, direction: Direction,
def rebuild_grids(self, directions: list = [],
upper_line: float = 0.0,
down_line: float = 0.0,
middle_line: float = 0.0,
@ -608,7 +608,7 @@ class CtaGridTrade(CtaComponent):
upRate , 上轨网格高度比率
dnRate 下轨网格高度比率
"""
self.write_log(u'重新拉网:direction:{},upline:{},dnline:{}'.format(direction, upper_line, down_line))
self.write_log(u'重新拉网:direction:{},upline:{},dnline:{}'.format(directions, upper_line, down_line))
# 检查上下网格的高度比率不能低于0.5
if upper_rate < 0.5 or down_rate < 0.5:
@ -616,7 +616,7 @@ class CtaGridTrade(CtaComponent):
down_rate = max(0.5, down_rate)
# 重建下网格(移除未挂单、保留开仓得网格、在最低价之下才增加网格
if direction == Direction.LONG:
if len(directions) == 0 or Direction.LONG in directions:
min_long_price = middle_line
remove_grids = []
opened_grids = []
@ -636,8 +636,8 @@ class CtaGridTrade(CtaComponent):
self.write_log(u'保留下网格[{}]'.format(opened_grids))
# 需要重建的剩余网格数量
remainLots = len(self.dn_grids)
lots = self.max_lots - remainLots
remain_lots = len(self.dn_grids)
lots = self.max_lots - remain_lots
down_line = min(down_line, min_long_price - self.grid_height * down_rate)
self.write_log(u'需要重建的网格数量:{0},起点:{1}'.format(lots, down_line))
@ -651,13 +651,13 @@ class CtaGridTrade(CtaComponent):
grid = CtaGrid(direction=Direction.LONG,
open_price=open_price,
close_price=close_price,
volume=self.volume * self.get_volume_rate(remainLots + i))
volume=self.volume * self.get_volume_rate(remain_lots + i))
grid.reuse_count = reuse_count
self.dn_grids.append(grid)
self.write_log(u'重新拉下网格:[{0}==>{1}]'.format(down_line, down_line - lots * self.grid_height * down_rate))
# 重建上网格(移除未挂单、保留开仓得网格、在最高价之上才增加网格
if direction == Direction.SHORT:
if len(directions) == 0 or Direction.SHORT in directions:
max_short_price = middle_line # 最高开空价
remove_grids = [] # 移除的网格列表
opened_grids = [] # 已开仓的网格列表
@ -678,8 +678,8 @@ class CtaGridTrade(CtaComponent):
self.write_log(u'保留上网格[{}]'.format(opened_grids))
# 需要重建的剩余网格数量
remainLots = len(self.up_grids)
lots = self.max_lots - remainLots
remain_lots = len(self.up_grids)
lots = self.max_lots - remain_lots
upper_line = max(upper_line, max_short_price + self.grid_height * upper_rate)
self.write_log(u'需要重建的网格数量:{0},起点:{1}'.format(lots, upper_line))
@ -692,7 +692,7 @@ class CtaGridTrade(CtaComponent):
grid = CtaGrid(direction=Direction.SHORT,
open_price=open_price,
close_price=close_price,
volume=self.volume * self.get_volume_rate(remainLots + i))
volume=self.volume * self.get_volume_rate(remain_lots + i))
grid.reuse_count = reuse_count
self.up_grids.append(grid)

View File

@ -178,13 +178,20 @@ class CtaLineBar(object):
# (实时运行时或者addbar小于bar得周期时不包含最后一根Bar
self.open_array = np.zeros(self.max_hold_bars) # 与lineBar一致得开仓价清单
self.open_array[:] = np.nan
self.high_array = np.zeros(self.max_hold_bars) # 与lineBar一致得最高价清单
self.high_array[:] = np.nan
self.low_array = np.zeros(self.max_hold_bars) # 与lineBar一致得最低价清单
self.low_array[:] = np.nan
self.close_array = np.zeros(self.max_hold_bars) # 与lineBar一致得收盘价清单
self.close_array[:] = np.nan
self.mid3_array = np.zeros(self.max_hold_bars) # 收盘价/最高/最低价 的平均价
self.mid3_array[:] = np.nan
self.mid4_array = np.zeros(self.max_hold_bars) # 收盘价*2/最高/最低价 的平均价
self.mid4_array[:] = np.nan
self.mid5_array = np.zeros(self.max_hold_bars) # 收盘价*2/开仓价/最高/最低价 的平均价
self.mid5_array[:] = np.nan
self.export_filename = None
self.export_fields = []
@ -1263,7 +1270,8 @@ class CtaLineBar(object):
# 2.计算前inputPreLen周期内(不包含当前周期的Bar高点和低点
preHigh = max(self.high_array[-count_len:])
preLow = min(self.low_array[-count_len:])
if np.isnan(preHigh) or np.isnan(preLow):
return
# 保存
if len(self.line_pre_high) > self.max_hold_bars:
del self.line_pre_high[0]
@ -1452,6 +1460,8 @@ class CtaLineBar(object):
count_len = min(self.para_ma1_len, self.bar_len)
barMa1 = ta.MA(self.close_array[-count_len:], count_len)[-1]
if np.isnan(barMa1):
return
barMa1 = round(barMa1, self.round_n)
if len(self.line_ma1) > self.max_hold_bars:
@ -1470,6 +1480,8 @@ class CtaLineBar(object):
if self.para_ma2_len > 0:
count_len = min(self.para_ma2_len, self.bar_len)
barMa2 = ta.MA(self.close_array[-count_len:], count_len)[-1]
if np.isnan(barMa2):
return
barMa2 = round(barMa2, self.round_n)
if len(self.line_ma2) > self.max_hold_bars:
@ -1488,6 +1500,8 @@ class CtaLineBar(object):
if self.para_ma3_len > 0:
count_len = min(self.para_ma3_len, self.bar_len)
barMa3 = ta.MA(self.close_array[-count_len:], count_len)[-1]
if np.isnan(barMa3):
return
barMa3 = round(barMa3, self.round_n)
if len(self.line_ma3) > self.max_hold_bars:
@ -1685,6 +1699,8 @@ class CtaLineBar(object):
# 3、获取前InputN周期(不包含当前周期的K线
barEma1 = ta.EMA(self.close_array[-ema1_data_len:], count_len)[-1]
if np.isnan(barEma1):
return
barEma1 = round(float(barEma1), self.round_n)
if len(self.line_ema1) > self.max_hold_bars:
@ -1698,7 +1714,8 @@ class CtaLineBar(object):
# 3、获取前InputN周期(不包含当前周期)的自适应均线
barEma2 = ta.EMA(self.close_array[-ema2_data_len:], count_len)[-1]
if np.isnan(barEma2):
return
barEma2 = round(float(barEma2), self.round_n)
if len(self.line_ema2) > self.max_hold_bars:
@ -1711,6 +1728,8 @@ class CtaLineBar(object):
# 3、获取前InputN周期(不包含当前周期)的自适应均线
barEma3 = ta.EMA(self.close_array[-ema3_data_len:], count_len)[-1]
if np.isnan(barEma3):
return
barEma3 = round(float(barEma3), self.round_n)
if len(self.line_ema3) > self.max_hold_bars:
@ -1737,6 +1756,8 @@ class CtaLineBar(object):
# 3、获取前InputN周期(不包含当前周期的K线
barEma1 = ta.EMA(np.append(self.close_array[-ema1_data_len:], [self.cur_price]), count_len)[-1]
if np.isnan(barEma1):
return
self._rt_ema1 = round(float(barEma1), self.round_n)
# 计算第二条EMA均线
@ -1746,7 +1767,8 @@ class CtaLineBar(object):
# 3、获取前InputN周期(不包含当前周期)的自适应均线
barEma2 = ta.EMA(np.append(self.close_array[-ema2_data_len:], [self.cur_price]), count_len)[-1]
if np.isnan(barEma2):
return
self._rt_ema2 = round(float(barEma2), self.round_n)
# 计算第三条EMA均线
@ -1755,6 +1777,8 @@ class CtaLineBar(object):
# 3、获取前InputN周期(不包含当前周期)的自适应均线
barEma3 = ta.EMA(np.append(self.close_array[-ema3_data_len:], [self.cur_price]), count_len)[-1]
if np.isnan(barEma3):
return
self._rt_ema3 = round(float(barEma3), self.round_n)
@property
@ -2076,7 +2100,7 @@ class CtaLineBar(object):
return
if self.para_boll_len > 0:
if self.bar_len < min(7, self.para_boll_len):
if self.bar_len < min(20, self.para_boll_len):
self.write_log(u'数据未充分,当前Bar数据数量{0}计算Boll需要{1}'.
format(len(self.line_bar), min(14, self.para_boll_len) + 1))
else:
@ -2086,6 +2110,9 @@ class CtaLineBar(object):
upper_list, middle_list, lower_list = ta.BBANDS(self.close_array,
timeperiod=bollLen, nbdevup=self.para_boll_std_rate,
nbdevdn=self.para_boll_std_rate, matype=0)
if np.isnan(upper_list[-1]):
return
if len(self.line_boll_upper) > self.max_hold_bars:
del self.line_boll_upper[0]
if len(self.line_boll_middle) > self.max_hold_bars:
@ -2144,6 +2171,8 @@ class CtaLineBar(object):
upper_list, middle_list, lower_list = ta.BBANDS(self.close_array,
timeperiod=boll2Len, nbdevup=self.para_boll2_std_rate,
nbdevdn=self.para_boll2_std_rate, matype=0)
if np.isnan(upper_list[-1]):
return
if len(self.line_boll2_upper) > self.max_hold_bars:
del self.line_boll2_upper[0]
if len(self.line_boll2_middle) > self.max_hold_bars:
@ -2510,14 +2539,19 @@ class CtaLineBar(object):
hhv = max(self.high_array[-inputKdjLen:])
llv = min(self.low_array[-inputKdjLen:])
if np.isnan(hhv) or np.isnan(llv):
return
if len(self.line_k) > 0:
lastK = self.line_k[-1]
if np.isnan(lastK):
lastK = 0
else:
lastK = 0
if len(self.line_d) > 0:
lastD = self.line_d[-1]
if np.isnan(lastD):
lastD = 0
else:
lastD = 0
@ -2620,14 +2654,20 @@ class CtaLineBar(object):
hhv = max(self.high_array[-data_len:])
llv = min(self.low_array[-data_len:])
if np.isnan(hhv) or np.isnan(llv):
return
if len(self.line_k) > 0:
lastK = self.line_k[-1]
if np.isnan(lastK):
lastK = 0
else:
lastK = 0
if len(self.line_d) > 0:
lastD = self.line_d[-1]
if np.isnan(lastD):
lastD = 0
else:
lastD = 0
@ -2747,7 +2787,8 @@ class CtaLineBar(object):
dif_list, dea_list, macd_list = ta.MACD(self.close_array[-2 * maxLen:], fastperiod=self.para_macd_fast_len,
slowperiod=self.para_macd_slow_len,
signalperiod=self.para_macd_signal_len)
if np.isnan(dif_list[-1]) or np.isnan(dea_list[-1]) or np.isnan(macd_list[-1]):
return
# dif, dea, macd = ta.MACDEXT(np.array(listClose, dtype=float),
# fastperiod=self.inputMacdFastPeriodLen, fastmatype=1,
# slowperiod=self.inputMacdSlowPeriodLen, slowmatype=1,
@ -2868,6 +2909,9 @@ class CtaLineBar(object):
fastperiod=self.para_macd_fast_len,
slowperiod=self.para_macd_slow_len, signalperiod=self.para_macd_signal_len)
if np.isnan(dif[-1]) or np.isnan(dea[-1]) or np.isnan(macd[-1]):
return
self._rt_dif = round(dif[-1], self.round_n) if len(dif) > 0 else None
self._rt_dea = round(dea[-1], self.round_n) if len(dea) > 0 else None
self._rt_macd = round(macd[-1] * 2, self.round_n) if len(macd) > 0 else None
@ -3958,6 +4002,9 @@ class CtaLineBar(object):
hhv = max(self.high_array[-bar_len:])
llv = min(self.low_array[-bar_len:])
if np.isnan(hhv) or np.isnan(llv):
return
self.cur_p192 = hhv - (hhv - llv) * 0.192
self.cur_p382 = hhv - (hhv - llv) * 0.382
self.cur_p500 = (hhv + llv) / 2

View File

@ -25,6 +25,7 @@ from .event import (
)
from .gateway import BaseGateway
from .object import (
Direction,
Exchange,
CancelRequest,
LogData,
@ -214,7 +215,7 @@ class MainEngine:
"""
# 自定义套利合约,交给算法引擎处理
if self.algo_engine and req.exchange == Exchange.SPD:
return self.algo_engine.send_algo_order(
return self.algo_engine.send_spd_order(
req=req,
gateway_name=gateway_name)
@ -228,6 +229,11 @@ class MainEngine:
"""
Send cancel order request to a specific gateway.
"""
# 自定义套利合约,交给算法引擎处理
if self.algo_engine and req.exchange == Exchange.SPD:
return self.algo_engine.cancel_spd_order(
req=req)
gateway = self.get_gateway(gateway_name)
if gateway:
return gateway.cancel_order(req)
@ -422,13 +428,19 @@ class OmsEngine(BaseEngine):
self.accounts: Dict[str, AccountData] = {}
self.contracts: Dict[str, ContractData] = {}
self.today_contracts: Dict[str, ContractData] = {}
self.custom_contracts = {}
# 自定义合约
self.custom_contracts = {} # vt_symbol: ContractData
self.custom_settings = {} # symbol: dict
self.symbol_spd_maping = {} # symbol: [spd_symbol]
self.prices = {}
self.active_orders: Dict[str, OrderData] = {}
self.add_function()
self.register_event()
self.load_contracts()
def __del__(self):
"""保存缓存"""
@ -445,6 +457,30 @@ class OmsEngine(BaseEngine):
self.contracts = pickle.load(f)
self.write_log(f'加载缓存合约字典:{contract_file_name}')
# 更新自定义合约
custom_contracts = self.get_all_custom_contracts()
for contract in custom_contracts.values():
# 更新合约缓存
self.contracts.update({contract.symbol: contract})
self.contracts.update({contract.vt_symbol: contract})
self.today_contracts[contract.vt_symbol] = contract
self.today_contracts[contract.symbol] = contract
# 获取自定义合约的主动腿/被动腿
setting = self.custom_settings.get(contract.symbol, {})
leg1_symbol = setting.get('leg1_symbol')
leg2_symbol = setting.get('leg2_symbol')
# 构建映射关系
for symbol in [leg1_symbol, leg2_symbol]:
spd_mapping_list = self.symbol_spd_maping.get(symbol, [])
# 更新映射 symbol => spd_symbol
if contract.symbol not in spd_mapping_list:
spd_mapping_list.append(contract.symbol)
self.symbol_spd_maping.update({symbol: spd_mapping_list})
def save_contracts(self) -> None:
"""持久化合约对象到缓存文件"""
import bz2
@ -474,6 +510,7 @@ class OmsEngine(BaseEngine):
self.main_engine.get_all_contracts = self.get_all_contracts
self.main_engine.get_all_active_orders = self.get_all_active_orders
self.main_engine.get_all_custom_contracts = self.get_all_custom_contracts
self.main_engine.get_mapping_spd = self.get_mapping_spd
self.main_engine.save_contracts = self.save_contracts
def register_event(self) -> None:
@ -515,6 +552,76 @@ class OmsEngine(BaseEngine):
position = event.data
self.positions[position.vt_positionid] = position
def reverse_direction(self, direction):
"""返回反向持仓"""
if direction == Direction.LONG:
return Direction.SHORT
elif direction == Direction.SHORT:
return Direction.LONG
return direction
def create_spd_position_event(self, symbol, direction ):
"""创建自定义品种对持仓信息"""
spd_symbols = self.symbol_spd_maping.get(symbol, [])
if not spd_symbols:
return
for spd_symbol in spd_symbols:
spd_setting = self.custom_settings.get(spd_symbol, None)
if not spd_setting:
continue
leg1_symbol = spd_setting.get('leg1_symbol')
leg2_symbol = spd_setting.get('leg2_symbol')
leg1_contract = self.contracts.get(leg1_symbol)
leg2_contract = self.contracts.get(leg2_symbol)
spd_contract = self.contracts.get(spd_symbol)
if leg1_contract is None or leg2_contract is None:
continue
leg1_ratio = spd_setting.get('leg1_ratio', 1)
leg2_ratio = spd_setting.get('leg2_ratio', 1)
# 找出leg1leg2的持仓并判断出spd的方向
if leg1_symbol == symbol:
k1 = f"{leg1_contract.gateway_name}.{leg1_contract.vt_symbol}.{direction.value}"
leg1_pos = self.positions.get(k1)
k2 = f"{leg2_contract.gateway_name}.{leg2_contract.vt_symbol}.{self.reverse_direction(direction).value}"
leg2_pos = self.positions.get(k2)
spd_direction = direction
elif leg2_symbol == symbol:
k1 = f"{leg1_contract.gateway_name}.{leg1_contract.vt_symbol}.{self.reverse_direction(direction).value}"
leg1_pos = self.positions.get(k1)
k2 = f"{leg2_contract.gateway_name}.{leg2_contract.vt_symbol}.{direction.value}"
leg2_pos = self.positions.get(k2)
spd_direction = self.reverse_direction(direction)
else:
continue
if leg1_pos is None or leg2_pos is None or leg1_pos.volume ==0 or leg2_pos.volume == 0:
continue
# 根据leg1/leg2的volume ratio计算出最小spd_volume
spd_volume = min(int(leg1_pos.volume/leg1_ratio), int(leg2_pos.volume/leg2_ratio))
if spd_volume <= 0:
continue
if spd_setting.get('is_ratio', False) and leg2_pos.price > 0:
spd_price = 100 * (leg2_pos.price * leg1_ratio) / (leg2_pos.price * leg2_ratio)
elif spd_setting.get('is_spread', False):
spd_price = leg1_pos.price * leg1_ratio - leg2_pos.price * leg2_ratio
else:
spd_price = 0
spd_pos = PositionData(
gateway_name=spd_contract.gateway_name,
symbol=spd_symbol,
exchange=Exchange.SPD,
direction=spd_direction,
volume=spd_volume,
price=spd_price
)
event = Event(EVENT_POSITION, data=spd_pos)
self.event_engine.put(event)
def process_account_event(self, event: Event) -> None:
""""""
account = event.data
@ -624,16 +731,25 @@ class OmsEngine(BaseEngine):
]
return active_orders
def get_all_custom_contracts(self):
def get_all_custom_contracts(self, rtn_setting=False):
"""
获取所有自定义合约
:return:
"""
if rtn_setting:
if len(self.custom_settings) == 0:
c = CustomContract()
self.custom_settings = c.get_config()
return self.custom_settings
if len(self.custom_contracts) == 0:
c = CustomContract()
self.custom_contracts = c.get_contracts()
return self.custom_contracts
def get_mapping_spd(self, symbol):
"""根据主动腿/被动腿symbol获取自定义套利对的symbol list"""
return self.symbol_spd_maping.get(symbol, [])
class CustomContract(object):
"""
@ -668,6 +784,7 @@ class CustomContract(object):
exchange=vn_exchange,
name=setting.get('name', symbol),
size=setting.get('size', 100),
product=None,
pricetick=setting.get('price_tick', 0.01),
margin_rate=setting.get('margin_rate', 0.1)
)

View File

@ -123,7 +123,7 @@ class BaseGateway(ABC):
Tick event of a specific vt_symbol is also pushed.
"""
self.on_event(EVENT_TICK, tick)
self.on_event(EVENT_TICK + tick.vt_symbol, tick)
# self.on_event(EVENT_TICK + tick.vt_symbol, tick)
# 推送Bar
kline = self.klines.get(tick.vt_symbol, None)
@ -142,7 +142,7 @@ class BaseGateway(ABC):
Trade event of a specific vt_symbol is also pushed.
"""
self.on_event(EVENT_TRADE, trade)
self.on_event(EVENT_TRADE + trade.vt_symbol, trade)
# self.on_event(EVENT_TRADE + trade.vt_symbol, trade)
def on_order(self, order: OrderData) -> None:
"""
@ -150,7 +150,7 @@ class BaseGateway(ABC):
Order event of a specific vt_orderid is also pushed.
"""
self.on_event(EVENT_ORDER, order)
self.on_event(EVENT_ORDER + order.vt_orderid, order)
# self.on_event(EVENT_ORDER + order.vt_orderid, order)
def on_position(self, position: PositionData) -> None:
"""
@ -158,7 +158,7 @@ class BaseGateway(ABC):
Position event of a specific vt_symbol is also pushed.
"""
self.on_event(EVENT_POSITION, position)
self.on_event(EVENT_POSITION + position.vt_symbol, position)
# self.on_event(EVENT_POSITION + position.vt_symbol, position)
def on_account(self, account: AccountData) -> None:
"""
@ -166,7 +166,7 @@ class BaseGateway(ABC):
Account event of a specific vt_accountid is also pushed.
"""
self.on_event(EVENT_ACCOUNT, account)
self.on_event(EVENT_ACCOUNT + account.vt_accountid, account)
# self.on_event(EVENT_ACCOUNT + account.vt_accountid, account)
def on_log(self, log: LogData) -> None:
"""

View File

@ -674,11 +674,11 @@ class BarGenerator:
"""
def __init__(
self,
on_bar: Callable,
window: int = 0,
on_window_bar: Callable = None,
interval: Interval = Interval.MINUTE
self,
on_bar: Callable,
window: int = 0,
on_window_bar: Callable = None,
interval: Interval = Interval.MINUTE
):
"""Constructor"""
self.bar: BarData = None
@ -804,11 +804,14 @@ class BarGenerator:
"""
Generate the bar data and call callback immediately.
"""
self.bar.datetime = self.bar.datetime.replace(
second=0, microsecond=0
)
self.on_bar(self.bar)
bar = self.bar
if self.bar:
bar.datetime = bar.datetime.replace(second=0, microsecond=0)
self.on_bar(bar)
self.bar = None
return bar
class ArrayManager(object):
@ -1067,11 +1070,11 @@ class ArrayManager(object):
return result[-1]
def macd(
self,
fast_period: int,
slow_period: int,
signal_period: int,
array: bool = False
self,
fast_period: int,
slow_period: int,
signal_period: int,
array: bool = False
) -> Union[
Tuple[np.ndarray, np.ndarray, np.ndarray],
Tuple[float, float, float]
@ -1159,10 +1162,10 @@ class ArrayManager(object):
return result[-1]
def boll(
self,
n: int,
dev: float,
array: bool = False
self,
n: int,
dev: float,
array: bool = False
) -> Union[
Tuple[np.ndarray, np.ndarray],
Tuple[float, float]
@ -1179,10 +1182,10 @@ class ArrayManager(object):
return up, down
def keltner(
self,
n: int,
dev: float,
array: bool = False
self,
n: int,
dev: float,
array: bool = False
) -> Union[
Tuple[np.ndarray, np.ndarray],
Tuple[float, float]
@ -1199,7 +1202,7 @@ class ArrayManager(object):
return up, down
def donchian(
self, n: int, array: bool = False
self, n: int, array: bool = False
) -> Union[
Tuple[np.ndarray, np.ndarray],
Tuple[float, float]
@ -1215,10 +1218,9 @@ class ArrayManager(object):
return up[-1], down[-1]
def aroon(
self,
n: int,
dev: float,
array: bool = False
self,
n: int,
array: bool = False
) -> Union[
Tuple[np.ndarray, np.ndarray],
Tuple[float, float]