Merge pull request #1485 from 1122455801/tiger_gateway_version_06

[Mod] tiger_gateway.py
This commit is contained in:
vn.py 2019-03-17 11:07:45 +08:00 committed by GitHub
commit 9cad3587a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -116,6 +116,8 @@ class TigerGateway(BaseGateway):
self.queue = Queue() self.queue = Queue()
self.pool = None self.pool = None
self.ID_TIGER2VT = {}
self.ID_VT2TIGER = {}
self.ticks = {} self.ticks = {}
self.trades = set() self.trades = set()
self.contracts = {} self.contracts = {}
@ -125,18 +127,14 @@ class TigerGateway(BaseGateway):
"""""" """"""
while self.active: while self.active:
try: try:
func, arg = self.queue.get(timeout=0.1) func, args = self.queue.get(timeout=0.1)
print(func, arg) func(*args)
if arg:
func(arg)
else:
func()
except Empty: except Empty:
pass pass
def add_task(self, func, arg=None): def add_task(self, func, *args):
"""""" """"""
self.queue.put((func, arg)) self.queue.put((func, [*args]))
def connect(self, setting: dict): def connect(self, setting: dict):
"""""" """"""
@ -157,9 +155,6 @@ class TigerGateway(BaseGateway):
self.add_task(self.connect_quote) self.add_task(self.connect_quote)
self.add_task(self.connect_trade) self.add_task(self.connect_trade)
self.add_task(self.connect_push) self.add_task(self.connect_push)
self.write_log("行情接口连接成功")
# self.thread.start()
def init_client_config(self, sandbox=True): def init_client_config(self, sandbox=True):
"""""" """"""
@ -183,6 +178,7 @@ class TigerGateway(BaseGateway):
self.write_log("查询合约失败") self.write_log("查询合约失败")
return return
self.write_log("行情接口连接成功")
self.write_log("合约查询成功") self.write_log("合约查询成功")
def connect_trade(self): def connect_trade(self):
@ -256,6 +252,8 @@ class TigerGateway(BaseGateway):
def on_asset_change(self, tiger_account: str, data: list): def on_asset_change(self, tiger_account: str, data: list):
"""""" """"""
data = dict(data) data = dict(data)
if "net_liquidation" not in data:
return
account = AccountData( account = AccountData(
accountid=tiger_account, accountid=tiger_account,
@ -274,7 +272,7 @@ class TigerGateway(BaseGateway):
symbol=symbol, symbol=symbol,
exchange=exchange, exchange=exchange,
direction=Direction.NET, direction=Direction.NET,
volume=data["quantity"], volume=int(data["quantity"]),
frozen=0.0, frozen=0.0,
price=data["average_cost"], price=data["average_cost"],
pnl=data["unrealized_pnl"], pnl=data["unrealized_pnl"],
@ -284,17 +282,15 @@ class TigerGateway(BaseGateway):
def on_order_change(self, tiger_account: str, data: list): def on_order_change(self, tiger_account: str, data: list):
"""""" """"""
print("委托", data)
self.local_id += 1
data = dict(data) data = dict(data)
print("委托推送", data["origin_symbol"], data["order_id"], data["filled"], data["status"])
symbol, exchange = convert_symbol_tiger2vt(data["origin_symbol"]) symbol, exchange = convert_symbol_tiger2vt(data["origin_symbol"])
status = PUSH_STATUS_TIGER2VT[data["status"]] status = PUSH_STATUS_TIGER2VT[data["status"]]
order = OrderData( order = OrderData(
symbol=symbol, symbol=symbol,
exchange=exchange, exchange=exchange,
# orderid=data["order_id"], orderid=self.ID_TIGER2VT.get(str(data["order_id"]), self.get_new_local_id()),
orderid=self.local_id,
direction=Direction.NET, direction=Direction.NET,
price=data.get("limit_price", 0), price=data.get("limit_price", 0),
volume=data["quantity"], volume=data["quantity"],
@ -313,7 +309,7 @@ class TigerGateway(BaseGateway):
exchange=exchange, exchange=exchange,
direction=Direction.NET, direction=Direction.NET,
tradeid=self.tradeid, tradeid=self.tradeid,
orderid=data["order_id"], orderid=self.ID_TIGER2VT[str(data["order_id"])],
price=data["avg_fill_price"], price=data["avg_fill_price"],
volume=data["filled"], volume=data["filled"],
time=datetime.fromtimestamp(data["trade_time"] / 1000).strftime("%H:%M:%S"), time=datetime.fromtimestamp(data["trade_time"] / 1000).strftime("%H:%M:%S"),
@ -321,20 +317,22 @@ class TigerGateway(BaseGateway):
) )
self.on_trade(trade) self.on_trade(trade)
def get_new_local_id(self):
self.local_id += 1
return self.local_id
def send_order(self, req: OrderRequest): def send_order(self, req: OrderRequest):
"""""" """"""
self.local_id += 1 local_id = self.get_new_local_id()
order = req.create_order_data(self.local_id, self.gateway_name) order = req.create_order_data(local_id, self.gateway_name)
return order.vt_orderid
self.on_order(order) self.on_order(order)
self.add_task(self._send_order, req, local_id)
return order.vt_orderid
self.add_task(self._send_order, req) def _send_order(self, req: OrderRequest, local_id):
def _send_order(self, req: OrderRequest):
"""""" """"""
currency = config_symbol_currency(req.symbol) currency = config_symbol_currency(req.symbol)
# first, get contract
try: try:
contract = self.trade_client.get_contracts(symbol=req.symbol, currency=currency)[0] contract = self.trade_client.get_contracts(symbol=req.symbol, currency=currency)[0]
order = self.trade_client.create_order( order = self.trade_client.create_order(
@ -345,8 +343,14 @@ class TigerGateway(BaseGateway):
quantity=int(req.volume), quantity=int(req.volume),
limit_price=req.price, limit_price=req.price,
) )
self.ID_TIGER2VT[str(order.order_id)] = local_id
self.ID_VT2TIGER[local_id] = str(order.order_id)
self.trade_client.place_order(order) self.trade_client.place_order(order)
print("发单:", order.contract.symbol, order.order_id, order.quantity, order.status)
except: # noqa except: # noqa
traceback.print_exc()
self.write_log("发单失败") self.write_log("发单失败")
return return
@ -357,7 +361,8 @@ class TigerGateway(BaseGateway):
def _cancel_order(self, req: CancelRequest): def _cancel_order(self, req: CancelRequest):
"""""" """"""
try: try:
data = self.trade_client.cancel_order(order_id=req.orderid) order_id = self.ID_VT2TIGER[req.orderid]
data = self.trade_client.cancel_order(order_id=order_id)
except ApiException: except ApiException:
self.write_log(f"撤单失败:{req.orderid}") self.write_log(f"撤单失败:{req.orderid}")
@ -420,8 +425,7 @@ class TigerGateway(BaseGateway):
for ix, row in contract_CN.iterrows(): for ix, row in contract_CN.iterrows():
symbol = row["symbol"] symbol = row["symbol"]
symbol, exchange = convert_symbol_tiger2vt(symbol) symbol, exchange = convert_symbol_tiger2vt(symbol)
if symbol == '600001':
print(f"symbol: {symbol} t:{type(symbol)} l:{len(symbol)} ex:{exchange} n:{row['name']}")
contract = ContractData( contract = ContractData(
symbol=symbol, symbol=symbol,
exchange=exchange, exchange=exchange,
@ -467,7 +471,7 @@ class TigerGateway(BaseGateway):
symbol=symbol, symbol=symbol,
exchange=exchange, exchange=exchange,
direction=Direction.NET, direction=Direction.NET,
volume=i.quantity, volume=int(i.quantity),
frozen=0.0, frozen=0.0,
price=i.average_cost, price=i.average_cost,
pnl=float(i.unrealized_pnl), pnl=float(i.unrealized_pnl),
@ -500,12 +504,12 @@ class TigerGateway(BaseGateway):
"""""" """"""
for i in data: for i in data:
symbol, exchange = convert_symbol_tiger2vt(str(i.contract)) symbol, exchange = convert_symbol_tiger2vt(str(i.contract))
self.local_id += 1 local_id = self.get_new_local_id()
order = OrderData( order = OrderData(
symbol=symbol, symbol=symbol,
exchange=exchange, exchange=exchange,
orderid=self.local_id, orderid=local_id,
# orderid=str(i.order_id),
direction=Direction.NET, direction=Direction.NET,
price=i.limit_price if i.limit_price else 0.0, price=i.limit_price if i.limit_price else 0.0,
volume=i.quantity, volume=i.quantity,
@ -514,17 +518,20 @@ class TigerGateway(BaseGateway):
time=datetime.fromtimestamp(i.order_time / 1000).strftime("%H:%M:%S"), time=datetime.fromtimestamp(i.order_time / 1000).strftime("%H:%M:%S"),
gateway_name=self.gateway_name, gateway_name=self.gateway_name,
) )
self.ID_TIGER2VT[str(i.order_id)] = local_id
self.on_order(order) self.on_order(order)
self.ID_VT2TIGER = {v: k for k, v in self.ID_TIGER2VT.items()}
print("原始委托字典", self.ID_TIGER2VT)
print("原始反向字典", self.ID_VT2TIGER)
def process_deal(self, data): def process_deal(self, data):
""" """
Process trade data for both query and update. Process trade data for both query and update.
""" """
for i in reversed(data): for i in data:
if i.status == ORDER_STATUS.PARTIALLY_FILLED or i.status == ORDER_STATUS.FILLED: if i.status == ORDER_STATUS.PARTIALLY_FILLED or i.status == ORDER_STATUS.FILLED:
symbol, exchange = convert_symbol_tiger2vt(str(i.contract)) symbol, exchange = convert_symbol_tiger2vt(str(i.contract))
self.local_id += 1
self.tradeid += 1 self.tradeid += 1
trade = TradeData( trade = TradeData(
@ -532,7 +539,7 @@ class TigerGateway(BaseGateway):
exchange=exchange, exchange=exchange,
direction=Direction.NET, direction=Direction.NET,
tradeid=self.tradeid, tradeid=self.tradeid,
orderid=self.local_id, orderid=self.ID_TIGER2VT[str(i.order_id)],
price=i.avg_fill_price, price=i.avg_fill_price,
volume=i.filled, volume=i.filled,
time=datetime.fromtimestamp(i.trade_time / 1000).strftime("%H:%M:%S"), time=datetime.fromtimestamp(i.trade_time / 1000).strftime("%H:%M:%S"),