[增强功能] 增强xtp账号信息内,包含股票实时净值权益的计算

This commit is contained in:
msincenselee 2020-05-21 15:24:59 +08:00
parent fa22e74a8a
commit 83d2a40624
2 changed files with 82 additions and 12 deletions

View File

@ -24,7 +24,7 @@ from vnpy.trader.object import (
PositionData, PositionData,
AccountData AccountData
) )
from vnpy.trader.utility import get_folder_path from vnpy.trader.utility import get_folder_path, print_dict, extract_vt_symbol
# 市场id <=> Exchange # 市场id <=> Exchange
MARKET_XTP2VT: Dict[int, Exchange] = { MARKET_XTP2VT: Dict[int, Exchange] = {
@ -299,6 +299,7 @@ class XtpMdApi(MdApi):
tick.ask_volume_1, tick.ask_volume_2, tick.ask_volume_3, tick.ask_volume_4, tick.ask_volume_5 = data["ask_qty"][0:5] tick.ask_volume_1, tick.ask_volume_2, tick.ask_volume_3, tick.ask_volume_4, tick.ask_volume_5 = data["ask_qty"][0:5]
tick.name = symbol_name_map.get(tick.vt_symbol, tick.symbol) tick.name = symbol_name_map.get(tick.vt_symbol, tick.symbol)
self.gateway.prices.update({tick.vt_symbol: tick.last_price})
self.gateway.on_tick(tick) self.gateway.on_tick(tick)
def onSubOrderBook(self, data: dict, error: dict, last: bool) -> None: def onSubOrderBook(self, data: dict, error: dict, last: bool) -> None:
@ -363,6 +364,12 @@ class XtpMdApi(MdApi):
) )
self.gateway.on_contract(contract) self.gateway.on_contract(contract)
# 更新最新价
pre_close_price = float(data["pre_close_price"])
vt_symbol = contract.vt_symbol
if vt_symbol not in self.gateway.prices and pre_close_price>0:
self.gateway.prices.update({vt_symbol: pre_close_price})
# 更新 symbol <=> 中文名称映射 # 更新 symbol <=> 中文名称映射
symbol_name_map[contract.vt_symbol] = contract.name symbol_name_map[contract.vt_symbol] = contract.name
@ -374,8 +381,17 @@ class XtpMdApi(MdApi):
self.gateway.write_log(f"{contract.exchange.value}合约信息查询成功") self.gateway.write_log(f"{contract.exchange.value}合约信息查询成功")
def onQueryTickersPriceInfo(self, data: dict, error: dict, last: bool) -> None: def onQueryTickersPriceInfo(self, data: dict, error: dict, last: bool) -> None:
"""""" """查询最新价"""
pass self.gateway.write_log('最新价:{}'.format(print_dict(data)))
symbol = data.get('ticker')
exchange_id = data.get('exchange_id')
last_price = float(data.get('last_price', 0))
if symbol and exchange_id and last_price > 0:
exchange = EXCHANGE_XTP2VT[exchange_id]
vt_symbol = f'{symbol}.{exchange.value}'
self.gateway.prices.update({vt_symbol: last_price})
self.gateway.write_log(f'{vt_symbol} 最新价: {last_price}')
def onSubscribeAllOptionMarketData(self, data: dict, error: dict) -> None: def onSubscribeAllOptionMarketData(self, data: dict, error: dict) -> None:
"""""" """"""
@ -483,6 +499,10 @@ class XtpTdApi(TdApi):
self.reqid: int = 0 self.reqid: int = 0
self.protocol: int = 0 self.protocol: int = 0
# 证券资产
self.security_asset = None # 未查询获取持仓时是None如果查询过无持仓为0
self.security_volumes: Dict[str, int] = {} # vt_symbol, volume
# Whether current account supports margin or option # Whether current account supports margin or option
self.margin_trading = False self.margin_trading = False
self.option_trading = False self.option_trading = False
@ -590,7 +610,9 @@ class XtpTdApi(TdApi):
last: bool, last: bool,
session: int session: int
) -> None: ) -> None:
"""""" """普通账号持仓"""
# self.gateway.write_log(f"------\n {print_dict(data)}")
if data["market"] == 0: if data["market"] == 0:
return return
@ -605,8 +627,35 @@ class XtpTdApi(TdApi):
yd_volume=data["yesterday_position"], yd_volume=data["yesterday_position"],
gateway_name=self.gateway_name gateway_name=self.gateway_name
) )
vt_symbol = position.vt_symbol
self.gateway.on_position(position) self.gateway.on_position(position)
# 如果持仓>0 获取持仓对应的当前最新价
if position.volume > 0 and vt_symbol not in self.gateway.prices:
req = SubscribeRequest(symbol=position.symbol, exchange=position.exchange)
self.gateway.subscribe(req)
self.security_volumes.update({vt_symbol: data["total_qty"]})
def update_security_asset(self):
"""更新资产净值"""
#self.gateway.write_log(f'更新资产净值')
total_asset = 0
for vt_symbol, volume in self.security_volumes.items():
price = self.gateway.prices.get(vt_symbol, None)
# 获取不到股票的最新价所以当前security_asset不可用
if price is None:
self.gateway.write_log(f'取不到:{vt_symbol}的价格')
self.security_asset = None
symbol, exchange = extract_vt_symbol(vt_symbol)
req = SubscribeRequest(symbol=symbol, exchange=exchange)
self.gateway.subscribe(req)
return
total_asset += volume * price
#self.gateway.write_log(f'资产净值 => {total_asset}')
self.security_asset = total_asset
def onQueryAsset( def onQueryAsset(
self, self,
data: dict, data: dict,
@ -616,12 +665,30 @@ class XtpTdApi(TdApi):
session: int session: int
) -> None: ) -> None:
"""""" """"""
# XTP_ACCOUNT_NORMAL = 0, ///<普通账户
# XTP_ACCOUNT_CREDIT, 1 ///<信用账户
# XTP_ACCOUNT_DERIVE, 2 ///<衍生品账户
# XTP_ACCOUNT_UNKNOWN 3 ///<未知账户类型
if data['account_type'] != 0:
return
# self.gateway.write_log(print_dict(data))
self.update_security_asset()
if self.security_asset is not None:
cash_asset = data["total_asset"]
balance = cash_asset + self.security_asset
account = AccountData( account = AccountData(
accountid=self.userid, accountid=self.userid,
balance=data["buying_power"], balance=balance, # 总资产
margin=self.security_asset, # 证券资产
frozen=data["withholding_amount"], frozen=data["withholding_amount"],
gateway_name=self.gateway_name gateway_name=self.gateway_name
) )
# AccountData缺省的available 计算方法有误,这里直接取可用资金
account.available = cash_asset
self.gateway.on_account(account) self.gateway.on_account(account)
if data["account_type"] == 1: if data["account_type"] == 1:
@ -669,7 +736,8 @@ class XtpTdApi(TdApi):
last: bool, last: bool,
session: int session: int
) -> None: ) -> None:
"""""" """信用账号持仓"""
self.gateway.write_log(f"------\n {print_dict(data)}")
if data["debt_type"] == 1: if data["debt_type"] == 1:
symbol = data["ticker"] symbol = data["ticker"]
exchange = MARKET_XTP2VT[data["market"]] exchange = MARKET_XTP2VT[data["market"]]

View File

@ -96,6 +96,8 @@ class BaseGateway(ABC):
self.klines = {} self.klines = {}
self.status = {'name': gateway_name, 'con': False} self.status = {'name': gateway_name, 'con': False}
self.prices: Dict[str, float] = {} # vt_symbol, last_price
self.query_functions = [] self.query_functions = []
def create_logger(self): def create_logger(self):