Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: update searchbyfilename tool to paginate #896

Merged
merged 8 commits into from
Mar 18, 2025
9 changes: 6 additions & 3 deletions src/codegen/extensions/langchain/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,24 +1091,27 @@ class SearchFilesByNameInput(BaseModel):
"""Input for searching files by name pattern."""

pattern: str = Field(..., description="`fd`-compatible glob pattern to search for (e.g. '*.py', 'test_*.py')")

page: int = Field(default=1, description="Page number to return (1-based)")
files_per_page: int | float = Field(default=10, description="Number of files per page to return, use math.inf to return all files")

class SearchFilesByNameTool(BaseTool):
"""Tool for searching files by filename across a codebase."""

name: ClassVar[str] = "search_files_by_name"
description: ClassVar[str] = """
Search for files and directories by glob pattern across the active codebase. This is useful when you need to:
Search for files and directories by glob pattern (with pagination) across the active codebase. This is useful when you need to:
- Find specific file types (e.g., '*.py', '*.tsx')
- Locate configuration files (e.g., 'package.json', 'requirements.txt')
- Find files with specific names (e.g., 'README.md', 'Dockerfile')
"""
args_schema: ClassVar[type[BaseModel]] = SearchFilesByNameInput
codebase: Codebase = Field(exclude=True)



def __init__(self, codebase: Codebase):
super().__init__(codebase=codebase)

def _run(self, pattern: str) -> str:
"""Execute the glob pattern search using fd."""
return search_files_by_name(self.codebase, pattern).render()
return search_files_by_name(self.codebase, pattern, page=self.page, files_per_page=self.files_per_page).render()
3 changes: 2 additions & 1 deletion src/codegen/extensions/tools/global_replacement_edit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import difflib
import logging
import math
import re
from typing import ClassVar

Expand Down Expand Up @@ -103,7 +104,7 @@ def replacement_edit_global(
)

diffs = []
for file in search_files_by_name(codebase, file_pattern).files:
for file in search_files_by_name(codebase, file_pattern, page=1, files_per_page=math.inf).files:
if count is not None and count <= 0:
break
try:
Expand Down
63 changes: 57 additions & 6 deletions src/codegen/extensions/tools/search_files_by_name.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
import shutil
import subprocess
from typing import ClassVar
from typing import ClassVar, Optional

from pydantic import Field

Expand All @@ -20,33 +21,55 @@ class SearchFilesByNameResultObservation(Observation):
files: list[str] = Field(
description="List of matching file paths",
)
page: int = Field(
description="Current page number (1-based)",
)
total_pages: int = Field(
description="Total number of pages available",
)
total_files: int = Field(
description="Total number of files with matches",
)
files_per_page: int | float = Field(
description="Number of files shown per page",
)

str_template: ClassVar[str] = "Found {total} files matching pattern: {pattern}"
str_template: ClassVar[str] = "Found {total_files} files matching pattern: {pattern} (page {page}/{total_pages})"

@property
def total(self) -> int:
return len(self.files)
return self.total_files


def search_files_by_name(
codebase: Codebase,
pattern: str,
page: int = 1,
files_per_page: int | float = 10,
) -> SearchFilesByNameResultObservation:
"""Search for files by name pattern in the codebase.

Args:
codebase: The codebase to search in
pattern: Glob pattern to search for (e.g. "*.py", "test_*.py")
page: Page number to return (1-based, default: 1)
files_per_page: Number of files to return per page (default: 10)
"""
try:
# Validate pagination parameters
if page < 1:
page = 1
if files_per_page is not None and files_per_page < 1:
files_per_page = 20

if shutil.which("fd") is None:
logger.warning("fd is not installed, falling back to find")
results = subprocess.check_output(
["find", "-name", pattern],
cwd=codebase.repo_path,
timeout=30,
)
files = [path.removeprefix("./") for path in results.decode("utf-8").strip().split("\n")] if results.strip() else []
all_files = [path.removeprefix("./") for path in results.decode("utf-8").strip().split("\n")] if results.strip() else []

else:
logger.info(f"Searching for files with pattern: {pattern}")
Expand All @@ -55,12 +78,36 @@ def search_files_by_name(
cwd=codebase.repo_path,
timeout=30,
)
files = results.decode("utf-8").strip().split("\n") if results.strip() else []
all_files = results.decode("utf-8").strip().split("\n") if results.strip() else []

# Sort files for consistent pagination
all_files.sort()

# Calculate pagination
total_files = len(all_files)
if files_per_page == math.inf:
files_per_page = total_files
total_pages = 1
else:
total_pages = (total_files + files_per_page - 1) // files_per_page if total_files > 0 else 1


# Ensure page is within valid range
page = min(page, total_pages)

# Get paginated results
start_idx = (page - 1) * files_per_page
end_idx = start_idx + files_per_page
paginated_files = all_files[start_idx:end_idx]

return SearchFilesByNameResultObservation(
status="success",
pattern=pattern,
files=files,
files=paginated_files,
page=page,
total_pages=total_pages,
total_files=total_files,
files_per_page=files_per_page,
)

except Exception as e:
Expand All @@ -69,4 +116,8 @@ def search_files_by_name(
error=f"Error searching files: {e!s}",
pattern=pattern,
files=[],
page=page,
total_pages=0,
total_files=0,
files_per_page=files_per_page,
)
Loading