[Fix] bugs in cta backtesting
This commit is contained in:
parent
144ca19b08
commit
3618044b36
@ -208,6 +208,7 @@ class BacktestingEngine:
|
||||
)
|
||||
.order_by(DbBarData.datetime)
|
||||
)
|
||||
self.history_data = [db_bar.to_bar() for db_bar in s]
|
||||
else:
|
||||
s = (
|
||||
DbTickData.select()
|
||||
@ -218,8 +219,7 @@ class BacktestingEngine:
|
||||
)
|
||||
.order_by(DbTickData.datetime)
|
||||
)
|
||||
|
||||
self.history_data = list(s)
|
||||
self.history_data = [db_tick.to_tick() for db_tick in s]
|
||||
|
||||
self.output(f"历史数据加载完成,数据量:{len(self.history_data)}")
|
||||
|
||||
@ -303,9 +303,7 @@ class BacktestingEngine:
|
||||
|
||||
# Calculate balance related time series data
|
||||
df["balance"] = df["net_pnl"].cumsum() + self.capital
|
||||
df["return"] = (np.log(df["balance"] - np.log(df["balance"].shift(1)))).fillna(
|
||||
0
|
||||
)
|
||||
df["return"] = np.log(df["balance"] / df["balance"].shift(1)).fillna(0)
|
||||
df["highlevel"] = (
|
||||
df["balance"].rolling(
|
||||
min_periods=1, window=len(df), center=False).max()
|
||||
@ -740,7 +738,7 @@ class BacktestingEngine:
|
||||
)
|
||||
|
||||
self.active_limit_orders[order.vt_orderid] = order
|
||||
self.limit_order_count[order.vt_orderid] = order
|
||||
self.limit_orders[order.vt_orderid] = order
|
||||
|
||||
return order.vt_orderid
|
||||
|
||||
@ -749,9 +747,9 @@ class BacktestingEngine:
|
||||
Cancel order by vt_orderid.
|
||||
"""
|
||||
if vt_orderid.startswith(STOPORDER_PREFIX):
|
||||
self.cancel_stop_order(vt_orderid)
|
||||
self.cancel_stop_order(strategy, vt_orderid)
|
||||
else:
|
||||
self.cancel_limit_order(vt_orderid)
|
||||
self.cancel_limit_order(strategy, vt_orderid)
|
||||
|
||||
def cancel_stop_order(self, strategy: CtaTemplate, vt_orderid: str):
|
||||
""""""
|
||||
|
@ -403,7 +403,7 @@ class CtaEngine(BaseEngine):
|
||||
# Query data from RQData by default, if not found, load from database.
|
||||
data = self.query_bar_from_rq(vt_symbol, interval, start, end)
|
||||
if not data:
|
||||
data = (
|
||||
s = (
|
||||
DbBarData.select()
|
||||
.where(
|
||||
(DbBarData.vt_symbol == vt_symbol) &
|
||||
@ -413,6 +413,7 @@ class CtaEngine(BaseEngine):
|
||||
)
|
||||
.order_by(DbBarData.datetime)
|
||||
)
|
||||
data = [db_bar.to_bar() for db_bar in s]
|
||||
|
||||
for bar in data:
|
||||
callback(bar)
|
||||
|
Loading…
Reference in New Issue
Block a user