[Add]backtesting function for cta strategy

This commit is contained in:
vn.py 2019-01-29 15:35:37 +08:00
parent 2dc8057de0
commit b99c5ff590
13 changed files with 209 additions and 117 deletions

View File

@ -7,3 +7,4 @@ numpy
pandas
matplotlib
seaborn
jupyter

View File

@ -0,0 +1,6 @@
{
"cells": [],
"metadata": {},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -0,0 +1,54 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"ename": "SyntaxError",
"evalue": "invalid syntax (backtesting.py, line 890)",
"output_type": "error",
"traceback": [
"Traceback \u001b[1;36m(most recent call last)\u001b[0m:\n",
" File \u001b[0;32m\"c:\\miniconda3\\lib\\site-packages\\IPython\\core\\interactiveshell.py\"\u001b[0m, line \u001b[0;32m3267\u001b[0m, in \u001b[0;35mrun_code\u001b[0m\n exec(code_obj, self.user_global_ns, self.user_ns)\n",
"\u001b[1;36m File \u001b[1;32m\"<ipython-input-1-fa539d7a6775>\"\u001b[1;36m, line \u001b[1;32m1\u001b[1;36m, in \u001b[1;35m<module>\u001b[1;36m\u001b[0m\n\u001b[1;33m from vnpy.app.cta_strategy.backtesting import BacktestingEngine\u001b[0m\n",
"\u001b[1;36m File \u001b[1;32m\"C:\\Github\\vnpy\\vnpy\\app\\cta_strategy\\backtesting.py\"\u001b[1;36m, line \u001b[1;32m890\u001b[0m\n\u001b[1;33m rate=rate,\u001b[0m\n\u001b[1;37m ^\u001b[0m\n\u001b[1;31mSyntaxError\u001b[0m\u001b[1;31m:\u001b[0m invalid syntax\n"
]
}
],
"source": [
"from vnpy.app.cta_strategy.backtesting import BacktestingEngine\n",
"from vnpy.app.cta_strategy.strategies.turtle_signal_strategy import TurtleSignalStrategy"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -0,0 +1,29 @@
#%%
from vnpy.app.cta_strategy.backtesting import BacktestingEngine
from vnpy.app.cta_strategy.strategies.turtle_signal_strategy import (
TurtleSignalStrategy,
)
from datetime import datetime
#%%
engine = BacktestingEngine()
engine.set_parameters(
vt_symbol="IF88.CFFEX",
interval="1m",
start=datetime(2013, 1, 1),
end=datetime(2019, 1, 30),
rate=0,
slippage=0,
size=300,
pricetick=0.2,
capital=1_000_000,
)
#%%
engine.add_strategy(TurtleSignalStrategy, {})
engine.load_data()
engine.run_backtesting()
df = engine.calculate_result()
engine.calculate_statistics()
engine.show_chart()

View File

@ -4,6 +4,7 @@ from typing import Callable
from itertools import product
import multiprocessing
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pandas import DataFrame
@ -27,6 +28,60 @@ from .template import CtaTemplate
sns.set_style("whitegrid")
class OptimizationSetting:
"""
Setting for runnning optimization.
"""
def __init__(self):
""""""
self.params = {}
self.target = ""
def add_parameter(
self, name: str, start: float, end: float = None, step: float = None
):
""""""
if not end and not step:
self.params[name] = [start]
return
if start >= end:
print("参数优化起始点必须小于终止点")
return
if step <= 0:
print("参数优化步进必须大于0")
return
value = start
value_list = []
while value <= end:
value_list.append(value)
value += step
self.params[name] = value_list
def set_target(self, target: str):
""""""
self.target = target
def generate_setting(self):
""""""
keys = self.params.keys()
values = self.params.values()
products = list(product(*values))
settings = []
for product in products:
setting = dict(zip(keys, product))
settings.append(setting)
return settings
class BacktestingEngine:
""""""
@ -108,15 +163,17 @@ class BacktestingEngine:
pricetick: float,
capital: int = 0,
end: datetime = None,
mode: BacktestingMode = None,
mode: BacktestingMode = BacktestingMode.BAR,
):
""""""
self.mode = mode # 1
self.vt_symbol = vt_symbol # 2
self.interval = interval
self.rate = rate # 3
self.slippage = slippage # 4
self.size = size #
self.pricetick = pricetick #
self.start = start
self.symbol, exchange_str = self.vt_symbol.split(".")
self.exchange = Exchange(exchange_str)
@ -124,6 +181,9 @@ class BacktestingEngine:
if capital:
self.capital = capital
if end:
self.end = end
if mode:
self.mode = mode
@ -162,7 +222,7 @@ class BacktestingEngine:
self.history_data = list(s)
self.output("历史数据加载完成")
self.output(f"历史数据加载完成,数据量:{len(self.history_data)}")
def run_backtesting(self):
""""""
@ -209,7 +269,7 @@ class BacktestingEngine:
for trade in self.trades.values():
d = trade.datetime.date()
daily_result = self.daily_results[d]
d.add_trade(trade)
daily_result.add_trade(trade)
# Calculate daily result by iteration.
pre_close = 0
@ -228,7 +288,7 @@ class BacktestingEngine:
for daily_result in self.daily_results.values():
for key, value in daily_result.__dict__.items():
results[key] = value
results[key].append(value)
self.daily_df = DataFrame.from_dict(results).set_index("date")
@ -258,8 +318,8 @@ class BacktestingEngine:
end_date = df.index[-1]
total_days = len(df)
profit_days = len(df[df["netPnl"] > 0])
loss_days = len(df[df["netPnl"] < 0])
profit_days = len(df[df["net_pnl"] > 0])
loss_days = len(df[df["net_pnl"] < 0])
end_balance = df["balance"].iloc[-1]
max_drawdown = df["drawdown"].min()
@ -418,7 +478,7 @@ class BacktestingEngine:
# Sort results and output
result_values = [result.get() for result in results]
result_values.sort(reverse=True, key=lambda result:result[1])
result_values.sort(reverse=True, key=lambda result: result[1])
for value in result_values:
msg = f"参数:{value[0]}, 目标:{value[1]}"
@ -426,9 +486,6 @@ class BacktestingEngine:
return result_values
return resultList
def update_daily_close(self, price: float):
""""""
d = self.datetime.date()
@ -476,7 +533,7 @@ class BacktestingEngine:
long_best_price = long_cross_price
short_best_price = short_cross_price
for order in self.active_limit_orders.values():
for order in list(self.active_limit_orders.values()):
# Push order update with status "not traded" (pending).
if order.status == Status.SUBMITTING:
order.status = Status.NOTTRADED
@ -549,7 +606,7 @@ class BacktestingEngine:
long_best_price = long_cross_price
short_best_price = short_cross_price
for stop_order in self.active_stop_orders.values():
for stop_order in list(self.active_stop_orders.values()):
# Check whether stop order can be triggered.
long_cross = (
stop_order.direction == Direction.LONG
@ -611,6 +668,8 @@ class BacktestingEngine:
stop_order.vt_orderid = order.vt_orderid
stop_order.status = StopOrderStatus.TRIGGERED
self.active_stop_orders.pop(stop_order.stop_orderid)
# Push update to strategy.
self.strategy.on_stop_order(stop_order)
self.strategy.on_order(order)
@ -715,10 +774,12 @@ class BacktestingEngine:
"""
Cancel all orders, both limit and stop.
"""
for vt_orderid in self.active_limit_orders.keys():
vt_orderids = list(self.active_limit_orders.keys())
for vt_orderid in vt_orderids:
self.cancel_limit_order(vt_orderid)
for vt_orderid in self.active_stop_orders.keys():
stop_orderids = list(self.active_stop_orders.keys())
for vt_orderid in stop_orderids:
self.cancel_stop_order(vt_orderid)
def write_log(self, msg: str, strategy: CtaTemplate = None):
@ -734,7 +795,7 @@ class BacktestingEngine:
"""
return self.engine_type
def put_put_strategy_event(self, strategy: CtaTemplate):
def put_strategy_event(self, strategy: CtaTemplate):
"""
Put an event to update strategy status.
"""
@ -810,60 +871,6 @@ class DailyResult:
self.net_pnl = self.total_pnl - self.commission - self.slippage
class OptimizationSetting:
"""
Setting for runnning optimization.
"""
def __init__(self):
""""""
self.params = {}
self.target = ""
def add_parameter(
self, name: str, start: float, end: float = None, step: float = None
):
""""""
if not end and not step:
self.params[name] = [start]
return
if start >= end:
print("参数优化起始点必须小于终止点")
return
if step <= 0:
print("参数优化步进必须大于0")
return
value = start
value_list = []
while value <= end:
value_list.append(value)
value += step
self.params[name] = value_list
def set_target(self, target: str):
""""""
self.target = target
def generate_setting(self):
""""""
keys = self.params.keys()
values = self.params.values()
products = list(product(*values))
settings = []
for product in products:
setting = dict(zip(keys, product))
settings.append(setting)
return settings
def optimize(
target_name: str,
strategy_class: CtaTemplate,
@ -886,7 +893,7 @@ def optimize(
engine.set_parameters(
vt_symbol=vt_symbol,
interval=interval,
start=start
start=start,
rate=rate,
slippage=slippage,
size=size,

View File

@ -8,14 +8,14 @@ from enum import Enum
from vnpy.trader.constant import Direction, Offset
APP_NAME = "CtaStrategy"
STOPORDER_PREFIX = "STOP."
STOPORDER_PREFIX = "STOP"
class CtaOrderType(Enum):
BUY = "买开"
SELL = "买开"
SHORT = ""
COVER = ""
SELL = "卖平"
SHORT = ""
COVER = ""
class StopOrderStatus(Enum):

View File

@ -38,7 +38,7 @@ class TurtleSignalStrategy(CtaTemplate):
def __init__(self, cta_engine, strategy_name, vt_symbol, setting):
""""""
super(DoubleMaStrategy, self).__init__(
super(TurtleSignalStrategy, self).__init__(
cta_engine, strategy_name, vt_symbol, setting
)
@ -50,6 +50,7 @@ class TurtleSignalStrategy(CtaTemplate):
Callback when strategy is inited.
"""
self.write_log("策略初始化")
self.load_bar(20)
def on_start(self):
"""
@ -83,22 +84,22 @@ class TurtleSignalStrategy(CtaTemplate):
self.exit_up, self.exit_down = self.am.donchian(self.exit_window)
if not self.pos:
self.atr_value = self.am.atr(self.atr_value)
self.atr_value = self.am.atr(self.atr_window)
self.long_entry = 0
self.short_entry = 0
self.long_stop = 0
self.short_stop = 0
self.send_buy_orders(self.long_entry)
self.send_short_orders(self.short_entry)
elif self.pos < 0:
self.send_buy_orders(self.entry_up)
self.send_short_orders(self.entry_down)
elif self.pos > 0:
self.send_buy_orders(self.long_entry)
sell_price = max(self.long_stop, self.exit_down)
self.sell(sell_price, abs(self.pos), True)
elif self.pos > 0:
elif self.pos < 0:
self.send_short_orders(self.short_entry)
cover_price = min(self.short_stop, self.exit_up)
@ -110,7 +111,7 @@ class TurtleSignalStrategy(CtaTemplate):
"""
Callback of new trade data update.
"""
if trade.dierction == Direction.LONG:
if trade.direction == Direction.LONG:
self.long_entry = trade.price
self.long_stop = self.long_entry - 2 * self.atr_value
else:

View File

@ -169,9 +169,13 @@ class CtaTemplate(ABC):
"""
Send a new order.
"""
return self.cta_engine.send_order(
self, order_type, price, volume, stop
)
if self.trading:
vt_orderid = self.cta_engine.send_order(
self, order_type, price, volume, stop
)
else:
vt_orderid = ""
return vt_orderid
def cancel_order(self, vt_orderid: str):
"""

View File

@ -16,15 +16,6 @@ from ibapi.order_state import OrderState
from ibapi.ticktype import TickType
from ibapi.wrapper import EWrapper
from vnpy.trader.constant import (
Currency,
Direction,
Exchange,
OptionType,
PriceType,
Product,
Status,
)
from vnpy.trader.gateway import BaseGateway
from vnpy.trader.object import (
AccountData,
@ -414,7 +405,7 @@ class IbApi(EWrapper):
pos = PositionData(
symbol=contract.conId,
exchange=EXCHANGE_IB2VT.get(contract.exchange, contract.exchange),
direction=DIRECTION_NET,
direction=Direction.NET,
volume=position,
price=averageCost,
pnl=unrealizedPNL,
@ -430,9 +421,7 @@ class IbApi(EWrapper):
for account in self.accounts.values():
self.gateway.on_account(copy(account))
def contractDetails(
self, reqId: int, contractDetails: ContractDetails
): # pylint: disable=invalid-name
def contractDetails(self, reqId: int, contractDetails: ContractDetails): # pylint: disable=invalid-name
"""
Callback of contract data update.
"""

View File

@ -1,13 +1,6 @@
""""""
from peewee import (
SqliteDatabase,
Model,
CharField,
DateTimeField,
FloatField,
IntegerField,
)
from peewee import SqliteDatabase, Model, CharField, DateTimeField, FloatField
from .constant import Exchange, Interval
from .object import BarData, TickData
@ -73,10 +66,10 @@ class DbBarData(Model):
datetime=self.datetime,
interval=Interval(self.interval),
volume=self.volume,
open_price=open_price,
high_price=high_price,
low_price=low_price,
close_price=close_price,
open_price=self.open_price,
high_price=self.high_price,
low_price=self.low_price,
close_price=self.close_price,
gateway_name=self.gateway_name,
)
return bar

View File

@ -2,7 +2,7 @@
Event type string used in VN Trader.
"""
from vnpy.event import EVENT_TIMER
from vnpy.event import EVENT_TIMER # noqa
EVENT_TICK = "eTick."
EVENT_TRADE = "eTrade."

View File

@ -6,7 +6,15 @@ from abc import ABC, abstractmethod
from typing import Any
from vnpy.event import Event, EventEngine
from .event import EVENT_ACCOUNT, EVENT_CONTRACT, EVENT_LOG, EVENT_CONTRACT
from .event import (
EVENT_TICK,
EVENT_TRADE,
EVENT_ORDER,
EVENT_POSITION,
EVENT_ACCOUNT,
EVENT_CONTRACT,
EVENT_LOG,
)
from .object import (
TickData,
OrderData,
@ -17,7 +25,7 @@ from .object import (
LogData,
OrderRequest,
CancelRequest,
SubscribeRequest
SubscribeRequest,
)

View File

@ -224,10 +224,10 @@ class ArrayManager(object):
self.close_array[:-1] = self.close_array[1:]
self.volume_array[:-1] = self.volume_array[1:]
self.open_array[-1] = bar.open
self.high_array[-1] = bar.high
self.low_array[-1] = bar.low
self.close_array[-1] = bar.close
self.open_array[-1] = bar.open_price
self.high_array[-1] = bar.high_price
self.low_array[-1] = bar.low_price
self.close_array[-1] = bar.close_price
self.volume_array[-1] = bar.volume
@property