Skip to content

Commit 0ffdde0

Browse files
authored
Revert "Adding Schema for Tool Outputs" (#892)
Reverts #888
1 parent fe98b92 commit 0ffdde0

File tree

13 files changed

+71
-411
lines changed

13 files changed

+71
-411
lines changed

codegen-examples/examples/swebench_agent_run/local_run.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
"metadata": {},
3333
"outputs": [],
3434
"source": [
35-
"await run_eval(use_existing_preds=None, dataset=\"lite\", length=5, repo=\"django/django\", num_workers=10, model=\"claude-3-7-sonnet-latest\")"
35+
"await run_eval(use_existing_preds=None, dataset=\"lite\", length=20, repo=\"django/django\", num_workers=10, model=\"claude-3-7-sonnet-latest\")"
3636
]
3737
},
3838
{

src/codegen/agents/data.py

-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ class ToolMessageData(BaseMessage):
5252
tool_name: Optional[str] = None
5353
tool_response: Optional[str] = None
5454
tool_id: Optional[str] = None
55-
status: Optional[str] = None
5655

5756

5857
@dataclass

src/codegen/agents/tracer.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,7 @@ def extract_structured_data(self, chunk: dict[str, Any]) -> Optional[BaseMessage
7171
tool_calls = [ToolCall(name=tc.get("name"), arguments=tc.get("arguments"), id=tc.get("id")) for tc in tool_calls_data]
7272
return AssistantMessage(type=message_type, content=content, tool_calls=tool_calls)
7373
elif message_type == "tool":
74-
return ToolMessageData(
75-
type=message_type,
76-
content=content,
77-
tool_name=getattr(latest_message, "name", None),
78-
tool_response=getattr(latest_message, "artifact", content),
79-
tool_id=getattr(latest_message, "tool_call_id", None),
80-
status=getattr(latest_message, "status", None),
81-
)
74+
return ToolMessageData(type=message_type, content=content, tool_name=getattr(latest_message, "name", None), tool_response=content, tool_id=getattr(latest_message, "tool_call_id", None))
8275
elif message_type == "function":
8376
return FunctionMessageData(type=message_type, content=content)
8477
else:

src/codegen/extensions/langchain/graph.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def reasoner(self, state: GraphState) -> dict[str, Any]:
100100
messages.append(HumanMessage(content=query))
101101

102102
result = self.model.invoke([self.system_message, *messages])
103-
if isinstance(result, AIMessage) and not result.tool_calls:
103+
if isinstance(result, AIMessage):
104104
updated_messages = [*messages, result]
105105
return {"messages": updated_messages, "final_answer": result.content}
106106

@@ -455,7 +455,7 @@ def get_field_descriptions(tool_obj):
455455
return f"Error: Could not identify the tool you're trying to use.\n\nAvailable tools:\n{available_tools}\n\nPlease use one of the available tools with the correct parameters."
456456

457457
# For other types of errors
458-
return f"Error executing tool: {exception!s}\n\nPlease check your tool usage and try again with the correct parameters."
458+
return f"Error executing tool: {error_msg}\n\nPlease check your tool usage and try again with the correct parameters."
459459

460460
# Add nodes
461461
builder.add_node("reasoner", self.reasoner, retry=retry_policy)

src/codegen/extensions/langchain/tools.py

+19-27
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
"""Langchain tools for workspace operations."""
22

33
from collections.abc import Callable
4-
from typing import Annotated, ClassVar, Literal, Optional
4+
from typing import ClassVar, Literal
55

6-
from langchain_core.messages import ToolMessage
7-
from langchain_core.tools import InjectedToolCallId
86
from langchain_core.tools.base import BaseTool
97
from pydantic import BaseModel, Field
108

@@ -54,11 +52,10 @@ class ViewFileInput(BaseModel):
5452
"""Input for viewing a file."""
5553

5654
filepath: str = Field(..., description="Path to the file relative to workspace root")
57-
start_line: Optional[int] = Field(None, description="Starting line number to view (1-indexed, inclusive)")
58-
end_line: Optional[int] = Field(None, description="Ending line number to view (1-indexed, inclusive)")
59-
max_lines: Optional[int] = Field(None, description="Maximum number of lines to view at once, defaults to 250")
60-
line_numbers: Optional[bool] = Field(True, description="If True, add line numbers to the content (1-indexed)")
61-
tool_call_id: Annotated[str, InjectedToolCallId]
55+
start_line: int | None = Field(None, description="Starting line number to view (1-indexed, inclusive)")
56+
end_line: int | None = Field(None, description="Ending line number to view (1-indexed, inclusive)")
57+
max_lines: int | None = Field(None, description="Maximum number of lines to view at once, defaults to 250")
58+
line_numbers: bool | None = Field(True, description="If True, add line numbers to the content (1-indexed)")
6259

6360

6461
class ViewFileTool(BaseTool):
@@ -76,13 +73,12 @@ def __init__(self, codebase: Codebase) -> None:
7673

7774
def _run(
7875
self,
79-
tool_call_id: str,
8076
filepath: str,
81-
start_line: Optional[int] = None,
82-
end_line: Optional[int] = None,
83-
max_lines: Optional[int] = None,
84-
line_numbers: Optional[bool] = True,
85-
) -> ToolMessage:
77+
start_line: int | None = None,
78+
end_line: int | None = None,
79+
max_lines: int | None = None,
80+
line_numbers: bool | None = True,
81+
) -> str:
8682
result = view_file(
8783
self.codebase,
8884
filepath,
@@ -92,15 +88,14 @@ def _run(
9288
max_lines=max_lines if max_lines is not None else 250,
9389
)
9490

95-
return result.render(tool_call_id)
91+
return result.render()
9692

9793

9894
class ListDirectoryInput(BaseModel):
9995
"""Input for listing directory contents."""
10096

10197
dirpath: str = Field(default="./", description="Path to directory relative to workspace root")
10298
depth: int = Field(default=1, description="How deep to traverse. Use -1 for unlimited depth.")
103-
tool_call_id: Annotated[str, InjectedToolCallId]
10499

105100

106101
class ListDirectoryTool(BaseTool):
@@ -114,9 +109,9 @@ class ListDirectoryTool(BaseTool):
114109
def __init__(self, codebase: Codebase) -> None:
115110
super().__init__(codebase=codebase)
116111

117-
def _run(self, tool_call_id: str, dirpath: str = "./", depth: int = 1) -> ToolMessage:
112+
def _run(self, dirpath: str = "./", depth: int = 1) -> str:
118113
result = list_directory(self.codebase, dirpath, depth)
119-
return result.render(tool_call_id)
114+
return result.render()
120115

121116

122117
class SearchInput(BaseModel):
@@ -131,7 +126,6 @@ class SearchInput(BaseModel):
131126
page: int = Field(default=1, description="Page number to return (1-based, default: 1)")
132127
files_per_page: int = Field(default=10, description="Number of files to return per page (default: 10)")
133128
use_regex: bool = Field(default=False, description="Whether to treat query as a regex pattern (default: False)")
134-
tool_call_id: Annotated[str, InjectedToolCallId]
135129

136130

137131
class SearchTool(BaseTool):
@@ -145,17 +139,16 @@ class SearchTool(BaseTool):
145139
def __init__(self, codebase: Codebase) -> None:
146140
super().__init__(codebase=codebase)
147141

148-
def _run(self, tool_call_id: str, query: str, file_extensions: Optional[list[str]] = None, page: int = 1, files_per_page: int = 10, use_regex: bool = False) -> ToolMessage:
142+
def _run(self, query: str, file_extensions: list[str] | None = None, page: int = 1, files_per_page: int = 10, use_regex: bool = False) -> str:
149143
result = search(self.codebase, query, file_extensions=file_extensions, page=page, files_per_page=files_per_page, use_regex=use_regex)
150-
return result.render(tool_call_id)
144+
return result.render()
151145

152146

153147
class EditFileInput(BaseModel):
154148
"""Input for editing a file."""
155149

156150
filepath: str = Field(..., description="Path to the file to edit")
157151
content: str = Field(..., description="New content for the file")
158-
tool_call_id: Annotated[str, InjectedToolCallId]
159152

160153

161154
class EditFileTool(BaseTool):
@@ -188,9 +181,9 @@ class EditFileTool(BaseTool):
188181
def __init__(self, codebase: Codebase) -> None:
189182
super().__init__(codebase=codebase)
190183

191-
def _run(self, filepath: str, content: str, tool_call_id: str) -> str:
184+
def _run(self, filepath: str, content: str) -> str:
192185
result = edit_file(self.codebase, filepath, content)
193-
return result.render(tool_call_id)
186+
return result.render()
194187

195188

196189
class CreateFileInput(BaseModel):
@@ -347,7 +340,6 @@ class SemanticEditInput(BaseModel):
347340
edit_content: str = Field(..., description=FILE_EDIT_PROMPT)
348341
start: int = Field(default=1, description="Starting line number (1-indexed, inclusive). Default is 1.")
349342
end: int = Field(default=-1, description="Ending line number (1-indexed, inclusive). Default is -1 (end of file).")
350-
tool_call_id: Annotated[str, InjectedToolCallId]
351343

352344

353345
class SemanticEditTool(BaseTool):
@@ -361,10 +353,10 @@ class SemanticEditTool(BaseTool):
361353
def __init__(self, codebase: Codebase) -> None:
362354
super().__init__(codebase=codebase)
363355

364-
def _run(self, filepath: str, tool_call_id: str, edit_content: str, start: int = 1, end: int = -1) -> ToolMessage:
356+
def _run(self, filepath: str, edit_content: str, start: int = 1, end: int = -1) -> str:
365357
# Create the the draft editor mini llm
366358
result = semantic_edit(self.codebase, filepath, edit_content, start=start, end=end)
367-
return result.render(tool_call_id)
359+
return result.render()
368360

369361

370362
class RenameFileInput(BaseModel):

src/codegen/extensions/tools/edit_file.py

+6-27
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,32 @@
11
"""Tool for editing file contents."""
22

3-
from typing import TYPE_CHECKING, ClassVar, Optional
3+
from typing import ClassVar
44

5-
from langchain_core.messages import ToolMessage
65
from pydantic import Field
76

87
from codegen.sdk.core.codebase import Codebase
98

109
from .observation import Observation
1110
from .replacement_edit import generate_diff
1211

13-
if TYPE_CHECKING:
14-
from .tool_output_types import EditFileArtifacts
15-
1612

1713
class EditFileObservation(Observation):
1814
"""Response from editing a file."""
1915

2016
filepath: str = Field(
2117
description="Path to the edited file",
2218
)
23-
diff: Optional[str] = Field(
24-
default=None,
19+
diff: str = Field(
2520
description="Unified diff showing the changes made",
2621
)
2722

2823
str_template: ClassVar[str] = "Edited file {filepath}"
2924

30-
def render(self, tool_call_id: str) -> ToolMessage:
25+
def render(self) -> str:
3126
"""Render edit results in a clean format."""
32-
if self.status == "error":
33-
artifacts_error: EditFileArtifacts = {"filepath": self.filepath, "error": self.error}
34-
return ToolMessage(
35-
content=f"[ERROR EDITING FILE]: {self.filepath}: {self.error}",
36-
status=self.status,
37-
tool_name="edit_file",
38-
artifact=artifacts_error,
39-
tool_call_id=tool_call_id,
40-
)
41-
42-
artifacts_success: EditFileArtifacts = {"filepath": self.filepath, "diff": self.diff}
43-
44-
return ToolMessage(
45-
content=f"""[EDIT FILE]: {self.filepath}\n\n{self.diff}""",
46-
status=self.status,
47-
tool_name="edit_file",
48-
artifact=artifacts_success,
49-
tool_call_id=tool_call_id,
50-
)
27+
return f"""[EDIT FILE]: {self.filepath}
28+
29+
{self.diff}"""
5130

5231

5332
def edit_file(codebase: Codebase, filepath: str, new_content: str) -> EditFileObservation:

src/codegen/extensions/tools/list_directory.py

+9-62
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,13 @@
22

33
from typing import ClassVar
44

5-
from langchain_core.messages import ToolMessage
65
from pydantic import Field
76

8-
from codegen.extensions.tools.observation import Observation
9-
from codegen.extensions.tools.tool_output_types import ListDirectoryArtifacts
107
from codegen.sdk.core.codebase import Codebase
118
from codegen.sdk.core.directory import Directory
129

10+
from .observation import Observation
11+
1312

1413
class DirectoryInfo(Observation):
1514
"""Information about a directory."""
@@ -32,14 +31,6 @@ class DirectoryInfo(Observation):
3231
default=False,
3332
description="Whether this is a leaf node (at max depth)",
3433
)
35-
depth: int = Field(
36-
default=0,
37-
description="Current depth in the tree",
38-
)
39-
max_depth: int = Field(
40-
default=1,
41-
description="Maximum depth allowed",
42-
)
4334

4435
str_template: ClassVar[str] = "Directory {path} ({file_count} files, {dir_count} subdirs)"
4536

@@ -50,7 +41,7 @@ def _get_details(self) -> dict[str, int]:
5041
"dir_count": len(self.subdirectories),
5142
}
5243

53-
def render_as_string(self) -> str:
44+
def render(self) -> str:
5445
"""Render directory listing as a file tree."""
5546
lines = [
5647
f"[LIST DIRECTORY]: {self.path}",
@@ -106,26 +97,6 @@ def build_tree(items: list[tuple[str, bool, "DirectoryInfo | None"]], prefix: st
10697

10798
return "\n".join(lines)
10899

109-
def to_artifacts(self) -> ListDirectoryArtifacts:
110-
"""Convert directory info to artifacts for UI."""
111-
artifacts: ListDirectoryArtifacts = {
112-
"dirpath": self.path,
113-
"name": self.name,
114-
"is_leaf": self.is_leaf,
115-
"depth": self.depth,
116-
"max_depth": self.max_depth,
117-
}
118-
119-
if self.files is not None:
120-
artifacts["files"] = self.files
121-
artifacts["file_paths"] = [f"{self.path}/{f}" for f in self.files]
122-
123-
if self.subdirectories:
124-
artifacts["subdirs"] = [d.name for d in self.subdirectories]
125-
artifacts["subdir_paths"] = [d.path for d in self.subdirectories]
126-
127-
return artifacts
128-
129100

130101
class ListDirectoryObservation(Observation):
131102
"""Response from listing directory contents."""
@@ -136,29 +107,9 @@ class ListDirectoryObservation(Observation):
136107

137108
str_template: ClassVar[str] = "{directory_info}"
138109

139-
def render(self, tool_call_id: str) -> ToolMessage:
140-
"""Render directory listing with artifacts for UI."""
141-
if self.status == "error":
142-
error_artifacts: ListDirectoryArtifacts = {
143-
"dirpath": self.directory_info.path,
144-
"name": self.directory_info.name,
145-
"error": self.error,
146-
}
147-
return ToolMessage(
148-
content=f"[ERROR LISTING DIRECTORY]: {self.directory_info.path}: {self.error}",
149-
status=self.status,
150-
tool_name="list_directory",
151-
artifact=error_artifacts,
152-
tool_call_id=tool_call_id,
153-
)
154-
155-
return ToolMessage(
156-
content=self.directory_info.render_as_string(),
157-
status=self.status,
158-
tool_name="list_directory",
159-
artifact=self.directory_info.to_artifacts(),
160-
tool_call_id=tool_call_id,
161-
)
110+
def render(self) -> str:
111+
"""Render directory listing."""
112+
return self.directory_info.render()
162113

163114

164115
def list_directory(codebase: Codebase, path: str = "./", depth: int = 2) -> ListDirectoryObservation:
@@ -185,7 +136,7 @@ def list_directory(codebase: Codebase, path: str = "./", depth: int = 2) -> List
185136
),
186137
)
187138

188-
def get_directory_info(dir_obj: Directory, current_depth: int, max_depth: int) -> DirectoryInfo:
139+
def get_directory_info(dir_obj: Directory, current_depth: int) -> DirectoryInfo:
189140
"""Helper function to get directory info recursively."""
190141
# Get direct files (always include files unless at max depth)
191142
all_files = []
@@ -200,7 +151,7 @@ def get_directory_info(dir_obj: Directory, current_depth: int, max_depth: int) -
200151
if current_depth > 1 or current_depth == -1:
201152
# For deeper traversal, get full directory info
202153
new_depth = current_depth - 1 if current_depth > 1 else -1
203-
subdirs.append(get_directory_info(subdir, new_depth, max_depth))
154+
subdirs.append(get_directory_info(subdir, new_depth))
204155
else:
205156
# At max depth, return a leaf node
206157
subdirs.append(
@@ -210,8 +161,6 @@ def get_directory_info(dir_obj: Directory, current_depth: int, max_depth: int) -
210161
path=subdir.dirpath,
211162
files=None, # Don't include files at max depth
212163
is_leaf=True,
213-
depth=current_depth,
214-
max_depth=max_depth,
215164
)
216165
)
217166

@@ -221,11 +170,9 @@ def get_directory_info(dir_obj: Directory, current_depth: int, max_depth: int) -
221170
path=dir_obj.dirpath,
222171
files=sorted(all_files),
223172
subdirectories=subdirs,
224-
depth=current_depth,
225-
max_depth=max_depth,
226173
)
227174

228-
dir_info = get_directory_info(directory, depth, depth)
175+
dir_info = get_directory_info(directory, depth)
229176
return ListDirectoryObservation(
230177
status="success",
231178
directory_info=dir_info,

0 commit comments

Comments
 (0)