[update] 天勤数据下载适应新版

This commit is contained in:
msincenselee 2022-01-18 10:52:14 +08:00
parent 015efa9d32
commit 2c58c51e7f
2 changed files with 249 additions and 97 deletions

View File

@ -6,14 +6,20 @@
# 2. 下载tick时5档行情都下载
# 3. 五档行情变量调整适合vnpy的命名方式
import os
import asyncio
import csv
import os
from datetime import date, datetime
from typing import Union, List
import lzma
import pandas
import json
from tqsdk.api import TqApi
from tqsdk.channel import TqChan
from tqsdk.datetime import _get_trading_day_start_time, _get_trading_day_end_time
from tqsdk.diff import _get_obj
from tqsdk.tafunc import get_dividend_df, get_dividend_factor
from tqsdk.utils import _generate_uuid
def get_account_config():
@ -32,15 +38,25 @@ def get_account_config():
return {}
DEAD_INS = {}
# 价格相关的字段,需要 format 数据格式
PRICE_KEYS = ["open", "high", "low", "close", "last_price", "highest", "lowest"] + [f"bid_price{i}" for i in range(1, 6)] + [f"ask_price{i}" for i in range(1, 6)]
class DataDownloader:
"""
数据下载工具是 TqSdk 专业版中的功能能让用户下载目前 TqSdk 提供的全部期货期权和股票类的历史数据下载数据支持 tick 级别精度和任意 kline 周期
如果想使用数据下载工具可以点击 `天勤量化专业版 <https://www.shinnytech.com/tqsdk_professional/>`_ 申请使用或购买
历史数据下载器, 输出到csv文件
多合约按时间横向对齐
"""
def __init__(self, api: TqApi, symbol_list: Union[str, List[str]], dur_sec: int, start_dt: Union[date, datetime],
end_dt: Union[date, datetime], csv_file_name: str) -> None:
end_dt: Union[date, datetime], csv_file_name: str, adj_type: Union[str, None] = None) -> None:
"""
创建历史数据下载器实例
@ -55,16 +71,18 @@ class DataDownloader:
end_dt (date/datetime): 结束时间, 如果类型为 date 则指的是交易日, 如果为 datetime 则指的是具体时间点
csv_file_name (str): 输出csv的文件名
csv_file_name (str): 输出 csv 的文件名
adj_type (str/None): 复权计算方式默认值为 None"F" 为前复权"B" 为后复权None 表示不复权只对股票基金合约有效
Example::
from datetime import datetime, date
from contextlib import closing
from tqsdk import TqApi, TqSim
from tqsdk import TqApi, TqAuth, TqSim
from tqsdk.tools import DataDownloader
api = TqApi(TqSim())
api = TqApi(auth=TqAuth("信易账户", "账户密码"))
download_tasks = {}
# 下载从 2018-01-01 到 2018-09-01 的 SR901 日线数据
download_tasks["SR_daily"] = DataDownloader(api, symbol_list="CZCE.SR901", dur_sec=24*60*60,
@ -87,27 +105,31 @@ class DataDownloader:
print("progress: ", { k:("%.2f%%" % v.get_progress()) for k,v in download_tasks.items() })
"""
self._api = api
if not self._api._auth._has_feature("tq_dl"):
raise Exception("您的账户不支持下载历史数据功能需要购买专业版本后使用。升级网址https://account.shinnytech.com")
if isinstance(start_dt, datetime):
self._start_dt_nano = int(start_dt.timestamp() * 1e9)
else:
self._start_dt_nano = _get_trading_day_start_time(
int(datetime(start_dt.year, start_dt.month, start_dt.day).timestamp()) * 1000000000)
self._start_dt_nano = _get_trading_day_start_time(int(datetime(start_dt.year, start_dt.month, start_dt.day).timestamp()) * 1000000000)
if isinstance(end_dt, datetime):
self._end_dt_nano = int(end_dt.timestamp() * 1e9)
else:
self._end_dt_nano = _get_trading_day_end_time(
int(datetime(end_dt.year, end_dt.month, end_dt.day).timestamp()) * 1000000000)
self._end_dt_nano = _get_trading_day_end_time(int(datetime(end_dt.year, end_dt.month, end_dt.day).timestamp()) * 1000000000)
self._current_dt_nano = self._start_dt_nano
self._symbol_list = symbol_list if isinstance(symbol_list, list) else [symbol_list]
# 检查合约代码是否存在
for symbol in self._symbol_list:
if (not self._api._stock) and symbol not in self._api._data.get("quotes", {}):
raise Exception("代码 %s 不存在, 请检查合约代码是否填写正确" % (symbol))
# 下载合约超时时间(默认 30s已下市的没有交易的合约超时时间可以设置短一点2s用户不希望自己的程序因为没有下载到数据而中断
self._timeout_seconds = 2 if any([symbol in DEAD_INS for symbol in self._symbol_list]) else 30
self._dur_nano = dur_sec * 1000000000
if self._dur_nano == 0 and len(self._symbol_list) != 1:
raise Exception("Tick序列不支持多合约")
if adj_type not in [None, "F", "B", "FORWARD", "BACK"]:
raise Exception("adj_type 参数只支持 None (不复权) 'F' (前复权) 'B' (后复权)")
self._adj_type = adj_type[0] if adj_type else adj_type
self._csv_file_name = csv_file_name
self._task = self._api.create_task(self._download_data())
self._csv_header = self._get_headers()
self._dividend_cache = {} # 缓存合约对应的复权系数矩阵,每个合约只计算一次
self._data_series = None
self._task = self._api.create_task(self._run())
def is_finished(self) -> bool:
"""
@ -128,6 +150,142 @@ class DataDownloader:
return 100.0 if self._task.done() else (self._current_dt_nano - self._start_dt_nano) / (
self._end_dt_nano - self._start_dt_nano) * 100
def _get_data_series(self) -> pandas.DataFrame:
"""
获取下载的 DataFrame 格式数据
todo: utils 中增加工具函数返回与 kline 一致的数据结构
Returns:
pandas.DataFrame/None: 下载的 klines 或者 ticks 数据DataFrame 格式下载完成前返回 None
Example::
from datetime import datetime, date
rom tqsdk import TqApi, TqAuth
from contextlib import closing
from tqsdk.tools import DataDownloader
api = TqApi(auth=TqAuth("信易账户", "账户密码"))
# 下载从 2018-06-01 到 2018-09-01 的 SR901 日线数据
download_task = DataDownloader(api, symbol_list="CZCE.SR901", dur_sec=24*60*60,
start_dt=date(2018, 6, 1), end_dt=date(2018, 9, 1), csv_file_name="klines.csv")
# 使用with closing机制确保下载完成后释放对应的资源
with closing(api):
while not download_task.is_finished():
api.wait_update()
print(f"progress: {download_task.get_progress():.2} %")
print(download_task._get_data_series())
"""
if not self._task.done():
return None
if not self._data_series:
self._data_series = pandas.read_csv(self._csv_file_name)
return self._data_series
async def _update_dividend_factor(self):
for s in self._dividend_cache:
# 对每个除权除息矩阵增加 factor 序列,为当日的复权因子
df = self._dividend_cache[s]["df"]
df["pre_close"] = float('nan') # 初始化 pre_close 为 nan
between = df["datetime"].between(self._start_dt_nano, self._end_dt_nano) # 只需要开始时间~结束时间之间的复权因子
for i in df[between].index:
chart_info = {
"aid": "set_chart",
"chart_id": _generate_uuid("PYSDK_downloader"),
"ins_list": s,
"duration": 86400 * 1000000000,
"view_width": 2,
"focus_datetime": int(df.iloc[i].datetime),
"focus_position": 1
}
await self._api._send_chan.send(chart_info)
chart = _get_obj(self._api._data, ["charts", chart_info["chart_id"]])
serial = _get_obj(self._api._data, ["klines", s, str(86400000000000)])
try:
async with self._api.register_update_notify() as update_chan:
async for _ in update_chan:
if not (chart_info.items() <= _get_obj(chart, ["state"]).items()):
continue # 当前请求还没收齐回应, 不应继续处理
left_id = chart.get("left_id", -1)
right_id = chart.get("right_id", -1)
if (left_id == -1 and right_id == -1) or self._api._data.get("mdhis_more_data", True) or serial.get("last_id", -1) == -1:
continue # 定位信息还没收到, 或数据序列还没收到, 合约的数据是否收到
last_item = serial["data"].get(str(left_id), {})
# 复权时间点的昨收盘
df.loc[i, 'pre_close'] = last_item['close'] if last_item.get('close') else float('nan')
break
finally:
await self._api._send_chan.send({
"aid": "set_chart",
"chart_id": chart_info["chart_id"],
"ins_list": "",
"duration": 86400000000000,
"view_width": 2
})
df["factor"] = (df["pre_close"] - df["cash_dividend"]) / df["pre_close"] / (1 + df["stock_dividend"])
df["factor"].fillna(1, inplace=True)
async def _run(self):
self._quote_list = await self._api.get_quote_list(self._symbol_list)
# 如果存在 STOCK / FUND 并且 adj_type is not None, 这里需要提前准备下载时间段内的复权因子
if self._adj_type:
for quote in self._quote_list:
if quote.ins_class in ["STOCK", "FUND"]:
self._dividend_cache[quote.instrument_id] = {
"df": get_dividend_df(quote.stock_dividend_ratio, quote.cash_dividend_ratio),
"back_factor": 1.0
}
# 前复权需要提前计算除权因子
await self._update_dividend_factor()
self._data_chan = TqChan(self._api)
task = self._api.create_task(self._download_data())
# cols 是复权需要重新计算的列名
index_datetime_nano = self._csv_header.index("datetime_nano")
if self._dur_nano != 0:
cols = ["open", "high", "low", "close"]
else:
cols = ["last_price", "highest", "lowest"]
cols.extend(f"{x}{i}" for x in ["bid_price", "ask_price"] for i in range(1, 6))
try:
with open(self._csv_file_name, 'w', newline='') as csvfile:
csv_writer = csv.writer(csvfile, dialect='excel')
csv_writer.writerow(self._csv_header)
last_dt = None
async for item in self._data_chan:
for quote in self._quote_list:
symbol = quote.instrument_id
if self._adj_type and quote.ins_class in ["STOCK", "FUND"]:
dividend_df = self._dividend_cache[symbol]["df"]
factor = 1
if self._adj_type == "F":
gt = dividend_df["datetime"].gt(item[index_datetime_nano])
if gt.any():
factor = dividend_df[gt]["factor"].cumprod().iloc[-1]
elif self._adj_type == "B" and last_dt:
gt = dividend_df['datetime'].gt(last_dt)
if gt.any():
index = dividend_df[gt].index[0]
if item[index_datetime_nano] >= dividend_df.loc[index, 'datetime']:
self._dividend_cache[symbol]["back_factor"] *= (1 / dividend_df[gt].loc[index, 'factor'])
factor = self._dividend_cache[symbol]["back_factor"]
last_dt = item[index_datetime_nano]
if factor != 1:
item = item.copy()
for c in cols: # datetime_nano
index = self._csv_header.index(f"{symbol}.{c}")
item[index] = item[index] * factor
csv_writer.writerow(item)
finally:
task.cancel()
await asyncio.gather(task, return_exceptions=True)
async def _timeout_handle(self, timeout, chart):
await asyncio.sleep(timeout)
if chart.get("left_id", -1) == -1 and chart.get("right_id", -1) == -1:
self._task.cancel()
async def _download_data(self):
"""下载数据, 多合约横向按时间对齐"""
chart_info = {
@ -139,98 +297,60 @@ class DataDownloader:
"focus_datetime": self._start_dt_nano,
"focus_position": 0,
}
if len(self._symbol_list) == 1:
single_exchange, single_symbol = self._symbol_list[0].split('.')
else:
single_exchange, single_symbol = None, None
# 还没有发送过任何请求, 先请求定位左端点
await self._api._send_chan.send(chart_info)
chart = _get_obj(self._api._data, ["charts", chart_info["chart_id"]])
# 增加一个 task在 30s 后检查 chart 是否返回了左右 id 范围,如果没有就 cancel self._task防止程序一直卡在那里
timeout_task = self._api.create_task(self._timeout_handle(self._timeout_seconds, chart))
current_id = None # 当前数据指针
csv_header = []
data_cols = ["open", "high", "low", "close", "volume", "open_oi", "close_oi"] if self._dur_nano != 0 else \
["last_price", "highest", "lowest", "volume",
"amount", "open_interest", "upper_limit", "lower_limit",
"bid_price1", "bid_volume1", "ask_price1", "ask_volume1",
"bid_price2", "bid_volume2", "ask_price2", "ask_volume2",
"bid_price3", "bid_volume3", "ask_price3", "ask_volume3",
"bid_price4", "bid_volume4", "ask_price4", "ask_volume4",
"bid_price5", "bid_volume5", "ask_price5", "ask_volume5"
]
data_cols = self._get_data_cols()
serials = []
for symbol in self._symbol_list:
path = ["klines", symbol, str(self._dur_nano)] if self._dur_nano != 0 else ["ticks", symbol]
serial = _get_obj(self._api._data, path)
serials.append(serial)
try:
with open(self._csv_file_name, 'w', newline='') as csvfile:
csv_writer = csv.writer(csvfile, dialect='excel')
async with self._api.register_update_notify() as update_chan:
async for _ in update_chan:
if not (chart_info.items() <= _get_obj(chart, ["state"]).items()):
# 当前请求还没收齐回应, 不应继续处理
async with self._api.register_update_notify() as update_chan:
async for _ in update_chan:
if not (chart_info.items() <= _get_obj(chart, ["state"]).items()):
# 当前请求还没收齐回应, 不应继续处理
continue
left_id = chart.get("left_id", -1)
right_id = chart.get("right_id", -1)
if (left_id == -1 and right_id == -1) or self._api._data.get("mdhis_more_data", True):
# 定位信息还没收到, 或数据序列还没收到
continue
for serial in serials:
# 检查合约的数据是否收到
if serial.get("last_id", -1) == -1:
continue
left_id = chart.get("left_id", -1)
right_id = chart.get("right_id", -1)
if (left_id == -1 and right_id == -1) or self._api._data.get("mdhis_more_data", True):
# 定位信息还没收到, 或数据序列还没收到
continue
for serial in serials:
# 检查合约的数据是否收到
if serial.get("last_id", -1) == -1:
continue
if current_id is None:
current_id = max(left_id, 0)
while current_id <= right_id:
item = serials[0]["data"].get(str(current_id), {})
if item.get("datetime", 0) == 0 or item["datetime"] > self._end_dt_nano:
# 当前 id 已超出 last_id 或k线数据的时间已经超过用户限定的右端
return
if len(csv_header) == 0:
# 写入文件头
csv_header = ["datetime"]
for symbol in self._symbol_list:
# 单一合约时,添加合约和交易所
if single_exchange:
csv_header.extend(['symbol', 'exchange'])
for col in data_cols:
if col.startswith('bid_') or col.startswith('ask_'):
col = col[:-1] + '_' + col[-1]
if len(self._symbol_list) > 2:
csv_header.append(symbol + "." + col)
else:
csv_header.append(col)
csv_writer.writerow(csv_header)
row = [self._nano_to_str(item["datetime"])]
# 单一合约时,添加合约和交易所
if single_exchange:
row.extend([single_symbol, single_exchange])
if current_id is None:
current_id = max(left_id, 0)
while current_id <= right_id:
item = serials[0]["data"].get(str(current_id), {})
if item.get("datetime", 0) == 0 or item["datetime"] > self._end_dt_nano:
# 当前 id 已超出 last_id 或k线数据的时间已经超过用户限定的右端
return
row = [self._nano_to_str(item["datetime"]), item["datetime"]]
for col in data_cols:
row.append(self._get_value(item, col, self._quote_list[0]["price_decs"]))
for i in range(1, len(self._symbol_list)):
symbol = self._symbol_list[i]
tid = serials[0].get("binding", {}).get(symbol, {}).get(str(current_id), -1)
k = {} if tid == -1 else serials[i]["data"].get(str(tid), {})
for col in data_cols:
row.append(self._get_value(item, col))
for i in range(1, len(self._symbol_list)):
symbol = self._symbol_list[i]
tid = serials[0].get("binding", {}).get(symbol, {}).get(str(current_id), -1)
k = {} if tid == -1 else serials[i]["data"].get(str(tid), {})
for col in data_cols:
row.append(self._get_value(k, col))
# 抛弃盘前的脏数据
if self._dur_nano == 0 and str(row[3]) == 'nan':
p = 1
else:
csv_writer.writerow(row)
current_id += 1
self._current_dt_nano = item["datetime"]
# 当前 id 已超出订阅范围, 需重新订阅后续数据
chart_info.pop("focus_datetime", None)
chart_info.pop("focus_position", None)
chart_info["left_kline_id"] = current_id
await self._api._send_chan.send(chart_info)
row.append(self._get_value(k, col, self._quote_list[i]["price_decs"]))
await self._data_chan.send(row)
current_id += 1
self._current_dt_nano = item["datetime"]
# 当前 id 已超出订阅范围, 需重新订阅后续数据
chart_info.pop("focus_datetime", None)
chart_info.pop("focus_position", None)
chart_info["left_kline_id"] = current_id
await self._api._send_chan.send(chart_info)
finally:
# 释放chart资源
await self._data_chan.close()
await self._api._send_chan.send({
"aid": "set_chart",
"chart_id": chart_info["chart_id"],
@ -238,18 +358,47 @@ class DataDownloader:
"duration": self._dur_nano,
"view_width": 2000,
})
timeout_task.cancel()
await timeout_task
def _get_headers(self):
data_cols = self._get_data_cols()
data_cols = [col.replace('d_price','d_price_').replace('d_volume','d_volume_').replace('k_price','k_price_').replace('k_volume','k_volume_') for col in data_cols]
if len(self._symbol_list) > 1:
return ["datetime", "datetime_nano"] + [f"{symbol}.{col}" for symbol in self._symbol_list for col in data_cols]
else:
return ["datetime", "datetime_nano"] + data_cols
def _get_data_cols(self):
if self._dur_nano != 0:
return ["open", "high", "low", "close", "volume", "open_oi", "close_oi"]
else:
# ,"upper_limit","lower_limit"
cols = ["last_price", "highest", "lowest", "volume", "amount", "open_interest"]
price_range = 1
for symbol in self._symbol_list:
if symbol.split('.')[0] in {"SHFE", "INE", "SSE", "SZSE"}:
price_range = 5
break
for i in range(price_range):
cols.extend(f"{x}{i+1}" for x in ["bid_price", "bid_volume", "ask_price", "ask_volume"])
return cols
@staticmethod
def _get_value(obj, key):
def _get_value(obj, key, price_decs):
if key not in obj:
return "#N/A"
if isinstance(obj[key], str):
return float("nan")
return obj[key]
if key in PRICE_KEYS:
return round(obj[key], price_decs)
else:
return obj[key]
@staticmethod
def _nano_to_str(nano):
dt = datetime.fromtimestamp(nano // 1000000000)
s = dt.strftime('%Y-%m-%d %H:%M:%S')
s += '.' + str(int(nano % 1000000000)).zfill(9)[:3]
s += '.' + str(int(nano % 1000000000)).zfill(9)
return s

View File

@ -6,7 +6,7 @@ from contextlib import closing
import os
from datetime import datetime, timedelta
from functools import lru_cache
from tqsdk import TqApi, TqSim
from tqsdk import TqApi, TqSim,TqAuth
from vnpy.data.tq.downloader import DataDownloader
from vnpy.trader.constant import (
Direction,
@ -22,6 +22,7 @@ from vnpy.trader.object import TickData, BarData
from vnpy.trader.utility import extract_vt_symbol, get_trading_date
import pandas as pd
import csv
from vnpy.data.tq.downloader import get_account_config
# pd.pandas.set_option('display.max_rows', None) # 设置最大显示行数超过该值用省略号代替为None时显示所有行。
# pd.pandas.set_option('display.max_columns', None) # 设置最大显示列数超过该值用省略号代替为None时显示所有列。
@ -59,6 +60,7 @@ def to_tq_symbol(symbol: str, exchange: Exchange) -> str:
"""
TQSdk exchange first
"""
count = 0
for count, word in enumerate(symbol):
if word.isdigit():
break
@ -128,8 +130,9 @@ class TqFutureData():
def __init__(self, strategy=None):
self.strategy = strategy # 传进来策略实例,这样可以写日志到策略实例
auth_dict = get_account_config()
self.api = TqApi(TqSim(), url="wss://u.shinnytech.com/t/md/front/mobile")
self.api = TqApi(TqSim(),auth=TqAuth(auth_dict['user_name'],auth_dict['password'])) # url="wss://u.shinnytech.com/t/md/front/mobile"
def get_tick_serial(self, vt_symbol: str):
# 获取最新的8964个数据 tick的话就相当于只有50分钟左右