[Add]backtesting function for cta strategy
This commit is contained in:
parent
2dc8057de0
commit
b99c5ff590
@ -7,3 +7,4 @@ numpy
|
||||
pandas
|
||||
matplotlib
|
||||
seaborn
|
||||
jupyter
|
@ -0,0 +1,6 @@
|
||||
{
|
||||
"cells": [],
|
||||
"metadata": {},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
54
tests/backtesting/turtle.ipynb
Normal file
54
tests/backtesting/turtle.ipynb
Normal file
@ -0,0 +1,54 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "SyntaxError",
|
||||
"evalue": "invalid syntax (backtesting.py, line 890)",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"Traceback \u001b[1;36m(most recent call last)\u001b[0m:\n",
|
||||
" File \u001b[0;32m\"c:\\miniconda3\\lib\\site-packages\\IPython\\core\\interactiveshell.py\"\u001b[0m, line \u001b[0;32m3267\u001b[0m, in \u001b[0;35mrun_code\u001b[0m\n exec(code_obj, self.user_global_ns, self.user_ns)\n",
|
||||
"\u001b[1;36m File \u001b[1;32m\"<ipython-input-1-fa539d7a6775>\"\u001b[1;36m, line \u001b[1;32m1\u001b[1;36m, in \u001b[1;35m<module>\u001b[1;36m\u001b[0m\n\u001b[1;33m from vnpy.app.cta_strategy.backtesting import BacktestingEngine\u001b[0m\n",
|
||||
"\u001b[1;36m File \u001b[1;32m\"C:\\Github\\vnpy\\vnpy\\app\\cta_strategy\\backtesting.py\"\u001b[1;36m, line \u001b[1;32m890\u001b[0m\n\u001b[1;33m rate=rate,\u001b[0m\n\u001b[1;37m ^\u001b[0m\n\u001b[1;31mSyntaxError\u001b[0m\u001b[1;31m:\u001b[0m invalid syntax\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from vnpy.app.cta_strategy.backtesting import BacktestingEngine\n",
|
||||
"from vnpy.app.cta_strategy.strategies.turtle_signal_strategy import TurtleSignalStrategy"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
29
tests/backtesting/turtle.py
Normal file
29
tests/backtesting/turtle.py
Normal file
@ -0,0 +1,29 @@
|
||||
#%%
|
||||
from vnpy.app.cta_strategy.backtesting import BacktestingEngine
|
||||
from vnpy.app.cta_strategy.strategies.turtle_signal_strategy import (
|
||||
TurtleSignalStrategy,
|
||||
)
|
||||
from datetime import datetime
|
||||
|
||||
#%%
|
||||
engine = BacktestingEngine()
|
||||
engine.set_parameters(
|
||||
vt_symbol="IF88.CFFEX",
|
||||
interval="1m",
|
||||
start=datetime(2013, 1, 1),
|
||||
end=datetime(2019, 1, 30),
|
||||
rate=0,
|
||||
slippage=0,
|
||||
size=300,
|
||||
pricetick=0.2,
|
||||
capital=1_000_000,
|
||||
)
|
||||
|
||||
#%%
|
||||
engine.add_strategy(TurtleSignalStrategy, {})
|
||||
engine.load_data()
|
||||
engine.run_backtesting()
|
||||
df = engine.calculate_result()
|
||||
engine.calculate_statistics()
|
||||
engine.show_chart()
|
||||
|
@ -4,6 +4,7 @@ from typing import Callable
|
||||
from itertools import product
|
||||
import multiprocessing
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
from pandas import DataFrame
|
||||
@ -27,6 +28,60 @@ from .template import CtaTemplate
|
||||
sns.set_style("whitegrid")
|
||||
|
||||
|
||||
|
||||
class OptimizationSetting:
|
||||
"""
|
||||
Setting for runnning optimization.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
""""""
|
||||
self.params = {}
|
||||
self.target = ""
|
||||
|
||||
def add_parameter(
|
||||
self, name: str, start: float, end: float = None, step: float = None
|
||||
):
|
||||
""""""
|
||||
if not end and not step:
|
||||
self.params[name] = [start]
|
||||
return
|
||||
|
||||
if start >= end:
|
||||
print("参数优化起始点必须小于终止点")
|
||||
return
|
||||
|
||||
if step <= 0:
|
||||
print("参数优化步进必须大于0")
|
||||
return
|
||||
|
||||
value = start
|
||||
value_list = []
|
||||
|
||||
while value <= end:
|
||||
value_list.append(value)
|
||||
value += step
|
||||
|
||||
self.params[name] = value_list
|
||||
|
||||
def set_target(self, target: str):
|
||||
""""""
|
||||
self.target = target
|
||||
|
||||
def generate_setting(self):
|
||||
""""""
|
||||
keys = self.params.keys()
|
||||
values = self.params.values()
|
||||
products = list(product(*values))
|
||||
|
||||
settings = []
|
||||
for product in products:
|
||||
setting = dict(zip(keys, product))
|
||||
settings.append(setting)
|
||||
|
||||
return settings
|
||||
|
||||
|
||||
class BacktestingEngine:
|
||||
""""""
|
||||
|
||||
@ -108,15 +163,17 @@ class BacktestingEngine:
|
||||
pricetick: float,
|
||||
capital: int = 0,
|
||||
end: datetime = None,
|
||||
mode: BacktestingMode = None,
|
||||
mode: BacktestingMode = BacktestingMode.BAR,
|
||||
):
|
||||
""""""
|
||||
self.mode = mode # 1
|
||||
self.vt_symbol = vt_symbol # 2
|
||||
self.interval = interval
|
||||
self.rate = rate # 3
|
||||
self.slippage = slippage # 4
|
||||
self.size = size #
|
||||
self.pricetick = pricetick #
|
||||
self.start = start
|
||||
|
||||
self.symbol, exchange_str = self.vt_symbol.split(".")
|
||||
self.exchange = Exchange(exchange_str)
|
||||
@ -124,6 +181,9 @@ class BacktestingEngine:
|
||||
if capital:
|
||||
self.capital = capital
|
||||
|
||||
if end:
|
||||
self.end = end
|
||||
|
||||
if mode:
|
||||
self.mode = mode
|
||||
|
||||
@ -162,7 +222,7 @@ class BacktestingEngine:
|
||||
|
||||
self.history_data = list(s)
|
||||
|
||||
self.output("历史数据加载完成")
|
||||
self.output(f"历史数据加载完成,数据量:{len(self.history_data)}")
|
||||
|
||||
def run_backtesting(self):
|
||||
""""""
|
||||
@ -209,7 +269,7 @@ class BacktestingEngine:
|
||||
for trade in self.trades.values():
|
||||
d = trade.datetime.date()
|
||||
daily_result = self.daily_results[d]
|
||||
d.add_trade(trade)
|
||||
daily_result.add_trade(trade)
|
||||
|
||||
# Calculate daily result by iteration.
|
||||
pre_close = 0
|
||||
@ -228,7 +288,7 @@ class BacktestingEngine:
|
||||
|
||||
for daily_result in self.daily_results.values():
|
||||
for key, value in daily_result.__dict__.items():
|
||||
results[key] = value
|
||||
results[key].append(value)
|
||||
|
||||
self.daily_df = DataFrame.from_dict(results).set_index("date")
|
||||
|
||||
@ -258,8 +318,8 @@ class BacktestingEngine:
|
||||
end_date = df.index[-1]
|
||||
|
||||
total_days = len(df)
|
||||
profit_days = len(df[df["netPnl"] > 0])
|
||||
loss_days = len(df[df["netPnl"] < 0])
|
||||
profit_days = len(df[df["net_pnl"] > 0])
|
||||
loss_days = len(df[df["net_pnl"] < 0])
|
||||
|
||||
end_balance = df["balance"].iloc[-1]
|
||||
max_drawdown = df["drawdown"].min()
|
||||
@ -426,9 +486,6 @@ class BacktestingEngine:
|
||||
|
||||
return result_values
|
||||
|
||||
return resultList
|
||||
|
||||
|
||||
def update_daily_close(self, price: float):
|
||||
""""""
|
||||
d = self.datetime.date()
|
||||
@ -476,7 +533,7 @@ class BacktestingEngine:
|
||||
long_best_price = long_cross_price
|
||||
short_best_price = short_cross_price
|
||||
|
||||
for order in self.active_limit_orders.values():
|
||||
for order in list(self.active_limit_orders.values()):
|
||||
# Push order update with status "not traded" (pending).
|
||||
if order.status == Status.SUBMITTING:
|
||||
order.status = Status.NOTTRADED
|
||||
@ -549,7 +606,7 @@ class BacktestingEngine:
|
||||
long_best_price = long_cross_price
|
||||
short_best_price = short_cross_price
|
||||
|
||||
for stop_order in self.active_stop_orders.values():
|
||||
for stop_order in list(self.active_stop_orders.values()):
|
||||
# Check whether stop order can be triggered.
|
||||
long_cross = (
|
||||
stop_order.direction == Direction.LONG
|
||||
@ -611,6 +668,8 @@ class BacktestingEngine:
|
||||
stop_order.vt_orderid = order.vt_orderid
|
||||
stop_order.status = StopOrderStatus.TRIGGERED
|
||||
|
||||
self.active_stop_orders.pop(stop_order.stop_orderid)
|
||||
|
||||
# Push update to strategy.
|
||||
self.strategy.on_stop_order(stop_order)
|
||||
self.strategy.on_order(order)
|
||||
@ -715,10 +774,12 @@ class BacktestingEngine:
|
||||
"""
|
||||
Cancel all orders, both limit and stop.
|
||||
"""
|
||||
for vt_orderid in self.active_limit_orders.keys():
|
||||
vt_orderids = list(self.active_limit_orders.keys())
|
||||
for vt_orderid in vt_orderids:
|
||||
self.cancel_limit_order(vt_orderid)
|
||||
|
||||
for vt_orderid in self.active_stop_orders.keys():
|
||||
stop_orderids = list(self.active_stop_orders.keys())
|
||||
for vt_orderid in stop_orderids:
|
||||
self.cancel_stop_order(vt_orderid)
|
||||
|
||||
def write_log(self, msg: str, strategy: CtaTemplate = None):
|
||||
@ -734,7 +795,7 @@ class BacktestingEngine:
|
||||
"""
|
||||
return self.engine_type
|
||||
|
||||
def put_put_strategy_event(self, strategy: CtaTemplate):
|
||||
def put_strategy_event(self, strategy: CtaTemplate):
|
||||
"""
|
||||
Put an event to update strategy status.
|
||||
"""
|
||||
@ -810,60 +871,6 @@ class DailyResult:
|
||||
self.net_pnl = self.total_pnl - self.commission - self.slippage
|
||||
|
||||
|
||||
class OptimizationSetting:
|
||||
"""
|
||||
Setting for runnning optimization.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
""""""
|
||||
self.params = {}
|
||||
self.target = ""
|
||||
|
||||
def add_parameter(
|
||||
self, name: str, start: float, end: float = None, step: float = None
|
||||
):
|
||||
""""""
|
||||
if not end and not step:
|
||||
self.params[name] = [start]
|
||||
return
|
||||
|
||||
if start >= end:
|
||||
print("参数优化起始点必须小于终止点")
|
||||
return
|
||||
|
||||
if step <= 0:
|
||||
print("参数优化步进必须大于0")
|
||||
return
|
||||
|
||||
value = start
|
||||
value_list = []
|
||||
|
||||
while value <= end:
|
||||
value_list.append(value)
|
||||
value += step
|
||||
|
||||
self.params[name] = value_list
|
||||
|
||||
def set_target(self, target: str):
|
||||
""""""
|
||||
self.target = target
|
||||
|
||||
def generate_setting(self):
|
||||
""""""
|
||||
keys = self.params.keys()
|
||||
values = self.params.values()
|
||||
products = list(product(*values))
|
||||
|
||||
settings = []
|
||||
for product in products:
|
||||
setting = dict(zip(keys, product))
|
||||
settings.append(setting)
|
||||
|
||||
return settings
|
||||
|
||||
|
||||
|
||||
def optimize(
|
||||
target_name: str,
|
||||
strategy_class: CtaTemplate,
|
||||
@ -886,7 +893,7 @@ def optimize(
|
||||
engine.set_parameters(
|
||||
vt_symbol=vt_symbol,
|
||||
interval=interval,
|
||||
start=start
|
||||
start=start,
|
||||
rate=rate,
|
||||
slippage=slippage,
|
||||
size=size,
|
||||
|
@ -8,14 +8,14 @@ from enum import Enum
|
||||
from vnpy.trader.constant import Direction, Offset
|
||||
|
||||
APP_NAME = "CtaStrategy"
|
||||
STOPORDER_PREFIX = "STOP."
|
||||
STOPORDER_PREFIX = "STOP"
|
||||
|
||||
|
||||
class CtaOrderType(Enum):
|
||||
BUY = "买开"
|
||||
SELL = "买开"
|
||||
SHORT = "买开"
|
||||
COVER = "买开"
|
||||
SELL = "卖平"
|
||||
SHORT = "卖开"
|
||||
COVER = "买平"
|
||||
|
||||
|
||||
class StopOrderStatus(Enum):
|
||||
|
@ -38,7 +38,7 @@ class TurtleSignalStrategy(CtaTemplate):
|
||||
|
||||
def __init__(self, cta_engine, strategy_name, vt_symbol, setting):
|
||||
""""""
|
||||
super(DoubleMaStrategy, self).__init__(
|
||||
super(TurtleSignalStrategy, self).__init__(
|
||||
cta_engine, strategy_name, vt_symbol, setting
|
||||
)
|
||||
|
||||
@ -50,6 +50,7 @@ class TurtleSignalStrategy(CtaTemplate):
|
||||
Callback when strategy is inited.
|
||||
"""
|
||||
self.write_log("策略初始化")
|
||||
self.load_bar(20)
|
||||
|
||||
def on_start(self):
|
||||
"""
|
||||
@ -83,22 +84,22 @@ class TurtleSignalStrategy(CtaTemplate):
|
||||
self.exit_up, self.exit_down = self.am.donchian(self.exit_window)
|
||||
|
||||
if not self.pos:
|
||||
self.atr_value = self.am.atr(self.atr_value)
|
||||
self.atr_value = self.am.atr(self.atr_window)
|
||||
|
||||
self.long_entry = 0
|
||||
self.short_entry = 0
|
||||
self.long_stop = 0
|
||||
self.short_stop = 0
|
||||
|
||||
self.send_buy_orders(self.long_entry)
|
||||
self.send_short_orders(self.short_entry)
|
||||
elif self.pos < 0:
|
||||
self.send_buy_orders(self.entry_up)
|
||||
self.send_short_orders(self.entry_down)
|
||||
elif self.pos > 0:
|
||||
self.send_buy_orders(self.long_entry)
|
||||
|
||||
sell_price = max(self.long_stop, self.exit_down)
|
||||
self.sell(sell_price, abs(self.pos), True)
|
||||
|
||||
elif self.pos > 0:
|
||||
elif self.pos < 0:
|
||||
self.send_short_orders(self.short_entry)
|
||||
|
||||
cover_price = min(self.short_stop, self.exit_up)
|
||||
@ -110,7 +111,7 @@ class TurtleSignalStrategy(CtaTemplate):
|
||||
"""
|
||||
Callback of new trade data update.
|
||||
"""
|
||||
if trade.dierction == Direction.LONG:
|
||||
if trade.direction == Direction.LONG:
|
||||
self.long_entry = trade.price
|
||||
self.long_stop = self.long_entry - 2 * self.atr_value
|
||||
else:
|
||||
|
@ -169,9 +169,13 @@ class CtaTemplate(ABC):
|
||||
"""
|
||||
Send a new order.
|
||||
"""
|
||||
return self.cta_engine.send_order(
|
||||
if self.trading:
|
||||
vt_orderid = self.cta_engine.send_order(
|
||||
self, order_type, price, volume, stop
|
||||
)
|
||||
else:
|
||||
vt_orderid = ""
|
||||
return vt_orderid
|
||||
|
||||
def cancel_order(self, vt_orderid: str):
|
||||
"""
|
||||
|
@ -16,15 +16,6 @@ from ibapi.order_state import OrderState
|
||||
from ibapi.ticktype import TickType
|
||||
from ibapi.wrapper import EWrapper
|
||||
|
||||
from vnpy.trader.constant import (
|
||||
Currency,
|
||||
Direction,
|
||||
Exchange,
|
||||
OptionType,
|
||||
PriceType,
|
||||
Product,
|
||||
Status,
|
||||
)
|
||||
from vnpy.trader.gateway import BaseGateway
|
||||
from vnpy.trader.object import (
|
||||
AccountData,
|
||||
@ -414,7 +405,7 @@ class IbApi(EWrapper):
|
||||
pos = PositionData(
|
||||
symbol=contract.conId,
|
||||
exchange=EXCHANGE_IB2VT.get(contract.exchange, contract.exchange),
|
||||
direction=DIRECTION_NET,
|
||||
direction=Direction.NET,
|
||||
volume=position,
|
||||
price=averageCost,
|
||||
pnl=unrealizedPNL,
|
||||
@ -430,9 +421,7 @@ class IbApi(EWrapper):
|
||||
for account in self.accounts.values():
|
||||
self.gateway.on_account(copy(account))
|
||||
|
||||
def contractDetails(
|
||||
self, reqId: int, contractDetails: ContractDetails
|
||||
): # pylint: disable=invalid-name
|
||||
def contractDetails(self, reqId: int, contractDetails: ContractDetails): # pylint: disable=invalid-name
|
||||
"""
|
||||
Callback of contract data update.
|
||||
"""
|
||||
|
@ -1,13 +1,6 @@
|
||||
""""""
|
||||
|
||||
from peewee import (
|
||||
SqliteDatabase,
|
||||
Model,
|
||||
CharField,
|
||||
DateTimeField,
|
||||
FloatField,
|
||||
IntegerField,
|
||||
)
|
||||
from peewee import SqliteDatabase, Model, CharField, DateTimeField, FloatField
|
||||
|
||||
from .constant import Exchange, Interval
|
||||
from .object import BarData, TickData
|
||||
@ -73,10 +66,10 @@ class DbBarData(Model):
|
||||
datetime=self.datetime,
|
||||
interval=Interval(self.interval),
|
||||
volume=self.volume,
|
||||
open_price=open_price,
|
||||
high_price=high_price,
|
||||
low_price=low_price,
|
||||
close_price=close_price,
|
||||
open_price=self.open_price,
|
||||
high_price=self.high_price,
|
||||
low_price=self.low_price,
|
||||
close_price=self.close_price,
|
||||
gateway_name=self.gateway_name,
|
||||
)
|
||||
return bar
|
||||
|
@ -2,7 +2,7 @@
|
||||
Event type string used in VN Trader.
|
||||
"""
|
||||
|
||||
from vnpy.event import EVENT_TIMER
|
||||
from vnpy.event import EVENT_TIMER # noqa
|
||||
|
||||
EVENT_TICK = "eTick."
|
||||
EVENT_TRADE = "eTrade."
|
||||
|
@ -6,7 +6,15 @@ from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from vnpy.event import Event, EventEngine
|
||||
from .event import EVENT_ACCOUNT, EVENT_CONTRACT, EVENT_LOG, EVENT_CONTRACT
|
||||
from .event import (
|
||||
EVENT_TICK,
|
||||
EVENT_TRADE,
|
||||
EVENT_ORDER,
|
||||
EVENT_POSITION,
|
||||
EVENT_ACCOUNT,
|
||||
EVENT_CONTRACT,
|
||||
EVENT_LOG,
|
||||
)
|
||||
from .object import (
|
||||
TickData,
|
||||
OrderData,
|
||||
@ -17,7 +25,7 @@ from .object import (
|
||||
LogData,
|
||||
OrderRequest,
|
||||
CancelRequest,
|
||||
SubscribeRequest
|
||||
SubscribeRequest,
|
||||
)
|
||||
|
||||
|
||||
|
@ -224,10 +224,10 @@ class ArrayManager(object):
|
||||
self.close_array[:-1] = self.close_array[1:]
|
||||
self.volume_array[:-1] = self.volume_array[1:]
|
||||
|
||||
self.open_array[-1] = bar.open
|
||||
self.high_array[-1] = bar.high
|
||||
self.low_array[-1] = bar.low
|
||||
self.close_array[-1] = bar.close
|
||||
self.open_array[-1] = bar.open_price
|
||||
self.high_array[-1] = bar.high_price
|
||||
self.low_array[-1] = bar.low_price
|
||||
self.close_array[-1] = bar.close_price
|
||||
self.volume_array[-1] = bar.volume
|
||||
|
||||
@property
|
||||
|
Loading…
Reference in New Issue
Block a user