package saml import ( "context" "encoding/base64" "net/url" "github.com/zitadel/saml/pkg/provider" "github.com/zitadel/saml/pkg/provider/models" "github.com/zitadel/saml/pkg/provider/xml" "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/domain" ) func (p *Provider) CreateErrorResponse(authReq models.AuthRequestInt, reason domain.SAMLErrorReason, description string) (string, string, error) { resp := &provider.Response{ ProtocolBinding: authReq.GetBindingType(), RelayState: authReq.GetRelayState(), AcsUrl: authReq.GetAccessConsumerServiceURL(), RequestID: authReq.GetAuthRequestID(), Issuer: authReq.GetDestination(), Audience: authReq.GetIssuer(), } return createResponse(p.AuthCallbackErrorResponse(resp, domain.SAMLErrorReasonToString(reason), description), authReq.GetBindingType(), authReq.GetAccessConsumerServiceURL(), resp.RelayState, resp.SigAlg, resp.Signature) } func (p *Provider) CreateResponse(ctx context.Context, authReq models.AuthRequestInt) (string, string, error) { resp := &provider.Response{ ProtocolBinding: authReq.GetBindingType(), RelayState: authReq.GetRelayState(), AcsUrl: authReq.GetAccessConsumerServiceURL(), RequestID: authReq.GetAuthRequestID(), Issuer: authReq.GetDestination(), Audience: authReq.GetIssuer(), } samlResponse, err := p.AuthCallbackResponse(ctx, authReq, resp) if err != nil { return "", "", err } if err := p.command.CreateSAMLSessionFromSAMLRequest( setContextUserSystem(ctx), authReq.GetID(), samlComplianceChecker(), samlResponse.Id, p.Expiration(), ); err != nil { return "", "", err } return createResponse(samlResponse, authReq.GetBindingType(), authReq.GetAccessConsumerServiceURL(), resp.RelayState, resp.SigAlg, resp.Signature) } func createResponse(samlResponse interface{}, binding, acs, relayState, sigAlg, sig string) (string, string, error) { respData, err := xml.Marshal(samlResponse) if err != nil { return "", "", err } switch binding { case provider.PostBinding: return acs, base64.StdEncoding.EncodeToString(respData), nil case provider.RedirectBinding: respData, err := xml.DeflateAndBase64(respData) if err != nil { return "", "", err } parsed, err := url.Parse(acs) if err != nil { return "", "", err } values := parsed.Query() values.Add("SAMLResponse", string(respData)) values.Add("RelayState", relayState) values.Add("SigAlg", sigAlg) values.Add("Signature", sig) parsed.RawQuery = values.Encode() return parsed.String(), "", nil } return "", "", nil } func setContextUserSystem(ctx context.Context) context.Context { data := authz.CtxData{ UserID: "SYSTEM", } return authz.SetCtxData(ctx, data) } func samlComplianceChecker() command.SAMLRequestComplianceChecker { return func(_ context.Context, samlReq *command.SAMLRequestWriteModel) error { if err := samlReq.CheckAuthenticated(); err != nil { return err } return nil } }