[改进] 显示小数浮点长度、增加gw状态

This commit is contained in:
msincenselee 2020-03-18 15:03:42 +08:00
parent 0b626d7501
commit 9b9389c9cf
15 changed files with 216 additions and 80 deletions

View File

@ -567,7 +567,7 @@ class BackTestingEngine(object):
# 更新策略的资金K线
fund_kline = self.fund_kline_dict.get(strategy.strategy_name, None)
if fund_kline:
hold_pnl = fund_kline.get_hold_pnl()
hold_pnl, _ = fund_kline.get_hold_pnl()
if hold_pnl != 0:
fund_kline.update_strategy(dt=self.last_dt, hold_pnl=hold_pnl)
@ -598,7 +598,7 @@ class BackTestingEngine(object):
# 更新策略的资金K线
fund_kline = self.fund_kline_dict.get(strategy.strategy_name, None)
if fund_kline:
hold_pnl = fund_kline.get_hold_pnl()
hold_pnl, _ = fund_kline.get_hold_pnl()
if hold_pnl != 0:
fund_kline.update_strategy(dt=self.last_dt, hold_pnl=hold_pnl)
@ -1349,6 +1349,7 @@ class BackTestingEngine(object):
if cover_volume >= open_trade.volume:
self.write_log(f'cover volume:{cover_volume}, 满足:{open_trade.volume}')
cover_volume = cover_volume - open_trade.volume
cover_volume = round(cover_volume, 7)
if cover_volume > 0:
self.write_log(u'剩余待平数量:{}'.format(cover_volume))
@ -1364,6 +1365,7 @@ class BackTestingEngine(object):
slippage=self.get_slippage(trade.vt_symbol),
size=self.get_size(trade.vt_symbol),
group_id=g_id,
margin_rate=self.get_margin_rate(trade.vt_symbol),
fix_commission=self.get_fix_commission(trade.vt_symbol))
t = OrderedDict()
@ -1415,6 +1417,7 @@ class BackTestingEngine(object):
# 开空volume,大于平仓volume需要更新减少tradeDict的数量。
else:
remain_volume = open_trade.volume - cover_volume
remain_volume = round(remain_volume, 7)
self.write_log(f'{open_trade.vt_symbol} short pos: {open_trade.volume} => {remain_volume}')
result = TradingResult(open_price=open_trade.price,
@ -1426,6 +1429,7 @@ class BackTestingEngine(object):
slippage=self.get_slippage(trade.vt_symbol),
size=self.get_size(trade.vt_symbol),
group_id=g_id,
margin_rate=self.get_margin_rate(trade.vt_symbol),
fix_commission=self.get_fix_commission(trade.vt_symbol))
t = OrderedDict()
@ -1512,7 +1516,7 @@ class BackTestingEngine(object):
if sell_volume >= open_trade.volume:
self.write_log(f'{open_trade.vt_symbol},Sell Volume:{sell_volume} 满足:{open_trade.volume}')
sell_volume = sell_volume - open_trade.volume
sell_volume = round(sell_volume, 7)
self.write_log(f'{open_trade.vt_symbol},sell, price:{trade.price},volume:{open_trade.volume}')
result = TradingResult(open_price=open_trade.price,
@ -1524,6 +1528,7 @@ class BackTestingEngine(object):
slippage=self.get_slippage(trade.vt_symbol),
size=self.get_size(trade.vt_symbol),
group_id=g_id,
margin_rate=self.get_margin_rate(trade.vt_symbol),
fix_commission=self.get_fix_commission(trade.vt_symbol))
t = OrderedDict()
@ -1571,6 +1576,7 @@ class BackTestingEngine(object):
# 开多volume,大于平仓volume需要更新减少tradeDict的数量。
else:
remain_volume = open_trade.volume - sell_volume
remain_volume = round(remain_volume, 7)
self.write_log(f'{open_trade.vt_symbol} short pos: {open_trade.volume} => {remain_volume}')
result = TradingResult(open_price=open_trade.price,
@ -1582,6 +1588,7 @@ class BackTestingEngine(object):
slippage=self.get_slippage(trade.vt_symbol),
size=self.get_size(trade.vt_symbol),
group_id=g_id,
margin_rate=self.get_margin_rate(trade.vt_symbol),
fix_commission=self.get_fix_commission(trade.vt_symbol))
t = OrderedDict()
@ -2107,7 +2114,7 @@ class TradingResult(object):
"""每笔交易的结果"""
def __init__(self, open_price, open_datetime, exit_price, close_datetime, volume, rate, slippage, size, group_id,
fix_commission=0.0):
margin_rate, fix_commission=0.0):
"""Constructor"""
self.open_price = open_price # 开仓价格
self.exit_price = exit_price # 平仓价格
@ -2118,11 +2125,11 @@ class TradingResult(object):
self.volume = volume # 交易数量(+/-代表方向)
self.group_id = group_id # 主交易ID针对多手平仓
self.turnover = (self.open_price + self.exit_price) * abs(volume) # 成交金额
self.turnover = (self.open_price + self.exit_price) * abs(volume) * margin_rate # 成交金额(实际保证金金额)
if fix_commission > 0:
self.commission = fix_commission * abs(self.volume)
else:
self.commission = abs(self.turnover * rate) # 手续费成本
self.slippage = slippage * 2 * abs(volume) # 滑点成本
self.pnl = ((self.exit_price - self.open_price) * volume * size
self.slippage = slippage * 2 * abs(self.turnover) # 滑点成本
self.pnl = ((self.exit_price - self.open_price) * volume
- self.commission - self.slippage) # 净盈亏

View File

@ -164,6 +164,29 @@ class CtaEngine(BaseEngine):
:return:
"""
self.main_engine.get_strategy_status = self.get_strategy_status
self.main_engine.get_strategy_pos = self.get_strategy_pos
self.main_engine.add_strategy = self.add_strategy
self.main_engine.init_strategy = self.init_strategy
self.main_engine.start_strategy = self.start_strategy
self.main_engine.stop_strategy = self.stop_strategy
self.main_engine.remove_strategy = self.remove_strategy
self.main_engine.reload_strategy = self.reload_strategy
self.main_engine.save_strategy_data = self.save_strategy_data
self.main_engine.save_strategy_snapshot = self.save_strategy_snapshot
# 注册到远程服务调用
rpc_service = self.main_engine.apps.get('RpcService')
if rpc_service:
rpc_service.register(self.main_engine.get_strategy_status)
rpc_service.register(self.main_engine.get_strategy_pos)
rpc_service.register(self.main_engine.add_strategy)
rpc_service.register(self.main_engine.init_strategy)
rpc_service.register(self.main_engine.start_strategy)
rpc_service.register(self.main_engine.stop_strategy)
rpc_service.register(self.main_engine.remove_strategy)
rpc_service.register(self.main_engine.reload_strategy)
rpc_service.register(self.main_engine.save_strategy_data)
rpc_service.register(self.main_engine.save_strategy_snapshot)
def process_timer_event(self, event: Event):
""" 处理定时器事件"""
@ -431,7 +454,6 @@ class CtaEngine(BaseEngine):
if contract.gateway_name and not gateway_name:
gateway_name = contract.gateway_name
# Send Orders
vt_orderids = []
@ -805,7 +827,6 @@ class CtaEngine(BaseEngine):
vt_position_id = f"{gateway_name}.{vt_symbol}.{direction.value}"
return self.main_engine.get_position(vt_position_id)
def get_position_holding(self, vt_symbol: str, gateway_name: str = ''):
""" 查询合约在账号的持仓(包含多空)"""
k = f'{gateway_name}.{vt_symbol}'
@ -819,7 +840,6 @@ class CtaEngine(BaseEngine):
self.holdings[k] = holding
return holding
def get_engine_type(self):
""""""
return self.engine_type
@ -835,11 +855,11 @@ class CtaEngine(BaseEngine):
return log_path
def load_bar(
self,
vt_symbol: str,
days: int,
interval: Interval,
callback: Callable[[BarData], None]
self,
vt_symbol: str,
days: int,
interval: Interval,
callback: Callable[[BarData], None]
):
""""""
symbol, exchange = extract_vt_symbol(vt_symbol)
@ -867,10 +887,10 @@ class CtaEngine(BaseEngine):
callback(bar)
def load_tick(
self,
vt_symbol: str,
days: int,
callback: Callable[[TickData], None]
self,
vt_symbol: str,
days: int,
callback: Callable[[TickData], None]
):
""""""
symbol, exchange = extract_vt_symbol(vt_symbol)
@ -887,7 +907,6 @@ class CtaEngine(BaseEngine):
for tick in ticks:
callback(tick)
def call_strategy_func(
self, strategy: CtaTemplate, func: Callable, params: Any = None
):
@ -1304,15 +1323,7 @@ class CtaEngine(BaseEngine):
:param strategy_name:
:return:
"""
inited = False
trading = False
strategy = self.strategies.get(strategy_name, None)
if strategy:
inited = strategy.inited
trading = strategy.trading
return inited, trading
return [{k: {'inited': v.inited, 'trading': v.trading}} for k, v in self.strategies.items()]
def get_strategy_pos(self, name, strategy=None):
"""

View File

@ -441,7 +441,7 @@ class CtaFutureTemplate(CtaTemplate):
backtesting = False
# 逻辑过程日志
dist_fieldnames = ['datetime', 'symbol', 'volume', 'price',
dist_fieldnames = ['datetime', 'symbol', 'volume', 'price','margin',
'operation', 'signal', 'stop_price', 'target_price',
'long_pos', 'short_pos']
@ -679,6 +679,7 @@ class CtaFutureTemplate(CtaTemplate):
dist_record['datetime'] = ' '.join([self.cur_datetime.strftime('%Y-%m-%d'), trade.time])
dist_record['volume'] = trade.volume
dist_record['price'] = trade.price
dist_record['margin'] = trade.price * trade.volume * self.cta_engine.get_margin_rate(trade.vt_symbol)
dist_record['symbol'] = trade.vt_symbol
if trade.direction == Direction.LONG and trade.offset == Offset.OPEN:
@ -1056,7 +1057,12 @@ class CtaFutureTemplate(CtaTemplate):
self.account_pos.long_pos,
grid.volume))
vt_orderids = self.sell(price=sell_price, volume=grid.volume, order_time=self.cur_datetime, grid=grid)
vt_orderids = self.sell(
vt_symbol=self.vt_symbol,
price=sell_price,
volume=grid.volume,
order_time=self.cur_datetime,
grid=grid)
if len(vt_orderids) == 0:
if self.backtesting:
self.write_error(u'多单平仓委托失败')
@ -1092,10 +1098,12 @@ class CtaFutureTemplate(CtaTemplate):
grid.volume -= grid.traded_volume
grid.traded_volume = 0
vt_orderids = self.cover(price=cover_price,
volume=grid.volume,
order_time=self.cur_datetime,
grid=grid)
vt_orderids = self.cover(
price=cover_price,
vt_symbol=self.vt_symbol,
volume=grid.volume,
order_time=self.cur_datetime,
grid=grid)
if len(vt_orderids) == 0:
if self.backtesting:
self.write_error(u'空单平仓委托失败')
@ -1295,6 +1303,8 @@ class CtaFutureTemplate(CtaTemplate):
else:
save_path = self.cta_engine.get_data_path()
try:
if 'margin' not in dist_data:
dist_data.update({'margin': dist_data.get('price', 0) * dist_data.get('volume', 0) * self.cta_engine.get_margin_rate(dist_data.get('symbol', self.vt_symbol))})
if self.position and 'long_pos' not in dist_data:
dist_data.update({'long_pos': self.position.long_pos})
if self.position and 'short_pos' not in dist_data:

View File

@ -552,7 +552,7 @@ class BackTestingEngine(object):
# 更新策略的资金K线
fund_kline = self.fund_kline_dict.get(strategy.strategy_name, None)
if fund_kline:
hold_pnl = fund_kline.get_hold_pnl()
hold_pnl, _ = fund_kline.get_hold_pnl()
if hold_pnl != 0:
fund_kline.update_strategy(dt=self.last_dt, hold_pnl=hold_pnl)
@ -583,7 +583,7 @@ class BackTestingEngine(object):
# 更新策略的资金K线
fund_kline = self.fund_kline_dict.get(strategy.strategy_name, None)
if fund_kline:
hold_pnl = fund_kline.get_hold_pnl()
hold_pnl, _ = fund_kline.get_hold_pnl()
if hold_pnl != 0:
fund_kline.update_strategy(dt=self.last_dt, hold_pnl=hold_pnl)

View File

@ -162,6 +162,29 @@ class CtaEngine(BaseEngine):
:return:
"""
self.main_engine.get_strategy_status = self.get_strategy_status
self.main_engine.get_strategy_pos = self.get_strategy_pos
self.main_engine.add_strategy = self.add_strategy
self.main_engine.init_strategy = self.init_strategy
self.main_engine.start_strategy = self.start_strategy
self.main_engine.stop_strategy = self.stop_strategy
self.main_engine.remove_strategy = self.remove_strategy
self.main_engine.reload_strategy = self.reload_strategy
self.main_engine.save_strategy_data = self.save_strategy_data
self.main_engine.save_strategy_snapshot = self.save_strategy_snapshot
# 注册到远程服务调用
rpc_service = self.main_engine.apps.get('RpcService')
if rpc_service:
rpc_service.register(self.main_engine.get_strategy_status)
rpc_service.register(self.main_engine.get_strategy_pos)
rpc_service.register(self.main_engine.add_strategy)
rpc_service.register(self.main_engine.init_strategy)
rpc_service.register(self.main_engine.start_strategy)
rpc_service.register(self.main_engine.stop_strategy)
rpc_service.register(self.main_engine.remove_strategy)
rpc_service.register(self.main_engine.reload_strategy)
rpc_service.register(self.main_engine.save_strategy_data)
rpc_service.register(self.main_engine.save_strategy_snapshot)
def process_timer_event(self, event: Event):
""" 处理定时器事件"""
@ -1211,21 +1234,14 @@ class CtaEngine(BaseEngine):
"""
return list(self.classes.keys())
def get_strategy_status(self, strategy_name):
def get_strategy_status(self):
"""
return strategy inited/trading status
:param strategy_name:
return strategy name list with inited/trading status
:param :
:return:
"""
inited = False
trading = False
return [{k: {'inited': v.inited, 'trading': v.trading}} for k, v in self.strategies.items()]
strategy = self.strategies.get(strategy_name, None)
if strategy:
inited = strategy.inited
trading = strategy.trading
return inited, trading
def get_strategy_pos(self, name, strategy=None):
"""
@ -1355,6 +1371,9 @@ class CtaEngine(BaseEngine):
d['date'] = dt.strftime('%Y%m%d')
d['hour'] = dt.hour
d['datetime'] = datetime.now()
strategy = self.strategies.get(strategy_name)
d['inited'] = strategy.inited
d['trading'] = strategy.trading
try:
d['pos'] = self.get_strategy_pos(name=strategy_name)
except Exception as ex:

View File

@ -0,0 +1,15 @@
# encoding: UTF-8
import os
from pathlib import Path
from vnpy.trader.app import BaseApp
from .dispatch_engine import DispatchEngine, APP_NAME
class DispatchApp(BaseApp):
""""""
app_name = APP_NAME
app_module = __module__
app_path = Path(__file__).parent
display_name = u'调度引擎'
engine_class = DispatchEngine

View File

@ -0,0 +1,23 @@
# encoding: UTF-8
# 策略调度引擎
# 华富资产
from vnpy.event import EventEngine
from vnpy.trader.constant import Exchange
from vnpy.trader.engine import BaseEngine, MainEngine
from vnpy.trader.event import EVENT_TIMER
APP_NAME = 'DispatchEngine'
class DispatchEngine(BaseEngine):
def __init__(self, main_engine: MainEngine, event_engine: EventEngine):
""""""
super().__init__(main_engine, event_engine, APP_NAME)
self.main_engine = main_engine
self.event_engine = event_engine
self.create_logger(logger_name=APP_NAME)

View File

@ -1,5 +1,5 @@
""""""
import sys
import traceback
from typing import Optional, Callable
@ -35,6 +35,7 @@ class RpcEngine(BaseEngine):
""""""
self.server = RpcServer()
self.server.register(self.main_engine.get_all_gateway_status)
self.server.register(self.main_engine.subscribe)
self.server.register(self.main_engine.send_order)
self.server.register(self.main_engine.send_orders)
@ -43,6 +44,7 @@ class RpcEngine(BaseEngine):
self.server.register(self.main_engine.query_history)
self.server.register(self.main_engine.get_tick)
self.server.register(self.main_engine.get_price)
self.server.register(self.main_engine.get_order)
self.server.register(self.main_engine.get_trade)
self.server.register(self.main_engine.get_position)
@ -55,6 +57,7 @@ class RpcEngine(BaseEngine):
self.server.register(self.main_engine.get_all_accounts)
self.server.register(self.main_engine.get_all_contracts)
self.server.register(self.main_engine.get_all_active_orders)
self.server.register(self.main_engine.get_all_custom_contracts)
def register(self, func: Callable):
""" 扩展注册接口"""
@ -75,25 +78,27 @@ class RpcEngine(BaseEngine):
}
save_json(self.setting_filename, setting)
def start(self, rep_address: str, pub_address: str):
def start(self, rep_address: str = None, pub_address: str = None):
""""""
if self.server.is_active():
self.write_log("RPC服务运行中")
return False
self.rep_address = rep_address
self.pub_address = pub_address
return False, "RPC服务运行中"
if rep_address:
self.rep_address = rep_address
if pub_address:
self.pub_address = pub_address
try:
self.server.start(rep_address, pub_address)
self.server.start(self.rep_address, self.pub_address)
except: # noqa
msg = traceback.format_exc()
print(msg, file=sys.stderr)
self.write_log(f"RPC服务启动失败{msg}")
return False
return False, msg
self.save_setting()
self.write_log("RPC服务启动成功")
return True
return True,"RPC服务启动成功"
def stop(self):
""""""

View File

@ -32,6 +32,8 @@ class CtaPosition(CtaComponent):
self.write_log(f'净:{self.pos}->{self.pos + volume}')
self.long_pos += volume
self.pos += volume
self.long_pos = round(self.long_pos, 7)
self.pos = round(self.pos, 7)
if direction == Direction.SHORT: # 加空仓
if (min(self.pos, self.short_pos) - volume) < (0 - self.maxPos):
@ -41,7 +43,8 @@ class CtaPosition(CtaComponent):
self.write_log(f'净:{self.pos}->{self.pos - volume}')
self.short_pos -= volume
self.pos -= volume
self.short_pos = round(self.short_pos, 7)
self.pos = round(self.pos, 7)
return True
def close_pos(self, direction: Direction, volume: float):
@ -56,6 +59,8 @@ class CtaPosition(CtaComponent):
self.write_log(f'净:{self.pos}->{self.pos + volume}')
self.short_pos += volume
self.pos += volume
self.short_pos = round(self.short_pos, 7)
self.pos = round(self.pos, 7)
# 更新上层策略的pos。该方法不推荐使用
self.strategy.pos = self.pos
@ -69,6 +74,8 @@ class CtaPosition(CtaComponent):
self.long_pos -= volume
self.pos -= volume
self.long_pos = round(self.long_pos, 7)
self.pos = round(self.pos, 7)
return True

View File

@ -154,7 +154,10 @@ class BinanceGateway(BaseGateway):
def process_timer_event(self, event: Event):
""""""
self.rest_api.keep_user_stream()
if self.status.get('td_con', False) \
and self.status.get('tdws_con', False) \
and self.status.get('mdws_con', False):
self.status.update({'con': True})
class BinanceRestApi(RestClient):
"""
@ -254,7 +257,7 @@ class BinanceRestApi(RestClient):
self.start(session_number)
self.gateway.write_log("REST API启动成功")
self.gateway.status.update({'md_con': True, 'md_con_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')})
self.query_time()
self.query_account()
self.query_order()
@ -625,6 +628,7 @@ class BinanceTradeWebsocketApi(WebsocketClient):
def on_connected(self):
""""""
self.gateway.write_log("交易Websocket API连接成功")
self.gateway.status.update({'tdws_con': True, 'tdws_con_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')})
def on_packet(self, packet: dict): # type: (dict)->None
""""""
@ -714,6 +718,7 @@ class BinanceDataWebsocketApi(WebsocketClient):
def on_connected(self):
""""""
self.gateway.write_log("行情Websocket API连接刷新")
self.gateway.status.update({'mdws_con': True, 'mdws_con_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')})
def subscribe(self, req: SubscribeRequest):
""""""

View File

@ -163,6 +163,10 @@ class BinancefGateway(BaseGateway):
def process_timer_event(self, event: Event) -> None:
""""""
self.rest_api.keep_user_stream()
if self.status.get('td_con', False) \
and self.status.get('tdws_con', False) \
and self.status.get('mdws_con', False):
self.status.update({'con': True})
def get_order(self, orderid: str):
return self.rest_api.get_order(orderid)
@ -275,6 +279,7 @@ class BinancefRestApi(RestClient):
self.start(session_number)
self.gateway.write_log("REST API启动成功")
self.gateway.status.update({'md_con': True, 'md_con_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')})
self.query_time()
self.query_account()
@ -799,6 +804,7 @@ class BinancefTradeWebsocketApi(WebsocketClient):
def on_connected(self) -> None:
""""""
self.gateway.write_log("交易Websocket API连接成功")
self.gateway.status.update({'tdws_con': True, 'tdws_con_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')})
def on_packet(self, packet: dict) -> None: # type: (dict)->None
""""""
@ -916,6 +922,7 @@ class BinancefDataWebsocketApi(WebsocketClient):
def on_connected(self) -> None:
""""""
self.gateway.write_log("行情Websocket API连接刷新")
self.gateway.status.update({'mdws_con': True, 'mdws_con_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')})
def subscribe(self, req: SubscribeRequest) -> None:
""""""

View File

@ -261,6 +261,10 @@ class CtpGateway(BaseGateway):
def check_status(self):
"""检查状态"""
if self.td_api.connect_status and self.md_api.connect_status:
self.status.update({'con': True})
if self.tdx_api:
self.tdx_api.check_status()
if self.tdx_api is None or self.md_api is None:
@ -449,6 +453,7 @@ class CtpMdApi(MdApi):
"""
self.gateway.write_log("行情服务器连接成功")
self.login()
self.gateway.status.update({'md_con': True, 'md_con_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')})
def onFrontDisconnected(self, reason: int):
"""
@ -456,6 +461,7 @@ class CtpMdApi(MdApi):
"""
self.login_status = False
self.gateway.write_log(f"行情服务器连接断开,原因{reason}")
self.gateway.status.update({'md_con': False, 'md_dis_con_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')})
def onRspUserLogin(self, data: dict, error: dict, reqid: int, last: bool):
"""
@ -642,11 +648,13 @@ class CtpTdApi(TdApi):
self.authenticate()
else:
self.login()
self.gateway.status.update({'td_con': True, 'td_con_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')})
def onFrontDisconnected(self, reason: int):
""""""
self.login_status = False
self.gateway.write_log(f"交易服务器连接断开,原因{reason}")
self.gateway.status.update({'td_con': True, 'td_dis_con_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')})
def onRspAuthenticate(self, data: dict, error: dict, reqid: int, last: bool):
""""""
@ -784,16 +792,16 @@ class CtpTdApi(TdApi):
account = AccountData(
accountid=data["AccountID"],
pre_balance=data['PreBalance'],
balance=data["Balance"],
frozen=data["FrozenMargin"] + data["FrozenCash"] + data["FrozenCommission"],
pre_balance=round(float(data['PreBalance']), 7),
balance=round(float(data["Balance"]), 7),
frozen=round(data["FrozenMargin"] + data["FrozenCash"] + data["FrozenCommission"], 7),
gateway_name=self.gateway_name
)
account.available = data["Available"]
account.commission = data['Commission']
account.margin = data['CurrMargin']
account.close_profit = data['CloseProfit']
account.holding_profit = data['PositionProfit']
account.available = round(float(data["Available"]), 7)
account.commission = round(float(data['Commission']), 7)
account.margin = round(float(data['CurrMargin']), 7)
account.close_profit = round(float(data['CloseProfit']), 7)
account.holding_profit = round(float(data['PositionProfit']),7)
account.trading_day = str(data['TradingDay'])
if '-' not in account.trading_day and len(account.trading_day) == 8:
account.trading_day = '-'.join(
@ -1257,7 +1265,7 @@ class TdxMdApi():
else:
self.gateway.write_log(u'创建tdx连接, IP: {}/{}'.format(self.best_ip['ip'], self.best_ip['port']))
self.connection_status = True
self.gateway.status.update({'tdx_con': True, 'tdx_con_time': datetime.now().strftime('%Y-%m-%d %H:%M%S')})
self.thread = Thread(target=self.run)
self.thread.start()
@ -1536,7 +1544,7 @@ class SubMdApi():
self.thread = Thread(target=self.sub.start)
self.thread.start()
self.connect_status = True
self.gateway.status.update({'sub_con': True, 'sub_con_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')})
except Exception as ex:
self.gateway.write_error(u'连接RabbitMQ {} 异常:{}'.format(self.setting, str(ex)))
self.gateway.write_error(traceback.format_exc())

View File

@ -73,8 +73,8 @@ class RpcServer:
return self.__active
def start(
self,
rep_address: str,
self,
rep_address: str,
pub_address: str,
server_secretkey_path: str = ""
) -> None:
@ -89,12 +89,12 @@ class RpcServer:
self.__authenticator = ThreadAuthenticator(self.__context)
self.__authenticator.start()
self.__authenticator.configure_curve(
domain="*",
domain="*",
location=zmq.auth.CURVE_ALLOW_ANY
)
publickey, secretkey = zmq.auth.load_certificate(server_secretkey_path)
self.__socket_pub.curve_secretkey = secretkey
self.__socket_pub.curve_publickey = publickey
self.__socket_pub.curve_server = True
@ -236,8 +236,8 @@ class RpcClient:
return dorpc
def start(
self,
req_address: str,
self,
req_address: str,
sub_address: str,
client_secretkey_path: str = "",
server_publickey_path: str = ""
@ -253,13 +253,13 @@ class RpcClient:
self.__authenticator = ThreadAuthenticator(self.__context)
self.__authenticator.start()
self.__authenticator.configure_curve(
domain="*",
domain="*",
location=zmq.auth.CURVE_ALLOW_ANY
)
publickey, secretkey = zmq.auth.load_certificate(client_secretkey_path)
serverkey, _ = zmq.auth.load_certificate(server_publickey_path)
self.__socket_sub.curve_secretkey = secretkey
self.__socket_sub.curve_publickey = publickey
self.__socket_sub.curve_serverkey = serverkey
@ -297,6 +297,11 @@ class RpcClient:
self.__thread.join()
self.__thread = None
def close(self):
"""close receiver, exit"""
self.stop()
self.join()
def run(self) -> None:
"""
Run RpcClient function
@ -347,4 +352,4 @@ def generate_certificates(name: str) -> None:
if not keys_path.exists():
os.mkdir(keys_path)
zmq.auth.create_certificates(keys_path, name)
zmq.auth.create_certificates(keys_path, name)

View File

@ -163,6 +163,13 @@ class MainEngine:
"""
return list(self.gateways.keys())
def get_all_gateway_status(self) -> List[dict]:
"""
Get all gateway status
:return:
"""
return list([{k: v.get_status()} for k, v in self.gateways.items()])
def get_all_apps(self) -> List[BaseApp]:
"""
Get all app objects.

View File

@ -95,6 +95,7 @@ class BaseGateway(ABC):
# 所有订阅on_bar的都会添加
self.klines = {}
self.status = {'name': gateway_name, 'con': False}
def create_logger(self):
"""
@ -314,6 +315,12 @@ class BaseGateway(ABC):
"""
return self.default_setting
def get_status(self) -> Dict[str, Any]:
"""
return gateway status
:return:
"""
return self.status
class LocalOrderManager:
"""
@ -343,7 +350,7 @@ class LocalOrderManager:
self.cancel_request_buf: Dict[str, CancelRequest] = {} # local_orderid: req
# Hook cancel order function
self._cancel_order: Callable[CancelRequest] = gateway.cancel_order
self._cancel_order = gateway.cancel_order
gateway.cancel_order = self.cancel_order
def new_local_orderid(self) -> str: