[update] gateway

This commit is contained in:
msincenselee 2021-07-15 08:09:41 +08:00
parent d79157dd8e
commit 917562f81f
15 changed files with 998 additions and 1157 deletions

View File

@ -155,6 +155,8 @@ class CtaEngine(BaseEngine):
self.health_status = {} self.health_status = {}
self.symbol_bar_dict = {} # vt_symbol: bar(一分钟bar)
def init_engine(self): def init_engine(self):
""" """
""" """
@ -266,6 +268,9 @@ class CtaEngine(BaseEngine):
def process_bar_event(self, event: Event): def process_bar_event(self, event: Event):
"""处理bar到达事件""" """处理bar到达事件"""
bar = event.data bar = event.data
# 更新bar
self.symbol_bar_dict[bar.vt_symbol] = bar
# 寻找订阅了该bar的策略
strategies = self.symbol_strategy_map[bar.vt_symbol] strategies = self.symbol_strategy_map[bar.vt_symbol]
if not strategies: if not strategies:
return return
@ -867,6 +872,10 @@ class CtaEngine(BaseEngine):
if tick: if tick:
return tick.last_price return tick.last_price
bar = self.symbol_bar_dict.get(vt_symbol)
if bar:
return bar.close_price
return None return None
def get_contract(self, vt_symbol): def get_contract(self, vt_symbol):

View File

@ -1008,7 +1008,7 @@ class CtaStockTemplate(CtaTemplate):
if len(grid.order_ids) > 0: if len(grid.order_ids) > 0:
continue continue
if grid.volume == grid.traded_volume: if grid.volume <= grid.traded_volume:
self.write_log(u'网格计划卖出:{},已成交:{}'.format(grid.volume, grid.traded_volume)) self.write_log(u'网格计划卖出:{},已成交:{}'.format(grid.volume, grid.traded_volume))
self.tns_finish_sell_grid(grid) self.tns_finish_sell_grid(grid)
continue continue
@ -1033,30 +1033,31 @@ class CtaStockTemplate(CtaTemplate):
# 实盘运行时,要加入市场买卖量的判断 # 实盘运行时,要加入市场买卖量的判断
if not self.backtesting: if not self.backtesting:
symbol_tick = self.cta_engine.get_tick(vt_symbol) cur_price = self.cta_engine.get_price(vt_symbol)
if symbol_tick is None: if not cur_price:
self.cta_engine.subscribe_symbol(strategy_name=self.strategy_name, vt_symbol=vt_symbol) self.cta_engine.subscribe_symbol(strategy_name=self.strategy_name, vt_symbol=vt_symbol)
self.write_log(f'获取不到{vt_symbol}得tick,无法根据市场深度进行计算')
continue continue
symbol_volume_tick = self.cta_engine.get_volume_tick(vt_symbol) symbol_tick = self.cta_engine.get_tick(vt_symbol)
# 根据市场计算前5档买单数量 if symbol_tick:
if all([symbol_tick.ask_volume_1, symbol_tick.ask_volume_2, symbol_tick.ask_volume_3, symbol_volume_tick = self.cta_engine.get_volume_tick(vt_symbol)
symbol_tick.ask_volume_4, symbol_tick.ask_volume_5]) \ # 根据市场计算前5档买单数量
and all( if all([symbol_tick.ask_volume_1, symbol_tick.ask_volume_2, symbol_tick.ask_volume_3,
[symbol_tick.bid_volume_1, symbol_tick.bid_volume_2, symbol_tick.bid_volume_3, symbol_tick.bid_volume_4, symbol_tick.ask_volume_4, symbol_tick.ask_volume_5]) \
symbol_tick.bid_volume_5]): and all(
market_ask_volumes = symbol_tick.ask_volume_1 + symbol_tick.ask_volume_2 + symbol_tick.ask_volume_3 + symbol_tick.ask_volume_4 + symbol_tick.ask_volume_5 [symbol_tick.bid_volume_1, symbol_tick.bid_volume_2, symbol_tick.bid_volume_3, symbol_tick.bid_volume_4,
market_bid_volumes = symbol_tick.bid_volume_1 + symbol_tick.bid_volume_2 + symbol_tick.bid_volume_3 + symbol_tick.bid_volume_4 + symbol_tick.bid_volume_5 symbol_tick.bid_volume_5]):
org_sell_volume = sell_volume market_ask_volumes = symbol_tick.ask_volume_1 + symbol_tick.ask_volume_2 + symbol_tick.ask_volume_3 + symbol_tick.ask_volume_4 + symbol_tick.ask_volume_5
if market_bid_volumes > 0 and market_ask_volumes > 0 and org_sell_volume >= 2 * symbol_volume_tick: market_bid_volumes = symbol_tick.bid_volume_1 + symbol_tick.bid_volume_2 + symbol_tick.bid_volume_3 + symbol_tick.bid_volume_4 + symbol_tick.bid_volume_5
sell_volume = min(market_bid_volumes / 4, market_ask_volumes / 4, sell_volume) org_sell_volume = sell_volume
sell_volume = max(round_to(value=sell_volume, target=symbol_volume_tick), symbol_volume_tick) if market_bid_volumes > 0 and market_ask_volumes > 0 and org_sell_volume >= 2 * symbol_volume_tick:
if org_sell_volume != sell_volume: sell_volume = min(market_bid_volumes / 4, market_ask_volumes / 4, sell_volume)
self.write_log(u'修正批次卖出{}数量:{}=>{}'.format(vt_symbol, org_sell_volume, sell_volume)) sell_volume = max(round_to(value=sell_volume, target=symbol_volume_tick), symbol_volume_tick)
if org_sell_volume != sell_volume:
self.write_log(u'修正批次卖出{}数量:{}=>{}'.format(vt_symbol, org_sell_volume, sell_volume))
# 获取当前价格 # 获取当前价格
sell_price = self.cta_engine.get_price(vt_symbol) - self.cta_engine.get_price_tick(vt_symbol) sell_price = cur_price - self.cta_engine.get_price_tick(vt_symbol)
# 发出委托卖出 # 发出委托卖出
vt_orderids = self.sell( vt_orderids = self.sell(
vt_symbol=vt_symbol, vt_symbol=vt_symbol,
@ -1138,7 +1139,7 @@ class CtaStockTemplate(CtaTemplate):
if len(grid.order_ids) > 0: if len(grid.order_ids) > 0:
continue continue
if grid.volume == grid.traded_volume: if grid.volume <= grid.traded_volume:
self.write_log(u'网格计划买入:{},已成交:{}'.format(grid.volume, grid.traded_volume)) self.write_log(u'网格计划买入:{},已成交:{}'.format(grid.volume, grid.traded_volume))
self.tns_finish_buy_grid(grid) self.tns_finish_buy_grid(grid)
return return
@ -1157,6 +1158,10 @@ class CtaStockTemplate(CtaTemplate):
continue continue
buy_volume = ordering_grid.volume - ordering_grid.traded_volume buy_volume = ordering_grid.volume - ordering_grid.traded_volume
# if buy_volume <= 0:
# self.write_error(f'{grid.vt_symbol} 已买入数量:{ordering_grid.traded_volume} 超过委托数量:{ordering_grid.volume}')
# continue
min_trade_volume = self.cta_engine.get_volume_tick(vt_symbol) min_trade_volume = self.cta_engine.get_volume_tick(vt_symbol)
if availiable < buy_volume * cur_price: if availiable < buy_volume * cur_price:
self.write_error(f'可用资金{availiable},不满足买入{vt_symbol},数量:{buy_volume} X价格{cur_price}') self.write_error(f'可用资金{availiable},不满足买入{vt_symbol},数量:{buy_volume} X价格{cur_price}')
@ -1175,17 +1180,18 @@ class CtaStockTemplate(CtaTemplate):
# 实盘运行时,要加入市场买卖量的判断 # 实盘运行时,要加入市场买卖量的判断
if not self.backtesting and 'market' in ordering_grid.snapshot: if not self.backtesting and 'market' in ordering_grid.snapshot:
symbol_tick = self.cta_engine.get_tick(vt_symbol) symbol_tick = self.cta_engine.get_tick(vt_symbol)
# 根据市场计算前5档买单数量 if symbol_tick:
if all([symbol_tick.ask_volume_1, symbol_tick.ask_volume_2, symbol_tick.ask_volume_3, # 根据市场计算前5档买单数量
symbol_tick.ask_volume_4, symbol_tick.ask_volume_5]) \ if all([symbol_tick.ask_volume_1, symbol_tick.ask_volume_2, symbol_tick.ask_volume_3,
and all( symbol_tick.ask_volume_4, symbol_tick.ask_volume_5]) \
[symbol_tick.bid_volume_1, symbol_tick.bid_volume_2, symbol_tick.bid_volume_3, symbol_tick.bid_volume_4, and all(
symbol_tick.bid_volume_5]): [symbol_tick.bid_volume_1, symbol_tick.bid_volume_2, symbol_tick.bid_volume_3, symbol_tick.bid_volume_4,
market_ask_volumes = symbol_tick.ask_volume_1 + symbol_tick.ask_volume_2 + symbol_tick.ask_volume_3 + symbol_tick.ask_volume_4 + symbol_tick.ask_volume_5 symbol_tick.bid_volume_5]):
market_bid_volumes = symbol_tick.bid_volume_1 + symbol_tick.bid_volume_2 + symbol_tick.bid_volume_3 + symbol_tick.bid_volume_4 + symbol_tick.bid_volume_5 market_ask_volumes = symbol_tick.ask_volume_1 + symbol_tick.ask_volume_2 + symbol_tick.ask_volume_3 + symbol_tick.ask_volume_4 + symbol_tick.ask_volume_5
if market_bid_volumes > 0 and market_ask_volumes > 0: market_bid_volumes = symbol_tick.bid_volume_1 + symbol_tick.bid_volume_2 + symbol_tick.bid_volume_3 + symbol_tick.bid_volume_4 + symbol_tick.bid_volume_5
buy_volume = min(market_bid_volumes / 4, market_ask_volumes / 4, buy_volume) if market_bid_volumes > 0 and market_ask_volumes > 0:
buy_volume = max(buy_volume - buy_volume % min_trade_volume, min_trade_volume) buy_volume = min(market_bid_volumes / 4, market_ask_volumes / 4, buy_volume)
buy_volume = max(buy_volume - buy_volume % min_trade_volume, min_trade_volume)
buy_price = cur_price + self.cta_engine.get_price_tick(vt_symbol) * 10 buy_price = cur_price + self.cta_engine.get_price_tick(vt_symbol) * 10

View File

@ -813,8 +813,21 @@ class CtaLineBar(object):
if not self.is_7x24 and (tick.datetime.hour == 8 or tick.datetime.hour == 20): if not self.is_7x24 and (tick.datetime.hour == 8 or tick.datetime.hour == 20):
self.write_log(u'{}竞价排名tick时间:{}'.format(self.name, tick.datetime)) self.write_log(u'{}竞价排名tick时间:{}'.format(self.name, tick.datetime))
return return
self.cur_datetime = tick.datetime
# 过滤一些 异常的tick价格
if self.cur_price is not None and self.cur_price !=0 and tick.last_price is not None and tick.last_price != 0:
# 前后价格超过10%
if abs(tick.last_price - self.cur_price)/self.cur_price >= 0.1:
# 是同一天,都不接受这些tick
if self.cur_datetime and self.cur_datetime.date == tick.datetime.date:
return
else:
# 不是同一天只过滤当前这个tick如果下个tick还是有变化就接受
self.cur_price = tick.last_price
self.cur_datetime = tick.datetime
return
self.cur_datetime = tick.datetime
self.cur_tick = copy.copy(tick) self.cur_tick = copy.copy(tick)
# 兼容 标准套利合约它没有last_price # 兼容 标准套利合约它没有last_price
@ -5614,6 +5627,20 @@ class CtaLineBar(object):
bi_list = self.bi_list[-bi_len:] bi_list = self.bi_list[-bi_len:]
return round(sum([bi.height for bi in bi_list]) / max(1, len(bi_list)), self.round_n) return round(sum([bi.height for bi in bi_list]) / max(1, len(bi_list)), self.round_n)
def duan_atan_ma(self, duan_len=20):
"""返回段得平均斜率"""
if not self.chanlun_calculated:
self.__count_chanlun()
duan_list = self.duan_list[-duan_len:]
return round(sum([d.atan for d in duan_list]) / max(1, len(duan_list)), self.round_n)
def bi_atan_ma(self, bi_len=20):
"""返回分笔得平均斜率"""
if not self.chanlun_calculated:
self.__count_chanlun()
bi_list = self.bi_list[-bi_len:]
return round(sum([bi.atan for bi in bi_list]) / max(1, len(bi_list)), self.round_n)
def export_chan(self): def export_chan(self):
""" """
输出缠论 = csv文件 输出缠论 = csv文件

View File

@ -73,10 +73,31 @@ def check_bi_not_rt(kline: CtaLineBar, direction: Direction) -> bool:
if not kline.cur_bi or kline.cur_bi.direction != direction: if not kline.cur_bi or kline.cur_bi.direction != direction:
return False return False
if kline.cur_bi.direction == kline.fenxing_list[-1].direction: if not kline.cur_fenxing:
if not kline.fenxing_list[-1].is_rt: return False
if kline.cur_bi.direction == kline.cur_fenxing.direction:
if not kline.cur_fenxing.is_rt:
return True return True
else: else:
if direction == 1:
# 判断还没走完的bar是否满足顶分型
if float(kline.cur_fenxing.high) == float(kline.high_array[-1]) \
and kline.cur_fenxing.index == kline.index_list[-1] \
and kline.line_bar[-1].datetime.strftime('%Y-%m-%d %H:%M:%S') > kline.cur_fenxing.index\
and kline.line_bar[-1].high_price < float(kline.cur_fenxing.high) \
and kline.line_bar[-1].low_price < kline.line_bar[-2].low_price:
return True
else:
# 判断还没走完的bar是否满足底分型
if float(kline.cur_fenxing.low) == float(kline.low_array[-1]) \
and kline.cur_fenxing.index == kline.index_list[-1] \
and kline.line_bar[-1].datetime.strftime('%Y-%m-%d %H:%M:%S') > kline.cur_fenxing.index \
and kline.line_bar[-1].low_price > float(kline.cur_fenxing.low) \
and kline.line_bar[-1].high_price < kline.line_bar[-2].high_price:
return True
return False return False
return True return True
@ -274,9 +295,14 @@ def check_chan_xt_five_bi(kline: CtaLineBar, bi_list: List[ChanObject]):
or (min_low == bi_3.low and bi_5.high > bi_3.high > bi_5.low > bi_3.low): or (min_low == bi_3.low and bi_5.high > bi_3.high > bi_5.low > bi_3.low):
v = ChanSignals.LG0.value v = ChanSignals.LG0.value
# 五笔三买要求bi_5.high是最高点 # 五笔三买要求bi_5.high是最高点, 或者bi_4.height超过笔2、笔3两倍
if max(bi_1.low, bi_3.low) < min(bi_1.high, bi_3.high) < bi_5.low and bi_5.high == max_high: if max(bi_1.low, bi_3.low) < min(bi_1.high, bi_3.high) < bi_5.low:
v = ChanSignals.LI0.value if bi_5.high == max_high:
v = ChanSignals.LI0.value
elif bi_4.low == min_low and bi_1.high == max_high \
and bi_4.height > max(bi_1.height, bi_2.height, bi_3.height) \
and bi_4.height > 2 * max(bi_2.height, bi_3.height):
v = ChanSignals.LI0.value
# 向上三角扩张中枢 # 向上三角扩张中枢
if bi_1.high < bi_3.high < bi_5.high and bi_1.low > bi_3.low > bi_5.low: if bi_1.high < bi_3.high < bi_5.high and bi_1.low > bi_3.low > bi_5.low:
@ -306,8 +332,14 @@ def check_chan_xt_five_bi(kline: CtaLineBar, bi_list: List[ChanObject]):
v = ChanSignals.SG0.value v = ChanSignals.SG0.value
# 五笔三卖要求bi_5.low是最低点中枢可能是1~3 # 五笔三卖要求bi_5.low是最低点中枢可能是1~3
if min(bi_1.high, bi_3.high) > max(bi_1.low, bi_3.low) > bi_5.high and bi_5.low == min_low: if min(bi_1.high, bi_3.high) > max(bi_1.low, bi_3.low) > bi_5.high:
v = ChanSignals.SI0.value if bi_5.low == min_low:
v = ChanSignals.SI0.value
elif bi_4.high == max_high and bi_1.low == min_low \
and bi_4.height > max(bi_1.height, bi_2.height, bi_3.height)\
and bi_4.height > 2 * max(bi_2.height,bi_3.height):
v = ChanSignals.SI0.value
# elif bi_1.high == max_high and bi_1.low == min_low:
# 向下三角扩张中枢 # 向下三角扩张中枢
if bi_1.high < bi_3.high < bi_5.high and bi_1.low > bi_3.low > bi_5.low: if bi_1.high < bi_3.high < bi_5.high and bi_1.low > bi_3.low > bi_5.low:
@ -369,7 +401,7 @@ def check_chan_xt_seven_bi(kline: CtaLineBar, bi_list: List[ChanObject]):
if bi_5.high == max_high and bi_5.high > bi_7.high \ if bi_5.high == max_high and bi_5.high > bi_7.high \
and bi_5.low > bi_7.low > min(bi_1.high, bi_3.high) > max(bi_1.low, bi_3.low): and bi_5.low > bi_7.low > min(bi_1.high, bi_3.high) > max(bi_1.low, bi_3.low):
v = ChanSignals.LI0.value v = ChanSignals.LI0.value
#
elif bi_7.direction == 1: elif bi_7.direction == 1:
# 顶背驰 # 顶背驰
if bi_1.low == min_low and bi_7.high == max_high: if bi_1.low == min_low and bi_7.high == max_high:

View File

@ -235,10 +235,9 @@ class TdxFutureData(object):
self.best_ip = get_cache_json(TDX_FUTURE_CONFIG) self.best_ip = get_cache_json(TDX_FUTURE_CONFIG)
if is_reconnect: if is_reconnect:
if is_reconnect: selected_ip = self.best_ip.get('ip')
selected_ip = self.best_ip.get('ip') if selected_ip not in self.exclude_ips:
if selected_ip not in self.exclude_ips: self.exclude_ips.append(selected_ip)
self.exclude_ips.append(selected_ip)
self.best_ip = {} self.best_ip = {}
else: else:
# 超时的话,重新选择 # 超时的话,重新选择

View File

@ -64,6 +64,7 @@ RQ_TDX_STOCK_MARKET_MAP = {v: k for k, v in TDX_RQ_STOCK_MARKET_MAP.items()}
# 本地缓存文件 # 本地缓存文件
class TdxStockData(object): class TdxStockData(object):
exclude_ips = []
def __init__(self, strategy=None, proxy_ip="", proxy_port=0): def __init__(self, strategy=None, proxy_ip="", proxy_port=0):
""" """
@ -93,6 +94,8 @@ class TdxStockData(object):
self.config = get_cache_config(TDX_STOCK_CONFIG) self.config = get_cache_config(TDX_STOCK_CONFIG)
self.symbol_dict = self.config.get('symbol_dict', {}) self.symbol_dict = self.config.get('symbol_dict', {})
self.cache_time = self.config.get('cache_time', datetime.now() - timedelta(days=7)) self.cache_time = self.config.get('cache_time', datetime.now() - timedelta(days=7))
self.best_ip = self.config.get('best_ip',{})
self.exclude_ips = self.config.get('exclude_ips',[])
if len(self.symbol_dict) == 0 or self.cache_time < datetime.now() - timedelta(days=1): if len(self.symbol_dict) == 0 or self.cache_time < datetime.now() - timedelta(days=1):
self.cache_config() self.cache_config()
@ -111,6 +114,32 @@ class TdxStockData(object):
else: else:
print(content, file=sys.stderr) print(content, file=sys.stderr)
def select_best_ip(self, ip_list, proxy_ip="", proxy_port=0, exclude_ips=[]):
"""
选取最快的IP
:param ip_list:
:param proxy_ip: 代理
:param proxy_port: 代理端口
:param exclude_ips: 排除清单
:return:
"""
from pytdx.util.best_ip import ping
data = [ping(ip=x['ip'], port=x['port'], type_='stock', proxy_ip=proxy_ip, proxy_port=proxy_port) for x in
ip_list if x['ip'] not in exclude_ips]
results = []
for i in range(len(data)):
# 删除ping不通的数据
if data[i] < timedelta(0, 9, 0):
results.append((data[i], ip_list[i]))
else:
if ip_list[i].get('ip') not in self.exclude_ips:
self.exclude_ips.append(ip_list[i].get('ip'))
# 按照ping值从小大大排序
results = [x[1] for x in sorted(results, key=lambda x: x[0])]
return results[0]
def connect(self, is_reconnect: bool = False): def connect(self, is_reconnect: bool = False):
""" """
连接API 连接API
@ -126,11 +155,38 @@ class TdxStockData(object):
# 选取最佳服务器 # 选取最佳服务器
if is_reconnect or self.best_ip is None: if is_reconnect or self.best_ip is None:
self.best_ip = self.config.get('best_ip', {}) self.best_ip = self.config.get('best_ip', {})
if is_reconnect:
selected_ip = self.best_ip.get('ip')
if selected_ip not in self.exclude_ips:
self.exclude_ips.append(selected_ip)
self.best_ip = {}
else:
# 超时的话,重新选择
last_datetime_str = self.best_ip.get('datetime', None)
if last_datetime_str:
try:
last_datetime = datetime.strptime(last_datetime_str, '%Y-%m-%d %H:%M:%S')
ip = self.best_ip.get('ip')
is_bad_ip = ip and ip in self.best_ip.get('exclude_ips', [])
if (datetime.now() - last_datetime).total_seconds() > 60 * 60 * 2 or is_bad_ip:
self.best_ip = {}
if not is_bad_ip:
self.exclude_ips = []
except Exception as ex: # noqa
self.best_ip = {}
else:
self.best_ip = {}
if len(self.best_ip) == 0: if len(self.best_ip) == 0:
from pytdx.util.best_ip import select_best_ip from pytdx.util.best_ip import stock_ip
self.best_ip = select_best_ip(_type='socket', proxy_ip=self.proxy_ip, proxy_port=self.proxy_port) self.best_ip = self.select_best_ip(ip_list=stock_ip,
self.config.update({'best_ip': self.best_ip}) proxy_ip=self.proxy_ip,
proxy_port=self.proxy_port,
exclude_ips=self.exclude_ips)
# 保存最新的选择,排除
self.config.update({'best_ip': self.best_ip,
'select_dt': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
'exclude_ips': self.exclude_ips})
save_cache_config(self.config, TDX_STOCK_CONFIG) save_cache_config(self.config, TDX_STOCK_CONFIG)
# 如果配置proxy5使用vnpy项目下的pytdx # 如果配置proxy5使用vnpy项目下的pytdx
@ -316,10 +372,10 @@ class TdxStockData(object):
for index, row in data.iterrows(): for index, row in data.iterrows():
try: try:
add_bar = BarData( add_bar = BarData(
gateway_name='tdx', gateway_name='tdx',
symbol=symbol, symbol=symbol,
exchange=exchange, exchange=exchange,
datetime=index datetime=index
) )
add_bar.date = row['date'] add_bar.date = row['date']
add_bar.time = row['time'] add_bar.time = row['time']
@ -365,6 +421,75 @@ class TdxStockData(object):
self.connect(is_reconnect=True) self.connect(is_reconnect=True)
return False, ret_bars return False, ret_bars
def get_last_bars(self, symbol: str, period: str = '1min', n: int = 2, return_bar: bool = True):
"""
获取最后n根bar
:param symbol:
:param period:
:param n:取bar数量
:param return_bar:
:return:
"""
if not self.api:
self.connect()
ret_bars = []
if self.api is None:
return False, []
# symbol => tdx_code, market_id
if '.' in symbol:
tdx_code, market_str = symbol.split('.')
# 1, 上交所 0 深交所
market_id = 1 if market_str.upper() in ['XSHG', Exchange.SSE.value] else 0
self.symbol_market_dict.update({tdx_code: market_id}) # tdx合约与tdx市场的字典
else:
market_id = get_tdx_market_code(symbol)
tdx_code = symbol
self.symbol_market_dict.update({symbol: market_id}) # tdx合约与tdx市场的字典
# period => tdx_period
if period not in PERIOD_MAPPING.keys():
self.write_error(u'{} 周期{}不在下载清单中: {}'
.format(datetime.now(), period, list(PERIOD_MAPPING.keys())))
return False, ret_bars
tdx_period = PERIOD_MAPPING.get(period)
try:
datas = self.api.get_security_bars(
category=PERIOD_MAPPING[period],
market=market_id,
code=tdx_code,
start=0,
count=n)
if not datas or len(datas) == 0:
return False, ret_bars
if not return_bar:
return True, datas
exchange = TDX_VN_STOCK_MARKET_MAP.get(market_id, Exchange.LOCAL)
delta_minutes = NUM_MINUTE_MAPPING.get(period, 1)
for data in datas:
bar_dt = datetime.strptime(data.get('datetime'), '%Y-%m-%d %H:%M')
bar_dt = bar_dt - timedelta(minutes=delta_minutes)
add_bar = BarData(
gateway_name='tdx',
symbol=symbol,
exchange=exchange,
datetime=bar_dt
)
add_bar.date = bar_dt.strftime('%Y-%m-%d')
add_bar.time = bar_dt.strftime('%H:%M:%S')
add_bar.trading_day = add_bar.date
add_bar.open_price = float(data['open'])
add_bar.high_price = float(data['high'])
add_bar.low_price = float(data['low'])
add_bar.close_price = float(data['close'])
add_bar.volume = float(data['vol'])
ret_bars.append(add_bar)
return True, ret_bars
except Exception as ex:
self.write_error(f'获取{symbol}数据失败:{str(ex)}')
return False, ret_bars
def save_cache(self, def save_cache(self,
cache_folder: str, cache_folder: str,
cache_symbol: str, cache_symbol: str,

View File

@ -374,8 +374,8 @@ class PbGateway(BaseGateway):
product_id=product_id, product_id=product_id,
unit_id=unit_id, unit_id=unit_id,
holder_ids=holder_ids) holder_ids=holder_ids)
#self.tq_api = TqMdApi(self) # self.tq_api = TqMdApi(self)
#self.tq_api.connect() # self.tq_api.connect()
self.init_query() self.init_query()
def close(self) -> None: def close(self) -> None:
@ -1054,6 +1054,9 @@ class PbTdApi(object):
# 未获取本地更新检查的orderid清单 # 未获取本地更新检查的orderid清单
self.unchecked_orderids = [] self.unchecked_orderids = []
# 警告
self.warning_dict = {}
def close(self): def close(self):
pass pass
@ -1110,9 +1113,9 @@ class PbTdApi(object):
"""获取资金账号信息""" """获取资金账号信息"""
# dbf 文件名 # dbf 文件名
account_dbf = os.path.abspath(os.path.join(self.account_folder, account_dbf = os.path.abspath(os.path.join(self.account_folder,
'{}{}.dbf'.format( '{}{}.dbf'.format(
PB_FILE_NAMES.get('accounts'), PB_FILE_NAMES.get('accounts'),
self.trading_date))) self.trading_date)))
try: try:
# dbf => 资金帐号信息 # dbf => 资金帐号信息
self.gateway.write_log(f'扫描资金帐号信息:{account_dbf}') self.gateway.write_log(f'扫描资金帐号信息:{account_dbf}')
@ -1125,18 +1128,28 @@ class PbTdApi(object):
account = AccountData( account = AccountData(
gateway_name=self.gateway_name, gateway_name=self.gateway_name,
accountid=self.userid, accountid=self.userid,
balance=float(data.dyjz), # ["单元净值"] balance=float(data.dyjz), # ["单元净值"]
frozen=float(data.dyjz) - float(data.kyye), # data["可用余额"] frozen=float(data.dyjz) - float(data.kyye), # data["可用余额"]
currency="人民币", currency="人民币",
trading_day=self.trading_day trading_day=self.trading_day
) )
self.gateway.on_account(account) self.gateway.on_account(account)
table.close() table.close()
self.warning_dict.pop('query_account', None)
except Exception as ex: except Exception as ex:
self.gateway.write_error(f'dbf扫描资金帐号异常:{str(ex)}') err_msg = f'dbf扫描资金帐号异常:{str(ex)}'
self.gateway.write_error(traceback.format_exc()) tra_msg = traceback.format_exc()
err_info = self.warning_dict.get('query_account', {})
err_count = err_info.get('err_count', 1)
if err_count > 10:
self.gateway.write_error(err_msg)
self.gateway.write_error(tra_msg)
else:
err_count += 1
err_info.update({'err_count': err_count, 'err_msg': err_msg, 'tra_msg': tra_msg})
self.warning_dict.update({'query_account': err_info})
def query_account_csv(self): def query_account_csv(self):
"""获取资金账号信息""" """获取资金账号信息"""
@ -1184,9 +1197,9 @@ class PbTdApi(object):
# , 'jysc', 'jybz', 'dryk', 'ljyk', 'fdyk', 'fyl', 'ykl', 'tzlx', 'gddm', 'mrsl', 'mcsl', 'mrje', 'mcje', 'zdf', 'bbj', 'qjcb', 'gtcb', 'gtyk', 'zgb'] # , 'jysc', 'jybz', 'dryk', 'ljyk', 'fdyk', 'fyl', 'ykl', 'tzlx', 'gddm', 'mrsl', 'mcsl', 'mrje', 'mcje', 'zdf', 'bbj', 'qjcb', 'gtcb', 'gtyk', 'zgb']
# dbf 文件名 # dbf 文件名
position_dbf = os.path.abspath(os.path.join(self.account_folder, position_dbf = os.path.abspath(os.path.join(self.account_folder,
'{}{}.dbf'.format( '{}{}.dbf'.format(
PB_FILE_NAMES.get('positions'), PB_FILE_NAMES.get('positions'),
self.trading_date))) self.trading_date)))
try: try:
# dbf => 股票持仓信息 # dbf => 股票持仓信息
self.gateway.write_log(f'扫描股票持仓信息:{position_dbf}') self.gateway.write_log(f'扫描股票持仓信息:{position_dbf}')
@ -1195,7 +1208,7 @@ class PbTdApi(object):
for data in table: for data in table:
if str(data.zjzh).strip() != self.userid: if str(data.zjzh).strip() != self.userid:
continue continue
symbol = str(data.zqdm).strip() #["证券代码"] symbol = str(data.zqdm).strip() # ["证券代码"]
# symbol => Exchange # symbol => Exchange
exchange = symbol_exchange_map.get(symbol, None) exchange = symbol_exchange_map.get(symbol, None)
@ -1207,30 +1220,41 @@ class PbTdApi(object):
name = symbol_name_map.get(symbol, None) name = symbol_name_map.get(symbol, None)
if not name: if not name:
name = data.zqmc # ["证券名称"] name = data.zqmc # ["证券名称"]
symbol_name_map.update({symbol: name}) symbol_name_map.update({symbol: name})
position = PositionData( position = PositionData(
gateway_name=self.gateway_name, gateway_name=self.gateway_name,
accountid=self.userid, accountid=self.userid,
symbol=symbol, #["证券代码"], symbol=symbol, # ["证券代码"],
exchange=exchange, exchange=exchange,
direction=Direction.NET, direction=Direction.NET,
name=name, name=name,
volume=int(data.ccsl), # ["持仓数量"] volume=int(data.ccsl), # ["持仓数量"]
yd_volume=int(data.kysl),# ["可用数量"] yd_volume=int(data.kysl), # ["可用数量"]
price=float(data.cbjg), # ["成本价"] price=float(data.cbjg), # ["成本价"]
cur_price=float(data.zxjg), # ["最新价"] cur_price=float(data.zxjg), # ["最新价"]
pnl=float(data.fdyk), # ["浮动盈亏"] pnl=float(data.fdyk), # ["浮动盈亏"]
holder_id=str(data.gddm).strip() #["股东"] holder_id=str(data.gddm).strip() # ["股东"]
) )
self.gateway.on_position(position) self.gateway.on_position(position)
table.close() table.close()
self.warning_dict.pop('query_position', None)
except Exception as ex: except Exception as ex:
self.gateway.write_error(f'dbf扫描股票持仓异常:{str(ex)}')
self.gateway.write_error(traceback.format_exc()) err_msg = f'dbf扫描股票持仓异常:{str(ex)}'
tra_msg = traceback.format_exc()
err_info = self.warning_dict.get('query_position', {})
err_count = err_info.get('err_count', 1)
if err_count > 10:
self.gateway.write_error(err_msg)
self.gateway.write_error(tra_msg)
else:
err_count += 1
err_info.update({'err_count': err_count, 'err_msg': err_msg, 'tra_msg': tra_msg})
self.warning_dict.update({'query_position': err_info})
def query_position_csv(self): def query_position_csv(self):
"""从csv获取持仓信息""" """从csv获取持仓信息"""
@ -1297,24 +1321,24 @@ class PbTdApi(object):
# fields:['zqgs', 'zjzh', 'zhlx', 'cpbh', 'cpmc', 'dybh', 'dymc', 'wtph', 'wtxh', 'zqdm', 'zqmc', 'wtfx', 'jglx', 'wtjg', 'wtsl', 'wtzt', 'cjsl', 'wtje' # fields:['zqgs', 'zjzh', 'zhlx', 'cpbh', 'cpmc', 'dybh', 'dymc', 'wtph', 'wtxh', 'zqdm', 'zqmc', 'wtfx', 'jglx', 'wtjg', 'wtsl', 'wtzt', 'cjsl', 'wtje'
# , 'cjjj', 'cdsl', 'jysc', 'fdyy', 'wtly', 'wtrq', 'wtsj', 'jybz'] # , 'cjjj', 'cdsl', 'jysc', 'fdyy', 'wtly', 'wtrq', 'wtsj', 'jybz']
orders_dbf = os.path.abspath(os.path.join(self.account_folder, orders_dbf = os.path.abspath(os.path.join(self.account_folder,
'{}{}.dbf'.format( '{}{}.dbf'.format(
PB_FILE_NAMES.get('orders'), PB_FILE_NAMES.get('orders'),
self.trading_date))) self.trading_date)))
try: try:
# dbf => 股票委托信息 # dbf => 股票委托信息
self.gateway.write_log(f'扫描股票委托信息:{orders_dbf}') self.gateway.write_log(f'扫描股票委托信息:{orders_dbf}')
table = dbf.Table(orders_dbf, codepage='cp936') table = dbf.Table(orders_dbf, codepage='cp936')
table.open(dbf.READ_ONLY) table.open(dbf.READ_ONLY)
for data in table: for data in table:
if str(data.zjzh).strip() != self.userid: # ["资金账户"] if str(data.zjzh).strip() != self.userid: # ["资金账户"]
continue continue
sys_orderid = str(data.wtxh).strip() # ["委托序号"] sys_orderid = str(data.wtxh).strip() # ["委托序号"]
# 检查是否存在本地order_manager缓存中 # 检查是否存在本地order_manager缓存中
order = self.gateway.order_manager.get_order_with_sys_orderid(sys_orderid) order = self.gateway.order_manager.get_order_with_sys_orderid(sys_orderid)
order_date = str(data.wtrq).strip() #["委托日期"] order_date = str(data.wtrq).strip() # ["委托日期"]
order_time = str(data.wtsj).strip() #["委托时间"] order_time = str(data.wtsj).strip() # ["委托时间"]
order_status = STATUS_NAME2VT.get(str(data.wtzt).strip()) # ["委托状态"] order_status = STATUS_NAME2VT.get(str(data.wtzt).strip()) # ["委托状态"]
# 检查是否存在本地orders缓存中系统级别的委托单 # 检查是否存在本地orders缓存中系统级别的委托单
@ -1341,7 +1365,7 @@ class PbTdApi(object):
sys_order = OrderData( sys_order = OrderData(
gateway_name=self.gateway_name, gateway_name=self.gateway_name,
symbol=str(data.zqdm).strip(), # ["证券代码"] symbol=str(data.zqdm).strip(), # ["证券代码"]
exchange=EXCHANGE_NAME2VT.get(str(data.jysc).strip()), # ["交易市场"] exchange=EXCHANGE_NAME2VT.get(str(data.jysc).strip()), # ["交易市场"]
orderid=sys_orderid, orderid=sys_orderid,
sys_orderid=sys_orderid, sys_orderid=sys_orderid,
accountid=self.userid, accountid=self.userid,
@ -1364,8 +1388,8 @@ class PbTdApi(object):
# 存在账号缓存,判断状态是否更新 # 存在账号缓存,判断状态是否更新
else: else:
# 暂不处理交给XHPT_WTCX模块处理 # 暂不处理交给XHPT_WTCX模块处理
if sys_order.status != order_status or sys_order.traded != float(data.cjsl): # ["成交数量"] if sys_order.status != order_status or sys_order.traded != float(data.cjsl): # ["成交数量"]
sys_order.traded = float(data.cjsl) # ["成交数量"] sys_order.traded = float(data.cjsl) # ["成交数量"]
sys_order.status = order_status sys_order.status = order_status
self.orders.update({sys_order.sys_orderid: sys_order}) self.orders.update({sys_order.sys_orderid: sys_order})
self.gateway.write_log(f'账号订单查询,更新:{sys_order.__dict__}') self.gateway.write_log(f'账号订单查询,更新:{sys_order.__dict__}')
@ -1373,10 +1397,20 @@ class PbTdApi(object):
continue continue
table.close() table.close()
self.warning_dict.pop('query_orders', None)
except Exception as ex: except Exception as ex:
self.gateway.write_error(f'dbf扫描股票委托异常:{str(ex)}')
self.gateway.write_error(traceback.format_exc()) err_msg = f'dbf扫描股票委托异常:{str(ex)}'
tra_msg = traceback.format_exc()
err_info = self.warning_dict.get('query_orders', {})
err_count = err_info.get('err_count', 1)
if err_count > 10:
self.gateway.write_error(err_msg)
self.gateway.write_error(tra_msg)
else:
err_count += 1
err_info.update({'err_count': err_count, 'err_msg': err_msg, 'tra_msg': tra_msg})
self.warning_dict.update({'query_orders': err_info})
def query_orders_csv(self): def query_orders_csv(self):
"""获取所有委托""" """获取所有委托"""
@ -1566,9 +1600,20 @@ class PbTdApi(object):
continue continue
table.close() table.close()
self.warning_dict.pop('query_update_order', None)
except Exception as ex: except Exception as ex:
self.gateway.write_error(f'dbf查询委托库异常:{str(ex)}')
self.gateway.write_error(traceback.format_exc()) err_msg = f'dbf查询委托库异常:{str(ex)}'
tra_msg = traceback.format_exc()
err_info = self.warning_dict.get('query_update_order', {})
err_count = err_info.get('err_count', 1)
if err_count > 10:
self.gateway.write_error(err_msg)
self.gateway.write_error(tra_msg)
else:
err_count += 1
err_info.update({'err_count': err_count, 'err_msg': err_msg, 'tra_msg': tra_msg})
self.warning_dict.update({'query_update_order': err_info})
def query_update_orders_csv(self): def query_update_orders_csv(self):
"""扫描批量下单的委托查询(csv文件格式""" """扫描批量下单的委托查询(csv文件格式"""
@ -1646,11 +1691,11 @@ class PbTdApi(object):
table = dbf.Table(trades_dbf, codepage='cp936') table = dbf.Table(trades_dbf, codepage='cp936')
table.open(dbf.READ_ONLY) table.open(dbf.READ_ONLY)
for data in table: for data in table:
if str(data.zjzh).strip()!= self.userid: # ["资金账户"] if str(data.zjzh).strip() != self.userid: # ["资金账户"]
continue continue
sys_orderid = str(data.wtxh) # ["委托序号"] sys_orderid = str(data.wtxh) # ["委托序号"]
sys_tradeid = str(data.cjxh) # ["成交序号"] sys_tradeid = str(data.cjxh) # ["成交序号"]
# 检查是否存在本地trades缓存中 # 检查是否存在本地trades缓存中
trade = self.trades.get(sys_tradeid, None) trade = self.trades.get(sys_tradeid, None)
@ -1658,10 +1703,10 @@ class PbTdApi(object):
# 如果交易不再本地映射关系 # 如果交易不再本地映射关系
if trade is None and order is None: if trade is None and order is None:
trade_date = str(data.cjrq).strip() #["成交日期"] trade_date = str(data.cjrq).strip() # ["成交日期"]
trade_time = str(data.cjsj).strip() #["成交时间"] trade_time = str(data.cjsj).strip() # ["成交时间"]
trade_dt = datetime.strptime(f'{trade_date} {trade_time}', "%Y%m%d %H%M%S") trade_dt = datetime.strptime(f'{trade_date} {trade_time}', "%Y%m%d %H%M%S")
direction = DIRECTION_STOCK_NAME2VT.get(str(data.wtfx).strip()) # ["委托方向"] direction = DIRECTION_STOCK_NAME2VT.get(str(data.wtfx).strip()) # ["委托方向"]
offset = Offset.NONE offset = Offset.NONE
if direction is None: if direction is None:
direction = Direction.NET direction = Direction.NET
@ -1671,8 +1716,8 @@ class PbTdApi(object):
offset = Offset.CLOSE offset = Offset.CLOSE
trade = TradeData( trade = TradeData(
gateway_name=self.gateway_name, gateway_name=self.gateway_name,
symbol=str(data.zqdm).strip(), # ["证券代码"] symbol=str(data.zqdm).strip(), # ["证券代码"]
exchange=EXCHANGE_NAME2VT.get(str(data.jysc).strip()), # ["交易市场"] exchange=EXCHANGE_NAME2VT.get(str(data.jysc).strip()), # ["交易市场"]
orderid=sys_tradeid, orderid=sys_tradeid,
tradeid=sys_tradeid, tradeid=sys_tradeid,
sys_orderid=sys_orderid, sys_orderid=sys_orderid,
@ -1680,21 +1725,30 @@ class PbTdApi(object):
direction=direction, direction=direction,
offset=offset, offset=offset,
price=float(data.cjjg), # ["成交价格"] price=float(data.cjjg), # ["成交价格"]
volume=float(data.cjsl), # ["成交数量"] volume=float(data.cjsl), # ["成交数量"]
datetime=trade_dt, datetime=trade_dt,
time=trade_dt.strftime('%H:%M:%S'), time=trade_dt.strftime('%H:%M:%S'),
trade_amount=float(data.cjje), # ["成交金额"] trade_amount=float(data.cjje), # ["成交金额"]
commission=float(data.zfy) # ["总费用"] commission=float(data.zfy) # ["总费用"]
) )
self.trades[sys_tradeid] = trade self.trades[sys_tradeid] = trade
self.gateway.on_trade(copy.copy(trade)) self.gateway.on_trade(copy.copy(trade))
continue continue
table.close() table.close()
self.warning_dict.pop('query_trades', None)
except Exception as ex: except Exception as ex:
self.gateway.write_error(f'dbf扫描股票成交异常:{str(ex)}')
self.gateway.write_error(traceback.format_exc())
err_msg = f'dbf扫描股票成交异常:{str(ex)}'
tra_msg = traceback.format_exc()
err_info = self.warning_dict.get('query_trades', {})
err_count = err_info.get('err_count', 1)
if err_count > 10:
self.gateway.write_error(err_msg)
self.gateway.write_error(tra_msg)
else:
err_count += 1
err_info.update({'err_count': err_count, 'err_msg': err_msg, 'tra_msg': tra_msg})
self.warning_dict.update({'query_trades': err_info})
def query_trades_csv(self): def query_trades_csv(self):
"""获取所有成交""" """获取所有成交"""
@ -1835,9 +1889,21 @@ class PbTdApi(object):
continue continue
table.close() table.close()
self.warning_dict.pop('query_update_trade', None)
except Exception as ex: except Exception as ex:
self.gateway.write_error(f'dbf查询成交库异常:{str(ex)}')
self.gateway.write_error(traceback.format_exc()) err_msg = f'dbf查询成交库异常:{str(ex)}'
tra_msg = traceback.format_exc()
err_info = self.warning_dict.get('query_update_trade', {})
err_count = err_info.get('err_count', 1)
if err_count > 10:
self.gateway.write_error(err_msg)
self.gateway.write_error(tra_msg)
else:
err_count += 1
err_info.update({'err_count': err_count, 'err_msg': err_msg, 'tra_msg': tra_msg})
self.warning_dict.update({'query_update_trade': err_info})
def query_update_trades_csv(self): def query_update_trades_csv(self):
"""获取接口的csv成交更新""" """获取接口的csv成交更新"""
@ -1954,7 +2020,7 @@ class PbTdApi(object):
self.gateway.write_error(msg=f'{order.direction.value},{order.vt_symbol},{err_msg}', self.gateway.write_error(msg=f'{order.direction.value},{order.vt_symbol},{err_msg}',
error={"ErrorID": err_id, "ErrorMsg": "委托失败"}) error={"ErrorID": err_id, "ErrorMsg": "委托失败"})
if sys_orderid != '0': if sys_orderid not in ['0','None']:
self.gateway.order_manager.update_orderid_map(local_orderid=local_orderid, self.gateway.order_manager.update_orderid_map(local_orderid=local_orderid,
sys_orderid=sys_orderid) sys_orderid=sys_orderid)
order.sys_orderid = sys_orderid order.sys_orderid = sys_orderid
@ -1969,10 +2035,21 @@ class PbTdApi(object):
self.unchecked_orderids.remove(local_orderid) self.unchecked_orderids.remove(local_orderid)
table.close() table.close()
self.warning_dict.pop('query_send_order', None)
except Exception as ex: except Exception as ex:
self.gateway.write_error(f'dbf查询系统委托号异常:{str(ex)}')
self.gateway.write_error(traceback.format_exc()) err_msg = f'dbf查询系统委托号异常:{str(ex)}'
tra_msg = traceback.format_exc()
err_info = self.warning_dict.get('query_send_order', {})
err_count = err_info.get('err_count', 1)
if err_count > 10:
self.gateway.write_error(err_msg)
self.gateway.write_error(tra_msg)
else:
err_count += 1
err_info.update({'err_count': err_count, 'err_msg': err_msg, 'tra_msg': tra_msg})
self.warning_dict.update({'query_send_order': err_info})
def check_send_order_csv(self): def check_send_order_csv(self):
"""检查更新委托文件csv""" """检查更新委托文件csv"""
@ -2097,9 +2174,21 @@ class PbTdApi(object):
table.append(data) table.append(data)
# 关闭dbf文件 # 关闭dbf文件
table.close() table.close()
self.warning_dict.pop('send_order', None)
except Exception as ex: except Exception as ex:
self.gateway.write_error(f'dbf添加发单记录异常:{str(ex)}')
self.gateway.write_error(traceback.format_exc()) err_msg = f'dbf添加发单记录异常:{str(ex)}'
tra_msg = traceback.format_exc()
err_info = self.warning_dict.get('send_order', {})
err_count = err_info.get('err_count', 1)
if err_count > 10:
self.gateway.write_error(err_msg)
self.gateway.write_error(tra_msg)
else:
err_count += 1
err_info.update({'err_count': err_count, 'err_msg': err_msg, 'tra_msg': tra_msg})
self.warning_dict.update({'send_order': err_info})
return "" return ""
# 设置状态为提交中 # 设置状态为提交中
@ -2213,8 +2302,8 @@ class PbTdApi(object):
sys_orderid = self.gateway.order_manager.get_sys_orderid(req.orderid) sys_orderid = self.gateway.order_manager.get_sys_orderid(req.orderid)
if sys_orderid is None or len(sys_orderid) == 0: if sys_orderid is None or len(sys_orderid) == 0 or sys_orderid == 'None':
self.gateway.write_error(f'订单{req.orderid}=》系统委托id不存在,撤单失败') self.gateway.write_error(f'订单{req.orderid}=》系统委托id:{sys_orderid}不存在,撤单失败')
return False return False
data = ( data = (
@ -2240,10 +2329,22 @@ class PbTdApi(object):
table.append(data) table.append(data)
# 关闭dbf文件 # 关闭dbf文件
table.close() table.close()
self.warning_dict.pop('cancel_order', None)
return True return True
except Exception as ex: except Exception as ex:
self.gateway.write_error(f'dbf委托撤单异常:{str(ex)}')
self.gateway.write_error(traceback.format_exc()) err_msg = f'dbf委托撤单异常:{str(ex)}'
tra_msg = traceback.format_exc()
err_info = self.warning_dict.get('cancel_order', {})
err_count = err_info.get('err_count', 1)
if err_count > 10:
self.gateway.write_error(err_msg)
self.gateway.write_error(tra_msg)
else:
err_count += 1
err_info.update({'err_count': err_count, 'err_msg': err_msg, 'tra_msg': tra_msg})
self.warning_dict.update({'cancel_order': err_info})
return False return False
def cancel_order_csv(self, req: CancelRequest): def cancel_order_csv(self, req: CancelRequest):

View File

@ -1,14 +1,22 @@
import sys
import traceback import traceback
import json import json
from copy import deepcopy from copy import deepcopy
from uuid import uuid1 from uuid import uuid1
from datetime import datetime, timedelta from datetime import datetime, timedelta
from time import sleep
from threading import Thread from threading import Thread
from multiprocessing.dummy import Pool
from typing import Dict
import pandas as pd
from vnpy.event import Event from vnpy.event import Event
from vnpy.rpc import RpcClient from vnpy.rpc import RpcClient
from vnpy.trader.gateway import BaseGateway from vnpy.trader.gateway import BaseGateway
from vnpy.trader.object import ( from vnpy.trader.object import (
TickData, TickData,
BarData,
ContractData,
SubscribeRequest, SubscribeRequest,
CancelRequest, CancelRequest,
OrderRequest OrderRequest
@ -21,19 +29,32 @@ from vnpy.trader.event import (
EVENT_ACCOUNT, EVENT_ACCOUNT,
EVENT_CONTRACT, EVENT_CONTRACT,
EVENT_LOG) EVENT_LOG)
from vnpy.trader.constant import Exchange from vnpy.trader.constant import Exchange, Product
from vnpy.amqp.consumer import subscriber from vnpy.amqp.consumer import subscriber
from vnpy.amqp.producer import task_creator from vnpy.amqp.producer import task_creator
from vnpy.data.tdx.tdx_common import get_stock_type_sz, get_stock_type_sh
STOCK_CONFIG_FILE = 'tdx_stock_config.pkb2'
from pytdx.hq import TdxHq_API
# 通达信股票行情
from vnpy.data.tdx.tdx_common import get_cache_config, get_tdx_market_code
from vnpy.trader.utility import get_stock_exchange
from pytdx.config.hosts import hq_hosts
from pytdx.params import TDXParams
class StockRpcGateway(BaseGateway): class StockRpcGateway(BaseGateway):
""" """
股票交易得RPC接口 股票交易得RPC接口
交易使用RPC实现 交易使用RPC实现
行情使用RabbitMQ订阅获取 行情1:
需要启动单独得进程运行stock_tick_publisher 使用RabbitMQ订阅获取
Cta_Stock => 行情订阅 =StockRpcGateway =RabbitMQ (task)= stock_tick_publisher =订阅(worker) 需要启动单独得进程运行stock_tick_publisher
stock_tick_publisher => restful接口获取股票行情 =RabbitMQ(pub) => StockRpcGateway =>on_tick event Cta_Stock => 行情订阅 =StockRpcGateway =RabbitMQ (task)= stock_tick_publisher =订阅(worker)
stock_tick_publisher => restful接口获取股票行情 =RabbitMQ(pub) => StockRpcGateway =>on_tick event
行情2
使用tdx进行bar订阅
""" """
default_setting = { default_setting = {
@ -53,6 +74,7 @@ class StockRpcGateway(BaseGateway):
self.client = RpcClient() self.client = RpcClient()
self.client.callback = self.client_callback self.client.callback = self.client_callback
self.rabbit_api = None self.rabbit_api = None
self.tdx_api = None
self.rabbit_dict = {} self.rabbit_dict = {}
# 远程RPC端gateway_name # 远程RPC端gateway_name
self.remote_gw_name = gateway_name self.remote_gw_name = gateway_name
@ -75,11 +97,17 @@ class StockRpcGateway(BaseGateway):
# self.client.subscribe_topic(EVENT_LOG) # self.client.subscribe_topic(EVENT_LOG)
self.client.start(req_address, pub_address) self.client.start(req_address, pub_address)
self.status.update({"con":True})
self.rabbit_dict = setting.get('rabbit', {}) self.rabbit_dict = setting.get('rabbit', {})
self.write_log(f'激活RabbitMQ行情接口.配置:\n{self.rabbit_dict}') if len(self.rabbit_dict) > 0:
self.rabbit_api = SubMdApi(gateway=self) self.write_log(f'激活RabbitMQ行情接口.配置:\n{self.rabbit_dict}')
self.rabbit_api.connect(self.rabbit_dict) self.rabbit_api = SubMdApi(gateway=self)
self.rabbit_api.connect(self.rabbit_dict)
else:
self.write_log(f'激活tdx行情订阅接口')
self.tdx_api = TdxMdApi(gateway=self)
self.tdx_api.connect()
self.write_log("服务器连接成功,开始初始化查询") self.write_log("服务器连接成功,开始初始化查询")
@ -97,6 +125,11 @@ class StockRpcGateway(BaseGateway):
def subscribe(self, req: SubscribeRequest): def subscribe(self, req: SubscribeRequest):
"""行情订阅""" """行情订阅"""
if self.tdx_api:
self.tdx_api.subscribe(req)
return
self.write_log(f'创建订阅任务=> rabbitMQ') self.write_log(f'创建订阅任务=> rabbitMQ')
host = self.rabbit_dict.get('host', 'localhost') host = self.rabbit_dict.get('host', 'localhost')
port = self.rabbit_dict.get('port', 5672) port = self.rabbit_dict.get('port', 5672)
@ -247,6 +280,548 @@ class StockRpcGateway(BaseGateway):
self.event_engine.put(event) self.event_engine.put(event)
# 代码 <=> 中文名称
symbol_name_map: Dict[str, str] = {}
# 代码 <=> 交易所
symbol_exchange_map: Dict[str, Exchange] = {}
class TdxMdApi(object):
"""通达信行情和基础数据"""
def __init__(self, gateway: StockRpcGateway):
""""""
super().__init__()
self.gateway: StockRpcGateway = gateway
self.gateway_name: str = gateway.gateway_name
self.connect_status: bool = False
self.login_status: bool = False
self.req_interval = 0.5 # 操作请求间隔500毫秒
self.req_id = 0 # 操作请求编号
self.connection_status = False # 连接状态
self.symbol_exchange_dict = {} # tdx合约与vn交易所的字典
self.symbol_market_dict = {} # tdx合约与tdx市场的字典
self.symbol_vn_dict = {} # tdx合约与vtSymbol的对应
self.symbol_bar_dict = {} # tdx合约与最后一个bar得字典
self.registed_symbol_set = set()
self.config = get_cache_config(STOCK_CONFIG_FILE)
self.symbol_dict = self.config.get('symbol_dict', {})
# 最佳IP地址
self.best_ip = self.config.get('best_ip', {})
# 排除的异常地址
self.exclude_ips = self.config.get('exclude_ips', [])
# 选择时间
self.select_time = self.config.get('select_time', datetime.now() - timedelta(days=7))
# 缓存时间
self.cache_time = self.config.get('cache_time', datetime.now() - timedelta(days=7))
self.commission_dict = {}
self.contract_dict = {}
# self.queue = Queue() # 请求队列
self.pool = None # 线程池
# self.req_thread = None # 定时器线程
# copy.copy(hq_hosts)
self.ip_list = [{'ip': "180.153.18.170", 'port': 7709},
{'ip': "180.153.18.171", 'port': 7709},
{'ip': "180.153.18.172", 'port': 80},
{'ip': "202.108.253.130", 'port': 7709},
{'ip': "202.108.253.131", 'port': 7709},
{'ip': "202.108.253.139", 'port': 80},
{'ip': "60.191.117.167", 'port': 7709},
{'ip': "115.238.56.198", 'port': 7709},
{'ip': "218.75.126.9", 'port': 7709},
{'ip': "115.238.90.165", 'port': 7709},
{'ip': "124.160.88.183", 'port': 7709},
{'ip': "60.12.136.250", 'port': 7709},
{'ip': "218.108.98.244", 'port': 7709},
# {'ip': "218.108.47.69", 'port': 7709},
{'ip': "114.80.63.12", 'port': 7709},
{'ip': "114.80.63.35", 'port': 7709},
{'ip': "180.153.39.51", 'port': 7709},
# {'ip': '14.215.128.18', 'port': 7709},
# {'ip': '59.173.18.140', 'port': 7709}
]
self.best_ip = {'ip': None, 'port': None}
self.api_dict = {} # API 的连接会话对象字典
self.last_bar_dt = {} # 记录该合约的最后一个bar(结束)时间
self.last_api_bar_dict = {} # 记录会话最后一个bar的时间
self.security_count = 50000
# 股票code name列表
self.stock_codelist = None
def ping(self, ip, port=7709):
"""
ping行情服务器
:param ip:
:param port:
:param type_:
:return:
"""
apix = TdxHq_API()
__time1 = datetime.now()
try:
with apix.connect(ip, port):
if apix.get_security_count(TDXParams.MARKET_SZ) > 9000: # 0深市 股票数量 = 9260
_timestamp = datetime.now() - __time1
self.gateway.write_log('服务器{}:{},耗时:{}'.format(ip, port, _timestamp))
return _timestamp
else:
self.gateway.write_log(u'该服务器IP {}无响应'.format(ip))
return timedelta(9, 9, 0)
except:
self.gateway.write_error(u'tdx ping服务器异常的响应{}'.format(ip))
return timedelta(9, 9, 0)
def select_best_ip(self, ip_list, proxy_ip="", proxy_port=0, exclude_ips=[]):
"""
选取最快的IP
:param ip_list:
:param proxy_ip: 代理
:param proxy_port: 代理端口
:param exclude_ips: 排除清单
:return:
"""
from pytdx.util.best_ip import ping
data = [ping(ip=x['ip'], port=x['port'], type_='stock', proxy_ip=proxy_ip, proxy_port=proxy_port) for x in
ip_list if x['ip'] not in exclude_ips]
results = []
for i in range(len(data)):
# 删除ping不通的数据
if data[i] < timedelta(0, 9, 0):
results.append((data[i], ip_list[i]))
else:
if ip_list[i].get('ip') not in self.exclude_ips:
self.exclude_ips.append(ip_list[i].get('ip'))
# 按照ping值从小大大排序
results = [x[1] for x in sorted(results, key=lambda x: x[0])]
return results[0]
def connect(self, n=3):
"""
连接通达讯行情服务器
:param n:
:return:
"""
if self.connection_status:
for api in self.api_dict:
if api is not None or getattr(api, "client", None) is not None:
self.gateway.write_log(u'当前已经连接,不需要重新连接')
return
self.gateway.write_log(u'开始通达信行情服务器')
if len(self.symbol_dict) == 0:
self.gateway.write_error(f'本地没有股票信息的缓存配置文件')
else:
self.cov_contracts()
# 选取最佳服务器
if self.best_ip['ip'] is None and self.best_ip['port'] is None:
self.best_ip = self.select_best_ip(ip_list=self.ip_list,
proxy_ip="",
proxy_port=0,
exclude_ips=self.exclude_ips)
# 创建n个api连接对象实例
for i in range(n):
try:
api = TdxHq_API(heartbeat=True, auto_retry=True, raise_exception=True)
api.connect(self.best_ip['ip'], self.best_ip['port'])
# 尝试获取市场合约统计
c = api.get_security_count(TDXParams.MARKET_SZ)
if c is None or c < 10:
err_msg = u'该服务器IP {}/{}无响应'.format(self.best_ip['ip'], self.best_ip['port'])
self.gateway.write_error(err_msg)
else:
self.gateway.write_log(u'创建第{}个tdx连接'.format(i + 1))
self.api_dict[i] = api
self.last_bar_dt[i] = datetime.now()
self.connection_status = True
self.security_count = c
# if len(symbol_name_map) == 0:
# self.get_stock_list()
except Exception as ex:
self.gateway.write_error(u'连接服务器tdx[{}]异常:{},{}'.format(i, str(ex), traceback.format_exc()))
self.gateway.status.update({"tdx_status":False, "tdx_error":str(ex)})
return
# 创建连接池每个连接都调用run方法
self.pool = Pool(n)
self.pool.map_async(self.run, range(n))
# 设置上层的连接状态
self.gateway.status.update({"tdx_con":True, 'tdx_con_time':datetime.now().strftime('%H:%M:%S')})
def reconnect(self, i):
"""
重连
:param i:
:return:
"""
try:
self.best_ip = self.select_best_ip(ip_list=self.ip_list, exclude_ips=self.exclude_ips)
api = TdxHq_API(heartbeat=True, auto_retry=True)
api.connect(self.best_ip['ip'], self.best_ip['port'])
# 尝试获取市场合约统计
c = api.get_security_count(TDXParams.MARKET_SZ)
if c is None or c < 10:
err_msg = u'该服务器IP {}/{}无响应'.format(self.best_ip['ip'], self.best_ip['port'])
self.gateway.write_error(err_msg)
else:
self.gateway.write_log(u'重新创建第{}个tdx连接'.format(i + 1))
self.api_dict[i] = api
sleep(1)
except Exception as ex:
self.gateway.write_error(u'重新连接服务器tdx[{}]异常:{},{}'.format(i, str(ex), traceback.format_exc()))
self.gateway.status.update({"tdx_status":False, "tdx_error":str(ex)})
return
def close(self):
"""退出API"""
self.connection_status = False
# 设置上层的连接状态
self.gateway.status.update({'tdx_con':False})
if self.pool is not None:
self.pool.close()
self.pool.join()
def subscribe(self, req):
"""订阅合约"""
# 这里的设计是,如果尚未登录就调用了订阅方法
# 则先保存订阅请求,登录完成后会自动订阅
vn_symbol = str(req.symbol)
if '.' in vn_symbol:
vn_symbol = vn_symbol.split('.')[0]
self.gateway.write_log(u'通达信行情订阅 {}'.format(str(vn_symbol)))
tdx_symbol = vn_symbol # [0:-2] + 'L9'
tdx_symbol = tdx_symbol.upper()
self.gateway.write_log(u'{}=>{}'.format(vn_symbol, tdx_symbol))
self.symbol_vn_dict[tdx_symbol] = vn_symbol
if tdx_symbol not in self.registed_symbol_set:
self.registed_symbol_set.add(tdx_symbol)
# 查询股票信息
self.qry_instrument(vn_symbol)
self.check_status()
def check_status(self):
"""
tdx行情接口状态监控
:return:
"""
self.gateway.write_log(u'检查tdx接口状态')
try:
# 一共订阅的数量
self.gateway.status.update({"tdx_symbols_count":len(self.registed_symbol_set)})
dt_now = datetime.now()
if len(self.registed_symbol_set) > 0 and '0935' < dt_now.strftime("%H%M") < '1500':
# 若还没有启动连接,就启动连接
over_time = [((dt_now - dt).total_seconds() > 60) for dt in self.last_api_bar_dict.values()]
if not self.connection_status or len(self.api_dict) == 0 or any(over_time):
self.gateway.write_log(u'tdx还没有启动连接就启动连接')
self.close()
self.pool = None
self.api_dict = {}
pool_cout = getattr(self.gateway, 'tdx_pool_count', 3)
self.connect(pool_cout)
api_bar_times = [f'{k}:{v.hour}:{v.minute}' for k,v in self.last_api_bar_dict.items()]
if len(api_bar_times) > 0:
self.gateway.status.update({"tdx_api_dt":api_bar_times,'tdx_status':True})
#self.gateway.write_log(u'tdx接口状态正常')
except Exception as ex:
msg = f'检查tdx接口时异常:{str(ex)}' + traceback.format_exc()
self.gateway.write_error(msg)
def qry_instrument(self, symbol):
"""
查询/更新股票信息
:return:
"""
if not self.connection_status:
return
api = self.api_dict.get(0)
if api is None:
self.gateway.write_log(u'取不到api连接更新合约信息失败')
return
# TODO 取得股票的中文名
market_code = get_tdx_market_code(symbol)
api.to_df(api.get_finance_info(market_code, symbol))
# 如果有预定的订阅合约,提前订阅
# if len(all_contacts) > 0:
# cur_folder = os.path.dirname(__file__)
# export_file = os.path.join(cur_folder,'contracts.csv')
# if not os.path.exists(export_file):
# df = pd.DataFrame(all_contacts)
# df.to_csv(export_file)
def cov_contracts(self):
"""转换本地缓存=》合约信息推送"""
for symbol_marketid, info in self.symbol_dict.items():
symbol, market_id = symbol_marketid.split('_')
exchange = info.get('exchange', '')
if len(exchange) == 0:
continue
vn_exchange_str = get_stock_exchange(symbol)
# 排除通达信的指数代码
if exchange != vn_exchange_str:
continue
exchange = Exchange(exchange)
if info['stock_type'] == 'stock_cn':
product = Product.EQUITY
elif info['stock_type'] in ['bond_cn', 'cb_cn']:
product = Product.BOND
elif info['stock_type'] == 'index_cn':
product = Product.INDEX
elif info['stock_type'] == 'etf_cn':
product = Product.ETF
else:
product = Product.EQUITY
volume_tick = info['volunit']
if symbol.startswith('688'):
volume_tick = 200
contract = ContractData(
gateway_name=self.gateway_name,
symbol=symbol,
exchange=exchange,
name=info['name'],
product=product,
pricetick=round(0.1 ** info['decimal_point'], info['decimal_point']),
size=1,
min_volume=volume_tick,
margin_rate=1
)
if product != Product.INDEX:
# 缓存 合约 =》 中文名
symbol_name_map.update({contract.symbol: contract.name})
# 缓存代码和交易所的印射关系
symbol_exchange_map[contract.symbol] = contract.exchange
self.contract_dict.update({contract.symbol: contract})
self.contract_dict.update({contract.vt_symbol: contract})
# 推送
self.gateway.on_contract(contract)
def get_stock_list(self):
"""股票所有的code&name列表"""
api = self.api_dict.get(0)
if api is None:
self.gateway.write_log(u'取不到api连接更新合约信息失败')
return None
self.gateway.write_log(f'查询所有的股票信息')
data = pd.concat(
[pd.concat([api.to_df(api.get_security_list(j, i * 1000)).assign(sse='sz' if j == 0 else 'sh').set_index(
['code', 'sse'], drop=False) for i in range(int(api.get_security_count(j) / 1000) + 1)], axis=0) for j
in range(2)], axis=0)
sz = data.query('sse=="sz"')
sh = data.query('sse=="sh"')
sz = sz.assign(sec=sz.code.apply(get_stock_type_sz))
sh = sh.assign(sec=sh.code.apply(get_stock_type_sh))
temp_df = pd.concat([sz, sh]).query('sec in ["stock_cn","etf_cn","bond_cn","cb_cn"]').sort_index().assign(
name=data['name'].apply(lambda x: str(x)[0:6]))
hq_codelist = temp_df.loc[:, ['code', 'name']].set_index(['code'], drop=False)
for i in range(0, len(temp_df)):
row = temp_df.iloc[i]
if row['sec'] == 'etf_cn':
product = Product.ETF
elif row['sec'] in ['bond_cn', 'cb_cn']:
product = Product.BOND
else:
product = Product.EQUITY
volume_tick = 100 if product != Product.BOND else 10
if row['code'].startswith('688'):
volume_tick = 200
contract = ContractData(
gateway_name=self.gateway_name,
symbol=row['code'],
exchange=Exchange.SSE if row['sse'] == 'sh' else Exchange.SZSE,
name=row['name'],
product=product,
pricetick=round(0.1 ** row['decimal_point'], row['decimal_point']),
size=1,
min_volume=volume_tick,
margin_rate=1
)
# 缓存 合约 =》 中文名
symbol_name_map.update({contract.symbol: contract.name})
# 缓存代码和交易所的印射关系
symbol_exchange_map[contract.symbol] = contract.exchange
self.contract_dict.update({contract.symbol: contract})
self.contract_dict.update({contract.vt_symbol: contract})
# 推送
self.gateway.on_contract(contract)
return hq_codelist
def run(self, i):
"""
版本1Pool内得线程持续运行,每个线程从queue中获取一个请求并处理
版本2Pool内线程从订阅合约集合中取出符合自己下标 mode n = 0的合约并发送请求
:param i:
:return:
"""
# 版本2
try:
api_count = len(self.api_dict)
last_dt = datetime.now()
last_minute = None
self.gateway.write_log(u'开始运行tdx[{}],{}'.format(i, last_dt))
while self.connection_status:
dt = datetime.now()
# 每个自然分钟的1~5秒进行
if last_minute == dt.minute or 1 < dt.second < 5:
continue
last_minute = dt.minute
symbols = set()
for idx, tdx_symbol in enumerate(list(self.registed_symbol_set)):
# self.gateway.write_log(u'tdx[{}], api_count:{}, idx:{}, tdx_symbol:{}'.format(i, api_count, idx, tdx_symbol))
if idx % api_count == i:
try:
symbols.add(tdx_symbol)
self.processReq(tdx_symbol, i)
except BrokenPipeError as bex:
self.gateway.write_error(u'BrokenPipeError{},重试重连tdx[{}]'.format(str(bex), i))
self.reconnect(i)
sleep(5)
break
except Exception as ex:
self.gateway.write_error(
u'tdx[{}] exception:{},{}'.format(i, str(ex), traceback.format_exc()))
self.gateway.write_error(u'重试重连tdx[{}]'.format(i))
print(u'重试重连tdx[{}]'.format(i), file=sys.stderr)
self.reconnect(i)
# self.gateway.write_log(u'tdx[{}] sleep'.format(i))
sleep(self.req_interval)
if last_dt.minute != dt.minute:
self.gateway.write_log('tdx[{}] check point. {}, process symbols:{}'.format(i, dt, symbols))
last_dt = dt
except Exception as ex:
self.gateway.write_error(u'tdx[{}] pool.run exception:{},{}'.format(i, str(ex), traceback.format_exc()))
self.gateway.write_error(u'tdx[{}] {}退出'.format(i, datetime.now()))
def processReq(self, req, i):
"""
处理行情信息bar请求
:param req:
:param i:
:return:
"""
symbol = req
if '.' in symbol:
symbol, exchange = symbol.split('.')
if exchange == 'SZSE':
market_id = 0
else:
market_id = 1
else:
market_id = get_tdx_market_code(symbol)
exchange = get_stock_exchange(symbol)
exchange = Exchange(exchange)
api = self.api_dict.get(i, None)
if api is None:
self.gateway.write_log(u'tdx[{}] Api is None'.format(i))
raise Exception(u'tdx[{}] Api is None'.format(i))
symbol_config = self.symbol_dict.get('{}_{}'.format(symbol, market_id), {})
decimal_point = symbol_config.get('decimal_point', 2)
rt_list = api.get_security_bars(
category=8,
market=market_id,
code=symbol,
start=0,
count=1)
if rt_list is None or len(rt_list) == 0:
self.gateway.write_log(u'tdx[{}]: rt_list为空'.format(i))
return
data = rt_list[0]
# tdx 返回bar的结束时间
bar_dt = datetime.strptime(data.get('datetime'), '%Y-%m-%d %H:%M')
# 更新api的获取bar结束时间
self.last_api_bar_dict[i] = bar_dt
if i in self.last_bar_dt:
if self.last_bar_dt[i] < bar_dt:
self.last_bar_dt[i] = bar_dt
pre_bar = self.symbol_bar_dict.get(symbol)
# 存在上一根Bar
if pre_bar and (datetime.now() - bar_dt).total_seconds() > 60:
return
# vnpy bar开始时间
bar_dt = bar_dt - timedelta(minutes=1)
bar = BarData(
gateway_name='tdx',
symbol=symbol,
exchange=exchange,
datetime=bar_dt
)
bar.trading_day = bar_dt.strftime('%Y-%m-%d')
bar.open_price = float(data['open'])
bar.high_price = float(data['high'])
bar.low_price = float(data['low'])
bar.close_price = float(data['close'])
bar.volume = float(data['vol'])
self.symbol_bar_dict[symbol] = bar
self.gateway.on_bar(deepcopy(bar))
class SubMdApi(): class SubMdApi():
""" """
RabbitMQ Subscriber 数据行情接收API RabbitMQ Subscriber 数据行情接收API
@ -282,7 +857,7 @@ class SubMdApi():
# 未有数据到达 # 未有数据到达
if self.last_tick_dt is None: if self.last_tick_dt is None:
d.update({"sub_status": False, "sub_error": u"rabbitmq未有行情数据到达"}) d.update({"sub_status": False, "sub_error": u"rabbitmq未有行情数据到达"})
else: # 有数据 else: # 有数据
# 超时5分钟以上 # 超时5分钟以上
if (dt_now - self.last_tick_dt).total_seconds() > 60 * 5: if (dt_now - self.last_tick_dt).total_seconds() > 60 * 5:

View File

@ -1 +0,0 @@
from .tora_gateway import ToraGateway

View File

@ -1,132 +0,0 @@
from typing import Dict, Tuple
from vnpy.api.tora.vntora import TORA_TSTP_D_Buy, TORA_TSTP_D_Sell, TORA_TSTP_EXD_SSE, \
TORA_TSTP_EXD_SZSE, TORA_TSTP_OPT_LimitPrice, TORA_TSTP_OST_AllTraded, TORA_TSTP_OST_Canceled, \
TORA_TSTP_OST_NoTradeQueueing, TORA_TSTP_OST_PartTradedQueueing, TORA_TSTP_OST_Unknown, \
TORA_TSTP_PID_SHBond, TORA_TSTP_PID_SHFund, TORA_TSTP_PID_SHStock, TORA_TSTP_PID_SZBond, \
TORA_TSTP_PID_SZFund, TORA_TSTP_PID_SZStock, TORA_TSTP_TC_GFD, TORA_TSTP_TC_IOC, TORA_TSTP_VC_AV
from vnpy.trader.constant import Direction, Exchange, OrderType, Product, Status
EXCHANGE_TORA2VT = {
TORA_TSTP_EXD_SSE: Exchange.SSE,
TORA_TSTP_EXD_SZSE: Exchange.SZSE,
# TORA_TSTP_EXD_HK: Exchange.SEHK,
}
EXCHANGE_VT2TORA = {v: k for k, v in EXCHANGE_TORA2VT.items()}
PRODUCT_TORA2VT = {
# 通用(内部使用)
# TORA_TSTP_PID_COMMON: 0,
# 上海股票
TORA_TSTP_PID_SHStock: Product.EQUITY,
# 上海配股配债
# TORA_TSTP_PID_SHWarrant: 0,
# 上海基金
TORA_TSTP_PID_SHFund: Product.FUND,
# 上海债券
TORA_TSTP_PID_SHBond: Product.BOND,
# 上海标准券
# TORA_TSTP_PID_SHStandard: 0,
# 上海质押式回购
# TORA_TSTP_PID_SHRepurchase: 0,
# 深圳股票
TORA_TSTP_PID_SZStock: Product.EQUITY,
# 深圳配股配债
# TORA_TSTP_PID_SZWarrant: 0,
# 深圳基金
TORA_TSTP_PID_SZFund: Product.FUND,
# 深圳债券
TORA_TSTP_PID_SZBond: Product.BOND,
# 深圳标准券
# TORA_TSTP_PID_SZStandard: 0,
# 深圳质押式回购
# TORA_TSTP_PID_SZRepurchase: 0,
}
PRODUCT_VT2TORA = {v: k for k, v in PRODUCT_TORA2VT.items()}
DIRECTION_TORA2VT = {
# 买入
TORA_TSTP_D_Buy: Direction.LONG,
# 卖出
TORA_TSTP_D_Sell: Direction.SHORT,
# # ETF申购
# TORA_TSTP_D_ETFPur: 0,
# # ETF赎回
# TORA_TSTP_D_ETFRed: 0,
# # 新股申购
# TORA_TSTP_D_IPO: 0,
# # 正回购
# TORA_TSTP_D_Repurchase: 0,
# # 逆回购
# TORA_TSTP_D_ReverseRepur: 0,
# # 开放式基金申购
# TORA_TSTP_D_OeFundPur: 0,
# # 开放式基金赎回
# TORA_TSTP_D_OeFundRed: 0,
# # 担保品划入
# TORA_TSTP_D_CollateralIn: 0,
# # 担保品划出
# TORA_TSTP_D_CollateralOut: 0,
# # 质押入库
# TORA_TSTP_D_PledgeIn: 0,
# # 质押出库
# TORA_TSTP_D_PledgeOut: 0,
# # 配股配债
# TORA_TSTP_D_Rationed: 0,
# # 开放式基金拆分
# TORA_TSTP_D_Split: 0,
# # 开放式基金合并
# TORA_TSTP_D_Merge: 0,
# # 融资买入
# TORA_TSTP_D_MarginBuy: 0,
# # 融券卖出
# TORA_TSTP_D_ShortSell: 0,
# # 卖券还款
# TORA_TSTP_D_SellRepayment: 0,
# # 买券还券
# TORA_TSTP_D_BuyRepayment: 0,
# # 还券划转
# TORA_TSTP_D_SecurityRepay: 0,
# # 余券划转
# TORA_TSTP_D_RemainTransfer: 0,
# # 债转股
# TORA_TSTP_D_BondConvertStock: 0,
# # 债券回售
# TORA_TSTP_D_BondPutback: 0,
}
DIRECTION_VT2TORA = {v: k for k, v in DIRECTION_TORA2VT.items()}
# OrderType-> (OrderPriceType, TimeCondition, VolumeCondition)
ORDER_TYPE_VT2TORA: Dict[OrderType, Tuple[str, str, str]] = {
OrderType.FOK: (TORA_TSTP_OPT_LimitPrice, TORA_TSTP_TC_IOC, TORA_TSTP_VC_AV),
OrderType.FAK: (TORA_TSTP_OPT_LimitPrice, TORA_TSTP_TC_IOC, TORA_TSTP_VC_AV),
OrderType.LIMIT: (TORA_TSTP_OPT_LimitPrice, TORA_TSTP_TC_GFD, TORA_TSTP_VC_AV),
}
ORDER_TYPE_TORA2VT: Dict[Tuple[str, str, str], OrderType] = {
v: k for k, v in ORDER_TYPE_VT2TORA.items()
}
ORDER_STATUS_TORA2VT: Dict[str, Status] = {
# 全部成交
TORA_TSTP_OST_AllTraded: Status.ALLTRADED,
# 部分成交还在队列中
TORA_TSTP_OST_PartTradedQueueing: Status.PARTTRADED,
# 部分成交不在队列中
# TORA_TSTP_OST_PartTradedNotQueueing: _,
# 未成交还在队列中
TORA_TSTP_OST_NoTradeQueueing: Status.NOTTRADED,
# 未成交不在队列中
# TORA_TSTP_OST_NoTradeNotQueueing: _,
# 撤单
TORA_TSTP_OST_Canceled: Status.CANCELLED,
# 未知
TORA_TSTP_OST_Unknown: Status.NOTTRADED, # todo: unknown status???
# 尚未触发
# TORA_TSTP_OST_NotTouched: _,
# 已触发
# TORA_TSTP_OST_Touched: _,
# 预埋
# TORA_TSTP_OST_Cached: _,
}

View File

@ -1,19 +0,0 @@
error_codes = {
0: "没有错误",
-1: "TCP连接没建立",
-2: "交互通道无效",
-3: "用户未登录",
-4: "非本前置会话不能订阅私有流",
-5: "重复的私有流订阅请求",
-6: "打开私有流文件失败",
-7: "内部通信错误",
-8: "创建会话通道失败",
-9: "超出流控限制",
}
def get_error_msg(error_code: int):
try:
return error_codes[error_code]
except KeyError:
return "未知错误"

View File

@ -1,176 +0,0 @@
from datetime import datetime
from typing import Any, List, Optional
from vnpy.api.tora.vntora import (CTORATstpMarketDataField, CTORATstpMdApi, CTORATstpMdSpi,
CTORATstpRspInfoField, CTORATstpRspUserLoginField,
CTORATstpUserLogoutField)
from vnpy.gateway.tora.error_codes import get_error_msg
from vnpy.trader.constant import Exchange
from vnpy.trader.gateway import BaseGateway
from vnpy.trader.object import (
TickData,
)
from .constant import EXCHANGE_TORA2VT, EXCHANGE_VT2TORA
def parse_datetime(date: str, time: str):
# sampled :
# date: '20190611'
# time: '16:28:24'
return datetime.strptime(f'{date}-{time}', "%Y%m%d-%H:%M:%S")
class ToraMdSpi(CTORATstpMdSpi):
""""""
def __init__(self, api: "ToraMdApi", gateway: "BaseGateway"):
""""""
super().__init__()
self.gateway = gateway
self._api = api
def OnFrontConnected(self) -> Any:
""""""
self.gateway.write_log("行情服务器连接成功")
def OnFrontDisconnected(self, error_code: int) -> Any:
""""""
self.gateway.write_log(
f"行情服务器连接断开({error_code}):{get_error_msg(error_code)}")
def OnRspError(
self, error_info: CTORATstpRspInfoField, request_id: int, is_last: bool
) -> Any:
""""""
error_id = error_info.ErrorID
error_msg = error_info.ErrorMsg
self.gateway.write_log(f"行情服务收到错误消息({error_id}){error_msg}")
def OnRspUserLogin(
self,
info: CTORATstpRspUserLoginField,
error_info: CTORATstpRspInfoField,
request_id: int,
is_last: bool,
) -> Any:
""""""
error_id = error_info.ErrorID
if error_id != 0:
error_msg = error_info.ErrorMsg
self.gateway.write_log(f"行情服务登录失败({error_id}){error_msg}")
return
self.gateway.write_log("行情服务器登录成功")
def OnRspUserLogout(
self,
info: CTORATstpUserLogoutField,
error_info: CTORATstpRspInfoField,
request_id: int,
is_last: bool,
) -> Any:
""""""
error_id = error_info.ErrorID
if error_id != 0:
error_msg = error_info.ErrorMsg
self.gateway.write_log(f"行情服务登出失败({error_id}){error_msg}")
return
self.gateway.write_log("行情服务器登出成功")
def OnRtnDepthMarketData(self, data: CTORATstpMarketDataField) -> Any:
""""""
if data.ExchangeID not in EXCHANGE_TORA2VT:
return
tick_data = TickData(
gateway_name=self.gateway.gateway_name,
symbol=data.SecurityID,
exchange=EXCHANGE_TORA2VT[data.ExchangeID],
datetime=parse_datetime(data.TradingDay, data.UpdateTime),
name=data.SecurityName,
volume=0,
last_price=data.LastPrice,
last_volume=data.Volume, # to verify: is this correct?
limit_up=data.UpperLimitPrice,
limit_down=data.LowerLimitPrice,
open_price=data.OpenPrice,
high_price=data.HighestPrice,
low_price=data.LowestPrice,
pre_close=data.PreClosePrice,
bid_price_1=data.BidPrice1,
bid_price_2=data.BidPrice2,
bid_price_3=data.BidPrice3,
bid_price_4=data.BidPrice4,
bid_price_5=data.BidPrice5,
ask_price_1=data.AskPrice1,
ask_price_2=data.AskPrice2,
ask_price_3=data.AskPrice3,
ask_price_4=data.AskPrice4,
ask_price_5=data.AskPrice5,
bid_volume_1=data.BidVolume1,
bid_volume_2=data.BidVolume2,
bid_volume_3=data.BidVolume3,
bid_volume_4=data.BidVolume4,
bid_volume_5=data.BidVolume5,
ask_volume_1=data.AskVolume1,
ask_volume_2=data.AskVolume2,
ask_volume_3=data.AskVolume3,
ask_volume_4=data.AskVolume4,
ask_volume_5=data.AskVolume5,
)
self.gateway.on_tick(tick_data)
class ToraMdApi:
""""""
def __init__(self, gateway: BaseGateway):
""""""
self.gateway = gateway
self.md_address = ""
self._native_api: Optional[CTORATstpMdApi] = None
self._spi: Optional["ToraMdApi"] = None
def stop(self):
"""
:note not thread-safe
"""
if self._native_api:
self._native_api.RegisterSpi(None)
self._spi = None
self._native_api.Release()
self._native_api = None
def join(self):
"""
:note not thread-safe
"""
if self._native_api:
self._native_api.Join()
def connect(self):
"""
:note not thread-safe
"""
self._native_api = CTORATstpMdApi.CreateTstpMdApi()
self._spi = ToraMdSpi(self, self.gateway)
self._native_api.RegisterSpi(self._spi)
self._native_api.RegisterFront(self.md_address)
self._native_api.Init()
return True
def subscribe(self, symbols: List[str], exchange: Exchange):
""""""
err = self._native_api.SubscribeMarketData(
symbols, EXCHANGE_VT2TORA[exchange])
self._if_error_write_log(err, "subscribe")
def _if_error_write_log(self, error_code: int, function_name: str):
""""""
if error_code != 0:
error_msg = get_error_msg(error_code)
msg = f'在执行 {function_name} 时发生错误({error_code}): {error_msg}'
self.gateway.write_log(msg)
return True

View File

@ -1,611 +0,0 @@
import functools
from dataclasses import dataclass, field
from datetime import datetime
from typing import Dict, Optional
from vnpy.api.tora.vntora import (CTORATstpConditionOrderField, CTORATstpInputOrderActionField,
CTORATstpInputOrderField, CTORATstpInvestorField,
CTORATstpOrderActionField, CTORATstpOrderField,
CTORATstpPositionField, CTORATstpQryExchangeField,
CTORATstpQryInvestorField, CTORATstpQryMarketDataField,
CTORATstpQryOrderField, CTORATstpQryPositionField,
CTORATstpQrySecurityField, CTORATstpQryShareholderAccountField,
CTORATstpQryTradeField, CTORATstpQryTradingAccountField,
CTORATstpReqUserLoginField, CTORATstpRspInfoField,
CTORATstpRspUserLoginField, CTORATstpSecurityField,
CTORATstpShareholderAccountField, CTORATstpTradeField,
CTORATstpTraderApi, CTORATstpTraderSpi,
CTORATstpTradingAccountField, TORA_TE_RESUME_TYPE,
TORA_TSTP_AF_Delete, TORA_TSTP_FCC_NotForceClose,
TORA_TSTP_HF_Speculation, TORA_TSTP_LACT_AccountID,
TORA_TSTP_OF_Open, TORA_TSTP_OPERW_PCClient)
from vnpy.event import EVENT_TIMER
from vnpy.trader.constant import Direction, Exchange, Offset, OrderType, Status
from vnpy.trader.gateway import BaseGateway
from vnpy.trader.object import AccountData, CancelRequest, ContractData, OrderData, OrderRequest, \
PositionData, TradeData
from vnpy.trader.utility import get_folder_path
from .constant import DIRECTION_TORA2VT, DIRECTION_VT2TORA, EXCHANGE_TORA2VT, EXCHANGE_VT2TORA, \
ORDER_STATUS_TORA2VT, ORDER_TYPE_TORA2VT, ORDER_TYPE_VT2TORA, PRODUCT_TORA2VT
from .error_codes import get_error_msg
def _check_error(none_return: bool = True,
error_return: bool = True,
write_log: bool = True,
print_function_name: bool = False):
"""
:param none_return: return if info is None
:param error_return: return if errors
:param write_log: write_log if errors
:param print_function_name print function name for every entry of this wrapper
"""
def wrapper(func):
@functools.wraps(func)
def wrapped(self, info, error_info, *args):
function_name = func.__name__
if print_function_name:
print(function_name, "info" if info else "None",
error_info.ErrorID)
# print if errors
error_code = error_info.ErrorID
if error_code != 0:
error_msg = error_info.ErrorMsg
msg = f'{function_name} 中收到错误({error_code}){error_msg}'
if write_log:
self.gateway.write_log(msg)
if error_return:
return
# return if flag is set
if none_return and info is None:
return
# call original function
func(self, info, error_info, *args)
return wrapped
return wrapper
class QueryLoop:
""""""
def __init__(self, gateway: "BaseGateway"):
""""""
self.event_engine = gateway.event_engine
self._seconds_left = 0
self._query_functions = [gateway.query_account, gateway.query_position]
def start(self):
""""""
self.event_engine.register(EVENT_TIMER, self._process_timer_event)
def stop(self):
""""""
self.event_engine.unregister(EVENT_TIMER, self._process_timer_event)
def _process_timer_event(self, event):
""""""
if self._seconds_left != 0:
self._seconds_left -= 1
return
# do a query every 2 seconds.
self._seconds_left = 2
# get the last one and re-queue it
# works fine if there is no so much items
func = self._query_functions.pop(0)
self._query_functions.append(func)
# call it
func()
OrdersType = Dict[str, "OrderInfo"]
class ToraTdSpi(CTORATstpTraderSpi):
""""""
def __init__(self, session_info: "SessionInfo",
api: "ToraTdApi",
gateway: "BaseGateway",
orders: OrdersType):
""""""
super().__init__()
self.session_info = session_info
self.gateway = gateway
self.orders = orders
self._api: "ToraTdApi" = api
def OnRtnTrade(self, info: CTORATstpTradeField) -> None:
""""""
try:
trade_data = TradeData(
gateway_name=self.gateway.gateway_name,
symbol=info.SecurityID,
exchange=EXCHANGE_TORA2VT[info.ExchangeID],
orderid=info.OrderRef,
tradeid=info.TradeID,
direction=DIRECTION_TORA2VT[info.Direction],
offset=Offset.OPEN,
price=info.Price,
volume=info.Volume,
time=info.TradeTime,
)
self.gateway.on_trade(trade_data)
except KeyError:
return
def OnRtnOrder(self, info: CTORATstpOrderField) -> None:
""""""
self._api.update_last_local_order_id(int(info.OrderRef))
try:
order_data = self.parse_order_field(info)
except KeyError:
return
order_data.status = ORDER_STATUS_TORA2VT[info.OrderStatus]
self.orders[info.OrderRef] = OrderInfo(local_order_id=info.OrderRef,
exchange_id=info.ExchangeID,
session_id=info.SessionID,
front_id=info.FrontID,
)
self.gateway.on_order(order_data)
@_check_error(error_return=False, write_log=False, print_function_name=False)
def OnErrRtnOrderInsert(self, info: CTORATstpInputOrderField,
error_info: CTORATstpRspInfoField) -> None:
""""""
try:
self._api.update_last_local_order_id(int(info.OrderRef))
except ValueError:
pass
try:
order_data = self.parse_order_field(info)
except KeyError:
# no prints here because we don't care about insertion failure.
return
order_data.status = Status.REJECTED
self.gateway.on_order(order_data)
self.gateway.write_log(f"拒单({order_data.orderid}):"
f"错误码:{error_info.ErrorID}, 错误消息:{error_info.ErrorMsg}")
@_check_error(error_return=False, write_log=False, print_function_name=False)
def OnErrRtnOrderAction(self, info: CTORATstpOrderActionField,
error_info: CTORATstpRspInfoField) -> None:
""""""
pass
@_check_error()
def OnRtnCondOrder(self, info: CTORATstpConditionOrderField) -> None:
""""""
pass
@_check_error()
def OnRspOrderAction(self, info: CTORATstpInputOrderActionField,
error_info: CTORATstpRspInfoField, request_id: int, is_last: bool) -> None:
pass
@_check_error()
def OnRspOrderInsert(self, info: CTORATstpInputOrderField,
error_info: CTORATstpRspInfoField, request_id: int, is_last: bool) -> None:
""""""
try:
order_data = self.parse_order_field(info)
except KeyError:
self.gateway.write_log(f"收到无法识别的下单回执({info.OrderRef})")
return
self.gateway.on_order(order_data)
# @_check_error(print_function_name=False)
# def OnRspQryTrade(self, info: CTORATstpTradeField, error_info: CTORATstpRspInfoField,
# request_id: int, is_last: bool) -> None:
# return
#
# @_check_error(print_function_name=False)
# def OnRspQryOrder(self, info: CTORATstpOrderField, error_info: CTORATstpRspInfoField,
# request_id: int, is_last: bool) -> None:
# order_data = self.parse_order_field(info)
# self.gateway.on_order(order_data)
@_check_error(print_function_name=False)
def OnRspQryPosition(self, info: CTORATstpPositionField, error_info: CTORATstpRspInfoField,
request_id: int, is_last: bool) -> None:
""""""
if info.InvestorID != self.session_info.investor_id:
self.gateway.write_log("OnRspQryPosition:收到其他账户的仓位信息")
return
if info.ExchangeID not in EXCHANGE_TORA2VT:
self.gateway.write_log(
f"OnRspQryPosition:忽略不支持的交易所:{info.ExchangeID}")
return
volume = info.CurrentPosition
frozen = info.HistoryPosFrozen + info.TodayBSFrozen + \
info.TodayPRFrozen + info.TodaySMPosFrozen
position_data = PositionData(
gateway_name=self.gateway.gateway_name,
symbol=info.SecurityID,
exchange=EXCHANGE_TORA2VT[info.ExchangeID],
direction=Direction.NET,
volume=volume, # verify this: which one should vnpy use?
frozen=frozen, # verify this: which one should i use?
price=info.TotalPosCost / volume,
# verify this: is this formula correct
pnl=info.LastPrice * volume - info.TotalPosCost,
yd_volume=info.HistoryPos,
)
self.gateway.on_position(position_data)
@_check_error(print_function_name=False)
def OnRspQryTradingAccount(self, info: CTORATstpTradingAccountField,
error_info: CTORATstpRspInfoField, request_id: int,
is_last: bool) -> None:
""""""
self.session_info.account_id = info.AccountID
account_data = AccountData(
gateway_name=self.gateway.gateway_name,
accountid=info.AccountID,
balance=info.Available,
frozen=info.FrozenCash + info.FrozenMargin + info.FrozenCommission
)
self.gateway.on_account(account_data)
@_check_error()
def OnRspQryShareholderAccount(self, info: CTORATstpShareholderAccountField,
error_info: CTORATstpRspInfoField, request_id: int,
is_last: bool) -> None:
""""""
exchange = EXCHANGE_TORA2VT[info.ExchangeID]
self.session_info.shareholder_ids[exchange] = info.ShareholderID
@_check_error(print_function_name=False)
def OnRspQryInvestor(self, info: CTORATstpInvestorField, error_info: CTORATstpRspInfoField,
request_id: int, is_last: bool) -> None:
""""""
self.session_info.investor_id = info.InvestorID
@_check_error(none_return=False, print_function_name=False)
def OnRspQrySecurity(self, info: CTORATstpSecurityField, error_info: CTORATstpRspInfoField,
request_id: int, is_last: bool) -> None:
""""""
if is_last:
self.gateway.write_log("合约信息查询成功")
if not info:
return
if info.ProductID not in PRODUCT_TORA2VT:
return
if info.ExchangeID not in EXCHANGE_TORA2VT:
return
contract_data = ContractData(
gateway_name=self.gateway.gateway_name,
symbol=info.SecurityID,
exchange=EXCHANGE_TORA2VT[info.ExchangeID],
name=info.SecurityName,
product=PRODUCT_TORA2VT[info.ProductID],
size=info.VolumeMultiple, # to verify
pricetick=info.PriceTick,
min_volume=info.MinLimitOrderBuyVolume, # verify: buy?sell
stop_supported=False,
net_position=True,
history_data=False,
)
self.gateway.on_contract(contract_data)
def OnFrontConnected(self) -> None:
""""""
self.gateway.write_log("交易服务器连接成功")
self._api.login()
@_check_error(print_function_name=False)
def OnRspUserLogin(self, info: CTORATstpRspUserLoginField,
error_info: CTORATstpRspInfoField, request_id: int, is_last: bool) -> None:
""""""
self._api.update_last_local_order_id(int(info.MaxOrderRef))
self.session_info.front_id = info.FrontID
self.session_info.session_id = info.SessionID
self.gateway.write_log("交易服务器登录成功")
self._api.query_initialize_status()
self._api.start_query_loop() # stop at ToraTdApi.stop()
def OnFrontDisconnected(self, error_code: int) -> None:
""""""
self.gateway.write_log(
f"交易服务器连接断开({error_code}):{get_error_msg(error_code)}")
def parse_order_field(self, info):
"""
:raise KeyError
:param info:
:return:
"""
opt = info.OrderPriceType
tc = info.TimeCondition
vc = info.VolumeCondition
order_type = ORDER_TYPE_TORA2VT[(opt, tc, vc)]
order_data = OrderData(
gateway_name=self.gateway.gateway_name,
symbol=info.SecurityID,
exchange=EXCHANGE_TORA2VT[info.ExchangeID],
orderid=info.OrderRef,
type=order_type,
direction=DIRECTION_TORA2VT[info.Direction],
offset=Offset.OPEN,
price=info.LimitPrice,
volume=info.VolumeTotalOriginal,
traded=0,
status=Status.NOTTRADED,
time=datetime.now().isoformat() # note: server doesn't provide a timestamp
)
return order_data
class ToraTdApi:
def __init__(self, gateway: BaseGateway):
""""""
self.gateway = gateway
self.username = ""
self.password = ""
self.td_address = ""
self.session_info: "SessionInfo" = SessionInfo()
self.orders: OrdersType = {}
self._native_api: Optional["CTORATstpTraderApi"] = None
self._spi: Optional["ToraTdSpi"] = None
self._query_loop: Optional["QueryLoop"] = None
self._last_req_id = 0
self._next_local_order_id = int(1e5)
def get_shareholder_id(self, exchange: Exchange):
""""""
return self.session_info.shareholder_ids[exchange]
def update_last_local_order_id(self, new_val: int):
""""""
cur = self._next_local_order_id
self._next_local_order_id = max(cur, new_val + 1)
def _if_error_write_log(self, error_code: int, function_name: str):
""""""
if error_code != 0:
error_msg = get_error_msg(error_code)
msg = f'在执行 {function_name} 时发生错误({error_code}): {error_msg}'
self.gateway.write_log(msg)
return True
def _get_new_req_id(self):
""""""
req_id = self._last_req_id
self._last_req_id += 1
return req_id
def _get_new_order_id(self) -> str:
""""""
order_id = self._next_local_order_id
self._next_local_order_id += 1
return str(order_id)
def query_contracts(self):
""""""
info = CTORATstpQrySecurityField()
err = self._native_api.ReqQrySecurity(info, self._get_new_req_id())
self._if_error_write_log(err, "query_contracts")
def query_exchange(self, exchange: Exchange):
""""""
info = CTORATstpQryExchangeField()
info.ExchangeID = EXCHANGE_VT2TORA[exchange]
err = self._native_api.ReqQryExchange(info, self._get_new_req_id())
self._if_error_write_log(err, "query_exchange")
def query_market_data(self, symbol: str, exchange: Exchange):
""""""
info = CTORATstpQryMarketDataField()
info.ExchangeID = EXCHANGE_VT2TORA[exchange]
info.SecurityID = symbol
err = self._native_api.ReqQryMarketData(info, self._get_new_req_id())
self._if_error_write_log(err, "query_market_data")
def stop(self):
""""""
self.stop_query_loop()
if self._native_api:
self._native_api.RegisterSpi(None)
self._spi = None
self._native_api.Release()
self._native_api = None
def join(self):
if self._native_api:
self._native_api.Join()
def login(self):
"""
send login request using self.username, self.password
:return:
"""
info = CTORATstpReqUserLoginField()
info.LogInAccount = self.username
info.LogInAccountType = TORA_TSTP_LACT_AccountID
info.Password = self.password
self._native_api.ReqUserLogin(info, self._get_new_req_id())
def connect(self):
"""
connect to self.td_address using self.username, self.password
:return:
"""
flow_path = str(get_folder_path(self.gateway.gateway_name.lower()))
self._native_api = CTORATstpTraderApi.CreateTstpTraderApi(
flow_path, True)
self._spi = ToraTdSpi(self.session_info, self,
self.gateway, self.orders)
self._native_api.RegisterSpi(self._spi)
self._native_api.RegisterFront(self.td_address)
self._native_api.SubscribePublicTopic(
TORA_TE_RESUME_TYPE.TORA_TERT_RESTART)
self._native_api.SubscribePrivateTopic(
TORA_TE_RESUME_TYPE.TORA_TERT_RESTART)
self._native_api.Init()
return True
def send_order(self, req: OrderRequest):
""""""
if req.type is OrderType.STOP:
raise NotImplementedError()
if req.type is OrderType.FAK or req.type is OrderType.FOK:
assert req.exchange is Exchange.SZSE
order_id = self._get_new_order_id()
info = CTORATstpInputOrderField()
info.InvestorID = self.session_info.investor_id
info.SecurityID = req.symbol
info.OrderRef = order_id
info.ShareholderID = self.get_shareholder_id(req.exchange)
info.ExchangeID = EXCHANGE_VT2TORA[req.exchange]
info.Direction = DIRECTION_VT2TORA[req.direction]
info.CombOffsetFlag = TORA_TSTP_OF_Open
info.CombHedgeFlag = TORA_TSTP_HF_Speculation
if req.type is not OrderType.MARKET:
info.LimitPrice = req.price
info.VolumeTotalOriginal = int(req.volume)
opt, tc, vc = ORDER_TYPE_VT2TORA[req.type]
info.OrderPriceType = opt
info.TimeCondition = tc
info.VolumeCondition = vc
# info.MinVolume = 0 # 当成交量类型为MV时有效
info.ForceCloseReason = TORA_TSTP_FCC_NotForceClose
# info.IsSwapOrder = 0 # 不支持互换单
# info.UserForceClose = 0 # 用户强评标志
info.Operway = TORA_TSTP_OPERW_PCClient # 委托方式PC端委托
self.orders[order_id] = OrderInfo(info.OrderRef,
info.ExchangeID,
self.session_info.session_id,
self.session_info.front_id,
)
self.gateway.on_order(req.create_order_data(
order_id, self.gateway.gateway_name))
# err = self._native_api.ReqCondOrderInsert(info, self._get_new_req_id())
err = self._native_api.ReqOrderInsert(info, self._get_new_req_id())
self._if_error_write_log(err, "send_order:ReqOrderInsert")
def cancel_order(self, req: CancelRequest):
""""""
info = CTORATstpInputOrderActionField()
info.InvestorID = self.session_info.investor_id
# 没有的话:(16608)VIP:未知的交易所代码
info.ExchangeID = EXCHANGE_VT2TORA[req.exchange]
info.SecurityID = req.symbol
# info.OrderActionRef = str(self._get_new_req_id())
order_info = self.orders[req.orderid]
info.OrderRef = req.orderid
info.FrontID = order_info.front_id
info.SessionID = order_info.session_id
info.ActionFlag = TORA_TSTP_AF_Delete # (12673)VIP:撤单与原报单信息不符
# info.ActionFlag = TORA_TSTP_AF_ForceDelete # (12368)VIP:报单状态异常
err = self._native_api.ReqOrderAction(info, self._get_new_req_id())
self._if_error_write_log(err, "cancel_order:ReqOrderAction")
def query_initialize_status(self):
""""""
self.query_contracts()
self.query_investors()
self.query_shareholder_ids()
self.query_accounts()
self.query_positions()
self.query_orders()
self.query_trades()
def query_accounts(self):
""""""
info = CTORATstpQryTradingAccountField()
err = self._native_api.ReqQryTradingAccount(
info, self._get_new_req_id())
self._if_error_write_log(err, "query_accounts")
def query_shareholder_ids(self):
""""""
info = CTORATstpQryShareholderAccountField()
err = self._native_api.ReqQryShareholderAccount(
info, self._get_new_req_id())
self._if_error_write_log(err, "query_shareholder_ids")
def query_investors(self):
""""""
info = CTORATstpQryInvestorField()
err = self._native_api.ReqQryInvestor(info, self._get_new_req_id())
self._if_error_write_log(err, "query_investors")
def query_positions(self):
""""""
info = CTORATstpQryPositionField()
err = self._native_api.ReqQryPosition(info, self._get_new_req_id())
self._if_error_write_log(err, "query_positions")
def query_orders(self):
""""""
info = CTORATstpQryOrderField()
err = self._native_api.ReqQryOrder(info, self._get_new_req_id())
self._if_error_write_log(err, "query_orders")
def query_trades(self):
""""""
info = CTORATstpQryTradeField()
err = self._native_api.ReqQryTrade(info, self._get_new_req_id())
self._if_error_write_log(err, "query_trades")
def start_query_loop(self):
""""""
if not self._query_loop:
self._query_loop = QueryLoop(self.gateway)
self._query_loop.start()
def stop_query_loop(self):
""""""
if self._query_loop:
self._query_loop.stop()
self._query_loop = None
@dataclass()
class OrderInfo:
local_order_id: str
exchange_id: str
session_id: int
front_id: int
@dataclass()
class SessionInfo:
investor_id: str = None # one investor per session
shareholder_ids: Dict[Exchange, str] = field(
default_factory=dict) # one share holder id per exchange
account_id: str = None # one account per session
front_id: int = None
session_id: int = None

View File

@ -1,94 +0,0 @@
"""
author: nanoric
TODO:
* Linux support
"""
from vnpy.api.tora.vntora import (
AsyncDispatchException, set_async_callback_exception_handler)
from vnpy.event import EventEngine
from vnpy.trader.gateway import BaseGateway
from vnpy.trader.object import (CancelRequest, OrderRequest, SubscribeRequest)
from .constant import EXCHANGE_VT2TORA
from .md import ToraMdApi
from .td import ToraTdApi
def is_valid_front_address(address: str):
return address.startswith("tcp://") or address.startswith("udp://")
class ToraGateway(BaseGateway):
""""""
default_setting = {
"账号": "",
"密码": "",
"交易服务器": "",
"行情服务器": "",
}
exchanges = list(EXCHANGE_VT2TORA.keys())
def __init__(self, event_engine: EventEngine):
""""""
super().__init__(event_engine, "TORA")
self._md_api = ToraMdApi(self)
self._td_api = ToraTdApi(self)
set_async_callback_exception_handler(
self._async_callback_exception_handler)
def connect(self, setting: dict):
""""""
username = setting['账号']
password = setting['密码']
td_address = setting["交易服务器"]
md_address = setting["行情服务器"]
if not is_valid_front_address(td_address):
td_address = "tcp://" + td_address
if not is_valid_front_address(md_address):
md_address = "tcp://" + md_address
self._md_api.md_address = md_address
self._md_api.connect()
self._td_api.username = username
self._td_api.password = password
self._td_api.td_address = td_address
self._td_api.connect()
def close(self):
""""""
self._md_api.stop()
self._td_api.stop()
self._md_api.join()
self._td_api.join()
def subscribe(self, req: SubscribeRequest):
""""""
self._md_api.subscribe([req.symbol], req.exchange)
def send_order(self, req: OrderRequest) -> str:
""""""
return self._td_api.send_order(req)
def cancel_order(self, req: CancelRequest):
""""""
self._td_api.cancel_order(req)
def query_account(self):
""""""
self._td_api.query_accounts()
def query_position(self):
""""""
self._td_api.query_positions()
def _async_callback_exception_handler(self, e: AsyncDispatchException):
error_str = f"发生内部错误:\n" f"位置:{e.instance}.{e.function_name}" f"详细信息:{e.what}"
self.write_log(error_str)

View File

@ -140,7 +140,7 @@ class BaseGateway(ABC):
"""市场行情推送""" """市场行情推送"""
# bar, 或者 barDict # bar, 或者 barDict
self.on_event(EVENT_BAR, bar) self.on_event(EVENT_BAR, bar)
self.write_log(f'on_bar Event:{bar.__dict__}') #self.write_log(f'on_bar Event:{bar.__dict__}')
def on_trade(self, trade: TradeData) -> None: def on_trade(self, trade: TradeData) -> None:
""" """