Skip to content

Commit bc9f739

Browse files
authored
Add logger interface support (#857)
# Motivation <!-- Why is this change necessary? --> # Content <!-- Please include a summary of the change --> # Testing <!-- How was the change tested? --> # Please check the following before marking your PR as ready for review - [ ] I have added tests for my changes - [ ] I have updated the documentation or added new documentation as needed --------- Co-authored-by: rushilpatel0 <[email protected]>
1 parent 8fbad1a commit bc9f739

File tree

5 files changed

+331
-3
lines changed

5 files changed

+331
-3
lines changed

src/codegen/agents/code_agent.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from langgraph.graph.graph import CompiledGraph
99
from langsmith import Client
1010

11+
from codegen.agents.loggers import ExternalLogger
12+
from codegen.agents.tracer import MessageStreamTracer
1113
from codegen.extensions.langchain.agent import create_codebase_agent
1214
from codegen.extensions.langchain.utils.get_langsmith_url import (
1315
find_and_print_langsmith_run_url,
@@ -30,6 +32,7 @@ class CodeAgent:
3032
run_id: str | None = None
3133
instance_id: str | None = None
3234
difficulty: int | None = None
35+
logger: Optional[ExternalLogger] = None
3336

3437
def __init__(
3538
self,
@@ -42,6 +45,7 @@ def __init__(
4245
metadata: Optional[dict] = {},
4346
agent_config: Optional[AgentConfig] = None,
4447
thread_id: Optional[str] = None,
48+
logger: Optional[ExternalLogger] = None,
4549
**kwargs,
4650
):
4751
"""Initialize a CodeAgent.
@@ -92,6 +96,9 @@ def __init__(
9296
# Initialize tags for agent trace
9397
self.tags = [*tags, self.model_name]
9498

99+
# set logger if provided
100+
self.logger = logger
101+
95102
# Initialize metadata for agent trace
96103
self.metadata = {
97104
"project": self.project_name,
@@ -123,19 +130,26 @@ def run(self, prompt: str) -> str:
123130

124131
config = RunnableConfig(configurable={"thread_id": self.thread_id}, tags=self.tags, metadata=self.metadata, recursion_limit=200)
125132
# we stream the steps instead of invoke because it allows us to access intermediate nodes
133+
126134
stream = self.agent.stream(input, config=config, stream_mode="values")
127135

136+
_tracer = MessageStreamTracer(logger=self.logger)
137+
138+
# Process the stream with the tracer
139+
traced_stream = _tracer.process_stream(stream)
140+
128141
# Keep track of run IDs from the stream
129142
run_ids = []
130143

131-
for s in stream:
144+
for s in traced_stream:
132145
if len(s["messages"]) == 0 or isinstance(s["messages"][-1], HumanMessage):
133146
message = HumanMessage(content=prompt)
134147
else:
135148
message = s["messages"][-1]
136149

137150
if isinstance(message, tuple):
138-
print(message)
151+
# print(message)
152+
pass
139153
else:
140154
if isinstance(message, AIMessage) and isinstance(message.content, list) and len(message.content) > 0 and "text" in message.content[0]:
141155
AIMessage(message.content[0]["text"]).pretty_print()
@@ -149,7 +163,7 @@ def run(self, prompt: str) -> str:
149163
# Get the last message content
150164
result = s["final_answer"]
151165

152-
# Try to find run IDs in the LangSmith client's recent runs
166+
# # Try to find run IDs in the LangSmith client's recent runs
153167
try:
154168
# Find and print the LangSmith run URL
155169
find_and_print_langsmith_run_url(self.langsmith_client, self.project_name)

src/codegen/agents/data.py

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from dataclasses import dataclass, field
2+
from datetime import UTC, datetime
3+
from typing import Literal, Optional, Union
4+
5+
6+
# Base dataclass for all message types
7+
@dataclass
8+
class BaseMessage:
9+
"""Base class for all message types."""
10+
11+
type: str
12+
timestamp: str = field(default_factory=lambda: datetime.now(tz=UTC).isoformat())
13+
content: str = ""
14+
15+
16+
@dataclass
17+
class UserMessage(BaseMessage):
18+
"""Represents a message from the user."""
19+
20+
type: Literal["user"] = field(default="user")
21+
22+
23+
@dataclass
24+
class SystemMessageData(BaseMessage):
25+
"""Represents a system message."""
26+
27+
type: Literal["system"] = field(default="system")
28+
29+
30+
@dataclass
31+
class ToolCall:
32+
"""Represents a tool call within an assistant message."""
33+
34+
name: Optional[str] = None
35+
arguments: Optional[str] = None
36+
id: Optional[str] = None
37+
38+
39+
@dataclass
40+
class AssistantMessage(BaseMessage):
41+
"""Represents a message from the assistant."""
42+
43+
type: Literal["assistant"] = field(default="assistant")
44+
tool_calls: list[ToolCall] = field(default_factory=list)
45+
46+
47+
@dataclass
48+
class ToolMessageData(BaseMessage):
49+
"""Represents a tool response message."""
50+
51+
type: Literal["tool"] = field(default="tool")
52+
tool_name: Optional[str] = None
53+
tool_response: Optional[str] = None
54+
tool_id: Optional[str] = None
55+
56+
57+
@dataclass
58+
class FunctionMessageData(BaseMessage):
59+
"""Represents a function message."""
60+
61+
type: Literal["function"] = field(default="function")
62+
63+
64+
@dataclass
65+
class UnknownMessage(BaseMessage):
66+
"""Represents an unknown message type."""
67+
68+
type: Literal["unknown"] = field(default="unknown")
69+
70+
71+
type AgentRunMessage = Union[UserMessage, SystemMessageData, AssistantMessage, ToolMessageData, FunctionMessageData, UnknownMessage]

src/codegen/agents/loggers.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from typing import Protocol
2+
3+
from .data import AgentRunMessage
4+
5+
6+
# Define the interface for ExternalLogger
7+
class ExternalLogger(Protocol):
8+
"""Protocol defining the interface for external loggers."""
9+
10+
def log(self, data: AgentRunMessage) -> None:
11+
"""Log structured data to an external system.
12+
13+
Args:
14+
data: The structured data to log, either as a dictionary or a BaseMessage
15+
"""
16+
pass

src/codegen/agents/scratch.ipynb

+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"from codegen.agents.code_agent import CodeAgent\n",
10+
"\n",
11+
"\n",
12+
"CodeAgent"
13+
]
14+
},
15+
{
16+
"cell_type": "code",
17+
"execution_count": null,
18+
"metadata": {},
19+
"outputs": [],
20+
"source": [
21+
"from codegen.sdk.core.codebase import Codebase\n",
22+
"\n",
23+
"\n",
24+
"codebase = Codebase.from_repo(\"codegen-sh/Kevin-s-Adventure-Game\")"
25+
]
26+
},
27+
{
28+
"cell_type": "code",
29+
"execution_count": null,
30+
"metadata": {},
31+
"outputs": [],
32+
"source": [
33+
"from typing import Any, Dict, Union\n",
34+
"from codegen.agents.data import BaseMessage\n",
35+
"from codegen.agents.loggers import ExternalLogger\n",
36+
"\n",
37+
"\n",
38+
"class ConsoleLogger(ExternalLogger):\n",
39+
" def log(self, data: Union[Dict[str, Any], BaseMessage]) -> None:\n",
40+
" print(data.content)"
41+
]
42+
},
43+
{
44+
"cell_type": "code",
45+
"execution_count": null,
46+
"metadata": {},
47+
"outputs": [],
48+
"source": [
49+
"agent = CodeAgent(codebase)\n",
50+
"agent.run(\"What is the main character's name? also show the source code where you find the answer\", logger=ConsoleLogger())"
51+
]
52+
},
53+
{
54+
"cell_type": "code",
55+
"execution_count": null,
56+
"metadata": {},
57+
"outputs": [],
58+
"source": [
59+
"agent.run(\"What is the main character's name?\")"
60+
]
61+
},
62+
{
63+
"cell_type": "code",
64+
"execution_count": null,
65+
"metadata": {},
66+
"outputs": [],
67+
"source": []
68+
}
69+
],
70+
"metadata": {
71+
"kernelspec": {
72+
"display_name": ".venv",
73+
"language": "python",
74+
"name": "python3"
75+
},
76+
"language_info": {
77+
"codemirror_mode": {
78+
"name": "ipython",
79+
"version": 3
80+
},
81+
"file_extension": ".py",
82+
"mimetype": "text/x-python",
83+
"name": "python",
84+
"nbconvert_exporter": "python",
85+
"pygments_lexer": "ipython3",
86+
"version": "3.13.0"
87+
}
88+
},
89+
"nbformat": 4,
90+
"nbformat_minor": 2
91+
}

src/codegen/agents/tracer.py

+136
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
from collections.abc import Generator
2+
from typing import Any, Optional
3+
4+
from langchain.schema import AIMessage, HumanMessage
5+
from langchain.schema import FunctionMessage as LCFunctionMessage
6+
from langchain.schema import SystemMessage as LCSystemMessage
7+
from langchain_core.messages import ToolMessage as LCToolMessage
8+
9+
from .data import AssistantMessage, BaseMessage, FunctionMessageData, SystemMessageData, ToolCall, ToolMessageData, UnknownMessage, UserMessage
10+
from .loggers import ExternalLogger
11+
12+
13+
class MessageStreamTracer:
14+
def __init__(self, logger: Optional[ExternalLogger] = None):
15+
self.traces = []
16+
self.logger = logger
17+
18+
def process_stream(self, message_stream: Generator) -> Generator:
19+
"""Process the stream of messages from the LangGraph agent,
20+
extract structured data, and pass through the messages.
21+
"""
22+
for chunk in message_stream:
23+
# Process the chunk
24+
structured_data = self.extract_structured_data(chunk)
25+
26+
# Log the structured data
27+
if structured_data:
28+
self.traces.append(structured_data)
29+
30+
# If there's an external logger, send the data there
31+
if self.logger:
32+
self.logger.log(structured_data)
33+
34+
# Pass through the chunk to maintain the original stream behavior
35+
yield chunk
36+
37+
def extract_structured_data(self, chunk: dict[str, Any]) -> Optional[BaseMessage]:
38+
"""Extract structured data from a message chunk.
39+
Returns None if the chunk doesn't contain useful information.
40+
Returns a BaseMessage subclass instance based on the message type.
41+
"""
42+
# Get the messages from the chunk if available
43+
messages = chunk.get("messages", [])
44+
if not messages and isinstance(chunk, dict):
45+
# Sometimes the message might be in a different format
46+
for key, value in chunk.items():
47+
if isinstance(value, list) and all(hasattr(item, "type") for item in value if hasattr(item, "__dict__")):
48+
messages = value
49+
break
50+
51+
if not messages:
52+
return None
53+
54+
# Get the latest message
55+
latest_message = messages[-1] if messages else None
56+
57+
if not latest_message:
58+
return None
59+
60+
# Determine message type
61+
message_type = self._get_message_type(latest_message)
62+
content = self._get_message_content(latest_message)
63+
64+
# Create the appropriate message type
65+
if message_type == "user":
66+
return UserMessage(type=message_type, content=content)
67+
elif message_type == "system":
68+
return SystemMessageData(type=message_type, content=content)
69+
elif message_type == "assistant":
70+
tool_calls_data = self._extract_tool_calls(latest_message)
71+
tool_calls = [ToolCall(name=tc.get("name"), arguments=tc.get("arguments"), id=tc.get("id")) for tc in tool_calls_data]
72+
return AssistantMessage(type=message_type, content=content, tool_calls=tool_calls)
73+
elif message_type == "tool":
74+
return ToolMessageData(type=message_type, content=content, tool_name=getattr(latest_message, "name", None), tool_response=content, tool_id=getattr(latest_message, "tool_call_id", None))
75+
elif message_type == "function":
76+
return FunctionMessageData(type=message_type, content=content)
77+
else:
78+
return UnknownMessage(type=message_type, content=content)
79+
80+
def _get_message_type(self, message) -> str:
81+
"""Determine the type of message."""
82+
if isinstance(message, HumanMessage):
83+
return "user"
84+
elif isinstance(message, AIMessage):
85+
return "assistant"
86+
elif isinstance(message, LCSystemMessage):
87+
return "system"
88+
elif isinstance(message, LCFunctionMessage):
89+
return "function"
90+
elif isinstance(message, LCToolMessage):
91+
return "tool"
92+
elif hasattr(message, "type") and message.type:
93+
return message.type
94+
else:
95+
return "unknown"
96+
97+
def _get_message_content(self, message) -> str:
98+
"""Extract content from a message."""
99+
if hasattr(message, "content"):
100+
return message.content
101+
elif hasattr(message, "message") and hasattr(message.message, "content"):
102+
return message.message.content
103+
else:
104+
return str(message)
105+
106+
def _extract_tool_calls(self, message) -> list[dict[str, Any]]:
107+
"""Extract tool calls from an assistant message."""
108+
tool_calls = []
109+
110+
# Check different possible locations for tool calls
111+
if hasattr(message, "additional_kwargs") and "tool_calls" in message.additional_kwargs:
112+
raw_tool_calls = message.additional_kwargs["tool_calls"]
113+
for tc in raw_tool_calls:
114+
tool_calls.append({"name": tc.get("function", {}).get("name"), "arguments": tc.get("function", {}).get("arguments"), "id": tc.get("id")})
115+
116+
# Also check for function_call which is used in some models
117+
elif hasattr(message, "additional_kwargs") and "function_call" in message.additional_kwargs:
118+
fc = message.additional_kwargs["function_call"]
119+
if isinstance(fc, dict):
120+
tool_calls.append(
121+
{
122+
"name": fc.get("name"),
123+
"arguments": fc.get("arguments"),
124+
"id": "function_call_1", # Assigning a default ID
125+
}
126+
)
127+
128+
return tool_calls
129+
130+
def get_traces(self) -> list[BaseMessage]:
131+
"""Get all collected traces."""
132+
return self.traces
133+
134+
def clear_traces(self) -> None:
135+
"""Clear all traces."""
136+
self.traces = []

0 commit comments

Comments
 (0)