package middleware

import (
	"fmt"
	"net/http"
	"regexp"
	"strings"
	"time"

	http_utils "github.com/caos/zitadel/internal/api/http"
	"github.com/caos/zitadel/internal/config/types"
)

type Cache struct {
	Cacheability Cacheability
	NoCache      bool
	NoStore      bool
	MaxAge       time.Duration
	SharedMaxAge time.Duration
	NoTransform  bool
	Revalidation Revalidation
}

type Cacheability string

const (
	CacheabilityNotSet  Cacheability = ""
	CacheabilityPublic               = "public"
	CacheabilityPrivate              = "private"
)

type Revalidation string

const (
	RevalidationNotSet Revalidation = ""
	RevalidationMust                = "must-revalidate"
	RevalidationProxy               = "proxy-revalidate"
)

type CacheConfig struct {
	MaxAge       types.Duration
	SharedMaxAge types.Duration
}

var (
	NeverCacheOptions = &Cache{
		NoStore: true,
	}
	AssetOptions = func(maxAge, SharedMaxAge time.Duration) *Cache {
		return &Cache{
			Cacheability: CacheabilityPublic,
			MaxAge:       maxAge,
			SharedMaxAge: SharedMaxAge,
		}
	}
)

func DefaultCacheInterceptor(pattern string, maxAge, sharedMaxAge time.Duration) (func(http.Handler) http.Handler, error) {
	regex, err := regexp.Compile(pattern)
	if err != nil {
		return nil, err
	}
	return func(handler http.Handler) http.Handler {
		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
			if regex.MatchString(r.URL.Path) {
				AssetsCacheInterceptor(maxAge, sharedMaxAge, handler).ServeHTTP(w, r)
				return
			}
			NoCacheInterceptor(handler).ServeHTTP(w, r)
		})
	}, nil
}

func NoCacheInterceptor(h http.Handler) http.Handler {
	return CacheInterceptorOpts(h, NeverCacheOptions)
}

func AssetsCacheInterceptor(maxAge, sharedMaxAge time.Duration, h http.Handler) http.Handler {
	return CacheInterceptorOpts(h, AssetOptions(maxAge, sharedMaxAge))
}

func CacheInterceptorOpts(h http.Handler, cache *Cache) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
		cachingResponseWriter := &cachingResponseWriter{
			ResponseWriter: w,
			Cache:          cache,
		}
		h.ServeHTTP(cachingResponseWriter, req)
	})
}

type cachingResponseWriter struct {
	http.ResponseWriter
	*Cache
}

func (w *cachingResponseWriter) WriteHeader(code int) {
	if code >= 400 {
		NeverCacheOptions.serializeHeaders(w.ResponseWriter)
		w.ResponseWriter.WriteHeader(code)
		return
	}
	w.Cache.serializeHeaders(w.ResponseWriter)
	w.ResponseWriter.WriteHeader(code)
}

func (c *Cache) serializeHeaders(w http.ResponseWriter) {
	control := make([]string, 0, 6)
	pragma := false

	if c.Cacheability != CacheabilityNotSet {
		control = append(control, string(c.Cacheability))
		control = append(control, fmt.Sprintf("max-age=%v", c.MaxAge.Seconds()))
		if c.SharedMaxAge != c.MaxAge {
			control = append(control, fmt.Sprintf("s-maxage=%v", c.SharedMaxAge.Seconds()))
		}
	}
	maxAge := c.MaxAge
	if maxAge == 0 {
		maxAge = -time.Hour
	}
	expires := time.Now().UTC().Add(maxAge).Format(http.TimeFormat)

	if c.NoCache {
		control = append(control, fmt.Sprintf("no-cache"))
		pragma = true
	}

	if c.NoStore {
		control = append(control, fmt.Sprintf("no-store"))
		pragma = true
	}
	if c.NoTransform {
		control = append(control, fmt.Sprintf("no-transform"))
	}

	if c.Revalidation != RevalidationNotSet {
		control = append(control, string(c.Revalidation))
	}

	w.Header().Set(http_utils.CacheControl, strings.Join(control, ", "))
	w.Header().Set(http_utils.Expires, expires)
	if pragma {
		w.Header().Set(http_utils.Pragma, "no-cache")
	}
}