From a75ca8a1c07d7da1d0f97189cbe3238f65dda2df Mon Sep 17 00:00:00 2001 From: Angus Gratton Date: Wed, 10 Jan 2024 14:11:00 +1100 Subject: [PATCH] esp32/modsocket: Use all supplied arguments to socket.getaddrinfo(). - Completes a longstanding TODO in the code, to not ignore the optional family, type, proto and flags arguments to socket.getaddrinfo(). - Note that passing family=socket.AF_INET6 will now cause queries to fail (OSError -202). Previously this argument was ignored so IPV4 results were returned instead. - Optional 'type' argument is now always copied into the result. If not set, results have type SOCK_STREAM. - Fixes inconsistency where previously querying mDNS local suffix (.local) hostnames returned results with socket type 0 (invalid), but all other queries returned results with socket type SOCK_STREAM (regardless of 'type' argument). - Optional proto argument is now returned in the result tuple, if supplied. - Optional flags argument is now passed through to lwIP. lwIP has handling for AI_NUMERICHOST, AI_V4MAPPED, AI_PASSIVE (untested, constants for these are not currently exposed in the esp32 socket module). - Also fixes a possible memory leak in an obscure code path (lwip_getaddrinfo apparently sometimes returns a result structure with address "0.0.0.0" instead of failing, and this structure would have been leaked.) This work was funded through GitHub Sponsors. Signed-off-by: Angus Gratton --- ports/esp32/modsocket.c | 161 ++++++++++++++++++++++++---------------- 1 file changed, 97 insertions(+), 64 deletions(-) diff --git a/ports/esp32/modsocket.c b/ports/esp32/modsocket.c index ca8c63030c..af94940156 100644 --- a/ports/esp32/modsocket.c +++ b/ports/esp32/modsocket.c @@ -156,64 +156,67 @@ static inline void check_for_exceptions(void) { mp_handle_pending(true); } -// This function mimics lwip_getaddrinfo, with added support for mDNS queries -static int _socket_getaddrinfo3(const char *nodename, const char *servname, +#if MICROPY_HW_ENABLE_MDNS_QUERIES +// This function mimics lwip_getaddrinfo, but makes an mDNS query +STATIC int mdns_getaddrinfo(const char *host_str, const char *port_str, const struct addrinfo *hints, struct addrinfo **res) { - - #if MICROPY_HW_ENABLE_MDNS_QUERIES - int nodename_len = strlen(nodename); + int host_len = strlen(host_str); const int local_len = sizeof(MDNS_LOCAL_SUFFIX) - 1; - if (nodename_len > local_len - && strcasecmp(nodename + nodename_len - local_len, MDNS_LOCAL_SUFFIX) == 0) { - // mDNS query - char nodename_no_local[nodename_len - local_len + 1]; - memcpy(nodename_no_local, nodename, nodename_len - local_len); - nodename_no_local[nodename_len - local_len] = '\0'; - - esp_ip4_addr_t addr = {0}; - - esp_err_t err = mdns_query_a(nodename_no_local, MDNS_QUERY_TIMEOUT_MS, &addr); - if (err != ESP_OK) { - if (err == ESP_ERR_NOT_FOUND) { - *res = NULL; - return 0; - } - *res = NULL; - return err; - } - - struct addrinfo *ai = memp_malloc(MEMP_NETDB); - if (ai == NULL) { - *res = NULL; - return EAI_MEMORY; - } - memset(ai, 0, sizeof(struct addrinfo) + sizeof(struct sockaddr_storage)); - - struct sockaddr_in *sa = (struct sockaddr_in *)((uint8_t *)ai + sizeof(struct addrinfo)); - inet_addr_from_ip4addr(&sa->sin_addr, &addr); - sa->sin_family = AF_INET; - sa->sin_len = sizeof(struct sockaddr_in); - sa->sin_port = lwip_htons((u16_t)atoi(servname)); - ai->ai_family = AF_INET; - ai->ai_canonname = ((char *)sa + sizeof(struct sockaddr_storage)); - memcpy(ai->ai_canonname, nodename, nodename_len + 1); - ai->ai_addrlen = sizeof(struct sockaddr_storage); - ai->ai_addr = (struct sockaddr *)sa; - - *res = ai; + if (host_len <= local_len || + strcasecmp(host_str + host_len - local_len, MDNS_LOCAL_SUFFIX) != 0) { return 0; } - #endif - // Normal query - return lwip_getaddrinfo(nodename, servname, hints, res); + // mDNS query + char host_no_local[host_len - local_len + 1]; + memcpy(host_no_local, host_str, host_len - local_len); + host_no_local[host_len - local_len] = '\0'; + + esp_ip4_addr_t addr = {0}; + + esp_err_t err = mdns_query_a(host_no_local, MDNS_QUERY_TIMEOUT_MS, &addr); + if (err != ESP_OK) { + if (err == ESP_ERR_NOT_FOUND) { + *res = NULL; + return 0; + } + *res = NULL; + return err; + } + + struct addrinfo *ai = memp_malloc(MEMP_NETDB); + if (ai == NULL) { + *res = NULL; + return EAI_MEMORY; + } + memset(ai, 0, sizeof(struct addrinfo) + sizeof(struct sockaddr_storage)); + + struct sockaddr_in *sa = (struct sockaddr_in *)((uint8_t *)ai + sizeof(struct addrinfo)); + inet_addr_from_ip4addr(&sa->sin_addr, &addr); + sa->sin_family = AF_INET; + sa->sin_len = sizeof(struct sockaddr_in); + sa->sin_port = lwip_htons((u16_t)atoi(port_str)); + ai->ai_family = AF_INET; + ai->ai_canonname = ((char *)sa + sizeof(struct sockaddr_storage)); + memcpy(ai->ai_canonname, host_str, host_len + 1); + ai->ai_addrlen = sizeof(struct sockaddr_storage); + ai->ai_addr = (struct sockaddr *)sa; + ai->ai_socktype = SOCK_STREAM; + if (hints) { + ai->ai_socktype = hints->ai_socktype; + ai->ai_protocol = hints->ai_protocol; + } + + *res = ai; + return 0; } +#endif // MICROPY_HW_ENABLE_MDNS_QUERIES -static int _socket_getaddrinfo2(const mp_obj_t host, const mp_obj_t portx, struct addrinfo **resp) { - const struct addrinfo hints = { - .ai_family = AF_INET, - .ai_socktype = SOCK_STREAM, - }; +static void _getaddrinfo_inner(const mp_obj_t host, const mp_obj_t portx, + const struct addrinfo *hints, struct addrinfo **res) { + int retval = 0; + + *res = NULL; mp_obj_t port = portx; if (mp_obj_is_integer(port)) { @@ -231,27 +234,37 @@ static int _socket_getaddrinfo2(const mp_obj_t host, const mp_obj_t portx, struc } MP_THREAD_GIL_EXIT(); - int res = _socket_getaddrinfo3(host_str, port_str, &hints, resp); + + #if MICROPY_HW_ENABLE_MDNS_QUERIES + retval = mdns_getaddrinfo(host_str, port_str, hints, res); + #endif + + if (retval == 0 && *res == NULL) { + // Normal query + retval = lwip_getaddrinfo(host_str, port_str, hints, res); + } + MP_THREAD_GIL_ENTER(); // Per docs: instead of raising gaierror getaddrinfo raises negative error number - if (res != 0) { - mp_raise_OSError(res > 0 ? -res : res); + if (retval != 0) { + mp_raise_OSError(retval > 0 ? -retval : retval); } // Somehow LwIP returns a resolution of 0.0.0.0 for failed lookups, traced it as far back // as netconn_gethostbyname_addrtype returning OK instead of error. - if (*resp == NULL || - (strcmp(resp[0]->ai_canonname, "0.0.0.0") == 0 && strcmp(host_str, "0.0.0.0") != 0)) { + if (*res == NULL || + (strcmp(res[0]->ai_canonname, "0.0.0.0") == 0 && strcmp(host_str, "0.0.0.0") != 0)) { + lwip_freeaddrinfo(*res); mp_raise_OSError(-2); // name or service not known } - return res; + assert(retval == 0 && *res != NULL); } STATIC void _socket_getaddrinfo(const mp_obj_t addrtuple, struct addrinfo **resp) { mp_obj_t *elem; mp_obj_get_array_fixed_n(addrtuple, 2, &elem); - _socket_getaddrinfo2(elem[0], elem[1], resp); + _getaddrinfo_inner(elem[0], elem[1], NULL, resp); } STATIC mp_obj_t socket_make_new(const mp_obj_type_t *type_in, size_t n_args, size_t n_kw, const mp_obj_t *args) { @@ -897,10 +910,32 @@ STATIC MP_DEFINE_CONST_OBJ_TYPE( ); STATIC mp_obj_t esp_socket_getaddrinfo(size_t n_args, const mp_obj_t *args) { - // TODO support additional args beyond the first two - + struct addrinfo hints = { }; struct addrinfo *res = NULL; - _socket_getaddrinfo2(args[0], args[1], &res); + + // Optional args: family=0, type=0, proto=0, flags=0, where 0 is "least narrow" + if (n_args > 2) { + hints.ai_family = mp_obj_get_int(args[2]); + } + if (n_args > 3) { + hints.ai_socktype = mp_obj_get_int(args[3]); + } + if (hints.ai_socktype == 0) { + // This is slightly different to CPython with POSIX getaddrinfo. In + // CPython, calling socket.getaddrinfo() with socktype=0 returns any/all + // supported SocketKind values. Here, lwip_getaddrinfo() will echo + // whatever socktype was supplied to the caller. Rather than returning 0 + // (invalid in a result), make it SOCK_STREAM. + hints.ai_socktype = SOCK_STREAM; + } + if (n_args > 4) { + hints.ai_protocol = mp_obj_get_int(args[4]); + } + if (n_args > 5) { + hints.ai_flags = mp_obj_get_int(args[5]); + } + + _getaddrinfo_inner(args[0], args[1], &hints, &res); mp_obj_t ret_list = mp_obj_new_list(0, NULL); for (struct addrinfo *resi = res; resi; resi = resi->ai_next) { @@ -927,9 +962,7 @@ STATIC mp_obj_t esp_socket_getaddrinfo(size_t n_args, const mp_obj_t *args) { mp_obj_list_append(ret_list, mp_obj_new_tuple(5, addrinfo_objs)); } - if (res) { - lwip_freeaddrinfo(res); - } + lwip_freeaddrinfo(res); return ret_list; } STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(esp_socket_getaddrinfo_obj, 2, 6, esp_socket_getaddrinfo);