14
14
from github .Tag import Tag
15
15
from github .Workflow import Workflow
16
16
17
- from codegen .git .clients .github_client_factory import GithubClientFactory
18
- from codegen .git .clients .types import GithubClientType
19
- from codegen .git .schemas .github import GithubScope , GithubType
17
+ from codegen .git .clients .github_client import GithubClient
20
18
from codegen .git .schemas .repo_config import RepoConfig
21
19
from codegen .git .utils .format import format_comparison
22
20
@@ -27,33 +25,27 @@ class GitRepoClient:
27
25
"""Wrapper around PyGithub's Remote Repository."""
28
26
29
27
repo_config : RepoConfig
30
- github_type : GithubType = GithubType .GithubEnterprise
31
- gh_client : GithubClientType
32
- read_client : Repository
33
- access_scope : GithubScope
34
- __write_client : Repository | None # Will not be initialized if access scope is read-only
28
+ gh_client : GithubClient
29
+ _repo : Repository
35
30
36
- def __init__ (self , repo_config : RepoConfig , github_type : GithubType = GithubType . GithubEnterprise , access_scope : GithubScope = GithubScope . READ ) -> None :
31
+ def __init__ (self , repo_config : RepoConfig ) -> None :
37
32
self .repo_config = repo_config
38
- self .github_type = github_type
39
- self .gh_client = GithubClientFactory . create_from_repo ( self .repo_config , github_type )
40
- self . read_client = self . _create_client ( GithubScope . READ )
41
- self . __write_client = self . _create_client ( GithubScope . WRITE ) if access_scope == GithubScope . WRITE else None
42
- self . access_scope = access_scope
43
-
44
- def _create_client (self , github_scope : GithubScope = GithubScope . READ ) -> Repository :
45
- client = self .gh_client .get_repo_by_full_name (self .repo_config .full_name , github_scope = github_scope )
33
+ self .gh_client = self . _create_github_client ()
34
+ self ._repo = self ._create_client ( )
35
+
36
+ def _create_github_client ( self ) -> GithubClient :
37
+ return GithubClient ()
38
+
39
+ def _create_client (self ) -> Repository :
40
+ client = self .gh_client .get_repo_by_full_name (self .repo_config .full_name )
46
41
if not client :
47
- msg = f"Repo { self .repo_config .full_name } not found in { self . github_type . value } !"
42
+ msg = f"Repo { self .repo_config .full_name } not found!"
48
43
raise ValueError (msg )
49
44
return client
50
45
51
46
@property
52
- def _write_client (self ) -> Repository :
53
- if self .__write_client is None :
54
- msg = "Cannot perform write operations with read-only client! Try setting github_scope to GithubScope.WRITE."
55
- raise ValueError (msg )
56
- return self .__write_client
47
+ def repo (self ) -> Repository :
48
+ return self ._repo
57
49
58
50
####################################################################################################################
59
51
# PROPERTIES
@@ -65,7 +57,7 @@ def id(self) -> int:
65
57
66
58
@property
67
59
def default_branch (self ) -> str :
68
- return self .read_client .default_branch
60
+ return self .repo .default_branch
69
61
70
62
####################################################################################################################
71
63
# CONTENTS
@@ -76,7 +68,7 @@ def get_contents(self, file_path: str, ref: str | None = None) -> str | None:
76
68
if not ref :
77
69
ref = self .default_branch
78
70
try :
79
- file = self .read_client .get_contents (file_path , ref = ref )
71
+ file = self .repo .get_contents (file_path , ref = ref )
80
72
file_contents = file .decoded_content .decode ("utf-8" ) # type: ignore[union-attr]
81
73
return file_contents
82
74
except UnknownObjectException :
@@ -100,7 +92,7 @@ def get_last_modified_date_of_path(self, path: str) -> datetime:
100
92
str: The last modified date of the directory in ISO format (YYYY-MM-DDTHH:MM:SSZ).
101
93
102
94
"""
103
- commits = self .read_client .get_commits (path = path )
95
+ commits = self .repo .get_commits (path = path )
104
96
if commits .totalCount > 0 :
105
97
# Get the date of the latest commit
106
98
last_modified_date = commits [0 ].commit .committer .date
@@ -124,7 +116,7 @@ def create_review_comment(
124
116
start_line : Opt [int ] = NotSet ,
125
117
) -> None :
126
118
# TODO: add protections (ex: can write to PR)
127
- writeable_pr = self ._write_client .get_pull (pull .number )
119
+ writeable_pr = self .repo .get_pull (pull .number )
128
120
writeable_pr .create_review_comment (
129
121
body = body ,
130
122
commit = commit ,
@@ -140,7 +132,7 @@ def create_issue_comment(
140
132
body : str ,
141
133
) -> None :
142
134
# TODO: add protections (ex: can write to PR)
143
- writeable_pr = self ._write_client .get_pull (pull .number )
135
+ writeable_pr = self .repo .get_pull (pull .number )
144
136
writeable_pr .create_issue_comment (body = body )
145
137
146
138
####################################################################################################################
@@ -163,7 +155,7 @@ def get_pull_by_branch_and_state(
163
155
head_branch_name = f"{ self .repo_config .organization_name } :{ head_branch_name } "
164
156
165
157
# retrieve all pulls ordered by created descending
166
- prs = self .read_client .get_pulls (base = base_branch_name , head = head_branch_name , state = state , sort = "created" , direction = "desc" )
158
+ prs = self .repo .get_pulls (base = base_branch_name , head = head_branch_name , state = state , sort = "created" , direction = "desc" )
167
159
if prs .totalCount > 0 :
168
160
return prs [0 ]
169
161
else :
@@ -174,7 +166,7 @@ def get_pull_safe(self, number: int) -> PullRequest | None:
174
166
TODO: catching UnknownObjectException is common enough to create a decorator
175
167
"""
176
168
try :
177
- pr = self .read_client .get_pull (number )
169
+ pr = self .repo .get_pull (number )
178
170
return pr
179
171
except UnknownObjectException as e :
180
172
return None
@@ -209,10 +201,10 @@ def create_pull(
209
201
if base_branch_name is None :
210
202
base_branch_name = self .default_branch
211
203
try :
212
- pr = self ._write_client .create_pull (title = title or f"Draft PR for { head_branch_name } " , body = body or "" , head = head_branch_name , base = base_branch_name , draft = draft )
204
+ pr = self .repo .create_pull (title = title or f"Draft PR for { head_branch_name } " , body = body or "" , head = head_branch_name , base = base_branch_name , draft = draft )
213
205
logger .info (f"Created pull request for head branch: { head_branch_name } at { pr .html_url } " )
214
206
# NOTE: return a read-only copy to prevent people from editing it
215
- return self .read_client .get_pull (pr .number )
207
+ return self .repo .get_pull (pr .number )
216
208
except GithubException as ge :
217
209
logger .warning (f"Failed to create PR got GithubException\n \t { ge } " )
218
210
except Exception as e :
@@ -235,15 +227,15 @@ def squash_and_merge(self, base_branch_name: str, head_branch_name: str, squash_
235
227
merge = squash_pr .merge (commit_message = squash_commit_msg , commit_title = squash_commit_title , merge_method = "squash" ) # type: ignore[arg-type]
236
228
237
229
def edit_pull (self , pull : PullRequest , title : Opt [str ] = NotSet , body : Opt [str ] = NotSet , state : Opt [str ] = NotSet ) -> None :
238
- writable_pr = self ._write_client .get_pull (pull .number )
230
+ writable_pr = self .repo .get_pull (pull .number )
239
231
writable_pr .edit (title = title , body = body , state = state )
240
232
241
233
def add_label_to_pull (self , pull : PullRequest , label : Label ) -> None :
242
- writeable_pr = self ._write_client .get_pull (pull .number )
234
+ writeable_pr = self .repo .get_pull (pull .number )
243
235
writeable_pr .add_to_labels (label )
244
236
245
237
def remove_label_from_pull (self , pull : PullRequest , label : Label ) -> None :
246
- writeable_pr = self ._write_client .get_pull (pull .number )
238
+ writeable_pr = self .repo .get_pull (pull .number )
247
239
writeable_pr .remove_from_labels (label )
248
240
249
241
####################################################################################################################
@@ -264,7 +256,7 @@ def get_or_create_branch(self, new_branch_name: str, base_branch_name: str | Non
264
256
def get_branch_safe (self , branch_name : str , attempts : int = 1 , wait_seconds : int = 1 ) -> Branch | None :
265
257
for i in range (attempts ):
266
258
try :
267
- return self .read_client .get_branch (branch_name )
259
+ return self .repo .get_branch (branch_name )
268
260
except GithubException as e :
269
261
if e .status == 404 and i < attempts - 1 :
270
262
time .sleep (wait_seconds )
@@ -276,14 +268,14 @@ def create_branch(self, new_branch_name: str, base_branch_name: str | None = Non
276
268
if base_branch_name is None :
277
269
base_branch_name = self .default_branch
278
270
279
- base_branch = self .read_client .get_branch (base_branch_name )
271
+ base_branch = self .repo .get_branch (base_branch_name )
280
272
# TODO: also wrap git ref. low pri b/c the only write operation on refs is creating one
281
- self ._write_client .create_git_ref (sha = base_branch .commit .sha , ref = f"refs/heads/{ new_branch_name } " )
273
+ self .repo .create_git_ref (sha = base_branch .commit .sha , ref = f"refs/heads/{ new_branch_name } " )
282
274
branch = self .get_branch_safe (new_branch_name )
283
275
return branch
284
276
285
277
def create_branch_from_sha (self , new_branch_name : str , base_sha : str ) -> Branch | None :
286
- self ._write_client .create_git_ref (ref = f"refs/heads/{ new_branch_name } " , sha = base_sha )
278
+ self .repo .create_git_ref (ref = f"refs/heads/{ new_branch_name } " , sha = base_sha )
287
279
branch = self .get_branch_safe (new_branch_name )
288
280
return branch
289
281
@@ -295,7 +287,7 @@ def delete_branch(self, branch_name: str) -> None:
295
287
296
288
branch_to_delete = self .get_branch_safe (branch_name )
297
289
if branch_to_delete :
298
- ref_to_delete = self ._write_client .get_git_ref (f"heads/{ branch_name } " )
290
+ ref_to_delete = self .repo .get_git_ref (f"heads/{ branch_name } " )
299
291
ref_to_delete .delete ()
300
292
logger .info (f"Branch: { branch_name } deleted successfully!" )
301
293
else :
@@ -307,7 +299,7 @@ def delete_branch(self, branch_name: str) -> None:
307
299
308
300
def get_commit_safe (self , commit_sha : str ) -> Commit | None :
309
301
try :
310
- return self .read_client .get_commit (commit_sha )
302
+ return self .repo .get_commit (commit_sha )
311
303
except UnknownObjectException as e :
312
304
logger .warning (f"Commit { commit_sha } not found:\n \t { e } " )
313
305
return None
@@ -338,7 +330,7 @@ def compare_branches(self, base_branch_name: str | None, head_branch_name: str,
338
330
339
331
# NOTE: base utility that other compare functions should try to use
340
332
def compare (self , base : str , head : str , show_commits : bool = False ) -> str :
341
- comparison = self .read_client .compare (base , head )
333
+ comparison = self .repo .compare (base , head )
342
334
return format_comparison (comparison , show_commits = show_commits )
343
335
344
336
####################################################################################################################
@@ -349,7 +341,7 @@ def compare(self, base: str, head: str, show_commits: bool = False) -> str:
349
341
def get_label_safe (self , label_name : str ) -> Label | None :
350
342
try :
351
343
label_name = label_name .strip ()
352
- label = self .read_client .get_label (label_name )
344
+ label = self .repo .get_label (label_name )
353
345
return label
354
346
except UnknownObjectException as e :
355
347
return None
@@ -360,10 +352,10 @@ def get_label_safe(self, label_name: str) -> Label | None:
360
352
def create_label (self , label_name : str , color : str ) -> Label :
361
353
# TODO: also offer description field
362
354
label_name = label_name .strip ()
363
- self ._write_client .create_label (label_name , color )
355
+ self .repo .create_label (label_name , color )
364
356
# TODO: is there a way to convert new_label to a read-only label without making another API call?
365
357
# NOTE: return a read-only label to prevent people from editing it
366
- return self .read_client .get_label (label_name )
358
+ return self .repo .get_label (label_name )
367
359
368
360
def get_or_create_label (self , label_name : str , color : str ) -> Label :
369
361
existing_label = self .get_label_safe (label_name )
@@ -377,7 +369,7 @@ def get_or_create_label(self, label_name: str, color: str) -> Label:
377
369
378
370
def get_check_suite_safe (self , check_suite_id : int ) -> CheckSuite | None :
379
371
try :
380
- return self .read_client .get_check_suite (check_suite_id )
372
+ return self .repo .get_check_suite (check_suite_id )
381
373
except UnknownObjectException as e :
382
374
return None
383
375
except Exception as e :
@@ -390,7 +382,7 @@ def get_check_suite_safe(self, check_suite_id: int) -> CheckSuite | None:
390
382
391
383
def get_check_run_safe (self , check_run_id : int ) -> CheckRun | None :
392
384
try :
393
- return self .read_client .get_check_run (check_run_id )
385
+ return self .repo .get_check_run (check_run_id )
394
386
except UnknownObjectException as e :
395
387
return None
396
388
except Exception as e :
@@ -406,24 +398,24 @@ def create_check_run(
406
398
conclusion : Opt [str ] = NotSet ,
407
399
output : Opt [dict [str , str | list [dict [str , str | int ]]]] = NotSet ,
408
400
) -> CheckRun :
409
- new_check_run = self ._write_client .create_check_run (name = name , head_sha = head_sha , details_url = details_url , status = status , conclusion = conclusion , output = output )
410
- return self .read_client .get_check_run (new_check_run .id )
401
+ new_check_run = self .repo .create_check_run (name = name , head_sha = head_sha , details_url = details_url , status = status , conclusion = conclusion , output = output )
402
+ return self .repo .get_check_run (new_check_run .id )
411
403
412
404
####################################################################################################################
413
405
# WORKFLOW
414
406
####################################################################################################################
415
407
416
408
def get_workflow_safe (self , file_name : str ) -> Workflow | None :
417
409
try :
418
- return self .read_client .get_workflow (file_name )
410
+ return self .repo .get_workflow (file_name )
419
411
except UnknownObjectException as e :
420
412
return None
421
413
except Exception as e :
422
414
logger .warning (f"Error getting workflow by file name: { file_name } \n \t { e } " )
423
415
return None
424
416
425
417
def create_workflow_dispatch (self , workflow : Workflow , ref : Branch | Tag | Commit | str , inputs : Opt [dict ] = NotSet ):
426
- writeable_workflow = self ._write_client .get_workflow (workflow .id )
418
+ writeable_workflow = self .repo .get_workflow (workflow .id )
427
419
writeable_workflow .create_dispatch (ref = ref , inputs = inputs )
428
420
429
421
####################################################################################################################
@@ -439,5 +431,5 @@ def merge_upstream(self, branch_name: str) -> bool:
439
431
"""
440
432
assert isinstance (branch_name , str ), branch_name
441
433
post_parameters = {"branch" : branch_name }
442
- status , _ , _ = self ._write_client ._requester .requestJson ("POST" , f"{ self ._write_client .url } /merge-upstream" , input = post_parameters )
434
+ status , _ , _ = self .repo ._requester .requestJson ("POST" , f"{ self .repo .url } /merge-upstream" , input = post_parameters )
443
435
return status == 200
0 commit comments