diff --git a/vnpy/app/cta_backtester/engine.py b/vnpy/app/cta_backtester/engine.py index adcab59d..4dc80589 100644 --- a/vnpy/app/cta_backtester/engine.py +++ b/vnpy/app/cta_backtester/engine.py @@ -165,6 +165,7 @@ class BacktesterEngine(BaseEngine): self.write_log("已有回测在运行中,请等待完成") return False + self.write_log("-" * 40) self.thread = Thread( target=self.run_backtesting, args=( diff --git a/vnpy/app/cta_strategy/backtesting.py b/vnpy/app/cta_strategy/backtesting.py index 5712f798..c1ba8e67 100644 --- a/vnpy/app/cta_strategy/backtesting.py +++ b/vnpy/app/cta_strategy/backtesting.py @@ -2,6 +2,7 @@ from collections import defaultdict from datetime import date, datetime from typing import Callable from itertools import product +from functools import lru_cache import multiprocessing import numpy as np @@ -197,28 +198,18 @@ class BacktestingEngine: self.output("开始加载历史数据") if self.mode == BacktestingMode.BAR: - s = ( - DbBarData.select() - .where( - (DbBarData.vt_symbol == self.vt_symbol) - & (DbBarData.interval == self.interval) - & (DbBarData.datetime >= self.start) - & (DbBarData.datetime <= self.end) - ) - .order_by(DbBarData.datetime) + self.history_data = load_bar_data( + self.vt_symbol, + self.interval, + self.start, + self.end ) - self.history_data = [db_bar.to_bar() for db_bar in s] else: - s = ( - DbTickData.select() - .where( - (DbTickData.vt_symbol == self.vt_symbol) - & (DbTickData.datetime >= self.start) - & (DbTickData.datetime <= self.end) - ) - .order_by(DbTickData.datetime) + self.history_data = load_tick_data( + self.vt_symbol, + self.start, + self.end ) - self.history_data = [db_tick.to_tick() for db_tick in s] self.output(f"历史数据加载完成,数据量:{len(self.history_data)}") @@ -970,3 +961,45 @@ def optimize( target_value = statistics[target_name] return (str(setting), target_value, statistics) + + +@lru_cache(maxsize=10) +def load_bar_data( + vt_symbol: str, + interval: str, + start: datetime, + end: datetime +): + """""" + s = ( + DbBarData.select() + .where( + (DbBarData.vt_symbol == vt_symbol) + & (DbBarData.interval == interval) + & (DbBarData.datetime >= start) + & (DbBarData.datetime <= end) + ) + .order_by(DbBarData.datetime) + ) + data = [db_bar.to_bar() for db_bar in s] + return data + + +@lru_cache(maxsize=10) +def load_tick_data( + vt_symbol: str, + start: datetime, + end: datetime +): + """""" + s = ( + DbTickData.select() + .where( + (DbTickData.vt_symbol == vt_symbol) + & (DbTickData.datetime >= start) + & (DbTickData.datetime <= end) + ) + .order_by(DbTickData.datetime) + ) + data = [db_tick.db_tick() for db_tick in s] + return data \ No newline at end of file