Skip to content

Commit 00bf849

Browse files
committed
Correcting AsResponseError method
1 parent 24238e7 commit 00bf849

File tree

2 files changed

+35
-7
lines changed

2 files changed

+35
-7
lines changed

errors.go

+23-6
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,29 @@ func newResponseError(resp *http.Response, target int, targets ...int) error {
3838
}
3939
}
4040

41-
// IsResponseError is convenience function to see
42-
// if it can convert into RequestError.
43-
func IsResponseError(err error) (*ResponseError, bool) {
44-
var re *ResponseError
45-
if errors.As(err, &re) {
46-
return err.(*ResponseError), true
41+
// AsResponseError is a convenience function to check the error
42+
// to see if it contains an `ResponseError` and returns the value with true.
43+
// If the error was initially joined using [errors.Join], it will check each error
44+
// within the list and return the first matching error.
45+
func AsResponseError(err error) (*ResponseError, bool) {
46+
// When `errors.Join` is called, it returns an error that
47+
// matches the provided interface.
48+
if joined, ok := err.(interface{ Unwrap() []error }); ok {
49+
for _, err := range joined.Unwrap() {
50+
if re, ok := AsResponseError(err); ok {
51+
return re, ok
52+
}
53+
}
54+
return nil, false
55+
}
56+
57+
for err != nil {
58+
if re, ok := err.(*ResponseError); ok {
59+
return re, true
60+
}
61+
// In case the error is wrapped using `fmt.Errorf`
62+
// this will also account for that.
63+
err = errors.Unwrap(err)
4764
}
4865
return nil, false
4966
}

errors_test.go

+12-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package signalfx
22

33
import (
44
"errors"
5+
"fmt"
56
"io"
67
"net/http"
78
"net/url"
@@ -130,12 +131,22 @@ func TestIsRequestError(t *testing.T) {
130131
err: &ResponseError{},
131132
expected: true,
132133
},
134+
{
135+
name: "joined errors",
136+
err: errors.Join(errors.New("boom"), &ResponseError{}),
137+
expected: true,
138+
},
139+
{
140+
name: "fmt error",
141+
err: fmt.Errorf("check permissions: %w", &ResponseError{}),
142+
expected: true,
143+
},
133144
} {
134145
tc := tc
135146
t.Run(tc.name, func(t *testing.T) {
136147
t.Parallel()
137148

138-
_, ok := IsResponseError(tc.err)
149+
_, ok := AsResponseError(tc.err)
139150
assert.Equal(t, tc.expected, ok, "Must match the expected value")
140151
})
141152
}

0 commit comments

Comments
 (0)