@@ -116,12 +116,12 @@ func (h *Header) encode() (string, error) {
116
116
// Decode decodes a claim set from a JWS payload.
117
117
func Decode (payload string ) (* ClaimSet , error ) {
118
118
// decode returned id token to get expiry
119
- s := strings . Split (payload , "." )
120
- if len ( s ) < 2 {
119
+ _ , claims , _ , ok := parseToken (payload )
120
+ if ! ok {
121
121
// TODO(jbd): Provide more context about the error.
122
122
return nil , errors .New ("jws: invalid token received" )
123
123
}
124
- decoded , err := base64 .RawURLEncoding .DecodeString (s [ 1 ] )
124
+ decoded , err := base64 .RawURLEncoding .DecodeString (claims )
125
125
if err != nil {
126
126
return nil , err
127
127
}
@@ -165,18 +165,34 @@ func Encode(header *Header, c *ClaimSet, key *rsa.PrivateKey) (string, error) {
165
165
// Verify tests whether the provided JWT token's signature was produced by the private key
166
166
// associated with the supplied public key.
167
167
func Verify (token string , key * rsa.PublicKey ) error {
168
- if strings .Count (token , "." ) != 2 {
168
+ header , claims , sig , ok := parseToken (token )
169
+ if ! ok {
169
170
return errors .New ("jws: invalid token received, token must have 3 parts" )
170
171
}
171
-
172
- parts := strings .SplitN (token , "." , 3 )
173
- signedContent := parts [0 ] + "." + parts [1 ]
174
- signatureString , err := base64 .RawURLEncoding .DecodeString (parts [2 ])
172
+ signatureString , err := base64 .RawURLEncoding .DecodeString (sig )
175
173
if err != nil {
176
174
return err
177
175
}
178
176
179
177
h := sha256 .New ()
180
- h .Write ([]byte (signedContent ))
178
+ h .Write ([]byte (header + tokenDelim + claims ))
181
179
return rsa .VerifyPKCS1v15 (key , crypto .SHA256 , h .Sum (nil ), signatureString )
182
180
}
181
+
182
+ func parseToken (s string ) (header , claims , sig string , ok bool ) {
183
+ header , s , ok = strings .Cut (s , tokenDelim )
184
+ if ! ok { // no period found
185
+ return "" , "" , "" , false
186
+ }
187
+ claims , s , ok = strings .Cut (s , tokenDelim )
188
+ if ! ok { // only one period found
189
+ return "" , "" , "" , false
190
+ }
191
+ sig , _ , ok = strings .Cut (s , tokenDelim )
192
+ if ok { // three periods found
193
+ return "" , "" , "" , false
194
+ }
195
+ return header , claims , sig , true
196
+ }
197
+
198
+ const tokenDelim = "."
0 commit comments