Skip to content

Commit 8afe9ad

Browse files
committed
feat: track return nodes, optionally collecting the local scope on exit
1 parent e842dfe commit 8afe9ad

File tree

3 files changed

+112
-40
lines changed

3 files changed

+112
-40
lines changed

maxray/__init__.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@ def inner(fn):
2626
return inner
2727

2828

29-
def xray(walker):
29+
def xray(walker, **kwargs):
3030
"""
3131
Immutable version of `maxray` - expressions are passed to `walker` but its return value is ignored and the original code execution is left unchanged.
3232
"""
33-
return maxray(walker, mutable=False)
33+
return maxray(walker, **kwargs, mutable=False)
3434

3535

3636
_GLOBAL_SKIP_MODULES = {
@@ -65,6 +65,9 @@ class W_erHook:
6565

6666

6767
def callable_allowed_for_transform(x, ctx: NodeContext):
68+
if getattr(x, "__module__", None) in _GLOBAL_SKIP_MODULES:
69+
return False
70+
6871
module_path = ctx.fn_context.module.split(".")
6972
if module_path[0] in _GLOBAL_SKIP_MODULES:
7073
return False
@@ -155,7 +158,11 @@ def _maxray_walker_handler(x, ctx: NodeContext):
155158

156159

157160
def maxray(
158-
writer: Callable[[Any, NodeContext], Any], skip_modules=frozenset(), *, mutable=True
161+
writer: Callable[[Any, NodeContext], Any],
162+
skip_modules=frozenset(),
163+
*,
164+
mutable=True,
165+
pass_scope=False,
159166
):
160167
"""
161168
A transform that recursively hooks into all further calls made within the function, so that `writer` will (in theory) observe every single expression evaluated by the Python interpreter occurring as part of the decorated function call.
@@ -198,7 +205,10 @@ def recursive_transform(fn):
198205
fn_transform = fn
199206
else:
200207
match recompile_fn_with_transform(
201-
fn, _maxray_walker_handler, initial_scope=caller_locals
208+
fn,
209+
_maxray_walker_handler,
210+
initial_scope=caller_locals,
211+
pass_scope=pass_scope,
202212
):
203213
case Ok(fn_transform):
204214
pass

maxray/transforms.py

+57-24
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ class FnContext:
3636
call_count: ContextVar[int]
3737

3838
def __repr__(self):
39-
return f"{self.module}/{self.name}/{self.call_count.get()}"
39+
# Call count not included in repr so the same source location can be "grouped by" over multiple calls
40+
return f"{self.module}/{self.name}"
4041

4142

4243
@dataclass
@@ -57,6 +58,8 @@ class NodeContext:
5758

5859
location: tuple[int, int, int, int]
5960

61+
local_scope: Any = None
62+
6063
def __repr__(self):
6164
return f"{self.fn_context}/{self.id}"
6265

@@ -70,6 +73,7 @@ def __init__(
7073
instance_type: str | None,
7174
dedent_chars: int = 0,
7275
record_call_counts: bool = True,
76+
pass_locals_on_return: bool = False,
7377
):
7478
"""
7579
If we're transforming a method, instance type should be the __name__ of the class. Otherwise, None.
@@ -80,6 +84,7 @@ def __init__(
8084
self.instance_type = instance_type
8185
self.dedent_chars = dedent_chars
8286
self.record_call_counts = record_call_counts
87+
self.pass_locals_on_return = pass_locals_on_return
8388

8489
# the first `def` we encounter is the one that we're transforming. Subsequent ones will be nested/within class definitions.
8590
self.fn_count = 0
@@ -113,7 +118,7 @@ def recover_source(self, pre_node):
113118
return self.safe_unparse(pre_node)
114119
return segment
115120

116-
def build_transform_node(self, node, label, node_source=None):
121+
def build_transform_node(self, node, label, node_source=None, pass_locals=False):
117122
"""
118123
Builds the "inspection" node that wraps the original source node - passing the (value, context) pair to `transform_fn`.
119124
"""
@@ -122,22 +127,33 @@ def build_transform_node(self, node, label, node_source=None):
122127

123128
line_offset = self.fn_context.impl_fn.__code__.co_firstlineno - 2
124129
col_offset = self.dedent_chars
130+
131+
context_args = [
132+
ast.Constant(label),
133+
ast.Constant(node_source),
134+
# Name is injected into the exec scope by `recompile_fn_with_transform`
135+
ast.Name(id="_MAXRAY_FN_CONTEXT", ctx=ast.Load()),
136+
ast.Constant(
137+
(
138+
line_offset + node.lineno,
139+
line_offset + node.end_lineno,
140+
node.col_offset + col_offset,
141+
node.end_col_offset + col_offset,
142+
)
143+
),
144+
]
145+
146+
if pass_locals:
147+
context_args.append(
148+
ast.Call(
149+
func=ast.Name(id="_MAXRAY_BUILTINS_LOCALS", ctx=ast.Load()),
150+
args=[],
151+
keywords=[],
152+
),
153+
)
125154
context_node = ast.Call(
126155
func=ast.Name(id=NodeContext.__name__, ctx=ast.Load()),
127-
args=[
128-
ast.Constant(label),
129-
ast.Constant(node_source),
130-
# Name is injected into the exec scope by `recompile_fn_with_transform`
131-
ast.Name(id="_MAXRAY_FN_CONTEXT", ctx=ast.Load()),
132-
ast.Constant(
133-
(
134-
line_offset + node.lineno,
135-
line_offset + node.end_lineno,
136-
node.col_offset + col_offset,
137-
node.end_col_offset + col_offset,
138-
)
139-
),
140-
],
156+
args=context_args,
141157
keywords=[],
142158
)
143159

@@ -196,6 +212,28 @@ def visit_Assign(self, node: ast.Assign) -> Any:
196212
# node.value = self.build_transform_node(new_node, f"assign/(multiple)")
197213
return node
198214

215+
def visit_Return(self, node: ast.Return) -> Any:
216+
node_pre = deepcopy(node)
217+
218+
if node.value is None:
219+
node.value = ast.Constant(None)
220+
221+
# Note: For a plain `return` statement, there's no source for a thing that *isn't* returned
222+
value_source_pre = self.recover_source(node.value)
223+
224+
node = self.generic_visit(node)
225+
226+
# TODO: Check source locations are correct here
227+
ast.fix_missing_locations(node)
228+
node.value = self.build_transform_node(
229+
node.value,
230+
f"return/{value_source_pre}",
231+
node_source=value_source_pre,
232+
pass_locals=self.pass_locals_on_return,
233+
)
234+
235+
return ast.copy_location(node, node_pre)
236+
199237
def visit_Call(self, node):
200238
source_pre = self.recover_source(node)
201239

@@ -204,14 +242,6 @@ def visit_Call(self, node):
204242
node = self.generic_visit(node) # mutates
205243

206244
# the function/callable instance itself is observed by Name/Attribute/... nodes
207-
208-
target = node.func
209-
match target:
210-
case ast.Name():
211-
logger.debug(f"Visiting call to function {target.id}")
212-
case ast.Attribute():
213-
logger.debug(f"Visiting call to attribute {target.attr}")
214-
215245
return ast.copy_location(
216246
self.build_transform_node(
217247
node, f"call/{source_pre}", node_source=source_pre
@@ -338,6 +368,7 @@ def recompile_fn_with_transform(
338368
ast_pre_callback=None,
339369
ast_post_callback=None,
340370
initial_scope={},
371+
pass_scope=False,
341372
) -> Result[Callable, str]:
342373
"""
343374
Recompiles `source_fn` so that essentially every node of its AST tree is wrapped by a call to `transform_fn` with the evaluated value along with context information about the source code.
@@ -421,6 +452,7 @@ def recompile_fn_with_transform(
421452
fn_context,
422453
instance_type=parent_cls.__name__ if fn_is_method else None,
423454
dedent_chars=dedent_chars,
455+
pass_locals_on_return=pass_scope,
424456
).visit(fn_ast)
425457
ast.fix_missing_locations(transformed_fn_ast)
426458

@@ -433,6 +465,7 @@ def recompile_fn_with_transform(
433465
"_MAXRAY_FN_CONTEXT": fn_context,
434466
"_MAXRAY_CALL_COUNTER": fn_call_counter,
435467
"_MAXRAY_DECORATE_WITH_COUNTER": count_calls_with,
468+
"_MAXRAY_BUILTINS_LOCALS": locals,
436469
}
437470
scope.update(initial_scope)
438471

tests/test_transforms.py

+41-12
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def test_basic():
1919
def f(x):
2020
return x
2121

22-
assert f(3) == 4
22+
assert f(3) == 5
2323

2424

2525
def test_type_hints():
@@ -29,7 +29,7 @@ def test_type_hints():
2929
def f(x: Any):
3030
return x
3131

32-
assert f(3) == 4
32+
assert f(3) == 5
3333

3434

3535
def test_closure_capture():
@@ -39,7 +39,7 @@ def test_closure_capture():
3939
def f(x):
4040
return x + z
4141

42-
assert f(3) == 6
42+
assert f(3) == 7
4343

4444

4545
def test_closure_capture_mutate():
@@ -62,7 +62,7 @@ def test_global_capture():
6262
def g(x):
6363
return x + GLOB_CONST
6464

65-
assert g(3) == 10
65+
assert g(3) == 11
6666

6767

6868
def test_nested_def():
@@ -76,8 +76,8 @@ def g(x):
7676

7777
return g
7878

79-
assert outer()(3) == 4
80-
assert outer()(3) == 4
79+
assert outer()(3) == 5
80+
assert outer()(3) == 5
8181

8282

8383
def test_recursive():
@@ -150,7 +150,7 @@ def f():
150150
pass
151151
return x
152152

153-
assert f() == 4
153+
assert f() == 8
154154

155155

156156
def test_property_access():
@@ -165,7 +165,7 @@ class A:
165165
def g():
166166
return obj.x
167167

168-
assert g() == 2
168+
assert g() == 3
169169

170170

171171
def test_method():
@@ -377,7 +377,7 @@ def dec(f):
377377
def f(x):
378378
return x
379379

380-
assert f(2) == 3
380+
assert f(2) == 4
381381
assert len(decor_count) == 1
382382

383383
# Works properly when applied last: is wiped for the transform, but is subsequently applied properly to the transformed function
@@ -386,7 +386,7 @@ def f(x):
386386
def f(x):
387387
return x
388388

389-
assert f(2) == 1
389+
assert f(2) == 0
390390
assert len(decor_count) == 2
391391

392392

@@ -402,7 +402,7 @@ def uh():
402402
z = X()
403403
return z()
404404

405-
assert uh() == 2
405+
assert uh() == 3
406406

407407

408408
def test_junk_annotations():
@@ -413,7 +413,7 @@ def inner(x: ASDF = 0, *, y: SDFSDF = 100) -> AAAAAAAAAAA:
413413

414414
return inner(2)
415415

416-
assert outer() == 105
416+
assert outer() == 107
417417

418418

419419
def test_call_counts():
@@ -454,6 +454,35 @@ def f(x):
454454
assert calls == [1, 2, 3, 3, 2, 1]
455455

456456

457+
def test_empty_return():
458+
@xray(dbg)
459+
def empty_returns():
460+
return
461+
462+
assert empty_returns() is None
463+
464+
465+
def test_scope_passed():
466+
found_scope = None
467+
468+
def get_scope(x, ctx):
469+
nonlocal found_scope
470+
if ctx.local_scope is not None:
471+
assert found_scope is None
472+
found_scope = ctx.local_scope
473+
return x
474+
475+
@xray(get_scope, pass_scope=True)
476+
def f(n):
477+
z = 3
478+
return n
479+
480+
assert f(1) == 1
481+
482+
assert "z" in found_scope
483+
assert found_scope["z"] == 3
484+
485+
457486
def test_wrap_unsound():
458487
# TODO
459488
pass

0 commit comments

Comments
 (0)