diff --git a/vnpy/data/tdx/tdx_future_data.py b/vnpy/data/tdx/tdx_future_data.py index 2ba0b974..e064e501 100644 --- a/vnpy/data/tdx/tdx_future_data.py +++ b/vnpy/data/tdx/tdx_future_data.py @@ -25,7 +25,7 @@ from pytdx.exhq import TdxExHq_API from vnpy.trader.constant import Exchange from vnpy.trader.object import BarData -from vnpy.trader.utility import (get_underlying_symbol, get_full_symbol, get_trading_date) +from vnpy.trader.utility import (get_underlying_symbol, get_full_symbol, get_trading_date, load_json, save_json) from vnpy.data.tdx.tdx_common import (TDX_FUTURE_HOSTS, PERIOD_MAPPING) @@ -67,6 +67,28 @@ ALL_MARKET_BEGIN_HOUR = 8 ALL_MARKET_END_HOUR = 16 +def get_cache_ip(): + """获取本地缓存的最快IP地址信息""" + config_file_name = os.path.abspath(os.path.join(os.path.dirname(__file__), 'tdx_config.json')) + return load_json(config_file_name) + + +def save_cache_ip(best_ip: dict): + """保存本地缓存的最快IP地址信息""" + config_file_name = os.path.abspath(os.path.join(os.path.dirname(__file__), 'tdx_config.json')) + save_json(filename=config_file_name, data=best_ip) + + +def get_tdx_marketid(symbol): + """普通合约/指数合约=》tdx合约所在市场id""" + underlying_symbol = get_underlying_symbol(symbol) + tdx_index_symbol = underlying_symbol.upper() + 'L9' + market_id = INIT_TDX_MARKET_MAP.get(tdx_index_symbol, None) + if market_id is None: + raise KeyError(f'{tdx_index_symbol}不存在INIT_TDX_MARKET_MAP中') + return market_id + + class TdxFutureData(object): # ---------------------------------------------------------------------- @@ -108,6 +130,9 @@ class TdxFutureData(object): # 选取最佳服务器 if is_reconnect or len(self.best_ip) == 0: + self.best_ip = get_cache_ip() + + if len(self.best_ip) == 0: self.best_ip = self.select_best_ip() self.api.connect(self.best_ip['ip'], self.best_ip['port']) @@ -120,9 +145,9 @@ class TdxFutureData(object): self.write_log(u'创建tdx连接, IP: {}/{}'.format(self.best_ip['ip'], self.best_ip['port'])) # print(u'创建tdx连接, IP: {}/{}'.format(self.best_ip['ip'], self.best_ip['port'])) self.connection_status = True - if not is_reconnect: - # 更新 symbol_exchange_dict , symbol_market_dict - self.qryInstrument() + # if not is_reconnect: + # 更新 symbol_exchange_dict , symbol_market_dict + # self.qryInstrument() except Exception as ex: self.write_log(u'连接服务器tdx异常:{},{}'.format(str(ex), traceback.format_exc())) return @@ -165,6 +190,7 @@ class TdxFutureData(object): self.write_log(u'选取 {}:{}'.format(best_future_ip['ip'], best_future_ip['port'])) # print(u'选取 {}:{}'.format(best_future_ip['ip'], best_future_ip['port'])) + save_cache_ip(best_future_ip) return best_future_ip # ---------------------------------------------------------------------- @@ -203,14 +229,11 @@ class TdxFutureData(object): ret_bars = [] tdx_symbol = symbol.upper().replace('_', '') tdx_symbol = tdx_symbol.replace('99', 'L9') + tdx_index_symbol = get_underlying_symbol(symbol) + 'L9' self.connect() if self.api is None: return False, ret_bars - if tdx_symbol not in self.symbol_exchange_dict.keys(): - self.write_error(u'{} 合约{}/{}不在下载清单中: {}' - .format(datetime.now(), symbol, tdx_symbol, self.symbol_exchange_dict.keys())) - return False, ret_bars if period not in PERIOD_MAPPING.keys(): self.write_error(u'{} 周期{}不在下载清单中: {}'.format(datetime.now(), period, list(PERIOD_MAPPING.keys()))) return False, ret_bars @@ -236,7 +259,7 @@ class TdxFutureData(object): while _start_date > qry_start_date: _res = self.api.get_instrument_bars( PERIOD_MAPPING[period], - self.symbol_market_dict.get(tdx_symbol, 0), + self.symbol_market_dict.get(tdx_index_symbol, 0), tdx_symbol, _pos, QSIZE) @@ -362,11 +385,11 @@ class TdxFutureData(object): if query_symbol != tdx_symbol: self.write_log('转换合约:{}=>{}'.format(tdx_symbol, query_symbol)) - index_symbol = short_symbol + 'L9' + tdx_index_symbol = short_symbol + 'L9' self.connect() if self.api is None: return 0 - market_id = self.symbol_market_dict.get(index_symbol, 0) + market_id = self.symbol_market_dict.get(tdx_index_symbol, 0) _res = self.api.get_instrument_quote(market_id, query_symbol) if not isinstance(_res, list): @@ -466,7 +489,12 @@ class TdxFutureData(object): max_data_size = sys.maxsize symbol = symbol.upper() if '99' in symbol: + # 查询的是指数合约 symbol = symbol.replace('99', 'L9') + tdx_index_symbol = symbol + else: + # 查询的是普通合约 + tdx_index_symbol = get_underlying_symbol(symbol).upper() + 'L9' self.connect() @@ -482,7 +510,7 @@ class TdxFutureData(object): while(True): _res = self.api.get_transaction_data( - market=self.symbol_market_dict.get(symbol, 0), + market=self.symbol_market_dict.get(tdx_index_symbol, 0), code=symbol, start=_pos, count=q_size) @@ -589,8 +617,12 @@ class TdxFutureData(object): max_data_size = sys.maxsize symbol = symbol.upper() if '99' in symbol: + # 查询的是指数合约 symbol = symbol.replace('99', 'L9') - + tdx_index_symbol = symbol + else: + # 查询的是普通合约 + tdx_index_symbol = get_underlying_symbol(symbol).upper() + 'L9' q_size = QSIZE * 5 # 每秒 2个, 10小时 max_data_size = 1000000 @@ -612,7 +644,7 @@ class TdxFutureData(object): while(True): _res = self.api.get_history_transaction_data( - market=self.symbol_market_dict.get(symbol, 0), + market=self.symbol_market_dict.get(tdx_index_symbol, 0), date=date, code=symbol, start=_pos,