From 3618044b36a63b51c858807f2e462f037c3bdf74 Mon Sep 17 00:00:00 2001 From: "vn.py" Date: Mon, 18 Feb 2019 13:42:16 +0800 Subject: [PATCH] [Fix] bugs in cta backtesting --- vnpy/app/cta_strategy/backtesting.py | 14 ++++++-------- vnpy/app/cta_strategy/engine.py | 3 ++- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/vnpy/app/cta_strategy/backtesting.py b/vnpy/app/cta_strategy/backtesting.py index e8b1f9f8..a8664662 100644 --- a/vnpy/app/cta_strategy/backtesting.py +++ b/vnpy/app/cta_strategy/backtesting.py @@ -208,6 +208,7 @@ class BacktestingEngine: ) .order_by(DbBarData.datetime) ) + self.history_data = [db_bar.to_bar() for db_bar in s] else: s = ( DbTickData.select() @@ -218,8 +219,7 @@ class BacktestingEngine: ) .order_by(DbTickData.datetime) ) - - self.history_data = list(s) + self.history_data = [db_tick.to_tick() for db_tick in s] self.output(f"历史数据加载完成,数据量:{len(self.history_data)}") @@ -303,9 +303,7 @@ class BacktestingEngine: # Calculate balance related time series data df["balance"] = df["net_pnl"].cumsum() + self.capital - df["return"] = (np.log(df["balance"] - np.log(df["balance"].shift(1)))).fillna( - 0 - ) + df["return"] = np.log(df["balance"] / df["balance"].shift(1)).fillna(0) df["highlevel"] = ( df["balance"].rolling( min_periods=1, window=len(df), center=False).max() @@ -740,7 +738,7 @@ class BacktestingEngine: ) self.active_limit_orders[order.vt_orderid] = order - self.limit_order_count[order.vt_orderid] = order + self.limit_orders[order.vt_orderid] = order return order.vt_orderid @@ -749,9 +747,9 @@ class BacktestingEngine: Cancel order by vt_orderid. """ if vt_orderid.startswith(STOPORDER_PREFIX): - self.cancel_stop_order(vt_orderid) + self.cancel_stop_order(strategy, vt_orderid) else: - self.cancel_limit_order(vt_orderid) + self.cancel_limit_order(strategy, vt_orderid) def cancel_stop_order(self, strategy: CtaTemplate, vt_orderid: str): """""" diff --git a/vnpy/app/cta_strategy/engine.py b/vnpy/app/cta_strategy/engine.py index 1abb44d9..770bf081 100644 --- a/vnpy/app/cta_strategy/engine.py +++ b/vnpy/app/cta_strategy/engine.py @@ -403,7 +403,7 @@ class CtaEngine(BaseEngine): # Query data from RQData by default, if not found, load from database. data = self.query_bar_from_rq(vt_symbol, interval, start, end) if not data: - data = ( + s = ( DbBarData.select() .where( (DbBarData.vt_symbol == vt_symbol) & @@ -413,6 +413,7 @@ class CtaEngine(BaseEngine): ) .order_by(DbBarData.datetime) ) + data = [db_bar.to_bar() for db_bar in s] for bar in data: callback(bar)