Merge remote-tracking branch 'remotes/origin/DEV' into mysql

# Conflicts:
#	vnpy/trader/utility.py
This commit is contained in:
nanoric 2019-04-11 05:31:08 -04:00
commit 2ef6e2fd00
5 changed files with 55 additions and 40 deletions

View File

@ -165,6 +165,7 @@ class BacktesterEngine(BaseEngine):
self.write_log("已有回测在运行中,请等待完成")
return False
self.write_log("-" * 40)
self.thread = Thread(
target=self.run_backtesting,
args=(

View File

@ -2,6 +2,7 @@ from collections import defaultdict
from datetime import date, datetime
from typing import Callable
from itertools import product
from functools import lru_cache
import multiprocessing
import numpy as np
@ -197,28 +198,18 @@ class BacktestingEngine:
self.output("开始加载历史数据")
if self.mode == BacktestingMode.BAR:
s = (
DbBarData.select()
.where(
(DbBarData.vt_symbol == self.vt_symbol)
& (DbBarData.interval == self.interval)
& (DbBarData.datetime >= self.start)
& (DbBarData.datetime <= self.end)
)
.order_by(DbBarData.datetime)
self.history_data = load_bar_data(
self.vt_symbol,
self.interval,
self.start,
self.end
)
self.history_data = [db_bar.to_bar() for db_bar in s]
else:
s = (
DbTickData.select()
.where(
(DbTickData.vt_symbol == self.vt_symbol)
& (DbTickData.datetime >= self.start)
& (DbTickData.datetime <= self.end)
)
.order_by(DbTickData.datetime)
self.history_data = load_tick_data(
self.vt_symbol,
self.start,
self.end
)
self.history_data = [db_tick.to_tick() for db_tick in s]
self.output(f"历史数据加载完成,数据量:{len(self.history_data)}")
@ -970,3 +961,45 @@ def optimize(
target_value = statistics[target_name]
return (str(setting), target_value, statistics)
@lru_cache(maxsize=10)
def load_bar_data(
vt_symbol: str,
interval: str,
start: datetime,
end: datetime
):
""""""
s = (
DbBarData.select()
.where(
(DbBarData.vt_symbol == vt_symbol)
& (DbBarData.interval == interval)
& (DbBarData.datetime >= start)
& (DbBarData.datetime <= end)
)
.order_by(DbBarData.datetime)
)
data = [db_bar.to_bar() for db_bar in s]
return data
@lru_cache(maxsize=10)
def load_tick_data(
vt_symbol: str,
start: datetime,
end: datetime
):
""""""
s = (
DbTickData.select()
.where(
(DbTickData.vt_symbol == vt_symbol)
& (DbTickData.datetime >= start)
& (DbTickData.datetime <= end)
)
.order_by(DbTickData.datetime)
)
data = [db_tick.db_tick() for db_tick in s]
return data

View File

@ -553,7 +553,7 @@ class CtpTdApi(TdApi):
)
# For option only
if data["OptionsType"]:
if contract.product == Product.OPTION:
contract.option_underlying = data["UnderlyingInstrID"],
contract.option_type = OPTIONTYPE_CTP2VT.get(data["OptionsType"], None),
contract.option_strike = data["StrikePrice"],

View File

@ -24,7 +24,7 @@ from .event import (
from .gateway import BaseGateway
from .object import CancelRequest, LogData, OrderRequest, SubscribeRequest
from .setting import SETTINGS
from .utility import Singleton, get_folder_path
from .utility import get_folder_path
class MainEngine:

View File

@ -13,25 +13,6 @@ import talib
from .object import BarData, TickData
class Singleton(type):
"""
Singleton metaclass,
usage:
class A(metaclass=Singleton):
...
"""
_instances = {}
def __call__(cls, *args, **kwargs):
""""""
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(
*args, **kwargs
)
return cls._instances[cls]
def resolve_path(pattern: str):
env = dict(os.environ)
env.update({"VNPY_TEMP": str(TEMP_DIR)})