@@ -86,11 +86,67 @@ def to_dict(self):
86
86
}
87
87
88
88
89
+ class RewriteRuntimeHelper :
90
+ """
91
+ Implementation for rewrite functionality that needs to execute code and maintain state at runtime (rather than only applying known AST transforms)
92
+
93
+ - `read_*` are called at runtime
94
+ - `write_*` are called during AST rewrite
95
+ """
96
+
97
+ SYMBOL = "_MAXRAY_REWRITE_RUNTIME"
98
+
99
+ def __init__ (self , fn_context : FnContext ):
100
+ self .fn_context = fn_context
101
+
102
+ def expand_scope (self ):
103
+ # TODO: exclude these methods from being patched/transformed
104
+ return {
105
+ self .SYMBOL : self ,
106
+ "_MAXRAY_INNER_NOTRANSFORM" : self .read_inner_notransform ,
107
+ "_MAXRAY_PATCH_MRO" : self .read_patch_mro ,
108
+ }
109
+
110
+ @property
111
+ def read_locals (self ):
112
+ return builtins .locals
113
+
114
+ def write_locals (self ):
115
+ return ast .Call (
116
+ func = ast .Attribute (
117
+ ast .Name (id = self .SYMBOL , ctx = ast .Load ()), "read_locals" , ctx = ast .Load ()
118
+ ),
119
+ args = [],
120
+ keywords = [],
121
+ )
122
+
123
+ def read_inner_notransform (self , f ):
124
+ set_property_on_functionlike (f , "_MAXRAY_NOTRANSFORM" , True )
125
+ return f
126
+
127
+ def write_inner_notransform (self ):
128
+ return ast .Name ("_MAXRAY_INNER_NOTRANSFORM" , ctx = ast .Load ())
129
+
130
+ def read_patch_mro (self , super_type : "super" ): # type: ignore
131
+ # TODO: use `ctx` to find current method to patch in addition to dunders, also apply actual transform
132
+ for parent_type in super_type .__self_class__ .mro ():
133
+ if not hasattr (parent_type , "__init__" ) or hasattr (
134
+ parent_type .__init__ , "_MAXRAY_TRANSFORMED"
135
+ ): # Seems to have side effects when picked up for patching?
136
+ continue
137
+
138
+ return super_type
139
+
140
+ def write_patch_mro (self ):
141
+ return ast .Name ("_MAXRAY_PATCH_MRO" , ctx = ast .Load ())
142
+
143
+
89
144
class FnRewriter (ast .NodeTransformer ):
90
145
def __init__ (
91
146
self ,
92
147
transform_fn ,
93
148
fn_context : FnContext ,
149
+ runtime_helper : RewriteRuntimeHelper ,
94
150
* ,
95
151
instance_type : str | None ,
96
152
dedent_chars : int = 0 ,
@@ -104,6 +160,7 @@ def __init__(
104
160
105
161
self .transform_fn = transform_fn
106
162
self .fn_context = fn_context
163
+ self .runtime_helper = runtime_helper
107
164
self .instance_type = instance_type
108
165
self .dedent_chars = dedent_chars
109
166
self .record_call_counts = record_call_counts
@@ -187,11 +244,12 @@ def build_transform_node(
187
244
keyword_args .append (
188
245
ast .keyword (
189
246
arg = "local_scope" ,
190
- value = ast .Call (
191
- func = ast .Name (id = "_MAXRAY_BUILTINS_LOCALS" , ctx = ast .Load ()),
192
- args = [],
193
- keywords = [],
194
- ),
247
+ value = self .runtime_helper .write_locals (),
248
+ # ast.Call(
249
+ # func=ast.Name(id="_MAXRAY_BUILTINS_LOCALS", ctx=ast.Load()),
250
+ # args=[],
251
+ # keywords=[],
252
+ # ),
195
253
)
196
254
)
197
255
@@ -252,16 +310,19 @@ def visit_Attribute(self, node: ast.Attribute) -> Any:
252
310
253
311
# if self.is_method() and self.is_private_class_name(node.attr):
254
312
# does the ast.Load() check need to be pulled up here?
313
+ # TODO: support failing/raising an exception if we can't guess any instance name
255
314
if self .is_private_class_name (node .attr ):
256
315
# currently we do a bad job of actually checking if it's supposed to be a method-like so this is just a hopeful guess
257
- qualname_type_guess = self .fn_context .name .split ("." )[- 2 ]
316
+ qualname_components = self .fn_context .name .split ("." )
317
+ if len (qualname_components ) < 2 :
318
+ logger .error (f"{ qualname_components } { self .safe_unparse (node )} " )
258
319
resolve_type_name = (
259
320
self .instance_type
260
321
if self .instance_type is not None
261
- else qualname_type_guess
322
+ else qualname_components [ - 2 ]
262
323
)
263
324
node .attr = f"_{ resolve_type_name .lstrip ('_' )} { node .attr } "
264
- logger .warning ("Replaced with mangled private name" )
325
+ logger .warning (f "Replaced with mangled private name: { node . attr } " )
265
326
266
327
if isinstance (node .ctx , ast .Load ):
267
328
node = self .generic_visit (node )
@@ -378,7 +439,7 @@ def visit_Call(self, node):
378
439
ast .Name ("cls" , ctx = ast .Load ()),
379
440
]
380
441
node = ast .Call (
381
- func = ast . Name ( "_MAXRAY_PATCH_MRO" , ctx = ast . Load () ),
442
+ func = self . runtime_helper . write_patch_mro ( ),
382
443
args = [node ],
383
444
keywords = [],
384
445
)
@@ -431,9 +492,7 @@ def transform_function_def(self, node: ast.FunctionDef | ast.AsyncFunctionDef):
431
492
node .name = f"{ node .name } _{ self .instance_type } _{ node .name } "
432
493
self .defined_fn_name = node .name
433
494
else :
434
- node .decorator_list .insert (
435
- 0 , ast .Name ("_MAXRAY_DECORATE_INNER_NOTRANSFORM" , ctx = ast .Load ())
436
- )
495
+ node .decorator_list .insert (0 , self .runtime_helper .write_inner_notransform ())
437
496
438
497
# Decorators are evaluated sequentially: decorators applied *before* our one (should?) get ignored while decorators applied *after* work correctly
439
498
is_transform_root = self .fn_count == 1 and self .is_maxray_root
@@ -627,10 +686,13 @@ def recompile_fn_with_transform(
627
686
fn_call_counter ,
628
687
compile_id = with_source_fn .compile_id ,
629
688
)
689
+ runtime_helper = RewriteRuntimeHelper (fn_context )
690
+
630
691
instance_type = parent_cls .__name__ if fn_is_method else None
631
692
fn_rewriter = FnRewriter (
632
693
transform_fn ,
633
694
fn_context ,
695
+ runtime_helper ,
634
696
instance_type = instance_type ,
635
697
dedent_chars = with_source_fn .source_dedent_chars ,
636
698
pass_locals_on_return = pass_scope ,
@@ -642,32 +704,15 @@ def recompile_fn_with_transform(
642
704
if ast_post_callback is not None :
643
705
ast_post_callback (transformed_fn_ast )
644
706
645
- def patch_mro (super_type : super ):
646
- for parent_type in super_type .__self_class__ .mro ():
647
- # Ok that's weird - this function gets picked up by the maxray decorator and seems to correctly patch the parent types - so despite looking like this function does absolutely nothing, it actually *has* side-effects
648
- if not hasattr (parent_type , "__init__" ) or hasattr (
649
- parent_type .__init__ , "_MAXRAY_TRANSFORMED"
650
- ):
651
- continue
652
-
653
- return super_type
654
-
655
- # TODO: does this make a difference? are superclasses even getting patched properly?
656
- # patch_mro._MAXRAY_NOTRANSFORM = True
657
-
658
707
scope_layers = {
659
708
"core" : {
660
709
transform_fn .__name__ : transform_fn ,
661
710
NodeContext .__name__ : NodeContext ,
662
711
"_MAXRAY_FN_CONTEXT" : fn_context ,
663
712
"_MAXRAY_CALL_COUNTER" : fn_call_counter ,
664
713
"_MAXRAY_DECORATE_WITH_COUNTER" : count_calls_with ,
665
- "_MAXRAY_DECORATE_INNER_NOTRANSFORM" : ensure_notransform ,
666
- "_MAXRAY_BUILTINS_LOCALS" : locals ,
667
- "_MAXRAY_PATCH_MRO" : patch_mro ,
668
- "_MAXRAY_MODULE_GLOBALS" : vars (
669
- module
670
- ), # TODO: on fail need to use different module
714
+ "_MAXRAY_MODULE_GLOBALS" : vars (module ),
715
+ ** runtime_helper .expand_scope (),
671
716
},
672
717
"override" : override_scope ,
673
718
"class_local" : {},
@@ -852,8 +897,3 @@ def lookup_module(fd: FunctionData):
852
897
return file_matched
853
898
else :
854
899
return module_matched
855
-
856
-
857
- def ensure_notransform (f ):
858
- set_property_on_functionlike (f , "_MAXRAY_NOTRANSFORM" , True )
859
- return f
0 commit comments