diff --git a/src/codegen/extensions/tools/github/create_pr.py b/src/codegen/extensions/tools/github/create_pr.py index 70da33b3d..a46f8390e 100644 --- a/src/codegen/extensions/tools/github/create_pr.py +++ b/src/codegen/extensions/tools/github/create_pr.py @@ -1,5 +1,6 @@ """Tool for creating pull requests.""" +import re import uuid from typing import ClassVar @@ -23,8 +24,66 @@ class CreatePRObservation(Observation): title: str = Field( description="Title of the PR", ) + changes_summary: str = Field( + description="Summary of changes included in the PR", + default="", + ) + + str_template: ClassVar[str] = "Created PR #{number}: {title}\n\nChanges Summary:\n{changes_summary}" + + +def generate_changes_summary(diff_text: str) -> str: + """Generate a human-readable summary of changes from a git diff. - str_template: ClassVar[str] = "Created PR #{number}: {title}" + Args: + diff_text: The git diff text + + Returns: + A formatted summary of the changes + """ + if not diff_text: + return "No changes detected." + + # Parse the diff to extract file information + file_pattern = re.compile(r"diff --git a/(.*?) b/(.*?)\n") + file_matches = file_pattern.findall(diff_text) + + # Count additions and deletions + addition_pattern = re.compile(r"^\+[^+]", re.MULTILINE) + deletion_pattern = re.compile(r"^-[^-]", re.MULTILINE) + + additions = len(addition_pattern.findall(diff_text)) + deletions = len(deletion_pattern.findall(diff_text)) + + # Get unique files changed + files_changed = set() + for match in file_matches: + # Use the second part of the match (b/file) as it represents the new file + files_changed.add(match[1]) + + # Group files by extension + file_extensions: dict[str, list[str]] = {} + for file in files_changed: + ext = file.split(".")[-1] if "." in file else "other" + if ext not in file_extensions: + file_extensions[ext] = [] + file_extensions[ext].append(file) + + # Build the summary + summary = [] + summary.append(f"**Files Changed:** {len(files_changed)}") + summary.append(f"**Lines Added:** {additions}") + summary.append(f"**Lines Deleted:** {deletions}") + + # Add file details grouped by extension + if file_extensions: + summary.append("\n**Modified Files:**") + for ext, files in file_extensions.items(): + summary.append(f"\n*{ext.upper()} Files:*") + for file in sorted(files): + summary.append(f"- {file}") + + return "\n".join(summary) def create_pr(codebase: Codebase, title: str, body: str) -> CreatePRObservation: @@ -37,15 +96,20 @@ def create_pr(codebase: Codebase, title: str, body: str) -> CreatePRObservation: """ try: # Check for uncommitted changes and commit them - if len(codebase.get_diff()) == 0: + diff_text = codebase.get_diff() + if len(diff_text) == 0: return CreatePRObservation( status="error", error="No changes to create a PR.", url="", number=0, title=title, + changes_summary="", ) + # Generate a summary of changes + changes_summary = generate_changes_summary(diff_text) + # TODO: this is very jank. We should ideally check out the branch before # making the changes, but it looks like `codebase.checkout` blows away # all of your changes @@ -65,6 +129,7 @@ def create_pr(codebase: Codebase, title: str, body: str) -> CreatePRObservation: url="", number=0, title=title, + changes_summary="", ) return CreatePRObservation( @@ -72,6 +137,7 @@ def create_pr(codebase: Codebase, title: str, body: str) -> CreatePRObservation: url=pr.html_url, number=pr.number, title=pr.title, + changes_summary=changes_summary, ) except Exception as e: @@ -81,4 +147,5 @@ def create_pr(codebase: Codebase, title: str, body: str) -> CreatePRObservation: url="", number=0, title=title, + changes_summary="", )