Skip to content

Commit d835fc6

Browse files
feat: returning a description of each workflow node (#3539)
# Description By returning a description of each node executed by a (LangGraph) workflow we can show it to the user and thus inform him about the status of the task execution
1 parent e0ccd3d commit d835fc6

File tree

3 files changed

+24
-1
lines changed

3 files changed

+24
-1
lines changed

core/quivr_core/rag/entities/config.py

+1
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,7 @@ def resolve_special_edges(self):
398398

399399
class NodeConfig(QuivrBaseConfig):
400400
name: str
401+
description: str | None = None
401402
edges: List[str] | None = None
402403
conditional_edge: ConditionalEdgeConfig | None = None
403404
tools: List[Dict[str, Any]] | None = None

core/quivr_core/rag/entities/models.py

+1
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ class RAGResponseMetadata(BaseModel):
7676
followup_questions: list[str] = Field(default_factory=list)
7777
sources: list[Any] = Field(default_factory=list)
7878
metadata_model: ChatLLMMetadata | None = None
79+
workflow_step: str | None = None
7980

8081

8182
class ParsedRAGResponse(BaseModel):

core/quivr_core/rag/quivr_rag_langgraph.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
TypedDict,
1414
)
1515
from uuid import UUID, uuid4
16-
1716
import openai
1817
from langchain.retrievers import ContextualCompressionRetriever
1918
from langchain_cohere import CohereRerank
@@ -38,6 +37,7 @@
3837
from quivr_core.rag.entities.models import (
3938
ParsedRAGChunkResponse,
4039
QuivrKnowledge,
40+
RAGResponseMetadata,
4141
)
4242
from quivr_core.rag.prompts import custom_prompts
4343
from quivr_core.rag.utils import (
@@ -950,6 +950,8 @@ async def answer_astream(
950950
version="v1",
951951
config={"metadata": metadata, "callbacks": [langfuse_handler]},
952952
):
953+
node_name = self._extract_node_name(event)
954+
953955
if self._is_final_node_with_docs(event):
954956
tasks = event["data"]["output"]["tasks"]
955957
docs = tasks.docs if tasks else []
@@ -965,9 +967,17 @@ async def answer_astream(
965967

966968
if new_content:
967969
chunk_metadata = get_chunk_metadata(rolling_message, docs)
970+
if node_name:
971+
chunk_metadata.workflow_step = node_name
968972
yield ParsedRAGChunkResponse(
969973
answer=new_content, metadata=chunk_metadata
970974
)
975+
else:
976+
if node_name:
977+
yield ParsedRAGChunkResponse(
978+
answer="",
979+
metadata=RAGResponseMetadata(workflow_step=node_name),
980+
)
971981

972982
# Yield final metadata chunk
973983
yield ParsedRAGChunkResponse(
@@ -991,6 +1001,17 @@ def _is_final_node_and_chat_model_stream(self, event: dict) -> bool:
9911001
and event["metadata"]["langgraph_node"] in self.final_nodes
9921002
)
9931003

1004+
def _extract_node_name(self, event: dict) -> str:
1005+
if "metadata" in event and "langgraph_node" in event["metadata"]:
1006+
name = event["metadata"]["langgraph_node"]
1007+
for node in self.retrieval_config.workflow_config.nodes:
1008+
if node.name == name:
1009+
if node.description:
1010+
return node.description
1011+
else:
1012+
return node.name
1013+
return ""
1014+
9941015
async def ainvoke_structured_output(
9951016
self, prompt: str, output_class: Type[BaseModel]
9961017
) -> Any:

0 commit comments

Comments
 (0)