diff --git a/pythonwhat/parsing.py b/pythonwhat/parsing.py index cd9102d4..e017c127 100644 --- a/pythonwhat/parsing.py +++ b/pythonwhat/parsing.py @@ -1,4 +1,6 @@ import ast +import re + from pythonwhat.utils_ast import wrap_in_module from collections.abc import Sequence, Mapping from collections import OrderedDict @@ -326,6 +328,7 @@ def visit_Dict(self, node): def visit_Call(self, node): if self.call_lookup_active: self.visit(node.func) + self.gen_name += "()" else: self.call_lookup_active = True self.visit( @@ -333,6 +336,7 @@ def visit_Call(self, node): ) # Need to visit func to start recording the current function name. if self.gen_name: + self.gen_name = re.sub(r"(?:\(\))+(.)", "\\1", self.gen_name) if self.gen_name not in self.out: self.out[self.gen_name] = [] diff --git a/tests/test_check_function.py b/tests/test_check_function.py index 6bb44d3e..85636fc5 100644 --- a/tests/test_check_function.py +++ b/tests/test_check_function.py @@ -520,3 +520,30 @@ def test_function_call_in_comparison(code): sct = "Ex().check_function('len')" res = helper.run({"DC_CODE": code, "DC_SOLUTION": code, "DC_SCT": sct}) assert res["correct"] + + +@pytest.mark.parametrize( + "sct", + [ + "Ex().check_function('numpy.array')", + "Ex().check_function('hof').check_args(0).has_equal_value(override=1)", + "Ex().check_function('hof()').check_args(0).has_equal_value(override=2)", + ], +) +def test_ho_function(sct): + + code = """ +import numpy as np +np.array([]) + +def hof(arg1): + def inner(arg2): + return arg1, arg2 + + return inner + +hof(1)(2) + """ + + res = helper.run({"DC_CODE": code, "DC_SOLUTION": code, "DC_SCT": sct}) + assert res["correct"]