diff --git a/examples/vn_trader/run.py b/examples/vn_trader/run.py index c7cd65c5..43d4fca1 100644 --- a/examples/vn_trader/run.py +++ b/examples/vn_trader/run.py @@ -3,30 +3,32 @@ from vnpy.event import EventEngine from vnpy.trader.engine import MainEngine from vnpy.trader.ui import MainWindow, create_qapp -from vnpy.gateway.binance import BinanceGateway -from vnpy.gateway.bitmex import BitmexGateway -from vnpy.gateway.futu import FutuGateway -from vnpy.gateway.ib import IbGateway -from vnpy.gateway.ctp import CtpGateway +# from vnpy.gateway.binance import BinanceGateway +# from vnpy.gateway.bitmex import BitmexGateway +# from vnpy.gateway.futu import FutuGateway +# from vnpy.gateway.ib import IbGateway +# from vnpy.gateway.ctp import CtpGateway # from vnpy.gateway.ctptest import CtptestGateway -from vnpy.gateway.femas import FemasGateway +# from vnpy.gateway.femas import FemasGateway from vnpy.gateway.tiger import TigerGateway # from vnpy.gateway.oes import OesGateway -from vnpy.gateway.okex import OkexGateway -from vnpy.gateway.huobi import HuobiGateway -from vnpy.gateway.bitfinex import BitfinexGateway -from vnpy.gateway.onetoken import OnetokenGateway -from vnpy.gateway.okexf import OkexfGateway +# from vnpy.gateway.okex import OkexGateway +# from vnpy.gateway.huobi import HuobiGateway +# from vnpy.gateway.bitfinex import BitfinexGateway +# from vnpy.gateway.onetoken import OnetokenGateway +# from vnpy.gateway.okexf import OkexfGateway # from vnpy.gateway.xtp import XtpGateway from vnpy.gateway.hbdm import HbdmGateway -from vnpy.gateway.tap import TapGateway +# from vnpy.gateway.tap import TapGateway +from vnpy.gateway.tora import ToraGateway +from vnpy.gateway.alpaca import AlpacaGateway -from vnpy.app.cta_strategy import CtaStrategyApp -from vnpy.app.csv_loader import CsvLoaderApp -from vnpy.app.algo_trading import AlgoTradingApp -from vnpy.app.cta_backtester import CtaBacktesterApp -from vnpy.app.data_recorder import DataRecorderApp -from vnpy.app.risk_manager import RiskManagerApp +# from vnpy.app.cta_strategy import CtaStrategyApp +# from vnpy.app.csv_loader import CsvLoaderApp +# from vnpy.app.algo_trading import AlgoTradingApp +# from vnpy.app.cta_backtester import CtaBacktesterApp +# from vnpy.app.data_recorder import DataRecorderApp +# from vnpy.app.risk_manager import RiskManagerApp def main(): @@ -37,30 +39,32 @@ def main(): main_engine = MainEngine(event_engine) - main_engine.add_gateway(BinanceGateway) - main_engine.add_gateway(CtpGateway) + # main_engine.add_gateway(BinanceGateway) + # main_engine.add_gateway(CtpGateway) # main_engine.add_gateway(CtptestGateway) - main_engine.add_gateway(FemasGateway) - main_engine.add_gateway(IbGateway) - main_engine.add_gateway(FutuGateway) - main_engine.add_gateway(BitmexGateway) - main_engine.add_gateway(TigerGateway) + # main_engine.add_gateway(FemasGateway) + # main_engine.add_gateway(IbGateway) + # main_engine.add_gateway(FutuGateway) + # main_engine.add_gateway(BitmexGateway) + # main_engine.add_gateway(TigerGateway) # main_engine.add_gateway(OesGateway) - main_engine.add_gateway(OkexGateway) - main_engine.add_gateway(HuobiGateway) - main_engine.add_gateway(BitfinexGateway) - main_engine.add_gateway(OnetokenGateway) - main_engine.add_gateway(OkexfGateway) - main_engine.add_gateway(HbdmGateway) + # main_engine.add_gateway(OkexGateway) + # main_engine.add_gateway(HuobiGateway) + # main_engine.add_gateway(BitfinexGateway) + # main_engine.add_gateway(OnetokenGateway) + # main_engine.add_gateway(OkexfGateway) + # main_engine.add_gateway(HbdmGateway) # main_engine.add_gateway(XtpGateway) - main_engine.add_gateway(TapGateway) + # main_engine.add_gateway(TapGateway) + main_engine.add_gateway(ToraGateway) + main_engine.add_gateway(AlpacaGateway) - main_engine.add_app(CtaStrategyApp) - main_engine.add_app(CtaBacktesterApp) - main_engine.add_app(CsvLoaderApp) - main_engine.add_app(AlgoTradingApp) - main_engine.add_app(DataRecorderApp) - main_engine.add_app(RiskManagerApp) + # main_engine.add_app(CtaStrategyApp) + # main_engine.add_app(CtaBacktesterApp) + # main_engine.add_app(CsvLoaderApp) + # main_engine.add_app(AlgoTradingApp) + # main_engine.add_app(DataRecorderApp) + # main_engine.add_app(RiskManagerApp) main_window = MainWindow(main_engine, event_engine) main_window.showMaximized() diff --git a/vnpy/api/rest/rest_client.py b/vnpy/api/rest/rest_client.py index 646c894e..c2083688 100644 --- a/vnpy/api/rest/rest_client.py +++ b/vnpy/api/rest/rest_client.py @@ -256,11 +256,14 @@ class RestClient(object): proxies=self.proxies, ) request.response = response - status_code = response.status_code - if status_code // 100 == 2: # 2xx都算成功,尽管交易所都用200 - jsonBody = response.json() - request.callback(jsonBody, request) + if status_code // 100 == 2: # 2xx codes are all successful + if status_code == 204: + json_body = None + else: + json_body = response.json() + + request.callback(json_body, request) request.status = RequestStatus.success else: request.status = RequestStatus.failed diff --git a/vnpy/gateway/alpaca/__init__.py b/vnpy/gateway/alpaca/__init__.py new file mode 100644 index 00000000..16a9a0b9 --- /dev/null +++ b/vnpy/gateway/alpaca/__init__.py @@ -0,0 +1 @@ +from .alpaca_gateway import AlpacaGateway diff --git a/vnpy/gateway/alpaca/alpaca_gateway.py b/vnpy/gateway/alpaca/alpaca_gateway.py new file mode 100644 index 00000000..036708c3 --- /dev/null +++ b/vnpy/gateway/alpaca/alpaca_gateway.py @@ -0,0 +1,651 @@ +# encoding: UTF-8 +""" +Author: vigarbuaa +""" + +import sys +import json +from threading import Lock +from datetime import datetime +from vnpy.api.rest import Request, RestClient +from vnpy.api.websocket import WebsocketClient +from vnpy.event import Event +from vnpy.trader.event import EVENT_TIMER + +from vnpy.trader.constant import ( + Direction, + Exchange, + OrderType, + Product, + Status +) +from vnpy.trader.gateway import BaseGateway +from vnpy.trader.object import ( + TickData, + OrderData, + TradeData, + PositionData, + AccountData, + ContractData, + OrderRequest, + CancelRequest, + SubscribeRequest, +) + + +REST_HOST = "https://api.alpaca.markets" # Live trading +WEBSOCKET_HOST = "wss://api.alpaca.markets/stream" +PAPER_REST_HOST = "https://paper-api.alpaca.markets" # Paper Trading +PAPER_WEBSOCKET_HOST = "wss://paper-api.alpaca.markets/stream" + +DATA_REST_HOST = "https://data.alpaca.markets" + + +STATUS_ALPACA2VT = { + "new": Status.NOTTRADED, + "partially_filled": Status.PARTTRADED, + "filled": Status.ALLTRADED, + "canceled": Status.CANCELLED, + "expired": Status.CANCELLED, + "rejected": Status.REJECTED +} + +DIRECTION_VT2ALPACA = { + Direction.LONG: "buy", + Direction.SHORT: "sell" +} +DIRECTION_ALPACA2VT = { + "buy": Direction.LONG, + "sell": Direction.SHORT, + "long": Direction.LONG, + "short": Direction.SHORT +} + +ORDERTYPE_VT2ALPACA = { + OrderType.LIMIT: "limit", + OrderType.MARKET: "market" +} +ORDERTYPE_ALPACA2VT = {v: k for k, v in ORDERTYPE_VT2ALPACA.items()} + +LOCAL_SYS_MAP = {} + + +class AlpacaGateway(BaseGateway): + """ + VN Trader Gateway for Alpaca connection. + """ + + default_setting = { + "KEY ID": "", + "Secret Key": "", + "会话数": 10, + "服务器": ["REAL", "PAPER"] + } + + exchanges = [Exchange.SMART] + + def __init__(self, event_engine): + """Constructor""" + super().__init__(event_engine, "ALPACA") + + self.rest_api = AlpacaRestApi(self) + self.ws_api = AlpacaWebsocketApi(self) + self.data_rest_api = AlpacaDataRestApi(self) + + def connect(self, setting: dict): + """""" + key = setting["KEY ID"] + secret = setting["Secret Key"] + session = setting["会话数"] + server = setting["服务器"] + + rest_url = REST_HOST if server == "REAL" else PAPER_REST_HOST + websocket_url = WEBSOCKET_HOST if server == "REAL" else PAPER_WEBSOCKET_HOST + + self.rest_api.connect(key, secret, session, rest_url) + self.data_rest_api.connect(key, secret, session) + self.ws_api.connect(key, secret, websocket_url) + + self.init_query() + + def subscribe(self, req: SubscribeRequest): + """""" + self.data_rest_api.subscribe(req) + + def send_order(self, req: OrderRequest): + """""" + return self.rest_api.send_order(req) + + def cancel_order(self, req: CancelRequest): + """""" + self.rest_api.cancel_order(req) + + def query_account(self): + """""" + self.rest_api.query_account() + + def query_position(self): + """""" + self.rest_api.query_position() + + def close(self): + """""" + self.rest_api.stop() + self.data_rest_api.stop() + self.ws_api.stop() + + def init_query(self): + """""" + self.count = 0 + self.event_engine.register(EVENT_TIMER, self.process_timer_event) + + def process_timer_event(self, event: Event): + """""" + self.data_rest_api.query_bar() + + self.count += 1 + if self.count < 5: + return + self.count = 0 + + self.query_account() + self.query_position() + + +class AlpacaRestApi(RestClient): + """ + Alpaca REST API + """ + + def __init__(self, gateway: AlpacaGateway): + """""" + super().__init__() + + self.gateway = gateway + self.gateway_name = gateway.gateway_name + + self.key = "" + self.secret = "" + + self.order_count = 1_000_000 + self.order_count_lock = Lock() + + self.connect_time = 0 + + self.cancel_reqs = {} + + def sign(self, request): + """ + Generate Alpaca signature. + """ + headers = { + "APCA-API-KEY-ID": self.key, + "APCA-API-SECRET-KEY": self.secret, + "Content-Type": "application/json" + } + + request.headers = headers + request.allow_redirects = False + request.data = json.dumps(request.data) + return request + + def connect( + self, + key: str, + secret: str, + session_num: int, + url: str, + ): + """ + Initialize connection to REST server. + """ + self.key = key + self.secret = secret + + self.connect_time = ( + int(datetime.now().strftime("%y%m%d%H%M%S")) * self.order_count + ) + + self.init(url) + self.start(session_num) + + self.gateway.write_log("REST API启动成功") + self.query_contract() + self.query_account() + self.query_position() + self.query_order() + + def query_contract(self): + """""" + params = {"status": "active"} + + self.add_request( + "GET", + "/v2/assets", + params=params, + callback=self.on_query_contract + ) + + def query_account(self): + """""" + self.add_request( + method="GET", + path="/v2/account", + callback=self.on_query_account + ) + + def query_position(self): + """""" + self.add_request( + method="GET", + path="/v2/positions", + callback=self.on_query_position + ) + + def query_order(self): + """""" + params = { + "status": "open" + } + + self.add_request( + method="GET", + path="/v2/orders", + params=params, + callback=self.on_query_order + ) + + def _new_order_id(self): + """""" + with self.order_count_lock: + self.order_count += 1 + return self.order_count + + def send_order(self, req: OrderRequest): + """""" + local_orderid = str(self.connect_time + self._new_order_id()) + + data = { + "symbol": req.symbol, + "qty": str(req.volume), + "side": DIRECTION_VT2ALPACA[req.direction], + "type": ORDERTYPE_VT2ALPACA[req.type], + "time_in_force": "day", + "client_order_id": local_orderid + } + + if data["type"] == "limit": + data["limit_price"] = str(req.price) + + order = req.create_order_data(local_orderid, self.gateway_name) + self.gateway.on_order(order) + + self.add_request( + "POST", + "/v2/orders", + callback=self.on_send_order, + data=data, + extra=order, + on_failed=self.on_send_order_failed, + on_error=self.on_send_order_error, + ) + + return order.vt_orderid + + def cancel_order(self, req: CancelRequest): + """""" + sys_orderid = LOCAL_SYS_MAP.get(req.orderid, None) + if not sys_orderid: + self.cancel_reqs[req.orderid] = req + return + + path = f"/v2/orders/{sys_orderid}" + + self.add_request( + "DELETE", + path, + callback=self.on_cancel_order, + extra=req + ) + + def on_query_contract(self, data, request: Request): + """""" + for d in data: + symbol = d["symbol"] + + contract = ContractData( + symbol=symbol, + exchange=Exchange.SMART, + name=symbol, + product=Product.SPOT, + size=1, + pricetick=0.01, + gateway_name=self.gateway_name + ) + self.gateway.on_contract(contract) + + self.gateway.write_log("合约信息查询成功") + + def on_query_account(self, data, request): + """""" + account = AccountData( + accountid=data["id"], + balance=float(data["equity"]), + gateway_name=self.gateway_name + ) + self.gateway.on_account(account) + + def on_query_position(self, data, request): + """""" + for d in data: + position = PositionData( + symbol=d["symbol"], + exchange=Exchange.SMART, + direction=DIRECTION_ALPACA2VT[d["side"]], + volume=int(d["qty"]), + price=float(d["avg_entry_price"]), + pnl=float(d["unrealized_pl"]), + gateway_name=self.gateway_name, + ) + self.gateway.on_position(position) + + def update_order(self, d: dict): + """""" + sys_orderid = d["id"] + local_orderid = d["client_order_id"] + LOCAL_SYS_MAP[local_orderid] = sys_orderid + + direction = DIRECTION_ALPACA2VT[d["side"]] + order_type = ORDERTYPE_ALPACA2VT[d["type"]] + + order = OrderData( + orderid=local_orderid, + symbol=d["symbol"], + exchange=Exchange.SMART, + price=float(d["limit_price"]), + volume=float(d["qty"]), + type=order_type, + direction=direction, + traded=float(d["filled_qty"]), + status=STATUS_ALPACA2VT.get(d["status"], Status.SUBMITTING), + time=d["created_at"], + gateway_name=self.gateway_name, + ) + self.gateway.on_order(order) + + def on_query_order(self, data, request): + """""" + for d in data: + self.update_order(d) + + self.gateway.write_log("委托信息查询成功") + + def on_send_order(self, data, request: Request): + """""" + self.update_order(data) + + order = request.extra + if order.orderid in self.cancel_reqs: + req = self.cancel_reqs.pop(order.orderid) + self.cancel_order(req) + + def on_send_order_failed(self, status_code: int, request: Request): + """ + Callback to handle request failed. + """ + order = request.extra + order.status = Status.REJECTED + self.gateway.on_order(order) + + msg = f"请求失败,状态码:{status_code},信息:{request.response.text}" + self.gateway.write_log(msg) + + def on_send_order_error( + self, exception_type: type, exception_value: Exception, tb, request: Request + ): + """ + Callback to handler request exception. + """ + order = request.extra + order.status = Status.REJECTED + self.gateway.on_order(order) + + msg = f"触发异常,状态码:{exception_type},信息:{exception_value}" + self.gateway.write_log(msg) + + sys.stderr.write( + self.exception_detail(exception_type, exception_value, tb, request) + ) + + def on_cancel_order(self, data, request): + """""" + req = request.extra + msg = f"撤单成功,委托号:{req.orderid}" + self.gateway.write_log(msg) + + +class AlpacaWebsocketApi(WebsocketClient): + """""" + + def __init__(self, gateway: AlpacaGateway): + """""" + super().__init__() + + self.gateway = gateway + self.gateway_name = gateway.gateway_name + + self.trade_count = 0 + + self.key = "" + self.secret = "" + + def connect( + self, key: str, secret: str, url: str + ): + """""" + self.key = key + self.secret = secret + + self.init(url) + self.start() + + def authenticate(self): + """""" + params = { + "action": "authenticate", + "data": { + "key_id": self.key, + "secret_key": self.secret + } + } + self.send_packet(params) + + def on_authenticate(self, data): + """""" + if data["status"] == "authorized": + self.gateway.write_log("Websocket API登录成功") + else: + self.gateway.write_log("Websocket API登录失败") + return + + params = { + "action": "listen", + "data": { + "streams": ["trade_updates", "account_updates"] + } + } + self.send_packet(params) + + def on_connected(self): + """""" + self.gateway.write_log("Websocket API连接成功") + self.authenticate() + + def on_disconnected(self): + """""" + self.gateway.write_log("Websocket API连接断开") + + def on_packet(self, packet: dict): + """""" + stream = packet["stream"] + data = packet["data"] + + if stream == "authorization": + self.on_authenticate(data) + elif stream == "listening": + streams = data["streams"] + + if "trade_updates" in streams: + self.gateway.write_log("委托成交推送订阅成功") + + if "account_updates" in streams: + self.gateway.write_log("资金变化推送订阅成功") + + elif stream == "trade_updates": + self.on_order(data) + elif stream == "account_updates": + self.on_account(data) + + def on_order(self, data): + """""" + # Update order + d = data["order"] + sys_orderid = d["id"] + local_orderid = d["client_order_id"] + LOCAL_SYS_MAP[local_orderid] = sys_orderid + + direction = DIRECTION_ALPACA2VT[d["side"]] + order_type = ORDERTYPE_ALPACA2VT[d["type"]] + + order = OrderData( + orderid=local_orderid, + symbol=d["symbol"], + exchange=Exchange.SMART, + price=float(d["limit_price"]), + volume=float(d["qty"]), + type=order_type, + direction=direction, + traded=float(d["filled_qty"]), + status=STATUS_ALPACA2VT.get(d["status"], Status.SUBMITTING), + time=d["created_at"], + gateway_name=self.gateway_name, + ) + self.gateway.on_order(order) + + # Update Trade + event = data.get("event", "") + if event != "fill": + return + + self.trade_count += 1 + + trade = TradeData( + symbol=order.symbol, + exchange=order.exchange, + orderid=order.orderid, + tradeid=str(self.trade_count), + direction=order.direction, + price=float(data["price"]), + volume=int(data["qty"]), + time=data["timestamp"], + gateway_name=self.gateway_name + ) + self.gateway.on_trade(trade) + + def on_account(self, data): + """""" + account = AccountData( + accountid=data["id"], + balance=float(data["equity"]), + gateway_name=self.gateway_name + ) + self.gateway.on_account(account) + + +class AlpacaDataRestApi(RestClient): + """ + Alpaca Market Data REST API + """ + + def __init__(self, gateway: AlpacaGateway): + """""" + super().__init__() + + self.gateway = gateway + self.gateway_name = gateway.gateway_name + + self.key = "" + self.secret = "" + + self.symbols = set() + + def sign(self, request): + """ + Generate Alpaca signature. + """ + headers = { + "APCA-API-KEY-ID": self.key, + "APCA-API-SECRET-KEY": self.secret, + "Content-Type": "application/json" + } + + request.headers = headers + request.allow_redirects = False + return request + + def connect( + self, + key: str, + secret: str, + session_num: int + ): + """ + Initialize connection to REST server. + """ + self.key = key + self.secret = secret + + self.init(DATA_REST_HOST) + self.start(session_num) + + self.gateway.write_log("行情REST API启动成功") + + def subscribe(self, req: SubscribeRequest): + """""" + self.symbols.add(req.symbol) + + def query_bar(self): + """""" + if not self._active or not self.symbols: + return + + params = { + "symbols": ",".join(list(self.symbols)), + "limit": 1 + } + + self.add_request( + method="GET", + path="/v1/bars/1Min", + params=params, + callback=self.on_query_bar + ) + + def on_query_bar(self, data, request): + """""" + for symbol, buf in data.items(): + d = buf[0] + + tick = TickData( + symbol=symbol, + exchange=Exchange.SMART, + datetime=datetime.now(), + name=symbol, + open_price=d["o"], + high_price=d["h"], + low_price=d["l"], + last_price=d["c"], + gateway_name=self.gateway_name + ) + + self.gateway.on_tick(tick) diff --git a/vnpy/gateway/tora/md.py b/vnpy/gateway/tora/md.py index b35a5b51..ef10176c 100644 --- a/vnpy/gateway/tora/md.py +++ b/vnpy/gateway/tora/md.py @@ -23,22 +23,28 @@ def parse_datetime(date: str, time: str): class ToraMdSpi(CTORATstpMdSpi): + """""" def __init__(self, api: "ToraMdApi", gateway: "BaseGateway"): + """""" super().__init__() self.gateway = gateway self._api = api def OnFrontConnected(self) -> Any: + """""" self.gateway.write_log("行情服务器连接成功") def OnFrontDisconnected(self, error_code: int) -> Any: - self.gateway.write_log(f"行情服务器连接断开({error_code}):{get_error_msg(error_code)}") + """""" + self.gateway.write_log( + f"行情服务器连接断开({error_code}):{get_error_msg(error_code)}") def OnRspError( self, error_info: CTORATstpRspInfoField, request_id: int, is_last: bool ) -> Any: + """""" error_id = error_info.ErrorID error_msg = error_info.ErrorMsg self.gateway.write_log(f"行情服务收到错误消息({error_id}):{error_msg}") @@ -50,6 +56,7 @@ class ToraMdSpi(CTORATstpMdSpi): request_id: int, is_last: bool, ) -> Any: + """""" error_id = error_info.ErrorID if error_id != 0: error_msg = error_info.ErrorMsg @@ -64,6 +71,7 @@ class ToraMdSpi(CTORATstpMdSpi): request_id: int, is_last: bool, ) -> Any: + """""" error_id = error_info.ErrorID if error_id != 0: error_msg = error_info.ErrorMsg @@ -72,6 +80,7 @@ class ToraMdSpi(CTORATstpMdSpi): self.gateway.write_log("行情服务器登出成功") def OnRtnDepthMarketData(self, data: CTORATstpMarketDataField) -> Any: + """""" if data.ExchangeID not in EXCHANGE_TORA2VT: return tick_data = TickData( @@ -114,8 +123,10 @@ class ToraMdSpi(CTORATstpMdSpi): class ToraMdApi: + """""" def __init__(self, gateway: BaseGateway): + """""" self.gateway = gateway self.md_address = "" @@ -151,10 +162,13 @@ class ToraMdApi: return True def subscribe(self, symbols: List[str], exchange: Exchange): - err = self._native_api.SubscribeMarketData(symbols, EXCHANGE_VT2TORA[exchange]) + """""" + err = self._native_api.SubscribeMarketData( + symbols, EXCHANGE_VT2TORA[exchange]) self._if_error_write_log(err, "subscribe") def _if_error_write_log(self, error_code: int, function_name: str): + """""" if error_code != 0: error_msg = get_error_msg(error_code) msg = f'在执行 {function_name} 时发生错误({error_code}): {error_msg}' diff --git a/vnpy/gateway/tora/td.py b/vnpy/gateway/tora/td.py index db3b3e44..1d0926af 100644 --- a/vnpy/gateway/tora/td.py +++ b/vnpy/gateway/tora/td.py @@ -47,7 +47,8 @@ def _check_error(none_return: bool = True, def wrapped(self, info, error_info, *args): function_name = func.__name__ if print_function_name: - print(function_name, "info" if info else "None", error_info.ErrorID) + print(function_name, "info" if info else "None", + error_info.ErrorID) # print if errors error_code = error_info.ErrorID @@ -72,8 +73,10 @@ def _check_error(none_return: bool = True, class QueryLoop: + """""" def __init__(self, gateway: "BaseGateway"): + """""" self.event_engine = gateway.event_engine self._seconds_left = 0 @@ -84,6 +87,7 @@ class QueryLoop: self.event_engine.register(EVENT_TIMER, self._process_timer_event) def stop(self): + """""" self.event_engine.unregister(EVENT_TIMER, self._process_timer_event) def _process_timer_event(self, event): @@ -96,7 +100,8 @@ class QueryLoop: self._seconds_left = 2 # get the last one and re-queue it - func = self._query_functions.pop(0) # works fine if there is no so much items + # works fine if there is no so much items + func = self._query_functions.pop(0) self._query_functions.append(func) # call it @@ -107,11 +112,13 @@ OrdersType = Dict[str, "OrderInfo"] class ToraTdSpi(CTORATstpTraderSpi): + """""" def __init__(self, session_info: "SessionInfo", api: "ToraTdApi", gateway: "BaseGateway", orders: OrdersType): + """""" super().__init__() self.session_info = session_info self.gateway = gateway @@ -120,6 +127,7 @@ class ToraTdSpi(CTORATstpTraderSpi): self._api: "ToraTdApi" = api def OnRtnTrade(self, info: CTORATstpTradeField) -> None: + """""" try: trade_data = TradeData( gateway_name=self.gateway.gateway_name, @@ -138,6 +146,7 @@ class ToraTdSpi(CTORATstpTraderSpi): return def OnRtnOrder(self, info: CTORATstpOrderField) -> None: + """""" self._api.update_last_local_order_id(int(info.OrderRef)) try: @@ -155,6 +164,7 @@ class ToraTdSpi(CTORATstpTraderSpi): @_check_error(error_return=False, write_log=False, print_function_name=False) def OnErrRtnOrderInsert(self, info: CTORATstpInputOrderField, error_info: CTORATstpRspInfoField) -> None: + """""" try: self._api.update_last_local_order_id(int(info.OrderRef)) except ValueError: @@ -173,24 +183,27 @@ class ToraTdSpi(CTORATstpTraderSpi): @_check_error(error_return=False, write_log=False, print_function_name=False) def OnErrRtnOrderAction(self, info: CTORATstpOrderActionField, error_info: CTORATstpRspInfoField) -> None: + """""" pass @_check_error() def OnRtnCondOrder(self, info: CTORATstpConditionOrderField) -> None: + """""" pass @_check_error() def OnRspOrderAction(self, info: CTORATstpInputOrderActionField, error_info: CTORATstpRspInfoField, request_id: int, is_last: bool) -> None: - print("order action succeed!") + pass @_check_error() def OnRspOrderInsert(self, info: CTORATstpInputOrderField, error_info: CTORATstpRspInfoField, request_id: int, is_last: bool) -> None: + """""" try: order_data = self.parse_order_field(info) except KeyError: - self.gateway.write_log(f"收到无法识别的下单回执({info.OrderRef})!") + self.gateway.write_log(f"收到无法识别的下单回执({info.OrderRef})") return self.gateway.on_order(order_data) @@ -208,14 +221,17 @@ class ToraTdSpi(CTORATstpTraderSpi): @_check_error(print_function_name=False) def OnRspQryPosition(self, info: CTORATstpPositionField, error_info: CTORATstpRspInfoField, request_id: int, is_last: bool) -> None: + """""" if info.InvestorID != self.session_info.investor_id: - self.gateway.write_log("OnRspQryPosition:收到其他账户的仓位信息!") + self.gateway.write_log("OnRspQryPosition:收到其他账户的仓位信息") return if info.ExchangeID not in EXCHANGE_TORA2VT: - self.gateway.write_log(f"OnRspQryPosition:忽略不支持的交易所:{info.ExchangeID}") + self.gateway.write_log( + f"OnRspQryPosition:忽略不支持的交易所:{info.ExchangeID}") return volume = info.CurrentPosition - frozen = info.HistoryPosFrozen + info.TodayBSFrozen + info.TodayPRFrozen + info.TodaySMPosFrozen + frozen = info.HistoryPosFrozen + info.TodayBSFrozen + \ + info.TodayPRFrozen + info.TodaySMPosFrozen position_data = PositionData( gateway_name=self.gateway.gateway_name, symbol=info.SecurityID, @@ -224,7 +240,8 @@ class ToraTdSpi(CTORATstpTraderSpi): volume=volume, # verify this: which one should vnpy use? frozen=frozen, # verify this: which one should i use? price=info.TotalPosCost / volume, - pnl=info.LastPrice * volume - info.TotalPosCost, # verify this: is this formula correct + # verify this: is this formula correct + pnl=info.LastPrice * volume - info.TotalPosCost, yd_volume=info.HistoryPos, ) self.gateway.on_position(position_data) @@ -233,7 +250,7 @@ class ToraTdSpi(CTORATstpTraderSpi): def OnRspQryTradingAccount(self, info: CTORATstpTradingAccountField, error_info: CTORATstpRspInfoField, request_id: int, is_last: bool) -> None: - + """""" self.session_info.account_id = info.AccountID account_data = AccountData( gateway_name=self.gateway.gateway_name, @@ -247,19 +264,22 @@ class ToraTdSpi(CTORATstpTraderSpi): def OnRspQryShareholderAccount(self, info: CTORATstpShareholderAccountField, error_info: CTORATstpRspInfoField, request_id: int, is_last: bool) -> None: + """""" exchange = EXCHANGE_TORA2VT[info.ExchangeID] self.session_info.shareholder_ids[exchange] = info.ShareholderID @_check_error(print_function_name=False) def OnRspQryInvestor(self, info: CTORATstpInvestorField, error_info: CTORATstpRspInfoField, request_id: int, is_last: bool) -> None: + """""" self.session_info.investor_id = info.InvestorID @_check_error(none_return=False, print_function_name=False) def OnRspQrySecurity(self, info: CTORATstpSecurityField, error_info: CTORATstpRspInfoField, request_id: int, is_last: bool) -> None: + """""" if is_last: - self.gateway.write_log("合约信息查询成功!") + self.gateway.write_log("合约信息查询成功") if not info: return @@ -283,12 +303,14 @@ class ToraTdSpi(CTORATstpTraderSpi): self.gateway.on_contract(contract_data) def OnFrontConnected(self) -> None: + """""" self.gateway.write_log("交易服务器连接成功") self._api.login() @_check_error(print_function_name=False) def OnRspUserLogin(self, info: CTORATstpRspUserLoginField, error_info: CTORATstpRspInfoField, request_id: int, is_last: bool) -> None: + """""" self._api.update_last_local_order_id(int(info.MaxOrderRef)) self.session_info.front_id = info.FrontID self.session_info.session_id = info.SessionID @@ -298,7 +320,9 @@ class ToraTdSpi(CTORATstpTraderSpi): self._api.start_query_loop() # stop at ToraTdApi.stop() def OnFrontDisconnected(self, error_code: int) -> None: - self.gateway.write_log(f"交易服务器连接断开({error_code}):{get_error_msg(error_code)}") + """""" + self.gateway.write_log( + f"交易服务器连接断开({error_code}):{get_error_msg(error_code)}") def parse_order_field(self, info): """ @@ -330,6 +354,7 @@ class ToraTdSpi(CTORATstpTraderSpi): class ToraTdApi: def __init__(self, gateway: BaseGateway): + """""" self.gateway = gateway self.username = "" @@ -347,13 +372,16 @@ class ToraTdApi: self._next_local_order_id = int(1e5) def get_shareholder_id(self, exchange: Exchange): + """""" return self.session_info.shareholder_ids[exchange] def update_last_local_order_id(self, new_val: int): + """""" cur = self._next_local_order_id self._next_local_order_id = max(cur, new_val + 1) def _if_error_write_log(self, error_code: int, function_name: str): + """""" if error_code != 0: error_msg = get_error_msg(error_code) msg = f'在执行 {function_name} 时发生错误({error_code}): {error_msg}' @@ -361,21 +389,25 @@ class ToraTdApi: return True def _get_new_req_id(self): + """""" req_id = self._last_req_id self._last_req_id += 1 return req_id def _get_new_order_id(self) -> str: + """""" order_id = self._next_local_order_id self._next_local_order_id += 1 return str(order_id) def query_contracts(self): + """""" info = CTORATstpQrySecurityField() err = self._native_api.ReqQrySecurity(info, self._get_new_req_id()) self._if_error_write_log(err, "query_contracts") def query_exchange(self, exchange: Exchange): + """""" info = CTORATstpQryExchangeField() info.ExchangeID = EXCHANGE_VT2TORA[exchange] err = self._native_api.ReqQryExchange(info, self._get_new_req_id()) @@ -383,6 +415,7 @@ class ToraTdApi: self._if_error_write_log(err, "query_exchange") def query_market_data(self, symbol: str, exchange: Exchange): + """""" info = CTORATstpQryMarketDataField() info.ExchangeID = EXCHANGE_VT2TORA[exchange] info.SecurityID = symbol @@ -390,6 +423,7 @@ class ToraTdApi: self._if_error_write_log(err, "query_market_data") def stop(self): + """""" self.stop_query_loop() if self._native_api: @@ -419,16 +453,21 @@ class ToraTdApi: :return: """ flow_path = str(get_folder_path(self.gateway.gateway_name.lower())) - self._native_api = CTORATstpTraderApi.CreateTstpTraderApi(flow_path, True) - self._spi = ToraTdSpi(self.session_info, self, self.gateway, self.orders) + self._native_api = CTORATstpTraderApi.CreateTstpTraderApi( + flow_path, True) + self._spi = ToraTdSpi(self.session_info, self, + self.gateway, self.orders) self._native_api.RegisterSpi(self._spi) self._native_api.RegisterFront(self.td_address) - self._native_api.SubscribePublicTopic(TORA_TE_RESUME_TYPE.TORA_TERT_RESTART) - self._native_api.SubscribePrivateTopic(TORA_TE_RESUME_TYPE.TORA_TERT_RESTART) + self._native_api.SubscribePublicTopic( + TORA_TE_RESUME_TYPE.TORA_TERT_RESTART) + self._native_api.SubscribePrivateTopic( + TORA_TE_RESUME_TYPE.TORA_TERT_RESTART) self._native_api.Init() return True def send_order(self, req: OrderRequest): + """""" if req.type is OrderType.STOP: raise NotImplementedError() if req.type is OrderType.FAK or req.type is OrderType.FOK: @@ -465,16 +504,19 @@ class ToraTdApi: self.session_info.session_id, self.session_info.front_id, ) - self.gateway.on_order(req.create_order_data(order_id, self.gateway.gateway_name)) + self.gateway.on_order(req.create_order_data( + order_id, self.gateway.gateway_name)) # err = self._native_api.ReqCondOrderInsert(info, self._get_new_req_id()) err = self._native_api.ReqOrderInsert(info, self._get_new_req_id()) self._if_error_write_log(err, "send_order:ReqOrderInsert") def cancel_order(self, req: CancelRequest): + """""" info = CTORATstpInputOrderActionField() info.InvestorID = self.session_info.investor_id - info.ExchangeID = EXCHANGE_VT2TORA[req.exchange] # 没有的话:(16608):VIP:未知的交易所代码 + # 没有的话:(16608):VIP:未知的交易所代码 + info.ExchangeID = EXCHANGE_VT2TORA[req.exchange] info.SecurityID = req.symbol # info.OrderActionRef = str(self._get_new_req_id()) @@ -491,6 +533,7 @@ class ToraTdApi: self._if_error_write_log(err, "cancel_order:ReqOrderAction") def query_initialize_status(self): + """""" self.query_contracts() self.query_investors() self.query_shareholder_ids() @@ -500,41 +543,51 @@ class ToraTdApi: self.query_trades() def query_accounts(self): + """""" info = CTORATstpQryTradingAccountField() - err = self._native_api.ReqQryTradingAccount(info, self._get_new_req_id()) + err = self._native_api.ReqQryTradingAccount( + info, self._get_new_req_id()) self._if_error_write_log(err, "query_accounts") def query_shareholder_ids(self): + """""" info = CTORATstpQryShareholderAccountField() - err = self._native_api.ReqQryShareholderAccount(info, self._get_new_req_id()) + err = self._native_api.ReqQryShareholderAccount( + info, self._get_new_req_id()) self._if_error_write_log(err, "query_shareholder_ids") def query_investors(self): + """""" info = CTORATstpQryInvestorField() err = self._native_api.ReqQryInvestor(info, self._get_new_req_id()) self._if_error_write_log(err, "query_investors") def query_positions(self): + """""" info = CTORATstpQryPositionField() err = self._native_api.ReqQryPosition(info, self._get_new_req_id()) self._if_error_write_log(err, "query_positions") def query_orders(self): + """""" info = CTORATstpQryOrderField() err = self._native_api.ReqQryOrder(info, self._get_new_req_id()) self._if_error_write_log(err, "query_orders") def query_trades(self): + """""" info = CTORATstpQryTradeField() err = self._native_api.ReqQryTrade(info, self._get_new_req_id()) self._if_error_write_log(err, "query_trades") def start_query_loop(self): + """""" if not self._query_loop: self._query_loop = QueryLoop(self.gateway) self._query_loop.start() def stop_query_loop(self): + """""" if self._query_loop: self._query_loop.stop() self._query_loop = None diff --git a/vnpy/gateway/tora/tora_gateway.py b/vnpy/gateway/tora/tora_gateway.py index 76501a94..c09acee3 100644 --- a/vnpy/gateway/tora/tora_gateway.py +++ b/vnpy/gateway/tora/tora_gateway.py @@ -5,7 +5,8 @@ TODO: * Linux support """ -from vnpy.api.tora.vntora import (AsyncDispatchException, set_async_callback_exception_handler) +from vnpy.api.tora.vntora import ( + AsyncDispatchException, set_async_callback_exception_handler) from vnpy.event import EventEngine from vnpy.trader.gateway import BaseGateway @@ -20,6 +21,8 @@ def is_valid_front_address(address: str): class ToraGateway(BaseGateway): + """""" + default_setting = { "账号": "", "密码": "", @@ -86,13 +89,6 @@ class ToraGateway(BaseGateway): """""" self._td_api.query_positions() - def write_log(self, msg: str): - """ - for easier test - """ - print(msg) - super().write_log(msg) - def _async_callback_exception_handler(self, e: AsyncDispatchException): error_str = f"发生内部错误:\n" f"位置:{e.instance}.{e.function_name}" f"详细信息:{e.what}" self.write_log(error_str) diff --git a/vnpy/trader/ui/mainwindow.py b/vnpy/trader/ui/mainwindow.py index 3313c780..8d2ab6c5 100644 --- a/vnpy/trader/ui/mainwindow.py +++ b/vnpy/trader/ui/mainwindow.py @@ -163,6 +163,7 @@ class MainWindow(QtWidgets.QMainWindow): def init_toolbar(self): """""" self.toolbar = QtWidgets.QToolBar(self) + self.toolbar.setObjectName("工具栏") self.toolbar.setFloatable(False) self.toolbar.setMovable(False)