Merge branch 'v2.0.1-DEV' of https://github.com/vnpy/vnpy into v2.0.1-DEV

This commit is contained in:
vn.py 2019-03-17 13:11:59 +08:00
commit 8e5ea24a36
5 changed files with 152 additions and 69 deletions

View File

@ -6,7 +6,7 @@ from datetime import datetime
from enum import Enum from enum import Enum
from multiprocessing.dummy import Pool from multiprocessing.dummy import Pool
from queue import Empty, Queue from queue import Empty, Queue
from typing import Any, Callable from typing import Any, Callable, Optional
import requests import requests
@ -79,7 +79,7 @@ class RestClient(object):
""" """
HTTP Client designed for all sorts of trading RESTFul API. HTTP Client designed for all sorts of trading RESTFul API.
* Reimplement before_request function to add signature function. * Reimplement sign function to add signature function.
* Reimplement on_failed function to handle Non-2xx responses. * Reimplement on_failed function to handle Non-2xx responses.
* Use on_failed parameter in add_request function for individual Non-2xx response handling. * Use on_failed parameter in add_request function for individual Non-2xx response handling.
* Reimplement on_error function to handle exception msg. * Reimplement on_error function to handle exception msg.
@ -88,7 +88,7 @@ class RestClient(object):
def __init__(self): def __init__(self):
""" """
""" """
self.url_base = None # type: str self.url_base = '' # type: str
self._active = False self._active = False
self._queue = Queue() self._queue = Queue()
@ -208,7 +208,7 @@ class RestClient(object):
exception_type: type, exception_type: type,
exception_value: Exception, exception_value: Exception,
tb, tb,
request: Request, request: Optional[Request],
): ):
""" """
Default on_error handler for Python exception. Default on_error handler for Python exception.
@ -223,7 +223,7 @@ class RestClient(object):
exception_type: type, exception_type: type,
exception_value: Exception, exception_value: Exception,
tb, tb,
request: Request, request: Optional[Request],
): ):
text = "[{}]: Unhandled RestClient Error:{}\n".format( text = "[{}]: Unhandled RestClient Error:{}\n".format(
datetime.now().isoformat(), exception_type datetime.now().isoformat(), exception_type
@ -236,8 +236,8 @@ class RestClient(object):
return text return text
def _process_request( def _process_request(
self, request: Request, session: requests.session self, request: Request, session: requests.Session
): # type: (Request, requests.Session)->None ):
""" """
Sending request to server and get result. Sending request to server and get result.
""" """

View File

@ -23,13 +23,16 @@ class WebsocketClient(object):
Default serialization format is json. Default serialization format is json.
Callbacks to reimplement: Callbacks to overrides:
* unpack_data
* on_connected * on_connected
* on_disconnected * on_disconnected
* on_packet * on_packet
* on_error * on_error
After start() is called, the ping thread will ping server every 60 seconds. After start() is called, the ping thread will ping server every 60 seconds.
If you want to send anything other than JSON, override send_packet.
""" """
def __init__(self): def __init__(self):
@ -92,22 +95,28 @@ class WebsocketClient(object):
def send_packet(self, packet: dict): def send_packet(self, packet: dict):
""" """
Send a packet (dict data) to server Send a packet (dict data) to server
override this if you want to send non-json packet
""" """
text = json.dumps(packet) text = json.dumps(packet)
self._record_last_sent_text(text) self._record_last_sent_text(text)
return self._get_ws().send(text, opcode=websocket.ABNF.OPCODE_TEXT) return self._send_text(text)
def send_text(self, text: str): def _send_text(self, text: str):
""" """
Send a text string to server. Send a text string to server.
""" """
return self._get_ws().send(text, opcode=websocket.ABNF.OPCODE_TEXT) ws = self._ws
if ws:
ws.send(text, opcode=websocket.ABNF.OPCODE_TEXT)
def send_binary(self, data: bytes): def _send_binary(self, data: bytes):
""" """
Send bytes data to server. Send bytes data to server.
""" """
return self._get_ws().send_binary(data) ws = self._ws
if ws:
ws._send_binary(data)
def _reconnect(self): def _reconnect(self):
"""""" """"""
@ -137,11 +146,6 @@ class WebsocketClient(object):
self._ws.close() self._ws.close()
self._ws = None self._ws = None
def _get_ws(self):
""""""
with self._ws_lock:
return self._ws
def _run(self): def _run(self):
""" """
Keep running till stop is called. Keep running till stop is called.
@ -152,7 +156,7 @@ class WebsocketClient(object):
# todo: onDisconnect # todo: onDisconnect
while self._active: while self._active:
try: try:
ws = self._get_ws() ws = self._ws
if ws: if ws:
text = ws.recv() text = ws.recv()
@ -189,7 +193,7 @@ class WebsocketClient(object):
""" """
Default serialization format is json. Default serialization format is json.
Reimplement this method if you want to use other serialization format. override this method if you want to use other serialization format.
""" """
return json.loads(data) return json.loads(data)
@ -209,7 +213,7 @@ class WebsocketClient(object):
def _ping(self): def _ping(self):
"""""" """"""
ws = self._get_ws() ws = self._ws
if ws: if ws:
ws.send("ping", websocket.ABNF.OPCODE_PING) ws.send("ping", websocket.ABNF.OPCODE_PING)

View File

@ -8,6 +8,7 @@ import sys
import time import time
from copy import copy from copy import copy
from datetime import datetime from datetime import datetime
from threading import Lock
from urllib.parse import urlencode from urllib.parse import urlencode
from requests import ConnectionError from requests import ConnectionError
@ -62,7 +63,7 @@ class BitmexGateway(BaseGateway):
default_setting = { default_setting = {
"key": "", "key": "",
"secret": "", "secret": "",
"session": 3, "session_number": 3,
"server": ["REAL", "TESTNET"], "server": ["REAL", "TESTNET"],
"proxy_host": "127.0.0.1", "proxy_host": "127.0.0.1",
"proxy_port": 1080, "proxy_port": 1080,
@ -79,15 +80,16 @@ class BitmexGateway(BaseGateway):
"""""" """"""
key = setting["key"] key = setting["key"]
secret = setting["secret"] secret = setting["secret"]
session = setting["session"] session_number = setting["session_number"]
server = setting["server"] server = setting["server"]
proxy_host = setting["proxy_host"] proxy_host = setting["proxy_host"]
proxy_port = setting["proxy_port"] proxy_port = setting["proxy_port"]
self.rest_api.connect(key, secret, session, self.rest_api.connect(key, secret, session_number,
server, proxy_host, proxy_port) server, proxy_host, proxy_port)
self.ws_api.connect(key, secret, server, proxy_host, proxy_port) self.ws_api.connect(key, secret, server, proxy_host, proxy_port)
# websocket will push all account status on connected, including asset, position and orders.
def subscribe(self, req: SubscribeRequest): def subscribe(self, req: SubscribeRequest):
"""""" """"""
@ -131,6 +133,8 @@ class BitmexRestApi(RestClient):
self.secret = "" self.secret = ""
self.order_count = 1_000_000 self.order_count = 1_000_000
self.order_count_lock = Lock()
self.connect_time = 0 self.connect_time = 0
def sign(self, request): def sign(self, request):
@ -172,7 +176,7 @@ class BitmexRestApi(RestClient):
self, self,
key: str, key: str,
secret: str, secret: str,
session: int, session_number: int,
server: str, server: str,
proxy_host: str, proxy_host: str,
proxy_port: int, proxy_port: int,
@ -192,14 +196,18 @@ class BitmexRestApi(RestClient):
else: else:
self.init(TESTNET_REST_HOST, proxy_host, proxy_port) self.init(TESTNET_REST_HOST, proxy_host, proxy_port)
self.start(session) self.start(session_number)
self.gateway.write_log("REST API启动成功") self.gateway.write_log("REST API启动成功")
def _new_order_id(self):
with self.order_count_lock:
self.order_count += 1
return self.order_count
def send_order(self, req: OrderRequest): def send_order(self, req: OrderRequest):
"""""" """"""
self.order_count += 1 orderid = str(self.connect_time + self._new_order_id())
orderid = str(self.connect_time + self.order_count)
data = { data = {
"symbol": req.symbol, "symbol": req.symbol,
@ -272,7 +280,7 @@ class BitmexRestApi(RestClient):
self.on_error(exception_type, exception_value, tb, request) self.on_error(exception_type, exception_value, tb, request)
def on_send_order(self, data, request): def on_send_order(self, data, request):
"""""" """Websocket will push a new order status"""
pass pass
def on_cancel_order_error( def on_cancel_order_error(
@ -286,7 +294,7 @@ class BitmexRestApi(RestClient):
self.on_error(exception_type, exception_value, tb, request) self.on_error(exception_type, exception_value, tb, request)
def on_cancel_order(self, data, request): def on_cancel_order(self, data, request):
"""""" """Websocket will push a new order status"""
pass pass
def on_failed(self, status_code: int, request: Request): def on_failed(self, status_code: int, request: Request):

View File

@ -116,6 +116,8 @@ class TigerGateway(BaseGateway):
self.queue = Queue() self.queue = Queue()
self.pool = None self.pool = None
self.ID_TIGER2VT = {}
self.ID_VT2TIGER = {}
self.ticks = {} self.ticks = {}
self.trades = set() self.trades = set()
self.contracts = {} self.contracts = {}
@ -125,18 +127,14 @@ class TigerGateway(BaseGateway):
"""""" """"""
while self.active: while self.active:
try: try:
func, arg = self.queue.get(timeout=0.1) func, args = self.queue.get(timeout=0.1)
print(func, arg) func(*args)
if arg:
func(arg)
else:
func()
except Empty: except Empty:
pass pass
def add_task(self, func, arg=None): def add_task(self, func, *args):
"""""" """"""
self.queue.put((func, arg)) self.queue.put((func, [*args]))
def connect(self, setting: dict): def connect(self, setting: dict):
"""""" """"""
@ -157,9 +155,6 @@ class TigerGateway(BaseGateway):
self.add_task(self.connect_quote) self.add_task(self.connect_quote)
self.add_task(self.connect_trade) self.add_task(self.connect_trade)
self.add_task(self.connect_push) self.add_task(self.connect_push)
self.write_log("行情接口连接成功")
# self.thread.start()
def init_client_config(self, sandbox=True): def init_client_config(self, sandbox=True):
"""""" """"""
@ -183,6 +178,7 @@ class TigerGateway(BaseGateway):
self.write_log("查询合约失败") self.write_log("查询合约失败")
return return
self.write_log("行情接口连接成功")
self.write_log("合约查询成功") self.write_log("合约查询成功")
def connect_trade(self): def connect_trade(self):
@ -256,6 +252,8 @@ class TigerGateway(BaseGateway):
def on_asset_change(self, tiger_account: str, data: list): def on_asset_change(self, tiger_account: str, data: list):
"""""" """"""
data = dict(data) data = dict(data)
if "net_liquidation" not in data:
return
account = AccountData( account = AccountData(
accountid=tiger_account, accountid=tiger_account,
@ -274,7 +272,7 @@ class TigerGateway(BaseGateway):
symbol=symbol, symbol=symbol,
exchange=exchange, exchange=exchange,
direction=Direction.NET, direction=Direction.NET,
volume=data["quantity"], volume=int(data["quantity"]),
frozen=0.0, frozen=0.0,
price=data["average_cost"], price=data["average_cost"],
pnl=data["unrealized_pnl"], pnl=data["unrealized_pnl"],
@ -284,17 +282,15 @@ class TigerGateway(BaseGateway):
def on_order_change(self, tiger_account: str, data: list): def on_order_change(self, tiger_account: str, data: list):
"""""" """"""
print("委托", data)
self.local_id += 1
data = dict(data) data = dict(data)
print("委托推送", data["origin_symbol"], data["order_id"], data["filled"], data["status"])
symbol, exchange = convert_symbol_tiger2vt(data["origin_symbol"]) symbol, exchange = convert_symbol_tiger2vt(data["origin_symbol"])
status = PUSH_STATUS_TIGER2VT[data["status"]] status = PUSH_STATUS_TIGER2VT[data["status"]]
order = OrderData( order = OrderData(
symbol=symbol, symbol=symbol,
exchange=exchange, exchange=exchange,
# orderid=data["order_id"], orderid=self.ID_TIGER2VT.get(str(data["order_id"]), self.get_new_local_id()),
orderid=self.local_id,
direction=Direction.NET, direction=Direction.NET,
price=data.get("limit_price", 0), price=data.get("limit_price", 0),
volume=data["quantity"], volume=data["quantity"],
@ -313,7 +309,7 @@ class TigerGateway(BaseGateway):
exchange=exchange, exchange=exchange,
direction=Direction.NET, direction=Direction.NET,
tradeid=self.tradeid, tradeid=self.tradeid,
orderid=data["order_id"], orderid=self.ID_TIGER2VT[str(data["order_id"])],
price=data["avg_fill_price"], price=data["avg_fill_price"],
volume=data["filled"], volume=data["filled"],
time=datetime.fromtimestamp(data["trade_time"] / 1000).strftime("%H:%M:%S"), time=datetime.fromtimestamp(data["trade_time"] / 1000).strftime("%H:%M:%S"),
@ -321,20 +317,22 @@ class TigerGateway(BaseGateway):
) )
self.on_trade(trade) self.on_trade(trade)
def get_new_local_id(self):
self.local_id += 1
return self.local_id
def send_order(self, req: OrderRequest): def send_order(self, req: OrderRequest):
"""""" """"""
self.local_id += 1 local_id = self.get_new_local_id()
order = req.create_order_data(self.local_id, self.gateway_name) order = req.create_order_data(local_id, self.gateway_name)
return order.vt_orderid
self.on_order(order) self.on_order(order)
self.add_task(self._send_order, req, local_id)
return order.vt_orderid
self.add_task(self._send_order, req) def _send_order(self, req: OrderRequest, local_id):
def _send_order(self, req: OrderRequest):
"""""" """"""
currency = config_symbol_currency(req.symbol) currency = config_symbol_currency(req.symbol)
# first, get contract
try: try:
contract = self.trade_client.get_contracts(symbol=req.symbol, currency=currency)[0] contract = self.trade_client.get_contracts(symbol=req.symbol, currency=currency)[0]
order = self.trade_client.create_order( order = self.trade_client.create_order(
@ -345,8 +343,14 @@ class TigerGateway(BaseGateway):
quantity=int(req.volume), quantity=int(req.volume),
limit_price=req.price, limit_price=req.price,
) )
self.ID_TIGER2VT[str(order.order_id)] = local_id
self.ID_VT2TIGER[local_id] = str(order.order_id)
self.trade_client.place_order(order) self.trade_client.place_order(order)
print("发单:", order.contract.symbol, order.order_id, order.quantity, order.status)
except: # noqa except: # noqa
traceback.print_exc()
self.write_log("发单失败") self.write_log("发单失败")
return return
@ -357,7 +361,8 @@ class TigerGateway(BaseGateway):
def _cancel_order(self, req: CancelRequest): def _cancel_order(self, req: CancelRequest):
"""""" """"""
try: try:
data = self.trade_client.cancel_order(order_id=req.orderid) order_id = self.ID_VT2TIGER[req.orderid]
data = self.trade_client.cancel_order(order_id=order_id)
except ApiException: except ApiException:
self.write_log(f"撤单失败:{req.orderid}") self.write_log(f"撤单失败:{req.orderid}")
@ -420,8 +425,7 @@ class TigerGateway(BaseGateway):
for ix, row in contract_CN.iterrows(): for ix, row in contract_CN.iterrows():
symbol = row["symbol"] symbol = row["symbol"]
symbol, exchange = convert_symbol_tiger2vt(symbol) symbol, exchange = convert_symbol_tiger2vt(symbol)
if symbol == '600001':
print(f"symbol: {symbol} t:{type(symbol)} l:{len(symbol)} ex:{exchange} n:{row['name']}")
contract = ContractData( contract = ContractData(
symbol=symbol, symbol=symbol,
exchange=exchange, exchange=exchange,
@ -467,7 +471,7 @@ class TigerGateway(BaseGateway):
symbol=symbol, symbol=symbol,
exchange=exchange, exchange=exchange,
direction=Direction.NET, direction=Direction.NET,
volume=i.quantity, volume=int(i.quantity),
frozen=0.0, frozen=0.0,
price=i.average_cost, price=i.average_cost,
pnl=float(i.unrealized_pnl), pnl=float(i.unrealized_pnl),
@ -500,12 +504,12 @@ class TigerGateway(BaseGateway):
"""""" """"""
for i in data: for i in data:
symbol, exchange = convert_symbol_tiger2vt(str(i.contract)) symbol, exchange = convert_symbol_tiger2vt(str(i.contract))
self.local_id += 1 local_id = self.get_new_local_id()
order = OrderData( order = OrderData(
symbol=symbol, symbol=symbol,
exchange=exchange, exchange=exchange,
orderid=self.local_id, orderid=local_id,
# orderid=str(i.order_id),
direction=Direction.NET, direction=Direction.NET,
price=i.limit_price if i.limit_price else 0.0, price=i.limit_price if i.limit_price else 0.0,
volume=i.quantity, volume=i.quantity,
@ -514,17 +518,20 @@ class TigerGateway(BaseGateway):
time=datetime.fromtimestamp(i.order_time / 1000).strftime("%H:%M:%S"), time=datetime.fromtimestamp(i.order_time / 1000).strftime("%H:%M:%S"),
gateway_name=self.gateway_name, gateway_name=self.gateway_name,
) )
self.ID_TIGER2VT[str(i.order_id)] = local_id
self.on_order(order) self.on_order(order)
self.ID_VT2TIGER = {v: k for k, v in self.ID_TIGER2VT.items()}
print("原始委托字典", self.ID_TIGER2VT)
print("原始反向字典", self.ID_VT2TIGER)
def process_deal(self, data): def process_deal(self, data):
""" """
Process trade data for both query and update. Process trade data for both query and update.
""" """
for i in reversed(data): for i in data:
if i.status == ORDER_STATUS.PARTIALLY_FILLED or i.status == ORDER_STATUS.FILLED: if i.status == ORDER_STATUS.PARTIALLY_FILLED or i.status == ORDER_STATUS.FILLED:
symbol, exchange = convert_symbol_tiger2vt(str(i.contract)) symbol, exchange = convert_symbol_tiger2vt(str(i.contract))
self.local_id += 1
self.tradeid += 1 self.tradeid += 1
trade = TradeData( trade = TradeData(
@ -532,7 +539,7 @@ class TigerGateway(BaseGateway):
exchange=exchange, exchange=exchange,
direction=Direction.NET, direction=Direction.NET,
tradeid=self.tradeid, tradeid=self.tradeid,
orderid=self.local_id, orderid=self.ID_TIGER2VT[str(i.order_id)],
price=i.avg_fill_price, price=i.avg_fill_price,
volume=i.filled, volume=i.filled,
time=datetime.fromtimestamp(i.trade_time / 1000).strftime("%H:%M:%S"), time=datetime.fromtimestamp(i.trade_time / 1000).strftime("%H:%M:%S"),

View File

@ -33,6 +33,39 @@ class BaseGateway(ABC):
""" """
Abstract gateway class for creating gateways connection Abstract gateway class for creating gateways connection
to different trading systems. to different trading systems.
# How to implement a gateway:
---
## Basics
A gateway should satisfies:
* this class should be thread-safe:
* all methods should be thread-safe
* no mutable shared properties between objects.
* all methods should be non-blocked
* satisfies all requirements written in docstring for every method and callbacks.
* automatically reconnect if connection lost.
---
## methods must implements:
all @abstractmethod
---
## callbacks must response manually:
* on_tick
* on_trade
* on_order
* on_position
* on_account
* on_contract
All the XxxData passed to callback should be constant, which means that
the object should not be modified after passing to on_xxxx.
So if you use a cache to store reference of data, use copy.copy to create a new object
before passing that data into on_xxxx
""" """
# Fields required in setting dict for connect function. # Fields required in setting dict for connect function.
@ -113,6 +146,21 @@ class BaseGateway(ABC):
def connect(self, setting: dict): def connect(self, setting: dict):
""" """
Start gateway connection. Start gateway connection.
to implement this method, you must:
* connect to server if necessary
* log connected if all necessary connection is established
* do the following query and response corresponding on_xxxx and write_log
* contracts : on_contract
* account asset : on_account
* account holding: on_position
* orders of account: on_order
* trades of account: on_trade
* if any of query above is failed, write log.
future plan:
response callback/change status instead of write_log
""" """
pass pass
@ -131,9 +179,20 @@ class BaseGateway(ABC):
pass pass
@abstractmethod @abstractmethod
def send_order(self, req: OrderRequest): def send_order(self, req: OrderRequest) -> str:
""" """
Send a new order. Send a new order to server.
implementation should finish the tasks blow:
* create an OrderData from req using OrderRequest.create_order_data
* assign a unique(gateway instance scope) id to OrderData.orderid
* send request to server
* if request is sent, OrderData.status should be set to Status.SUBMITTING
* if request is failed to sent, OrderData.status should be set to Status.REJECTED
* response on_order:
* return OrderData.vt_orderid
:return str vt_orderid for created OrderData
""" """
pass pass
@ -141,6 +200,10 @@ class BaseGateway(ABC):
def cancel_order(self, req: CancelRequest): def cancel_order(self, req: CancelRequest):
""" """
Cancel an existing order. Cancel an existing order.
implementation should finish the tasks blow:
* send request to server
""" """
pass pass
@ -148,6 +211,7 @@ class BaseGateway(ABC):
def query_account(self): def query_account(self):
""" """
Query account balance. Query account balance.
""" """
pass pass