Skip to content

Commit d04eb5d

Browse files
remove extra predict field
1 parent 540bfab commit d04eb5d

File tree

5 files changed

+17
-13
lines changed

5 files changed

+17
-13
lines changed

dspy/adapters/base.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,21 @@ def __init_subclass__(cls, **kwargs) -> None:
1818
cls.format = with_callbacks(cls.format)
1919
cls.parse = with_callbacks(cls.parse)
2020

21-
def __call__(self, lm, lm_kwargs, signature, demos, inputs, predict=None):
21+
def __call__(self, lm, lm_kwargs, signature, demos, inputs):
2222
inputs_ = self.format(signature, demos, inputs)
2323
inputs_ = dict(prompt=inputs_) if isinstance(inputs_, str) else dict(messages=inputs_)
2424

2525
stream_listeners = settings.stream_listeners or []
26+
caller_predict = settings.caller_predict
2627
stream = settings.send_stream is not None
2728
if stream and len(stream_listeners) > 0:
28-
stream = any(stream_listener.predict == predict for stream_listener in stream_listeners)
29+
stream = any(stream_listener.predict == caller_predict for stream_listener in stream_listeners)
2930

3031
if stream:
31-
with settings.context(stream_predict=predict):
32-
outputs = lm(**inputs_, **lm_kwargs)
32+
outputs = lm(**inputs_, **lm_kwargs)
3333
else:
34+
# Explicilty disable streaming if streaming is not enabled globally or the caller predict shouldn't be
35+
# streamed.
3436
with settings.context(send_stream=None):
3537
outputs = lm(**inputs_, **lm_kwargs)
3638

dspy/adapters/json_adapter.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@ class JSONAdapter(Adapter):
3232
def __init__(self):
3333
pass
3434

35-
def __call__(self, lm, lm_kwargs, signature, demos, inputs, predict=None):
35+
def __call__(self, lm, lm_kwargs, signature, demos, inputs):
3636
inputs = self.format(signature, demos, inputs)
3737
inputs = dict(prompt=inputs) if isinstance(inputs, str) else dict(messages=inputs)
3838

3939
stream_listeners = settings.stream_listeners or []
4040
if len(stream_listeners) > 0:
41-
raise ValueError("Stream listener is not supported for JsonAdapter, please use ChatAdapter instead.")
41+
raise ValueError("Stream listener is not yet supported for JsonAdapter, please use ChatAdapter instead.")
4242

4343
try:
4444
provider = lm.model.split("/", 1)[0] or "openai"

dspy/clients/lm.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def litellm_completion(request: Dict[str, Any], num_retries: int, cache={"no-cac
341341
)
342342

343343
stream = dspy.settings.send_stream
344-
stream_predict = dspy.settings.stream_predict
344+
caller_predict = dspy.settings.caller_predict
345345
if stream is None:
346346
# If `streamify` is not used, or if the exact predict doesn't need to be streamed,
347347
# we can just return the completion without streaming.
@@ -353,7 +353,7 @@ def litellm_completion(request: Dict[str, Any], num_retries: int, cache={"no-cac
353353

354354
# The stream is already opened, and will be closed by the caller.
355355
stream = cast(MemoryObjectSendStream, stream)
356-
stream_predict_id = id(stream_predict) if stream_predict else None
356+
caller_predict_id = id(caller_predict) if caller_predict else None
357357

358358
@syncify
359359
async def stream_completion():
@@ -365,9 +365,9 @@ async def stream_completion():
365365
)
366366
chunks = []
367367
async for chunk in response:
368-
if stream_predict_id:
368+
if caller_predict_id:
369369
# Add the predict id to the chunk so that the stream listener can identify which predict produces it.
370-
chunk.predict_id = stream_predict_id
370+
chunk.predict_id = caller_predict_id
371371
chunks.append(chunk)
372372
await stream.send(chunk)
373373
return litellm.stream_chunk_builder(chunks)

dspy/dsp/utils/settings.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
async_max_workers=8,
2121
send_stream=None,
2222
disable_history=False,
23-
stream_predict=None,
23+
caller_predict=None,
2424
stream_listeners=[],
2525
)
2626

dspy/predict/predict.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22

33
from pydantic import BaseModel
44

5+
from dspy.clients.lm import LM
6+
from dspy.dsp.utils.settings import settings
57
from dspy.predict.parameter import Parameter
68
from dspy.primitives.prediction import Prediction
79
from dspy.primitives.program import Module
810
from dspy.signatures.signature import ensure_signature
911
from dspy.utils.callback import with_callbacks
10-
from dspy.clients.lm import LM
1112

1213

1314
class Predict(Module, Parameter):
@@ -98,7 +99,8 @@ def forward(self, **kwargs):
9899
import dspy
99100

100101
adapter = dspy.settings.adapter or dspy.ChatAdapter()
101-
completions = adapter(lm, lm_kwargs=config, signature=signature, demos=demos, inputs=kwargs, predict=self)
102+
with settings.context(caller_predict=self):
103+
completions = adapter(lm, lm_kwargs=config, signature=signature, demos=demos, inputs=kwargs)
102104

103105
pred = Prediction.from_completions(completions, signature=signature)
104106

0 commit comments

Comments
 (0)