Merge remote-tracking branch 'remotes/origin/DEV' into mysql
# Conflicts: # vnpy/trader/utility.py
This commit is contained in:
commit
2ef6e2fd00
@ -165,6 +165,7 @@ class BacktesterEngine(BaseEngine):
|
|||||||
self.write_log("已有回测在运行中,请等待完成")
|
self.write_log("已有回测在运行中,请等待完成")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
self.write_log("-" * 40)
|
||||||
self.thread = Thread(
|
self.thread = Thread(
|
||||||
target=self.run_backtesting,
|
target=self.run_backtesting,
|
||||||
args=(
|
args=(
|
||||||
|
@ -2,6 +2,7 @@ from collections import defaultdict
|
|||||||
from datetime import date, datetime
|
from datetime import date, datetime
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
from itertools import product
|
from itertools import product
|
||||||
|
from functools import lru_cache
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -197,28 +198,18 @@ class BacktestingEngine:
|
|||||||
self.output("开始加载历史数据")
|
self.output("开始加载历史数据")
|
||||||
|
|
||||||
if self.mode == BacktestingMode.BAR:
|
if self.mode == BacktestingMode.BAR:
|
||||||
s = (
|
self.history_data = load_bar_data(
|
||||||
DbBarData.select()
|
self.vt_symbol,
|
||||||
.where(
|
self.interval,
|
||||||
(DbBarData.vt_symbol == self.vt_symbol)
|
self.start,
|
||||||
& (DbBarData.interval == self.interval)
|
self.end
|
||||||
& (DbBarData.datetime >= self.start)
|
|
||||||
& (DbBarData.datetime <= self.end)
|
|
||||||
)
|
|
||||||
.order_by(DbBarData.datetime)
|
|
||||||
)
|
)
|
||||||
self.history_data = [db_bar.to_bar() for db_bar in s]
|
|
||||||
else:
|
else:
|
||||||
s = (
|
self.history_data = load_tick_data(
|
||||||
DbTickData.select()
|
self.vt_symbol,
|
||||||
.where(
|
self.start,
|
||||||
(DbTickData.vt_symbol == self.vt_symbol)
|
self.end
|
||||||
& (DbTickData.datetime >= self.start)
|
|
||||||
& (DbTickData.datetime <= self.end)
|
|
||||||
)
|
|
||||||
.order_by(DbTickData.datetime)
|
|
||||||
)
|
)
|
||||||
self.history_data = [db_tick.to_tick() for db_tick in s]
|
|
||||||
|
|
||||||
self.output(f"历史数据加载完成,数据量:{len(self.history_data)}")
|
self.output(f"历史数据加载完成,数据量:{len(self.history_data)}")
|
||||||
|
|
||||||
@ -970,3 +961,45 @@ def optimize(
|
|||||||
|
|
||||||
target_value = statistics[target_name]
|
target_value = statistics[target_name]
|
||||||
return (str(setting), target_value, statistics)
|
return (str(setting), target_value, statistics)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=10)
|
||||||
|
def load_bar_data(
|
||||||
|
vt_symbol: str,
|
||||||
|
interval: str,
|
||||||
|
start: datetime,
|
||||||
|
end: datetime
|
||||||
|
):
|
||||||
|
""""""
|
||||||
|
s = (
|
||||||
|
DbBarData.select()
|
||||||
|
.where(
|
||||||
|
(DbBarData.vt_symbol == vt_symbol)
|
||||||
|
& (DbBarData.interval == interval)
|
||||||
|
& (DbBarData.datetime >= start)
|
||||||
|
& (DbBarData.datetime <= end)
|
||||||
|
)
|
||||||
|
.order_by(DbBarData.datetime)
|
||||||
|
)
|
||||||
|
data = [db_bar.to_bar() for db_bar in s]
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=10)
|
||||||
|
def load_tick_data(
|
||||||
|
vt_symbol: str,
|
||||||
|
start: datetime,
|
||||||
|
end: datetime
|
||||||
|
):
|
||||||
|
""""""
|
||||||
|
s = (
|
||||||
|
DbTickData.select()
|
||||||
|
.where(
|
||||||
|
(DbTickData.vt_symbol == vt_symbol)
|
||||||
|
& (DbTickData.datetime >= start)
|
||||||
|
& (DbTickData.datetime <= end)
|
||||||
|
)
|
||||||
|
.order_by(DbTickData.datetime)
|
||||||
|
)
|
||||||
|
data = [db_tick.db_tick() for db_tick in s]
|
||||||
|
return data
|
@ -553,7 +553,7 @@ class CtpTdApi(TdApi):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# For option only
|
# For option only
|
||||||
if data["OptionsType"]:
|
if contract.product == Product.OPTION:
|
||||||
contract.option_underlying = data["UnderlyingInstrID"],
|
contract.option_underlying = data["UnderlyingInstrID"],
|
||||||
contract.option_type = OPTIONTYPE_CTP2VT.get(data["OptionsType"], None),
|
contract.option_type = OPTIONTYPE_CTP2VT.get(data["OptionsType"], None),
|
||||||
contract.option_strike = data["StrikePrice"],
|
contract.option_strike = data["StrikePrice"],
|
||||||
|
@ -24,7 +24,7 @@ from .event import (
|
|||||||
from .gateway import BaseGateway
|
from .gateway import BaseGateway
|
||||||
from .object import CancelRequest, LogData, OrderRequest, SubscribeRequest
|
from .object import CancelRequest, LogData, OrderRequest, SubscribeRequest
|
||||||
from .setting import SETTINGS
|
from .setting import SETTINGS
|
||||||
from .utility import Singleton, get_folder_path
|
from .utility import get_folder_path
|
||||||
|
|
||||||
|
|
||||||
class MainEngine:
|
class MainEngine:
|
||||||
|
@ -13,25 +13,6 @@ import talib
|
|||||||
from .object import BarData, TickData
|
from .object import BarData, TickData
|
||||||
|
|
||||||
|
|
||||||
class Singleton(type):
|
|
||||||
"""
|
|
||||||
Singleton metaclass,
|
|
||||||
|
|
||||||
usage:
|
|
||||||
class A(metaclass=Singleton):
|
|
||||||
...
|
|
||||||
"""
|
|
||||||
_instances = {}
|
|
||||||
|
|
||||||
def __call__(cls, *args, **kwargs):
|
|
||||||
""""""
|
|
||||||
if cls not in cls._instances:
|
|
||||||
cls._instances[cls] = super(Singleton, cls).__call__(
|
|
||||||
*args, **kwargs
|
|
||||||
)
|
|
||||||
return cls._instances[cls]
|
|
||||||
|
|
||||||
|
|
||||||
def resolve_path(pattern: str):
|
def resolve_path(pattern: str):
|
||||||
env = dict(os.environ)
|
env = dict(os.environ)
|
||||||
env.update({"VNPY_TEMP": str(TEMP_DIR)})
|
env.update({"VNPY_TEMP": str(TEMP_DIR)})
|
||||||
|
Loading…
Reference in New Issue
Block a user