diff --git a/extmod/modssl_axtls.c b/extmod/modssl_axtls.c index de6e0ce5dd..d169d89a2c 100644 --- a/extmod/modssl_axtls.c +++ b/extmod/modssl_axtls.c @@ -4,6 +4,7 @@ * The MIT License (MIT) * * Copyright (c) 2015-2019 Paul Sokolovsky + * Copyright (c) 2023 Damien P. George * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -35,6 +36,17 @@ #include "ssl.h" +#define PROTOCOL_TLS_CLIENT (0) +#define PROTOCOL_TLS_SERVER (1) + +// This corresponds to an SSLContext object. +typedef struct _mp_obj_ssl_context_t { + mp_obj_base_t base; + mp_obj_t key; + mp_obj_t cert; +} mp_obj_ssl_context_t; + +// This corresponds to an SSLSocket object. typedef struct _mp_obj_ssl_socket_t { mp_obj_base_t base; mp_obj_t sock; @@ -53,8 +65,15 @@ struct ssl_args { mp_arg_val_t do_handshake; }; +STATIC const mp_obj_type_t ssl_context_type; STATIC const mp_obj_type_t ssl_socket_type; +STATIC mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t sock, + bool server_side, bool do_handshake_on_connect, mp_obj_t server_hostname); + +/******************************************************************************/ +// Helper functions. + // Table of error strings corresponding to SSL_xxx error codes. STATIC const char *const ssl_error_tab1[] = { "NOT_OK", @@ -116,8 +135,71 @@ STATIC NORETURN void ssl_raise_error(int err) { nlr_raise(mp_obj_exception_make_new(&mp_type_OSError, 2, 0, args)); } +/******************************************************************************/ +// SSLContext type. + +STATIC mp_obj_t ssl_context_make_new(const mp_obj_type_t *type_in, size_t n_args, size_t n_kw, const mp_obj_t *args) { + mp_arg_check_num(n_args, n_kw, 1, 1, false); + + // The "protocol" argument is ignored in this implementation. + + // Create SSLContext object. + #if MICROPY_PY_SSL_FINALISER + mp_obj_ssl_context_t *self = m_new_obj_with_finaliser(mp_obj_ssl_context_t); + #else + mp_obj_ssl_context_t *self = m_new_obj(mp_obj_ssl_context_t); + #endif + self->base.type = type_in; + self->key = mp_const_none; + self->cert = mp_const_none; + + return MP_OBJ_FROM_PTR(self); +} + +STATIC void ssl_context_load_key(mp_obj_ssl_context_t *self, mp_obj_t key_obj, mp_obj_t cert_obj) { + self->key = key_obj; + self->cert = cert_obj; +} + +STATIC mp_obj_t ssl_context_wrap_socket(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) { + enum { ARG_server_side, ARG_do_handshake_on_connect, ARG_server_hostname }; + static const mp_arg_t allowed_args[] = { + { MP_QSTR_server_side, MP_ARG_KW_ONLY | MP_ARG_BOOL, {.u_bool = false} }, + { MP_QSTR_do_handshake_on_connect, MP_ARG_KW_ONLY | MP_ARG_BOOL, {.u_bool = true} }, + { MP_QSTR_server_hostname, MP_ARG_KW_ONLY | MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE} }, + }; + + // Parse arguments. + mp_obj_ssl_context_t *self = MP_OBJ_TO_PTR(pos_args[0]); + mp_obj_t sock = pos_args[1]; + mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)]; + mp_arg_parse_all(n_args - 2, pos_args + 2, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args); + + // Create and return the new SSLSocket object. + return ssl_socket_make_new(self, sock, args[ARG_server_side].u_bool, + args[ARG_do_handshake_on_connect].u_bool, args[ARG_server_hostname].u_obj); +} +STATIC MP_DEFINE_CONST_FUN_OBJ_KW(ssl_context_wrap_socket_obj, 2, ssl_context_wrap_socket); + +STATIC const mp_rom_map_elem_t ssl_context_locals_dict_table[] = { + { MP_ROM_QSTR(MP_QSTR_wrap_socket), MP_ROM_PTR(&ssl_context_wrap_socket_obj) }, +}; +STATIC MP_DEFINE_CONST_DICT(ssl_context_locals_dict, ssl_context_locals_dict_table); + +STATIC MP_DEFINE_CONST_OBJ_TYPE( + ssl_context_type, + MP_QSTR_SSLContext, + MP_TYPE_FLAG_NONE, + make_new, ssl_context_make_new, + locals_dict, &ssl_context_locals_dict + ); + +/******************************************************************************/ +// SSLSocket type. + +STATIC mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t sock, + bool server_side, bool do_handshake_on_connect, mp_obj_t server_hostname) { -STATIC mp_obj_ssl_socket_t *ssl_socket_new(mp_obj_t sock, struct ssl_args *args) { #if MICROPY_PY_SSL_FINALISER mp_obj_ssl_socket_t *o = m_new_obj_with_finaliser(mp_obj_ssl_socket_t); #else @@ -130,43 +212,43 @@ STATIC mp_obj_ssl_socket_t *ssl_socket_new(mp_obj_t sock, struct ssl_args *args) o->blocking = true; uint32_t options = SSL_SERVER_VERIFY_LATER; - if (!args->do_handshake.u_bool) { + if (!do_handshake_on_connect) { options |= SSL_CONNECT_IN_PARTS; } - if (args->key.u_obj != mp_const_none) { + if (ssl_context->key != mp_const_none) { options |= SSL_NO_DEFAULT_KEY; } if ((o->ssl_ctx = ssl_ctx_new(options, SSL_DEFAULT_CLNT_SESS)) == NULL) { mp_raise_OSError(MP_EINVAL); } - if (args->key.u_obj != mp_const_none) { + if (ssl_context->key != mp_const_none) { size_t len; - const byte *data = (const byte *)mp_obj_str_get_data(args->key.u_obj, &len); + const byte *data = (const byte *)mp_obj_str_get_data(ssl_context->key, &len); int res = ssl_obj_memory_load(o->ssl_ctx, SSL_OBJ_RSA_KEY, data, len, NULL); if (res != SSL_OK) { mp_raise_ValueError(MP_ERROR_TEXT("invalid key")); } - data = (const byte *)mp_obj_str_get_data(args->cert.u_obj, &len); + data = (const byte *)mp_obj_str_get_data(ssl_context->cert, &len); res = ssl_obj_memory_load(o->ssl_ctx, SSL_OBJ_X509_CERT, data, len, NULL); if (res != SSL_OK) { mp_raise_ValueError(MP_ERROR_TEXT("invalid cert")); } } - if (args->server_side.u_bool) { + if (server_side) { o->ssl_sock = ssl_server_new(o->ssl_ctx, (long)sock); } else { SSL_EXTENSIONS *ext = ssl_ext_new(); - if (args->server_hostname.u_obj != mp_const_none) { - ext->host_name = (char *)mp_obj_str_get_str(args->server_hostname.u_obj); + if (server_hostname != mp_const_none) { + ext->host_name = (char *)mp_obj_str_get_str(server_hostname); } o->ssl_sock = ssl_client_new(o->ssl_ctx, (long)sock, NULL, 0, ext); - if (args->do_handshake.u_bool) { + if (do_handshake_on_connect) { int r = ssl_handshake_status(o->ssl_sock); if (r != SSL_OK) { @@ -178,18 +260,11 @@ STATIC mp_obj_ssl_socket_t *ssl_socket_new(mp_obj_t sock, struct ssl_args *args) ssl_raise_error(r); } } - } return o; } -STATIC void ssl_socket_print(const mp_print_t *print, mp_obj_t self_in, mp_print_kind_t kind) { - (void)kind; - mp_obj_ssl_socket_t *self = MP_OBJ_TO_PTR(self_in); - mp_printf(print, "<_SSLSocket %p>", self->ssl_sock); -} - STATIC mp_uint_t ssl_socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errcode) { mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(o_in); @@ -305,7 +380,6 @@ STATIC const mp_rom_map_elem_t ssl_socket_locals_dict_table[] = { { MP_ROM_QSTR(MP_QSTR___del__), MP_ROM_PTR(&mp_stream_close_obj) }, #endif }; - STATIC MP_DEFINE_CONST_DICT(ssl_socket_locals_dict, ssl_socket_locals_dict_table); STATIC const mp_stream_p_t ssl_socket_stream_p = { @@ -316,16 +390,23 @@ STATIC const mp_stream_p_t ssl_socket_stream_p = { STATIC MP_DEFINE_CONST_OBJ_TYPE( ssl_socket_type, - // Save on qstr's, reuse same as for module - MP_QSTR_ssl, + MP_QSTR_SSLSocket, MP_TYPE_FLAG_NONE, - print, ssl_socket_print, protocol, &ssl_socket_stream_p, locals_dict, &ssl_socket_locals_dict ); +/******************************************************************************/ +// ssl module. + STATIC mp_obj_t mod_ssl_wrap_socket(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) { - // TODO: Implement more args + enum { + ARG_key, + ARG_cert, + ARG_server_side, + ARG_server_hostname, + ARG_do_handshake, + }; static const mp_arg_t allowed_args[] = { { MP_QSTR_key, MP_ARG_KW_ONLY | MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE} }, { MP_QSTR_cert, MP_ARG_KW_ONLY | MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE} }, @@ -334,22 +415,40 @@ STATIC mp_obj_t mod_ssl_wrap_socket(size_t n_args, const mp_obj_t *pos_args, mp_ { MP_QSTR_do_handshake, MP_ARG_KW_ONLY | MP_ARG_BOOL, {.u_bool = true} }, }; - // TODO: Check that sock implements stream protocol + // Parse arguments. mp_obj_t sock = pos_args[0]; + mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)]; + mp_arg_parse_all(n_args - 1, pos_args + 1, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args); - struct ssl_args args; - mp_arg_parse_all(n_args - 1, pos_args + 1, kw_args, - MP_ARRAY_SIZE(allowed_args), allowed_args, (mp_arg_val_t *)&args); + // Create SSLContext. + mp_int_t protocol = args[ARG_server_side].u_bool ? PROTOCOL_TLS_SERVER : PROTOCOL_TLS_CLIENT; + mp_obj_t ssl_context_args[1] = { MP_OBJ_NEW_SMALL_INT(protocol) }; + mp_obj_ssl_context_t *ssl_context = MP_OBJ_TO_PTR(ssl_context_make_new(&ssl_context_type, 1, 0, ssl_context_args)); - return MP_OBJ_FROM_PTR(ssl_socket_new(sock, &args)); + // Load key and cert if given. + if (args[ARG_key].u_obj != mp_const_none) { + ssl_context_load_key(ssl_context, args[ARG_key].u_obj, args[ARG_cert].u_obj); + } + + // Create and return the new SSLSocket object. + return ssl_socket_make_new(ssl_context, sock, args[ARG_server_side].u_bool, + args[ARG_do_handshake].u_bool, args[ARG_server_hostname].u_obj); } STATIC MP_DEFINE_CONST_FUN_OBJ_KW(mod_ssl_wrap_socket_obj, 1, mod_ssl_wrap_socket); STATIC const mp_rom_map_elem_t mp_module_ssl_globals_table[] = { { MP_ROM_QSTR(MP_QSTR___name__), MP_ROM_QSTR(MP_QSTR_ssl) }, - { MP_ROM_QSTR(MP_QSTR_wrap_socket), MP_ROM_PTR(&mod_ssl_wrap_socket_obj) }, -}; + // Functions. + { MP_ROM_QSTR(MP_QSTR_wrap_socket), MP_ROM_PTR(&mod_ssl_wrap_socket_obj) }, + + // Classes. + { MP_ROM_QSTR(MP_QSTR_SSLContext), MP_ROM_PTR(&ssl_context_type) }, + + // Constants. + { MP_ROM_QSTR(MP_QSTR_PROTOCOL_TLS_CLIENT), MP_ROM_INT(PROTOCOL_TLS_CLIENT) }, + { MP_ROM_QSTR(MP_QSTR_PROTOCOL_TLS_SERVER), MP_ROM_INT(PROTOCOL_TLS_SERVER) }, +}; STATIC MP_DEFINE_CONST_DICT(mp_module_ssl_globals, mp_module_ssl_globals_table); const mp_obj_module_t mp_module_ssl = { diff --git a/extmod/modssl_mbedtls.c b/extmod/modssl_mbedtls.c index e346f986f7..83f6f907f4 100644 --- a/extmod/modssl_mbedtls.c +++ b/extmod/modssl_mbedtls.c @@ -5,6 +5,7 @@ * * Copyright (c) 2016 Linaro Ltd. * Copyright (c) 2019 Paul Sokolovsky + * Copyright (c) 2023 Damien P. George * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -48,33 +49,38 @@ #define MP_STREAM_POLL_RDWR (MP_STREAM_POLL_RD | MP_STREAM_POLL_WR) -typedef struct _mp_obj_ssl_socket_t { +// This corresponds to an SSLContext object. +typedef struct _mp_obj_ssl_context_t { mp_obj_base_t base; - mp_obj_t sock; mbedtls_entropy_context entropy; mbedtls_ctr_drbg_context ctr_drbg; - mbedtls_ssl_context ssl; mbedtls_ssl_config conf; mbedtls_x509_crt cacert; mbedtls_x509_crt cert; mbedtls_pk_context pkey; + int authmode; +} mp_obj_ssl_context_t; + +// This corresponds to an SSLSocket object. +typedef struct _mp_obj_ssl_socket_t { + mp_obj_base_t base; + mp_obj_ssl_context_t *ssl_context; + mp_obj_t sock; + mbedtls_ssl_context ssl; uintptr_t poll_mask; // Indicates which read or write operations the protocol needs next int last_error; // The last error code, if any } mp_obj_ssl_socket_t; -struct ssl_args { - mp_arg_val_t key; - mp_arg_val_t cert; - mp_arg_val_t server_side; - mp_arg_val_t server_hostname; - mp_arg_val_t cert_reqs; - mp_arg_val_t cadata; - mp_arg_val_t do_handshake; -}; - +STATIC const mp_obj_type_t ssl_context_type; STATIC const mp_obj_type_t ssl_socket_type; +STATIC mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t sock, + bool server_side, bool do_handshake_on_connect, mp_obj_t server_hostname); + +/******************************************************************************/ +// Helper functions. + #ifdef MBEDTLS_DEBUG_C STATIC void mbedtls_debug(void *ctx, int level, const char *file, int line, const char *str) { (void)ctx; @@ -84,6 +90,15 @@ STATIC void mbedtls_debug(void *ctx, int level, const char *file, int line, cons #endif STATIC NORETURN void mbedtls_raise_error(int err) { + // Handle special cases. + if (err == MBEDTLS_ERR_SSL_ALLOC_FAILED) { + mp_raise_OSError(MP_ENOMEM); + } else if (err == MBEDTLS_ERR_PK_BAD_INPUT_DATA) { + mp_raise_ValueError(MP_ERROR_TEXT("invalid key")); + } else if (err == MBEDTLS_ERR_X509_BAD_INPUT_DATA) { + mp_raise_ValueError(MP_ERROR_TEXT("invalid cert")); + } + // _mbedtls_ssl_send and _mbedtls_ssl_recv (below) turn positive error codes from the // underlying socket into negative codes to pass them through mbedtls. Here we turn them // positive again so they get interpreted as the OSError they really are. The @@ -123,6 +138,178 @@ STATIC NORETURN void mbedtls_raise_error(int err) { #endif } +/******************************************************************************/ +// SSLContext type. + +STATIC mp_obj_t ssl_context_make_new(const mp_obj_type_t *type_in, size_t n_args, size_t n_kw, const mp_obj_t *args) { + mp_arg_check_num(n_args, n_kw, 1, 1, false); + + // This is the "protocol" argument. + mp_int_t endpoint = mp_obj_get_int(args[0]); + + // Create SSLContext object. + #if MICROPY_PY_SSL_FINALISER + mp_obj_ssl_context_t *self = m_new_obj_with_finaliser(mp_obj_ssl_context_t); + #else + mp_obj_ssl_context_t *self = m_new_obj(mp_obj_ssl_context_t); + #endif + self->base.type = type_in; + + // Initialise mbedTLS state. + mbedtls_ssl_config_init(&self->conf); + mbedtls_entropy_init(&self->entropy); + mbedtls_ctr_drbg_init(&self->ctr_drbg); + mbedtls_x509_crt_init(&self->cacert); + mbedtls_x509_crt_init(&self->cert); + mbedtls_pk_init(&self->pkey); + + #ifdef MBEDTLS_DEBUG_C + // Debug level (0-4) 1=warning, 2=info, 3=debug, 4=verbose + mbedtls_debug_set_threshold(3); + #endif + + const byte seed[] = "upy"; + int ret = mbedtls_ctr_drbg_seed(&self->ctr_drbg, mbedtls_entropy_func, &self->entropy, seed, sizeof(seed)); + if (ret != 0) { + mbedtls_raise_error(ret); + } + + ret = mbedtls_ssl_config_defaults(&self->conf, endpoint, + MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT); + if (ret != 0) { + mbedtls_raise_error(ret); + } + + if (endpoint == MBEDTLS_SSL_IS_CLIENT) { + // The CPython default is MBEDTLS_SSL_VERIFY_REQUIRED, but to maintain + // backwards compatibility we use MBEDTLS_SSL_VERIFY_NONE for now. + self->authmode = MBEDTLS_SSL_VERIFY_NONE; + } else { + self->authmode = MBEDTLS_SSL_VERIFY_NONE; + } + mbedtls_ssl_conf_authmode(&self->conf, self->authmode); + mbedtls_ssl_conf_rng(&self->conf, mbedtls_ctr_drbg_random, &self->ctr_drbg); + #ifdef MBEDTLS_DEBUG_C + mbedtls_ssl_conf_dbg(&self->conf, mbedtls_debug, NULL); + #endif + + return MP_OBJ_FROM_PTR(self); +} + +STATIC void ssl_context_attr(mp_obj_t self_in, qstr attr, mp_obj_t *dest) { + mp_obj_ssl_context_t *self = MP_OBJ_TO_PTR(self_in); + if (dest[0] == MP_OBJ_NULL) { + // Load attribute. + if (attr == MP_QSTR_verify_mode) { + dest[0] = MP_OBJ_NEW_SMALL_INT(self->authmode); + } else { + // Continue lookup in locals_dict. + dest[1] = MP_OBJ_SENTINEL; + } + } else if (dest[1] != MP_OBJ_NULL) { + // Store attribute. + if (attr == MP_QSTR_verify_mode) { + self->authmode = mp_obj_get_int(dest[1]); + dest[0] = MP_OBJ_NULL; + mbedtls_ssl_conf_authmode(&self->conf, self->authmode); + } + } +} + +#if MICROPY_PY_SSL_FINALISER +STATIC mp_obj_t ssl_context___del__(mp_obj_t self_in) { + mp_obj_ssl_context_t *self = MP_OBJ_TO_PTR(self_in); + mbedtls_pk_free(&self->pkey); + mbedtls_x509_crt_free(&self->cert); + mbedtls_x509_crt_free(&self->cacert); + mbedtls_ctr_drbg_free(&self->ctr_drbg); + mbedtls_entropy_free(&self->entropy); + mbedtls_ssl_config_free(&self->conf); + return mp_const_none; +} +STATIC MP_DEFINE_CONST_FUN_OBJ_1(ssl_context___del___obj, ssl_context___del__); +#endif + +STATIC void ssl_context_load_key(mp_obj_ssl_context_t *self, mp_obj_t key_obj, mp_obj_t cert_obj) { + size_t key_len; + const byte *key = (const byte *)mp_obj_str_get_data(key_obj, &key_len); + // len should include terminating null + int ret; + #if MBEDTLS_VERSION_NUMBER >= 0x03000000 + ret = mbedtls_pk_parse_key(&self->pkey, key, key_len + 1, NULL, 0, mbedtls_ctr_drbg_random, &self->ctr_drbg); + #else + ret = mbedtls_pk_parse_key(&self->pkey, key, key_len + 1, NULL, 0); + #endif + if (ret != 0) { + mbedtls_raise_error(MBEDTLS_ERR_PK_BAD_INPUT_DATA); // use general error for all key errors + } + + size_t cert_len; + const byte *cert = (const byte *)mp_obj_str_get_data(cert_obj, &cert_len); + // len should include terminating null + ret = mbedtls_x509_crt_parse(&self->cert, cert, cert_len + 1); + if (ret != 0) { + mbedtls_raise_error(MBEDTLS_ERR_X509_BAD_INPUT_DATA); // use general error for all cert errors + } + + ret = mbedtls_ssl_conf_own_cert(&self->conf, &self->cert, &self->pkey); + if (ret != 0) { + mbedtls_raise_error(ret); + } +} + +STATIC void ssl_context_load_cadata(mp_obj_ssl_context_t *self, mp_obj_t cadata_obj) { + size_t cacert_len; + const byte *cacert = (const byte *)mp_obj_str_get_data(cadata_obj, &cacert_len); + // len should include terminating null + int ret = mbedtls_x509_crt_parse(&self->cacert, cacert, cacert_len + 1); + if (ret != 0) { + mbedtls_raise_error(MBEDTLS_ERR_X509_BAD_INPUT_DATA); // use general error for all cert errors + } + + mbedtls_ssl_conf_ca_chain(&self->conf, &self->cacert, NULL); +} + +STATIC mp_obj_t ssl_context_wrap_socket(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) { + enum { ARG_server_side, ARG_do_handshake_on_connect, ARG_server_hostname }; + static const mp_arg_t allowed_args[] = { + { MP_QSTR_server_side, MP_ARG_KW_ONLY | MP_ARG_BOOL, {.u_bool = false} }, + { MP_QSTR_do_handshake_on_connect, MP_ARG_KW_ONLY | MP_ARG_BOOL, {.u_bool = true} }, + { MP_QSTR_server_hostname, MP_ARG_KW_ONLY | MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE} }, + }; + + // Parse arguments. + mp_obj_ssl_context_t *self = MP_OBJ_TO_PTR(pos_args[0]); + mp_obj_t sock = pos_args[1]; + mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)]; + mp_arg_parse_all(n_args - 2, pos_args + 2, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args); + + // Create and return the new SSLSocket object. + return ssl_socket_make_new(self, sock, args[ARG_server_side].u_bool, + args[ARG_do_handshake_on_connect].u_bool, args[ARG_server_hostname].u_obj); +} +STATIC MP_DEFINE_CONST_FUN_OBJ_KW(ssl_context_wrap_socket_obj, 2, ssl_context_wrap_socket); + +STATIC const mp_rom_map_elem_t ssl_context_locals_dict_table[] = { + #if MICROPY_PY_SSL_FINALISER + { MP_ROM_QSTR(MP_QSTR___del__), MP_ROM_PTR(&ssl_context___del___obj) }, + #endif + { MP_ROM_QSTR(MP_QSTR_wrap_socket), MP_ROM_PTR(&ssl_context_wrap_socket_obj) }, +}; +STATIC MP_DEFINE_CONST_DICT(ssl_context_locals_dict, ssl_context_locals_dict_table); + +STATIC MP_DEFINE_CONST_OBJ_TYPE( + ssl_context_type, + MP_QSTR_SSLContext, + MP_TYPE_FLAG_NONE, + make_new, ssl_context_make_new, + attr, ssl_context_attr, + locals_dict, &ssl_context_locals_dict + ); + +/******************************************************************************/ +// SSLSocket type. + STATIC int _mbedtls_ssl_send(void *ctx, const byte *buf, size_t len) { mp_obj_t sock = *(mp_obj_t *)ctx; @@ -158,8 +345,9 @@ STATIC int _mbedtls_ssl_recv(void *ctx, byte *buf, size_t len) { } } +STATIC mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t sock, + bool server_side, bool do_handshake_on_connect, mp_obj_t server_hostname) { -STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) { // Verify the socket object has the full stream protocol mp_get_stream_raise(sock, MP_STREAM_OP_READ | MP_STREAM_OP_WRITE | MP_STREAM_OP_IOCTL); @@ -175,44 +363,14 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) { int ret; mbedtls_ssl_init(&o->ssl); - mbedtls_ssl_config_init(&o->conf); - mbedtls_x509_crt_init(&o->cacert); - mbedtls_x509_crt_init(&o->cert); - mbedtls_pk_init(&o->pkey); - mbedtls_ctr_drbg_init(&o->ctr_drbg); - #ifdef MBEDTLS_DEBUG_C - // Debug level (0-4) 1=warning, 2=info, 3=debug, 4=verbose - mbedtls_debug_set_threshold(3); - #endif - mbedtls_entropy_init(&o->entropy); - const byte seed[] = "upy"; - ret = mbedtls_ctr_drbg_seed(&o->ctr_drbg, mbedtls_entropy_func, &o->entropy, seed, sizeof(seed)); + ret = mbedtls_ssl_setup(&o->ssl, &ssl_context->conf); if (ret != 0) { goto cleanup; } - ret = mbedtls_ssl_config_defaults(&o->conf, - args->server_side.u_bool ? MBEDTLS_SSL_IS_SERVER : MBEDTLS_SSL_IS_CLIENT, - MBEDTLS_SSL_TRANSPORT_STREAM, - MBEDTLS_SSL_PRESET_DEFAULT); - if (ret != 0) { - goto cleanup; - } - - mbedtls_ssl_conf_authmode(&o->conf, args->cert_reqs.u_int); - mbedtls_ssl_conf_rng(&o->conf, mbedtls_ctr_drbg_random, &o->ctr_drbg); - #ifdef MBEDTLS_DEBUG_C - mbedtls_ssl_conf_dbg(&o->conf, mbedtls_debug, NULL); - #endif - - ret = mbedtls_ssl_setup(&o->ssl, &o->conf); - if (ret != 0) { - goto cleanup; - } - - if (args->server_hostname.u_obj != mp_const_none) { - const char *sni = mp_obj_str_get_str(args->server_hostname.u_obj); + if (server_hostname != mp_const_none) { + const char *sni = mp_obj_str_get_str(server_hostname); ret = mbedtls_ssl_set_hostname(&o->ssl, sni); if (ret != 0) { goto cleanup; @@ -221,49 +379,7 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) { mbedtls_ssl_set_bio(&o->ssl, &o->sock, _mbedtls_ssl_send, _mbedtls_ssl_recv, NULL); - if (args->key.u_obj != mp_const_none) { - size_t key_len; - const byte *key = (const byte *)mp_obj_str_get_data(args->key.u_obj, &key_len); - // len should include terminating null - #if MBEDTLS_VERSION_NUMBER >= 0x03000000 - ret = mbedtls_pk_parse_key(&o->pkey, key, key_len + 1, NULL, 0, mbedtls_ctr_drbg_random, &o->ctr_drbg); - #else - ret = mbedtls_pk_parse_key(&o->pkey, key, key_len + 1, NULL, 0); - #endif - if (ret != 0) { - ret = MBEDTLS_ERR_PK_BAD_INPUT_DATA; // use general error for all key errors - goto cleanup; - } - - size_t cert_len; - const byte *cert = (const byte *)mp_obj_str_get_data(args->cert.u_obj, &cert_len); - // len should include terminating null - ret = mbedtls_x509_crt_parse(&o->cert, cert, cert_len + 1); - if (ret != 0) { - ret = MBEDTLS_ERR_X509_BAD_INPUT_DATA; // use general error for all cert errors - goto cleanup; - } - - ret = mbedtls_ssl_conf_own_cert(&o->conf, &o->cert, &o->pkey); - if (ret != 0) { - goto cleanup; - } - } - - if (args->cadata.u_obj != mp_const_none) { - size_t cacert_len; - const byte *cacert = (const byte *)mp_obj_str_get_data(args->cadata.u_obj, &cacert_len); - // len should include terminating null - ret = mbedtls_x509_crt_parse(&o->cacert, cacert, cacert_len + 1); - if (ret != 0) { - ret = MBEDTLS_ERR_X509_BAD_INPUT_DATA; // use general error for all cert errors - goto cleanup; - } - - mbedtls_ssl_conf_ca_chain(&o->conf, &o->cacert, NULL); - } - - if (args->do_handshake.u_bool) { + if (do_handshake_on_connect) { while ((ret = mbedtls_ssl_handshake(&o->ssl)) != 0) { if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) { goto cleanup; @@ -274,26 +390,11 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) { } } - return o; + return MP_OBJ_FROM_PTR(o); cleanup: - mbedtls_pk_free(&o->pkey); - mbedtls_x509_crt_free(&o->cert); - mbedtls_x509_crt_free(&o->cacert); mbedtls_ssl_free(&o->ssl); - mbedtls_ssl_config_free(&o->conf); - mbedtls_ctr_drbg_free(&o->ctr_drbg); - mbedtls_entropy_free(&o->entropy); - - if (ret == MBEDTLS_ERR_SSL_ALLOC_FAILED) { - mp_raise_OSError(MP_ENOMEM); - } else if (ret == MBEDTLS_ERR_PK_BAD_INPUT_DATA) { - mp_raise_ValueError(MP_ERROR_TEXT("invalid key")); - } else if (ret == MBEDTLS_ERR_X509_BAD_INPUT_DATA) { - mp_raise_ValueError(MP_ERROR_TEXT("invalid cert")); - } else { - mbedtls_raise_error(ret); - } + mbedtls_raise_error(ret); } STATIC mp_obj_t mod_ssl_getpeercert(mp_obj_t o_in, mp_obj_t binary_form) { @@ -309,12 +410,6 @@ STATIC mp_obj_t mod_ssl_getpeercert(mp_obj_t o_in, mp_obj_t binary_form) { } STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_ssl_getpeercert_obj, mod_ssl_getpeercert); -STATIC void socket_print(const mp_print_t *print, mp_obj_t self_in, mp_print_kind_t kind) { - (void)kind; - mp_obj_ssl_socket_t *self = MP_OBJ_TO_PTR(self_in); - mp_printf(print, "<_SSLSocket %p>", self); -} - STATIC mp_uint_t socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errcode) { mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(o_in); o->poll_mask = 0; @@ -397,13 +492,7 @@ STATIC mp_uint_t socket_ioctl(mp_obj_t o_in, mp_uint_t request, uintptr_t arg, i if (request == MP_STREAM_CLOSE) { self->sock = MP_OBJ_NULL; - mbedtls_pk_free(&self->pkey); - mbedtls_x509_crt_free(&self->cert); - mbedtls_x509_crt_free(&self->cacert); mbedtls_ssl_free(&self->ssl); - mbedtls_ssl_config_free(&self->conf); - mbedtls_ctr_drbg_free(&self->ctr_drbg); - mbedtls_entropy_free(&self->entropy); } else if (request == MP_STREAM_POLL) { // If the library signaled us that it needs reading or writing, only check that direction, // but save what the caller asked because we need to restore it later @@ -454,7 +543,6 @@ STATIC const mp_rom_map_elem_t ssl_socket_locals_dict_table[] = { #endif { MP_ROM_QSTR(MP_QSTR_getpeercert), MP_ROM_PTR(&mod_ssl_getpeercert_obj) }, }; - STATIC MP_DEFINE_CONST_DICT(ssl_socket_locals_dict, ssl_socket_locals_dict_table); STATIC const mp_stream_p_t ssl_socket_stream_p = { @@ -465,16 +553,25 @@ STATIC const mp_stream_p_t ssl_socket_stream_p = { STATIC MP_DEFINE_CONST_OBJ_TYPE( ssl_socket_type, - // Save on qstr's, reuse same as for module - MP_QSTR_ssl, + MP_QSTR_SSLSocket, MP_TYPE_FLAG_NONE, - print, socket_print, protocol, &ssl_socket_stream_p, locals_dict, &ssl_socket_locals_dict ); +/******************************************************************************/ +// ssl module. + STATIC mp_obj_t mod_ssl_wrap_socket(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) { - // TODO: Implement more args + enum { + ARG_key, + ARG_cert, + ARG_server_side, + ARG_server_hostname, + ARG_cert_reqs, + ARG_cadata, + ARG_do_handshake, + }; static const mp_arg_t allowed_args[] = { { MP_QSTR_key, MP_ARG_KW_ONLY | MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE} }, { MP_QSTR_cert, MP_ARG_KW_ONLY | MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE} }, @@ -485,25 +582,52 @@ STATIC mp_obj_t mod_ssl_wrap_socket(size_t n_args, const mp_obj_t *pos_args, mp_ { MP_QSTR_do_handshake, MP_ARG_KW_ONLY | MP_ARG_BOOL, {.u_bool = true} }, }; - // TODO: Check that sock implements stream protocol + // Parse arguments. mp_obj_t sock = pos_args[0]; + mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)]; + mp_arg_parse_all(n_args - 1, pos_args + 1, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args); - struct ssl_args args; - mp_arg_parse_all(n_args - 1, pos_args + 1, kw_args, - MP_ARRAY_SIZE(allowed_args), allowed_args, (mp_arg_val_t *)&args); + // Create SSLContext. + mp_int_t protocol = args[ARG_server_side].u_bool ? MBEDTLS_SSL_IS_SERVER : MBEDTLS_SSL_IS_CLIENT; + mp_obj_t ssl_context_args[1] = { MP_OBJ_NEW_SMALL_INT(protocol) }; + mp_obj_ssl_context_t *ssl_context = MP_OBJ_TO_PTR(ssl_context_make_new(&ssl_context_type, 1, 0, ssl_context_args)); - return MP_OBJ_FROM_PTR(socket_new(sock, &args)); + // Load key and cert if given. + if (args[ARG_key].u_obj != mp_const_none) { + ssl_context_load_key(ssl_context, args[ARG_key].u_obj, args[ARG_cert].u_obj); + } + + // Set the verify_mode. + mp_obj_t dest[2] = { MP_OBJ_SENTINEL, MP_OBJ_NEW_SMALL_INT(args[ARG_cert_reqs].u_int) }; + ssl_context_attr(MP_OBJ_FROM_PTR(ssl_context), MP_QSTR_verify_mode, dest); + + // Load cadata if given. + if (args[ARG_cadata].u_obj != mp_const_none) { + ssl_context_load_cadata(ssl_context, args[ARG_cadata].u_obj); + } + + // Create and return the new SSLSocket object. + return ssl_socket_make_new(ssl_context, sock, args[ARG_server_side].u_bool, + args[ARG_do_handshake].u_bool, args[ARG_server_hostname].u_obj); } STATIC MP_DEFINE_CONST_FUN_OBJ_KW(mod_ssl_wrap_socket_obj, 1, mod_ssl_wrap_socket); STATIC const mp_rom_map_elem_t mp_module_ssl_globals_table[] = { { MP_ROM_QSTR(MP_QSTR___name__), MP_ROM_QSTR(MP_QSTR_ssl) }, + + // Functions. { MP_ROM_QSTR(MP_QSTR_wrap_socket), MP_ROM_PTR(&mod_ssl_wrap_socket_obj) }, + + // Classes. + { MP_ROM_QSTR(MP_QSTR_SSLContext), MP_ROM_PTR(&ssl_context_type) }, + + // Constants. + { MP_ROM_QSTR(MP_QSTR_PROTOCOL_TLS_CLIENT), MP_ROM_INT(MBEDTLS_SSL_IS_CLIENT) }, + { MP_ROM_QSTR(MP_QSTR_PROTOCOL_TLS_SERVER), MP_ROM_INT(MBEDTLS_SSL_IS_SERVER) }, { MP_ROM_QSTR(MP_QSTR_CERT_NONE), MP_ROM_INT(MBEDTLS_SSL_VERIFY_NONE) }, { MP_ROM_QSTR(MP_QSTR_CERT_OPTIONAL), MP_ROM_INT(MBEDTLS_SSL_VERIFY_OPTIONAL) }, { MP_ROM_QSTR(MP_QSTR_CERT_REQUIRED), MP_ROM_INT(MBEDTLS_SSL_VERIFY_REQUIRED) }, }; - STATIC MP_DEFINE_CONST_DICT(mp_module_ssl_globals, mp_module_ssl_globals_table); const mp_obj_module_t mp_module_ssl = { diff --git a/tests/extmod/ssl_basic.py b/tests/extmod/ssl_basic.py index d035798c98..95e66e0cae 100644 --- a/tests/extmod/ssl_basic.py +++ b/tests/extmod/ssl_basic.py @@ -33,7 +33,7 @@ except OSError as er: ss = ssl.wrap_socket(TestSocket(), server_side=1, do_handshake=0) # print -print(repr(ss)[:12]) +print(ss) # setblocking() propagates call to the underlying stream object ss.setblocking(False) diff --git a/tests/extmod/ssl_basic.py.exp b/tests/extmod/ssl_basic.py.exp index 30bfd3a436..1a10a18614 100644 --- a/tests/extmod/ssl_basic.py.exp +++ b/tests/extmod/ssl_basic.py.exp @@ -1,5 +1,5 @@ OSError: client -<_SSLSocket + TestSocket.setblocking(False) TestSocket.setblocking(True) TestSocket.ioctl 4 0 diff --git a/tests/net_inet/test_tls_sites.py b/tests/net_inet/test_tls_sites.py index f9a3dc86d2..4f457b3abc 100644 --- a/tests/net_inet/test_tls_sites.py +++ b/tests/net_inet/test_tls_sites.py @@ -3,7 +3,7 @@ import ssl # CPython only supports server_hostname with SSLContext if hasattr(ssl, "SSLContext"): - ssl = ssl.SSLContext() + ssl = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) def test_one(site, opts):