1
1
import ast
2
2
import inspect
3
3
import sys
4
+ import uuid
4
5
5
6
from textwrap import dedent
6
7
from pathlib import Path
7
- from dataclasses import dataclass
8
+ from dataclasses import dataclass , asdict
8
9
from copy import deepcopy
9
10
from result import Result , Ok , Err
10
11
from contextvars import ContextVar
@@ -60,9 +61,25 @@ class NodeContext:
60
61
61
62
local_scope : Any = None
62
63
64
+ caller_id : Any = None
65
+
63
66
def __repr__ (self ):
64
67
return f"{ self .fn_context } /{ self .id } "
65
68
69
+ def to_dict (self ):
70
+ return {
71
+ "id" : self .id ,
72
+ "source" : self .source ,
73
+ "location" : self .location ,
74
+ "caller_id" : self .caller_id ,
75
+ "fn_context" : {
76
+ "name" : self .fn_context .name ,
77
+ "module" : self .fn_context .module ,
78
+ "source_file" : self .fn_context .source_file ,
79
+ "call_count" : self .fn_context .call_count .get (),
80
+ },
81
+ }
82
+
66
83
67
84
class FnRewriter (ast .NodeTransformer ):
68
85
def __init__ (
@@ -120,10 +137,13 @@ def recover_source(self, pre_node):
120
137
return self .safe_unparse (pre_node )
121
138
return segment
122
139
123
- def build_transform_node (self , node , label , node_source = None , pass_locals = False ):
140
+ def build_transform_node (
141
+ self , node , label , node_source = None , pass_locals = False , extra_kwargs = None
142
+ ):
124
143
"""
125
144
Builds the "inspection" node that wraps the original source node - passing the (value, context) pair to `transform_fn`.
126
145
"""
146
+ node = deepcopy (node )
127
147
if node_source is None :
128
148
node_source = self .safe_unparse (node )
129
149
@@ -145,28 +165,38 @@ def build_transform_node(self, node, label, node_source=None, pass_locals=False)
145
165
),
146
166
]
147
167
168
+ keyword_args = []
169
+
170
+ if extra_kwargs is not None :
171
+ keyword_args .extend (extra_kwargs )
172
+
148
173
if pass_locals :
149
- context_args .append (
150
- ast .Call (
151
- func = ast .Name (id = "_MAXRAY_BUILTINS_LOCALS" , ctx = ast .Load ()),
152
- args = [],
153
- keywords = [],
154
- ),
174
+ keyword_args .append (
175
+ ast .keyword (
176
+ arg = "local_scope" ,
177
+ value = ast .Call (
178
+ func = ast .Name (id = "_MAXRAY_BUILTINS_LOCALS" , ctx = ast .Load ()),
179
+ args = [],
180
+ keywords = [],
181
+ ),
182
+ )
155
183
)
184
+
156
185
context_node = ast .Call (
157
186
func = ast .Name (id = NodeContext .__name__ , ctx = ast .Load ()),
158
187
args = context_args ,
159
- keywords = [] ,
188
+ keywords = keyword_args ,
160
189
)
161
-
162
- return ast .Call (
190
+ ret = ast .Call (
163
191
func = ast .Name (id = self .transform_fn .__name__ , ctx = ast .Load ()),
164
192
args = [node , context_node ],
165
193
keywords = [],
166
194
)
195
+ return ret
167
196
168
197
def visit_Name (self , node ):
169
198
source_pre = self .recover_source (node )
199
+ node = deepcopy (node )
170
200
171
201
match node .ctx :
172
202
case ast .Load ():
@@ -189,6 +219,7 @@ def visit_Attribute(self, node: ast.Attribute) -> Any:
189
219
> Private name mangling: When an identifier that textually occurs in a class definition begins with two or more underscore characters and does not end in two or more underscores, it is considered a private name of that class. Private names are transformed to a longer form before code is generated for them. The transformation inserts the class name, with leading underscores removed and a single underscore inserted, in front of the name. For example, the identifier __spam occurring in a class named Ham will be transformed to _Ham__spam. This transformation is independent of the syntactical context in which the identifier is used. If the transformed name is extremely long (longer than 255 characters), implementation defined truncation may happen. If the class name consists only of underscores, no transformation is done.
190
220
"""
191
221
source_pre = self .recover_source (node )
222
+ node = deepcopy (node )
192
223
193
224
if self .is_method () and self .is_private_class_name (node .attr ):
194
225
node .attr = f"_{ self .instance_type } { node .attr } "
@@ -208,6 +239,7 @@ def visit_match_case(self, node: ast.match_case) -> Any:
208
239
return node
209
240
210
241
def visit_Assign (self , node : ast .Assign ) -> Any :
242
+ node = deepcopy (node )
211
243
new_node = self .generic_visit (node )
212
244
assert isinstance (new_node , ast .Assign )
213
245
# node = new_node
@@ -216,6 +248,7 @@ def visit_Assign(self, node: ast.Assign) -> Any:
216
248
217
249
def visit_Return (self , node : ast .Return ) -> Any :
218
250
node_pre = deepcopy (node )
251
+ node = deepcopy (node )
219
252
220
253
if node .value is None :
221
254
node .value = ast .Constant (None )
@@ -236,9 +269,37 @@ def visit_Return(self, node: ast.Return) -> Any:
236
269
237
270
return ast .copy_location (node , node_pre )
238
271
272
+ @staticmethod
273
+ def temp_binding (node , by_name : str ):
274
+ return ast .fix_missing_locations (
275
+ ast .Call (
276
+ ast .Name ("_MAXRAY_SET_TEMP" , ctx = ast .Load ()),
277
+ [node , ast .Constant (by_name )],
278
+ keywords = [],
279
+ )
280
+ )
281
+ # Cannot use walrus expressions in nested list comprehension context
282
+ # return ast.fix_missing_locations(
283
+ # ast.Subscript(
284
+ # value=ast.Tuple(
285
+ # elts=[
286
+ # ast.NamedExpr(
287
+ # target=ast.Name(id=by_name, ctx=ast.Store()),
288
+ # value=node,
289
+ # ),
290
+ # ast.Name(id=by_name, ctx=ast.Load()),
291
+ # ],
292
+ # ctx=ast.Load(),
293
+ # ),
294
+ # slice=ast.Constant(-1),
295
+ # ctx=ast.Load(),
296
+ # )
297
+ # )
298
+
239
299
def visit_Call (self , node ):
240
300
source_pre = self .recover_source (node )
241
301
node_pre = deepcopy (node )
302
+ node = deepcopy (node )
242
303
243
304
match node :
244
305
case ast .Call (func = ast .Name (id = "super" ), args = []):
@@ -262,17 +323,44 @@ def visit_Call(self, node):
262
323
ast .fix_missing_locations (node )
263
324
264
325
node = self .generic_visit (node ) # mutates
326
+ # Want to keep track of which function we're calling
327
+ # `node.func` is now likely a call to `_maxray_walker_handler`
328
+ node .func = self .temp_binding (node .func , "_MAXRAY_TEMP_VAR" )
329
+ # We can't use `node_pre.func` beacuse evaluating it has side effects
330
+
331
+ # TODO: maybe assign the super() proxy and do the MRO patching at the start of the function
332
+ extra_kwargs = []
333
+ if "super" not in source_pre :
334
+ extra_kwargs .append (
335
+ ast .keyword (
336
+ "caller_id" ,
337
+ # value=ast.Name("_MAXRAY_TEMP_VAR", ctx=ast.Load()),
338
+ value = ast .Call (
339
+ func = ast .Name ("getattr" , ctx = ast .Load ()),
340
+ args = [
341
+ ast .Name ("_MAXRAY_TEMP_VAR" , ctx = ast .Load ()),
342
+ ast .Constant ("_MAXRAY_TRANSFORM_ID" ),
343
+ ast .Constant (None ),
344
+ ],
345
+ keywords = [],
346
+ ),
347
+ )
348
+ )
265
349
266
350
# the function/callable instance itself is observed by Name/Attribute/... nodes
267
351
return ast .copy_location (
268
352
self .build_transform_node (
269
- node , f"call/{ source_pre } " , node_source = source_pre
353
+ node ,
354
+ f"call/{ source_pre } " ,
355
+ node_source = source_pre ,
356
+ extra_kwargs = extra_kwargs ,
270
357
),
271
358
node_pre ,
272
359
)
273
360
274
361
def visit_FunctionDef (self , node : ast .FunctionDef ):
275
362
pre_node = deepcopy (node )
363
+ node = deepcopy (node )
276
364
self .fn_count += 1
277
365
278
366
# Only overwrite the name of our "target function"
@@ -283,9 +371,6 @@ def visit_FunctionDef(self, node: ast.FunctionDef):
283
371
is_transform_root = self .fn_count == 1 and self .is_maxray_root
284
372
285
373
if is_transform_root :
286
- logger .info (
287
- f"Wiped decorators at level { self .fn_count } for { self .fn_context .impl_fn } : { node .decorator_list } "
288
- )
289
374
# If we didn't clear, decorators would be applied twice - screwing up routing handling in libraries like `quart`: `@app.post("/generate")`
290
375
node .decorator_list = []
291
376
@@ -381,6 +466,17 @@ def recompile_fn_with_transform(
381
466
"""
382
467
Recompiles `source_fn` so that essentially every node of its AST tree is wrapped by a call to `transform_fn` with the evaluated value along with context information about the source code.
383
468
"""
469
+
470
+ # Could store as bytes but extra memory insignificant
471
+ # Even failed transforms should get ids assigned and returned as part of Err for tracking
472
+ # TODO: assign as field, and use instead of cache?
473
+ # TODO: hash based on properties so consistent across program runs
474
+ maxray_assigned_id = str (uuid .uuid4 ())
475
+ try :
476
+ source_fn ._MAXRAY_TRANSFORM_ID = maxray_assigned_id
477
+ except AttributeError :
478
+ pass
479
+
384
480
# TODO: use non-overridable __getattribute__ instead?
385
481
if not hasattr (source_fn , "__name__" ): # Safety check against weird functions
386
482
return Err (f"There is no __name__ for function { get_fn_name (source_fn )} " )
@@ -494,6 +590,12 @@ def patch_mro(super_type: super):
494
590
}
495
591
scope .update (initial_scope )
496
592
593
+ # BUG: this will NOT work with threading - could use ContextVar if no performance impact?
594
+ def set_temp (val , name : str ):
595
+ scope [name ] = val
596
+ return val
597
+
598
+ scope ["_MAXRAY_SET_TEMP" ] = set_temp
497
599
# Add class-private names to scope (though only should be usable as a default argument)
498
600
# TODO: should apply to all definitions within a class scope - so @staticmethod descriptors as well...
499
601
if fn_is_method :
@@ -546,9 +648,9 @@ def extract_cell(cell):
546
648
logger .error (
547
649
f"Failed to compile function `{ source_fn .__name__ } ` at '{ sourcefile } ' in its module { module } "
548
650
)
549
- logger .debug (f"Relevant original source code:\n { source } " )
550
- logger .debug (f"Corresponding AST:\n { FnRewriter .safe_show_ast (fn_ast )} " )
551
- logger .debug (
651
+ logger .trace (f"Relevant original source code:\n { source } " )
652
+ logger .trace (f"Corresponding AST:\n { FnRewriter .safe_show_ast (fn_ast )} " )
653
+ logger .trace (
552
654
f"Transformed code we attempted to compile:\n { FnRewriter .safe_unparse (transformed_fn_ast )} "
553
655
)
554
656
@@ -598,6 +700,7 @@ def extract_cell(cell):
598
700
599
701
# way to keep track of which functions we've already transformed
600
702
transformed_fn ._MAXRAY_TRANSFORMED = True
703
+ transformed_fn ._MAXRAY_TRANSFORM_ID = maxray_assigned_id
601
704
602
705
return Ok (transformed_fn )
603
706
0 commit comments