Merge remote-tracking branch 'remotes/origin/DEV' into mysql

# Conflicts:
#	vnpy/trader/utility.py
This commit is contained in:
nanoric 2019-04-11 05:31:08 -04:00
commit 2ef6e2fd00
5 changed files with 55 additions and 40 deletions

View File

@ -165,6 +165,7 @@ class BacktesterEngine(BaseEngine):
self.write_log("已有回测在运行中,请等待完成") self.write_log("已有回测在运行中,请等待完成")
return False return False
self.write_log("-" * 40)
self.thread = Thread( self.thread = Thread(
target=self.run_backtesting, target=self.run_backtesting,
args=( args=(

View File

@ -2,6 +2,7 @@ from collections import defaultdict
from datetime import date, datetime from datetime import date, datetime
from typing import Callable from typing import Callable
from itertools import product from itertools import product
from functools import lru_cache
import multiprocessing import multiprocessing
import numpy as np import numpy as np
@ -197,28 +198,18 @@ class BacktestingEngine:
self.output("开始加载历史数据") self.output("开始加载历史数据")
if self.mode == BacktestingMode.BAR: if self.mode == BacktestingMode.BAR:
s = ( self.history_data = load_bar_data(
DbBarData.select() self.vt_symbol,
.where( self.interval,
(DbBarData.vt_symbol == self.vt_symbol) self.start,
& (DbBarData.interval == self.interval) self.end
& (DbBarData.datetime >= self.start)
& (DbBarData.datetime <= self.end)
)
.order_by(DbBarData.datetime)
) )
self.history_data = [db_bar.to_bar() for db_bar in s]
else: else:
s = ( self.history_data = load_tick_data(
DbTickData.select() self.vt_symbol,
.where( self.start,
(DbTickData.vt_symbol == self.vt_symbol) self.end
& (DbTickData.datetime >= self.start)
& (DbTickData.datetime <= self.end)
)
.order_by(DbTickData.datetime)
) )
self.history_data = [db_tick.to_tick() for db_tick in s]
self.output(f"历史数据加载完成,数据量:{len(self.history_data)}") self.output(f"历史数据加载完成,数据量:{len(self.history_data)}")
@ -970,3 +961,45 @@ def optimize(
target_value = statistics[target_name] target_value = statistics[target_name]
return (str(setting), target_value, statistics) return (str(setting), target_value, statistics)
@lru_cache(maxsize=10)
def load_bar_data(
vt_symbol: str,
interval: str,
start: datetime,
end: datetime
):
""""""
s = (
DbBarData.select()
.where(
(DbBarData.vt_symbol == vt_symbol)
& (DbBarData.interval == interval)
& (DbBarData.datetime >= start)
& (DbBarData.datetime <= end)
)
.order_by(DbBarData.datetime)
)
data = [db_bar.to_bar() for db_bar in s]
return data
@lru_cache(maxsize=10)
def load_tick_data(
vt_symbol: str,
start: datetime,
end: datetime
):
""""""
s = (
DbTickData.select()
.where(
(DbTickData.vt_symbol == vt_symbol)
& (DbTickData.datetime >= start)
& (DbTickData.datetime <= end)
)
.order_by(DbTickData.datetime)
)
data = [db_tick.db_tick() for db_tick in s]
return data

View File

@ -553,7 +553,7 @@ class CtpTdApi(TdApi):
) )
# For option only # For option only
if data["OptionsType"]: if contract.product == Product.OPTION:
contract.option_underlying = data["UnderlyingInstrID"], contract.option_underlying = data["UnderlyingInstrID"],
contract.option_type = OPTIONTYPE_CTP2VT.get(data["OptionsType"], None), contract.option_type = OPTIONTYPE_CTP2VT.get(data["OptionsType"], None),
contract.option_strike = data["StrikePrice"], contract.option_strike = data["StrikePrice"],

View File

@ -24,7 +24,7 @@ from .event import (
from .gateway import BaseGateway from .gateway import BaseGateway
from .object import CancelRequest, LogData, OrderRequest, SubscribeRequest from .object import CancelRequest, LogData, OrderRequest, SubscribeRequest
from .setting import SETTINGS from .setting import SETTINGS
from .utility import Singleton, get_folder_path from .utility import get_folder_path
class MainEngine: class MainEngine:

View File

@ -13,25 +13,6 @@ import talib
from .object import BarData, TickData from .object import BarData, TickData
class Singleton(type):
"""
Singleton metaclass,
usage:
class A(metaclass=Singleton):
...
"""
_instances = {}
def __call__(cls, *args, **kwargs):
""""""
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(
*args, **kwargs
)
return cls._instances[cls]
def resolve_path(pattern: str): def resolve_path(pattern: str):
env = dict(os.environ) env = dict(os.environ)
env.update({"VNPY_TEMP": str(TEMP_DIR)}) env.update({"VNPY_TEMP": str(TEMP_DIR)})