Merge pull request #2185 from vnpy/dev-spread-backtesting

Dev spread backtesting
This commit is contained in:
vn.py 2019-11-10 17:55:58 +09:00 committed by GitHub
commit a56c37b3af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 1636 additions and 11 deletions

File diff suppressed because one or more lines are too long

View File

@ -628,8 +628,6 @@ class CtaEngine(BaseEngine):
""" """
strategy = self.strategies[strategy_name] strategy = self.strategies[strategy_name]
print(datetime.now(), strategy_name, strategy.vt_symbol)
if strategy.inited: if strategy.inited:
self.write_log(f"{strategy_name}已经完成初始化,禁止重复操作") self.write_log(f"{strategy_name}已经完成初始化,禁止重复操作")
return return

View File

@ -3,7 +3,9 @@ from pathlib import Path
from vnpy.trader.app import BaseApp from vnpy.trader.app import BaseApp
from vnpy.trader.object import ( from vnpy.trader.object import (
OrderData, OrderData,
TradeData TradeData,
TickData,
BarData
) )
from .engine import ( from .engine import (

View File

@ -0,0 +1,685 @@
from collections import defaultdict
from datetime import date, datetime
from typing import Callable, Type
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pandas import DataFrame
from vnpy.trader.constant import (Direction, Offset, Exchange,
Interval, Status)
from vnpy.trader.object import TradeData, BarData, TickData
from .template import SpreadStrategyTemplate, SpreadAlgoTemplate
from .base import SpreadData, BacktestingMode, load_bar_data, load_tick_data
sns.set_style("whitegrid")
class BacktestingEngine:
""""""
gateway_name = "BACKTESTING"
def __init__(self):
""""""
self.spread: SpreadData = None
self.start = None
self.end = None
self.rate = 0
self.slippage = 0
self.size = 1
self.pricetick = 0
self.capital = 1_000_000
self.mode = BacktestingMode.BAR
self.strategy_class: Type[SpreadStrategyTemplate] = None
self.strategy: SpreadStrategyTemplate = None
self.tick: TickData = None
self.bar: BarData = None
self.datetime = None
self.interval = None
self.days = 0
self.callback = None
self.history_data = []
self.algo_count = 0
self.algos = {}
self.active_algos = {}
self.trade_count = 0
self.trades = {}
self.logs = []
self.daily_results = {}
self.daily_df = None
def output(self, msg):
"""
Output message of backtesting engine.
"""
print(f"{datetime.now()}\t{msg}")
def clear_data(self):
"""
Clear all data of last backtesting.
"""
self.strategy = None
self.tick = None
self.bar = None
self.datetime = None
self.algo_count = 0
self.algos.clear()
self.active_algos.clear()
self.trade_count = 0
self.trades.clear()
self.logs.clear()
self.daily_results.clear()
def set_parameters(
self,
spread: SpreadData,
interval: Interval,
start: datetime,
rate: float,
slippage: float,
size: float,
pricetick: float,
capital: int = 0,
end: datetime = None,
mode: BacktestingMode = BacktestingMode.BAR
):
""""""
self.spread = spread
self.interval = Interval(interval)
self.rate = rate
self.slippage = slippage
self.size = size
self.pricetick = pricetick
self.start = start
self.capital = capital
self.end = end
self.mode = mode
def add_strategy(self, strategy_class: type, setting: dict):
""""""
self.strategy_class = strategy_class
self.strategy = strategy_class(
self,
strategy_class.__name__,
self.spread,
setting
)
def load_data(self):
""""""
self.output("开始加载历史数据")
if not self.end:
self.end = datetime.now()
if self.start >= self.end:
self.output("起始日期必须小于结束日期")
return
if self.mode == BacktestingMode.BAR:
self.history_data = load_bar_data(
self.spread,
self.interval,
self.start,
self.end,
self.pricetick
)
else:
self.history_datas = load_tick_data(
self.spread,
self.start,
self.end
)
self.output(f"历史数据加载完成,数据量:{len(self.history_data)}")
def run_backtesting(self):
""""""
if self.mode == BacktestingMode.BAR:
func = self.new_bar
else:
func = self.new_tick
self.strategy.on_init()
# Use the first [days] of history data for initializing strategy
day_count = 0
ix = 0
for ix, data in enumerate(self.history_data):
if self.datetime and data.datetime.day != self.datetime.day:
day_count += 1
if day_count >= self.days:
break
self.datetime = data.datetime
self.callback(data)
self.strategy.inited = True
self.output("策略初始化完成")
self.strategy.on_start()
self.strategy.trading = True
self.output("开始回放历史数据")
# Use the rest of history data for running backtesting
for data in self.history_data[ix:]:
func(data)
self.output("历史数据回放结束")
def calculate_result(self):
""""""
self.output("开始计算逐日盯市盈亏")
if not self.trades:
self.output("成交记录为空,无法计算")
return
# Add trade data into daily reuslt.
for trade in self.trades.values():
d = trade.datetime.date()
daily_result = self.daily_results[d]
daily_result.add_trade(trade)
# Calculate daily result by iteration.
pre_close = 0
start_pos = 0
for daily_result in self.daily_results.values():
daily_result.calculate_pnl(
pre_close,
start_pos,
self.size,
self.rate,
self.slippage
)
pre_close = daily_result.close_price
start_pos = daily_result.end_pos
# Generate dataframe
results = defaultdict(list)
for daily_result in self.daily_results.values():
for key, value in daily_result.__dict__.items():
results[key].append(value)
self.daily_df = DataFrame.from_dict(results).set_index("date")
self.output("逐日盯市盈亏计算完成")
return self.daily_df
def calculate_statistics(self, df: DataFrame = None, output=True):
""""""
self.output("开始计算策略统计指标")
# Check DataFrame input exterior
if df is None:
df = self.daily_df
# Check for init DataFrame
if df is None:
# Set all statistics to 0 if no trade.
start_date = ""
end_date = ""
total_days = 0
profit_days = 0
loss_days = 0
end_balance = 0
max_drawdown = 0
max_ddpercent = 0
max_drawdown_duration = 0
total_net_pnl = 0
daily_net_pnl = 0
total_commission = 0
daily_commission = 0
total_slippage = 0
daily_slippage = 0
total_turnover = 0
daily_turnover = 0
total_trade_count = 0
daily_trade_count = 0
total_return = 0
annual_return = 0
daily_return = 0
return_std = 0
sharpe_ratio = 0
return_drawdown_ratio = 0
else:
# Calculate balance related time series data
df["balance"] = df["net_pnl"].cumsum() + self.capital
df["return"] = np.log(df["balance"] / df["balance"].shift(1)).fillna(0)
df["highlevel"] = (
df["balance"].rolling(
min_periods=1, window=len(df), center=False).max()
)
df["drawdown"] = df["balance"] - df["highlevel"]
df["ddpercent"] = df["drawdown"] / df["highlevel"] * 100
# Calculate statistics value
start_date = df.index[0]
end_date = df.index[-1]
total_days = len(df)
profit_days = len(df[df["net_pnl"] > 0])
loss_days = len(df[df["net_pnl"] < 0])
end_balance = df["balance"].iloc[-1]
max_drawdown = df["drawdown"].min()
max_ddpercent = df["ddpercent"].min()
max_drawdown_end = df["drawdown"].idxmin()
max_drawdown_start = df["balance"][:max_drawdown_end].argmax()
max_drawdown_duration = (max_drawdown_end - max_drawdown_start).days
total_net_pnl = df["net_pnl"].sum()
daily_net_pnl = total_net_pnl / total_days
total_commission = df["commission"].sum()
daily_commission = total_commission / total_days
total_slippage = df["slippage"].sum()
daily_slippage = total_slippage / total_days
total_turnover = df["turnover"].sum()
daily_turnover = total_turnover / total_days
total_trade_count = df["trade_count"].sum()
daily_trade_count = total_trade_count / total_days
total_return = (end_balance / self.capital - 1) * 100
annual_return = total_return / total_days * 240
daily_return = df["return"].mean() * 100
return_std = df["return"].std() * 100
if return_std:
sharpe_ratio = daily_return / return_std * np.sqrt(240)
else:
sharpe_ratio = 0
return_drawdown_ratio = -total_return / max_ddpercent
# Output
if output:
self.output("-" * 30)
self.output(f"首个交易日:\t{start_date}")
self.output(f"最后交易日:\t{end_date}")
self.output(f"总交易日:\t{total_days}")
self.output(f"盈利交易日:\t{profit_days}")
self.output(f"亏损交易日:\t{loss_days}")
self.output(f"起始资金:\t{self.capital:,.2f}")
self.output(f"结束资金:\t{end_balance:,.2f}")
self.output(f"总收益率:\t{total_return:,.2f}%")
self.output(f"年化收益:\t{annual_return:,.2f}%")
self.output(f"最大回撤: \t{max_drawdown:,.2f}")
self.output(f"百分比最大回撤: {max_ddpercent:,.2f}%")
self.output(f"最长回撤天数: \t{max_drawdown_duration}")
self.output(f"总盈亏:\t{total_net_pnl:,.2f}")
self.output(f"总手续费:\t{total_commission:,.2f}")
self.output(f"总滑点:\t{total_slippage:,.2f}")
self.output(f"总成交金额:\t{total_turnover:,.2f}")
self.output(f"总成交笔数:\t{total_trade_count}")
self.output(f"日均盈亏:\t{daily_net_pnl:,.2f}")
self.output(f"日均手续费:\t{daily_commission:,.2f}")
self.output(f"日均滑点:\t{daily_slippage:,.2f}")
self.output(f"日均成交金额:\t{daily_turnover:,.2f}")
self.output(f"日均成交笔数:\t{daily_trade_count}")
self.output(f"日均收益率:\t{daily_return:,.2f}%")
self.output(f"收益标准差:\t{return_std:,.2f}%")
self.output(f"Sharpe Ratio\t{sharpe_ratio:,.2f}")
self.output(f"收益回撤比:\t{return_drawdown_ratio:,.2f}")
statistics = {
"start_date": start_date,
"end_date": end_date,
"total_days": total_days,
"profit_days": profit_days,
"loss_days": loss_days,
"capital": self.capital,
"end_balance": end_balance,
"max_drawdown": max_drawdown,
"max_ddpercent": max_ddpercent,
"max_drawdown_duration": max_drawdown_duration,
"total_net_pnl": total_net_pnl,
"daily_net_pnl": daily_net_pnl,
"total_commission": total_commission,
"daily_commission": daily_commission,
"total_slippage": total_slippage,
"daily_slippage": daily_slippage,
"total_turnover": total_turnover,
"daily_turnover": daily_turnover,
"total_trade_count": total_trade_count,
"daily_trade_count": daily_trade_count,
"total_return": total_return,
"annual_return": annual_return,
"daily_return": daily_return,
"return_std": return_std,
"sharpe_ratio": sharpe_ratio,
"return_drawdown_ratio": return_drawdown_ratio,
}
return statistics
def show_chart(self, df: DataFrame = None):
""""""
# Check DataFrame input exterior
if df is None:
df = self.daily_df
# Check for init DataFrame
if df is None:
return
plt.figure(figsize=(10, 16))
balance_plot = plt.subplot(4, 1, 1)
balance_plot.set_title("Balance")
df["balance"].plot(legend=True)
drawdown_plot = plt.subplot(4, 1, 2)
drawdown_plot.set_title("Drawdown")
drawdown_plot.fill_between(range(len(df)), df["drawdown"].values)
pnl_plot = plt.subplot(4, 1, 3)
pnl_plot.set_title("Daily Pnl")
df["net_pnl"].plot(kind="bar", legend=False, grid=False, xticks=[])
distribution_plot = plt.subplot(4, 1, 4)
distribution_plot.set_title("Daily Pnl Distribution")
df["net_pnl"].hist(bins=50)
plt.show()
def update_daily_close(self, price: float):
""""""
d = self.datetime.date()
daily_result = self.daily_results.get(d, None)
if daily_result:
daily_result.close_price = price
else:
self.daily_results[d] = DailyResult(d, price)
def new_bar(self, bar: BarData):
""""""
self.bar = bar
self.datetime = bar.datetime
self.cross_algo()
self.strategy.on_spread_bar(bar)
self.update_daily_close(bar.close_price)
def new_tick(self, tick: TickData):
""""""
self.tick = tick
self.datetime = tick.datetime
self.cross_algo()
self.spread.bid_price = tick.bid_price_1
self.spread.bid_volume = tick.bid_volume_1
self.spread.ask_price = tick.ask_price_1
self.spread.ask_volume = tick.ask_volume_1
self.strategy.on_spread_data()
self.update_daily_close(tick.last_price)
def cross_algo(self):
"""
Cross limit order with last bar/tick data.
"""
if self.mode == BacktestingMode.BAR:
long_cross_price = self.bar.close_price
short_cross_price = self.bar.close_price
else:
long_cross_price = self.tick.ask_price_1
short_cross_price = self.tick.bid_price_1
for algo in list(self.active_algos.values()):
# Check whether limit orders can be filled.
long_cross = (
algo.direction == Direction.LONG
and algo.price >= long_cross_price
and long_cross_price > 0
)
short_cross = (
algo.direction == Direction.SHORT
and algo.price <= short_cross_price
and short_cross_price > 0
)
if not long_cross and not short_cross:
continue
# Push order udpate with status "all traded" (filled).
algo.traded = algo.volume
algo.status = Status.ALLTRADED
self.strategy.update_spread_algo(algo)
self.active_algos.pop(algo.algoid)
# Push trade update
self.trade_count += 1
if long_cross:
trade_price = long_cross_price
pos_change = algo.volume
else:
trade_price = short_cross_price
pos_change = -algo.volume
trade = TradeData(
symbol=self.spread.name,
exchange=Exchange.LOCAL,
orderid=algo.algoid,
tradeid=str(self.trade_count),
direction=algo.direction,
offset=algo.offset,
price=trade_price,
volume=algo.volume,
time=self.datetime.strftime("%H:%M:%S"),
gateway_name=self.gateway_name,
)
trade.datetime = self.datetime
self.spread.net_pos += pos_change
self.strategy.on_spread_pos()
self.trades[trade.vt_tradeid] = trade
def load_bar(
self, spread: SpreadData, days: int, interval: Interval, callback: Callable
):
""""""
self.days = days
self.callback = callback
def load_tick(self, spread: SpreadData, days: int, callback: Callable):
""""""
self.days = days
self.callback = callback
def start_algo(
self,
strategy: SpreadStrategyTemplate,
spread_name: str,
direction: Direction,
offset: Offset,
price: float,
volume: float,
payup: int,
interval: int,
lock: bool
) -> str:
""""""
self.algo_count += 1
algoid = str(self.algo_count)
algo = SpreadAlgoTemplate(
self,
algoid,
self.spread,
direction,
offset,
price,
volume,
payup,
interval,
lock
)
self.algos[algoid] = algo
self.active_algos[algoid] = algo
return algoid
def stop_algo(
self,
strategy: SpreadStrategyTemplate,
algoid: str
):
""""""
if algoid not in self.active_algos:
return
algo = self.active_algos.pop(algoid)
algo.status = Status.CANCELLED
self.strategy.update_spread_algo(algo)
def send_order(
self,
strategy: SpreadStrategyTemplate,
direction: Direction,
offset: Offset,
price: float,
volume: float,
stop: bool,
lock: bool
):
""""""
pass
def cancel_order(self, strategy: SpreadStrategyTemplate, vt_orderid: str):
"""
Cancel order by vt_orderid.
"""
pass
def write_strategy_log(self, strategy: SpreadStrategyTemplate, msg: str):
"""
Write log message.
"""
msg = f"{self.datetime}\t{msg}"
self.logs.append(msg)
def send_email(self, msg: str, strategy: SpreadStrategyTemplate = None):
"""
Send email to default receiver.
"""
pass
def put_strategy_event(self, strategy: SpreadStrategyTemplate):
"""
Put an event to update strategy status.
"""
pass
def write_algo_log(self, algo: SpreadAlgoTemplate, msg: str):
""""""
pass
class DailyResult:
""""""
def __init__(self, date: date, close_price: float):
""""""
self.date = date
self.close_price = close_price
self.pre_close = 0
self.trades = []
self.trade_count = 0
self.start_pos = 0
self.end_pos = 0
self.turnover = 0
self.commission = 0
self.slippage = 0
self.trading_pnl = 0
self.holding_pnl = 0
self.total_pnl = 0
self.net_pnl = 0
def add_trade(self, trade: TradeData):
""""""
self.trades.append(trade)
def calculate_pnl(
self,
pre_close: float,
start_pos: float,
size: int,
rate: float,
slippage: float
):
""""""
# If no pre_close provided on the first day,
# use value 1 to avoid zero division error
if pre_close:
self.pre_close = pre_close
else:
self.pre_close = 1
# Holding pnl is the pnl from holding position at day start
self.start_pos = start_pos
self.end_pos = start_pos
self.holding_pnl = self.start_pos * (self.close_price - self.pre_close) * size
# Trading pnl is the pnl from new trade during the day
self.trade_count = len(self.trades)
for trade in self.trades:
if trade.direction == Direction.LONG:
pos_change = trade.volume
else:
pos_change = -trade.volume
self.end_pos += pos_change
turnover = trade.volume * size * trade.price
self.trading_pnl += pos_change * \
(self.close_price - trade.price) * size
self.slippage += trade.volume * size * slippage
self.turnover += turnover
self.commission += turnover * rate
# Net pnl takes account of commission and slippage cost
self.total_pnl = self.trading_pnl + self.holding_pnl
self.net_pnl = self.total_pnl - self.commission - self.slippage

View File

@ -1,9 +1,14 @@
from typing import Dict, List from typing import Dict, List
from datetime import datetime from datetime import datetime
from enum import Enum
from functools import lru_cache
from vnpy.trader.object import TickData, PositionData, TradeData, ContractData from vnpy.trader.object import (
from vnpy.trader.constant import Direction, Offset, Exchange TickData, PositionData, TradeData, ContractData, BarData
from vnpy.trader.utility import floor_to, ceil_to )
from vnpy.trader.constant import Direction, Offset, Exchange, Interval
from vnpy.trader.utility import floor_to, ceil_to, round_to, extract_vt_symbol
from vnpy.trader.database import database_manager
EVENT_SPREAD_DATA = "eSpreadData" EVENT_SPREAD_DATA = "eSpreadData"
@ -347,3 +352,78 @@ def calculate_inverse_volume(
if not price: if not price:
return 0 return 0
return original_volume * size / price return original_volume * size / price
class BacktestingMode(Enum):
BAR = 1
TICK = 2
@lru_cache(maxsize=999)
def load_bar_data(
spread: SpreadData,
interval: Interval,
start: datetime,
end: datetime,
pricetick: float = 0
):
""""""
# Load bar data of each spread leg
leg_bars: Dict[str, Dict] = {}
for vt_symbol in spread.legs.keys():
symbol, exchange = extract_vt_symbol(vt_symbol)
bar_data: List[BarData] = database_manager.load_bar_data(
symbol, exchange, interval, start, end
)
bars: Dict[datetime, BarData] = {bar.datetime: bar for bar in bar_data}
leg_bars[vt_symbol] = bars
# Calculate spread bar data
spread_bars: List[BarData] = []
for dt in bars.keys():
spread_price = 0
spread_available = True
for leg in spread.legs.values():
leg_bar = leg_bars[leg.vt_symbol].get(dt, None)
if leg_bar:
price_multiplier = spread.price_multipliers[leg.vt_symbol]
spread_price += price_multiplier * leg_bar.close_price
else:
spread_available = False
if spread_available:
if pricetick:
spread_price = round_to(spread_price, pricetick)
spread_bar = BarData(
symbol=spread.name,
exchange=exchange.LOCAL,
datetime=dt,
interval=interval,
open_price=spread_price,
high_price=spread_price,
low_price=spread_price,
close_price=spread_price,
gateway_name="SPREAD",
)
spread_bars.append(spread_bar)
return spread_bars
@lru_cache(maxsize=999)
def load_tick_data(
spread: SpreadData,
start: datetime,
end: datetime
):
""""""
return database_manager.load_tick_data(
spread.name, Exchange.LOCAL, start, end
)

View File

@ -5,6 +5,7 @@ from typing import List, Dict, Set, Callable, Any, Type
from collections import defaultdict from collections import defaultdict
from copy import copy from copy import copy
from pathlib import Path from pathlib import Path
from datetime import datetime, timedelta
from vnpy.event import EventEngine, Event from vnpy.event import EventEngine, Event
from vnpy.trader.engine import BaseEngine, MainEngine from vnpy.trader.engine import BaseEngine, MainEngine
@ -17,14 +18,17 @@ from vnpy.trader.object import (
TickData, ContractData, LogData, TickData, ContractData, LogData,
SubscribeRequest, OrderRequest SubscribeRequest, OrderRequest
) )
from vnpy.trader.constant import Direction, Offset, OrderType from vnpy.trader.constant import (
Direction, Offset, OrderType, Interval
)
from vnpy.trader.converter import OffsetConverter from vnpy.trader.converter import OffsetConverter
from .base import ( from .base import (
LegData, SpreadData, LegData, SpreadData,
EVENT_SPREAD_DATA, EVENT_SPREAD_POS, EVENT_SPREAD_DATA, EVENT_SPREAD_POS,
EVENT_SPREAD_ALGO, EVENT_SPREAD_LOG, EVENT_SPREAD_ALGO, EVENT_SPREAD_LOG,
EVENT_SPREAD_STRATEGY EVENT_SPREAD_STRATEGY,
load_bar_data, load_tick_data
) )
from .template import SpreadAlgoTemplate, SpreadStrategyTemplate from .template import SpreadAlgoTemplate, SpreadStrategyTemplate
from .algo import SpreadTakerAlgo from .algo import SpreadTakerAlgo
@ -1024,3 +1028,25 @@ class SpreadStrategyEngine:
subject = "价差策略引擎" subject = "价差策略引擎"
self.main_engine.send_email(subject, msg) self.main_engine.send_email(subject, msg)
def load_bar(
self, spread: SpreadData, days: int, interval: Interval, callback: Callable
):
""""""
end = datetime.now()
start = end - timedelta(days)
bars = load_bar_data(spread, interval, start, end)
for bar in bars:
callback(bar)
def load_tick(self, spread: SpreadData, days: int, callback: Callable):
""""""
end = datetime.now()
start = end - timedelta(days)
ticks = load_tick_data(spread, start, end)
for tick in ticks:
callback(tick)

View File

@ -0,0 +1,176 @@
from vnpy.trader.utility import BarGenerator, ArrayManager
from vnpy.app.spread_trading import (
SpreadStrategyTemplate,
SpreadAlgoTemplate,
SpreadData,
OrderData,
TradeData,
TickData,
BarData
)
class StatisticalArbitrageStrategy(SpreadStrategyTemplate):
""""""
author = "用Python的交易员"
boll_window = 20
boll_dev = 2
max_pos = 10
payup = 10
interval = 5
spread_pos = 0.0
boll_up = 0.0
boll_down = 0.0
boll_mid = 0.0
parameters = [
"boll_window",
"boll_dev",
"max_pos",
"payup",
"interval"
]
variables = [
"spread_pos",
"boll_up",
"boll_down",
"boll_mid"
]
def __init__(
self,
strategy_engine,
strategy_name: str,
spread: SpreadData,
setting: dict
):
""""""
super().__init__(
strategy_engine, strategy_name, spread, setting
)
self.bg = BarGenerator(self.on_spread_bar)
self.am = ArrayManager()
def on_init(self):
"""
Callback when strategy is inited.
"""
self.write_log("策略初始化")
self.load_bar(10)
def on_start(self):
"""
Callback when strategy is started.
"""
self.write_log("策略启动")
def on_stop(self):
"""
Callback when strategy is stopped.
"""
self.write_log("策略停止")
self.put_event()
def on_spread_data(self):
"""
Callback when spread price is updated.
"""
tick = self.get_spread_tick()
self.on_spread_tick(tick)
def on_spread_tick(self, tick: TickData):
"""
Callback when new spread tick data is generated.
"""
self.bg.update_tick(tick)
def on_spread_bar(self, bar: BarData):
"""
Callback when spread bar data is generated.
"""
self.am.update_bar(bar)
if not self.am.inited:
return
self.boll_mid = self.am.sma(self.boll_window)
self.boll_up, self.boll_down = self.am.boll(
self.boll_window, self.boll_dev)
if not self.spread_pos:
if bar.close_price >= self.boll_up:
self.start_short_algo(
bar.close_price - 10,
self.max_pos,
payup=self.payup,
interval=self.interval
)
elif bar.close_price <= self.boll_down:
self.start_long_algo(
bar.close_price + 10,
self.max_pos,
payup=self.payup,
interval=self.interval
)
elif self.spread_pos < 0:
if bar.close_price <= self.boll_mid:
self.start_long_algo(
bar.close_price + 10,
abs(self.spread_pos),
payup=self.payup,
interval=self.interval
)
else:
if bar.close_price >= self.boll_mid:
self.start_short_algo(
bar.close_price - 10,
abs(self.spread_pos),
payup=self.payup,
interval=self.interval
)
def on_spread_pos(self):
"""
Callback when spread position is updated.
"""
self.spread_pos = self.get_spread_pos()
self.put_event()
def on_spread_algo(self, algo: SpreadAlgoTemplate):
"""
Callback when algo status is updated.
"""
pass
def on_order(self, order: OrderData):
"""
Callback when order status is updated.
"""
pass
def on_trade(self, trade: TradeData):
"""
Callback when new trade data is received.
"""
pass
def stop_open_algos(self):
""""""
if self.buy_algoid:
self.stop_algo(self.buy_algoid)
if self.short_algoid:
self.stop_algo(self.short_algoid)
def stop_close_algos(self):
""""""
if self.sell_algoid:
self.stop_algo(self.sell_algoid)
if self.cover_algoid:
self.stop_algo(self.cover_algoid)

View File

@ -1,10 +1,12 @@
from collections import defaultdict from collections import defaultdict
from typing import Dict, List, Set from typing import Dict, List, Set, Callable
from copy import copy from copy import copy
from vnpy.trader.object import TickData, TradeData, OrderData, ContractData from vnpy.trader.object import (
from vnpy.trader.constant import Direction, Status, Offset TickData, TradeData, OrderData, ContractData, BarData
)
from vnpy.trader.constant import Direction, Status, Offset, Interval
from vnpy.trader.utility import virtual, floor_to, ceil_to, round_to from vnpy.trader.utility import virtual, floor_to, ceil_to, round_to
from .base import SpreadData, calculate_inverse_volume from .base import SpreadData, calculate_inverse_volume
@ -434,6 +436,20 @@ class SpreadStrategyTemplate:
""" """
pass pass
@virtual
def on_spread_tick(self, tick: TickData):
"""
Callback when new spread tick data is generated.
"""
pass
@virtual
def on_spread_bar(self, bar: BarData):
"""
Callback when new spread bar data is generated.
"""
pass
@virtual @virtual
def on_spread_pos(self): def on_spread_pos(self):
""" """
@ -635,3 +651,23 @@ class SpreadStrategyTemplate:
""" """
if self.inited: if self.inited:
self.strategy_engine.send_email(msg, self) self.strategy_engine.send_email(msg, self)
def load_bar(
self,
days: int,
interval: Interval = Interval.MINUTE,
callback: Callable = None,
):
"""
Load historical bar data for initializing strategy.
"""
if not callback:
callback = self.on_spread_bar
self.strategy_engine.load_bar(self.spread, days, interval, callback)
def load_tick(self, days: int):
"""
Load historical tick data for initializing strategy.
"""
self.strategy_engine.load_tick(self.spread, days, self.on_spread_tick)