mirror of
https://github.com/zitadel/zitadel.git
synced 2025-01-10 18:13:55 +00:00
100 lines
3.0 KiB
Go
100 lines
3.0 KiB
Go
|
package saml
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"encoding/base64"
|
||
|
"net/url"
|
||
|
|
||
|
"github.com/zitadel/saml/pkg/provider"
|
||
|
"github.com/zitadel/saml/pkg/provider/models"
|
||
|
"github.com/zitadel/saml/pkg/provider/xml"
|
||
|
|
||
|
"github.com/zitadel/zitadel/internal/api/authz"
|
||
|
"github.com/zitadel/zitadel/internal/command"
|
||
|
"github.com/zitadel/zitadel/internal/domain"
|
||
|
)
|
||
|
|
||
|
func (p *Provider) CreateErrorResponse(authReq models.AuthRequestInt, reason domain.SAMLErrorReason, description string) (string, string, error) {
|
||
|
resp := &provider.Response{
|
||
|
ProtocolBinding: authReq.GetBindingType(),
|
||
|
RelayState: authReq.GetRelayState(),
|
||
|
AcsUrl: authReq.GetAccessConsumerServiceURL(),
|
||
|
RequestID: authReq.GetAuthRequestID(),
|
||
|
Issuer: authReq.GetDestination(),
|
||
|
Audience: authReq.GetIssuer(),
|
||
|
}
|
||
|
return createResponse(p.AuthCallbackErrorResponse(resp, domain.SAMLErrorReasonToString(reason), description), authReq.GetBindingType(), authReq.GetAccessConsumerServiceURL(), resp.RelayState, resp.SigAlg, resp.Signature)
|
||
|
}
|
||
|
|
||
|
func (p *Provider) CreateResponse(ctx context.Context, authReq models.AuthRequestInt) (string, string, error) {
|
||
|
resp := &provider.Response{
|
||
|
ProtocolBinding: authReq.GetBindingType(),
|
||
|
RelayState: authReq.GetRelayState(),
|
||
|
AcsUrl: authReq.GetAccessConsumerServiceURL(),
|
||
|
RequestID: authReq.GetAuthRequestID(),
|
||
|
Issuer: authReq.GetDestination(),
|
||
|
Audience: authReq.GetIssuer(),
|
||
|
}
|
||
|
samlResponse, err := p.AuthCallbackResponse(ctx, authReq, resp)
|
||
|
if err != nil {
|
||
|
return "", "", err
|
||
|
}
|
||
|
|
||
|
if err := p.command.CreateSAMLSessionFromSAMLRequest(
|
||
|
setContextUserSystem(ctx),
|
||
|
authReq.GetID(),
|
||
|
samlComplianceChecker(),
|
||
|
samlResponse.Id,
|
||
|
p.Expiration(),
|
||
|
); err != nil {
|
||
|
return "", "", err
|
||
|
}
|
||
|
|
||
|
return createResponse(samlResponse, authReq.GetBindingType(), authReq.GetAccessConsumerServiceURL(), resp.RelayState, resp.SigAlg, resp.Signature)
|
||
|
}
|
||
|
|
||
|
func createResponse(samlResponse interface{}, binding, acs, relayState, sigAlg, sig string) (string, string, error) {
|
||
|
respData, err := xml.Marshal(samlResponse)
|
||
|
if err != nil {
|
||
|
return "", "", err
|
||
|
}
|
||
|
|
||
|
switch binding {
|
||
|
case provider.PostBinding:
|
||
|
return acs, base64.StdEncoding.EncodeToString(respData), nil
|
||
|
case provider.RedirectBinding:
|
||
|
respData, err := xml.DeflateAndBase64(respData)
|
||
|
if err != nil {
|
||
|
return "", "", err
|
||
|
}
|
||
|
parsed, err := url.Parse(acs)
|
||
|
if err != nil {
|
||
|
return "", "", err
|
||
|
}
|
||
|
values := parsed.Query()
|
||
|
values.Add("SAMLResponse", string(respData))
|
||
|
values.Add("RelayState", relayState)
|
||
|
values.Add("SigAlg", sigAlg)
|
||
|
values.Add("Signature", sig)
|
||
|
parsed.RawQuery = values.Encode()
|
||
|
return parsed.String(), "", nil
|
||
|
}
|
||
|
return "", "", nil
|
||
|
}
|
||
|
|
||
|
func setContextUserSystem(ctx context.Context) context.Context {
|
||
|
data := authz.CtxData{
|
||
|
UserID: "SYSTEM",
|
||
|
}
|
||
|
return authz.SetCtxData(ctx, data)
|
||
|
}
|
||
|
|
||
|
func samlComplianceChecker() command.SAMLRequestComplianceChecker {
|
||
|
return func(_ context.Context, samlReq *command.SAMLRequestWriteModel) error {
|
||
|
if err := samlReq.CheckAuthenticated(); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
}
|