mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-12 01:37:31 +00:00
feat: create user scim v2 endpoint (#9132)
# Which Problems Are Solved - Adds infrastructure code (basic implementation, error handling, middlewares, ...) to implement the SCIM v2 interface - Adds support for the user create SCIM v2 endpoint # How the Problems Are Solved - Adds support for the user create SCIM v2 endpoint under `POST /scim/v2/{orgID}/Users` # Additional Context Part of #8140
This commit is contained in:
53
internal/api/scim/middleware/content_type_middleware.go
Normal file
53
internal/api/scim/middleware/content_type_middleware.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"mime"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
zhttp "github.com/zitadel/zitadel/internal/api/http"
|
||||
"github.com/zitadel/zitadel/internal/api/http/middleware"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
const (
|
||||
ContentTypeScim = "application/scim+json"
|
||||
ContentTypeJson = "application/json"
|
||||
)
|
||||
|
||||
func ContentTypeMiddleware(next middleware.HandlerFuncWithError) middleware.HandlerFuncWithError {
|
||||
return func(w http.ResponseWriter, r *http.Request) error {
|
||||
w.Header().Set(zhttp.ContentType, ContentTypeScim)
|
||||
|
||||
if !validateContentType(r.Header.Get(zhttp.ContentType)) {
|
||||
return zerrors.ThrowInvalidArgumentf(nil, "SMCM-12x4", "Invalid content type header")
|
||||
}
|
||||
|
||||
if !validateContentType(r.Header.Get(zhttp.Accept)) {
|
||||
return zerrors.ThrowInvalidArgumentf(nil, "SMCM-12x5", "Invalid accept header")
|
||||
}
|
||||
|
||||
return next(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func validateContentType(contentType string) bool {
|
||||
if contentType == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
mediaType, params, err := mime.ParseMediaType(contentType)
|
||||
if err != nil {
|
||||
logging.OnError(err).Warn("failed to parse content type header")
|
||||
return false
|
||||
}
|
||||
|
||||
if mediaType != "" && !strings.EqualFold(mediaType, ContentTypeJson) && !strings.EqualFold(mediaType, ContentTypeScim) {
|
||||
return false
|
||||
}
|
||||
|
||||
charset, ok := params["charset"]
|
||||
return !ok || strings.EqualFold(charset, "utf-8")
|
||||
}
|
107
internal/api/scim/middleware/content_type_middleware_test.go
Normal file
107
internal/api/scim/middleware/content_type_middleware_test.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
zhttp "github.com/zitadel/zitadel/internal/api/http"
|
||||
)
|
||||
|
||||
func TestContentTypeMiddleware(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
contentTypeHeader string
|
||||
acceptHeader string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid",
|
||||
contentTypeHeader: "application/scim+json",
|
||||
acceptHeader: "application/scim+json",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid content type",
|
||||
contentTypeHeader: "application/octet-stream",
|
||||
acceptHeader: "application/json",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid accept",
|
||||
contentTypeHeader: "application/json",
|
||||
acceptHeader: "application/octet-stream",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
if tt.acceptHeader != "" {
|
||||
req.Header.Set(zhttp.Accept, tt.acceptHeader)
|
||||
}
|
||||
|
||||
if tt.contentTypeHeader != "" {
|
||||
req.Header.Set(zhttp.ContentType, tt.contentTypeHeader)
|
||||
}
|
||||
|
||||
err := ContentTypeMiddleware(func(w http.ResponseWriter, r *http.Request) error {
|
||||
return nil
|
||||
})(httptest.NewRecorder(), req)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ContentTypeMiddleware() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_validateContentType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
contentType string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
contentType: "",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "json",
|
||||
contentType: "application/json",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "scim",
|
||||
contentType: "application/scim+json",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "json utf-8",
|
||||
contentType: "application/json; charset=utf-8",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "scim utf-8",
|
||||
contentType: "application/scim+json; charset=utf-8",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "unknown content type",
|
||||
contentType: "application/octet-stream",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "unknown charset",
|
||||
contentType: "application/scim+json; charset=utf-16",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := validateContentType(tt.contentType); got != tt.want {
|
||||
t.Errorf("validateContentType() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
54
internal/api/scim/middleware/scim_context_middleware.go
Normal file
54
internal/api/scim/middleware/scim_context_middleware.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
zhttp "github.com/zitadel/zitadel/internal/api/http/middleware"
|
||||
smetadata "github.com/zitadel/zitadel/internal/api/scim/metadata"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
func ScimContextMiddleware(q *query.Queries) func(next zhttp.HandlerFuncWithError) zhttp.HandlerFuncWithError {
|
||||
return func(next zhttp.HandlerFuncWithError) zhttp.HandlerFuncWithError {
|
||||
return func(w http.ResponseWriter, r *http.Request) error {
|
||||
ctx, err := initScimContext(r.Context(), q)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return next(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func initScimContext(ctx context.Context, q *query.Queries) (context.Context, error) {
|
||||
data := smetadata.ScimContextData{
|
||||
ProvisioningDomain: "",
|
||||
ExternalIDScopedMetadataKey: smetadata.ScopedKey(smetadata.KeyExternalId),
|
||||
}
|
||||
|
||||
ctx = smetadata.SetScimContextData(ctx, data)
|
||||
|
||||
userID := authz.GetCtxData(ctx).UserID
|
||||
metadata, err := q.GetUserMetadataByKey(ctx, false, userID, string(smetadata.KeyProvisioningDomain), false)
|
||||
if err != nil {
|
||||
if zerrors.IsNotFound(err) {
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
return ctx, err
|
||||
}
|
||||
|
||||
if metadata == nil {
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
data.ProvisioningDomain = string(metadata.Value)
|
||||
if data.ProvisioningDomain != "" {
|
||||
data.ExternalIDScopedMetadataKey = smetadata.ScopeExternalIdKey(data.ProvisioningDomain)
|
||||
}
|
||||
return smetadata.SetScimContextData(ctx, data), nil
|
||||
}
|
Reference in New Issue
Block a user