@@ -75,6 +75,7 @@ func (ex *Executor) makeRequest(method string, url string) (res *http.Response,
75
75
if err != nil {
76
76
return nil , err
77
77
}
78
+
78
79
req .Header .Add ("Authorization" , "token " + ex .token )
79
80
res , _ = ex .client .Do (req )
80
81
if res .StatusCode >= 400 {
@@ -92,7 +93,7 @@ func (ex *Executor) listClosedPullRequests(user string, repo string, days int) (
92
93
res , err := ex .makeRequest ("GET" , "repos/" + user + "/" + repo + "/pulls?state=closed&sort=updated&direction=desc&per_page=100&page=" + strconv .Itoa (page ))
93
94
94
95
if err != nil {
95
- return pullRequests , errors .New ("failed to get pull requests (" + err .Error () + ")" )
96
+ return pullRequests , errors .New ("failed to get closed pull requests (" + err .Error () + ")" )
96
97
}
97
98
98
99
d := json .NewDecoder (res .Body )
@@ -102,7 +103,7 @@ func (ex *Executor) listClosedPullRequests(user string, repo string, days int) (
102
103
err = d .Decode (& prs .PullRequests )
103
104
104
105
if err != nil {
105
- return pullRequests , errors .New ("failed to parse pull requests (" + err .Error () + ")" )
106
+ return pullRequests , errors .New ("failed to parse closed pull requests (" + err .Error () + ")" )
106
107
}
107
108
108
109
for _ , pr := range prs .PullRequests {
@@ -122,6 +123,36 @@ func (ex *Executor) listClosedPullRequests(user string, repo string, days int) (
122
123
return pullRequests , nil
123
124
}
124
125
126
+ func (ex * Executor ) listOpenPullRequests (user string , repo string ) ([]pullRequest , error ) {
127
+ pullRequests := make ([]pullRequest , 0 , 1 )
128
+
129
+ for page , keepGoing := 1 , true ; keepGoing ; page ++ {
130
+ res , err := ex .makeRequest ("GET" , "repos/" + user + "/" + repo + "/pulls?state=open&sort=updated&direction=desc&per_page=100&page=" + strconv .Itoa (page ))
131
+
132
+ if err != nil {
133
+ return pullRequests , errors .New ("failed to get open pull requests (" + err .Error () + ")" )
134
+ }
135
+
136
+ d := json .NewDecoder (res .Body )
137
+ var prs struct {
138
+ PullRequests []pullRequest
139
+ }
140
+ err = d .Decode (& prs .PullRequests )
141
+
142
+ if err != nil {
143
+ return pullRequests , errors .New ("failed to parse open pull requests (" + err .Error () + ")" )
144
+ }
145
+
146
+ pullRequests = append (pullRequests , prs .PullRequests ... )
147
+
148
+ if len (prs .PullRequests ) == 0 || len (prs .PullRequests ) < 100 {
149
+ break
150
+ }
151
+ }
152
+
153
+ return pullRequests , nil
154
+ }
155
+
125
156
func (ex * Executor ) listUnprotectedBranches (user string , repo string ) ([]branch , error ) {
126
157
branches := make ([]branch , 0 , 1 )
127
158
@@ -174,18 +205,33 @@ func (ex *Executor) deleteBranches(user string, repo string, branches []string)
174
205
return deletedBranches , nil
175
206
}
176
207
177
- func getStaleBranches (branches []branch , pullRequests []pullRequest ) []string {
178
- branchesByName := make (map [string ]branch )
179
- staleBranches := make ([]string , 0 , 1 )
208
+ func getStaleBranches (closedBranches []branch , closedPullRequests []pullRequest , openPullRequests []pullRequest ) []string {
209
+ branchShaMap := make (map [string ]string ) // Map[branch_name]branch_SHA
210
+ staleBranchMap := make (map [string ]bool ) // Map[branch_name]is_stale
211
+
212
+ for _ , b := range closedBranches {
213
+ branchShaMap [b .Name ] = b .Commit .SHA
214
+ }
215
+
216
+ for _ , pr := range closedPullRequests {
217
+ staleBranchSHA , branchExists := branchShaMap [pr .Head .Ref ]
218
+ if branchExists && staleBranchSHA == pr .Head .SHA {
219
+ staleBranchMap [pr .Head .Ref ] = true
220
+ }
221
+ }
180
222
181
- for _ , b := range branches {
182
- branchesByName [b .Name ] = b
223
+ // If we've marked this branch as stale, but there is another open PR tied to the branch, unmark as stale
224
+ for _ , pr := range openPullRequests {
225
+ _ , branchExists := staleBranchMap [pr .Head .Ref ]
226
+ if branchExists {
227
+ staleBranchMap [pr .Head .Ref ] = false
228
+ }
183
229
}
184
230
185
- for _ , pr := range pullRequests {
186
- staleBranch , branchExists := branchesByName [ pr . Head . Ref ]
187
- if branchExists && staleBranch . Commit . SHA == pr . Head . SHA {
188
- staleBranches = append (staleBranches , pr . Head . Ref )
231
+ staleBranches := make ([] string , 0 , 1 )
232
+ for branch , isStale := range staleBranchMap {
233
+ if isStale {
234
+ staleBranches = append (staleBranches , branch )
189
235
}
190
236
}
191
237
@@ -269,11 +315,18 @@ func Run(user string, repo string, days int, ex Executor) error {
269
315
if err != nil {
270
316
return err
271
317
}
318
+
319
+ openPullRequests , err := ex .listOpenPullRequests (user , repo )
320
+ if err != nil {
321
+ return err
322
+ }
323
+
272
324
unprotectedBranches , err := ex .listUnprotectedBranches (user , repo )
273
325
if err != nil {
274
326
return err
275
327
}
276
- staleBranches := getStaleBranches (unprotectedBranches , closedPullRequests )
328
+
329
+ staleBranches := getStaleBranches (unprotectedBranches , closedPullRequests , openPullRequests )
277
330
db , err := ex .deleteBranches (user , repo , staleBranches )
278
331
if err != nil {
279
332
return err
0 commit comments