Skip to content

Commit 34694be

Browse files
committed
refactor: move runtime impl into RuntimeHelper
1 parent 3835f26 commit 34694be

File tree

1 file changed

+76
-36
lines changed

1 file changed

+76
-36
lines changed

maxray/transforms.py

+76-36
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,67 @@ def to_dict(self):
8686
}
8787

8888

89+
class RewriteRuntimeHelper:
90+
"""
91+
Implementation for rewrite functionality that needs to execute code and maintain state at runtime (rather than only applying known AST transforms)
92+
93+
- `read_*` are called at runtime
94+
- `write_*` are called during AST rewrite
95+
"""
96+
97+
SYMBOL = "_MAXRAY_REWRITE_RUNTIME"
98+
99+
def __init__(self, fn_context: FnContext):
100+
self.fn_context = fn_context
101+
102+
def expand_scope(self):
103+
# TODO: exclude these methods from being patched/transformed
104+
return {
105+
self.SYMBOL: self,
106+
"_MAXRAY_INNER_NOTRANSFORM": self.read_inner_notransform,
107+
"_MAXRAY_PATCH_MRO": self.read_patch_mro,
108+
}
109+
110+
@property
111+
def read_locals(self):
112+
return builtins.locals
113+
114+
def write_locals(self):
115+
return ast.Call(
116+
func=ast.Attribute(
117+
ast.Name(id=self.SYMBOL, ctx=ast.Load()), "read_locals", ctx=ast.Load()
118+
),
119+
args=[],
120+
keywords=[],
121+
)
122+
123+
def read_inner_notransform(self, f):
124+
set_property_on_functionlike(f, "_MAXRAY_NOTRANSFORM", True)
125+
return f
126+
127+
def write_inner_notransform(self):
128+
return ast.Name("_MAXRAY_INNER_NOTRANSFORM", ctx=ast.Load())
129+
130+
def read_patch_mro(self, super_type: "super"): # type: ignore
131+
# TODO: use `ctx` to find current method to patch in addition to dunders, also apply actual transform
132+
for parent_type in super_type.__self_class__.mro():
133+
if not hasattr(parent_type, "__init__") or hasattr(
134+
parent_type.__init__, "_MAXRAY_TRANSFORMED"
135+
): # Seems to have side effects when picked up for patching?
136+
continue
137+
138+
return super_type
139+
140+
def write_patch_mro(self):
141+
return ast.Name("_MAXRAY_PATCH_MRO", ctx=ast.Load())
142+
143+
89144
class FnRewriter(ast.NodeTransformer):
90145
def __init__(
91146
self,
92147
transform_fn,
93148
fn_context: FnContext,
149+
runtime_helper: RewriteRuntimeHelper,
94150
*,
95151
instance_type: str | None,
96152
dedent_chars: int = 0,
@@ -104,6 +160,7 @@ def __init__(
104160

105161
self.transform_fn = transform_fn
106162
self.fn_context = fn_context
163+
self.runtime_helper = runtime_helper
107164
self.instance_type = instance_type
108165
self.dedent_chars = dedent_chars
109166
self.record_call_counts = record_call_counts
@@ -187,11 +244,12 @@ def build_transform_node(
187244
keyword_args.append(
188245
ast.keyword(
189246
arg="local_scope",
190-
value=ast.Call(
191-
func=ast.Name(id="_MAXRAY_BUILTINS_LOCALS", ctx=ast.Load()),
192-
args=[],
193-
keywords=[],
194-
),
247+
value=self.runtime_helper.write_locals(),
248+
# ast.Call(
249+
# func=ast.Name(id="_MAXRAY_BUILTINS_LOCALS", ctx=ast.Load()),
250+
# args=[],
251+
# keywords=[],
252+
# ),
195253
)
196254
)
197255

@@ -252,16 +310,19 @@ def visit_Attribute(self, node: ast.Attribute) -> Any:
252310

253311
# if self.is_method() and self.is_private_class_name(node.attr):
254312
# does the ast.Load() check need to be pulled up here?
313+
# TODO: support failing/raising an exception if we can't guess any instance name
255314
if self.is_private_class_name(node.attr):
256315
# currently we do a bad job of actually checking if it's supposed to be a method-like so this is just a hopeful guess
257-
qualname_type_guess = self.fn_context.name.split(".")[-2]
316+
qualname_components = self.fn_context.name.split(".")
317+
if len(qualname_components) < 2:
318+
logger.error(f"{qualname_components} {self.safe_unparse(node)}")
258319
resolve_type_name = (
259320
self.instance_type
260321
if self.instance_type is not None
261-
else qualname_type_guess
322+
else qualname_components[-2]
262323
)
263324
node.attr = f"_{resolve_type_name.lstrip('_')}{node.attr}"
264-
logger.warning("Replaced with mangled private name")
325+
logger.warning(f"Replaced with mangled private name: {node.attr}")
265326

266327
if isinstance(node.ctx, ast.Load):
267328
node = self.generic_visit(node)
@@ -378,7 +439,7 @@ def visit_Call(self, node):
378439
ast.Name("cls", ctx=ast.Load()),
379440
]
380441
node = ast.Call(
381-
func=ast.Name("_MAXRAY_PATCH_MRO", ctx=ast.Load()),
442+
func=self.runtime_helper.write_patch_mro(),
382443
args=[node],
383444
keywords=[],
384445
)
@@ -431,9 +492,7 @@ def transform_function_def(self, node: ast.FunctionDef | ast.AsyncFunctionDef):
431492
node.name = f"{node.name}_{self.instance_type}_{node.name}"
432493
self.defined_fn_name = node.name
433494
else:
434-
node.decorator_list.insert(
435-
0, ast.Name("_MAXRAY_DECORATE_INNER_NOTRANSFORM", ctx=ast.Load())
436-
)
495+
node.decorator_list.insert(0, self.runtime_helper.write_inner_notransform())
437496

438497
# Decorators are evaluated sequentially: decorators applied *before* our one (should?) get ignored while decorators applied *after* work correctly
439498
is_transform_root = self.fn_count == 1 and self.is_maxray_root
@@ -627,10 +686,13 @@ def recompile_fn_with_transform(
627686
fn_call_counter,
628687
compile_id=with_source_fn.compile_id,
629688
)
689+
runtime_helper = RewriteRuntimeHelper(fn_context)
690+
630691
instance_type = parent_cls.__name__ if fn_is_method else None
631692
fn_rewriter = FnRewriter(
632693
transform_fn,
633694
fn_context,
695+
runtime_helper,
634696
instance_type=instance_type,
635697
dedent_chars=with_source_fn.source_dedent_chars,
636698
pass_locals_on_return=pass_scope,
@@ -642,32 +704,15 @@ def recompile_fn_with_transform(
642704
if ast_post_callback is not None:
643705
ast_post_callback(transformed_fn_ast)
644706

645-
def patch_mro(super_type: super):
646-
for parent_type in super_type.__self_class__.mro():
647-
# Ok that's weird - this function gets picked up by the maxray decorator and seems to correctly patch the parent types - so despite looking like this function does absolutely nothing, it actually *has* side-effects
648-
if not hasattr(parent_type, "__init__") or hasattr(
649-
parent_type.__init__, "_MAXRAY_TRANSFORMED"
650-
):
651-
continue
652-
653-
return super_type
654-
655-
# TODO: does this make a difference? are superclasses even getting patched properly?
656-
# patch_mro._MAXRAY_NOTRANSFORM = True
657-
658707
scope_layers = {
659708
"core": {
660709
transform_fn.__name__: transform_fn,
661710
NodeContext.__name__: NodeContext,
662711
"_MAXRAY_FN_CONTEXT": fn_context,
663712
"_MAXRAY_CALL_COUNTER": fn_call_counter,
664713
"_MAXRAY_DECORATE_WITH_COUNTER": count_calls_with,
665-
"_MAXRAY_DECORATE_INNER_NOTRANSFORM": ensure_notransform,
666-
"_MAXRAY_BUILTINS_LOCALS": locals,
667-
"_MAXRAY_PATCH_MRO": patch_mro,
668-
"_MAXRAY_MODULE_GLOBALS": vars(
669-
module
670-
), # TODO: on fail need to use different module
714+
"_MAXRAY_MODULE_GLOBALS": vars(module),
715+
**runtime_helper.expand_scope(),
671716
},
672717
"override": override_scope,
673718
"class_local": {},
@@ -852,8 +897,3 @@ def lookup_module(fd: FunctionData):
852897
return file_matched
853898
else:
854899
return module_matched
855-
856-
857-
def ensure_notransform(f):
858-
set_property_on_functionlike(f, "_MAXRAY_NOTRANSFORM", True)
859-
return f

0 commit comments

Comments
 (0)