Implement DTLSSocketWrapper and fix non-blocking connections on TLSSocket

DTLSSocketWrapper is equivalent of TLSSocketWrapper but uses datagram mode
and timers for handling Mbed TLS timeouts.

Non-blocking connections were not working earlier, now fixed for both
secure socket modes.
pull/8659/head
Seppo Takalo 2018-11-01 13:12:40 +02:00
parent 9aef9d3661
commit d22adbdb26
6 changed files with 269 additions and 51 deletions

View File

@ -0,0 +1,55 @@
#include "DTLSSocketWrapper.h"
#include "platform/Callback.h"
#include "drivers/Timer.h"
#include "events/mbed_events.h"
#if defined(MBEDTLS_SSL_CLI_C)
DTLSSocketWrapper::DTLSSocketWrapper(Socket *transport, const char *hostname, control_transport control)
: TLSSocketWrapper(transport, hostname, control)
{
mbedtls_ssl_conf_transport( get_ssl_config(), MBEDTLS_SSL_TRANSPORT_DATAGRAM);
mbedtls_ssl_set_timer_cb( get_ssl_context(), this, timing_set_delay, timing_get_delay);
_timer_event_id = 0;
_timer_expired = false;
}
void DTLSSocketWrapper::timing_set_delay(void *ctx, uint32_t int_ms, uint32_t fin_ms)
{
DTLSSocketWrapper *context = static_cast<DTLSSocketWrapper *>(ctx);
if (context->_timer_event_id) {
mbed::mbed_event_queue()->cancel(context->_timer_event_id);
context->_timer_expired = false;
}
if (fin_ms == 0) {
context->_timer_event_id = 0;
return;
}
context->_timer_event_id = mbed::mbed_event_queue()->call_in(fin_ms, context, &DTLSSocketWrapper::timer_event);
}
int DTLSSocketWrapper::timing_get_delay(void *ctx)
{
DTLSSocketWrapper *context = static_cast<DTLSSocketWrapper *>(ctx);
/* See documentation of "typedef int mbedtls_ssl_get_timer_t( void * ctx );" from ssl.h */
if (context->_timer_event_id == 0) {
return -1;
} else if (context->_timer_expired) {
return 2;
} else {
return 0;
}
}
void DTLSSocketWrapper::timer_event(void)
{
_timer_expired = true;
event();
}
#endif /* MBEDTLS_SSL_CLI_C */

View File

@ -0,0 +1,21 @@
#ifndef DTLSSOCKETWRAPPER_H
#define DTLSSOCKETWRAPPER_H
#include "TLSSocketWrapper.h"
// This class requires Mbed TLS SSL/TLS client code
#if defined(MBEDTLS_SSL_CLI_C)
class DTLSSocketWrapper : public TLSSocketWrapper {
public:
DTLSSocketWrapper(Socket *transport, const char *hostname = NULL, control_transport control = TRANSPORT_CONNECT_AND_CLOSE);
private:
static void timing_set_delay(void *ctx, uint32_t int_ms, uint32_t fin_ms);
static int timing_get_delay(void *ctx);
void timer_event();
int _timer_event_id;
bool _timer_expired:1;
};
#endif
#endif

View File

@ -25,14 +25,17 @@
nsapi_error_t TLSSocket::connect(const char *host, uint16_t port) nsapi_error_t TLSSocket::connect(const char *host, uint16_t port)
{ {
set_hostname(host); nsapi_error_t ret = NSAPI_ERROR_OK;
if (!is_handshake_started()) {
nsapi_error_t ret = tcp_socket.connect(host, port); ret = tcp_socket.connect(host, port);
if (ret) { if (ret == NSAPI_ERROR_OK || ret == NSAPI_ERROR_IN_PROGRESS) {
return ret; set_hostname(host);
}
if (ret != NSAPI_ERROR_OK && ret != NSAPI_ERROR_IS_CONNECTED) {
return ret;
}
} }
return TLSSocketWrapper::start_handshake(ret == NSAPI_ERROR_OK);
return TLSSocketWrapper::do_handshake();
} }
TLSSocket::~TLSSocket() TLSSocket::~TLSSocket()

View File

@ -16,18 +16,22 @@
*/ */
#include "TLSSocketWrapper.h" #include "TLSSocketWrapper.h"
#include "platform/Callback.h"
#include "drivers/Timer.h" #include "drivers/Timer.h"
#include "events/mbed_events.h"
#define TRACE_GROUP "TLSW" #define TRACE_GROUP "TLSW"
#include "mbed-trace/mbed_trace.h" #include "mbed-trace/mbed_trace.h"
#include "mbedtls/debug.h" #include "mbedtls/debug.h"
#include "mbed_error.h" #include "mbed_error.h"
#include "Kernel.h"
// This class requires Mbed TLS SSL/TLS client code // This class requires Mbed TLS SSL/TLS client code
#if defined(MBEDTLS_SSL_CLI_C) #if defined(MBEDTLS_SSL_CLI_C)
TLSSocketWrapper::TLSSocketWrapper(Socket *transport, const char *hostname, control_transport control) : TLSSocketWrapper::TLSSocketWrapper(Socket *transport, const char *hostname, control_transport control) :
_transport(transport), _transport(transport),
_timeout(-1),
#ifdef MBEDTLS_X509_CRT_PARSE_C #ifdef MBEDTLS_X509_CRT_PARSE_C
_cacert(NULL), _cacert(NULL),
_clicert(NULL), _clicert(NULL),
@ -35,6 +39,7 @@ TLSSocketWrapper::TLSSocketWrapper(Socket *transport, const char *hostname, cont
_ssl_conf(NULL), _ssl_conf(NULL),
_connect_transport(control == TRANSPORT_CONNECT || control == TRANSPORT_CONNECT_AND_CLOSE), _connect_transport(control == TRANSPORT_CONNECT || control == TRANSPORT_CONNECT_AND_CLOSE),
_close_transport(control == TRANSPORT_CLOSE || control == TRANSPORT_CONNECT_AND_CLOSE), _close_transport(control == TRANSPORT_CLOSE || control == TRANSPORT_CONNECT_AND_CLOSE),
_tls_initialized(false),
_handshake_completed(false), _handshake_completed(false),
_cacert_allocated(false), _cacert_allocated(false),
_clicert_allocated(false), _clicert_allocated(false),
@ -140,26 +145,33 @@ nsapi_error_t TLSSocketWrapper::set_client_cert_key(const void *client_cert, siz
} }
nsapi_error_t TLSSocketWrapper::do_handshake() nsapi_error_t TLSSocketWrapper::start_handshake(bool first_call)
{ {
nsapi_error_t _error;
const char DRBG_PERS[] = "mbed TLS client"; const char DRBG_PERS[] = "mbed TLS client";
int ret;
if (!_transport) { if (!_transport) {
return NSAPI_ERROR_NO_SOCKET; return NSAPI_ERROR_NO_SOCKET;
} }
_transport->set_blocking(true); if (_tls_initialized) {
return continue_handshake();
}
#ifdef MBEDTLS_X509_CRT_PARSE_C
/* Start the handshake, the rest will be done in onReceive() */
tr_info("Starting TLS handshake with %s", _ssl.hostname);
#else
tr_info("Starting TLS handshake");
#endif
/* /*
* Initialize TLS-related stuf. * Initialize TLS-related stuf.
*/ */
int ret;
if ((ret = mbedtls_ctr_drbg_seed(&_ctr_drbg, mbedtls_entropy_func, &_entropy, if ((ret = mbedtls_ctr_drbg_seed(&_ctr_drbg, mbedtls_entropy_func, &_entropy,
(const unsigned char *) DRBG_PERS, (const unsigned char *) DRBG_PERS,
sizeof(DRBG_PERS))) != 0) { sizeof(DRBG_PERS))) != 0) {
print_mbedtls_error("mbedtls_crt_drbg_init", ret); print_mbedtls_error("mbedtls_crt_drbg_init", ret);
_error = ret; return NSAPI_ERROR_PARAMETER;
return _error;
} }
tr_info("mbedtls_ssl_conf_rng()"); tr_info("mbedtls_ssl_conf_rng()");
@ -175,26 +187,58 @@ nsapi_error_t TLSSocketWrapper::do_handshake()
tr_info("mbedtls_ssl_setup()"); tr_info("mbedtls_ssl_setup()");
if ((ret = mbedtls_ssl_setup(&_ssl, get_ssl_config())) != 0) { if ((ret = mbedtls_ssl_setup(&_ssl, get_ssl_config())) != 0) {
print_mbedtls_error("mbedtls_ssl_setup", ret); print_mbedtls_error("mbedtls_ssl_setup", ret);
_error = ret; return NSAPI_ERROR_PARAMETER;
return _error;
} }
_transport->set_blocking(false);
_transport->sigio(mbed::callback(this, &TLSSocketWrapper::event));
mbedtls_ssl_set_bio(&_ssl, this, ssl_send, ssl_recv, NULL); mbedtls_ssl_set_bio(&_ssl, this, ssl_send, ssl_recv, NULL);
#ifdef MBEDTLS_X509_CRT_PARSE_C _tls_initialized = true;
/* Start the handshake, the rest will be done in onReceive() */
tr_info("Starting TLS handshake with %s", _ssl.hostname);
#else
tr_info("Starting TLS handshake");
#endif
do { ret = continue_handshake();
if (first_call) {
if (ret == NSAPI_ERROR_ALREADY ) {
ret = NSAPI_ERROR_IN_PROGRESS; // If first call should return IN_PROGRESS
}
if (ret == NSAPI_ERROR_IS_CONNECTED) {
ret = NSAPI_ERROR_OK; // If we happened to complete the request on the first call, return OK.
}
}
return ret;
}
nsapi_error_t TLSSocketWrapper::continue_handshake() {
int ret;
if (_handshake_completed) {
return NSAPI_ERROR_IS_CONNECTED;
}
if (!_tls_initialized) {
return NSAPI_ERROR_NO_CONNECTION;
}
while (true) {
ret = mbedtls_ssl_handshake(&_ssl); ret = mbedtls_ssl_handshake(&_ssl);
} while (ret != 0 && (ret == MBEDTLS_ERR_SSL_WANT_READ || if (_timeout && (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE)) {
ret == MBEDTLS_ERR_SSL_WANT_WRITE)); uint32_t flag;
flag = _event_flag.wait_any(1, _timeout);
if (flag & osFlagsError) {
break;
}
} else {
break;
}
}
if (ret < 0) { if (ret < 0) {
print_mbedtls_error("mbedtls_ssl_handshake", ret); print_mbedtls_error("mbedtls_ssl_handshake", ret);
return ret; if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
return NSAPI_ERROR_ALREADY;
} else {
return NSAPI_ERROR_AUTH_FAILURE;
}
} }
#ifdef MBEDTLS_X509_CRT_PARSE_C #ifdef MBEDTLS_X509_CRT_PARSE_C
@ -225,8 +269,7 @@ nsapi_error_t TLSSocketWrapper::do_handshake()
#endif #endif
_handshake_completed = true; _handshake_completed = true;
return NSAPI_ERROR_IS_CONNECTED;
return 0;
} }
@ -239,15 +282,42 @@ nsapi_error_t TLSSocketWrapper::send(const void *data, nsapi_size_t size)
} }
tr_debug("send %d", size); tr_debug("send %d", size);
ret = mbedtls_ssl_write(&_ssl, (const unsigned char *) data, size); while (true) {
if (!_handshake_completed) {
ret = continue_handshake();
if (ret != NSAPI_ERROR_IS_CONNECTED) {
if (ret == NSAPI_ERROR_ALREADY) {
ret = NSAPI_ERROR_NO_CONNECTION;
}
return ret;
}
}
ret = mbedtls_ssl_write(&_ssl, (const unsigned char *) data, size);
if (_timeout == 0) {
break;
} else if (ret == MBEDTLS_ERR_SSL_WANT_WRITE || ret == MBEDTLS_ERR_SSL_WANT_READ) {
uint32_t flag;
flag = _event_flag.wait_any(1, _timeout);
if (flag & osFlagsError) {
// Timeout break
break;
}
} else {
break;
}
}
if (ret == MBEDTLS_ERR_SSL_WANT_WRITE || if (ret == MBEDTLS_ERR_SSL_WANT_WRITE ||
ret == MBEDTLS_ERR_SSL_WANT_READ) { ret == MBEDTLS_ERR_SSL_WANT_READ) {
// translate to socket error // translate to socket error
return NSAPI_ERROR_WOULD_BLOCK; return NSAPI_ERROR_WOULD_BLOCK;
} }
if (ret < 0) { if (ret < 0) {
print_mbedtls_error("mbedtls_ssl_write", ret); print_mbedtls_error("mbedtls_ssl_write", ret);
return NSAPI_ERROR_DEVICE_ERROR;
} }
return ret; // Assume "non negative errorcode" to be propagated from Socket layer return ret; // Assume "non negative errorcode" to be propagated from Socket layer
} }
@ -266,15 +336,39 @@ nsapi_size_or_error_t TLSSocketWrapper::recv(void *data, nsapi_size_t size)
return NSAPI_ERROR_NO_SOCKET; return NSAPI_ERROR_NO_SOCKET;
} }
ret = mbedtls_ssl_read(&_ssl, (unsigned char *) data, size); while (true) {
if (!_handshake_completed) {
ret = continue_handshake();
if (ret != NSAPI_ERROR_IS_CONNECTED) {
if (ret == NSAPI_ERROR_ALREADY) {
ret = NSAPI_ERROR_NO_CONNECTION;
}
return ret;
}
}
ret = mbedtls_ssl_read(&_ssl, (unsigned char *) data, size);
if (_timeout == 0) {
break;
} else if (ret == MBEDTLS_ERR_SSL_WANT_WRITE || ret == MBEDTLS_ERR_SSL_WANT_READ) {
uint32_t flag;
flag = _event_flag.wait_any(1, _timeout);
if (flag & osFlagsError) {
// Timeout break
break;
}
} else {
break;
}
}
if (ret == MBEDTLS_ERR_SSL_WANT_WRITE || if (ret == MBEDTLS_ERR_SSL_WANT_WRITE ||
ret == MBEDTLS_ERR_SSL_WANT_READ) { ret == MBEDTLS_ERR_SSL_WANT_READ) {
// translate to socket error // translate to socket error
return NSAPI_ERROR_WOULD_BLOCK; return NSAPI_ERROR_WOULD_BLOCK;
} else if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) { } else if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
/* MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY is not considered as error. /* MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY is not considered as error.
* Just ignre here. Once connection is closed, mbedtls_ssl_read() * Just ignore here. Once connection is closed, mbedtls_ssl_read()
* will return 0. * will return 0.
*/ */
return 0; return 0;
@ -470,6 +564,11 @@ void TLSSocketWrapper::set_ssl_config(mbedtls_ssl_config *conf)
_ssl_conf = conf; _ssl_conf = conf;
} }
mbedtls_ssl_context *TLSSocketWrapper::get_ssl_context()
{
return &_ssl;
}
nsapi_error_t TLSSocketWrapper::close() nsapi_error_t TLSSocketWrapper::close()
{ {
if (!_transport) { if (!_transport) {
@ -502,17 +601,18 @@ nsapi_error_t TLSSocketWrapper::close()
nsapi_error_t TLSSocketWrapper::connect(const SocketAddress &address) nsapi_error_t TLSSocketWrapper::connect(const SocketAddress &address)
{ {
nsapi_error_t ret = NSAPI_ERROR_OK;
if (!_transport) { if (!_transport) {
return NSAPI_ERROR_NO_SOCKET; return NSAPI_ERROR_NO_SOCKET;
} }
if (_connect_transport) { if (!is_handshake_started() && _connect_transport) {
nsapi_error_t ret = _transport->connect(address); ret = _transport->connect(address);
if (ret) { if (ret && ret != NSAPI_ERROR_IS_CONNECTED) {
return ret; return ret;
} }
} }
return do_handshake(); return start_handshake(ret == NSAPI_ERROR_OK);
} }
nsapi_error_t TLSSocketWrapper::bind(const SocketAddress &address) nsapi_error_t TLSSocketWrapper::bind(const SocketAddress &address)
@ -525,18 +625,17 @@ nsapi_error_t TLSSocketWrapper::bind(const SocketAddress &address)
void TLSSocketWrapper::set_blocking(bool blocking) void TLSSocketWrapper::set_blocking(bool blocking)
{ {
if (!_transport) { set_timeout(blocking?-1:0);
return;
}
_transport->set_blocking(blocking);
} }
void TLSSocketWrapper::set_timeout(int timeout) void TLSSocketWrapper::set_timeout(int timeout)
{ {
if (!_transport) { _timeout = timeout;
return; if (!is_handshake_started() && timeout!=-1 && _connect_transport) {
// If we have not yet connected the transport, we need to modify its blocking mode as well.
// After connection is initiated, it is already set to non blocking mode
_transport->set_timeout(timeout);
} }
_transport->set_timeout(timeout);
} }
void TLSSocketWrapper::sigio(mbed::Callback<void()> func) void TLSSocketWrapper::sigio(mbed::Callback<void()> func)
@ -544,8 +643,8 @@ void TLSSocketWrapper::sigio(mbed::Callback<void()> func)
if (!_transport) { if (!_transport) {
return; return;
} }
// Allow sigio() to propagate to upper level and handle errors on recv() and send() _sigio = func;
_transport->sigio(func); _transport->sigio(mbed::callback(this, &TLSSocketWrapper::event));
} }
nsapi_error_t TLSSocketWrapper::setsockopt(int level, int optname, const void *optval, unsigned optlen) nsapi_error_t TLSSocketWrapper::setsockopt(int level, int optname, const void *optval, unsigned optlen)
@ -577,6 +676,20 @@ nsapi_error_t TLSSocketWrapper::listen(int)
return NSAPI_ERROR_UNSUPPORTED; return NSAPI_ERROR_UNSUPPORTED;
} }
void TLSSocketWrapper::event()
{
_event_flag.set(1);
if (_sigio) {
_sigio();
}
}
bool TLSSocketWrapper::is_handshake_started() const
{
return _tls_initialized;
}
nsapi_error_t TLSSocketWrapper::getpeername(SocketAddress *address) nsapi_error_t TLSSocketWrapper::getpeername(SocketAddress *address)
{ {
if (!_handshake_completed) { if (!_handshake_completed) {

View File

@ -19,7 +19,8 @@
#define _MBED_HTTPS_TLS_SOCKET_WRAPPER_H_ #define _MBED_HTTPS_TLS_SOCKET_WRAPPER_H_
#include "netsocket/Socket.h" #include "netsocket/Socket.h"
#include "rtos/EventFlags.h"
#include "platform/Callback.h"
#include "mbedtls/platform.h" #include "mbedtls/platform.h"
#include "mbedtls/ssl.h" #include "mbedtls/ssl.h"
#include "mbedtls/entropy.h" #include "mbedtls/entropy.h"
@ -167,12 +168,12 @@ public:
*/ */
void set_ssl_config(mbedtls_ssl_config *conf); void set_ssl_config(mbedtls_ssl_config *conf);
protected: /** Get internal Mbed TLS contect structure.
/** * @return SSL context
* Helper for pretty-printing mbed TLS error codes
*/ */
static void print_mbedtls_error(const char *name, int err); mbedtls_ssl_context *get_ssl_context();
protected:
/** Initiates TLS Handshake /** Initiates TLS Handshake
* *
* Initiates a TLS handshake to a remote peer * Initiates a TLS handshake to a remote peer
@ -181,9 +182,28 @@ protected:
* Root CA certification must be set by set_ssl_ca_pem() before * Root CA certification must be set by set_ssl_ca_pem() before
* call this function. * call this function.
* *
* For non-blocking purposes, this functions needs to know whether this
* was a first call to Socket::connect() API so that NSAPI_ERROR_INPROGRESS
* does not happen twice.
*
* @parameter first_call is this a first call to Socket::connect() API.
* @return 0 on success, negative error code on failure * @return 0 on success, negative error code on failure
*/ */
nsapi_error_t do_handshake(); nsapi_error_t start_handshake(bool first_call);
bool is_handshake_started() const;
void event();
private:
/** Continue already initialised handshake */
nsapi_error_t continue_handshake();
/**
* Helper for pretty-printing mbed TLS error codes
*/
static void print_mbedtls_error(const char *name, int err);
#if MBED_CONF_TLS_SOCKET_DEBUG_LEVEL > 0 #if MBED_CONF_TLS_SOCKET_DEBUG_LEVEL > 0
/** /**
@ -211,13 +231,15 @@ protected:
*/ */
static int ssl_send(void *ctx, const unsigned char *buf, size_t len); static int ssl_send(void *ctx, const unsigned char *buf, size_t len);
private:
mbedtls_ssl_context _ssl; mbedtls_ssl_context _ssl;
mbedtls_pk_context _pkctx; mbedtls_pk_context _pkctx;
mbedtls_ctr_drbg_context _ctr_drbg; mbedtls_ctr_drbg_context _ctr_drbg;
mbedtls_entropy_context _entropy; mbedtls_entropy_context _entropy;
rtos::EventFlags _event_flag;
mbed::Callback<void()> _sigio;
Socket *_transport; Socket *_transport;
int _timeout;
#ifdef MBEDTLS_X509_CRT_PARSE_C #ifdef MBEDTLS_X509_CRT_PARSE_C
mbedtls_x509_crt *_cacert; mbedtls_x509_crt *_cacert;
@ -227,6 +249,7 @@ private:
bool _connect_transport: 1; bool _connect_transport: 1;
bool _close_transport: 1; bool _close_transport: 1;
bool _tls_initialized: 1;
bool _handshake_completed: 1; bool _handshake_completed: 1;
bool _cacert_allocated: 1; bool _cacert_allocated: 1;
bool _clicert_allocated: 1; bool _clicert_allocated: 1;

View File

@ -40,9 +40,12 @@
#include "netsocket/UDPSocket.h" #include "netsocket/UDPSocket.h"
#include "netsocket/TCPSocket.h" #include "netsocket/TCPSocket.h"
#include "netsocket/TCPServer.h" #include "netsocket/TCPServer.h"
#include "netsocket/TLSSocketWrapper.h"
#include "netsocket/DTLSSocketWrapper.h"
#include "netsocket/TLSSocket.h" #include "netsocket/TLSSocket.h"
#include "netsocket/DTLSSocket.h"
#endif #endif // __cplusplus
#endif #endif