Skip to content

Commit 9681a9e

Browse files
chloediaStanGirard
andauthored
fix: langfuse talk to model (#3535)
# Description Please include a summary of the changes and the related issue. Please also include relevant motivation and context. ## Checklist before requesting a review Please delete options that are not relevant. - [ ] My code follows the style guidelines of this project - [ ] I have performed a self-review of my code - [ ] I have commented hard-to-understand areas - [ ] I have ideally added tests that prove my fix is effective or that my feature works - [ ] New and existing unit tests pass locally with my changes - [ ] Any dependent changes have been merged ## Screenshots (if appropriate): --------- Co-authored-by: Stan Girard <[email protected]>
1 parent d835fc6 commit 9681a9e

File tree

3 files changed

+23
-12
lines changed

3 files changed

+23
-12
lines changed

core/quivr_core/rag/quivr_rag.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
1313
from langchain_core.vectorstores import VectorStore
1414

15+
from quivr_core.llm import LLMEndpoint
1516
from quivr_core.rag.entities.chat import ChatHistory
1617
from quivr_core.rag.entities.config import RetrievalConfig
17-
from quivr_core.llm import LLMEndpoint
1818
from quivr_core.rag.entities.models import (
1919
ParsedRAGChunkResponse,
2020
ParsedRAGResponse,
@@ -24,6 +24,7 @@
2424
)
2525
from quivr_core.rag.prompts import custom_prompts
2626
from quivr_core.rag.utils import (
27+
LangfuseService,
2728
combine_documents,
2829
format_file_list,
2930
get_chunk_metadata,
@@ -32,6 +33,8 @@
3233
)
3334

3435
logger = logging.getLogger("quivr_core")
36+
langfuse_service = LangfuseService()
37+
langfuse_handler = langfuse_service.get_handler()
3538

3639

3740
class IdempotentCompressor(BaseDocumentCompressor):
@@ -173,7 +176,7 @@ def answer(
173176
"chat_history": history,
174177
"custom_instructions": (self.retrieval_config.prompt),
175178
},
176-
config={"metadata": metadata},
179+
config={"metadata": metadata, "callbacks": [langfuse_handler]},
177180
)
178181
response = parse_response(
179182
raw_llm_response, self.retrieval_config.llm_config.model
@@ -206,7 +209,7 @@ async def answer_astream(
206209
"chat_history": history,
207210
"custom_personality": (self.retrieval_config.prompt),
208211
},
209-
config={"metadata": metadata},
212+
config={"metadata": metadata, "callbacks": [langfuse_handler]},
210213
):
211214
# Could receive this anywhere so we need to save it for the last chunk
212215
if "docs" in chunk:

core/quivr_core/rag/quivr_rag_langgraph.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@
2828
from langgraph.types import Send
2929
from pydantic import BaseModel, Field
3030

31-
from langfuse.callback import CallbackHandler
32-
3331
from quivr_core.llm import LLMEndpoint
3432
from quivr_core.llm_tools.llm_tools import LLMToolFactory
3533
from quivr_core.rag.entities.chat import ChatHistory
@@ -41,6 +39,7 @@
4139
)
4240
from quivr_core.rag.prompts import custom_prompts
4341
from quivr_core.rag.utils import (
42+
LangfuseService,
4443
collect_tools,
4544
combine_documents,
4645
format_file_list,
@@ -50,8 +49,8 @@
5049

5150
logger = logging.getLogger("quivr_core")
5251

53-
# Initialize Langfuse CallbackHandler for Langchain (tracing)
54-
langfuse_handler = CallbackHandler()
52+
langfuse_service = LangfuseService()
53+
langfuse_handler = langfuse_service.get_handler()
5554

5655

5756
class SplittedInput(BaseModel):
@@ -502,7 +501,7 @@ async def rewrite(self, state: AgentState) -> AgentState:
502501
task_ids = [jobs[1] for jobs in async_jobs] if async_jobs else []
503502

504503
# Replace each question with its condensed version
505-
for response, task_id in zip(responses, task_ids):
504+
for response, task_id in zip(responses, task_ids, strict=False):
506505
tasks.set_definition(task_id, response.content)
507506

508507
return {**state, "tasks": tasks}
@@ -558,7 +557,7 @@ async def tool_routing(self, state: AgentState):
558557
)
559558
task_ids = [jobs[1] for jobs in async_jobs] if async_jobs else []
560559

561-
for response, task_id in zip(responses, task_ids):
560+
for response, task_id in zip(responses, task_ids, strict=False):
562561
tasks.set_completion(task_id, response.is_task_completable)
563562
if not response.is_task_completable and response.tool:
564563
tasks.set_tool(task_id, response.tool)
@@ -599,7 +598,7 @@ async def run_tool(self, state: AgentState) -> AgentState:
599598
)
600599
task_ids = [jobs[1] for jobs in async_jobs] if async_jobs else []
601600

602-
for response, task_id in zip(responses, task_ids):
601+
for response, task_id in zip(responses, task_ids, strict=False):
603602
_docs = tool_wrapper.format_output(response)
604603
_docs = self.filter_chunks_by_relevance(_docs)
605604
tasks.set_docs(task_id, _docs)
@@ -652,7 +651,7 @@ async def retrieve(self, state: AgentState) -> AgentState:
652651
task_ids = [task[1] for task in async_jobs] if async_jobs else []
653652

654653
# Process responses and associate docs with tasks
655-
for response, task_id in zip(responses, task_ids):
654+
for response, task_id in zip(responses, task_ids, strict=False):
656655
_docs = self.filter_chunks_by_relevance(response)
657656
tasks.set_docs(task_id, _docs) # Associate docs with the specific task
658657

@@ -715,7 +714,7 @@ async def dynamic_retrieve(self, state: AgentState) -> AgentState:
715714
task_ids = [jobs[1] for jobs in async_jobs] if async_jobs else []
716715

717716
_n = []
718-
for response, task_id in zip(responses, task_ids):
717+
for response, task_id in zip(responses, task_ids, strict=False):
719718
_docs = self.filter_chunks_by_relevance(response)
720719
_n.append(len(_docs))
721720
tasks.set_docs(task_id, _docs)

core/quivr_core/rag/utils.py

+9
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
55
from langchain_core.messages.ai import AIMessageChunk
66
from langchain_core.prompts import format_document
7+
from langfuse.callback import CallbackHandler
78

89
from quivr_core.rag.entities.config import WorkflowConfig
910
from quivr_core.rag.entities.models import (
@@ -195,3 +196,11 @@ def collect_tools(workflow_config: WorkflowConfig):
195196
activated_tools += f"Tool {i+1} description: {tool.description}\n\n"
196197

197198
return validated_tools, activated_tools
199+
200+
201+
class LangfuseService:
202+
def __init__(self):
203+
self.langfuse_handler = CallbackHandler()
204+
205+
def get_handler(self):
206+
return self.langfuse_handler

0 commit comments

Comments
 (0)