Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Revert "Adding Schema for Tool Outputs"" #894

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"metadata": {},
"outputs": [],
"source": [
"await run_eval(use_existing_preds=None, dataset=\"lite\", length=20, repo=\"django/django\", num_workers=10, model=\"claude-3-7-sonnet-latest\")"
"await run_eval(use_existing_preds=None, dataset=\"lite\", length=5, repo=\"django/django\", num_workers=10, model=\"claude-3-7-sonnet-latest\")"
]
},
{
Expand Down
1 change: 1 addition & 0 deletions src/codegen/agents/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class ToolMessageData(BaseMessage):
tool_name: Optional[str] = None
tool_response: Optional[str] = None
tool_id: Optional[str] = None
status: Optional[str] = None


@dataclass
Expand Down
9 changes: 8 additions & 1 deletion src/codegen/agents/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,14 @@ def extract_structured_data(self, chunk: dict[str, Any]) -> Optional[BaseMessage
tool_calls = [ToolCall(name=tc.get("name"), arguments=tc.get("arguments"), id=tc.get("id")) for tc in tool_calls_data]
return AssistantMessage(type=message_type, content=content, tool_calls=tool_calls)
elif message_type == "tool":
return ToolMessageData(type=message_type, content=content, tool_name=getattr(latest_message, "name", None), tool_response=content, tool_id=getattr(latest_message, "tool_call_id", None))
return ToolMessageData(
type=message_type,
content=content,
tool_name=getattr(latest_message, "name", None),
tool_response=getattr(latest_message, "artifact", content),
tool_id=getattr(latest_message, "tool_call_id", None),
status=getattr(latest_message, "status", None),
)
elif message_type == "function":
return FunctionMessageData(type=message_type, content=content)
else:
Expand Down
4 changes: 2 additions & 2 deletions src/codegen/extensions/langchain/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def reasoner(self, state: GraphState) -> dict[str, Any]:
messages.append(HumanMessage(content=query))

result = self.model.invoke([self.system_message, *messages])
if isinstance(result, AIMessage):
if isinstance(result, AIMessage) and not result.tool_calls:
updated_messages = [*messages, result]
return {"messages": updated_messages, "final_answer": result.content}

Expand Down Expand Up @@ -455,7 +455,7 @@ def get_field_descriptions(tool_obj):
return f"Error: Could not identify the tool you're trying to use.\n\nAvailable tools:\n{available_tools}\n\nPlease use one of the available tools with the correct parameters."

# For other types of errors
return f"Error executing tool: {error_msg}\n\nPlease check your tool usage and try again with the correct parameters."
return f"Error executing tool: {exception!s}\n\nPlease check your tool usage and try again with the correct parameters."

# Add nodes
builder.add_node("reasoner", self.reasoner, retry=retry_policy)
Expand Down
51 changes: 30 additions & 21 deletions src/codegen/extensions/langchain/tools.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Langchain tools for workspace operations."""

from collections.abc import Callable
from typing import ClassVar, Literal
from typing import Annotated, ClassVar, Literal, Optional

from langchain_core.messages import ToolMessage
from langchain_core.tools import InjectedToolCallId
from langchain_core.tools.base import BaseTool
from pydantic import BaseModel, Field

Expand Down Expand Up @@ -52,10 +54,11 @@ class ViewFileInput(BaseModel):
"""Input for viewing a file."""

filepath: str = Field(..., description="Path to the file relative to workspace root")
start_line: int | None = Field(None, description="Starting line number to view (1-indexed, inclusive)")
end_line: int | None = Field(None, description="Ending line number to view (1-indexed, inclusive)")
max_lines: int | None = Field(None, description="Maximum number of lines to view at once, defaults to 500")
line_numbers: bool | None = Field(True, description="If True, add line numbers to the content (1-indexed)")
start_line: Optional[int] = Field(None, description="Starting line number to view (1-indexed, inclusive)")
end_line: Optional[int] = Field(None, description="Ending line number to view (1-indexed, inclusive)")
max_lines: Optional[int] = Field(None, description="Maximum number of lines to view at once, defaults to 500")
line_numbers: Optional[bool] = Field(True, description="If True, add line numbers to the content (1-indexed)")
tool_call_id: Annotated[str, InjectedToolCallId]


class ViewFileTool(BaseTool):
Expand All @@ -73,12 +76,13 @@ def __init__(self, codebase: Codebase) -> None:

def _run(
self,
tool_call_id: str,
filepath: str,
start_line: int | None = None,
end_line: int | None = None,
max_lines: int | None = None,
line_numbers: bool | None = True,
) -> str:
start_line: Optional[int] = None,
end_line: Optional[int] = None,
max_lines: Optional[int] = None,
line_numbers: Optional[bool] = True,
) -> ToolMessage:
result = view_file(
self.codebase,
filepath,
Expand All @@ -88,14 +92,15 @@ def _run(
max_lines=max_lines if max_lines is not None else 500,
)

return result.render()
return result.render(tool_call_id)


class ListDirectoryInput(BaseModel):
"""Input for listing directory contents."""

dirpath: str = Field(default="./", description="Path to directory relative to workspace root")
depth: int = Field(default=1, description="How deep to traverse. Use -1 for unlimited depth.")
tool_call_id: Annotated[str, InjectedToolCallId]


class ListDirectoryTool(BaseTool):
Expand All @@ -109,9 +114,9 @@ class ListDirectoryTool(BaseTool):
def __init__(self, codebase: Codebase) -> None:
super().__init__(codebase=codebase)

def _run(self, dirpath: str = "./", depth: int = 1) -> str:
def _run(self, tool_call_id: str, dirpath: str = "./", depth: int = 1) -> ToolMessage:
result = list_directory(self.codebase, dirpath, depth)
return result.render()
return result.render(tool_call_id)


class SearchInput(BaseModel):
Expand All @@ -126,6 +131,7 @@ class SearchInput(BaseModel):
page: int = Field(default=1, description="Page number to return (1-based, default: 1)")
files_per_page: int = Field(default=10, description="Number of files to return per page (default: 10)")
use_regex: bool = Field(default=False, description="Whether to treat query as a regex pattern (default: False)")
tool_call_id: Annotated[str, InjectedToolCallId]


class SearchTool(BaseTool):
Expand All @@ -139,16 +145,17 @@ class SearchTool(BaseTool):
def __init__(self, codebase: Codebase) -> None:
super().__init__(codebase=codebase)

def _run(self, query: str, file_extensions: list[str] | None = None, page: int = 1, files_per_page: int = 10, use_regex: bool = False) -> str:
def _run(self, tool_call_id: str, query: str, file_extensions: Optional[list[str]] = None, page: int = 1, files_per_page: int = 10, use_regex: bool = False) -> ToolMessage:
result = search(self.codebase, query, file_extensions=file_extensions, page=page, files_per_page=files_per_page, use_regex=use_regex)
return result.render()
return result.render(tool_call_id)


class EditFileInput(BaseModel):
"""Input for editing a file."""

filepath: str = Field(..., description="Path to the file to edit")
content: str = Field(..., description="New content for the file")
tool_call_id: Annotated[str, InjectedToolCallId]


class EditFileTool(BaseTool):
Expand Down Expand Up @@ -181,9 +188,9 @@ class EditFileTool(BaseTool):
def __init__(self, codebase: Codebase) -> None:
super().__init__(codebase=codebase)

def _run(self, filepath: str, content: str) -> str:
def _run(self, filepath: str, content: str, tool_call_id: str) -> str:
result = edit_file(self.codebase, filepath, content)
return result.render()
return result.render(tool_call_id)


class CreateFileInput(BaseModel):
Expand Down Expand Up @@ -340,6 +347,7 @@ class SemanticEditInput(BaseModel):
edit_content: str = Field(..., description=FILE_EDIT_PROMPT)
start: int = Field(default=1, description="Starting line number (1-indexed, inclusive). Default is 1.")
end: int = Field(default=-1, description="Ending line number (1-indexed, inclusive). Default is -1 (end of file).")
tool_call_id: Annotated[str, InjectedToolCallId]


class SemanticEditTool(BaseTool):
Expand All @@ -353,10 +361,10 @@ class SemanticEditTool(BaseTool):
def __init__(self, codebase: Codebase) -> None:
super().__init__(codebase=codebase)

def _run(self, filepath: str, edit_content: str, start: int = 1, end: int = -1) -> str:
def _run(self, filepath: str, tool_call_id: str, edit_content: str, start: int = 1, end: int = -1) -> ToolMessage:
# Create the the draft editor mini llm
result = semantic_edit(self.codebase, filepath, edit_content, start=start, end=end)
return result.render()
return result.render(tool_call_id)


class RenameFileInput(BaseModel):
Expand Down Expand Up @@ -1033,6 +1041,7 @@ class RelaceEditInput(BaseModel):

filepath: str = Field(..., description="Path of the file relative to workspace root")
edit_snippet: str = Field(..., description=RELACE_EDIT_PROMPT)
tool_call_id: Annotated[str, InjectedToolCallId]


class RelaceEditTool(BaseTool):
Expand All @@ -1046,9 +1055,9 @@ class RelaceEditTool(BaseTool):
def __init__(self, codebase: Codebase) -> None:
super().__init__(codebase=codebase)

def _run(self, filepath: str, edit_snippet: str) -> str:
def _run(self, filepath: str, edit_snippet: str, tool_call_id: str) -> ToolMessage:
result = relace_edit(self.codebase, filepath, edit_snippet)
return result.render()
return result.render(tool_call_id=tool_call_id)


class ReflectionInput(BaseModel):
Expand Down
33 changes: 27 additions & 6 deletions src/codegen/extensions/tools/edit_file.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,53 @@
"""Tool for editing file contents."""

from typing import ClassVar
from typing import TYPE_CHECKING, ClassVar, Optional

from langchain_core.messages import ToolMessage
from pydantic import Field

from codegen.sdk.core.codebase import Codebase

from .observation import Observation
from .replacement_edit import generate_diff

if TYPE_CHECKING:
from .tool_output_types import EditFileArtifacts


class EditFileObservation(Observation):
"""Response from editing a file."""

filepath: str = Field(
description="Path to the edited file",
)
diff: str = Field(
diff: Optional[str] = Field(
default=None,
description="Unified diff showing the changes made",
)

str_template: ClassVar[str] = "Edited file {filepath}"

def render(self) -> str:
def render(self, tool_call_id: str) -> ToolMessage:
"""Render edit results in a clean format."""
return f"""[EDIT FILE]: {self.filepath}

{self.diff}"""
if self.status == "error":
artifacts_error: EditFileArtifacts = {"filepath": self.filepath, "error": self.error}
return ToolMessage(
content=f"[ERROR EDITING FILE]: {self.filepath}: {self.error}",
status=self.status,
name="edit_file",
artifact=artifacts_error,
tool_call_id=tool_call_id,
)

artifacts_success: EditFileArtifacts = {"filepath": self.filepath, "diff": self.diff}

return ToolMessage(
content=f"""[EDIT FILE]: {self.filepath}\n\n{self.diff}""",
status=self.status,
name="edit_file",
artifact=artifacts_success,
tool_call_id=tool_call_id,
)


def edit_file(codebase: Codebase, filepath: str, new_content: str) -> EditFileObservation:
Expand Down
Loading
Loading