From b99c5ff590c916cc011c0b1d069c4a0b94f5fde3 Mon Sep 17 00:00:00 2001 From: "vn.py" Date: Tue, 29 Jan 2019 15:35:37 +0800 Subject: [PATCH] [Add]backtesting function for cta strategy --- requirements.txt | 3 +- .../turtle-checkpoint.ipynb | 6 + tests/backtesting/turtle.ipynb | 54 +++++++ tests/backtesting/turtle.py | 29 ++++ vnpy/app/cta_strategy/backtesting.py | 147 +++++++++--------- vnpy/app/cta_strategy/base.py | 8 +- .../strategies/turtle_signal_strategy.py | 15 +- vnpy/app/cta_strategy/template.py | 10 +- vnpy/gateway/ib/ib_gateway.py | 15 +- vnpy/trader/database.py | 17 +- vnpy/trader/event.py | 2 +- vnpy/trader/gateway.py | 12 +- vnpy/trader/utility.py | 8 +- 13 files changed, 209 insertions(+), 117 deletions(-) create mode 100644 tests/backtesting/.ipynb_checkpoints/turtle-checkpoint.ipynb create mode 100644 tests/backtesting/turtle.ipynb create mode 100644 tests/backtesting/turtle.py diff --git a/requirements.txt b/requirements.txt index fa2f5a18..2ef247b6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,5 @@ peewee numpy pandas matplotlib -seaborn \ No newline at end of file +seaborn +jupyter \ No newline at end of file diff --git a/tests/backtesting/.ipynb_checkpoints/turtle-checkpoint.ipynb b/tests/backtesting/.ipynb_checkpoints/turtle-checkpoint.ipynb new file mode 100644 index 00000000..2fd64429 --- /dev/null +++ b/tests/backtesting/.ipynb_checkpoints/turtle-checkpoint.ipynb @@ -0,0 +1,6 @@ +{ + "cells": [], + "metadata": {}, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/backtesting/turtle.ipynb b/tests/backtesting/turtle.ipynb new file mode 100644 index 00000000..5a54df76 --- /dev/null +++ b/tests/backtesting/turtle.ipynb @@ -0,0 +1,54 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "ename": "SyntaxError", + "evalue": "invalid syntax (backtesting.py, line 890)", + "output_type": "error", + "traceback": [ + "Traceback \u001b[1;36m(most recent call last)\u001b[0m:\n", + " File \u001b[0;32m\"c:\\miniconda3\\lib\\site-packages\\IPython\\core\\interactiveshell.py\"\u001b[0m, line \u001b[0;32m3267\u001b[0m, in \u001b[0;35mrun_code\u001b[0m\n exec(code_obj, self.user_global_ns, self.user_ns)\n", + "\u001b[1;36m File \u001b[1;32m\"\"\u001b[1;36m, line \u001b[1;32m1\u001b[1;36m, in \u001b[1;35m\u001b[1;36m\u001b[0m\n\u001b[1;33m from vnpy.app.cta_strategy.backtesting import BacktestingEngine\u001b[0m\n", + "\u001b[1;36m File \u001b[1;32m\"C:\\Github\\vnpy\\vnpy\\app\\cta_strategy\\backtesting.py\"\u001b[1;36m, line \u001b[1;32m890\u001b[0m\n\u001b[1;33m rate=rate,\u001b[0m\n\u001b[1;37m ^\u001b[0m\n\u001b[1;31mSyntaxError\u001b[0m\u001b[1;31m:\u001b[0m invalid syntax\n" + ] + } + ], + "source": [ + "from vnpy.app.cta_strategy.backtesting import BacktestingEngine\n", + "from vnpy.app.cta_strategy.strategies.turtle_signal_strategy import TurtleSignalStrategy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.1" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/backtesting/turtle.py b/tests/backtesting/turtle.py new file mode 100644 index 00000000..f7bfa6ad --- /dev/null +++ b/tests/backtesting/turtle.py @@ -0,0 +1,29 @@ +#%% +from vnpy.app.cta_strategy.backtesting import BacktestingEngine +from vnpy.app.cta_strategy.strategies.turtle_signal_strategy import ( + TurtleSignalStrategy, +) +from datetime import datetime + +#%% +engine = BacktestingEngine() +engine.set_parameters( + vt_symbol="IF88.CFFEX", + interval="1m", + start=datetime(2013, 1, 1), + end=datetime(2019, 1, 30), + rate=0, + slippage=0, + size=300, + pricetick=0.2, + capital=1_000_000, +) + +#%% +engine.add_strategy(TurtleSignalStrategy, {}) +engine.load_data() +engine.run_backtesting() +df = engine.calculate_result() +engine.calculate_statistics() +engine.show_chart() + diff --git a/vnpy/app/cta_strategy/backtesting.py b/vnpy/app/cta_strategy/backtesting.py index 049f3d67..7f6f5acf 100644 --- a/vnpy/app/cta_strategy/backtesting.py +++ b/vnpy/app/cta_strategy/backtesting.py @@ -4,6 +4,7 @@ from typing import Callable from itertools import product import multiprocessing +import numpy as np import matplotlib.pyplot as plt import seaborn as sns from pandas import DataFrame @@ -27,6 +28,60 @@ from .template import CtaTemplate sns.set_style("whitegrid") + +class OptimizationSetting: + """ + Setting for runnning optimization. + """ + + def __init__(self): + """""" + self.params = {} + self.target = "" + + def add_parameter( + self, name: str, start: float, end: float = None, step: float = None + ): + """""" + if not end and not step: + self.params[name] = [start] + return + + if start >= end: + print("参数优化起始点必须小于终止点") + return + + if step <= 0: + print("参数优化步进必须大于0") + return + + value = start + value_list = [] + + while value <= end: + value_list.append(value) + value += step + + self.params[name] = value_list + + def set_target(self, target: str): + """""" + self.target = target + + def generate_setting(self): + """""" + keys = self.params.keys() + values = self.params.values() + products = list(product(*values)) + + settings = [] + for product in products: + setting = dict(zip(keys, product)) + settings.append(setting) + + return settings + + class BacktestingEngine: """""" @@ -108,21 +163,26 @@ class BacktestingEngine: pricetick: float, capital: int = 0, end: datetime = None, - mode: BacktestingMode = None, + mode: BacktestingMode = BacktestingMode.BAR, ): """""" self.mode = mode # 1 self.vt_symbol = vt_symbol # 2 + self.interval = interval self.rate = rate # 3 self.slippage = slippage # 4 self.size = size # self.pricetick = pricetick # + self.start = start self.symbol, exchange_str = self.vt_symbol.split(".") self.exchange = Exchange(exchange_str) if capital: self.capital = capital + + if end: + self.end = end if mode: self.mode = mode @@ -162,7 +222,7 @@ class BacktestingEngine: self.history_data = list(s) - self.output("历史数据加载完成") + self.output(f"历史数据加载完成,数据量:{len(self.history_data)}") def run_backtesting(self): """""" @@ -209,7 +269,7 @@ class BacktestingEngine: for trade in self.trades.values(): d = trade.datetime.date() daily_result = self.daily_results[d] - d.add_trade(trade) + daily_result.add_trade(trade) # Calculate daily result by iteration. pre_close = 0 @@ -228,7 +288,7 @@ class BacktestingEngine: for daily_result in self.daily_results.values(): for key, value in daily_result.__dict__.items(): - results[key] = value + results[key].append(value) self.daily_df = DataFrame.from_dict(results).set_index("date") @@ -258,8 +318,8 @@ class BacktestingEngine: end_date = df.index[-1] total_days = len(df) - profit_days = len(df[df["netPnl"] > 0]) - loss_days = len(df[df["netPnl"] < 0]) + profit_days = len(df[df["net_pnl"] > 0]) + loss_days = len(df[df["net_pnl"] < 0]) end_balance = df["balance"].iloc[-1] max_drawdown = df["drawdown"].min() @@ -418,7 +478,7 @@ class BacktestingEngine: # Sort results and output result_values = [result.get() for result in results] - result_values.sort(reverse=True, key=lambda result:result[1]) + result_values.sort(reverse=True, key=lambda result: result[1]) for value in result_values: msg = f"参数:{value[0]}, 目标:{value[1]}" @@ -426,9 +486,6 @@ class BacktestingEngine: return result_values - return resultList - - def update_daily_close(self, price: float): """""" d = self.datetime.date() @@ -476,7 +533,7 @@ class BacktestingEngine: long_best_price = long_cross_price short_best_price = short_cross_price - for order in self.active_limit_orders.values(): + for order in list(self.active_limit_orders.values()): # Push order update with status "not traded" (pending). if order.status == Status.SUBMITTING: order.status = Status.NOTTRADED @@ -549,7 +606,7 @@ class BacktestingEngine: long_best_price = long_cross_price short_best_price = short_cross_price - for stop_order in self.active_stop_orders.values(): + for stop_order in list(self.active_stop_orders.values()): # Check whether stop order can be triggered. long_cross = ( stop_order.direction == Direction.LONG @@ -611,6 +668,8 @@ class BacktestingEngine: stop_order.vt_orderid = order.vt_orderid stop_order.status = StopOrderStatus.TRIGGERED + self.active_stop_orders.pop(stop_order.stop_orderid) + # Push update to strategy. self.strategy.on_stop_order(stop_order) self.strategy.on_order(order) @@ -715,10 +774,12 @@ class BacktestingEngine: """ Cancel all orders, both limit and stop. """ - for vt_orderid in self.active_limit_orders.keys(): + vt_orderids = list(self.active_limit_orders.keys()) + for vt_orderid in vt_orderids: self.cancel_limit_order(vt_orderid) - for vt_orderid in self.active_stop_orders.keys(): + stop_orderids = list(self.active_stop_orders.keys()) + for vt_orderid in stop_orderids: self.cancel_stop_order(vt_orderid) def write_log(self, msg: str, strategy: CtaTemplate = None): @@ -734,7 +795,7 @@ class BacktestingEngine: """ return self.engine_type - def put_put_strategy_event(self, strategy: CtaTemplate): + def put_strategy_event(self, strategy: CtaTemplate): """ Put an event to update strategy status. """ @@ -810,60 +871,6 @@ class DailyResult: self.net_pnl = self.total_pnl - self.commission - self.slippage -class OptimizationSetting: - """ - Setting for runnning optimization. - """ - - def __init__(self): - """""" - self.params = {} - self.target = "" - - def add_parameter( - self, name: str, start: float, end: float = None, step: float = None - ): - """""" - if not end and not step: - self.params[name] = [start] - return - - if start >= end: - print("参数优化起始点必须小于终止点") - return - - if step <= 0: - print("参数优化步进必须大于0") - return - - value = start - value_list = [] - - while value <= end: - value_list.append(value) - value += step - - self.params[name] = value_list - - def set_target(self, target: str): - """""" - self.target = target - - def generate_setting(self): - """""" - keys = self.params.keys() - values = self.params.values() - products = list(product(*values)) - - settings = [] - for product in products: - setting = dict(zip(keys, product)) - settings.append(setting) - - return settings - - - def optimize( target_name: str, strategy_class: CtaTemplate, @@ -886,7 +893,7 @@ def optimize( engine.set_parameters( vt_symbol=vt_symbol, interval=interval, - start=start + start=start, rate=rate, slippage=slippage, size=size, diff --git a/vnpy/app/cta_strategy/base.py b/vnpy/app/cta_strategy/base.py index 351e72b5..4ee50b70 100644 --- a/vnpy/app/cta_strategy/base.py +++ b/vnpy/app/cta_strategy/base.py @@ -8,14 +8,14 @@ from enum import Enum from vnpy.trader.constant import Direction, Offset APP_NAME = "CtaStrategy" -STOPORDER_PREFIX = "STOP." +STOPORDER_PREFIX = "STOP" class CtaOrderType(Enum): BUY = "买开" - SELL = "买开" - SHORT = "买开" - COVER = "买开" + SELL = "卖平" + SHORT = "卖开" + COVER = "买平" class StopOrderStatus(Enum): diff --git a/vnpy/app/cta_strategy/strategies/turtle_signal_strategy.py b/vnpy/app/cta_strategy/strategies/turtle_signal_strategy.py index b5a1044f..8ee0a0c4 100644 --- a/vnpy/app/cta_strategy/strategies/turtle_signal_strategy.py +++ b/vnpy/app/cta_strategy/strategies/turtle_signal_strategy.py @@ -38,7 +38,7 @@ class TurtleSignalStrategy(CtaTemplate): def __init__(self, cta_engine, strategy_name, vt_symbol, setting): """""" - super(DoubleMaStrategy, self).__init__( + super(TurtleSignalStrategy, self).__init__( cta_engine, strategy_name, vt_symbol, setting ) @@ -50,6 +50,7 @@ class TurtleSignalStrategy(CtaTemplate): Callback when strategy is inited. """ self.write_log("策略初始化") + self.load_bar(20) def on_start(self): """ @@ -83,22 +84,22 @@ class TurtleSignalStrategy(CtaTemplate): self.exit_up, self.exit_down = self.am.donchian(self.exit_window) if not self.pos: - self.atr_value = self.am.atr(self.atr_value) + self.atr_value = self.am.atr(self.atr_window) self.long_entry = 0 self.short_entry = 0 self.long_stop = 0 self.short_stop = 0 - self.send_buy_orders(self.long_entry) - self.send_short_orders(self.short_entry) - elif self.pos < 0: + self.send_buy_orders(self.entry_up) + self.send_short_orders(self.entry_down) + elif self.pos > 0: self.send_buy_orders(self.long_entry) sell_price = max(self.long_stop, self.exit_down) self.sell(sell_price, abs(self.pos), True) - elif self.pos > 0: + elif self.pos < 0: self.send_short_orders(self.short_entry) cover_price = min(self.short_stop, self.exit_up) @@ -110,7 +111,7 @@ class TurtleSignalStrategy(CtaTemplate): """ Callback of new trade data update. """ - if trade.dierction == Direction.LONG: + if trade.direction == Direction.LONG: self.long_entry = trade.price self.long_stop = self.long_entry - 2 * self.atr_value else: diff --git a/vnpy/app/cta_strategy/template.py b/vnpy/app/cta_strategy/template.py index c6dda25b..318c3c11 100644 --- a/vnpy/app/cta_strategy/template.py +++ b/vnpy/app/cta_strategy/template.py @@ -169,9 +169,13 @@ class CtaTemplate(ABC): """ Send a new order. """ - return self.cta_engine.send_order( - self, order_type, price, volume, stop - ) + if self.trading: + vt_orderid = self.cta_engine.send_order( + self, order_type, price, volume, stop + ) + else: + vt_orderid = "" + return vt_orderid def cancel_order(self, vt_orderid: str): """ diff --git a/vnpy/gateway/ib/ib_gateway.py b/vnpy/gateway/ib/ib_gateway.py index 5204fa33..75caaf56 100644 --- a/vnpy/gateway/ib/ib_gateway.py +++ b/vnpy/gateway/ib/ib_gateway.py @@ -16,15 +16,6 @@ from ibapi.order_state import OrderState from ibapi.ticktype import TickType from ibapi.wrapper import EWrapper -from vnpy.trader.constant import ( - Currency, - Direction, - Exchange, - OptionType, - PriceType, - Product, - Status, -) from vnpy.trader.gateway import BaseGateway from vnpy.trader.object import ( AccountData, @@ -414,7 +405,7 @@ class IbApi(EWrapper): pos = PositionData( symbol=contract.conId, exchange=EXCHANGE_IB2VT.get(contract.exchange, contract.exchange), - direction=DIRECTION_NET, + direction=Direction.NET, volume=position, price=averageCost, pnl=unrealizedPNL, @@ -430,9 +421,7 @@ class IbApi(EWrapper): for account in self.accounts.values(): self.gateway.on_account(copy(account)) - def contractDetails( - self, reqId: int, contractDetails: ContractDetails - ): # pylint: disable=invalid-name + def contractDetails(self, reqId: int, contractDetails: ContractDetails): # pylint: disable=invalid-name """ Callback of contract data update. """ diff --git a/vnpy/trader/database.py b/vnpy/trader/database.py index 9ecf3023..97027e9d 100644 --- a/vnpy/trader/database.py +++ b/vnpy/trader/database.py @@ -1,13 +1,6 @@ """""" -from peewee import ( - SqliteDatabase, - Model, - CharField, - DateTimeField, - FloatField, - IntegerField, -) +from peewee import SqliteDatabase, Model, CharField, DateTimeField, FloatField from .constant import Exchange, Interval from .object import BarData, TickData @@ -73,10 +66,10 @@ class DbBarData(Model): datetime=self.datetime, interval=Interval(self.interval), volume=self.volume, - open_price=open_price, - high_price=high_price, - low_price=low_price, - close_price=close_price, + open_price=self.open_price, + high_price=self.high_price, + low_price=self.low_price, + close_price=self.close_price, gateway_name=self.gateway_name, ) return bar diff --git a/vnpy/trader/event.py b/vnpy/trader/event.py index 6c92f948..83902ec2 100644 --- a/vnpy/trader/event.py +++ b/vnpy/trader/event.py @@ -2,7 +2,7 @@ Event type string used in VN Trader. """ -from vnpy.event import EVENT_TIMER +from vnpy.event import EVENT_TIMER # noqa EVENT_TICK = "eTick." EVENT_TRADE = "eTrade." diff --git a/vnpy/trader/gateway.py b/vnpy/trader/gateway.py index 4308a528..ca49cf12 100644 --- a/vnpy/trader/gateway.py +++ b/vnpy/trader/gateway.py @@ -6,7 +6,15 @@ from abc import ABC, abstractmethod from typing import Any from vnpy.event import Event, EventEngine -from .event import EVENT_ACCOUNT, EVENT_CONTRACT, EVENT_LOG, EVENT_CONTRACT +from .event import ( + EVENT_TICK, + EVENT_TRADE, + EVENT_ORDER, + EVENT_POSITION, + EVENT_ACCOUNT, + EVENT_CONTRACT, + EVENT_LOG, +) from .object import ( TickData, OrderData, @@ -17,7 +25,7 @@ from .object import ( LogData, OrderRequest, CancelRequest, - SubscribeRequest + SubscribeRequest, ) diff --git a/vnpy/trader/utility.py b/vnpy/trader/utility.py index 39262068..0a8198a1 100644 --- a/vnpy/trader/utility.py +++ b/vnpy/trader/utility.py @@ -224,10 +224,10 @@ class ArrayManager(object): self.close_array[:-1] = self.close_array[1:] self.volume_array[:-1] = self.volume_array[1:] - self.open_array[-1] = bar.open - self.high_array[-1] = bar.high - self.low_array[-1] = bar.low - self.close_array[-1] = bar.close + self.open_array[-1] = bar.open_price + self.high_array[-1] = bar.high_price + self.low_array[-1] = bar.low_price + self.close_array[-1] = bar.close_price self.volume_array[-1] = bar.volume @property