[bug fix] 自定义合约

This commit is contained in:
msincenselee 2020-01-13 14:48:24 +08:00
parent 0ef6fc2151
commit aa7f45e929
5 changed files with 232 additions and 27 deletions

View File

@ -3,7 +3,7 @@
import traceback
import json
from datetime import datetime, timedelta
from copy import copy,deepcopy
from copy import copy, deepcopy
from vnpy.api.ctp import (
MdApi,
@ -61,10 +61,12 @@ from vnpy.trader.object import (
SubscribeRequest,
)
from vnpy.trader.utility import (
extract_vt_symbol,
get_folder_path,
get_trading_date,
get_underlying_symbol,
round_to
round_to,
BarGenerator
)
from vnpy.trader.event import EVENT_TIMER
@ -121,7 +123,9 @@ EXCHANGE_CTP2VT = {
"SHFE": Exchange.SHFE,
"CZCE": Exchange.CZCE,
"DCE": Exchange.DCE,
"INE": Exchange.INE
"INE": Exchange.INE,
"SPD": Exchange.SPD
}
PRODUCT_CTP2VT = {
@ -142,6 +146,7 @@ index_contracts = {}
# tdx 期货配置本地缓存
future_contracts = get_future_contracts()
class CtpGateway(BaseGateway):
"""
VN Trader Gateway for CTP .
@ -170,11 +175,13 @@ class CtpGateway(BaseGateway):
"""Constructor"""
super().__init__(event_engine, "CTP")
self.td_api = CtpTdApi(self)
self.md_api = CtpMdApi(self)
self.td_api = None
self.md_api = None
self.tdx_api = None
self.rabbit_api = None
self.subscribed_symbols = set() # 已订阅合约代码
self.combiner_conf_dict = {} # 保存合成器配置
# 自定义价差/加比的tick合成器
self.combiners = {}
@ -203,7 +210,20 @@ class CtpGateway(BaseGateway):
):
md_address = "tcp://" + md_address
# 获取自定义价差/价比合约的配置
try:
from vnpy.trader.engine import CustomContract
c = CustomContract()
self.combiner_conf_dict = c.get_config()
if len(self.combiner_conf_dict) > 0:
self.write_log(u'加载的自定义价差/价比配置:{}'.format(self.combiner_conf_dict))
except Exception as ex:
pass
if not self.td_api:
self.td_api = CtpTdApi(self)
self.td_api.connect(td_address, userid, password, brokerid, auth_code, appid, product_info)
if not self.md_api:
self.md_api = CtpMdApi(self)
self.md_api.connect(md_address, userid, password, brokerid)
if rabbit_dict:
@ -215,8 +235,90 @@ class CtpGateway(BaseGateway):
self.init_query()
for (vt_symbol, is_bar) in self.subscribed_symbols:
symbol, exchange = extract_vt_symbol(vt_symbol)
req = SubscribeRequest(
symbol=symbol,
exchange=exchange,
is_bar=is_bar
)
# 指数合约从tdx行情订阅
if req.symbol[-2:] in ['99']:
req.symbol = req.symbol.upper()
if self.tdx_api is not None:
self.write_log(u'有指数订阅,连接通达信行情服务器')
self.tdx_api.connect()
self.tdx_api.subscribe(req)
elif self.rabbit_api is not None:
self.rabbit_api.subscribe(req)
else:
self.md_api.subscribe(req)
def check_status(self):
"""检查状态"""
if self.tdx_api:
self.tdx_api.check_status()
if self.tdx_api is None or self.md_api is None:
return False
if not self.td_api.connect_status or self.md_api.connect_status:
return False
return True
def subscribe(self, req: SubscribeRequest):
""""""
try:
if self.md_api:
# 如果是自定义的套利合约符号
if req.symbol in self.combiner_conf_dict:
self.write_log(u'订阅自定义套利合约:{}'.format(req.symbol))
# 创建合成器
if req.symbol not in self.combiners:
setting = self.combiner_conf_dict.get(req.symbol)
setting.update({"symbol": req.symbol})
combiner = TickCombiner(self, setting)
# 更新合成器
self.write_log(u'添加{}与合成器映射'.format(req.symbol))
self.combiners.update({setting.get('symbol'): combiner})
# 增加映射( leg1 对应的合成器列表映射)
leg1_symbol = setting.get('leg1_symbol')
combiner_list = self.tick_combiner_map.get(leg1_symbol, [])
if combiner not in combiner_list:
self.write_log(u'添加Leg1:{}与合成器得映射'.format(leg1_symbol))
combiner_list.append(combiner)
self.tick_combiner_map.update({leg1_symbol: combiner_list})
# 增加映射( leg2 对应的合成器列表映射)
leg2_symbol = setting.get('leg2_symbol')
combiner_list = self.tick_combiner_map.get(leg2_symbol, [])
if combiner not in combiner_list:
self.write_log(u'添加Leg2:{}与合成器得映射'.format(leg2_symbol))
combiner_list.append(combiner)
self.tick_combiner_map.update({leg2_symbol: combiner_list})
self.write_log(u'订阅leg1:{}'.format(leg1_symbol))
leg1_req = SubscribeRequest(
symbol=leg1_symbol,
exchange=symbol_exchange_map.get(leg1_symbol, Exchange.LOCAL)
)
self.subscribe(leg1_req)
self.write_log(u'订阅leg2:{}'.format(leg2_symbol))
leg2_req = SubscribeRequest(
symbol=leg2_symbol,
exchange=symbol_exchange_map.get(leg1_symbol, Exchange.LOCAL)
)
self.subscribe(leg2_req)
self.subscribed_symbols.add((req.vt_symbol, req.is_bar))
else:
self.write_log(u'{}合成器已经在存在'.format(req.symbol))
return
elif req.exchange == Exchange.SPD:
self.write_error(u'自定义合约{}不在CTP设置中'.format(req.symbol))
# 指数合约从tdx行情订阅
if req.symbol[-2:] in ['99']:
req.symbol = req.symbol.upper()
@ -227,6 +329,26 @@ class CtpGateway(BaseGateway):
else:
self.md_api.subscribe(req)
# Allow the strategies to start before the connection
self.subscribed_symbols.add((req.vt_symbol, req.is_bar))
if req.is_bar:
self.subscribe_bar(req)
except Exception as ex:
self.write_error(u'订阅合约异常:{},{}'.format(str(ex), traceback.format_exc()))
def subscribe_bar(self, req: SubscribeRequest):
"""订阅1分钟行情"""
vt_symbol = req.vt_symbol
if vt_symbol in self.klines:
return
# 创建1分钟bar产生器
self.write_log(u'创建:{}的一分钟行情产生器'.format(vt_symbol))
bg = BarGenerator(on_bar=self.on_bar)
self.klines.update({vt_symbol: bg})
def send_order(self, req: OrderRequest):
""""""
return self.td_api.send_order(req)
@ -245,8 +367,29 @@ class CtpGateway(BaseGateway):
def close(self):
""""""
self.td_api.close()
self.md_api.close()
if self.md_api:
self.write_log('断开行情API')
tmp1 = self.md_api
self.md_api = None
tmp1.close()
if self.td_api:
self.write_log('断开交易API')
tmp2 = self.td_api
self.td_api = None
tmp2.close()
if self.tdx_api:
self.write_log(u'断开tdx指数行情API')
tmp3 = self.tdx_api
self.tdx_api = None
tmp3.close()
if self.rabbit_api:
self.write_log(u'断开rabbit MQ tdx指数行情API')
tmp4 = self.rabbit_api
self.rabbit_api = None
tmp4.close()
def process_timer_event(self, event):
""""""
@ -398,6 +541,7 @@ class CtpMdApi(MdApi):
tick.ask_volume_5 = data["AskVolume5"]
self.gateway.on_tick(tick)
self.gateway.on_custom_tick(tick)
def connect(self, address: str, userid: str, password: str, brokerid: int):
"""
@ -689,7 +833,8 @@ class CtpTdApi(TdApi):
mi_margin_rate = round(idx_contract.margin_rate, 4)
if mi_contract_symbol == contract.symbol:
if margin_rate != mi_margin_rate:
self.gateway.write_log(f"{underlying_symbol}合约主力{mi_contract_symbol} 保证金{margin_rate}=>{mi_margin_rate}")
self.gateway.write_log(
f"{underlying_symbol}合约主力{mi_contract_symbol} 保证金{margin_rate}=>{mi_margin_rate}")
future_contract.update({'margin_rate': mi_margin_rate})
future_contract.update({'symbol_size': idx_contract.size})
future_contract.update({'price_tick': idx_contract.pricetick})
@ -937,6 +1082,7 @@ class CtpTdApi(TdApi):
if self.connect_status:
self.exit()
class TdxMdApi():
"""
通达信数据行情API实现
@ -957,8 +1103,6 @@ class TdxMdApi():
self.symbol_vn_dict = {} # tdx合约与vtSymbol的对应
self.symbol_tick_dict = {} # tdx合约与最后一个Tick得字典
self.registered_symbol_set = set()
self.thread = None # 查询线程
@ -1486,22 +1630,22 @@ class TickCombiner(object):
return
# 以下情况,基本为单腿涨跌停,不合成价差/价格比 Tick
if (self.last_leg1_tick.ask_price_1 == 0 or self.last_leg1_tick.bid_price_1 == self.last_leg1_tick.upperLimit) \
if (self.last_leg1_tick.ask_price_1 == 0 or self.last_leg1_tick.bid_price_1 == self.last_leg1_tick.limit_up) \
and self.last_leg1_tick.ask_volume_1 == 0:
self.gateway.write_log(
u'leg1:{0}涨停{1}不合成价差Tick'.format(self.last_leg1_tick.vtSymbol, self.last_leg1_tick.bid_price_1))
return
if (self.last_leg1_tick.bid_price_1 == 0 or self.last_leg1_tick.ask_price_1 == self.last_leg1_tick.lowerLimit) \
if (self.last_leg1_tick.bid_price_1 == 0 or self.last_leg1_tick.ask_price_1 == self.last_leg1_tick.limit_down) \
and self.last_leg1_tick.bid_volume_1 == 0:
self.gateway.write_log(
u'leg1:{0}跌停{1}不合成价差Tick'.format(self.last_leg1_tick.vtSymbol, self.last_leg1_tick.ask_price_1))
return
if (self.last_leg2_tick.ask_price_1 == 0 or self.last_leg2_tick.bid_price_1 == self.last_leg2_tick.upperLimit) \
if (self.last_leg2_tick.ask_price_1 == 0 or self.last_leg2_tick.bid_price_1 == self.last_leg2_tick.limit_up) \
and self.last_leg2_tick.ask_volume_1 == 0:
self.gateway.write_log(
u'leg2:{0}涨停{1}不合成价差Tick'.format(self.last_leg2_tick.vtSymbol, self.last_leg2_tick.bid_price_1))
return
if (self.last_leg2_tick.bid_price_1 == 0 or self.last_leg2_tick.ask_price_1 == self.last_leg2_tick.lowerLimit) \
if (self.last_leg2_tick.bid_price_1 == 0 or self.last_leg2_tick.ask_price_1 == self.last_leg2_tick.limit_down) \
and self.last_leg2_tick.bid_volume_1 == 0:
self.gateway.write_log(
u'leg2:{0}跌停{1}不合成价差Tick'.format(self.last_leg2_tick.vtSymbol, self.last_leg2_tick.ask_price_1))
@ -1517,7 +1661,7 @@ class TickCombiner(object):
if self.is_spread:
spread_tick = TickData(gateway_name=self.gateway_name,
symbol=self.symbol,
exchange=tick.exchange,
exchange=Exchange.SPD,
datetime=tick.datetime)
spread_tick.trading_day = tick.trading_day
@ -1563,9 +1707,9 @@ class TickCombiner(object):
self.gateway.on_tick(spread_tick)
if self.is_ratio:
ratio_tick = TickData(gatway_name=self.gateway_name,
ratio_tick = TickData(gateway_name=self.gateway_name,
symbol=self.symbol,
exchange=tick.exchange,
exchange=Exchange.SPD,
datetime=tick.datetime)
ratio_tick.trading_day = tick.trading_day

View File

@ -532,6 +532,44 @@ class OmsEngine(BaseEngine):
]
return active_orders
class CustomContract(object):
"""
定制合约
# 适用于初始化系统时,补充到本地合约信息文件中 contracts.vt
# 适用于CTP网关加载自定义的套利合约做内部行情撮合
"""
# 运行本地目录下定制合约的配置文件dict
file_name = 'custom_contracts.json'
def __init__(self):
"""构造函数"""
from vnpy.trader.utility import load_json
self.setting = load_json(self.file_name) # 所有设置
def get_config(self):
"""获取配置"""
return self.setting
def get_contracts(self):
"""获取所有合约信息"""
d = {}
from vnpy.trader.object import ContractData, Exchange
for symbol, setting in self.setting.items():
gateway_name = setting.get('gateway_name', None)
if gateway_name is None:
gateway_name= SETTINGS.get('gateway_name','')
vn_exchange = Exchange(setting.get('exchange', 'LOCAL'))
contract = ContractData(
gateway_name=gateway_name,
symbol=symbol,
name=contract.symbol,
size=setting.get('size', 100),
pricetick=setting.get('price_tick', 0.01),
margin_rate=setting.get('margin_rate', 0.1)
)
d[contract.vt_symbol] = contract
return d
class EmailEngine(BaseEngine):
"""

View File

@ -11,6 +11,7 @@ from logging import INFO, DEBUG, ERROR
from vnpy.event import Event, EventEngine
from .event import (
EVENT_TICK,
EVENT_BAR,
EVENT_ORDER,
EVENT_TRADE,
EVENT_POSITION,
@ -20,6 +21,7 @@ from .event import (
)
from .object import (
TickData,
BarData,
OrderData,
TradeData,
PositionData,
@ -60,6 +62,7 @@ class BaseGateway(ABC):
---
## callbacks must response manually:
* on_tick
* on_bar
* on_trade
* on_order
* on_position
@ -89,6 +92,9 @@ class BaseGateway(ABC):
self.create_logger()
# 所有订阅on_bar的都会添加
self.klines = {}
def create_logger(self):
"""
创建engine独有的日志
@ -116,6 +122,17 @@ class BaseGateway(ABC):
self.on_event(EVENT_TICK, tick)
self.on_event(EVENT_TICK + tick.vt_symbol, tick)
# 推送Bar
kline = self.klines.get(tick.vt_symbol, None)
if kline:
kline.update_tick(tick)
def on_bar(self, bar: BarData):
"""市场行情推送"""
# bar, 或者 barDict
self.on_event(EVENT_BAR, bar)
self.write_log(f'on_bar Event:{bar.__dict__}')
def on_trade(self, trade: TradeData):
"""
Trade event push.

View File

@ -287,6 +287,8 @@ class SubscribeRequest:
""""""
self.vt_symbol = f"{self.symbol}.{self.exchange.value}"
def __eq__(self, other):
return self.vt_symbol == other.vt_symbol
@dataclass
class OrderRequest:

View File

@ -300,6 +300,10 @@ def ceil_to(value: float, target: float) -> float:
return result
def print_dict(d: dict):
"""返回dict的字符串类型"""
return '\n'.join([f'{key}:{d[key]}' for key in sorted(d.keys())])
class BarGenerator:
"""
For: