-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.swift
88 lines (72 loc) · 2.73 KB
/
main.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import Foundation
import PostgresNIO
guard let apiKey = ProcessInfo.processInfo.environment["CO_API_KEY"] else {
print("Set CO_API_KEY")
exit(1)
}
let config = PostgresClient.Configuration(
host: "localhost",
port: 5432,
username: ProcessInfo.processInfo.environment["USER"]!,
password: nil,
database: "pgvector_example",
tls: .disable
)
let client = PostgresClient(configuration: config)
struct ApiData: Encodable {
var texts: [String]
var model: String
var inputType: String
var embeddingTypes: [String]
}
struct EmbedResponse: Decodable {
var embeddings: EmbeddingsObject
}
struct EmbeddingsObject: Decodable {
var ubinary: [[UInt8]]
}
func embed(texts: [String], inputType: String, apiKey: String) async throws -> [String] {
let url = URL(string: "https://api.cohere.com/v1/embed")!
let data = ApiData(
texts: texts,
model: "embed-english-v3.0",
inputType: inputType,
embeddingTypes: ["ubinary"]
)
var request = URLRequest(url: url)
request.httpMethod = "POST"
request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization")
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
let encoder = JSONEncoder()
encoder.keyEncodingStrategy = .convertToSnakeCase
request.httpBody = try encoder.encode(data)
let (body, _) = try await URLSession.shared.data(for: request)
let response = try JSONDecoder().decode(EmbedResponse.self, from: body)
return response.embeddings.ubinary.map {
$0.map { String(repeating: "0", count: 8 - String($0, radix: 2).count) + String($0, radix: 2) }.joined()
}
}
try await withThrowingTaskGroup(of: Void.self) { taskGroup in
taskGroup.addTask {
await client.run()
}
try await client.query("CREATE EXTENSION IF NOT EXISTS vector")
try await client.query("DROP TABLE IF EXISTS documents")
try await client.query("CREATE TABLE documents (id serial PRIMARY KEY, content text, embedding bit(1024))")
let input = [
"The dog is barking",
"The cat is purring",
"The bear is growling",
]
let embeddings = try await embed(texts: input, inputType: "search_document", apiKey: apiKey)
for (content, embedding) in zip(input, embeddings) {
try await client.query("INSERT INTO documents (content, embedding) VALUES (\(content), \(embedding)::bit(1024))")
}
let query = "forest"
let queryEmbedding = (try await embed(texts: [query], inputType: "search_query", apiKey: apiKey))[0]
let rows = try await client.query("SELECT content FROM documents ORDER BY embedding <~> \(queryEmbedding)::bit(1024) LIMIT 5")
for try await row in rows {
print(row)
}
taskGroup.cancelAll()
}