diff --git a/extmod/modlwip.c b/extmod/modlwip.c index b4fe46bd42..af0d4cbcdf 100644 --- a/extmod/modlwip.c +++ b/extmod/modlwip.c @@ -308,6 +308,7 @@ typedef struct _lwip_socket_obj_t { #define STATE_CONNECTING 2 #define STATE_CONNECTED 3 #define STATE_PEER_CLOSED 4 + #define STATE_ACTIVE_UDP 5 // Negative value is lwIP error int8_t state; } lwip_socket_obj_t; @@ -812,9 +813,13 @@ STATIC mp_obj_t lwip_socket_make_new(const mp_obj_type_t *type, size_t n_args, s lwip_socket_obj_t *socket = m_new_obj_with_finaliser(lwip_socket_obj_t); socket->base.type = &lwip_socket_type; + socket->timeout = -1; + socket->recv_offset = 0; socket->domain = MOD_NETWORK_AF_INET; socket->type = MOD_NETWORK_SOCK_STREAM; socket->callback = MP_OBJ_NULL; + socket->state = STATE_NEW; + if (n_args >= 1) { socket->domain = mp_obj_get_int(args[0]); if (n_args >= 2) { @@ -856,6 +861,7 @@ STATIC mp_obj_t lwip_socket_make_new(const mp_obj_type_t *type, size_t n_args, s break; } case MOD_NETWORK_SOCK_DGRAM: { + socket->state = STATE_ACTIVE_UDP; // Register our receive callback now. Since UDP sockets don't require binding or connection // before use, there's no other good time to do it. udp_recv(socket->pcb.udp, _lwip_udp_incoming, (void *)socket); @@ -871,9 +877,6 @@ STATIC mp_obj_t lwip_socket_make_new(const mp_obj_type_t *type, size_t n_args, s #endif } - socket->timeout = -1; - socket->state = STATE_NEW; - socket->recv_offset = 0; return MP_OBJ_FROM_PTR(socket); } diff --git a/tests/extmod/uselect_poll_udp.py b/tests/extmod/uselect_poll_udp.py new file mode 100644 index 0000000000..e7d7dfe341 --- /dev/null +++ b/tests/extmod/uselect_poll_udp.py @@ -0,0 +1,28 @@ +# test select.poll on UDP sockets + +try: + import usocket as socket, uselect as select +except ImportError: + try: + import socket, select + except ImportError: + print("SKIP") + raise SystemExit + + +s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) +s.bind(socket.getaddrinfo('127.0.0.1', 8000)[0][-1]) +poll = select.poll() + +# UDP socket should not be readable +poll.register(s, select.POLLIN) +print(len(poll.poll(0))) + +# UDP socket should be writable +poll.modify(s, select.POLLOUT) +print(poll.poll(0)[0][1] == select.POLLOUT) + +# same test for select.select, but just skip it if the function isn't available +if hasattr(select, "select"): + r, w, e = select.select([s], [], [], 0) + assert not r and not w and not e