Skip to content

Commit be224a8

Browse files
authored
Merge pull request #36 from jmont-dev/stop_streaming
Add support for stopping generation during a stream.
2 parents 76cb17a + 3b006dd commit be224a8

File tree

5 files changed

+121
-46
lines changed

5 files changed

+121
-46
lines changed

Diff for: README.md

+17-6
Original file line numberDiff line numberDiff line change
@@ -227,17 +227,20 @@ You can use a streaming generation to bind a callback function that is invoked e
227227

228228
```C++
229229

230-
void on_receive_response(const ollama::response& response)
230+
bool on_receive_response(const ollama::response& response)
231231
{
232232
// Print the token received
233233
std::cout << response << std::flush;
234234

235235
// The server will set "done" to true for the last response
236236
if (response.as_json()["done"]==true) std::cout << std::endl;
237+
238+
// Return true to continue streaming this response; return false to stop immediately.
239+
return true;
237240
}
238241

239242
// This function will be called every token
240-
std::function<void(const ollama::response&)> response_callback = on_receive_response;
243+
std::function<bool(const ollama::response&)> response_callback = on_receive_response;
241244

242245
// Bind the callback to the generation
243246
ollama::generate("llama3:8b", "Why is the sky blue?", response_callback);
@@ -251,16 +254,19 @@ You can launch a streaming call in a thread if you don't want it to block the pr
251254
252255
std::atomic<bool> done{false};
253256
254-
void on_receive_response(const ollama::response& response)
257+
bool on_receive_response(const ollama::response& response)
255258
{
256259
std::cout << response << std::flush;
257260
258261
if (response.as_json()["done"]==true) { done=true; std::cout << std::endl;}
262+
263+
// Return true to continue streaming this response; return false to stop immediately.
264+
return !done;
259265
}
260266
261267
// Use std::function to define a callback from an existing function
262268
// You can also use a lambda with an equivalent signature
263-
std::function<void(const ollama::response&)> response_callback = on_receive_response;
269+
std::function<bool(const ollama::response&)> response_callback = on_receive_response;
264270
265271
// You can launch the generation in a thread with a callback to use it asynchronously.
266272
std::thread new_thread( [response_callback]{
@@ -270,6 +276,8 @@ std::thread new_thread( [response_callback]{
270276
while (!done) { std::this_thread::sleep_for(std::chrono::microseconds(100) ); }
271277
new_thread.join();
272278
```
279+
The return value of the function determines whether to continue streaming or stop. This is useful in cases where you want to stop immediately instead of waiting for an entire response to return.
280+
273281
### Using Images
274282
Generations can include images for vision-enabled models such as `llava`. The `ollama::image` class can load an image from a file and encode it as a [base64](https://en.wikipedia.org/wiki/Base64) string.
275283

@@ -352,14 +360,17 @@ The default chat generation does not stream tokens and will return the entire re
352360

353361
```C++
354362

355-
void on_receive_response(const ollama::response& response)
363+
bool on_receive_response(const ollama::response& response)
356364
{
357365
std::cout << response << std::flush;
358366

359367
if (response.as_json()["done"]==true) std::cout << std::endl;
368+
369+
// Return true to continue streaming, or false to stop immediately
370+
return true;
360371
}
361372

362-
std::function<void(const ollama::response&)> response_callback = on_receive_response;
373+
std::function<bool(const ollama::response&)> response_callback = on_receive_response;
363374

364375
ollama::message message("user", "Why is the sky blue?");
365376

Diff for: examples/main.cpp

+5-2
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@ using json = nlohmann::json;
1212

1313
std::atomic<bool> done{false};
1414

15-
void on_receive_response(const ollama::response& response)
15+
bool on_receive_response(const ollama::response& response)
1616
{
1717
std::cout << response << std::flush;
1818

1919
if (response.as_json()["done"]==true) { done=true; std::cout << std::endl;}
20+
21+
return !done; // Return true to continue streaming this response; return false to stop immediately.
2022
}
2123

2224
// Install ollama, llama3, and llava first to run this demo
@@ -130,13 +132,14 @@ int main()
130132
// Perform a simple generation which includes model options.
131133
std::cout << ollama::generate("llama3:8b", "Why is the sky green?", options) << std::endl;
132134

133-
std::function<void(const ollama::response&)> response_callback = on_receive_response;
135+
std::function<bool(const ollama::response&)> response_callback = on_receive_response;
134136
ollama::generate("llama3:8b", "Why is the sky orange?", response_callback);
135137

136138
// You can launch the generation in a thread with a callback to use it asynchronously.
137139
std::thread new_thread( [response_callback]{ ollama::generate("llama3:8b", "Why is the sky gray?", response_callback); } );
138140

139141
// Prevent the main thread from exiting while we wait for an asynchronous response.
142+
// Alternatively, we can set done=true to stop this thread immediately.
140143
while (!done) { std::this_thread::sleep_for(std::chrono::microseconds(100) ); }
141144
new_thread.join();
142145

Diff for: include/ollama.hpp

+22-16
Original file line numberDiff line numberDiff line change
@@ -419,22 +419,22 @@ class Ollama
419419
return response;
420420
}
421421

422-
bool generate(const std::string& model,const std::string& prompt, ollama::response& context, std::function<void(const ollama::response&)> on_receive_token, const json& options=nullptr, const std::vector<std::string>& images=std::vector<std::string>())
422+
bool generate(const std::string& model,const std::string& prompt, ollama::response& context, std::function<bool(const ollama::response&)> on_receive_token, const json& options=nullptr, const std::vector<std::string>& images=std::vector<std::string>())
423423
{
424424
ollama::request request(model, prompt, options, true, images);
425425
if ( context.as_json().contains("context") ) request["context"] = context.as_json()["context"];
426426
return generate(request, on_receive_token);
427427
}
428428

429-
bool generate(const std::string& model,const std::string& prompt, std::function<void(const ollama::response&)> on_receive_token, const json& options=nullptr, const std::vector<std::string>& images=std::vector<std::string>())
429+
bool generate(const std::string& model,const std::string& prompt, std::function<bool(const ollama::response&)> on_receive_token, const json& options=nullptr, const std::vector<std::string>& images=std::vector<std::string>())
430430
{
431431
ollama::request request(model, prompt, options, true, images);
432432
return generate(request, on_receive_token);
433433
}
434434

435435

436436
// Generate a streaming reply where a user-defined callback function is invoked when each token is received.
437-
bool generate(ollama::request& request, std::function<void(const ollama::response&)> on_receive_token)
437+
bool generate(ollama::request& request, std::function<bool(const ollama::response&)> on_receive_token)
438438
{
439439
request["stream"] = true;
440440

@@ -446,22 +446,25 @@ class Ollama
446446
auto stream_callback = [on_receive_token, partial_responses](const char *data, size_t data_length)->bool{
447447

448448
std::string message(data, data_length);
449+
bool continue_stream = true;
450+
449451
if (ollama::log_replies) std::cout << message << std::endl;
450452
try
451453
{
452454
partial_responses->push_back(message);
453455
std::string total_response = std::accumulate(partial_responses->begin(), partial_responses->end(), std::string(""));
454456
ollama::response response(total_response);
455457
partial_responses->clear();
456-
on_receive_token(response);
458+
continue_stream = on_receive_token(response);
457459
}
458460
catch (const ollama::invalid_json_exception& e) { /* Partial response was received. Will do nothing and attempt to concatenate with the next response. */ }
459461

460-
return true;
462+
return continue_stream;
461463
};
462464

463465
if (auto res = this->cli->Post("/api/generate", request_string, "application/json", stream_callback)) { return true; }
464-
else { if (ollama::use_exceptions) throw ollama::exception( "No response from server returned at URL"+this->server_url+" Error: "+httplib::to_string( res.error() ) ); }
466+
else if (res.error()==httplib::Error::Canceled) { /* Request cancelled by user. */ return true; }
467+
else { if (ollama::use_exceptions) throw ollama::exception( "No response from server returned at URL "+this->server_url+" Error: "+httplib::to_string( res.error() ) ); }
465468

466469
return false;
467470
}
@@ -498,14 +501,14 @@ class Ollama
498501
return response;
499502
}
500503

501-
bool chat(const std::string& model, const ollama::messages& messages, std::function<void(const ollama::response&)> on_receive_token, const json& options=nullptr, const std::string& format="json", const std::string& keep_alive_duration="5m")
504+
bool chat(const std::string& model, const ollama::messages& messages, std::function<bool(const ollama::response&)> on_receive_token, const json& options=nullptr, const std::string& format="json", const std::string& keep_alive_duration="5m")
502505
{
503506
ollama::request request(model, messages, options, true, format, keep_alive_duration);
504507
return chat(request, on_receive_token);
505508
}
506509

507510

508-
bool chat(ollama::request& request, std::function<void(const ollama::response&)> on_receive_token)
511+
bool chat(ollama::request& request, std::function<bool(const ollama::response&)> on_receive_token)
509512
{
510513
ollama::response response;
511514
request["stream"] = true;
@@ -518,6 +521,8 @@ class Ollama
518521
auto stream_callback = [on_receive_token, partial_responses](const char *data, size_t data_length)->bool{
519522

520523
std::string message(data, data_length);
524+
bool continue_stream = true;
525+
521526
if (ollama::log_replies) std::cout << message << std::endl;
522527
try
523528
{
@@ -527,14 +532,15 @@ class Ollama
527532
partial_responses->clear();
528533

529534
if ( response.has_error() ) { if (ollama::use_exceptions) throw ollama::exception("Ollama response returned error: "+response.get_error() ); }
530-
on_receive_token(response);
535+
continue_stream = on_receive_token(response);
531536
}
532537
catch (const ollama::invalid_json_exception& e) { /* Partial response was received. Will do nothing and attempt to concatenate with the next response. */ }
533538

534-
return true;
539+
return continue_stream;
535540
};
536541

537542
if (auto res = this->cli->Post("/api/chat", request_string, "application/json", stream_callback)) { return true; }
543+
else if (res.error()==httplib::Error::Canceled) { /* Request cancelled by user. */ return true; }
538544
else { if (ollama::use_exceptions) throw ollama::exception( "No response from server returned at URL"+this->server_url+" Error: "+httplib::to_string( res.error() ) ); }
539545

540546
return false;
@@ -872,7 +878,7 @@ class Ollama
872878
private:
873879

874880
/*
875-
bool send_request(const ollama::request& request, std::function<void(const ollama::response&)> on_receive_response=nullptr)
881+
bool send_request(const ollama::request& request, std::function<bool(const ollama::response&)> on_receive_response=nullptr)
876882
{
877883
878884
return true;
@@ -910,17 +916,17 @@ namespace ollama
910916
return ollama.generate(request);
911917
}
912918

913-
inline bool generate(const std::string& model,const std::string& prompt, std::function<void(const ollama::response&)> on_receive_response, const json& options=nullptr, const std::vector<std::string>& images=std::vector<std::string>())
919+
inline bool generate(const std::string& model,const std::string& prompt, std::function<bool(const ollama::response&)> on_receive_response, const json& options=nullptr, const std::vector<std::string>& images=std::vector<std::string>())
914920
{
915921
return ollama.generate(model, prompt, on_receive_response, options, images);
916922
}
917923

918-
inline bool generate(const std::string& model,const std::string& prompt, ollama::response& context, std::function<void(const ollama::response&)> on_receive_response, const json& options=nullptr, const std::vector<std::string>& images=std::vector<std::string>())
924+
inline bool generate(const std::string& model,const std::string& prompt, ollama::response& context, std::function<bool(const ollama::response&)> on_receive_response, const json& options=nullptr, const std::vector<std::string>& images=std::vector<std::string>())
919925
{
920926
return ollama.generate(model, prompt, context, on_receive_response, options, images);
921927
}
922928

923-
inline bool generate(ollama::request& request, std::function<void(const ollama::response&)> on_receive_response)
929+
inline bool generate(ollama::request& request, std::function<bool(const ollama::response&)> on_receive_response)
924930
{
925931
return ollama.generate(request, on_receive_response);
926932
}
@@ -935,12 +941,12 @@ namespace ollama
935941
return ollama.chat(request);
936942
}
937943

938-
inline bool chat(const std::string& model, const ollama::messages& messages, std::function<void(const ollama::response&)> on_receive_response, const json& options=nullptr, const std::string& format="json", const std::string& keep_alive_duration="5m")
944+
inline bool chat(const std::string& model, const ollama::messages& messages, std::function<bool(const ollama::response&)> on_receive_response, const json& options=nullptr, const std::string& format="json", const std::string& keep_alive_duration="5m")
939945
{
940946
return ollama.chat(model, messages, on_receive_response, options, format, keep_alive_duration);
941947
}
942948

943-
inline bool chat(ollama::request& request, std::function<void(const ollama::response&)> on_receive_response)
949+
inline bool chat(ollama::request& request, std::function<bool(const ollama::response&)> on_receive_response)
944950
{
945951
return ollama.chat(request, on_receive_response);
946952
}

0 commit comments

Comments
 (0)