2025-01-09 12:46:36 +01:00
|
|
|
package middleware
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"net/http"
|
2025-07-23 10:47:05 +02:00
|
|
|
"strconv"
|
|
|
|
"strings"
|
|
|
|
|
|
|
|
"github.com/zitadel/logging"
|
2025-01-09 12:46:36 +01:00
|
|
|
|
|
|
|
"github.com/zitadel/zitadel/internal/api/authz"
|
|
|
|
zhttp "github.com/zitadel/zitadel/internal/api/http/middleware"
|
|
|
|
smetadata "github.com/zitadel/zitadel/internal/api/scim/metadata"
|
2025-07-23 10:47:05 +02:00
|
|
|
sresources "github.com/zitadel/zitadel/internal/api/scim/resources"
|
2025-01-09 12:46:36 +01:00
|
|
|
"github.com/zitadel/zitadel/internal/query"
|
|
|
|
"github.com/zitadel/zitadel/internal/zerrors"
|
|
|
|
)
|
|
|
|
|
|
|
|
func ScimContextMiddleware(q *query.Queries) func(next zhttp.HandlerFuncWithError) zhttp.HandlerFuncWithError {
|
|
|
|
return func(next zhttp.HandlerFuncWithError) zhttp.HandlerFuncWithError {
|
|
|
|
return func(w http.ResponseWriter, r *http.Request) error {
|
|
|
|
ctx, err := initScimContext(r.Context(), q)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
return next(w, r.WithContext(ctx))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func initScimContext(ctx context.Context, q *query.Queries) (context.Context, error) {
|
2025-01-29 15:23:56 +01:00
|
|
|
data := smetadata.NewScimContextData()
|
2025-01-09 12:46:36 +01:00
|
|
|
ctx = smetadata.SetScimContextData(ctx, data)
|
|
|
|
|
|
|
|
userID := authz.GetCtxData(ctx).UserID
|
2025-07-23 10:47:05 +02:00
|
|
|
|
|
|
|
// get the provisioningDomain and ignorePasswordOnCreate metadata keys associated with the service user
|
|
|
|
metadataKeys := []smetadata.Key{
|
|
|
|
smetadata.KeyProvisioningDomain,
|
|
|
|
smetadata.KeyIgnorePasswordOnCreate,
|
|
|
|
}
|
|
|
|
queries := sresources.BuildMetadataQueries(ctx, metadataKeys)
|
|
|
|
|
|
|
|
metadataList, err := q.SearchUserMetadata(ctx, false, userID, queries, nil)
|
2025-01-09 12:46:36 +01:00
|
|
|
if err != nil {
|
|
|
|
if zerrors.IsNotFound(err) {
|
|
|
|
return ctx, nil
|
|
|
|
}
|
|
|
|
return ctx, err
|
|
|
|
}
|
|
|
|
|
2025-07-23 10:47:05 +02:00
|
|
|
if metadataList == nil || len(metadataList.Metadata) == 0 {
|
2025-01-09 12:46:36 +01:00
|
|
|
return ctx, nil
|
|
|
|
}
|
|
|
|
|
2025-07-23 10:47:05 +02:00
|
|
|
for _, metadata := range metadataList.Metadata {
|
|
|
|
switch metadata.Key {
|
|
|
|
case string(smetadata.KeyProvisioningDomain):
|
|
|
|
data.ProvisioningDomain = string(metadata.Value)
|
|
|
|
if data.ProvisioningDomain != "" {
|
|
|
|
data.ExternalIDScopedMetadataKey = smetadata.ScopeExternalIdKey(data.ProvisioningDomain)
|
|
|
|
}
|
|
|
|
case string(smetadata.KeyIgnorePasswordOnCreate):
|
|
|
|
ignorePasswordOnCreate, parseErr := strconv.ParseBool(strings.TrimSpace(string(metadata.Value)))
|
|
|
|
if parseErr != nil {
|
|
|
|
return ctx,
|
|
|
|
zerrors.ThrowInvalidArgumentf(nil, "SMCM-yvw2rt", "Invalid value for metadata key %s: %s", smetadata.KeyIgnorePasswordOnCreate, metadata.Value)
|
|
|
|
}
|
|
|
|
data.IgnorePasswordOnCreate = ignorePasswordOnCreate
|
|
|
|
default:
|
|
|
|
logging.WithFields("user_metadata_key", metadata.Key).Warn("unexpected metadata key")
|
|
|
|
}
|
2025-01-09 12:46:36 +01:00
|
|
|
}
|
|
|
|
return smetadata.SetScimContextData(ctx, data), nil
|
|
|
|
}
|