Stuart Douglas 81920e599b
fix(SAML): log underlying error if SAML response validation fails (#8721)
# Which Problems Are Solved

If SAML response validation in crewjam/saml fails, a generic
"Authentication failed" error is thrown. This makes it challenging to
determine the actual cause, since there are a variety of reasons
response validation may fail.

# How the Problems Are Solved

Add a log statement if we receive a response validation error from
crewjam/saml that logs the internal `InvalidResponseError.PrivateErr`
error from crewjam/saml to stdout. We continue to return a generic error
message to the client to prevent leaking data.

Verified by running `go test -v ./internal/idp/providers/saml` in
verbose mode, which output the following line for the "response_invalid"
test case:
```
time="2024-10-03T14:53:10+01:00" level=info msg="invalid SAML response details" caller="/Users/sdouglas/Documents/thirdparty-repos/zitadel/internal/idp/providers/saml/session.go:72" error="cannot parse base64: illegal base64 data at input byte 2"
```

# Additional Changes

None

# Additional Context

- closes #8717

---------

Co-authored-by: Stuart Douglas <sdouglas@hopper.com>
2024-10-11 07:04:15 +00:00

142 lines
3.9 KiB
Go

package saml
import (
"bytes"
"context"
"errors"
"net/http"
"net/url"
"github.com/crewjam/saml"
"github.com/crewjam/saml/samlsp"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/idp"
"github.com/zitadel/zitadel/internal/zerrors"
)
var _ idp.Session = (*Session)(nil)
// Session is the [idp.Session] implementation for the SAML provider.
type Session struct {
ServiceProvider *samlsp.Middleware
state string
TransientMappingAttributeName string
RequestID string
Request *http.Request
Assertion *saml.Assertion
}
func NewSession(provider *Provider, requestID string, request *http.Request) (*Session, error) {
sp, err := provider.GetSP()
if err != nil {
return nil, err
}
return &Session{
ServiceProvider: sp,
TransientMappingAttributeName: provider.TransientMappingAttributeName(),
RequestID: requestID,
Request: request,
}, nil
}
// GetAuth implements the [idp.Session] interface.
func (s *Session) GetAuth(ctx context.Context) (string, bool) {
url, _ := url.Parse(s.state)
resp := NewTempResponseWriter()
request := &http.Request{
URL: url,
}
s.ServiceProvider.HandleStartAuthFlow(
resp,
request.WithContext(ctx),
)
if location := resp.Header().Get("Location"); location != "" {
return idp.Redirect(location)
}
return idp.Form(resp.content.String())
}
// FetchUser implements the [idp.Session] interface.
func (s *Session) FetchUser(ctx context.Context) (user idp.User, err error) {
if s.RequestID == "" || s.Request == nil {
return nil, zerrors.ThrowInvalidArgument(nil, "SAML-d09hy0wkex", "Errors.Intent.ResponseInvalid")
}
s.Assertion, err = s.ServiceProvider.ServiceProvider.ParseResponse(s.Request, []string{s.RequestID})
if err != nil {
invalidRespErr := new(saml.InvalidResponseError)
if errors.As(err, &invalidRespErr) {
logging.WithError(invalidRespErr.PrivateErr).Info("invalid SAML response details")
}
return nil, zerrors.ThrowInvalidArgument(err, "SAML-nuo0vphhh9", "Errors.Intent.ResponseInvalid")
}
// nameID is required, but at least in ADFS it will not be sent unless explicitly configured
if s.Assertion.Subject == nil || s.Assertion.Subject.NameID == nil {
return nil, zerrors.ThrowInvalidArgument(err, "SAML-EFG32", "Errors.Intent.ResponseInvalid")
}
nameID := s.Assertion.Subject.NameID
userMapper := NewUser()
// use the nameID as default mapping id
userMapper.SetID(nameID.Value)
if nameID.Format == string(saml.TransientNameIDFormat) {
mappingID, err := s.transientMappingID()
if err != nil {
return nil, err
}
userMapper.SetID(mappingID)
}
for _, statement := range s.Assertion.AttributeStatements {
for _, attribute := range statement.Attributes {
values := make([]string, len(attribute.Values))
for i := range attribute.Values {
values[i] = attribute.Values[i].Value
}
userMapper.Attributes[attribute.Name] = values
}
}
return userMapper, nil
}
func (s *Session) transientMappingID() (string, error) {
for _, statement := range s.Assertion.AttributeStatements {
for _, attribute := range statement.Attributes {
if attribute.Name != s.TransientMappingAttributeName {
continue
}
if len(attribute.Values) != 1 {
return "", zerrors.ThrowInvalidArgument(nil, "SAML-Soij4", "Errors.Intent.MissingSingleMappingAttribute")
}
return attribute.Values[0].Value, nil
}
}
return "", zerrors.ThrowInvalidArgument(nil, "SAML-swwg2", "Errors.Intent.MissingSingleMappingAttribute")
}
type TempResponseWriter struct {
header http.Header
content *bytes.Buffer
}
func (w *TempResponseWriter) Header() http.Header {
return w.header
}
func (w *TempResponseWriter) Write(content []byte) (int, error) {
return w.content.Write(content)
}
func (w *TempResponseWriter) WriteHeader(statusCode int) {}
func NewTempResponseWriter() *TempResponseWriter {
return &TempResponseWriter{
header: map[string][]string{},
content: bytes.NewBuffer([]byte{}),
}
}