package http import ( "context" "net" "net/http" "strings" ) const ( Authorization = "authorization" Accept = "accept" AcceptLanguage = "accept-language" CacheControl = "cache-control" ContentType = "content-type" ContentLength = "content-length" Expires = "expires" Location = "location" Origin = "origin" Pragma = "pragma" UserAgentHeader = "user-agent" ForwardedFor = "x-forwarded-for" XUserAgent = "x-user-agent" XGrpcWeb = "x-grpc-web" XRequestedWith = "x-requested-with" XRobotsTag = "x-robots-tag" IfNoneMatch = "If-None-Match" LastModified = "Last-Modified" Etag = "Etag" ContentSecurityPolicy = "content-security-policy" XXSSProtection = "x-xss-protection" StrictTransportSecurity = "strict-transport-security" XFrameOptions = "x-frame-options" XContentTypeOptions = "x-content-type-options" ReferrerPolicy = "referrer-policy" FeaturePolicy = "feature-policy" PermissionsPolicy = "permissions-policy" ZitadelOrgID = "x-zitadel-orgid" ) type key int const ( httpHeaders key = iota remoteAddr origin ) func CopyHeadersToContext(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := context.WithValue(r.Context(), httpHeaders, r.Header) ctx = context.WithValue(ctx, remoteAddr, r.RemoteAddr) r = r.WithContext(ctx) h.ServeHTTP(w, r) }) } func HeadersFromCtx(ctx context.Context) (http.Header, bool) { headers, ok := ctx.Value(httpHeaders).(http.Header) return headers, ok } func OriginHeader(ctx context.Context) string { headers, ok := ctx.Value(httpHeaders).(http.Header) if !ok { return "" } return headers.Get(Origin) } func ComposedOrigin(ctx context.Context) string { o, ok := ctx.Value(origin).(string) if !ok { return "" } return o } func WithComposedOrigin(ctx context.Context, composed string) context.Context { return context.WithValue(ctx, origin, composed) } func RemoteIPFromCtx(ctx context.Context) string { ctxHeaders, ok := HeadersFromCtx(ctx) if !ok { return RemoteAddrFromCtx(ctx) } forwarded, ok := GetForwardedFor(ctxHeaders) if ok { return forwarded } return RemoteAddrFromCtx(ctx) } func RemoteIPFromRequest(r *http.Request) net.IP { return net.ParseIP(RemoteIPStringFromRequest(r)) } func RemoteIPStringFromRequest(r *http.Request) string { ip, ok := GetForwardedFor(r.Header) if ok { return ip } host, _, _ := net.SplitHostPort(r.RemoteAddr) return host } func GetAuthorization(r *http.Request) string { return r.Header.Get(Authorization) } func GetOrgID(r *http.Request) string { return r.Header.Get(ZitadelOrgID) } func GetForwardedFor(headers http.Header) (string, bool) { forwarded, ok := headers[ForwardedFor] if ok { ip := strings.TrimSpace(strings.Split(forwarded[0], ",")[0]) if ip != "" { return ip, true } } return "", false } func RemoteAddrFromCtx(ctx context.Context) string { ctxRemoteAddr, _ := ctx.Value(remoteAddr).(string) return ctxRemoteAddr }