[增强] 增加缓存最快服务器,修复股票价格小数位bug

This commit is contained in:
msincenselee 2019-12-24 20:25:00 +08:00
parent c59b508275
commit 0426820e78
3 changed files with 247 additions and 102 deletions

View File

@ -1,6 +1,11 @@
# encoding: UTF-8
import sys
import os
import pickle
import bz2
from functools import lru_cache
from logging import INFO, ERROR
@lru_cache()
@ -58,3 +63,37 @@ TDX_FUTURE_HOSTS = [
{"ip": '218.80.248.229', 'port': 7721, "name": "备用服务器1"},
{"ip": '124.74.236.94', 'port': 7721, "name": "备用服务器2"},
{'ip': '58.246.109.27', 'port': 7721, "name": "备用服务器3"}]
def get_cache_config(config_file_name):
"""获取本地缓存的配置地址信息"""
config_file_name = os.path.abspath(os.path.join(os.path.dirname(__file__), config_file_name))
config = {}
if not os.path.exists(config_file_name):
return config
with bz2.BZ2File(config_file_name, 'rb') as f:
config = pickle.load(f)
return config
return config
def save_cache_config(data: dict, config_file_name ):
"""保存本地缓存的配置地址信息"""
config_file_name = os.path.abspath(os.path.join(os.path.dirname(__file__), config_file_name))
with bz2.BZ2File(config_file_name, 'wb') as f:
pickle.dump(data, f)
class FakeStrategy(object):
"""制作一个假得策略,用于测试"""
def write_log(self, content, level=INFO):
if level == INFO:
print(content)
else:
print(content, file=sys.stderr)
def write_error(self, content):
self.write_log(content, level=ERROR)
def display_bar(self, bar, bar_is_completed=True, freq=1):
print(u'{} {}'.format(bar.vtSymbol, bar.datetime))

View File

@ -26,7 +26,7 @@ from pytdx.exhq import TdxExHq_API
from vnpy.trader.constant import Exchange
from vnpy.trader.object import BarData
from vnpy.trader.utility import (get_underlying_symbol, get_full_symbol, get_trading_date, load_json, save_json)
from vnpy.data.tdx.tdx_common import (TDX_FUTURE_HOSTS, PERIOD_MAPPING)
from vnpy.data.tdx.tdx_common import (lru_cache, TDX_FUTURE_HOSTS, PERIOD_MAPPING)
# 每个周期包含多少分钟 (估算值, 没考虑夜盘和10:15的影响)
@ -78,7 +78,7 @@ def save_cache_ip(best_ip: dict):
config_file_name = os.path.abspath(os.path.join(os.path.dirname(__file__), 'tdx_config.json'))
save_json(filename=config_file_name, data=best_ip)
@lru_cache()
def get_tdx_marketid(symbol):
"""普通合约/指数合约=》tdx合约所在市场id"""
underlying_symbol = get_underlying_symbol(symbol)
@ -153,7 +153,9 @@ class TdxFutureData(object):
return
# ----------------------------------------------------------------------
def ping(self, ip, port=7709):
def ping(self,
ip: str,
port: int = 7709):
"""
ping行情服务器
:param ip:
@ -219,7 +221,12 @@ class TdxFutureData(object):
self.symbol_market_dict.update({tdx_symbol: tdx_market_id})
# ----------------------------------------------------------------------
def get_bars(self, symbol, period, callback, bar_is_completed=False, bar_freq=1, start_dt=None):
def get_bars(self,
symbol: str,
period: str,
callback,
bar_freq: int = 1,
start_dt: datetime = None):
"""
返回k线数据
symbol合约
@ -238,7 +245,7 @@ class TdxFutureData(object):
self.write_error(u'{} 周期{}不在下载清单中: {}'.format(datetime.now(), period, list(PERIOD_MAPPING.keys())))
return False, ret_bars
# tdx_period = PERIOD_MAPPING.get(period)
tdx_period = PERIOD_MAPPING.get(period)
if start_dt is None:
self.write_log(u'没有设置开始时间缺省为10天前')
@ -258,7 +265,7 @@ class TdxFutureData(object):
_pos = 0
while _start_date > qry_start_date:
_res = self.api.get_instrument_bars(
PERIOD_MAPPING[period],
tdx_period,
self.symbol_market_dict.get(tdx_index_symbol, 0),
tdx_symbol,
_pos,
@ -334,10 +341,10 @@ class TdxFutureData(object):
add_bar.date = row['date']
add_bar.time = row['time']
add_bar.trading_date = row['trading_date']
add_bar.open = float(row['open'])
add_bar.high = float(row['high'])
add_bar.low = float(row['low'])
add_bar.close = float(row['close'])
add_bar.open_price = float(row['open'])
add_bar.high_price = float(row['high'])
add_bar.low_price = float(row['low'])
add_bar.close_price = float(row['close'])
add_bar.volume = float(row['volume'])
add_bar.openInterest = float(row['open_interest'])
except Exception as ex:
@ -372,7 +379,7 @@ class TdxFutureData(object):
self.connect(is_reconnect=True)
return False, ret_bars
def get_price(self, symbol):
def get_price(self, symbol: str):
"""获取最新价格"""
tdx_symbol = symbol.upper().replace('_', '')
@ -508,7 +515,7 @@ class TdxFutureData(object):
_datas = []
_pos = 0
while(True):
while True:
_res = self.api.get_transaction_data(
market=self.symbol_market_dict.get(tdx_index_symbol, 0),
code=symbol,
@ -601,18 +608,22 @@ class TdxFutureData(object):
return None
def get_history_transaction_data(self, symbol, date, cache_folder=None):
def get_history_transaction_data(self,
symbol: str,
trading_date,
cache_folder:str = None):
"""获取当某一交易日的历史成交记录"""
ret_datas = []
if isinstance(date, datetime):
date = date.strftime('%Y%m%d')
if isinstance(date, str):
date = int(date)
# trading_date, 转换为数字类型得日期
if isinstance(trading_date, datetime):
trading_date = trading_date.strftime('%Y%m%d')
if isinstance(trading_date, str):
trading_date = int(trading_date.replace('-', ''))
self.connect()
cache_symbol = symbol
cache_date = str(date)
cache_date = str(trading_date)
max_data_size = sys.maxsize
symbol = symbol.upper()
@ -634,18 +645,18 @@ class TdxFutureData(object):
self.write_log(u'使用缓存文件')
return True, buffer_data
self.write_log(u'开始下载{} 历史{}分笔数据'.format(date, symbol))
self.write_log(u'开始下载{} 历史{}分笔数据'.format(trading_date, symbol))
cur_trading_date = get_trading_date()
if date == int(cur_trading_date.replace('-', '')):
if trading_date == int(cur_trading_date.replace('-', '')):
return self.get_transaction_data(symbol)
try:
_datas = []
_pos = 0
while(True):
while True:
_res = self.api.get_history_transaction_data(
market=self.symbol_market_dict.get(tdx_index_symbol, 0),
date=date,
trading_date=trading_date,
code=symbol,
start=_pos,
count=q_size)
@ -689,7 +700,7 @@ class TdxFutureData(object):
break
if len(_datas) == 0:
self.write_error(u'{}分笔成交数据获取为空'.format(date))
self.write_error(u'{}分笔成交数据获取为空'.format(trading_date))
return False, _datas
# 缓存文件
@ -708,20 +719,8 @@ class TdxFutureData(object):
return False, ret_datas
class FakeStrategy(object):
def write_log(self, content, level=INFO):
if level == INFO:
print(content)
else:
print(content, file=sys.stderr)
def display_bar(self, bar, bar_is_completed=True, freq=1):
print(u'{} {}'.format(bar.vtSymbol, bar.datetime))
if __name__ == "__main__":
from .tdx_common import FakeStrategy
t1 = FakeStrategy()
t2 = FakeStrategy()
# 创建API对象
@ -780,6 +779,6 @@ if __name__ == "__main__":
# print(r)
# 获取历史分时数据
ret, result = api_01.get_history_transaction_data('J99', '20191009')
ret, result = api_01.get_history_transaction_data('rb1905', '20190109')
for r in result[0:10] + result[-10:]:
print(r)

View File

@ -16,12 +16,18 @@ import pickle
import bz2
import traceback
from datetime import datetime, timedelta
from logging import ERROR, INFO
from logging import ERROR
from pytdx.hq import TdxHq_API
from pytdx.params import TDXParams
from pandas import to_datetime
from vnpy.trader.object import BarData
from vnpy.data.tdx.tdx_common import PERIOD_MAPPING, get_tdx_market_code
from vnpy.trader.constant import Exchange
from vnpy.data.tdx.tdx_common import (
PERIOD_MAPPING,
get_tdx_market_code,
get_cache_config,
save_cache_config)
# 每个周期包含多少分钟
NUM_MINUTE_MAPPING = {}
@ -30,16 +36,31 @@ NUM_MINUTE_MAPPING['5min'] = 5
NUM_MINUTE_MAPPING['15min'] = 15
NUM_MINUTE_MAPPING['30min'] = 30
NUM_MINUTE_MAPPING['1hour'] = 60
NUM_MINUTE_MAPPING['1day'] = 60 * 5.5 # 股票收盘时间是1500开盘是930
NUM_MINUTE_MAPPING['1day'] = 60 * 5.5 # 股票收盘时间是1500开盘是930
# 常量
QSIZE = 800
STOCK_CONFIG_FILE = 'tdx_stock_config.pkb2'
# 通达信 <=> 交易所代码 映射
TDX_VN_STOCK_MARKET_MAP = {
TDXParams.MARKET_SH: Exchange.SSE, # 1: 上交所
TDXParams.MARKET_SZ: Exchange.SZSE # 0 深交所
}
VN_TDX_STOCK_MARKET_MAP = {v: k for k, v in TDX_VN_STOCK_MARKET_MAP.items()}
# 通达信 <=> rq交易所代码 映射
TDX_RQ_STOCK_MARKET_MAP = {
TDXParams.MARKET_SH: 'XSHG', # 1: 上交所
TDXParams.MARKET_SZ: 'XSHE' # 0 深交所
}
RQ_TDX_STOCK_MARKET_MAP = {v: k for k, v in TDX_RQ_STOCK_MARKET_MAP.items()}
class TdxStockData(object):
best_ip = None
symbol_exchange_dict = {} # tdx合约与vn交易所的字典
symbol_market_dict = {} # tdx合约与tdx市场的字典
# ----------------------------------------------------------------------
def __init__(self, strategy=None):
@ -52,24 +73,35 @@ class TdxStockData(object):
self.connection_status = False # 连接状态
self.strategy = strategy
self.best_ip = None
self.symbol_exchange_dict = {} # tdx合约与vn交易所的字典
self.symbol_market_dict = {} # tdx合约与tdx市场的字典
self.connect()
self.config = get_cache_config(STOCK_CONFIG_FILE)
self.symbol_dict = self.config.get('symbol_dict', {})
self.cache_time = self.config.get('cache_time', datetime.now() - timedelta(days=7))
if len(self.symbol_dict) == 0 or self.cache_time < datetime.now() - timedelta(days=1):
self.cache_config()
def write_log(self, content):
"""记录日志"""
if self.strategy:
self.strategy.write_log(content)
else:
print(content)
def write_error(self, content):
"""记录错误"""
if self.strategy:
self.strategy.write_log(content, level=ERROR)
else:
print(content, file=sys.stderr)
def connect(self):
def connect(self, is_reconnect: bool = False):
"""
连接API
:param:is_reconnect 是否重新连接
:return:
"""
# 创建api连接对象实例
@ -79,24 +111,74 @@ class TdxStockData(object):
self.api = TdxHq_API(heartbeat=True, auto_retry=True, raise_exception=True)
# 选取最佳服务器
if TdxStockData.best_ip is None:
if is_reconnect or self.best_ip is None:
self.best_ip = self.config.get('best_ip', {})
if len(self.best_ip) == 0:
from pytdx.util.best_ip import select_best_ip
TdxStockData.best_ip = select_best_ip()
self.best_ip = select_best_ip()
self.config.update({'best_ip': self.best_ip})
save_cache_config(self.config, STOCK_CONFIG_FILE)
self.api.connect(self.best_ip.get('ip'), self.best_ip.get('port'))
self.write_log(f'创建tdx连接, : {self.best_ip}')
TdxStockData.connection_status = True
self.connection_status = True
except Exception as ex:
self.write_log(u'连接服务器tdx异常:{},{}'.format(str(ex), traceback.format_exc()))
return
def disconnect(self):
"""断开连接"""
if self.api is not None:
self.api = None
def cache_config(self):
"""缓存所有股票的清单"""
for market_id in range(2):
print('get market_id:{}'.format(market_id))
security_list = self.get_security_list(market_id)
if len(security_list) == 0:
continue
for security in security_list:
tdx_symbol = security.get('code', None)
if tdx_symbol:
self.symbol_dict.update({f'{tdx_symbol}_{market_id}': security})
self.config.update({'symbol_dict': self.symbol_dict, 'cache_time': datetime.now()})
save_cache_config(data=self.config, config_file_name=STOCK_CONFIG_FILE)
def get_security_list(self, market_id: int = 0):
"""
获取市场代码
:param: market_id: 1上交所 0 深交所
:return:
"""
if self.api is None:
self.connect()
start = 0
results = []
# 接口有数据量连续,循环获取,直至取不到结果为止
while True:
try:
result = self.api.get_security_list(market_id, start)
except Exception:
break
if len(result) > 0:
start += len(result)
else:
break
results.extend(result)
return results
# ----------------------------------------------------------------------
def get_bars(self, symbol, period, callback, bar_is_completed=False, bar_freq=1, start_dt=None):
def get_bars(self,
symbol: str,
period: str,
callback=None,
bar_freq: int = 1,
start_dt: datetime = None):
"""
返回k线数据
symbol股票 000001.XG
@ -104,11 +186,15 @@ class TdxStockData(object):
"""
if self.api is None:
self.connect()
ret_bars = []
if self.api is None:
return False, []
# 新版一劳永逸偷懒写法zzz
# symbol => tdx_code, market_id
if '.' in symbol:
tdx_code, market_str = symbol.split('.')
market_code = 1 if market_str.upper() == 'XSHG' else 0
# 1, 上交所 0 深交所
market_code = 1 if market_str.upper() in ['XSHG', Exchange.SSE.value] else 0
self.symbol_exchange_dict.update({tdx_code: symbol}) # tdx合约与vn交易所的字典
self.symbol_market_dict.update({tdx_code: market_code}) # tdx合约与tdx市场的字典
else:
@ -117,37 +203,30 @@ class TdxStockData(object):
self.symbol_exchange_dict.update({symbol: symbol}) # tdx合约与vn交易所的字典
self.symbol_market_dict.update({symbol: market_code}) # tdx合约与tdx市场的字典
# https://github.com/rainx/pytdx/issues/33
# 0 - 深圳, 1 - 上海
ret_bars = []
# period => tdx_period
if period not in PERIOD_MAPPING.keys():
self.write_error(u'{} 周期{}不在下载清单中: {}'
.format(datetime.now(), period, list(PERIOD_MAPPING.keys())))
# print(u'{} 周期{}不在下载清单中: {}'.format(datetime.now(), period, list(PERIOD_MAPPING.keys())))
return False, ret_bars
if self.api is None:
return False, ret_bars
tdx_period = PERIOD_MAPPING.get(period)
# start_dt => qry_start_dt & qry_end_dt
if start_dt is None:
self.write_log(u'没有设置开始时间缺省为10天前')
qry_start_date = datetime.now() - timedelta(days=10)
start_dt = qry_start_date
else:
qry_start_date = start_dt
end_date = datetime.now()
if qry_start_date > end_date:
qry_start_date = end_date
qry_end_date = datetime.now()
if qry_start_date > qry_end_date:
qry_start_date = qry_end_date
self.write_log('{}开始下载tdx股票:{} {}数据, {} to {}.'
.format(datetime.now(), tdx_code, tdx_period, qry_start_date, end_date))
.format(datetime.now(), tdx_code, tdx_period, qry_start_date, qry_end_date))
try:
_start_date = end_date
_start_date = qry_end_date
_bars = []
_pos = 0
while _start_date > qry_start_date:
@ -189,7 +268,8 @@ class TdxStockData(object):
return False, ret_bars
# 通达信是以bar的结束时间标记的vnpy是以bar开始时间标记的,所以要扣减bar本身的分钟数
data['datetime'] = data['datetime'].apply(lambda x: x - timedelta(minutes=NUM_MINUTE_MAPPING.get(period, 1)))
data['datetime'] = data['datetime'].apply(
lambda x: x - timedelta(minutes=NUM_MINUTE_MAPPING.get(period, 1)))
data['trading_date'] = data['datetime'].apply(lambda x: (x.strftime('%Y-%m-%d')))
data['date'] = data['datetime'].apply(lambda x: (x.strftime('%Y-%m-%d')))
data['time'] = data['datetime'].apply(lambda x: (x.strftime('%H:%M:%S')))
@ -203,10 +283,10 @@ class TdxStockData(object):
add_bar.date = row['date']
add_bar.time = row['time']
add_bar.trading_date = row['trading_date']
add_bar.open = float(row['open'])
add_bar.high = float(row['high'])
add_bar.low = float(row['low'])
add_bar.close = float(row['close'])
add_bar.open_price = float(row['open'])
add_bar.high_price = float(row['high'])
add_bar.low_price = float(row['low'])
add_bar.close_price = float(row['close'])
add_bar.volume = float(row['volume'])
except Exception as ex:
self.write_error('error when convert bar:{},ex:{},t:{}'
@ -237,11 +317,15 @@ class TdxStockData(object):
self.write_error('exception in get:{},{},{}'.format(tdx_code, str(ex), traceback.format_exc()))
# print('exception in get:{},{},{}'.format(tdx_symbol,str(ex), traceback.format_exc()))
self.write_log(u'重置连接')
TdxStockData.api = None
self.connect()
self.api = None
self.connect(is_reconnect=True)
return False, ret_bars
def save_cache(self, cache_folder, cache_symbol, cache_date, data_list):
def save_cache(self,
cache_folder: str,
cache_symbol: str,
cache_date: str,
data_list: list):
"""保存文件到缓存"""
os.makedirs(cache_folder, exist_ok=True)
@ -257,7 +341,10 @@ class TdxStockData(object):
pickle.dump(data_list, f)
self.write_log(u'缓存成功:{}'.format(save_file))
def load_cache(self, cache_folder, cache_symbol, cache_date):
def load_cache(self,
cache_folder: str,
cache_symbol: str,
cache_date: str):
"""加载缓存数据"""
if not os.path.exists(cache_folder):
self.write_error('缓存目录:{}不存在,不能读取'.format(cache_folder))
@ -277,22 +364,32 @@ class TdxStockData(object):
return None
def get_history_transaction_data(self, symbol, date, cache_folder=None):
"""获取当某一交易日的历史成交记录"""
def get_history_transaction_data(self,
symbol: str,
trading_date,
cache_folder: str = None):
"""
获取当某一交易日的历史成交记录
:param symbol: 查询合约 xxxxxx.交易所
:param trading_date: 可以是日期参数或者字符串参数,支持 2019-01-01 20190101格式
:param cache_folder:
:return:
"""
ret_datas = []
if isinstance(date, datetime):
date = date.strftime('%Y%m%d')
if isinstance(date, str):
date = int(date)
# trading_date ,转为为查询数字类型
if isinstance(trading_date, datetime):
trading_date = trading_date.strftime('%Y%m%d')
if isinstance(trading_date, str):
trading_date = int(trading_date.replace('-', ''))
cache_symbol = symbol
cache_date = str(date)
cache_date = str(trading_date)
max_data_size = sys.maxsize
# symbol.exchange => tdx_code market_code
if '.' in symbol:
tdx_code, market_str = symbol.split('.')
market_code = 1 if market_str.upper() == 'XSHG' else 0
market_code = 1 if market_str.upper() in ['XSHG', Exchange.SSE.value] else 0
self.symbol_exchange_dict.update({tdx_code: symbol}) # tdx合约与vn交易所的字典
self.symbol_market_dict.update({tdx_code: market_code}) # tdx合约与tdx市场的字典
else:
@ -310,27 +407,29 @@ class TdxStockData(object):
if buffer_data:
return True, buffer_data
self.write_log(u'开始下载{} 历史{}分笔数据'.format(date, symbol))
self.write_log(u'开始下载{} 历史{}分笔数据'.format(trading_date, symbol))
is_today = False
if date == int(datetime.now().strftime('%Y%m%d')):
if trading_date == int(datetime.now().strftime('%Y%m%d')):
is_today = True
try:
_datas = []
_pos = 0
while(True):
while True:
if is_today:
# 获取当前交易日得交易记录
_res = self.api.get_transaction_data(
market=self.symbol_market_dict[symbol],
code=symbol,
start=_pos,
count=q_size)
else:
# 获取历史交易记录
_res = self.api.get_history_transaction_data(
market=self.symbol_market_dict[symbol],
date=date,
trading_date=trading_date,
code=symbol,
start=_pos,
count=q_size)
@ -341,7 +440,7 @@ class TdxStockData(object):
if _res is not None and len(_res) > 0:
self.write_log(u'分段取{}分笔数据:{} ~{}, {}条,累计:{}'
.format(date, _res[0]['time'], _res[-1]['time'], len(_res), _pos))
.format(trading_date, _res[0]['time'], _res[-1]['time'], len(_res), _pos))
else:
break
@ -349,11 +448,11 @@ class TdxStockData(object):
break
if len(_datas) == 0:
self.write_error(u'{}分笔成交数据获取为空'.format(date))
self.write_error(u'{}分笔成交数据获取为空'.format(trading_date))
return False, _datas
for d in _datas:
dt = datetime.strptime(str(date) + ' ' + d.get('time'), '%Y%m%d %H:%M')
dt = datetime.strptime(str(trading_date) + ' ' + d.get('time'), '%Y%m%d %H:%M')
if last_dt is None or last_dt < dt:
last_dt = dt
else:
@ -372,26 +471,34 @@ class TdxStockData(object):
return True, _datas
except Exception as ex:
self.write_error('exception in get_transaction_data:{},{},{}'.format(symbol, str(ex), traceback.format_exc()))
self.write_error(
'exception in get_transaction_data:{},{},{}'.format(symbol, str(ex), traceback.format_exc()))
return False, ret_datas
if __name__ == "__main__":
class T(object):
def write_log(self, content, level=INFO):
if level == INFO:
print(content)
else:
print(content, file=sys.stderr)
def display_bar(self, bar, bar_is_completed=True, freq=1):
print(u'{} {}'.format(bar.vtSymbol, bar.datetime))
t1 = T()
t2 = T()
from tdx_common import FakeStrategy
import json
t1 = FakeStrategy()
t2 = FakeStrategy()
# 创建API对象
api_01 = TdxStockData(t1)
# 获取市场下股票
for market_id in range(2):
print('get market_id:{}'.format(market_id))
security_list = api_01.get_security_list(market_id)
if len(security_list) == 0:
continue
for security in security_list:
if security.get('code', '').startswith('12') or u'转债' in security.get('name', ''):
str_security = json.dumps(security, indent=1, ensure_ascii=False)
print('market_id:{},{}'.format(market_id, str_security))
# str_markets = json.dumps(security_list, indent=1, ensure_ascii=False)
# print(u'{}'.format(str_markets))
# 获取历史分钟线
# api_01.get_bars('002024', period='1hour', callback=t1.display_bar)