diff --git a/extmod/modtls_mbedtls.c b/extmod/modtls_mbedtls.c index a0f2cd1699..6cb4f00205 100644 --- a/extmod/modtls_mbedtls.c +++ b/extmod/modtls_mbedtls.c @@ -67,6 +67,7 @@ typedef struct _mp_obj_ssl_context_t { mbedtls_pk_context pkey; int authmode; int *ciphersuites; + mp_obj_t handler; } mp_obj_ssl_context_t; // This corresponds to an SSLSocket object. @@ -188,6 +189,16 @@ STATIC void ssl_check_async_handshake_failure(mp_obj_ssl_socket_t *sslsock, int } } +STATIC int ssl_sock_cert_verify(void *ptr, mbedtls_x509_crt *crt, int depth, uint32_t *flags) { + mp_obj_ssl_context_t *o = ptr; + if (o->handler == mp_const_none) { + return 0; + } + mp_obj_array_t cert; + mp_obj_memoryview_init(&cert, 'B', 0, crt->raw.len, crt->raw.p); + return mp_obj_get_int(mp_call_function_2(o->handler, MP_OBJ_FROM_PTR(&cert), MP_OBJ_NEW_SMALL_INT(depth))); +} + /******************************************************************************/ // SSLContext type. @@ -213,6 +224,7 @@ STATIC mp_obj_t ssl_context_make_new(const mp_obj_type_t *type_in, size_t n_args mbedtls_x509_crt_init(&self->cert); mbedtls_pk_init(&self->pkey); self->ciphersuites = NULL; + self->handler = mp_const_none; #ifdef MBEDTLS_DEBUG_C // Debug level (0-4) 1=warning, 2=info, 3=debug, 4=verbose @@ -243,6 +255,7 @@ STATIC mp_obj_t ssl_context_make_new(const mp_obj_type_t *type_in, size_t n_args self->authmode = MBEDTLS_SSL_VERIFY_NONE; } mbedtls_ssl_conf_authmode(&self->conf, self->authmode); + mbedtls_ssl_conf_verify(&self->conf, &ssl_sock_cert_verify, self); mbedtls_ssl_conf_rng(&self->conf, mbedtls_ctr_drbg_random, &self->ctr_drbg); #ifdef MBEDTLS_DEBUG_C mbedtls_ssl_conf_dbg(&self->conf, mbedtls_debug, NULL); @@ -257,6 +270,8 @@ STATIC void ssl_context_attr(mp_obj_t self_in, qstr attr, mp_obj_t *dest) { // Load attribute. if (attr == MP_QSTR_verify_mode) { dest[0] = MP_OBJ_NEW_SMALL_INT(self->authmode); + } else if (attr == MP_QSTR_verify_callback) { + dest[0] = self->handler; } else { // Continue lookup in locals_dict. dest[1] = MP_OBJ_SENTINEL; @@ -267,6 +282,9 @@ STATIC void ssl_context_attr(mp_obj_t self_in, qstr attr, mp_obj_t *dest) { self->authmode = mp_obj_get_int(dest[1]); dest[0] = MP_OBJ_NULL; mbedtls_ssl_conf_authmode(&self->conf, self->authmode); + } else if (attr == MP_QSTR_verify_callback) { + dest[0] = MP_OBJ_NULL; + self->handler = dest[1]; } } } diff --git a/tests/multi_net/sslcontext_verify_callback.py b/tests/multi_net/sslcontext_verify_callback.py new file mode 100644 index 0000000000..74b9738245 --- /dev/null +++ b/tests/multi_net/sslcontext_verify_callback.py @@ -0,0 +1,68 @@ +# Test creating an SSL connection and getting the peer certificate. + +try: + import io + import os + import socket + import tls +except ImportError: + print("SKIP") + raise SystemExit + +PORT = 8000 + +# These are test certificates. See tests/README.md for details. +cert = cafile = "ec_cert.der" +key = "ec_key.der" + +try: + with open(cafile, "rb") as f: + cadata = f.read() + with open(key, "rb") as f: + key = f.read() +except OSError: + print("SKIP") + raise SystemExit + + +def verify_callback(cert, depth): + print(cert.hex()) + return 0 + + +# Server +def instance0(): + multitest.globals(IP=multitest.get_network_ip()) + s = socket.socket() + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(socket.getaddrinfo("0.0.0.0", PORT)[0][-1]) + s.listen(1) + multitest.next() + s2, _ = s.accept() + server_ctx = tls.SSLContext(tls.PROTOCOL_TLS_SERVER) + server_ctx.load_cert_chain(cadata, key) + s2 = server_ctx.wrap_socket(s2, server_side=True) + print(s2.read(16)) + s2.write(b"server to client") + s2.close() + s.close() + + +# Client +def instance1(): + s_test = tls.SSLContext(tls.PROTOCOL_TLS_CLIENT) + if not hasattr(s_test, "verify_callback"): + print("SKIP") + raise SystemExit + + multitest.next() + s = socket.socket() + s.connect(socket.getaddrinfo(IP, PORT)[0][-1]) + client_ctx = tls.SSLContext(tls.PROTOCOL_TLS_CLIENT) + client_ctx.verify_mode = tls.CERT_REQUIRED + client_ctx.verify_callback = verify_callback + client_ctx.load_verify_locations(cadata) + s = client_ctx.wrap_socket(s, server_hostname="micropython.local") + s.write(b"client to server") + print(s.read(16)) + s.close() diff --git a/tests/multi_net/sslcontext_verify_callback.py.exp b/tests/multi_net/sslcontext_verify_callback.py.exp new file mode 100644 index 0000000000..e7a0ab0b46 --- /dev/null +++ b/tests/multi_net/sslcontext_verify_callback.py.exp @@ -0,0 +1,5 @@ +--- instance0 --- +b'client to server' +--- instance1 --- +308201d330820179a00302010202144315a7cd8f69febe2640314e7c97d60a2523ad15300a06082a8648ce3d040302303f311a301806035504030c116d6963726f707974686f6e2e6c6f63616c31143012060355040a0c0b4d6963726f507974686f6e310b3009060355040613024155301e170d3234303131343034353335335a170d3235303131333034353335335a303f311a301806035504030c116d6963726f707974686f6e2e6c6f63616c31143012060355040a0c0b4d6963726f507974686f6e310b30090603550406130241553059301306072a8648ce3d020106082a8648ce3d0301070342000449b7f5fa687cb25a9464c397508149992f445c860bcf7002958eb4337636c6af840cd4c8cf3b96f2384860d8ae3ee3fa135dba051e8605e62bd871689c6af43ca3533051301d0603551d0e0416041441b3ae171d91e330411d8543ba45e0f2d5b2951b301f0603551d2304183016801441b3ae171d91e330411d8543ba45e0f2d5b2951b300f0603551d130101ff040530030101ff300a06082a8648ce3d04030203480030450220587f61c34739d6fab5802a674dcc54443ae9c87da374078c4ee1cd83f4ad1694022100cfc45dcf264888c6ba2c36e78bd27bb67856d7879a052dd7aa7ecf7215f7b992 +b'server to client' diff --git a/tests/net_hosted/ssl_verify_callback.py b/tests/net_hosted/ssl_verify_callback.py new file mode 100644 index 0000000000..0dba4e4fdd --- /dev/null +++ b/tests/net_hosted/ssl_verify_callback.py @@ -0,0 +1,37 @@ +# test ssl verify_callback + +import io +import socket +import tls + + +def verify_callback(cert, depth): + print("verify_callback:", type(cert), len(cert) > 100, depth) + return 0 + + +def verify_callback_fail(cert, depth): + print("verify_callback_fail:", type(cert), len(cert) > 100, depth) + return 1 + + +def test(peer_addr): + context = tls.SSLContext(tls.PROTOCOL_TLS_CLIENT) + context.verify_mode = tls.CERT_OPTIONAL + context.verify_callback = verify_callback + s = socket.socket() + s.connect(peer_addr) + s = context.wrap_socket(s) + s.close() + + context.verify_callback = verify_callback_fail + s = socket.socket() + s.connect(peer_addr) + try: + s = context.wrap_socket(s) + except OSError as e: + print(e.args[1]) + + +if __name__ == "__main__": + test(socket.getaddrinfo("micropython.org", 443)[0][-1]) diff --git a/tests/net_hosted/ssl_verify_callback.py.exp b/tests/net_hosted/ssl_verify_callback.py.exp new file mode 100644 index 0000000000..e27dcbb9d5 --- /dev/null +++ b/tests/net_hosted/ssl_verify_callback.py.exp @@ -0,0 +1,5 @@ +verify_callback: True 2 +verify_callback: True 1 +verify_callback: True 0 +verify_callback_fail: True 2 +MBEDTLS_ERR_ERROR_GENERIC_ERROR