Skip to content

Commit 26f16aa

Browse files
committed
fix(surrogate): invalidation
1 parent af6ec14 commit 26f16aa

File tree

7 files changed

+37
-27
lines changed

7 files changed

+37
-27
lines changed

pkg/middleware/middleware.go

+19-14
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ func (s *SouinBaseHandler) Store(
190190
rq *http.Request,
191191
requestCc *cacheobject.RequestCacheDirectives,
192192
cachedKey string,
193+
uri string,
193194
) error {
194195
statusCode := customWriter.GetStatusCode()
195196
if !isCacheableCode(statusCode) {
@@ -341,6 +342,7 @@ func (s *SouinBaseHandler) Store(
341342
variedKey,
342343
) == nil {
343344
s.Configuration.GetLogger().Sugar().Debugf("Stored the key %s in the %s provider", variedKey, currentStorer.Name())
345+
res.Request = rq
344346
} else {
345347
mu.Lock()
346348
fails = append(fails, fmt.Sprintf("; detail=%s-INSERTION-ERROR", currentStorer.Name()))
@@ -351,9 +353,9 @@ func (s *SouinBaseHandler) Store(
351353

352354
wg.Wait()
353355
if len(fails) < s.storersLen {
354-
go func(rs http.Response, key string) {
355-
_ = s.SurrogateKeyStorer.Store(&rs, key)
356-
}(res, variedKey)
356+
go func(rs http.Response, key string, basekey string) {
357+
_ = s.SurrogateKeyStorer.Store(&rs, key, uri, basekey)
358+
}(res, variedKey, cachedKey)
357359
status += "; stored"
358360
}
359361

@@ -387,6 +389,7 @@ func (s *SouinBaseHandler) Upstream(
387389
next handlerFunc,
388390
requestCc *cacheobject.RequestCacheDirectives,
389391
cachedKey string,
392+
uri string,
390393
) error {
391394
s.Configuration.GetLogger().Sugar().Debug("Request the upstream server")
392395
prometheus.Increment(prometheus.RequestCounter)
@@ -434,7 +437,7 @@ func (s *SouinBaseHandler) Upstream(
434437
customWriter.Header().Set(headerName, s.DefaultMatchedUrl.DefaultCacheControl)
435438
}
436439

437-
err := s.Store(customWriter, rq, requestCc, cachedKey)
440+
err := s.Store(customWriter, rq, requestCc, cachedKey, uri)
438441
defer customWriter.Buf.Reset()
439442

440443
return singleflightValue{
@@ -458,7 +461,7 @@ func (s *SouinBaseHandler) Upstream(
458461
for _, vh := range variedHeaders {
459462
if rq.Header.Get(vh) != sfWriter.requestHeaders.Get(vh) {
460463
// cachedKey += rfc.GetVariedCacheKey(rq, variedHeaders)
461-
return s.Upstream(customWriter, rq, next, requestCc, cachedKey)
464+
return s.Upstream(customWriter, rq, next, requestCc, cachedKey, uri)
462465
}
463466
}
464467
}
@@ -474,7 +477,7 @@ func (s *SouinBaseHandler) Upstream(
474477
return nil
475478
}
476479

477-
func (s *SouinBaseHandler) Revalidate(validator *core.Revalidator, next handlerFunc, customWriter *CustomWriter, rq *http.Request, requestCc *cacheobject.RequestCacheDirectives, cachedKey string) error {
480+
func (s *SouinBaseHandler) Revalidate(validator *core.Revalidator, next handlerFunc, customWriter *CustomWriter, rq *http.Request, requestCc *cacheobject.RequestCacheDirectives, cachedKey string, uri string) error {
478481
s.Configuration.GetLogger().Sugar().Debug("Revalidate the request with the upstream server")
479482
prometheus.Increment(prometheus.RequestRevalidationCounter)
480483

@@ -496,7 +499,7 @@ func (s *SouinBaseHandler) Revalidate(validator *core.Revalidator, next handlerF
496499
}
497500

498501
if statusCode != http.StatusNotModified {
499-
err = s.Store(customWriter, rq, requestCc, cachedKey)
502+
err = s.Store(customWriter, rq, requestCc, cachedKey, uri)
500503
}
501504
}
502505

@@ -616,6 +619,8 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
616619
}
617620
cachedKey := req.Context().Value(context.Key).(string)
618621

622+
// Need to copy URL path before calling next because it can alter the URI
623+
uri := req.URL.Path
619624
bufPool := s.bufPool.Get().(*bytes.Buffer)
620625
bufPool.Reset()
621626
defer s.bufPool.Put(bufPool)
@@ -669,14 +674,14 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
669674
}
670675

671676
if validator.NeedRevalidation {
672-
err := s.Revalidate(validator, next, customWriter, req, requestCc, cachedKey)
677+
err := s.Revalidate(validator, next, customWriter, req, requestCc, cachedKey, uri)
673678
_, _ = customWriter.Send()
674679

675680
return err
676681
}
677682
if resCc, _ := cacheobject.ParseResponseCacheControl(rfc.HeaderAllCommaSepValuesString(response.Header, headerName)); resCc.NoCachePresent {
678683
prometheus.Increment(prometheus.NoCachedResponseCounter)
679-
err := s.Revalidate(validator, next, customWriter, req, requestCc, cachedKey)
684+
err := s.Revalidate(validator, next, customWriter, req, requestCc, cachedKey, uri)
680685
_, _ = customWriter.Send()
681686

682687
return err
@@ -711,9 +716,9 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
711716
_, _ = io.Copy(customWriter.Buf, response.Body)
712717
_, err := customWriter.Send()
713718
customWriter = NewCustomWriter(req, rw, bufPool)
714-
go func(v *core.Revalidator, goCw *CustomWriter, goRq *http.Request, goNext func(http.ResponseWriter, *http.Request) error, goCc *cacheobject.RequestCacheDirectives, goCk string) {
715-
_ = s.Revalidate(v, goNext, goCw, goRq, goCc, goCk)
716-
}(validator, customWriter, req, next, requestCc, cachedKey)
719+
go func(v *core.Revalidator, goCw *CustomWriter, goRq *http.Request, goNext func(http.ResponseWriter, *http.Request) error, goCc *cacheobject.RequestCacheDirectives, goCk string, goUri string) {
720+
_ = s.Revalidate(v, goNext, goCw, goRq, goCc, goCk, goUri)
721+
}(validator, customWriter, req, next, requestCc, cachedKey, uri)
717722
buf := s.bufPool.Get().(*bytes.Buffer)
718723
buf.Reset()
719724
defer s.bufPool.Put(buf)
@@ -723,7 +728,7 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
723728

724729
if responseCc.MustRevalidate || responseCc.NoCachePresent || validator.NeedRevalidation {
725730
req.Header["If-None-Match"] = append(req.Header["If-None-Match"], validator.ResponseETag)
726-
err := s.Revalidate(validator, next, customWriter, req, requestCc, cachedKey)
731+
err := s.Revalidate(validator, next, customWriter, req, requestCc, cachedKey, uri)
727732
statusCode := customWriter.GetStatusCode()
728733
if err != nil {
729734
if responseCc.StaleIfError > -1 || requestCc.StaleIfError > 0 {
@@ -785,7 +790,7 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
785790
errorCacheCh := make(chan error)
786791
go func(vr *http.Request, cw *CustomWriter) {
787792
prometheus.Increment(prometheus.NoCachedResponseCounter)
788-
errorCacheCh <- s.Upstream(cw, vr, next, requestCc, cachedKey)
793+
errorCacheCh <- s.Upstream(cw, vr, next, requestCc, cachedKey, uri)
789794
}(req, customWriter)
790795

791796
select {

pkg/surrogate/providers/akamai.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,12 @@ func (*AkamaiSurrogateStorage) getHeaderSeparator() string {
3939
}
4040

4141
// Store stores the response tags located in the first non empty supported header
42-
func (a *AkamaiSurrogateStorage) Store(response *http.Response, cacheKey string) error {
42+
func (a *AkamaiSurrogateStorage) Store(response *http.Response, cacheKey, uri, basekey string) error {
4343
defer func() {
4444
response.Header.Del(surrogateKey)
4545
response.Header.Del(surrogateControl)
4646
}()
47-
e := a.baseStorage.Store(response, cacheKey)
47+
e := a.baseStorage.Store(response, cacheKey, uri, basekey)
4848
response.Header.Set(edgeCacheTag, response.Header.Get(surrogateKey))
4949

5050
return e

pkg/surrogate/providers/cloudflare.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ func (*CloudflareSurrogateStorage) getHeaderSeparator() string {
3838
}
3939

4040
// Store stores the response tags located in the first non empty supported header
41-
func (c *CloudflareSurrogateStorage) Store(response *http.Response, cacheKey string) error {
41+
func (c *CloudflareSurrogateStorage) Store(response *http.Response, cacheKey, uri, basekey string) error {
4242
defer func() {
4343
response.Header.Del(surrogateKey)
4444
response.Header.Del(surrogateControl)
4545
}()
46-
e := c.baseStorage.Store(response, cacheKey)
46+
e := c.baseStorage.Store(response, cacheKey, uri, basekey)
4747
response.Header.Set(cacheTag, strings.Join(c.ParseHeaders(response.Header.Get(surrogateKey)), c.getHeaderSeparator()))
4848

4949
return e

pkg/surrogate/providers/common.go

+6-1
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ func (s *baseStorage) purgeTag(tag string) []string {
204204
}
205205

206206
// Store will take the lead to store the cache key for each provided Surrogate-key
207-
func (s *baseStorage) Store(response *http.Response, cacheKey string) error {
207+
func (s *baseStorage) Store(response *http.Response, cacheKey, uri, basekey string) error {
208208
h := response.Header
209209

210210
cacheKey = url.QueryEscape(cacheKey)
@@ -223,13 +223,18 @@ func (s *baseStorage) Store(response *http.Response, cacheKey string) error {
223223
for _, control := range controls {
224224
if s.parent.candidateStore(control) {
225225
s.storeTag(key, cacheKey, urlRegexp)
226+
227+
break
226228
}
227229
}
228230
} else {
229231
s.storeTag(key, cacheKey, urlRegexp)
230232
}
231233
}
232234

235+
urlRegexp = regexp.MustCompile("(^|" + regexp.QuoteMeta(souinStorageSeparator) + ")" + regexp.QuoteMeta(basekey) + "(" + regexp.QuoteMeta(souinStorageSeparator) + "|$)")
236+
s.storeTag(uri, basekey, urlRegexp)
237+
233238
return nil
234239
}
235240

pkg/surrogate/providers/common_test.go

+6-6
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ func TestBaseStorage_Store(t *testing.T) {
106106

107107
bs := mockCommonProvider()
108108

109-
e := bs.Store(&res, "((((invalid_key_but_escaped")
109+
e := bs.Store(&res, "((((invalid_key_but_escaped", "", "")
110110
if e != nil {
111111
t.Error("It shouldn't throw an error with a valid key.")
112112
}
@@ -116,7 +116,7 @@ func TestBaseStorage_Store(t *testing.T) {
116116
_ = bs.Storage.Set("test5", []byte("first,second,fifth"), storageToInfiniteTTLMap[bs.Storage.Name()])
117117
_ = bs.Storage.Set("testInvalid", []byte("invalid"), storageToInfiniteTTLMap[bs.Storage.Name()])
118118

119-
if e = bs.Store(&res, "stored"); e != nil {
119+
if e = bs.Store(&res, "stored", "", ""); e != nil {
120120
t.Error("It shouldn't throw an error with a valid key.")
121121
}
122122

@@ -133,10 +133,10 @@ func TestBaseStorage_Store(t *testing.T) {
133133
}
134134

135135
res.Header.Set(surrogateKey, "something")
136-
_ = bs.Store(&res, "/something")
137-
_ = bs.Store(&res, "/something")
136+
_ = bs.Store(&res, "/something", "", "")
137+
_ = bs.Store(&res, "/something", "", "")
138138
res.Header.Set(surrogateKey, "something")
139-
_ = bs.Store(&res, "/some")
139+
_ = bs.Store(&res, "/some", "", "")
140140

141141
storageSize := len(bs.Storage.MapKeys(surrogatePrefix))
142142
if storageSize != 6 {
@@ -161,7 +161,7 @@ func TestBaseStorage_Store_Load(t *testing.T) {
161161
wg.Add(1)
162162
go func(r http.Response, iteration int, group *sync.WaitGroup) {
163163
defer wg.Done()
164-
_ = bs.Store(&r, fmt.Sprintf("my_dynamic_cache_key_%d", iteration))
164+
_ = bs.Store(&r, fmt.Sprintf("my_dynamic_cache_key_%d", iteration), "", "")
165165
}(res, i, &wg)
166166
}
167167

pkg/surrogate/providers/types.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ type SurrogateInterface interface {
1616
Purge(http.Header) (cacheKeys []string, surrogateKeys []string)
1717
Invalidate(method string, h http.Header)
1818
purgeTag(string) []string
19-
Store(*http.Response, string) error
19+
Store(*http.Response, string, string, string) error
2020
storeTag(string, string, *regexp.Regexp)
2121
ParseHeaders(string) []string
2222
List() map[string]string

plugins/caddy/httpcache_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ func TestInvalidationAPI(t *testing.T) {
287287

288288
_, _ = tester.AssertGetResponse(`http://localhost:9080/invalidation-api/souin-api/souin`, 200, `["GET-http-localhost:9080-/invalidation-api"]`)
289289

290-
_, _ = tester.AssertGetResponse(`http://localhost:9080/invalidation-api/souin-api/souin/surrogate_keys`, 200, `{"":",GET-http-localhost%3A9080-%2Finvalidation-api"}`)
290+
_, _ = tester.AssertGetResponse(`http://localhost:9080/invalidation-api/souin-api/souin/surrogate_keys`, 200, `{"":",GET-http-localhost%3A9080-%2Finvalidation-api","/invalidation-api":",GET-http-localhost:9080-/invalidation-api"}`)
291291

292292
purgeRq, _ := http.NewRequest("PURGE", "http://localhost:9080/invalidation-api/souin-api/souin", nil)
293293
purgeRq.Header.Set("Surrogate-Key", " , /something")

0 commit comments

Comments
 (0)