From bcb6ca4d5e926d9571d150fb045c5ac4b53f8ecd Mon Sep 17 00:00:00 2001 From: Damien George Date: Tue, 3 Jun 2014 12:53:44 +0100 Subject: [PATCH] py: Implement full behaviour of dict.update(), and dict(). Add keyword args to dict.update(), and ability to take a dictionary as argument. dict() class constructor can now use dict.update() directly. This patch loses fast path for dict(other_dict), but is that really needed? Any anyway, this idiom will now re-hash the dictionary, so is arguably more memory efficient. Addresses issue #647. --- py/objdict.c | 105 ++++++++++++++++----------------- tests/basics/dict_construct.py | 16 +++++ tests/basics/dict_update.py | 6 ++ 3 files changed, 74 insertions(+), 53 deletions(-) create mode 100644 tests/basics/dict_construct.py diff --git a/py/objdict.c b/py/objdict.c index f41eacd939..8a0a08772a 100644 --- a/py/objdict.c +++ b/py/objdict.c @@ -40,7 +40,7 @@ STATIC mp_obj_t mp_obj_new_dict_iterator(mp_obj_dict_t *dict, int cur); STATIC mp_map_elem_t *dict_it_iternext_elem(mp_obj_t self_in); -STATIC mp_obj_t dict_copy(mp_obj_t self_in); +STATIC mp_obj_t dict_update(uint n_args, const mp_obj_t *args, mp_map_t *kwargs); STATIC void dict_print(void (*print)(void *env, const char *fmt, ...), void *env, mp_obj_t self_in, mp_print_kind_t kind) { mp_obj_dict_t *self = self_in; @@ -61,40 +61,13 @@ STATIC void dict_print(void (*print)(void *env, const char *fmt, ...), void *env } STATIC mp_obj_t dict_make_new(mp_obj_t type_in, uint n_args, uint n_kw, const mp_obj_t *args) { - mp_obj_t dict; - switch (n_args) { - case 0: - dict = mp_obj_new_dict(0); - break; - - case 1: { - if (MP_OBJ_IS_TYPE(args[0], &mp_type_dict)) { - return dict_copy(args[0]); - } - // TODO create dict from an arbitrary mapping! - - // Make dict from iterable of pairs - mp_obj_t iterable = mp_getiter(args[0]); - mp_obj_t dict = mp_obj_new_dict(0); - // TODO: support arbitrary seq as a pair - mp_obj_t item; - while ((item = mp_iternext(iterable)) != MP_OBJ_STOP_ITERATION) { - mp_obj_t *sub_items; - mp_obj_get_array_fixed_n(item, 2, &sub_items); - mp_obj_dict_store(dict, sub_items[0], sub_items[1]); - } - return dict; - } - - default: - nlr_raise(mp_obj_new_exception_msg_varg(&mp_type_TypeError, "dict takes at most 1 argument")); + mp_obj_t dict = mp_obj_new_dict(0); + if (n_args > 0 || n_kw > 0) { + mp_obj_t args2[2] = {dict, args[0]}; // args[0] is always valid, even if it's not a positional arg + mp_map_t kwargs; + mp_map_init_fixed_table(&kwargs, n_kw, args + n_args); + dict_update(n_args + 1, args2, &kwargs); // dict_update will check that n_args + 1 == 1 or 2 } - - // add to the new dict any keyword args - for (const mp_obj_t *a = args + n_args; n_kw > 0; n_kw--, a += 2) { - mp_obj_dict_store(dict, a[0], a[1]); - } - return dict; } @@ -348,31 +321,57 @@ STATIC mp_obj_t dict_popitem(mp_obj_t self_in) { } STATIC MP_DEFINE_CONST_FUN_OBJ_1(dict_popitem_obj, dict_popitem); -STATIC mp_obj_t dict_update(mp_obj_t self_in, mp_obj_t iterable) { - assert(MP_OBJ_IS_TYPE(self_in, &mp_type_dict)); - mp_obj_dict_t *self = self_in; - /* TODO: check for the "keys" method */ - mp_obj_t iter = mp_getiter(iterable); - mp_obj_t next = NULL; - while ((next = mp_iternext(iter)) != MP_OBJ_STOP_ITERATION) { - mp_obj_t inneriter = mp_getiter(next); - mp_obj_t key = mp_iternext(inneriter); - mp_obj_t value = mp_iternext(inneriter); - mp_obj_t stop = mp_iternext(inneriter); - if (key == MP_OBJ_STOP_ITERATION - || value == MP_OBJ_STOP_ITERATION - || stop != MP_OBJ_STOP_ITERATION) { - nlr_raise(mp_obj_new_exception_msg( - &mp_type_ValueError, - "dictionary update sequence has the wrong length")); +STATIC mp_obj_t dict_update(uint n_args, const mp_obj_t *args, mp_map_t *kwargs) { + assert(MP_OBJ_IS_TYPE(args[0], &mp_type_dict)); + mp_obj_dict_t *self = args[0]; + + mp_arg_check_num(n_args, kwargs->used, 1, 2, true); + + if (n_args == 2) { + // given a positional argument + + if (MP_OBJ_IS_TYPE(args[1], &mp_type_dict)) { + // update from other dictionary (make sure other is not self) + if (args[1] != self) { + // TODO don't allocate heap object for this iterator + mp_obj_t *dict_iter = mp_obj_new_dict_iterator(args[1], 0); + mp_map_elem_t *elem = NULL; + while ((elem = dict_it_iternext_elem(dict_iter)) != MP_OBJ_STOP_ITERATION) { + mp_map_lookup(&self->map, elem->key, MP_MAP_LOOKUP_ADD_IF_NOT_FOUND)->value = elem->value; + } + } } else { - mp_map_lookup(&self->map, key, MP_MAP_LOOKUP_ADD_IF_NOT_FOUND)->value = value; + // update from a generic iterable of pairs + mp_obj_t iter = mp_getiter(args[1]); + mp_obj_t next = NULL; + while ((next = mp_iternext(iter)) != MP_OBJ_STOP_ITERATION) { + mp_obj_t inneriter = mp_getiter(next); + mp_obj_t key = mp_iternext(inneriter); + mp_obj_t value = mp_iternext(inneriter); + mp_obj_t stop = mp_iternext(inneriter); + if (key == MP_OBJ_STOP_ITERATION + || value == MP_OBJ_STOP_ITERATION + || stop != MP_OBJ_STOP_ITERATION) { + nlr_raise(mp_obj_new_exception_msg( + &mp_type_ValueError, + "dictionary update sequence has the wrong length")); + } else { + mp_map_lookup(&self->map, key, MP_MAP_LOOKUP_ADD_IF_NOT_FOUND)->value = value; + } + } + } + } + + // update the dict with any keyword args + for (machine_uint_t i = 0; i < kwargs->alloc; i++) { + if (MP_MAP_SLOT_IS_FILLED(kwargs, i)) { + mp_map_lookup(&self->map, kwargs->table[i].key, MP_MAP_LOOKUP_ADD_IF_NOT_FOUND)->value = kwargs->table[i].value; } } return mp_const_none; } -STATIC MP_DEFINE_CONST_FUN_OBJ_2(dict_update_obj, dict_update); +STATIC MP_DEFINE_CONST_FUN_OBJ_KW(dict_update_obj, 1, dict_update); /******************************************************************************/ diff --git a/tests/basics/dict_construct.py b/tests/basics/dict_construct.py new file mode 100644 index 0000000000..0035e9c0f9 --- /dev/null +++ b/tests/basics/dict_construct.py @@ -0,0 +1,16 @@ +# dict constructor + +d = dict() +print(d) + +d = dict({1:2}) +print(d) + +d = dict(a=1) +print(d) + +d = dict({1:2}, a=3) +print(d[1], d['a']) + +d = dict([(1, 2)], a=3, b=4) +print(d[1], d['a'], d['b']) diff --git a/tests/basics/dict_update.py b/tests/basics/dict_update.py index 46d1f41b5f..ab1a63304a 100644 --- a/tests/basics/dict_update.py +++ b/tests/basics/dict_update.py @@ -8,3 +8,9 @@ print(len(d)) d.update([(1,4)]) print(d[1]) print(len(d)) + +# using keywords +d.update(a=5) +print(d['a']) +d.update([(1,5)], b=6) +print(d[1], d['b'])