-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathmain.rs
73 lines (64 loc) · 2.08 KB
/
main.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
use pgvector::Vector;
use postgres::{Client, NoTls};
use serde_json::Value;
use std::error::Error;
fn main() -> Result<(), Box<dyn Error>> {
let mut client = Client::configure()
.host("localhost")
.dbname("pgvector_example")
.user(std::env::var("USER")?.as_str())
.connect(NoTls)?;
client.execute("CREATE EXTENSION IF NOT EXISTS vector", &[])?;
client.execute("DROP TABLE IF EXISTS documents", &[])?;
client.execute(
"CREATE TABLE documents (id serial PRIMARY KEY, content text, embedding vector(1536))",
&[],
)?;
let input = [
"The dog is barking",
"The cat is purring",
"The bear is growling",
];
let embeddings = embed(&input)?;
for (content, embedding) in input.iter().zip(embeddings) {
client.execute(
"INSERT INTO documents (content, embedding) VALUES ($1, $2)",
&[&content, &Vector::from(embedding)],
)?;
}
let query = "forest";
let query_embedding = embed(&[query])?.drain(..).next().unwrap();
for row in client.query(
"SELECT content FROM documents ORDER BY embedding <=> $1 LIMIT 5",
&[&Vector::from(query_embedding)],
)? {
let content: &str = row.get(0);
println!("{}", content);
}
Ok(())
}
fn embed(input: &[&str]) -> Result<Vec<Vec<f32>>, Box<dyn Error>> {
let api_key = std::env::var("OPENAI_API_KEY").or(Err("Set OPENAI_API_KEY"))?;
let response: Value = ureq::post("https://api.openai.com/v1/embeddings")
.header("Authorization", &format!("Bearer {}", api_key))
.send_json(serde_json::json!({
"input": input,
"model": "text-embedding-3-small",
}))?
.body_mut()
.read_json()?;
let embeddings = response["data"]
.as_array()
.unwrap()
.iter()
.map(|v| {
v["embedding"]
.as_array()
.unwrap()
.iter()
.map(|v| v.as_f64().unwrap() as f32)
.collect()
})
.collect();
Ok(embeddings)
}