Skip to content

Commit ce56909

Browse files
Julien Cretelgopherbot
Julien Cretel
authored andcommitted
jws: improve fix for CVE-2025-22868
The fix for CVE-2025-22868 relies on strings.Count, which isn't ideal because it precludes failing fast when the token contains an unexpected number of periods. Moreover, Verify still allocates more than necessary. Eschew strings.Count in favor of strings.Cut. Some benchmark results: goos: darwin goarch: amd64 pkg: golang.org/x/oauth2/jws cpu: Intel(R) Core(TM) i7-6700HQ CPU @ 2.60GHz │ old │ new │ │ sec/op │ sec/op vs base │ Verify/full_of_periods-8 24862.50n ± 1% 57.87n ± 0% -99.77% (p=0.000 n=20) Verify/two_trailing_periods-8 3.485m ± 1% 3.445m ± 1% -1.13% (p=0.003 n=20) geomean 294.3µ 14.12µ -95.20% │ old │ new │ │ B/op │ B/op vs base │ Verify/full_of_periods-8 16.00 ± 0% 16.00 ± 0% ~ (p=1.000 n=20) ¹ Verify/two_trailing_periods-8 2.001Mi ± 0% 1.001Mi ± 0% -49.98% (p=0.000 n=20) geomean 5.658Ki 4.002Ki -29.27% ¹ all samples are equal │ old │ new │ │ allocs/op │ allocs/op vs base │ Verify/full_of_periods-8 1.000 ± 0% 1.000 ± 0% ~ (p=1.000 n=20) ¹ Verify/two_trailing_periods-8 12.000 ± 0% 9.000 ± 0% -25.00% (p=0.000 n=20) geomean 3.464 3.000 -13.40% ¹ all samples are equal Also, remove all remaining calls to strings.Split. Updates golang/go#71490 Change-Id: Icac3c7a81562161ab6533d892ba19247d6d5b943 GitHub-Last-Rev: 3a82900 GitHub-Pull-Request: #774 Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/655455 Commit-Queue: Neal Patel <[email protected]> Reviewed-by: Roland Shoemaker <[email protected]> LUCI-TryBot-Result: Go LUCI <[email protected]> Reviewed-by: Neal Patel <[email protected]> Auto-Submit: Neal Patel <[email protected]>
1 parent 0042180 commit ce56909

File tree

2 files changed

+79
-12
lines changed

2 files changed

+79
-12
lines changed

Diff for: jws/jws.go

+25-9
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,12 @@ func (h *Header) encode() (string, error) {
116116
// Decode decodes a claim set from a JWS payload.
117117
func Decode(payload string) (*ClaimSet, error) {
118118
// decode returned id token to get expiry
119-
s := strings.Split(payload, ".")
120-
if len(s) < 2 {
119+
_, claims, _, ok := parseToken(payload)
120+
if !ok {
121121
// TODO(jbd): Provide more context about the error.
122122
return nil, errors.New("jws: invalid token received")
123123
}
124-
decoded, err := base64.RawURLEncoding.DecodeString(s[1])
124+
decoded, err := base64.RawURLEncoding.DecodeString(claims)
125125
if err != nil {
126126
return nil, err
127127
}
@@ -165,18 +165,34 @@ func Encode(header *Header, c *ClaimSet, key *rsa.PrivateKey) (string, error) {
165165
// Verify tests whether the provided JWT token's signature was produced by the private key
166166
// associated with the supplied public key.
167167
func Verify(token string, key *rsa.PublicKey) error {
168-
if strings.Count(token, ".") != 2 {
168+
header, claims, sig, ok := parseToken(token)
169+
if !ok {
169170
return errors.New("jws: invalid token received, token must have 3 parts")
170171
}
171-
172-
parts := strings.SplitN(token, ".", 3)
173-
signedContent := parts[0] + "." + parts[1]
174-
signatureString, err := base64.RawURLEncoding.DecodeString(parts[2])
172+
signatureString, err := base64.RawURLEncoding.DecodeString(sig)
175173
if err != nil {
176174
return err
177175
}
178176

179177
h := sha256.New()
180-
h.Write([]byte(signedContent))
178+
h.Write([]byte(header + tokenDelim + claims))
181179
return rsa.VerifyPKCS1v15(key, crypto.SHA256, h.Sum(nil), signatureString)
182180
}
181+
182+
func parseToken(s string) (header, claims, sig string, ok bool) {
183+
header, s, ok = strings.Cut(s, tokenDelim)
184+
if !ok { // no period found
185+
return "", "", "", false
186+
}
187+
claims, s, ok = strings.Cut(s, tokenDelim)
188+
if !ok { // only one period found
189+
return "", "", "", false
190+
}
191+
sig, _, ok = strings.Cut(s, tokenDelim)
192+
if ok { // three periods found
193+
return "", "", "", false
194+
}
195+
return header, claims, sig, true
196+
}
197+
198+
const tokenDelim = "."

Diff for: jws/jws_test.go

+54-3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ package jws
77
import (
88
"crypto/rand"
99
"crypto/rsa"
10+
"net/http"
11+
"strings"
1012
"testing"
1113
)
1214

@@ -39,8 +41,57 @@ func TestSignAndVerify(t *testing.T) {
3941
}
4042

4143
func TestVerifyFailsOnMalformedClaim(t *testing.T) {
42-
err := Verify("abc.def", nil)
43-
if err == nil {
44-
t.Error("got no errors; want improperly formed JWT not to be verified")
44+
cases := []struct {
45+
desc string
46+
token string
47+
}{
48+
{
49+
desc: "no periods",
50+
token: "aa",
51+
}, {
52+
desc: "only one period",
53+
token: "a.a",
54+
}, {
55+
desc: "more than two periods",
56+
token: "a.a.a.a",
57+
},
58+
}
59+
for _, tc := range cases {
60+
f := func(t *testing.T) {
61+
err := Verify(tc.token, nil)
62+
if err == nil {
63+
t.Error("got no errors; want improperly formed JWT not to be verified")
64+
}
65+
}
66+
t.Run(tc.desc, f)
67+
}
68+
}
69+
70+
func BenchmarkVerify(b *testing.B) {
71+
cases := []struct {
72+
desc string
73+
token string
74+
}{
75+
{
76+
desc: "full of periods",
77+
token: strings.Repeat(".", http.DefaultMaxHeaderBytes),
78+
}, {
79+
desc: "two trailing periods",
80+
token: strings.Repeat("a", http.DefaultMaxHeaderBytes-2) + "..",
81+
},
82+
}
83+
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
84+
if err != nil {
85+
b.Fatal(err)
86+
}
87+
for _, bc := range cases {
88+
f := func(b *testing.B) {
89+
b.ReportAllocs()
90+
b.ResetTimer()
91+
for range b.N {
92+
Verify(bc.token, &privateKey.PublicKey)
93+
}
94+
}
95+
b.Run(bc.desc, f)
4596
}
4697
}

0 commit comments

Comments
 (0)