fix: potential panics in login and return proper http 405 (#8065)

# Which Problems Are Solved

We identified some parts in the code, which could panic with a nil
pointer when accessed without auth request.
Additionally, if a GRPC method was called with an unmapped HTTP method,
e.g. POST instead of GET a 501 instead of a 405 was returned.

# How the Problems Are Solved

- Additional checks for existing authRequest
- custom http status code mapper for gateway

# Additional Changes

None.

# Additional Context

- noted internally in OPS

(cherry picked from commit 26c7d95c88)
This commit is contained in:
Livio Spring 2024-06-07 09:30:04 +02:00
parent 33235a5cbe
commit 4552fe7d99
No known key found for this signature in database
GPG Key ID: 26BB1C2FA5952CF0
5 changed files with 29 additions and 6 deletions

View File

@ -10,9 +10,11 @@ import (
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"github.com/zitadel/logging" "github.com/zitadel/logging"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
healthpb "google.golang.org/grpc/health/grpc_health_v1" healthpb "google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
@ -36,6 +38,23 @@ var (
}, },
} }
httpErrorHandler = runtime.RoutingErrorHandlerFunc(
func(ctx context.Context, mux *runtime.ServeMux, marshaler runtime.Marshaler, w http.ResponseWriter, r *http.Request, httpStatus int) {
if httpStatus != http.StatusMethodNotAllowed {
runtime.DefaultRoutingErrorHandler(ctx, mux, marshaler, w, r, httpStatus)
return
}
// Use HTTPStatusError to customize the DefaultHTTPErrorHandler status code
err := &runtime.HTTPStatusError{
HTTPStatus: httpStatus,
Err: status.Error(codes.Unimplemented, http.StatusText(httpStatus)),
}
runtime.DefaultHTTPErrorHandler(ctx, mux, marshaler, w, r, err)
},
)
serveMuxOptions = []runtime.ServeMuxOption{ serveMuxOptions = []runtime.ServeMuxOption{
runtime.WithMarshalerOption(jsonMarshaler.ContentType(nil), jsonMarshaler), runtime.WithMarshalerOption(jsonMarshaler.ContentType(nil), jsonMarshaler),
runtime.WithMarshalerOption(mimeWildcard, jsonMarshaler), runtime.WithMarshalerOption(mimeWildcard, jsonMarshaler),
@ -43,6 +62,7 @@ var (
runtime.WithIncomingHeaderMatcher(headerMatcher), runtime.WithIncomingHeaderMatcher(headerMatcher),
runtime.WithOutgoingHeaderMatcher(runtime.DefaultHeaderMatcher), runtime.WithOutgoingHeaderMatcher(runtime.DefaultHeaderMatcher),
runtime.WithForwardResponseOption(responseForwarder), runtime.WithForwardResponseOption(responseForwarder),
runtime.WithRoutingErrorHandler(httpErrorHandler),
} }
headerMatcher = runtime.HeaderMatcherFunc( headerMatcher = runtime.HeaderMatcherFunc(

View File

@ -91,16 +91,18 @@ func (l *Login) checkPWCode(w http.ResponseWriter, r *http.Request, authReq *dom
func (l *Login) resendPasswordSet(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, data *initPasswordFormData) { func (l *Login) resendPasswordSet(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, data *initPasswordFormData) {
userOrg := data.OrgID userOrg := data.OrgID
userID := data.UserID userID := data.UserID
var authReqID string
if authReq != nil { if authReq != nil {
userOrg = authReq.UserOrgID userOrg = authReq.UserOrgID
userID = authReq.UserID userID = authReq.UserID
authReqID = authReq.ID
} }
passwordCodeGenerator, err := l.query.InitEncryptionGenerator(r.Context(), domain.SecretGeneratorTypePasswordResetCode, l.userCodeAlg) passwordCodeGenerator, err := l.query.InitEncryptionGenerator(r.Context(), domain.SecretGeneratorTypePasswordResetCode, l.userCodeAlg)
if err != nil { if err != nil {
l.renderInitPassword(w, r, authReq, userID, "", err) l.renderInitPassword(w, r, authReq, userID, "", err)
return return
} }
_, err = l.command.RequestSetPassword(setContext(r.Context(), userOrg), userID, userOrg, domain.NotificationTypeEmail, passwordCodeGenerator, authReq.ID) _, err = l.command.RequestSetPassword(setContext(r.Context(), userOrg), userID, userOrg, domain.NotificationTypeEmail, passwordCodeGenerator, authReqID)
l.renderInitPassword(w, r, authReq, userID, "", err) l.renderInitPassword(w, r, authReq, userID, "", err)
} }

View File

@ -69,7 +69,7 @@ func (l *Login) handleLoginNameCheck(w http.ResponseWriter, r *http.Request) {
return return
} }
if data.Register { if data.Register {
if authReq.LoginPolicy != nil && authReq.LoginPolicy.AllowExternalIDP && authReq.AllowedExternalIDPs != nil && len(authReq.AllowedExternalIDPs) > 0 { if authReq != nil && authReq.LoginPolicy != nil && authReq.LoginPolicy.AllowExternalIDP && authReq.AllowedExternalIDPs != nil && len(authReq.AllowedExternalIDPs) > 0 {
l.handleRegisterOption(w, r) l.handleRegisterOption(w, r)
return return
} }

View File

@ -58,16 +58,17 @@ func (l *Login) handleMailVerificationCheck(w http.ResponseWriter, r *http.Reque
l.checkMailCode(w, r, authReq, data.UserID, data.Code) l.checkMailCode(w, r, authReq, data.UserID, data.Code)
return return
} }
userOrg := "" var userOrg, authReqID string
if authReq != nil { if authReq != nil {
userOrg = authReq.UserOrgID userOrg = authReq.UserOrgID
authReqID = authReq.ID
} }
emailCodeGenerator, err := l.query.InitEncryptionGenerator(r.Context(), domain.SecretGeneratorTypeVerifyEmailCode, l.userCodeAlg) emailCodeGenerator, err := l.query.InitEncryptionGenerator(r.Context(), domain.SecretGeneratorTypeVerifyEmailCode, l.userCodeAlg)
if err != nil { if err != nil {
l.checkMailCode(w, r, authReq, data.UserID, data.Code) l.checkMailCode(w, r, authReq, data.UserID, data.Code)
return return
} }
_, err = l.command.CreateHumanEmailVerificationCode(setContext(r.Context(), userOrg), data.UserID, userOrg, emailCodeGenerator, authReq.ID) _, err = l.command.CreateHumanEmailVerificationCode(setContext(r.Context(), userOrg), data.UserID, userOrg, emailCodeGenerator, authReqID)
l.renderMailVerification(w, r, authReq, data.UserID, err) l.renderMailVerification(w, r, authReq, data.UserID, err)
} }

View File

@ -114,11 +114,11 @@ func (l *Login) renderPasswordlessRegistration(w http.ResponseWriter, r *http.Re
} }
if authReq == nil { if authReq == nil {
policy, err := l.query.ActiveLabelPolicyByOrg(r.Context(), orgID, false) policy, err := l.query.ActiveLabelPolicyByOrg(r.Context(), orgID, false)
logging.Log("HANDL-XjWKE").OnError(err).Error("unable to get active label policy") logging.OnError(err).Error("unable to get active label policy")
data.LabelPolicy = labelPolicyToDomain(policy) data.LabelPolicy = labelPolicyToDomain(policy)
if err == nil { if err == nil {
texts, err := l.authRepo.GetLoginText(r.Context(), orgID) texts, err := l.authRepo.GetLoginText(r.Context(), orgID)
logging.Log("LOGIN-HJK4t").OnError(err).Warn("could not get custom texts") logging.OnError(err).Warn("could not get custom texts")
l.addLoginTranslations(translator, texts) l.addLoginTranslations(translator, texts)
} }
} }