[Add]backtesting function for cta strategy

This commit is contained in:
vn.py 2019-01-29 15:35:37 +08:00
parent 2dc8057de0
commit b99c5ff590
13 changed files with 209 additions and 117 deletions

View File

@ -7,3 +7,4 @@ numpy
pandas pandas
matplotlib matplotlib
seaborn seaborn
jupyter

View File

@ -0,0 +1,6 @@
{
"cells": [],
"metadata": {},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -0,0 +1,54 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"ename": "SyntaxError",
"evalue": "invalid syntax (backtesting.py, line 890)",
"output_type": "error",
"traceback": [
"Traceback \u001b[1;36m(most recent call last)\u001b[0m:\n",
" File \u001b[0;32m\"c:\\miniconda3\\lib\\site-packages\\IPython\\core\\interactiveshell.py\"\u001b[0m, line \u001b[0;32m3267\u001b[0m, in \u001b[0;35mrun_code\u001b[0m\n exec(code_obj, self.user_global_ns, self.user_ns)\n",
"\u001b[1;36m File \u001b[1;32m\"<ipython-input-1-fa539d7a6775>\"\u001b[1;36m, line \u001b[1;32m1\u001b[1;36m, in \u001b[1;35m<module>\u001b[1;36m\u001b[0m\n\u001b[1;33m from vnpy.app.cta_strategy.backtesting import BacktestingEngine\u001b[0m\n",
"\u001b[1;36m File \u001b[1;32m\"C:\\Github\\vnpy\\vnpy\\app\\cta_strategy\\backtesting.py\"\u001b[1;36m, line \u001b[1;32m890\u001b[0m\n\u001b[1;33m rate=rate,\u001b[0m\n\u001b[1;37m ^\u001b[0m\n\u001b[1;31mSyntaxError\u001b[0m\u001b[1;31m:\u001b[0m invalid syntax\n"
]
}
],
"source": [
"from vnpy.app.cta_strategy.backtesting import BacktestingEngine\n",
"from vnpy.app.cta_strategy.strategies.turtle_signal_strategy import TurtleSignalStrategy"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -0,0 +1,29 @@
#%%
from vnpy.app.cta_strategy.backtesting import BacktestingEngine
from vnpy.app.cta_strategy.strategies.turtle_signal_strategy import (
TurtleSignalStrategy,
)
from datetime import datetime
#%%
engine = BacktestingEngine()
engine.set_parameters(
vt_symbol="IF88.CFFEX",
interval="1m",
start=datetime(2013, 1, 1),
end=datetime(2019, 1, 30),
rate=0,
slippage=0,
size=300,
pricetick=0.2,
capital=1_000_000,
)
#%%
engine.add_strategy(TurtleSignalStrategy, {})
engine.load_data()
engine.run_backtesting()
df = engine.calculate_result()
engine.calculate_statistics()
engine.show_chart()

View File

@ -4,6 +4,7 @@ from typing import Callable
from itertools import product from itertools import product
import multiprocessing import multiprocessing
import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import seaborn as sns import seaborn as sns
from pandas import DataFrame from pandas import DataFrame
@ -27,6 +28,60 @@ from .template import CtaTemplate
sns.set_style("whitegrid") sns.set_style("whitegrid")
class OptimizationSetting:
"""
Setting for runnning optimization.
"""
def __init__(self):
""""""
self.params = {}
self.target = ""
def add_parameter(
self, name: str, start: float, end: float = None, step: float = None
):
""""""
if not end and not step:
self.params[name] = [start]
return
if start >= end:
print("参数优化起始点必须小于终止点")
return
if step <= 0:
print("参数优化步进必须大于0")
return
value = start
value_list = []
while value <= end:
value_list.append(value)
value += step
self.params[name] = value_list
def set_target(self, target: str):
""""""
self.target = target
def generate_setting(self):
""""""
keys = self.params.keys()
values = self.params.values()
products = list(product(*values))
settings = []
for product in products:
setting = dict(zip(keys, product))
settings.append(setting)
return settings
class BacktestingEngine: class BacktestingEngine:
"""""" """"""
@ -108,15 +163,17 @@ class BacktestingEngine:
pricetick: float, pricetick: float,
capital: int = 0, capital: int = 0,
end: datetime = None, end: datetime = None,
mode: BacktestingMode = None, mode: BacktestingMode = BacktestingMode.BAR,
): ):
"""""" """"""
self.mode = mode # 1 self.mode = mode # 1
self.vt_symbol = vt_symbol # 2 self.vt_symbol = vt_symbol # 2
self.interval = interval
self.rate = rate # 3 self.rate = rate # 3
self.slippage = slippage # 4 self.slippage = slippage # 4
self.size = size # self.size = size #
self.pricetick = pricetick # self.pricetick = pricetick #
self.start = start
self.symbol, exchange_str = self.vt_symbol.split(".") self.symbol, exchange_str = self.vt_symbol.split(".")
self.exchange = Exchange(exchange_str) self.exchange = Exchange(exchange_str)
@ -124,6 +181,9 @@ class BacktestingEngine:
if capital: if capital:
self.capital = capital self.capital = capital
if end:
self.end = end
if mode: if mode:
self.mode = mode self.mode = mode
@ -162,7 +222,7 @@ class BacktestingEngine:
self.history_data = list(s) self.history_data = list(s)
self.output("历史数据加载完成") self.output(f"历史数据加载完成,数据量:{len(self.history_data)}")
def run_backtesting(self): def run_backtesting(self):
"""""" """"""
@ -209,7 +269,7 @@ class BacktestingEngine:
for trade in self.trades.values(): for trade in self.trades.values():
d = trade.datetime.date() d = trade.datetime.date()
daily_result = self.daily_results[d] daily_result = self.daily_results[d]
d.add_trade(trade) daily_result.add_trade(trade)
# Calculate daily result by iteration. # Calculate daily result by iteration.
pre_close = 0 pre_close = 0
@ -228,7 +288,7 @@ class BacktestingEngine:
for daily_result in self.daily_results.values(): for daily_result in self.daily_results.values():
for key, value in daily_result.__dict__.items(): for key, value in daily_result.__dict__.items():
results[key] = value results[key].append(value)
self.daily_df = DataFrame.from_dict(results).set_index("date") self.daily_df = DataFrame.from_dict(results).set_index("date")
@ -258,8 +318,8 @@ class BacktestingEngine:
end_date = df.index[-1] end_date = df.index[-1]
total_days = len(df) total_days = len(df)
profit_days = len(df[df["netPnl"] > 0]) profit_days = len(df[df["net_pnl"] > 0])
loss_days = len(df[df["netPnl"] < 0]) loss_days = len(df[df["net_pnl"] < 0])
end_balance = df["balance"].iloc[-1] end_balance = df["balance"].iloc[-1]
max_drawdown = df["drawdown"].min() max_drawdown = df["drawdown"].min()
@ -426,9 +486,6 @@ class BacktestingEngine:
return result_values return result_values
return resultList
def update_daily_close(self, price: float): def update_daily_close(self, price: float):
"""""" """"""
d = self.datetime.date() d = self.datetime.date()
@ -476,7 +533,7 @@ class BacktestingEngine:
long_best_price = long_cross_price long_best_price = long_cross_price
short_best_price = short_cross_price short_best_price = short_cross_price
for order in self.active_limit_orders.values(): for order in list(self.active_limit_orders.values()):
# Push order update with status "not traded" (pending). # Push order update with status "not traded" (pending).
if order.status == Status.SUBMITTING: if order.status == Status.SUBMITTING:
order.status = Status.NOTTRADED order.status = Status.NOTTRADED
@ -549,7 +606,7 @@ class BacktestingEngine:
long_best_price = long_cross_price long_best_price = long_cross_price
short_best_price = short_cross_price short_best_price = short_cross_price
for stop_order in self.active_stop_orders.values(): for stop_order in list(self.active_stop_orders.values()):
# Check whether stop order can be triggered. # Check whether stop order can be triggered.
long_cross = ( long_cross = (
stop_order.direction == Direction.LONG stop_order.direction == Direction.LONG
@ -611,6 +668,8 @@ class BacktestingEngine:
stop_order.vt_orderid = order.vt_orderid stop_order.vt_orderid = order.vt_orderid
stop_order.status = StopOrderStatus.TRIGGERED stop_order.status = StopOrderStatus.TRIGGERED
self.active_stop_orders.pop(stop_order.stop_orderid)
# Push update to strategy. # Push update to strategy.
self.strategy.on_stop_order(stop_order) self.strategy.on_stop_order(stop_order)
self.strategy.on_order(order) self.strategy.on_order(order)
@ -715,10 +774,12 @@ class BacktestingEngine:
""" """
Cancel all orders, both limit and stop. Cancel all orders, both limit and stop.
""" """
for vt_orderid in self.active_limit_orders.keys(): vt_orderids = list(self.active_limit_orders.keys())
for vt_orderid in vt_orderids:
self.cancel_limit_order(vt_orderid) self.cancel_limit_order(vt_orderid)
for vt_orderid in self.active_stop_orders.keys(): stop_orderids = list(self.active_stop_orders.keys())
for vt_orderid in stop_orderids:
self.cancel_stop_order(vt_orderid) self.cancel_stop_order(vt_orderid)
def write_log(self, msg: str, strategy: CtaTemplate = None): def write_log(self, msg: str, strategy: CtaTemplate = None):
@ -734,7 +795,7 @@ class BacktestingEngine:
""" """
return self.engine_type return self.engine_type
def put_put_strategy_event(self, strategy: CtaTemplate): def put_strategy_event(self, strategy: CtaTemplate):
""" """
Put an event to update strategy status. Put an event to update strategy status.
""" """
@ -810,60 +871,6 @@ class DailyResult:
self.net_pnl = self.total_pnl - self.commission - self.slippage self.net_pnl = self.total_pnl - self.commission - self.slippage
class OptimizationSetting:
"""
Setting for runnning optimization.
"""
def __init__(self):
""""""
self.params = {}
self.target = ""
def add_parameter(
self, name: str, start: float, end: float = None, step: float = None
):
""""""
if not end and not step:
self.params[name] = [start]
return
if start >= end:
print("参数优化起始点必须小于终止点")
return
if step <= 0:
print("参数优化步进必须大于0")
return
value = start
value_list = []
while value <= end:
value_list.append(value)
value += step
self.params[name] = value_list
def set_target(self, target: str):
""""""
self.target = target
def generate_setting(self):
""""""
keys = self.params.keys()
values = self.params.values()
products = list(product(*values))
settings = []
for product in products:
setting = dict(zip(keys, product))
settings.append(setting)
return settings
def optimize( def optimize(
target_name: str, target_name: str,
strategy_class: CtaTemplate, strategy_class: CtaTemplate,
@ -886,7 +893,7 @@ def optimize(
engine.set_parameters( engine.set_parameters(
vt_symbol=vt_symbol, vt_symbol=vt_symbol,
interval=interval, interval=interval,
start=start start=start,
rate=rate, rate=rate,
slippage=slippage, slippage=slippage,
size=size, size=size,

View File

@ -8,14 +8,14 @@ from enum import Enum
from vnpy.trader.constant import Direction, Offset from vnpy.trader.constant import Direction, Offset
APP_NAME = "CtaStrategy" APP_NAME = "CtaStrategy"
STOPORDER_PREFIX = "STOP." STOPORDER_PREFIX = "STOP"
class CtaOrderType(Enum): class CtaOrderType(Enum):
BUY = "买开" BUY = "买开"
SELL = "买开" SELL = "卖平"
SHORT = "" SHORT = ""
COVER = "" COVER = ""
class StopOrderStatus(Enum): class StopOrderStatus(Enum):

View File

@ -38,7 +38,7 @@ class TurtleSignalStrategy(CtaTemplate):
def __init__(self, cta_engine, strategy_name, vt_symbol, setting): def __init__(self, cta_engine, strategy_name, vt_symbol, setting):
"""""" """"""
super(DoubleMaStrategy, self).__init__( super(TurtleSignalStrategy, self).__init__(
cta_engine, strategy_name, vt_symbol, setting cta_engine, strategy_name, vt_symbol, setting
) )
@ -50,6 +50,7 @@ class TurtleSignalStrategy(CtaTemplate):
Callback when strategy is inited. Callback when strategy is inited.
""" """
self.write_log("策略初始化") self.write_log("策略初始化")
self.load_bar(20)
def on_start(self): def on_start(self):
""" """
@ -83,22 +84,22 @@ class TurtleSignalStrategy(CtaTemplate):
self.exit_up, self.exit_down = self.am.donchian(self.exit_window) self.exit_up, self.exit_down = self.am.donchian(self.exit_window)
if not self.pos: if not self.pos:
self.atr_value = self.am.atr(self.atr_value) self.atr_value = self.am.atr(self.atr_window)
self.long_entry = 0 self.long_entry = 0
self.short_entry = 0 self.short_entry = 0
self.long_stop = 0 self.long_stop = 0
self.short_stop = 0 self.short_stop = 0
self.send_buy_orders(self.long_entry) self.send_buy_orders(self.entry_up)
self.send_short_orders(self.short_entry) self.send_short_orders(self.entry_down)
elif self.pos < 0: elif self.pos > 0:
self.send_buy_orders(self.long_entry) self.send_buy_orders(self.long_entry)
sell_price = max(self.long_stop, self.exit_down) sell_price = max(self.long_stop, self.exit_down)
self.sell(sell_price, abs(self.pos), True) self.sell(sell_price, abs(self.pos), True)
elif self.pos > 0: elif self.pos < 0:
self.send_short_orders(self.short_entry) self.send_short_orders(self.short_entry)
cover_price = min(self.short_stop, self.exit_up) cover_price = min(self.short_stop, self.exit_up)
@ -110,7 +111,7 @@ class TurtleSignalStrategy(CtaTemplate):
""" """
Callback of new trade data update. Callback of new trade data update.
""" """
if trade.dierction == Direction.LONG: if trade.direction == Direction.LONG:
self.long_entry = trade.price self.long_entry = trade.price
self.long_stop = self.long_entry - 2 * self.atr_value self.long_stop = self.long_entry - 2 * self.atr_value
else: else:

View File

@ -169,9 +169,13 @@ class CtaTemplate(ABC):
""" """
Send a new order. Send a new order.
""" """
return self.cta_engine.send_order( if self.trading:
vt_orderid = self.cta_engine.send_order(
self, order_type, price, volume, stop self, order_type, price, volume, stop
) )
else:
vt_orderid = ""
return vt_orderid
def cancel_order(self, vt_orderid: str): def cancel_order(self, vt_orderid: str):
""" """

View File

@ -16,15 +16,6 @@ from ibapi.order_state import OrderState
from ibapi.ticktype import TickType from ibapi.ticktype import TickType
from ibapi.wrapper import EWrapper from ibapi.wrapper import EWrapper
from vnpy.trader.constant import (
Currency,
Direction,
Exchange,
OptionType,
PriceType,
Product,
Status,
)
from vnpy.trader.gateway import BaseGateway from vnpy.trader.gateway import BaseGateway
from vnpy.trader.object import ( from vnpy.trader.object import (
AccountData, AccountData,
@ -414,7 +405,7 @@ class IbApi(EWrapper):
pos = PositionData( pos = PositionData(
symbol=contract.conId, symbol=contract.conId,
exchange=EXCHANGE_IB2VT.get(contract.exchange, contract.exchange), exchange=EXCHANGE_IB2VT.get(contract.exchange, contract.exchange),
direction=DIRECTION_NET, direction=Direction.NET,
volume=position, volume=position,
price=averageCost, price=averageCost,
pnl=unrealizedPNL, pnl=unrealizedPNL,
@ -430,9 +421,7 @@ class IbApi(EWrapper):
for account in self.accounts.values(): for account in self.accounts.values():
self.gateway.on_account(copy(account)) self.gateway.on_account(copy(account))
def contractDetails( def contractDetails(self, reqId: int, contractDetails: ContractDetails): # pylint: disable=invalid-name
self, reqId: int, contractDetails: ContractDetails
): # pylint: disable=invalid-name
""" """
Callback of contract data update. Callback of contract data update.
""" """

View File

@ -1,13 +1,6 @@
"""""" """"""
from peewee import ( from peewee import SqliteDatabase, Model, CharField, DateTimeField, FloatField
SqliteDatabase,
Model,
CharField,
DateTimeField,
FloatField,
IntegerField,
)
from .constant import Exchange, Interval from .constant import Exchange, Interval
from .object import BarData, TickData from .object import BarData, TickData
@ -73,10 +66,10 @@ class DbBarData(Model):
datetime=self.datetime, datetime=self.datetime,
interval=Interval(self.interval), interval=Interval(self.interval),
volume=self.volume, volume=self.volume,
open_price=open_price, open_price=self.open_price,
high_price=high_price, high_price=self.high_price,
low_price=low_price, low_price=self.low_price,
close_price=close_price, close_price=self.close_price,
gateway_name=self.gateway_name, gateway_name=self.gateway_name,
) )
return bar return bar

View File

@ -2,7 +2,7 @@
Event type string used in VN Trader. Event type string used in VN Trader.
""" """
from vnpy.event import EVENT_TIMER from vnpy.event import EVENT_TIMER # noqa
EVENT_TICK = "eTick." EVENT_TICK = "eTick."
EVENT_TRADE = "eTrade." EVENT_TRADE = "eTrade."

View File

@ -6,7 +6,15 @@ from abc import ABC, abstractmethod
from typing import Any from typing import Any
from vnpy.event import Event, EventEngine from vnpy.event import Event, EventEngine
from .event import EVENT_ACCOUNT, EVENT_CONTRACT, EVENT_LOG, EVENT_CONTRACT from .event import (
EVENT_TICK,
EVENT_TRADE,
EVENT_ORDER,
EVENT_POSITION,
EVENT_ACCOUNT,
EVENT_CONTRACT,
EVENT_LOG,
)
from .object import ( from .object import (
TickData, TickData,
OrderData, OrderData,
@ -17,7 +25,7 @@ from .object import (
LogData, LogData,
OrderRequest, OrderRequest,
CancelRequest, CancelRequest,
SubscribeRequest SubscribeRequest,
) )

View File

@ -224,10 +224,10 @@ class ArrayManager(object):
self.close_array[:-1] = self.close_array[1:] self.close_array[:-1] = self.close_array[1:]
self.volume_array[:-1] = self.volume_array[1:] self.volume_array[:-1] = self.volume_array[1:]
self.open_array[-1] = bar.open self.open_array[-1] = bar.open_price
self.high_array[-1] = bar.high self.high_array[-1] = bar.high_price
self.low_array[-1] = bar.low self.low_array[-1] = bar.low_price
self.close_array[-1] = bar.close self.close_array[-1] = bar.close_price
self.volume_array[-1] = bar.volume self.volume_array[-1] = bar.volume
@property @property