Skip to content

Commit 95f3208

Browse files
committed
test
1 parent 22918dc commit 95f3208

File tree

4 files changed

+885
-69
lines changed

4 files changed

+885
-69
lines changed

tests/trace/test_op_generators.py

+240
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
from collections.abc import AsyncGenerator, Generator
2+
3+
import pytest
4+
5+
import weave
6+
7+
8+
@weave.op
9+
def basic_gen(x: int) -> Generator[int, None, None]:
10+
yield from range(x)
11+
12+
13+
@weave.op
14+
def inner(x: int) -> int:
15+
return x + 1
16+
17+
18+
@weave.op
19+
def nested_generator(x: int) -> Generator[int, None, None]:
20+
for i in range(x):
21+
yield inner(i)
22+
23+
24+
@weave.op
25+
def deeply_nested_generator(x: int) -> Generator[int, None, None]:
26+
for i in range(x):
27+
yield from nested_generator(i)
28+
29+
30+
@weave.op
31+
async def basic_async_gen(x: int) -> AsyncGenerator[int, None]:
32+
for i in range(x):
33+
yield i
34+
35+
36+
@weave.op
37+
async def inner_async(x: int) -> int:
38+
return x + 1
39+
40+
41+
@weave.op
42+
async def nested_async_generator(x: int) -> AsyncGenerator[int, None]:
43+
for i in range(x):
44+
yield await inner_async(i)
45+
46+
47+
@weave.op
48+
async def deeply_nested_async_generator(x: int) -> AsyncGenerator[int, None]:
49+
for i in range(x):
50+
async for j in await nested_async_generator(i):
51+
yield j
52+
53+
54+
def test_basic_gen(client):
55+
res = basic_gen(3)
56+
assert list(res) == [0, 1, 2]
57+
58+
calls = client.get_calls()
59+
assert len(calls) == 1
60+
61+
62+
def test_nested_generator(client):
63+
res = nested_generator(3)
64+
assert list(res) == [1, 2, 3]
65+
66+
calls = client.get_calls()
67+
assert len(calls) == 4
68+
69+
root_call = calls[0]
70+
assert "nested_generator" in root_call.op_name
71+
for i, call in enumerate(root_call.children()):
72+
assert "inner" in call.op_name
73+
assert call.inputs["x"] == i
74+
75+
76+
def test_deeply_nested_generator(client):
77+
res = deeply_nested_generator(4)
78+
# basic_gen(0) -> nothing
79+
# basic_gen(1) -> 1
80+
# basic_gen(2) -> 1, 2
81+
# basic_gen(3) -> 1, 2, 3
82+
assert list(res) == [1, 1, 2, 1, 2, 3]
83+
84+
calls = client.get_calls()
85+
assert len(calls) == 11
86+
87+
root_call = calls[0]
88+
assert "deeply_nested_generator" in root_call.op_name
89+
for i, call in enumerate(root_call.children()):
90+
assert "nested_generator" in call.op_name
91+
for j, call2 in enumerate(call.children()):
92+
assert "inner" in call2.op_name
93+
assert call2.inputs["x"] == j
94+
95+
96+
@pytest.mark.asyncio
97+
async def test_basic_async_gen(client):
98+
lst = []
99+
res = await basic_async_gen(3)
100+
async for i in res:
101+
lst.append(i)
102+
103+
assert lst == [0, 1, 2]
104+
105+
calls = client.get_calls()
106+
assert len(calls) == 1
107+
108+
109+
@pytest.mark.asyncio
110+
async def test_nested_async_generator(client):
111+
lst = []
112+
res = await nested_async_generator(3)
113+
async for i in res:
114+
lst.append(i)
115+
116+
assert lst == [1, 2, 3]
117+
118+
calls = client.get_calls()
119+
assert len(calls) == 4
120+
121+
root_call = calls[0]
122+
assert "nested_async_generator" in root_call.op_name
123+
for i, call in enumerate(root_call.children()):
124+
assert "inner_async" in call.op_name
125+
assert call.inputs["x"] == i
126+
127+
128+
@pytest.mark.asyncio
129+
async def test_deeply_nested_async_generator(client):
130+
lst = []
131+
res = await deeply_nested_async_generator(4)
132+
async for i in res:
133+
lst.append(i)
134+
135+
# basic_gen(0) -> nothing
136+
# basic_gen(1) -> 1
137+
# basic_gen(2) -> 1, 2
138+
# basic_gen(3) -> 1, 2, 3
139+
assert lst == [1, 1, 2, 1, 2, 3]
140+
141+
calls = client.get_calls()
142+
assert len(calls) == 11
143+
144+
root_call = calls[0]
145+
assert "deeply_nested_async_generator" in root_call.op_name
146+
for i, call in enumerate(root_call.children()):
147+
assert "nested_async_generator" in call.op_name
148+
for j, call2 in enumerate(call.children()):
149+
assert "inner_async" in call2.op_name
150+
assert call2.inputs["x"] == j
151+
152+
153+
def list_accumulator(acc, value):
154+
if acc is None:
155+
acc = []
156+
acc.append(value)
157+
return acc
158+
159+
160+
@weave.op(accumulator=list_accumulator)
161+
def basic_gen_with_accumulator(x: int) -> Generator[int, None, None]:
162+
yield from range(x)
163+
164+
165+
def test_generator_with_custom_accumulator(client):
166+
# Call the generator with the accumulator from the decorator
167+
res = basic_gen_with_accumulator(3)
168+
169+
# The generator still works as expected
170+
assert list(res) == [0, 1, 2]
171+
172+
# Get the call and check its output
173+
calls = client.get_calls()
174+
assert len(calls) == 1
175+
assert calls[0].output == [0, 1, 2]
176+
177+
178+
async def async_list_accumulator(acc, val):
179+
if acc is None:
180+
acc = []
181+
acc.append(val)
182+
return acc
183+
184+
185+
@weave.op(accumulator=async_list_accumulator)
186+
async def basic_async_gen_with_accumulator(x: int) -> AsyncGenerator[int, None]:
187+
for i in range(x):
188+
yield i
189+
190+
191+
@pytest.mark.asyncio
192+
async def test_async_generator_with_custom_accumulator(client):
193+
# Call the generator with the accumulator from the decorator
194+
res = await basic_async_gen_with_accumulator(3)
195+
196+
# The generator still works as expected
197+
assert [item async for item in res] == [0, 1, 2]
198+
199+
# Get the call and check its output
200+
calls = client.get_calls()
201+
assert len(calls) == 1
202+
assert calls[0].output == [0, 1, 2]
203+
204+
205+
@weave.op(accumulator=list_accumulator)
206+
def gen_with_decorator_accumulator(x: int) -> Generator[int, None, None]:
207+
yield from range(x)
208+
209+
210+
def test_generator_with_decorator_accumulator(client):
211+
# Call the generator with the accumulator from the decorator
212+
res = gen_with_decorator_accumulator(3)
213+
214+
# The generator still works as expected
215+
assert list(res) == [0, 1, 2]
216+
217+
# Get the call and check its output
218+
calls = client.get_calls()
219+
assert len(calls) == 1
220+
assert calls[0].output == [0, 1, 2]
221+
222+
223+
@weave.op(accumulator=async_list_accumulator)
224+
async def async_gen_with_decorator_accumulator(x: int) -> AsyncGenerator[int, None]:
225+
for i in range(x):
226+
yield i
227+
228+
229+
@pytest.mark.asyncio
230+
async def test_async_generator_with_decorator_accumulator(client):
231+
# Call the generator with the accumulator from the decorator
232+
res = await async_gen_with_decorator_accumulator(3)
233+
234+
# The generator still works as expected
235+
assert [item async for item in res] == [0, 1, 2]
236+
237+
# Get the call and check its output
238+
calls = client.get_calls()
239+
assert len(calls) == 1
240+
assert calls[0].output == [0, 1, 2]

0 commit comments

Comments
 (0)