From 29e9db0c587145a8823227635734924896cdc4d1 Mon Sep 17 00:00:00 2001 From: Damien George Date: Sat, 12 Dec 2015 13:42:51 +0000 Subject: [PATCH] py: Fix compiler to handle lambdas used as default arguments. Addresses issue #1709. --- py/compile.c | 12 ++++++++++++ tests/basics/fun_defargs.py | 9 +++++++++ tests/basics/fun_kwonly.py | 7 +++++++ 3 files changed, 28 insertions(+) diff --git a/py/compile.c b/py/compile.c index 841b8f90c0..c84d23e943 100644 --- a/py/compile.c +++ b/py/compile.c @@ -662,6 +662,13 @@ STATIC void compile_funcdef_lambdef_param(compiler_t *comp, mp_parse_node_t pn) } STATIC void compile_funcdef_lambdef(compiler_t *comp, scope_t *scope, mp_parse_node_t pn_params, pn_kind_t pn_list_kind) { + // When we call compile_funcdef_lambdef_param below it can compile an arbitrary + // expression for default arguments, which may contain a lambda. The lambda will + // call here in a nested way, so we must save and restore the relevant state. + bool orig_have_star = comp->have_star; + uint16_t orig_num_dict_params = comp->num_dict_params; + uint16_t orig_num_default_params = comp->num_default_params; + // compile default parameters comp->have_star = false; comp->num_dict_params = 0; @@ -681,6 +688,11 @@ STATIC void compile_funcdef_lambdef(compiler_t *comp, scope_t *scope, mp_parse_n // make the function close_over_variables_etc(comp, scope, comp->num_default_params, comp->num_dict_params); + + // restore state + comp->have_star = orig_have_star; + comp->num_dict_params = orig_num_dict_params; + comp->num_default_params = orig_num_default_params; } // leaves function object on stack diff --git a/tests/basics/fun_defargs.py b/tests/basics/fun_defargs.py index ed25f5739d..1466c44094 100644 --- a/tests/basics/fun_defargs.py +++ b/tests/basics/fun_defargs.py @@ -1,3 +1,5 @@ +# testing default args to a function + def fun1(val=5): print(val) @@ -18,3 +20,10 @@ try: fun2(1, 2, 3, 4) except TypeError: print("TypeError") + +# lambda as default arg (exposes nested behaviour in compiler) +def f(x=lambda:1): + return x() +print(f()) +print(f(f)) +print(f(lambda:2)) diff --git a/tests/basics/fun_kwonly.py b/tests/basics/fun_kwonly.py index bdff3a8210..7694c8ddca 100644 --- a/tests/basics/fun_kwonly.py +++ b/tests/basics/fun_kwonly.py @@ -57,3 +57,10 @@ def f(a, *b, c): f(1, c=2) f(1, 2, c=3) f(a=1, c=3) + +# lambda as kw-only arg (exposes nested behaviour in compiler) +def f(*, x=lambda:1): + return x() +print(f()) +print(f(x=f)) +print(f(x=lambda:2))