Merge remote-tracking branch 'remotes/origin/dev' into import_fix

This commit is contained in:
nanoric 2019-06-28 10:59:29 +08:00
commit f9c726fed7
8 changed files with 794 additions and 71 deletions

View File

@ -3,30 +3,32 @@ from vnpy.event import EventEngine
from vnpy.trader.engine import MainEngine
from vnpy.trader.ui import MainWindow, create_qapp
from vnpy.gateway.binance import BinanceGateway
from vnpy.gateway.bitmex import BitmexGateway
from vnpy.gateway.futu import FutuGateway
from vnpy.gateway.ib import IbGateway
from vnpy.gateway.ctp import CtpGateway
# from vnpy.gateway.binance import BinanceGateway
# from vnpy.gateway.bitmex import BitmexGateway
# from vnpy.gateway.futu import FutuGateway
# from vnpy.gateway.ib import IbGateway
# from vnpy.gateway.ctp import CtpGateway
# from vnpy.gateway.ctptest import CtptestGateway
from vnpy.gateway.femas import FemasGateway
# from vnpy.gateway.femas import FemasGateway
from vnpy.gateway.tiger import TigerGateway
# from vnpy.gateway.oes import OesGateway
from vnpy.gateway.okex import OkexGateway
from vnpy.gateway.huobi import HuobiGateway
from vnpy.gateway.bitfinex import BitfinexGateway
from vnpy.gateway.onetoken import OnetokenGateway
from vnpy.gateway.okexf import OkexfGateway
# from vnpy.gateway.okex import OkexGateway
# from vnpy.gateway.huobi import HuobiGateway
# from vnpy.gateway.bitfinex import BitfinexGateway
# from vnpy.gateway.onetoken import OnetokenGateway
# from vnpy.gateway.okexf import OkexfGateway
# from vnpy.gateway.xtp import XtpGateway
from vnpy.gateway.hbdm import HbdmGateway
from vnpy.gateway.tap import TapGateway
# from vnpy.gateway.tap import TapGateway
from vnpy.gateway.tora import ToraGateway
from vnpy.gateway.alpaca import AlpacaGateway
from vnpy.app.cta_strategy import CtaStrategyApp
from vnpy.app.csv_loader import CsvLoaderApp
from vnpy.app.algo_trading import AlgoTradingApp
from vnpy.app.cta_backtester import CtaBacktesterApp
from vnpy.app.data_recorder import DataRecorderApp
from vnpy.app.risk_manager import RiskManagerApp
# from vnpy.app.cta_strategy import CtaStrategyApp
# from vnpy.app.csv_loader import CsvLoaderApp
# from vnpy.app.algo_trading import AlgoTradingApp
# from vnpy.app.cta_backtester import CtaBacktesterApp
# from vnpy.app.data_recorder import DataRecorderApp
# from vnpy.app.risk_manager import RiskManagerApp
def main():
@ -37,30 +39,32 @@ def main():
main_engine = MainEngine(event_engine)
main_engine.add_gateway(BinanceGateway)
main_engine.add_gateway(CtpGateway)
# main_engine.add_gateway(BinanceGateway)
# main_engine.add_gateway(CtpGateway)
# main_engine.add_gateway(CtptestGateway)
main_engine.add_gateway(FemasGateway)
main_engine.add_gateway(IbGateway)
main_engine.add_gateway(FutuGateway)
main_engine.add_gateway(BitmexGateway)
main_engine.add_gateway(TigerGateway)
# main_engine.add_gateway(FemasGateway)
# main_engine.add_gateway(IbGateway)
# main_engine.add_gateway(FutuGateway)
# main_engine.add_gateway(BitmexGateway)
# main_engine.add_gateway(TigerGateway)
# main_engine.add_gateway(OesGateway)
main_engine.add_gateway(OkexGateway)
main_engine.add_gateway(HuobiGateway)
main_engine.add_gateway(BitfinexGateway)
main_engine.add_gateway(OnetokenGateway)
main_engine.add_gateway(OkexfGateway)
main_engine.add_gateway(HbdmGateway)
# main_engine.add_gateway(OkexGateway)
# main_engine.add_gateway(HuobiGateway)
# main_engine.add_gateway(BitfinexGateway)
# main_engine.add_gateway(OnetokenGateway)
# main_engine.add_gateway(OkexfGateway)
# main_engine.add_gateway(HbdmGateway)
# main_engine.add_gateway(XtpGateway)
main_engine.add_gateway(TapGateway)
# main_engine.add_gateway(TapGateway)
main_engine.add_gateway(ToraGateway)
main_engine.add_gateway(AlpacaGateway)
main_engine.add_app(CtaStrategyApp)
main_engine.add_app(CtaBacktesterApp)
main_engine.add_app(CsvLoaderApp)
main_engine.add_app(AlgoTradingApp)
main_engine.add_app(DataRecorderApp)
main_engine.add_app(RiskManagerApp)
# main_engine.add_app(CtaStrategyApp)
# main_engine.add_app(CtaBacktesterApp)
# main_engine.add_app(CsvLoaderApp)
# main_engine.add_app(AlgoTradingApp)
# main_engine.add_app(DataRecorderApp)
# main_engine.add_app(RiskManagerApp)
main_window = MainWindow(main_engine, event_engine)
main_window.showMaximized()

View File

@ -256,11 +256,14 @@ class RestClient(object):
proxies=self.proxies,
)
request.response = response
status_code = response.status_code
if status_code // 100 == 2: # 2xx都算成功尽管交易所都用200
jsonBody = response.json()
request.callback(jsonBody, request)
if status_code // 100 == 2: # 2xx codes are all successful
if status_code == 204:
json_body = None
else:
json_body = response.json()
request.callback(json_body, request)
request.status = RequestStatus.success
else:
request.status = RequestStatus.failed

View File

@ -0,0 +1 @@
from .alpaca_gateway import AlpacaGateway

View File

@ -0,0 +1,651 @@
# encoding: UTF-8
"""
Author: vigarbuaa
"""
import sys
import json
from threading import Lock
from datetime import datetime
from vnpy.api.rest import Request, RestClient
from vnpy.api.websocket import WebsocketClient
from vnpy.event import Event
from vnpy.trader.event import EVENT_TIMER
from vnpy.trader.constant import (
Direction,
Exchange,
OrderType,
Product,
Status
)
from vnpy.trader.gateway import BaseGateway
from vnpy.trader.object import (
TickData,
OrderData,
TradeData,
PositionData,
AccountData,
ContractData,
OrderRequest,
CancelRequest,
SubscribeRequest,
)
REST_HOST = "https://api.alpaca.markets" # Live trading
WEBSOCKET_HOST = "wss://api.alpaca.markets/stream"
PAPER_REST_HOST = "https://paper-api.alpaca.markets" # Paper Trading
PAPER_WEBSOCKET_HOST = "wss://paper-api.alpaca.markets/stream"
DATA_REST_HOST = "https://data.alpaca.markets"
STATUS_ALPACA2VT = {
"new": Status.NOTTRADED,
"partially_filled": Status.PARTTRADED,
"filled": Status.ALLTRADED,
"canceled": Status.CANCELLED,
"expired": Status.CANCELLED,
"rejected": Status.REJECTED
}
DIRECTION_VT2ALPACA = {
Direction.LONG: "buy",
Direction.SHORT: "sell"
}
DIRECTION_ALPACA2VT = {
"buy": Direction.LONG,
"sell": Direction.SHORT,
"long": Direction.LONG,
"short": Direction.SHORT
}
ORDERTYPE_VT2ALPACA = {
OrderType.LIMIT: "limit",
OrderType.MARKET: "market"
}
ORDERTYPE_ALPACA2VT = {v: k for k, v in ORDERTYPE_VT2ALPACA.items()}
LOCAL_SYS_MAP = {}
class AlpacaGateway(BaseGateway):
"""
VN Trader Gateway for Alpaca connection.
"""
default_setting = {
"KEY ID": "",
"Secret Key": "",
"会话数": 10,
"服务器": ["REAL", "PAPER"]
}
exchanges = [Exchange.SMART]
def __init__(self, event_engine):
"""Constructor"""
super().__init__(event_engine, "ALPACA")
self.rest_api = AlpacaRestApi(self)
self.ws_api = AlpacaWebsocketApi(self)
self.data_rest_api = AlpacaDataRestApi(self)
def connect(self, setting: dict):
""""""
key = setting["KEY ID"]
secret = setting["Secret Key"]
session = setting["会话数"]
server = setting["服务器"]
rest_url = REST_HOST if server == "REAL" else PAPER_REST_HOST
websocket_url = WEBSOCKET_HOST if server == "REAL" else PAPER_WEBSOCKET_HOST
self.rest_api.connect(key, secret, session, rest_url)
self.data_rest_api.connect(key, secret, session)
self.ws_api.connect(key, secret, websocket_url)
self.init_query()
def subscribe(self, req: SubscribeRequest):
""""""
self.data_rest_api.subscribe(req)
def send_order(self, req: OrderRequest):
""""""
return self.rest_api.send_order(req)
def cancel_order(self, req: CancelRequest):
""""""
self.rest_api.cancel_order(req)
def query_account(self):
""""""
self.rest_api.query_account()
def query_position(self):
""""""
self.rest_api.query_position()
def close(self):
""""""
self.rest_api.stop()
self.data_rest_api.stop()
self.ws_api.stop()
def init_query(self):
""""""
self.count = 0
self.event_engine.register(EVENT_TIMER, self.process_timer_event)
def process_timer_event(self, event: Event):
""""""
self.data_rest_api.query_bar()
self.count += 1
if self.count < 5:
return
self.count = 0
self.query_account()
self.query_position()
class AlpacaRestApi(RestClient):
"""
Alpaca REST API
"""
def __init__(self, gateway: AlpacaGateway):
""""""
super().__init__()
self.gateway = gateway
self.gateway_name = gateway.gateway_name
self.key = ""
self.secret = ""
self.order_count = 1_000_000
self.order_count_lock = Lock()
self.connect_time = 0
self.cancel_reqs = {}
def sign(self, request):
"""
Generate Alpaca signature.
"""
headers = {
"APCA-API-KEY-ID": self.key,
"APCA-API-SECRET-KEY": self.secret,
"Content-Type": "application/json"
}
request.headers = headers
request.allow_redirects = False
request.data = json.dumps(request.data)
return request
def connect(
self,
key: str,
secret: str,
session_num: int,
url: str,
):
"""
Initialize connection to REST server.
"""
self.key = key
self.secret = secret
self.connect_time = (
int(datetime.now().strftime("%y%m%d%H%M%S")) * self.order_count
)
self.init(url)
self.start(session_num)
self.gateway.write_log("REST API启动成功")
self.query_contract()
self.query_account()
self.query_position()
self.query_order()
def query_contract(self):
""""""
params = {"status": "active"}
self.add_request(
"GET",
"/v2/assets",
params=params,
callback=self.on_query_contract
)
def query_account(self):
""""""
self.add_request(
method="GET",
path="/v2/account",
callback=self.on_query_account
)
def query_position(self):
""""""
self.add_request(
method="GET",
path="/v2/positions",
callback=self.on_query_position
)
def query_order(self):
""""""
params = {
"status": "open"
}
self.add_request(
method="GET",
path="/v2/orders",
params=params,
callback=self.on_query_order
)
def _new_order_id(self):
""""""
with self.order_count_lock:
self.order_count += 1
return self.order_count
def send_order(self, req: OrderRequest):
""""""
local_orderid = str(self.connect_time + self._new_order_id())
data = {
"symbol": req.symbol,
"qty": str(req.volume),
"side": DIRECTION_VT2ALPACA[req.direction],
"type": ORDERTYPE_VT2ALPACA[req.type],
"time_in_force": "day",
"client_order_id": local_orderid
}
if data["type"] == "limit":
data["limit_price"] = str(req.price)
order = req.create_order_data(local_orderid, self.gateway_name)
self.gateway.on_order(order)
self.add_request(
"POST",
"/v2/orders",
callback=self.on_send_order,
data=data,
extra=order,
on_failed=self.on_send_order_failed,
on_error=self.on_send_order_error,
)
return order.vt_orderid
def cancel_order(self, req: CancelRequest):
""""""
sys_orderid = LOCAL_SYS_MAP.get(req.orderid, None)
if not sys_orderid:
self.cancel_reqs[req.orderid] = req
return
path = f"/v2/orders/{sys_orderid}"
self.add_request(
"DELETE",
path,
callback=self.on_cancel_order,
extra=req
)
def on_query_contract(self, data, request: Request):
""""""
for d in data:
symbol = d["symbol"]
contract = ContractData(
symbol=symbol,
exchange=Exchange.SMART,
name=symbol,
product=Product.SPOT,
size=1,
pricetick=0.01,
gateway_name=self.gateway_name
)
self.gateway.on_contract(contract)
self.gateway.write_log("合约信息查询成功")
def on_query_account(self, data, request):
""""""
account = AccountData(
accountid=data["id"],
balance=float(data["equity"]),
gateway_name=self.gateway_name
)
self.gateway.on_account(account)
def on_query_position(self, data, request):
""""""
for d in data:
position = PositionData(
symbol=d["symbol"],
exchange=Exchange.SMART,
direction=DIRECTION_ALPACA2VT[d["side"]],
volume=int(d["qty"]),
price=float(d["avg_entry_price"]),
pnl=float(d["unrealized_pl"]),
gateway_name=self.gateway_name,
)
self.gateway.on_position(position)
def update_order(self, d: dict):
""""""
sys_orderid = d["id"]
local_orderid = d["client_order_id"]
LOCAL_SYS_MAP[local_orderid] = sys_orderid
direction = DIRECTION_ALPACA2VT[d["side"]]
order_type = ORDERTYPE_ALPACA2VT[d["type"]]
order = OrderData(
orderid=local_orderid,
symbol=d["symbol"],
exchange=Exchange.SMART,
price=float(d["limit_price"]),
volume=float(d["qty"]),
type=order_type,
direction=direction,
traded=float(d["filled_qty"]),
status=STATUS_ALPACA2VT.get(d["status"], Status.SUBMITTING),
time=d["created_at"],
gateway_name=self.gateway_name,
)
self.gateway.on_order(order)
def on_query_order(self, data, request):
""""""
for d in data:
self.update_order(d)
self.gateway.write_log("委托信息查询成功")
def on_send_order(self, data, request: Request):
""""""
self.update_order(data)
order = request.extra
if order.orderid in self.cancel_reqs:
req = self.cancel_reqs.pop(order.orderid)
self.cancel_order(req)
def on_send_order_failed(self, status_code: int, request: Request):
"""
Callback to handle request failed.
"""
order = request.extra
order.status = Status.REJECTED
self.gateway.on_order(order)
msg = f"请求失败,状态码:{status_code},信息:{request.response.text}"
self.gateway.write_log(msg)
def on_send_order_error(
self, exception_type: type, exception_value: Exception, tb, request: Request
):
"""
Callback to handler request exception.
"""
order = request.extra
order.status = Status.REJECTED
self.gateway.on_order(order)
msg = f"触发异常,状态码:{exception_type},信息:{exception_value}"
self.gateway.write_log(msg)
sys.stderr.write(
self.exception_detail(exception_type, exception_value, tb, request)
)
def on_cancel_order(self, data, request):
""""""
req = request.extra
msg = f"撤单成功,委托号:{req.orderid}"
self.gateway.write_log(msg)
class AlpacaWebsocketApi(WebsocketClient):
""""""
def __init__(self, gateway: AlpacaGateway):
""""""
super().__init__()
self.gateway = gateway
self.gateway_name = gateway.gateway_name
self.trade_count = 0
self.key = ""
self.secret = ""
def connect(
self, key: str, secret: str, url: str
):
""""""
self.key = key
self.secret = secret
self.init(url)
self.start()
def authenticate(self):
""""""
params = {
"action": "authenticate",
"data": {
"key_id": self.key,
"secret_key": self.secret
}
}
self.send_packet(params)
def on_authenticate(self, data):
""""""
if data["status"] == "authorized":
self.gateway.write_log("Websocket API登录成功")
else:
self.gateway.write_log("Websocket API登录失败")
return
params = {
"action": "listen",
"data": {
"streams": ["trade_updates", "account_updates"]
}
}
self.send_packet(params)
def on_connected(self):
""""""
self.gateway.write_log("Websocket API连接成功")
self.authenticate()
def on_disconnected(self):
""""""
self.gateway.write_log("Websocket API连接断开")
def on_packet(self, packet: dict):
""""""
stream = packet["stream"]
data = packet["data"]
if stream == "authorization":
self.on_authenticate(data)
elif stream == "listening":
streams = data["streams"]
if "trade_updates" in streams:
self.gateway.write_log("委托成交推送订阅成功")
if "account_updates" in streams:
self.gateway.write_log("资金变化推送订阅成功")
elif stream == "trade_updates":
self.on_order(data)
elif stream == "account_updates":
self.on_account(data)
def on_order(self, data):
""""""
# Update order
d = data["order"]
sys_orderid = d["id"]
local_orderid = d["client_order_id"]
LOCAL_SYS_MAP[local_orderid] = sys_orderid
direction = DIRECTION_ALPACA2VT[d["side"]]
order_type = ORDERTYPE_ALPACA2VT[d["type"]]
order = OrderData(
orderid=local_orderid,
symbol=d["symbol"],
exchange=Exchange.SMART,
price=float(d["limit_price"]),
volume=float(d["qty"]),
type=order_type,
direction=direction,
traded=float(d["filled_qty"]),
status=STATUS_ALPACA2VT.get(d["status"], Status.SUBMITTING),
time=d["created_at"],
gateway_name=self.gateway_name,
)
self.gateway.on_order(order)
# Update Trade
event = data.get("event", "")
if event != "fill":
return
self.trade_count += 1
trade = TradeData(
symbol=order.symbol,
exchange=order.exchange,
orderid=order.orderid,
tradeid=str(self.trade_count),
direction=order.direction,
price=float(data["price"]),
volume=int(data["qty"]),
time=data["timestamp"],
gateway_name=self.gateway_name
)
self.gateway.on_trade(trade)
def on_account(self, data):
""""""
account = AccountData(
accountid=data["id"],
balance=float(data["equity"]),
gateway_name=self.gateway_name
)
self.gateway.on_account(account)
class AlpacaDataRestApi(RestClient):
"""
Alpaca Market Data REST API
"""
def __init__(self, gateway: AlpacaGateway):
""""""
super().__init__()
self.gateway = gateway
self.gateway_name = gateway.gateway_name
self.key = ""
self.secret = ""
self.symbols = set()
def sign(self, request):
"""
Generate Alpaca signature.
"""
headers = {
"APCA-API-KEY-ID": self.key,
"APCA-API-SECRET-KEY": self.secret,
"Content-Type": "application/json"
}
request.headers = headers
request.allow_redirects = False
return request
def connect(
self,
key: str,
secret: str,
session_num: int
):
"""
Initialize connection to REST server.
"""
self.key = key
self.secret = secret
self.init(DATA_REST_HOST)
self.start(session_num)
self.gateway.write_log("行情REST API启动成功")
def subscribe(self, req: SubscribeRequest):
""""""
self.symbols.add(req.symbol)
def query_bar(self):
""""""
if not self._active or not self.symbols:
return
params = {
"symbols": ",".join(list(self.symbols)),
"limit": 1
}
self.add_request(
method="GET",
path="/v1/bars/1Min",
params=params,
callback=self.on_query_bar
)
def on_query_bar(self, data, request):
""""""
for symbol, buf in data.items():
d = buf[0]
tick = TickData(
symbol=symbol,
exchange=Exchange.SMART,
datetime=datetime.now(),
name=symbol,
open_price=d["o"],
high_price=d["h"],
low_price=d["l"],
last_price=d["c"],
gateway_name=self.gateway_name
)
self.gateway.on_tick(tick)

View File

@ -23,22 +23,28 @@ def parse_datetime(date: str, time: str):
class ToraMdSpi(CTORATstpMdSpi):
""""""
def __init__(self, api: "ToraMdApi", gateway: "BaseGateway"):
""""""
super().__init__()
self.gateway = gateway
self._api = api
def OnFrontConnected(self) -> Any:
""""""
self.gateway.write_log("行情服务器连接成功")
def OnFrontDisconnected(self, error_code: int) -> Any:
self.gateway.write_log(f"行情服务器连接断开({error_code}):{get_error_msg(error_code)}")
""""""
self.gateway.write_log(
f"行情服务器连接断开({error_code}):{get_error_msg(error_code)}")
def OnRspError(
self, error_info: CTORATstpRspInfoField, request_id: int, is_last: bool
) -> Any:
""""""
error_id = error_info.ErrorID
error_msg = error_info.ErrorMsg
self.gateway.write_log(f"行情服务收到错误消息({error_id}){error_msg}")
@ -50,6 +56,7 @@ class ToraMdSpi(CTORATstpMdSpi):
request_id: int,
is_last: bool,
) -> Any:
""""""
error_id = error_info.ErrorID
if error_id != 0:
error_msg = error_info.ErrorMsg
@ -64,6 +71,7 @@ class ToraMdSpi(CTORATstpMdSpi):
request_id: int,
is_last: bool,
) -> Any:
""""""
error_id = error_info.ErrorID
if error_id != 0:
error_msg = error_info.ErrorMsg
@ -72,6 +80,7 @@ class ToraMdSpi(CTORATstpMdSpi):
self.gateway.write_log("行情服务器登出成功")
def OnRtnDepthMarketData(self, data: CTORATstpMarketDataField) -> Any:
""""""
if data.ExchangeID not in EXCHANGE_TORA2VT:
return
tick_data = TickData(
@ -114,8 +123,10 @@ class ToraMdSpi(CTORATstpMdSpi):
class ToraMdApi:
""""""
def __init__(self, gateway: BaseGateway):
""""""
self.gateway = gateway
self.md_address = ""
@ -151,10 +162,13 @@ class ToraMdApi:
return True
def subscribe(self, symbols: List[str], exchange: Exchange):
err = self._native_api.SubscribeMarketData(symbols, EXCHANGE_VT2TORA[exchange])
""""""
err = self._native_api.SubscribeMarketData(
symbols, EXCHANGE_VT2TORA[exchange])
self._if_error_write_log(err, "subscribe")
def _if_error_write_log(self, error_code: int, function_name: str):
""""""
if error_code != 0:
error_msg = get_error_msg(error_code)
msg = f'在执行 {function_name} 时发生错误({error_code}): {error_msg}'

View File

@ -47,7 +47,8 @@ def _check_error(none_return: bool = True,
def wrapped(self, info, error_info, *args):
function_name = func.__name__
if print_function_name:
print(function_name, "info" if info else "None", error_info.ErrorID)
print(function_name, "info" if info else "None",
error_info.ErrorID)
# print if errors
error_code = error_info.ErrorID
@ -72,8 +73,10 @@ def _check_error(none_return: bool = True,
class QueryLoop:
""""""
def __init__(self, gateway: "BaseGateway"):
""""""
self.event_engine = gateway.event_engine
self._seconds_left = 0
@ -84,6 +87,7 @@ class QueryLoop:
self.event_engine.register(EVENT_TIMER, self._process_timer_event)
def stop(self):
""""""
self.event_engine.unregister(EVENT_TIMER, self._process_timer_event)
def _process_timer_event(self, event):
@ -96,7 +100,8 @@ class QueryLoop:
self._seconds_left = 2
# get the last one and re-queue it
func = self._query_functions.pop(0) # works fine if there is no so much items
# works fine if there is no so much items
func = self._query_functions.pop(0)
self._query_functions.append(func)
# call it
@ -107,11 +112,13 @@ OrdersType = Dict[str, "OrderInfo"]
class ToraTdSpi(CTORATstpTraderSpi):
""""""
def __init__(self, session_info: "SessionInfo",
api: "ToraTdApi",
gateway: "BaseGateway",
orders: OrdersType):
""""""
super().__init__()
self.session_info = session_info
self.gateway = gateway
@ -120,6 +127,7 @@ class ToraTdSpi(CTORATstpTraderSpi):
self._api: "ToraTdApi" = api
def OnRtnTrade(self, info: CTORATstpTradeField) -> None:
""""""
try:
trade_data = TradeData(
gateway_name=self.gateway.gateway_name,
@ -138,6 +146,7 @@ class ToraTdSpi(CTORATstpTraderSpi):
return
def OnRtnOrder(self, info: CTORATstpOrderField) -> None:
""""""
self._api.update_last_local_order_id(int(info.OrderRef))
try:
@ -155,6 +164,7 @@ class ToraTdSpi(CTORATstpTraderSpi):
@_check_error(error_return=False, write_log=False, print_function_name=False)
def OnErrRtnOrderInsert(self, info: CTORATstpInputOrderField,
error_info: CTORATstpRspInfoField) -> None:
""""""
try:
self._api.update_last_local_order_id(int(info.OrderRef))
except ValueError:
@ -173,24 +183,27 @@ class ToraTdSpi(CTORATstpTraderSpi):
@_check_error(error_return=False, write_log=False, print_function_name=False)
def OnErrRtnOrderAction(self, info: CTORATstpOrderActionField,
error_info: CTORATstpRspInfoField) -> None:
""""""
pass
@_check_error()
def OnRtnCondOrder(self, info: CTORATstpConditionOrderField) -> None:
""""""
pass
@_check_error()
def OnRspOrderAction(self, info: CTORATstpInputOrderActionField,
error_info: CTORATstpRspInfoField, request_id: int, is_last: bool) -> None:
print("order action succeed!")
pass
@_check_error()
def OnRspOrderInsert(self, info: CTORATstpInputOrderField,
error_info: CTORATstpRspInfoField, request_id: int, is_last: bool) -> None:
""""""
try:
order_data = self.parse_order_field(info)
except KeyError:
self.gateway.write_log(f"收到无法识别的下单回执({info.OrderRef})")
self.gateway.write_log(f"收到无法识别的下单回执({info.OrderRef})")
return
self.gateway.on_order(order_data)
@ -208,14 +221,17 @@ class ToraTdSpi(CTORATstpTraderSpi):
@_check_error(print_function_name=False)
def OnRspQryPosition(self, info: CTORATstpPositionField, error_info: CTORATstpRspInfoField,
request_id: int, is_last: bool) -> None:
""""""
if info.InvestorID != self.session_info.investor_id:
self.gateway.write_log("OnRspQryPosition:收到其他账户的仓位信息")
self.gateway.write_log("OnRspQryPosition:收到其他账户的仓位信息")
return
if info.ExchangeID not in EXCHANGE_TORA2VT:
self.gateway.write_log(f"OnRspQryPosition:忽略不支持的交易所:{info.ExchangeID}")
self.gateway.write_log(
f"OnRspQryPosition:忽略不支持的交易所:{info.ExchangeID}")
return
volume = info.CurrentPosition
frozen = info.HistoryPosFrozen + info.TodayBSFrozen + info.TodayPRFrozen + info.TodaySMPosFrozen
frozen = info.HistoryPosFrozen + info.TodayBSFrozen + \
info.TodayPRFrozen + info.TodaySMPosFrozen
position_data = PositionData(
gateway_name=self.gateway.gateway_name,
symbol=info.SecurityID,
@ -224,7 +240,8 @@ class ToraTdSpi(CTORATstpTraderSpi):
volume=volume, # verify this: which one should vnpy use?
frozen=frozen, # verify this: which one should i use?
price=info.TotalPosCost / volume,
pnl=info.LastPrice * volume - info.TotalPosCost, # verify this: is this formula correct
# verify this: is this formula correct
pnl=info.LastPrice * volume - info.TotalPosCost,
yd_volume=info.HistoryPos,
)
self.gateway.on_position(position_data)
@ -233,7 +250,7 @@ class ToraTdSpi(CTORATstpTraderSpi):
def OnRspQryTradingAccount(self, info: CTORATstpTradingAccountField,
error_info: CTORATstpRspInfoField, request_id: int,
is_last: bool) -> None:
""""""
self.session_info.account_id = info.AccountID
account_data = AccountData(
gateway_name=self.gateway.gateway_name,
@ -247,19 +264,22 @@ class ToraTdSpi(CTORATstpTraderSpi):
def OnRspQryShareholderAccount(self, info: CTORATstpShareholderAccountField,
error_info: CTORATstpRspInfoField, request_id: int,
is_last: bool) -> None:
""""""
exchange = EXCHANGE_TORA2VT[info.ExchangeID]
self.session_info.shareholder_ids[exchange] = info.ShareholderID
@_check_error(print_function_name=False)
def OnRspQryInvestor(self, info: CTORATstpInvestorField, error_info: CTORATstpRspInfoField,
request_id: int, is_last: bool) -> None:
""""""
self.session_info.investor_id = info.InvestorID
@_check_error(none_return=False, print_function_name=False)
def OnRspQrySecurity(self, info: CTORATstpSecurityField, error_info: CTORATstpRspInfoField,
request_id: int, is_last: bool) -> None:
""""""
if is_last:
self.gateway.write_log("合约信息查询成功!")
self.gateway.write_log("合约信息查询成功")
if not info:
return
@ -283,12 +303,14 @@ class ToraTdSpi(CTORATstpTraderSpi):
self.gateway.on_contract(contract_data)
def OnFrontConnected(self) -> None:
""""""
self.gateway.write_log("交易服务器连接成功")
self._api.login()
@_check_error(print_function_name=False)
def OnRspUserLogin(self, info: CTORATstpRspUserLoginField,
error_info: CTORATstpRspInfoField, request_id: int, is_last: bool) -> None:
""""""
self._api.update_last_local_order_id(int(info.MaxOrderRef))
self.session_info.front_id = info.FrontID
self.session_info.session_id = info.SessionID
@ -298,7 +320,9 @@ class ToraTdSpi(CTORATstpTraderSpi):
self._api.start_query_loop() # stop at ToraTdApi.stop()
def OnFrontDisconnected(self, error_code: int) -> None:
self.gateway.write_log(f"交易服务器连接断开({error_code}):{get_error_msg(error_code)}")
""""""
self.gateway.write_log(
f"交易服务器连接断开({error_code}):{get_error_msg(error_code)}")
def parse_order_field(self, info):
"""
@ -330,6 +354,7 @@ class ToraTdSpi(CTORATstpTraderSpi):
class ToraTdApi:
def __init__(self, gateway: BaseGateway):
""""""
self.gateway = gateway
self.username = ""
@ -347,13 +372,16 @@ class ToraTdApi:
self._next_local_order_id = int(1e5)
def get_shareholder_id(self, exchange: Exchange):
""""""
return self.session_info.shareholder_ids[exchange]
def update_last_local_order_id(self, new_val: int):
""""""
cur = self._next_local_order_id
self._next_local_order_id = max(cur, new_val + 1)
def _if_error_write_log(self, error_code: int, function_name: str):
""""""
if error_code != 0:
error_msg = get_error_msg(error_code)
msg = f'在执行 {function_name} 时发生错误({error_code}): {error_msg}'
@ -361,21 +389,25 @@ class ToraTdApi:
return True
def _get_new_req_id(self):
""""""
req_id = self._last_req_id
self._last_req_id += 1
return req_id
def _get_new_order_id(self) -> str:
""""""
order_id = self._next_local_order_id
self._next_local_order_id += 1
return str(order_id)
def query_contracts(self):
""""""
info = CTORATstpQrySecurityField()
err = self._native_api.ReqQrySecurity(info, self._get_new_req_id())
self._if_error_write_log(err, "query_contracts")
def query_exchange(self, exchange: Exchange):
""""""
info = CTORATstpQryExchangeField()
info.ExchangeID = EXCHANGE_VT2TORA[exchange]
err = self._native_api.ReqQryExchange(info, self._get_new_req_id())
@ -383,6 +415,7 @@ class ToraTdApi:
self._if_error_write_log(err, "query_exchange")
def query_market_data(self, symbol: str, exchange: Exchange):
""""""
info = CTORATstpQryMarketDataField()
info.ExchangeID = EXCHANGE_VT2TORA[exchange]
info.SecurityID = symbol
@ -390,6 +423,7 @@ class ToraTdApi:
self._if_error_write_log(err, "query_market_data")
def stop(self):
""""""
self.stop_query_loop()
if self._native_api:
@ -419,16 +453,21 @@ class ToraTdApi:
:return:
"""
flow_path = str(get_folder_path(self.gateway.gateway_name.lower()))
self._native_api = CTORATstpTraderApi.CreateTstpTraderApi(flow_path, True)
self._spi = ToraTdSpi(self.session_info, self, self.gateway, self.orders)
self._native_api = CTORATstpTraderApi.CreateTstpTraderApi(
flow_path, True)
self._spi = ToraTdSpi(self.session_info, self,
self.gateway, self.orders)
self._native_api.RegisterSpi(self._spi)
self._native_api.RegisterFront(self.td_address)
self._native_api.SubscribePublicTopic(TORA_TE_RESUME_TYPE.TORA_TERT_RESTART)
self._native_api.SubscribePrivateTopic(TORA_TE_RESUME_TYPE.TORA_TERT_RESTART)
self._native_api.SubscribePublicTopic(
TORA_TE_RESUME_TYPE.TORA_TERT_RESTART)
self._native_api.SubscribePrivateTopic(
TORA_TE_RESUME_TYPE.TORA_TERT_RESTART)
self._native_api.Init()
return True
def send_order(self, req: OrderRequest):
""""""
if req.type is OrderType.STOP:
raise NotImplementedError()
if req.type is OrderType.FAK or req.type is OrderType.FOK:
@ -465,16 +504,19 @@ class ToraTdApi:
self.session_info.session_id,
self.session_info.front_id,
)
self.gateway.on_order(req.create_order_data(order_id, self.gateway.gateway_name))
self.gateway.on_order(req.create_order_data(
order_id, self.gateway.gateway_name))
# err = self._native_api.ReqCondOrderInsert(info, self._get_new_req_id())
err = self._native_api.ReqOrderInsert(info, self._get_new_req_id())
self._if_error_write_log(err, "send_order:ReqOrderInsert")
def cancel_order(self, req: CancelRequest):
""""""
info = CTORATstpInputOrderActionField()
info.InvestorID = self.session_info.investor_id
info.ExchangeID = EXCHANGE_VT2TORA[req.exchange] # 没有的话:(16608)VIP:未知的交易所代码
# 没有的话:(16608)VIP:未知的交易所代码
info.ExchangeID = EXCHANGE_VT2TORA[req.exchange]
info.SecurityID = req.symbol
# info.OrderActionRef = str(self._get_new_req_id())
@ -491,6 +533,7 @@ class ToraTdApi:
self._if_error_write_log(err, "cancel_order:ReqOrderAction")
def query_initialize_status(self):
""""""
self.query_contracts()
self.query_investors()
self.query_shareholder_ids()
@ -500,41 +543,51 @@ class ToraTdApi:
self.query_trades()
def query_accounts(self):
""""""
info = CTORATstpQryTradingAccountField()
err = self._native_api.ReqQryTradingAccount(info, self._get_new_req_id())
err = self._native_api.ReqQryTradingAccount(
info, self._get_new_req_id())
self._if_error_write_log(err, "query_accounts")
def query_shareholder_ids(self):
""""""
info = CTORATstpQryShareholderAccountField()
err = self._native_api.ReqQryShareholderAccount(info, self._get_new_req_id())
err = self._native_api.ReqQryShareholderAccount(
info, self._get_new_req_id())
self._if_error_write_log(err, "query_shareholder_ids")
def query_investors(self):
""""""
info = CTORATstpQryInvestorField()
err = self._native_api.ReqQryInvestor(info, self._get_new_req_id())
self._if_error_write_log(err, "query_investors")
def query_positions(self):
""""""
info = CTORATstpQryPositionField()
err = self._native_api.ReqQryPosition(info, self._get_new_req_id())
self._if_error_write_log(err, "query_positions")
def query_orders(self):
""""""
info = CTORATstpQryOrderField()
err = self._native_api.ReqQryOrder(info, self._get_new_req_id())
self._if_error_write_log(err, "query_orders")
def query_trades(self):
""""""
info = CTORATstpQryTradeField()
err = self._native_api.ReqQryTrade(info, self._get_new_req_id())
self._if_error_write_log(err, "query_trades")
def start_query_loop(self):
""""""
if not self._query_loop:
self._query_loop = QueryLoop(self.gateway)
self._query_loop.start()
def stop_query_loop(self):
""""""
if self._query_loop:
self._query_loop.stop()
self._query_loop = None

View File

@ -5,7 +5,8 @@ TODO:
* Linux support
"""
from vnpy.api.tora.vntora import (AsyncDispatchException, set_async_callback_exception_handler)
from vnpy.api.tora.vntora import (
AsyncDispatchException, set_async_callback_exception_handler)
from vnpy.event import EventEngine
from vnpy.trader.gateway import BaseGateway
@ -20,6 +21,8 @@ def is_valid_front_address(address: str):
class ToraGateway(BaseGateway):
""""""
default_setting = {
"账号": "",
"密码": "",
@ -86,13 +89,6 @@ class ToraGateway(BaseGateway):
""""""
self._td_api.query_positions()
def write_log(self, msg: str):
"""
for easier test
"""
print(msg)
super().write_log(msg)
def _async_callback_exception_handler(self, e: AsyncDispatchException):
error_str = f"发生内部错误:\n" f"位置:{e.instance}.{e.function_name}" f"详细信息:{e.what}"
self.write_log(error_str)

View File

@ -163,6 +163,7 @@ class MainWindow(QtWidgets.QMainWindow):
def init_toolbar(self):
""""""
self.toolbar = QtWidgets.QToolBar(self)
self.toolbar.setObjectName("工具栏")
self.toolbar.setFloatable(False)
self.toolbar.setMovable(False)