diff --git a/docs/library/ussl.rst b/docs/library/ussl.rst index ffe146331c..14e3f3ad14 100644 --- a/docs/library/ussl.rst +++ b/docs/library/ussl.rst @@ -13,16 +13,23 @@ facilities for network sockets, both client-side and server-side. Functions --------- -.. function:: ussl.wrap_socket(sock, server_side=False, keyfile=None, certfile=None, cert_reqs=CERT_NONE, ca_certs=None) - +.. function:: ussl.wrap_socket(sock, server_side=False, keyfile=None, certfile=None, cert_reqs=CERT_NONE, ca_certs=None, do_handshake=True) Takes a `stream` *sock* (usually usocket.socket instance of ``SOCK_STREAM`` type), and returns an instance of ssl.SSLSocket, which wraps the underlying stream in an SSL context. Returned object has the usual `stream` interface methods like - ``read()``, ``write()``, etc. In MicroPython, the returned object does not expose - socket interface and methods like ``recv()``, ``send()``. In particular, a - server-side SSL socket should be created from a normal socket returned from + ``read()``, ``write()``, etc. + A server-side SSL socket should be created from a normal socket returned from :meth:`~usocket.socket.accept()` on a non-SSL listening server socket. + - *do_handshake* determines whether the handshake is done as part of the ``wrap_socket`` + or whether it is deferred to be done as part of the initial reads or writes + (there is no ``do_handshake`` method as in CPython). + For blocking sockets doing the handshake immediately is standard. For non-blocking + sockets (i.e. when the *sock* passed into ``wrap_socket`` is in non-blocking mode) + the handshake should generally be deferred because otherwise ``wrap_socket`` blocks + until it completes. Note that in AXTLS the handshake can be deferred until the first + read or write but it then blocks until completion. + Depending on the underlying module implementation in a particular :term:`MicroPython port`, some or all keyword arguments above may be not supported. @@ -31,6 +38,11 @@ Functions Some implementations of ``ussl`` module do NOT validate server certificates, which makes an SSL connection established prone to man-in-the-middle attacks. + CPython's ``wrap_socket`` returns an ``SSLSocket`` object which has methods typical + for sockets, such as ``send``, ``recv``, etc. MicroPython's ``wrap_socket`` + returns an object more similar to CPython's ``SSLObject`` which does not have + these socket methods. + Exceptions ---------- diff --git a/extmod/modussl_axtls.c b/extmod/modussl_axtls.c index da5941a55b..9d59342067 100644 --- a/extmod/modussl_axtls.c +++ b/extmod/modussl_axtls.c @@ -167,10 +167,15 @@ STATIC mp_obj_ssl_socket_t *ussl_socket_new(mp_obj_t sock, struct ssl_args *args o->ssl_sock = ssl_client_new(o->ssl_ctx, (long)sock, NULL, 0, ext); if (args->do_handshake.u_bool) { - int res = ssl_handshake_status(o->ssl_sock); + int r = ssl_handshake_status(o->ssl_sock); - if (res != SSL_OK) { - ussl_raise_error(res); + if (r != SSL_OK) { + if (r == SSL_CLOSE_NOTIFY) { // EOF + r = MP_ENOTCONN; + } else if (r == SSL_EAGAIN) { + r = MP_EAGAIN; + } + ussl_raise_error(r); } } @@ -242,8 +247,24 @@ STATIC mp_uint_t ussl_socket_write(mp_obj_t o_in, const void *buf, mp_uint_t siz return MP_STREAM_ERROR; } - mp_int_t r = ssl_write(o->ssl_sock, buf, size); + mp_int_t r; +eagain: + r = ssl_write(o->ssl_sock, buf, size); + if (r == 0) { + // see comment in ussl_socket_read above + if (o->blocking) { + goto eagain; + } else { + r = SSL_EAGAIN; + } + } if (r < 0) { + if (r == SSL_CLOSE_NOTIFY || r == SSL_ERROR_CONN_LOST) { + return 0; // EOF + } + if (r == SSL_EAGAIN) { + r = MP_EAGAIN; + } *errcode = r; return MP_STREAM_ERROR; } diff --git a/extmod/modussl_mbedtls.c b/extmod/modussl_mbedtls.c index 1677dc6e1c..277af37c7c 100644 --- a/extmod/modussl_mbedtls.c +++ b/extmod/modussl_mbedtls.c @@ -133,6 +133,7 @@ STATIC int _mbedtls_ssl_send(void *ctx, const byte *buf, size_t len) { } } +// _mbedtls_ssl_recv is called by mbedtls to receive bytes from the underlying socket STATIC int _mbedtls_ssl_recv(void *ctx, byte *buf, size_t len) { mp_obj_t sock = *(mp_obj_t *)ctx; @@ -171,7 +172,7 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) { mbedtls_pk_init(&o->pkey); mbedtls_ctr_drbg_init(&o->ctr_drbg); #ifdef MBEDTLS_DEBUG_C - // Debug level (0-4) + // Debug level (0-4) 1=warning, 2=info, 3=debug, 4=verbose mbedtls_debug_set_threshold(0); #endif diff --git a/ports/esp32/modsocket.c b/ports/esp32/modsocket.c index 61761d8194..5135e31631 100644 --- a/ports/esp32/modsocket.c +++ b/ports/esp32/modsocket.c @@ -558,7 +558,8 @@ int _socket_send(socket_obj_t *sock, const char *data, size_t datalen) { MP_THREAD_GIL_EXIT(); int r = lwip_write(sock->fd, data + sentlen, datalen - sentlen); MP_THREAD_GIL_ENTER(); - if (r < 0 && errno != EWOULDBLOCK) { + // lwip returns EINPROGRESS when trying to send right after a non-blocking connect + if (r < 0 && errno != EWOULDBLOCK && errno != EINPROGRESS) { mp_raise_OSError(errno); } if (r > 0) { @@ -567,7 +568,7 @@ int _socket_send(socket_obj_t *sock, const char *data, size_t datalen) { check_for_exceptions(); } if (sentlen == 0) { - mp_raise_OSError(MP_ETIMEDOUT); + mp_raise_OSError(sock->retries == 0 ? MP_EWOULDBLOCK : MP_ETIMEDOUT); } return sentlen; } @@ -650,7 +651,8 @@ STATIC mp_uint_t socket_stream_write(mp_obj_t self_in, const void *buf, mp_uint_ if (r > 0) { return r; } - if (r < 0 && errno != EWOULDBLOCK) { + // lwip returns MP_EINPROGRESS when trying to write right after a non-blocking connect + if (r < 0 && errno != EWOULDBLOCK && errno != EINPROGRESS) { *errcode = errno; return MP_STREAM_ERROR; } diff --git a/tests/net_hosted/accept_timeout.py b/tests/net_hosted/accept_timeout.py index ff989110ae..5f528d557d 100644 --- a/tests/net_hosted/accept_timeout.py +++ b/tests/net_hosted/accept_timeout.py @@ -1,9 +1,9 @@ # test that socket.accept() on a socket with timeout raises ETIMEDOUT try: - import usocket as socket + import uerrno as errno, usocket as socket except: - import socket + import errno, socket try: socket.socket.settimeout @@ -18,5 +18,5 @@ s.listen(1) try: s.accept() except OSError as er: - print(er.args[0] in (110, "timed out")) # 110 is ETIMEDOUT; CPython uses a string + print(er.args[0] in (errno.ETIMEDOUT, "timed out")) # CPython uses a string instead of errno s.close() diff --git a/tests/net_hosted/connect_nonblock_xfer.py b/tests/net_hosted/connect_nonblock_xfer.py new file mode 100644 index 0000000000..feb648ea0a --- /dev/null +++ b/tests/net_hosted/connect_nonblock_xfer.py @@ -0,0 +1,147 @@ +# test that socket.connect() on a non-blocking socket raises EINPROGRESS +# and that an immediate write/send/read/recv does the right thing + +try: + import sys, time + import uerrno as errno, usocket as socket, ussl as ssl +except: + import socket, errno, ssl +isMP = sys.implementation.name == "micropython" + + +def dp(e): + # uncomment next line for development and testing, to print the actual exceptions + # print(repr(e)) + pass + + +# do_connect establishes the socket and wraps it if tls is True. +# If handshake is true, the initial connect (and TLS handshake) is +# allowed to be performed before returning. +def do_connect(peer_addr, tls, handshake): + s = socket.socket() + s.setblocking(False) + try: + # print("Connecting to", peer_addr) + s.connect(peer_addr) + except OSError as er: + print("connect:", er.args[0] == errno.EINPROGRESS) + if er.args[0] != errno.EINPROGRESS: + print(" got", er.args[0]) + # wrap with ssl/tls if desired + if tls: + try: + if sys.implementation.name == "micropython": + s = ssl.wrap_socket(s, do_handshake=handshake) + else: + s = ssl.wrap_socket(s, do_handshake_on_connect=handshake) + print("wrap: True") + except Exception as e: + dp(e) + print("wrap:", e) + elif handshake: + # just sleep a little bit, this allows any connect() errors to happen + time.sleep(0.2) + return s + + +# test runs the test against a specific peer address. +def test(peer_addr, tls=False, handshake=False): + # MicroPython plain sockets have read/write, but CPython's don't + # MicroPython TLS sockets and CPython's have read/write + # hasRW captures this wonderful state of affairs + hasRW = isMP or tls + + # MicroPython plain sockets and CPython's have send/recv + # MicroPython TLS sockets don't have send/recv, but CPython's do + # hasSR captures this wonderful state of affairs + hasSR = not (isMP and tls) + + # connect + send + if hasSR: + s = do_connect(peer_addr, tls, handshake) + # send -> 4 or EAGAIN + try: + ret = s.send(b"1234") + print("send:", handshake and ret == 4) + except OSError as er: + # + dp(er) + print("send:", er.args[0] in (errno.EAGAIN, errno.EINPROGRESS)) + s.close() + else: # fake it... + print("connect:", True) + if tls: + print("wrap:", True) + print("send:", True) + + # connect + write + if hasRW: + s = do_connect(peer_addr, tls, handshake) + # write -> None + try: + ret = s.write(b"1234") + print("write:", ret in (4, None)) # SSL may accept 4 into buffer + except OSError as er: + dp(er) + print("write:", False) # should not raise + except ValueError as er: # CPython + dp(er) + print("write:", er.args[0] == "Write on closed or unwrapped SSL socket.") + s.close() + else: # fake it... + print("connect:", True) + if tls: + print("wrap:", True) + print("write:", True) + + if hasSR: + # connect + recv + s = do_connect(peer_addr, tls, handshake) + # recv -> EAGAIN + try: + print("recv:", s.recv(10)) + except OSError as er: + dp(er) + print("recv:", er.args[0] == errno.EAGAIN) + s.close() + else: # fake it... + print("connect:", True) + if tls: + print("wrap:", True) + print("recv:", True) + + # connect + read + if hasRW: + s = do_connect(peer_addr, tls, handshake) + # read -> None + try: + ret = s.read(10) + print("read:", ret is None) + except OSError as er: + dp(er) + print("read:", False) # should not raise + except ValueError as er: # CPython + dp(er) + print("read:", er.args[0] == "Read on closed or unwrapped SSL socket.") + s.close() + else: # fake it... + print("connect:", True) + if tls: + print("wrap:", True) + print("read:", True) + + +if __name__ == "__main__": + # these tests use a non-existent test IP address, this way the connect takes forever and + # we can see EAGAIN/None (https://tools.ietf.org/html/rfc5737) + print("--- Plain sockets to nowhere ---") + test(socket.getaddrinfo("192.0.2.1", 80)[0][-1], False, False) + print("--- SSL sockets to nowhere ---") + # this test fails with AXTLS because do_handshake=False blocks on first read/write and + # there it times out until the connect is aborted + test(socket.getaddrinfo("192.0.2.1", 443)[0][-1], True, False) + print("--- Plain sockets ---") + test(socket.getaddrinfo("micropython.org", 80)[0][-1], False, True) + print("--- SSL sockets ---") + test(socket.getaddrinfo("micropython.org", 443)[0][-1], True, True) diff --git a/tests/net_inet/ssl_errors.py b/tests/net_inet/ssl_errors.py new file mode 100644 index 0000000000..fd281b1c49 --- /dev/null +++ b/tests/net_inet/ssl_errors.py @@ -0,0 +1,51 @@ +# test that socket.connect() on a non-blocking socket raises EINPROGRESS +# and that an immediate write/send/read/recv does the right thing + +import sys + +try: + import uerrno as errno, usocket as socket, ussl as ssl +except: + import errno, socket, ssl + + +def test(addr, hostname, block=True): + print("---", hostname or addr) + s = socket.socket() + s.setblocking(block) + try: + s.connect(addr) + print("connected") + except OSError as e: + if e.args[0] != errno.EINPROGRESS: + raise + print("EINPROGRESS") + + try: + if sys.implementation.name == "micropython": + s = ssl.wrap_socket(s, do_handshake=block) + else: + s = ssl.wrap_socket(s, do_handshake_on_connect=block) + print("wrap: True") + except OSError: + print("wrap: error") + + if not block: + try: + while s.write(b"0") is None: + pass + except (ValueError, OSError): # CPython raises ValueError, MicroPython raises OSError + print("write: error") + s.close() + + +if __name__ == "__main__": + # connect to plain HTTP port, oops! + addr = socket.getaddrinfo("micropython.org", 80)[0][-1] + test(addr, None) + # connect to plain HTTP port, oops! + addr = socket.getaddrinfo("micropython.org", 80)[0][-1] + test(addr, None, False) + # connect to server with self-signed cert, oops! + addr = socket.getaddrinfo("test.mosquitto.org", 8883)[0][-1] + test(addr, "test.mosquitto.org") diff --git a/tests/net_inet/test_tls_nonblock.py b/tests/net_inet/test_tls_nonblock.py new file mode 100644 index 0000000000..c27ead3d50 --- /dev/null +++ b/tests/net_inet/test_tls_nonblock.py @@ -0,0 +1,116 @@ +try: + import usocket as socket, ussl as ssl, uerrno as errno, sys +except: + import socket, ssl, errno, sys, time, select + + +def test_one(site, opts): + ai = socket.getaddrinfo(site, 443) + addr = ai[0][-1] + print(addr) + + # Connect the raw socket + s = socket.socket() + s.setblocking(False) + try: + s.connect(addr) + raise OSError(-1, "connect blocks") + except OSError as e: + if e.args[0] != errno.EINPROGRESS: + raise + + if sys.implementation.name != "micropython": + # in CPython we have to wait, otherwise wrap_socket is not happy + select.select([], [s], []) + + try: + # Wrap with SSL + try: + if sys.implementation.name == "micropython": + s = ssl.wrap_socket(s, do_handshake=False) + else: + s = ssl.wrap_socket(s, do_handshake_on_connect=False) + except OSError as e: + if e.args[0] != errno.EINPROGRESS: + raise + print("wrapped") + + # CPython needs to be told to do the handshake + if sys.implementation.name != "micropython": + while True: + try: + s.do_handshake() + break + except ssl.SSLError as err: + if err.args[0] == ssl.SSL_ERROR_WANT_READ: + select.select([s], [], []) + elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE: + select.select([], [s], []) + else: + raise + time.sleep(0.1) + # print("shook hands") + + # Write HTTP request + out = b"GET / HTTP/1.0\r\nHost: %s\r\n\r\n" % bytes(site, "latin") + while len(out) > 0: + n = s.write(out) + if n is None: + continue + if n > 0: + out = out[n:] + elif n == 0: + raise OSError(-1, "unexpected EOF in write") + print("wrote") + + # Read response + resp = b"" + while True: + try: + b = s.read(128) + except OSError as err: + if err.args[0] == 2: # 2=ssl.SSL_ERROR_WANT_READ: + continue + raise + if b is None: + continue + if len(b) > 0: + if len(resp) < 1024: + resp += b + elif len(b) == 0: + break + print("read") + + if resp[:7] != b"HTTP/1.": + raise ValueError("response doesn't start with HTTP/1.") + # print(resp) + + finally: + s.close() + + +SITES = [ + "google.com", + {"host": "www.google.com"}, + "micropython.org", + "pypi.org", + "api.telegram.org", + {"host": "api.pushbullet.com", "sni": True}, +] + + +def main(): + for site in SITES: + opts = {} + if isinstance(site, dict): + opts = site + site = opts["host"] + try: + test_one(site, opts) + print(site, "ok") + except Exception as e: + print(site, "error") + print("DONE") + + +main() diff --git a/tests/net_inet/test_tls_sites.py b/tests/net_inet/test_tls_sites.py index d2cb928c8d..3f945efb83 100644 --- a/tests/net_inet/test_tls_sites.py +++ b/tests/net_inet/test_tls_sites.py @@ -27,6 +27,8 @@ def test_one(site, opts): s.write(b"GET / HTTP/1.0\r\nHost: %s\r\n\r\n" % bytes(site, "latin")) resp = s.read(4096) + if resp[:7] != b"HTTP/1.": + raise ValueError("response doesn't start with HTTP/1.") # print(resp) finally: @@ -36,10 +38,10 @@ def test_one(site, opts): SITES = [ "google.com", "www.google.com", + "micropython.org", + "pypi.org", "api.telegram.org", {"host": "api.pushbullet.com", "sni": True}, - # "w9rybpfril.execute-api.ap-southeast-2.amazonaws.com", - {"host": "w9rybpfril.execute-api.ap-southeast-2.amazonaws.com", "sni": True}, ] diff --git a/tests/net_inet/test_tls_sites.py.exp b/tests/net_inet/test_tls_sites.py.exp index 2f3c113d2f..bc4a8dbd11 100644 --- a/tests/net_inet/test_tls_sites.py.exp +++ b/tests/net_inet/test_tls_sites.py.exp @@ -1,5 +1,6 @@ google.com ok www.google.com ok +micropython.org ok +pypi.org ok api.telegram.org ok api.pushbullet.com ok -w9rybpfril.execute-api.ap-southeast-2.amazonaws.com ok