[update] 天勤数据下载适应新版
This commit is contained in:
parent
015efa9d32
commit
2c58c51e7f
@ -6,14 +6,20 @@
|
||||
# 2. 下载tick时,5档行情都下载
|
||||
# 3. 五档行情变量调整适合vnpy的命名方式
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import csv
|
||||
import os
|
||||
from datetime import date, datetime
|
||||
from typing import Union, List
|
||||
import lzma
|
||||
|
||||
import pandas
|
||||
import json
|
||||
from tqsdk.api import TqApi
|
||||
from tqsdk.channel import TqChan
|
||||
from tqsdk.datetime import _get_trading_day_start_time, _get_trading_day_end_time
|
||||
from tqsdk.diff import _get_obj
|
||||
from tqsdk.tafunc import get_dividend_df, get_dividend_factor
|
||||
from tqsdk.utils import _generate_uuid
|
||||
|
||||
def get_account_config():
|
||||
@ -32,15 +38,25 @@ def get_account_config():
|
||||
|
||||
return {}
|
||||
|
||||
DEAD_INS = {}
|
||||
|
||||
# 价格相关的字段,需要 format 数据格式
|
||||
PRICE_KEYS = ["open", "high", "low", "close", "last_price", "highest", "lowest"] + [f"bid_price{i}" for i in range(1, 6)] + [f"ask_price{i}" for i in range(1, 6)]
|
||||
|
||||
|
||||
class DataDownloader:
|
||||
"""
|
||||
数据下载工具是 TqSdk 专业版中的功能,能让用户下载目前 TqSdk 提供的全部期货、期权和股票类的历史数据,下载数据支持 tick 级别精度和任意 kline 周期
|
||||
|
||||
如果想使用数据下载工具,可以点击 `天勤量化专业版 <https://www.shinnytech.com/tqsdk_professional/>`_ 申请使用或购买
|
||||
|
||||
历史数据下载器, 输出到csv文件
|
||||
|
||||
多合约按时间横向对齐
|
||||
"""
|
||||
|
||||
def __init__(self, api: TqApi, symbol_list: Union[str, List[str]], dur_sec: int, start_dt: Union[date, datetime],
|
||||
end_dt: Union[date, datetime], csv_file_name: str) -> None:
|
||||
end_dt: Union[date, datetime], csv_file_name: str, adj_type: Union[str, None] = None) -> None:
|
||||
"""
|
||||
创建历史数据下载器实例
|
||||
|
||||
@ -57,14 +73,16 @@ class DataDownloader:
|
||||
|
||||
csv_file_name (str): 输出 csv 的文件名
|
||||
|
||||
adj_type (str/None): 复权计算方式,默认值为 None。"F" 为前复权;"B" 为后复权;None 表示不复权。只对股票、基金合约有效。
|
||||
|
||||
Example::
|
||||
|
||||
from datetime import datetime, date
|
||||
from contextlib import closing
|
||||
from tqsdk import TqApi, TqSim
|
||||
from tqsdk import TqApi, TqAuth, TqSim
|
||||
from tqsdk.tools import DataDownloader
|
||||
|
||||
api = TqApi(TqSim())
|
||||
api = TqApi(auth=TqAuth("信易账户", "账户密码"))
|
||||
download_tasks = {}
|
||||
# 下载从 2018-01-01 到 2018-09-01 的 SR901 日线数据
|
||||
download_tasks["SR_daily"] = DataDownloader(api, symbol_list="CZCE.SR901", dur_sec=24*60*60,
|
||||
@ -87,27 +105,31 @@ class DataDownloader:
|
||||
print("progress: ", { k:("%.2f%%" % v.get_progress()) for k,v in download_tasks.items() })
|
||||
"""
|
||||
self._api = api
|
||||
if not self._api._auth._has_feature("tq_dl"):
|
||||
raise Exception("您的账户不支持下载历史数据功能,需要购买专业版本后使用。升级网址:https://account.shinnytech.com")
|
||||
if isinstance(start_dt, datetime):
|
||||
self._start_dt_nano = int(start_dt.timestamp() * 1e9)
|
||||
else:
|
||||
self._start_dt_nano = _get_trading_day_start_time(
|
||||
int(datetime(start_dt.year, start_dt.month, start_dt.day).timestamp()) * 1000000000)
|
||||
self._start_dt_nano = _get_trading_day_start_time(int(datetime(start_dt.year, start_dt.month, start_dt.day).timestamp()) * 1000000000)
|
||||
if isinstance(end_dt, datetime):
|
||||
self._end_dt_nano = int(end_dt.timestamp() * 1e9)
|
||||
else:
|
||||
self._end_dt_nano = _get_trading_day_end_time(
|
||||
int(datetime(end_dt.year, end_dt.month, end_dt.day).timestamp()) * 1000000000)
|
||||
self._end_dt_nano = _get_trading_day_end_time(int(datetime(end_dt.year, end_dt.month, end_dt.day).timestamp()) * 1000000000)
|
||||
self._current_dt_nano = self._start_dt_nano
|
||||
self._symbol_list = symbol_list if isinstance(symbol_list, list) else [symbol_list]
|
||||
# 检查合约代码是否存在
|
||||
for symbol in self._symbol_list:
|
||||
if (not self._api._stock) and symbol not in self._api._data.get("quotes", {}):
|
||||
raise Exception("代码 %s 不存在, 请检查合约代码是否填写正确" % (symbol))
|
||||
# 下载合约超时时间(默认 30s),已下市的没有交易的合约,超时时间可以设置短一点(2s),用户不希望自己的程序因为没有下载到数据而中断
|
||||
self._timeout_seconds = 2 if any([symbol in DEAD_INS for symbol in self._symbol_list]) else 30
|
||||
self._dur_nano = dur_sec * 1000000000
|
||||
if self._dur_nano == 0 and len(self._symbol_list) != 1:
|
||||
raise Exception("Tick序列不支持多合约")
|
||||
if adj_type not in [None, "F", "B", "FORWARD", "BACK"]:
|
||||
raise Exception("adj_type 参数只支持 None (不复权) | 'F' (前复权) | 'B' (后复权)")
|
||||
self._adj_type = adj_type[0] if adj_type else adj_type
|
||||
self._csv_file_name = csv_file_name
|
||||
self._task = self._api.create_task(self._download_data())
|
||||
self._csv_header = self._get_headers()
|
||||
self._dividend_cache = {} # 缓存合约对应的复权系数矩阵,每个合约只计算一次
|
||||
self._data_series = None
|
||||
self._task = self._api.create_task(self._run())
|
||||
|
||||
def is_finished(self) -> bool:
|
||||
"""
|
||||
@ -128,6 +150,142 @@ class DataDownloader:
|
||||
return 100.0 if self._task.done() else (self._current_dt_nano - self._start_dt_nano) / (
|
||||
self._end_dt_nano - self._start_dt_nano) * 100
|
||||
|
||||
def _get_data_series(self) -> pandas.DataFrame:
|
||||
"""
|
||||
获取下载的 DataFrame 格式数据
|
||||
|
||||
todo: 在 utils 中增加工具函数,返回与 kline 一致的数据结构
|
||||
|
||||
Returns:
|
||||
pandas.DataFrame/None: 下载的 klines 或者 ticks 数据,DataFrame 格式。下载完成前返回 None。
|
||||
|
||||
|
||||
Example::
|
||||
|
||||
from datetime import datetime, date
|
||||
rom tqsdk import TqApi, TqAuth
|
||||
from contextlib import closing
|
||||
from tqsdk.tools import DataDownloader
|
||||
|
||||
api = TqApi(auth=TqAuth("信易账户", "账户密码"))
|
||||
# 下载从 2018-06-01 到 2018-09-01 的 SR901 日线数据
|
||||
download_task = DataDownloader(api, symbol_list="CZCE.SR901", dur_sec=24*60*60,
|
||||
start_dt=date(2018, 6, 1), end_dt=date(2018, 9, 1), csv_file_name="klines.csv")
|
||||
# 使用with closing机制确保下载完成后释放对应的资源
|
||||
with closing(api):
|
||||
while not download_task.is_finished():
|
||||
api.wait_update()
|
||||
print(f"progress: {download_task.get_progress():.2} %")
|
||||
print(download_task._get_data_series())
|
||||
"""
|
||||
if not self._task.done():
|
||||
return None
|
||||
if not self._data_series:
|
||||
self._data_series = pandas.read_csv(self._csv_file_name)
|
||||
return self._data_series
|
||||
|
||||
async def _update_dividend_factor(self):
|
||||
for s in self._dividend_cache:
|
||||
# 对每个除权除息矩阵增加 factor 序列,为当日的复权因子
|
||||
df = self._dividend_cache[s]["df"]
|
||||
df["pre_close"] = float('nan') # 初始化 pre_close 为 nan
|
||||
between = df["datetime"].between(self._start_dt_nano, self._end_dt_nano) # 只需要开始时间~结束时间之间的复权因子
|
||||
for i in df[between].index:
|
||||
chart_info = {
|
||||
"aid": "set_chart",
|
||||
"chart_id": _generate_uuid("PYSDK_downloader"),
|
||||
"ins_list": s,
|
||||
"duration": 86400 * 1000000000,
|
||||
"view_width": 2,
|
||||
"focus_datetime": int(df.iloc[i].datetime),
|
||||
"focus_position": 1
|
||||
}
|
||||
await self._api._send_chan.send(chart_info)
|
||||
chart = _get_obj(self._api._data, ["charts", chart_info["chart_id"]])
|
||||
serial = _get_obj(self._api._data, ["klines", s, str(86400000000000)])
|
||||
try:
|
||||
async with self._api.register_update_notify() as update_chan:
|
||||
async for _ in update_chan:
|
||||
if not (chart_info.items() <= _get_obj(chart, ["state"]).items()):
|
||||
continue # 当前请求还没收齐回应, 不应继续处理
|
||||
left_id = chart.get("left_id", -1)
|
||||
right_id = chart.get("right_id", -1)
|
||||
if (left_id == -1 and right_id == -1) or self._api._data.get("mdhis_more_data", True) or serial.get("last_id", -1) == -1:
|
||||
continue # 定位信息还没收到, 或数据序列还没收到, 合约的数据是否收到
|
||||
last_item = serial["data"].get(str(left_id), {})
|
||||
# 复权时间点的昨收盘
|
||||
df.loc[i, 'pre_close'] = last_item['close'] if last_item.get('close') else float('nan')
|
||||
break
|
||||
finally:
|
||||
await self._api._send_chan.send({
|
||||
"aid": "set_chart",
|
||||
"chart_id": chart_info["chart_id"],
|
||||
"ins_list": "",
|
||||
"duration": 86400000000000,
|
||||
"view_width": 2
|
||||
})
|
||||
df["factor"] = (df["pre_close"] - df["cash_dividend"]) / df["pre_close"] / (1 + df["stock_dividend"])
|
||||
df["factor"].fillna(1, inplace=True)
|
||||
|
||||
async def _run(self):
|
||||
self._quote_list = await self._api.get_quote_list(self._symbol_list)
|
||||
# 如果存在 STOCK / FUND 并且 adj_type is not None, 这里需要提前准备下载时间段内的复权因子
|
||||
if self._adj_type:
|
||||
for quote in self._quote_list:
|
||||
if quote.ins_class in ["STOCK", "FUND"]:
|
||||
self._dividend_cache[quote.instrument_id] = {
|
||||
"df": get_dividend_df(quote.stock_dividend_ratio, quote.cash_dividend_ratio),
|
||||
"back_factor": 1.0
|
||||
}
|
||||
# 前复权需要提前计算除权因子
|
||||
await self._update_dividend_factor()
|
||||
self._data_chan = TqChan(self._api)
|
||||
task = self._api.create_task(self._download_data())
|
||||
# cols 是复权需要重新计算的列名
|
||||
index_datetime_nano = self._csv_header.index("datetime_nano")
|
||||
if self._dur_nano != 0:
|
||||
cols = ["open", "high", "low", "close"]
|
||||
else:
|
||||
cols = ["last_price", "highest", "lowest"]
|
||||
cols.extend(f"{x}{i}" for x in ["bid_price", "ask_price"] for i in range(1, 6))
|
||||
try:
|
||||
with open(self._csv_file_name, 'w', newline='') as csvfile:
|
||||
csv_writer = csv.writer(csvfile, dialect='excel')
|
||||
csv_writer.writerow(self._csv_header)
|
||||
last_dt = None
|
||||
async for item in self._data_chan:
|
||||
for quote in self._quote_list:
|
||||
symbol = quote.instrument_id
|
||||
if self._adj_type and quote.ins_class in ["STOCK", "FUND"]:
|
||||
dividend_df = self._dividend_cache[symbol]["df"]
|
||||
factor = 1
|
||||
if self._adj_type == "F":
|
||||
gt = dividend_df["datetime"].gt(item[index_datetime_nano])
|
||||
if gt.any():
|
||||
factor = dividend_df[gt]["factor"].cumprod().iloc[-1]
|
||||
elif self._adj_type == "B" and last_dt:
|
||||
gt = dividend_df['datetime'].gt(last_dt)
|
||||
if gt.any():
|
||||
index = dividend_df[gt].index[0]
|
||||
if item[index_datetime_nano] >= dividend_df.loc[index, 'datetime']:
|
||||
self._dividend_cache[symbol]["back_factor"] *= (1 / dividend_df[gt].loc[index, 'factor'])
|
||||
factor = self._dividend_cache[symbol]["back_factor"]
|
||||
last_dt = item[index_datetime_nano]
|
||||
if factor != 1:
|
||||
item = item.copy()
|
||||
for c in cols: # datetime_nano
|
||||
index = self._csv_header.index(f"{symbol}.{c}")
|
||||
item[index] = item[index] * factor
|
||||
csv_writer.writerow(item)
|
||||
finally:
|
||||
task.cancel()
|
||||
await asyncio.gather(task, return_exceptions=True)
|
||||
|
||||
async def _timeout_handle(self, timeout, chart):
|
||||
await asyncio.sleep(timeout)
|
||||
if chart.get("left_id", -1) == -1 and chart.get("right_id", -1) == -1:
|
||||
self._task.cancel()
|
||||
|
||||
async def _download_data(self):
|
||||
"""下载数据, 多合约横向按时间对齐"""
|
||||
chart_info = {
|
||||
@ -139,32 +297,19 @@ class DataDownloader:
|
||||
"focus_datetime": self._start_dt_nano,
|
||||
"focus_position": 0,
|
||||
}
|
||||
if len(self._symbol_list) == 1:
|
||||
single_exchange, single_symbol = self._symbol_list[0].split('.')
|
||||
else:
|
||||
single_exchange, single_symbol = None, None
|
||||
# 还没有发送过任何请求, 先请求定位左端点
|
||||
await self._api._send_chan.send(chart_info)
|
||||
chart = _get_obj(self._api._data, ["charts", chart_info["chart_id"]])
|
||||
# 增加一个 task,在 30s 后检查 chart 是否返回了左右 id 范围,如果没有就 cancel self._task,防止程序一直卡在那里
|
||||
timeout_task = self._api.create_task(self._timeout_handle(self._timeout_seconds, chart))
|
||||
current_id = None # 当前数据指针
|
||||
csv_header = []
|
||||
data_cols = ["open", "high", "low", "close", "volume", "open_oi", "close_oi"] if self._dur_nano != 0 else \
|
||||
["last_price", "highest", "lowest", "volume",
|
||||
"amount", "open_interest", "upper_limit", "lower_limit",
|
||||
"bid_price1", "bid_volume1", "ask_price1", "ask_volume1",
|
||||
"bid_price2", "bid_volume2", "ask_price2", "ask_volume2",
|
||||
"bid_price3", "bid_volume3", "ask_price3", "ask_volume3",
|
||||
"bid_price4", "bid_volume4", "ask_price4", "ask_volume4",
|
||||
"bid_price5", "bid_volume5", "ask_price5", "ask_volume5"
|
||||
]
|
||||
data_cols = self._get_data_cols()
|
||||
serials = []
|
||||
for symbol in self._symbol_list:
|
||||
path = ["klines", symbol, str(self._dur_nano)] if self._dur_nano != 0 else ["ticks", symbol]
|
||||
serial = _get_obj(self._api._data, path)
|
||||
serials.append(serial)
|
||||
try:
|
||||
with open(self._csv_file_name, 'w', newline='') as csvfile:
|
||||
csv_writer = csv.writer(csvfile, dialect='excel')
|
||||
async with self._api.register_update_notify() as update_chan:
|
||||
async for _ in update_chan:
|
||||
if not (chart_info.items() <= _get_obj(chart, ["state"]).items()):
|
||||
@ -186,42 +331,16 @@ class DataDownloader:
|
||||
if item.get("datetime", 0) == 0 or item["datetime"] > self._end_dt_nano:
|
||||
# 当前 id 已超出 last_id 或k线数据的时间已经超过用户限定的右端
|
||||
return
|
||||
if len(csv_header) == 0:
|
||||
# 写入文件头
|
||||
csv_header = ["datetime"]
|
||||
for symbol in self._symbol_list:
|
||||
# 单一合约时,添加合约和交易所
|
||||
if single_exchange:
|
||||
csv_header.extend(['symbol', 'exchange'])
|
||||
|
||||
row = [self._nano_to_str(item["datetime"]), item["datetime"]]
|
||||
for col in data_cols:
|
||||
if col.startswith('bid_') or col.startswith('ask_'):
|
||||
col = col[:-1] + '_' + col[-1]
|
||||
if len(self._symbol_list) > 2:
|
||||
csv_header.append(symbol + "." + col)
|
||||
else:
|
||||
csv_header.append(col)
|
||||
|
||||
csv_writer.writerow(csv_header)
|
||||
row = [self._nano_to_str(item["datetime"])]
|
||||
|
||||
# 单一合约时,添加合约和交易所
|
||||
if single_exchange:
|
||||
row.extend([single_symbol, single_exchange])
|
||||
|
||||
for col in data_cols:
|
||||
row.append(self._get_value(item, col))
|
||||
row.append(self._get_value(item, col, self._quote_list[0]["price_decs"]))
|
||||
for i in range(1, len(self._symbol_list)):
|
||||
symbol = self._symbol_list[i]
|
||||
tid = serials[0].get("binding", {}).get(symbol, {}).get(str(current_id), -1)
|
||||
k = {} if tid == -1 else serials[i]["data"].get(str(tid), {})
|
||||
for col in data_cols:
|
||||
row.append(self._get_value(k, col))
|
||||
# 抛弃盘前的脏数据
|
||||
if self._dur_nano == 0 and str(row[3]) == 'nan':
|
||||
p = 1
|
||||
else:
|
||||
csv_writer.writerow(row)
|
||||
row.append(self._get_value(k, col, self._quote_list[i]["price_decs"]))
|
||||
await self._data_chan.send(row)
|
||||
current_id += 1
|
||||
self._current_dt_nano = item["datetime"]
|
||||
# 当前 id 已超出订阅范围, 需重新订阅后续数据
|
||||
@ -231,6 +350,7 @@ class DataDownloader:
|
||||
await self._api._send_chan.send(chart_info)
|
||||
finally:
|
||||
# 释放chart资源
|
||||
await self._data_chan.close()
|
||||
await self._api._send_chan.send({
|
||||
"aid": "set_chart",
|
||||
"chart_id": chart_info["chart_id"],
|
||||
@ -238,18 +358,47 @@ class DataDownloader:
|
||||
"duration": self._dur_nano,
|
||||
"view_width": 2000,
|
||||
})
|
||||
timeout_task.cancel()
|
||||
await timeout_task
|
||||
|
||||
def _get_headers(self):
|
||||
data_cols = self._get_data_cols()
|
||||
data_cols = [col.replace('d_price','d_price_').replace('d_volume','d_volume_').replace('k_price','k_price_').replace('k_volume','k_volume_') for col in data_cols]
|
||||
if len(self._symbol_list) > 1:
|
||||
return ["datetime", "datetime_nano"] + [f"{symbol}.{col}" for symbol in self._symbol_list for col in data_cols]
|
||||
else:
|
||||
return ["datetime", "datetime_nano"] + data_cols
|
||||
|
||||
|
||||
def _get_data_cols(self):
|
||||
if self._dur_nano != 0:
|
||||
return ["open", "high", "low", "close", "volume", "open_oi", "close_oi"]
|
||||
else:
|
||||
# ,"upper_limit","lower_limit"
|
||||
cols = ["last_price", "highest", "lowest", "volume", "amount", "open_interest"]
|
||||
price_range = 1
|
||||
for symbol in self._symbol_list:
|
||||
if symbol.split('.')[0] in {"SHFE", "INE", "SSE", "SZSE"}:
|
||||
price_range = 5
|
||||
break
|
||||
for i in range(price_range):
|
||||
cols.extend(f"{x}{i+1}" for x in ["bid_price", "bid_volume", "ask_price", "ask_volume"])
|
||||
return cols
|
||||
|
||||
@staticmethod
|
||||
def _get_value(obj, key):
|
||||
def _get_value(obj, key, price_decs):
|
||||
if key not in obj:
|
||||
return "#N/A"
|
||||
if isinstance(obj[key], str):
|
||||
return float("nan")
|
||||
if key in PRICE_KEYS:
|
||||
return round(obj[key], price_decs)
|
||||
else:
|
||||
return obj[key]
|
||||
|
||||
@staticmethod
|
||||
def _nano_to_str(nano):
|
||||
dt = datetime.fromtimestamp(nano // 1000000000)
|
||||
s = dt.strftime('%Y-%m-%d %H:%M:%S')
|
||||
s += '.' + str(int(nano % 1000000000)).zfill(9)[:3]
|
||||
s += '.' + str(int(nano % 1000000000)).zfill(9)
|
||||
return s
|
||||
|
@ -6,7 +6,7 @@ from contextlib import closing
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from functools import lru_cache
|
||||
from tqsdk import TqApi, TqSim
|
||||
from tqsdk import TqApi, TqSim,TqAuth
|
||||
from vnpy.data.tq.downloader import DataDownloader
|
||||
from vnpy.trader.constant import (
|
||||
Direction,
|
||||
@ -22,6 +22,7 @@ from vnpy.trader.object import TickData, BarData
|
||||
from vnpy.trader.utility import extract_vt_symbol, get_trading_date
|
||||
import pandas as pd
|
||||
import csv
|
||||
from vnpy.data.tq.downloader import get_account_config
|
||||
|
||||
# pd.pandas.set_option('display.max_rows', None) # 设置最大显示行数,超过该值用省略号代替,为None时显示所有行。
|
||||
# pd.pandas.set_option('display.max_columns', None) # 设置最大显示列数,超过该值用省略号代替,为None时显示所有列。
|
||||
@ -59,6 +60,7 @@ def to_tq_symbol(symbol: str, exchange: Exchange) -> str:
|
||||
"""
|
||||
TQSdk exchange first
|
||||
"""
|
||||
count = 0
|
||||
for count, word in enumerate(symbol):
|
||||
if word.isdigit():
|
||||
break
|
||||
@ -128,8 +130,9 @@ class TqFutureData():
|
||||
|
||||
def __init__(self, strategy=None):
|
||||
self.strategy = strategy # 传进来策略实例,这样可以写日志到策略实例
|
||||
auth_dict = get_account_config()
|
||||
|
||||
self.api = TqApi(TqSim(), url="wss://u.shinnytech.com/t/md/front/mobile")
|
||||
self.api = TqApi(TqSim(),auth=TqAuth(auth_dict['user_name'],auth_dict['password'])) # url="wss://u.shinnytech.com/t/md/front/mobile"
|
||||
|
||||
def get_tick_serial(self, vt_symbol: str):
|
||||
# 获取最新的8964个数据 tick的话就相当于只有50分钟左右
|
||||
|
Loading…
Reference in New Issue
Block a user