/* Copyright (C) 2013 Interactive Brokers LLC. All rights reserved. This code is subject to the terms * and conditions of the IB API Non-Commercial License or the IB API Commercial License, as applicable. */ #include "StdAfx.h" #include "../client/EPosixClientSocketPlatform.h" #include "EClientSocketSSL.h" #include "../client/TwsSocketClientErrors.h" #include "../client/EWrapper.h" #include "../client/EDecoder.h" #include "../client/EReaderSignal.h" #include "EReaderSSL.h" #include "../client/EMessage.h" #include #include #include const int MIN_SERVER_VER_SUPPORTED = 38; //all supported server versions are defined in EDecoder.h std::vector EClientSocketSSL::sslLocks(CRYPTO_num_locks()); void EClientSocketSSL::lockingFunc(int mode, int type, const char *file, int line) { if (mode & CRYPTO_LOCK) sslLocks[type].Enter(); else sslLocks[type].Leave(); } unsigned long EClientSocketSSL::thIdFunc() { #if defined(IB_POSIX) return syscall(SYS_gettid); #elif defined(IB_WIN32) return GetCurrentThreadId(); #else # error "Not implemented on this platform" #endif } /////////////////////////////////////////////////////////// // member funcs EClientSocketSSL::EClientSocketSSL(EWrapper *ptr, EReaderSignal *pSignal) : EClient( ptr, new ESocketSSL()) { m_fd = SocketsInit() ? -1 : -2; m_allowRedirect = false; m_asyncEConnect = false; m_pSignal = pSignal; SSL_load_error_strings(); ERR_load_BIO_strings(); SSL_library_init(); CRYPTO_set_locking_callback(lockingFunc); CRYPTO_w_lock(CRYPTO_LOCK_DYNLOCK); CRYPTO_set_id_callback(thIdFunc); CRYPTO_w_unlock(CRYPTO_LOCK_DYNLOCK); } EClientSocketSSL::~EClientSocketSSL() { if( m_fd != -2) SocketsDestroy(); } bool EClientSocketSSL::asyncEConnect() const { return m_asyncEConnect; } void EClientSocketSSL::asyncEConnect(bool val) { m_asyncEConnect = val; } bool EClientSocketSSL::eConnect( const char *host, unsigned int port, int clientId, bool extraAuth) { if( m_fd == -2) { getWrapper()->error( NO_VALID_ID, FAIL_CREATE_SOCK.code(), FAIL_CREATE_SOCK.msg()); return false; } // reset errno errno = 0; // already connected? if( m_fd >= 0) { errno = EISCONN; getWrapper()->error( NO_VALID_ID, ALREADY_CONNECTED.code(), ALREADY_CONNECTED.msg()); return false; } // normalize host m_hostNorm = (host && *host) ? host : "127.0.0.1"; // initialize host and port setHost( m_hostNorm); setPort( port); // try to connect to specified host and port ConnState resState = CS_DISCONNECTED; return eConnectImpl( clientId, extraAuth, &resState); } ESocketSSL *EClientSocketSSL::getTransport() { assert(dynamic_cast(m_transport.get()) != 0); return static_cast(m_transport.get()); } void EClientSocketSSL::ImportRootCertificatesFromWindowsCertStore() { #if defined(IB_WIN32) auto store = CertOpenSystemStore(0, "ROOT"); auto osslStore = SSL_CTX_get_cert_store(m_pCTX); for (auto certCtx = CertEnumCertificatesInStore(store, 0); certCtx; certCtx = CertEnumCertificatesInStore(store, certCtx)) { if (!(certCtx->dwCertEncodingType & X509_ASN_ENCODING)) continue; auto pBuf = certCtx->pbCertEncoded; auto cert = d2i_X509(0, (const unsigned char **)&pBuf, certCtx->cbCertEncoded); X509_STORE_add_cert(osslStore, cert); } CertCloseStore(store, 0); #endif } bool EClientSocketSSL::eConnectImpl(int clientId, bool extraAuth, ConnState* stateOutPt) { // resolve host struct hostent* hostEnt = gethostbyname( host().c_str()); if ( !hostEnt) { getWrapper()->error( NO_VALID_ID, CONNECT_FAIL.code(), CONNECT_FAIL.msg()); return false; } // create socket m_fd = socket(AF_INET, SOCK_STREAM, 0); // cannot create socket if( m_fd < 0) { getWrapper()->error( NO_VALID_ID, FAIL_CREATE_SOCK.code(), FAIL_CREATE_SOCK.msg()); return false; } // starting to connect to server struct sockaddr_in sa; memset( &sa, 0, sizeof(sa)); sa.sin_family = AF_INET; sa.sin_port = htons( port()); sa.sin_addr.s_addr = ((in_addr*)hostEnt->h_addr)->s_addr; // try to connect if( (connect( m_fd, (struct sockaddr *) &sa, sizeof( sa))) < 0) { // error connecting SocketClose( m_fd); m_fd = -1; getWrapper()->error( NO_VALID_ID, CONNECT_FAIL.code(), CONNECT_FAIL.msg()); return false; } m_pCTX = SSL_CTX_new(SSLv23_client_method()); if (!m_pCTX && !handleSocketError()) return false; ImportRootCertificatesFromWindowsCertStore(); m_pSSL = SSL_new(m_pCTX); if (!m_pSSL && !handleSocketError()) return false; if (!SSL_set_fd(m_pSSL, m_fd) && !handleSocketError()) return false; if (!SSL_connect(m_pSSL) && !handleSocketError()) return false; if(SSL_get_verify_result(m_pSSL) != X509_V_OK) { getWrapper()->error(NO_VALID_ID, SSL_FAIL.code(), SSL_FAIL.msg() + "certificate verification failure"); //return false; } getTransport()->fd(m_pSSL); // set client id setClientId( clientId); setExtraAuth( extraAuth); int res = sendConnectRequest(); if (res == 0 || res < 0 && !handleSocketError(res)) return false; if( !isConnected()) { if( connState() != CS_DISCONNECTED) { assert( connState() == CS_REDIRECT); if( stateOutPt) { *stateOutPt = connState(); } eDisconnect(); } return false; } // set socket to non-blocking state if ( !SetSocketNonBlocking(m_fd)) { // error setting socket to non-blocking eDisconnect(); getWrapper()->error( NO_VALID_ID, CONNECT_FAIL.code(), CONNECT_FAIL.msg()); return false; } assert( connState() == CS_CONNECTED); if( stateOutPt) { *stateOutPt = connState(); } if (!m_asyncEConnect) { EReaderSSL reader(this, m_pSignal); while (m_pSignal && !m_serverVersion && isSocketOK()) { reader.checkClient(); m_pSignal->waitForSignal(); reader.processMsgs(); } } // successfully connected return isSocketOK(); } void EClientSocketSSL::encodeMsgLen(std::string& msg, unsigned offset) const { assert( !msg.empty()); assert( m_useV100Plus); assert( sizeof(unsigned) == HEADER_LEN); assert( msg.size() > offset + HEADER_LEN); unsigned len = msg.size() - HEADER_LEN - offset; if( len > MAX_MSG_LEN) { m_pEWrapper->error( NO_VALID_ID, BAD_LENGTH.code(), BAD_LENGTH.msg()); return; } unsigned netlen = htonl( len); memcpy( &msg[offset], &netlen, HEADER_LEN); } bool EClientSocketSSL::closeAndSend(std::string msg, unsigned offset) { assert( !msg.empty()); if( m_useV100Plus) { encodeMsgLen( msg, offset); } int res = bufferedSend(msg); if (res < 0) return handleSocketError(res); return true; } void EClientSocketSSL::prepareBufferImpl(std::ostream& buf) const { assert( m_useV100Plus); assert( sizeof(unsigned) == HEADER_LEN); char header[HEADER_LEN] = { 0 }; buf.write( header, sizeof(header)); } void EClientSocketSSL::prepareBuffer(std::ostream& buf) const { if( !m_useV100Plus) return; prepareBufferImpl( buf); } void EClientSocketSSL::eDisconnect() { if (m_pSSL) SSL_shutdown(m_pSSL); if (m_pCTX) SSL_CTX_free(m_pCTX); if ( m_fd >= 0 ) // close socket SocketClose( m_fd); m_fd = -1; eDisconnectBase(); } bool EClientSocketSSL::isSocketOK() const { return ( m_fd >= 0); } int EClientSocketSSL::fd() const { return m_fd; } int EClientSocketSSL::receive(char* buf, size_t sz) { if( sz <= 0) return 0; int nResult = SSL_read(m_pSSL, buf, sz); if( nResult == -1 && !handleSocketError(nResult)) { return -1; } if( nResult == 0) { onClose(); } if( nResult <= 0) { return 0; } return nResult; } void EClientSocketSSL::serverVersion(int version, const char *time) { m_serverVersion = version; m_TwsTime = time; if( usingV100Plus() ? (m_serverVersion < MIN_CLIENT_VER || m_serverVersion > MAX_CLIENT_VER) : m_serverVersion < MIN_SERVER_VER_SUPPORTED ) { getWrapper()->error( NO_VALID_ID, UNSUPPORTED_VERSION.code(), UNSUPPORTED_VERSION.msg()); eDisconnect(); } if (!m_asyncEConnect) startApi(); } void EClientSocketSSL::redirect(const char *host, int port) { // handle redirect if( (m_hostNorm != this->host() || port != this->port())) { if (!m_allowRedirect) { getWrapper()->error(NO_VALID_ID, CONNECT_FAIL.code(), CONNECT_FAIL.msg()); return; } eDisconnect(); eConnectImpl( clientId(), extraAuth(), 0); } } bool EClientSocketSSL::handleSSLError(int &ret_code) { ret_code = SSL_get_error(m_pSSL, ret_code); switch (ret_code) { case SSL_ERROR_NONE: case SSL_ERROR_WANT_READ: case SSL_ERROR_WANT_WRITE: ret_code = 0; return true; case SSL_ERROR_SYSCALL: #if defined(IB_POSIX) ret_code = errno; #elif defined(IB_WIN32) ret_code = GetLastError(); #else # error "Not implemented on this platform" #endif return true; } return false; } bool EClientSocketSSL::handleSocketErrorInternal(int hr) { #if defined(IB_WIN32) if (hr != 0) { LPTSTR buf; FormatMessage(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM, 0, hr, 0, (LPTSTR)&buf, 0, 0); getWrapper()->error( NO_VALID_ID, SOCKET_EXCEPTION.code(), SOCKET_EXCEPTION.msg() + buf); LocalFree(buf); return false; } #endif // no error if( errno == 0) return true; // Socket is already connected if( errno == EISCONN) { return true; } if( errno == EWOULDBLOCK) return false; if( errno == ECONNREFUSED) { getWrapper()->error( NO_VALID_ID, CONNECT_FAIL.code(), CONNECT_FAIL.msg()); } else { getWrapper()->error( NO_VALID_ID, SOCKET_EXCEPTION.code(), SOCKET_EXCEPTION.msg() + strerror(errno)); } // reset errno errno = 0; eDisconnect(); return false; } bool EClientSocketSSL::handleSocketError(int res) { if (!handleSSLError(res)) { getWrapper()->error(NO_VALID_ID, SSL_FAIL.code(), SSL_FAIL.msg() + ERR_error_string(res, 0)); return false; } return handleSocketErrorInternal(res); } bool EClientSocketSSL::handleSocketError() { int res = ERR_get_error(); if (res) { getWrapper()->error(NO_VALID_ID, SSL_FAIL.code(), SSL_FAIL.msg() + ERR_error_string(res, 0)); return false; } return handleSocketErrorInternal(); } /////////////////////////////////////////////////////////// // callbacks from socket void EClientSocketSSL::onSend() { if( !handleSocketError()) return; getTransport()->sendBufferedData(); } void EClientSocketSSL::onClose() { if( !handleSocketError()) return; eDisconnect(); getWrapper()->connectionClosed(); } void EClientSocketSSL::onError() { handleSocketError(); }