[update] 天勤数据下载适应新版
This commit is contained in:
parent
015efa9d32
commit
2c58c51e7f
@ -6,14 +6,20 @@
|
|||||||
# 2. 下载tick时,5档行情都下载
|
# 2. 下载tick时,5档行情都下载
|
||||||
# 3. 五档行情变量调整适合vnpy的命名方式
|
# 3. 五档行情变量调整适合vnpy的命名方式
|
||||||
|
|
||||||
import os
|
import asyncio
|
||||||
import csv
|
import csv
|
||||||
|
import os
|
||||||
from datetime import date, datetime
|
from datetime import date, datetime
|
||||||
from typing import Union, List
|
from typing import Union, List
|
||||||
|
import lzma
|
||||||
|
|
||||||
|
import pandas
|
||||||
import json
|
import json
|
||||||
from tqsdk.api import TqApi
|
from tqsdk.api import TqApi
|
||||||
|
from tqsdk.channel import TqChan
|
||||||
from tqsdk.datetime import _get_trading_day_start_time, _get_trading_day_end_time
|
from tqsdk.datetime import _get_trading_day_start_time, _get_trading_day_end_time
|
||||||
from tqsdk.diff import _get_obj
|
from tqsdk.diff import _get_obj
|
||||||
|
from tqsdk.tafunc import get_dividend_df, get_dividend_factor
|
||||||
from tqsdk.utils import _generate_uuid
|
from tqsdk.utils import _generate_uuid
|
||||||
|
|
||||||
def get_account_config():
|
def get_account_config():
|
||||||
@ -32,15 +38,25 @@ def get_account_config():
|
|||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
DEAD_INS = {}
|
||||||
|
|
||||||
|
# 价格相关的字段,需要 format 数据格式
|
||||||
|
PRICE_KEYS = ["open", "high", "low", "close", "last_price", "highest", "lowest"] + [f"bid_price{i}" for i in range(1, 6)] + [f"ask_price{i}" for i in range(1, 6)]
|
||||||
|
|
||||||
|
|
||||||
class DataDownloader:
|
class DataDownloader:
|
||||||
"""
|
"""
|
||||||
|
数据下载工具是 TqSdk 专业版中的功能,能让用户下载目前 TqSdk 提供的全部期货、期权和股票类的历史数据,下载数据支持 tick 级别精度和任意 kline 周期
|
||||||
|
|
||||||
|
如果想使用数据下载工具,可以点击 `天勤量化专业版 <https://www.shinnytech.com/tqsdk_professional/>`_ 申请使用或购买
|
||||||
|
|
||||||
历史数据下载器, 输出到csv文件
|
历史数据下载器, 输出到csv文件
|
||||||
|
|
||||||
多合约按时间横向对齐
|
多合约按时间横向对齐
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, api: TqApi, symbol_list: Union[str, List[str]], dur_sec: int, start_dt: Union[date, datetime],
|
def __init__(self, api: TqApi, symbol_list: Union[str, List[str]], dur_sec: int, start_dt: Union[date, datetime],
|
||||||
end_dt: Union[date, datetime], csv_file_name: str) -> None:
|
end_dt: Union[date, datetime], csv_file_name: str, adj_type: Union[str, None] = None) -> None:
|
||||||
"""
|
"""
|
||||||
创建历史数据下载器实例
|
创建历史数据下载器实例
|
||||||
|
|
||||||
@ -55,16 +71,18 @@ class DataDownloader:
|
|||||||
|
|
||||||
end_dt (date/datetime): 结束时间, 如果类型为 date 则指的是交易日, 如果为 datetime 则指的是具体时间点
|
end_dt (date/datetime): 结束时间, 如果类型为 date 则指的是交易日, 如果为 datetime 则指的是具体时间点
|
||||||
|
|
||||||
csv_file_name (str): 输出csv的文件名
|
csv_file_name (str): 输出 csv 的文件名
|
||||||
|
|
||||||
|
adj_type (str/None): 复权计算方式,默认值为 None。"F" 为前复权;"B" 为后复权;None 表示不复权。只对股票、基金合约有效。
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
from datetime import datetime, date
|
from datetime import datetime, date
|
||||||
from contextlib import closing
|
from contextlib import closing
|
||||||
from tqsdk import TqApi, TqSim
|
from tqsdk import TqApi, TqAuth, TqSim
|
||||||
from tqsdk.tools import DataDownloader
|
from tqsdk.tools import DataDownloader
|
||||||
|
|
||||||
api = TqApi(TqSim())
|
api = TqApi(auth=TqAuth("信易账户", "账户密码"))
|
||||||
download_tasks = {}
|
download_tasks = {}
|
||||||
# 下载从 2018-01-01 到 2018-09-01 的 SR901 日线数据
|
# 下载从 2018-01-01 到 2018-09-01 的 SR901 日线数据
|
||||||
download_tasks["SR_daily"] = DataDownloader(api, symbol_list="CZCE.SR901", dur_sec=24*60*60,
|
download_tasks["SR_daily"] = DataDownloader(api, symbol_list="CZCE.SR901", dur_sec=24*60*60,
|
||||||
@ -87,27 +105,31 @@ class DataDownloader:
|
|||||||
print("progress: ", { k:("%.2f%%" % v.get_progress()) for k,v in download_tasks.items() })
|
print("progress: ", { k:("%.2f%%" % v.get_progress()) for k,v in download_tasks.items() })
|
||||||
"""
|
"""
|
||||||
self._api = api
|
self._api = api
|
||||||
|
if not self._api._auth._has_feature("tq_dl"):
|
||||||
|
raise Exception("您的账户不支持下载历史数据功能,需要购买专业版本后使用。升级网址:https://account.shinnytech.com")
|
||||||
if isinstance(start_dt, datetime):
|
if isinstance(start_dt, datetime):
|
||||||
self._start_dt_nano = int(start_dt.timestamp() * 1e9)
|
self._start_dt_nano = int(start_dt.timestamp() * 1e9)
|
||||||
else:
|
else:
|
||||||
self._start_dt_nano = _get_trading_day_start_time(
|
self._start_dt_nano = _get_trading_day_start_time(int(datetime(start_dt.year, start_dt.month, start_dt.day).timestamp()) * 1000000000)
|
||||||
int(datetime(start_dt.year, start_dt.month, start_dt.day).timestamp()) * 1000000000)
|
|
||||||
if isinstance(end_dt, datetime):
|
if isinstance(end_dt, datetime):
|
||||||
self._end_dt_nano = int(end_dt.timestamp() * 1e9)
|
self._end_dt_nano = int(end_dt.timestamp() * 1e9)
|
||||||
else:
|
else:
|
||||||
self._end_dt_nano = _get_trading_day_end_time(
|
self._end_dt_nano = _get_trading_day_end_time(int(datetime(end_dt.year, end_dt.month, end_dt.day).timestamp()) * 1000000000)
|
||||||
int(datetime(end_dt.year, end_dt.month, end_dt.day).timestamp()) * 1000000000)
|
|
||||||
self._current_dt_nano = self._start_dt_nano
|
self._current_dt_nano = self._start_dt_nano
|
||||||
self._symbol_list = symbol_list if isinstance(symbol_list, list) else [symbol_list]
|
self._symbol_list = symbol_list if isinstance(symbol_list, list) else [symbol_list]
|
||||||
# 检查合约代码是否存在
|
# 下载合约超时时间(默认 30s),已下市的没有交易的合约,超时时间可以设置短一点(2s),用户不希望自己的程序因为没有下载到数据而中断
|
||||||
for symbol in self._symbol_list:
|
self._timeout_seconds = 2 if any([symbol in DEAD_INS for symbol in self._symbol_list]) else 30
|
||||||
if (not self._api._stock) and symbol not in self._api._data.get("quotes", {}):
|
|
||||||
raise Exception("代码 %s 不存在, 请检查合约代码是否填写正确" % (symbol))
|
|
||||||
self._dur_nano = dur_sec * 1000000000
|
self._dur_nano = dur_sec * 1000000000
|
||||||
if self._dur_nano == 0 and len(self._symbol_list) != 1:
|
if self._dur_nano == 0 and len(self._symbol_list) != 1:
|
||||||
raise Exception("Tick序列不支持多合约")
|
raise Exception("Tick序列不支持多合约")
|
||||||
|
if adj_type not in [None, "F", "B", "FORWARD", "BACK"]:
|
||||||
|
raise Exception("adj_type 参数只支持 None (不复权) | 'F' (前复权) | 'B' (后复权)")
|
||||||
|
self._adj_type = adj_type[0] if adj_type else adj_type
|
||||||
self._csv_file_name = csv_file_name
|
self._csv_file_name = csv_file_name
|
||||||
self._task = self._api.create_task(self._download_data())
|
self._csv_header = self._get_headers()
|
||||||
|
self._dividend_cache = {} # 缓存合约对应的复权系数矩阵,每个合约只计算一次
|
||||||
|
self._data_series = None
|
||||||
|
self._task = self._api.create_task(self._run())
|
||||||
|
|
||||||
def is_finished(self) -> bool:
|
def is_finished(self) -> bool:
|
||||||
"""
|
"""
|
||||||
@ -128,6 +150,142 @@ class DataDownloader:
|
|||||||
return 100.0 if self._task.done() else (self._current_dt_nano - self._start_dt_nano) / (
|
return 100.0 if self._task.done() else (self._current_dt_nano - self._start_dt_nano) / (
|
||||||
self._end_dt_nano - self._start_dt_nano) * 100
|
self._end_dt_nano - self._start_dt_nano) * 100
|
||||||
|
|
||||||
|
def _get_data_series(self) -> pandas.DataFrame:
|
||||||
|
"""
|
||||||
|
获取下载的 DataFrame 格式数据
|
||||||
|
|
||||||
|
todo: 在 utils 中增加工具函数,返回与 kline 一致的数据结构
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pandas.DataFrame/None: 下载的 klines 或者 ticks 数据,DataFrame 格式。下载完成前返回 None。
|
||||||
|
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
from datetime import datetime, date
|
||||||
|
rom tqsdk import TqApi, TqAuth
|
||||||
|
from contextlib import closing
|
||||||
|
from tqsdk.tools import DataDownloader
|
||||||
|
|
||||||
|
api = TqApi(auth=TqAuth("信易账户", "账户密码"))
|
||||||
|
# 下载从 2018-06-01 到 2018-09-01 的 SR901 日线数据
|
||||||
|
download_task = DataDownloader(api, symbol_list="CZCE.SR901", dur_sec=24*60*60,
|
||||||
|
start_dt=date(2018, 6, 1), end_dt=date(2018, 9, 1), csv_file_name="klines.csv")
|
||||||
|
# 使用with closing机制确保下载完成后释放对应的资源
|
||||||
|
with closing(api):
|
||||||
|
while not download_task.is_finished():
|
||||||
|
api.wait_update()
|
||||||
|
print(f"progress: {download_task.get_progress():.2} %")
|
||||||
|
print(download_task._get_data_series())
|
||||||
|
"""
|
||||||
|
if not self._task.done():
|
||||||
|
return None
|
||||||
|
if not self._data_series:
|
||||||
|
self._data_series = pandas.read_csv(self._csv_file_name)
|
||||||
|
return self._data_series
|
||||||
|
|
||||||
|
async def _update_dividend_factor(self):
|
||||||
|
for s in self._dividend_cache:
|
||||||
|
# 对每个除权除息矩阵增加 factor 序列,为当日的复权因子
|
||||||
|
df = self._dividend_cache[s]["df"]
|
||||||
|
df["pre_close"] = float('nan') # 初始化 pre_close 为 nan
|
||||||
|
between = df["datetime"].between(self._start_dt_nano, self._end_dt_nano) # 只需要开始时间~结束时间之间的复权因子
|
||||||
|
for i in df[between].index:
|
||||||
|
chart_info = {
|
||||||
|
"aid": "set_chart",
|
||||||
|
"chart_id": _generate_uuid("PYSDK_downloader"),
|
||||||
|
"ins_list": s,
|
||||||
|
"duration": 86400 * 1000000000,
|
||||||
|
"view_width": 2,
|
||||||
|
"focus_datetime": int(df.iloc[i].datetime),
|
||||||
|
"focus_position": 1
|
||||||
|
}
|
||||||
|
await self._api._send_chan.send(chart_info)
|
||||||
|
chart = _get_obj(self._api._data, ["charts", chart_info["chart_id"]])
|
||||||
|
serial = _get_obj(self._api._data, ["klines", s, str(86400000000000)])
|
||||||
|
try:
|
||||||
|
async with self._api.register_update_notify() as update_chan:
|
||||||
|
async for _ in update_chan:
|
||||||
|
if not (chart_info.items() <= _get_obj(chart, ["state"]).items()):
|
||||||
|
continue # 当前请求还没收齐回应, 不应继续处理
|
||||||
|
left_id = chart.get("left_id", -1)
|
||||||
|
right_id = chart.get("right_id", -1)
|
||||||
|
if (left_id == -1 and right_id == -1) or self._api._data.get("mdhis_more_data", True) or serial.get("last_id", -1) == -1:
|
||||||
|
continue # 定位信息还没收到, 或数据序列还没收到, 合约的数据是否收到
|
||||||
|
last_item = serial["data"].get(str(left_id), {})
|
||||||
|
# 复权时间点的昨收盘
|
||||||
|
df.loc[i, 'pre_close'] = last_item['close'] if last_item.get('close') else float('nan')
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
await self._api._send_chan.send({
|
||||||
|
"aid": "set_chart",
|
||||||
|
"chart_id": chart_info["chart_id"],
|
||||||
|
"ins_list": "",
|
||||||
|
"duration": 86400000000000,
|
||||||
|
"view_width": 2
|
||||||
|
})
|
||||||
|
df["factor"] = (df["pre_close"] - df["cash_dividend"]) / df["pre_close"] / (1 + df["stock_dividend"])
|
||||||
|
df["factor"].fillna(1, inplace=True)
|
||||||
|
|
||||||
|
async def _run(self):
|
||||||
|
self._quote_list = await self._api.get_quote_list(self._symbol_list)
|
||||||
|
# 如果存在 STOCK / FUND 并且 adj_type is not None, 这里需要提前准备下载时间段内的复权因子
|
||||||
|
if self._adj_type:
|
||||||
|
for quote in self._quote_list:
|
||||||
|
if quote.ins_class in ["STOCK", "FUND"]:
|
||||||
|
self._dividend_cache[quote.instrument_id] = {
|
||||||
|
"df": get_dividend_df(quote.stock_dividend_ratio, quote.cash_dividend_ratio),
|
||||||
|
"back_factor": 1.0
|
||||||
|
}
|
||||||
|
# 前复权需要提前计算除权因子
|
||||||
|
await self._update_dividend_factor()
|
||||||
|
self._data_chan = TqChan(self._api)
|
||||||
|
task = self._api.create_task(self._download_data())
|
||||||
|
# cols 是复权需要重新计算的列名
|
||||||
|
index_datetime_nano = self._csv_header.index("datetime_nano")
|
||||||
|
if self._dur_nano != 0:
|
||||||
|
cols = ["open", "high", "low", "close"]
|
||||||
|
else:
|
||||||
|
cols = ["last_price", "highest", "lowest"]
|
||||||
|
cols.extend(f"{x}{i}" for x in ["bid_price", "ask_price"] for i in range(1, 6))
|
||||||
|
try:
|
||||||
|
with open(self._csv_file_name, 'w', newline='') as csvfile:
|
||||||
|
csv_writer = csv.writer(csvfile, dialect='excel')
|
||||||
|
csv_writer.writerow(self._csv_header)
|
||||||
|
last_dt = None
|
||||||
|
async for item in self._data_chan:
|
||||||
|
for quote in self._quote_list:
|
||||||
|
symbol = quote.instrument_id
|
||||||
|
if self._adj_type and quote.ins_class in ["STOCK", "FUND"]:
|
||||||
|
dividend_df = self._dividend_cache[symbol]["df"]
|
||||||
|
factor = 1
|
||||||
|
if self._adj_type == "F":
|
||||||
|
gt = dividend_df["datetime"].gt(item[index_datetime_nano])
|
||||||
|
if gt.any():
|
||||||
|
factor = dividend_df[gt]["factor"].cumprod().iloc[-1]
|
||||||
|
elif self._adj_type == "B" and last_dt:
|
||||||
|
gt = dividend_df['datetime'].gt(last_dt)
|
||||||
|
if gt.any():
|
||||||
|
index = dividend_df[gt].index[0]
|
||||||
|
if item[index_datetime_nano] >= dividend_df.loc[index, 'datetime']:
|
||||||
|
self._dividend_cache[symbol]["back_factor"] *= (1 / dividend_df[gt].loc[index, 'factor'])
|
||||||
|
factor = self._dividend_cache[symbol]["back_factor"]
|
||||||
|
last_dt = item[index_datetime_nano]
|
||||||
|
if factor != 1:
|
||||||
|
item = item.copy()
|
||||||
|
for c in cols: # datetime_nano
|
||||||
|
index = self._csv_header.index(f"{symbol}.{c}")
|
||||||
|
item[index] = item[index] * factor
|
||||||
|
csv_writer.writerow(item)
|
||||||
|
finally:
|
||||||
|
task.cancel()
|
||||||
|
await asyncio.gather(task, return_exceptions=True)
|
||||||
|
|
||||||
|
async def _timeout_handle(self, timeout, chart):
|
||||||
|
await asyncio.sleep(timeout)
|
||||||
|
if chart.get("left_id", -1) == -1 and chart.get("right_id", -1) == -1:
|
||||||
|
self._task.cancel()
|
||||||
|
|
||||||
async def _download_data(self):
|
async def _download_data(self):
|
||||||
"""下载数据, 多合约横向按时间对齐"""
|
"""下载数据, 多合约横向按时间对齐"""
|
||||||
chart_info = {
|
chart_info = {
|
||||||
@ -139,98 +297,60 @@ class DataDownloader:
|
|||||||
"focus_datetime": self._start_dt_nano,
|
"focus_datetime": self._start_dt_nano,
|
||||||
"focus_position": 0,
|
"focus_position": 0,
|
||||||
}
|
}
|
||||||
if len(self._symbol_list) == 1:
|
|
||||||
single_exchange, single_symbol = self._symbol_list[0].split('.')
|
|
||||||
else:
|
|
||||||
single_exchange, single_symbol = None, None
|
|
||||||
# 还没有发送过任何请求, 先请求定位左端点
|
# 还没有发送过任何请求, 先请求定位左端点
|
||||||
await self._api._send_chan.send(chart_info)
|
await self._api._send_chan.send(chart_info)
|
||||||
chart = _get_obj(self._api._data, ["charts", chart_info["chart_id"]])
|
chart = _get_obj(self._api._data, ["charts", chart_info["chart_id"]])
|
||||||
|
# 增加一个 task,在 30s 后检查 chart 是否返回了左右 id 范围,如果没有就 cancel self._task,防止程序一直卡在那里
|
||||||
|
timeout_task = self._api.create_task(self._timeout_handle(self._timeout_seconds, chart))
|
||||||
current_id = None # 当前数据指针
|
current_id = None # 当前数据指针
|
||||||
csv_header = []
|
data_cols = self._get_data_cols()
|
||||||
data_cols = ["open", "high", "low", "close", "volume", "open_oi", "close_oi"] if self._dur_nano != 0 else \
|
|
||||||
["last_price", "highest", "lowest", "volume",
|
|
||||||
"amount", "open_interest", "upper_limit", "lower_limit",
|
|
||||||
"bid_price1", "bid_volume1", "ask_price1", "ask_volume1",
|
|
||||||
"bid_price2", "bid_volume2", "ask_price2", "ask_volume2",
|
|
||||||
"bid_price3", "bid_volume3", "ask_price3", "ask_volume3",
|
|
||||||
"bid_price4", "bid_volume4", "ask_price4", "ask_volume4",
|
|
||||||
"bid_price5", "bid_volume5", "ask_price5", "ask_volume5"
|
|
||||||
]
|
|
||||||
serials = []
|
serials = []
|
||||||
for symbol in self._symbol_list:
|
for symbol in self._symbol_list:
|
||||||
path = ["klines", symbol, str(self._dur_nano)] if self._dur_nano != 0 else ["ticks", symbol]
|
path = ["klines", symbol, str(self._dur_nano)] if self._dur_nano != 0 else ["ticks", symbol]
|
||||||
serial = _get_obj(self._api._data, path)
|
serial = _get_obj(self._api._data, path)
|
||||||
serials.append(serial)
|
serials.append(serial)
|
||||||
try:
|
try:
|
||||||
with open(self._csv_file_name, 'w', newline='') as csvfile:
|
async with self._api.register_update_notify() as update_chan:
|
||||||
csv_writer = csv.writer(csvfile, dialect='excel')
|
async for _ in update_chan:
|
||||||
async with self._api.register_update_notify() as update_chan:
|
if not (chart_info.items() <= _get_obj(chart, ["state"]).items()):
|
||||||
async for _ in update_chan:
|
# 当前请求还没收齐回应, 不应继续处理
|
||||||
if not (chart_info.items() <= _get_obj(chart, ["state"]).items()):
|
continue
|
||||||
# 当前请求还没收齐回应, 不应继续处理
|
left_id = chart.get("left_id", -1)
|
||||||
|
right_id = chart.get("right_id", -1)
|
||||||
|
if (left_id == -1 and right_id == -1) or self._api._data.get("mdhis_more_data", True):
|
||||||
|
# 定位信息还没收到, 或数据序列还没收到
|
||||||
|
continue
|
||||||
|
for serial in serials:
|
||||||
|
# 检查合约的数据是否收到
|
||||||
|
if serial.get("last_id", -1) == -1:
|
||||||
continue
|
continue
|
||||||
left_id = chart.get("left_id", -1)
|
if current_id is None:
|
||||||
right_id = chart.get("right_id", -1)
|
current_id = max(left_id, 0)
|
||||||
if (left_id == -1 and right_id == -1) or self._api._data.get("mdhis_more_data", True):
|
while current_id <= right_id:
|
||||||
# 定位信息还没收到, 或数据序列还没收到
|
item = serials[0]["data"].get(str(current_id), {})
|
||||||
continue
|
if item.get("datetime", 0) == 0 or item["datetime"] > self._end_dt_nano:
|
||||||
for serial in serials:
|
# 当前 id 已超出 last_id 或k线数据的时间已经超过用户限定的右端
|
||||||
# 检查合约的数据是否收到
|
return
|
||||||
if serial.get("last_id", -1) == -1:
|
row = [self._nano_to_str(item["datetime"]), item["datetime"]]
|
||||||
continue
|
for col in data_cols:
|
||||||
if current_id is None:
|
row.append(self._get_value(item, col, self._quote_list[0]["price_decs"]))
|
||||||
current_id = max(left_id, 0)
|
for i in range(1, len(self._symbol_list)):
|
||||||
while current_id <= right_id:
|
symbol = self._symbol_list[i]
|
||||||
item = serials[0]["data"].get(str(current_id), {})
|
tid = serials[0].get("binding", {}).get(symbol, {}).get(str(current_id), -1)
|
||||||
if item.get("datetime", 0) == 0 or item["datetime"] > self._end_dt_nano:
|
k = {} if tid == -1 else serials[i]["data"].get(str(tid), {})
|
||||||
# 当前 id 已超出 last_id 或k线数据的时间已经超过用户限定的右端
|
|
||||||
return
|
|
||||||
if len(csv_header) == 0:
|
|
||||||
# 写入文件头
|
|
||||||
csv_header = ["datetime"]
|
|
||||||
for symbol in self._symbol_list:
|
|
||||||
# 单一合约时,添加合约和交易所
|
|
||||||
if single_exchange:
|
|
||||||
csv_header.extend(['symbol', 'exchange'])
|
|
||||||
|
|
||||||
for col in data_cols:
|
|
||||||
if col.startswith('bid_') or col.startswith('ask_'):
|
|
||||||
col = col[:-1] + '_' + col[-1]
|
|
||||||
if len(self._symbol_list) > 2:
|
|
||||||
csv_header.append(symbol + "." + col)
|
|
||||||
else:
|
|
||||||
csv_header.append(col)
|
|
||||||
|
|
||||||
csv_writer.writerow(csv_header)
|
|
||||||
row = [self._nano_to_str(item["datetime"])]
|
|
||||||
|
|
||||||
# 单一合约时,添加合约和交易所
|
|
||||||
if single_exchange:
|
|
||||||
row.extend([single_symbol, single_exchange])
|
|
||||||
|
|
||||||
for col in data_cols:
|
for col in data_cols:
|
||||||
row.append(self._get_value(item, col))
|
row.append(self._get_value(k, col, self._quote_list[i]["price_decs"]))
|
||||||
for i in range(1, len(self._symbol_list)):
|
await self._data_chan.send(row)
|
||||||
symbol = self._symbol_list[i]
|
current_id += 1
|
||||||
tid = serials[0].get("binding", {}).get(symbol, {}).get(str(current_id), -1)
|
self._current_dt_nano = item["datetime"]
|
||||||
k = {} if tid == -1 else serials[i]["data"].get(str(tid), {})
|
# 当前 id 已超出订阅范围, 需重新订阅后续数据
|
||||||
for col in data_cols:
|
chart_info.pop("focus_datetime", None)
|
||||||
row.append(self._get_value(k, col))
|
chart_info.pop("focus_position", None)
|
||||||
# 抛弃盘前的脏数据
|
chart_info["left_kline_id"] = current_id
|
||||||
if self._dur_nano == 0 and str(row[3]) == 'nan':
|
await self._api._send_chan.send(chart_info)
|
||||||
p = 1
|
|
||||||
else:
|
|
||||||
csv_writer.writerow(row)
|
|
||||||
current_id += 1
|
|
||||||
self._current_dt_nano = item["datetime"]
|
|
||||||
# 当前 id 已超出订阅范围, 需重新订阅后续数据
|
|
||||||
chart_info.pop("focus_datetime", None)
|
|
||||||
chart_info.pop("focus_position", None)
|
|
||||||
chart_info["left_kline_id"] = current_id
|
|
||||||
await self._api._send_chan.send(chart_info)
|
|
||||||
finally:
|
finally:
|
||||||
# 释放chart资源
|
# 释放chart资源
|
||||||
|
await self._data_chan.close()
|
||||||
await self._api._send_chan.send({
|
await self._api._send_chan.send({
|
||||||
"aid": "set_chart",
|
"aid": "set_chart",
|
||||||
"chart_id": chart_info["chart_id"],
|
"chart_id": chart_info["chart_id"],
|
||||||
@ -238,18 +358,47 @@ class DataDownloader:
|
|||||||
"duration": self._dur_nano,
|
"duration": self._dur_nano,
|
||||||
"view_width": 2000,
|
"view_width": 2000,
|
||||||
})
|
})
|
||||||
|
timeout_task.cancel()
|
||||||
|
await timeout_task
|
||||||
|
|
||||||
|
def _get_headers(self):
|
||||||
|
data_cols = self._get_data_cols()
|
||||||
|
data_cols = [col.replace('d_price','d_price_').replace('d_volume','d_volume_').replace('k_price','k_price_').replace('k_volume','k_volume_') for col in data_cols]
|
||||||
|
if len(self._symbol_list) > 1:
|
||||||
|
return ["datetime", "datetime_nano"] + [f"{symbol}.{col}" for symbol in self._symbol_list for col in data_cols]
|
||||||
|
else:
|
||||||
|
return ["datetime", "datetime_nano"] + data_cols
|
||||||
|
|
||||||
|
|
||||||
|
def _get_data_cols(self):
|
||||||
|
if self._dur_nano != 0:
|
||||||
|
return ["open", "high", "low", "close", "volume", "open_oi", "close_oi"]
|
||||||
|
else:
|
||||||
|
# ,"upper_limit","lower_limit"
|
||||||
|
cols = ["last_price", "highest", "lowest", "volume", "amount", "open_interest"]
|
||||||
|
price_range = 1
|
||||||
|
for symbol in self._symbol_list:
|
||||||
|
if symbol.split('.')[0] in {"SHFE", "INE", "SSE", "SZSE"}:
|
||||||
|
price_range = 5
|
||||||
|
break
|
||||||
|
for i in range(price_range):
|
||||||
|
cols.extend(f"{x}{i+1}" for x in ["bid_price", "bid_volume", "ask_price", "ask_volume"])
|
||||||
|
return cols
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_value(obj, key):
|
def _get_value(obj, key, price_decs):
|
||||||
if key not in obj:
|
if key not in obj:
|
||||||
return "#N/A"
|
return "#N/A"
|
||||||
if isinstance(obj[key], str):
|
if isinstance(obj[key], str):
|
||||||
return float("nan")
|
return float("nan")
|
||||||
return obj[key]
|
if key in PRICE_KEYS:
|
||||||
|
return round(obj[key], price_decs)
|
||||||
|
else:
|
||||||
|
return obj[key]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _nano_to_str(nano):
|
def _nano_to_str(nano):
|
||||||
dt = datetime.fromtimestamp(nano // 1000000000)
|
dt = datetime.fromtimestamp(nano // 1000000000)
|
||||||
s = dt.strftime('%Y-%m-%d %H:%M:%S')
|
s = dt.strftime('%Y-%m-%d %H:%M:%S')
|
||||||
s += '.' + str(int(nano % 1000000000)).zfill(9)[:3]
|
s += '.' + str(int(nano % 1000000000)).zfill(9)
|
||||||
return s
|
return s
|
||||||
|
@ -6,7 +6,7 @@ from contextlib import closing
|
|||||||
import os
|
import os
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from tqsdk import TqApi, TqSim
|
from tqsdk import TqApi, TqSim,TqAuth
|
||||||
from vnpy.data.tq.downloader import DataDownloader
|
from vnpy.data.tq.downloader import DataDownloader
|
||||||
from vnpy.trader.constant import (
|
from vnpy.trader.constant import (
|
||||||
Direction,
|
Direction,
|
||||||
@ -22,6 +22,7 @@ from vnpy.trader.object import TickData, BarData
|
|||||||
from vnpy.trader.utility import extract_vt_symbol, get_trading_date
|
from vnpy.trader.utility import extract_vt_symbol, get_trading_date
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import csv
|
import csv
|
||||||
|
from vnpy.data.tq.downloader import get_account_config
|
||||||
|
|
||||||
# pd.pandas.set_option('display.max_rows', None) # 设置最大显示行数,超过该值用省略号代替,为None时显示所有行。
|
# pd.pandas.set_option('display.max_rows', None) # 设置最大显示行数,超过该值用省略号代替,为None时显示所有行。
|
||||||
# pd.pandas.set_option('display.max_columns', None) # 设置最大显示列数,超过该值用省略号代替,为None时显示所有列。
|
# pd.pandas.set_option('display.max_columns', None) # 设置最大显示列数,超过该值用省略号代替,为None时显示所有列。
|
||||||
@ -59,6 +60,7 @@ def to_tq_symbol(symbol: str, exchange: Exchange) -> str:
|
|||||||
"""
|
"""
|
||||||
TQSdk exchange first
|
TQSdk exchange first
|
||||||
"""
|
"""
|
||||||
|
count = 0
|
||||||
for count, word in enumerate(symbol):
|
for count, word in enumerate(symbol):
|
||||||
if word.isdigit():
|
if word.isdigit():
|
||||||
break
|
break
|
||||||
@ -128,8 +130,9 @@ class TqFutureData():
|
|||||||
|
|
||||||
def __init__(self, strategy=None):
|
def __init__(self, strategy=None):
|
||||||
self.strategy = strategy # 传进来策略实例,这样可以写日志到策略实例
|
self.strategy = strategy # 传进来策略实例,这样可以写日志到策略实例
|
||||||
|
auth_dict = get_account_config()
|
||||||
|
|
||||||
self.api = TqApi(TqSim(), url="wss://u.shinnytech.com/t/md/front/mobile")
|
self.api = TqApi(TqSim(),auth=TqAuth(auth_dict['user_name'],auth_dict['password'])) # url="wss://u.shinnytech.com/t/md/front/mobile"
|
||||||
|
|
||||||
def get_tick_serial(self, vt_symbol: str):
|
def get_tick_serial(self, vt_symbol: str):
|
||||||
# 获取最新的8964个数据 tick的话就相当于只有50分钟左右
|
# 获取最新的8964个数据 tick的话就相当于只有50分钟左右
|
||||||
|
Loading…
Reference in New Issue
Block a user