diff --git a/vnpy/gateway/tiger/tiger_gateway.py b/vnpy/gateway/tiger/tiger_gateway.py index d0c1f4ab..74d34bce 100644 --- a/vnpy/gateway/tiger/tiger_gateway.py +++ b/vnpy/gateway/tiger/tiger_gateway.py @@ -116,6 +116,8 @@ class TigerGateway(BaseGateway): self.queue = Queue() self.pool = None + self.ID_TIGER2VT = {} + self.ID_VT2TIGER = {} self.ticks = {} self.trades = set() self.contracts = {} @@ -125,18 +127,14 @@ class TigerGateway(BaseGateway): """""" while self.active: try: - func, arg = self.queue.get(timeout=0.1) - print(func, arg) - if arg: - func(arg) - else: - func() + func, args = self.queue.get(timeout=0.1) + func(*args) except Empty: pass - def add_task(self, func, arg=None): + def add_task(self, func, *args): """""" - self.queue.put((func, arg)) + self.queue.put((func, [*args])) def connect(self, setting: dict): """""" @@ -157,9 +155,6 @@ class TigerGateway(BaseGateway): self.add_task(self.connect_quote) self.add_task(self.connect_trade) self.add_task(self.connect_push) - self.write_log("行情接口连接成功") - - # self.thread.start() def init_client_config(self, sandbox=True): """""" @@ -183,6 +178,7 @@ class TigerGateway(BaseGateway): self.write_log("查询合约失败") return + self.write_log("行情接口连接成功") self.write_log("合约查询成功") def connect_trade(self): @@ -256,6 +252,8 @@ class TigerGateway(BaseGateway): def on_asset_change(self, tiger_account: str, data: list): """""" data = dict(data) + if "net_liquidation" not in data: + return account = AccountData( accountid=tiger_account, @@ -274,7 +272,7 @@ class TigerGateway(BaseGateway): symbol=symbol, exchange=exchange, direction=Direction.NET, - volume=data["quantity"], + volume=int(data["quantity"]), frozen=0.0, price=data["average_cost"], pnl=data["unrealized_pnl"], @@ -283,18 +281,16 @@ class TigerGateway(BaseGateway): self.on_position(pos) def on_order_change(self, tiger_account: str, data: list): - """""" - print("委托", data) - self.local_id += 1 + """""" data = dict(data) + print("委托推送", data["origin_symbol"], data["order_id"], data["filled"], data["status"]) symbol, exchange = convert_symbol_tiger2vt(data["origin_symbol"]) status = PUSH_STATUS_TIGER2VT[data["status"]] order = OrderData( symbol=symbol, exchange=exchange, - # orderid=data["order_id"], - orderid=self.local_id, + orderid=self.ID_TIGER2VT.get(str(data["order_id"]), self.get_new_local_id()), direction=Direction.NET, price=data.get("limit_price", 0), volume=data["quantity"], @@ -313,28 +309,30 @@ class TigerGateway(BaseGateway): exchange=exchange, direction=Direction.NET, tradeid=self.tradeid, - orderid=data["order_id"], + orderid=self.ID_TIGER2VT[str(data["order_id"])], price=data["avg_fill_price"], volume=data["filled"], time=datetime.fromtimestamp(data["trade_time"] / 1000).strftime("%H:%M:%S"), gateway_name=self.gateway_name, ) self.on_trade(trade) + + def get_new_local_id(self): + self.local_id += 1 + return self.local_id def send_order(self, req: OrderRequest): """""" - self.local_id += 1 - order = req.create_order_data(self.local_id, self.gateway_name) - return order.vt_orderid + local_id = self.get_new_local_id() + order = req.create_order_data(local_id, self.gateway_name) + self.on_order(order) + self.add_task(self._send_order, req, local_id) + return order.vt_orderid - self.add_task(self._send_order, req) - - def _send_order(self, req: OrderRequest): + def _send_order(self, req: OrderRequest, local_id): """""" currency = config_symbol_currency(req.symbol) - - # first, get contract try: contract = self.trade_client.get_contracts(symbol=req.symbol, currency=currency)[0] order = self.trade_client.create_order( @@ -345,8 +343,14 @@ class TigerGateway(BaseGateway): quantity=int(req.volume), limit_price=req.price, ) + self.ID_TIGER2VT[str(order.order_id)] = local_id + self.ID_VT2TIGER[local_id] = str(order.order_id) + self.trade_client.place_order(order) + print("发单:", order.contract.symbol, order.order_id, order.quantity, order.status) + except: # noqa + traceback.print_exc() self.write_log("发单失败") return @@ -357,7 +361,8 @@ class TigerGateway(BaseGateway): def _cancel_order(self, req: CancelRequest): """""" try: - data = self.trade_client.cancel_order(order_id=req.orderid) + order_id = self.ID_VT2TIGER[req.orderid] + data = self.trade_client.cancel_order(order_id=order_id) except ApiException: self.write_log(f"撤单失败:{req.orderid}") @@ -420,8 +425,7 @@ class TigerGateway(BaseGateway): for ix, row in contract_CN.iterrows(): symbol = row["symbol"] symbol, exchange = convert_symbol_tiger2vt(symbol) - if symbol == '600001': - print(f"symbol: {symbol} t:{type(symbol)} l:{len(symbol)} ex:{exchange} n:{row['name']}") + contract = ContractData( symbol=symbol, exchange=exchange, @@ -467,7 +471,7 @@ class TigerGateway(BaseGateway): symbol=symbol, exchange=exchange, direction=Direction.NET, - volume=i.quantity, + volume=int(i.quantity), frozen=0.0, price=i.average_cost, pnl=float(i.unrealized_pnl), @@ -500,12 +504,12 @@ class TigerGateway(BaseGateway): """""" for i in data: symbol, exchange = convert_symbol_tiger2vt(str(i.contract)) - self.local_id += 1 + local_id = self.get_new_local_id() + order = OrderData( symbol=symbol, exchange=exchange, - orderid=self.local_id, - # orderid=str(i.order_id), + orderid=local_id, direction=Direction.NET, price=i.limit_price if i.limit_price else 0.0, volume=i.quantity, @@ -514,17 +518,20 @@ class TigerGateway(BaseGateway): time=datetime.fromtimestamp(i.order_time / 1000).strftime("%H:%M:%S"), gateway_name=self.gateway_name, ) - + self.ID_TIGER2VT[str(i.order_id)] = local_id self.on_order(order) + + self.ID_VT2TIGER = {v: k for k, v in self.ID_TIGER2VT.items()} + print("原始委托字典", self.ID_TIGER2VT) + print("原始反向字典", self.ID_VT2TIGER) def process_deal(self, data): """ Process trade data for both query and update. """ - for i in reversed(data): + for i in data: if i.status == ORDER_STATUS.PARTIALLY_FILLED or i.status == ORDER_STATUS.FILLED: symbol, exchange = convert_symbol_tiger2vt(str(i.contract)) - self.local_id += 1 self.tradeid += 1 trade = TradeData( @@ -532,7 +539,7 @@ class TigerGateway(BaseGateway): exchange=exchange, direction=Direction.NET, tradeid=self.tradeid, - orderid=self.local_id, + orderid=self.ID_TIGER2VT[str(i.order_id)], price=i.avg_fill_price, volume=i.filled, time=datetime.fromtimestamp(i.trade_time / 1000).strftime("%H:%M:%S"),