Skip to content

Commit cd2d1c9

Browse files
committed
feat: set the config.py
1 parent 4c97bf4 commit cd2d1c9

File tree

5 files changed

+271
-0
lines changed

5 files changed

+271
-0
lines changed

Diff for: 3-feature-pipeline/config.py

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from pydantic_settings import BaseSettings
2+
3+
4+
class Settings(BaseSettings):
5+
# CometML config
6+
COMET_API_KEY: str | None = None
7+
COMET_WORKSPACE: str | None = None
8+
COMET_PROJECT: str | None = None
9+
10+
# Embeddings config
11+
EMBEDDING_MODEL_ID: str = "sentence-transformers/all-MiniLM-L6-v2"
12+
EMBEDDING_MODEL_MAX_INPUT_LENGTH: int = 256
13+
EMBEDDING_SIZE: int = 384
14+
EMBEDDING_MODEL_DEVICE: str = "cpu"
15+
16+
# OpenAI
17+
OPENAI_MODEL_ID: str = "gpt-4-1106-preview"
18+
OPENAI_API_KEY: str | None = None
19+
20+
# MQ config
21+
RABBITMQ_DEFAULT_USERNAME: str = "guest"
22+
RABBITMQ_DEFAULT_PASSWORD: str = "guest"
23+
RABBITMQ_HOST: str = "mq" # or localhost if running outside Docker
24+
RABBITMQ_PORT: int = 5672
25+
RABBITMQ_QUEUE_NAME: str = "default"
26+
27+
# QdrantDB config
28+
QDRANT_DATABASE_HOST: str = "qdrant" # or localhost if running outside Docker
29+
QDRANT_DATABASE_PORT: int = 6333
30+
USE_QDRANT_CLOUD: bool = False # if True, fill in QDRANT_CLOUD_URL and QDRANT_APIKEY
31+
QDRANT_CLOUD_URL: str | None = None
32+
QDRANT_APIKEY: str | None = None
33+
34+
35+
settings = Settings()

Diff for: 3-feature-pipeline/db.py

+82
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from qdrant_client import QdrantClient, models
2+
from qdrant_client.http.exceptions import UnexpectedResponse
3+
from qdrant_client.http.models import Batch, Distance, VectorParams
4+
5+
from utils.logging import get_logger
6+
from config import settings
7+
8+
logger = get_logger(__name__)
9+
10+
11+
class QdrantDatabaseConnector:
12+
_instance: QdrantClient | None = None
13+
14+
def __init__(self) -> None:
15+
if self._instance is None:
16+
try:
17+
if settings.USE_QDRANT_CLOUD:
18+
self._instance = QdrantClient(
19+
url=settings.QDRANT_CLOUD_URL,
20+
api_key=settings.QDRANT_APIKEY,
21+
)
22+
else:
23+
self._instance = QdrantClient(
24+
host=settings.QDRANT_DATABASE_HOST,
25+
port=settings.QDRANT_DATABASE_PORT,
26+
)
27+
except UnexpectedResponse:
28+
logger.exception(
29+
"Couldn't connect to Qdrant.",
30+
host=settings.QDRANT_DATABASE_HOST,
31+
port=settings.QDRANT_DATABASE_PORT,
32+
url=settings.QDRANT_CLOUD_URL,
33+
)
34+
35+
raise
36+
37+
def get_collection(self, collection_name: str):
38+
return self._instance.get_collection(collection_name=collection_name)
39+
40+
def create_non_vector_collection(self, collection_name: str):
41+
self._instance.create_collection(
42+
collection_name=collection_name, vectors_config={}
43+
)
44+
45+
def create_vector_collection(self, collection_name: str):
46+
self._instance.create_collection(
47+
collection_name=collection_name,
48+
vectors_config=VectorParams(
49+
size=settings.EMBEDDING_SIZE, distance=Distance.COSINE
50+
),
51+
)
52+
53+
def write_data(self, collection_name: str, points: Batch):
54+
try:
55+
self._instance.upsert(collection_name=collection_name, points=points)
56+
except Exception:
57+
logger.exception("An error occurred while inserting data.")
58+
59+
raise
60+
61+
def search(
62+
self,
63+
collection_name: str,
64+
query_vector: list,
65+
query_filter: models.Filter | None = None,
66+
limit: int = 3,
67+
) -> list:
68+
return self._instance.search(
69+
collection_name=collection_name,
70+
query_vector=query_vector,
71+
query_filter=query_filter,
72+
limit=limit,
73+
)
74+
75+
def scroll(self, collection_name: str, limit: int):
76+
return self._instance.scroll(collection_name=collection_name, limit=limit)
77+
78+
def close(self):
79+
if self._instance:
80+
self._instance.close()
81+
82+
logger.info("Connected to database has been closed.")

Diff for: 3-feature-pipeline/main.py

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import bytewax.operators as op
2+
from bytewax.dataflow import Dataflow
3+
4+
from db import QdrantDatabaseConnector
5+
6+
from data_flow.stream_input import RabbitMQSource
7+
from data_flow.stream_output import QdrantOutput
8+
from data_logic.dispatchers import (
9+
ChunkingDispatcher,
10+
CleaningDispatcher,
11+
EmbeddingDispatcher,
12+
RawDispatcher,
13+
)
14+
15+
connection = QdrantDatabaseConnector()
16+
17+
flow = Dataflow("Streaming ingestion pipeline")
18+
stream = op.input("input", flow, RabbitMQSource())
19+
stream = op.map("raw dispatch", stream, RawDispatcher.handle_mq_message)
20+
stream = op.map("clean dispatch", stream, CleaningDispatcher.dispatch_cleaner)
21+
op.output(
22+
"cleaned data insert to qdrant",
23+
stream,
24+
QdrantOutput(connection=connection, sink_type="clean"),
25+
)
26+
stream = op.flat_map("chunk dispatch", stream, ChunkingDispatcher.dispatch_chunker)
27+
stream = op.map(
28+
"embedded chunk dispatch", stream, EmbeddingDispatcher.dispatch_embedder
29+
)
30+
op.output(
31+
"embedded data insert to qdrant",
32+
stream,
33+
QdrantOutput(connection=connection, sink_type="vector"),
34+
)

Diff for: 3-feature-pipeline/mq.py

+98
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import pika
2+
3+
from utils.logging import get_logger
4+
from config import settings
5+
6+
logger = get_logger(__name__)
7+
8+
9+
class RabbitMQConnection:
10+
_instance = None
11+
12+
def __new__(
13+
cls,
14+
host: str | None = None,
15+
port: int | None = None,
16+
username: str | None = None,
17+
password: str | None = None,
18+
virtual_host: str = "/",
19+
):
20+
if not cls._instance:
21+
cls._instance = super().__new__(cls)
22+
23+
return cls._instance
24+
25+
def __init__(
26+
self,
27+
host: str | None = None,
28+
port: int | None = None,
29+
username: str | None = None,
30+
password: str | None = None,
31+
virtual_host: str = "/",
32+
fail_silently: bool = False,
33+
**kwargs,
34+
):
35+
self.host = host or settings.RABBITMQ_HOST
36+
self.port = port or settings.RABBITMQ_PORT
37+
self.username = username or settings.RABBITMQ_DEFAULT_USERNAME
38+
self.password = password or settings.RABBITMQ_DEFAULT_PASSWORD
39+
self.virtual_host = virtual_host
40+
self.fail_silently = fail_silently
41+
self._connection = None
42+
43+
def __enter__(self):
44+
self.connect()
45+
return self
46+
47+
def __exit__(self, exc_type, exc_val, exc_tb):
48+
self.close()
49+
50+
def connect(self):
51+
try:
52+
credentials = pika.PlainCredentials(self.username, self.password)
53+
self._connection = pika.BlockingConnection(
54+
pika.ConnectionParameters(
55+
host=self.host,
56+
port=self.port,
57+
virtual_host=self.virtual_host,
58+
credentials=credentials,
59+
)
60+
)
61+
except pika.exceptions.AMQPConnectionError as e:
62+
logger.exception("Failed to connect to RabbitMQ.")
63+
64+
if not self.fail_silently:
65+
raise e
66+
67+
def publish_message(self, data: str, queue: str):
68+
channel = self.get_channel()
69+
channel.queue_declare(
70+
queue=queue, durable=True, exclusive=False, auto_delete=False
71+
)
72+
channel.confirm_delivery()
73+
74+
try:
75+
channel.basic_publish(
76+
exchange="", routing_key=queue, body=data, mandatory=True
77+
)
78+
logger.info(
79+
"Sent message successfully.", queue_type="RabbitMQ", queue_name=queue
80+
)
81+
except pika.exceptions.UnroutableError:
82+
logger.info(
83+
"Failed to send the message.", queue_type="RabbitMQ", queue_name=queue
84+
)
85+
86+
def is_connected(self) -> bool:
87+
return self._connection is not None and self._connection.is_open
88+
89+
def get_channel(self):
90+
if self.is_connected():
91+
return self._connection.channel()
92+
93+
def close(self):
94+
if self.is_connected():
95+
self._connection.close()
96+
self._connection = None
97+
98+
logger.info("Closed RabbitMQ connection.")

Diff for: 3-feature-pipeline/retriever.py

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from dotenv import load_dotenv
2+
from langchain.globals import set_verbose
3+
from rag.retriever import VectorRetriever
4+
5+
from utils.logging import get_logger
6+
7+
set_verbose(True)
8+
9+
logger = get_logger(__name__)
10+
11+
if __name__ == "__main__":
12+
load_dotenv()
13+
query = """
14+
Could you please draft a LinkedIn post discussing RAG systems?
15+
I'm particularly interested in how RAG works and how it is integrated with vector DBs and large language models (LLMs).
16+
"""
17+
retriever = VectorRetriever(query=query)
18+
hits = retriever.retrieve_top_k(k=6, to_expand_to_n_queries=5)
19+
20+
reranked_hits = retriever.rerank(hits=hits, keep_top_k=5)
21+
for rank, hit in enumerate(reranked_hits):
22+
logger.info(f"{rank}: {hit}")

0 commit comments

Comments
 (0)