extmod/modtls_mbedtls: Implement SSLSession support.

Signed-off-by: Daniël van de Giessen <daniel@dvdgiessen.nl>
pull/12780/head
Daniël van de Giessen 2023-09-13 15:08:26 +02:00
rodzic 9d27183bde
commit 82ac8e6535
Nie znaleziono w bazie danych klucza dla tego podpisu
ID klucza GPG: 9F0EF4D3441C8163
1 zmienionych plików z 106 dodań i 4 usunięć

Wyświetl plik

@ -70,6 +70,12 @@ typedef struct _mp_obj_ssl_context_t {
mp_obj_t handler;
} mp_obj_ssl_context_t;
// This corresponds to an SSLSession object.
typedef struct _mp_obj_ssl_session_t {
mp_obj_base_t base;
mbedtls_ssl_session session;
} mp_obj_ssl_session_t;
// This corresponds to an SSLSocket object.
typedef struct _mp_obj_ssl_socket_t {
mp_obj_base_t base;
@ -81,13 +87,14 @@ typedef struct _mp_obj_ssl_socket_t {
int last_error; // The last error code, if any
} mp_obj_ssl_socket_t;
static const mp_obj_type_t ssl_session_type;
static const mp_obj_type_t ssl_context_type;
static const mp_obj_type_t ssl_socket_type;
static const MP_DEFINE_STR_OBJ(mbedtls_version_obj, MBEDTLS_VERSION_STRING_FULL);
static mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t sock,
bool server_side, bool do_handshake_on_connect, mp_obj_t server_hostname);
bool server_side, bool do_handshake_on_connect, mp_obj_t server_hostname, mp_obj_t ssl_session);
/******************************************************************************/
// Helper functions.
@ -199,6 +206,60 @@ static int ssl_sock_cert_verify(void *ptr, mbedtls_x509_crt *crt, int depth, uin
return mp_obj_get_int(mp_call_function_2(o->handler, MP_OBJ_FROM_PTR(&cert), MP_OBJ_NEW_SMALL_INT(depth)));
}
/******************************************************************************/
// SSLSession type.
static mp_obj_t ssl_session_make_new(const mp_obj_type_t *type_in, size_t n_args, size_t n_kw, const mp_obj_t *args) {
mp_arg_check_num(n_args, n_kw, 1, 1, false);
mp_buffer_info_t bufinfo;
mp_get_buffer_raise(args[0], &bufinfo, MP_BUFFER_READ);
mp_obj_ssl_session_t *self = m_new_obj(mp_obj_ssl_session_t);
self->base.type = type_in;
mbedtls_ssl_session_init(&self->session);
int ret = mbedtls_ssl_session_load(&self->session, bufinfo.buf, bufinfo.len);
if (ret != 0) {
mbedtls_raise_error(ret);
}
return MP_OBJ_FROM_PTR(self);
}
static mp_obj_t ssl_session_serialize(mp_obj_t self_in) {
mp_obj_ssl_session_t *self = MP_OBJ_TO_PTR(self_in);
size_t len;
vstr_t vstr;
mbedtls_ssl_session_save(&self->session, NULL, 0, &len);
vstr_init_len(&vstr, len);
mbedtls_ssl_session_save(&self->session, (unsigned char *)vstr.buf, len, &len);
return mp_obj_new_bytes_from_vstr(&vstr);
}
static MP_DEFINE_CONST_FUN_OBJ_1(ssl_session_serialize_obj, ssl_session_serialize);
static mp_int_t ssl_session_get_buffer(mp_obj_t self_in, mp_buffer_info_t *bufinfo, mp_uint_t flags) {
if (flags != MP_BUFFER_READ) {
return 1;
}
mp_get_buffer_raise(ssl_session_serialize(self_in), bufinfo, flags);
return 0;
}
static const mp_rom_map_elem_t ssl_session_locals_dict_table[] = {
{ MP_ROM_QSTR(MP_QSTR_serialize), MP_ROM_PTR(&ssl_session_serialize_obj) },
};
static MP_DEFINE_CONST_DICT(ssl_session_locals_dict, ssl_session_locals_dict_table);
static MP_DEFINE_CONST_OBJ_TYPE(
ssl_session_type,
MP_QSTR_SSLSession,
MP_TYPE_FLAG_NONE,
make_new, ssl_session_make_new,
buffer, ssl_session_get_buffer,
locals_dict, &ssl_session_locals_dict
);
/******************************************************************************/
// SSLContext type.
@ -402,11 +463,12 @@ static mp_obj_t ssl_context_load_verify_locations(mp_obj_t self_in, mp_obj_t cad
static MP_DEFINE_CONST_FUN_OBJ_2(ssl_context_load_verify_locations_obj, ssl_context_load_verify_locations);
static mp_obj_t ssl_context_wrap_socket(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
enum { ARG_server_side, ARG_do_handshake_on_connect, ARG_server_hostname };
enum { ARG_server_side, ARG_do_handshake_on_connect, ARG_server_hostname, ARG_session };
static const mp_arg_t allowed_args[] = {
{ MP_QSTR_server_side, MP_ARG_KW_ONLY | MP_ARG_BOOL, {.u_bool = false} },
{ MP_QSTR_do_handshake_on_connect, MP_ARG_KW_ONLY | MP_ARG_BOOL, {.u_bool = true} },
{ MP_QSTR_server_hostname, MP_ARG_KW_ONLY | MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE} },
{ MP_QSTR_session, MP_ARG_KW_ONLY | MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE} },
};
// Parse arguments.
@ -417,7 +479,7 @@ static mp_obj_t ssl_context_wrap_socket(size_t n_args, const mp_obj_t *pos_args,
// Create and return the new SSLSocket object.
return ssl_socket_make_new(self, sock, args[ARG_server_side].u_bool,
args[ARG_do_handshake_on_connect].u_bool, args[ARG_server_hostname].u_obj);
args[ARG_do_handshake_on_connect].u_bool, args[ARG_server_hostname].u_obj, args[ARG_session].u_obj);
}
static MP_DEFINE_CONST_FUN_OBJ_KW(ssl_context_wrap_socket_obj, 2, ssl_context_wrap_socket);
@ -481,7 +543,7 @@ static int _mbedtls_ssl_recv(void *ctx, byte *buf, size_t len) {
}
static mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t sock,
bool server_side, bool do_handshake_on_connect, mp_obj_t server_hostname) {
bool server_side, bool do_handshake_on_connect, mp_obj_t server_hostname, mp_obj_t ssl_session) {
// Verify the socket object has the full stream protocol
mp_get_stream_raise(sock, MP_STREAM_OP_READ | MP_STREAM_OP_WRITE | MP_STREAM_OP_IOCTL);
@ -519,6 +581,14 @@ static mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t
mp_raise_ValueError(MP_ERROR_TEXT("CERT_REQUIRED requires server_hostname"));
}
if (ssl_session != mp_const_none) {
mp_obj_ssl_session_t *session = MP_OBJ_TO_PTR(ssl_session);
ret = mbedtls_ssl_set_session(&o->ssl, &session->session);
if (ret != 0) {
goto cleanup;
}
}
mbedtls_ssl_set_bio(&o->ssl, &o->sock, _mbedtls_ssl_send, _mbedtls_ssl_recv, NULL);
if (do_handshake_on_connect) {
@ -716,6 +786,36 @@ static mp_uint_t socket_ioctl(mp_obj_t o_in, mp_uint_t request, uintptr_t arg, i
return ret;
}
static void ssl_socket_attr(mp_obj_t self_in, qstr attr, mp_obj_t *dest) {
mp_obj_ssl_socket_t *self = MP_OBJ_TO_PTR(self_in);
if (dest[0] == MP_OBJ_NULL) {
// Load attribute.
if (attr == MP_QSTR_session) {
mp_obj_ssl_session_t *o = m_new_obj(mp_obj_ssl_session_t);
o->base.type = &ssl_session_type;
mbedtls_ssl_session_init(&o->session);
int ret = mbedtls_ssl_get_session(&self->ssl, &o->session);
if (ret != 0) {
mbedtls_raise_error(ret);
}
dest[0] = MP_OBJ_FROM_PTR(o);
} else {
// Continue lookup in locals_dict.
dest[1] = MP_OBJ_SENTINEL;
}
} else if (dest[1] != MP_OBJ_NULL) {
// Store attribute.
if (attr == MP_QSTR_session) {
mp_obj_ssl_session_t *ssl_session = MP_OBJ_TO_PTR(dest[1]);
dest[0] = MP_OBJ_NULL;
int ret = mbedtls_ssl_set_session(&self->ssl, &ssl_session->session);
if (ret != 0) {
mbedtls_raise_error(ret);
}
}
}
}
static const mp_rom_map_elem_t ssl_socket_locals_dict_table[] = {
{ MP_ROM_QSTR(MP_QSTR_read), MP_ROM_PTR(&mp_stream_read_obj) },
{ MP_ROM_QSTR(MP_QSTR_readinto), MP_ROM_PTR(&mp_stream_readinto_obj) },
@ -747,6 +847,7 @@ static MP_DEFINE_CONST_OBJ_TYPE(
MP_QSTR_SSLSocket,
MP_TYPE_FLAG_NONE,
protocol, &ssl_socket_stream_p,
attr, ssl_socket_attr,
locals_dict, &ssl_socket_locals_dict
);
@ -758,6 +859,7 @@ static const mp_rom_map_elem_t mp_module_tls_globals_table[] = {
// Classes.
{ MP_ROM_QSTR(MP_QSTR_SSLContext), MP_ROM_PTR(&ssl_context_type) },
{ MP_ROM_QSTR(MP_QSTR_SSLSession), MP_ROM_PTR(&ssl_session_type) },
// Constants.
{ MP_ROM_QSTR(MP_QSTR_MBEDTLS_VERSION), MP_ROM_PTR(&mbedtls_version_obj)},