@@ -36,7 +36,8 @@ class FnContext:
36
36
call_count : ContextVar [int ]
37
37
38
38
def __repr__ (self ):
39
- return f"{ self .module } /{ self .name } /{ self .call_count .get ()} "
39
+ # Call count not included in repr so the same source location can be "grouped by" over multiple calls
40
+ return f"{ self .module } /{ self .name } "
40
41
41
42
42
43
@dataclass
@@ -57,6 +58,8 @@ class NodeContext:
57
58
58
59
location : tuple [int , int , int , int ]
59
60
61
+ local_scope : Any = None
62
+
60
63
def __repr__ (self ):
61
64
return f"{ self .fn_context } /{ self .id } "
62
65
@@ -70,6 +73,7 @@ def __init__(
70
73
instance_type : str | None ,
71
74
dedent_chars : int = 0 ,
72
75
record_call_counts : bool = True ,
76
+ pass_locals_on_return : bool = False ,
73
77
):
74
78
"""
75
79
If we're transforming a method, instance type should be the __name__ of the class. Otherwise, None.
@@ -80,6 +84,7 @@ def __init__(
80
84
self .instance_type = instance_type
81
85
self .dedent_chars = dedent_chars
82
86
self .record_call_counts = record_call_counts
87
+ self .pass_locals_on_return = pass_locals_on_return
83
88
84
89
# the first `def` we encounter is the one that we're transforming. Subsequent ones will be nested/within class definitions.
85
90
self .fn_count = 0
@@ -113,7 +118,7 @@ def recover_source(self, pre_node):
113
118
return self .safe_unparse (pre_node )
114
119
return segment
115
120
116
- def build_transform_node (self , node , label , node_source = None ):
121
+ def build_transform_node (self , node , label , node_source = None , pass_locals = False ):
117
122
"""
118
123
Builds the "inspection" node that wraps the original source node - passing the (value, context) pair to `transform_fn`.
119
124
"""
@@ -122,22 +127,33 @@ def build_transform_node(self, node, label, node_source=None):
122
127
123
128
line_offset = self .fn_context .impl_fn .__code__ .co_firstlineno - 2
124
129
col_offset = self .dedent_chars
130
+
131
+ context_args = [
132
+ ast .Constant (label ),
133
+ ast .Constant (node_source ),
134
+ # Name is injected into the exec scope by `recompile_fn_with_transform`
135
+ ast .Name (id = "_MAXRAY_FN_CONTEXT" , ctx = ast .Load ()),
136
+ ast .Constant (
137
+ (
138
+ line_offset + node .lineno ,
139
+ line_offset + node .end_lineno ,
140
+ node .col_offset + col_offset ,
141
+ node .end_col_offset + col_offset ,
142
+ )
143
+ ),
144
+ ]
145
+
146
+ if pass_locals :
147
+ context_args .append (
148
+ ast .Call (
149
+ func = ast .Name (id = "_MAXRAY_BUILTINS_LOCALS" , ctx = ast .Load ()),
150
+ args = [],
151
+ keywords = [],
152
+ ),
153
+ )
125
154
context_node = ast .Call (
126
155
func = ast .Name (id = NodeContext .__name__ , ctx = ast .Load ()),
127
- args = [
128
- ast .Constant (label ),
129
- ast .Constant (node_source ),
130
- # Name is injected into the exec scope by `recompile_fn_with_transform`
131
- ast .Name (id = "_MAXRAY_FN_CONTEXT" , ctx = ast .Load ()),
132
- ast .Constant (
133
- (
134
- line_offset + node .lineno ,
135
- line_offset + node .end_lineno ,
136
- node .col_offset + col_offset ,
137
- node .end_col_offset + col_offset ,
138
- )
139
- ),
140
- ],
156
+ args = context_args ,
141
157
keywords = [],
142
158
)
143
159
@@ -196,6 +212,28 @@ def visit_Assign(self, node: ast.Assign) -> Any:
196
212
# node.value = self.build_transform_node(new_node, f"assign/(multiple)")
197
213
return node
198
214
215
+ def visit_Return (self , node : ast .Return ) -> Any :
216
+ node_pre = deepcopy (node )
217
+
218
+ if node .value is None :
219
+ node .value = ast .Constant (None )
220
+
221
+ # Note: For a plain `return` statement, there's no source for a thing that *isn't* returned
222
+ value_source_pre = self .recover_source (node .value )
223
+
224
+ node = self .generic_visit (node )
225
+
226
+ # TODO: Check source locations are correct here
227
+ ast .fix_missing_locations (node )
228
+ node .value = self .build_transform_node (
229
+ node .value ,
230
+ f"return/{ value_source_pre } " ,
231
+ node_source = value_source_pre ,
232
+ pass_locals = self .pass_locals_on_return ,
233
+ )
234
+
235
+ return ast .copy_location (node , node_pre )
236
+
199
237
def visit_Call (self , node ):
200
238
source_pre = self .recover_source (node )
201
239
@@ -204,14 +242,6 @@ def visit_Call(self, node):
204
242
node = self .generic_visit (node ) # mutates
205
243
206
244
# the function/callable instance itself is observed by Name/Attribute/... nodes
207
-
208
- target = node .func
209
- match target :
210
- case ast .Name ():
211
- logger .debug (f"Visiting call to function { target .id } " )
212
- case ast .Attribute ():
213
- logger .debug (f"Visiting call to attribute { target .attr } " )
214
-
215
245
return ast .copy_location (
216
246
self .build_transform_node (
217
247
node , f"call/{ source_pre } " , node_source = source_pre
@@ -338,6 +368,7 @@ def recompile_fn_with_transform(
338
368
ast_pre_callback = None ,
339
369
ast_post_callback = None ,
340
370
initial_scope = {},
371
+ pass_scope = False ,
341
372
) -> Result [Callable , str ]:
342
373
"""
343
374
Recompiles `source_fn` so that essentially every node of its AST tree is wrapped by a call to `transform_fn` with the evaluated value along with context information about the source code.
@@ -421,6 +452,7 @@ def recompile_fn_with_transform(
421
452
fn_context ,
422
453
instance_type = parent_cls .__name__ if fn_is_method else None ,
423
454
dedent_chars = dedent_chars ,
455
+ pass_locals_on_return = pass_scope ,
424
456
).visit (fn_ast )
425
457
ast .fix_missing_locations (transformed_fn_ast )
426
458
@@ -433,6 +465,7 @@ def recompile_fn_with_transform(
433
465
"_MAXRAY_FN_CONTEXT" : fn_context ,
434
466
"_MAXRAY_CALL_COUNTER" : fn_call_counter ,
435
467
"_MAXRAY_DECORATE_WITH_COUNTER" : count_calls_with ,
468
+ "_MAXRAY_BUILTINS_LOCALS" : locals ,
436
469
}
437
470
scope .update (initial_scope )
438
471
0 commit comments