diff --git a/internal/api/http/middleware/origin_interceptor.go b/internal/api/http/middleware/origin_interceptor.go index 35af8770b7..607855b80f 100644 --- a/internal/api/http/middleware/origin_interceptor.go +++ b/internal/api/http/middleware/origin_interceptor.go @@ -10,12 +10,12 @@ import ( http_util "github.com/zitadel/zitadel/internal/api/http" ) -func WithOrigin(fallBackToHttps bool, http1Header, http2Header string, instanceHostHeaders, publicDomainHeaders []string) mux.MiddlewareFunc { +func WithOrigin(enforceHttps bool, http1Header, http2Header string, instanceHostHeaders, publicDomainHeaders []string) mux.MiddlewareFunc { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { origin := composeDomainContext( r, - fallBackToHttps, + enforceHttps, // to make sure we don't break existing configurations we append the existing checked headers as well slices.Compact(append(instanceHostHeaders, http1Header, http2Header, http_util.Forwarded, http_util.ZitadelForwarded, http_util.ForwardedFor, http_util.ForwardedHost, http_util.ForwardedProto)), publicDomainHeaders, @@ -25,28 +25,32 @@ func WithOrigin(fallBackToHttps bool, http1Header, http2Header string, instanceH } } -func composeDomainContext(r *http.Request, fallBackToHttps bool, instanceDomainHeaders, publicDomainHeaders []string) *http_util.DomainCtx { +func composeDomainContext(r *http.Request, enforceHttps bool, instanceDomainHeaders, publicDomainHeaders []string) *http_util.DomainCtx { instanceHost, instanceProto := hostFromRequest(r, instanceDomainHeaders) publicHost, publicProto := hostFromRequest(r, publicDomainHeaders) - if publicProto == "" { - publicProto = instanceProto - } - if publicProto == "" { - publicProto = "http" - if fallBackToHttps { - publicProto = "https" - } - } if instanceHost == "" { instanceHost = r.Host } return &http_util.DomainCtx{ InstanceHost: instanceHost, - Protocol: publicProto, + Protocol: protocolFromRequest(instanceProto, publicProto, enforceHttps), PublicHost: publicHost, } } +func protocolFromRequest(instanceProto, publicProto string, enforceHttps bool) string { + if enforceHttps { + return "https" + } + if publicProto != "" { + return publicProto + } + if instanceProto != "" { + return instanceProto + } + return "http" +} + func hostFromRequest(r *http.Request, headers []string) (host, proto string) { var hostFromHeader, protoFromHeader string for _, header := range headers { @@ -65,7 +69,7 @@ func hostFromRequest(r *http.Request, headers []string) (host, proto string) { if host == "" { host = hostFromHeader } - if proto == "" { + if proto == "" && (protoFromHeader == "http" || protoFromHeader == "https") { proto = protoFromHeader } } diff --git a/internal/api/http/middleware/origin_interceptor_test.go b/internal/api/http/middleware/origin_interceptor_test.go index 989e4d48b3..7419c91aba 100644 --- a/internal/api/http/middleware/origin_interceptor_test.go +++ b/internal/api/http/middleware/origin_interceptor_test.go @@ -11,8 +11,8 @@ import ( func Test_composeOrigin(t *testing.T) { type args struct { - h http.Header - fallBackToHttps bool + h http.Header + enforceHttps bool } tests := []struct { name string @@ -30,7 +30,7 @@ func Test_composeOrigin(t *testing.T) { h: http.Header{ "Forwarded": []string{"proto=https"}, }, - fallBackToHttps: false, + enforceHttps: false, }, want: &http_util.DomainCtx{ InstanceHost: "host.header", @@ -42,7 +42,7 @@ func Test_composeOrigin(t *testing.T) { h: http.Header{ "Forwarded": []string{"host=forwarded.host"}, }, - fallBackToHttps: false, + enforceHttps: false, }, want: &http_util.DomainCtx{ InstanceHost: "forwarded.host", @@ -54,7 +54,7 @@ func Test_composeOrigin(t *testing.T) { h: http.Header{ "Forwarded": []string{"proto=https;host=forwarded.host"}, }, - fallBackToHttps: false, + enforceHttps: false, }, want: &http_util.DomainCtx{ InstanceHost: "forwarded.host", @@ -66,7 +66,7 @@ func Test_composeOrigin(t *testing.T) { h: http.Header{ "Forwarded": []string{"proto=https;host=forwarded.host, proto=http;host=forwarded.host2"}, }, - fallBackToHttps: false, + enforceHttps: false, }, want: &http_util.DomainCtx{ InstanceHost: "forwarded.host", @@ -78,7 +78,7 @@ func Test_composeOrigin(t *testing.T) { h: http.Header{ "Forwarded": []string{"proto=https;host=forwarded.host, proto=http"}, }, - fallBackToHttps: false, + enforceHttps: false, }, want: &http_util.DomainCtx{ InstanceHost: "forwarded.host", @@ -90,11 +90,11 @@ func Test_composeOrigin(t *testing.T) { h: http.Header{ "Forwarded": []string{"proto=http", "proto=https;host=forwarded.host", "proto=http"}, }, - fallBackToHttps: true, + enforceHttps: true, }, want: &http_util.DomainCtx{ InstanceHost: "forwarded.host", - Protocol: "http", + Protocol: "https", }, }, { name: "x-forwarded-proto https", @@ -102,7 +102,7 @@ func Test_composeOrigin(t *testing.T) { h: http.Header{ "X-Forwarded-Proto": []string{"https"}, }, - fallBackToHttps: false, + enforceHttps: false, }, want: &http_util.DomainCtx{ InstanceHost: "host.header", @@ -114,25 +114,25 @@ func Test_composeOrigin(t *testing.T) { h: http.Header{ "X-Forwarded-Proto": []string{"http"}, }, - fallBackToHttps: true, + enforceHttps: true, }, want: &http_util.DomainCtx{ InstanceHost: "host.header", - Protocol: "http", + Protocol: "https", }, }, { name: "fallback to http", args: args{ - fallBackToHttps: false, + enforceHttps: false, }, want: &http_util.DomainCtx{ InstanceHost: "host.header", Protocol: "http", }, }, { - name: "fallback to https", + name: "enforce https", args: args{ - fallBackToHttps: true, + enforceHttps: true, }, want: &http_util.DomainCtx{ InstanceHost: "host.header", @@ -144,7 +144,7 @@ func Test_composeOrigin(t *testing.T) { h: http.Header{ "X-Forwarded-Host": []string{"x-forwarded.host"}, }, - fallBackToHttps: false, + enforceHttps: false, }, want: &http_util.DomainCtx{ InstanceHost: "x-forwarded.host", @@ -157,7 +157,7 @@ func Test_composeOrigin(t *testing.T) { "X-Forwarded-Proto": []string{"https"}, "X-Forwarded-Host": []string{"x-forwarded.host"}, }, - fallBackToHttps: false, + enforceHttps: false, }, want: &http_util.DomainCtx{ InstanceHost: "x-forwarded.host", @@ -170,7 +170,7 @@ func Test_composeOrigin(t *testing.T) { "Forwarded": []string{"host=forwarded.host"}, "X-Forwarded-Host": []string{"x-forwarded.host"}, }, - fallBackToHttps: false, + enforceHttps: false, }, want: &http_util.DomainCtx{ InstanceHost: "forwarded.host", @@ -183,7 +183,7 @@ func Test_composeOrigin(t *testing.T) { "Forwarded": []string{"host=forwarded.host"}, "X-Forwarded-Proto": []string{"https"}, }, - fallBackToHttps: false, + enforceHttps: false, }, want: &http_util.DomainCtx{ InstanceHost: "forwarded.host", @@ -198,10 +198,10 @@ func Test_composeOrigin(t *testing.T) { Host: "host.header", Header: tt.args.h, }, - tt.args.fallBackToHttps, + tt.args.enforceHttps, []string{http_util.Forwarded, http_util.ForwardedFor, http_util.ForwardedHost, http_util.ForwardedProto}, []string{"x-zitadel-public-host"}, - ), "headers: %+v, fallBackToHttps: %t", tt.args.h, tt.args.fallBackToHttps) + ), "headers: %+v, enforceHttps: %t", tt.args.h, tt.args.enforceHttps) }) } }