zitadel/internal/api/http/middleware/cache_interceptor.go

146 lines
3.4 KiB
Go
Raw Normal View History

package middleware
import (
"fmt"
"net/http"
"regexp"
"strings"
"time"
http_utils "github.com/caos/zitadel/internal/api/http"
)
type Cache struct {
Cacheability Cacheability
NoCache bool
NoStore bool
MaxAge time.Duration
SharedMaxAge time.Duration
NoTransform bool
Revalidation Revalidation
}
type Cacheability string
const (
CacheabilityNotSet Cacheability = ""
CacheabilityPublic = "public"
CacheabilityPrivate = "private"
)
type Revalidation string
const (
RevalidationNotSet Revalidation = ""
RevalidationMust = "must-revalidate"
RevalidationProxy = "proxy-revalidate"
)
type CacheConfig struct {
MaxAge time.Duration
SharedMaxAge time.Duration
}
var (
NeverCacheOptions = &Cache{
NoStore: true,
}
AssetOptions = func(maxAge, SharedMaxAge time.Duration) *Cache {
return &Cache{
Cacheability: CacheabilityPublic,
MaxAge: maxAge,
SharedMaxAge: SharedMaxAge,
}
}
)
func DefaultCacheInterceptor(pattern string, maxAge, sharedMaxAge time.Duration) (func(http.Handler) http.Handler, error) {
regex, err := regexp.Compile(pattern)
if err != nil {
return nil, err
}
return func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if regex.MatchString(r.URL.Path) {
AssetsCacheInterceptor(maxAge, sharedMaxAge, handler).ServeHTTP(w, r)
return
}
NoCacheInterceptor(handler).ServeHTTP(w, r)
})
}, nil
}
func NoCacheInterceptor(h http.Handler) http.Handler {
return CacheInterceptorOpts(h, NeverCacheOptions)
}
func AssetsCacheInterceptor(maxAge, sharedMaxAge time.Duration, h http.Handler) http.Handler {
return CacheInterceptorOpts(h, AssetOptions(maxAge, sharedMaxAge))
}
func CacheInterceptorOpts(h http.Handler, cache *Cache) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
cachingResponseWriter := &cachingResponseWriter{
ResponseWriter: w,
Cache: cache,
}
h.ServeHTTP(cachingResponseWriter, req)
})
}
type cachingResponseWriter struct {
http.ResponseWriter
*Cache
}
func (w *cachingResponseWriter) WriteHeader(code int) {
if code >= 400 {
NeverCacheOptions.serializeHeaders(w.ResponseWriter)
w.ResponseWriter.WriteHeader(code)
return
}
w.Cache.serializeHeaders(w.ResponseWriter)
w.ResponseWriter.WriteHeader(code)
}
func (c *Cache) serializeHeaders(w http.ResponseWriter) {
control := make([]string, 0, 6)
pragma := false
if c.Cacheability != CacheabilityNotSet {
control = append(control, string(c.Cacheability))
control = append(control, fmt.Sprintf("max-age=%v", c.MaxAge.Seconds()))
if c.SharedMaxAge != c.MaxAge {
control = append(control, fmt.Sprintf("s-maxage=%v", c.SharedMaxAge.Seconds()))
}
}
maxAge := c.MaxAge
if maxAge == 0 {
maxAge = -time.Hour
}
expires := time.Now().UTC().Add(maxAge).Format(http.TimeFormat)
if c.NoCache {
control = append(control, fmt.Sprintf("no-cache"))
pragma = true
}
if c.NoStore {
control = append(control, fmt.Sprintf("no-store"))
pragma = true
}
if c.NoTransform {
control = append(control, fmt.Sprintf("no-transform"))
}
if c.Revalidation != RevalidationNotSet {
control = append(control, string(c.Revalidation))
}
w.Header().Set(http_utils.CacheControl, strings.Join(control, ", "))
w.Header().Set(http_utils.Expires, expires)
if pragma {
w.Header().Set(http_utils.Pragma, "no-cache")
}
}