Skip to content

Commit fdd114f

Browse files
authored
Merge pull request #16 from jmont-dev/tool-support
Update embeddings to new endpoint; change tests to not rely on exact comparisons.
2 parents f4909c6 + 0aaa3bd commit fdd114f

File tree

3 files changed

+46
-50
lines changed

3 files changed

+46
-50
lines changed

Diff for: include/ollama.hpp

+10-9
Original file line numberDiff line numberDiff line change
@@ -263,13 +263,14 @@ namespace ollama
263263
request(): json() {}
264264
~request(){};
265265

266-
static ollama::request from_embedding(const std::string& name, const std::string& prompt, const json& options=nullptr, const std::string& keep_alive_duration="5m")
266+
static ollama::request from_embedding(const std::string& model, const std::string& input, const json& options=nullptr, bool truncate=true, const std::string& keep_alive_duration="5m")
267267
{
268268
ollama::request request(message_type::embedding);
269269

270-
request["model"] = name;
271-
request["prompt"] = prompt;
270+
request["model"] = model;
271+
request["input"] = input;
272272
if (options!=nullptr) request["options"] = options["options"];
273+
request["truncate"] = truncate;
273274
request["keep_alive"] = keep_alive_duration;
274275

275276
return request;
@@ -295,7 +296,7 @@ namespace ollama
295296

296297
if (type==message_type::generation && json_data.contains("response")) simple_string=json_data["response"].get<std::string>();
297298
else
298-
if (type==message_type::embedding && json_data.contains("embedding")) simple_string=json_data["embedding"].get<std::string>();
299+
if (type==message_type::embedding && json_data.contains("embeddings")) simple_string=json_data["embeddings"].get<std::string>();
299300
else
300301
if (type==message_type::chat && json_data.contains("message")) simple_string=json_data["message"]["content"].get<std::string>();
301302

@@ -715,15 +716,15 @@ class Ollama
715716
return false;
716717
}
717718

718-
ollama::response generate_embeddings(const std::string& model, const std::string& prompt, const json& options=nullptr, const std::string& keep_alive_duration="5m")
719+
ollama::response generate_embeddings(const std::string& model, const std::string& input, const json& options=nullptr, bool truncate = true, const std::string& keep_alive_duration="5m")
719720
{
720-
ollama::request request = ollama::request::from_embedding(model, prompt, options, keep_alive_duration);
721+
ollama::request request = ollama::request::from_embedding(model, input, options, truncate, keep_alive_duration);
721722
ollama::response response;
722723

723724
std::string request_string = request.dump();
724725
if (ollama::log_requests) std::cout << request_string << std::endl;
725726

726-
if (auto res = cli->Post("/api/embeddings", request_string, "application/json"))
727+
if (auto res = cli->Post("/api/embed", request_string, "application/json"))
727728
{
728729
if (ollama::log_replies) std::cout << res->body << std::endl;
729730

@@ -885,9 +886,9 @@ namespace ollama
885886
return ollama.push_model(model, allow_insecure);
886887
}
887888

888-
inline ollama::response generate_embeddings(const std::string& model, const std::string& prompt, const json& options=nullptr, const std::string& keep_alive_duration="5m")
889+
inline ollama::response generate_embeddings(const std::string& model, const std::string& input, const json& options=nullptr, bool truncate = true, const std::string& keep_alive_duration="5m")
889890
{
890-
return ollama.generate_embeddings(model, prompt, options, keep_alive_duration);
891+
return ollama.generate_embeddings(model, input, options, truncate, keep_alive_duration);
891892
}
892893

893894
inline void setReadTimeout(const int& seconds)

Diff for: singleheader/ollama.hpp

+10-9
Original file line numberDiff line numberDiff line change
@@ -35053,13 +35053,14 @@ namespace ollama
3505335053
request(): json() {}
3505435054
~request(){};
3505535055

35056-
static ollama::request from_embedding(const std::string& name, const std::string& prompt, const json& options=nullptr, const std::string& keep_alive_duration="5m")
35056+
static ollama::request from_embedding(const std::string& model, const std::string& input, const json& options=nullptr, bool truncate=true, const std::string& keep_alive_duration="5m")
3505735057
{
3505835058
ollama::request request(message_type::embedding);
3505935059

35060-
request["model"] = name;
35061-
request["prompt"] = prompt;
35060+
request["model"] = model;
35061+
request["input"] = input;
3506235062
if (options!=nullptr) request["options"] = options["options"];
35063+
request["truncate"] = truncate;
3506335064
request["keep_alive"] = keep_alive_duration;
3506435065

3506535066
return request;
@@ -35085,7 +35086,7 @@ namespace ollama
3508535086

3508635087
if (type==message_type::generation && json_data.contains("response")) simple_string=json_data["response"].get<std::string>();
3508735088
else
35088-
if (type==message_type::embedding && json_data.contains("embedding")) simple_string=json_data["embedding"].get<std::string>();
35089+
if (type==message_type::embedding && json_data.contains("embeddings")) simple_string=json_data["embeddings"].get<std::string>();
3508935090
else
3509035091
if (type==message_type::chat && json_data.contains("message")) simple_string=json_data["message"]["content"].get<std::string>();
3509135092

@@ -35505,15 +35506,15 @@ class Ollama
3550535506
return false;
3550635507
}
3550735508

35508-
ollama::response generate_embeddings(const std::string& model, const std::string& prompt, const json& options=nullptr, const std::string& keep_alive_duration="5m")
35509+
ollama::response generate_embeddings(const std::string& model, const std::string& input, const json& options=nullptr, bool truncate = true, const std::string& keep_alive_duration="5m")
3550935510
{
35510-
ollama::request request = ollama::request::from_embedding(model, prompt, options, keep_alive_duration);
35511+
ollama::request request = ollama::request::from_embedding(model, input, options, truncate, keep_alive_duration);
3551135512
ollama::response response;
3551235513

3551335514
std::string request_string = request.dump();
3551435515
if (ollama::log_requests) std::cout << request_string << std::endl;
3551535516

35516-
if (auto res = cli->Post("/api/embeddings", request_string, "application/json"))
35517+
if (auto res = cli->Post("/api/embed", request_string, "application/json"))
3551735518
{
3551835519
if (ollama::log_replies) std::cout << res->body << std::endl;
3551935520

@@ -35675,9 +35676,9 @@ namespace ollama
3567535676
return ollama.push_model(model, allow_insecure);
3567635677
}
3567735678

35678-
inline ollama::response generate_embeddings(const std::string& model, const std::string& prompt, const json& options=nullptr, const std::string& keep_alive_duration="5m")
35679+
inline ollama::response generate_embeddings(const std::string& model, const std::string& input, const json& options=nullptr, bool truncate = true, const std::string& keep_alive_duration="5m")
3567935680
{
35680-
return ollama.generate_embeddings(model, prompt, options, keep_alive_duration);
35681+
return ollama.generate_embeddings(model, input, options, truncate, keep_alive_duration);
3568135682
}
3568235683

3568335684
inline void setReadTimeout(const int& seconds)

Diff for: test/test.cpp

+26-32
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
// Note that this is static. We will use these options for other generations.
1313
static ollama::options options;
1414

15+
static std::string test_model = "llama3:8b", image_test_model = "llava";
16+
1517
TEST_SUITE("Ollama Tests") {
1618

1719
TEST_CASE("Initialize Options") {
@@ -52,19 +54,19 @@ TEST_SUITE("Ollama Tests") {
5254

5355
TEST_CASE("Load Model") {
5456

55-
CHECK( ollama::load_model("llama3:8b") );
57+
CHECK( ollama::load_model(test_model) );
5658
}
5759

5860
TEST_CASE("Pull, Copy, and Delete Models") {
5961

6062
// Pull a model by specifying a model name.
61-
CHECK( ollama::pull_model("llama3:8b") == true );
63+
CHECK( ollama::pull_model(test_model) == true );
6264

6365
// Copy a model by specifying a source model and destination model name.
64-
CHECK( ollama::copy_model("llama3:8b", "llama3_copy") ==true );
66+
CHECK( ollama::copy_model(test_model, test_model+"_copy") ==true );
6567

6668
// Delete a model by specifying a model name.
67-
CHECK( ollama::delete_model("llama3_copy") == true );
69+
CHECK( ollama::delete_model(test_model+"_copy") == true );
6870
}
6971

7072
TEST_CASE("Model Info") {
@@ -81,7 +83,7 @@ TEST_SUITE("Ollama Tests") {
8183
// List the models available locally in the ollama server
8284
std::vector<std::string> models = ollama::list_models();
8385

84-
bool contains_model = (std::find(models.begin(), models.end(), "llama3:8b") != models.end() );
86+
bool contains_model = (std::find(models.begin(), models.end(), test_model) != models.end() );
8587

8688
CHECK( contains_model );
8789
}
@@ -101,12 +103,9 @@ TEST_SUITE("Ollama Tests") {
101103

102104
TEST_CASE("Basic Generation") {
103105

104-
ollama::response response = ollama::generate("llama3:8b", "Why is the sky blue?", options);
105-
//std::cout << response << std::endl;
106-
107-
std::string expected_response = "What a great question!\n\nThe sky appears blue because of a phenomenon called Rayleigh scattering,";
106+
ollama::response response = ollama::generate(test_model, "Why is the sky blue?", options);
108107

109-
CHECK(response.as_simple_string() == expected_response);
108+
CHECK( response.as_json().contains("response") == true );
110109
}
111110

112111

@@ -124,35 +123,34 @@ TEST_SUITE("Ollama Tests") {
124123
TEST_CASE("Streaming Generation") {
125124

126125
std::function<void(const ollama::response&)> response_callback = on_receive_response;
127-
ollama::generate("llama3:8b", "Why is the sky blue?", response_callback, options);
126+
ollama::generate(test_model, "Why is the sky blue?", response_callback, options);
128127

129128
std::string expected_response = "What a great question!\n\nThe sky appears blue because of a phenomenon called Rayleigh scattering,";
130129

131-
CHECK( streamed_response == expected_response );
130+
CHECK( streamed_response != "" );
132131
}
133132

134133
TEST_CASE("Non-Singleton Generation") {
135134

136135
Ollama my_ollama_server("http://localhost:11434");
137136

138137
// You can use all of the same functions from this instanced version of the class.
139-
ollama::response response = my_ollama_server.generate("llama3:8b", "Why is the sky blue?", options);
140-
//std::cout << response << std::endl;
138+
ollama::response response = my_ollama_server.generate(test_model, "Why is the sky blue?", options);
141139

142140
std::string expected_response = "What a great question!\n\nThe sky appears blue because of a phenomenon called Rayleigh scattering,";
143141

144-
CHECK(response.as_simple_string() == expected_response);
142+
CHECK(response.as_json().contains("response") == true);
145143
}
146144

147145
TEST_CASE("Single-Message Chat") {
148146

149147
ollama::message message("user", "Why is the sky blue?");
150148

151-
ollama::response response = ollama::chat("llama3:8b", message, options);
149+
ollama::response response = ollama::chat(test_model, message, options);
152150

153151
std::string expected_response = "What a great question!\n\nThe sky appears blue because of a phenomenon called Rayleigh scattering,";
154152

155-
CHECK(response.as_simple_string()!="");
153+
CHECK(response.as_json().contains("message") == true);
156154
}
157155

158156
TEST_CASE("Multi-Message Chat") {
@@ -163,11 +161,11 @@ TEST_SUITE("Ollama Tests") {
163161

164162
ollama::messages messages = {message1, message2, message3};
165163

166-
ollama::response response = ollama::chat("llama3:8b", messages, options);
164+
ollama::response response = ollama::chat(test_model, messages, options);
167165

168166
std::string expected_response = "";
169167

170-
CHECK(response.as_simple_string()!="");
168+
CHECK(response.as_json().contains("message") == true);
171169
}
172170

173171
TEST_CASE("Chat with Streaming Response") {
@@ -182,7 +180,7 @@ TEST_SUITE("Ollama Tests") {
182180

183181
ollama::message message("user", "Why is the sky blue?");
184182

185-
ollama::chat("llama3:8b", message, response_callback, options);
183+
ollama::chat(test_model, message, response_callback, options);
186184

187185
CHECK(streamed_response!="");
188186
}
@@ -195,12 +193,9 @@ TEST_SUITE("Ollama Tests") {
195193

196194
ollama::image image = ollama::image::from_file("llama.jpg");
197195

198-
//ollama::images images={image};
199-
200-
ollama::response response = ollama::generate("llava", "What do you see in this image?", options, image);
201-
std::string expected_response = " The image features a large, fluffy white llama";
196+
ollama::response response = ollama::generate(image_test_model, "What do you see in this image?", options, image);
202197

203-
CHECK(response.as_simple_string() == expected_response);
198+
CHECK( response.as_json().contains("response") == true );
204199
}
205200

206201
TEST_CASE("Generation with Multiple Images") {
@@ -214,10 +209,10 @@ TEST_SUITE("Ollama Tests") {
214209

215210
ollama::images images={image, base64_image};
216211

217-
ollama::response response = ollama::generate("llava", "What do you see in this image?", options, images);
212+
ollama::response response = ollama::generate(image_test_model, "What do you see in this image?", options, images);
218213
std::string expected_response = " The image features a large, fluffy white and gray llama";
219214

220-
CHECK(response.as_simple_string() == expected_response);
215+
CHECK(response.as_json().contains("response") == true);
221216
}
222217

223218
TEST_CASE("Chat with Image") {
@@ -230,21 +225,20 @@ TEST_SUITE("Ollama Tests") {
230225

231226
// We can optionally include images with each message. Vision-enabled models will be able to utilize these.
232227
ollama::message message_with_image("user", "What do you see in this image?", image);
233-
ollama::response response = ollama::chat("llava", message_with_image, options);
228+
ollama::response response = ollama::chat(image_test_model, message_with_image, options);
234229

235230
std::string expected_response = " The image features a large, fluffy white llama";
236231

237-
CHECK(response.as_simple_string()!="");
232+
CHECK(response.as_json().contains("message") == true);
238233
}
239234

240235
TEST_CASE("Embedding Generation") {
241236

242237
options["num_predict"] = 18;
243238

244-
ollama::response response = ollama::generate_embeddings("llama3:8b", "Why is the sky blue?");
245-
//std::cout << response << std::endl;
239+
ollama::response response = ollama::generate_embeddings(test_model, "Why is the sky blue?");
246240

247-
CHECK(response.as_json().contains("embedding") == true);
241+
CHECK(response.as_json().contains("embeddings") == true);
248242
}
249243

250244
TEST_CASE("Enable Debug Logging") {

0 commit comments

Comments
 (0)