|
| 1 | +from collections.abc import Generator |
| 2 | +from typing import Any, Optional |
| 3 | + |
| 4 | +from langchain.schema import AIMessage, HumanMessage |
| 5 | +from langchain.schema import FunctionMessage as LCFunctionMessage |
| 6 | +from langchain.schema import SystemMessage as LCSystemMessage |
| 7 | +from langchain_core.messages import ToolMessage as LCToolMessage |
| 8 | + |
| 9 | +from .data import AssistantMessage, BaseMessage, FunctionMessageData, SystemMessageData, ToolCall, ToolMessageData, UnknownMessage, UserMessage |
| 10 | +from .loggers import ExternalLogger |
| 11 | + |
| 12 | + |
| 13 | +class MessageStreamTracer: |
| 14 | + def __init__(self, logger: Optional[ExternalLogger] = None): |
| 15 | + self.traces = [] |
| 16 | + self.logger = logger |
| 17 | + |
| 18 | + def process_stream(self, message_stream: Generator) -> Generator: |
| 19 | + """Process the stream of messages from the LangGraph agent, |
| 20 | + extract structured data, and pass through the messages. |
| 21 | + """ |
| 22 | + for chunk in message_stream: |
| 23 | + # Process the chunk |
| 24 | + structured_data = self.extract_structured_data(chunk) |
| 25 | + |
| 26 | + # Log the structured data |
| 27 | + if structured_data: |
| 28 | + self.traces.append(structured_data) |
| 29 | + |
| 30 | + # If there's an external logger, send the data there |
| 31 | + if self.logger: |
| 32 | + self.logger.log(structured_data) |
| 33 | + |
| 34 | + # Pass through the chunk to maintain the original stream behavior |
| 35 | + yield chunk |
| 36 | + |
| 37 | + def extract_structured_data(self, chunk: dict[str, Any]) -> Optional[BaseMessage]: |
| 38 | + """Extract structured data from a message chunk. |
| 39 | + Returns None if the chunk doesn't contain useful information. |
| 40 | + Returns a BaseMessage subclass instance based on the message type. |
| 41 | + """ |
| 42 | + # Get the messages from the chunk if available |
| 43 | + messages = chunk.get("messages", []) |
| 44 | + if not messages and isinstance(chunk, dict): |
| 45 | + # Sometimes the message might be in a different format |
| 46 | + for key, value in chunk.items(): |
| 47 | + if isinstance(value, list) and all(hasattr(item, "type") for item in value if hasattr(item, "__dict__")): |
| 48 | + messages = value |
| 49 | + break |
| 50 | + |
| 51 | + if not messages: |
| 52 | + return None |
| 53 | + |
| 54 | + # Get the latest message |
| 55 | + latest_message = messages[-1] if messages else None |
| 56 | + |
| 57 | + if not latest_message: |
| 58 | + return None |
| 59 | + |
| 60 | + # Determine message type |
| 61 | + message_type = self._get_message_type(latest_message) |
| 62 | + content = self._get_message_content(latest_message) |
| 63 | + |
| 64 | + # Create the appropriate message type |
| 65 | + if message_type == "user": |
| 66 | + return UserMessage(type=message_type, content=content) |
| 67 | + elif message_type == "system": |
| 68 | + return SystemMessageData(type=message_type, content=content) |
| 69 | + elif message_type == "assistant": |
| 70 | + tool_calls_data = self._extract_tool_calls(latest_message) |
| 71 | + tool_calls = [ToolCall(name=tc.get("name"), arguments=tc.get("arguments"), id=tc.get("id")) for tc in tool_calls_data] |
| 72 | + return AssistantMessage(type=message_type, content=content, tool_calls=tool_calls) |
| 73 | + elif message_type == "tool": |
| 74 | + return ToolMessageData(type=message_type, content=content, tool_name=getattr(latest_message, "name", None), tool_response=content, tool_id=getattr(latest_message, "tool_call_id", None)) |
| 75 | + elif message_type == "function": |
| 76 | + return FunctionMessageData(type=message_type, content=content) |
| 77 | + else: |
| 78 | + return UnknownMessage(type=message_type, content=content) |
| 79 | + |
| 80 | + def _get_message_type(self, message) -> str: |
| 81 | + """Determine the type of message.""" |
| 82 | + if isinstance(message, HumanMessage): |
| 83 | + return "user" |
| 84 | + elif isinstance(message, AIMessage): |
| 85 | + return "assistant" |
| 86 | + elif isinstance(message, LCSystemMessage): |
| 87 | + return "system" |
| 88 | + elif isinstance(message, LCFunctionMessage): |
| 89 | + return "function" |
| 90 | + elif isinstance(message, LCToolMessage): |
| 91 | + return "tool" |
| 92 | + elif hasattr(message, "type") and message.type: |
| 93 | + return message.type |
| 94 | + else: |
| 95 | + return "unknown" |
| 96 | + |
| 97 | + def _get_message_content(self, message) -> str: |
| 98 | + """Extract content from a message.""" |
| 99 | + if hasattr(message, "content"): |
| 100 | + return message.content |
| 101 | + elif hasattr(message, "message") and hasattr(message.message, "content"): |
| 102 | + return message.message.content |
| 103 | + else: |
| 104 | + return str(message) |
| 105 | + |
| 106 | + def _extract_tool_calls(self, message) -> list[dict[str, Any]]: |
| 107 | + """Extract tool calls from an assistant message.""" |
| 108 | + tool_calls = [] |
| 109 | + |
| 110 | + # Check different possible locations for tool calls |
| 111 | + if hasattr(message, "additional_kwargs") and "tool_calls" in message.additional_kwargs: |
| 112 | + raw_tool_calls = message.additional_kwargs["tool_calls"] |
| 113 | + for tc in raw_tool_calls: |
| 114 | + tool_calls.append({"name": tc.get("function", {}).get("name"), "arguments": tc.get("function", {}).get("arguments"), "id": tc.get("id")}) |
| 115 | + |
| 116 | + # Also check for function_call which is used in some models |
| 117 | + elif hasattr(message, "additional_kwargs") and "function_call" in message.additional_kwargs: |
| 118 | + fc = message.additional_kwargs["function_call"] |
| 119 | + if isinstance(fc, dict): |
| 120 | + tool_calls.append( |
| 121 | + { |
| 122 | + "name": fc.get("name"), |
| 123 | + "arguments": fc.get("arguments"), |
| 124 | + "id": "function_call_1", # Assigning a default ID |
| 125 | + } |
| 126 | + ) |
| 127 | + |
| 128 | + return tool_calls |
| 129 | + |
| 130 | + def get_traces(self) -> list[BaseMessage]: |
| 131 | + """Get all collected traces.""" |
| 132 | + return self.traces |
| 133 | + |
| 134 | + def clear_traces(self) -> None: |
| 135 | + """Clear all traces.""" |
| 136 | + self.traces = [] |
0 commit comments