vnpy/prod/jobs/refill_tdx_stock_bars.py

203 lines
6.4 KiB
Python

# flake8: noqa
"""
下载通达信股票合约1分钟&日线bar => vnpy项目目录/bar_data/
上海股票 => SSE子目录
深圳股票 => SZSE子目录
修改为多进程模式
"""
import os
import sys
import csv
import json
from collections import OrderedDict
import pandas as pd
from multiprocessing import Pool
from concurrent.futures import ThreadPoolExecutor
from copy import copy
vnpy_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
if vnpy_root not in sys.path:
sys.path.append(vnpy_root)
os.environ["VNPY_TESTING"] = "1"
from vnpy.data.tdx.tdx_stock_data import *
from vnpy.data.common import resample_bars_file
from vnpy.trader.utility import load_json
from vnpy.trader.utility import get_csv_last_dt
from vnpy.trader.util_wechat import send_wx_msg
# 保存的1分钟指数 bar目录
bar_data_folder = os.path.abspath(os.path.join(vnpy_root, 'bar_data'))
# 开始日期(每年大概需要几分钟)
start_date = '20160101'
# 创建API对象
api_01 = TdxStockData()
# 额外需要数据下载的基金列表
stock_list = load_json('stock_list.json')
# 强制更新缓存
api_01.cache_config()
symbol_dict = api_01.symbol_dict
#
# thread_executor = ThreadPoolExecutor(max_workers=1)
# thread_tasks = []
def refill(symbol_info):
period = symbol_info['period']
progress = symbol_info['progress']
# print("{}_{}".format(period, symbol_info['code']))
# return
stock_code = symbol_info['code']
# if stock_code in stock_list:
# print(symbol_info['code'])
if symbol_info['exchange'] == 'SZSE':
exchange_name = '深交所'
exchange = Exchange.SZSE
else:
exchange_name = '上交所'
exchange = Exchange.SSE
# num_stocks += 1
stock_name = symbol_info.get('name')
print(f'开始更新:{exchange_name}/{stock_name}, 代码:{stock_code}')
bar_file_folder = os.path.abspath(os.path.join(bar_data_folder, f'{exchange.value}'))
if not os.path.exists(bar_file_folder):
os.makedirs(bar_file_folder)
# csv数据文件名
p_name = period.replace('min', 'm').replace('day', 'd').replace('hour', 'h')
bar_file_path = os.path.abspath(os.path.join(bar_file_folder, f'{stock_code}_{p_name}.csv'))
# 如果文件存在,
if os.path.exists(bar_file_path):
# 取最后一条时间
last_dt = get_csv_last_dt(bar_file_path)
else:
last_dt = None
if last_dt:
start_dt = last_dt - timedelta(days=1)
print(f'文件{bar_file_path}存在,最后时间:{start_dt}')
else:
start_dt = datetime.strptime(start_date, '%Y%m%d')
print(f'文件{bar_file_path}不存在,或读取最后记录错误,开始时间:{start_date}')
d1 = datetime.now()
result, bars = api_01.get_bars(symbol=stock_code,
period=period,
callback=None,
start_dt=start_dt,
return_bar=False)
# [dict] => dataframe
if not result or len(bars) == 0:
return
need_resample = False
# 全新数据
if last_dt is None:
data_df = pd.DataFrame(bars)
data_df.set_index('datetime', inplace=True)
data_df = data_df.sort_index()
# print(data_df.head())
print(data_df.tail())
data_df.to_csv(bar_file_path, index=True, encoding='utf8')
d2 = datetime.now()
microseconds = (d1 - d1).microseconds
print(f'{progress}% 首次更新{stock_code} {stock_name}数据 {microseconds} 毫秒=> 文件{bar_file_path}')
need_resample = True
# 增量更新
else:
# 获取标题
headers = []
with open(bar_file_path, "r", encoding='utf8') as f:
reader = csv.reader(f)
for header in reader:
headers = header
break
bar_count = 0
# 写入所有大于最后bar时间的数据
# with open(bar_file_path, 'a', encoding='utf8', newline='\n') as csvWriteFile:
with open(bar_file_path, 'a', encoding='utf8') as csvWriteFile:
writer = csv.DictWriter(f=csvWriteFile, fieldnames=headers, dialect='excel',
extrasaction='ignore')
for bar in bars:
if bar['datetime'] <= last_dt:
continue
bar_count += 1
writer.writerow(bar)
if not need_resample:
need_resample = True
d2 = datetime.now()
microseconds = round((d2 - d1).microseconds / 100, 0)
print(f'{progress}%,更新{stock_code} {stock_name} 数据 {microseconds}毫秒 => 文件{bar_file_path}, 最后记录:{bars[-1]}')
# 采用多线程方式输出 5、15、30分钟的数据
# if period == '1min' and need_resample:
# task = thread_executor.submit(resample, stock_code, exchange, [5, 15, 30])
# thread_tasks.append(task)
def resample(vt_symbol, x_mins=[5, 15, 30]):
"""
更新多周期文件
:param vt_symbol: 代码.交易所
:param x_mins:
:return:
"""
d1 = datetime.now()
out_files, err_msg = resample_bars_file(vt_symbol=vt_symbol,
x_mins=x_mins)
d2 = datetime.now()
microseconds = round((d2 - d1).microseconds / 100, 0)
if len(err_msg) > 0:
print(err_msg, file=sys.stderr)
if out_files:
print(f'{microseconds}毫秒,生成 =>{out_files}')
if __name__ == '__main__':
# 下载所有的股票数据
num_progress = 0
total_tasks = len(symbol_dict.keys()) * 2
tasks = []
for period in ['1min', '5min', '15min', '30min', '1hour', '1day']:
for symbol in symbol_dict.keys():
info = copy(symbol_dict[symbol])
stock_code = info['code']
# 股票/可转债; 或 存在指定下载文件中
if ('stock_type' in info.keys() \
and info['stock_type'] in ['stock_cn', 'cb_cn']) \
or stock_code in stock_list:
info['period'] = period
tasks.append(info)
# if len(tasks) > 12:
# break
total_tasks = len(tasks)
for task in tasks:
num_progress += 1
task['progress'] = round(100 * num_progress / total_tasks, 2)
p = Pool(12)
p.map(refill, tasks)
p.close()
p.join()
#
msg = 'tdx股票数据补充完毕: num_stocks={}'.format(total_tasks)
send_wx_msg(content=msg)
os._exit(0)