|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +# pyre-strict |
| 8 | + |
| 9 | +import logging |
| 10 | +import os |
| 11 | +from datetime import datetime |
| 12 | +from typing import List, Tuple |
| 13 | + |
| 14 | +import requests |
| 15 | + |
| 16 | +from datatypes import GHCommit, GHPullRequest |
| 17 | +from dateutil import parser as dtparser # For flexible date parsing |
| 18 | + |
| 19 | +logging.basicConfig(level=logging.INFO) |
| 20 | + |
| 21 | + |
| 22 | +class GitHubClient: |
| 23 | + # GitHub API base URL |
| 24 | + API_URL = "https://api.github.com" |
| 25 | + |
| 26 | + def __init__(self, owner: str, repo: str) -> None: |
| 27 | + # Replace with your GitHub token and the repo details |
| 28 | + if "GITHUB_TOKEN" not in os.environ: |
| 29 | + raise Exception("GITHUB_TOKEN not set") |
| 30 | + self.token = os.environ["GITHUB_TOKEN"] |
| 31 | + |
| 32 | + self.owner = owner |
| 33 | + self.repo = repo |
| 34 | + |
| 35 | + # Headers for authentication |
| 36 | + self.headers = { |
| 37 | + "Authorization": f"token {self.token}", |
| 38 | + "Accept": "application/vnd.github.v3+json", |
| 39 | + } |
| 40 | + |
| 41 | + def fetch_commits(self, ref: str, since_date: str | datetime) -> List[GHCommit]: |
| 42 | + """ |
| 43 | + Fetch commits with PR numbers (for merge commits) |
| 44 | +
|
| 45 | + Args: |
| 46 | + ref (str): Branch, tag, or SHA |
| 47 | + since_date (str): Date in YYYY-MM-DD format |
| 48 | +
|
| 49 | + Returns: |
| 50 | + list: List of commits, sorted by date. |
| 51 | + """ |
| 52 | + base_url = f"{self.API_URL}/repos/{self.owner}/{self.repo}/commits" |
| 53 | + |
| 54 | + since_str = ( |
| 55 | + f"{since_date}T00:00:00Z" |
| 56 | + if isinstance(since_date, str) |
| 57 | + else since_date.isoformat(timespec="seconds").replace("+00:00", "Z") |
| 58 | + ) |
| 59 | + params = {"sha": ref, "since": since_str, "per_page": 100} |
| 60 | + logging.info( |
| 61 | + f"Fetching commits for {self.owner}/{self.repo}:{ref} since {since_str} ..." |
| 62 | + ) |
| 63 | + |
| 64 | + commits = [] |
| 65 | + page = 1 |
| 66 | + |
| 67 | + while True: |
| 68 | + params["page"] = page |
| 69 | + response = requests.get(base_url, headers=self.headers, params=params) |
| 70 | + response.raise_for_status() |
| 71 | + |
| 72 | + current_commits = response.json() |
| 73 | + if not current_commits: |
| 74 | + break |
| 75 | + |
| 76 | + for commit in current_commits: |
| 77 | + commits.append( |
| 78 | + GHCommit( |
| 79 | + commit["sha"], |
| 80 | + dtparser.parse(commit["commit"]["author"]["date"]), |
| 81 | + commit["commit"]["message"], |
| 82 | + ) |
| 83 | + ) |
| 84 | + |
| 85 | + page += 1 |
| 86 | + |
| 87 | + return commits |
| 88 | + |
| 89 | + def pr_for_commit(self, commit_sha: str) -> GHPullRequest | None: |
| 90 | + """ |
| 91 | + Fetch pull request information from GitHub API given the commit SHA. |
| 92 | +
|
| 93 | + Args: |
| 94 | + pr_number: Pull request number |
| 95 | +
|
| 96 | + Returns: |
| 97 | + GHPullRequest containing PR information, or None if no PR found |
| 98 | +
|
| 99 | + Raises: |
| 100 | + requests.exceptions.HTTPError: If the request fails |
| 101 | + """ |
| 102 | + logging.info( |
| 103 | + f"Fetching PR associated with {self.owner}/{self.repo}:{commit_sha} ..." |
| 104 | + ) |
| 105 | + pr_id = self.pr_id_for_commit(commit_sha) |
| 106 | + return self.fetch_pr(pr_id) if pr_id else None |
| 107 | + |
| 108 | + def pr_id_for_commit(self, commit_sha: str) -> int | None: |
| 109 | + """ |
| 110 | + Find PR number associated with a commit |
| 111 | + """ |
| 112 | + search_url = f"{self.API_URL}/search/issues" |
| 113 | + |
| 114 | + query = f"repo:{self.owner}/{self.repo} is:pr {commit_sha}" |
| 115 | + params = {"q": query} |
| 116 | + |
| 117 | + response = requests.get(search_url, headers=self.headers, params=params) |
| 118 | + response.raise_for_status() |
| 119 | + |
| 120 | + results = response.json() |
| 121 | + if results["total_count"] > 0: |
| 122 | + return results["items"][0]["number"] |
| 123 | + return None |
| 124 | + |
| 125 | + def fetch_pr(self, pr_number: int) -> GHPullRequest: |
| 126 | + """ |
| 127 | + Fetch pull request information from GitHub API. |
| 128 | +
|
| 129 | + Args: |
| 130 | + pr_number: Pull request number |
| 131 | +
|
| 132 | + Returns: |
| 133 | + GHPullRequest containing PR information |
| 134 | +
|
| 135 | + Raises: |
| 136 | + requests.exceptions.HTTPError: If the request fails |
| 137 | + """ |
| 138 | + url = f"{self.API_URL}/repos/{self.owner}/{self.repo}/pulls/{pr_number}" |
| 139 | + response = requests.get(url, headers=self.headers) |
| 140 | + response.raise_for_status() # Raises exception for 4XX/5XX errors |
| 141 | + |
| 142 | + # Get labels (requires separate API call) |
| 143 | + labels_url = ( |
| 144 | + f"{self.API_URL}/repos/{self.owner}/{self.repo}/issues/{pr_number}/labels" |
| 145 | + ) |
| 146 | + labels_response = requests.get(labels_url, headers=self.headers) |
| 147 | + labels = ( |
| 148 | + [label["name"] for label in labels_response.json()] |
| 149 | + if labels_response.ok |
| 150 | + else [] |
| 151 | + ) |
| 152 | + |
| 153 | + pr_info = response.json() |
| 154 | + closed_at = pr_info.get("closed_at") |
| 155 | + closed_at = dtparser.parse(closed_at) if closed_at else None |
| 156 | + |
| 157 | + return GHPullRequest( |
| 158 | + pr_info["title"], pr_number, closed_at, set(labels), pr_info["base"]["ref"] |
| 159 | + ).cleaned() |
| 160 | + |
| 161 | + def fetch_prs(self, ref: str, since_date: str | datetime) -> List[GHPullRequest]: |
| 162 | + """ |
| 163 | + Fetch PRs merged into a given branch within a date range. |
| 164 | +
|
| 165 | + Args: |
| 166 | + ref (str): Branch, tag, or SHA |
| 167 | + since_date (str): Date in YYYY-MM-DD format |
| 168 | +
|
| 169 | + Returns: |
| 170 | + list: List of PRs, sorted by date. |
| 171 | + """ |
| 172 | + commits = self.fetch_commits(ref, since_date) |
| 173 | + prs = [self.pr_for_commit(c.sha) for c in commits] |
| 174 | + return [pr for pr in prs if pr is not None] |
| 175 | + |
| 176 | + def fetch_unlabeled_commits( |
| 177 | + self, ref: str, since_date: str | datetime |
| 178 | + ) -> List[Tuple[GHCommit, GHPullRequest]]: |
| 179 | + """ |
| 180 | + Fetch commitswith no PR labels, along with their associated PRs. |
| 181 | +
|
| 182 | + Args: |
| 183 | + ref (str): Branch, tag, or SHA |
| 184 | + since_date (str): Date in YYYY-MM-DD format |
| 185 | +
|
| 186 | + Returns: |
| 187 | + List of (GHCommit, GHPullRequest) tuples, sorted by commit date. |
| 188 | + """ |
| 189 | + commits = self.fetch_commits(ref, since_date) |
| 190 | + prs = [self.pr_for_commit(c.sha) for c in commits] |
| 191 | + return [ |
| 192 | + (c, pr) for c, pr in zip(commits, prs) if pr is not None and pr.unlabeled() |
| 193 | + ] |
0 commit comments