diff --git a/py/objtype.c b/py/objtype.c index 3d22b8ec45..75d53b8279 100644 --- a/py/objtype.c +++ b/py/objtype.c @@ -338,6 +338,16 @@ STATIC mp_obj_t class_subscr(mp_obj_t self_in, mp_obj_t index, mp_obj_t value) { } } +STATIC mp_obj_t class_call(mp_obj_t self_in, uint n_args, uint n_kw, const mp_obj_t *args) { + mp_obj_class_t *self = self_in; + mp_obj_t member = mp_obj_class_lookup(self->base.type, MP_QSTR___call__); + if (member == MP_OBJ_NULL) { + return member; + } + mp_obj_t meth = mp_obj_new_bound_meth(member, self); + return mp_call_function_n_kw(meth, n_args, n_kw, args); +} + /******************************************************************************/ // type object // - the struct is mp_obj_type_t and is defined in obj.h so const types can be made @@ -474,6 +484,7 @@ mp_obj_t mp_obj_new_type(qstr name, mp_obj_t bases_tuple, mp_obj_t locals_dict) o->load_attr = class_load_attr; o->store_attr = class_store_attr; o->subscr = class_subscr; + o->call = class_call; o->bases_tuple = bases_tuple; o->locals_dict = locals_dict; return o; diff --git a/py/qstrdefs.h b/py/qstrdefs.h index e970f58bef..784bf59d23 100644 --- a/py/qstrdefs.h +++ b/py/qstrdefs.h @@ -31,6 +31,7 @@ Q(__repr__) Q(__str__) Q(__getattr__) Q(__del__) +Q(__call__) Q(micropython) Q(byte_code) diff --git a/py/runtime.c b/py/runtime.c index 30db01cd53..b56740a022 100644 --- a/py/runtime.c +++ b/py/runtime.c @@ -481,10 +481,13 @@ mp_obj_t mp_call_function_n_kw(mp_obj_t fun_in, uint n_args, uint n_kw, const mp // do the call if (type->call != NULL) { - return type->call(fun_in, n_args, n_kw, args); - } else { - nlr_raise(mp_obj_new_exception_msg_varg(&mp_type_TypeError, "'%s' object is not callable", mp_obj_get_type_str(fun_in))); + mp_obj_t res = type->call(fun_in, n_args, n_kw, args); + if (res != NULL) { + return res; + } } + + nlr_raise(mp_obj_new_exception_msg_varg(&mp_type_TypeError, "'%s' object is not callable", mp_obj_get_type_str(fun_in))); } // args contains: fun self/NULL arg(0) ... arg(n_args-2) arg(n_args-1) kw_key(0) kw_val(0) ... kw_key(n_kw-1) kw_val(n_kw-1) diff --git a/tests/basics/class_call.py b/tests/basics/class_call.py new file mode 100644 index 0000000000..b7a3d70f9e --- /dev/null +++ b/tests/basics/class_call.py @@ -0,0 +1,18 @@ +class C1: + def __call__(self, val): + print('call', val) + return 'item' + +class C2: + + def __getattr__(self, k): + pass + +c1 = C1() +print(c1(1)) + +c2 = C2() +try: + print(c2(1)) +except TypeError: + print("TypeError")