-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathmain.rs
75 lines (66 loc) · 2.41 KB
/
main.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
use pgvector::Vector;
use postgres::binary_copy::BinaryCopyInWriter;
use postgres::types::{Kind, Type};
use postgres::{Client, NoTls};
use rand::Rng;
use std::error::Error;
use std::io::{self, Write};
fn main() -> Result<(), Box<dyn Error>> {
// generate random data
let rows = 1000000;
let dimensions = 128;
let mut rng = rand::rng();
let embeddings: Vec<Vec<f32>> = (0..rows)
.map(|_| (0..dimensions).map(|_| rng.random()).collect())
.collect();
// enable extension
let mut client = Client::configure()
.host("localhost")
.dbname("pgvector_example")
.user(std::env::var("USER")?.as_str())
.connect(NoTls)?;
client.execute("CREATE EXTENSION IF NOT EXISTS vector", &[])?;
// create table
client.execute("DROP TABLE IF EXISTS items", &[])?;
client.execute(
&format!("CREATE TABLE items (id bigserial, embedding vector({dimensions}))"),
&[],
)?;
// load data
println!("Loading {} rows", embeddings.len());
let vector_type = get_type(&mut client, "vector")?;
let writer = client.copy_in("COPY items (embedding) FROM STDIN WITH (FORMAT BINARY)")?;
let mut writer = BinaryCopyInWriter::new(writer, &[vector_type]);
for (i, embedding) in embeddings.into_iter().enumerate() {
// show progress
if i % 10000 == 0 {
print!(".");
io::stdout().flush()?;
}
writer.write(&[&Vector::from(embedding)])?;
}
writer.finish()?;
println!("\nSuccess!");
// create any indexes *after* loading initial data (skipping for this example)
if std::env::var("INDEX").is_ok() {
println!("Creating index");
client.execute("SET maintenance_work_mem = '8GB'", &[])?;
client.execute("SET max_parallel_maintenance_workers = 7", &[])?;
client.execute(
"CREATE INDEX ON items USING hnsw (embedding vector_cosine_ops)",
&[],
)?;
}
// update planner statistics for good measure
client.execute("ANALYZE items", &[])?;
Ok(())
}
fn get_type(client: &mut Client, name: &str) -> Result<Type, Box<dyn Error>> {
let row = client.query_one("SELECT pg_type.oid, nspname AS schema FROM pg_type INNER JOIN pg_namespace ON pg_namespace.oid = pg_type.typnamespace WHERE typname = $1", &[&name])?;
Ok(Type::new(
name.into(),
row.get("oid"),
Kind::Simple,
row.get("schema"),
))
}