diff --git a/vnpy/gateway/xtp/xtp_gateway.py b/vnpy/gateway/xtp/xtp_gateway.py index 5667efd2..6f8ebf52 100644 --- a/vnpy/gateway/xtp/xtp_gateway.py +++ b/vnpy/gateway/xtp/xtp_gateway.py @@ -24,7 +24,7 @@ from vnpy.trader.object import ( PositionData, AccountData ) -from vnpy.trader.utility import get_folder_path +from vnpy.trader.utility import get_folder_path, print_dict, extract_vt_symbol # 市场id <=> Exchange MARKET_XTP2VT: Dict[int, Exchange] = { @@ -299,6 +299,7 @@ class XtpMdApi(MdApi): tick.ask_volume_1, tick.ask_volume_2, tick.ask_volume_3, tick.ask_volume_4, tick.ask_volume_5 = data["ask_qty"][0:5] tick.name = symbol_name_map.get(tick.vt_symbol, tick.symbol) + self.gateway.prices.update({tick.vt_symbol: tick.last_price}) self.gateway.on_tick(tick) def onSubOrderBook(self, data: dict, error: dict, last: bool) -> None: @@ -363,6 +364,12 @@ class XtpMdApi(MdApi): ) self.gateway.on_contract(contract) + # 更新最新价 + pre_close_price = float(data["pre_close_price"]) + vt_symbol = contract.vt_symbol + if vt_symbol not in self.gateway.prices and pre_close_price>0: + self.gateway.prices.update({vt_symbol: pre_close_price}) + # 更新 symbol <=> 中文名称映射 symbol_name_map[contract.vt_symbol] = contract.name @@ -374,8 +381,17 @@ class XtpMdApi(MdApi): self.gateway.write_log(f"{contract.exchange.value}合约信息查询成功") def onQueryTickersPriceInfo(self, data: dict, error: dict, last: bool) -> None: - """""" - pass + """查询最新价""" + self.gateway.write_log('最新价:{}'.format(print_dict(data))) + symbol = data.get('ticker') + exchange_id = data.get('exchange_id') + last_price = float(data.get('last_price', 0)) + + if symbol and exchange_id and last_price > 0: + exchange = EXCHANGE_XTP2VT[exchange_id] + vt_symbol = f'{symbol}.{exchange.value}' + self.gateway.prices.update({vt_symbol: last_price}) + self.gateway.write_log(f'{vt_symbol} 最新价: {last_price}') def onSubscribeAllOptionMarketData(self, data: dict, error: dict) -> None: """""" @@ -483,6 +499,10 @@ class XtpTdApi(TdApi): self.reqid: int = 0 self.protocol: int = 0 + # 证券资产 + self.security_asset = None # 未查询获取持仓时,是None,如果查询过,无持仓,为0 + self.security_volumes: Dict[str, int] = {} # vt_symbol, volume + # Whether current account supports margin or option self.margin_trading = False self.option_trading = False @@ -590,7 +610,9 @@ class XtpTdApi(TdApi): last: bool, session: int ) -> None: - """""" + """普通账号持仓""" + # self.gateway.write_log(f"------\n {print_dict(data)}") + if data["market"] == 0: return @@ -605,8 +627,35 @@ class XtpTdApi(TdApi): yd_volume=data["yesterday_position"], gateway_name=self.gateway_name ) + vt_symbol = position.vt_symbol self.gateway.on_position(position) + # 如果持仓>0 获取持仓对应的当前最新价 + if position.volume > 0 and vt_symbol not in self.gateway.prices: + req = SubscribeRequest(symbol=position.symbol, exchange=position.exchange) + self.gateway.subscribe(req) + self.security_volumes.update({vt_symbol: data["total_qty"]}) + + def update_security_asset(self): + """更新资产净值""" + #self.gateway.write_log(f'更新资产净值') + total_asset = 0 + for vt_symbol, volume in self.security_volumes.items(): + price = self.gateway.prices.get(vt_symbol, None) + # 获取不到股票的最新价,所以当前security_asset不可用 + if price is None: + self.gateway.write_log(f'取不到:{vt_symbol}的价格') + self.security_asset = None + symbol, exchange = extract_vt_symbol(vt_symbol) + req = SubscribeRequest(symbol=symbol, exchange=exchange) + self.gateway.subscribe(req) + return + + total_asset += volume * price + #self.gateway.write_log(f'资产净值 => {total_asset}') + + self.security_asset = total_asset + def onQueryAsset( self, data: dict, @@ -616,13 +665,31 @@ class XtpTdApi(TdApi): session: int ) -> None: """""" - account = AccountData( - accountid=self.userid, - balance=data["buying_power"], - frozen=data["withholding_amount"], - gateway_name=self.gateway_name - ) - self.gateway.on_account(account) + # XTP_ACCOUNT_NORMAL = 0, ///<普通账户 + # XTP_ACCOUNT_CREDIT, 1 ///<信用账户 + # XTP_ACCOUNT_DERIVE, 2 ///<衍生品账户 + # XTP_ACCOUNT_UNKNOWN 3 ///<未知账户类型 + if data['account_type'] != 0: + return + + # self.gateway.write_log(print_dict(data)) + self.update_security_asset() + + if self.security_asset is not None: + cash_asset = data["total_asset"] + balance = cash_asset + self.security_asset + + account = AccountData( + accountid=self.userid, + balance=balance, # 总资产 + margin=self.security_asset, # 证券资产 + frozen=data["withholding_amount"], + gateway_name=self.gateway_name + ) + # AccountData缺省的available 计算方法有误,这里直接取可用资金 + account.available = cash_asset + + self.gateway.on_account(account) if data["account_type"] == 1: self.margin_trading = True @@ -669,7 +736,8 @@ class XtpTdApi(TdApi): last: bool, session: int ) -> None: - """""" + """信用账号持仓""" + self.gateway.write_log(f"------\n {print_dict(data)}") if data["debt_type"] == 1: symbol = data["ticker"] exchange = MARKET_XTP2VT[data["market"]] diff --git a/vnpy/trader/gateway.py b/vnpy/trader/gateway.py index 34da6863..21172eb6 100644 --- a/vnpy/trader/gateway.py +++ b/vnpy/trader/gateway.py @@ -96,6 +96,8 @@ class BaseGateway(ABC): self.klines = {} self.status = {'name': gateway_name, 'con': False} + self.prices: Dict[str, float] = {} # vt_symbol, last_price + self.query_functions = [] def create_logger(self):