From a5550d42c7d39c3f26ecc658cf181b7104728db8 Mon Sep 17 00:00:00 2001 From: "vn.py" Date: Tue, 16 Apr 2019 13:39:06 +0800 Subject: [PATCH] [Mod]complete test of new database module with CtaStrategyApp --- tests/backtesting/getdata.py | 41 ++++++++++++++++------------ vnpy/app/cta_strategy/backtesting.py | 2 +- vnpy/app/cta_strategy/engine.py | 17 +++++++++--- vnpy/trader/utility.py | 10 +++---- 4 files changed, 42 insertions(+), 28 deletions(-) diff --git a/tests/backtesting/getdata.py b/tests/backtesting/getdata.py index 62d02465..d5b01fdf 100644 --- a/tests/backtesting/getdata.py +++ b/tests/backtesting/getdata.py @@ -2,7 +2,9 @@ from time import time import rqdatac as rq -from vnpy.trader.database import DbBarData +from vnpy.trader.object import BarData +from vnpy.trader.constant import Exchange, Interval +from vnpy.trader.database import database_manager USERNAME = "" PASSWORD = "" @@ -13,20 +15,18 @@ rq.init(USERNAME, PASSWORD, ("rqdatad-pro.ricequant.com", 16011)) def generate_bar_from_row(row, symbol, exchange): """""" - bar = DbBarData() - - bar.symbol = symbol - bar.exchange = exchange - bar.interval = "1m" - bar.open_price = row["open"] - bar.high_price = row["high"] - bar.low_price = row["low"] - bar.close_price = row["close"] - bar.volume = row["volume"] - bar.datetime = row.name.to_pydatetime() - bar.gateway_name = "DB" - bar.vt_symbol = f"{symbol}.{exchange}" - + bar = BarData( + symbol=symbol, + exchange=Exchange(exchange), + interval=Interval.MINUTE, + open_price=row["open"], + high_price=row["high"], + low_price=row["low"], + close_price=row["close"], + volume=row["volume"], + datetime=row.name.to_pydatetime(), + gateway_name="DB" + ) return bar @@ -37,13 +37,20 @@ def download_minute_bar(vt_symbol): start = time() - df = rq.get_price(symbol, frequency="1m", fields=FIELDS) + df = rq.get_price( + symbol, + frequency="1m", + fields=FIELDS, + start_date='20100416', + end_date='20190416' + ) bars = [] for ix, row in df.iterrows(): bar = generate_bar_from_row(row, symbol, exchange) bars.append(bar) - DbBarData.save_all(bars) + + database_manager.save_bar_data(bars) end = time() cost = (end - start) * 1000 diff --git a/vnpy/app/cta_strategy/backtesting.py b/vnpy/app/cta_strategy/backtesting.py index 01423bf5..d439f4b2 100644 --- a/vnpy/app/cta_strategy/backtesting.py +++ b/vnpy/app/cta_strategy/backtesting.py @@ -167,7 +167,7 @@ class BacktestingEngine: """""" self.mode = mode self.vt_symbol = vt_symbol - self.interval = interval + self.interval = Interval(interval) self.rate = rate self.slippage = slippage self.size = size diff --git a/vnpy/app/cta_strategy/engine.py b/vnpy/app/cta_strategy/engine.py index 7824ff2d..84939da3 100644 --- a/vnpy/app/cta_strategy/engine.py +++ b/vnpy/app/cta_strategy/engine.py @@ -35,7 +35,7 @@ from vnpy.trader.constant import ( Offset, Status ) -from vnpy.trader.utility import load_json, save_json +from vnpy.trader.utility import load_json, save_json, extract_vt_symbol from vnpy.trader.database import database_manager from vnpy.trader.setting import SETTINGS @@ -528,10 +528,14 @@ class CtaEngine(BaseEngine): return self.engine_type def load_bar( - self, symbol: str, exchange: Exchange, days: int, interval: Interval, + self, + vt_symbol: str, + days: int, + interval: Interval, callback: Callable[[BarData], None] ): """""" + symbol, exchange = extract_vt_symbol(vt_symbol) end = datetime.now() start = end - timedelta(days) @@ -549,9 +553,14 @@ class CtaEngine(BaseEngine): for bar in bars: callback(bar) - def load_tick(self, symbol: str, exchange: Exchange, days: int, - callback: Callable[[TickData], None]): + def load_tick( + self, + vt_symbol: str, + days: int, + callback: Callable[[TickData], None] + ): """""" + symbol, exchange = extract_vt_symbol(vt_symbol) end = datetime.now() start = end - timedelta(days) diff --git a/vnpy/trader/utility.py b/vnpy/trader/utility.py index 69fa6991..89af4945 100644 --- a/vnpy/trader/utility.py +++ b/vnpy/trader/utility.py @@ -10,20 +10,18 @@ import numpy as np import talib from .object import BarData, TickData - -if TYPE_CHECKING: - from vnpy.trader.constant import Exchange +from .constant import Exchange def extract_vt_symbol(vt_symbol: str): """ :return: (symbol, exchange) """ - symbol, exchange = vt_symbol.split('.') - return symbol, exchange + symbol, exchange_str = vt_symbol.split('.') + return symbol, Exchange(exchange_str) -def generate_vt_symbol(symbol: str, exchange: "Exchange"): +def generate_vt_symbol(symbol: str, exchange: Exchange): return f'{symbol}.{exchange.value}'