From aa7f45e9290e622560a09027784d99a8efaf217d Mon Sep 17 00:00:00 2001 From: msincenselee Date: Mon, 13 Jan 2020 14:48:24 +0800 Subject: [PATCH] =?UTF-8?q?[bug=20fix]=20=E8=87=AA=E5=AE=9A=E4=B9=89?= =?UTF-8?q?=E5=90=88=E7=BA=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- vnpy/gateway/ctp/ctp_gateway.py | 198 +++++++++++++++++++++++++++----- vnpy/trader/engine.py | 38 ++++++ vnpy/trader/gateway.py | 17 +++ vnpy/trader/object.py | 2 + vnpy/trader/utility.py | 4 + 5 files changed, 232 insertions(+), 27 deletions(-) diff --git a/vnpy/gateway/ctp/ctp_gateway.py b/vnpy/gateway/ctp/ctp_gateway.py index 0dd9eea1..01f24d37 100644 --- a/vnpy/gateway/ctp/ctp_gateway.py +++ b/vnpy/gateway/ctp/ctp_gateway.py @@ -3,7 +3,7 @@ import traceback import json from datetime import datetime, timedelta -from copy import copy,deepcopy +from copy import copy, deepcopy from vnpy.api.ctp import ( MdApi, @@ -61,10 +61,12 @@ from vnpy.trader.object import ( SubscribeRequest, ) from vnpy.trader.utility import ( + extract_vt_symbol, get_folder_path, get_trading_date, get_underlying_symbol, - round_to + round_to, + BarGenerator ) from vnpy.trader.event import EVENT_TIMER @@ -121,7 +123,9 @@ EXCHANGE_CTP2VT = { "SHFE": Exchange.SHFE, "CZCE": Exchange.CZCE, "DCE": Exchange.DCE, - "INE": Exchange.INE + "INE": Exchange.INE, + "SPD": Exchange.SPD + } PRODUCT_CTP2VT = { @@ -142,6 +146,7 @@ index_contracts = {} # tdx 期货配置本地缓存 future_contracts = get_future_contracts() + class CtpGateway(BaseGateway): """ VN Trader Gateway for CTP . @@ -170,11 +175,13 @@ class CtpGateway(BaseGateway): """Constructor""" super().__init__(event_engine, "CTP") - self.td_api = CtpTdApi(self) - self.md_api = CtpMdApi(self) + self.td_api = None + self.md_api = None self.tdx_api = None self.rabbit_api = None + self.subscribed_symbols = set() # 已订阅合约代码 + self.combiner_conf_dict = {} # 保存合成器配置 # 自定义价差/加比的tick合成器 self.combiners = {} @@ -203,7 +210,20 @@ class CtpGateway(BaseGateway): ): md_address = "tcp://" + md_address + # 获取自定义价差/价比合约的配置 + try: + from vnpy.trader.engine import CustomContract + c = CustomContract() + self.combiner_conf_dict = c.get_config() + if len(self.combiner_conf_dict) > 0: + self.write_log(u'加载的自定义价差/价比配置:{}'.format(self.combiner_conf_dict)) + except Exception as ex: + pass + if not self.td_api: + self.td_api = CtpTdApi(self) self.td_api.connect(td_address, userid, password, brokerid, auth_code, appid, product_info) + if not self.md_api: + self.md_api = CtpMdApi(self) self.md_api.connect(md_address, userid, password, brokerid) if rabbit_dict: @@ -215,17 +235,119 @@ class CtpGateway(BaseGateway): self.init_query() + for (vt_symbol, is_bar) in self.subscribed_symbols: + symbol, exchange = extract_vt_symbol(vt_symbol) + req = SubscribeRequest( + symbol=symbol, + exchange=exchange, + is_bar=is_bar + ) + # 指数合约,从tdx行情订阅 + if req.symbol[-2:] in ['99']: + req.symbol = req.symbol.upper() + if self.tdx_api is not None: + self.write_log(u'有指数订阅,连接通达信行情服务器') + self.tdx_api.connect() + self.tdx_api.subscribe(req) + elif self.rabbit_api is not None: + self.rabbit_api.subscribe(req) + else: + self.md_api.subscribe(req) + + def check_status(self): + """检查状态""" + if self.tdx_api: + self.tdx_api.check_status() + if self.tdx_api is None or self.md_api is None: + return False + + if not self.td_api.connect_status or self.md_api.connect_status: + return False + + return True + def subscribe(self, req: SubscribeRequest): """""" - # 指数合约,从tdx行情订阅 - if req.symbol[-2:] in ['99']: - req.symbol = req.symbol.upper() - if self.tdx_api: - self.tdx_api.subscribe(req) - elif self.rabbit_api: - self.rabbit_api.subscribe(req) - else: - self.md_api.subscribe(req) + try: + if self.md_api: + # 如果是自定义的套利合约符号 + if req.symbol in self.combiner_conf_dict: + self.write_log(u'订阅自定义套利合约:{}'.format(req.symbol)) + # 创建合成器 + if req.symbol not in self.combiners: + setting = self.combiner_conf_dict.get(req.symbol) + setting.update({"symbol": req.symbol}) + combiner = TickCombiner(self, setting) + # 更新合成器 + self.write_log(u'添加{}与合成器映射'.format(req.symbol)) + self.combiners.update({setting.get('symbol'): combiner}) + + # 增加映射( leg1 对应的合成器列表映射) + leg1_symbol = setting.get('leg1_symbol') + combiner_list = self.tick_combiner_map.get(leg1_symbol, []) + if combiner not in combiner_list: + self.write_log(u'添加Leg1:{}与合成器得映射'.format(leg1_symbol)) + combiner_list.append(combiner) + self.tick_combiner_map.update({leg1_symbol: combiner_list}) + + # 增加映射( leg2 对应的合成器列表映射) + leg2_symbol = setting.get('leg2_symbol') + combiner_list = self.tick_combiner_map.get(leg2_symbol, []) + if combiner not in combiner_list: + self.write_log(u'添加Leg2:{}与合成器得映射'.format(leg2_symbol)) + combiner_list.append(combiner) + self.tick_combiner_map.update({leg2_symbol: combiner_list}) + + self.write_log(u'订阅leg1:{}'.format(leg1_symbol)) + leg1_req = SubscribeRequest( + symbol=leg1_symbol, + exchange=symbol_exchange_map.get(leg1_symbol, Exchange.LOCAL) + ) + self.subscribe(leg1_req) + + self.write_log(u'订阅leg2:{}'.format(leg2_symbol)) + leg2_req = SubscribeRequest( + symbol=leg2_symbol, + exchange=symbol_exchange_map.get(leg1_symbol, Exchange.LOCAL) + ) + self.subscribe(leg2_req) + + self.subscribed_symbols.add((req.vt_symbol, req.is_bar)) + else: + self.write_log(u'{}合成器已经在存在'.format(req.symbol)) + return + elif req.exchange == Exchange.SPD: + self.write_error(u'自定义合约{}不在CTP设置中'.format(req.symbol)) + + # 指数合约,从tdx行情订阅 + if req.symbol[-2:] in ['99']: + req.symbol = req.symbol.upper() + if self.tdx_api: + self.tdx_api.subscribe(req) + elif self.rabbit_api: + self.rabbit_api.subscribe(req) + else: + self.md_api.subscribe(req) + + # Allow the strategies to start before the connection + self.subscribed_symbols.add((req.vt_symbol, req.is_bar)) + if req.is_bar: + self.subscribe_bar(req) + + except Exception as ex: + self.write_error(u'订阅合约异常:{},{}'.format(str(ex), traceback.format_exc())) + + def subscribe_bar(self, req: SubscribeRequest): + """订阅1分钟行情""" + + vt_symbol = req.vt_symbol + if vt_symbol in self.klines: + return + + # 创建1分钟bar产生器 + self.write_log(u'创建:{}的一分钟行情产生器'.format(vt_symbol)) + bg = BarGenerator(on_bar=self.on_bar) + self.klines.update({vt_symbol: bg}) def send_order(self, req: OrderRequest): """""" @@ -245,8 +367,29 @@ class CtpGateway(BaseGateway): def close(self): """""" - self.td_api.close() - self.md_api.close() + if self.md_api: + self.write_log('断开行情API') + tmp1 = self.md_api + self.md_api = None + tmp1.close() + + if self.td_api: + self.write_log('断开交易API') + tmp2 = self.td_api + self.td_api = None + tmp2.close() + + if self.tdx_api: + self.write_log(u'断开tdx指数行情API') + tmp3 = self.tdx_api + self.tdx_api = None + tmp3.close() + + if self.rabbit_api: + self.write_log(u'断开rabbit MQ tdx指数行情API') + tmp4 = self.rabbit_api + self.rabbit_api = None + tmp4.close() def process_timer_event(self, event): """""" @@ -398,6 +541,7 @@ class CtpMdApi(MdApi): tick.ask_volume_5 = data["AskVolume5"] self.gateway.on_tick(tick) + self.gateway.on_custom_tick(tick) def connect(self, address: str, userid: str, password: str, brokerid: int): """ @@ -671,7 +815,7 @@ class CtpTdApi(TdApi): if contract.product == Product.FUTURES: # 生成指数合约信息 - underlying_symbol = data["ProductID"] # 短合约名称 + underlying_symbol = data["ProductID"] # 短合约名称 underlying_symbol = underlying_symbol.upper() # 只推送普通合约的指数 if len(underlying_symbol) <= 2: @@ -689,7 +833,8 @@ class CtpTdApi(TdApi): mi_margin_rate = round(idx_contract.margin_rate, 4) if mi_contract_symbol == contract.symbol: if margin_rate != mi_margin_rate: - self.gateway.write_log(f"{underlying_symbol}合约主力{mi_contract_symbol} 保证金{margin_rate}=>{mi_margin_rate}") + self.gateway.write_log( + f"{underlying_symbol}合约主力{mi_contract_symbol} 保证金{margin_rate}=>{mi_margin_rate}") future_contract.update({'margin_rate': mi_margin_rate}) future_contract.update({'symbol_size': idx_contract.size}) future_contract.update({'price_tick': idx_contract.pricetick}) @@ -937,6 +1082,7 @@ class CtpTdApi(TdApi): if self.connect_status: self.exit() + class TdxMdApi(): """ 通达信数据行情API实现 @@ -957,8 +1103,6 @@ class TdxMdApi(): self.symbol_vn_dict = {} # tdx合约与vtSymbol的对应 self.symbol_tick_dict = {} # tdx合约与最后一个Tick得字典 - - self.registered_symbol_set = set() self.thread = None # 查询线程 @@ -1486,22 +1630,22 @@ class TickCombiner(object): return # 以下情况,基本为单腿涨跌停,不合成价差/价格比 Tick - if (self.last_leg1_tick.ask_price_1 == 0 or self.last_leg1_tick.bid_price_1 == self.last_leg1_tick.upperLimit) \ + if (self.last_leg1_tick.ask_price_1 == 0 or self.last_leg1_tick.bid_price_1 == self.last_leg1_tick.limit_up) \ and self.last_leg1_tick.ask_volume_1 == 0: self.gateway.write_log( u'leg1:{0}涨停{1},不合成价差Tick'.format(self.last_leg1_tick.vtSymbol, self.last_leg1_tick.bid_price_1)) return - if (self.last_leg1_tick.bid_price_1 == 0 or self.last_leg1_tick.ask_price_1 == self.last_leg1_tick.lowerLimit) \ + if (self.last_leg1_tick.bid_price_1 == 0 or self.last_leg1_tick.ask_price_1 == self.last_leg1_tick.limit_down) \ and self.last_leg1_tick.bid_volume_1 == 0: self.gateway.write_log( u'leg1:{0}跌停{1},不合成价差Tick'.format(self.last_leg1_tick.vtSymbol, self.last_leg1_tick.ask_price_1)) return - if (self.last_leg2_tick.ask_price_1 == 0 or self.last_leg2_tick.bid_price_1 == self.last_leg2_tick.upperLimit) \ + if (self.last_leg2_tick.ask_price_1 == 0 or self.last_leg2_tick.bid_price_1 == self.last_leg2_tick.limit_up) \ and self.last_leg2_tick.ask_volume_1 == 0: self.gateway.write_log( u'leg2:{0}涨停{1},不合成价差Tick'.format(self.last_leg2_tick.vtSymbol, self.last_leg2_tick.bid_price_1)) return - if (self.last_leg2_tick.bid_price_1 == 0 or self.last_leg2_tick.ask_price_1 == self.last_leg2_tick.lowerLimit) \ + if (self.last_leg2_tick.bid_price_1 == 0 or self.last_leg2_tick.ask_price_1 == self.last_leg2_tick.limit_down) \ and self.last_leg2_tick.bid_volume_1 == 0: self.gateway.write_log( u'leg2:{0}跌停{1},不合成价差Tick'.format(self.last_leg2_tick.vtSymbol, self.last_leg2_tick.ask_price_1)) @@ -1517,7 +1661,7 @@ class TickCombiner(object): if self.is_spread: spread_tick = TickData(gateway_name=self.gateway_name, symbol=self.symbol, - exchange=tick.exchange, + exchange=Exchange.SPD, datetime=tick.datetime) spread_tick.trading_day = tick.trading_day @@ -1563,9 +1707,9 @@ class TickCombiner(object): self.gateway.on_tick(spread_tick) if self.is_ratio: - ratio_tick = TickData(gatway_name=self.gateway_name, + ratio_tick = TickData(gateway_name=self.gateway_name, symbol=self.symbol, - exchange=tick.exchange, + exchange=Exchange.SPD, datetime=tick.datetime) ratio_tick.trading_day = tick.trading_day diff --git a/vnpy/trader/engine.py b/vnpy/trader/engine.py index 2d2f3514..21ed4920 100644 --- a/vnpy/trader/engine.py +++ b/vnpy/trader/engine.py @@ -532,6 +532,44 @@ class OmsEngine(BaseEngine): ] return active_orders +class CustomContract(object): + """ + 定制合约 + # 适用于初始化系统时,补充到本地合约信息文件中 contracts.vt + # 适用于CTP网关,加载自定义的套利合约,做内部行情撮合 + """ + # 运行本地目录下,定制合约的配置文件(dict) + file_name = 'custom_contracts.json' + + def __init__(self): + """构造函数""" + from vnpy.trader.utility import load_json + self.setting = load_json(self.file_name) # 所有设置 + + def get_config(self): + """获取配置""" + return self.setting + + def get_contracts(self): + """获取所有合约信息""" + d = {} + from vnpy.trader.object import ContractData, Exchange + for symbol, setting in self.setting.items(): + gateway_name = setting.get('gateway_name', None) + if gateway_name is None: + gateway_name= SETTINGS.get('gateway_name','') + vn_exchange = Exchange(setting.get('exchange', 'LOCAL')) + contract = ContractData( + gateway_name=gateway_name, + symbol=symbol, + name=contract.symbol, + size=setting.get('size', 100), + pricetick=setting.get('price_tick', 0.01), + margin_rate=setting.get('margin_rate', 0.1) + ) + d[contract.vt_symbol] = contract + + return d class EmailEngine(BaseEngine): """ diff --git a/vnpy/trader/gateway.py b/vnpy/trader/gateway.py index 12cc4db9..b389ac39 100644 --- a/vnpy/trader/gateway.py +++ b/vnpy/trader/gateway.py @@ -11,6 +11,7 @@ from logging import INFO, DEBUG, ERROR from vnpy.event import Event, EventEngine from .event import ( EVENT_TICK, + EVENT_BAR, EVENT_ORDER, EVENT_TRADE, EVENT_POSITION, @@ -20,6 +21,7 @@ from .event import ( ) from .object import ( TickData, + BarData, OrderData, TradeData, PositionData, @@ -60,6 +62,7 @@ class BaseGateway(ABC): --- ## callbacks must response manually: * on_tick + * on_bar * on_trade * on_order * on_position @@ -89,6 +92,9 @@ class BaseGateway(ABC): self.create_logger() + # 所有订阅on_bar的都会添加 + self.klines = {} + def create_logger(self): """ 创建engine独有的日志 @@ -116,6 +122,17 @@ class BaseGateway(ABC): self.on_event(EVENT_TICK, tick) self.on_event(EVENT_TICK + tick.vt_symbol, tick) + # 推送Bar + kline = self.klines.get(tick.vt_symbol, None) + if kline: + kline.update_tick(tick) + + def on_bar(self, bar: BarData): + """市场行情推送""" + # bar, 或者 barDict + self.on_event(EVENT_BAR, bar) + self.write_log(f'on_bar Event:{bar.__dict__}') + def on_trade(self, trade: TradeData): """ Trade event push. diff --git a/vnpy/trader/object.py b/vnpy/trader/object.py index bc3266dc..b10dedeb 100644 --- a/vnpy/trader/object.py +++ b/vnpy/trader/object.py @@ -287,6 +287,8 @@ class SubscribeRequest: """""" self.vt_symbol = f"{self.symbol}.{self.exchange.value}" + def __eq__(self, other): + return self.vt_symbol == other.vt_symbol @dataclass class OrderRequest: diff --git a/vnpy/trader/utility.py b/vnpy/trader/utility.py index 24f242fc..33dcd73c 100644 --- a/vnpy/trader/utility.py +++ b/vnpy/trader/utility.py @@ -300,6 +300,10 @@ def ceil_to(value: float, target: float) -> float: return result +def print_dict(d: dict): + """返回dict的字符串类型""" + return '\n'.join([f'{key}:{d[key]}' for key in sorted(d.keys())]) + class BarGenerator: """ For: