diff --git a/internal/api/http/middleware/origin_interceptor.go b/internal/api/http/middleware/origin_interceptor.go index 2cf9a644f5..da03145ab0 100644 --- a/internal/api/http/middleware/origin_interceptor.go +++ b/internal/api/http/middleware/origin_interceptor.go @@ -3,7 +3,6 @@ package middleware import ( "fmt" "net/http" - "net/url" "github.com/muhlemmer/httpforwarded" "github.com/zitadel/logging" @@ -24,74 +23,32 @@ func OriginHandler(next http.Handler) http.Handler { } func composeOrigin(r *http.Request) string { - if origin, err := originFromForwardedHeader(r); err != nil { - logging.OnError(err).Debug("failed to build origin from forwarded header, trying x-forwarded-* headers") - } else { - return origin + var proto, host string + fwd, fwdErr := httpforwarded.ParseFromRequest(r) + if fwdErr == nil { + proto = oldestForwardedValue(fwd, "proto") + host = oldestForwardedValue(fwd, "host") } - if origin, err := originFromXForwardedHeaders(r); err != nil { - logging.OnError(err).Debug("failed to build origin from x-forwarded-* headers, using host header") - } else { - return origin + if proto == "" { + proto = r.Header.Get("X-Forwarded-Proto") } - scheme := "https" - if r.TLS == nil { - scheme = "http" - } - return fmt.Sprintf("%s://%s", scheme, r.Host) -} - -func originFromForwardedHeader(r *http.Request) (string, error) { - fwd, err := httpforwarded.ParseFromRequest(r) - if err != nil { - return "", err - } - var fwdProto, fwdHost, fwdPort string - if fwdProto = mostRecentValue(fwd, "proto"); fwdProto == "" { - return "", fmt.Errorf("no proto in forwarded header") - } - if fwdHost = mostRecentValue(fwd, "host"); fwdHost == "" { - return "", fmt.Errorf("no host in forwarded header") - } - fwdPort, foundFwdFor := extractPort(mostRecentValue(fwd, "for")) - if !foundFwdFor { - return "", fmt.Errorf("no for in forwarded header") - } - o := fmt.Sprintf("%s://%s", fwdProto, fwdHost) - if fwdPort != "" { - o += ":" + fwdPort - } - return o, nil -} - -func originFromXForwardedHeaders(r *http.Request) (string, error) { - scheme := r.Header.Get("X-Forwarded-Proto") - if scheme == "" { - return "", fmt.Errorf("no X-Forwarded-Proto header") - } - host := r.Header.Get("X-Forwarded-Host") if host == "" { - return "", fmt.Errorf("no X-Forwarded-Host header") + host = r.Header.Get("X-Forwarded-Host") } - return fmt.Sprintf("%s://%s", scheme, host), nil + if proto == "" { + if r.TLS == nil { + proto = "http" + } else { + proto = "https" + } + } + if host == "" { + host = r.Host + } + return fmt.Sprintf("%s://%s", proto, host) } -func extractPort(raw string) (string, bool) { - if u, ok := parseURL(raw); ok { - return u.Port(), ok - } - return "", false -} - -func parseURL(raw string) (*url.URL, bool) { - if raw == "" { - return nil, false - } - u, err := url.Parse(raw) - return u, err == nil -} - -func mostRecentValue(forwarded map[string][]string, key string) string { +func oldestForwardedValue(forwarded map[string][]string, key string) string { if forwarded == nil { return "" } @@ -99,5 +56,5 @@ func mostRecentValue(forwarded map[string][]string, key string) string { if len(values) == 0 { return "" } - return values[len(values)-1] + return values[0] } diff --git a/internal/api/http/middleware/origin_interceptor_test.go b/internal/api/http/middleware/origin_interceptor_test.go new file mode 100644 index 0000000000..e8e4f306b8 --- /dev/null +++ b/internal/api/http/middleware/origin_interceptor_test.go @@ -0,0 +1,122 @@ +package middleware + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_composeOrigin(t *testing.T) { + type args struct { + h http.Header + } + tests := []struct { + name string + args args + want string + }{{ + name: "no proxy headers", + want: "http://host.header", + }, { + name: "forwarded proto", + args: args{ + h: http.Header{ + "Forwarded": []string{"proto=https"}, + }, + }, + want: "https://host.header", + }, { + name: "forwarded host", + args: args{ + h: http.Header{ + "Forwarded": []string{"host=forwarded.host"}, + }, + }, + want: "http://forwarded.host", + }, { + name: "forwarded proto and host", + args: args{ + h: http.Header{ + "Forwarded": []string{"proto=https;host=forwarded.host"}, + }, + }, + want: "https://forwarded.host", + }, { + name: "forwarded proto and host with multiple complete entries", + args: args{ + h: http.Header{ + "Forwarded": []string{"proto=https;host=forwarded.host, proto=http;host=forwarded.host2"}, + }, + }, + want: "https://forwarded.host", + }, { + name: "forwarded proto and host with multiple incomplete entries", + args: args{ + h: http.Header{ + "Forwarded": []string{"proto=https;host=forwarded.host, proto=http"}, + }, + }, + want: "https://forwarded.host", + }, { + name: "forwarded proto and host with incomplete entries in different values", + args: args{ + h: http.Header{ + "Forwarded": []string{"proto=http", "proto=https;host=forwarded.host", "proto=http"}, + }, + }, + want: "http://forwarded.host", + }, { + name: "x-forwarded-proto", + args: args{ + h: http.Header{ + "X-Forwarded-Proto": []string{"https"}, + }, + }, + want: "https://host.header", + }, { + name: "x-forwarded-host", + args: args{ + h: http.Header{ + "X-Forwarded-Host": []string{"x-forwarded.host"}, + }, + }, + want: "http://x-forwarded.host", + }, { + name: "x-forwarded-proto and x-forwarded-host", + args: args{ + h: http.Header{ + "X-Forwarded-Proto": []string{"https"}, + "X-Forwarded-Host": []string{"x-forwarded.host"}, + }, + }, + want: "https://x-forwarded.host", + }, { + name: "forwarded host and x-forwarded-host", + args: args{ + h: http.Header{ + "Forwarded": []string{"host=forwarded.host"}, + "X-Forwarded-Host": []string{"x-forwarded.host"}, + }, + }, + want: "http://forwarded.host", + }, { + name: "forwarded host and x-forwarded-proto", + args: args{ + h: http.Header{ + "Forwarded": []string{"host=forwarded.host"}, + "X-Forwarded-Proto": []string{"https"}, + }, + }, + want: "https://forwarded.host", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, composeOrigin(&http.Request{ + Host: "host.header", + Header: tt.args.h, + }), "headers: %+v", tt.args.h) + }) + } +}