From 4552fe7d99bd4287280c66f79c3addc7d6751320 Mon Sep 17 00:00:00 2001 From: Livio Spring Date: Fri, 7 Jun 2024 09:30:04 +0200 Subject: [PATCH] fix: potential panics in login and return proper http 405 (#8065) # Which Problems Are Solved We identified some parts in the code, which could panic with a nil pointer when accessed without auth request. Additionally, if a GRPC method was called with an unmapped HTTP method, e.g. POST instead of GET a 501 instead of a 405 was returned. # How the Problems Are Solved - Additional checks for existing authRequest - custom http status code mapper for gateway # Additional Changes None. # Additional Context - noted internally in OPS (cherry picked from commit 26c7d95c882295a012dbdd2963b13dbad6272527) --- internal/api/grpc/server/gateway.go | 20 +++++++++++++++++++ .../api/ui/login/init_password_handler.go | 4 +++- internal/api/ui/login/login_handler.go | 2 +- internal/api/ui/login/mail_verify_handler.go | 5 +++-- .../passwordless_registration_handler.go | 4 ++-- 5 files changed, 29 insertions(+), 6 deletions(-) diff --git a/internal/api/grpc/server/gateway.go b/internal/api/grpc/server/gateway.go index 6a3ac94bad..327865bd6c 100644 --- a/internal/api/grpc/server/gateway.go +++ b/internal/api/grpc/server/gateway.go @@ -10,9 +10,11 @@ import ( "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/zitadel/logging" "google.golang.org/grpc" + "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" healthpb "google.golang.org/grpc/health/grpc_health_v1" + "google.golang.org/grpc/status" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" @@ -36,6 +38,23 @@ var ( }, } + httpErrorHandler = runtime.RoutingErrorHandlerFunc( + func(ctx context.Context, mux *runtime.ServeMux, marshaler runtime.Marshaler, w http.ResponseWriter, r *http.Request, httpStatus int) { + if httpStatus != http.StatusMethodNotAllowed { + runtime.DefaultRoutingErrorHandler(ctx, mux, marshaler, w, r, httpStatus) + return + } + + // Use HTTPStatusError to customize the DefaultHTTPErrorHandler status code + err := &runtime.HTTPStatusError{ + HTTPStatus: httpStatus, + Err: status.Error(codes.Unimplemented, http.StatusText(httpStatus)), + } + + runtime.DefaultHTTPErrorHandler(ctx, mux, marshaler, w, r, err) + }, + ) + serveMuxOptions = []runtime.ServeMuxOption{ runtime.WithMarshalerOption(jsonMarshaler.ContentType(nil), jsonMarshaler), runtime.WithMarshalerOption(mimeWildcard, jsonMarshaler), @@ -43,6 +62,7 @@ var ( runtime.WithIncomingHeaderMatcher(headerMatcher), runtime.WithOutgoingHeaderMatcher(runtime.DefaultHeaderMatcher), runtime.WithForwardResponseOption(responseForwarder), + runtime.WithRoutingErrorHandler(httpErrorHandler), } headerMatcher = runtime.HeaderMatcherFunc( diff --git a/internal/api/ui/login/init_password_handler.go b/internal/api/ui/login/init_password_handler.go index d57d1f83ff..91a197ef64 100644 --- a/internal/api/ui/login/init_password_handler.go +++ b/internal/api/ui/login/init_password_handler.go @@ -91,16 +91,18 @@ func (l *Login) checkPWCode(w http.ResponseWriter, r *http.Request, authReq *dom func (l *Login) resendPasswordSet(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, data *initPasswordFormData) { userOrg := data.OrgID userID := data.UserID + var authReqID string if authReq != nil { userOrg = authReq.UserOrgID userID = authReq.UserID + authReqID = authReq.ID } passwordCodeGenerator, err := l.query.InitEncryptionGenerator(r.Context(), domain.SecretGeneratorTypePasswordResetCode, l.userCodeAlg) if err != nil { l.renderInitPassword(w, r, authReq, userID, "", err) return } - _, err = l.command.RequestSetPassword(setContext(r.Context(), userOrg), userID, userOrg, domain.NotificationTypeEmail, passwordCodeGenerator, authReq.ID) + _, err = l.command.RequestSetPassword(setContext(r.Context(), userOrg), userID, userOrg, domain.NotificationTypeEmail, passwordCodeGenerator, authReqID) l.renderInitPassword(w, r, authReq, userID, "", err) } diff --git a/internal/api/ui/login/login_handler.go b/internal/api/ui/login/login_handler.go index ae21d84d87..059048eecb 100644 --- a/internal/api/ui/login/login_handler.go +++ b/internal/api/ui/login/login_handler.go @@ -69,7 +69,7 @@ func (l *Login) handleLoginNameCheck(w http.ResponseWriter, r *http.Request) { return } if data.Register { - if authReq.LoginPolicy != nil && authReq.LoginPolicy.AllowExternalIDP && authReq.AllowedExternalIDPs != nil && len(authReq.AllowedExternalIDPs) > 0 { + if authReq != nil && authReq.LoginPolicy != nil && authReq.LoginPolicy.AllowExternalIDP && authReq.AllowedExternalIDPs != nil && len(authReq.AllowedExternalIDPs) > 0 { l.handleRegisterOption(w, r) return } diff --git a/internal/api/ui/login/mail_verify_handler.go b/internal/api/ui/login/mail_verify_handler.go index 327b8a1182..aa14969008 100644 --- a/internal/api/ui/login/mail_verify_handler.go +++ b/internal/api/ui/login/mail_verify_handler.go @@ -58,16 +58,17 @@ func (l *Login) handleMailVerificationCheck(w http.ResponseWriter, r *http.Reque l.checkMailCode(w, r, authReq, data.UserID, data.Code) return } - userOrg := "" + var userOrg, authReqID string if authReq != nil { userOrg = authReq.UserOrgID + authReqID = authReq.ID } emailCodeGenerator, err := l.query.InitEncryptionGenerator(r.Context(), domain.SecretGeneratorTypeVerifyEmailCode, l.userCodeAlg) if err != nil { l.checkMailCode(w, r, authReq, data.UserID, data.Code) return } - _, err = l.command.CreateHumanEmailVerificationCode(setContext(r.Context(), userOrg), data.UserID, userOrg, emailCodeGenerator, authReq.ID) + _, err = l.command.CreateHumanEmailVerificationCode(setContext(r.Context(), userOrg), data.UserID, userOrg, emailCodeGenerator, authReqID) l.renderMailVerification(w, r, authReq, data.UserID, err) } diff --git a/internal/api/ui/login/passwordless_registration_handler.go b/internal/api/ui/login/passwordless_registration_handler.go index 0346374cee..976a9277b2 100644 --- a/internal/api/ui/login/passwordless_registration_handler.go +++ b/internal/api/ui/login/passwordless_registration_handler.go @@ -114,11 +114,11 @@ func (l *Login) renderPasswordlessRegistration(w http.ResponseWriter, r *http.Re } if authReq == nil { policy, err := l.query.ActiveLabelPolicyByOrg(r.Context(), orgID, false) - logging.Log("HANDL-XjWKE").OnError(err).Error("unable to get active label policy") + logging.OnError(err).Error("unable to get active label policy") data.LabelPolicy = labelPolicyToDomain(policy) if err == nil { texts, err := l.authRepo.GetLoginText(r.Context(), orgID) - logging.Log("LOGIN-HJK4t").OnError(err).Warn("could not get custom texts") + logging.OnError(err).Warn("could not get custom texts") l.addLoginTranslations(translator, texts) } }