Skip to content

Commit de4e1f0

Browse files
committed
feat: patch __init__ and __call__
1 parent 982af59 commit de4e1f0

File tree

3 files changed

+63
-11
lines changed

3 files changed

+63
-11
lines changed

maxray/__init__.py

+46-5
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,54 @@ def callable_allowed_for_transform(x, ctx: NodeContext):
8080
)
8181

8282

83+
def instance_init_allowed_for_transform(x, ctx: NodeContext):
84+
"""
85+
Decides whether the __init__ method can be transformed.
86+
"""
87+
return (
88+
type(x) is type
89+
and hasattr(x, "__init__")
90+
and not hasattr(x, "_MAXRAY_TRANSFORMED")
91+
)
92+
93+
94+
def instance_call_allowed_for_transform(x, ctx: NodeContext):
95+
"""
96+
Decides whether the __call__ method can be transformed.
97+
"""
98+
return (
99+
type(x) is type
100+
and hasattr(x, "__call__")
101+
and not hasattr(x, "_MAXRAY_TRANSFORMED")
102+
)
103+
104+
83105
def _maxray_walker_handler(x, ctx: NodeContext):
84106
# We ignore writer calls triggered by code execution in other writers to prevent easily getting stuck in recursive hell
85107
if _GLOBAL_WRITER_ACTIVE_FLAG.get():
86108
return x
87109

88-
# 1. logic to recursively patch callables
110+
# 1. logic to recursively patch callables
111+
# 1a. special-case callables: __init__ and __call__
112+
if instance_init_allowed_for_transform(x, ctx):
113+
# TODO: should we somehow delay doing this until before an actual call?
114+
match recompile_fn_with_transform(
115+
x.__init__, _maxray_walker_handler, special_use_instance_type=x
116+
):
117+
case Ok(init_patch):
118+
logger.info(f"Patching __init__ for class {x}")
119+
setattr(x, "__init__", init_patch)
120+
121+
if instance_call_allowed_for_transform(x, ctx):
122+
# TODO: should we somehow delay doing this until before an actual call?
123+
match recompile_fn_with_transform(
124+
x.__call__, _maxray_walker_handler, special_use_instance_type=x
125+
):
126+
case Ok(call_patch):
127+
logger.info(f"Patching __call__ for class {x}")
128+
setattr(x, "__call__", call_patch)
129+
130+
# 1b. normal functions or bound methods or method descriptors like @classmethod and @staticmethod
89131
if callable_allowed_for_transform(x, ctx):
90132
# TODO: don't cache objects w/ __call__
91133
if x in _MAXRAY_FN_CACHE:
@@ -94,25 +136,24 @@ def _maxray_walker_handler(x, ctx: NodeContext):
94136
hook.active_call_state.get() and hook.descend_predicate(x, ctx)
95137
for hook in _MAXRAY_REGISTERED_HOOKS
96138
):
97-
# user-defined filters for which nodes to descend into
139+
# user-defined filters for which nodes (not) to descend into
98140
pass
99141
else:
100142
match recompile_fn_with_transform(x, _maxray_walker_handler):
101143
case Ok(x_trans):
102144
# NOTE: x_trans now has _MAXRAY_TRANSFORMED field to True
103145
if inspect.ismethod(x):
104146
# Two cases: descriptor vs bound method
105-
# TODO: handle callables and .__call__ patching
106147
match x.__self__:
107148
case type():
108149
# Descriptor
109-
logger.warning(
150+
logger.debug(
110151
f"monkey-patching descriptor method {x.__name__} on type {x.__self__}"
111152
)
112153
parent_cls = x.__self__
113154
case _:
114155
# Bound method
115-
logger.warning(
156+
logger.debug(
116157
f"monkey-patching bound method {x.__name__} on type {type(x.__self__)}"
117158
)
118159
parent_cls = type(x.__self__)

maxray/transforms.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ def get_fn_name(fn):
357357
try:
358358
name = repr(fn)
359359
except Exception:
360-
name = "<unrepresentable function>"
360+
name = f"<unrepresentable function of type {type(fn)}>"
361361

362362
return f"{name} @ {id(fn)}"
363363

@@ -369,10 +369,18 @@ def recompile_fn_with_transform(
369369
ast_post_callback=None,
370370
initial_scope={},
371371
pass_scope=False,
372+
special_use_instance_type=None,
372373
) -> Result[Callable, str]:
373374
"""
374375
Recompiles `source_fn` so that essentially every node of its AST tree is wrapped by a call to `transform_fn` with the evaluated value along with context information about the source code.
375376
"""
377+
# TODO: use non-overridable __getattribute__ instead?
378+
if not hasattr(source_fn, "__name__"): # Safety check against weird functions
379+
return Err(f"There is no __name__ for function {get_fn_name(source_fn)}")
380+
381+
if source_fn.__name__ == "<lambda>":
382+
return Err("Cannot safely recompile lambda functions")
383+
376384
# handle `functools.wraps`
377385
if hasattr(source_fn, "__wrapped__"):
378386
# SOUNDNESS: failure when decorators aren't applied at the definition site (will look for the original definition, ignoring any transformations that have been applied before the wrap but after definition)
@@ -404,10 +412,8 @@ def recompile_fn_with_transform(
404412
return Err(
405413
f"No source code for probable built-in function {get_fn_name(source_fn)}"
406414
)
407-
408-
# TODO: use non-overridable __getattribute__ instead?
409-
if not hasattr(source_fn, "__name__"): # Safety check against weird functions
410-
return Err(f"There is no __name__ for function {get_fn_name(source_fn)}")
415+
except SyntaxError:
416+
return Err(f"Syntax error in function {get_fn_name(source_fn)}")
411417

412418
if "super()" in source:
413419
# TODO: we could replace calls to super() with super(__class__, self)?
@@ -438,6 +444,11 @@ def recompile_fn_with_transform(
438444
# Bound method
439445
parent_cls = type(source_fn.__self__)
440446

447+
# yeah yeah an unbound __init__ isn't actually a method but we can basically treat it as one
448+
if special_use_instance_type is not None:
449+
fn_is_method = True
450+
parent_cls = special_use_instance_type
451+
441452
fn_call_counter = ContextVar("maxray_call_counter", default=0)
442453
fn_context = FnContext(
443454
source_fn,

tests/test_transforms.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ def uh():
402402
z = X()
403403
return z()
404404

405-
assert uh() == 3
405+
assert uh() == 4
406406

407407

408408
def test_junk_annotations():

0 commit comments

Comments
 (0)