zitadel/internal/api/saml/auth_request.go

100 lines
3.0 KiB
Go
Raw Normal View History

package saml
import (
"context"
"encoding/base64"
"net/url"
"github.com/zitadel/saml/pkg/provider"
"github.com/zitadel/saml/pkg/provider/models"
"github.com/zitadel/saml/pkg/provider/xml"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/domain"
)
func (p *Provider) CreateErrorResponse(authReq models.AuthRequestInt, reason domain.SAMLErrorReason, description string) (string, string, error) {
resp := &provider.Response{
ProtocolBinding: authReq.GetBindingType(),
RelayState: authReq.GetRelayState(),
AcsUrl: authReq.GetAccessConsumerServiceURL(),
RequestID: authReq.GetAuthRequestID(),
Issuer: authReq.GetDestination(),
Audience: authReq.GetIssuer(),
}
return createResponse(p.AuthCallbackErrorResponse(resp, domain.SAMLErrorReasonToString(reason), description), authReq.GetBindingType(), authReq.GetAccessConsumerServiceURL(), resp.RelayState, resp.SigAlg, resp.Signature)
}
func (p *Provider) CreateResponse(ctx context.Context, authReq models.AuthRequestInt) (string, string, error) {
resp := &provider.Response{
ProtocolBinding: authReq.GetBindingType(),
RelayState: authReq.GetRelayState(),
AcsUrl: authReq.GetAccessConsumerServiceURL(),
RequestID: authReq.GetAuthRequestID(),
Issuer: authReq.GetDestination(),
Audience: authReq.GetIssuer(),
}
samlResponse, err := p.AuthCallbackResponse(ctx, authReq, resp)
if err != nil {
return "", "", err
}
if err := p.command.CreateSAMLSessionFromSAMLRequest(
setContextUserSystem(ctx),
authReq.GetID(),
samlComplianceChecker(),
samlResponse.Id,
p.Expiration(),
); err != nil {
return "", "", err
}
return createResponse(samlResponse, authReq.GetBindingType(), authReq.GetAccessConsumerServiceURL(), resp.RelayState, resp.SigAlg, resp.Signature)
}
func createResponse(samlResponse interface{}, binding, acs, relayState, sigAlg, sig string) (string, string, error) {
respData, err := xml.Marshal(samlResponse)
if err != nil {
return "", "", err
}
switch binding {
case provider.PostBinding:
return acs, base64.StdEncoding.EncodeToString(respData), nil
case provider.RedirectBinding:
respData, err := xml.DeflateAndBase64(respData)
if err != nil {
return "", "", err
}
parsed, err := url.Parse(acs)
if err != nil {
return "", "", err
}
values := parsed.Query()
values.Add("SAMLResponse", string(respData))
values.Add("RelayState", relayState)
values.Add("SigAlg", sigAlg)
values.Add("Signature", sig)
parsed.RawQuery = values.Encode()
return parsed.String(), "", nil
}
return "", "", nil
}
func setContextUserSystem(ctx context.Context) context.Context {
data := authz.CtxData{
UserID: "SYSTEM",
}
return authz.SetCtxData(ctx, data)
}
func samlComplianceChecker() command.SAMLRequestComplianceChecker {
return func(_ context.Context, samlReq *command.SAMLRequestWriteModel) error {
if err := samlReq.CheckAuthenticated(); err != nil {
return err
}
return nil
}
}