Skip to content

Commit 90ae0ee

Browse files
committed
refactor: NodeContext -> RayContext
1 parent 031fec4 commit 90ae0ee

12 files changed

+365
-306
lines changed

maxray/__init__.py

+22-6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from .transforms import recompile_fn_with_transform, NodeContext
1+
from .nodes import NodeContext, RayContext
2+
from .transforms import recompile_fn_with_transform
23
from .function_store import FunctionStore, set_property_on_functionlike
34

45
import inspect
@@ -106,7 +107,14 @@ class TransformSettings:
106107
Whether to always populate `NodeContext.local_scope` with the evaluation of `locals()` at every node.
107108
"""
108109

109-
def update(self, forbid_modules, restrict_modules, pass_local_scopes):
110+
preserve_values: bool = True
111+
"""
112+
If False, will attempt to unpack/destructure `x` values to support pattern matching features.
113+
"""
114+
115+
def update(
116+
self, forbid_modules, restrict_modules, pass_local_scopes, preserve_values
117+
):
110118
def or_empty_set(s):
111119
return frozenset([] if s is None else s)
112120

@@ -119,6 +127,7 @@ def or_empty_set(s):
119127
forbid_modules=self.forbid_modules.union(forbid_modules),
120128
restrict_modules=restrict_modules,
121129
pass_local_scopes=self.pass_local_scopes or pass_local_scopes,
130+
preserve_values=self.preserve_values and preserve_values,
122131
)
123132

124133

@@ -372,6 +381,11 @@ def _maxray_walker_handler(x, ctx: NodeContext):
372381
# 2. run the active hooks
373382
global_write_active_token = _GLOBAL_WRITER_ACTIVE_FLAG.set(True)
374383
try:
384+
ray = RayContext(
385+
x, ctx, unpack_assignments=not _MAXRAY_TRANSFORM_SETTINGS.preserve_values
386+
)
387+
x = ray.value()
388+
375389
for walk_hook in _MAXRAY_REGISTERED_HOOKS:
376390
# Our recompiled fn sets and unsets a contextvar whenever it is active
377391
if not walk_hook.active_call_state.get():
@@ -380,9 +394,9 @@ def _maxray_walker_handler(x, ctx: NodeContext):
380394
# Set the writer active flag
381395
write_active_token = walk_hook.writer_active_call_state.set(True)
382396
if walk_hook.mutable:
383-
x = walk_hook.impl_fn(x, ctx)
397+
x = walk_hook.impl_fn(x, ray)
384398
else:
385-
walk_hook.impl_fn(x, ctx)
399+
walk_hook.impl_fn(x, ray)
386400
walk_hook.writer_active_call_state.reset(write_active_token)
387401
finally:
388402
_GLOBAL_WRITER_ACTIVE_FLAG.reset(global_write_active_token)
@@ -394,14 +408,15 @@ def _maxray_walker_handler(x, ctx: NodeContext):
394408

395409

396410
def maxray(
397-
writer: Callable[[Any, NodeContext], Any],
411+
writer: Callable[[Any, RayContext], Any],
398412
*,
399413
mutable=True,
400414
forbid_modules=frozenset(),
401415
restrict_modules=None,
402-
pass_scope=False,
403416
initial_scope={},
404417
assume_transformed=False,
418+
pass_scope=False,
419+
preserve_values=True,
405420
) -> Callable[[T], T]:
406421
"""
407422
A transform that recursively hooks into all further calls made within the function, so that `writer` will (in theory) observe every single expression evaluated by the Python interpreter occurring as part of the decorated function call.
@@ -419,6 +434,7 @@ def maxray(
419434
forbid_modules=forbid_modules,
420435
restrict_modules=restrict_modules,
421436
pass_local_scopes=pass_scope,
437+
preserve_values=preserve_values,
422438
)
423439

424440
# TODO: allow configuring injection of variables into exec scope

maxray/__main__.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,36 @@
44

55
import click
66

7+
from typing import Optional
8+
79

810
@click.group()
911
def cli():
1012
pass
1113

1214

1315
@cli.command()
14-
@click.argument("file", type=str)
16+
@click.option("--over", type=str)
17+
@click.option("--new", type=str)
1518
@click.option("-f", "--force", is_flag=True)
1619
@click.option("--runner", is_flag=True)
17-
def template(file: str, force: bool, runner: bool):
18-
path = Path(file).resolve(True)
19-
assert path.suffix == ".py"
20+
def template(over: Optional[str], new: Optional[str], force: bool, runner: bool):
21+
save_path_args = [f for f in [over, new] if f is not None]
22+
23+
if not save_path_args or len(save_path_args) > 1:
24+
raise ValueError("Must specify exactly one of --new or --over")
25+
26+
save_path = Path(save_path_args[0]).resolve(False)
27+
if over is not None:
28+
save_path = save_path.with_name(f"over_{save_path.name}")
29+
30+
assert save_path.suffix == ".py"
2031

2132
spec = find_spec("maxray.inators.template")
2233
assert spec is not None
2334
assert spec.origin is not None
2435

25-
template_path = path.with_name(f"over_{path.name}")
36+
template_path = save_path
2637
if not force:
2738
assert not template_path.exists(), f"{template_path} exists!"
2839

maxray/capture/logs.py

+2-9
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ class CaptureLogs:
2424
"loc_col_start": pa.int32(),
2525
"loc_col_end": pa.int32(),
2626
# Function calls
27-
"context_call_id": pa.string(),
28-
"target_call_id": pa.string(),
27+
"fn_compile_id": pa.string(),
2928
"fn": pa.string(),
3029
"fn_call_count": pa.int32(),
3130
# Source info/code
@@ -58,16 +57,13 @@ class CaptureLogs:
5857
@staticmethod
5958
def extractor(x, ctx: NodeContext):
6059
if isinstance(instance := CaptureLogs.instance.get(None), CaptureLogs):
61-
if ctx.caller_id is not None:
62-
instance.fn_sources[ctx.caller_id] = ctx.fn_context
6360
if ctx.source != "self":
6461
instance.builder("loc_line_start").append(ctx.location[0])
6562
instance.builder("loc_line_end").append(ctx.location[1])
6663
instance.builder("loc_col_start").append(ctx.location[2])
6764
instance.builder("loc_col_end").append(ctx.location[3])
6865

69-
instance.builder("context_call_id").append(ctx.fn_context.compile_id)
70-
instance.builder("target_call_id").append(ctx.caller_id)
66+
instance.builder("fn_compile_id").append(ctx.fn_context.compile_id)
7167
instance.builder("source_file").append(ctx.fn_context.source_file)
7268
instance.builder("source_module").append(ctx.fn_context.module)
7369
instance.builder("source").append(ctx.source)
@@ -106,9 +102,6 @@ def schema():
106102
)
107103

108104
def __init__(self, stream_to_arrow_file=None, flush_every_records: int = 10_000):
109-
# Maps function UUIDs (_MAXRAY_TRANSFORM_ID) to FnContext instances
110-
self.fn_sources = {}
111-
112105
self.builders = {}
113106

114107
if stream_to_arrow_file is not None:

maxray/inators/core.py

+11-100
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
from __future__ import annotations
22

3-
from maxray.transforms import NodeContext
3+
from maxray.nodes import NodeContext, RayContext
44
from .display import Display
55

66
import ipdb
77
import attrs
88

99
import json
10-
from result import Result, Ok, Err
1110
from contextvars import ContextVar
1211
from dataclasses import dataclass
1312
from pathlib import Path
@@ -32,7 +31,11 @@ def __getitem__(self, key):
3231
return self._existing_keys[key][1]
3332

3433
def __setitem__(self, key, value):
35-
v, _old_value = self._existing_keys[key]
34+
if key in self._existing_keys:
35+
v, _old_value = self._existing_keys[key]
36+
else:
37+
v = 1
38+
3639
self._existing_keys[key] = (v, value)
3740

3841
def define_once(self, key, factory, /, v: int = 0):
@@ -243,70 +246,21 @@ def __getattr__(self, rewrite_cls_name: str) -> RewriteContext:
243246
return self.by_class[rewrite_cls_name]
244247

245248

246-
def unpack_assign_context(x, ctx):
247-
match ctx.props:
248-
case {"assigned": {"targets": targets}}:
249-
if len(targets) > 1:
250-
if inspect.isgenerator(x) or isinstance(x, (map, filter)):
251-
# Greedily consume iterators before assignment
252-
unpacked_x = tuple(iter(x))
253-
else:
254-
# Otherwise for chained equality like a = b, c = it, code may rely on `a` being of the original type
255-
unpacked_x = x
256-
257-
# TODO: doesn't work for starred assignments: x, *y, z = iterable
258-
assigned = {target: val for target, val in zip(targets, unpacked_x)}
259-
return unpacked_x, assigned
260-
261-
elif len(targets) == 1:
262-
return x, {targets[0]: x}
263-
else:
264-
return x, {}
265-
case _:
266-
return x, {}
267-
268-
269249
class LoggingEncoder(json.JSONEncoder):
270250
def default(self, o):
271251
if attrs.has(type(o)):
272252
return attrs.asdict(o)
273253
return super().default(o)
274254

275255

276-
class Ray:
256+
class Ray(RayContext):
277257
"""
278258
Captures the state of a point (syntax node) in the source code of the original program.
279259
280260
One instance is created for each point in the program, that is then passed to multiple handlers.
281261
"""
282262

283-
INSTANCE = ContextVar("RY")
284-
285-
def __init__(self, x, ctx: NodeContext, *, unpack_assignments: bool):
286-
if unpack_assignments:
287-
self._x, self._assigned = unpack_assign_context(x, ctx)
288-
self.ctx = ctx
289-
290-
# TODO: staticmethod to get
291-
292-
@contextmanager
293-
def _bind(self, x, ctx):
294-
# TODO: assert no previous active instance
295-
reset = Ray.INSTANCE.set(self)
296-
try:
297-
yield Ray.INSTANCE.get()
298-
finally:
299-
Ray.INSTANCE.reset(reset)
300-
301-
@staticmethod
302-
def try_get() -> Result[Ray, None]:
303-
try:
304-
return Ok(Ray.INSTANCE.get())
305-
except LookupError:
306-
return Err(None)
307-
308-
@staticmethod
309-
def log(msg, *, level="INFO"):
263+
def log(self, msg, *, level="INFO"):
310264
"""
311265
Logs to Rerun with the current context if active.
312266
"""
@@ -319,52 +273,9 @@ def log(msg, *, level="INFO"):
319273
case _ if attrs.has(type(msg)):
320274
msg = str(msg)
321275

322-
match Ray.try_get():
323-
case Ok(ray):
324-
location = Path(ray.ctx.fn_context.source_file).name
325-
line = ray.ctx.location[0]
326-
rr.log(f"log/{location}:{line}", rr.TextLog(msg, level=level))
327-
case Err():
328-
rr.log("log/somewhere", rr.TextLog(msg, level=level))
329-
330-
def value(self):
331-
return self._x
332-
333-
def locals(self):
334-
match self.ctx.local_scope:
335-
case dict():
336-
return self.ctx.local_scope
337-
case None:
338-
return {}
339-
case _:
340-
raise TypeError(
341-
f"Unexpected type {type(self.ctx.local_scope)} for local scope"
342-
)
343-
344-
def assigned(self):
345-
return self._assigned
346-
347-
def iterated(self):
348-
match self.ctx.props:
349-
case {"iterated": {"target": target}}:
350-
return [target]
351-
case _:
352-
return []
353-
354-
def returned(self): ...
355-
356-
def entered(self):
357-
"""
358-
Returns:
359-
{
360-
source: Source code of the LHS being entered
361-
as_var: Variable binding after "as", if present
362-
"""
363-
match self.ctx.props:
364-
case {"entered": entered}:
365-
return entered
366-
case _:
367-
return {}
276+
location = Path(self.ctx.fn_context.source_file).name
277+
line = self.ctx.location[0]
278+
rr.log(f"log/{location}:{line}", rr.TextLog(msg, level=level))
368279

369280
def contextmanager(self, fn: Callable[[Ray], Iterator[Any]]):
370281
return contextmanager(partial(fn, self))

0 commit comments

Comments
 (0)