[Mod]complete test of new database module with CtaStrategyApp

This commit is contained in:
vn.py 2019-04-16 13:39:06 +08:00
parent 1d527bc36e
commit a5550d42c7
4 changed files with 42 additions and 28 deletions

View File

@ -2,7 +2,9 @@ from time import time
import rqdatac as rq import rqdatac as rq
from vnpy.trader.database import DbBarData from vnpy.trader.object import BarData
from vnpy.trader.constant import Exchange, Interval
from vnpy.trader.database import database_manager
USERNAME = "" USERNAME = ""
PASSWORD = "" PASSWORD = ""
@ -13,20 +15,18 @@ rq.init(USERNAME, PASSWORD, ("rqdatad-pro.ricequant.com", 16011))
def generate_bar_from_row(row, symbol, exchange): def generate_bar_from_row(row, symbol, exchange):
"""""" """"""
bar = DbBarData() bar = BarData(
symbol=symbol,
bar.symbol = symbol exchange=Exchange(exchange),
bar.exchange = exchange interval=Interval.MINUTE,
bar.interval = "1m" open_price=row["open"],
bar.open_price = row["open"] high_price=row["high"],
bar.high_price = row["high"] low_price=row["low"],
bar.low_price = row["low"] close_price=row["close"],
bar.close_price = row["close"] volume=row["volume"],
bar.volume = row["volume"] datetime=row.name.to_pydatetime(),
bar.datetime = row.name.to_pydatetime() gateway_name="DB"
bar.gateway_name = "DB" )
bar.vt_symbol = f"{symbol}.{exchange}"
return bar return bar
@ -37,13 +37,20 @@ def download_minute_bar(vt_symbol):
start = time() start = time()
df = rq.get_price(symbol, frequency="1m", fields=FIELDS) df = rq.get_price(
symbol,
frequency="1m",
fields=FIELDS,
start_date='20100416',
end_date='20190416'
)
bars = [] bars = []
for ix, row in df.iterrows(): for ix, row in df.iterrows():
bar = generate_bar_from_row(row, symbol, exchange) bar = generate_bar_from_row(row, symbol, exchange)
bars.append(bar) bars.append(bar)
DbBarData.save_all(bars)
database_manager.save_bar_data(bars)
end = time() end = time()
cost = (end - start) * 1000 cost = (end - start) * 1000

View File

@ -167,7 +167,7 @@ class BacktestingEngine:
"""""" """"""
self.mode = mode self.mode = mode
self.vt_symbol = vt_symbol self.vt_symbol = vt_symbol
self.interval = interval self.interval = Interval(interval)
self.rate = rate self.rate = rate
self.slippage = slippage self.slippage = slippage
self.size = size self.size = size

View File

@ -35,7 +35,7 @@ from vnpy.trader.constant import (
Offset, Offset,
Status Status
) )
from vnpy.trader.utility import load_json, save_json from vnpy.trader.utility import load_json, save_json, extract_vt_symbol
from vnpy.trader.database import database_manager from vnpy.trader.database import database_manager
from vnpy.trader.setting import SETTINGS from vnpy.trader.setting import SETTINGS
@ -528,10 +528,14 @@ class CtaEngine(BaseEngine):
return self.engine_type return self.engine_type
def load_bar( def load_bar(
self, symbol: str, exchange: Exchange, days: int, interval: Interval, self,
vt_symbol: str,
days: int,
interval: Interval,
callback: Callable[[BarData], None] callback: Callable[[BarData], None]
): ):
"""""" """"""
symbol, exchange = extract_vt_symbol(vt_symbol)
end = datetime.now() end = datetime.now()
start = end - timedelta(days) start = end - timedelta(days)
@ -549,9 +553,14 @@ class CtaEngine(BaseEngine):
for bar in bars: for bar in bars:
callback(bar) callback(bar)
def load_tick(self, symbol: str, exchange: Exchange, days: int, def load_tick(
callback: Callable[[TickData], None]): self,
vt_symbol: str,
days: int,
callback: Callable[[TickData], None]
):
"""""" """"""
symbol, exchange = extract_vt_symbol(vt_symbol)
end = datetime.now() end = datetime.now()
start = end - timedelta(days) start = end - timedelta(days)

View File

@ -10,20 +10,18 @@ import numpy as np
import talib import talib
from .object import BarData, TickData from .object import BarData, TickData
from .constant import Exchange
if TYPE_CHECKING:
from vnpy.trader.constant import Exchange
def extract_vt_symbol(vt_symbol: str): def extract_vt_symbol(vt_symbol: str):
""" """
:return: (symbol, exchange) :return: (symbol, exchange)
""" """
symbol, exchange = vt_symbol.split('.') symbol, exchange_str = vt_symbol.split('.')
return symbol, exchange return symbol, Exchange(exchange_str)
def generate_vt_symbol(symbol: str, exchange: "Exchange"): def generate_vt_symbol(symbol: str, exchange: Exchange):
return f'{symbol}.{exchange.value}' return f'{symbol}.{exchange.value}'