Merge remote-tracking branch 'remotes/origin/dev' into import_fix

This commit is contained in:
nanoric 2019-06-28 10:59:29 +08:00
commit f9c726fed7
8 changed files with 794 additions and 71 deletions

View File

@ -3,30 +3,32 @@ from vnpy.event import EventEngine
from vnpy.trader.engine import MainEngine from vnpy.trader.engine import MainEngine
from vnpy.trader.ui import MainWindow, create_qapp from vnpy.trader.ui import MainWindow, create_qapp
from vnpy.gateway.binance import BinanceGateway # from vnpy.gateway.binance import BinanceGateway
from vnpy.gateway.bitmex import BitmexGateway # from vnpy.gateway.bitmex import BitmexGateway
from vnpy.gateway.futu import FutuGateway # from vnpy.gateway.futu import FutuGateway
from vnpy.gateway.ib import IbGateway # from vnpy.gateway.ib import IbGateway
from vnpy.gateway.ctp import CtpGateway # from vnpy.gateway.ctp import CtpGateway
# from vnpy.gateway.ctptest import CtptestGateway # from vnpy.gateway.ctptest import CtptestGateway
from vnpy.gateway.femas import FemasGateway # from vnpy.gateway.femas import FemasGateway
from vnpy.gateway.tiger import TigerGateway from vnpy.gateway.tiger import TigerGateway
# from vnpy.gateway.oes import OesGateway # from vnpy.gateway.oes import OesGateway
from vnpy.gateway.okex import OkexGateway # from vnpy.gateway.okex import OkexGateway
from vnpy.gateway.huobi import HuobiGateway # from vnpy.gateway.huobi import HuobiGateway
from vnpy.gateway.bitfinex import BitfinexGateway # from vnpy.gateway.bitfinex import BitfinexGateway
from vnpy.gateway.onetoken import OnetokenGateway # from vnpy.gateway.onetoken import OnetokenGateway
from vnpy.gateway.okexf import OkexfGateway # from vnpy.gateway.okexf import OkexfGateway
# from vnpy.gateway.xtp import XtpGateway # from vnpy.gateway.xtp import XtpGateway
from vnpy.gateway.hbdm import HbdmGateway from vnpy.gateway.hbdm import HbdmGateway
from vnpy.gateway.tap import TapGateway # from vnpy.gateway.tap import TapGateway
from vnpy.gateway.tora import ToraGateway
from vnpy.gateway.alpaca import AlpacaGateway
from vnpy.app.cta_strategy import CtaStrategyApp # from vnpy.app.cta_strategy import CtaStrategyApp
from vnpy.app.csv_loader import CsvLoaderApp # from vnpy.app.csv_loader import CsvLoaderApp
from vnpy.app.algo_trading import AlgoTradingApp # from vnpy.app.algo_trading import AlgoTradingApp
from vnpy.app.cta_backtester import CtaBacktesterApp # from vnpy.app.cta_backtester import CtaBacktesterApp
from vnpy.app.data_recorder import DataRecorderApp # from vnpy.app.data_recorder import DataRecorderApp
from vnpy.app.risk_manager import RiskManagerApp # from vnpy.app.risk_manager import RiskManagerApp
def main(): def main():
@ -37,30 +39,32 @@ def main():
main_engine = MainEngine(event_engine) main_engine = MainEngine(event_engine)
main_engine.add_gateway(BinanceGateway) # main_engine.add_gateway(BinanceGateway)
main_engine.add_gateway(CtpGateway) # main_engine.add_gateway(CtpGateway)
# main_engine.add_gateway(CtptestGateway) # main_engine.add_gateway(CtptestGateway)
main_engine.add_gateway(FemasGateway) # main_engine.add_gateway(FemasGateway)
main_engine.add_gateway(IbGateway) # main_engine.add_gateway(IbGateway)
main_engine.add_gateway(FutuGateway) # main_engine.add_gateway(FutuGateway)
main_engine.add_gateway(BitmexGateway) # main_engine.add_gateway(BitmexGateway)
main_engine.add_gateway(TigerGateway) # main_engine.add_gateway(TigerGateway)
# main_engine.add_gateway(OesGateway) # main_engine.add_gateway(OesGateway)
main_engine.add_gateway(OkexGateway) # main_engine.add_gateway(OkexGateway)
main_engine.add_gateway(HuobiGateway) # main_engine.add_gateway(HuobiGateway)
main_engine.add_gateway(BitfinexGateway) # main_engine.add_gateway(BitfinexGateway)
main_engine.add_gateway(OnetokenGateway) # main_engine.add_gateway(OnetokenGateway)
main_engine.add_gateway(OkexfGateway) # main_engine.add_gateway(OkexfGateway)
main_engine.add_gateway(HbdmGateway) # main_engine.add_gateway(HbdmGateway)
# main_engine.add_gateway(XtpGateway) # main_engine.add_gateway(XtpGateway)
main_engine.add_gateway(TapGateway) # main_engine.add_gateway(TapGateway)
main_engine.add_gateway(ToraGateway)
main_engine.add_gateway(AlpacaGateway)
main_engine.add_app(CtaStrategyApp) # main_engine.add_app(CtaStrategyApp)
main_engine.add_app(CtaBacktesterApp) # main_engine.add_app(CtaBacktesterApp)
main_engine.add_app(CsvLoaderApp) # main_engine.add_app(CsvLoaderApp)
main_engine.add_app(AlgoTradingApp) # main_engine.add_app(AlgoTradingApp)
main_engine.add_app(DataRecorderApp) # main_engine.add_app(DataRecorderApp)
main_engine.add_app(RiskManagerApp) # main_engine.add_app(RiskManagerApp)
main_window = MainWindow(main_engine, event_engine) main_window = MainWindow(main_engine, event_engine)
main_window.showMaximized() main_window.showMaximized()

View File

@ -256,11 +256,14 @@ class RestClient(object):
proxies=self.proxies, proxies=self.proxies,
) )
request.response = response request.response = response
status_code = response.status_code status_code = response.status_code
if status_code // 100 == 2: # 2xx都算成功尽管交易所都用200 if status_code // 100 == 2: # 2xx codes are all successful
jsonBody = response.json() if status_code == 204:
request.callback(jsonBody, request) json_body = None
else:
json_body = response.json()
request.callback(json_body, request)
request.status = RequestStatus.success request.status = RequestStatus.success
else: else:
request.status = RequestStatus.failed request.status = RequestStatus.failed

View File

@ -0,0 +1 @@
from .alpaca_gateway import AlpacaGateway

View File

@ -0,0 +1,651 @@
# encoding: UTF-8
"""
Author: vigarbuaa
"""
import sys
import json
from threading import Lock
from datetime import datetime
from vnpy.api.rest import Request, RestClient
from vnpy.api.websocket import WebsocketClient
from vnpy.event import Event
from vnpy.trader.event import EVENT_TIMER
from vnpy.trader.constant import (
Direction,
Exchange,
OrderType,
Product,
Status
)
from vnpy.trader.gateway import BaseGateway
from vnpy.trader.object import (
TickData,
OrderData,
TradeData,
PositionData,
AccountData,
ContractData,
OrderRequest,
CancelRequest,
SubscribeRequest,
)
REST_HOST = "https://api.alpaca.markets" # Live trading
WEBSOCKET_HOST = "wss://api.alpaca.markets/stream"
PAPER_REST_HOST = "https://paper-api.alpaca.markets" # Paper Trading
PAPER_WEBSOCKET_HOST = "wss://paper-api.alpaca.markets/stream"
DATA_REST_HOST = "https://data.alpaca.markets"
STATUS_ALPACA2VT = {
"new": Status.NOTTRADED,
"partially_filled": Status.PARTTRADED,
"filled": Status.ALLTRADED,
"canceled": Status.CANCELLED,
"expired": Status.CANCELLED,
"rejected": Status.REJECTED
}
DIRECTION_VT2ALPACA = {
Direction.LONG: "buy",
Direction.SHORT: "sell"
}
DIRECTION_ALPACA2VT = {
"buy": Direction.LONG,
"sell": Direction.SHORT,
"long": Direction.LONG,
"short": Direction.SHORT
}
ORDERTYPE_VT2ALPACA = {
OrderType.LIMIT: "limit",
OrderType.MARKET: "market"
}
ORDERTYPE_ALPACA2VT = {v: k for k, v in ORDERTYPE_VT2ALPACA.items()}
LOCAL_SYS_MAP = {}
class AlpacaGateway(BaseGateway):
"""
VN Trader Gateway for Alpaca connection.
"""
default_setting = {
"KEY ID": "",
"Secret Key": "",
"会话数": 10,
"服务器": ["REAL", "PAPER"]
}
exchanges = [Exchange.SMART]
def __init__(self, event_engine):
"""Constructor"""
super().__init__(event_engine, "ALPACA")
self.rest_api = AlpacaRestApi(self)
self.ws_api = AlpacaWebsocketApi(self)
self.data_rest_api = AlpacaDataRestApi(self)
def connect(self, setting: dict):
""""""
key = setting["KEY ID"]
secret = setting["Secret Key"]
session = setting["会话数"]
server = setting["服务器"]
rest_url = REST_HOST if server == "REAL" else PAPER_REST_HOST
websocket_url = WEBSOCKET_HOST if server == "REAL" else PAPER_WEBSOCKET_HOST
self.rest_api.connect(key, secret, session, rest_url)
self.data_rest_api.connect(key, secret, session)
self.ws_api.connect(key, secret, websocket_url)
self.init_query()
def subscribe(self, req: SubscribeRequest):
""""""
self.data_rest_api.subscribe(req)
def send_order(self, req: OrderRequest):
""""""
return self.rest_api.send_order(req)
def cancel_order(self, req: CancelRequest):
""""""
self.rest_api.cancel_order(req)
def query_account(self):
""""""
self.rest_api.query_account()
def query_position(self):
""""""
self.rest_api.query_position()
def close(self):
""""""
self.rest_api.stop()
self.data_rest_api.stop()
self.ws_api.stop()
def init_query(self):
""""""
self.count = 0
self.event_engine.register(EVENT_TIMER, self.process_timer_event)
def process_timer_event(self, event: Event):
""""""
self.data_rest_api.query_bar()
self.count += 1
if self.count < 5:
return
self.count = 0
self.query_account()
self.query_position()
class AlpacaRestApi(RestClient):
"""
Alpaca REST API
"""
def __init__(self, gateway: AlpacaGateway):
""""""
super().__init__()
self.gateway = gateway
self.gateway_name = gateway.gateway_name
self.key = ""
self.secret = ""
self.order_count = 1_000_000
self.order_count_lock = Lock()
self.connect_time = 0
self.cancel_reqs = {}
def sign(self, request):
"""
Generate Alpaca signature.
"""
headers = {
"APCA-API-KEY-ID": self.key,
"APCA-API-SECRET-KEY": self.secret,
"Content-Type": "application/json"
}
request.headers = headers
request.allow_redirects = False
request.data = json.dumps(request.data)
return request
def connect(
self,
key: str,
secret: str,
session_num: int,
url: str,
):
"""
Initialize connection to REST server.
"""
self.key = key
self.secret = secret
self.connect_time = (
int(datetime.now().strftime("%y%m%d%H%M%S")) * self.order_count
)
self.init(url)
self.start(session_num)
self.gateway.write_log("REST API启动成功")
self.query_contract()
self.query_account()
self.query_position()
self.query_order()
def query_contract(self):
""""""
params = {"status": "active"}
self.add_request(
"GET",
"/v2/assets",
params=params,
callback=self.on_query_contract
)
def query_account(self):
""""""
self.add_request(
method="GET",
path="/v2/account",
callback=self.on_query_account
)
def query_position(self):
""""""
self.add_request(
method="GET",
path="/v2/positions",
callback=self.on_query_position
)
def query_order(self):
""""""
params = {
"status": "open"
}
self.add_request(
method="GET",
path="/v2/orders",
params=params,
callback=self.on_query_order
)
def _new_order_id(self):
""""""
with self.order_count_lock:
self.order_count += 1
return self.order_count
def send_order(self, req: OrderRequest):
""""""
local_orderid = str(self.connect_time + self._new_order_id())
data = {
"symbol": req.symbol,
"qty": str(req.volume),
"side": DIRECTION_VT2ALPACA[req.direction],
"type": ORDERTYPE_VT2ALPACA[req.type],
"time_in_force": "day",
"client_order_id": local_orderid
}
if data["type"] == "limit":
data["limit_price"] = str(req.price)
order = req.create_order_data(local_orderid, self.gateway_name)
self.gateway.on_order(order)
self.add_request(
"POST",
"/v2/orders",
callback=self.on_send_order,
data=data,
extra=order,
on_failed=self.on_send_order_failed,
on_error=self.on_send_order_error,
)
return order.vt_orderid
def cancel_order(self, req: CancelRequest):
""""""
sys_orderid = LOCAL_SYS_MAP.get(req.orderid, None)
if not sys_orderid:
self.cancel_reqs[req.orderid] = req
return
path = f"/v2/orders/{sys_orderid}"
self.add_request(
"DELETE",
path,
callback=self.on_cancel_order,
extra=req
)
def on_query_contract(self, data, request: Request):
""""""
for d in data:
symbol = d["symbol"]
contract = ContractData(
symbol=symbol,
exchange=Exchange.SMART,
name=symbol,
product=Product.SPOT,
size=1,
pricetick=0.01,
gateway_name=self.gateway_name
)
self.gateway.on_contract(contract)
self.gateway.write_log("合约信息查询成功")
def on_query_account(self, data, request):
""""""
account = AccountData(
accountid=data["id"],
balance=float(data["equity"]),
gateway_name=self.gateway_name
)
self.gateway.on_account(account)
def on_query_position(self, data, request):
""""""
for d in data:
position = PositionData(
symbol=d["symbol"],
exchange=Exchange.SMART,
direction=DIRECTION_ALPACA2VT[d["side"]],
volume=int(d["qty"]),
price=float(d["avg_entry_price"]),
pnl=float(d["unrealized_pl"]),
gateway_name=self.gateway_name,
)
self.gateway.on_position(position)
def update_order(self, d: dict):
""""""
sys_orderid = d["id"]
local_orderid = d["client_order_id"]
LOCAL_SYS_MAP[local_orderid] = sys_orderid
direction = DIRECTION_ALPACA2VT[d["side"]]
order_type = ORDERTYPE_ALPACA2VT[d["type"]]
order = OrderData(
orderid=local_orderid,
symbol=d["symbol"],
exchange=Exchange.SMART,
price=float(d["limit_price"]),
volume=float(d["qty"]),
type=order_type,
direction=direction,
traded=float(d["filled_qty"]),
status=STATUS_ALPACA2VT.get(d["status"], Status.SUBMITTING),
time=d["created_at"],
gateway_name=self.gateway_name,
)
self.gateway.on_order(order)
def on_query_order(self, data, request):
""""""
for d in data:
self.update_order(d)
self.gateway.write_log("委托信息查询成功")
def on_send_order(self, data, request: Request):
""""""
self.update_order(data)
order = request.extra
if order.orderid in self.cancel_reqs:
req = self.cancel_reqs.pop(order.orderid)
self.cancel_order(req)
def on_send_order_failed(self, status_code: int, request: Request):
"""
Callback to handle request failed.
"""
order = request.extra
order.status = Status.REJECTED
self.gateway.on_order(order)
msg = f"请求失败,状态码:{status_code},信息:{request.response.text}"
self.gateway.write_log(msg)
def on_send_order_error(
self, exception_type: type, exception_value: Exception, tb, request: Request
):
"""
Callback to handler request exception.
"""
order = request.extra
order.status = Status.REJECTED
self.gateway.on_order(order)
msg = f"触发异常,状态码:{exception_type},信息:{exception_value}"
self.gateway.write_log(msg)
sys.stderr.write(
self.exception_detail(exception_type, exception_value, tb, request)
)
def on_cancel_order(self, data, request):
""""""
req = request.extra
msg = f"撤单成功,委托号:{req.orderid}"
self.gateway.write_log(msg)
class AlpacaWebsocketApi(WebsocketClient):
""""""
def __init__(self, gateway: AlpacaGateway):
""""""
super().__init__()
self.gateway = gateway
self.gateway_name = gateway.gateway_name
self.trade_count = 0
self.key = ""
self.secret = ""
def connect(
self, key: str, secret: str, url: str
):
""""""
self.key = key
self.secret = secret
self.init(url)
self.start()
def authenticate(self):
""""""
params = {
"action": "authenticate",
"data": {
"key_id": self.key,
"secret_key": self.secret
}
}
self.send_packet(params)
def on_authenticate(self, data):
""""""
if data["status"] == "authorized":
self.gateway.write_log("Websocket API登录成功")
else:
self.gateway.write_log("Websocket API登录失败")
return
params = {
"action": "listen",
"data": {
"streams": ["trade_updates", "account_updates"]
}
}
self.send_packet(params)
def on_connected(self):
""""""
self.gateway.write_log("Websocket API连接成功")
self.authenticate()
def on_disconnected(self):
""""""
self.gateway.write_log("Websocket API连接断开")
def on_packet(self, packet: dict):
""""""
stream = packet["stream"]
data = packet["data"]
if stream == "authorization":
self.on_authenticate(data)
elif stream == "listening":
streams = data["streams"]
if "trade_updates" in streams:
self.gateway.write_log("委托成交推送订阅成功")
if "account_updates" in streams:
self.gateway.write_log("资金变化推送订阅成功")
elif stream == "trade_updates":
self.on_order(data)
elif stream == "account_updates":
self.on_account(data)
def on_order(self, data):
""""""
# Update order
d = data["order"]
sys_orderid = d["id"]
local_orderid = d["client_order_id"]
LOCAL_SYS_MAP[local_orderid] = sys_orderid
direction = DIRECTION_ALPACA2VT[d["side"]]
order_type = ORDERTYPE_ALPACA2VT[d["type"]]
order = OrderData(
orderid=local_orderid,
symbol=d["symbol"],
exchange=Exchange.SMART,
price=float(d["limit_price"]),
volume=float(d["qty"]),
type=order_type,
direction=direction,
traded=float(d["filled_qty"]),
status=STATUS_ALPACA2VT.get(d["status"], Status.SUBMITTING),
time=d["created_at"],
gateway_name=self.gateway_name,
)
self.gateway.on_order(order)
# Update Trade
event = data.get("event", "")
if event != "fill":
return
self.trade_count += 1
trade = TradeData(
symbol=order.symbol,
exchange=order.exchange,
orderid=order.orderid,
tradeid=str(self.trade_count),
direction=order.direction,
price=float(data["price"]),
volume=int(data["qty"]),
time=data["timestamp"],
gateway_name=self.gateway_name
)
self.gateway.on_trade(trade)
def on_account(self, data):
""""""
account = AccountData(
accountid=data["id"],
balance=float(data["equity"]),
gateway_name=self.gateway_name
)
self.gateway.on_account(account)
class AlpacaDataRestApi(RestClient):
"""
Alpaca Market Data REST API
"""
def __init__(self, gateway: AlpacaGateway):
""""""
super().__init__()
self.gateway = gateway
self.gateway_name = gateway.gateway_name
self.key = ""
self.secret = ""
self.symbols = set()
def sign(self, request):
"""
Generate Alpaca signature.
"""
headers = {
"APCA-API-KEY-ID": self.key,
"APCA-API-SECRET-KEY": self.secret,
"Content-Type": "application/json"
}
request.headers = headers
request.allow_redirects = False
return request
def connect(
self,
key: str,
secret: str,
session_num: int
):
"""
Initialize connection to REST server.
"""
self.key = key
self.secret = secret
self.init(DATA_REST_HOST)
self.start(session_num)
self.gateway.write_log("行情REST API启动成功")
def subscribe(self, req: SubscribeRequest):
""""""
self.symbols.add(req.symbol)
def query_bar(self):
""""""
if not self._active or not self.symbols:
return
params = {
"symbols": ",".join(list(self.symbols)),
"limit": 1
}
self.add_request(
method="GET",
path="/v1/bars/1Min",
params=params,
callback=self.on_query_bar
)
def on_query_bar(self, data, request):
""""""
for symbol, buf in data.items():
d = buf[0]
tick = TickData(
symbol=symbol,
exchange=Exchange.SMART,
datetime=datetime.now(),
name=symbol,
open_price=d["o"],
high_price=d["h"],
low_price=d["l"],
last_price=d["c"],
gateway_name=self.gateway_name
)
self.gateway.on_tick(tick)

View File

@ -23,22 +23,28 @@ def parse_datetime(date: str, time: str):
class ToraMdSpi(CTORATstpMdSpi): class ToraMdSpi(CTORATstpMdSpi):
""""""
def __init__(self, api: "ToraMdApi", gateway: "BaseGateway"): def __init__(self, api: "ToraMdApi", gateway: "BaseGateway"):
""""""
super().__init__() super().__init__()
self.gateway = gateway self.gateway = gateway
self._api = api self._api = api
def OnFrontConnected(self) -> Any: def OnFrontConnected(self) -> Any:
""""""
self.gateway.write_log("行情服务器连接成功") self.gateway.write_log("行情服务器连接成功")
def OnFrontDisconnected(self, error_code: int) -> Any: def OnFrontDisconnected(self, error_code: int) -> Any:
self.gateway.write_log(f"行情服务器连接断开({error_code}):{get_error_msg(error_code)}") """"""
self.gateway.write_log(
f"行情服务器连接断开({error_code}):{get_error_msg(error_code)}")
def OnRspError( def OnRspError(
self, error_info: CTORATstpRspInfoField, request_id: int, is_last: bool self, error_info: CTORATstpRspInfoField, request_id: int, is_last: bool
) -> Any: ) -> Any:
""""""
error_id = error_info.ErrorID error_id = error_info.ErrorID
error_msg = error_info.ErrorMsg error_msg = error_info.ErrorMsg
self.gateway.write_log(f"行情服务收到错误消息({error_id}){error_msg}") self.gateway.write_log(f"行情服务收到错误消息({error_id}){error_msg}")
@ -50,6 +56,7 @@ class ToraMdSpi(CTORATstpMdSpi):
request_id: int, request_id: int,
is_last: bool, is_last: bool,
) -> Any: ) -> Any:
""""""
error_id = error_info.ErrorID error_id = error_info.ErrorID
if error_id != 0: if error_id != 0:
error_msg = error_info.ErrorMsg error_msg = error_info.ErrorMsg
@ -64,6 +71,7 @@ class ToraMdSpi(CTORATstpMdSpi):
request_id: int, request_id: int,
is_last: bool, is_last: bool,
) -> Any: ) -> Any:
""""""
error_id = error_info.ErrorID error_id = error_info.ErrorID
if error_id != 0: if error_id != 0:
error_msg = error_info.ErrorMsg error_msg = error_info.ErrorMsg
@ -72,6 +80,7 @@ class ToraMdSpi(CTORATstpMdSpi):
self.gateway.write_log("行情服务器登出成功") self.gateway.write_log("行情服务器登出成功")
def OnRtnDepthMarketData(self, data: CTORATstpMarketDataField) -> Any: def OnRtnDepthMarketData(self, data: CTORATstpMarketDataField) -> Any:
""""""
if data.ExchangeID not in EXCHANGE_TORA2VT: if data.ExchangeID not in EXCHANGE_TORA2VT:
return return
tick_data = TickData( tick_data = TickData(
@ -114,8 +123,10 @@ class ToraMdSpi(CTORATstpMdSpi):
class ToraMdApi: class ToraMdApi:
""""""
def __init__(self, gateway: BaseGateway): def __init__(self, gateway: BaseGateway):
""""""
self.gateway = gateway self.gateway = gateway
self.md_address = "" self.md_address = ""
@ -151,10 +162,13 @@ class ToraMdApi:
return True return True
def subscribe(self, symbols: List[str], exchange: Exchange): def subscribe(self, symbols: List[str], exchange: Exchange):
err = self._native_api.SubscribeMarketData(symbols, EXCHANGE_VT2TORA[exchange]) """"""
err = self._native_api.SubscribeMarketData(
symbols, EXCHANGE_VT2TORA[exchange])
self._if_error_write_log(err, "subscribe") self._if_error_write_log(err, "subscribe")
def _if_error_write_log(self, error_code: int, function_name: str): def _if_error_write_log(self, error_code: int, function_name: str):
""""""
if error_code != 0: if error_code != 0:
error_msg = get_error_msg(error_code) error_msg = get_error_msg(error_code)
msg = f'在执行 {function_name} 时发生错误({error_code}): {error_msg}' msg = f'在执行 {function_name} 时发生错误({error_code}): {error_msg}'

View File

@ -47,7 +47,8 @@ def _check_error(none_return: bool = True,
def wrapped(self, info, error_info, *args): def wrapped(self, info, error_info, *args):
function_name = func.__name__ function_name = func.__name__
if print_function_name: if print_function_name:
print(function_name, "info" if info else "None", error_info.ErrorID) print(function_name, "info" if info else "None",
error_info.ErrorID)
# print if errors # print if errors
error_code = error_info.ErrorID error_code = error_info.ErrorID
@ -72,8 +73,10 @@ def _check_error(none_return: bool = True,
class QueryLoop: class QueryLoop:
""""""
def __init__(self, gateway: "BaseGateway"): def __init__(self, gateway: "BaseGateway"):
""""""
self.event_engine = gateway.event_engine self.event_engine = gateway.event_engine
self._seconds_left = 0 self._seconds_left = 0
@ -84,6 +87,7 @@ class QueryLoop:
self.event_engine.register(EVENT_TIMER, self._process_timer_event) self.event_engine.register(EVENT_TIMER, self._process_timer_event)
def stop(self): def stop(self):
""""""
self.event_engine.unregister(EVENT_TIMER, self._process_timer_event) self.event_engine.unregister(EVENT_TIMER, self._process_timer_event)
def _process_timer_event(self, event): def _process_timer_event(self, event):
@ -96,7 +100,8 @@ class QueryLoop:
self._seconds_left = 2 self._seconds_left = 2
# get the last one and re-queue it # get the last one and re-queue it
func = self._query_functions.pop(0) # works fine if there is no so much items # works fine if there is no so much items
func = self._query_functions.pop(0)
self._query_functions.append(func) self._query_functions.append(func)
# call it # call it
@ -107,11 +112,13 @@ OrdersType = Dict[str, "OrderInfo"]
class ToraTdSpi(CTORATstpTraderSpi): class ToraTdSpi(CTORATstpTraderSpi):
""""""
def __init__(self, session_info: "SessionInfo", def __init__(self, session_info: "SessionInfo",
api: "ToraTdApi", api: "ToraTdApi",
gateway: "BaseGateway", gateway: "BaseGateway",
orders: OrdersType): orders: OrdersType):
""""""
super().__init__() super().__init__()
self.session_info = session_info self.session_info = session_info
self.gateway = gateway self.gateway = gateway
@ -120,6 +127,7 @@ class ToraTdSpi(CTORATstpTraderSpi):
self._api: "ToraTdApi" = api self._api: "ToraTdApi" = api
def OnRtnTrade(self, info: CTORATstpTradeField) -> None: def OnRtnTrade(self, info: CTORATstpTradeField) -> None:
""""""
try: try:
trade_data = TradeData( trade_data = TradeData(
gateway_name=self.gateway.gateway_name, gateway_name=self.gateway.gateway_name,
@ -138,6 +146,7 @@ class ToraTdSpi(CTORATstpTraderSpi):
return return
def OnRtnOrder(self, info: CTORATstpOrderField) -> None: def OnRtnOrder(self, info: CTORATstpOrderField) -> None:
""""""
self._api.update_last_local_order_id(int(info.OrderRef)) self._api.update_last_local_order_id(int(info.OrderRef))
try: try:
@ -155,6 +164,7 @@ class ToraTdSpi(CTORATstpTraderSpi):
@_check_error(error_return=False, write_log=False, print_function_name=False) @_check_error(error_return=False, write_log=False, print_function_name=False)
def OnErrRtnOrderInsert(self, info: CTORATstpInputOrderField, def OnErrRtnOrderInsert(self, info: CTORATstpInputOrderField,
error_info: CTORATstpRspInfoField) -> None: error_info: CTORATstpRspInfoField) -> None:
""""""
try: try:
self._api.update_last_local_order_id(int(info.OrderRef)) self._api.update_last_local_order_id(int(info.OrderRef))
except ValueError: except ValueError:
@ -173,24 +183,27 @@ class ToraTdSpi(CTORATstpTraderSpi):
@_check_error(error_return=False, write_log=False, print_function_name=False) @_check_error(error_return=False, write_log=False, print_function_name=False)
def OnErrRtnOrderAction(self, info: CTORATstpOrderActionField, def OnErrRtnOrderAction(self, info: CTORATstpOrderActionField,
error_info: CTORATstpRspInfoField) -> None: error_info: CTORATstpRspInfoField) -> None:
""""""
pass pass
@_check_error() @_check_error()
def OnRtnCondOrder(self, info: CTORATstpConditionOrderField) -> None: def OnRtnCondOrder(self, info: CTORATstpConditionOrderField) -> None:
""""""
pass pass
@_check_error() @_check_error()
def OnRspOrderAction(self, info: CTORATstpInputOrderActionField, def OnRspOrderAction(self, info: CTORATstpInputOrderActionField,
error_info: CTORATstpRspInfoField, request_id: int, is_last: bool) -> None: error_info: CTORATstpRspInfoField, request_id: int, is_last: bool) -> None:
print("order action succeed!") pass
@_check_error() @_check_error()
def OnRspOrderInsert(self, info: CTORATstpInputOrderField, def OnRspOrderInsert(self, info: CTORATstpInputOrderField,
error_info: CTORATstpRspInfoField, request_id: int, is_last: bool) -> None: error_info: CTORATstpRspInfoField, request_id: int, is_last: bool) -> None:
""""""
try: try:
order_data = self.parse_order_field(info) order_data = self.parse_order_field(info)
except KeyError: except KeyError:
self.gateway.write_log(f"收到无法识别的下单回执({info.OrderRef})") self.gateway.write_log(f"收到无法识别的下单回执({info.OrderRef})")
return return
self.gateway.on_order(order_data) self.gateway.on_order(order_data)
@ -208,14 +221,17 @@ class ToraTdSpi(CTORATstpTraderSpi):
@_check_error(print_function_name=False) @_check_error(print_function_name=False)
def OnRspQryPosition(self, info: CTORATstpPositionField, error_info: CTORATstpRspInfoField, def OnRspQryPosition(self, info: CTORATstpPositionField, error_info: CTORATstpRspInfoField,
request_id: int, is_last: bool) -> None: request_id: int, is_last: bool) -> None:
""""""
if info.InvestorID != self.session_info.investor_id: if info.InvestorID != self.session_info.investor_id:
self.gateway.write_log("OnRspQryPosition:收到其他账户的仓位信息") self.gateway.write_log("OnRspQryPosition:收到其他账户的仓位信息")
return return
if info.ExchangeID not in EXCHANGE_TORA2VT: if info.ExchangeID not in EXCHANGE_TORA2VT:
self.gateway.write_log(f"OnRspQryPosition:忽略不支持的交易所:{info.ExchangeID}") self.gateway.write_log(
f"OnRspQryPosition:忽略不支持的交易所:{info.ExchangeID}")
return return
volume = info.CurrentPosition volume = info.CurrentPosition
frozen = info.HistoryPosFrozen + info.TodayBSFrozen + info.TodayPRFrozen + info.TodaySMPosFrozen frozen = info.HistoryPosFrozen + info.TodayBSFrozen + \
info.TodayPRFrozen + info.TodaySMPosFrozen
position_data = PositionData( position_data = PositionData(
gateway_name=self.gateway.gateway_name, gateway_name=self.gateway.gateway_name,
symbol=info.SecurityID, symbol=info.SecurityID,
@ -224,7 +240,8 @@ class ToraTdSpi(CTORATstpTraderSpi):
volume=volume, # verify this: which one should vnpy use? volume=volume, # verify this: which one should vnpy use?
frozen=frozen, # verify this: which one should i use? frozen=frozen, # verify this: which one should i use?
price=info.TotalPosCost / volume, price=info.TotalPosCost / volume,
pnl=info.LastPrice * volume - info.TotalPosCost, # verify this: is this formula correct # verify this: is this formula correct
pnl=info.LastPrice * volume - info.TotalPosCost,
yd_volume=info.HistoryPos, yd_volume=info.HistoryPos,
) )
self.gateway.on_position(position_data) self.gateway.on_position(position_data)
@ -233,7 +250,7 @@ class ToraTdSpi(CTORATstpTraderSpi):
def OnRspQryTradingAccount(self, info: CTORATstpTradingAccountField, def OnRspQryTradingAccount(self, info: CTORATstpTradingAccountField,
error_info: CTORATstpRspInfoField, request_id: int, error_info: CTORATstpRspInfoField, request_id: int,
is_last: bool) -> None: is_last: bool) -> None:
""""""
self.session_info.account_id = info.AccountID self.session_info.account_id = info.AccountID
account_data = AccountData( account_data = AccountData(
gateway_name=self.gateway.gateway_name, gateway_name=self.gateway.gateway_name,
@ -247,19 +264,22 @@ class ToraTdSpi(CTORATstpTraderSpi):
def OnRspQryShareholderAccount(self, info: CTORATstpShareholderAccountField, def OnRspQryShareholderAccount(self, info: CTORATstpShareholderAccountField,
error_info: CTORATstpRspInfoField, request_id: int, error_info: CTORATstpRspInfoField, request_id: int,
is_last: bool) -> None: is_last: bool) -> None:
""""""
exchange = EXCHANGE_TORA2VT[info.ExchangeID] exchange = EXCHANGE_TORA2VT[info.ExchangeID]
self.session_info.shareholder_ids[exchange] = info.ShareholderID self.session_info.shareholder_ids[exchange] = info.ShareholderID
@_check_error(print_function_name=False) @_check_error(print_function_name=False)
def OnRspQryInvestor(self, info: CTORATstpInvestorField, error_info: CTORATstpRspInfoField, def OnRspQryInvestor(self, info: CTORATstpInvestorField, error_info: CTORATstpRspInfoField,
request_id: int, is_last: bool) -> None: request_id: int, is_last: bool) -> None:
""""""
self.session_info.investor_id = info.InvestorID self.session_info.investor_id = info.InvestorID
@_check_error(none_return=False, print_function_name=False) @_check_error(none_return=False, print_function_name=False)
def OnRspQrySecurity(self, info: CTORATstpSecurityField, error_info: CTORATstpRspInfoField, def OnRspQrySecurity(self, info: CTORATstpSecurityField, error_info: CTORATstpRspInfoField,
request_id: int, is_last: bool) -> None: request_id: int, is_last: bool) -> None:
""""""
if is_last: if is_last:
self.gateway.write_log("合约信息查询成功!") self.gateway.write_log("合约信息查询成功")
if not info: if not info:
return return
@ -283,12 +303,14 @@ class ToraTdSpi(CTORATstpTraderSpi):
self.gateway.on_contract(contract_data) self.gateway.on_contract(contract_data)
def OnFrontConnected(self) -> None: def OnFrontConnected(self) -> None:
""""""
self.gateway.write_log("交易服务器连接成功") self.gateway.write_log("交易服务器连接成功")
self._api.login() self._api.login()
@_check_error(print_function_name=False) @_check_error(print_function_name=False)
def OnRspUserLogin(self, info: CTORATstpRspUserLoginField, def OnRspUserLogin(self, info: CTORATstpRspUserLoginField,
error_info: CTORATstpRspInfoField, request_id: int, is_last: bool) -> None: error_info: CTORATstpRspInfoField, request_id: int, is_last: bool) -> None:
""""""
self._api.update_last_local_order_id(int(info.MaxOrderRef)) self._api.update_last_local_order_id(int(info.MaxOrderRef))
self.session_info.front_id = info.FrontID self.session_info.front_id = info.FrontID
self.session_info.session_id = info.SessionID self.session_info.session_id = info.SessionID
@ -298,7 +320,9 @@ class ToraTdSpi(CTORATstpTraderSpi):
self._api.start_query_loop() # stop at ToraTdApi.stop() self._api.start_query_loop() # stop at ToraTdApi.stop()
def OnFrontDisconnected(self, error_code: int) -> None: def OnFrontDisconnected(self, error_code: int) -> None:
self.gateway.write_log(f"交易服务器连接断开({error_code}):{get_error_msg(error_code)}") """"""
self.gateway.write_log(
f"交易服务器连接断开({error_code}):{get_error_msg(error_code)}")
def parse_order_field(self, info): def parse_order_field(self, info):
""" """
@ -330,6 +354,7 @@ class ToraTdSpi(CTORATstpTraderSpi):
class ToraTdApi: class ToraTdApi:
def __init__(self, gateway: BaseGateway): def __init__(self, gateway: BaseGateway):
""""""
self.gateway = gateway self.gateway = gateway
self.username = "" self.username = ""
@ -347,13 +372,16 @@ class ToraTdApi:
self._next_local_order_id = int(1e5) self._next_local_order_id = int(1e5)
def get_shareholder_id(self, exchange: Exchange): def get_shareholder_id(self, exchange: Exchange):
""""""
return self.session_info.shareholder_ids[exchange] return self.session_info.shareholder_ids[exchange]
def update_last_local_order_id(self, new_val: int): def update_last_local_order_id(self, new_val: int):
""""""
cur = self._next_local_order_id cur = self._next_local_order_id
self._next_local_order_id = max(cur, new_val + 1) self._next_local_order_id = max(cur, new_val + 1)
def _if_error_write_log(self, error_code: int, function_name: str): def _if_error_write_log(self, error_code: int, function_name: str):
""""""
if error_code != 0: if error_code != 0:
error_msg = get_error_msg(error_code) error_msg = get_error_msg(error_code)
msg = f'在执行 {function_name} 时发生错误({error_code}): {error_msg}' msg = f'在执行 {function_name} 时发生错误({error_code}): {error_msg}'
@ -361,21 +389,25 @@ class ToraTdApi:
return True return True
def _get_new_req_id(self): def _get_new_req_id(self):
""""""
req_id = self._last_req_id req_id = self._last_req_id
self._last_req_id += 1 self._last_req_id += 1
return req_id return req_id
def _get_new_order_id(self) -> str: def _get_new_order_id(self) -> str:
""""""
order_id = self._next_local_order_id order_id = self._next_local_order_id
self._next_local_order_id += 1 self._next_local_order_id += 1
return str(order_id) return str(order_id)
def query_contracts(self): def query_contracts(self):
""""""
info = CTORATstpQrySecurityField() info = CTORATstpQrySecurityField()
err = self._native_api.ReqQrySecurity(info, self._get_new_req_id()) err = self._native_api.ReqQrySecurity(info, self._get_new_req_id())
self._if_error_write_log(err, "query_contracts") self._if_error_write_log(err, "query_contracts")
def query_exchange(self, exchange: Exchange): def query_exchange(self, exchange: Exchange):
""""""
info = CTORATstpQryExchangeField() info = CTORATstpQryExchangeField()
info.ExchangeID = EXCHANGE_VT2TORA[exchange] info.ExchangeID = EXCHANGE_VT2TORA[exchange]
err = self._native_api.ReqQryExchange(info, self._get_new_req_id()) err = self._native_api.ReqQryExchange(info, self._get_new_req_id())
@ -383,6 +415,7 @@ class ToraTdApi:
self._if_error_write_log(err, "query_exchange") self._if_error_write_log(err, "query_exchange")
def query_market_data(self, symbol: str, exchange: Exchange): def query_market_data(self, symbol: str, exchange: Exchange):
""""""
info = CTORATstpQryMarketDataField() info = CTORATstpQryMarketDataField()
info.ExchangeID = EXCHANGE_VT2TORA[exchange] info.ExchangeID = EXCHANGE_VT2TORA[exchange]
info.SecurityID = symbol info.SecurityID = symbol
@ -390,6 +423,7 @@ class ToraTdApi:
self._if_error_write_log(err, "query_market_data") self._if_error_write_log(err, "query_market_data")
def stop(self): def stop(self):
""""""
self.stop_query_loop() self.stop_query_loop()
if self._native_api: if self._native_api:
@ -419,16 +453,21 @@ class ToraTdApi:
:return: :return:
""" """
flow_path = str(get_folder_path(self.gateway.gateway_name.lower())) flow_path = str(get_folder_path(self.gateway.gateway_name.lower()))
self._native_api = CTORATstpTraderApi.CreateTstpTraderApi(flow_path, True) self._native_api = CTORATstpTraderApi.CreateTstpTraderApi(
self._spi = ToraTdSpi(self.session_info, self, self.gateway, self.orders) flow_path, True)
self._spi = ToraTdSpi(self.session_info, self,
self.gateway, self.orders)
self._native_api.RegisterSpi(self._spi) self._native_api.RegisterSpi(self._spi)
self._native_api.RegisterFront(self.td_address) self._native_api.RegisterFront(self.td_address)
self._native_api.SubscribePublicTopic(TORA_TE_RESUME_TYPE.TORA_TERT_RESTART) self._native_api.SubscribePublicTopic(
self._native_api.SubscribePrivateTopic(TORA_TE_RESUME_TYPE.TORA_TERT_RESTART) TORA_TE_RESUME_TYPE.TORA_TERT_RESTART)
self._native_api.SubscribePrivateTopic(
TORA_TE_RESUME_TYPE.TORA_TERT_RESTART)
self._native_api.Init() self._native_api.Init()
return True return True
def send_order(self, req: OrderRequest): def send_order(self, req: OrderRequest):
""""""
if req.type is OrderType.STOP: if req.type is OrderType.STOP:
raise NotImplementedError() raise NotImplementedError()
if req.type is OrderType.FAK or req.type is OrderType.FOK: if req.type is OrderType.FAK or req.type is OrderType.FOK:
@ -465,16 +504,19 @@ class ToraTdApi:
self.session_info.session_id, self.session_info.session_id,
self.session_info.front_id, self.session_info.front_id,
) )
self.gateway.on_order(req.create_order_data(order_id, self.gateway.gateway_name)) self.gateway.on_order(req.create_order_data(
order_id, self.gateway.gateway_name))
# err = self._native_api.ReqCondOrderInsert(info, self._get_new_req_id()) # err = self._native_api.ReqCondOrderInsert(info, self._get_new_req_id())
err = self._native_api.ReqOrderInsert(info, self._get_new_req_id()) err = self._native_api.ReqOrderInsert(info, self._get_new_req_id())
self._if_error_write_log(err, "send_order:ReqOrderInsert") self._if_error_write_log(err, "send_order:ReqOrderInsert")
def cancel_order(self, req: CancelRequest): def cancel_order(self, req: CancelRequest):
""""""
info = CTORATstpInputOrderActionField() info = CTORATstpInputOrderActionField()
info.InvestorID = self.session_info.investor_id info.InvestorID = self.session_info.investor_id
info.ExchangeID = EXCHANGE_VT2TORA[req.exchange] # 没有的话:(16608)VIP:未知的交易所代码 # 没有的话:(16608)VIP:未知的交易所代码
info.ExchangeID = EXCHANGE_VT2TORA[req.exchange]
info.SecurityID = req.symbol info.SecurityID = req.symbol
# info.OrderActionRef = str(self._get_new_req_id()) # info.OrderActionRef = str(self._get_new_req_id())
@ -491,6 +533,7 @@ class ToraTdApi:
self._if_error_write_log(err, "cancel_order:ReqOrderAction") self._if_error_write_log(err, "cancel_order:ReqOrderAction")
def query_initialize_status(self): def query_initialize_status(self):
""""""
self.query_contracts() self.query_contracts()
self.query_investors() self.query_investors()
self.query_shareholder_ids() self.query_shareholder_ids()
@ -500,41 +543,51 @@ class ToraTdApi:
self.query_trades() self.query_trades()
def query_accounts(self): def query_accounts(self):
""""""
info = CTORATstpQryTradingAccountField() info = CTORATstpQryTradingAccountField()
err = self._native_api.ReqQryTradingAccount(info, self._get_new_req_id()) err = self._native_api.ReqQryTradingAccount(
info, self._get_new_req_id())
self._if_error_write_log(err, "query_accounts") self._if_error_write_log(err, "query_accounts")
def query_shareholder_ids(self): def query_shareholder_ids(self):
""""""
info = CTORATstpQryShareholderAccountField() info = CTORATstpQryShareholderAccountField()
err = self._native_api.ReqQryShareholderAccount(info, self._get_new_req_id()) err = self._native_api.ReqQryShareholderAccount(
info, self._get_new_req_id())
self._if_error_write_log(err, "query_shareholder_ids") self._if_error_write_log(err, "query_shareholder_ids")
def query_investors(self): def query_investors(self):
""""""
info = CTORATstpQryInvestorField() info = CTORATstpQryInvestorField()
err = self._native_api.ReqQryInvestor(info, self._get_new_req_id()) err = self._native_api.ReqQryInvestor(info, self._get_new_req_id())
self._if_error_write_log(err, "query_investors") self._if_error_write_log(err, "query_investors")
def query_positions(self): def query_positions(self):
""""""
info = CTORATstpQryPositionField() info = CTORATstpQryPositionField()
err = self._native_api.ReqQryPosition(info, self._get_new_req_id()) err = self._native_api.ReqQryPosition(info, self._get_new_req_id())
self._if_error_write_log(err, "query_positions") self._if_error_write_log(err, "query_positions")
def query_orders(self): def query_orders(self):
""""""
info = CTORATstpQryOrderField() info = CTORATstpQryOrderField()
err = self._native_api.ReqQryOrder(info, self._get_new_req_id()) err = self._native_api.ReqQryOrder(info, self._get_new_req_id())
self._if_error_write_log(err, "query_orders") self._if_error_write_log(err, "query_orders")
def query_trades(self): def query_trades(self):
""""""
info = CTORATstpQryTradeField() info = CTORATstpQryTradeField()
err = self._native_api.ReqQryTrade(info, self._get_new_req_id()) err = self._native_api.ReqQryTrade(info, self._get_new_req_id())
self._if_error_write_log(err, "query_trades") self._if_error_write_log(err, "query_trades")
def start_query_loop(self): def start_query_loop(self):
""""""
if not self._query_loop: if not self._query_loop:
self._query_loop = QueryLoop(self.gateway) self._query_loop = QueryLoop(self.gateway)
self._query_loop.start() self._query_loop.start()
def stop_query_loop(self): def stop_query_loop(self):
""""""
if self._query_loop: if self._query_loop:
self._query_loop.stop() self._query_loop.stop()
self._query_loop = None self._query_loop = None

View File

@ -5,7 +5,8 @@ TODO:
* Linux support * Linux support
""" """
from vnpy.api.tora.vntora import (AsyncDispatchException, set_async_callback_exception_handler) from vnpy.api.tora.vntora import (
AsyncDispatchException, set_async_callback_exception_handler)
from vnpy.event import EventEngine from vnpy.event import EventEngine
from vnpy.trader.gateway import BaseGateway from vnpy.trader.gateway import BaseGateway
@ -20,6 +21,8 @@ def is_valid_front_address(address: str):
class ToraGateway(BaseGateway): class ToraGateway(BaseGateway):
""""""
default_setting = { default_setting = {
"账号": "", "账号": "",
"密码": "", "密码": "",
@ -86,13 +89,6 @@ class ToraGateway(BaseGateway):
"""""" """"""
self._td_api.query_positions() self._td_api.query_positions()
def write_log(self, msg: str):
"""
for easier test
"""
print(msg)
super().write_log(msg)
def _async_callback_exception_handler(self, e: AsyncDispatchException): def _async_callback_exception_handler(self, e: AsyncDispatchException):
error_str = f"发生内部错误:\n" f"位置:{e.instance}.{e.function_name}" f"详细信息:{e.what}" error_str = f"发生内部错误:\n" f"位置:{e.instance}.{e.function_name}" f"详细信息:{e.what}"
self.write_log(error_str) self.write_log(error_str)

View File

@ -163,6 +163,7 @@ class MainWindow(QtWidgets.QMainWindow):
def init_toolbar(self): def init_toolbar(self):
"""""" """"""
self.toolbar = QtWidgets.QToolBar(self) self.toolbar = QtWidgets.QToolBar(self)
self.toolbar.setObjectName("工具栏")
self.toolbar.setFloatable(False) self.toolbar.setFloatable(False)
self.toolbar.setMovable(False) self.toolbar.setMovable(False)