diff --git a/vnpy/data/tq/downloader.py b/vnpy/data/tq/downloader.py index a32313c6..8ffa8578 100644 --- a/vnpy/data/tq/downloader.py +++ b/vnpy/data/tq/downloader.py @@ -6,14 +6,20 @@ # 2. 下载tick时,5档行情都下载 # 3. 五档行情变量调整适合vnpy的命名方式 -import os +import asyncio import csv +import os from datetime import date, datetime from typing import Union, List +import lzma + +import pandas import json from tqsdk.api import TqApi +from tqsdk.channel import TqChan from tqsdk.datetime import _get_trading_day_start_time, _get_trading_day_end_time from tqsdk.diff import _get_obj +from tqsdk.tafunc import get_dividend_df, get_dividend_factor from tqsdk.utils import _generate_uuid def get_account_config(): @@ -32,15 +38,25 @@ def get_account_config(): return {} +DEAD_INS = {} + +# 价格相关的字段,需要 format 数据格式 +PRICE_KEYS = ["open", "high", "low", "close", "last_price", "highest", "lowest"] + [f"bid_price{i}" for i in range(1, 6)] + [f"ask_price{i}" for i in range(1, 6)] + + class DataDownloader: """ + 数据下载工具是 TqSdk 专业版中的功能,能让用户下载目前 TqSdk 提供的全部期货、期权和股票类的历史数据,下载数据支持 tick 级别精度和任意 kline 周期 + + 如果想使用数据下载工具,可以点击 `天勤量化专业版 `_ 申请使用或购买 + 历史数据下载器, 输出到csv文件 多合约按时间横向对齐 """ def __init__(self, api: TqApi, symbol_list: Union[str, List[str]], dur_sec: int, start_dt: Union[date, datetime], - end_dt: Union[date, datetime], csv_file_name: str) -> None: + end_dt: Union[date, datetime], csv_file_name: str, adj_type: Union[str, None] = None) -> None: """ 创建历史数据下载器实例 @@ -55,16 +71,18 @@ class DataDownloader: end_dt (date/datetime): 结束时间, 如果类型为 date 则指的是交易日, 如果为 datetime 则指的是具体时间点 - csv_file_name (str): 输出csv的文件名 + csv_file_name (str): 输出 csv 的文件名 + + adj_type (str/None): 复权计算方式,默认值为 None。"F" 为前复权;"B" 为后复权;None 表示不复权。只对股票、基金合约有效。 Example:: from datetime import datetime, date from contextlib import closing - from tqsdk import TqApi, TqSim + from tqsdk import TqApi, TqAuth, TqSim from tqsdk.tools import DataDownloader - api = TqApi(TqSim()) + api = TqApi(auth=TqAuth("信易账户", "账户密码")) download_tasks = {} # 下载从 2018-01-01 到 2018-09-01 的 SR901 日线数据 download_tasks["SR_daily"] = DataDownloader(api, symbol_list="CZCE.SR901", dur_sec=24*60*60, @@ -87,27 +105,31 @@ class DataDownloader: print("progress: ", { k:("%.2f%%" % v.get_progress()) for k,v in download_tasks.items() }) """ self._api = api + if not self._api._auth._has_feature("tq_dl"): + raise Exception("您的账户不支持下载历史数据功能,需要购买专业版本后使用。升级网址:https://account.shinnytech.com") if isinstance(start_dt, datetime): self._start_dt_nano = int(start_dt.timestamp() * 1e9) else: - self._start_dt_nano = _get_trading_day_start_time( - int(datetime(start_dt.year, start_dt.month, start_dt.day).timestamp()) * 1000000000) + self._start_dt_nano = _get_trading_day_start_time(int(datetime(start_dt.year, start_dt.month, start_dt.day).timestamp()) * 1000000000) if isinstance(end_dt, datetime): self._end_dt_nano = int(end_dt.timestamp() * 1e9) else: - self._end_dt_nano = _get_trading_day_end_time( - int(datetime(end_dt.year, end_dt.month, end_dt.day).timestamp()) * 1000000000) + self._end_dt_nano = _get_trading_day_end_time(int(datetime(end_dt.year, end_dt.month, end_dt.day).timestamp()) * 1000000000) self._current_dt_nano = self._start_dt_nano self._symbol_list = symbol_list if isinstance(symbol_list, list) else [symbol_list] - # 检查合约代码是否存在 - for symbol in self._symbol_list: - if (not self._api._stock) and symbol not in self._api._data.get("quotes", {}): - raise Exception("代码 %s 不存在, 请检查合约代码是否填写正确" % (symbol)) + # 下载合约超时时间(默认 30s),已下市的没有交易的合约,超时时间可以设置短一点(2s),用户不希望自己的程序因为没有下载到数据而中断 + self._timeout_seconds = 2 if any([symbol in DEAD_INS for symbol in self._symbol_list]) else 30 self._dur_nano = dur_sec * 1000000000 if self._dur_nano == 0 and len(self._symbol_list) != 1: raise Exception("Tick序列不支持多合约") + if adj_type not in [None, "F", "B", "FORWARD", "BACK"]: + raise Exception("adj_type 参数只支持 None (不复权) | 'F' (前复权) | 'B' (后复权)") + self._adj_type = adj_type[0] if adj_type else adj_type self._csv_file_name = csv_file_name - self._task = self._api.create_task(self._download_data()) + self._csv_header = self._get_headers() + self._dividend_cache = {} # 缓存合约对应的复权系数矩阵,每个合约只计算一次 + self._data_series = None + self._task = self._api.create_task(self._run()) def is_finished(self) -> bool: """ @@ -128,6 +150,142 @@ class DataDownloader: return 100.0 if self._task.done() else (self._current_dt_nano - self._start_dt_nano) / ( self._end_dt_nano - self._start_dt_nano) * 100 + def _get_data_series(self) -> pandas.DataFrame: + """ + 获取下载的 DataFrame 格式数据 + + todo: 在 utils 中增加工具函数,返回与 kline 一致的数据结构 + + Returns: + pandas.DataFrame/None: 下载的 klines 或者 ticks 数据,DataFrame 格式。下载完成前返回 None。 + + + Example:: + + from datetime import datetime, date + rom tqsdk import TqApi, TqAuth + from contextlib import closing + from tqsdk.tools import DataDownloader + + api = TqApi(auth=TqAuth("信易账户", "账户密码")) + # 下载从 2018-06-01 到 2018-09-01 的 SR901 日线数据 + download_task = DataDownloader(api, symbol_list="CZCE.SR901", dur_sec=24*60*60, + start_dt=date(2018, 6, 1), end_dt=date(2018, 9, 1), csv_file_name="klines.csv") + # 使用with closing机制确保下载完成后释放对应的资源 + with closing(api): + while not download_task.is_finished(): + api.wait_update() + print(f"progress: {download_task.get_progress():.2} %") + print(download_task._get_data_series()) + """ + if not self._task.done(): + return None + if not self._data_series: + self._data_series = pandas.read_csv(self._csv_file_name) + return self._data_series + + async def _update_dividend_factor(self): + for s in self._dividend_cache: + # 对每个除权除息矩阵增加 factor 序列,为当日的复权因子 + df = self._dividend_cache[s]["df"] + df["pre_close"] = float('nan') # 初始化 pre_close 为 nan + between = df["datetime"].between(self._start_dt_nano, self._end_dt_nano) # 只需要开始时间~结束时间之间的复权因子 + for i in df[between].index: + chart_info = { + "aid": "set_chart", + "chart_id": _generate_uuid("PYSDK_downloader"), + "ins_list": s, + "duration": 86400 * 1000000000, + "view_width": 2, + "focus_datetime": int(df.iloc[i].datetime), + "focus_position": 1 + } + await self._api._send_chan.send(chart_info) + chart = _get_obj(self._api._data, ["charts", chart_info["chart_id"]]) + serial = _get_obj(self._api._data, ["klines", s, str(86400000000000)]) + try: + async with self._api.register_update_notify() as update_chan: + async for _ in update_chan: + if not (chart_info.items() <= _get_obj(chart, ["state"]).items()): + continue # 当前请求还没收齐回应, 不应继续处理 + left_id = chart.get("left_id", -1) + right_id = chart.get("right_id", -1) + if (left_id == -1 and right_id == -1) or self._api._data.get("mdhis_more_data", True) or serial.get("last_id", -1) == -1: + continue # 定位信息还没收到, 或数据序列还没收到, 合约的数据是否收到 + last_item = serial["data"].get(str(left_id), {}) + # 复权时间点的昨收盘 + df.loc[i, 'pre_close'] = last_item['close'] if last_item.get('close') else float('nan') + break + finally: + await self._api._send_chan.send({ + "aid": "set_chart", + "chart_id": chart_info["chart_id"], + "ins_list": "", + "duration": 86400000000000, + "view_width": 2 + }) + df["factor"] = (df["pre_close"] - df["cash_dividend"]) / df["pre_close"] / (1 + df["stock_dividend"]) + df["factor"].fillna(1, inplace=True) + + async def _run(self): + self._quote_list = await self._api.get_quote_list(self._symbol_list) + # 如果存在 STOCK / FUND 并且 adj_type is not None, 这里需要提前准备下载时间段内的复权因子 + if self._adj_type: + for quote in self._quote_list: + if quote.ins_class in ["STOCK", "FUND"]: + self._dividend_cache[quote.instrument_id] = { + "df": get_dividend_df(quote.stock_dividend_ratio, quote.cash_dividend_ratio), + "back_factor": 1.0 + } + # 前复权需要提前计算除权因子 + await self._update_dividend_factor() + self._data_chan = TqChan(self._api) + task = self._api.create_task(self._download_data()) + # cols 是复权需要重新计算的列名 + index_datetime_nano = self._csv_header.index("datetime_nano") + if self._dur_nano != 0: + cols = ["open", "high", "low", "close"] + else: + cols = ["last_price", "highest", "lowest"] + cols.extend(f"{x}{i}" for x in ["bid_price", "ask_price"] for i in range(1, 6)) + try: + with open(self._csv_file_name, 'w', newline='') as csvfile: + csv_writer = csv.writer(csvfile, dialect='excel') + csv_writer.writerow(self._csv_header) + last_dt = None + async for item in self._data_chan: + for quote in self._quote_list: + symbol = quote.instrument_id + if self._adj_type and quote.ins_class in ["STOCK", "FUND"]: + dividend_df = self._dividend_cache[symbol]["df"] + factor = 1 + if self._adj_type == "F": + gt = dividend_df["datetime"].gt(item[index_datetime_nano]) + if gt.any(): + factor = dividend_df[gt]["factor"].cumprod().iloc[-1] + elif self._adj_type == "B" and last_dt: + gt = dividend_df['datetime'].gt(last_dt) + if gt.any(): + index = dividend_df[gt].index[0] + if item[index_datetime_nano] >= dividend_df.loc[index, 'datetime']: + self._dividend_cache[symbol]["back_factor"] *= (1 / dividend_df[gt].loc[index, 'factor']) + factor = self._dividend_cache[symbol]["back_factor"] + last_dt = item[index_datetime_nano] + if factor != 1: + item = item.copy() + for c in cols: # datetime_nano + index = self._csv_header.index(f"{symbol}.{c}") + item[index] = item[index] * factor + csv_writer.writerow(item) + finally: + task.cancel() + await asyncio.gather(task, return_exceptions=True) + + async def _timeout_handle(self, timeout, chart): + await asyncio.sleep(timeout) + if chart.get("left_id", -1) == -1 and chart.get("right_id", -1) == -1: + self._task.cancel() + async def _download_data(self): """下载数据, 多合约横向按时间对齐""" chart_info = { @@ -139,98 +297,60 @@ class DataDownloader: "focus_datetime": self._start_dt_nano, "focus_position": 0, } - if len(self._symbol_list) == 1: - single_exchange, single_symbol = self._symbol_list[0].split('.') - else: - single_exchange, single_symbol = None, None # 还没有发送过任何请求, 先请求定位左端点 await self._api._send_chan.send(chart_info) chart = _get_obj(self._api._data, ["charts", chart_info["chart_id"]]) + # 增加一个 task,在 30s 后检查 chart 是否返回了左右 id 范围,如果没有就 cancel self._task,防止程序一直卡在那里 + timeout_task = self._api.create_task(self._timeout_handle(self._timeout_seconds, chart)) current_id = None # 当前数据指针 - csv_header = [] - data_cols = ["open", "high", "low", "close", "volume", "open_oi", "close_oi"] if self._dur_nano != 0 else \ - ["last_price", "highest", "lowest", "volume", - "amount", "open_interest", "upper_limit", "lower_limit", - "bid_price1", "bid_volume1", "ask_price1", "ask_volume1", - "bid_price2", "bid_volume2", "ask_price2", "ask_volume2", - "bid_price3", "bid_volume3", "ask_price3", "ask_volume3", - "bid_price4", "bid_volume4", "ask_price4", "ask_volume4", - "bid_price5", "bid_volume5", "ask_price5", "ask_volume5" - ] + data_cols = self._get_data_cols() serials = [] for symbol in self._symbol_list: path = ["klines", symbol, str(self._dur_nano)] if self._dur_nano != 0 else ["ticks", symbol] serial = _get_obj(self._api._data, path) serials.append(serial) try: - with open(self._csv_file_name, 'w', newline='') as csvfile: - csv_writer = csv.writer(csvfile, dialect='excel') - async with self._api.register_update_notify() as update_chan: - async for _ in update_chan: - if not (chart_info.items() <= _get_obj(chart, ["state"]).items()): - # 当前请求还没收齐回应, 不应继续处理 + async with self._api.register_update_notify() as update_chan: + async for _ in update_chan: + if not (chart_info.items() <= _get_obj(chart, ["state"]).items()): + # 当前请求还没收齐回应, 不应继续处理 + continue + left_id = chart.get("left_id", -1) + right_id = chart.get("right_id", -1) + if (left_id == -1 and right_id == -1) or self._api._data.get("mdhis_more_data", True): + # 定位信息还没收到, 或数据序列还没收到 + continue + for serial in serials: + # 检查合约的数据是否收到 + if serial.get("last_id", -1) == -1: continue - left_id = chart.get("left_id", -1) - right_id = chart.get("right_id", -1) - if (left_id == -1 and right_id == -1) or self._api._data.get("mdhis_more_data", True): - # 定位信息还没收到, 或数据序列还没收到 - continue - for serial in serials: - # 检查合约的数据是否收到 - if serial.get("last_id", -1) == -1: - continue - if current_id is None: - current_id = max(left_id, 0) - while current_id <= right_id: - item = serials[0]["data"].get(str(current_id), {}) - if item.get("datetime", 0) == 0 or item["datetime"] > self._end_dt_nano: - # 当前 id 已超出 last_id 或k线数据的时间已经超过用户限定的右端 - return - if len(csv_header) == 0: - # 写入文件头 - csv_header = ["datetime"] - for symbol in self._symbol_list: - # 单一合约时,添加合约和交易所 - if single_exchange: - csv_header.extend(['symbol', 'exchange']) - - for col in data_cols: - if col.startswith('bid_') or col.startswith('ask_'): - col = col[:-1] + '_' + col[-1] - if len(self._symbol_list) > 2: - csv_header.append(symbol + "." + col) - else: - csv_header.append(col) - - csv_writer.writerow(csv_header) - row = [self._nano_to_str(item["datetime"])] - - # 单一合约时,添加合约和交易所 - if single_exchange: - row.extend([single_symbol, single_exchange]) - + if current_id is None: + current_id = max(left_id, 0) + while current_id <= right_id: + item = serials[0]["data"].get(str(current_id), {}) + if item.get("datetime", 0) == 0 or item["datetime"] > self._end_dt_nano: + # 当前 id 已超出 last_id 或k线数据的时间已经超过用户限定的右端 + return + row = [self._nano_to_str(item["datetime"]), item["datetime"]] + for col in data_cols: + row.append(self._get_value(item, col, self._quote_list[0]["price_decs"])) + for i in range(1, len(self._symbol_list)): + symbol = self._symbol_list[i] + tid = serials[0].get("binding", {}).get(symbol, {}).get(str(current_id), -1) + k = {} if tid == -1 else serials[i]["data"].get(str(tid), {}) for col in data_cols: - row.append(self._get_value(item, col)) - for i in range(1, len(self._symbol_list)): - symbol = self._symbol_list[i] - tid = serials[0].get("binding", {}).get(symbol, {}).get(str(current_id), -1) - k = {} if tid == -1 else serials[i]["data"].get(str(tid), {}) - for col in data_cols: - row.append(self._get_value(k, col)) - # 抛弃盘前的脏数据 - if self._dur_nano == 0 and str(row[3]) == 'nan': - p = 1 - else: - csv_writer.writerow(row) - current_id += 1 - self._current_dt_nano = item["datetime"] - # 当前 id 已超出订阅范围, 需重新订阅后续数据 - chart_info.pop("focus_datetime", None) - chart_info.pop("focus_position", None) - chart_info["left_kline_id"] = current_id - await self._api._send_chan.send(chart_info) + row.append(self._get_value(k, col, self._quote_list[i]["price_decs"])) + await self._data_chan.send(row) + current_id += 1 + self._current_dt_nano = item["datetime"] + # 当前 id 已超出订阅范围, 需重新订阅后续数据 + chart_info.pop("focus_datetime", None) + chart_info.pop("focus_position", None) + chart_info["left_kline_id"] = current_id + await self._api._send_chan.send(chart_info) finally: # 释放chart资源 + await self._data_chan.close() await self._api._send_chan.send({ "aid": "set_chart", "chart_id": chart_info["chart_id"], @@ -238,18 +358,47 @@ class DataDownloader: "duration": self._dur_nano, "view_width": 2000, }) + timeout_task.cancel() + await timeout_task + + def _get_headers(self): + data_cols = self._get_data_cols() + data_cols = [col.replace('d_price','d_price_').replace('d_volume','d_volume_').replace('k_price','k_price_').replace('k_volume','k_volume_') for col in data_cols] + if len(self._symbol_list) > 1: + return ["datetime", "datetime_nano"] + [f"{symbol}.{col}" for symbol in self._symbol_list for col in data_cols] + else: + return ["datetime", "datetime_nano"] + data_cols + + + def _get_data_cols(self): + if self._dur_nano != 0: + return ["open", "high", "low", "close", "volume", "open_oi", "close_oi"] + else: + # ,"upper_limit","lower_limit" + cols = ["last_price", "highest", "lowest", "volume", "amount", "open_interest"] + price_range = 1 + for symbol in self._symbol_list: + if symbol.split('.')[0] in {"SHFE", "INE", "SSE", "SZSE"}: + price_range = 5 + break + for i in range(price_range): + cols.extend(f"{x}{i+1}" for x in ["bid_price", "bid_volume", "ask_price", "ask_volume"]) + return cols @staticmethod - def _get_value(obj, key): + def _get_value(obj, key, price_decs): if key not in obj: return "#N/A" if isinstance(obj[key], str): return float("nan") - return obj[key] + if key in PRICE_KEYS: + return round(obj[key], price_decs) + else: + return obj[key] @staticmethod def _nano_to_str(nano): dt = datetime.fromtimestamp(nano // 1000000000) s = dt.strftime('%Y-%m-%d %H:%M:%S') - s += '.' + str(int(nano % 1000000000)).zfill(9)[:3] + s += '.' + str(int(nano % 1000000000)).zfill(9) return s diff --git a/vnpy/data/tq/tianqin_data.py b/vnpy/data/tq/tianqin_data.py index f1124323..4aa91fd9 100644 --- a/vnpy/data/tq/tianqin_data.py +++ b/vnpy/data/tq/tianqin_data.py @@ -6,7 +6,7 @@ from contextlib import closing import os from datetime import datetime, timedelta from functools import lru_cache -from tqsdk import TqApi, TqSim +from tqsdk import TqApi, TqSim,TqAuth from vnpy.data.tq.downloader import DataDownloader from vnpy.trader.constant import ( Direction, @@ -22,6 +22,7 @@ from vnpy.trader.object import TickData, BarData from vnpy.trader.utility import extract_vt_symbol, get_trading_date import pandas as pd import csv +from vnpy.data.tq.downloader import get_account_config # pd.pandas.set_option('display.max_rows', None) # 设置最大显示行数,超过该值用省略号代替,为None时显示所有行。 # pd.pandas.set_option('display.max_columns', None) # 设置最大显示列数,超过该值用省略号代替,为None时显示所有列。 @@ -59,6 +60,7 @@ def to_tq_symbol(symbol: str, exchange: Exchange) -> str: """ TQSdk exchange first """ + count = 0 for count, word in enumerate(symbol): if word.isdigit(): break @@ -128,8 +130,9 @@ class TqFutureData(): def __init__(self, strategy=None): self.strategy = strategy # 传进来策略实例,这样可以写日志到策略实例 + auth_dict = get_account_config() - self.api = TqApi(TqSim(), url="wss://u.shinnytech.com/t/md/front/mobile") + self.api = TqApi(TqSim(),auth=TqAuth(auth_dict['user_name'],auth_dict['password'])) # url="wss://u.shinnytech.com/t/md/front/mobile" def get_tick_serial(self, vt_symbol: str): # 获取最新的8964个数据 tick的话就相当于只有50分钟左右