[Add]history data cache in cta backtesting to improve speed

This commit is contained in:
vn.py 2019-04-11 16:56:19 +08:00
parent 1d6506c5f3
commit b3961dbb84
2 changed files with 53 additions and 19 deletions

View File

@ -165,6 +165,7 @@ class BacktesterEngine(BaseEngine):
self.write_log("已有回测在运行中,请等待完成") self.write_log("已有回测在运行中,请等待完成")
return False return False
self.write_log("-" * 40)
self.thread = Thread( self.thread = Thread(
target=self.run_backtesting, target=self.run_backtesting,
args=( args=(

View File

@ -2,6 +2,7 @@ from collections import defaultdict
from datetime import date, datetime from datetime import date, datetime
from typing import Callable from typing import Callable
from itertools import product from itertools import product
from functools import lru_cache
import multiprocessing import multiprocessing
import numpy as np import numpy as np
@ -197,28 +198,18 @@ class BacktestingEngine:
self.output("开始加载历史数据") self.output("开始加载历史数据")
if self.mode == BacktestingMode.BAR: if self.mode == BacktestingMode.BAR:
s = ( self.history_data = load_bar_data(
DbBarData.select() self.vt_symbol,
.where( self.interval,
(DbBarData.vt_symbol == self.vt_symbol) self.start,
& (DbBarData.interval == self.interval) self.end
& (DbBarData.datetime >= self.start)
& (DbBarData.datetime <= self.end)
)
.order_by(DbBarData.datetime)
) )
self.history_data = [db_bar.to_bar() for db_bar in s]
else: else:
s = ( self.history_data = load_tick_data(
DbTickData.select() self.vt_symbol,
.where( self.start,
(DbTickData.vt_symbol == self.vt_symbol) self.end
& (DbTickData.datetime >= self.start)
& (DbTickData.datetime <= self.end)
)
.order_by(DbTickData.datetime)
) )
self.history_data = [db_tick.to_tick() for db_tick in s]
self.output(f"历史数据加载完成,数据量:{len(self.history_data)}") self.output(f"历史数据加载完成,数据量:{len(self.history_data)}")
@ -970,3 +961,45 @@ def optimize(
target_value = statistics[target_name] target_value = statistics[target_name]
return (str(setting), target_value, statistics) return (str(setting), target_value, statistics)
@lru_cache(maxsize=10)
def load_bar_data(
vt_symbol: str,
interval: str,
start: datetime,
end: datetime
):
""""""
s = (
DbBarData.select()
.where(
(DbBarData.vt_symbol == vt_symbol)
& (DbBarData.interval == interval)
& (DbBarData.datetime >= start)
& (DbBarData.datetime <= end)
)
.order_by(DbBarData.datetime)
)
data = [db_bar.to_bar() for db_bar in s]
return data
@lru_cache(maxsize=10)
def load_tick_data(
vt_symbol: str,
start: datetime,
end: datetime
):
""""""
s = (
DbTickData.select()
.where(
(DbTickData.vt_symbol == vt_symbol)
& (DbTickData.datetime >= start)
& (DbTickData.datetime <= end)
)
.order_by(DbTickData.datetime)
)
data = [db_tick.db_tick() for db_tick in s]
return data