mirror of
https://github.com/zitadel/zitadel.git
synced 2025-04-17 03:51:30 +00:00
133 lines
5.1 KiB
Go
133 lines
5.1 KiB
Go
package saml
|
|
|
|
import (
|
|
"context"
|
|
|
|
"github.com/zitadel/logging"
|
|
"github.com/zitadel/oidc/v3/pkg/op"
|
|
"github.com/zitadel/saml/pkg/provider"
|
|
"google.golang.org/protobuf/types/known/timestamppb"
|
|
|
|
"github.com/zitadel/zitadel/internal/api/grpc/object/v2"
|
|
"github.com/zitadel/zitadel/internal/api/http"
|
|
"github.com/zitadel/zitadel/internal/api/oidc"
|
|
"github.com/zitadel/zitadel/internal/api/saml"
|
|
"github.com/zitadel/zitadel/internal/domain"
|
|
"github.com/zitadel/zitadel/internal/query"
|
|
"github.com/zitadel/zitadel/internal/zerrors"
|
|
saml_pb "github.com/zitadel/zitadel/pkg/grpc/saml/v2"
|
|
)
|
|
|
|
func (s *Server) GetAuthRequest(ctx context.Context, req *saml_pb.GetSAMLRequestRequest) (*saml_pb.GetSAMLRequestResponse, error) {
|
|
authRequest, err := s.query.AuthRequestByID(ctx, true, req.GetSamlRequestId(), true)
|
|
if err != nil {
|
|
logging.WithError(err).Error("query samlRequest by ID")
|
|
return nil, err
|
|
}
|
|
return &saml_pb.GetSAMLRequestResponse{
|
|
SamlRequest: samlRequestToPb(authRequest),
|
|
}, nil
|
|
}
|
|
|
|
func samlRequestToPb(a *query.AuthRequest) *saml_pb.SAMLRequest {
|
|
return &saml_pb.SAMLRequest{
|
|
Id: a.ID,
|
|
CreationDate: timestamppb.New(a.CreationDate),
|
|
}
|
|
}
|
|
|
|
func (s *Server) CreateCallback(ctx context.Context, req *saml_pb.CreateCallbackRequest) (*saml_pb.CreateCallbackResponse, error) {
|
|
switch v := req.GetCallbackKind().(type) {
|
|
case *saml_pb.CreateCallbackRequest_Error:
|
|
return s.failAuthRequest(ctx, req.GetSamlRequestId(), v.Error)
|
|
case *saml_pb.CreateCallbackRequest_Session:
|
|
return s.linkSessionToAuthRequest(ctx, req.GetSamlRequestId(), v.Session)
|
|
default:
|
|
return nil, zerrors.ThrowUnimplementedf(nil, "OIDCv2-zee7A", "verification oneOf %T in method CreateCallback not implemented", v)
|
|
}
|
|
}
|
|
|
|
func (s *Server) failAuthRequest(ctx context.Context, samlRequestID string, ae *saml_pb.AuthorizationError) (*saml_pb.CreateCallbackResponse, error) {
|
|
details, aar, err := s.command.FailSAMLRequest(ctx, samlRequestID, errorReasonToDomain(ae.GetError()))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
authReq := &saml.AuthRequestV2{CurrentSAMLRequest: aar}
|
|
callback, err := saml.CreateErrorCallbackURL(authReq, errorReasonToSAML(ae.GetError()), ae.GetErrorDescription(), ae.GetErrorUri())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &saml_pb.CreateCallbackResponse{
|
|
Details: object.DomainToDetailsPb(details),
|
|
CallbackUrl: callback,
|
|
}, nil
|
|
}
|
|
|
|
func (s *Server) linkSessionToAuthRequest(ctx context.Context, samlRequestID string, session *saml_pb.Session) (*saml_pb.CreateCallbackResponse, error) {
|
|
details, aar, err := s.command.LinkSessionToAuthRequest(ctx, samlRequestID, session.GetSessionId(), session.GetSessionToken(), true)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
authReq := &oidc.AuthRequestV2{CurrentAuthRequest: aar}
|
|
ctx = op.ContextWithIssuer(ctx, http.DomainContext(ctx).Origin())
|
|
var callback string
|
|
if aar.ResponseType == domain.OIDCResponseTypeCode {
|
|
callback, err = oidc.CreateCodeCallbackURL(ctx, authReq, s.op.Provider())
|
|
} else {
|
|
callback, err = s.op.CreateTokenCallbackURL(ctx, authReq)
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &saml_pb.CreateCallbackResponse{
|
|
Details: object.DomainToDetailsPb(details),
|
|
CallbackUrl: callback,
|
|
}, nil
|
|
}
|
|
|
|
func errorReasonToDomain(errorReason saml_pb.ErrorReason) domain.SAMLErrorReason {
|
|
switch errorReason {
|
|
case saml_pb.ErrorReason_ERROR_REASON_UNSPECIFIED:
|
|
return domain.SAMLErrorReasonUnspecified
|
|
case saml_pb.ErrorReason_ERROR_REASON_VERSION_MISSMATCH:
|
|
return domain.SAMLErrorReasonVersionMissmatch
|
|
case saml_pb.ErrorReason_ERROR_REASON_AUTH_N_FAILED:
|
|
return domain.SAMLErrorReasonAuthNFailed
|
|
case saml_pb.ErrorReason_ERROR_REASON_INVALID_ATTR_NAME_OR_VALUE:
|
|
return domain.SAMLErrorReasonInvalidAttrNameOrValue
|
|
case saml_pb.ErrorReason_ERROR_REASON_INVALID_NAMEID_POLICY:
|
|
return domain.SAMLErrorReasonInvalidNameIDPolicy
|
|
case saml_pb.ErrorReason_ERROR_REASON_REQUEST_DENIED:
|
|
return domain.SAMLErrorReasonRequestDenied
|
|
case saml_pb.ErrorReason_ERROR_REASON_REQUEST_UNSUPPORTED:
|
|
return domain.SAMLErrorReasonRequestUnsupported
|
|
case saml_pb.ErrorReason_ERROR_REASON_UNSUPPORTED_BINDING:
|
|
return domain.SAMLErrorReasonUnsupportedBinding
|
|
default:
|
|
return domain.SAMLErrorReasonUnspecified
|
|
}
|
|
}
|
|
|
|
func errorReasonToSAML(reason saml_pb.ErrorReason) string {
|
|
switch reason {
|
|
case saml_pb.ErrorReason_ERROR_REASON_UNSPECIFIED:
|
|
return "unspecified error"
|
|
case saml_pb.ErrorReason_ERROR_REASON_VERSION_MISSMATCH:
|
|
return provider.StatusCodeVersionMissmatch
|
|
case saml_pb.ErrorReason_ERROR_REASON_AUTH_N_FAILED:
|
|
return provider.StatusCodeAuthNFailed
|
|
case saml_pb.ErrorReason_ERROR_REASON_INVALID_ATTR_NAME_OR_VALUE:
|
|
return provider.StatusCodeInvalidAttrNameOrValue
|
|
case saml_pb.ErrorReason_ERROR_REASON_INVALID_NAMEID_POLICY:
|
|
return provider.StatusCodeInvalidNameIDPolicy
|
|
case saml_pb.ErrorReason_ERROR_REASON_REQUEST_DENIED:
|
|
return provider.StatusCodeRequestDenied
|
|
case saml_pb.ErrorReason_ERROR_REASON_REQUEST_UNSUPPORTED:
|
|
return provider.StatusCodeRequestUnsupported
|
|
case saml_pb.ErrorReason_ERROR_REASON_UNSUPPORTED_BINDING:
|
|
return provider.StatusCodeUnsupportedBinding
|
|
default:
|
|
return "unspecified error"
|
|
}
|
|
}
|