Skip to content

Commit 4d60721

Browse files
authored
Merge pull request #2 from blepabyte/06-10-feat_assign_transformed_functions_a_uuid
feat: assign transformed functions a UUID
2 parents 9d0fbde + 9557d68 commit 4d60721

File tree

2 files changed

+171
-18
lines changed

2 files changed

+171
-18
lines changed

maxray/transforms.py

+121-18
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import ast
22
import inspect
33
import sys
4+
import uuid
45

56
from textwrap import dedent
67
from pathlib import Path
7-
from dataclasses import dataclass
8+
from dataclasses import dataclass, asdict
89
from copy import deepcopy
910
from result import Result, Ok, Err
1011
from contextvars import ContextVar
@@ -60,9 +61,25 @@ class NodeContext:
6061

6162
local_scope: Any = None
6263

64+
caller_id: Any = None
65+
6366
def __repr__(self):
6467
return f"{self.fn_context}/{self.id}"
6568

69+
def to_dict(self):
70+
return {
71+
"id": self.id,
72+
"source": self.source,
73+
"location": self.location,
74+
"caller_id": self.caller_id,
75+
"fn_context": {
76+
"name": self.fn_context.name,
77+
"module": self.fn_context.module,
78+
"source_file": self.fn_context.source_file,
79+
"call_count": self.fn_context.call_count.get(),
80+
},
81+
}
82+
6683

6784
class FnRewriter(ast.NodeTransformer):
6885
def __init__(
@@ -120,10 +137,13 @@ def recover_source(self, pre_node):
120137
return self.safe_unparse(pre_node)
121138
return segment
122139

123-
def build_transform_node(self, node, label, node_source=None, pass_locals=False):
140+
def build_transform_node(
141+
self, node, label, node_source=None, pass_locals=False, extra_kwargs=None
142+
):
124143
"""
125144
Builds the "inspection" node that wraps the original source node - passing the (value, context) pair to `transform_fn`.
126145
"""
146+
node = deepcopy(node)
127147
if node_source is None:
128148
node_source = self.safe_unparse(node)
129149

@@ -145,28 +165,38 @@ def build_transform_node(self, node, label, node_source=None, pass_locals=False)
145165
),
146166
]
147167

168+
keyword_args = []
169+
170+
if extra_kwargs is not None:
171+
keyword_args.extend(extra_kwargs)
172+
148173
if pass_locals:
149-
context_args.append(
150-
ast.Call(
151-
func=ast.Name(id="_MAXRAY_BUILTINS_LOCALS", ctx=ast.Load()),
152-
args=[],
153-
keywords=[],
154-
),
174+
keyword_args.append(
175+
ast.keyword(
176+
arg="local_scope",
177+
value=ast.Call(
178+
func=ast.Name(id="_MAXRAY_BUILTINS_LOCALS", ctx=ast.Load()),
179+
args=[],
180+
keywords=[],
181+
),
182+
)
155183
)
184+
156185
context_node = ast.Call(
157186
func=ast.Name(id=NodeContext.__name__, ctx=ast.Load()),
158187
args=context_args,
159-
keywords=[],
188+
keywords=keyword_args,
160189
)
161-
162-
return ast.Call(
190+
ret = ast.Call(
163191
func=ast.Name(id=self.transform_fn.__name__, ctx=ast.Load()),
164192
args=[node, context_node],
165193
keywords=[],
166194
)
195+
return ret
167196

168197
def visit_Name(self, node):
169198
source_pre = self.recover_source(node)
199+
node = deepcopy(node)
170200

171201
match node.ctx:
172202
case ast.Load():
@@ -189,6 +219,7 @@ def visit_Attribute(self, node: ast.Attribute) -> Any:
189219
> Private name mangling: When an identifier that textually occurs in a class definition begins with two or more underscore characters and does not end in two or more underscores, it is considered a private name of that class. Private names are transformed to a longer form before code is generated for them. The transformation inserts the class name, with leading underscores removed and a single underscore inserted, in front of the name. For example, the identifier __spam occurring in a class named Ham will be transformed to _Ham__spam. This transformation is independent of the syntactical context in which the identifier is used. If the transformed name is extremely long (longer than 255 characters), implementation defined truncation may happen. If the class name consists only of underscores, no transformation is done.
190220
"""
191221
source_pre = self.recover_source(node)
222+
node = deepcopy(node)
192223

193224
if self.is_method() and self.is_private_class_name(node.attr):
194225
node.attr = f"_{self.instance_type}{node.attr}"
@@ -208,6 +239,7 @@ def visit_match_case(self, node: ast.match_case) -> Any:
208239
return node
209240

210241
def visit_Assign(self, node: ast.Assign) -> Any:
242+
node = deepcopy(node)
211243
new_node = self.generic_visit(node)
212244
assert isinstance(new_node, ast.Assign)
213245
# node = new_node
@@ -216,6 +248,7 @@ def visit_Assign(self, node: ast.Assign) -> Any:
216248

217249
def visit_Return(self, node: ast.Return) -> Any:
218250
node_pre = deepcopy(node)
251+
node = deepcopy(node)
219252

220253
if node.value is None:
221254
node.value = ast.Constant(None)
@@ -236,9 +269,37 @@ def visit_Return(self, node: ast.Return) -> Any:
236269

237270
return ast.copy_location(node, node_pre)
238271

272+
@staticmethod
273+
def temp_binding(node, by_name: str):
274+
return ast.fix_missing_locations(
275+
ast.Call(
276+
ast.Name("_MAXRAY_SET_TEMP", ctx=ast.Load()),
277+
[node, ast.Constant(by_name)],
278+
keywords=[],
279+
)
280+
)
281+
# Cannot use walrus expressions in nested list comprehension context
282+
# return ast.fix_missing_locations(
283+
# ast.Subscript(
284+
# value=ast.Tuple(
285+
# elts=[
286+
# ast.NamedExpr(
287+
# target=ast.Name(id=by_name, ctx=ast.Store()),
288+
# value=node,
289+
# ),
290+
# ast.Name(id=by_name, ctx=ast.Load()),
291+
# ],
292+
# ctx=ast.Load(),
293+
# ),
294+
# slice=ast.Constant(-1),
295+
# ctx=ast.Load(),
296+
# )
297+
# )
298+
239299
def visit_Call(self, node):
240300
source_pre = self.recover_source(node)
241301
node_pre = deepcopy(node)
302+
node = deepcopy(node)
242303

243304
match node:
244305
case ast.Call(func=ast.Name(id="super"), args=[]):
@@ -262,17 +323,44 @@ def visit_Call(self, node):
262323
ast.fix_missing_locations(node)
263324

264325
node = self.generic_visit(node) # mutates
326+
# Want to keep track of which function we're calling
327+
# `node.func` is now likely a call to `_maxray_walker_handler`
328+
node.func = self.temp_binding(node.func, "_MAXRAY_TEMP_VAR")
329+
# We can't use `node_pre.func` beacuse evaluating it has side effects
330+
331+
# TODO: maybe assign the super() proxy and do the MRO patching at the start of the function
332+
extra_kwargs = []
333+
if "super" not in source_pre:
334+
extra_kwargs.append(
335+
ast.keyword(
336+
"caller_id",
337+
# value=ast.Name("_MAXRAY_TEMP_VAR", ctx=ast.Load()),
338+
value=ast.Call(
339+
func=ast.Name("getattr", ctx=ast.Load()),
340+
args=[
341+
ast.Name("_MAXRAY_TEMP_VAR", ctx=ast.Load()),
342+
ast.Constant("_MAXRAY_TRANSFORM_ID"),
343+
ast.Constant(None),
344+
],
345+
keywords=[],
346+
),
347+
)
348+
)
265349

266350
# the function/callable instance itself is observed by Name/Attribute/... nodes
267351
return ast.copy_location(
268352
self.build_transform_node(
269-
node, f"call/{source_pre}", node_source=source_pre
353+
node,
354+
f"call/{source_pre}",
355+
node_source=source_pre,
356+
extra_kwargs=extra_kwargs,
270357
),
271358
node_pre,
272359
)
273360

274361
def visit_FunctionDef(self, node: ast.FunctionDef):
275362
pre_node = deepcopy(node)
363+
node = deepcopy(node)
276364
self.fn_count += 1
277365

278366
# Only overwrite the name of our "target function"
@@ -283,9 +371,6 @@ def visit_FunctionDef(self, node: ast.FunctionDef):
283371
is_transform_root = self.fn_count == 1 and self.is_maxray_root
284372

285373
if is_transform_root:
286-
logger.info(
287-
f"Wiped decorators at level {self.fn_count} for {self.fn_context.impl_fn}: {node.decorator_list}"
288-
)
289374
# If we didn't clear, decorators would be applied twice - screwing up routing handling in libraries like `quart`: `@app.post("/generate")`
290375
node.decorator_list = []
291376

@@ -381,6 +466,17 @@ def recompile_fn_with_transform(
381466
"""
382467
Recompiles `source_fn` so that essentially every node of its AST tree is wrapped by a call to `transform_fn` with the evaluated value along with context information about the source code.
383468
"""
469+
470+
# Could store as bytes but extra memory insignificant
471+
# Even failed transforms should get ids assigned and returned as part of Err for tracking
472+
# TODO: assign as field, and use instead of cache?
473+
# TODO: hash based on properties so consistent across program runs
474+
maxray_assigned_id = str(uuid.uuid4())
475+
try:
476+
source_fn._MAXRAY_TRANSFORM_ID = maxray_assigned_id
477+
except AttributeError:
478+
pass
479+
384480
# TODO: use non-overridable __getattribute__ instead?
385481
if not hasattr(source_fn, "__name__"): # Safety check against weird functions
386482
return Err(f"There is no __name__ for function {get_fn_name(source_fn)}")
@@ -494,6 +590,12 @@ def patch_mro(super_type: super):
494590
}
495591
scope.update(initial_scope)
496592

593+
# BUG: this will NOT work with threading - could use ContextVar if no performance impact?
594+
def set_temp(val, name: str):
595+
scope[name] = val
596+
return val
597+
598+
scope["_MAXRAY_SET_TEMP"] = set_temp
497599
# Add class-private names to scope (though only should be usable as a default argument)
498600
# TODO: should apply to all definitions within a class scope - so @staticmethod descriptors as well...
499601
if fn_is_method:
@@ -546,9 +648,9 @@ def extract_cell(cell):
546648
logger.error(
547649
f"Failed to compile function `{source_fn.__name__}` at '{sourcefile}' in its module {module}"
548650
)
549-
logger.debug(f"Relevant original source code:\n{source}")
550-
logger.debug(f"Corresponding AST:\n{FnRewriter.safe_show_ast(fn_ast)}")
551-
logger.debug(
651+
logger.trace(f"Relevant original source code:\n{source}")
652+
logger.trace(f"Corresponding AST:\n{FnRewriter.safe_show_ast(fn_ast)}")
653+
logger.trace(
552654
f"Transformed code we attempted to compile:\n{FnRewriter.safe_unparse(transformed_fn_ast)}"
553655
)
554656

@@ -598,6 +700,7 @@ def extract_cell(cell):
598700

599701
# way to keep track of which functions we've already transformed
600702
transformed_fn._MAXRAY_TRANSFORMED = True
703+
transformed_fn._MAXRAY_TRANSFORM_ID = maxray_assigned_id
601704

602705
return Ok(transformed_fn)
603706

tests/test_transforms.py

+50
Original file line numberDiff line numberDiff line change
@@ -571,3 +571,53 @@ def fff():
571571
return S1.foo()
572572

573573
assert fff() == 6
574+
575+
576+
def test_partialmethod():
577+
from functools import partialmethod
578+
579+
# TQDM threw an error from partialmethod via @env_wrap but can't seem to reproduce
580+
@xray(dbg)
581+
def run_part():
582+
class X:
583+
def set_state(self, active: bool):
584+
self.active = active
585+
586+
set_active = partialmethod(set_state, True)
587+
588+
x = X()
589+
x.set_active()
590+
591+
assert x.active
592+
593+
run_part()
594+
595+
596+
def test_caller_id():
597+
f1_id = None
598+
f2_id = None
599+
600+
def collect_ids(x, ctx: NodeContext):
601+
if ctx.source == "f1()":
602+
nonlocal f1_id
603+
f1_id = ctx.caller_id
604+
elif ctx.source == "f2()":
605+
nonlocal f2_id
606+
f2_id = ctx.caller_id
607+
608+
def f1():
609+
return 1
610+
611+
def f2():
612+
return 2
613+
614+
@xray(collect_ids)
615+
def func():
616+
f1()
617+
f2()
618+
619+
func()
620+
621+
assert f1._MAXRAY_TRANSFORM_ID == f1_id
622+
assert f2._MAXRAY_TRANSFORM_ID == f2_id
623+
assert f1_id != f2_id

0 commit comments

Comments
 (0)