diff --git a/Socket.cpp b/Socket.cpp index 3b853e19ec..fd32a30cb8 100644 --- a/Socket.cpp +++ b/Socket.cpp @@ -19,109 +19,153 @@ Socket::Socket() : _iface(0) , _socket(0) - , _timeout(-1) + , _timeout(osWaitForever) { } Socket::~Socket() { - if (_socket) { - close(); - } + // Underlying close is thread safe + close(); } int Socket::open(NetworkStack *iface, nsapi_protocol_t proto) { + _lock.lock(); + + if (_iface != NULL) { + _lock.unlock(); + return NSAPI_ERROR_PARAMETER; + } _iface = iface; void *socket; int err = _iface->socket_open(&socket, proto); if (err) { + _lock.unlock(); return err; } _socket = socket; _iface->socket_attach(_socket, &Socket::thunk, this); + _lock.unlock(); + return 0; } int Socket::close() { - if (!_socket) { - return 0; + _lock.lock(); + + int ret = 0; + if (_socket) { + _iface->socket_attach(_socket, 0, 0); + + void * socket = _socket; + _socket = 0; + ret = _iface->socket_close(socket); } - - _iface->socket_attach(_socket, 0, 0); - - void *volatile socket = _socket; - _socket = 0; - return _iface->socket_close(socket); + + // Wakeup anything in a blocking operation + // on this socket + socket_event(); + + _lock.unlock(); + return ret; } int Socket::bind(uint16_t port) { + // Underlying bind is thread safe SocketAddress addr(0, port); return bind(addr); } int Socket::bind(const char *address, uint16_t port) { + // Underlying bind is thread safe SocketAddress addr(address, port); return bind(addr); } int Socket::bind(const SocketAddress &address) { - if (!_socket) { - return NSAPI_ERROR_NO_SOCKET; + _lock.lock(); + + int ret = NSAPI_ERROR_NO_SOCKET; + if (_socket) { + ret = _iface->socket_bind(_socket, address); } - return _iface->socket_bind(_socket, address); + _lock.unlock(); + return ret; } void Socket::set_blocking(bool blocking) { + // Socket::set_timeout is thread safe set_timeout(blocking ? -1 : 0); } void Socket::set_timeout(int timeout) { - _timeout = timeout; + _lock.lock(); + + if (timeout >= 0) { + _timeout = (uint32_t)timeout; + } else { + _timeout = osWaitForever; + } + + _lock.unlock(); } int Socket::setsockopt(int level, int optname, const void *optval, unsigned optlen) { - if (!_socket) { - return NSAPI_ERROR_NO_SOCKET; + _lock.lock(); + + int ret = NSAPI_ERROR_NO_SOCKET; + if (_socket) { + ret = _iface->setsockopt(_socket, level, optname, optval, optlen); } - return _iface->setsockopt(_socket, level, optname, optval, optlen); + _lock.unlock(); + return ret; } int Socket::getsockopt(int level, int optname, void *optval, unsigned *optlen) { - if (!_socket) { - return NSAPI_ERROR_NO_SOCKET; + _lock.lock(); + + int ret = NSAPI_ERROR_NO_SOCKET; + if (_socket) { + ret = _iface->getsockopt(_socket, level, optname, optval, optlen); } - return _iface->getsockopt(_socket, level, optname, optval, optlen); + _lock.unlock(); + return ret; } -void Socket::wakeup() +void Socket::attach(FunctionPointer callback) { + _lock.lock(); + + _callback = callback; + + _lock.unlock(); } void Socket::thunk(void *data) { Socket *self = (Socket *)data; - if (self->_callback) { - self->_callback(); - } + self->socket_event(); } -void Socket::attach(FunctionPointer callback) +void Socket::socket_event(void) { - _callback = callback; + if (_callback) { + _callback(); + } } diff --git a/Socket.h b/Socket.h index 64e7325214..bee2d11a74 100644 --- a/Socket.h +++ b/Socket.h @@ -19,6 +19,7 @@ #include "SocketAddress.h" #include "NetworkStack.h" +#include "Mutex.h" /** Abstract socket class */ @@ -97,10 +98,11 @@ public: * * Initially all sockets have unbounded timeouts. NSAPI_ERROR_WOULD_BLOCK * is returned if a blocking operation takes longer than the specified - * timeout. A timeout of -1 removes the timeout from the socket. + * timeout. A timeout of 0 removes the timeout from the socket. A negative + * value give the socket an unbounded timeout. * - * set_timeout(-1) is equivalent to set_blocking(false) - * set_timeout(0) is equivalent to set_blocking(true) + * set_timeout(0) is equivalent to set_blocking(false) + * set_timeout(-1) is equivalent to set_blocking(true) * * @param timeout Timeout in milliseconds */ @@ -169,12 +171,13 @@ protected: int open(NetworkStack *iface, nsapi_protocol_t proto); static void thunk(void *); - static void wakeup(); + virtual void socket_event(void); NetworkStack *_iface; void *_socket; - int _timeout; + uint32_t _timeout; FunctionPointer _callback; + rtos::Mutex _lock; }; #endif diff --git a/TCPServer.cpp b/TCPServer.cpp index 84c47a09ae..4d4c9f5390 100644 --- a/TCPServer.cpp +++ b/TCPServer.cpp @@ -17,11 +17,11 @@ #include "TCPServer.h" #include "Timer.h" -TCPServer::TCPServer() +TCPServer::TCPServer(): _accept_sem(0) { } -TCPServer::TCPServer(NetworkStack *iface) +TCPServer::TCPServer(NetworkStack *iface): _accept_sem(0) { open(iface); } @@ -33,43 +33,69 @@ int TCPServer::open(NetworkStack *iface) int TCPServer::listen(int backlog) { - if (!_socket) { - return NSAPI_ERROR_NO_SOCKET; + _lock.lock(); + + int ret = NSAPI_ERROR_NO_SOCKET; + if (_socket) { + ret = _iface->socket_listen(_socket, backlog); } - return _iface->socket_listen(_socket, backlog); + _lock.unlock(); + return ret; } int TCPServer::accept(TCPSocket *connection) { - mbed::Timer timer; - timer.start(); - mbed::Timeout timeout; - if (_timeout >= 0) { - timeout.attach_us(&Socket::wakeup, _timeout * 1000); - } - - if (connection->_socket) { - connection->close(); - } + _lock.lock(); + int ret = NSAPI_ERROR_NO_SOCKET; while (true) { if (!_socket) { - return NSAPI_ERROR_NO_SOCKET; + ret = NSAPI_ERROR_NO_SOCKET; + break; } void *socket; - int err = _iface->socket_accept(&socket, _socket); - if (!err) { + ret = _iface->socket_accept(&socket, _socket); + if (0 == ret) { + connection->_lock.lock(); + + if (connection->_socket) { + connection->close(); + } + connection->_iface = _iface; connection->_socket = socket; + _iface->socket_attach(socket, &Socket::thunk, connection); + + connection->_lock.unlock(); + break; } - if (err != NSAPI_ERROR_WOULD_BLOCK - || (_timeout >= 0 && timer.read_ms() >= _timeout)) { - return err; - } + if (NSAPI_ERROR_WOULD_BLOCK == ret) { + int32_t count; - __WFI(); + _lock.unlock(); + count = _accept_sem.wait(_timeout); + _lock.lock(); + + if (count < 1) { + ret = NSAPI_ERROR_WOULD_BLOCK; + break; + } + } } + + _lock.unlock(); + return ret; +} + +void TCPServer::socket_event() +{ + int32_t status = _accept_sem.wait(0); + if (status <= 1) { + _accept_sem.release(); + } + + Socket::socket_event(); } diff --git a/TCPServer.h b/TCPServer.h index 0bb7d15a01..9155f78e67 100644 --- a/TCPServer.h +++ b/TCPServer.h @@ -20,6 +20,7 @@ #include "Socket.h" #include "TCPSocket.h" #include "NetworkStack.h" +#include "Semaphore.h" /** TCP socket server */ @@ -74,6 +75,9 @@ public: * @return 0 on success, negative error code on failure */ int accept(TCPSocket *connection); +protected: + virtual void socket_event(void); + rtos::Semaphore _accept_sem; }; #endif diff --git a/TCPSocket.cpp b/TCPSocket.cpp index 4ec5ce6ab1..14c0b61736 100644 --- a/TCPSocket.cpp +++ b/TCPSocket.cpp @@ -17,83 +17,140 @@ #include "TCPSocket.h" #include "Timer.h" -TCPSocket::TCPSocket() +TCPSocket::TCPSocket(): _read_sem(0), _write_sem(0) { } -TCPSocket::TCPSocket(NetworkStack *iface) +TCPSocket::TCPSocket(NetworkStack *iface): _read_sem(0), _write_sem(0) { + // TCPSocket::open is thread safe open(iface); } int TCPSocket::open(NetworkStack *iface) { + // Socket::open is thread safe return Socket::open(iface, NSAPI_TCP); } int TCPSocket::connect(const SocketAddress &addr) { - if (!_socket) { - return NSAPI_ERROR_NO_SOCKET; + _lock.lock(); + + int ret = NSAPI_ERROR_NO_SOCKET; + if (_socket) { + ret = _iface->socket_connect(_socket, addr); } - return _iface->socket_connect(_socket, addr); + _lock.unlock(); + return ret; } int TCPSocket::connect(const char *host, uint16_t port) { + _lock.lock(); + SocketAddress addr(_iface, host, port); - if (!addr) { - return NSAPI_ERROR_DNS_FAILURE; + int ret = NSAPI_ERROR_DNS_FAILURE; + if (addr) { + ret = connect(addr); } - return connect(addr); + _lock.unlock(); + return ret; } int TCPSocket::send(const void *data, unsigned size) { - mbed::Timer timer; - timer.start(); - mbed::Timeout timeout; - if (_timeout >= 0) { - timeout.attach_us(&Socket::wakeup, _timeout * 1000); + if (osOK != _write_lock.lock(_timeout)) { + return NSAPI_ERROR_WOULD_BLOCK; } + _lock.lock(); + int ret; while (true) { if (!_socket) { - return NSAPI_ERROR_NO_SOCKET; + ret = NSAPI_ERROR_NO_SOCKET; + break; } int sent = _iface->socket_send(_socket, data, size); - if (sent != NSAPI_ERROR_WOULD_BLOCK - || (_timeout >= 0 && timer.read_ms() >= _timeout)) { - return sent; - } + if ((0 == _timeout) || (NSAPI_ERROR_WOULD_BLOCK != sent)) { + ret = sent; + break; + } else { + int32_t count; - __WFI(); + // Release lock before blocking so other threads + // accessing this object aren't blocked + _lock.unlock(); + count = _write_sem.wait(_timeout); + _lock.lock(); + + if (count < 1) { + // Semaphore wait timed out so break out and return + ret = NSAPI_ERROR_WOULD_BLOCK; + break; + } + } } + + _lock.unlock(); + _write_lock.unlock(); + return ret; } int TCPSocket::recv(void *data, unsigned size) { - mbed::Timer timer; - timer.start(); - mbed::Timeout timeout; - if (_timeout >= 0) { - timeout.attach_us(&Socket::wakeup, _timeout * 1000); + if (osOK != _read_lock.lock(_timeout)) { + return NSAPI_ERROR_WOULD_BLOCK; } + _lock.lock(); + int ret; while (true) { if (!_socket) { - return NSAPI_ERROR_NO_SOCKET; - } - - int recv = _iface->socket_recv(_socket, data, size); - if (recv != NSAPI_ERROR_WOULD_BLOCK - || (_timeout >= 0 && timer.read_ms() >= _timeout)) { - return recv; + ret = NSAPI_ERROR_NO_SOCKET; + break; } - __WFI(); + int recv = _iface->socket_recv(_socket, data, size); + if ((0 == _timeout) || (NSAPI_ERROR_WOULD_BLOCK != recv)) { + ret = recv; + break; + } else { + int32_t count; + + // Release lock before blocking so other threads + // accessing this object aren't blocked + _lock.unlock(); + count = _read_sem.wait(_timeout); + _lock.lock(); + + if (count < 1) { + // Semaphore wait timed out so break out and return + ret = NSAPI_ERROR_WOULD_BLOCK; + break; + } + } } + + _lock.unlock(); + _read_lock.unlock(); + return ret; +} + +void TCPSocket::socket_event() +{ + int32_t count; + count = _write_sem.wait(0); + if (count <= 1) { + _write_sem.release(); + } + count = _read_sem.wait(0); + if (count <= 1) { + _read_sem.release(); + } + + Socket::socket_event(); } diff --git a/TCPSocket.h b/TCPSocket.h index 2a37f776fb..5341174344 100644 --- a/TCPSocket.h +++ b/TCPSocket.h @@ -19,6 +19,7 @@ #include "Socket.h" #include "NetworkStack.h" +#include "Semaphore.h" /** TCP socket connection */ @@ -101,7 +102,12 @@ public: */ int recv(void *data, unsigned size); -private: +protected: + virtual void socket_event(void); + rtos::Mutex _read_lock; + rtos::Semaphore _read_sem; + rtos::Mutex _write_lock; + rtos::Semaphore _write_sem; friend class TCPServer; }; diff --git a/UDPSocket.cpp b/UDPSocket.cpp index 9494323fbd..dbff20fbf0 100644 --- a/UDPSocket.cpp +++ b/UDPSocket.cpp @@ -17,11 +17,11 @@ #include "UDPSocket.h" #include "Timer.h" -UDPSocket::UDPSocket() +UDPSocket::UDPSocket(): _read_sem(0), _write_sem(0) { } -UDPSocket::UDPSocket(NetworkStack *iface) +UDPSocket::UDPSocket(NetworkStack *iface): _read_sem(0), _write_sem(0) { open(iface); } @@ -38,53 +38,103 @@ int UDPSocket::sendto(const char *host, uint16_t port, const void *data, unsigne return NSAPI_ERROR_DNS_FAILURE; } - return sendto(addr, data, size); + // sendto is thread safe + int ret = sendto(addr, data, size); + + return ret; } int UDPSocket::sendto(const SocketAddress &address, const void *data, unsigned size) { - mbed::Timer timer; - timer.start(); - mbed::Timeout timeout; - if (_timeout >= 0) { - timeout.attach_us(&Socket::wakeup, _timeout * 1000); + if (osOK != _write_lock.lock(_timeout)) { + return NSAPI_ERROR_WOULD_BLOCK; } + _lock.lock(); + int ret; while (true) { if (!_socket) { - return NSAPI_ERROR_NO_SOCKET; - } - - int sent = _iface->socket_sendto(_socket, address, data, size); - if (sent != NSAPI_ERROR_WOULD_BLOCK - || (_timeout >= 0 && timer.read_ms() >= _timeout)) { - return sent; + ret = NSAPI_ERROR_NO_SOCKET; + break; } - __WFI(); + int sent = _iface->socket_sendto(_socket, address, data, size); + if ((0 == _timeout) || (NSAPI_ERROR_WOULD_BLOCK != sent)) { + ret = sent; + break; + } else { + int32_t count; + + // Release lock before blocking so other threads + // accessing this object aren't blocked + _lock.unlock(); + count = _write_sem.wait(_timeout); + _lock.lock(); + + if (count < 1) { + // Semaphore wait timed out so break out and return + ret = NSAPI_ERROR_WOULD_BLOCK; + break; + } + } } + + _lock.unlock(); + _write_lock.unlock(); + return ret; } int UDPSocket::recvfrom(SocketAddress *address, void *buffer, unsigned size) { - mbed::Timer timer; - timer.start(); - mbed::Timeout timeout; - if (_timeout >= 0) { - timeout.attach_us(&Socket::wakeup, _timeout * 1000); + if (osOK != _read_lock.lock(_timeout)) { + return NSAPI_ERROR_WOULD_BLOCK; } + _lock.lock(); + int ret; while (true) { if (!_socket) { - return NSAPI_ERROR_NO_SOCKET; - } - - int recv = _iface->socket_recvfrom(_socket, address, buffer, size); - if (recv != NSAPI_ERROR_WOULD_BLOCK - || (_timeout >= 0 && timer.read_ms() >= _timeout)) { - return recv; + ret = NSAPI_ERROR_NO_SOCKET; + break; } - __WFI(); + int recv = _iface->socket_recvfrom(_socket, address, buffer, size); + if ((0 == _timeout) || (NSAPI_ERROR_WOULD_BLOCK != recv)) { + ret = recv; + break; + } else { + int32_t count; + + // Release lock before blocking so other threads + // accessing this object aren't blocked + _lock.unlock(); + count = _read_sem.wait(_timeout); + _lock.lock(); + + if (count < 1) { + // Semaphore wait timed out so break out and return + ret = NSAPI_ERROR_WOULD_BLOCK; + break; + } + } } + + _lock.unlock(); + _read_lock.unlock(); + return ret; +} + +void UDPSocket::socket_event() +{ + int32_t count; + count = _write_sem.wait(0); + if (count <= 1) { + _write_sem.release(); + } + count = _read_sem.wait(0); + if (count <= 1) { + _read_sem.release(); + } + + Socket::socket_event(); } diff --git a/UDPSocket.h b/UDPSocket.h index b2a8454d9f..d5152d6348 100644 --- a/UDPSocket.h +++ b/UDPSocket.h @@ -19,6 +19,7 @@ #include "Socket.h" #include "NetworkStack.h" +#include "Semaphore.h" /** UDP socket */ @@ -100,6 +101,12 @@ public: * code on failure */ int recvfrom(SocketAddress *address, void *data, unsigned size); +protected: + virtual void socket_event(void); + rtos::Mutex _read_lock; + rtos::Semaphore _read_sem; + rtos::Mutex _write_lock; + rtos::Semaphore _write_sem; }; #endif