-
Notifications
You must be signed in to change notification settings - Fork 74
/
Copy pathrrf.py
52 lines (47 loc) · 1.84 KB
/
rrf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from pgvector.psycopg import register_vector
import psycopg
from sentence_transformers import SentenceTransformer
conn = psycopg.connect(dbname='pgvector_example', autocommit=True)
conn.execute('CREATE EXTENSION IF NOT EXISTS vector')
register_vector(conn)
conn.execute('DROP TABLE IF EXISTS documents')
conn.execute('CREATE TABLE documents (id bigserial PRIMARY KEY, content text, embedding vector(384))')
conn.execute("CREATE INDEX ON documents USING GIN (to_tsvector('english', content))")
sentences = [
'The dog is barking',
'The cat is purring',
'The bear is growling'
]
model = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
embeddings = model.encode(sentences)
for content, embedding in zip(sentences, embeddings):
conn.execute('INSERT INTO documents (content, embedding) VALUES (%s, %s)', (content, embedding))
sql = """
WITH semantic_search AS (
SELECT id, RANK () OVER (ORDER BY embedding <=> %(embedding)s) AS rank
FROM documents
ORDER BY embedding <=> %(embedding)s
LIMIT 20
),
keyword_search AS (
SELECT id, RANK () OVER (ORDER BY ts_rank_cd(to_tsvector('english', content), query) DESC)
FROM documents, plainto_tsquery('english', %(query)s) query
WHERE to_tsvector('english', content) @@ query
ORDER BY ts_rank_cd(to_tsvector('english', content), query) DESC
LIMIT 20
)
SELECT
COALESCE(semantic_search.id, keyword_search.id) AS id,
COALESCE(1.0 / (%(k)s + semantic_search.rank), 0.0) +
COALESCE(1.0 / (%(k)s + keyword_search.rank), 0.0) AS score
FROM semantic_search
FULL OUTER JOIN keyword_search ON semantic_search.id = keyword_search.id
ORDER BY score DESC
LIMIT 5
"""
query = 'growling bear'
embedding = model.encode(query)
k = 60
results = conn.execute(sql, {'query': query, 'embedding': embedding, 'k': k}).fetchall()
for row in results:
print('document:', row[0], 'RRF score:', row[1])