Skip to content

Commit 42d9299

Browse files
authored
Don't coalesce anymore uncoalesceable request (#56)
* Don't coalesce anymore uncoalesceable request * Add uncoalesceable request to the storage layer * Update LayerStorage behaviour
1 parent beb1e43 commit 42d9299

File tree

8 files changed

+171
-80
lines changed

8 files changed

+171
-80
lines changed

cache/types/layerStorage.go

+45
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,16 @@ type VaryLayerStorage struct {
88
*ristretto.Cache
99
}
1010

11+
func InitializeVaryLayerStorage() *VaryLayerStorage {
12+
storage, _ := ristretto.NewCache(&ristretto.Config{
13+
NumCounters: 1e7, // number of keys to track frequency of (10M).
14+
MaxCost: 1 << 30, // maximum cost of cache (1GB).
15+
BufferItems: 64, // number of keys per Get buffer.
16+
})
17+
18+
return &VaryLayerStorage{Cache: storage}
19+
}
20+
1121
// Get method returns the varied headers list if exists, empty array then
1222
func (provider *VaryLayerStorage) Get(key string) []string {
1323
val, found := provider.Cache.Get(key)
@@ -24,3 +34,38 @@ func (provider *VaryLayerStorage) Set(key string, headers []string) {
2434
panic("Impossible to set value into Ristretto")
2535
}
2636
}
37+
38+
type CoalescingLayerStorage struct {
39+
*ristretto.Cache
40+
}
41+
42+
func InitializeCoalescingLayerStorage() *CoalescingLayerStorage {
43+
storage, _ := ristretto.NewCache(&ristretto.Config{
44+
NumCounters: 1e7, // number of keys to track frequency of (10M).
45+
MaxCost: 1 << 30, // maximum cost of cache (1GB).
46+
BufferItems: 64, // number of keys per Get buffer.
47+
})
48+
49+
return &CoalescingLayerStorage{Cache: storage}
50+
}
51+
52+
// Exists method returns if the key should coalesce
53+
func (provider *CoalescingLayerStorage) Exists(key string) bool {
54+
_, found := provider.Cache.Get(key)
55+
return !found
56+
}
57+
58+
// Set method will store the response in Ristretto provider
59+
func (provider *CoalescingLayerStorage) Set(key string) {
60+
isSet := provider.Cache.Set(key, nil, 1)
61+
if !isSet {
62+
panic("Impossible to set value into Ristretto")
63+
}
64+
}
65+
66+
// Delete method will delete the response in Ristretto provider if exists corresponding to key param
67+
func (provider *CoalescingLayerStorage) Delete(key string) {
68+
go func() {
69+
provider.Del(key)
70+
}()
71+
}

cache/types/layerStorage_test.go

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package types
2+
3+
import (
4+
"fmt"
5+
"testing"
6+
7+
"github.com/darkweak/souin/errors"
8+
"time"
9+
)
10+
11+
const LAYEREDKEY = "LayeredKey"
12+
const BYTEKEY = "MyByteKey"
13+
const NONEXISTENTKEY = "NonexistentKey"
14+
15+
func TestInitializeCoalescingLayerStorage(t *testing.T) {
16+
r := InitializeCoalescingLayerStorage()
17+
18+
if nil == r || nil == r.Cache {
19+
errors.GenerateError(t, "Ristretto should be instanciated")
20+
}
21+
}
22+
23+
func TestIShouldBeAbleToReadAndWriteDataInTheLayerStorage(t *testing.T) {
24+
store := InitializeCoalescingLayerStorage()
25+
26+
store.Set(LAYEREDKEY)
27+
time.Sleep(1 * time.Second)
28+
29+
if store.Exists(LAYEREDKEY) {
30+
errors.GenerateError(t, fmt.Sprintf("Key %s should exist", LAYEREDKEY))
31+
}
32+
}
33+
34+
func TestLayerStorage_GetRequestInTheLayerStorage(t *testing.T) {
35+
store := InitializeCoalescingLayerStorage()
36+
if !store.Exists(NONEXISTENTKEY) {
37+
errors.GenerateError(t, fmt.Sprintf("Key %s should not exist", NONEXISTENTKEY))
38+
}
39+
}
40+
41+
func TestLayerStorage_DeleteRequestInCache(t *testing.T) {
42+
store := InitializeCoalescingLayerStorage()
43+
store.Delete(BYTEKEY)
44+
time.Sleep(1 * time.Second)
45+
if !store.Exists(BYTEKEY) {
46+
errors.GenerateError(t, fmt.Sprintf("Key %s should not exist", BYTEKEY))
47+
}
48+
}

cache/types/souin.go

+7-5
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ type TransportInterface interface {
1313
SetURL(url configurationtypes.URL)
1414
UpdateCacheEventually(req *http.Request) (resp *http.Response, err error)
1515
GetVaryLayerStorage() *VaryLayerStorage
16+
GetCoalescingLayerStorage() *CoalescingLayerStorage
1617
}
1718

1819
// Transport is an implementation of http.RoundTripper that will return values from a cache
@@ -21,11 +22,12 @@ type TransportInterface interface {
2122
type Transport struct {
2223
// The RoundTripper interface actually used to make requests
2324
// If nil, http.DefaultTransport is used
24-
Transport http.RoundTripper
25-
Provider AbstractProviderInterface
26-
ConfigurationURL configurationtypes.URL
27-
MarkCachedResponses bool
28-
VaryLayerStorage *VaryLayerStorage
25+
Transport http.RoundTripper
26+
Provider AbstractProviderInterface
27+
ConfigurationURL configurationtypes.URL
28+
MarkCachedResponses bool
29+
VaryLayerStorage *VaryLayerStorage
30+
CoalescingLayerStorage *CoalescingLayerStorage
2931
}
3032

3133
// RetrieverResponsePropertiesInterface interface

plugins/base.go

+10-1
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,17 @@ func DefaultSouinPluginCallback(
2020
nextMiddleware func(w http.ResponseWriter, r *http.Request) error,
2121
) {
2222
responses := make(chan types.ReverseResponse)
23+
coalesceable := make(chan bool)
2324

2425
go func() {
2526
cacheKey := rfc.GetCacheKey(req)
2627
varied := retriever.GetTransport().GetVaryLayerStorage().Get(cacheKey)
2728
if len(varied) != 0 {
2829
cacheKey = rfc.GetVariedCacheKey(req, varied)
2930
}
31+
go func() {
32+
coalesceable <- retriever.GetTransport().GetCoalescingLayerStorage().Exists(cacheKey)
33+
}()
3034
if http.MethodGet == req.Method {
3135
r, _ := rfc.CachedResponse(
3236
retriever.GetProvider(),
@@ -56,7 +60,12 @@ func DefaultSouinPluginCallback(
5660
}
5761

5862
close(responses)
59-
rc.Temporise(req, res, nextMiddleware)
63+
if <-coalesceable {
64+
rc.Temporise(req, res, nextMiddleware)
65+
} else {
66+
_ = nextMiddleware(res, req)
67+
}
68+
close(coalesceable)
6069
}
6170

6271
// DefaultSouinPluginInitializerFromConfiguration is the default initialization for plugins

rfc/bridge.go

+44-52
Original file line numberDiff line numberDiff line change
@@ -47,43 +47,55 @@ func (t *VaryTransport) BaseRoundTrip(req *http.Request, shouldReUpdate bool) (s
4747
cachedResp = cr.Response
4848
}
4949
} else {
50+
go func() {
51+
t.CoalescingLayerStorage.Set(cacheKey)
52+
}()
5053
t.Provider.Delete(cacheKey)
5154
}
5255

5356
return cacheKey, cacheable, cachedResp
5457
}
5558

56-
// UpdateCacheEventually will handle Request and update the previous one in the cache provider
57-
func (t *VaryTransport) UpdateCacheEventually(req *http.Request) (resp *http.Response, err error) {
58-
cacheKey, cacheable, cachedResp := t.BaseRoundTrip(req, false)
59+
func commonVaryMatchesVerification(cachedResp *http.Response, req *http.Request) *http.Response {
60+
if varyMatches(cachedResp, req) {
61+
// Can only use cached value if the new request doesn't Vary significantly
62+
freshness := getFreshness(cachedResp.Header, req.Header)
63+
if freshness == fresh {
64+
return cachedResp
65+
}
5966

60-
if cacheable && cachedResp != nil {
61-
if varyMatches(cachedResp, req) {
62-
// Can only use cached value if the new request doesn't Vary significantly
63-
freshness := getFreshness(cachedResp.Header, req.Header)
64-
if freshness == fresh {
65-
return cachedResp, nil
67+
if freshness == stale {
68+
var req2 *http.Request
69+
// Add validators if caller hasn't already done so
70+
etag := cachedResp.Header.Get("etag")
71+
if etag != "" && req.Header.Get("etag") == "" {
72+
req2 = cloneRequest(req)
73+
req2.Header.Set("if-none-match", etag)
6674
}
67-
68-
if freshness == stale {
69-
var req2 *http.Request
70-
// Add validators if caller hasn't already done so
71-
etag := cachedResp.Header.Get("etag")
72-
if etag != "" && req.Header.Get("etag") == "" {
75+
lastModified := cachedResp.Header.Get("last-modified")
76+
if lastModified != "" && req.Header.Get("last-modified") == "" {
77+
if req2 == nil {
7378
req2 = cloneRequest(req)
74-
req2.Header.Set("if-none-match", etag)
75-
}
76-
lastModified := cachedResp.Header.Get("last-modified")
77-
if lastModified != "" && req.Header.Get("last-modified") == "" {
78-
if req2 == nil {
79-
req2 = cloneRequest(req)
80-
}
81-
req2.Header.Set("if-modified-since", lastModified)
82-
}
83-
if req2 != nil {
84-
req = req2
8579
}
80+
req2.Header.Set("if-modified-since", lastModified)
8681
}
82+
if req2 != nil {
83+
req = req2
84+
}
85+
}
86+
}
87+
88+
return nil
89+
}
90+
91+
// UpdateCacheEventually will handle Request and update the previous one in the cache provider
92+
func (t *VaryTransport) UpdateCacheEventually(req *http.Request) (resp *http.Response, err error) {
93+
cacheKey, cacheable, cachedResp := t.BaseRoundTrip(req, false)
94+
95+
if cacheable && cachedResp != nil {
96+
r := commonVaryMatchesVerification(cachedResp, req)
97+
if r != nil {
98+
return r, nil
8799
}
88100
} else {
89101
reqCacheControl := parseCacheControl(req.Header)
@@ -126,32 +138,9 @@ func (t *VaryTransport) RoundTrip(req *http.Request) (resp *http.Response, err e
126138
cachedResp.Header.Set(XFromCache, "1")
127139
}
128140

129-
if varyMatches(cachedResp, req) {
130-
// Can only use cached value if the new request doesn't Vary significantly
131-
freshness := getFreshness(cachedResp.Header, req.Header)
132-
if freshness == fresh {
133-
return cachedResp, nil
134-
}
135-
136-
if freshness == stale {
137-
var req2 *http.Request
138-
// Add validators if caller hasn't already done so
139-
etag := cachedResp.Header.Get("etag")
140-
if etag != "" && req.Header.Get("etag") == "" {
141-
req2 = cloneRequest(req)
142-
req2.Header.Set("if-none-match", etag)
143-
}
144-
lastModified := cachedResp.Header.Get("last-modified")
145-
if lastModified != "" && req.Header.Get("last-modified") == "" {
146-
if req2 == nil {
147-
req2 = cloneRequest(req)
148-
}
149-
req2.Header.Set("if-modified-since", lastModified)
150-
}
151-
if req2 != nil {
152-
req = req2
153-
}
154-
}
141+
r := commonVaryMatchesVerification(cachedResp, req)
142+
if r != nil {
143+
return r, nil
155144
}
156145

157146
resp, err = transport.RoundTrip(req)
@@ -188,6 +177,9 @@ func (t *VaryTransport) RoundTrip(req *http.Request) (resp *http.Response, err e
188177
}
189178
resp, err = transport.RoundTrip(req)
190179
if !(cacheable && validateVary(req, resp, cacheKey, t)) {
180+
go func() {
181+
t.CoalescingLayerStorage.Set(cacheKey)
182+
}()
191183
t.Provider.Delete(cacheKey)
192184
}
193185
return resp, nil

rfc/transport.go

+10-8
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package rfc
33
import (
44
"github.com/darkweak/souin/cache/types"
55
"github.com/darkweak/souin/configurationtypes"
6-
"github.com/dgraph-io/ristretto"
76
"github.com/pquerna/cachecontrol"
87
"net/http"
98
"net/http/httputil"
@@ -23,13 +22,12 @@ func IsVaryCacheable(req *http.Request) bool {
2322
// NewTransport returns a new Transport with the
2423
// provided Cache implementation and MarkCachedResponses set to true
2524
func NewTransport(p types.AbstractProviderInterface) *VaryTransport {
26-
storage, _ := ristretto.NewCache(&ristretto.Config{
27-
NumCounters: 1e7, // number of keys to track frequency of (10M).
28-
MaxCost: 1 << 30, // maximum cost of cache (1GB).
29-
BufferItems: 64, // number of keys per Get buffer.
30-
OnEvict: func(key, conflict uint64, value interface{}, cost int64) {},
31-
})
32-
return &VaryTransport{Provider: p, VaryLayerStorage: &types.VaryLayerStorage{Cache: storage}, MarkCachedResponses: true}
25+
return &VaryTransport{
26+
Provider: p,
27+
VaryLayerStorage: types.InitializeVaryLayerStorage(),
28+
CoalescingLayerStorage: types.InitializeCoalescingLayerStorage(),
29+
MarkCachedResponses: true,
30+
}
3331
}
3432

3533
// GetProvider returns the associated provider
@@ -46,6 +44,10 @@ func (t *VaryTransport) GetVaryLayerStorage() *types.VaryLayerStorage {
4644
return t.VaryLayerStorage
4745
}
4846

47+
func (t *VaryTransport) GetCoalescingLayerStorage() *types.CoalescingLayerStorage {
48+
return t.CoalescingLayerStorage
49+
}
50+
4951
// SetCache set the cache
5052
func (t *VaryTransport) SetCache(key string, resp *http.Response, req *http.Request) {
5153
r, _, _ := cachecontrol.CachableResponse(req, resp, cachecontrol.Options{})

rfc/vary.go

+4-9
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import (
1111
func varyMatches(cachedResp *http.Response, req *http.Request) bool {
1212
for _, header := range headerAllCommaSepValues(cachedResp.Header, "vary") {
1313
header = http.CanonicalHeaderKey(header)
14-
if header != "" && req.Header.Get(header) != cachedResp.Header.Get("X-Varied-"+header) {
14+
if header == "" || req.Header.Get(header) == "" {
1515
return false
1616
}
1717
}
@@ -26,14 +26,6 @@ func validateVary(req *http.Request, resp *http.Response, key string, t *VaryTra
2626
go func() {
2727
t.VaryLayerStorage.Set(key, variedHeaders)
2828
}()
29-
for _, varyKey := range variedHeaders {
30-
varyKey = http.CanonicalHeaderKey(varyKey)
31-
fakeHeader := "X-Varied-" + varyKey
32-
reqValue := req.Header.Get(varyKey)
33-
if reqValue != "" {
34-
resp.Header.Set(fakeHeader, reqValue)
35-
}
36-
}
3729
cacheKey = GetVariedCacheKey(req, variedHeaders)
3830
}
3931
switch req.Method {
@@ -45,6 +37,9 @@ func validateVary(req *http.Request, resp *http.Response, key string, t *VaryTra
4537
resp := *resp
4638
resp.Body = ioutil.NopCloser(r)
4739
t.SetCache(cacheKey, &resp, req)
40+
go func() {
41+
t.CoalescingLayerStorage.Delete(cacheKey)
42+
}()
4843
},
4944
}
5045
}

rfc/vary_test.go

+3-5
Original file line numberDiff line numberDiff line change
@@ -25,23 +25,21 @@ func TestVaryMatches(t *testing.T) {
2525
}
2626

2727
header := "Cache"
28-
r.Header.Set("Vary", header)
2928
r.Header.Set(header, "same")
3029
res.Header.Set("vary", header)
31-
res.Header.Set("X-Varied-"+header, "same")
3230

3331
if !varyMatches(res, r) {
34-
errors.GenerateError(t, "Vary match should return true if Response contains X-Varied-* header is the same than * in Request header")
32+
errors.GenerateError(t, "Vary match should return true if Response contains a vary header that is not null in the request")
3533
}
3634

3735
if !validateVary(r, res, GetCacheKey(r), tr) {
3836
errors.GenerateError(t, fmt.Sprintf("It contains valid vary headers in the Response. It should validate it, %v given", res.Header))
3937
}
4038

41-
res.Header.Set("X-Varied-"+header, "different")
39+
r.Header.Set(header, "")
4240

4341
if varyMatches(res, r) {
44-
errors.GenerateError(t, "Vary match should return false if Response contains X-Varied-* header different than * in Request header")
42+
errors.GenerateError(t, "Vary match should return false if Response contains a vary header that is empty in the request")
4543
}
4644

4745
if !validateVary(r, res, GetCacheKey(r), tr) {

0 commit comments

Comments
 (0)