diff --git a/cmd/start/start.go b/cmd/start/start.go index d75a3bb3d3..3e095eb2b3 100644 --- a/cmd/start/start.go +++ b/cmd/start/start.go @@ -325,7 +325,7 @@ func startAPIs( } oidcPrefixes := []string{"/.well-known/openid-configuration", "/oidc/v1", "/oauth/v2"} // always set the origin in the context if available in the http headers, no matter for what protocol - router.Use(middleware.OriginHandler) + router.Use(middleware.WithOrigin(config.ExternalSecure)) systemTokenVerifier, err := internal_authz.StartSystemTokenVerifierFromConfig(http_util.BuildHTTP(config.ExternalDomain, config.ExternalPort, config.ExternalSecure), config.SystemAPIUsers) if err != nil { return err diff --git a/internal/api/http/middleware/origin_interceptor.go b/internal/api/http/middleware/origin_interceptor.go index da03145ab0..02a67ad05d 100644 --- a/internal/api/http/middleware/origin_interceptor.go +++ b/internal/api/http/middleware/origin_interceptor.go @@ -4,25 +4,28 @@ import ( "fmt" "net/http" + "github.com/gorilla/mux" "github.com/muhlemmer/httpforwarded" "github.com/zitadel/logging" http_util "github.com/zitadel/zitadel/internal/api/http" ) -func OriginHandler(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - origin := composeOrigin(r) - if !http_util.IsOrigin(origin) { - logging.Debugf("extracted origin is not valid: %s", origin) - next.ServeHTTP(w, r) - return - } - next.ServeHTTP(w, r.WithContext(http_util.WithComposedOrigin(r.Context(), origin))) - }) +func WithOrigin(fallBackToHttps bool) mux.MiddlewareFunc { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + origin := composeOrigin(r, fallBackToHttps) + if !http_util.IsOrigin(origin) { + logging.Debugf("extracted origin is not valid: %s", origin) + next.ServeHTTP(w, r) + return + } + next.ServeHTTP(w, r.WithContext(http_util.WithComposedOrigin(r.Context(), origin))) + }) + } } -func composeOrigin(r *http.Request) string { +func composeOrigin(r *http.Request, fallBackToHttps bool) string { var proto, host string fwd, fwdErr := httpforwarded.ParseFromRequest(r) if fwdErr == nil { @@ -36,9 +39,8 @@ func composeOrigin(r *http.Request) string { host = r.Header.Get("X-Forwarded-Host") } if proto == "" { - if r.TLS == nil { - proto = "http" - } else { + proto = "http" + if fallBackToHttps { proto = "https" } } diff --git a/internal/api/http/middleware/origin_interceptor_test.go b/internal/api/http/middleware/origin_interceptor_test.go index e8e4f306b8..31b2136b58 100644 --- a/internal/api/http/middleware/origin_interceptor_test.go +++ b/internal/api/http/middleware/origin_interceptor_test.go @@ -9,7 +9,8 @@ import ( func Test_composeOrigin(t *testing.T) { type args struct { - h http.Header + h http.Header + fallBackToHttps bool } tests := []struct { name string @@ -24,6 +25,7 @@ func Test_composeOrigin(t *testing.T) { h: http.Header{ "Forwarded": []string{"proto=https"}, }, + fallBackToHttps: false, }, want: "https://host.header", }, { @@ -32,6 +34,7 @@ func Test_composeOrigin(t *testing.T) { h: http.Header{ "Forwarded": []string{"host=forwarded.host"}, }, + fallBackToHttps: false, }, want: "http://forwarded.host", }, { @@ -40,6 +43,7 @@ func Test_composeOrigin(t *testing.T) { h: http.Header{ "Forwarded": []string{"proto=https;host=forwarded.host"}, }, + fallBackToHttps: false, }, want: "https://forwarded.host", }, { @@ -48,6 +52,7 @@ func Test_composeOrigin(t *testing.T) { h: http.Header{ "Forwarded": []string{"proto=https;host=forwarded.host, proto=http;host=forwarded.host2"}, }, + fallBackToHttps: false, }, want: "https://forwarded.host", }, { @@ -56,6 +61,7 @@ func Test_composeOrigin(t *testing.T) { h: http.Header{ "Forwarded": []string{"proto=https;host=forwarded.host, proto=http"}, }, + fallBackToHttps: false, }, want: "https://forwarded.host", }, { @@ -64,14 +70,37 @@ func Test_composeOrigin(t *testing.T) { h: http.Header{ "Forwarded": []string{"proto=http", "proto=https;host=forwarded.host", "proto=http"}, }, + fallBackToHttps: true, }, want: "http://forwarded.host", }, { - name: "x-forwarded-proto", + name: "x-forwarded-proto https", args: args{ h: http.Header{ "X-Forwarded-Proto": []string{"https"}, }, + fallBackToHttps: false, + }, + want: "https://host.header", + }, { + name: "x-forwarded-proto http", + args: args{ + h: http.Header{ + "X-Forwarded-Proto": []string{"http"}, + }, + fallBackToHttps: true, + }, + want: "http://host.header", + }, { + name: "fallback to http", + args: args{ + fallBackToHttps: false, + }, + want: "http://host.header", + }, { + name: "fallback to https", + args: args{ + fallBackToHttps: true, }, want: "https://host.header", }, { @@ -80,6 +109,7 @@ func Test_composeOrigin(t *testing.T) { h: http.Header{ "X-Forwarded-Host": []string{"x-forwarded.host"}, }, + fallBackToHttps: false, }, want: "http://x-forwarded.host", }, { @@ -89,6 +119,7 @@ func Test_composeOrigin(t *testing.T) { "X-Forwarded-Proto": []string{"https"}, "X-Forwarded-Host": []string{"x-forwarded.host"}, }, + fallBackToHttps: false, }, want: "https://x-forwarded.host", }, { @@ -98,6 +129,7 @@ func Test_composeOrigin(t *testing.T) { "Forwarded": []string{"host=forwarded.host"}, "X-Forwarded-Host": []string{"x-forwarded.host"}, }, + fallBackToHttps: false, }, want: "http://forwarded.host", }, { @@ -107,16 +139,20 @@ func Test_composeOrigin(t *testing.T) { "Forwarded": []string{"host=forwarded.host"}, "X-Forwarded-Proto": []string{"https"}, }, + fallBackToHttps: false, }, want: "https://forwarded.host", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assert.Equalf(t, tt.want, composeOrigin(&http.Request{ - Host: "host.header", - Header: tt.args.h, - }), "headers: %+v", tt.args.h) + assert.Equalf(t, tt.want, composeOrigin( + &http.Request{ + Host: "host.header", + Header: tt.args.h, + }, + tt.args.fallBackToHttps, + ), "headers: %+v, fallBackToHttps: %t", tt.args.h, tt.args.fallBackToHttps) }) } }