[Mod] complete general test of coinbase gateway

This commit is contained in:
vn.py 2019-09-05 11:23:37 +08:00
parent 019da35a55
commit 77c2b52842
2 changed files with 41 additions and 20 deletions

View File

@ -28,6 +28,7 @@ from vnpy.gateway.okexs import OkexsGateway
# from vnpy.gateway.tora import ToraGateway
# from vnpy.gateway.alpaca import AlpacaGateway
from vnpy.gateway.da import DaGateway
from vnpy.gateway.coinbase import CoinbaseGateway
from vnpy.app.cta_strategy import CtaStrategyApp
# from vnpy.app.csv_loader import CsvLoaderApp
@ -69,8 +70,9 @@ def main():
# main_engine.add_gateway(TapGateway)
# main_engine.add_gateway(ToraGateway)
# main_engine.add_gateway(AlpacaGateway)
main_engine.add_gateway(OkexsGateway)
main_engine.add_gateway(DaGateway)
# main_engine.add_gateway(OkexsGateway)
# main_engine.add_gateway(DaGateway)
main_engine.add_gateway(CoinbaseGateway)
main_engine.add_app(CtaStrategyApp)
main_engine.add_app(CtaBacktesterApp)

View File

@ -41,6 +41,7 @@ from vnpy.trader.object import (
REST_HOST = "https://api.pro.coinbase.com"
WEBSOCKET_HOST = "wss://ws-feed.pro.coinbase.com"
SANDBOX_REST_HOST = "https://api-public.sandbox.pro.coinbase.com"
SANDBOX_WEBSOCKET_HOST = "wss://ws-feed-public.sandbox.pro.coinbase.com"
@ -68,6 +69,7 @@ TIMEDELTA_MAP = {
cancelDict = {} # orderid:cancelreq
orderDict = {} # sysid:order
orderSysDict = {} # orderid:sysid
symbol_name_map = {}
class CoinbaseGateway(BaseGateway):
@ -84,11 +86,13 @@ class CoinbaseGateway(BaseGateway):
"proxy_host": "",
"proxy_port": "",
}
exchanges = [Exchange.COINBASE]
def __init__(self, event_engine):
"""Constructor"""
super(CoinbaseGateway, self).__init__(event_engine, "COINBASE")
self.rest_api = CoinbaseRestApi(self)
self.ws_api = CoinbaseWebsocketApi(self)
@ -107,8 +111,15 @@ class CoinbaseGateway(BaseGateway):
else:
proxy_port = 0
self.rest_api.connect(key, secret, passphrase, session_number, server,
proxy_host, proxy_port)
self.rest_api.connect(
key,
secret,
passphrase,
session_number,
server,
proxy_host,
proxy_port
)
self.ws_api.connect(
key,
@ -186,8 +197,8 @@ class CoinbaseWebsocketApi(WebsocketClient):
"done": self.on_order_done,
"match": self.on_order_match,
}
self.ticks = {}
self.ticks = {}
self.accounts = {}
self.orderbooks = {}
@ -349,10 +360,9 @@ class CoinbaseWebsocketApi(WebsocketClient):
orderid=order.orderid,
tradeid=packet['trade_id'],
direction=DIRECTION_COINBASE2VT[packet['side']],
price=packet['price'],
volume=packet['size'],
time=datetime.strptime(
packet['time'], "%Y-%m-%dT%H:%M:%S.%fZ"),
price=float(packet['price']),
volume=float(packet['size']),
time=packet['time'],
gateway_name=self.gateway_name,
)
self.gateway.on_trade(trade)
@ -372,7 +382,12 @@ class OrderBook():
self.bids = dict()
self.gateway = gateway
self.newest_tick = TickData(
"COINBASE", symbol, exchange, datetime.now())
symbol=symbol,
exchange=exchange,
name=symbol_name_map.get(symbol, ""),
datetime=datetime.now(),
gateway_name=gateway.gateway_name,
)
self.first_update = False
def on_message(self, d: dict):
@ -583,16 +598,19 @@ class CoinbaseRestApi(RestClient):
def on_query_account(self, data, request):
""""""
for acc in data:
account_id = str(acc['id'])
account_id = str(acc['currency'])
account = self.accounts.get(account_id, None)
if not account:
account = AccountData(accountid=account_id,
gateway_name=self.gateway_name)
account = AccountData(
accountid=account_id,
gateway_name=self.gateway_name
)
self.accounts[account_id] = account
account.balance = acc.get("balance", account.balance)
account.available = acc.get("available", account.available)
account.frozen = acc.get("hold", account.frozen)
account.balance = float(acc.get("balance", account.balance))
account.available = float(acc.get("available", account.available))
account.frozen = float(acc.get("hold", account.frozen))
self.gateway.on_account(copy(account))
@ -670,17 +688,18 @@ class CoinbaseRestApi(RestClient):
exchange=Exchange.COINBASE,
name=d['display_name'],
product=Product.SPOT,
pricetick=d['quote_increment'],
size=d['base_min_size'],
stop_supported=False,
pricetick=float(d['quote_increment']),
size=1,
min_volume=float(d['base_min_size']),
net_position=True,
history_data=False,
gateway_name=self.gateway_name,
)
self.gateway.on_contract(contract)
symbol_name_map[contract.symbol] = contract.name
self.gateway.write_log("")
self.gateway.write_log("合约信息查询成功")
def send_order(self, req: OrderRequest):
""""""