133 lines
5.1 KiB
Go

package saml
import (
"context"
"github.com/zitadel/logging"
"github.com/zitadel/oidc/v3/pkg/op"
"github.com/zitadel/saml/pkg/provider"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/zitadel/zitadel/internal/api/grpc/object/v2"
"github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/api/oidc"
"github.com/zitadel/zitadel/internal/api/saml"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/zerrors"
saml_pb "github.com/zitadel/zitadel/pkg/grpc/saml/v2"
)
func (s *Server) GetAuthRequest(ctx context.Context, req *saml_pb.GetSAMLRequestRequest) (*saml_pb.GetSAMLRequestResponse, error) {
authRequest, err := s.query.AuthRequestByID(ctx, true, req.GetSamlRequestId(), true)
if err != nil {
logging.WithError(err).Error("query samlRequest by ID")
return nil, err
}
return &saml_pb.GetSAMLRequestResponse{
SamlRequest: samlRequestToPb(authRequest),
}, nil
}
func samlRequestToPb(a *query.AuthRequest) *saml_pb.SAMLRequest {
return &saml_pb.SAMLRequest{
Id: a.ID,
CreationDate: timestamppb.New(a.CreationDate),
}
}
func (s *Server) CreateCallback(ctx context.Context, req *saml_pb.CreateCallbackRequest) (*saml_pb.CreateCallbackResponse, error) {
switch v := req.GetCallbackKind().(type) {
case *saml_pb.CreateCallbackRequest_Error:
return s.failAuthRequest(ctx, req.GetSamlRequestId(), v.Error)
case *saml_pb.CreateCallbackRequest_Session:
return s.linkSessionToAuthRequest(ctx, req.GetSamlRequestId(), v.Session)
default:
return nil, zerrors.ThrowUnimplementedf(nil, "OIDCv2-zee7A", "verification oneOf %T in method CreateCallback not implemented", v)
}
}
func (s *Server) failAuthRequest(ctx context.Context, samlRequestID string, ae *saml_pb.AuthorizationError) (*saml_pb.CreateCallbackResponse, error) {
details, aar, err := s.command.FailSAMLRequest(ctx, samlRequestID, errorReasonToDomain(ae.GetError()))
if err != nil {
return nil, err
}
authReq := &saml.AuthRequestV2{CurrentSAMLRequest: aar}
callback, err := saml.CreateErrorCallbackURL(authReq, errorReasonToSAML(ae.GetError()), ae.GetErrorDescription(), ae.GetErrorUri())
if err != nil {
return nil, err
}
return &saml_pb.CreateCallbackResponse{
Details: object.DomainToDetailsPb(details),
CallbackUrl: callback,
}, nil
}
func (s *Server) linkSessionToAuthRequest(ctx context.Context, samlRequestID string, session *saml_pb.Session) (*saml_pb.CreateCallbackResponse, error) {
details, aar, err := s.command.LinkSessionToAuthRequest(ctx, samlRequestID, session.GetSessionId(), session.GetSessionToken(), true)
if err != nil {
return nil, err
}
authReq := &oidc.AuthRequestV2{CurrentAuthRequest: aar}
ctx = op.ContextWithIssuer(ctx, http.DomainContext(ctx).Origin())
var callback string
if aar.ResponseType == domain.OIDCResponseTypeCode {
callback, err = oidc.CreateCodeCallbackURL(ctx, authReq, s.op.Provider())
} else {
callback, err = s.op.CreateTokenCallbackURL(ctx, authReq)
}
if err != nil {
return nil, err
}
return &saml_pb.CreateCallbackResponse{
Details: object.DomainToDetailsPb(details),
CallbackUrl: callback,
}, nil
}
func errorReasonToDomain(errorReason saml_pb.ErrorReason) domain.SAMLErrorReason {
switch errorReason {
case saml_pb.ErrorReason_ERROR_REASON_UNSPECIFIED:
return domain.SAMLErrorReasonUnspecified
case saml_pb.ErrorReason_ERROR_REASON_VERSION_MISSMATCH:
return domain.SAMLErrorReasonVersionMissmatch
case saml_pb.ErrorReason_ERROR_REASON_AUTH_N_FAILED:
return domain.SAMLErrorReasonAuthNFailed
case saml_pb.ErrorReason_ERROR_REASON_INVALID_ATTR_NAME_OR_VALUE:
return domain.SAMLErrorReasonInvalidAttrNameOrValue
case saml_pb.ErrorReason_ERROR_REASON_INVALID_NAMEID_POLICY:
return domain.SAMLErrorReasonInvalidNameIDPolicy
case saml_pb.ErrorReason_ERROR_REASON_REQUEST_DENIED:
return domain.SAMLErrorReasonRequestDenied
case saml_pb.ErrorReason_ERROR_REASON_REQUEST_UNSUPPORTED:
return domain.SAMLErrorReasonRequestUnsupported
case saml_pb.ErrorReason_ERROR_REASON_UNSUPPORTED_BINDING:
return domain.SAMLErrorReasonUnsupportedBinding
default:
return domain.SAMLErrorReasonUnspecified
}
}
func errorReasonToSAML(reason saml_pb.ErrorReason) string {
switch reason {
case saml_pb.ErrorReason_ERROR_REASON_UNSPECIFIED:
return "unspecified error"
case saml_pb.ErrorReason_ERROR_REASON_VERSION_MISSMATCH:
return provider.StatusCodeVersionMissmatch
case saml_pb.ErrorReason_ERROR_REASON_AUTH_N_FAILED:
return provider.StatusCodeAuthNFailed
case saml_pb.ErrorReason_ERROR_REASON_INVALID_ATTR_NAME_OR_VALUE:
return provider.StatusCodeInvalidAttrNameOrValue
case saml_pb.ErrorReason_ERROR_REASON_INVALID_NAMEID_POLICY:
return provider.StatusCodeInvalidNameIDPolicy
case saml_pb.ErrorReason_ERROR_REASON_REQUEST_DENIED:
return provider.StatusCodeRequestDenied
case saml_pb.ErrorReason_ERROR_REASON_REQUEST_UNSUPPORTED:
return provider.StatusCodeRequestUnsupported
case saml_pb.ErrorReason_ERROR_REASON_UNSUPPORTED_BINDING:
return provider.StatusCodeUnsupportedBinding
default:
return "unspecified error"
}
}