diff --git a/vnpy/gateway/tiger/tiger_gateway.py b/vnpy/gateway/tiger/tiger_gateway.py index 9dd8c2dc..d0c1f4ab 100644 --- a/vnpy/gateway/tiger/tiger_gateway.py +++ b/vnpy/gateway/tiger/tiger_gateway.py @@ -9,6 +9,7 @@ from datetime import datetime from multiprocessing.dummy import Pool from queue import Empty, Queue import functools +import traceback import pandas as pd from pandas import DataFrame @@ -70,7 +71,7 @@ STATUS_TIGER2VT = { ORDER_STATUS.CANCELLED: Status.CANCELLED, ORDER_STATUS.PENDING_CANCEL: Status.CANCELLED, ORDER_STATUS.REJECTED: Status.REJECTED, - ORDER_STATUS.EXPIRED: Status.NOTTRADED, + ORDER_STATUS.EXPIRED: Status.NOTTRADED } PUSH_STATUS_TIGER2VT = { @@ -81,7 +82,7 @@ PUSH_STATUS_TIGER2VT = { "Submitted": Status.SUBMITTING, "PendingSubmit": Status.SUBMITTING, "Filled": Status.ALLTRADED, - "Inactive": Status.REJECTED, + "Inactive": Status.REJECTED } @@ -91,13 +92,13 @@ class TigerGateway(BaseGateway): "tiger_id": "", "account": "", "standard_account": "", + "private_key": '', } def __init__(self, event_engine): """Constructor""" super(TigerGateway, self).__init__(event_engine, "TIGER") - self.private_key = "" self.tiger_id = "" self.account = "" self.standard_account = "" @@ -108,6 +109,7 @@ class TigerGateway(BaseGateway): self.quote_client = None self.push_client = None + self.local_id = 1000000 self.tradeid = 0 self.active = False @@ -124,6 +126,7 @@ class TigerGateway(BaseGateway): while self.active: try: func, arg = self.queue.get(timeout=0.1) + print(func, arg) if arg: func(arg) else: @@ -137,7 +140,7 @@ class TigerGateway(BaseGateway): def connect(self, setting: dict): """""" - self.private_key = self.private_key + self.private_key = setting['private_key'] self.tiger_id = setting["tiger_id"] self.account = setting["account"] self.standard_account = setting["standard_account"] @@ -154,6 +157,7 @@ class TigerGateway(BaseGateway): self.add_task(self.connect_quote) self.add_task(self.connect_trade) self.add_task(self.connect_push) + self.write_log("行情接口连接成功") # self.thread.start() @@ -171,17 +175,15 @@ class TigerGateway(BaseGateway): """ Connect to market data server. """ - self.quote_client = QuoteClient(self.client_config) try: + self.quote_client = QuoteClient(self.client_config) self.symbol_names = dict(self.quote_client.get_symbol_names(lang=Language.zh_CN)) + self.query_contract() except ApiException: - self.write_log("行情接口连接失败") + self.write_log("查询合约失败") return - - if self.symbol_names: - self.add_task(self.query_contract) - self.write_log("行情接口连接成功") + self.write_log("合约查询成功") def connect_trade(self): """ @@ -189,17 +191,14 @@ class TigerGateway(BaseGateway): """ self.trade_client = TradeClient(self.client_config) try: - data = self.trade_client.get_managed_accounts() + self.add_task(self.query_order) + self.add_task(self.query_position) + self.add_task(self.query_account) except ApiException: self.write_log("交易接口连接失败") return - if data: - self.add_task(self.query_order) - self.add_task(self.query_position) - self.add_task(self.query_account) - - self.write_log("交易接口连接成功") + self.write_log("交易接口连接成功") def connect_push(self): """ @@ -286,6 +285,7 @@ class TigerGateway(BaseGateway): def on_order_change(self, tiger_account: str, data: list): """""" print("委托", data) + self.local_id += 1 data = dict(data) symbol, exchange = convert_symbol_tiger2vt(data["origin_symbol"]) status = PUSH_STATUS_TIGER2VT[data["status"]] @@ -293,7 +293,8 @@ class TigerGateway(BaseGateway): order = OrderData( symbol=symbol, exchange=exchange, - orderid=data["order_id"], + # orderid=data["order_id"], + orderid=self.local_id, direction=Direction.NET, price=data.get("limit_price", 0), volume=data["quantity"], @@ -322,6 +323,11 @@ class TigerGateway(BaseGateway): def send_order(self, req: OrderRequest): """""" + self.local_id += 1 + order = req.create_order_data(self.local_id, self.gateway_name) + return order.vt_orderid + self.on_order(order) + self.add_task(self._send_order, req) def _send_order(self, req: OrderRequest): @@ -331,12 +337,6 @@ class TigerGateway(BaseGateway): # first, get contract try: contract = self.trade_client.get_contracts(symbol=req.symbol, currency=currency)[0] - except ApiException: - self.write_log("获取合约对象失败") - return - - # second, create order - try: order = self.trade_client.create_order( account=self.account, contract=contract, @@ -345,24 +345,11 @@ class TigerGateway(BaseGateway): quantity=int(req.volume), limit_price=req.price, ) - except ApiException: - self.write_log("创建订单失败") + self.trade_client.place_order(order) + except: # noqa + self.write_log("发单失败") return - # third, place order - try: - data = self.trade_client.place_order(order) - except ApiException: - self.write_log("发送订单失败") - return - - if data: - orderid = order.order_id - - order = req.create_order_data(orderid, self.gateway_name) - self.on_order(order) - return order.vt_orderid - def cancel_order(self, req: CancelRequest): """""" self.add_task(self._cancel_order, req) @@ -373,7 +360,6 @@ class TigerGateway(BaseGateway): data = self.trade_client.cancel_order(order_id=req.orderid) except ApiException: self.write_log(f"撤单失败:{req.orderid}") - return if not data: self.write_log('撤单成功') @@ -382,12 +368,8 @@ class TigerGateway(BaseGateway): """""" # HK Stock - try: - symbols_names_HK = self.quote_client.get_symbol_names(lang=Language.zh_CN, market=Market.HK) - contract_names_HK = DataFrame(symbols_names_HK, columns=['symbol', 'name']) - except ApiException: - self.write_log("查询合约失败") - return + symbols_names_HK = self.quote_client.get_symbol_names(lang=Language.zh_CN, market=Market.HK) + contract_names_HK = DataFrame(symbols_names_HK, columns=['symbol', 'name']) contractList = list(contract_names_HK["symbol"]) i, n = 0, len(contractList) @@ -397,7 +379,6 @@ class TigerGateway(BaseGateway): c = contractList[i - 500:i] r = self.quote_client.get_trade_metas(c) result = result.append(r) - pass contract_detail_HK = result.sort_values(by="symbol", ascending=True) contract_HK = pd.merge(contract_names_HK, contract_detail_HK, how='left', on='symbol') @@ -439,7 +420,8 @@ class TigerGateway(BaseGateway): for ix, row in contract_CN.iterrows(): symbol = row["symbol"] symbol, exchange = convert_symbol_tiger2vt(symbol) - + if symbol == '600001': + print(f"symbol: {symbol} t:{type(symbol)} l:{len(symbol)} ex:{exchange} n:{row['name']}") contract = ContractData( symbol=symbol, exchange=exchange, @@ -451,9 +433,7 @@ class TigerGateway(BaseGateway): ) self.on_contract(contract) self.contracts[contract.vt_symbol] = contract - - self.write_log("合约查询成功") - + def query_account(self): """""" try: @@ -500,7 +480,9 @@ class TigerGateway(BaseGateway): """""" try: data = self.trade_client.get_orders() - except ApiException: + data = sorted(data, key=lambda x: x.order_time, reverse=False) + except: # noqa + traceback.print_exc() self.write_log("查询委托失败") return @@ -518,11 +500,12 @@ class TigerGateway(BaseGateway): """""" for i in data: symbol, exchange = convert_symbol_tiger2vt(str(i.contract)) - + self.local_id += 1 order = OrderData( symbol=symbol, exchange=exchange, - orderid=str(i.order_id), + orderid=self.local_id, + # orderid=str(i.order_id), direction=Direction.NET, price=i.limit_price if i.limit_price else 0.0, volume=i.quantity, @@ -541,6 +524,7 @@ class TigerGateway(BaseGateway): for i in reversed(data): if i.status == ORDER_STATUS.PARTIALLY_FILLED or i.status == ORDER_STATUS.FILLED: symbol, exchange = convert_symbol_tiger2vt(str(i.contract)) + self.local_id += 1 self.tradeid += 1 trade = TradeData( @@ -548,7 +532,7 @@ class TigerGateway(BaseGateway): exchange=exchange, direction=Direction.NET, tradeid=self.tradeid, - orderid=i.order_id, + orderid=self.local_id, price=i.avg_fill_price, volume=i.filled, time=datetime.fromtimestamp(i.trade_time / 1000).strftime("%H:%M:%S"),