10
10
import sys
11
11
import builtins
12
12
13
- from dataclasses import dataclass
13
+ from dataclasses import dataclass , field
14
14
from copy import deepcopy
15
15
from result import Result , Ok , Err
16
16
from contextvars import ContextVar
@@ -68,6 +68,8 @@ class NodeContext:
68
68
69
69
caller_id : Any = None
70
70
71
+ props : dict = field (default_factory = lambda : {})
72
+
71
73
def __repr__ (self ):
72
74
return f"{ self .fn_context } /{ self .id } "
73
75
@@ -85,6 +87,10 @@ def to_dict(self):
85
87
},
86
88
}
87
89
90
+ def _set_assigned (self , targets : list [str ]):
91
+ self .props ["assigned" ] = {"targets" : targets }
92
+ return self
93
+
88
94
89
95
class RewriteRuntimeHelper :
90
96
"""
@@ -144,6 +150,31 @@ def write_patch_mro(self):
144
150
class RewriteFailed (Exception ): ...
145
151
146
152
153
+ class RewriteTransformCall (ast .Call ):
154
+ @staticmethod
155
+ def build (
156
+ transform_func_name : str , source_node , node_context_args , node_context_kwargs
157
+ ):
158
+ context_node = ast .Call (
159
+ func = ast .Name (id = NodeContext .__name__ , ctx = ast .Load ()),
160
+ args = node_context_args ,
161
+ keywords = node_context_kwargs ,
162
+ )
163
+
164
+ return RewriteTransformCall (
165
+ func = ast .Name (id = transform_func_name , ctx = ast .Load ()),
166
+ args = [source_node , context_node ],
167
+ keywords = [],
168
+ )
169
+
170
+ def assigned (self , assign_targets : list [str ]):
171
+ self .args [1 ] = ast .Call (
172
+ ast .Attribute (self .args [1 ], "_set_assigned" , ctx = ast .Load ()),
173
+ args = [ast .List ([ast .Constant (s ) for s in assign_targets ], ctx = ast .Load ())],
174
+ keywords = [],
175
+ )
176
+
177
+
147
178
class FnRewriter (ast .NodeTransformer ):
148
179
def __init__ (
149
180
self ,
@@ -248,25 +279,12 @@ def build_transform_node(
248
279
ast .keyword (
249
280
arg = "local_scope" ,
250
281
value = self .runtime_helper .write_locals (),
251
- # ast.Call(
252
- # func=ast.Name(id="_MAXRAY_BUILTINS_LOCALS", ctx=ast.Load()),
253
- # args=[],
254
- # keywords=[],
255
- # ),
256
282
)
257
283
)
258
284
259
- context_node = ast .Call (
260
- func = ast .Name (id = NodeContext .__name__ , ctx = ast .Load ()),
261
- args = context_args ,
262
- keywords = keyword_args ,
263
- )
264
- ret = ast .Call (
265
- func = ast .Name (id = self .transform_fn .__name__ , ctx = ast .Load ()),
266
- args = [node , context_node ],
267
- keywords = [],
285
+ return RewriteTransformCall .build (
286
+ self .transform_fn .__name__ , node , context_args , keyword_args
268
287
)
269
- return ret
270
288
271
289
def visit_Name (self , node ):
272
290
source_pre = self .recover_source (node )
@@ -345,10 +363,12 @@ def visit_match_case(self, node: ast.match_case) -> Any:
345
363
def visit_Assign (self , node : ast .Assign ) -> Any :
346
364
node = deepcopy (node )
347
365
new_node = self .generic_visit (node )
348
- assert isinstance (new_node , ast .Assign )
349
- # node = new_node
350
- # node.value = self.build_transform_node(new_node, f"assign/(multiple)")
351
- return node
366
+ match new_node :
367
+ case ast .Assign (targets = targets , value = RewriteTransformCall () as rtc ):
368
+ target_reprs = [self .recover_source (t ) for t in targets ]
369
+ rtc .assigned (target_reprs )
370
+
371
+ return new_node
352
372
353
373
def visit_Subscript (self , node : ast .Subscript ) -> Any :
354
374
if isinstance (node .ctx , ast .Load ):
0 commit comments