From 897f8988441776d929f9fb21c4b337377fc030cb Mon Sep 17 00:00:00 2001 From: 1122455801 Date: Wed, 6 Mar 2019 10:54:56 +0800 Subject: [PATCH] Update tiger_gateway.py --- vnpy/gateway/tiger/tiger_gateway.py | 95 ++++++++++++----------------- 1 file changed, 38 insertions(+), 57 deletions(-) diff --git a/vnpy/gateway/tiger/tiger_gateway.py b/vnpy/gateway/tiger/tiger_gateway.py index 2ab0badb..697bed03 100644 --- a/vnpy/gateway/tiger/tiger_gateway.py +++ b/vnpy/gateway/tiger/tiger_gateway.py @@ -5,7 +5,6 @@ pip install tigeropen """ from copy import copy -from datetime import datetime from threading import Thread from time import sleep import time @@ -13,7 +12,7 @@ import pandas as pd from pandas import DataFrame from tigeropen.tiger_open_config import TigerOpenClientConfig -from tigeropen.common.consts import Language, Currency ,Market +from tigeropen.common.consts import Language, Currency, Market from tigeropen.quote.quote_client import QuoteClient from tigeropen.trade.trade_client import TradeClient from tigeropen.trade.domain.order import ORDER_STATUS @@ -35,7 +34,7 @@ from vnpy.trader.object import ( SubscribeRequest, OrderRequest, CancelRequest, - ) +) PRODUCT_VT2TIGER = { @@ -46,18 +45,18 @@ PRODUCT_VT2TIGER = { Product.FUTURES: "FUT", Product.OPTION: "FOP", Product.FOREX: "CASH" - } +} DIRECTION_VT2TIGER = { Direction.LONG: "BUY", Direction.SHORT: "SELL", - } +} DIRECTION_TIGER2VT = {v: k for k, v in DIRECTION_VT2TIGER.items()} PRICETYPE_VT2TIGER = { PriceType.LIMIT: "LMT", PriceType.MARKET: "MKT", - } +} STATUS_TIGER2VT = { ORDER_STATUS.PENDING_NEW: Status.SUBMITTING, @@ -69,7 +68,7 @@ STATUS_TIGER2VT = { ORDER_STATUS.PENDING_CANCEL: Status.CANCELLED, ORDER_STATUS.REJECTED: Status.REJECTED, ORDER_STATUS.EXPIRED: Status.NOTTRADED - } +} class TigerGateway(BaseGateway): @@ -82,7 +81,7 @@ class TigerGateway(BaseGateway): "standard_account": "", "paper_account": "DU575569", "language": "Language.zh_CN", - } + } def __init__(self, event_engine): """Constructor""" @@ -109,7 +108,7 @@ class TigerGateway(BaseGateway): # For query function. self.count = 0 self.interval = 1 - self.query_funcs = [self.query_account, self.query_position,self.query_order, self.query_trade] + self.query_funcs = [self.query_account, self.query_position, self.query_order] def connect(self, setting: dict): """""" @@ -120,7 +119,6 @@ class TigerGateway(BaseGateway): self.paper_account = setting["paper_account"] self.languege = setting["language"] - self.get_client_config() self.connect_quote() self.connect_trade() @@ -132,23 +130,19 @@ class TigerGateway(BaseGateway): """ Query all data necessary. """ - #self.write_log("query_data") sleep(2.0) # Wait 2 seconds till connection completed. self.query_contract() - self.query_trade() #1 - self.query_order() #2 - self.query_position() # + # self.query_trade() + self.query_order() + self.query_position() self.query_account() - # Start fixed interval query. self.event_engine.register(EVENT_TIMER, self.process_timer_event) - def process_timer_event(self, event): """""" - #self.write_log("process_time_event") self.count += 1 if self.count < self.interval: return @@ -157,7 +151,6 @@ class TigerGateway(BaseGateway): func() self.query_funcs.append(func) - def get_client_config(self, sandbox=True): """""" self.client_config = TigerOpenClientConfig(sandbox_debug=sandbox) @@ -209,35 +202,37 @@ class TigerGateway(BaseGateway): self.write_log("推送接口连接成功") -########################################################## def subscribe(self, req: SubscribeRequest): """""" - symbol = convert_symbol_vt2tiger(req.symbol, req.exchange) - self.push_client.subscribe_quote(symbol) + # symbol = convert_symbol_vt2tiger(req.symbol, req.exchange) + self.push_client.subscribe_quote([req.symbol]) - data = self.push_client.on_quote_change(symbol=req.symbol) - tick = self.ticks.get(data,None) + def on_quote_change(*args): + print(args) + data = self.push_client.quote_changed = on_quote_change + + tick = self.ticks.get(data, None) if not tick: tick = TickData( - symbol=symbol, - exchange=None, + symbol=req.symbol, + exchange=req.exchange, datetime=None, gateway_name=self.gateway_name, ) self.ticks[data] = tick - self.process_quote() + self.process_quote(data) contract = self.contracts.get(tick.vt_symbol, None) if contract: tick.name = contract.name return tick - def process_quote(self,data): + def process_quote(self, data): """报价推送""" - symbol,info,_= data + symbol, info, _ = data volume, latest_price, high_price, prev_close, low_price, open_price, latest_time = [i[1] for i in info] - tick = self.get_tick(symbol) - time_local = time.localtime(latest_time/1000) + tick = self.get_tick(symbol) + time_local = time.localtime(latest_time / 1000) tick.datetime = time.strftime("%Y-%m-%d %H:%M:%S", time_local) tick.open_price = open_price tick.high_price = open_price @@ -248,12 +243,6 @@ class TigerGateway(BaseGateway): self.on_tick(copy(tick)) - - - - - -########################################### def send_order(self, req: OrderRequest): """""" symbol = convert_symbol_vt2tiger(req.symbol, req.exchange) @@ -262,7 +251,7 @@ class TigerGateway(BaseGateway): # first, get contract try: - contract = self.trade_client.get_contracts(symbol=symbol,currency=currency)[0] + contract = self.trade_client.get_contracts(symbol=symbol, currency=currency)[0] except ApiException: self.write_log("获取合约对象失败") return @@ -311,8 +300,8 @@ class TigerGateway(BaseGateway): # HK Stock try: - symbols_names_HK = self.quote_client.get_symbol_names(lang=Language.zh_CN ,market=Market.HK) - contract_names_HK = DataFrame(symbols_names_HK,columns=['symbol', 'name']) + symbols_names_HK = self.quote_client.get_symbol_names(lang=Language.zh_CN, market=Market.HK) + contract_names_HK = DataFrame(symbols_names_HK, columns=['symbol', 'name']) except ApiException: self.write_log("查询合约失败") return @@ -341,11 +330,11 @@ class TigerGateway(BaseGateway): gateway_name=self.gateway_name, ) self.on_contract(contract) - self.contracts[contract.vt_symbol] = contract + self.contracts[contract.vt_symbol] = contract # US Stock - symbols_names_US = self.quote_client.get_symbol_names(lang=Language.zh_CN ,market=Market.US) - contract_US = DataFrame(symbols_names_US,columns=['symbol', 'name']) + symbols_names_US = self.quote_client.get_symbol_names(lang=Language.zh_CN, market=Market.US) + contract_US = DataFrame(symbols_names_US, columns=['symbol', 'name']) for ix, row in contract_US.iterrows(): contract = ContractData( @@ -361,8 +350,8 @@ class TigerGateway(BaseGateway): self.contracts[contract.vt_symbol] = contract # CN Stock - symbols_names_CN = self.quote_client.get_symbol_names(lang=Language.zh_CN ,market=Market.CN) - contract_CN = DataFrame(symbols_names_CN,columns=['symbol', 'name']) + symbols_names_CN = self.quote_client.get_symbol_names(lang=Language.zh_CN, market=Market.CN) + contract_CN = DataFrame(symbols_names_CN, columns=['symbol', 'name']) for ix, row in contract_CN.iterrows(): symbol = row["symbol"] @@ -382,10 +371,8 @@ class TigerGateway(BaseGateway): self.write_log("合约查询成功") - def query_account(self): """""" - #self.write_log("query_account_##") try: assets = self.trade_client.get_assets() except ApiException: @@ -404,7 +391,6 @@ class TigerGateway(BaseGateway): def query_position(self): """""" - #self.write_log("query_position_##") try: position = self.trade_client.get_positions() except ApiException: @@ -436,7 +422,6 @@ class TigerGateway(BaseGateway): def query_order(self): """""" - #self.write_log("query_order") try: data = self.trade_client.get_orders() except ApiException: @@ -445,11 +430,10 @@ class TigerGateway(BaseGateway): self.process_order(data) self.process_deal(data) - #self.write_log("委托查询成功") def query_trade(self): """""" - #self.write_log("query_trade") + pass def close(self): """""" @@ -458,13 +442,12 @@ class TigerGateway(BaseGateway): def process_order(self, data): """""" - #self.write_log("process_order") for i in data: symbol = str(i.contract) symbol, exchange = convert_symbol_tiger2vt(symbol) - time_local = time.localtime(i.order_time/1000) + time_local = time.localtime(i.order_time / 1000) - if i.order_type =="LMT": + if i.order_type == "LMT": price = i.limit_price else: price = i.avg_fill_price @@ -488,12 +471,11 @@ class TigerGateway(BaseGateway): """ Process trade data for both query and update. """ - #self.write_log("process_deal_###") for i in data: if i.status == ORDER_STATUS.PARTIALLY_FILLED or i.status == ORDER_STATUS.FILLED: symbol = str(i.contract) symbol, exchange = convert_symbol_tiger2vt(symbol) - time_local = time.localtime(i.trade_time/1000) + time_local = time.localtime(i.trade_time / 1000) trade = TradeData( symbol=symbol, @@ -510,8 +492,6 @@ class TigerGateway(BaseGateway): self.on_trade(trade) - - def convert_symbol_tiger2vt(symbol): """ Convert symbol from vt to tiger. @@ -541,6 +521,7 @@ def convert_symbol_vt2tiger(symbol, exchange): symbol = symbol return symbol + def config_symbol_currency(symbol): """ Config symbol to corresponding currency