[Add]history data cache in cta backtesting to improve speed
This commit is contained in:
parent
1d6506c5f3
commit
b3961dbb84
@ -165,6 +165,7 @@ class BacktesterEngine(BaseEngine):
|
||||
self.write_log("已有回测在运行中,请等待完成")
|
||||
return False
|
||||
|
||||
self.write_log("-" * 40)
|
||||
self.thread = Thread(
|
||||
target=self.run_backtesting,
|
||||
args=(
|
||||
|
@ -2,6 +2,7 @@ from collections import defaultdict
|
||||
from datetime import date, datetime
|
||||
from typing import Callable
|
||||
from itertools import product
|
||||
from functools import lru_cache
|
||||
import multiprocessing
|
||||
|
||||
import numpy as np
|
||||
@ -197,28 +198,18 @@ class BacktestingEngine:
|
||||
self.output("开始加载历史数据")
|
||||
|
||||
if self.mode == BacktestingMode.BAR:
|
||||
s = (
|
||||
DbBarData.select()
|
||||
.where(
|
||||
(DbBarData.vt_symbol == self.vt_symbol)
|
||||
& (DbBarData.interval == self.interval)
|
||||
& (DbBarData.datetime >= self.start)
|
||||
& (DbBarData.datetime <= self.end)
|
||||
self.history_data = load_bar_data(
|
||||
self.vt_symbol,
|
||||
self.interval,
|
||||
self.start,
|
||||
self.end
|
||||
)
|
||||
.order_by(DbBarData.datetime)
|
||||
)
|
||||
self.history_data = [db_bar.to_bar() for db_bar in s]
|
||||
else:
|
||||
s = (
|
||||
DbTickData.select()
|
||||
.where(
|
||||
(DbTickData.vt_symbol == self.vt_symbol)
|
||||
& (DbTickData.datetime >= self.start)
|
||||
& (DbTickData.datetime <= self.end)
|
||||
self.history_data = load_tick_data(
|
||||
self.vt_symbol,
|
||||
self.start,
|
||||
self.end
|
||||
)
|
||||
.order_by(DbTickData.datetime)
|
||||
)
|
||||
self.history_data = [db_tick.to_tick() for db_tick in s]
|
||||
|
||||
self.output(f"历史数据加载完成,数据量:{len(self.history_data)}")
|
||||
|
||||
@ -970,3 +961,45 @@ def optimize(
|
||||
|
||||
target_value = statistics[target_name]
|
||||
return (str(setting), target_value, statistics)
|
||||
|
||||
|
||||
@lru_cache(maxsize=10)
|
||||
def load_bar_data(
|
||||
vt_symbol: str,
|
||||
interval: str,
|
||||
start: datetime,
|
||||
end: datetime
|
||||
):
|
||||
""""""
|
||||
s = (
|
||||
DbBarData.select()
|
||||
.where(
|
||||
(DbBarData.vt_symbol == vt_symbol)
|
||||
& (DbBarData.interval == interval)
|
||||
& (DbBarData.datetime >= start)
|
||||
& (DbBarData.datetime <= end)
|
||||
)
|
||||
.order_by(DbBarData.datetime)
|
||||
)
|
||||
data = [db_bar.to_bar() for db_bar in s]
|
||||
return data
|
||||
|
||||
|
||||
@lru_cache(maxsize=10)
|
||||
def load_tick_data(
|
||||
vt_symbol: str,
|
||||
start: datetime,
|
||||
end: datetime
|
||||
):
|
||||
""""""
|
||||
s = (
|
||||
DbTickData.select()
|
||||
.where(
|
||||
(DbTickData.vt_symbol == vt_symbol)
|
||||
& (DbTickData.datetime >= start)
|
||||
& (DbTickData.datetime <= end)
|
||||
)
|
||||
.order_by(DbTickData.datetime)
|
||||
)
|
||||
data = [db_tick.db_tick() for db_tick in s]
|
||||
return data
|
Loading…
Reference in New Issue
Block a user