203 lines
6.4 KiB
Python
203 lines
6.4 KiB
Python
# flake8: noqa
|
|
"""
|
|
下载通达信股票合约1分钟&日线bar => vnpy项目目录/bar_data/
|
|
上海股票 => SSE子目录
|
|
深圳股票 => SZSE子目录
|
|
修改为多进程模式
|
|
"""
|
|
import os
|
|
import sys
|
|
import csv
|
|
import json
|
|
from collections import OrderedDict
|
|
import pandas as pd
|
|
from multiprocessing import Pool
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
from copy import copy
|
|
|
|
vnpy_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
|
|
if vnpy_root not in sys.path:
|
|
sys.path.append(vnpy_root)
|
|
|
|
os.environ["VNPY_TESTING"] = "1"
|
|
|
|
from vnpy.data.tdx.tdx_stock_data import *
|
|
from vnpy.data.common import resample_bars_file
|
|
from vnpy.trader.utility import load_json
|
|
from vnpy.trader.utility import get_csv_last_dt
|
|
from vnpy.trader.util_wechat import send_wx_msg
|
|
|
|
# 保存的1分钟指数 bar目录
|
|
bar_data_folder = os.path.abspath(os.path.join(vnpy_root, 'bar_data'))
|
|
|
|
# 开始日期(每年大概需要几分钟)
|
|
start_date = '20160101'
|
|
|
|
# 创建API对象
|
|
api_01 = TdxStockData()
|
|
|
|
# 额外需要数据下载的基金列表
|
|
stock_list = load_json('stock_list.json')
|
|
|
|
# 强制更新缓存
|
|
api_01.cache_config()
|
|
symbol_dict = api_01.symbol_dict
|
|
#
|
|
# thread_executor = ThreadPoolExecutor(max_workers=1)
|
|
# thread_tasks = []
|
|
|
|
|
|
def refill(symbol_info):
|
|
period = symbol_info['period']
|
|
progress = symbol_info['progress']
|
|
# print("{}_{}".format(period, symbol_info['code']))
|
|
# return
|
|
stock_code = symbol_info['code']
|
|
|
|
# if stock_code in stock_list:
|
|
# print(symbol_info['code'])
|
|
if symbol_info['exchange'] == 'SZSE':
|
|
exchange_name = '深交所'
|
|
exchange = Exchange.SZSE
|
|
else:
|
|
exchange_name = '上交所'
|
|
exchange = Exchange.SSE
|
|
|
|
# num_stocks += 1
|
|
|
|
stock_name = symbol_info.get('name')
|
|
print(f'开始更新:{exchange_name}/{stock_name}, 代码:{stock_code}')
|
|
bar_file_folder = os.path.abspath(os.path.join(bar_data_folder, f'{exchange.value}'))
|
|
if not os.path.exists(bar_file_folder):
|
|
os.makedirs(bar_file_folder)
|
|
# csv数据文件名
|
|
p_name = period.replace('min', 'm').replace('day', 'd').replace('hour', 'h')
|
|
bar_file_path = os.path.abspath(os.path.join(bar_file_folder, f'{stock_code}_{p_name}.csv'))
|
|
|
|
# 如果文件存在,
|
|
if os.path.exists(bar_file_path):
|
|
# 取最后一条时间
|
|
last_dt = get_csv_last_dt(bar_file_path)
|
|
else:
|
|
last_dt = None
|
|
|
|
if last_dt:
|
|
start_dt = last_dt - timedelta(days=1)
|
|
print(f'文件{bar_file_path}存在,最后时间:{start_dt}')
|
|
else:
|
|
start_dt = datetime.strptime(start_date, '%Y%m%d')
|
|
print(f'文件{bar_file_path}不存在,或读取最后记录错误,开始时间:{start_date}')
|
|
|
|
d1 = datetime.now()
|
|
result, bars = api_01.get_bars(symbol=stock_code,
|
|
period=period,
|
|
callback=None,
|
|
start_dt=start_dt,
|
|
return_bar=False)
|
|
# [dict] => dataframe
|
|
if not result or len(bars) == 0:
|
|
return
|
|
|
|
need_resample = False
|
|
# 全新数据
|
|
if last_dt is None:
|
|
data_df = pd.DataFrame(bars)
|
|
data_df.set_index('datetime', inplace=True)
|
|
data_df = data_df.sort_index()
|
|
# print(data_df.head())
|
|
print(data_df.tail())
|
|
data_df.to_csv(bar_file_path, index=True, encoding='utf8')
|
|
d2 = datetime.now()
|
|
microseconds = (d1 - d1).microseconds
|
|
print(f'{progress}% 首次更新{stock_code} {stock_name}数据 {microseconds} 毫秒=> 文件{bar_file_path}')
|
|
need_resample = True
|
|
|
|
# 增量更新
|
|
else:
|
|
# 获取标题
|
|
headers = []
|
|
with open(bar_file_path, "r", encoding='utf8') as f:
|
|
reader = csv.reader(f)
|
|
for header in reader:
|
|
headers = header
|
|
break
|
|
|
|
bar_count = 0
|
|
# 写入所有大于最后bar时间的数据
|
|
# with open(bar_file_path, 'a', encoding='utf8', newline='\n') as csvWriteFile:
|
|
with open(bar_file_path, 'a', encoding='utf8') as csvWriteFile:
|
|
|
|
writer = csv.DictWriter(f=csvWriteFile, fieldnames=headers, dialect='excel',
|
|
extrasaction='ignore')
|
|
for bar in bars:
|
|
if bar['datetime'] <= last_dt:
|
|
continue
|
|
bar_count += 1
|
|
writer.writerow(bar)
|
|
if not need_resample:
|
|
need_resample = True
|
|
d2 = datetime.now()
|
|
microseconds = round((d2 - d1).microseconds / 100, 0)
|
|
print(f'{progress}%,更新{stock_code} {stock_name} 数据 {microseconds}毫秒 => 文件{bar_file_path}, 最后记录:{bars[-1]}')
|
|
|
|
# 采用多线程方式输出 5、15、30分钟的数据
|
|
# if period == '1min' and need_resample:
|
|
# task = thread_executor.submit(resample, stock_code, exchange, [5, 15, 30])
|
|
# thread_tasks.append(task)
|
|
|
|
|
|
def resample(vt_symbol, x_mins=[5, 15, 30]):
|
|
"""
|
|
更新多周期文件
|
|
:param vt_symbol: 代码.交易所
|
|
:param x_mins:
|
|
:return:
|
|
"""
|
|
d1 = datetime.now()
|
|
out_files, err_msg = resample_bars_file(vt_symbol=vt_symbol,
|
|
x_mins=x_mins)
|
|
d2 = datetime.now()
|
|
microseconds = round((d2 - d1).microseconds / 100, 0)
|
|
if len(err_msg) > 0:
|
|
print(err_msg, file=sys.stderr)
|
|
|
|
if out_files:
|
|
print(f'{microseconds}毫秒,生成 =>{out_files}')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
# 下载所有的股票数据
|
|
num_progress = 0
|
|
total_tasks = len(symbol_dict.keys()) * 2
|
|
tasks = []
|
|
for period in ['1min', '5min', '15min', '30min', '1hour', '1day']:
|
|
for symbol in symbol_dict.keys():
|
|
info = copy(symbol_dict[symbol])
|
|
stock_code = info['code']
|
|
|
|
# 股票/可转债; 或 存在指定下载文件中
|
|
if ('stock_type' in info.keys() \
|
|
and info['stock_type'] in ['stock_cn', 'cb_cn']) \
|
|
or stock_code in stock_list:
|
|
info['period'] = period
|
|
tasks.append(info)
|
|
# if len(tasks) > 12:
|
|
# break
|
|
|
|
total_tasks = len(tasks)
|
|
for task in tasks:
|
|
num_progress += 1
|
|
task['progress'] = round(100 * num_progress / total_tasks, 2)
|
|
|
|
p = Pool(12)
|
|
p.map(refill, tasks)
|
|
p.close()
|
|
p.join()
|
|
|
|
#
|
|
msg = 'tdx股票数据补充完毕: num_stocks={}'.format(total_tasks)
|
|
send_wx_msg(content=msg)
|
|
os._exit(0)
|