package cache import ( "fmt" "net/http" "time" "github.com/coocood/freecache" "github.com/labstack/echo/v4" "github.com/mcuadros/go-defaults" ) // Config defiens the configuration for a cache middleware. type Config struct { // TTL time to life of the cache. TTL time.Duration `default:"1m"` // Methods methods to be cached. Methods []string `default:"[GET]"` // StatusCode method to be cached. StatusCode []int `default:"[200,404]"` // IgnoreQuery if true the Query values from the requests are ignored on // the key generation. IgnoreQuery bool // Refresh fuction called before use the cache, if true, the cache is deleted. Refresh func(r *http.Request) bool // Cache fuction called before cache a request, if false, the request is not // cached. If set Method is ignored. Cache func(r *http.Request) bool } func New(cfg *Config, cache *freecache.Cache) echo.MiddlewareFunc { if cfg == nil { cfg = &Config{} } defaults.SetDefaults(cfg) m := &CacheMiddleware{cfg: cfg, cache: cache} return m.Handler } type CacheMiddleware struct { cfg *Config cache *freecache.Cache } func (m *CacheMiddleware) Handler(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { if !m.isCacheable(c.Request()) { return next(c) } if mayHasBody(c.Request().Method) { c.Logger().Warnf("request with body are cached ignoring the content") } key := m.getKey(c.Request()) err := m.readCache(key, c) if err == nil { return nil } if err != freecache.ErrNotFound { c.Logger().Errorf("error reading cache: %s", err) } recorder := NewResponseRecorder(c.Response().Writer) c.Response().Writer = recorder err = next(c) if err := m.cacheResult(key, recorder); err != nil { c.Logger().Error(err) } return err } } func (m *CacheMiddleware) readCache(key []byte, c echo.Context) error { if m.cfg.Refresh != nil && m.cfg.Refresh(c.Request()) { return freecache.ErrNotFound } value, err := m.cache.Get(key) if err != nil { return err } entry := &CacheEntry{} if err := entry.Decode(value); err != nil { return err } return entry.Replay(c.Response()) } func (m *CacheMiddleware) cacheResult(key []byte, r *ResponseRecorder) error { e := r.Result() b, err := e.Encode() if err != nil { return fmt.Errorf("unable to read recorded response: %s", err) } if !m.isStatusCacheable(e) { return nil } return m.cache.Set(key, b, int(m.cfg.TTL.Seconds())) } func (m *CacheMiddleware) isStatusCacheable(e *CacheEntry) bool { for _, status := range m.cfg.StatusCode { if e.StatusCode == status { return true } } return false } func (m *CacheMiddleware) isCacheable(r *http.Request) bool { if m.cfg.Cache != nil { return m.cfg.Cache(r) } for _, method := range m.cfg.Methods { if r.Method == method { return true } } return false } func (m *CacheMiddleware) getKey(r *http.Request) []byte { base := r.Method + "|" + r.URL.Path if !m.cfg.IgnoreQuery { base += "|" + r.URL.Query().Encode() } return []byte(base) } func mayHasBody(method string) bool { m := method if m == http.MethodPost || m == http.MethodPut || m == http.MethodDelete || m == http.MethodPatch { return true } return false }