Update tiger_gateway.py

This commit is contained in:
1122455801 2019-03-12 17:39:28 +08:00
parent 82f89fc2ec
commit d0073ea6c6

View File

@ -9,6 +9,7 @@ from datetime import datetime
from multiprocessing.dummy import Pool from multiprocessing.dummy import Pool
from queue import Empty, Queue from queue import Empty, Queue
import functools import functools
import traceback
import pandas as pd import pandas as pd
from pandas import DataFrame from pandas import DataFrame
@ -70,7 +71,7 @@ STATUS_TIGER2VT = {
ORDER_STATUS.CANCELLED: Status.CANCELLED, ORDER_STATUS.CANCELLED: Status.CANCELLED,
ORDER_STATUS.PENDING_CANCEL: Status.CANCELLED, ORDER_STATUS.PENDING_CANCEL: Status.CANCELLED,
ORDER_STATUS.REJECTED: Status.REJECTED, ORDER_STATUS.REJECTED: Status.REJECTED,
ORDER_STATUS.EXPIRED: Status.NOTTRADED, ORDER_STATUS.EXPIRED: Status.NOTTRADED
} }
PUSH_STATUS_TIGER2VT = { PUSH_STATUS_TIGER2VT = {
@ -81,7 +82,7 @@ PUSH_STATUS_TIGER2VT = {
"Submitted": Status.SUBMITTING, "Submitted": Status.SUBMITTING,
"PendingSubmit": Status.SUBMITTING, "PendingSubmit": Status.SUBMITTING,
"Filled": Status.ALLTRADED, "Filled": Status.ALLTRADED,
"Inactive": Status.REJECTED, "Inactive": Status.REJECTED
} }
@ -91,13 +92,13 @@ class TigerGateway(BaseGateway):
"tiger_id": "", "tiger_id": "",
"account": "", "account": "",
"standard_account": "", "standard_account": "",
"private_key": '',
} }
def __init__(self, event_engine): def __init__(self, event_engine):
"""Constructor""" """Constructor"""
super(TigerGateway, self).__init__(event_engine, "TIGER") super(TigerGateway, self).__init__(event_engine, "TIGER")
self.private_key = ""
self.tiger_id = "" self.tiger_id = ""
self.account = "" self.account = ""
self.standard_account = "" self.standard_account = ""
@ -108,6 +109,7 @@ class TigerGateway(BaseGateway):
self.quote_client = None self.quote_client = None
self.push_client = None self.push_client = None
self.local_id = 1000000
self.tradeid = 0 self.tradeid = 0
self.active = False self.active = False
@ -124,6 +126,7 @@ class TigerGateway(BaseGateway):
while self.active: while self.active:
try: try:
func, arg = self.queue.get(timeout=0.1) func, arg = self.queue.get(timeout=0.1)
print(func, arg)
if arg: if arg:
func(arg) func(arg)
else: else:
@ -137,7 +140,7 @@ class TigerGateway(BaseGateway):
def connect(self, setting: dict): def connect(self, setting: dict):
"""""" """"""
self.private_key = self.private_key self.private_key = setting['private_key']
self.tiger_id = setting["tiger_id"] self.tiger_id = setting["tiger_id"]
self.account = setting["account"] self.account = setting["account"]
self.standard_account = setting["standard_account"] self.standard_account = setting["standard_account"]
@ -154,6 +157,7 @@ class TigerGateway(BaseGateway):
self.add_task(self.connect_quote) self.add_task(self.connect_quote)
self.add_task(self.connect_trade) self.add_task(self.connect_trade)
self.add_task(self.connect_push) self.add_task(self.connect_push)
self.write_log("行情接口连接成功")
# self.thread.start() # self.thread.start()
@ -171,17 +175,15 @@ class TigerGateway(BaseGateway):
""" """
Connect to market data server. Connect to market data server.
""" """
self.quote_client = QuoteClient(self.client_config)
try: try:
self.quote_client = QuoteClient(self.client_config)
self.symbol_names = dict(self.quote_client.get_symbol_names(lang=Language.zh_CN)) self.symbol_names = dict(self.quote_client.get_symbol_names(lang=Language.zh_CN))
self.query_contract()
except ApiException: except ApiException:
self.write_log("行情接口连接失败") self.write_log("查询合约失败")
return return
if self.symbol_names: self.write_log("合约查询成功")
self.add_task(self.query_contract)
self.write_log("行情接口连接成功")
def connect_trade(self): def connect_trade(self):
""" """
@ -189,15 +191,12 @@ class TigerGateway(BaseGateway):
""" """
self.trade_client = TradeClient(self.client_config) self.trade_client = TradeClient(self.client_config)
try: try:
data = self.trade_client.get_managed_accounts()
except ApiException:
self.write_log("交易接口连接失败")
return
if data:
self.add_task(self.query_order) self.add_task(self.query_order)
self.add_task(self.query_position) self.add_task(self.query_position)
self.add_task(self.query_account) self.add_task(self.query_account)
except ApiException:
self.write_log("交易接口连接失败")
return
self.write_log("交易接口连接成功") self.write_log("交易接口连接成功")
@ -286,6 +285,7 @@ class TigerGateway(BaseGateway):
def on_order_change(self, tiger_account: str, data: list): def on_order_change(self, tiger_account: str, data: list):
"""""" """"""
print("委托", data) print("委托", data)
self.local_id += 1
data = dict(data) data = dict(data)
symbol, exchange = convert_symbol_tiger2vt(data["origin_symbol"]) symbol, exchange = convert_symbol_tiger2vt(data["origin_symbol"])
status = PUSH_STATUS_TIGER2VT[data["status"]] status = PUSH_STATUS_TIGER2VT[data["status"]]
@ -293,7 +293,8 @@ class TigerGateway(BaseGateway):
order = OrderData( order = OrderData(
symbol=symbol, symbol=symbol,
exchange=exchange, exchange=exchange,
orderid=data["order_id"], # orderid=data["order_id"],
orderid=self.local_id,
direction=Direction.NET, direction=Direction.NET,
price=data.get("limit_price", 0), price=data.get("limit_price", 0),
volume=data["quantity"], volume=data["quantity"],
@ -322,6 +323,11 @@ class TigerGateway(BaseGateway):
def send_order(self, req: OrderRequest): def send_order(self, req: OrderRequest):
"""""" """"""
self.local_id += 1
order = req.create_order_data(self.local_id, self.gateway_name)
return order.vt_orderid
self.on_order(order)
self.add_task(self._send_order, req) self.add_task(self._send_order, req)
def _send_order(self, req: OrderRequest): def _send_order(self, req: OrderRequest):
@ -331,12 +337,6 @@ class TigerGateway(BaseGateway):
# first, get contract # first, get contract
try: try:
contract = self.trade_client.get_contracts(symbol=req.symbol, currency=currency)[0] contract = self.trade_client.get_contracts(symbol=req.symbol, currency=currency)[0]
except ApiException:
self.write_log("获取合约对象失败")
return
# second, create order
try:
order = self.trade_client.create_order( order = self.trade_client.create_order(
account=self.account, account=self.account,
contract=contract, contract=contract,
@ -345,24 +345,11 @@ class TigerGateway(BaseGateway):
quantity=int(req.volume), quantity=int(req.volume),
limit_price=req.price, limit_price=req.price,
) )
except ApiException: self.trade_client.place_order(order)
self.write_log("创建订单失败") except: # noqa
self.write_log("发单失败")
return return
# third, place order
try:
data = self.trade_client.place_order(order)
except ApiException:
self.write_log("发送订单失败")
return
if data:
orderid = order.order_id
order = req.create_order_data(orderid, self.gateway_name)
self.on_order(order)
return order.vt_orderid
def cancel_order(self, req: CancelRequest): def cancel_order(self, req: CancelRequest):
"""""" """"""
self.add_task(self._cancel_order, req) self.add_task(self._cancel_order, req)
@ -373,7 +360,6 @@ class TigerGateway(BaseGateway):
data = self.trade_client.cancel_order(order_id=req.orderid) data = self.trade_client.cancel_order(order_id=req.orderid)
except ApiException: except ApiException:
self.write_log(f"撤单失败:{req.orderid}") self.write_log(f"撤单失败:{req.orderid}")
return
if not data: if not data:
self.write_log('撤单成功') self.write_log('撤单成功')
@ -382,12 +368,8 @@ class TigerGateway(BaseGateway):
"""""" """"""
# HK Stock # HK Stock
try:
symbols_names_HK = self.quote_client.get_symbol_names(lang=Language.zh_CN, market=Market.HK) symbols_names_HK = self.quote_client.get_symbol_names(lang=Language.zh_CN, market=Market.HK)
contract_names_HK = DataFrame(symbols_names_HK, columns=['symbol', 'name']) contract_names_HK = DataFrame(symbols_names_HK, columns=['symbol', 'name'])
except ApiException:
self.write_log("查询合约失败")
return
contractList = list(contract_names_HK["symbol"]) contractList = list(contract_names_HK["symbol"])
i, n = 0, len(contractList) i, n = 0, len(contractList)
@ -397,7 +379,6 @@ class TigerGateway(BaseGateway):
c = contractList[i - 500:i] c = contractList[i - 500:i]
r = self.quote_client.get_trade_metas(c) r = self.quote_client.get_trade_metas(c)
result = result.append(r) result = result.append(r)
pass
contract_detail_HK = result.sort_values(by="symbol", ascending=True) contract_detail_HK = result.sort_values(by="symbol", ascending=True)
contract_HK = pd.merge(contract_names_HK, contract_detail_HK, how='left', on='symbol') contract_HK = pd.merge(contract_names_HK, contract_detail_HK, how='left', on='symbol')
@ -439,7 +420,8 @@ class TigerGateway(BaseGateway):
for ix, row in contract_CN.iterrows(): for ix, row in contract_CN.iterrows():
symbol = row["symbol"] symbol = row["symbol"]
symbol, exchange = convert_symbol_tiger2vt(symbol) symbol, exchange = convert_symbol_tiger2vt(symbol)
if symbol == '600001':
print(f"symbol: {symbol} t:{type(symbol)} l:{len(symbol)} ex:{exchange} n:{row['name']}")
contract = ContractData( contract = ContractData(
symbol=symbol, symbol=symbol,
exchange=exchange, exchange=exchange,
@ -452,8 +434,6 @@ class TigerGateway(BaseGateway):
self.on_contract(contract) self.on_contract(contract)
self.contracts[contract.vt_symbol] = contract self.contracts[contract.vt_symbol] = contract
self.write_log("合约查询成功")
def query_account(self): def query_account(self):
"""""" """"""
try: try:
@ -500,7 +480,9 @@ class TigerGateway(BaseGateway):
"""""" """"""
try: try:
data = self.trade_client.get_orders() data = self.trade_client.get_orders()
except ApiException: data = sorted(data, key=lambda x: x.order_time, reverse=False)
except: # noqa
traceback.print_exc()
self.write_log("查询委托失败") self.write_log("查询委托失败")
return return
@ -518,11 +500,12 @@ class TigerGateway(BaseGateway):
"""""" """"""
for i in data: for i in data:
symbol, exchange = convert_symbol_tiger2vt(str(i.contract)) symbol, exchange = convert_symbol_tiger2vt(str(i.contract))
self.local_id += 1
order = OrderData( order = OrderData(
symbol=symbol, symbol=symbol,
exchange=exchange, exchange=exchange,
orderid=str(i.order_id), orderid=self.local_id,
# orderid=str(i.order_id),
direction=Direction.NET, direction=Direction.NET,
price=i.limit_price if i.limit_price else 0.0, price=i.limit_price if i.limit_price else 0.0,
volume=i.quantity, volume=i.quantity,
@ -541,6 +524,7 @@ class TigerGateway(BaseGateway):
for i in reversed(data): for i in reversed(data):
if i.status == ORDER_STATUS.PARTIALLY_FILLED or i.status == ORDER_STATUS.FILLED: if i.status == ORDER_STATUS.PARTIALLY_FILLED or i.status == ORDER_STATUS.FILLED:
symbol, exchange = convert_symbol_tiger2vt(str(i.contract)) symbol, exchange = convert_symbol_tiger2vt(str(i.contract))
self.local_id += 1
self.tradeid += 1 self.tradeid += 1
trade = TradeData( trade = TradeData(
@ -548,7 +532,7 @@ class TigerGateway(BaseGateway):
exchange=exchange, exchange=exchange,
direction=Direction.NET, direction=Direction.NET,
tradeid=self.tradeid, tradeid=self.tradeid,
orderid=i.order_id, orderid=self.local_id,
price=i.avg_fill_price, price=i.avg_fill_price,
volume=i.filled, volume=i.filled,
time=datetime.fromtimestamp(i.trade_time / 1000).strftime("%H:%M:%S"), time=datetime.fromtimestamp(i.trade_time / 1000).strftime("%H:%M:%S"),