diff --git a/extmod/modtls_mbedtls.c b/extmod/modtls_mbedtls.c index 6db6ac1958..ce889bc759 100644 --- a/extmod/modtls_mbedtls.c +++ b/extmod/modtls_mbedtls.c @@ -70,6 +70,12 @@ typedef struct _mp_obj_ssl_context_t { mp_obj_t handler; } mp_obj_ssl_context_t; +// This corresponds to an SSLSession object. +typedef struct _mp_obj_ssl_session_t { + mp_obj_base_t base; + mbedtls_ssl_session session; +} mp_obj_ssl_session_t; + // This corresponds to an SSLSocket object. typedef struct _mp_obj_ssl_socket_t { mp_obj_base_t base; @@ -81,13 +87,14 @@ typedef struct _mp_obj_ssl_socket_t { int last_error; // The last error code, if any } mp_obj_ssl_socket_t; +static const mp_obj_type_t ssl_session_type; static const mp_obj_type_t ssl_context_type; static const mp_obj_type_t ssl_socket_type; static const MP_DEFINE_STR_OBJ(mbedtls_version_obj, MBEDTLS_VERSION_STRING_FULL); static mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t sock, - bool server_side, bool do_handshake_on_connect, mp_obj_t server_hostname); + bool server_side, bool do_handshake_on_connect, mp_obj_t server_hostname, mp_obj_t ssl_session); /******************************************************************************/ // Helper functions. @@ -199,6 +206,60 @@ static int ssl_sock_cert_verify(void *ptr, mbedtls_x509_crt *crt, int depth, uin return mp_obj_get_int(mp_call_function_2(o->handler, MP_OBJ_FROM_PTR(&cert), MP_OBJ_NEW_SMALL_INT(depth))); } +/******************************************************************************/ +// SSLSession type. + +static mp_obj_t ssl_session_make_new(const mp_obj_type_t *type_in, size_t n_args, size_t n_kw, const mp_obj_t *args) { + mp_arg_check_num(n_args, n_kw, 1, 1, false); + + mp_buffer_info_t bufinfo; + mp_get_buffer_raise(args[0], &bufinfo, MP_BUFFER_READ); + + mp_obj_ssl_session_t *self = m_new_obj(mp_obj_ssl_session_t); + self->base.type = type_in; + + mbedtls_ssl_session_init(&self->session); + int ret = mbedtls_ssl_session_load(&self->session, bufinfo.buf, bufinfo.len); + if (ret != 0) { + mbedtls_raise_error(ret); + } + + return MP_OBJ_FROM_PTR(self); +} + +static mp_obj_t ssl_session_serialize(mp_obj_t self_in) { + mp_obj_ssl_session_t *self = MP_OBJ_TO_PTR(self_in); + size_t len; + vstr_t vstr; + mbedtls_ssl_session_save(&self->session, NULL, 0, &len); + vstr_init_len(&vstr, len); + mbedtls_ssl_session_save(&self->session, (unsigned char *)vstr.buf, len, &len); + return mp_obj_new_bytes_from_vstr(&vstr); +} +static MP_DEFINE_CONST_FUN_OBJ_1(ssl_session_serialize_obj, ssl_session_serialize); + +static mp_int_t ssl_session_get_buffer(mp_obj_t self_in, mp_buffer_info_t *bufinfo, mp_uint_t flags) { + if (flags != MP_BUFFER_READ) { + return 1; + } + mp_get_buffer_raise(ssl_session_serialize(self_in), bufinfo, flags); + return 0; +} + +static const mp_rom_map_elem_t ssl_session_locals_dict_table[] = { + { MP_ROM_QSTR(MP_QSTR_serialize), MP_ROM_PTR(&ssl_session_serialize_obj) }, +}; +static MP_DEFINE_CONST_DICT(ssl_session_locals_dict, ssl_session_locals_dict_table); + +static MP_DEFINE_CONST_OBJ_TYPE( + ssl_session_type, + MP_QSTR_SSLSession, + MP_TYPE_FLAG_NONE, + make_new, ssl_session_make_new, + buffer, ssl_session_get_buffer, + locals_dict, &ssl_session_locals_dict + ); + /******************************************************************************/ // SSLContext type. @@ -402,11 +463,12 @@ static mp_obj_t ssl_context_load_verify_locations(mp_obj_t self_in, mp_obj_t cad static MP_DEFINE_CONST_FUN_OBJ_2(ssl_context_load_verify_locations_obj, ssl_context_load_verify_locations); static mp_obj_t ssl_context_wrap_socket(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) { - enum { ARG_server_side, ARG_do_handshake_on_connect, ARG_server_hostname }; + enum { ARG_server_side, ARG_do_handshake_on_connect, ARG_server_hostname, ARG_session }; static const mp_arg_t allowed_args[] = { { MP_QSTR_server_side, MP_ARG_KW_ONLY | MP_ARG_BOOL, {.u_bool = false} }, { MP_QSTR_do_handshake_on_connect, MP_ARG_KW_ONLY | MP_ARG_BOOL, {.u_bool = true} }, { MP_QSTR_server_hostname, MP_ARG_KW_ONLY | MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE} }, + { MP_QSTR_session, MP_ARG_KW_ONLY | MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE} }, }; // Parse arguments. @@ -417,7 +479,7 @@ static mp_obj_t ssl_context_wrap_socket(size_t n_args, const mp_obj_t *pos_args, // Create and return the new SSLSocket object. return ssl_socket_make_new(self, sock, args[ARG_server_side].u_bool, - args[ARG_do_handshake_on_connect].u_bool, args[ARG_server_hostname].u_obj); + args[ARG_do_handshake_on_connect].u_bool, args[ARG_server_hostname].u_obj, args[ARG_session].u_obj); } static MP_DEFINE_CONST_FUN_OBJ_KW(ssl_context_wrap_socket_obj, 2, ssl_context_wrap_socket); @@ -481,7 +543,7 @@ static int _mbedtls_ssl_recv(void *ctx, byte *buf, size_t len) { } static mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t sock, - bool server_side, bool do_handshake_on_connect, mp_obj_t server_hostname) { + bool server_side, bool do_handshake_on_connect, mp_obj_t server_hostname, mp_obj_t ssl_session) { // Verify the socket object has the full stream protocol mp_get_stream_raise(sock, MP_STREAM_OP_READ | MP_STREAM_OP_WRITE | MP_STREAM_OP_IOCTL); @@ -519,6 +581,14 @@ static mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t mp_raise_ValueError(MP_ERROR_TEXT("CERT_REQUIRED requires server_hostname")); } + if (ssl_session != mp_const_none) { + mp_obj_ssl_session_t *session = MP_OBJ_TO_PTR(ssl_session); + ret = mbedtls_ssl_set_session(&o->ssl, &session->session); + if (ret != 0) { + goto cleanup; + } + } + mbedtls_ssl_set_bio(&o->ssl, &o->sock, _mbedtls_ssl_send, _mbedtls_ssl_recv, NULL); if (do_handshake_on_connect) { @@ -716,6 +786,36 @@ static mp_uint_t socket_ioctl(mp_obj_t o_in, mp_uint_t request, uintptr_t arg, i return ret; } +static void ssl_socket_attr(mp_obj_t self_in, qstr attr, mp_obj_t *dest) { + mp_obj_ssl_socket_t *self = MP_OBJ_TO_PTR(self_in); + if (dest[0] == MP_OBJ_NULL) { + // Load attribute. + if (attr == MP_QSTR_session) { + mp_obj_ssl_session_t *o = m_new_obj(mp_obj_ssl_session_t); + o->base.type = &ssl_session_type; + mbedtls_ssl_session_init(&o->session); + int ret = mbedtls_ssl_get_session(&self->ssl, &o->session); + if (ret != 0) { + mbedtls_raise_error(ret); + } + dest[0] = MP_OBJ_FROM_PTR(o); + } else { + // Continue lookup in locals_dict. + dest[1] = MP_OBJ_SENTINEL; + } + } else if (dest[1] != MP_OBJ_NULL) { + // Store attribute. + if (attr == MP_QSTR_session) { + mp_obj_ssl_session_t *ssl_session = MP_OBJ_TO_PTR(dest[1]); + dest[0] = MP_OBJ_NULL; + int ret = mbedtls_ssl_set_session(&self->ssl, &ssl_session->session); + if (ret != 0) { + mbedtls_raise_error(ret); + } + } + } +} + static const mp_rom_map_elem_t ssl_socket_locals_dict_table[] = { { MP_ROM_QSTR(MP_QSTR_read), MP_ROM_PTR(&mp_stream_read_obj) }, { MP_ROM_QSTR(MP_QSTR_readinto), MP_ROM_PTR(&mp_stream_readinto_obj) }, @@ -747,6 +847,7 @@ static MP_DEFINE_CONST_OBJ_TYPE( MP_QSTR_SSLSocket, MP_TYPE_FLAG_NONE, protocol, &ssl_socket_stream_p, + attr, ssl_socket_attr, locals_dict, &ssl_socket_locals_dict ); @@ -758,6 +859,7 @@ static const mp_rom_map_elem_t mp_module_tls_globals_table[] = { // Classes. { MP_ROM_QSTR(MP_QSTR_SSLContext), MP_ROM_PTR(&ssl_context_type) }, + { MP_ROM_QSTR(MP_QSTR_SSLSession), MP_ROM_PTR(&ssl_session_type) }, // Constants. { MP_ROM_QSTR(MP_QSTR_MBEDTLS_VERSION), MP_ROM_PTR(&mbedtls_version_obj)},