Skip to content

Commit bbdd60d

Browse files
authored
Revert "Revert "Adding Schema for Tool Outputs" (#892)"
This reverts commit 0ffdde0.
1 parent 0ffdde0 commit bbdd60d

File tree

13 files changed

+411
-71
lines changed

13 files changed

+411
-71
lines changed

codegen-examples/examples/swebench_agent_run/local_run.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
"metadata": {},
3333
"outputs": [],
3434
"source": [
35-
"await run_eval(use_existing_preds=None, dataset=\"lite\", length=20, repo=\"django/django\", num_workers=10, model=\"claude-3-7-sonnet-latest\")"
35+
"await run_eval(use_existing_preds=None, dataset=\"lite\", length=5, repo=\"django/django\", num_workers=10, model=\"claude-3-7-sonnet-latest\")"
3636
]
3737
},
3838
{

src/codegen/agents/data.py

+1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ class ToolMessageData(BaseMessage):
5252
tool_name: Optional[str] = None
5353
tool_response: Optional[str] = None
5454
tool_id: Optional[str] = None
55+
status: Optional[str] = None
5556

5657

5758
@dataclass

src/codegen/agents/tracer.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,14 @@ def extract_structured_data(self, chunk: dict[str, Any]) -> Optional[BaseMessage
7171
tool_calls = [ToolCall(name=tc.get("name"), arguments=tc.get("arguments"), id=tc.get("id")) for tc in tool_calls_data]
7272
return AssistantMessage(type=message_type, content=content, tool_calls=tool_calls)
7373
elif message_type == "tool":
74-
return ToolMessageData(type=message_type, content=content, tool_name=getattr(latest_message, "name", None), tool_response=content, tool_id=getattr(latest_message, "tool_call_id", None))
74+
return ToolMessageData(
75+
type=message_type,
76+
content=content,
77+
tool_name=getattr(latest_message, "name", None),
78+
tool_response=getattr(latest_message, "artifact", content),
79+
tool_id=getattr(latest_message, "tool_call_id", None),
80+
status=getattr(latest_message, "status", None),
81+
)
7582
elif message_type == "function":
7683
return FunctionMessageData(type=message_type, content=content)
7784
else:

src/codegen/extensions/langchain/graph.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def reasoner(self, state: GraphState) -> dict[str, Any]:
100100
messages.append(HumanMessage(content=query))
101101

102102
result = self.model.invoke([self.system_message, *messages])
103-
if isinstance(result, AIMessage):
103+
if isinstance(result, AIMessage) and not result.tool_calls:
104104
updated_messages = [*messages, result]
105105
return {"messages": updated_messages, "final_answer": result.content}
106106

@@ -455,7 +455,7 @@ def get_field_descriptions(tool_obj):
455455
return f"Error: Could not identify the tool you're trying to use.\n\nAvailable tools:\n{available_tools}\n\nPlease use one of the available tools with the correct parameters."
456456

457457
# For other types of errors
458-
return f"Error executing tool: {error_msg}\n\nPlease check your tool usage and try again with the correct parameters."
458+
return f"Error executing tool: {exception!s}\n\nPlease check your tool usage and try again with the correct parameters."
459459

460460
# Add nodes
461461
builder.add_node("reasoner", self.reasoner, retry=retry_policy)

src/codegen/extensions/langchain/tools.py

+27-19
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
"""Langchain tools for workspace operations."""
22

33
from collections.abc import Callable
4-
from typing import ClassVar, Literal
4+
from typing import Annotated, ClassVar, Literal, Optional
55

6+
from langchain_core.messages import ToolMessage
7+
from langchain_core.tools import InjectedToolCallId
68
from langchain_core.tools.base import BaseTool
79
from pydantic import BaseModel, Field
810

@@ -52,10 +54,11 @@ class ViewFileInput(BaseModel):
5254
"""Input for viewing a file."""
5355

5456
filepath: str = Field(..., description="Path to the file relative to workspace root")
55-
start_line: int | None = Field(None, description="Starting line number to view (1-indexed, inclusive)")
56-
end_line: int | None = Field(None, description="Ending line number to view (1-indexed, inclusive)")
57-
max_lines: int | None = Field(None, description="Maximum number of lines to view at once, defaults to 250")
58-
line_numbers: bool | None = Field(True, description="If True, add line numbers to the content (1-indexed)")
57+
start_line: Optional[int] = Field(None, description="Starting line number to view (1-indexed, inclusive)")
58+
end_line: Optional[int] = Field(None, description="Ending line number to view (1-indexed, inclusive)")
59+
max_lines: Optional[int] = Field(None, description="Maximum number of lines to view at once, defaults to 250")
60+
line_numbers: Optional[bool] = Field(True, description="If True, add line numbers to the content (1-indexed)")
61+
tool_call_id: Annotated[str, InjectedToolCallId]
5962

6063

6164
class ViewFileTool(BaseTool):
@@ -73,12 +76,13 @@ def __init__(self, codebase: Codebase) -> None:
7376

7477
def _run(
7578
self,
79+
tool_call_id: str,
7680
filepath: str,
77-
start_line: int | None = None,
78-
end_line: int | None = None,
79-
max_lines: int | None = None,
80-
line_numbers: bool | None = True,
81-
) -> str:
81+
start_line: Optional[int] = None,
82+
end_line: Optional[int] = None,
83+
max_lines: Optional[int] = None,
84+
line_numbers: Optional[bool] = True,
85+
) -> ToolMessage:
8286
result = view_file(
8387
self.codebase,
8488
filepath,
@@ -88,14 +92,15 @@ def _run(
8892
max_lines=max_lines if max_lines is not None else 250,
8993
)
9094

91-
return result.render()
95+
return result.render(tool_call_id)
9296

9397

9498
class ListDirectoryInput(BaseModel):
9599
"""Input for listing directory contents."""
96100

97101
dirpath: str = Field(default="./", description="Path to directory relative to workspace root")
98102
depth: int = Field(default=1, description="How deep to traverse. Use -1 for unlimited depth.")
103+
tool_call_id: Annotated[str, InjectedToolCallId]
99104

100105

101106
class ListDirectoryTool(BaseTool):
@@ -109,9 +114,9 @@ class ListDirectoryTool(BaseTool):
109114
def __init__(self, codebase: Codebase) -> None:
110115
super().__init__(codebase=codebase)
111116

112-
def _run(self, dirpath: str = "./", depth: int = 1) -> str:
117+
def _run(self, tool_call_id: str, dirpath: str = "./", depth: int = 1) -> ToolMessage:
113118
result = list_directory(self.codebase, dirpath, depth)
114-
return result.render()
119+
return result.render(tool_call_id)
115120

116121

117122
class SearchInput(BaseModel):
@@ -126,6 +131,7 @@ class SearchInput(BaseModel):
126131
page: int = Field(default=1, description="Page number to return (1-based, default: 1)")
127132
files_per_page: int = Field(default=10, description="Number of files to return per page (default: 10)")
128133
use_regex: bool = Field(default=False, description="Whether to treat query as a regex pattern (default: False)")
134+
tool_call_id: Annotated[str, InjectedToolCallId]
129135

130136

131137
class SearchTool(BaseTool):
@@ -139,16 +145,17 @@ class SearchTool(BaseTool):
139145
def __init__(self, codebase: Codebase) -> None:
140146
super().__init__(codebase=codebase)
141147

142-
def _run(self, query: str, file_extensions: list[str] | None = None, page: int = 1, files_per_page: int = 10, use_regex: bool = False) -> str:
148+
def _run(self, tool_call_id: str, query: str, file_extensions: Optional[list[str]] = None, page: int = 1, files_per_page: int = 10, use_regex: bool = False) -> ToolMessage:
143149
result = search(self.codebase, query, file_extensions=file_extensions, page=page, files_per_page=files_per_page, use_regex=use_regex)
144-
return result.render()
150+
return result.render(tool_call_id)
145151

146152

147153
class EditFileInput(BaseModel):
148154
"""Input for editing a file."""
149155

150156
filepath: str = Field(..., description="Path to the file to edit")
151157
content: str = Field(..., description="New content for the file")
158+
tool_call_id: Annotated[str, InjectedToolCallId]
152159

153160

154161
class EditFileTool(BaseTool):
@@ -181,9 +188,9 @@ class EditFileTool(BaseTool):
181188
def __init__(self, codebase: Codebase) -> None:
182189
super().__init__(codebase=codebase)
183190

184-
def _run(self, filepath: str, content: str) -> str:
191+
def _run(self, filepath: str, content: str, tool_call_id: str) -> str:
185192
result = edit_file(self.codebase, filepath, content)
186-
return result.render()
193+
return result.render(tool_call_id)
187194

188195

189196
class CreateFileInput(BaseModel):
@@ -340,6 +347,7 @@ class SemanticEditInput(BaseModel):
340347
edit_content: str = Field(..., description=FILE_EDIT_PROMPT)
341348
start: int = Field(default=1, description="Starting line number (1-indexed, inclusive). Default is 1.")
342349
end: int = Field(default=-1, description="Ending line number (1-indexed, inclusive). Default is -1 (end of file).")
350+
tool_call_id: Annotated[str, InjectedToolCallId]
343351

344352

345353
class SemanticEditTool(BaseTool):
@@ -353,10 +361,10 @@ class SemanticEditTool(BaseTool):
353361
def __init__(self, codebase: Codebase) -> None:
354362
super().__init__(codebase=codebase)
355363

356-
def _run(self, filepath: str, edit_content: str, start: int = 1, end: int = -1) -> str:
364+
def _run(self, filepath: str, tool_call_id: str, edit_content: str, start: int = 1, end: int = -1) -> ToolMessage:
357365
# Create the the draft editor mini llm
358366
result = semantic_edit(self.codebase, filepath, edit_content, start=start, end=end)
359-
return result.render()
367+
return result.render(tool_call_id)
360368

361369

362370
class RenameFileInput(BaseModel):

src/codegen/extensions/tools/edit_file.py

+27-6
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,53 @@
11
"""Tool for editing file contents."""
22

3-
from typing import ClassVar
3+
from typing import TYPE_CHECKING, ClassVar, Optional
44

5+
from langchain_core.messages import ToolMessage
56
from pydantic import Field
67

78
from codegen.sdk.core.codebase import Codebase
89

910
from .observation import Observation
1011
from .replacement_edit import generate_diff
1112

13+
if TYPE_CHECKING:
14+
from .tool_output_types import EditFileArtifacts
15+
1216

1317
class EditFileObservation(Observation):
1418
"""Response from editing a file."""
1519

1620
filepath: str = Field(
1721
description="Path to the edited file",
1822
)
19-
diff: str = Field(
23+
diff: Optional[str] = Field(
24+
default=None,
2025
description="Unified diff showing the changes made",
2126
)
2227

2328
str_template: ClassVar[str] = "Edited file {filepath}"
2429

25-
def render(self) -> str:
30+
def render(self, tool_call_id: str) -> ToolMessage:
2631
"""Render edit results in a clean format."""
27-
return f"""[EDIT FILE]: {self.filepath}
28-
29-
{self.diff}"""
32+
if self.status == "error":
33+
artifacts_error: EditFileArtifacts = {"filepath": self.filepath, "error": self.error}
34+
return ToolMessage(
35+
content=f"[ERROR EDITING FILE]: {self.filepath}: {self.error}",
36+
status=self.status,
37+
tool_name="edit_file",
38+
artifact=artifacts_error,
39+
tool_call_id=tool_call_id,
40+
)
41+
42+
artifacts_success: EditFileArtifacts = {"filepath": self.filepath, "diff": self.diff}
43+
44+
return ToolMessage(
45+
content=f"""[EDIT FILE]: {self.filepath}\n\n{self.diff}""",
46+
status=self.status,
47+
tool_name="edit_file",
48+
artifact=artifacts_success,
49+
tool_call_id=tool_call_id,
50+
)
3051

3152

3253
def edit_file(codebase: Codebase, filepath: str, new_content: str) -> EditFileObservation:

src/codegen/extensions/tools/list_directory.py

+62-9
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22

33
from typing import ClassVar
44

5+
from langchain_core.messages import ToolMessage
56
from pydantic import Field
67

8+
from codegen.extensions.tools.observation import Observation
9+
from codegen.extensions.tools.tool_output_types import ListDirectoryArtifacts
710
from codegen.sdk.core.codebase import Codebase
811
from codegen.sdk.core.directory import Directory
912

10-
from .observation import Observation
11-
1213

1314
class DirectoryInfo(Observation):
1415
"""Information about a directory."""
@@ -31,6 +32,14 @@ class DirectoryInfo(Observation):
3132
default=False,
3233
description="Whether this is a leaf node (at max depth)",
3334
)
35+
depth: int = Field(
36+
default=0,
37+
description="Current depth in the tree",
38+
)
39+
max_depth: int = Field(
40+
default=1,
41+
description="Maximum depth allowed",
42+
)
3443

3544
str_template: ClassVar[str] = "Directory {path} ({file_count} files, {dir_count} subdirs)"
3645

@@ -41,7 +50,7 @@ def _get_details(self) -> dict[str, int]:
4150
"dir_count": len(self.subdirectories),
4251
}
4352

44-
def render(self) -> str:
53+
def render_as_string(self) -> str:
4554
"""Render directory listing as a file tree."""
4655
lines = [
4756
f"[LIST DIRECTORY]: {self.path}",
@@ -97,6 +106,26 @@ def build_tree(items: list[tuple[str, bool, "DirectoryInfo | None"]], prefix: st
97106

98107
return "\n".join(lines)
99108

109+
def to_artifacts(self) -> ListDirectoryArtifacts:
110+
"""Convert directory info to artifacts for UI."""
111+
artifacts: ListDirectoryArtifacts = {
112+
"dirpath": self.path,
113+
"name": self.name,
114+
"is_leaf": self.is_leaf,
115+
"depth": self.depth,
116+
"max_depth": self.max_depth,
117+
}
118+
119+
if self.files is not None:
120+
artifacts["files"] = self.files
121+
artifacts["file_paths"] = [f"{self.path}/{f}" for f in self.files]
122+
123+
if self.subdirectories:
124+
artifacts["subdirs"] = [d.name for d in self.subdirectories]
125+
artifacts["subdir_paths"] = [d.path for d in self.subdirectories]
126+
127+
return artifacts
128+
100129

101130
class ListDirectoryObservation(Observation):
102131
"""Response from listing directory contents."""
@@ -107,9 +136,29 @@ class ListDirectoryObservation(Observation):
107136

108137
str_template: ClassVar[str] = "{directory_info}"
109138

110-
def render(self) -> str:
111-
"""Render directory listing."""
112-
return self.directory_info.render()
139+
def render(self, tool_call_id: str) -> ToolMessage:
140+
"""Render directory listing with artifacts for UI."""
141+
if self.status == "error":
142+
error_artifacts: ListDirectoryArtifacts = {
143+
"dirpath": self.directory_info.path,
144+
"name": self.directory_info.name,
145+
"error": self.error,
146+
}
147+
return ToolMessage(
148+
content=f"[ERROR LISTING DIRECTORY]: {self.directory_info.path}: {self.error}",
149+
status=self.status,
150+
tool_name="list_directory",
151+
artifact=error_artifacts,
152+
tool_call_id=tool_call_id,
153+
)
154+
155+
return ToolMessage(
156+
content=self.directory_info.render_as_string(),
157+
status=self.status,
158+
tool_name="list_directory",
159+
artifact=self.directory_info.to_artifacts(),
160+
tool_call_id=tool_call_id,
161+
)
113162

114163

115164
def list_directory(codebase: Codebase, path: str = "./", depth: int = 2) -> ListDirectoryObservation:
@@ -136,7 +185,7 @@ def list_directory(codebase: Codebase, path: str = "./", depth: int = 2) -> List
136185
),
137186
)
138187

139-
def get_directory_info(dir_obj: Directory, current_depth: int) -> DirectoryInfo:
188+
def get_directory_info(dir_obj: Directory, current_depth: int, max_depth: int) -> DirectoryInfo:
140189
"""Helper function to get directory info recursively."""
141190
# Get direct files (always include files unless at max depth)
142191
all_files = []
@@ -151,7 +200,7 @@ def get_directory_info(dir_obj: Directory, current_depth: int) -> DirectoryInfo:
151200
if current_depth > 1 or current_depth == -1:
152201
# For deeper traversal, get full directory info
153202
new_depth = current_depth - 1 if current_depth > 1 else -1
154-
subdirs.append(get_directory_info(subdir, new_depth))
203+
subdirs.append(get_directory_info(subdir, new_depth, max_depth))
155204
else:
156205
# At max depth, return a leaf node
157206
subdirs.append(
@@ -161,6 +210,8 @@ def get_directory_info(dir_obj: Directory, current_depth: int) -> DirectoryInfo:
161210
path=subdir.dirpath,
162211
files=None, # Don't include files at max depth
163212
is_leaf=True,
213+
depth=current_depth,
214+
max_depth=max_depth,
164215
)
165216
)
166217

@@ -170,9 +221,11 @@ def get_directory_info(dir_obj: Directory, current_depth: int) -> DirectoryInfo:
170221
path=dir_obj.dirpath,
171222
files=sorted(all_files),
172223
subdirectories=subdirs,
224+
depth=current_depth,
225+
max_depth=max_depth,
173226
)
174227

175-
dir_info = get_directory_info(directory, depth)
228+
dir_info = get_directory_info(directory, depth, depth)
176229
return ListDirectoryObservation(
177230
status="success",
178231
directory_info=dir_info,

0 commit comments

Comments
 (0)