diff --git a/vnpy/data/tdx/tdx_common.py b/vnpy/data/tdx/tdx_common.py index c453a082..6d45f434 100644 --- a/vnpy/data/tdx/tdx_common.py +++ b/vnpy/data/tdx/tdx_common.py @@ -1,6 +1,11 @@ # encoding: UTF-8 +import sys +import os +import pickle +import bz2 from functools import lru_cache +from logging import INFO, ERROR @lru_cache() @@ -58,3 +63,37 @@ TDX_FUTURE_HOSTS = [ {"ip": '218.80.248.229', 'port': 7721, "name": "备用服务器1"}, {"ip": '124.74.236.94', 'port': 7721, "name": "备用服务器2"}, {'ip': '58.246.109.27', 'port': 7721, "name": "备用服务器3"}] + + +def get_cache_config(config_file_name): + """获取本地缓存的配置地址信息""" + config_file_name = os.path.abspath(os.path.join(os.path.dirname(__file__), config_file_name)) + config = {} + if not os.path.exists(config_file_name): + return config + with bz2.BZ2File(config_file_name, 'rb') as f: + config = pickle.load(f) + return config + return config + +def save_cache_config(data: dict, config_file_name ): + """保存本地缓存的配置地址信息""" + config_file_name = os.path.abspath(os.path.join(os.path.dirname(__file__), config_file_name)) + + with bz2.BZ2File(config_file_name, 'wb') as f: + pickle.dump(data, f) + + +class FakeStrategy(object): + """制作一个假得策略,用于测试""" + def write_log(self, content, level=INFO): + if level == INFO: + print(content) + else: + print(content, file=sys.stderr) + def write_error(self, content): + + self.write_log(content, level=ERROR) + + def display_bar(self, bar, bar_is_completed=True, freq=1): + print(u'{} {}'.format(bar.vtSymbol, bar.datetime)) diff --git a/vnpy/data/tdx/tdx_future_data.py b/vnpy/data/tdx/tdx_future_data.py index e064e501..428fb16b 100644 --- a/vnpy/data/tdx/tdx_future_data.py +++ b/vnpy/data/tdx/tdx_future_data.py @@ -26,7 +26,7 @@ from pytdx.exhq import TdxExHq_API from vnpy.trader.constant import Exchange from vnpy.trader.object import BarData from vnpy.trader.utility import (get_underlying_symbol, get_full_symbol, get_trading_date, load_json, save_json) -from vnpy.data.tdx.tdx_common import (TDX_FUTURE_HOSTS, PERIOD_MAPPING) +from vnpy.data.tdx.tdx_common import (lru_cache, TDX_FUTURE_HOSTS, PERIOD_MAPPING) # 每个周期包含多少分钟 (估算值, 没考虑夜盘和10:15的影响) @@ -78,7 +78,7 @@ def save_cache_ip(best_ip: dict): config_file_name = os.path.abspath(os.path.join(os.path.dirname(__file__), 'tdx_config.json')) save_json(filename=config_file_name, data=best_ip) - +@lru_cache() def get_tdx_marketid(symbol): """普通合约/指数合约=》tdx合约所在市场id""" underlying_symbol = get_underlying_symbol(symbol) @@ -153,7 +153,9 @@ class TdxFutureData(object): return # ---------------------------------------------------------------------- - def ping(self, ip, port=7709): + def ping(self, + ip: str, + port: int = 7709): """ ping行情服务器 :param ip: @@ -219,7 +221,12 @@ class TdxFutureData(object): self.symbol_market_dict.update({tdx_symbol: tdx_market_id}) # ---------------------------------------------------------------------- - def get_bars(self, symbol, period, callback, bar_is_completed=False, bar_freq=1, start_dt=None): + def get_bars(self, + symbol: str, + period: str, + callback, + bar_freq: int = 1, + start_dt: datetime = None): """ 返回k线数据 symbol:合约 @@ -238,7 +245,7 @@ class TdxFutureData(object): self.write_error(u'{} 周期{}不在下载清单中: {}'.format(datetime.now(), period, list(PERIOD_MAPPING.keys()))) return False, ret_bars - # tdx_period = PERIOD_MAPPING.get(period) + tdx_period = PERIOD_MAPPING.get(period) if start_dt is None: self.write_log(u'没有设置开始时间,缺省为10天前') @@ -258,7 +265,7 @@ class TdxFutureData(object): _pos = 0 while _start_date > qry_start_date: _res = self.api.get_instrument_bars( - PERIOD_MAPPING[period], + tdx_period, self.symbol_market_dict.get(tdx_index_symbol, 0), tdx_symbol, _pos, @@ -334,10 +341,10 @@ class TdxFutureData(object): add_bar.date = row['date'] add_bar.time = row['time'] add_bar.trading_date = row['trading_date'] - add_bar.open = float(row['open']) - add_bar.high = float(row['high']) - add_bar.low = float(row['low']) - add_bar.close = float(row['close']) + add_bar.open_price = float(row['open']) + add_bar.high_price = float(row['high']) + add_bar.low_price = float(row['low']) + add_bar.close_price = float(row['close']) add_bar.volume = float(row['volume']) add_bar.openInterest = float(row['open_interest']) except Exception as ex: @@ -372,7 +379,7 @@ class TdxFutureData(object): self.connect(is_reconnect=True) return False, ret_bars - def get_price(self, symbol): + def get_price(self, symbol: str): """获取最新价格""" tdx_symbol = symbol.upper().replace('_', '') @@ -508,7 +515,7 @@ class TdxFutureData(object): _datas = [] _pos = 0 - while(True): + while True: _res = self.api.get_transaction_data( market=self.symbol_market_dict.get(tdx_index_symbol, 0), code=symbol, @@ -601,18 +608,22 @@ class TdxFutureData(object): return None - def get_history_transaction_data(self, symbol, date, cache_folder=None): + def get_history_transaction_data(self, + symbol: str, + trading_date, + cache_folder:str = None): """获取当某一交易日的历史成交记录""" ret_datas = [] - if isinstance(date, datetime): - date = date.strftime('%Y%m%d') - if isinstance(date, str): - date = int(date) + # trading_date, 转换为数字类型得日期 + if isinstance(trading_date, datetime): + trading_date = trading_date.strftime('%Y%m%d') + if isinstance(trading_date, str): + trading_date = int(trading_date.replace('-', '')) self.connect() cache_symbol = symbol - cache_date = str(date) + cache_date = str(trading_date) max_data_size = sys.maxsize symbol = symbol.upper() @@ -634,18 +645,18 @@ class TdxFutureData(object): self.write_log(u'使用缓存文件') return True, buffer_data - self.write_log(u'开始下载{} 历史{}分笔数据'.format(date, symbol)) + self.write_log(u'开始下载{} 历史{}分笔数据'.format(trading_date, symbol)) cur_trading_date = get_trading_date() - if date == int(cur_trading_date.replace('-', '')): + if trading_date == int(cur_trading_date.replace('-', '')): return self.get_transaction_data(symbol) try: _datas = [] _pos = 0 - while(True): + while True: _res = self.api.get_history_transaction_data( market=self.symbol_market_dict.get(tdx_index_symbol, 0), - date=date, + trading_date=trading_date, code=symbol, start=_pos, count=q_size) @@ -689,7 +700,7 @@ class TdxFutureData(object): break if len(_datas) == 0: - self.write_error(u'{}分笔成交数据获取为空'.format(date)) + self.write_error(u'{}分笔成交数据获取为空'.format(trading_date)) return False, _datas # 缓存文件 @@ -708,20 +719,8 @@ class TdxFutureData(object): return False, ret_datas -class FakeStrategy(object): - - def write_log(self, content, level=INFO): - if level == INFO: - print(content) - else: - print(content, file=sys.stderr) - - def display_bar(self, bar, bar_is_completed=True, freq=1): - print(u'{} {}'.format(bar.vtSymbol, bar.datetime)) - - if __name__ == "__main__": - + from .tdx_common import FakeStrategy t1 = FakeStrategy() t2 = FakeStrategy() # 创建API对象 @@ -780,6 +779,6 @@ if __name__ == "__main__": # print(r) # 获取历史分时数据 - ret, result = api_01.get_history_transaction_data('J99', '20191009') + ret, result = api_01.get_history_transaction_data('rb1905', '20190109') for r in result[0:10] + result[-10:]: print(r) diff --git a/vnpy/data/tdx/tdx_stock_data.py b/vnpy/data/tdx/tdx_stock_data.py index 42e56d0e..0375005c 100644 --- a/vnpy/data/tdx/tdx_stock_data.py +++ b/vnpy/data/tdx/tdx_stock_data.py @@ -16,12 +16,18 @@ import pickle import bz2 import traceback from datetime import datetime, timedelta -from logging import ERROR, INFO +from logging import ERROR from pytdx.hq import TdxHq_API +from pytdx.params import TDXParams from pandas import to_datetime from vnpy.trader.object import BarData -from vnpy.data.tdx.tdx_common import PERIOD_MAPPING, get_tdx_market_code +from vnpy.trader.constant import Exchange +from vnpy.data.tdx.tdx_common import ( + PERIOD_MAPPING, + get_tdx_market_code, + get_cache_config, + save_cache_config) # 每个周期包含多少分钟 NUM_MINUTE_MAPPING = {} @@ -30,16 +36,31 @@ NUM_MINUTE_MAPPING['5min'] = 5 NUM_MINUTE_MAPPING['15min'] = 15 NUM_MINUTE_MAPPING['30min'] = 30 NUM_MINUTE_MAPPING['1hour'] = 60 -NUM_MINUTE_MAPPING['1day'] = 60 * 5.5 # 股票,收盘时间是15:00,开盘是9:30 +NUM_MINUTE_MAPPING['1day'] = 60 * 5.5 # 股票,收盘时间是15:00,开盘是9:30 # 常量 QSIZE = 800 +STOCK_CONFIG_FILE = 'tdx_stock_config.pkb2' + +# 通达信 <=> 交易所代码 映射 +TDX_VN_STOCK_MARKET_MAP = { + TDXParams.MARKET_SH: Exchange.SSE, # 1: 上交所 + TDXParams.MARKET_SZ: Exchange.SZSE # 0: 深交所 +} +VN_TDX_STOCK_MARKET_MAP = {v: k for k, v in TDX_VN_STOCK_MARKET_MAP.items()} + +# 通达信 <=> rq交易所代码 映射 +TDX_RQ_STOCK_MARKET_MAP = { + TDXParams.MARKET_SH: 'XSHG', # 1: 上交所 + TDXParams.MARKET_SZ: 'XSHE' # 0: 深交所 +} +RQ_TDX_STOCK_MARKET_MAP = {v: k for k, v in TDX_RQ_STOCK_MARKET_MAP.items()} + + class TdxStockData(object): - best_ip = None - symbol_exchange_dict = {} # tdx合约与vn交易所的字典 - symbol_market_dict = {} # tdx合约与tdx市场的字典 + # ---------------------------------------------------------------------- def __init__(self, strategy=None): @@ -52,24 +73,35 @@ class TdxStockData(object): self.connection_status = False # 连接状态 self.strategy = strategy + self.best_ip = None + self.symbol_exchange_dict = {} # tdx合约与vn交易所的字典 + self.symbol_market_dict = {} # tdx合约与tdx市场的字典 - self.connect() + self.config = get_cache_config(STOCK_CONFIG_FILE) + self.symbol_dict = self.config.get('symbol_dict', {}) + self.cache_time = self.config.get('cache_time', datetime.now() - timedelta(days=7)) + + if len(self.symbol_dict) == 0 or self.cache_time < datetime.now() - timedelta(days=1): + self.cache_config() def write_log(self, content): + """记录日志""" if self.strategy: self.strategy.write_log(content) else: print(content) def write_error(self, content): + """记录错误""" if self.strategy: self.strategy.write_log(content, level=ERROR) else: print(content, file=sys.stderr) - def connect(self): + def connect(self, is_reconnect: bool = False): """ 连接API + :param:is_reconnect, 是否重新连接 :return: """ # 创建api连接对象实例 @@ -79,24 +111,74 @@ class TdxStockData(object): self.api = TdxHq_API(heartbeat=True, auto_retry=True, raise_exception=True) # 选取最佳服务器 - if TdxStockData.best_ip is None: + if is_reconnect or self.best_ip is None: + self.best_ip = self.config.get('best_ip', {}) + + if len(self.best_ip) == 0: from pytdx.util.best_ip import select_best_ip - TdxStockData.best_ip = select_best_ip() + self.best_ip = select_best_ip() + self.config.update({'best_ip': self.best_ip}) + save_cache_config(self.config, STOCK_CONFIG_FILE) self.api.connect(self.best_ip.get('ip'), self.best_ip.get('port')) self.write_log(f'创建tdx连接, : {self.best_ip}') - TdxStockData.connection_status = True + self.connection_status = True except Exception as ex: self.write_log(u'连接服务器tdx异常:{},{}'.format(str(ex), traceback.format_exc())) return def disconnect(self): + """断开连接""" if self.api is not None: self.api = None + def cache_config(self): + """缓存所有股票的清单""" + for market_id in range(2): + print('get market_id:{}'.format(market_id)) + security_list = self.get_security_list(market_id) + if len(security_list) == 0: + continue + for security in security_list: + tdx_symbol = security.get('code', None) + if tdx_symbol: + self.symbol_dict.update({f'{tdx_symbol}_{market_id}': security}) + + self.config.update({'symbol_dict': self.symbol_dict, 'cache_time': datetime.now()}) + save_cache_config(data=self.config, config_file_name=STOCK_CONFIG_FILE) + + def get_security_list(self, market_id: int = 0): + """ + 获取市场代码 + :param: market_id: 1,上交所 , 0, 深交所 + :return: + """ + if self.api is None: + self.connect() + + start = 0 + results = [] + # 接口有数据量连续,循环获取,直至取不到结果为止 + while True: + try: + result = self.api.get_security_list(market_id, start) + except Exception: + break + if len(result) > 0: + start += len(result) + else: + break + results.extend(result) + + return results # ---------------------------------------------------------------------- - def get_bars(self, symbol, period, callback, bar_is_completed=False, bar_freq=1, start_dt=None): + def get_bars(self, + symbol: str, + period: str, + callback=None, + bar_freq: int = 1, + start_dt: datetime = None): """ 返回k线数据 symbol:股票 000001.XG @@ -104,11 +186,15 @@ class TdxStockData(object): """ if self.api is None: self.connect() + ret_bars = [] + if self.api is None: + return False, [] - # 新版一劳永逸偷懒写法zzz + # symbol => tdx_code, market_id if '.' in symbol: tdx_code, market_str = symbol.split('.') - market_code = 1 if market_str.upper() == 'XSHG' else 0 + # 1, 上交所 , 0, 深交所 + market_code = 1 if market_str.upper() in ['XSHG', Exchange.SSE.value] else 0 self.symbol_exchange_dict.update({tdx_code: symbol}) # tdx合约与vn交易所的字典 self.symbol_market_dict.update({tdx_code: market_code}) # tdx合约与tdx市场的字典 else: @@ -117,37 +203,30 @@ class TdxStockData(object): self.symbol_exchange_dict.update({symbol: symbol}) # tdx合约与vn交易所的字典 self.symbol_market_dict.update({symbol: market_code}) # tdx合约与tdx市场的字典 - # https://github.com/rainx/pytdx/issues/33 - # 0 - 深圳, 1 - 上海 - - ret_bars = [] - + # period => tdx_period if period not in PERIOD_MAPPING.keys(): self.write_error(u'{} 周期{}不在下载清单中: {}' .format(datetime.now(), period, list(PERIOD_MAPPING.keys()))) # print(u'{} 周期{}不在下载清单中: {}'.format(datetime.now(), period, list(PERIOD_MAPPING.keys()))) return False, ret_bars - - if self.api is None: - return False, ret_bars - tdx_period = PERIOD_MAPPING.get(period) + # start_dt => qry_start_dt & qry_end_dt if start_dt is None: self.write_log(u'没有设置开始时间,缺省为10天前') qry_start_date = datetime.now() - timedelta(days=10) start_dt = qry_start_date else: qry_start_date = start_dt - end_date = datetime.now() - if qry_start_date > end_date: - qry_start_date = end_date + qry_end_date = datetime.now() + if qry_start_date > qry_end_date: + qry_start_date = qry_end_date self.write_log('{}开始下载tdx股票:{} {}数据, {} to {}.' - .format(datetime.now(), tdx_code, tdx_period, qry_start_date, end_date)) + .format(datetime.now(), tdx_code, tdx_period, qry_start_date, qry_end_date)) try: - _start_date = end_date + _start_date = qry_end_date _bars = [] _pos = 0 while _start_date > qry_start_date: @@ -189,7 +268,8 @@ class TdxStockData(object): return False, ret_bars # 通达信是以bar的结束时间标记的,vnpy是以bar开始时间标记的,所以要扣减bar本身的分钟数 - data['datetime'] = data['datetime'].apply(lambda x: x - timedelta(minutes=NUM_MINUTE_MAPPING.get(period, 1))) + data['datetime'] = data['datetime'].apply( + lambda x: x - timedelta(minutes=NUM_MINUTE_MAPPING.get(period, 1))) data['trading_date'] = data['datetime'].apply(lambda x: (x.strftime('%Y-%m-%d'))) data['date'] = data['datetime'].apply(lambda x: (x.strftime('%Y-%m-%d'))) data['time'] = data['datetime'].apply(lambda x: (x.strftime('%H:%M:%S'))) @@ -203,10 +283,10 @@ class TdxStockData(object): add_bar.date = row['date'] add_bar.time = row['time'] add_bar.trading_date = row['trading_date'] - add_bar.open = float(row['open']) - add_bar.high = float(row['high']) - add_bar.low = float(row['low']) - add_bar.close = float(row['close']) + add_bar.open_price = float(row['open']) + add_bar.high_price = float(row['high']) + add_bar.low_price = float(row['low']) + add_bar.close_price = float(row['close']) add_bar.volume = float(row['volume']) except Exception as ex: self.write_error('error when convert bar:{},ex:{},t:{}' @@ -237,11 +317,15 @@ class TdxStockData(object): self.write_error('exception in get:{},{},{}'.format(tdx_code, str(ex), traceback.format_exc())) # print('exception in get:{},{},{}'.format(tdx_symbol,str(ex), traceback.format_exc())) self.write_log(u'重置连接') - TdxStockData.api = None - self.connect() + self.api = None + self.connect(is_reconnect=True) return False, ret_bars - def save_cache(self, cache_folder, cache_symbol, cache_date, data_list): + def save_cache(self, + cache_folder: str, + cache_symbol: str, + cache_date: str, + data_list: list): """保存文件到缓存""" os.makedirs(cache_folder, exist_ok=True) @@ -257,7 +341,10 @@ class TdxStockData(object): pickle.dump(data_list, f) self.write_log(u'缓存成功:{}'.format(save_file)) - def load_cache(self, cache_folder, cache_symbol, cache_date): + def load_cache(self, + cache_folder: str, + cache_symbol: str, + cache_date: str): """加载缓存数据""" if not os.path.exists(cache_folder): self.write_error('缓存目录:{}不存在,不能读取'.format(cache_folder)) @@ -277,22 +364,32 @@ class TdxStockData(object): return None - def get_history_transaction_data(self, symbol, date, cache_folder=None): - """获取当某一交易日的历史成交记录""" + def get_history_transaction_data(self, + symbol: str, + trading_date, + cache_folder: str = None): + """ + 获取当某一交易日的历史成交记录 + :param symbol: 查询合约 xxxxxx.交易所 + :param trading_date: 可以是日期参数,或者字符串参数,支持 2019-01-01 或 20190101格式 + :param cache_folder: + :return: + """ ret_datas = [] - if isinstance(date, datetime): - date = date.strftime('%Y%m%d') - if isinstance(date, str): - date = int(date) + # trading_date ,转为为查询数字类型 + if isinstance(trading_date, datetime): + trading_date = trading_date.strftime('%Y%m%d') + if isinstance(trading_date, str): + trading_date = int(trading_date.replace('-', '')) cache_symbol = symbol - cache_date = str(date) + cache_date = str(trading_date) max_data_size = sys.maxsize # symbol.exchange => tdx_code market_code if '.' in symbol: tdx_code, market_str = symbol.split('.') - market_code = 1 if market_str.upper() == 'XSHG' else 0 + market_code = 1 if market_str.upper() in ['XSHG', Exchange.SSE.value] else 0 self.symbol_exchange_dict.update({tdx_code: symbol}) # tdx合约与vn交易所的字典 self.symbol_market_dict.update({tdx_code: market_code}) # tdx合约与tdx市场的字典 else: @@ -310,27 +407,29 @@ class TdxStockData(object): if buffer_data: return True, buffer_data - self.write_log(u'开始下载{} 历史{}分笔数据'.format(date, symbol)) + self.write_log(u'开始下载{} 历史{}分笔数据'.format(trading_date, symbol)) is_today = False - if date == int(datetime.now().strftime('%Y%m%d')): + if trading_date == int(datetime.now().strftime('%Y%m%d')): is_today = True try: _datas = [] _pos = 0 - while(True): + while True: if is_today: + # 获取当前交易日得交易记录 _res = self.api.get_transaction_data( market=self.symbol_market_dict[symbol], code=symbol, start=_pos, count=q_size) else: + # 获取历史交易记录 _res = self.api.get_history_transaction_data( market=self.symbol_market_dict[symbol], - date=date, + trading_date=trading_date, code=symbol, start=_pos, count=q_size) @@ -341,7 +440,7 @@ class TdxStockData(object): if _res is not None and len(_res) > 0: self.write_log(u'分段取{}分笔数据:{} ~{}, {}条,累计:{}条' - .format(date, _res[0]['time'], _res[-1]['time'], len(_res), _pos)) + .format(trading_date, _res[0]['time'], _res[-1]['time'], len(_res), _pos)) else: break @@ -349,11 +448,11 @@ class TdxStockData(object): break if len(_datas) == 0: - self.write_error(u'{}分笔成交数据获取为空'.format(date)) + self.write_error(u'{}分笔成交数据获取为空'.format(trading_date)) return False, _datas for d in _datas: - dt = datetime.strptime(str(date) + ' ' + d.get('time'), '%Y%m%d %H:%M') + dt = datetime.strptime(str(trading_date) + ' ' + d.get('time'), '%Y%m%d %H:%M') if last_dt is None or last_dt < dt: last_dt = dt else: @@ -372,26 +471,34 @@ class TdxStockData(object): return True, _datas except Exception as ex: - self.write_error('exception in get_transaction_data:{},{},{}'.format(symbol, str(ex), traceback.format_exc())) + self.write_error( + 'exception in get_transaction_data:{},{},{}'.format(symbol, str(ex), traceback.format_exc())) return False, ret_datas if __name__ == "__main__": - class T(object): - def write_log(self, content, level=INFO): - if level == INFO: - print(content) - else: - print(content, file=sys.stderr) - - def display_bar(self, bar, bar_is_completed=True, freq=1): - print(u'{} {}'.format(bar.vtSymbol, bar.datetime)) - - t1 = T() - t2 = T() + from tdx_common import FakeStrategy + import json + t1 = FakeStrategy() + t2 = FakeStrategy() # 创建API对象 api_01 = TdxStockData(t1) + # 获取市场下股票 + for market_id in range(2): + print('get market_id:{}'.format(market_id)) + security_list = api_01.get_security_list(market_id) + if len(security_list) == 0: + continue + for security in security_list: + if security.get('code', '').startswith('12') or u'转债' in security.get('name', ''): + str_security = json.dumps(security, indent=1, ensure_ascii=False) + print('market_id:{},{}'.format(market_id, str_security)) + + # str_markets = json.dumps(security_list, indent=1, ensure_ascii=False) + # print(u'{}'.format(str_markets)) + + # 获取历史分钟线 # api_01.get_bars('002024', period='1hour', callback=t1.display_bar)