diff --git a/vnpy/api/rest/rest_client.py b/vnpy/api/rest/rest_client.py index ee8e98ba..df7b6ea7 100644 --- a/vnpy/api/rest/rest_client.py +++ b/vnpy/api/rest/rest_client.py @@ -1,20 +1,17 @@ -import logging +import json import multiprocessing import os import sys import traceback -import uuid from datetime import datetime from enum import Enum from multiprocessing.dummy import Pool -from threading import Lock +from threading import Lock, Thread from types import TracebackType from typing import Any, Callable, List, Optional, Type, Union import requests -from vnpy.trader.utility import get_file_logger - class RequestStatus(Enum): ready = 0 # Request created @@ -25,6 +22,11 @@ class RequestStatus(Enum): pool: multiprocessing.pool.Pool = Pool(os.cpu_count() * 20) +CALLBACK_TYPE = Callable[[dict, "Request"], Any] +ON_FAILED_TYPE = Callable[[int, "Request"], Any] +ON_ERROR_TYPE = Callable[[Type, Exception, TracebackType, "Request"], Any] +CONNECTED_TYPE = Callable[["Request"], Any] + class Request(object): """ @@ -38,9 +40,11 @@ class Request(object): params: dict, data: Union[dict, str, bytes], headers: dict, - callback: Callable = None, - on_failed: Callable = None, - on_error: Callable = None, + callback: CALLBACK_TYPE = None, + on_failed: ON_FAILED_TYPE = None, + on_error: ON_ERROR_TYPE = None, + stream: bool = False, + on_connected: CONNECTED_TYPE = None, # for streaming request extra: Any = None, client: "RestClient" = None, ): @@ -52,6 +56,10 @@ class Request(object): self.data = data self.headers = headers + self.stream = stream + self.on_connected = on_connected + self.processing_line: Optional[str] = '' + self.on_failed = on_failed self.on_error = on_error self.extra = extra @@ -66,6 +74,14 @@ class Request(object): else: status_code = self.response.status_code + if self.stream: + response = self.processing_line + else: + if self.response is None: + response = None + else: + response = self.response.text + return ( "request: {method} {path} {http_code}: \n" "full_url: {full_url}\n" @@ -82,7 +98,7 @@ class Request(object): headers=self.headers, params=self.params, data=self.data, - response="" if self.response is None else self.response.text, + response=response, ) ) @@ -114,39 +130,28 @@ class RestClient(object): """ """ self.url_base = '' # type: str - self.logger: Optional[logging.Logger] = None + self._active = False self.proxies = None - self._active = False - self._tasks_lock = Lock() self._tasks: List[multiprocessing.pool.AsyncResult] = [] self._sessions_lock = Lock() self._sessions: List[requests.Session] = [] + self._streams_lock = Lock() + self._streams: List[Thread] = [] + @property def alive(self): return self._active - def init(self, - url_base: str, - proxy_host: str = "", - proxy_port: int = 0, - log_path: Optional[str] = None, - ): + def init(self, url_base: str, proxy_host: str = "", proxy_port: int = 0): """ Init rest client with url_base which is the API root address. e.g. 'https://www.bitmex.com/api/v1/' - :param url_base: - :param proxy_host: - :param proxy_port: - :param log_path: optional. file to save log. """ self.url_base = url_base - if log_path is not None: - self.logger = get_file_logger(log_path) - self.logger.setLevel(logging.DEBUG) if proxy_host and proxy_port: proxy = f"{proxy_host}:{proxy_port}" @@ -177,16 +182,55 @@ class RestClient(object): for task in self._tasks: task.wait() + def add_streaming_request( + self, + method: str, + path: str, + callback: CALLBACK_TYPE, + params: dict = None, + data: Union[dict, str, bytes] = None, + headers: dict = None, + on_connected: CONNECTED_TYPE = None, + on_failed: ON_FAILED_TYPE = None, + on_error: ON_ERROR_TYPE = None, + extra: Any = None, + ): + """ + See add_request for usage. + """ + request = Request( + method=method, + path=path, + params=params, + data=data, + headers=headers, + callback=callback, + on_failed=on_failed, + on_error=on_error, + extra=extra, + client=self, + stream=True, + on_connected=on_connected, + ) + # stream task should not push to thread pool + # because it is usually no return. + th = Thread( + target=self._process_stream_request, + args=[request, ], + ) + th.start() + return request + def add_request( self, method: str, path: str, - callback: Callable[[dict, "Request"], Any], + callback: CALLBACK_TYPE, params: dict = None, data: Union[dict, str, bytes] = None, headers: dict = None, - on_failed: Callable[[int, "Request"], Any] = None, - on_error: Callable[[Type, Exception, TracebackType, "Request"], Any] = None, + on_failed: ON_FAILED_TYPE = None, + on_error: ON_ERROR_TYPE = None, extra: Any = None, ): """ @@ -232,6 +276,11 @@ class RestClient(object): not_finished_tasks = [i for i in self._tasks if not i.ready()] self._tasks = not_finished_tasks + def _clean_finished_streams(self): + with self._streams_lock: + not_finished_streams = [i for i in self._streams if i.is_alive()] + self._streams = not_finished_streams + def _get_session(self): with self._sessions_lock: if self._sessions: @@ -293,10 +342,12 @@ class RestClient(object): """ return True - def _log(self, msg, *args): - logger = self.logger - if logger: - logger.debug(msg, *args) + def _process_stream_request(self, request: Request): + """ + Sending request to server and get result. + """ + self._process_request(request) + self._clean_finished_streams() def _process_request( self, request: Request @@ -306,55 +357,47 @@ class RestClient(object): """ try: with self._get_session() as session: - # sign request = self.sign(request) - # send request url = self.make_full_url(request.path) - uid = uuid.uuid4() - method = request.method - headers = request.headers - params = request.params - data = request.data - self._log("[%s] sending request %s %s, headers:%s, params:%s, data:%s", - uid, method, url, - headers, params, data) + stream = request.stream response = session.request( - method, + request.method, url, - headers=headers, - params=params, - data=data, + headers=request.headers, + params=request.params, + data=request.data, proxies=self.proxies, + stream=stream, ) request.response = response - self._log("[%s] received response from %s:%s", uid, method, url) - - # check result & call corresponding callbacks status_code = response.status_code - success = False - json_body: Optional[dict] = None - try: + if not stream: # normal API: + # just call callback with all contents received. if status_code // 100 == 2: # 2xx codes are all successful if status_code == 204: json_body = None else: json_body = response.json() - - if self.is_request_success(json_body, request): - success = True - finally: - if success: - request.status = RequestStatus.success - request.callback(json_body, request) + self._process_json_body(json_body, request) else: - request.status = RequestStatus.failed if request.on_failed: + request.status = RequestStatus.failed request.on_failed(status_code, request) else: self.on_failed(status_code, request) + else: # streaming API: + if request.on_connected: + request.on_connected(request) + # split response by lines, and call one callback for each line. + for line in response.iter_lines(chunk_size=None): + if line: + request.processing_line = line + json_body = json.loads(line) + self._process_json_body(json_body, request) + request.status = RequestStatus.success except Exception: request.status = RequestStatus.error t, v, tb = sys.exc_info() @@ -363,6 +406,18 @@ class RestClient(object): else: self.on_error(t, v, tb, request) + def _process_json_body(self, json_body: Optional[dict], request: "Request"): + status_code = request.response.status_code + if self.is_request_success(json_body, request): + request.status = RequestStatus.success + request.callback(json_body, request) + else: + request.status = RequestStatus.failed + if request.on_failed: + request.on_failed(status_code, request) + else: + self.on_failed(status_code, request) + def make_full_url(self, path: str): """ Make relative api path into full url. diff --git a/vnpy/gateway/oanda/__init__.py b/vnpy/gateway/oanda/__init__.py new file mode 100644 index 00000000..ccf8b071 --- /dev/null +++ b/vnpy/gateway/oanda/__init__.py @@ -0,0 +1 @@ +from .oanda_gateway import OandaGateway diff --git a/vnpy/gateway/oanda/oanda_api_base.py b/vnpy/gateway/oanda/oanda_api_base.py new file mode 100644 index 00000000..c3fe192f --- /dev/null +++ b/vnpy/gateway/oanda/oanda_api_base.py @@ -0,0 +1,54 @@ +import json +import sys +from typing import TYPE_CHECKING + +from vnpy.api.rest import Request, RestClient + +if TYPE_CHECKING: + from vnpy.gateway.oanda import OandaGateway + +_ = lambda x: x # noqa + + +class OandaApiBase(RestClient): + """ + Oanda Base API + """ + + def __init__(self, gateway: "OandaGateway"): + super().__init__() + self.gateway = gateway + self.gateway_name = gateway.gateway_name + + self.key = "" + self.secret = b"" + + def sign(self, request): + """ + Generate BitMEX signature. + """ + if request.data: + request.data = json.dumps(request.data) + request.headers = { + "Authorization": f"Bearer {self.key}", + "Content-Type": "application/json", + } + return request + + def is_request_success(self, data: dict, request: "Request"): + # message should be check too + # but checking only this is enough for us. + return super().is_request_success(data, request) and 'errorMessage' not in data + + def on_error( + self, exception_type: type, exception_value: Exception, tb, request: Request + ): + """ + Callback to handler request exception. + """ + msg = f"触发异常,状态码:{exception_type},信息:{exception_value}" + self.gateway.write_log(msg) + + sys.stderr.write( + self.exception_detail(exception_type, exception_value, tb, request) + ) diff --git a/vnpy/gateway/oanda/oanda_common.py b/vnpy/gateway/oanda/oanda_common.py new file mode 100644 index 00000000..968063c7 --- /dev/null +++ b/vnpy/gateway/oanda/oanda_common.py @@ -0,0 +1,86 @@ +import time +from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING + +from vnpy.trader.constant import Direction, Interval, OrderType, Status + +if TYPE_CHECKING: + # noinspection PyUnresolvedReferences + from vnpy.gateway.oanda import OandaGateway # noqa + +STATUS_OANDA2VT = { + "PENDING": Status.NOTTRADED, + "FILLED": Status.ALLTRADED, + "CANCELLED": Status.CANCELLED, + # "TRIGGERED": Status.REJECTED, +} +STOP_ORDER_STATUS_OANDA2VT = { + "Untriggered": Status.NOTTRADED, + "Triggered": Status.NOTTRADED, + # Active: triggered and placed. + # since price is market price, placed == AllTraded? + "Active": Status.ALLTRADED, + "Cancelled": Status.CANCELLED, + "Rejected": Status.REJECTED, +} +DIRECTION_VT2OANDA = {Direction.LONG: "Buy", Direction.SHORT: "Sell"} +DIRECTION_OANDA2VT = {v: k for k, v in DIRECTION_VT2OANDA.items()} +DIRECTION_OANDA2VT.update({ + "None": Direction.LONG +}) + +OPPOSITE_DIRECTION = { + Direction.LONG: Direction.SHORT, + Direction.SHORT: Direction.LONG, +} + +ORDER_TYPE_VT2OANDA = { + OrderType.LIMIT: "LIMIT", + OrderType.MARKET: "MARKET", + OrderType.STOP: "STOP", +} + +ORDER_TYPE_OANDA2VT = {v: k for k, v in ORDER_TYPE_VT2OANDA.items()} +ORDER_TYPE_OANDA2VT.update({ + 'LIMIT_ORDER': OrderType.LIMIT, + 'MARKET_ORDER': OrderType.MARKET, + 'STOP_ORDER': OrderType.STOP, +}) + +INTERVAL_VT2OANDA = { + Interval.MINUTE: "M1", + Interval.HOUR: "H1", + Interval.DAILY: "D", + Interval.WEEKLY: "W", +} +INTERVAL_VT2OANDA_INT = { + Interval.MINUTE: 1, + Interval.HOUR: 60, + Interval.DAILY: 60 * 24, + Interval.WEEKLY: 60 * 24 * 7, +} +INTERVAL_VT2OANDA_DELTA = { + Interval.MINUTE: timedelta(minutes=1), + Interval.HOUR: timedelta(hours=1), + Interval.DAILY: timedelta(days=1), + Interval.WEEKLY: timedelta(days=7), +} + +utc_tz = timezone.utc +local_tz = datetime.now(timezone.utc).astimezone().tzinfo + + +def generate_timestamp(expire_after: float = 30) -> int: + """ + :param expire_after: expires in seconds. + :return: timestamp in milliseconds + """ + return int(time.time() * 1000 + expire_after * 1000) + + +def parse_datetime(dt: str) -> datetime: + return datetime.fromisoformat(dt[:-4]) + + +def parse_time(dt: str) -> str: + return dt[11:26] diff --git a/vnpy/gateway/oanda/oanda_gateway.py b/vnpy/gateway/oanda/oanda_gateway.py new file mode 100644 index 00000000..f099ac69 --- /dev/null +++ b/vnpy/gateway/oanda/oanda_gateway.py @@ -0,0 +1,204 @@ +""" +author: nanoric +todo: + * query history +""" +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Dict, List, Optional, Tuple + +from vnpy.event import Event +from vnpy.trader.constant import (Direction, Exchange, Interval, Offset, OrderType, Status) +from vnpy.trader.event import EVENT_TIMER +from vnpy.trader.gateway import BaseGateway +from vnpy.trader.object import (BarData, CancelRequest, HistoryRequest, OrderData, OrderRequest, + SubscribeRequest) +from .oanda_common import INTERVAL_VT2OANDA_DELTA, ORDER_TYPE_OANDA2VT, local_tz, parse_time, utc_tz +from .oanda_rest_api import HistoryDataNextInfo, OandaRestApi +from .oanda_stream_api import OandaStreamApi + +_ = lambda x: x # noqa + + +@dataclass() +class HistoryDataInfo: + bars: List[BarData] + extra: Any + + +class OandaGateway(BaseGateway): + """ + VN Trader Gateway for BitMEX connection. + """ + + default_setting = { + "APIKey": "", + "会话数": 3, + "服务器": ["REAL", "TESTNET"], + "代理地址": "", + "代理端口": "", + } + HISTORY_RECORD_PER_REQUEST = 5000 # # of records per history request + + exchanges = [Exchange.OANDA] + + def __init__(self, event_engine): + """Constructor""" + super(OandaGateway, self).__init__(event_engine, "OANDA") + + self.rest_api = OandaRestApi(self) + self.stream_api = OandaStreamApi(self) + + self.account_id: Optional[str] = None + + self.orders: Dict[str, OrderData] = {} + self.local2sys_map: Dict[str, str] = {} + event_engine.register(EVENT_TIMER, self.process_timer_event) + + def connect(self, setting: dict): + """""" + key = setting["APIKey"] + session_number = setting["会话数"] + server = setting["服务器"] + proxy_host = setting["代理地址"] + proxy_port = setting["代理端口"] + + if proxy_port.isdigit(): + proxy_port = int(proxy_port) + else: + proxy_port = 0 + + self.stream_api.connect(key, session_number, + server, proxy_host, proxy_port) + self.rest_api.connect(key, session_number, + server, proxy_host, proxy_port) + self.query_account() + + def subscribe(self, req: SubscribeRequest): + """""" + assert self.account_id is not None, _("请先初始化并连接RestAPI") + self.stream_api.subscribe(req) + + def send_order(self, req: OrderRequest): + """""" + return self.rest_api.send_order(req) + + def cancel_order(self, req: CancelRequest): + """""" + self.rest_api.cancel_order(req) + + def query_account(self): + """""" + self.rest_api.query_accounts() + + def query_position(self): + """""" + self.rest_api.query_positions() + + def query_first_history(self, + symbol: str, + interval: Interval, + start: datetime, + ) -> Tuple[List[BarData], "HistoryDataNextInfo"]: + + # datetime for a bar is close_time + # we got open_time from API. + adjustment = INTERVAL_VT2OANDA_DELTA[interval] + + utc_time = start.replace(tzinfo=local_tz).astimezone(tz=utc_tz) + return self.rest_api.query_history( + symbol=symbol, + interval=interval, + + # todo: vnpy: shall all datetime object use tzinfo? + start=utc_time - adjustment, + ) + + def query_next_history(self, + next_info: Any, + ): + data: "HistoryDataNextInfo" = next_info + return self.rest_api.query_history( + symbol=data.symbol, + interval=data.interval, + start=data.end, + ) + + def query_history(self, req: HistoryRequest): + """ + todo: vnpy: download in parallel + todo: vnpy: use yield to simplify logic + :raises RequestFailedException: if server reply an error. + Any Exception might be raised from requests.request: network error. + """ + history = [] + + symbol = req.symbol + interval = req.interval + start = req.start + + bars, next_info = self.query_first_history( + symbol=symbol, + interval=interval, + start=start, + ) + msg = f"获取历史数据成功,{req.symbol} - {req.interval.value},{bars[0].datetime} - {bars[-1].datetime}" + self.write_log(msg) + history.extend(bars) + + end = req.end + if end is None: + end = datetime.now() + if bars[-1].datetime >= end or len(bars) < self.HISTORY_RECORD_PER_REQUEST: + return history + while True: + bars, next_info = self.query_next_history(next_info) + history.extend(bars) + + msg = f"获取历史数据成功,{req.symbol} - {req.interval.value},{bars[0].datetime} - {bars[-1].datetime}" + self.write_log(msg) + + # Break if total data count less than (latest date collected) + if bars[-1].datetime >= end or len(bars) < self.HISTORY_RECORD_PER_REQUEST: + break + return history + + def close(self): + """""" + self.rest_api.stop() + self.stream_api.stop() + + def process_timer_event(self, event: Event): + """""" + if self.rest_api.fully_initialized: + self.rest_api.query_account_changes() + + def write_log(self, msg: str): + print(msg) + return super().write_log(msg) + + def parse_order_data(self, data, status: Status, time_key: str): + client_extension = data.get('clientExtensions', None) + if client_extension is None: + order_id = data['id'] + else: + order_id = client_extension['id'] + vol = int(data['units']) + type_ = ORDER_TYPE_OANDA2VT[data['type']] + + order = OrderData( + gateway_name=self.gateway_name, + symbol=data['instrument'], + exchange=Exchange.OANDA, + orderid=order_id, + type=type_, + direction=Direction.LONG if vol > 0 else Direction.SHORT, + offset=Offset.NONE, + price=float(data['price']) if type_ is not OrderType.MARKET else 0.0, + volume=abs(vol), + # status=STATUS_OANDA2VT[data['state']], + status=status, + time=parse_time(data[time_key]), + ) + self.orders[order_id] = order + return order diff --git a/vnpy/gateway/oanda/oanda_rest_api.py b/vnpy/gateway/oanda/oanda_rest_api.py new file mode 100644 index 00000000..852931c8 --- /dev/null +++ b/vnpy/gateway/oanda/oanda_rest_api.py @@ -0,0 +1,451 @@ +from dataclasses import dataclass +from datetime import datetime +from threading import Lock +from typing import Dict, List, Optional, TYPE_CHECKING, Tuple + +from requests import ConnectionError + +from vnpy.api.rest import Request +from vnpy.gateway.oanda.oanda_api_base import OandaApiBase +from vnpy.trader.constant import Direction, Exchange, Interval, Offset, Product, Status +from vnpy.trader.object import AccountData, BarData, CancelRequest, ContractData, OrderRequest, \ + PositionData +from .oanda_common import (INTERVAL_VT2OANDA, INTERVAL_VT2OANDA_DELTA, ORDER_TYPE_VT2OANDA, + STATUS_OANDA2VT, local_tz, parse_datetime, parse_time, utc_tz) + +if TYPE_CHECKING: + from vnpy.gateway.oanda import OandaGateway +_ = lambda x: x # noqa + +HOST = "https://api-fxtrade.oanda.com" +TEST_HOST = "https://api-fxpractice.oanda.com" + +# asked from official developer +PRICE_TICKS = { + "BTCUSD": 0.5, + "ETHUSD": 0.05, + "EOSUSD": 0.001, + "XRPUSD": 0.0001, +} + + +@dataclass() +class HistoryDataNextInfo: + symbol: str + interval: Interval + end: datetime + + +class RequestFailedException(Exception): + pass + + +class OandaRestApi(OandaApiBase): + """ + Oanda Rest API + """ + + def __init__(self, gateway: "OandaGateway"): + """""" + super().__init__(gateway) + + self.order_count = 1_000_000 + self.order_count_lock = Lock() + + self.connect_time = 0 + + self.contracts: Dict[str, ContractData] = {} + + self.last_account_transaction_id: Optional[str] = None + + # used for automatic tests. + self.account_initialized = False + self.orders_initialized = False + + def connect( + self, + key: str, + session_number: int, + server: str, + proxy_host: str, + proxy_port: int, + ): + """ + Initialize connection to REST server. + """ + self.key = key + + self.connect_time = ( + int(datetime.now().strftime("%y%m%d%H%M%S")) * self.order_count + ) + + if server == "REAL": + self.init(HOST, proxy_host, proxy_port) + else: + self.init(TEST_HOST, proxy_host, proxy_port) + + self.start(session_number) + + self.gateway.write_log(_("REST API启动成功")) + + def _new_order_id(self) -> str: + """""" + with self.order_count_lock: + self.order_count += 1 + return f'a{self.connect_time}{self.order_count}' + + @staticmethod + def is_local_order_id(order_id: str): + return order_id[0] == 'a' + + def send_order(self, req: OrderRequest): + """""" + req.offset = Offset.NONE + order_id = self._new_order_id() + + symbol = req.symbol + vol = int(req.volume) + direction = req.direction + data = { + "instrument": symbol, + # positive for long , negative for short + "units": vol if direction is Direction.LONG else -vol, + "clientExtensions": { + "id": order_id, + } + } + + order = req.create_order_data(order_id, self.gateway_name) + order.time = parse_time(datetime.now().isoformat()) + + # Only add price for limit order. + data["type"] = ORDER_TYPE_VT2OANDA[req.type] + data["price"] = str(req.price) + self.gateway.orders[order.orderid] = order + self.add_request( + "POST", + f"/v3/accounts/{self.gateway.account_id}/orders", + callback=self.on_send_order, + data={'order': data}, + extra=order, + on_failed=self.on_send_order_failed, + on_error=self.on_send_order_error, + ) + self.gateway.on_order(order) + return order.vt_orderid + + def cancel_order(self, req: CancelRequest): + """""" + order_id = req.orderid + order = self.gateway.orders[order_id] + + if self.is_local_order_id(order_id): + order_id = '@' + order_id + self.add_request( + "PUT", + f"/v3/accounts/{self.gateway.account_id}/orders/{order_id}/cancel", + callback=self.on_cancel_order, + extra=order, + ) + + def query_history(self, + symbol: str, + interval: Interval, + start: datetime, + limit: int = None, + ) -> Tuple[List[BarData], "HistoryDataNextInfo"]: + """ + Get history data synchronously. + """ + if limit is None: + limit = self.gateway.HISTORY_RECORD_PER_REQUEST + + bars = [] + + # datetime for a bar is close_time + # we got open_time from API. + adjustment = INTERVAL_VT2OANDA_DELTA[interval] + + # todo: RestClient: return RestClient.Request object instead of requests.Response. + resp = self.request( + "GET", + f"/v3/instruments/{symbol}/candles", + params={ + # "price": "M", # M for mids, B for bids, A for asks + "granularity": INTERVAL_VT2OANDA[interval], + "count": 5000, + "from": start.isoformat(), + } + ) + + # Break if request failed with other status code + raw_data = resp.json() + # noinspection PyTypeChecker + if not self.is_request_success(raw_data, None): + msg = f"获取历史数据失败,状态码:{resp.status_code},信息:{resp.text}" + self.gateway.write_log(msg) + raise RequestFailedException(msg) + result = raw_data['candles'] + for data in result: + bar_data = data['mid'] + open_time = parse_datetime(data["time"]) + close_time = open_time + adjustment + bar = BarData( + symbol=symbol, + exchange=Exchange.OANDA, + datetime=close_time, + interval=interval, + volume=data["volume"], + open_price=float(bar_data["o"]), + high_price=float(bar_data["h"]), + low_price=float(bar_data["l"]), + close_price=float(bar_data["c"]), + gateway_name=self.gateway_name + ) + bars.append(bar) + + end = bars[-1].datetime.replace(tzinfo=utc_tz).astimezone(tz=local_tz) + return bars, HistoryDataNextInfo(symbol, interval, end) + + def on_send_order_failed(self, status_code: int, request: Request): + """ + Callback when sending order failed on server. + """ + data = request.response.json() + + order = request.extra + order.status = Status.REJECTED + self.gateway.on_order(order) + + msg = f"委托失败,错误代码:{data.get('errorCode', '')}, 错误信息:{data['errorMessage']}" + self.gateway.write_log(msg) + + def on_send_order_error( + self, exception_type: type, exception_value: Exception, tb, request: Request + ): + """ + Callback when sending order caused exception. + """ + order = request.extra + order.status = Status.REJECTED + self.gateway.on_order(order) + + # Record exception if not ConnectionError + if not issubclass(exception_type, ConnectionError): + self.on_error(exception_type, exception_value, tb, request) + + def on_send_order(self, raw_data: dict, request: Request): + """""" + # order: OrderData = request.extra + # creation = raw_data.get('orderCreateTransaction', None) + # if creation is not None: + # order.status = Status.NOTTRADED + # order.time = creation['time'][11:19] + # self.gateway.on_order(order) + + # cancel = raw_data.get('orderCancelTransaction', None) + # if cancel is not None: + # # potential bug: stream API will generate a Status.Cancel Order for this transaction + # order.status = Status.REJECTED + # order.time = parse_time(cancel['time']) + pass + + def on_cancel_order(self, raw_data: dict, request: Request): + """""" + # order: OrderData = request.extra + # order.status = Status.CANCELLED + # order.time = raw_data['orderCancelTransaction']['time'][11:19] + # self.gateway.on_order(order) + pass + + def on_failed(self, status_code: int, request: Request): + """ + Callback to handle request failed. + """ + data = request.response.json() + self._handle_error_response(data, request) + + def _handle_error_response(self, data, request, operation_name: str = None): + if operation_name is None: + operation_name = request.path + + # todo: rate limit? + + error_msg = data.get("message", None) + if error_msg is None: + error_msg = data.get("errorMessage", None) + msg = f"请求{operation_name}失败,状态码:{request.status},错误消息:{error_msg}" + msg += f'\n{request}' + self.gateway.write_log(msg) + + def query_positions(self): + self.add_request("GET", + f"/v3/accounts/{self.gateway.account_id}/positions", + self.on_query_positions + ) + + def query_account_changes(self): + do_nothing = lambda a, b, c, d: None # noqa + + if self.last_account_transaction_id is not None: + account_id = self.gateway.account_id + self.add_request("GET", + f"/v3/accounts/{account_id}/changes", + params={ + "sinceTransactionID": self.last_account_transaction_id + }, + callback=self.on_query_account_changes, + extra=account_id, + on_error=do_nothing, + ) + + def on_query_account_changes(self, raw_data: dict, request: "Request"): + account_id: str = request.extra + + # state: we focus mainly on account balance + state = raw_data['state'] + NAV = float(state['NAV']) + pnl = float(state['unrealizedPL']) + account = AccountData( + gateway_name=self.gateway_name, + accountid=account_id, + balance=NAV - pnl, + frozen=0, # no frozen + ) + self.gateway.on_account(account) + + # changes: we focus mainly on position changes + changes = raw_data['changes'] + positions = changes['positions'] + unrealized_pnls: Dict[(str, Direction), float] = {} + + # pnl in query_account_changes is different from data returned by query_account + # we have to get pnl from 'state' record. + for pos_state_data in state['positions']: + symbol = pos_state_data['instrument'] + unrealized_pnls[(symbol, Direction.LONG)] = float(pos_state_data['longUnrealizedPL']) + unrealized_pnls[(symbol, Direction.SHORT)] = float(pos_state_data['shortUnrealizedPL']) + + for pos_data in positions: + pos_long, pos_short = self.parse_position_data(pos_data) + symbol = pos_long.symbol + pos_long.pnl = unrealized_pnls[(symbol, Direction.LONG)] + pos_short.pnl = unrealized_pnls[(symbol, Direction.SHORT)] + self.gateway.on_position(pos_long) + self.gateway.on_position(pos_short) + self.last_account_transaction_id = raw_data['lastTransactionID'] + + def on_query_positions(self, raw_data: dict, request: "Request"): + for pos_data in raw_data['positions']: + pos_long, pos_short = self.parse_position_data(pos_data) + self.gateway.on_position(pos_long) + self.gateway.on_position(pos_short) + + def parse_position_data(self, pos_data) -> Tuple[PositionData, PositionData]: + symbol = pos_data['instrument'] + pos_long, pos_short = [ + PositionData( + gateway_name=self.gateway_name, + symbol=symbol, + exchange=Exchange.OANDA, + direction=direction, + volume=int(data['units']), + price=float(data.get('averagePrice', 0.0)), + pnl=float(data.get('unrealizedPL', 0.0)), + ) + for direction, data in ( + (Direction.LONG, pos_data['long']), + (Direction.SHORT, pos_data['short']) + ) + ] + return pos_long, pos_short + + def query_orders(self): + """ + query all orders, including stop orders + """ + self.add_request("GET", + f"/v3/accounts/{self.gateway.account_id}/orders?state=ALL", + callback=self.on_query_orders, + ) + + def on_query_orders(self, raw_data: dict, request: "Request"): + for data in raw_data['orders']: + order = self.gateway.parse_order_data(data, + STATUS_OANDA2VT[data['state']], + 'createTime') + self.gateway.on_order(order) + self.orders_initialized = True + + def query_accounts(self): + self.add_request("GET", + "/v3/accounts", + callback=self.on_query_accounts, + ) + + def on_query_accounts(self, raw_data: dict, request: "Request"): + """ + {"accounts":[{"id":"101-001-12185735-001","tags":[]}]} + """ + for acc in raw_data['accounts']: + account_id: str = acc['id'] + self.gateway.account_id = account_id + + self.query_account(account_id) + self.query_contracts(account_id) + self.query_orders() + # self.query_positions() + + def query_account(self, account_id: str): + self.add_request("GET", + f"/v3/accounts/{account_id}", + callback=self.on_query_account, + ) + + def on_query_account(self, raw_data: dict, request: "Request"): + acc = raw_data['account'] + account_data = AccountData( + gateway_name=self.gateway_name, + accountid=acc['id'], + balance=float(acc['balance']), + frozen=0, # no frozen + ) + self.gateway.on_account(account_data) + for data in acc['orders']: + # order = self.parse_order_data(data) + # self.gateway.on_order(order) + pass + for pos_data in acc['positions']: + pos_long, pos_short = self.parse_position_data(pos_data) + self.gateway.on_position(pos_long) + self.gateway.on_position(pos_short) + for trade in acc['trades']: + pass + self.last_account_transaction_id = acc['lastTransactionID'] + self.gateway.stream_api.subscribe_transaction() + self.account_initialized = True + + @property + def fully_initialized(self): + return self.account_initialized and self.orders_initialized + + def query_contracts(self, account_id): + self.add_request("GET", + f"/v3/accounts/{account_id}/instruments", + self.on_query_contracts) + + def on_query_contracts(self, data: dict, request: "Request"): + for ins in data['instruments']: + symbol = ins['name'] + contract = ContractData( + gateway_name=self.gateway_name, + symbol=symbol, + exchange=Exchange.OANDA, + name=symbol, + product=Product.FOREX, + size=1, + pricetick=1.0 / pow(10, ins['displayPrecision']), + ) + self.gateway.on_contract(contract) + self.contracts[symbol] = contract + self.gateway.write_log(_("合约信息查询成功")) diff --git a/vnpy/gateway/oanda/oanda_stream_api.py b/vnpy/gateway/oanda/oanda_stream_api.py new file mode 100644 index 00000000..4a433d02 --- /dev/null +++ b/vnpy/gateway/oanda/oanda_stream_api.py @@ -0,0 +1,227 @@ +from copy import copy +from dataclasses import dataclass +from functools import partial +from http.client import IncompleteRead, RemoteDisconnected +from typing import Callable, TYPE_CHECKING, Type + +from urllib3.exceptions import ProtocolError + +from vnpy.api.rest import Request +from vnpy.trader.constant import Exchange, Interval, Offset, Status +from vnpy.trader.object import OrderData, SubscribeRequest, TickData, TradeData +from .oanda_api_base import OandaApiBase +from .oanda_common import (parse_datetime, parse_time) + +if TYPE_CHECKING: + from vnpy.gateway.oanda import OandaGateway +_ = lambda x: x # noqa + +HOST = "https://stream-fxtrade.oanda.com" +TEST_HOST = "https://stream-fxpractice.oanda.com" + +# asked from official developer +PRICE_TICKS = { + "BTCUSD": 0.5, + "ETHUSD": 0.05, + "EOSUSD": 0.001, + "XRPUSD": 0.0001, +} + + +@dataclass() +class HistoryDataNextInfo: + symbol: str + interval: Interval + end: int + + +class RequestFailedException(Exception): + pass + + +class OandaStreamApi(OandaApiBase): + """ + Oanda Streaming API + """ + + def __init__(self, gateway: "OandaGateway"): + """""" + super().__init__(gateway) + + self.fully_initialized = False + + self._transaction_callbacks = { + 'ORDER_FILL': self.on_order_filled, + 'MARKET_ORDER': self.on_order, + 'LIMIT_ORDER': self.on_order, + 'STOP_ORDER': self.on_order, + 'ORDER_CANCEL': self.on_order_canceled, + + # 'HEARTBEAT': do_nothing, + } + + def connect( + self, + key: str, + session_number: int, + server: str, + proxy_host: str, + proxy_port: int, + ): + """ + Initialize connection to REST server. + """ + self.key = key + + if server == "REAL": + self.init(HOST, proxy_host, proxy_port) + else: + self.init(TEST_HOST, proxy_host, proxy_port) + + self.start(session_number) + + self.gateway.write_log(_("Streaming API启动成功")) + + def subscribe(self, req: SubscribeRequest): + # noinspection PyTypeChecker + self.add_streaming_request( + "GET", + f"/v3/accounts/{self.gateway.account_id}/pricing/stream?instruments={req.symbol}", + callback=self.on_price, + on_error=partial(self.on_streaming_error, partial(self.subscribe, copy(req))), + ) + + def on_price(self, data: dict, request: Request): + type_ = data['type'] + if type_ == 'PRICE': + symbol = data['instrument'] + # only one level of bids/asks + bid = data['bids'][0] + ask = data['asks'][0] + tick = TickData( + gateway_name=self.gateway_name, + symbol=symbol, + exchange=Exchange.OANDA, + datetime=parse_datetime(data['time']), + name=symbol, + bid_price_1=float(bid['price']), + bid_volume_1=bid['liquidity'], + ask_price_1=float(ask['price']), + ask_volume_1=ask['liquidity'], + ) + self.gateway.on_tick(tick) + + def has_error(self, target_type: Type[Exception], e: Exception): + """check if error type \a target_error exists inside \a e""" + if isinstance(e, target_type): + return True + for arg in e.args: + if isinstance(arg, Exception) and self.has_error(target_type, arg): + return True + return False + + def on_streaming_error(self, + re_subscribe: Callable, + exception_type: type, + exception_value: Exception, + tb, + request: Request, + ): + """normally triggered by network error.""" + # skip known errors + known = False + for et in (ProtocolError, IncompleteRead, RemoteDisconnected,): + if self.has_error(et, exception_value): + known = True + break + + if known: + # re-subscribe + re_subscribe() + # write log for any unknown errors + else: + super().on_error(exception_type, exception_value, tb, request) + + def subscribe_transaction(self): + # noinspection PyTypeChecker + self.add_streaming_request( + "GET", + f"/v3/accounts/{self.gateway.account_id}/transactions/stream", + callback=self.on_transaction, + on_connected=self.on_subscribed_transaction, + on_error=partial(self.on_streaming_error, partial(self.subscribe_transaction, )), + ) + + def on_subscribed_transaction(self, request: "Request"): + self.fully_initialized = True + + def on_transaction(self, data: dict, request: "Request"): + type_ = data['type'] + callback = self._transaction_callbacks.get(type_, None) + if callback is not None: + callback(data, request) + elif type_ != "HEARTBEAT": + print(type_) + + def on_order(self, data: dict, request: "Request"): + order = self.gateway.parse_order_data(data, + Status.NOTTRADED, + 'time', + ) + self.gateway.on_order(order) + + def on_order_canceled(self, data: dict, request: "Request"): + order_id = data.get('clientOrderID', None) + if order_id is None: + order_id = data['id'] + order = self.gateway.orders[order_id] + order.status = Status.CANCELLED + order.time = parse_time(data['time']) + self.gateway.on_order(order) + + def on_order_filled(self, data: dict, request: "Request"): + order_id = data.get('clientOrderID', None) + if order_id is None: + order_id = data['orderID'] + + order: OrderData = self.gateway.orders[order_id] + + # # new API: + # price = 0.0 + # if 'tradeOpened' in data: + # price += float(data['tradeOpened']['price']) + # if 'tradeReduced' in data: + # price += float(data['tradeReduced']['price']) + # if 'tradeClosed' in data: + # price += sum([float(i['price']) for i in data['tradeClosed']]) + + # note: 'price' record is Deprecated + # but since this is faster and easier, we use this record. + price = float(data['price']) + + # for Oanda, one order fill is a single trade. + trade = TradeData( + gateway_name=self.gateway_name, + symbol=order.symbol, + exchange=Exchange.OANDA, + orderid=order_id, + tradeid=order_id, + direction=order.direction, + offset=Offset.NONE, + price=price, + volume=order.volume, + time=parse_time(data['time']), + ) + self.gateway.on_trade(trade) + + # this message indicate that this order is full filled. + # ps: oanda's order has only two state: NOTTRADED, ALLTRADED. It it my settings error? + order.traded = order.volume + order.status = Status.ALLTRADED + # order.time = trade.time + order.time = parse_time(data['time']) + self.gateway.on_order(order) + + # quick references + on_tick = on_price + on_trade = on_order_filled diff --git a/vnpy/trader/constant.py b/vnpy/trader/constant.py index 5ed93469..41838766 100644 --- a/vnpy/trader/constant.py +++ b/vnpy/trader/constant.py @@ -111,6 +111,8 @@ class Exchange(Enum): EUNX = "EUNX" # Euronext Exchange KRX = "KRX" # Korean Exchange + OANDA = "OANDA" # oanda.com + # CryptoCurrency BITMEX = "BITMEX" OKEX = "OKEX"