Skip to content

Commit c8c0b89

Browse files
authored
CG-10694: Remove lowside + enterprise from codegen.git (#349)
1 parent ff0727a commit c8c0b89

36 files changed

+456
-523
lines changed

.github/workflows/unit-tests.yml

+1
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ jobs:
141141
timeout-minutes: 5
142142
env:
143143
GITHUB_WORKSPACE: $GITHUB_WORKSPACE
144+
GITHUB_TOKEN: ${{ secrets.GHA_PAT }}
144145
run: |
145146
uv run pytest \
146147
-n auto \

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ dev-dependencies = [
143143
"isort>=5.13.2",
144144
"emoji>=2.14.0",
145145
"pytest-benchmark[histogram]>=5.1.0",
146+
"pytest-asyncio<1.0.0,>=0.21.1",
146147
"loguru>=0.7.3",
147148
"httpx<0.28.2,>=0.28.1",
148149
]

src/codegen/git/clients/git_integration_client.py

-56
This file was deleted.

src/codegen/git/clients/git_repo_client.py

+44-52
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@
1414
from github.Tag import Tag
1515
from github.Workflow import Workflow
1616

17-
from codegen.git.clients.github_client_factory import GithubClientFactory
18-
from codegen.git.clients.types import GithubClientType
19-
from codegen.git.schemas.github import GithubScope, GithubType
17+
from codegen.git.clients.github_client import GithubClient
2018
from codegen.git.schemas.repo_config import RepoConfig
2119
from codegen.git.utils.format import format_comparison
2220

@@ -27,33 +25,27 @@ class GitRepoClient:
2725
"""Wrapper around PyGithub's Remote Repository."""
2826

2927
repo_config: RepoConfig
30-
github_type: GithubType = GithubType.GithubEnterprise
31-
gh_client: GithubClientType
32-
read_client: Repository
33-
access_scope: GithubScope
34-
__write_client: Repository | None # Will not be initialized if access scope is read-only
28+
gh_client: GithubClient
29+
_repo: Repository
3530

36-
def __init__(self, repo_config: RepoConfig, github_type: GithubType = GithubType.GithubEnterprise, access_scope: GithubScope = GithubScope.READ) -> None:
31+
def __init__(self, repo_config: RepoConfig) -> None:
3732
self.repo_config = repo_config
38-
self.github_type = github_type
39-
self.gh_client = GithubClientFactory.create_from_repo(self.repo_config, github_type)
40-
self.read_client = self._create_client(GithubScope.READ)
41-
self.__write_client = self._create_client(GithubScope.WRITE) if access_scope == GithubScope.WRITE else None
42-
self.access_scope = access_scope
43-
44-
def _create_client(self, github_scope: GithubScope = GithubScope.READ) -> Repository:
45-
client = self.gh_client.get_repo_by_full_name(self.repo_config.full_name, github_scope=github_scope)
33+
self.gh_client = self._create_github_client()
34+
self._repo = self._create_client()
35+
36+
def _create_github_client(self) -> GithubClient:
37+
return GithubClient()
38+
39+
def _create_client(self) -> Repository:
40+
client = self.gh_client.get_repo_by_full_name(self.repo_config.full_name)
4641
if not client:
47-
msg = f"Repo {self.repo_config.full_name} not found in {self.github_type.value}!"
42+
msg = f"Repo {self.repo_config.full_name} not found!"
4843
raise ValueError(msg)
4944
return client
5045

5146
@property
52-
def _write_client(self) -> Repository:
53-
if self.__write_client is None:
54-
msg = "Cannot perform write operations with read-only client! Try setting github_scope to GithubScope.WRITE."
55-
raise ValueError(msg)
56-
return self.__write_client
47+
def repo(self) -> Repository:
48+
return self._repo
5749

5850
####################################################################################################################
5951
# PROPERTIES
@@ -65,7 +57,7 @@ def id(self) -> int:
6557

6658
@property
6759
def default_branch(self) -> str:
68-
return self.read_client.default_branch
60+
return self.repo.default_branch
6961

7062
####################################################################################################################
7163
# CONTENTS
@@ -76,7 +68,7 @@ def get_contents(self, file_path: str, ref: str | None = None) -> str | None:
7668
if not ref:
7769
ref = self.default_branch
7870
try:
79-
file = self.read_client.get_contents(file_path, ref=ref)
71+
file = self.repo.get_contents(file_path, ref=ref)
8072
file_contents = file.decoded_content.decode("utf-8") # type: ignore[union-attr]
8173
return file_contents
8274
except UnknownObjectException:
@@ -100,7 +92,7 @@ def get_last_modified_date_of_path(self, path: str) -> datetime:
10092
str: The last modified date of the directory in ISO format (YYYY-MM-DDTHH:MM:SSZ).
10193
10294
"""
103-
commits = self.read_client.get_commits(path=path)
95+
commits = self.repo.get_commits(path=path)
10496
if commits.totalCount > 0:
10597
# Get the date of the latest commit
10698
last_modified_date = commits[0].commit.committer.date
@@ -124,7 +116,7 @@ def create_review_comment(
124116
start_line: Opt[int] = NotSet,
125117
) -> None:
126118
# TODO: add protections (ex: can write to PR)
127-
writeable_pr = self._write_client.get_pull(pull.number)
119+
writeable_pr = self.repo.get_pull(pull.number)
128120
writeable_pr.create_review_comment(
129121
body=body,
130122
commit=commit,
@@ -140,7 +132,7 @@ def create_issue_comment(
140132
body: str,
141133
) -> None:
142134
# TODO: add protections (ex: can write to PR)
143-
writeable_pr = self._write_client.get_pull(pull.number)
135+
writeable_pr = self.repo.get_pull(pull.number)
144136
writeable_pr.create_issue_comment(body=body)
145137

146138
####################################################################################################################
@@ -163,7 +155,7 @@ def get_pull_by_branch_and_state(
163155
head_branch_name = f"{self.repo_config.organization_name}:{head_branch_name}"
164156

165157
# retrieve all pulls ordered by created descending
166-
prs = self.read_client.get_pulls(base=base_branch_name, head=head_branch_name, state=state, sort="created", direction="desc")
158+
prs = self.repo.get_pulls(base=base_branch_name, head=head_branch_name, state=state, sort="created", direction="desc")
167159
if prs.totalCount > 0:
168160
return prs[0]
169161
else:
@@ -174,7 +166,7 @@ def get_pull_safe(self, number: int) -> PullRequest | None:
174166
TODO: catching UnknownObjectException is common enough to create a decorator
175167
"""
176168
try:
177-
pr = self.read_client.get_pull(number)
169+
pr = self.repo.get_pull(number)
178170
return pr
179171
except UnknownObjectException as e:
180172
return None
@@ -209,10 +201,10 @@ def create_pull(
209201
if base_branch_name is None:
210202
base_branch_name = self.default_branch
211203
try:
212-
pr = self._write_client.create_pull(title=title or f"Draft PR for {head_branch_name}", body=body or "", head=head_branch_name, base=base_branch_name, draft=draft)
204+
pr = self.repo.create_pull(title=title or f"Draft PR for {head_branch_name}", body=body or "", head=head_branch_name, base=base_branch_name, draft=draft)
213205
logger.info(f"Created pull request for head branch: {head_branch_name} at {pr.html_url}")
214206
# NOTE: return a read-only copy to prevent people from editing it
215-
return self.read_client.get_pull(pr.number)
207+
return self.repo.get_pull(pr.number)
216208
except GithubException as ge:
217209
logger.warning(f"Failed to create PR got GithubException\n\t{ge}")
218210
except Exception as e:
@@ -235,15 +227,15 @@ def squash_and_merge(self, base_branch_name: str, head_branch_name: str, squash_
235227
merge = squash_pr.merge(commit_message=squash_commit_msg, commit_title=squash_commit_title, merge_method="squash") # type: ignore[arg-type]
236228

237229
def edit_pull(self, pull: PullRequest, title: Opt[str] = NotSet, body: Opt[str] = NotSet, state: Opt[str] = NotSet) -> None:
238-
writable_pr = self._write_client.get_pull(pull.number)
230+
writable_pr = self.repo.get_pull(pull.number)
239231
writable_pr.edit(title=title, body=body, state=state)
240232

241233
def add_label_to_pull(self, pull: PullRequest, label: Label) -> None:
242-
writeable_pr = self._write_client.get_pull(pull.number)
234+
writeable_pr = self.repo.get_pull(pull.number)
243235
writeable_pr.add_to_labels(label)
244236

245237
def remove_label_from_pull(self, pull: PullRequest, label: Label) -> None:
246-
writeable_pr = self._write_client.get_pull(pull.number)
238+
writeable_pr = self.repo.get_pull(pull.number)
247239
writeable_pr.remove_from_labels(label)
248240

249241
####################################################################################################################
@@ -264,7 +256,7 @@ def get_or_create_branch(self, new_branch_name: str, base_branch_name: str | Non
264256
def get_branch_safe(self, branch_name: str, attempts: int = 1, wait_seconds: int = 1) -> Branch | None:
265257
for i in range(attempts):
266258
try:
267-
return self.read_client.get_branch(branch_name)
259+
return self.repo.get_branch(branch_name)
268260
except GithubException as e:
269261
if e.status == 404 and i < attempts - 1:
270262
time.sleep(wait_seconds)
@@ -276,14 +268,14 @@ def create_branch(self, new_branch_name: str, base_branch_name: str | None = Non
276268
if base_branch_name is None:
277269
base_branch_name = self.default_branch
278270

279-
base_branch = self.read_client.get_branch(base_branch_name)
271+
base_branch = self.repo.get_branch(base_branch_name)
280272
# TODO: also wrap git ref. low pri b/c the only write operation on refs is creating one
281-
self._write_client.create_git_ref(sha=base_branch.commit.sha, ref=f"refs/heads/{new_branch_name}")
273+
self.repo.create_git_ref(sha=base_branch.commit.sha, ref=f"refs/heads/{new_branch_name}")
282274
branch = self.get_branch_safe(new_branch_name)
283275
return branch
284276

285277
def create_branch_from_sha(self, new_branch_name: str, base_sha: str) -> Branch | None:
286-
self._write_client.create_git_ref(ref=f"refs/heads/{new_branch_name}", sha=base_sha)
278+
self.repo.create_git_ref(ref=f"refs/heads/{new_branch_name}", sha=base_sha)
287279
branch = self.get_branch_safe(new_branch_name)
288280
return branch
289281

@@ -295,7 +287,7 @@ def delete_branch(self, branch_name: str) -> None:
295287

296288
branch_to_delete = self.get_branch_safe(branch_name)
297289
if branch_to_delete:
298-
ref_to_delete = self._write_client.get_git_ref(f"heads/{branch_name}")
290+
ref_to_delete = self.repo.get_git_ref(f"heads/{branch_name}")
299291
ref_to_delete.delete()
300292
logger.info(f"Branch: {branch_name} deleted successfully!")
301293
else:
@@ -307,7 +299,7 @@ def delete_branch(self, branch_name: str) -> None:
307299

308300
def get_commit_safe(self, commit_sha: str) -> Commit | None:
309301
try:
310-
return self.read_client.get_commit(commit_sha)
302+
return self.repo.get_commit(commit_sha)
311303
except UnknownObjectException as e:
312304
logger.warning(f"Commit {commit_sha} not found:\n\t{e}")
313305
return None
@@ -338,7 +330,7 @@ def compare_branches(self, base_branch_name: str | None, head_branch_name: str,
338330

339331
# NOTE: base utility that other compare functions should try to use
340332
def compare(self, base: str, head: str, show_commits: bool = False) -> str:
341-
comparison = self.read_client.compare(base, head)
333+
comparison = self.repo.compare(base, head)
342334
return format_comparison(comparison, show_commits=show_commits)
343335

344336
####################################################################################################################
@@ -349,7 +341,7 @@ def compare(self, base: str, head: str, show_commits: bool = False) -> str:
349341
def get_label_safe(self, label_name: str) -> Label | None:
350342
try:
351343
label_name = label_name.strip()
352-
label = self.read_client.get_label(label_name)
344+
label = self.repo.get_label(label_name)
353345
return label
354346
except UnknownObjectException as e:
355347
return None
@@ -360,10 +352,10 @@ def get_label_safe(self, label_name: str) -> Label | None:
360352
def create_label(self, label_name: str, color: str) -> Label:
361353
# TODO: also offer description field
362354
label_name = label_name.strip()
363-
self._write_client.create_label(label_name, color)
355+
self.repo.create_label(label_name, color)
364356
# TODO: is there a way to convert new_label to a read-only label without making another API call?
365357
# NOTE: return a read-only label to prevent people from editing it
366-
return self.read_client.get_label(label_name)
358+
return self.repo.get_label(label_name)
367359

368360
def get_or_create_label(self, label_name: str, color: str) -> Label:
369361
existing_label = self.get_label_safe(label_name)
@@ -377,7 +369,7 @@ def get_or_create_label(self, label_name: str, color: str) -> Label:
377369

378370
def get_check_suite_safe(self, check_suite_id: int) -> CheckSuite | None:
379371
try:
380-
return self.read_client.get_check_suite(check_suite_id)
372+
return self.repo.get_check_suite(check_suite_id)
381373
except UnknownObjectException as e:
382374
return None
383375
except Exception as e:
@@ -390,7 +382,7 @@ def get_check_suite_safe(self, check_suite_id: int) -> CheckSuite | None:
390382

391383
def get_check_run_safe(self, check_run_id: int) -> CheckRun | None:
392384
try:
393-
return self.read_client.get_check_run(check_run_id)
385+
return self.repo.get_check_run(check_run_id)
394386
except UnknownObjectException as e:
395387
return None
396388
except Exception as e:
@@ -406,24 +398,24 @@ def create_check_run(
406398
conclusion: Opt[str] = NotSet,
407399
output: Opt[dict[str, str | list[dict[str, str | int]]]] = NotSet,
408400
) -> CheckRun:
409-
new_check_run = self._write_client.create_check_run(name=name, head_sha=head_sha, details_url=details_url, status=status, conclusion=conclusion, output=output)
410-
return self.read_client.get_check_run(new_check_run.id)
401+
new_check_run = self.repo.create_check_run(name=name, head_sha=head_sha, details_url=details_url, status=status, conclusion=conclusion, output=output)
402+
return self.repo.get_check_run(new_check_run.id)
411403

412404
####################################################################################################################
413405
# WORKFLOW
414406
####################################################################################################################
415407

416408
def get_workflow_safe(self, file_name: str) -> Workflow | None:
417409
try:
418-
return self.read_client.get_workflow(file_name)
410+
return self.repo.get_workflow(file_name)
419411
except UnknownObjectException as e:
420412
return None
421413
except Exception as e:
422414
logger.warning(f"Error getting workflow by file name: {file_name}\n\t{e}")
423415
return None
424416

425417
def create_workflow_dispatch(self, workflow: Workflow, ref: Branch | Tag | Commit | str, inputs: Opt[dict] = NotSet):
426-
writeable_workflow = self._write_client.get_workflow(workflow.id)
418+
writeable_workflow = self.repo.get_workflow(workflow.id)
427419
writeable_workflow.create_dispatch(ref=ref, inputs=inputs)
428420

429421
####################################################################################################################
@@ -439,5 +431,5 @@ def merge_upstream(self, branch_name: str) -> bool:
439431
"""
440432
assert isinstance(branch_name, str), branch_name
441433
post_parameters = {"branch": branch_name}
442-
status, _, _ = self._write_client._requester.requestJson("POST", f"{self._write_client.url}/merge-upstream", input=post_parameters)
434+
status, _, _ = self.repo._requester.requestJson("POST", f"{self.repo.url}/merge-upstream", input=post_parameters)
443435
return status == 200

0 commit comments

Comments
 (0)