|
28 | 28 | from langgraph.types import Send
|
29 | 29 | from pydantic import BaseModel, Field
|
30 | 30 |
|
31 |
| -from langfuse.callback import CallbackHandler |
32 |
| - |
33 | 31 | from quivr_core.llm import LLMEndpoint
|
34 | 32 | from quivr_core.llm_tools.llm_tools import LLMToolFactory
|
35 | 33 | from quivr_core.rag.entities.chat import ChatHistory
|
|
41 | 39 | )
|
42 | 40 | from quivr_core.rag.prompts import custom_prompts
|
43 | 41 | from quivr_core.rag.utils import (
|
| 42 | + LangfuseService, |
44 | 43 | collect_tools,
|
45 | 44 | combine_documents,
|
46 | 45 | format_file_list,
|
|
50 | 49 |
|
51 | 50 | logger = logging.getLogger("quivr_core")
|
52 | 51 |
|
53 |
| -# Initialize Langfuse CallbackHandler for Langchain (tracing) |
54 |
| -langfuse_handler = CallbackHandler() |
| 52 | +langfuse_service = LangfuseService() |
| 53 | +langfuse_handler = langfuse_service.get_handler() |
55 | 54 |
|
56 | 55 |
|
57 | 56 | class SplittedInput(BaseModel):
|
@@ -502,7 +501,7 @@ async def rewrite(self, state: AgentState) -> AgentState:
|
502 | 501 | task_ids = [jobs[1] for jobs in async_jobs] if async_jobs else []
|
503 | 502 |
|
504 | 503 | # Replace each question with its condensed version
|
505 |
| - for response, task_id in zip(responses, task_ids): |
| 504 | + for response, task_id in zip(responses, task_ids, strict=False): |
506 | 505 | tasks.set_definition(task_id, response.content)
|
507 | 506 |
|
508 | 507 | return {**state, "tasks": tasks}
|
@@ -558,7 +557,7 @@ async def tool_routing(self, state: AgentState):
|
558 | 557 | )
|
559 | 558 | task_ids = [jobs[1] for jobs in async_jobs] if async_jobs else []
|
560 | 559 |
|
561 |
| - for response, task_id in zip(responses, task_ids): |
| 560 | + for response, task_id in zip(responses, task_ids, strict=False): |
562 | 561 | tasks.set_completion(task_id, response.is_task_completable)
|
563 | 562 | if not response.is_task_completable and response.tool:
|
564 | 563 | tasks.set_tool(task_id, response.tool)
|
@@ -599,7 +598,7 @@ async def run_tool(self, state: AgentState) -> AgentState:
|
599 | 598 | )
|
600 | 599 | task_ids = [jobs[1] for jobs in async_jobs] if async_jobs else []
|
601 | 600 |
|
602 |
| - for response, task_id in zip(responses, task_ids): |
| 601 | + for response, task_id in zip(responses, task_ids, strict=False): |
603 | 602 | _docs = tool_wrapper.format_output(response)
|
604 | 603 | _docs = self.filter_chunks_by_relevance(_docs)
|
605 | 604 | tasks.set_docs(task_id, _docs)
|
@@ -652,7 +651,7 @@ async def retrieve(self, state: AgentState) -> AgentState:
|
652 | 651 | task_ids = [task[1] for task in async_jobs] if async_jobs else []
|
653 | 652 |
|
654 | 653 | # Process responses and associate docs with tasks
|
655 |
| - for response, task_id in zip(responses, task_ids): |
| 654 | + for response, task_id in zip(responses, task_ids, strict=False): |
656 | 655 | _docs = self.filter_chunks_by_relevance(response)
|
657 | 656 | tasks.set_docs(task_id, _docs) # Associate docs with the specific task
|
658 | 657 |
|
@@ -715,7 +714,7 @@ async def dynamic_retrieve(self, state: AgentState) -> AgentState:
|
715 | 714 | task_ids = [jobs[1] for jobs in async_jobs] if async_jobs else []
|
716 | 715 |
|
717 | 716 | _n = []
|
718 |
| - for response, task_id in zip(responses, task_ids): |
| 717 | + for response, task_id in zip(responses, task_ids, strict=False): |
719 | 718 | _docs = self.filter_chunks_by_relevance(response)
|
720 | 719 | _n.append(len(_docs))
|
721 | 720 | tasks.set_docs(task_id, _docs)
|
|
0 commit comments