Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Adding Schema for Tool Outputs" #892

Merged
merged 1 commit into from
Mar 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"metadata": {},
"outputs": [],
"source": [
"await run_eval(use_existing_preds=None, dataset=\"lite\", length=5, repo=\"django/django\", num_workers=10, model=\"claude-3-7-sonnet-latest\")"
"await run_eval(use_existing_preds=None, dataset=\"lite\", length=20, repo=\"django/django\", num_workers=10, model=\"claude-3-7-sonnet-latest\")"
]
},
{
Expand Down
1 change: 0 additions & 1 deletion src/codegen/agents/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ class ToolMessageData(BaseMessage):
tool_name: Optional[str] = None
tool_response: Optional[str] = None
tool_id: Optional[str] = None
status: Optional[str] = None


@dataclass
Expand Down
9 changes: 1 addition & 8 deletions src/codegen/agents/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,7 @@ def extract_structured_data(self, chunk: dict[str, Any]) -> Optional[BaseMessage
tool_calls = [ToolCall(name=tc.get("name"), arguments=tc.get("arguments"), id=tc.get("id")) for tc in tool_calls_data]
return AssistantMessage(type=message_type, content=content, tool_calls=tool_calls)
elif message_type == "tool":
return ToolMessageData(
type=message_type,
content=content,
tool_name=getattr(latest_message, "name", None),
tool_response=getattr(latest_message, "artifact", content),
tool_id=getattr(latest_message, "tool_call_id", None),
status=getattr(latest_message, "status", None),
)
return ToolMessageData(type=message_type, content=content, tool_name=getattr(latest_message, "name", None), tool_response=content, tool_id=getattr(latest_message, "tool_call_id", None))
elif message_type == "function":
return FunctionMessageData(type=message_type, content=content)
else:
Expand Down
4 changes: 2 additions & 2 deletions src/codegen/extensions/langchain/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def reasoner(self, state: GraphState) -> dict[str, Any]:
messages.append(HumanMessage(content=query))

result = self.model.invoke([self.system_message, *messages])
if isinstance(result, AIMessage) and not result.tool_calls:
if isinstance(result, AIMessage):
updated_messages = [*messages, result]
return {"messages": updated_messages, "final_answer": result.content}

Expand Down Expand Up @@ -455,7 +455,7 @@ def get_field_descriptions(tool_obj):
return f"Error: Could not identify the tool you're trying to use.\n\nAvailable tools:\n{available_tools}\n\nPlease use one of the available tools with the correct parameters."

# For other types of errors
return f"Error executing tool: {exception!s}\n\nPlease check your tool usage and try again with the correct parameters."
return f"Error executing tool: {error_msg}\n\nPlease check your tool usage and try again with the correct parameters."

# Add nodes
builder.add_node("reasoner", self.reasoner, retry=retry_policy)
Expand Down
46 changes: 19 additions & 27 deletions src/codegen/extensions/langchain/tools.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
"""Langchain tools for workspace operations."""

from collections.abc import Callable
from typing import Annotated, ClassVar, Literal, Optional
from typing import ClassVar, Literal

from langchain_core.messages import ToolMessage
from langchain_core.tools import InjectedToolCallId
from langchain_core.tools.base import BaseTool
from pydantic import BaseModel, Field

Expand Down Expand Up @@ -54,11 +52,10 @@ class ViewFileInput(BaseModel):
"""Input for viewing a file."""

filepath: str = Field(..., description="Path to the file relative to workspace root")
start_line: Optional[int] = Field(None, description="Starting line number to view (1-indexed, inclusive)")
end_line: Optional[int] = Field(None, description="Ending line number to view (1-indexed, inclusive)")
max_lines: Optional[int] = Field(None, description="Maximum number of lines to view at once, defaults to 250")
line_numbers: Optional[bool] = Field(True, description="If True, add line numbers to the content (1-indexed)")
tool_call_id: Annotated[str, InjectedToolCallId]
start_line: int | None = Field(None, description="Starting line number to view (1-indexed, inclusive)")
end_line: int | None = Field(None, description="Ending line number to view (1-indexed, inclusive)")
max_lines: int | None = Field(None, description="Maximum number of lines to view at once, defaults to 250")
line_numbers: bool | None = Field(True, description="If True, add line numbers to the content (1-indexed)")


class ViewFileTool(BaseTool):
Expand All @@ -76,13 +73,12 @@ def __init__(self, codebase: Codebase) -> None:

def _run(
self,
tool_call_id: str,
filepath: str,
start_line: Optional[int] = None,
end_line: Optional[int] = None,
max_lines: Optional[int] = None,
line_numbers: Optional[bool] = True,
) -> ToolMessage:
start_line: int | None = None,
end_line: int | None = None,
max_lines: int | None = None,
line_numbers: bool | None = True,
) -> str:
result = view_file(
self.codebase,
filepath,
Expand All @@ -92,15 +88,14 @@ def _run(
max_lines=max_lines if max_lines is not None else 250,
)

return result.render(tool_call_id)
return result.render()


class ListDirectoryInput(BaseModel):
"""Input for listing directory contents."""

dirpath: str = Field(default="./", description="Path to directory relative to workspace root")
depth: int = Field(default=1, description="How deep to traverse. Use -1 for unlimited depth.")
tool_call_id: Annotated[str, InjectedToolCallId]


class ListDirectoryTool(BaseTool):
Expand All @@ -114,9 +109,9 @@ class ListDirectoryTool(BaseTool):
def __init__(self, codebase: Codebase) -> None:
super().__init__(codebase=codebase)

def _run(self, tool_call_id: str, dirpath: str = "./", depth: int = 1) -> ToolMessage:
def _run(self, dirpath: str = "./", depth: int = 1) -> str:
result = list_directory(self.codebase, dirpath, depth)
return result.render(tool_call_id)
return result.render()


class SearchInput(BaseModel):
Expand All @@ -131,7 +126,6 @@ class SearchInput(BaseModel):
page: int = Field(default=1, description="Page number to return (1-based, default: 1)")
files_per_page: int = Field(default=10, description="Number of files to return per page (default: 10)")
use_regex: bool = Field(default=False, description="Whether to treat query as a regex pattern (default: False)")
tool_call_id: Annotated[str, InjectedToolCallId]


class SearchTool(BaseTool):
Expand All @@ -145,17 +139,16 @@ class SearchTool(BaseTool):
def __init__(self, codebase: Codebase) -> None:
super().__init__(codebase=codebase)

def _run(self, tool_call_id: str, query: str, file_extensions: Optional[list[str]] = None, page: int = 1, files_per_page: int = 10, use_regex: bool = False) -> ToolMessage:
def _run(self, query: str, file_extensions: list[str] | None = None, page: int = 1, files_per_page: int = 10, use_regex: bool = False) -> str:
result = search(self.codebase, query, file_extensions=file_extensions, page=page, files_per_page=files_per_page, use_regex=use_regex)
return result.render(tool_call_id)
return result.render()


class EditFileInput(BaseModel):
"""Input for editing a file."""

filepath: str = Field(..., description="Path to the file to edit")
content: str = Field(..., description="New content for the file")
tool_call_id: Annotated[str, InjectedToolCallId]


class EditFileTool(BaseTool):
Expand Down Expand Up @@ -188,9 +181,9 @@ class EditFileTool(BaseTool):
def __init__(self, codebase: Codebase) -> None:
super().__init__(codebase=codebase)

def _run(self, filepath: str, content: str, tool_call_id: str) -> str:
def _run(self, filepath: str, content: str) -> str:
result = edit_file(self.codebase, filepath, content)
return result.render(tool_call_id)
return result.render()


class CreateFileInput(BaseModel):
Expand Down Expand Up @@ -347,7 +340,6 @@ class SemanticEditInput(BaseModel):
edit_content: str = Field(..., description=FILE_EDIT_PROMPT)
start: int = Field(default=1, description="Starting line number (1-indexed, inclusive). Default is 1.")
end: int = Field(default=-1, description="Ending line number (1-indexed, inclusive). Default is -1 (end of file).")
tool_call_id: Annotated[str, InjectedToolCallId]


class SemanticEditTool(BaseTool):
Expand All @@ -361,10 +353,10 @@ class SemanticEditTool(BaseTool):
def __init__(self, codebase: Codebase) -> None:
super().__init__(codebase=codebase)

def _run(self, filepath: str, tool_call_id: str, edit_content: str, start: int = 1, end: int = -1) -> ToolMessage:
def _run(self, filepath: str, edit_content: str, start: int = 1, end: int = -1) -> str:
# Create the the draft editor mini llm
result = semantic_edit(self.codebase, filepath, edit_content, start=start, end=end)
return result.render(tool_call_id)
return result.render()


class RenameFileInput(BaseModel):
Expand Down
33 changes: 6 additions & 27 deletions src/codegen/extensions/tools/edit_file.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,32 @@
"""Tool for editing file contents."""

from typing import TYPE_CHECKING, ClassVar, Optional
from typing import ClassVar

from langchain_core.messages import ToolMessage
from pydantic import Field

from codegen.sdk.core.codebase import Codebase

from .observation import Observation
from .replacement_edit import generate_diff

if TYPE_CHECKING:
from .tool_output_types import EditFileArtifacts


class EditFileObservation(Observation):
"""Response from editing a file."""

filepath: str = Field(
description="Path to the edited file",
)
diff: Optional[str] = Field(
default=None,
diff: str = Field(
description="Unified diff showing the changes made",
)

str_template: ClassVar[str] = "Edited file {filepath}"

def render(self, tool_call_id: str) -> ToolMessage:
def render(self) -> str:
"""Render edit results in a clean format."""
if self.status == "error":
artifacts_error: EditFileArtifacts = {"filepath": self.filepath, "error": self.error}
return ToolMessage(
content=f"[ERROR EDITING FILE]: {self.filepath}: {self.error}",
status=self.status,
tool_name="edit_file",
artifact=artifacts_error,
tool_call_id=tool_call_id,
)

artifacts_success: EditFileArtifacts = {"filepath": self.filepath, "diff": self.diff}

return ToolMessage(
content=f"""[EDIT FILE]: {self.filepath}\n\n{self.diff}""",
status=self.status,
tool_name="edit_file",
artifact=artifacts_success,
tool_call_id=tool_call_id,
)
return f"""[EDIT FILE]: {self.filepath}

{self.diff}"""


def edit_file(codebase: Codebase, filepath: str, new_content: str) -> EditFileObservation:
Expand Down
71 changes: 9 additions & 62 deletions src/codegen/extensions/tools/list_directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@

from typing import ClassVar

from langchain_core.messages import ToolMessage
from pydantic import Field

from codegen.extensions.tools.observation import Observation
from codegen.extensions.tools.tool_output_types import ListDirectoryArtifacts
from codegen.sdk.core.codebase import Codebase
from codegen.sdk.core.directory import Directory

from .observation import Observation


class DirectoryInfo(Observation):
"""Information about a directory."""
Expand All @@ -32,14 +31,6 @@ class DirectoryInfo(Observation):
default=False,
description="Whether this is a leaf node (at max depth)",
)
depth: int = Field(
default=0,
description="Current depth in the tree",
)
max_depth: int = Field(
default=1,
description="Maximum depth allowed",
)

str_template: ClassVar[str] = "Directory {path} ({file_count} files, {dir_count} subdirs)"

Expand All @@ -50,7 +41,7 @@ def _get_details(self) -> dict[str, int]:
"dir_count": len(self.subdirectories),
}

def render_as_string(self) -> str:
def render(self) -> str:
"""Render directory listing as a file tree."""
lines = [
f"[LIST DIRECTORY]: {self.path}",
Expand Down Expand Up @@ -106,26 +97,6 @@ def build_tree(items: list[tuple[str, bool, "DirectoryInfo | None"]], prefix: st

return "\n".join(lines)

def to_artifacts(self) -> ListDirectoryArtifacts:
"""Convert directory info to artifacts for UI."""
artifacts: ListDirectoryArtifacts = {
"dirpath": self.path,
"name": self.name,
"is_leaf": self.is_leaf,
"depth": self.depth,
"max_depth": self.max_depth,
}

if self.files is not None:
artifacts["files"] = self.files
artifacts["file_paths"] = [f"{self.path}/{f}" for f in self.files]

if self.subdirectories:
artifacts["subdirs"] = [d.name for d in self.subdirectories]
artifacts["subdir_paths"] = [d.path for d in self.subdirectories]

return artifacts


class ListDirectoryObservation(Observation):
"""Response from listing directory contents."""
Expand All @@ -136,29 +107,9 @@ class ListDirectoryObservation(Observation):

str_template: ClassVar[str] = "{directory_info}"

def render(self, tool_call_id: str) -> ToolMessage:
"""Render directory listing with artifacts for UI."""
if self.status == "error":
error_artifacts: ListDirectoryArtifacts = {
"dirpath": self.directory_info.path,
"name": self.directory_info.name,
"error": self.error,
}
return ToolMessage(
content=f"[ERROR LISTING DIRECTORY]: {self.directory_info.path}: {self.error}",
status=self.status,
tool_name="list_directory",
artifact=error_artifacts,
tool_call_id=tool_call_id,
)

return ToolMessage(
content=self.directory_info.render_as_string(),
status=self.status,
tool_name="list_directory",
artifact=self.directory_info.to_artifacts(),
tool_call_id=tool_call_id,
)
def render(self) -> str:
"""Render directory listing."""
return self.directory_info.render()


def list_directory(codebase: Codebase, path: str = "./", depth: int = 2) -> ListDirectoryObservation:
Expand All @@ -185,7 +136,7 @@ def list_directory(codebase: Codebase, path: str = "./", depth: int = 2) -> List
),
)

def get_directory_info(dir_obj: Directory, current_depth: int, max_depth: int) -> DirectoryInfo:
def get_directory_info(dir_obj: Directory, current_depth: int) -> DirectoryInfo:
"""Helper function to get directory info recursively."""
# Get direct files (always include files unless at max depth)
all_files = []
Expand All @@ -200,7 +151,7 @@ def get_directory_info(dir_obj: Directory, current_depth: int, max_depth: int) -
if current_depth > 1 or current_depth == -1:
# For deeper traversal, get full directory info
new_depth = current_depth - 1 if current_depth > 1 else -1
subdirs.append(get_directory_info(subdir, new_depth, max_depth))
subdirs.append(get_directory_info(subdir, new_depth))
else:
# At max depth, return a leaf node
subdirs.append(
Expand All @@ -210,8 +161,6 @@ def get_directory_info(dir_obj: Directory, current_depth: int, max_depth: int) -
path=subdir.dirpath,
files=None, # Don't include files at max depth
is_leaf=True,
depth=current_depth,
max_depth=max_depth,
)
)

Expand All @@ -221,11 +170,9 @@ def get_directory_info(dir_obj: Directory, current_depth: int, max_depth: int) -
path=dir_obj.dirpath,
files=sorted(all_files),
subdirectories=subdirs,
depth=current_depth,
max_depth=max_depth,
)

dir_info = get_directory_info(directory, depth, depth)
dir_info = get_directory_info(directory, depth)
return ListDirectoryObservation(
status="success",
directory_info=dir_info,
Expand Down
Loading
Loading