@@ -18,19 +18,21 @@ def __init_subclass__(cls, **kwargs) -> None:
18
18
cls .format = with_callbacks (cls .format )
19
19
cls .parse = with_callbacks (cls .parse )
20
20
21
- def __call__ (self , lm , lm_kwargs , signature , demos , inputs , predict = None ):
21
+ def __call__ (self , lm , lm_kwargs , signature , demos , inputs ):
22
22
inputs_ = self .format (signature , demos , inputs )
23
23
inputs_ = dict (prompt = inputs_ ) if isinstance (inputs_ , str ) else dict (messages = inputs_ )
24
24
25
25
stream_listeners = settings .stream_listeners or []
26
+ caller_predict = settings .caller_predict
26
27
stream = settings .send_stream is not None
27
28
if stream and len (stream_listeners ) > 0 :
28
- stream = any (stream_listener .predict == predict for stream_listener in stream_listeners )
29
+ stream = any (stream_listener .predict == caller_predict for stream_listener in stream_listeners )
29
30
30
31
if stream :
31
- with settings .context (stream_predict = predict ):
32
- outputs = lm (** inputs_ , ** lm_kwargs )
32
+ outputs = lm (** inputs_ , ** lm_kwargs )
33
33
else :
34
+ # Explicilty disable streaming if streaming is not enabled globally or the caller predict shouldn't be
35
+ # streamed.
34
36
with settings .context (send_stream = None ):
35
37
outputs = lm (** inputs_ , ** lm_kwargs )
36
38
0 commit comments