Skip to content

Commit 869db4c

Browse files
committed
feat: propagate more assignment info in NodeContext.props
1 parent 05ff21b commit 869db4c

File tree

1 file changed

+40
-20
lines changed

1 file changed

+40
-20
lines changed

maxray/transforms.py

+40-20
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import sys
1111
import builtins
1212

13-
from dataclasses import dataclass
13+
from dataclasses import dataclass, field
1414
from copy import deepcopy
1515
from result import Result, Ok, Err
1616
from contextvars import ContextVar
@@ -68,6 +68,8 @@ class NodeContext:
6868

6969
caller_id: Any = None
7070

71+
props: dict = field(default_factory=lambda: {})
72+
7173
def __repr__(self):
7274
return f"{self.fn_context}/{self.id}"
7375

@@ -85,6 +87,10 @@ def to_dict(self):
8587
},
8688
}
8789

90+
def _set_assigned(self, targets: list[str]):
91+
self.props["assigned"] = {"targets": targets}
92+
return self
93+
8894

8995
class RewriteRuntimeHelper:
9096
"""
@@ -144,6 +150,31 @@ def write_patch_mro(self):
144150
class RewriteFailed(Exception): ...
145151

146152

153+
class RewriteTransformCall(ast.Call):
154+
@staticmethod
155+
def build(
156+
transform_func_name: str, source_node, node_context_args, node_context_kwargs
157+
):
158+
context_node = ast.Call(
159+
func=ast.Name(id=NodeContext.__name__, ctx=ast.Load()),
160+
args=node_context_args,
161+
keywords=node_context_kwargs,
162+
)
163+
164+
return RewriteTransformCall(
165+
func=ast.Name(id=transform_func_name, ctx=ast.Load()),
166+
args=[source_node, context_node],
167+
keywords=[],
168+
)
169+
170+
def assigned(self, assign_targets: list[str]):
171+
self.args[1] = ast.Call(
172+
ast.Attribute(self.args[1], "_set_assigned", ctx=ast.Load()),
173+
args=[ast.List([ast.Constant(s) for s in assign_targets], ctx=ast.Load())],
174+
keywords=[],
175+
)
176+
177+
147178
class FnRewriter(ast.NodeTransformer):
148179
def __init__(
149180
self,
@@ -248,25 +279,12 @@ def build_transform_node(
248279
ast.keyword(
249280
arg="local_scope",
250281
value=self.runtime_helper.write_locals(),
251-
# ast.Call(
252-
# func=ast.Name(id="_MAXRAY_BUILTINS_LOCALS", ctx=ast.Load()),
253-
# args=[],
254-
# keywords=[],
255-
# ),
256282
)
257283
)
258284

259-
context_node = ast.Call(
260-
func=ast.Name(id=NodeContext.__name__, ctx=ast.Load()),
261-
args=context_args,
262-
keywords=keyword_args,
263-
)
264-
ret = ast.Call(
265-
func=ast.Name(id=self.transform_fn.__name__, ctx=ast.Load()),
266-
args=[node, context_node],
267-
keywords=[],
285+
return RewriteTransformCall.build(
286+
self.transform_fn.__name__, node, context_args, keyword_args
268287
)
269-
return ret
270288

271289
def visit_Name(self, node):
272290
source_pre = self.recover_source(node)
@@ -345,10 +363,12 @@ def visit_match_case(self, node: ast.match_case) -> Any:
345363
def visit_Assign(self, node: ast.Assign) -> Any:
346364
node = deepcopy(node)
347365
new_node = self.generic_visit(node)
348-
assert isinstance(new_node, ast.Assign)
349-
# node = new_node
350-
# node.value = self.build_transform_node(new_node, f"assign/(multiple)")
351-
return node
366+
match new_node:
367+
case ast.Assign(targets=targets, value=RewriteTransformCall() as rtc):
368+
target_reprs = [self.recover_source(t) for t in targets]
369+
rtc.assigned(target_reprs)
370+
371+
return new_node
352372

353373
def visit_Subscript(self, node: ast.Subscript) -> Any:
354374
if isinstance(node.ctx, ast.Load):

0 commit comments

Comments
 (0)