diff --git a/py/objtype.c b/py/objtype.c index f812a0e86c..7689e42b25 100644 --- a/py/objtype.c +++ b/py/objtype.c @@ -46,6 +46,8 @@ #define DEBUG_printf(...) (void)0 #endif +STATIC mp_obj_t static_class_method_make_new(mp_obj_t self_in, uint n_args, uint n_kw, const mp_obj_t *args); + /******************************************************************************/ // instance object @@ -749,6 +751,8 @@ mp_obj_t mp_obj_new_type(qstr name, mp_obj_t bases_tuple, mp_obj_t locals_dict) assert(MP_OBJ_IS_TYPE(bases_tuple, &mp_type_tuple)); // Micro Python restriction, for now assert(MP_OBJ_IS_TYPE(locals_dict, &mp_type_dict)); // Micro Python restriction, for now + // TODO might need to make a copy of locals_dict; at least that's how CPython does it + // Basic validation of base classes uint len; mp_obj_t *items; @@ -783,6 +787,16 @@ mp_obj_t mp_obj_new_type(qstr name, mp_obj_t bases_tuple, mp_obj_t locals_dict) nlr_raise(mp_obj_new_exception_msg(&mp_type_TypeError, "multiple bases have instance lay-out conflict")); } + mp_map_t *locals_map = mp_obj_dict_get_map(o->locals_dict); + mp_map_elem_t *elem = mp_map_lookup(locals_map, MP_OBJ_NEW_QSTR(MP_QSTR___new__), MP_MAP_LOOKUP); + if (elem != NULL) { + // __new__ slot exists; check if it is a function + if (MP_OBJ_IS_TYPE(elem->value, &mp_type_fun_native) || MP_OBJ_IS_TYPE(elem->value, &mp_type_fun_bc)) { + // __new__ is a function, wrap it in a staticmethod decorator + elem->value = static_class_method_make_new((mp_obj_t)&mp_type_staticmethod, 1, 0, &elem->value); + } + } + return o; } diff --git a/tests/basics/class_new.py b/tests/basics/class_new.py index 7fedcab6c2..7e84dccf40 100644 --- a/tests/basics/class_new.py +++ b/tests/basics/class_new.py @@ -1,6 +1,4 @@ class A: - - @staticmethod def __new__(cls): print("A.__new__") return super(cls, A).__new__(cls) @@ -9,13 +7,21 @@ class A: pass def meth(self): - pass + print('A.meth') #print(A.__new__) #print(A.__init__) a = A() +a.meth() + +a = A.__new__(A) +a.meth() #print(a.meth) #print(a.__init__) #print(a.__new__) + +# __new__ should automatically be a staticmethod, so this should work +a = a.__new__(A) +a.meth()