diff --git a/TEST_APPS/device/socket_app/cmd_socket.cpp b/TEST_APPS/device/socket_app/cmd_socket.cpp index 17ecf4769a..7ecf176133 100644 --- a/TEST_APPS/device/socket_app/cmd_socket.cpp +++ b/TEST_APPS/device/socket_app/cmd_socket.cpp @@ -18,13 +18,13 @@ #include "UDPSocket.h" #include "TCPSocket.h" #include "TCPServer.h" +#include "TLSSocket.h" #include "NetworkInterface.h" #include "SocketAddress.h" #include "Queue.h" #include #include -#include #include #ifndef __STDC_FORMAT_MACROS #define __STDC_FORMAT_MACROS @@ -42,23 +42,26 @@ #define SIGNAL_SIGIO 0x1 #define PACKET_SIZE_ARRAY_LEN 5 +#define CERT_BUFFER_SIZE 1648 + #define MAN_SOCKET "\r\nSOCKET API\r\n"\ "\r\n"\ "socket [options]\r\n\r\n"\ " new \r\n" \ - " type: UDPSocket|TCPSocket|TCPServer\r\n"\ + " type: UDPSocket|TCPSocket|TCPServer|TLSSocket [--cert_file |--cert_default]\r\n"\ " return socket id\r\n"\ " delete\r\n"\ " remote the space allocated for Socket\r\n"\ - " open\r\n"\ + " open [--if ] \r\n"\ + " interface (or use default interface) \r\n"\ " close\r\n"\ " bind [port] [addr ]\r\n"\ " set_blocking \r\n"\ " set_timeout \r\n"\ " register_sigio_cb\r\n"\ " set_RFC_864_pattern_check \r\n"\ - "\r\nFor UDPSocket\r\n"\ + " set_root_ca_cert --cert_url |--cert_file |--cert_default\r\n"\ " sendto (\"msg\" | --data_len )\r\n"\ " \"msg\" Send packet with defined string content\r\n"\ " --data_len Send packet with random content with size \r\n"\ @@ -66,7 +69,6 @@ " start_udp_receiver_thread --max_data_len [--packets ]\r\n"\ " --max_data_len Size of input buffer to fill up\r\n"\ " --packets Receive N number of packets, default 1\r\n"\ - "\r\nFor TCPSocket\r\n"\ " connect \r\n"\ " send (\"msg\" | --data_len )\r\n"\ " recv \r\n"\ @@ -79,29 +81,60 @@ " join_bg_traffic_thread\r\n"\ " setsockopt_keepalive \r\n"\ " getsockopt_keepalive\r\n"\ - "\r\nFor TCPServer\r\n"\ " listen [backlog]\r\n"\ + " accept\r\n" \ + " accept new connection and returns new socket ID\r\n" \ + "\r\nFor TCPServer\r\n"\ " accept \r\n"\ " accept new connection into socket. Requires to be pre-allocated.\r\n"\ "\r\nOther options\r\n"\ " print-mode [--string|--hex|--disabled] [--col-width ]" -class SInfo; -static Queue event_queue; -static int id_count = 0; + + +const char *cert = \ + "-----BEGIN CERTIFICATE-----\n" \ + "MIIEkjCCA3qgAwIBAgIQCgFBQgAAAVOFc2oLheynCDANBgkqhkiG9w0BAQsFADA/\n" \ + "MSQwIgYDVQQKExtEaWdpdGFsIFNpZ25hdHVyZSBUcnVzdCBDby4xFzAVBgNVBAMT\n" \ + "DkRTVCBSb290IENBIFgzMB4XDTE2MDMxNzE2NDA0NloXDTIxMDMxNzE2NDA0Nlow\n" \ + "SjELMAkGA1UEBhMCVVMxFjAUBgNVBAoTDUxldCdzIEVuY3J5cHQxIzAhBgNVBAMT\n" \ + "GkxldCdzIEVuY3J5cHQgQXV0aG9yaXR5IFgzMIIBIjANBgkqhkiG9w0BAQEFAAOC\n" \ + "AQ8AMIIBCgKCAQEAnNMM8FrlLke3cl03g7NoYzDq1zUmGSXhvb418XCSL7e4S0EF\n" \ + "q6meNQhY7LEqxGiHC6PjdeTm86dicbp5gWAf15Gan/PQeGdxyGkOlZHP/uaZ6WA8\n" \ + "SMx+yk13EiSdRxta67nsHjcAHJyse6cF6s5K671B5TaYucv9bTyWaN8jKkKQDIZ0\n" \ + "Z8h/pZq4UmEUEz9l6YKHy9v6Dlb2honzhT+Xhq+w3Brvaw2VFn3EK6BlspkENnWA\n" \ + "a6xK8xuQSXgvopZPKiAlKQTGdMDQMc2PMTiVFrqoM7hD8bEfwzB/onkxEz0tNvjj\n" \ + "/PIzark5McWvxI0NHWQWM6r6hCm21AvA2H3DkwIDAQABo4IBfTCCAXkwEgYDVR0T\n" \ + "AQH/BAgwBgEB/wIBADAOBgNVHQ8BAf8EBAMCAYYwfwYIKwYBBQUHAQEEczBxMDIG\n" \ + "CCsGAQUFBzABhiZodHRwOi8vaXNyZy50cnVzdGlkLm9jc3AuaWRlbnRydXN0LmNv\n" \ + "bTA7BggrBgEFBQcwAoYvaHR0cDovL2FwcHMuaWRlbnRydXN0LmNvbS9yb290cy9k\n" \ + "c3Ryb290Y2F4My5wN2MwHwYDVR0jBBgwFoAUxKexpHsscfrb4UuQdf/EFWCFiRAw\n" \ + "VAYDVR0gBE0wSzAIBgZngQwBAgEwPwYLKwYBBAGC3xMBAQEwMDAuBggrBgEFBQcC\n" \ + "ARYiaHR0cDovL2Nwcy5yb290LXgxLmxldHNlbmNyeXB0Lm9yZzA8BgNVHR8ENTAz\n" \ + "MDGgL6AthitodHRwOi8vY3JsLmlkZW50cnVzdC5jb20vRFNUUk9PVENBWDNDUkwu\n" \ + "Y3JsMB0GA1UdDgQWBBSoSmpjBH3duubRObemRWXv86jsoTANBgkqhkiG9w0BAQsF\n" \ + "AAOCAQEA3TPXEfNjWDjdGBX7CVW+dla5cEilaUcne8IkCJLxWh9KEik3JHRRHGJo\n" \ + "uM2VcGfl96S8TihRzZvoroed6ti6WqEBmtzw3Wodatg+VyOeph4EYpr/1wXKtx8/\n" \ + "wApIvJSwtmVi4MFU5aMqrSDE6ea73Mj2tcMyo5jMd6jmeWUHK8so/joWUoHOUgwu\n" \ + "X4Po1QYz+3dszkDqMp4fklxBwXRsW10KXzPMTZ+sOPAveyxindmjkW8lGy+QsRlG\n" \ + "PfZ+G6Z6h7mjem0Y+iWlkYcV4PIWL1iwBi8saCbGS5jN2p8M+X+Q7UNKEkROb3N6\n" \ + "KOqkqm57TH2H3eDJAkSnh6/DNFu0Qg==\n" \ + "-----END CERTIFICATE-----\n"; class SInfo { public: enum SocketType { - TCP_CLIENT, + IP, TCP_SERVER, - UDP, - OTHER + OTHER, +#if defined(MBEDTLS_SSL_CLI_C) + TLS +#endif }; - SInfo(TCPSocket *sock): + SInfo(InternetSocket *sock, bool delete_on_exit = true): _id(id_count++), _sock(sock), - _type(SInfo::TCP_CLIENT), + _type(SInfo::IP), _blocking(true), _dataLen(0), _maxRecvLen(0), @@ -112,11 +145,12 @@ public: _senderThreadId(NULL), _receiverThreadId(NULL), _packetSizes(NULL), - _check_pattern(false) + _check_pattern(false), + _delete_on_exit(delete_on_exit) { - assert(sock); + MBED_ASSERT(sock); } - SInfo(TCPServer *sock): + SInfo(TCPServer *sock, bool delete_on_exit = true): _id(id_count++), _sock(sock), _type(SInfo::TCP_SERVER), @@ -130,27 +164,10 @@ public: _senderThreadId(NULL), _receiverThreadId(NULL), _packetSizes(NULL), - _check_pattern(false) + _check_pattern(false), + _delete_on_exit(delete_on_exit) { - assert(sock); - } - SInfo(UDPSocket *sock): - _id(id_count++), - _sock(sock), - _type(SInfo::UDP), - _blocking(true), - _dataLen(0), - _maxRecvLen(0), - _repeatBufferFill(1), - _receivedTotal(0), - _receiverThread(NULL), - _receiveBuffer(NULL), - _senderThreadId(NULL), - _receiverThreadId(NULL), - _packetSizes(NULL), - _check_pattern(false) - { - assert(sock); + MBED_ASSERT(sock); } SInfo(Socket *sock, bool delete_on_exit = true): _id(id_count++), @@ -166,10 +183,32 @@ public: _senderThreadId(NULL), _receiverThreadId(NULL), _packetSizes(NULL), - _check_pattern(false) + _check_pattern(false), + _delete_on_exit(delete_on_exit) { MBED_ASSERT(sock); } +#if defined(MBEDTLS_SSL_CLI_C) + SInfo(TLSSocket *sock, bool delete_on_exit = true): + _id(id_count++), + _sock(sock), + _type(SInfo::TLS), + _blocking(true), + _dataLen(0), + _maxRecvLen(0), + _repeatBufferFill(1), + _receivedTotal(0), + _receiverThread(NULL), + _receiveBuffer(NULL), + _senderThreadId(NULL), + _receiverThreadId(NULL), + _packetSizes(NULL), + _check_pattern(false), + _delete_on_exit(delete_on_exit) + { + MBED_ASSERT(sock); + } +#endif ~SInfo() { this->_sock->sigio(Callback()); @@ -180,7 +219,9 @@ public: if (this->_receiveBuffer) { delete this->_receiveBuffer; } - delete this->_sock; + if (_delete_on_exit) { + delete this->_sock; + } } int id() const { @@ -194,18 +235,20 @@ public: { return *(this->_sock); } - TCPSocket *tcp_socket() + InternetSocket *internetsocket() { - return this->_type == SInfo::TCP_CLIENT ? static_cast(this->_sock) : NULL; + return this->_type == SInfo::IP ? static_cast(this->_sock) : NULL; } TCPServer *tcp_server() { return this->_type == SInfo::TCP_SERVER ? static_cast(this->_sock) : NULL; } - UDPSocket *udp_socket() +#if defined(MBEDTLS_SSL_CLI_C) + TLSSocket *tls_socket() { - return this->_type == SInfo::UDP ? static_cast(this->_sock) : NULL; + return this->_type == SInfo::TLS ? static_cast(this->_sock) : NULL; } +#endif SInfo::SocketType type() const { return this->_type; @@ -304,17 +347,22 @@ public: { const char *str; switch (this->_type) { - case SInfo::TCP_CLIENT: - str = "TCPSocket"; + case SInfo::IP: + str = "InternetSocket"; break; case SInfo::TCP_SERVER: str = "TCPServer"; break; - case SInfo::UDP: - str = "UDPSocket"; + case SInfo::OTHER: + str = "Socket"; break; +#if defined(MBEDTLS_SSL_CLI_C) + case SInfo::TLS: + str = "TLSSocket"; + break; +#endif default: - assert(0); + MBED_ASSERT(0); break; } return str; @@ -328,39 +376,8 @@ public: socket().set_blocking(blocking); this->_blocking = blocking; } - bool can_connect() - { - return (this->type() == SInfo::TCP_CLIENT); - } - bool can_bind() - { - return (this->type() == SInfo::UDP || this->type() == SInfo::TCP_SERVER); - } - bool can_send() - { - return (this->type() == SInfo::TCP_CLIENT); - } - bool can_recv() - { - return (this->type() == SInfo::TCP_CLIENT); - } - bool can_sendto() - { - return (this->type() == SInfo::UDP); - } - bool can_recvfrom() - { - return (this->type() == SInfo::UDP); - } - bool can_listen() - { - return (this->type() == SInfo::TCP_SERVER); - } - bool can_accept() - { - return (this->type() == SInfo::TCP_SERVER); - } private: + static int id_count; const int _id; Socket *_sock; const SInfo::SocketType _type; @@ -376,9 +393,11 @@ private: int *_packetSizes; bool _available; bool _check_pattern; + bool _delete_on_exit; SInfo(); }; +int SInfo::id_count = 0; static std::vector m_sockets; @@ -426,6 +445,22 @@ static void generate_RFC_864_pattern(size_t offset, uint8_t *buf, size_t len, b } } +static int get_cert_from_file(const char *filename, char **cert) +{ + int filedesc = open(filename, O_RDONLY); + if (filedesc < 0) { + cmd_printf("Cannot open file: %s\r\n", filename); + return CMDLINE_RETCODE_FAIL; + } + + if (read(filedesc, *cert, CERT_BUFFER_SIZE) != CERT_BUFFER_SIZE) { + cmd_printf("Cannot read from file %s\r\n", filename); + return CMDLINE_RETCODE_FAIL; + } + + return CMDLINE_RETCODE_SUCCESS; +} + bool SInfo::check_pattern(void *buffer, size_t len) { static bool is_xinetd = false; @@ -465,6 +500,9 @@ static void sigio_handler(SInfo *info) void cmd_socket_init(void) { cmd_add("socket", cmd_socket, "socket", MAN_SOCKET); + cmd_alias_add("socket help", "socket -h"); + cmd_alias_add("socket --help", "socket -h"); + cmd_alias_add("ping server start", "socket echo-server new start"); } int handle_nsapi_error(const char *function, nsapi_error_t ret) @@ -509,13 +547,46 @@ static int del_sinfo(SInfo *info) return CMDLINE_RETCODE_FAIL; } +#if defined(MBEDTLS_SSL_CLI_C) +static int tls_set_cert(int argc, char *argv[], SInfo *info) +{ + static char read_cert[CERT_BUFFER_SIZE]; + char *ptr_cert = NULL; + char *src = NULL; + if (cmd_parameter_val(argc, argv, "--cert_file", &src)) { + tr_debug("Root ca certificate read from file: %s", src); + ptr_cert = read_cert; + if (get_cert_from_file(src, &ptr_cert) == CMDLINE_RETCODE_FAIL) { + cmd_printf("Cannot read from url: %s\r\n", src); + return CMDLINE_RETCODE_INVALID_PARAMETERS; + } + } else if (cmd_parameter_index(argc, argv, "--cert_default") != -1) { + cmd_printf("Using default certificate\r\n"); + ptr_cert = (char *)cert; + } else { + cmd_printf("No cert specified. Use set_root_ca_cert to set it.\r\n"); + // Do not return error, allow the certificate not to be set. + return CMDLINE_RETCODE_SUCCESS; + } + + int ret = info->tls_socket()->set_root_ca_cert(ptr_cert); + if (ret != NSAPI_ERROR_OK) { + cmd_printf("Invalid root certificate\r\n"); + return CMDLINE_RETCODE_FAIL; + } + + return CMDLINE_RETCODE_SUCCESS; +} +#endif + static int cmd_socket_new(int argc, char *argv[]) { const char *s; SInfo *info; + nsapi_error_t ret; - if (cmd_parameter_last(argc, argv)) { - s = cmd_parameter_last(argc, argv); + if (argc > 2) { + s = argv[2]; if (strcmp(s, "UDPSocket") == 0) { tr_debug("Creating a new UDPSocket"); info = new SInfo(new UDPSocket); @@ -525,6 +596,16 @@ static int cmd_socket_new(int argc, char *argv[]) } else if (strcmp(s, "TCPServer") == 0) { tr_debug("Creating a new TCPServer"); info = new SInfo(new TCPServer); +#if defined(MBEDTLS_SSL_CLI_C) + } else if (strcmp(s, "TLSSocket") == 0) { + tr_debug("Creating a new TLSSocket"); + info = new SInfo(new TLSSocket); + ret = tls_set_cert(argc, argv, info); + if (ret) { + delete info; + return ret; + } +#endif } else { cmd_printf("unsupported protocol: %s\r\n", s); return CMDLINE_RETCODE_INVALID_PARAMETERS; @@ -551,7 +632,7 @@ static void udp_receiver_thread(SInfo *info) info->setReceiverThreadId(ThisThread::get_id()); while (i < n) { - ret = static_cast(info->socket()).recvfrom(&addr, info->getReceiveBuffer() + received, info->getDataCount() - received); + ret = info->socket().recvfrom(&addr, info->getReceiveBuffer() + received, info->getDataCount() - received); if (ret > 0) { if (!info->check_pattern(info->getReceiveBuffer() + received, ret)) { return; @@ -561,9 +642,9 @@ static void udp_receiver_thread(SInfo *info) i++; info->setRecvTotal(info->getRecvTotal() + ret); } else if (ret == NSAPI_ERROR_WOULD_BLOCK) { - ThisThread::flags_wait_all(SIGNAL_SIGIO); + ThisThread::flags_wait_any(SIGNAL_SIGIO); } else { - handle_nsapi_size_or_error("Thread: UDPSocket::recvfrom()", ret); + handle_nsapi_size_or_error("Thread: Socket::recvfrom()", ret); return; } } @@ -642,7 +723,12 @@ static nsapi_size_or_error_t udp_sendto_command_handler(SInfo *info, int argc, c } } - nsapi_size_or_error_t ret = static_cast(info->socket()).sendto(host, port, data, len); + SocketAddress addr(NULL, port); + nsapi_size_or_error_t ret = get_interface()->gethostbyname(host, &addr); + if (ret) { + return handle_nsapi_size_or_error("NetworkInterface::gethostbyname()", ret); + } + ret = info->socket().sendto(addr, data, len); if (ret > 0) { cmd_printf("sent: %d bytes\r\n", ret); } @@ -650,7 +736,7 @@ static nsapi_size_or_error_t udp_sendto_command_handler(SInfo *info, int argc, c free(data); } - return handle_nsapi_size_or_error("UDPSocket::sendto()", ret); + return handle_nsapi_size_or_error("Socket::sendto()", ret); } static nsapi_size_or_error_t udp_recvfrom_command_handler(SInfo *info, int argc, char *argv[]) @@ -668,9 +754,9 @@ static nsapi_size_or_error_t udp_recvfrom_command_handler(SInfo *info, int argc, cmd_printf("malloc() failed\r\n"); return CMDLINE_RETCODE_FAIL; } - nsapi_size_or_error_t ret = static_cast(info->socket()).recvfrom(&addr, data, len); + nsapi_size_or_error_t ret = info->socket().recvfrom(&addr, data, len); if (ret > 0) { - cmd_printf("UDPSocket::recvfrom, addr=%s port=%d\r\n", addr.get_ip_address(), addr.get_port()); + cmd_printf("Socket::recvfrom, addr=%s port=%d\r\n", addr.get_ip_address(), addr.get_port()); cmd_printf("received: %d bytes\r\n", ret); print_data((const uint8_t *)data, len); if (!info->check_pattern(data, len)) { @@ -679,7 +765,7 @@ static nsapi_size_or_error_t udp_recvfrom_command_handler(SInfo *info, int argc, info->setRecvTotal(info->getRecvTotal() + ret); } free(data); - return handle_nsapi_size_or_error("UDPSocket::recvfrom()", ret); + return handle_nsapi_size_or_error("Socket::recvfrom()", ret); } static void tcp_receiver_thread(SInfo *info) @@ -695,7 +781,7 @@ static void tcp_receiver_thread(SInfo *info) for (i = 0; i < n; i++) { received = 0; while (received < bufferSize) { - ret = static_cast(info->socket()).recv(info->getReceiveBuffer() + received, recv_len - received); + ret = info->socket().recv(info->getReceiveBuffer() + received, recv_len - received); if (ret > 0) { if (!info->check_pattern(info->getReceiveBuffer() + received, ret)) { return; @@ -705,7 +791,7 @@ static void tcp_receiver_thread(SInfo *info) } else if (ret == NSAPI_ERROR_WOULD_BLOCK) { ThisThread::flags_wait_all(SIGNAL_SIGIO); } else { - handle_nsapi_size_or_error("Thread: TCPSocket::recv()", ret); + handle_nsapi_size_or_error("Thread: Socket::recv()", ret); return; } } @@ -772,7 +858,7 @@ static nsapi_size_or_error_t tcp_send_command_handler(SInfo *info, int argc, cha len = strlen(argv[3]); } - ret = static_cast(info->socket()).send(data, len); + ret = info->socket().send(data, len); if (ret > 0) { cmd_printf("sent: %d bytes\r\n", ret); @@ -780,7 +866,7 @@ static nsapi_size_or_error_t tcp_send_command_handler(SInfo *info, int argc, cha if (data != argv[3]) { free(data); } - return handle_nsapi_size_or_error("TCPSocket::send()", ret); + return handle_nsapi_size_or_error("Socket::send()", ret); } static nsapi_size_or_error_t tcp_recv_command_handler(SInfo *info, int argc, char *argv[]) @@ -798,7 +884,7 @@ static nsapi_size_or_error_t tcp_recv_command_handler(SInfo *info, int argc, cha return CMDLINE_RETCODE_FAIL; } - nsapi_size_or_error_t ret = static_cast(info->socket()).recv(data, len); + nsapi_size_or_error_t ret = info->socket().recv(data, len); if (ret > 0) { cmd_printf("received: %d bytes\r\n", ret); print_data((const uint8_t *)data, ret); @@ -808,7 +894,7 @@ static nsapi_size_or_error_t tcp_recv_command_handler(SInfo *info, int argc, cha info->setRecvTotal(info->getRecvTotal() + ret); } free(data); - return handle_nsapi_size_or_error("TCPSocket::recv()", ret); + return handle_nsapi_size_or_error("Socket::recv()", ret); } static nsapi_size_or_error_t recv_all(char *const rbuffer, const int expt_len, SInfo *const info) @@ -820,7 +906,7 @@ static nsapi_size_or_error_t recv_all(char *const rbuffer, const int expt_len, S rhead = rbuffer; while (rtotal < expt_len) { - rbytes = info->tcp_socket()->recv(rhead, expt_len); + rbytes = info->socket().recv(rhead, expt_len); if (rbytes <= 0) { // Connection closed abruptly rbuffer[rtotal] = '\0'; return rbytes; @@ -847,7 +933,7 @@ static void bg_traffic_thread(SInfo *info) break; } sbuffer[data_len - 1] = 'A' + (rand() % 26); - scount = info->tcp_socket()->send(sbuffer, data_len); + scount = info->socket().send(sbuffer, data_len); rtotal = recv_all(rbuffer, data_len, info); if (scount != rtotal || (strcmp(sbuffer, rbuffer) != 0)) { @@ -950,14 +1036,18 @@ static int cmd_socket(int argc, char *argv[]) } switch (info->type()) { - case SInfo::TCP_CLIENT: - return handle_nsapi_error("Socket::open()", info->tcp_socket()->open(interface)); - case SInfo::UDP: - return handle_nsapi_error("Socket::open()", info->udp_socket()->open(interface)); + case SInfo::IP: + return handle_nsapi_error("Socket::open()", info->internetsocket()->open(interface)); case SInfo::TCP_SERVER: - return handle_nsapi_error("Socket::open()", info->tcp_server()->open(interface)); + return handle_nsapi_error("TCPServer::open()", info->tcp_server()->open(interface)); +#if defined(MBEDTLS_SSL_CLI_C) + case SInfo::TLS: + return handle_nsapi_error("Socket::open()", info->tls_socket()->open(interface)); +#endif + default: + cmd_printf("Not a IP socket\r\n"); + return CMDLINE_RETCODE_FAIL; } - } else if (COMMAND_IS("close")) { return handle_nsapi_error("Socket::close()", info->socket().close()); @@ -1016,12 +1106,6 @@ static int cmd_socket(int argc, char *argv[]) * Commands related to UDPSocket: * sendto, recvfrom */ - if ((COMMAND_IS("sendto") || COMMAND_IS("recvfrom") || COMMAND_IS("start_udp_receiver_thread") - || COMMAND_IS("last_data_received")) && info->type() != SInfo::UDP) { - cmd_printf("Not UDPSocket\r\n"); - return CMDLINE_RETCODE_FAIL; - } - if (COMMAND_IS("sendto")) { return udp_sendto_command_handler(info, argc, argv); } else if (COMMAND_IS("recvfrom")) { @@ -1043,22 +1127,13 @@ static int cmd_socket(int argc, char *argv[]) } thread_clean_up(info); - return handle_nsapi_error("UDPSocket::last_data_received()", NSAPI_ERROR_OK); + return handle_nsapi_error("Socket::last_data_received()", NSAPI_ERROR_OK); } /* - * Commands related to TCPSocket + * Commands related to TCPSocket, TLSSocket * connect, send, recv */ - if ((COMMAND_IS("connect") || COMMAND_IS("recv") - || COMMAND_IS("start_tcp_receiver_thread") || COMMAND_IS("join_tcp_receiver_thread") - || COMMAND_IS("start_bg_traffic_thread") || COMMAND_IS("join_bg_traffic_thread") - || COMMAND_IS("setsockopt_keepalive") || COMMAND_IS("getsockopt_keepalive")) - && info->type() != SInfo::TCP_CLIENT) { - cmd_printf("Not TCPSocket\r\n"); - return CMDLINE_RETCODE_FAIL; - } - if (COMMAND_IS("connect")) { char *host; int32_t port; @@ -1076,7 +1151,18 @@ static int cmd_socket(int argc, char *argv[]) } cmd_printf("Host name: %s port: %" PRId32 "\r\n", host, port); - return handle_nsapi_error("TCPSocket::connect()", static_cast(info->socket()).connect(host, port)); + if (info->type() == SInfo::IP) { + SocketAddress addr(NULL, port); + nsapi_error_t ret = get_interface()->gethostbyname(host, &addr); + if (ret) { + return handle_nsapi_error("NetworkInterface::gethostbyname()", ret); + } + return handle_nsapi_error("Socket::connect()", info->socket().connect(addr)); +#if defined(MBEDTLS_SSL_CLI_C) + } else if (info->type() == SInfo::TLS) { + return handle_nsapi_error("TLSSocket::connect()", static_cast(info->socket()).connect(host, port)); +#endif + } } else if (COMMAND_IS("send")) { return tcp_send_command_handler(info, argc, argv); @@ -1121,7 +1207,7 @@ static int cmd_socket(int argc, char *argv[]) ret = info->socket().setsockopt(NSAPI_SOCKET, NSAPI_KEEPALIVE, &seconds, sizeof(seconds)); - return handle_nsapi_error("TCPSocket::setsockopt()", ret); + return handle_nsapi_error("Socket::setsockopt()", ret); } else if (COMMAND_IS("getsockopt_keepalive")) { int32_t optval; unsigned optlen = sizeof(optval); @@ -1133,29 +1219,30 @@ static int cmd_socket(int argc, char *argv[]) return CMDLINE_RETCODE_FAIL; } if (ret < 0) { - return handle_nsapi_error("TCPSocket::getsockopt()", ret); + return handle_nsapi_error("Socket::getsockopt()", ret); } - return handle_nsapi_size_or_error("TCPSocket::getsockopt()", optval); + return handle_nsapi_size_or_error("Socket::getsockopt()", optval); } /* - * Commands for TCPServer and TCPSocket + * Commands for TCPServer * listen, accept */ if (COMMAND_IS("listen")) { int32_t backlog; if (cmd_parameter_int(argc, argv, "listen", &backlog)) { - return handle_nsapi_error("Socket::listen()", info->socket().listen(backlog)); + return handle_nsapi_error("TCPServer::listen()", info->socket().listen(backlog)); } else { - return handle_nsapi_error("Socket::listen()", info->socket().listen()); + return handle_nsapi_error("TCPServer::listen()", info->socket().listen()); } } else if (COMMAND_IS("accept")) { nsapi_error_t ret; + if (info->type() != SInfo::TCP_SERVER) { Socket *new_sock = info->socket().accept(&ret); if (ret == NSAPI_ERROR_OK) { - SInfo *new_info = new SInfo(new_sock); + SInfo *new_info = new SInfo(new_sock, false); m_sockets.push_back(new_info); cmd_printf("Socket::accept() new socket sid: %d\r\n", new_info->id()); } @@ -1182,6 +1269,22 @@ static int cmd_socket(int argc, char *argv[]) return handle_nsapi_error("TCPServer::accept()", ret); } } + + + /* + * Commands for TLSSocket + * set_root_ca_cert + */ +#if defined(MBEDTLS_SSL_CLI_C) + if (COMMAND_IS("set_root_ca_cert")) { + if (info->type() != SInfo::TLS) { + cmd_printf("Not a TLS socket.\r\n"); + return CMDLINE_RETCODE_FAIL; + } + return handle_nsapi_error("TLSSocket::tls_set_cert", tls_set_cert(argc, argv, info)); + } +#endif + return CMDLINE_RETCODE_INVALID_PARAMETERS; } @@ -1197,7 +1300,7 @@ void print_data(const uint8_t *buf, int len) case PRINT_DISABLED: break; default: - assert(0); + MBED_ASSERT(0); } } diff --git a/TEST_APPS/testcases/netsocket/TCPSOCKET_ECHOTEST_BURST_SHORT.py b/TEST_APPS/testcases/netsocket/TCPSOCKET_ECHOTEST_BURST_SHORT.py index e913b1b4e5..521db0e516 100644 --- a/TEST_APPS/testcases/netsocket/TCPSOCKET_ECHOTEST_BURST_SHORT.py +++ b/TEST_APPS/testcases/netsocket/TCPSOCKET_ECHOTEST_BURST_SHORT.py @@ -59,7 +59,7 @@ class Testcase(Bench): packet = Randomize.random_string(max_len=size, min_len=size, chars=string.ascii_uppercase) sentData += packet response = self.command("dut1", "socket " + str(self.socket_id) + " send " + str(packet)) - response.verify_trace("TCPSocket::send() returned: " + str(size)) + response.verify_trace("Socket::send() returned: " + str(size)) received = 0 data = ""