Skip to content

Commit 2d48e9a

Browse files
committed
fix: keep exact original source code locations
1 parent 9facb09 commit 2d48e9a

File tree

1 file changed

+35
-8
lines changed

1 file changed

+35
-8
lines changed

maxray/transforms.py

+35-8
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,12 @@ def __repr__(self):
5858

5959
class FnRewriter(ast.NodeTransformer):
6060
def __init__(
61-
self, transform_fn, fn_context: FnContext, *, instance_type: str | None
61+
self,
62+
transform_fn,
63+
fn_context: FnContext,
64+
*,
65+
instance_type: str | None,
66+
dedent_chars: int = 0,
6267
):
6368
"""
6469
If we're transforming a method, instance type should be the __name__ of the class. Otherwise, None.
@@ -67,6 +72,7 @@ def __init__(
6772
self.transform_fn = transform_fn
6873
self.fn_context = fn_context
6974
self.instance_type = instance_type
75+
self.dedent_chars = dedent_chars
7076

7177
# the first `def` we encounter is the one that we're transforming. Subsequent ones will be nested/within class definitions.
7278
self.fn_count = 0
@@ -90,6 +96,12 @@ def is_private_class_name(identifier_name: str):
9096
and identifier_name.strip("_")
9197
)
9298

99+
def recover_source(self, pre_node):
100+
segment = ast.get_source_segment(self.fn_context.source, pre_node, padded=False)
101+
if segment is None:
102+
return self.safe_unparse(pre_node)
103+
return segment
104+
93105
def build_transform_node(self, node, label, node_source=None):
94106
"""
95107
Builds the "inspection" node that wraps the original source node - passing the (value, context) pair to `transform_fn`.
@@ -98,7 +110,7 @@ def build_transform_node(self, node, label, node_source=None):
98110
node_source = self.safe_unparse(node)
99111

100112
line_offset = self.fn_context.impl_fn.__code__.co_firstlineno - 2
101-
col_offset = 4
113+
col_offset = self.dedent_chars
102114
context_node = ast.Call(
103115
func=ast.Name(id=NodeContext.__name__, ctx=ast.Load()),
104116
args=[
@@ -125,6 +137,8 @@ def build_transform_node(self, node, label, node_source=None):
125137
)
126138

127139
def visit_Name(self, node):
140+
source_pre = self.recover_source(node)
141+
128142
match node.ctx:
129143
case ast.Load():
130144
# Variable is accessed
@@ -136,14 +150,16 @@ def visit_Name(self, node):
136150
logger.error(f"Unknown context {node.ctx}")
137151
return node
138152

139-
return self.build_transform_node(new_node, f"name/{node.id}")
153+
return self.build_transform_node(
154+
new_node, f"name/{node.id}", node_source=source_pre
155+
)
140156

141157
def visit_Attribute(self, node: ast.Attribute) -> Any:
142158
"""
143159
https://docs.python.org/3/reference/expressions.html#atom-identifiers
144160
> Private name mangling: When an identifier that textually occurs in a class definition begins with two or more underscore characters and does not end in two or more underscores, it is considered a private name of that class. Private names are transformed to a longer form before code is generated for them. The transformation inserts the class name, with leading underscores removed and a single underscore inserted, in front of the name. For example, the identifier __spam occurring in a class named Ham will be transformed to _Ham__spam. This transformation is independent of the syntactical context in which the identifier is used. If the transformed name is extremely long (longer than 255 characters), implementation defined truncation may happen. If the class name consists only of underscores, no transformation is done.
145161
"""
146-
source_pre = self.safe_unparse(node)
162+
source_pre = self.recover_source(node)
147163

148164
if self.is_method() and self.is_private_class_name(node.attr):
149165
node.attr = f"_{self.instance_type}{node.attr}"
@@ -164,8 +180,9 @@ def visit_Assign(self, node: ast.Assign) -> Any:
164180
return node
165181

166182
def visit_Call(self, node):
183+
source_pre = self.recover_source(node)
184+
167185
node_pre = deepcopy(node)
168-
source_pre = self.safe_unparse(node_pre)
169186

170187
node = self.generic_visit(node) # mutates
171188

@@ -248,10 +265,13 @@ def recompile_fn_with_transform(
248265
Recompiles `source_fn` so that essentially every node of its AST tree is wrapped by a call to `transform_fn` with the evaluated value along with context information about the source code.
249266
"""
250267
try:
251-
source = inspect.getsource(source_fn)
268+
original_source = inspect.getsource(source_fn)
252269

253-
# nested functions have excess indentation preventing compile; inspect.cleandoc(source) is an alternative
254-
source = dedent(source)
270+
# nested functions have excess indentation preventing compile; inspect.cleandoc(source) is an alternative but less reliable
271+
source = dedent(original_source)
272+
273+
# want to map back to correct original location
274+
dedent_chars = original_source.find("\n") - source.find("\n")
255275

256276
sourcefile = inspect.getsourcefile(source_fn)
257277
module = inspect.getmodule(source_fn)
@@ -311,6 +331,7 @@ def recompile_fn_with_transform(
311331
transform_fn,
312332
fn_context,
313333
instance_type=parent_cls.__name__ if fn_is_method else None,
334+
dedent_chars=dedent_chars,
314335
).visit(fn_ast)
315336
ast.fix_missing_locations(transformed_fn_ast)
316337

@@ -412,6 +433,12 @@ def extract_cell(cell):
412433
else:
413434
transformed_fn = scope[source_fn.__name__]
414435

436+
# a decorator doesn't actually have to return a function! (could be used solely for side effect) e.g. `@register_backend_lookup_factory` for `find_content_backend` in `awkward/contents/content.py`
437+
if not callable(transformed_fn):
438+
return Err(
439+
f"Resulting transform of definition of {get_fn_name(source_fn)} is not even callable (got {transform_fn}). Perhaps a decorator that returns None?"
440+
)
441+
415442
# unmangle the name again - it's possible some packages might use __name__ internally for registries and whatnot
416443
transformed_fn.__name__ = source_fn.__name__
417444

0 commit comments

Comments
 (0)