diff --git a/cmd/defaults.yaml b/cmd/defaults.yaml index e376e8a488..74ffffafcd 100644 --- a/cmd/defaults.yaml +++ b/cmd/defaults.yaml @@ -590,6 +590,11 @@ SAML: # Company: ZITADEL # ZITADEL_SAML_PROVIDERCONFIG_CONTACTPERSON_COMPANY # EmailAddress: hi@zitadel.com # ZITADEL_SAML_PROVIDERCONFIG_CONTACTPERSON_EMAILADDRESS +SCIM: + # default values whether an email/phone is considered verified when a users email/phone is created or updated + EmailVerified: true # ZITADEL_SCIM_EMAILVERIFIED + PhoneVerified: true # ZITADEL_SCIM_PHONEVERIFIED + Login: LanguageCookieName: zitadel.login.lang # ZITADEL_LOGIN_LANGUAGECOOKIENAME CSRFCookieName: zitadel.login.csrf # ZITADEL_LOGIN_CSRFCOOKIENAME diff --git a/cmd/start/config.go b/cmd/start/config.go index 6182342592..d63b8a319a 100644 --- a/cmd/start/config.go +++ b/cmd/start/config.go @@ -15,6 +15,7 @@ import ( "github.com/zitadel/zitadel/internal/api/http/middleware" "github.com/zitadel/zitadel/internal/api/oidc" "github.com/zitadel/zitadel/internal/api/saml" + scim_config "github.com/zitadel/zitadel/internal/api/scim/config" "github.com/zitadel/zitadel/internal/api/ui/console" "github.com/zitadel/zitadel/internal/api/ui/login" auth_es "github.com/zitadel/zitadel/internal/auth/repository/eventsourcing" @@ -60,6 +61,7 @@ type Config struct { UserAgentCookie *middleware.UserAgentCookieConfig OIDC oidc.Config SAML saml.Config + SCIM scim_config.Config Login login.Config Console console.Config AssetStorage static_config.AssetStorageConfig diff --git a/cmd/start/start.go b/cmd/start/start.go index 154c683481..21f445cfd6 100644 --- a/cmd/start/start.go +++ b/cmd/start/start.go @@ -63,6 +63,8 @@ import ( "github.com/zitadel/zitadel/internal/api/oidc" "github.com/zitadel/zitadel/internal/api/robots_txt" "github.com/zitadel/zitadel/internal/api/saml" + "github.com/zitadel/zitadel/internal/api/scim" + "github.com/zitadel/zitadel/internal/api/scim/schemas" "github.com/zitadel/zitadel/internal/api/ui/console" "github.com/zitadel/zitadel/internal/api/ui/console/path" "github.com/zitadel/zitadel/internal/api/ui/login" @@ -519,6 +521,17 @@ func startAPIs( } apis.RegisterHandlerOnPrefix(saml.HandlerPrefix, samlProvider.HttpHandler()) + apis.RegisterHandlerOnPrefix( + schemas.HandlerPrefix, + scim.NewServer( + commands, + queries, + verifier, + keys.User, + &config.SCIM, + instanceInterceptor.HandlerFuncWithError, + middleware.AuthorizationInterceptor(verifier, config.InternalAuthZ).HandlerFuncWithError)) + c, err := console.Start(config.Console, config.ExternalSecure, oidcServer.IssuerFromRequest, middleware.CallDurationHandler, instanceInterceptor.Handler, limitingAccessInterceptor, config.CustomerPortal) if err != nil { return nil, fmt.Errorf("unable to start console: %w", err) diff --git a/internal/api/http/header.go b/internal/api/http/header.go index 982684c77c..a6c2818728 100644 --- a/internal/api/http/header.go +++ b/internal/api/http/header.go @@ -5,6 +5,8 @@ import ( "net" "net/http" "strings" + + "github.com/gorilla/mux" ) const ( @@ -14,6 +16,7 @@ const ( CacheControl = "cache-control" ContentType = "content-type" ContentLength = "content-length" + ContentLocation = "content-location" Expires = "expires" Location = "location" Origin = "origin" @@ -42,6 +45,9 @@ const ( PermissionsPolicy = "permissions-policy" ZitadelOrgID = "x-zitadel-orgid" + + OrgIdInPathVariableName = "orgId" + OrgIdInPathVariable = "{" + OrgIdInPathVariableName + "}" ) type key int @@ -104,6 +110,12 @@ func GetAuthorization(r *http.Request) string { } func GetOrgID(r *http.Request) string { + // path variable takes precedence over header + orgID, ok := mux.Vars(r)[OrgIdInPathVariableName] + if ok { + return orgID + } + return r.Header.Get(ZitadelOrgID) } diff --git a/internal/api/http/middleware/auth_interceptor.go b/internal/api/http/middleware/auth_interceptor.go index c327d8c846..1581d401b4 100644 --- a/internal/api/http/middleware/auth_interceptor.go +++ b/internal/api/http/middleware/auth_interceptor.go @@ -2,12 +2,15 @@ package middleware import ( "context" - "errors" "net/http" + "strings" + + "github.com/gorilla/mux" "github.com/zitadel/zitadel/internal/api/authz" http_util "github.com/zitadel/zitadel/internal/api/http" "github.com/zitadel/zitadel/internal/telemetry/tracing" + "github.com/zitadel/zitadel/internal/zerrors" ) type AuthInterceptor struct { @@ -23,34 +26,40 @@ func AuthorizationInterceptor(verifier authz.APITokenVerifier, authConfig authz. } func (a *AuthInterceptor) Handler(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx, err := authorize(r, a.verifier, a.authConfig) - if err != nil { - http.Error(w, err.Error(), http.StatusUnauthorized) - return - } - r = r.WithContext(ctx) - next.ServeHTTP(w, r) - }) + return a.HandlerFunc(next) } -func (a *AuthInterceptor) HandlerFunc(next http.HandlerFunc) http.HandlerFunc { +func (a *AuthInterceptor) HandlerFunc(next http.Handler) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, err := authorize(r, a.verifier, a.authConfig) if err != nil { http.Error(w, err.Error(), http.StatusUnauthorized) return } + r = r.WithContext(ctx) next.ServeHTTP(w, r) } } +func (a *AuthInterceptor) HandlerFuncWithError(next HandlerFuncWithError) HandlerFuncWithError { + return func(w http.ResponseWriter, r *http.Request) error { + ctx, err := authorize(r, a.verifier, a.authConfig) + if err != nil { + return err + } + + r = r.WithContext(ctx) + return next(w, r) + } +} + type httpReq struct{} func authorize(r *http.Request, verifier authz.APITokenVerifier, authConfig authz.Config) (_ context.Context, err error) { ctx := r.Context() - authOpt, needsToken := verifier.CheckAuthMethod(r.Method + ":" + r.RequestURI) + + authOpt, needsToken := checkAuthMethod(r, verifier) if !needsToken { return ctx, nil } @@ -59,7 +68,7 @@ func authorize(r *http.Request, verifier authz.APITokenVerifier, authConfig auth authToken := http_util.GetAuthorization(r) if authToken == "" { - return nil, errors.New("auth header missing") + return nil, zerrors.ThrowUnauthenticated(nil, "AUT-1179", "auth header missing") } ctxSetter, err := authz.CheckUserAuthorization(authCtx, &httpReq{}, authToken, http_util.GetOrgID(r), "", verifier, authConfig, authOpt, r.RequestURI) @@ -69,3 +78,30 @@ func authorize(r *http.Request, verifier authz.APITokenVerifier, authConfig auth span.End() return ctxSetter(ctx), nil } + +func checkAuthMethod(r *http.Request, verifier authz.APITokenVerifier) (authz.Option, bool) { + authOpt, needsToken := verifier.CheckAuthMethod(r.Method + ":" + r.RequestURI) + if needsToken { + return authOpt, true + } + + route := mux.CurrentRoute(r) + if route == nil { + return authOpt, false + } + + pathTemplate, err := route.GetPathTemplate() + if err != nil || pathTemplate == "" { + return authOpt, false + } + + // the path prefix is usually handled in a router in upper layer + // trim the query and the path of the url to get the correct path prefix + pathPrefix := r.RequestURI + if i := strings.Index(pathPrefix, "?"); i != -1 { + pathPrefix = pathPrefix[0:i] + } + pathPrefix = strings.TrimSuffix(pathPrefix, r.URL.Path) + + return verifier.CheckAuthMethod(r.Method + ":" + pathPrefix + pathTemplate) +} diff --git a/internal/api/http/middleware/handler.go b/internal/api/http/middleware/handler.go new file mode 100644 index 0000000000..2c79b6227a --- /dev/null +++ b/internal/api/http/middleware/handler.go @@ -0,0 +1,26 @@ +package middleware + +import "net/http" + +// HandlerFuncWithError is a http handler func which can return an error +// the error should then get handled later on in the pipeline by an error handler +// the error handler can be dependent on the interface standard (e.g. SCIM, Problem Details, ...) +type HandlerFuncWithError = func(w http.ResponseWriter, r *http.Request) error + +// MiddlewareWithErrorFunc is a http middleware which can return an error +// the error should then get handled later on in the pipeline by an error handler +// the error handler can be dependent on the interface standard (e.g. SCIM, Problem Details, ...) +type MiddlewareWithErrorFunc = func(HandlerFuncWithError) HandlerFuncWithError + +// ErrorHandlerFunc handles errors and returns a regular http handler +type ErrorHandlerFunc = func(HandlerFuncWithError) http.Handler + +func ChainedWithErrorHandler(errorHandler ErrorHandlerFunc, middlewares ...MiddlewareWithErrorFunc) func(HandlerFuncWithError) http.Handler { + return func(next HandlerFuncWithError) http.Handler { + for i := len(middlewares) - 1; i >= 0; i-- { + next = middlewares[i](next) + } + + return errorHandler(next) + } +} diff --git a/internal/api/http/middleware/instance_interceptor.go b/internal/api/http/middleware/instance_interceptor.go index facb2ceec0..3ae5dfbb88 100644 --- a/internal/api/http/middleware/instance_interceptor.go +++ b/internal/api/http/middleware/instance_interceptor.go @@ -34,43 +34,57 @@ func InstanceInterceptor(verifier authz.InstanceVerifier, externalDomain string, } func (a *instanceInterceptor) Handler(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - a.handleInstance(w, r, next) - }) + return a.HandlerFunc(next) } -func (a *instanceInterceptor) HandlerFunc(next http.HandlerFunc) http.HandlerFunc { +func (a *instanceInterceptor) HandlerFunc(next http.Handler) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - a.handleInstance(w, r, next) - } -} - -func (a *instanceInterceptor) handleInstance(w http.ResponseWriter, r *http.Request, next http.Handler) { - for _, prefix := range a.ignoredPrefixes { - if strings.HasPrefix(r.URL.Path, prefix) { + ctx, err := a.setInstanceIfNeeded(r.Context(), r) + if err == nil { + r = r.WithContext(ctx) next.ServeHTTP(w, r) return } - } - ctx, err := setInstance(r, a.verifier) - if err != nil { + origin := zitadel_http.DomainContext(r.Context()) logging.WithFields("origin", origin.Origin(), "externalDomain", a.externalDomain).WithError(err).Error("unable to set instance") + zErr := new(zerrors.ZitadelError) if errors.As(err, &zErr) { zErr.SetMessage(a.translator.LocalizeFromRequest(r, zErr.GetMessage(), nil)) http.Error(w, fmt.Sprintf("unable to set instance using origin %s (ExternalDomain is %s): %s", origin, a.externalDomain, zErr), http.StatusNotFound) return } + http.Error(w, fmt.Sprintf("unable to set instance using origin %s (ExternalDomain is %s)", origin, a.externalDomain), http.StatusNotFound) - return } - r = r.WithContext(ctx) - next.ServeHTTP(w, r) } -func setInstance(r *http.Request, verifier authz.InstanceVerifier) (_ context.Context, err error) { - ctx := r.Context() +func (a *instanceInterceptor) HandlerFuncWithError(next HandlerFuncWithError) HandlerFuncWithError { + return func(w http.ResponseWriter, r *http.Request) error { + ctx, err := a.setInstanceIfNeeded(r.Context(), r) + if err != nil { + origin := zitadel_http.DomainContext(r.Context()) + logging.WithFields("origin", origin.Origin(), "externalDomain", a.externalDomain).WithError(err).Error("unable to set instance") + return err + } + + r = r.WithContext(ctx) + return next(w, r) + } +} + +func (a *instanceInterceptor) setInstanceIfNeeded(ctx context.Context, r *http.Request) (context.Context, error) { + for _, prefix := range a.ignoredPrefixes { + if strings.HasPrefix(r.URL.Path, prefix) { + return ctx, nil + } + } + + return setInstance(ctx, a.verifier) +} + +func setInstance(ctx context.Context, verifier authz.InstanceVerifier) (_ context.Context, err error) { authCtx, span := tracing.NewServerInterceptorSpan(ctx) defer func() { span.EndWithError(err) }() diff --git a/internal/api/http/middleware/instance_interceptor_test.go b/internal/api/http/middleware/instance_interceptor_test.go index 51c0fb9a10..da831dff65 100644 --- a/internal/api/http/middleware/instance_interceptor_test.go +++ b/internal/api/http/middleware/instance_interceptor_test.go @@ -72,7 +72,7 @@ func Test_instanceInterceptor_Handler(t *testing.T) { translator: newZitadelTranslator(), } next := &testHandler{} - got := a.HandlerFunc(next.ServeHTTP) + got := a.HandlerFunc(next) rr := httptest.NewRecorder() got.ServeHTTP(rr, tt.args.request) assert.Equal(t, tt.res.statusCode, rr.Code) @@ -136,7 +136,7 @@ func Test_instanceInterceptor_HandlerFunc(t *testing.T) { translator: newZitadelTranslator(), } next := &testHandler{} - got := a.HandlerFunc(next.ServeHTTP) + got := a.HandlerFunc(next) rr := httptest.NewRecorder() got.ServeHTTP(rr, tt.args.request) assert.Equal(t, tt.res.statusCode, rr.Code) @@ -145,9 +145,78 @@ func Test_instanceInterceptor_HandlerFunc(t *testing.T) { } } +func Test_instanceInterceptor_HandlerFuncWithError(t *testing.T) { + type fields struct { + verifier authz.InstanceVerifier + } + type args struct { + request *http.Request + } + type res struct { + wantErr bool + context context.Context + } + tests := []struct { + name string + fields fields + args args + res res + }{ + { + "setInstance error", + fields{ + verifier: &mockInstanceVerifier{}, + }, + args{ + request: httptest.NewRequest("", "/url", nil), + }, + res{ + wantErr: true, + context: nil, + }, + }, + { + "setInstance ok", + fields{ + verifier: &mockInstanceVerifier{instanceHost: "host"}, + }, + args{ + request: func() *http.Request { + r := httptest.NewRequest("", "/url", nil) + r = r.WithContext(zitadel_http.WithDomainContext(r.Context(), &zitadel_http.DomainCtx{InstanceHost: "host"})) + return r + }(), + }, + res{ + context: authz.WithInstance(zitadel_http.WithDomainContext(context.Background(), &zitadel_http.DomainCtx{InstanceHost: "host"}), &mockInstance{}), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &instanceInterceptor{ + verifier: tt.fields.verifier, + translator: newZitadelTranslator(), + } + var ctx context.Context + got := a.HandlerFuncWithError(func(w http.ResponseWriter, r *http.Request) error { + ctx = r.Context() + return nil + }) + rr := httptest.NewRecorder() + err := got(rr, tt.args.request) + if (err != nil) != tt.res.wantErr { + t.Errorf("got error %v, want %v", err, tt.res.wantErr) + } + + assert.Equal(t, tt.res.context, ctx) + }) + } +} + func Test_setInstance(t *testing.T) { type args struct { - r *http.Request + ctx context.Context verifier authz.InstanceVerifier } type res struct { @@ -162,10 +231,7 @@ func Test_setInstance(t *testing.T) { { "no domain context, not found error", args{ - r: func() *http.Request { - r := httptest.NewRequest("", "/url", nil) - return r - }(), + ctx: context.Background(), verifier: &mockInstanceVerifier{}, }, res{ @@ -176,10 +242,7 @@ func Test_setInstance(t *testing.T) { { "instanceHost found, ok", args{ - r: func() *http.Request { - r := httptest.NewRequest("", "/url", nil) - return r.WithContext(zitadel_http.WithDomainContext(r.Context(), &zitadel_http.DomainCtx{InstanceHost: "host", Protocol: "https"})) - }(), + ctx: zitadel_http.WithDomainContext(context.Background(), &zitadel_http.DomainCtx{InstanceHost: "host", Protocol: "https"}), verifier: &mockInstanceVerifier{instanceHost: "host"}, }, res{ @@ -190,10 +253,7 @@ func Test_setInstance(t *testing.T) { { "instanceHost not found, error", args{ - r: func() *http.Request { - r := httptest.NewRequest("", "/url", nil) - return r.WithContext(zitadel_http.WithDomainContext(r.Context(), &zitadel_http.DomainCtx{InstanceHost: "fromorigin:9999", Protocol: "https"})) - }(), + ctx: zitadel_http.WithDomainContext(context.Background(), &zitadel_http.DomainCtx{InstanceHost: "fromorigin:9999", Protocol: "https"}), verifier: &mockInstanceVerifier{instanceHost: "unknowndomain"}, }, res{ @@ -204,7 +264,7 @@ func Test_setInstance(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := setInstance(tt.args.r, tt.args.verifier) + got, err := setInstance(tt.args.ctx, tt.args.verifier) if (err != nil) != tt.res.err { t.Errorf("setInstance() error = %v, wantErr %v", err, tt.res.err) return diff --git a/internal/api/scim/authz.go b/internal/api/scim/authz.go new file mode 100644 index 0000000000..a89df38061 --- /dev/null +++ b/internal/api/scim/authz.go @@ -0,0 +1,16 @@ +package scim + +import ( + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/api/http" + "github.com/zitadel/zitadel/internal/domain" +) + +var AuthMapping = authz.MethodMapping{ + "POST:/scim/v2/" + http.OrgIdInPathVariable + "/Users": { + Permission: domain.PermissionUserWrite, + }, + "DELETE:/scim/v2/" + http.OrgIdInPathVariable + "/Users/{id}": { + Permission: domain.PermissionUserDelete, + }, +} diff --git a/internal/api/scim/config/config.go b/internal/api/scim/config/config.go new file mode 100644 index 0000000000..6199f0a2ea --- /dev/null +++ b/internal/api/scim/config/config.go @@ -0,0 +1,6 @@ +package config + +type Config struct { + EmailVerified bool + PhoneVerified bool +} diff --git a/internal/api/scim/integration_test/scim_test.go b/internal/api/scim/integration_test/scim_test.go new file mode 100644 index 0000000000..e722ffdb18 --- /dev/null +++ b/internal/api/scim/integration_test/scim_test.go @@ -0,0 +1,28 @@ +//go:build integration + +package integration_test + +import ( + "context" + "github.com/zitadel/zitadel/internal/integration" + "os" + "testing" + "time" +) + +var ( + Instance *integration.Instance + CTX context.Context +) + +func TestMain(m *testing.M) { + os.Exit(func() int { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) + defer cancel() + + Instance = integration.NewInstance(ctx) + + CTX = Instance.WithAuthorization(ctx, integration.UserTypeOrgOwner) + return m.Run() + }()) +} diff --git a/internal/api/scim/integration_test/testdata/users_create_test_full.json b/internal/api/scim/integration_test/testdata/users_create_test_full.json new file mode 100644 index 0000000000..7879ecf160 --- /dev/null +++ b/internal/api/scim/integration_test/testdata/users_create_test_full.json @@ -0,0 +1,116 @@ +{ + "schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"], + "externalId": "701984", + "userName": "bjensen@example.com", + "name": { + "formatted": "Ms. Barbara J Jensen, III", + "familyName": "Jensen", + "givenName": "Barbara", + "middleName": "Jane", + "honorificPrefix": "Ms.", + "honorificSuffix": "III" + }, + "displayName": "Babs Jensen", + "nickName": "Babs", + "profileUrl": "http://login.example.com/bjensen", + "emails": [ + { + "value": "bjensen@example.com", + "type": "work", + "primary": true + }, + { + "value": "babs@jensen.org", + "type": "home" + } + ], + "addresses": [ + { + "type": "work", + "streetAddress": "100 Universal City Plaza", + "locality": "Hollywood", + "region": "CA", + "postalCode": "91608", + "country": "USA", + "formatted": "100 Universal City Plaza\nHollywood, CA 91608 USA", + "primary": true + }, + { + "type": "home", + "streetAddress": "456 Hollywood Blvd", + "locality": "Hollywood", + "region": "CA", + "postalCode": "91608", + "country": "USA", + "formatted": "456 Hollywood Blvd\nHollywood, CA 91608 USA" + } + ], + "phoneNumbers": [ + { + "value": "555-555-5555", + "type": "work", + "primary": true + }, + { + "value": "555-555-4444", + "type": "mobile" + } + ], + "ims": [ + { + "value": "someaimhandle", + "type": "aim" + }, + { + "value": "twitterhandle", + "type": "X" + } + ], + "photos": [ + { + "value": + "https://photos.example.com/profilephoto/72930000000Ccne/F", + "type": "photo" + }, + { + "value": + "https://photos.example.com/profilephoto/72930000000Ccne/T", + "type": "thumbnail" + } + ], + "roles": [ + { + "value": "my-role-1", + "display": "Rolle 1", + "type": "main-role", + "primary": true + }, + { + "value": "my-role-2", + "display": "Rolle 2", + "type": "secondary-role", + "primary": false + } + ], + "entitlements": [ + { + "value": "my-entitlement-1", + "display": "Entitlement 1", + "type": "main-entitlement", + "primary": true + }, + { + "value": "my-entitlement-2", + "display": "Entitlement 2", + "type": "secondary-entitlement", + "primary": false + } + ], + "userType": "Employee", + "title": "Tour Guide", + "preferredLanguage": "en-US", + "locale": "en-US", + "timezone": "America/Los_Angeles", + "active":true, + "password": "Password1!" +} \ No newline at end of file diff --git a/internal/api/scim/integration_test/testdata/users_create_test_invalid_locale.json b/internal/api/scim/integration_test/testdata/users_create_test_invalid_locale.json new file mode 100644 index 0000000000..eaadac8b90 --- /dev/null +++ b/internal/api/scim/integration_test/testdata/users_create_test_invalid_locale.json @@ -0,0 +1,17 @@ +{ + "schemas": [ + "urn:ietf:params:scim:schemas:core:2.0:User" + ], + "userName": "acmeUser1", + "name": { + "familyName": "Ross", + "givenName": "Bethany" + }, + "emails": [ + { + "value": "user1@example.com", + "primary": true + } + ], + "locale": "fooBar" +} \ No newline at end of file diff --git a/internal/api/scim/integration_test/testdata/users_create_test_invalid_password.json b/internal/api/scim/integration_test/testdata/users_create_test_invalid_password.json new file mode 100644 index 0000000000..7a3d71cbed --- /dev/null +++ b/internal/api/scim/integration_test/testdata/users_create_test_invalid_password.json @@ -0,0 +1,17 @@ +{ + "schemas": [ + "urn:ietf:params:scim:schemas:core:2.0:User" + ], + "userName": "acmeUser1", + "name": { + "familyName": "Ross", + "givenName": "Bethany" + }, + "emails": [ + { + "value": "user1@example.com", + "primary": true + } + ], + "password": "fooBar" +} \ No newline at end of file diff --git a/internal/api/scim/integration_test/testdata/users_create_test_invalid_profile_url.json b/internal/api/scim/integration_test/testdata/users_create_test_invalid_profile_url.json new file mode 100644 index 0000000000..3bc8fee87b --- /dev/null +++ b/internal/api/scim/integration_test/testdata/users_create_test_invalid_profile_url.json @@ -0,0 +1,17 @@ +{ + "schemas": [ + "urn:ietf:params:scim:schemas:core:2.0:User" + ], + "userName": "acmeUser1", + "name": { + "familyName": "Ross", + "givenName": "Bethany" + }, + "emails": [ + { + "value": "user1@example.com", + "primary": true + } + ], + "profileUrl": "ftp://login.example.com/bjensen" +} \ No newline at end of file diff --git a/internal/api/scim/integration_test/testdata/users_create_test_invalid_timezone.json b/internal/api/scim/integration_test/testdata/users_create_test_invalid_timezone.json new file mode 100644 index 0000000000..d4ac9aa0a5 --- /dev/null +++ b/internal/api/scim/integration_test/testdata/users_create_test_invalid_timezone.json @@ -0,0 +1,17 @@ +{ + "schemas": [ + "urn:ietf:params:scim:schemas:core:2.0:User" + ], + "userName": "acmeUser1", + "name": { + "familyName": "Ross", + "givenName": "Bethany" + }, + "emails": [ + { + "value": "user1@example.com", + "primary": true + } + ], + "timezone": "fooBar" +} \ No newline at end of file diff --git a/internal/api/scim/integration_test/testdata/users_create_test_minimal.json b/internal/api/scim/integration_test/testdata/users_create_test_minimal.json new file mode 100644 index 0000000000..c51f416bc7 --- /dev/null +++ b/internal/api/scim/integration_test/testdata/users_create_test_minimal.json @@ -0,0 +1,16 @@ +{ + "schemas": [ + "urn:ietf:params:scim:schemas:core:2.0:User" + ], + "userName": "acmeUser1", + "name": { + "familyName": "Ross", + "givenName": "Bethany" + }, + "emails": [ + { + "value": "user1@example.com", + "primary": true + } + ] +} \ No newline at end of file diff --git a/internal/api/scim/integration_test/testdata/users_create_test_missing_email.json b/internal/api/scim/integration_test/testdata/users_create_test_missing_email.json new file mode 100644 index 0000000000..c68ebf98a0 --- /dev/null +++ b/internal/api/scim/integration_test/testdata/users_create_test_missing_email.json @@ -0,0 +1,10 @@ +{ + "schemas": [ + "urn:ietf:params:scim:schemas:core:2.0:User" + ], + "userName": "acmeUser1", + "name": { + "familyName": "Ross", + "givenName": "Bethany" + } +} \ No newline at end of file diff --git a/internal/api/scim/integration_test/testdata/users_create_test_missing_name.json b/internal/api/scim/integration_test/testdata/users_create_test_missing_name.json new file mode 100644 index 0000000000..d1d3375f89 --- /dev/null +++ b/internal/api/scim/integration_test/testdata/users_create_test_missing_name.json @@ -0,0 +1,15 @@ +{ + "schemas": [ + "urn:ietf:params:scim:schemas:core:2.0:User" + ], + "userName": "acmeUser1", + "name": { + "givenName": "Bethany" + }, + "emails": [ + { + "value": "user1@example.com", + "primary": true + } + ] +} \ No newline at end of file diff --git a/internal/api/scim/integration_test/testdata/users_create_test_missing_username.json b/internal/api/scim/integration_test/testdata/users_create_test_missing_username.json new file mode 100644 index 0000000000..9446665226 --- /dev/null +++ b/internal/api/scim/integration_test/testdata/users_create_test_missing_username.json @@ -0,0 +1,15 @@ +{ + "schemas": [ + "urn:ietf:params:scim:schemas:core:2.0:User" + ], + "name": { + "familyName": "Ross", + "givenName": "Bethany" + }, + "emails": [ + { + "value": "user1@example.com", + "primary": true + } + ] +} \ No newline at end of file diff --git a/internal/api/scim/integration_test/users_create_test.go b/internal/api/scim/integration_test/users_create_test.go new file mode 100644 index 0000000000..b7d97e342f --- /dev/null +++ b/internal/api/scim/integration_test/users_create_test.go @@ -0,0 +1,244 @@ +//go:build integration + +package integration_test + +import ( + "context" + _ "embed" + "github.com/brianvoe/gofakeit/v6" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/zitadel/zitadel/internal/api/scim/schemas" + "github.com/zitadel/zitadel/internal/integration" + "github.com/zitadel/zitadel/internal/integration/scim" + "github.com/zitadel/zitadel/pkg/grpc/management" + "github.com/zitadel/zitadel/pkg/grpc/user/v2" + "google.golang.org/grpc/codes" + "net/http" + "path" + "testing" +) + +var ( + //go:embed testdata/users_create_test_minimal.json + minimalUserJson []byte + + //go:embed testdata/users_create_test_full.json + fullUserJson []byte + + //go:embed testdata/users_create_test_missing_username.json + missingUserNameUserJson []byte + + //go:embed testdata/users_create_test_missing_name.json + missingNameUserJson []byte + + //go:embed testdata/users_create_test_missing_email.json + missingEmailUserJson []byte + + //go:embed testdata/users_create_test_invalid_password.json + invalidPasswordUserJson []byte + + //go:embed testdata/users_create_test_invalid_profile_url.json + invalidProfileUrlUserJson []byte + + //go:embed testdata/users_create_test_invalid_locale.json + invalidLocaleUserJson []byte + + //go:embed testdata/users_create_test_invalid_timezone.json + invalidTimeZoneUserJson []byte +) + +func TestCreateUser(t *testing.T) { + tests := []struct { + name string + body []byte + ctx context.Context + wantErr bool + scimErrorType string + errorStatus int + zitadelErrID string + }{ + { + name: "minimal user", + body: minimalUserJson, + }, + { + name: "full user", + body: fullUserJson, + }, + { + name: "missing userName", + wantErr: true, + scimErrorType: "invalidValue", + body: missingUserNameUserJson, + }, + { + // this is an expected schema violation + name: "missing name", + wantErr: true, + scimErrorType: "invalidValue", + body: missingNameUserJson, + }, + { + name: "missing email", + wantErr: true, + scimErrorType: "invalidValue", + body: missingEmailUserJson, + }, + { + name: "password complexity violation", + wantErr: true, + scimErrorType: "invalidValue", + body: invalidPasswordUserJson, + }, + { + name: "invalid profile url", + wantErr: true, + scimErrorType: "invalidValue", + zitadelErrID: "SCIM-htturl1", + body: invalidProfileUrlUserJson, + }, + { + name: "invalid time zone", + wantErr: true, + scimErrorType: "invalidValue", + body: invalidTimeZoneUserJson, + }, + { + name: "invalid locale", + wantErr: true, + scimErrorType: "invalidValue", + body: invalidLocaleUserJson, + }, + { + name: "not authenticated", + body: minimalUserJson, + ctx: context.Background(), + wantErr: true, + errorStatus: http.StatusUnauthorized, + }, + { + name: "no permissions", + body: minimalUserJson, + ctx: Instance.WithAuthorization(CTX, integration.UserTypeNoPermission), + wantErr: true, + errorStatus: http.StatusNotFound, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := tt.ctx + if ctx == nil { + ctx = CTX + } + + createdUser, err := Instance.Client.SCIM.Users.Create(ctx, Instance.DefaultOrg.Id, tt.body) + if (err != nil) != tt.wantErr { + t.Errorf("CreateUser() error = %v, wantErr %v", err, tt.wantErr) + } + + if err != nil { + statusCode := tt.errorStatus + if statusCode == 0 { + statusCode = http.StatusBadRequest + } + scimErr := scim.RequireScimError(t, statusCode, err) + assert.Equal(t, tt.scimErrorType, scimErr.Error.ScimType) + if tt.zitadelErrID != "" { + assert.Equal(t, tt.zitadelErrID, scimErr.Error.ZitadelDetail.ID) + } + + return + } + + assert.NotEmpty(t, createdUser.ID) + assert.EqualValues(t, []schemas.ScimSchemaType{"urn:ietf:params:scim:schemas:core:2.0:User"}, createdUser.Resource.Schemas) + assert.Equal(t, schemas.ScimResourceTypeSingular("User"), createdUser.Resource.Meta.ResourceType) + assert.Equal(t, "http://"+Instance.Host()+path.Join(schemas.HandlerPrefix, Instance.DefaultOrg.Id, "Users", createdUser.ID), createdUser.Resource.Meta.Location) + assert.Nil(t, createdUser.Password) + + _, err = Instance.Client.UserV2.DeleteUser(CTX, &user.DeleteUserRequest{UserId: createdUser.ID}) + assert.NoError(t, err) + }) + } +} + +func TestCreateUser_duplicate(t *testing.T) { + createdUser, err := Instance.Client.SCIM.Users.Create(CTX, Instance.DefaultOrg.Id, minimalUserJson) + require.NoError(t, err) + + _, err = Instance.Client.SCIM.Users.Create(CTX, Instance.DefaultOrg.Id, minimalUserJson) + scimErr := scim.RequireScimError(t, http.StatusConflict, err) + assert.Equal(t, "User already exists", scimErr.Error.Detail) + + _, err = Instance.Client.UserV2.DeleteUser(CTX, &user.DeleteUserRequest{UserId: createdUser.ID}) + require.NoError(t, err) +} + +func TestCreateUser_metadata(t *testing.T) { + createdUser, err := Instance.Client.SCIM.Users.Create(CTX, Instance.DefaultOrg.Id, fullUserJson) + require.NoError(t, err) + + md, err := Instance.Client.Mgmt.ListUserMetadata(CTX, &management.ListUserMetadataRequest{ + Id: createdUser.ID, + }) + require.NoError(t, err) + + mdMap := make(map[string]string) + for i := range md.Result { + mdMap[md.Result[i].Key] = string(md.Result[i].Value) + } + + integration.AssertMapContains(t, mdMap, "urn:zitadel:scim:name.honorificPrefix", "Ms.") + integration.AssertMapContains(t, mdMap, "urn:zitadel:scim:timezone", "America/Los_Angeles") + integration.AssertMapContains(t, mdMap, "urn:zitadel:scim:photos", `[{"value":"https://photos.example.com/profilephoto/72930000000Ccne/F","type":"photo"},{"value":"https://photos.example.com/profilephoto/72930000000Ccne/T","type":"thumbnail"}]`) + integration.AssertMapContains(t, mdMap, "urn:zitadel:scim:addresses", `[{"type":"work","streetAddress":"100 Universal City Plaza","locality":"Hollywood","region":"CA","postalCode":"91608","country":"USA","formatted":"100 Universal City Plaza\nHollywood, CA 91608 USA","primary":true},{"type":"home","streetAddress":"456 Hollywood Blvd","locality":"Hollywood","region":"CA","postalCode":"91608","country":"USA","formatted":"456 Hollywood Blvd\nHollywood, CA 91608 USA"}]`) + integration.AssertMapContains(t, mdMap, "urn:zitadel:scim:entitlements", `[{"value":"my-entitlement-1","display":"Entitlement 1","type":"main-entitlement","primary":true},{"value":"my-entitlement-2","display":"Entitlement 2","type":"secondary-entitlement"}]`) + integration.AssertMapContains(t, mdMap, "urn:zitadel:scim:externalId", "701984") + integration.AssertMapContains(t, mdMap, "urn:zitadel:scim:name.middleName", "Jane") + integration.AssertMapContains(t, mdMap, "urn:zitadel:scim:name.honorificSuffix", "III") + integration.AssertMapContains(t, mdMap, "urn:zitadel:scim:profileURL", "http://login.example.com/bjensen") + integration.AssertMapContains(t, mdMap, "urn:zitadel:scim:title", "Tour Guide") + integration.AssertMapContains(t, mdMap, "urn:zitadel:scim:locale", "en-US") + integration.AssertMapContains(t, mdMap, "urn:zitadel:scim:ims", `[{"value":"someaimhandle","type":"aim"},{"value":"twitterhandle","type":"X"}]`) + integration.AssertMapContains(t, mdMap, "urn:zitadel:scim:roles", `[{"value":"my-role-1","display":"Rolle 1","type":"main-role","primary":true},{"value":"my-role-2","display":"Rolle 2","type":"secondary-role"}]`) + + _, err = Instance.Client.UserV2.DeleteUser(CTX, &user.DeleteUserRequest{UserId: createdUser.ID}) + require.NoError(t, err) +} + +func TestCreateUser_scopedExternalID(t *testing.T) { + _, err := Instance.Client.Mgmt.SetUserMetadata(CTX, &management.SetUserMetadataRequest{ + Id: Instance.Users.Get(integration.UserTypeOrgOwner).ID, + Key: "urn:zitadel:scim:provisioning_domain", + Value: []byte("fooBar"), + }) + require.NoError(t, err) + + createdUser, err := Instance.Client.SCIM.Users.Create(CTX, Instance.DefaultOrg.Id, fullUserJson) + require.NoError(t, err) + + // unscoped externalID should not exist + _, err = Instance.Client.Mgmt.GetUserMetadata(CTX, &management.GetUserMetadataRequest{ + Id: createdUser.ID, + Key: "urn:zitadel:scim:externalId", + }) + integration.AssertGrpcStatus(t, codes.NotFound, err) + + // scoped externalID should exist + md, err := Instance.Client.Mgmt.GetUserMetadata(CTX, &management.GetUserMetadataRequest{ + Id: createdUser.ID, + Key: "urn:zitadel:scim:fooBar:externalId", + }) + require.NoError(t, err) + assert.Equal(t, "701984", string(md.Metadata.Value)) + + _, err = Instance.Client.UserV2.DeleteUser(CTX, &user.DeleteUserRequest{UserId: createdUser.ID}) + require.NoError(t, err) +} + +func TestCreateUser_anotherOrg(t *testing.T) { + org := Instance.CreateOrganization(Instance.WithAuthorization(CTX, integration.UserTypeIAMOwner), gofakeit.Name(), gofakeit.Email()) + _, err := Instance.Client.SCIM.Users.Create(CTX, org.OrganizationId, fullUserJson) + scim.RequireScimError(t, http.StatusNotFound, err) +} diff --git a/internal/api/scim/integration_test/users_delete_test.go b/internal/api/scim/integration_test/users_delete_test.go new file mode 100644 index 0000000000..6d3f73a71e --- /dev/null +++ b/internal/api/scim/integration_test/users_delete_test.go @@ -0,0 +1,84 @@ +//go:build integration + +package integration_test + +import ( + "context" + "github.com/brianvoe/gofakeit/v6" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/zitadel/zitadel/internal/integration" + "github.com/zitadel/zitadel/internal/integration/scim" + "github.com/zitadel/zitadel/pkg/grpc/user/v2" + "google.golang.org/grpc/codes" + "net/http" + "testing" +) + +func TestDeleteUser_errors(t *testing.T) { + tests := []struct { + name string + ctx context.Context + errorStatus int + }{ + { + name: "not authenticated", + ctx: context.Background(), + errorStatus: http.StatusUnauthorized, + }, + { + name: "no permissions", + ctx: Instance.WithAuthorization(CTX, integration.UserTypeNoPermission), + errorStatus: http.StatusNotFound, + }, + { + name: "unknown user id", + errorStatus: http.StatusNotFound, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := tt.ctx + if ctx == nil { + ctx = CTX + } + + err := Instance.Client.SCIM.Users.Delete(ctx, Instance.DefaultOrg.Id, "1") + + statusCode := tt.errorStatus + if statusCode == 0 { + statusCode = http.StatusBadRequest + } + + scim.RequireScimError(t, statusCode, err) + }) + } +} + +func TestDeleteUser_ensureReallyDeleted(t *testing.T) { + // create user and dependencies + createUserResp := Instance.CreateHumanUser(CTX) + proj, err := Instance.CreateProject(CTX) + require.NoError(t, err) + + Instance.CreateProjectUserGrant(t, CTX, proj.Id, createUserResp.UserId) + + // delete user via scim + err = Instance.Client.SCIM.Users.Delete(CTX, Instance.DefaultOrg.Id, createUserResp.UserId) + assert.NoError(t, err) + + // ensure it is really deleted => try to delete again => should 404 + err = Instance.Client.SCIM.Users.Delete(CTX, Instance.DefaultOrg.Id, createUserResp.UserId) + scim.RequireScimError(t, http.StatusNotFound, err) + + // try to get user via api => should 404 + _, err = Instance.Client.UserV2.GetUserByID(CTX, &user.GetUserByIDRequest{UserId: createUserResp.UserId}) + integration.AssertGrpcStatus(t, codes.NotFound, err) +} + +func TestDeleteUser_anotherOrg(t *testing.T) { + createUserResp := Instance.CreateHumanUser(CTX) + org := Instance.CreateOrganization(Instance.WithAuthorization(CTX, integration.UserTypeIAMOwner), gofakeit.Name(), gofakeit.Email()) + err := Instance.Client.SCIM.Users.Delete(CTX, org.OrganizationId, createUserResp.UserId) + scim.RequireScimError(t, http.StatusNotFound, err) +} diff --git a/internal/api/scim/metadata/context.go b/internal/api/scim/metadata/context.go new file mode 100644 index 0000000000..5be54d7123 --- /dev/null +++ b/internal/api/scim/metadata/context.go @@ -0,0 +1,23 @@ +package metadata + +import ( + "context" +) + +type provisioningDomainKeyType struct{} + +var provisioningDomainKey provisioningDomainKeyType + +type ScimContextData struct { + ProvisioningDomain string + ExternalIDScopedMetadataKey ScopedKey +} + +func SetScimContextData(ctx context.Context, data ScimContextData) context.Context { + return context.WithValue(ctx, provisioningDomainKey, data) +} + +func GetScimContextData(ctx context.Context) ScimContextData { + data, _ := ctx.Value(provisioningDomainKey).(ScimContextData) + return data +} diff --git a/internal/api/scim/metadata/metadata.go b/internal/api/scim/metadata/metadata.go new file mode 100644 index 0000000000..626d938234 --- /dev/null +++ b/internal/api/scim/metadata/metadata.go @@ -0,0 +1,60 @@ +package metadata + +import ( + "context" + "strings" +) + +type Key string +type ScopedKey string + +const ( + externalIdProvisioningDomainPlaceholder = "{provisioningDomain}" + + KeyPrefix = "urn:zitadel:scim:" + KeyProvisioningDomain Key = KeyPrefix + "provisioning_domain" + + KeyExternalId Key = KeyPrefix + "externalId" + keyScopedExternalIdTemplate = KeyPrefix + externalIdProvisioningDomainPlaceholder + ":externalId" + KeyMiddleName Key = KeyPrefix + "name.middleName" + KeyHonorificPrefix Key = KeyPrefix + "name.honorificPrefix" + KeyHonorificSuffix Key = KeyPrefix + "name.honorificSuffix" + KeyProfileUrl Key = KeyPrefix + "profileURL" + KeyTitle Key = KeyPrefix + "title" + KeyLocale Key = KeyPrefix + "locale" + KeyTimezone Key = KeyPrefix + "timezone" + KeyIms Key = KeyPrefix + "ims" + KeyPhotos Key = KeyPrefix + "photos" + KeyAddresses Key = KeyPrefix + "addresses" + KeyEntitlements Key = KeyPrefix + "entitlements" + KeyRoles Key = KeyPrefix + "roles" +) + +var ScimUserRelevantMetadataKeys = []Key{ + KeyExternalId, + KeyMiddleName, + KeyHonorificPrefix, + KeyHonorificSuffix, + KeyProfileUrl, + KeyTitle, + KeyLocale, + KeyTimezone, + KeyIms, + KeyPhotos, + KeyAddresses, + KeyEntitlements, + KeyRoles, +} + +func ScopeExternalIdKey(provisioningDomain string) ScopedKey { + return ScopedKey(strings.Replace(keyScopedExternalIdTemplate, externalIdProvisioningDomainPlaceholder, provisioningDomain, 1)) +} + +func ScopeKey(ctx context.Context, key Key) ScopedKey { + // only the externalID is scoped + if key == KeyExternalId { + return GetScimContextData(ctx).ExternalIDScopedMetadataKey + } + + return ScopedKey(key) +} diff --git a/internal/api/scim/middleware/content_type_middleware.go b/internal/api/scim/middleware/content_type_middleware.go new file mode 100644 index 0000000000..9b456bb141 --- /dev/null +++ b/internal/api/scim/middleware/content_type_middleware.go @@ -0,0 +1,53 @@ +package middleware + +import ( + "mime" + "net/http" + "strings" + + "github.com/zitadel/logging" + + zhttp "github.com/zitadel/zitadel/internal/api/http" + "github.com/zitadel/zitadel/internal/api/http/middleware" + "github.com/zitadel/zitadel/internal/zerrors" +) + +const ( + ContentTypeScim = "application/scim+json" + ContentTypeJson = "application/json" +) + +func ContentTypeMiddleware(next middleware.HandlerFuncWithError) middleware.HandlerFuncWithError { + return func(w http.ResponseWriter, r *http.Request) error { + w.Header().Set(zhttp.ContentType, ContentTypeScim) + + if !validateContentType(r.Header.Get(zhttp.ContentType)) { + return zerrors.ThrowInvalidArgumentf(nil, "SMCM-12x4", "Invalid content type header") + } + + if !validateContentType(r.Header.Get(zhttp.Accept)) { + return zerrors.ThrowInvalidArgumentf(nil, "SMCM-12x5", "Invalid accept header") + } + + return next(w, r) + } +} + +func validateContentType(contentType string) bool { + if contentType == "" { + return true + } + + mediaType, params, err := mime.ParseMediaType(contentType) + if err != nil { + logging.OnError(err).Warn("failed to parse content type header") + return false + } + + if mediaType != "" && !strings.EqualFold(mediaType, ContentTypeJson) && !strings.EqualFold(mediaType, ContentTypeScim) { + return false + } + + charset, ok := params["charset"] + return !ok || strings.EqualFold(charset, "utf-8") +} diff --git a/internal/api/scim/middleware/content_type_middleware_test.go b/internal/api/scim/middleware/content_type_middleware_test.go new file mode 100644 index 0000000000..918d4618ae --- /dev/null +++ b/internal/api/scim/middleware/content_type_middleware_test.go @@ -0,0 +1,107 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + zhttp "github.com/zitadel/zitadel/internal/api/http" +) + +func TestContentTypeMiddleware(t *testing.T) { + tests := []struct { + name string + contentTypeHeader string + acceptHeader string + wantErr bool + }{ + { + name: "valid", + contentTypeHeader: "application/scim+json", + acceptHeader: "application/scim+json", + wantErr: false, + }, + { + name: "invalid content type", + contentTypeHeader: "application/octet-stream", + acceptHeader: "application/json", + wantErr: true, + }, + { + name: "invalid accept", + contentTypeHeader: "application/json", + acceptHeader: "application/octet-stream", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + if tt.acceptHeader != "" { + req.Header.Set(zhttp.Accept, tt.acceptHeader) + } + + if tt.contentTypeHeader != "" { + req.Header.Set(zhttp.ContentType, tt.contentTypeHeader) + } + + err := ContentTypeMiddleware(func(w http.ResponseWriter, r *http.Request) error { + return nil + })(httptest.NewRecorder(), req) + if (err != nil) != tt.wantErr { + t.Errorf("ContentTypeMiddleware() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func Test_validateContentType(t *testing.T) { + tests := []struct { + name string + contentType string + want bool + }{ + { + name: "empty", + contentType: "", + want: true, + }, + { + name: "json", + contentType: "application/json", + want: true, + }, + { + name: "scim", + contentType: "application/scim+json", + want: true, + }, + { + name: "json utf-8", + contentType: "application/json; charset=utf-8", + want: true, + }, + { + name: "scim utf-8", + contentType: "application/scim+json; charset=utf-8", + want: true, + }, + { + name: "unknown content type", + contentType: "application/octet-stream", + want: false, + }, + { + name: "unknown charset", + contentType: "application/scim+json; charset=utf-16", + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := validateContentType(tt.contentType); got != tt.want { + t.Errorf("validateContentType() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/api/scim/middleware/scim_context_middleware.go b/internal/api/scim/middleware/scim_context_middleware.go new file mode 100644 index 0000000000..c52f6f13f6 --- /dev/null +++ b/internal/api/scim/middleware/scim_context_middleware.go @@ -0,0 +1,54 @@ +package middleware + +import ( + "context" + "net/http" + + "github.com/zitadel/zitadel/internal/api/authz" + zhttp "github.com/zitadel/zitadel/internal/api/http/middleware" + smetadata "github.com/zitadel/zitadel/internal/api/scim/metadata" + "github.com/zitadel/zitadel/internal/query" + "github.com/zitadel/zitadel/internal/zerrors" +) + +func ScimContextMiddleware(q *query.Queries) func(next zhttp.HandlerFuncWithError) zhttp.HandlerFuncWithError { + return func(next zhttp.HandlerFuncWithError) zhttp.HandlerFuncWithError { + return func(w http.ResponseWriter, r *http.Request) error { + ctx, err := initScimContext(r.Context(), q) + if err != nil { + return err + } + + return next(w, r.WithContext(ctx)) + } + } +} + +func initScimContext(ctx context.Context, q *query.Queries) (context.Context, error) { + data := smetadata.ScimContextData{ + ProvisioningDomain: "", + ExternalIDScopedMetadataKey: smetadata.ScopedKey(smetadata.KeyExternalId), + } + + ctx = smetadata.SetScimContextData(ctx, data) + + userID := authz.GetCtxData(ctx).UserID + metadata, err := q.GetUserMetadataByKey(ctx, false, userID, string(smetadata.KeyProvisioningDomain), false) + if err != nil { + if zerrors.IsNotFound(err) { + return ctx, nil + } + + return ctx, err + } + + if metadata == nil { + return ctx, nil + } + + data.ProvisioningDomain = string(metadata.Value) + if data.ProvisioningDomain != "" { + data.ExternalIDScopedMetadataKey = smetadata.ScopeExternalIdKey(data.ProvisioningDomain) + } + return smetadata.SetScimContextData(ctx, data), nil +} diff --git a/internal/api/scim/resources/resource_handler.go b/internal/api/scim/resources/resource_handler.go new file mode 100644 index 0000000000..2d601fd1fc --- /dev/null +++ b/internal/api/scim/resources/resource_handler.go @@ -0,0 +1,62 @@ +package resources + +import ( + "context" + "path" + "strconv" + "time" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/api/http" + "github.com/zitadel/zitadel/internal/api/scim/schemas" + "github.com/zitadel/zitadel/internal/domain" +) + +type ResourceHandler[T ResourceHolder] interface { + ResourceNameSingular() schemas.ScimResourceTypeSingular + ResourceNamePlural() schemas.ScimResourceTypePlural + SchemaType() schemas.ScimSchemaType + NewResource() T + + Create(ctx context.Context, resource T) (T, error) + Delete(ctx context.Context, id string) error +} + +type Resource struct { + Schemas []schemas.ScimSchemaType `json:"schemas"` + Meta *ResourceMeta `json:"meta"` +} + +type ResourceMeta struct { + ResourceType schemas.ScimResourceTypeSingular `json:"resourceType"` + Created time.Time `json:"created"` + LastModified time.Time `json:"lastModified"` + Version string `json:"version"` + Location string `json:"location"` +} + +type ResourceHolder interface { + GetResource() *Resource +} + +func buildResource[T ResourceHolder](ctx context.Context, handler ResourceHandler[T], details *domain.ObjectDetails) *Resource { + created := details.CreationDate.UTC() + if created.IsZero() { + created = details.EventDate.UTC() + } + + return &Resource{ + Schemas: []schemas.ScimSchemaType{handler.SchemaType()}, + Meta: &ResourceMeta{ + ResourceType: handler.ResourceNameSingular(), + Created: created, + LastModified: details.EventDate.UTC(), + Version: strconv.FormatUint(details.Sequence, 10), + Location: buildLocation(ctx, handler, details.ID), + }, + } +} + +func buildLocation[T ResourceHolder](ctx context.Context, handler ResourceHandler[T], id string) string { + return http.DomainContext(ctx).Origin() + path.Join(schemas.HandlerPrefix, authz.GetCtxData(ctx).OrgID, string(handler.ResourceNamePlural()), id) +} diff --git a/internal/api/scim/resources/resource_handler_adapter.go b/internal/api/scim/resources/resource_handler_adapter.go new file mode 100644 index 0000000000..979fdad99a --- /dev/null +++ b/internal/api/scim/resources/resource_handler_adapter.go @@ -0,0 +1,76 @@ +package resources + +import ( + "encoding/json" + "net/http" + "slices" + + "github.com/gorilla/mux" + + "github.com/zitadel/zitadel/internal/api/scim/schemas" + "github.com/zitadel/zitadel/internal/api/scim/serrors" + "github.com/zitadel/zitadel/internal/zerrors" +) + +type ResourceHandlerAdapter[T ResourceHolder] struct { + handler ResourceHandler[T] +} + +type ListRequest struct { + // Count An integer indicating the desired maximum number of query results per page. OPTIONAL. + Count uint64 `json:"count" schema:"count"` + + // StartIndex An integer indicating the 1-based index of the first query result. Optional. + StartIndex uint64 `json:"startIndex" schema:"startIndex"` +} + +type ListResponse[T any] struct { + Schemas []schemas.ScimSchemaType `json:"schemas"` + ItemsPerPage uint64 `json:"itemsPerPage"` + TotalResults uint64 `json:"totalResults"` + StartIndex uint64 `json:"startIndex"` + Resources []T `json:"Resources"` // according to the rfc this is the only field in PascalCase... +} + +func NewResourceHandlerAdapter[T ResourceHolder](handler ResourceHandler[T]) *ResourceHandlerAdapter[T] { + return &ResourceHandlerAdapter[T]{ + handler, + } +} + +func (adapter *ResourceHandlerAdapter[T]) Create(r *http.Request) (T, error) { + entity, err := adapter.readEntityFromBody(r) + if err != nil { + return entity, err + } + + return adapter.handler.Create(r.Context(), entity) +} + +func (adapter *ResourceHandlerAdapter[T]) Delete(r *http.Request) error { + id := mux.Vars(r)["id"] + return adapter.handler.Delete(r.Context(), id) +} + +func (adapter *ResourceHandlerAdapter[T]) readEntityFromBody(r *http.Request) (T, error) { + entity := adapter.handler.NewResource() + err := json.NewDecoder(r.Body).Decode(entity) + if err != nil { + if zerrors.IsZitadelError(err) { + return entity, err + } + + return entity, serrors.ThrowInvalidSyntax(zerrors.ThrowInvalidArgumentf(nil, "SCIM-ucrjson", "Could not deserialize json: %v", err.Error())) + } + + resource := entity.GetResource() + if resource == nil { + return entity, serrors.ThrowInvalidSyntax(zerrors.ThrowInvalidArgument(nil, "SCIM-xxrjson", "Could not get resource, is the schema correct?")) + } + + if !slices.Contains(resource.Schemas, adapter.handler.SchemaType()) { + return entity, serrors.ThrowInvalidSyntax(zerrors.ThrowInvalidArgumentf(nil, "SCIM-xxrschema", "Expected schema %v is not provided", adapter.handler.SchemaType())) + } + + return entity, nil +} diff --git a/internal/api/scim/resources/user.go b/internal/api/scim/resources/user.go new file mode 100644 index 0000000000..14f5af6115 --- /dev/null +++ b/internal/api/scim/resources/user.go @@ -0,0 +1,184 @@ +package resources + +import ( + "context" + + "golang.org/x/text/language" + + "github.com/zitadel/zitadel/internal/api/authz" + scim_config "github.com/zitadel/zitadel/internal/api/scim/config" + scim_schemas "github.com/zitadel/zitadel/internal/api/scim/schemas" + "github.com/zitadel/zitadel/internal/command" + "github.com/zitadel/zitadel/internal/crypto" + "github.com/zitadel/zitadel/internal/query" +) + +type UsersHandler struct { + command *command.Commands + query *query.Queries + userCodeAlg crypto.EncryptionAlgorithm + config *scim_config.Config +} + +type ScimUser struct { + *Resource + ID string `json:"id"` + ExternalID string `json:"externalId,omitempty"` + UserName string `json:"userName,omitempty"` + Name *ScimUserName `json:"name,omitempty"` + DisplayName string `json:"displayName,omitempty"` + NickName string `json:"nickName,omitempty"` + ProfileUrl *scim_schemas.HttpURL `json:"profileUrl,omitempty"` + Title string `json:"title,omitempty"` + PreferredLanguage language.Tag `json:"preferredLanguage,omitempty"` + Locale string `json:"locale,omitempty"` + Timezone string `json:"timezone,omitempty"` + Active *bool `json:"active,omitempty"` + Emails []*ScimEmail `json:"emails,omitempty"` + PhoneNumbers []*ScimPhoneNumber `json:"phoneNumbers,omitempty"` + Password *scim_schemas.WriteOnlyString `json:"password,omitempty"` + Ims []*ScimIms `json:"ims,omitempty"` + Addresses []*ScimAddress `json:"addresses,omitempty"` + Photos []*ScimPhoto `json:"photos,omitempty"` + Entitlements []*ScimEntitlement `json:"entitlements,omitempty"` + Roles []*ScimRole `json:"roles,omitempty"` +} + +type ScimEntitlement struct { + Value string `json:"value,omitempty"` + Display string `json:"display,omitempty"` + Type string `json:"type,omitempty"` + Primary bool `json:"primary,omitempty"` +} + +type ScimRole struct { + Value string `json:"value,omitempty"` + Display string `json:"display,omitempty"` + Type string `json:"type,omitempty"` + Primary bool `json:"primary,omitempty"` +} + +type ScimPhoto struct { + Value scim_schemas.HttpURL `json:"value"` + Display string `json:"display,omitempty"` + Type string `json:"type"` + Primary bool `json:"primary,omitempty"` +} + +type ScimAddress struct { + Type string `json:"type,omitempty"` + StreetAddress string `json:"streetAddress,omitempty"` + Locality string `json:"locality,omitempty"` + Region string `json:"region,omitempty"` + PostalCode string `json:"postalCode,omitempty"` + Country string `json:"country,omitempty"` + Formatted string `json:"formatted,omitempty"` + Primary bool `json:"primary,omitempty"` +} + +type ScimIms struct { + Value string `json:"value"` + Type string `json:"type"` +} + +type ScimEmail struct { + Value string `json:"value"` + Primary bool `json:"primary"` +} + +type ScimPhoneNumber struct { + Value string `json:"value"` + Primary bool `json:"primary"` +} + +type ScimUserName struct { + Formatted string `json:"formatted,omitempty"` + FamilyName string `json:"familyName,omitempty"` + GivenName string `json:"givenName,omitempty"` + MiddleName string `json:"middleName,omitempty"` + HonorificPrefix string `json:"honorificPrefix,omitempty"` + HonorificSuffix string `json:"honorificSuffix,omitempty"` +} + +func NewUsersHandler( + command *command.Commands, + query *query.Queries, + userCodeAlg crypto.EncryptionAlgorithm, + config *scim_config.Config) ResourceHandler[*ScimUser] { + return &UsersHandler{command, query, userCodeAlg, config} +} + +func (h *UsersHandler) ResourceNameSingular() scim_schemas.ScimResourceTypeSingular { + return scim_schemas.UserResourceType +} + +func (h *UsersHandler) ResourceNamePlural() scim_schemas.ScimResourceTypePlural { + return scim_schemas.UsersResourceType +} + +func (u *ScimUser) GetResource() *Resource { + return u.Resource +} + +func (h *UsersHandler) NewResource() *ScimUser { + return new(ScimUser) +} + +func (h *UsersHandler) SchemaType() scim_schemas.ScimSchemaType { + return scim_schemas.IdUser +} + +func (h *UsersHandler) Create(ctx context.Context, user *ScimUser) (*ScimUser, error) { + orgID := authz.GetCtxData(ctx).OrgID + addHuman, err := h.mapToAddHuman(ctx, user) + if err != nil { + return nil, err + } + + err = h.command.AddUserHuman(ctx, orgID, addHuman, true, h.userCodeAlg) + if err != nil { + return nil, err + } + + user.ID = addHuman.Details.ID + user.Resource = buildResource(ctx, h, addHuman.Details) + return user, nil +} + +func (h *UsersHandler) Delete(ctx context.Context, id string) error { + memberships, grants, err := h.queryUserDependencies(ctx, id) + if err != nil { + return err + } + + _, err = h.command.RemoveUserV2(ctx, id, memberships, grants...) + return err +} + +func (h *UsersHandler) queryUserDependencies(ctx context.Context, userID string) ([]*command.CascadingMembership, []string, error) { + userGrantUserQuery, err := query.NewUserGrantUserIDSearchQuery(userID) + if err != nil { + return nil, nil, err + } + + grants, err := h.query.UserGrants(ctx, &query.UserGrantsQueries{ + Queries: []query.SearchQuery{userGrantUserQuery}, + }, true) + if err != nil { + return nil, nil, err + } + + membershipsUserQuery, err := query.NewMembershipUserIDQuery(userID) + if err != nil { + return nil, nil, err + } + + memberships, err := h.query.Memberships(ctx, &query.MembershipSearchQuery{ + Queries: []query.SearchQuery{membershipsUserQuery}, + }, false) + + if err != nil { + return nil, nil, err + } + return cascadingMemberships(memberships.Memberships), userGrantsToIDs(grants.UserGrants), nil +} diff --git a/internal/api/scim/resources/user_mapping.go b/internal/api/scim/resources/user_mapping.go new file mode 100644 index 0000000000..bc40005382 --- /dev/null +++ b/internal/api/scim/resources/user_mapping.go @@ -0,0 +1,133 @@ +package resources + +import ( + "context" + + "golang.org/x/text/language" + + "github.com/zitadel/zitadel/internal/command" + "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/query" +) + +func (h *UsersHandler) mapToAddHuman(ctx context.Context, scimUser *ScimUser) (*command.AddHuman, error) { + // zitadel has its own state mechanism + // ignore scimUser.Active + human := &command.AddHuman{ + Username: scimUser.UserName, + NickName: scimUser.NickName, + DisplayName: scimUser.DisplayName, + Email: h.mapPrimaryEmail(scimUser), + Phone: h.mapPrimaryPhone(scimUser), + } + + md, err := h.mapMetadataToCommands(ctx, scimUser) + if err != nil { + return nil, err + } + human.Metadata = md + + if scimUser.Password != nil { + human.Password = scimUser.Password.String() + scimUser.Password = nil + } + + if scimUser.Name != nil { + human.FirstName = scimUser.Name.GivenName + human.LastName = scimUser.Name.FamilyName + + // the direct mapping displayName => displayName has priority + // over the formatted name assignment + if human.DisplayName == "" { + human.DisplayName = scimUser.Name.Formatted + } + } + + if err := domain.LanguageIsDefined(scimUser.PreferredLanguage); err != nil { + human.PreferredLanguage = language.English + scimUser.PreferredLanguage = language.English + } + + return human, nil +} + +func (h *UsersHandler) mapPrimaryEmail(scimUser *ScimUser) command.Email { + for _, email := range scimUser.Emails { + if !email.Primary { + continue + } + + return command.Email{ + Address: domain.EmailAddress(email.Value), + Verified: h.config.EmailVerified, + } + } + + return command.Email{} +} + +func (h *UsersHandler) mapPrimaryPhone(scimUser *ScimUser) command.Phone { + for _, phone := range scimUser.PhoneNumbers { + if !phone.Primary { + continue + } + + return command.Phone{ + Number: domain.PhoneNumber(phone.Value), + Verified: h.config.PhoneVerified, + } + } + + return command.Phone{} +} + +func cascadingMemberships(memberships []*query.Membership) []*command.CascadingMembership { + cascades := make([]*command.CascadingMembership, len(memberships)) + for i, membership := range memberships { + cascades[i] = &command.CascadingMembership{ + UserID: membership.UserID, + ResourceOwner: membership.ResourceOwner, + IAM: cascadingIAMMembership(membership.IAM), + Org: cascadingOrgMembership(membership.Org), + Project: cascadingProjectMembership(membership.Project), + ProjectGrant: cascadingProjectGrantMembership(membership.ProjectGrant), + } + } + return cascades +} + +func cascadingIAMMembership(membership *query.IAMMembership) *command.CascadingIAMMembership { + if membership == nil { + return nil + } + return &command.CascadingIAMMembership{IAMID: membership.IAMID} +} + +func cascadingOrgMembership(membership *query.OrgMembership) *command.CascadingOrgMembership { + if membership == nil { + return nil + } + return &command.CascadingOrgMembership{OrgID: membership.OrgID} +} + +func cascadingProjectMembership(membership *query.ProjectMembership) *command.CascadingProjectMembership { + if membership == nil { + return nil + } + return &command.CascadingProjectMembership{ProjectID: membership.ProjectID} +} + +func cascadingProjectGrantMembership(membership *query.ProjectGrantMembership) *command.CascadingProjectGrantMembership { + if membership == nil { + return nil + } + return &command.CascadingProjectGrantMembership{ProjectID: membership.ProjectID, GrantID: membership.GrantID} +} + +func userGrantsToIDs(userGrants []*query.UserGrant) []string { + converted := make([]string, len(userGrants)) + for i, grant := range userGrants { + converted[i] = grant.ID + } + return converted +} diff --git a/internal/api/scim/resources/user_metadata.go b/internal/api/scim/resources/user_metadata.go new file mode 100644 index 0000000000..3d745d6857 --- /dev/null +++ b/internal/api/scim/resources/user_metadata.go @@ -0,0 +1,150 @@ +package resources + +import ( + "context" + "encoding/json" + "time" + + "github.com/zitadel/logging" + "golang.org/x/text/language" + + "github.com/zitadel/zitadel/internal/api/scim/metadata" + "github.com/zitadel/zitadel/internal/api/scim/schemas" + "github.com/zitadel/zitadel/internal/api/scim/serrors" + "github.com/zitadel/zitadel/internal/command" + "github.com/zitadel/zitadel/internal/zerrors" +) + +func (h *UsersHandler) mapMetadataToCommands(ctx context.Context, user *ScimUser) ([]*command.AddMetadataEntry, error) { + md := make([]*command.AddMetadataEntry, 0, len(metadata.ScimUserRelevantMetadataKeys)) + for _, key := range metadata.ScimUserRelevantMetadataKeys { + value, err := getValueForMetadataKey(user, key) + if err != nil { + return nil, err + } + + if len(value) > 0 { + md = append(md, &command.AddMetadataEntry{ + Key: string(metadata.ScopeKey(ctx, key)), + Value: value, + }) + } + } + + return md, nil +} + +func getValueForMetadataKey(user *ScimUser, key metadata.Key) ([]byte, error) { + value := getRawValueForMetadataKey(user, key) + if value == nil { + return nil, nil + } + + switch key { + // json values + case metadata.KeyEntitlements: + fallthrough + case metadata.KeyIms: + fallthrough + case metadata.KeyPhotos: + fallthrough + case metadata.KeyAddresses: + fallthrough + case metadata.KeyRoles: + return json.Marshal(value) + + // http url values + case metadata.KeyProfileUrl: + return []byte(value.(*schemas.HttpURL).String()), nil + + // raw values + case metadata.KeyProvisioningDomain: + fallthrough + case metadata.KeyExternalId: + fallthrough + case metadata.KeyMiddleName: + fallthrough + case metadata.KeyHonorificSuffix: + fallthrough + case metadata.KeyHonorificPrefix: + fallthrough + case metadata.KeyTitle: + fallthrough + case metadata.KeyLocale: + fallthrough + case metadata.KeyTimezone: + valueStr := value.(string) + if valueStr == "" { + return nil, nil + } + + return []byte(valueStr), validateValueForMetadataKey(valueStr, key) + } + + logging.Panicf("Unknown metadata key %s", key) + return nil, nil +} + +func validateValueForMetadataKey(v string, key metadata.Key) error { + //nolint:exhaustive + switch key { + case metadata.KeyLocale: + if _, err := language.Parse(v); err != nil { + return serrors.ThrowInvalidValue(zerrors.ThrowInvalidArgument(err, "SCIM-MD11", "Could not parse locale")) + } + return nil + case metadata.KeyTimezone: + if _, err := time.LoadLocation(v); err != nil { + return serrors.ThrowInvalidValue(zerrors.ThrowInvalidArgument(err, "SCIM-MD12", "Could not parse timezone")) + } + + return nil + } + + return nil +} + +func getRawValueForMetadataKey(user *ScimUser, key metadata.Key) interface{} { + switch key { + case metadata.KeyIms: + return user.Ims + case metadata.KeyPhotos: + return user.Photos + case metadata.KeyAddresses: + return user.Addresses + case metadata.KeyEntitlements: + return user.Entitlements + case metadata.KeyRoles: + return user.Roles + case metadata.KeyMiddleName: + if user.Name == nil { + return "" + } + return user.Name.MiddleName + case metadata.KeyHonorificPrefix: + if user.Name == nil { + return "" + } + return user.Name.HonorificPrefix + case metadata.KeyHonorificSuffix: + if user.Name == nil { + return "" + } + return user.Name.HonorificSuffix + case metadata.KeyExternalId: + return user.ExternalID + case metadata.KeyProfileUrl: + return user.ProfileUrl + case metadata.KeyTitle: + return user.Title + case metadata.KeyLocale: + return user.Locale + case metadata.KeyTimezone: + return user.Timezone + case metadata.KeyProvisioningDomain: + break + } + + logging.Panicf("Unknown or unsupported metadata key %s", key) + return nil +} diff --git a/internal/api/scim/schemas/schemas.go b/internal/api/scim/schemas/schemas.go new file mode 100644 index 0000000000..662a31f46f --- /dev/null +++ b/internal/api/scim/schemas/schemas.go @@ -0,0 +1,20 @@ +package schemas + +type ScimSchemaType string +type ScimResourceTypeSingular string +type ScimResourceTypePlural string + +const ( + idPrefixMessages = "urn:ietf:params:scim:api:messages:2.0:" + idPrefixCore = "urn:ietf:params:scim:schemas:core:2.0:" + idPrefixZitadelMessages = "urn:ietf:params:scim:api:zitadel:messages:2.0:" + + IdUser ScimSchemaType = idPrefixCore + "User" + IdError ScimSchemaType = idPrefixMessages + "Error" + IdZitadelErrorDetail ScimSchemaType = idPrefixZitadelMessages + "ErrorDetail" + + UserResourceType ScimResourceTypeSingular = "User" + UsersResourceType ScimResourceTypePlural = "Users" + + HandlerPrefix = "/scim/v2" +) diff --git a/internal/api/scim/schemas/string.go b/internal/api/scim/schemas/string.go new file mode 100644 index 0000000000..b62e50893d --- /dev/null +++ b/internal/api/scim/schemas/string.go @@ -0,0 +1,28 @@ +package schemas + +import "encoding/json" + +// WriteOnlyString a write only string is not serializable to json. +// in the SCIM RFC it has a mutability of writeOnly. +// This increases security to really ensure this is never sent to a client. +type WriteOnlyString string + +func NewWriteOnlyString(s string) *WriteOnlyString { + wos := WriteOnlyString(s) + return &wos +} + +func (s *WriteOnlyString) MarshalJSON() ([]byte, error) { + return []byte("null"), nil +} + +func (s *WriteOnlyString) UnmarshalJSON(bytes []byte) error { + var str string + err := json.Unmarshal(bytes, &str) + *s = WriteOnlyString(str) + return err +} + +func (s *WriteOnlyString) String() string { + return string(*s) +} diff --git a/internal/api/scim/schemas/string_test.go b/internal/api/scim/schemas/string_test.go new file mode 100644 index 0000000000..c48130a5d1 --- /dev/null +++ b/internal/api/scim/schemas/string_test.go @@ -0,0 +1,70 @@ +package schemas + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestWriteOnlyString_MarshalJSON(t *testing.T) { + tests := []struct { + name string + s WriteOnlyString + }{ + { + name: "always returns null", + s: "foo bar", + }, + { + name: "empty string returns null", + s: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := json.Marshal(&tt.s) + assert.NoError(t, err) + assert.Equal(t, "null", string(got)) + }) + } +} + +func TestWriteOnlyString_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + input []byte + want WriteOnlyString + wantErr bool + }{ + { + name: "string", + input: []byte(`"fooBar"`), + want: "fooBar", + wantErr: false, + }, + { + name: "empty string", + input: []byte(`""`), + want: "", + wantErr: false, + }, + { + name: "bad format", + input: []byte(`"bad "format"`), + want: "", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got WriteOnlyString + err := json.Unmarshal(tt.input, &got) + if (err != nil) != tt.wantErr { + t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/internal/api/scim/schemas/url.go b/internal/api/scim/schemas/url.go new file mode 100644 index 0000000000..343803bc04 --- /dev/null +++ b/internal/api/scim/schemas/url.go @@ -0,0 +1,50 @@ +package schemas + +import ( + "encoding/json" + "net/url" + + "github.com/zitadel/zitadel/internal/zerrors" +) + +type HttpURL url.URL + +func ParseHTTPURL(rawURL string) (*HttpURL, error) { + parsedURL, err := url.Parse(rawURL) + if err != nil { + return nil, err + } + + if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { + return nil, zerrors.ThrowInvalidArgumentf(nil, "SCIM-htturl1", "HTTP URL expected, got %v", parsedURL.Scheme) + } + + return (*HttpURL)(parsedURL), nil +} + +func (u *HttpURL) UnmarshalJSON(data []byte) error { + var urlStr string + if err := json.Unmarshal(data, &urlStr); err != nil { + return err + } + + parsedURL, err := ParseHTTPURL(urlStr) + if err != nil { + return err + } + + *u = *parsedURL + return nil +} + +func (u *HttpURL) MarshalJSON() ([]byte, error) { + return json.Marshal(u.String()) +} + +func (u *HttpURL) String() string { + if u == nil { + return "" + } + + return (*url.URL)(u).String() +} diff --git a/internal/api/scim/schemas/url_test.go b/internal/api/scim/schemas/url_test.go new file mode 100644 index 0000000000..a6a60322e0 --- /dev/null +++ b/internal/api/scim/schemas/url_test.go @@ -0,0 +1,182 @@ +package schemas + +import ( + "reflect" + "testing" + + "github.com/goccy/go-json" + "github.com/stretchr/testify/assert" + "github.com/zitadel/logging" +) + +func TestHttpURL_MarshalJSON(t *testing.T) { + tests := []struct { + name string + u *HttpURL + want []byte + wantErr bool + }{ + { + name: "http url", + u: mustParseURL("http://example.com"), + want: []byte(`"http://example.com"`), + wantErr: false, + }, + { + name: "https url", + u: mustParseURL("https://example.com"), + want: []byte(`"https://example.com"`), + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := json.Marshal(tt.u) + if (err != nil) != tt.wantErr { + t.Errorf("MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + + assert.Equal(t, string(got), string(tt.want)) + }) + } +} + +func TestHttpURL_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + data []byte + want *HttpURL + wantErr bool + }{ + { + name: "http url", + data: []byte(`"http://example.com"`), + want: mustParseURL("http://example.com"), + wantErr: false, + }, + { + name: "https url", + data: []byte(`"https://example.com"`), + want: mustParseURL("https://example.com"), + wantErr: false, + }, + { + name: "ftp url should fail", + data: []byte(`"ftp://example.com"`), + want: nil, + wantErr: true, + }, + { + name: "no url should fail", + data: []byte(`"test"`), + want: nil, + wantErr: true, + }, + { + name: "number should fail", + data: []byte(`120`), + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + url := new(HttpURL) + err := json.Unmarshal(tt.data, url) + if (err != nil) != tt.wantErr { + t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + + if tt.wantErr { + return + } + + assert.Equal(t, tt.want.String(), url.String()) + }) + } +} + +func TestHttpURL_String(t *testing.T) { + tests := []struct { + name string + u *HttpURL + want string + }{ + { + name: "http url", + u: mustParseURL("http://example.com"), + want: "http://example.com", + }, + { + name: "https url", + u: mustParseURL("https://example.com"), + want: "https://example.com", + }, + { + name: "nil", + u: nil, + want: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.u.String(); got != tt.want { + t.Errorf("String() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestParseHTTPURL(t *testing.T) { + tests := []struct { + name string + rawURL string + want *HttpURL + wantErr bool + }{ + { + name: "http url", + rawURL: "http://example.com", + want: mustParseURL("http://example.com"), + wantErr: false, + }, + { + name: "https url", + rawURL: "https://example.com", + want: mustParseURL("https://example.com"), + wantErr: false, + }, + { + name: "ftp url should fail", + rawURL: "ftp://example.com", + want: nil, + wantErr: true, + }, + { + name: "no url should fail", + rawURL: "test", + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseHTTPURL(tt.rawURL) + if (err != nil) != tt.wantErr { + t.Errorf("ParseHTTPURL() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("ParseHTTPURL() got = %v, want %v", got, tt.want) + } + }) + } +} + +func mustParseURL(rawURL string) *HttpURL { + url, err := ParseHTTPURL(rawURL) + logging.OnError(err).Fatal("failed to parse URL") + return url +} diff --git a/internal/api/scim/serrors/errors.go b/internal/api/scim/serrors/errors.go new file mode 100644 index 0000000000..fffd598b27 --- /dev/null +++ b/internal/api/scim/serrors/errors.go @@ -0,0 +1,140 @@ +package serrors + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "strconv" + + "github.com/zitadel/logging" + "golang.org/x/text/language" + + http_util "github.com/zitadel/zitadel/internal/api/http" + zhttp_middleware "github.com/zitadel/zitadel/internal/api/http/middleware" + "github.com/zitadel/zitadel/internal/api/scim/schemas" + "github.com/zitadel/zitadel/internal/i18n" + "github.com/zitadel/zitadel/internal/zerrors" +) + +type scimErrorType string + +type wrappedScimError struct { + Parent error + ScimType scimErrorType +} + +type scimError struct { + Schemas []schemas.ScimSchemaType `json:"schemas"` + ScimType scimErrorType `json:"scimType,omitempty"` + Detail string `json:"detail,omitempty"` + StatusCode int `json:"-"` + Status string `json:"status"` + ZitadelDetail *errorDetail `json:"urn:ietf:params:scim:api:zitadel:messages:2.0:ErrorDetail,omitempty"` +} + +type errorDetail struct { + ID string `json:"id"` + Message string `json:"message"` +} + +const ( + // ScimTypeInvalidValue A required value was missing, + // or the value specified was not compatible with the operation, + // or attribute type (see Section 2.2 of RFC7643), + // or resource schema (see Section 4 of RFC7643). + ScimTypeInvalidValue scimErrorType = "invalidValue" + + // ScimTypeInvalidSyntax The request body message structure was invalid or did + // not conform to the request schema. + ScimTypeInvalidSyntax scimErrorType = "invalidSyntax" +) + +var translator *i18n.Translator + +func ErrorHandler(next zhttp_middleware.HandlerFuncWithError) http.Handler { + var err error + translator, err = i18n.NewZitadelTranslator(language.English) + logging.OnError(err).Panic("unable to get translator") + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err = next(w, r); err == nil { + return + } + + scimErr := mapToScimJsonError(r.Context(), err) + w.WriteHeader(scimErr.StatusCode) + + jsonErr := json.NewEncoder(w).Encode(scimErr) + logging.OnError(jsonErr).Warn("Failed to marshal scim error response") + }) +} + +func ThrowInvalidValue(parent error) error { + return &wrappedScimError{ + Parent: parent, + ScimType: ScimTypeInvalidValue, + } +} + +func ThrowInvalidSyntax(parent error) error { + return &wrappedScimError{ + Parent: parent, + ScimType: ScimTypeInvalidSyntax, + } +} + +func (err *scimError) Error() string { + return fmt.Sprintf("SCIM Error: %s: %s", err.ScimType, err.Detail) +} + +func (err *wrappedScimError) Error() string { + return fmt.Sprintf("SCIM Error: %s: %s", err.ScimType, err.Parent.Error()) +} + +func mapToScimJsonError(ctx context.Context, err error) *scimError { + scimErr := new(wrappedScimError) + if ok := errors.As(err, &scimErr); ok { + mappedErr := mapToScimJsonError(ctx, scimErr.Parent) + mappedErr.ScimType = scimErr.ScimType + return mappedErr + } + + zitadelErr := new(zerrors.ZitadelError) + if ok := errors.As(err, &zitadelErr); !ok { + return &scimError{ + Schemas: []schemas.ScimSchemaType{schemas.IdError}, + Detail: "Unknown internal server error", + Status: strconv.Itoa(http.StatusInternalServerError), + StatusCode: http.StatusInternalServerError, + } + } + + statusCode, ok := http_util.ZitadelErrorToHTTPStatusCode(err) + if !ok { + statusCode = http.StatusInternalServerError + } + + localizedMsg := translator.LocalizeFromCtx(ctx, zitadelErr.GetMessage(), nil) + return &scimError{ + Schemas: []schemas.ScimSchemaType{schemas.IdError, schemas.IdZitadelErrorDetail}, + ScimType: mapErrorToScimErrorType(err), + Detail: localizedMsg, + StatusCode: statusCode, + Status: strconv.Itoa(statusCode), + ZitadelDetail: &errorDetail{ + ID: zitadelErr.GetID(), + Message: zitadelErr.GetMessage(), + }, + } +} + +func mapErrorToScimErrorType(err error) scimErrorType { + switch { + case zerrors.IsErrorInvalidArgument(err): + return ScimTypeInvalidValue + default: + return "" + } +} diff --git a/internal/api/scim/serrors/errors_test.go b/internal/api/scim/serrors/errors_test.go new file mode 100644 index 0000000000..71d8018355 --- /dev/null +++ b/internal/api/scim/serrors/errors_test.go @@ -0,0 +1,110 @@ +package serrors + +import ( + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/zitadel/zitadel/internal/i18n" + "github.com/zitadel/zitadel/internal/zerrors" +) + +func TestErrorHandler(t *testing.T) { + i18n.MustLoadSupportedLanguagesFromDir() + + tests := []struct { + name string + err error + wantStatus int + wantBody string + }{ + { + name: "scim error", + err: ThrowInvalidSyntax(zerrors.ThrowInvalidArgument(nil, "FOO", "Invalid syntax")), + wantStatus: http.StatusBadRequest, + wantBody: `{ + "schemas":[ + "urn:ietf:params:scim:api:messages:2.0:Error", + "urn:ietf:params:scim:api:zitadel:messages:2.0:ErrorDetail" + ], + "scimType":"invalidSyntax", + "detail":"Invalid syntax", + "status":"400", + "urn:ietf:params:scim:api:zitadel:messages:2.0:ErrorDetail": { + "id":"FOO", + "message":"Invalid syntax" + } + }`, + }, + { + name: "zitadel error", + err: zerrors.ThrowInvalidArgument(nil, "FOO", "Invalid syntax"), + wantStatus: http.StatusBadRequest, + wantBody: `{ + "schemas":[ + "urn:ietf:params:scim:api:messages:2.0:Error", + "urn:ietf:params:scim:api:zitadel:messages:2.0:ErrorDetail" + ], + "scimType":"invalidValue", + "detail":"Invalid syntax", + "status":"400", + "urn:ietf:params:scim:api:zitadel:messages:2.0:ErrorDetail": { + "id":"FOO", + "message":"Invalid syntax" + } + }`, + }, + { + name: "zitadel internal error", + err: zerrors.ThrowInternal(nil, "FOO", "Internal error"), + wantStatus: http.StatusInternalServerError, + wantBody: `{ + "schemas":[ + "urn:ietf:params:scim:api:messages:2.0:Error", + "urn:ietf:params:scim:api:zitadel:messages:2.0:ErrorDetail" + ], + "detail":"Internal error", + "status":"500", + "urn:ietf:params:scim:api:zitadel:messages:2.0:ErrorDetail": { + "id":"FOO", + "message":"Internal error" + } + }`, + }, + { + name: "unknown error", + err: errors.New("FOO"), + wantStatus: http.StatusInternalServerError, + wantBody: `{ + "schemas":[ + "urn:ietf:params:scim:api:messages:2.0:Error" + ], + "detail":"Unknown internal server error", + "status":"500" + }`, + }, + { + name: "no error", + err: nil, + wantStatus: http.StatusOK, + wantBody: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + recorder := httptest.NewRecorder() + ErrorHandler(func(http.ResponseWriter, *http.Request) error { + return tt.err + }).ServeHTTP(recorder, req) + assert.Equal(t, tt.wantStatus, recorder.Code) + + if tt.wantBody != "" { + assert.JSONEq(t, tt.wantBody, recorder.Body.String()) + } + }) + } +} diff --git a/internal/api/scim/server.go b/internal/api/scim/server.go new file mode 100644 index 0000000000..a2f9c7e7bf --- /dev/null +++ b/internal/api/scim/server.go @@ -0,0 +1,87 @@ +package scim + +import ( + "encoding/json" + "net/http" + "path" + + "github.com/gorilla/mux" + "github.com/zitadel/logging" + + "github.com/zitadel/zitadel/internal/api/authz" + zhttp "github.com/zitadel/zitadel/internal/api/http" + zhttp_middlware "github.com/zitadel/zitadel/internal/api/http/middleware" + sconfig "github.com/zitadel/zitadel/internal/api/scim/config" + smiddleware "github.com/zitadel/zitadel/internal/api/scim/middleware" + sresources "github.com/zitadel/zitadel/internal/api/scim/resources" + "github.com/zitadel/zitadel/internal/api/scim/schemas" + "github.com/zitadel/zitadel/internal/api/scim/serrors" + "github.com/zitadel/zitadel/internal/command" + "github.com/zitadel/zitadel/internal/crypto" + "github.com/zitadel/zitadel/internal/query" +) + +func NewServer( + command *command.Commands, + query *query.Queries, + verifier *authz.ApiTokenVerifier, + userCodeAlg crypto.EncryptionAlgorithm, + config *sconfig.Config, + middlewares ...zhttp_middlware.MiddlewareWithErrorFunc) http.Handler { + verifier.RegisterServer("SCIM-V2", schemas.HandlerPrefix, AuthMapping) + return buildHandler(command, query, userCodeAlg, config, middlewares...) +} + +func buildHandler( + command *command.Commands, + query *query.Queries, + userCodeAlg crypto.EncryptionAlgorithm, + cfg *sconfig.Config, + middlewares ...zhttp_middlware.MiddlewareWithErrorFunc) http.Handler { + + router := mux.NewRouter() + + // content type middleware needs to run at the very beginning to correctly set content types of errors + middlewares = append([]zhttp_middlware.MiddlewareWithErrorFunc{smiddleware.ContentTypeMiddleware}, middlewares...) + middlewares = append(middlewares, smiddleware.ScimContextMiddleware(query)) + scimMiddleware := zhttp_middlware.ChainedWithErrorHandler(serrors.ErrorHandler, middlewares...) + mapResource(router, scimMiddleware, sresources.NewUsersHandler(command, query, userCodeAlg, cfg)) + return router +} + +func mapResource[T sresources.ResourceHolder](router *mux.Router, mw zhttp_middlware.ErrorHandlerFunc, handler sresources.ResourceHandler[T]) { + adapter := sresources.NewResourceHandlerAdapter[T](handler) + resourceRouter := router.PathPrefix("/" + path.Join(zhttp.OrgIdInPathVariable, string(handler.ResourceNamePlural()))).Subrouter() + + resourceRouter.Handle("", mw(handleResourceCreatedResponse(adapter.Create))).Methods(http.MethodPost) + resourceRouter.Handle("/{id}", mw(handleEmptyResponse(adapter.Delete))).Methods(http.MethodDelete) +} + +func handleResourceCreatedResponse[T sresources.ResourceHolder](next func(*http.Request) (T, error)) zhttp_middlware.HandlerFuncWithError { + return func(w http.ResponseWriter, r *http.Request) error { + entity, err := next(r) + if err != nil { + return err + } + + resource := entity.GetResource() + w.Header().Set(zhttp.Location, resource.Meta.Location) + w.WriteHeader(http.StatusCreated) + + err = json.NewEncoder(w).Encode(entity) + logging.OnError(err).Warn("scim json response encoding failed") + return nil + } +} + +func handleEmptyResponse(next func(*http.Request) error) zhttp_middlware.HandlerFuncWithError { + return func(w http.ResponseWriter, r *http.Request) error { + err := next(r) + if err != nil { + return err + } + + w.WriteHeader(http.StatusNoContent) + return nil + } +} diff --git a/internal/integration/assert.go b/internal/integration/assert.go index 6743c8297e..3f5ebdf54f 100644 --- a/internal/integration/assert.go +++ b/internal/integration/assert.go @@ -6,6 +6,8 @@ import ( "github.com/pmezard/go-difflib/difflib" "github.com/stretchr/testify/assert" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" @@ -128,6 +130,13 @@ func AssertResourceListDetails[D ResourceListDetailsMsg](t assert.TestingT, expe } } +func AssertGrpcStatus(t assert.TestingT, expected codes.Code, err error) { + assert.Error(t, err) + statusErr, ok := status.FromError(err) + assert.True(t, ok) + assert.Equal(t, expected, statusErr.Code()) +} + // EqualProto is inspired by [assert.Equal], only that it tests equality of a proto message. // A message diff is printed on the error test log if the messages are not equal. // @@ -160,3 +169,9 @@ func diffProto(expected, actual proto.Message) string { } return "\n\nDiff:\n" + diff } + +func AssertMapContains[M ~map[K]V, K comparable, V any](t *testing.T, m M, key K, expectedValue V) { + val, exists := m[key] + assert.True(t, exists, "Key '%s' should exist in the map", key) + assert.Equal(t, expectedValue, val, "Key '%s' should have value '%d'", key, expectedValue) +} diff --git a/internal/integration/client.go b/internal/integration/client.go index af30f0e642..d18c2d9b12 100644 --- a/internal/integration/client.go +++ b/internal/integration/client.go @@ -17,6 +17,7 @@ import ( "google.golang.org/protobuf/types/known/structpb" "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/integration/scim" "github.com/zitadel/zitadel/pkg/grpc/admin" "github.com/zitadel/zitadel/pkg/grpc/auth" "github.com/zitadel/zitadel/pkg/grpc/feature/v2" @@ -67,6 +68,7 @@ type Client struct { IDPv2 idp_pb.IdentityProviderServiceClient UserV3Alpha user_v3alpha.ZITADELUsersClient SAMLv2 saml_pb.SAMLServiceClient + SCIM *scim.Client } func newClient(ctx context.Context, target string) (*Client, error) { @@ -99,6 +101,7 @@ func newClient(ctx context.Context, target string) (*Client, error) { IDPv2: idp_pb.NewIdentityProviderServiceClient(cc), UserV3Alpha: user_v3alpha.NewZITADELUsersClient(cc), SAMLv2: saml_pb.NewSAMLServiceClient(cc), + SCIM: scim.NewScimClient(target), } return client, client.pollHealth(ctx) } diff --git a/internal/integration/scim/assertions.go b/internal/integration/scim/assertions.go new file mode 100644 index 0000000000..a91c33da82 --- /dev/null +++ b/internal/integration/scim/assertions.go @@ -0,0 +1,22 @@ +package scim + +import ( + "errors" + "strconv" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type AssertedScimError struct { + Error *ScimError +} + +func RequireScimError(t require.TestingT, httpStatus int, err error) AssertedScimError { + require.Error(t, err) + + var scimErr *ScimError + assert.True(t, errors.As(err, &scimErr)) + assert.Equal(t, strconv.Itoa(httpStatus), scimErr.Status) + return AssertedScimError{scimErr} // wrap it, otherwise error handling is enforced +} diff --git a/internal/integration/scim/client.go b/internal/integration/scim/client.go new file mode 100644 index 0000000000..478c831826 --- /dev/null +++ b/internal/integration/scim/client.go @@ -0,0 +1,146 @@ +package scim + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "path" + + "github.com/zitadel/logging" + "google.golang.org/grpc/metadata" + + zhttp "github.com/zitadel/zitadel/internal/api/http" + "github.com/zitadel/zitadel/internal/api/scim/middleware" + "github.com/zitadel/zitadel/internal/api/scim/resources" + "github.com/zitadel/zitadel/internal/api/scim/schemas" +) + +type Client struct { + Users *ResourceClient +} + +type ResourceClient struct { + client *http.Client + baseUrl string + resourceName string +} + +type ScimError struct { + Schemas []string `json:"schemas"` + ScimType string `json:"scimType"` + Detail string `json:"detail"` + Status string `json:"status"` + ZitadelDetail *ZitadelErrorDetail `json:"urn:ietf:params:scim:api:zitadel:messages:2.0:ErrorDetail,omitempty"` +} + +type ZitadelErrorDetail struct { + ID string `json:"id"` + Message string `json:"message"` +} + +func NewScimClient(target string) *Client { + target = "http://" + target + schemas.HandlerPrefix + client := &http.Client{} + return &Client{ + Users: &ResourceClient{ + client: client, + baseUrl: target, + resourceName: "Users", + }, + } +} + +func (c *ResourceClient) Create(ctx context.Context, orgID string, body []byte) (*resources.ScimUser, error) { + user := new(resources.ScimUser) + err := c.doWithBody(ctx, http.MethodPost, orgID, "", bytes.NewReader(body), user) + return user, err +} + +func (c *ResourceClient) Delete(ctx context.Context, orgID, id string) error { + return c.do(ctx, http.MethodDelete, orgID, id) +} + +func (c *ResourceClient) do(ctx context.Context, method, orgID, url string) error { + req, err := http.NewRequestWithContext(ctx, method, c.buildURL(orgID, url), nil) + if err != nil { + return err + } + + return c.doReq(req, nil) +} + +func (c *ResourceClient) doWithBody(ctx context.Context, method, orgID, url string, body io.Reader, responseEntity interface{}) error { + req, err := http.NewRequestWithContext(ctx, method, c.buildURL(orgID, url), body) + if err != nil { + return err + } + + req.Header.Set(zhttp.ContentType, middleware.ContentTypeScim) + return c.doReq(req, responseEntity) +} + +func (c *ResourceClient) doReq(req *http.Request, responseEntity interface{}) error { + addTokenAsHeader(req) + + resp, err := c.client.Do(req) + defer func() { + err := resp.Body.Close() + logging.OnError(err).Error("Failed to close response body") + }() + + if err != nil { + return err + } + + if (resp.StatusCode / 100) != 2 { + return readScimError(resp) + } + + if responseEntity == nil { + return nil + } + + err = readJson(responseEntity, resp) + return err +} + +func addTokenAsHeader(req *http.Request) { + md, ok := metadata.FromOutgoingContext(req.Context()) + if !ok { + return + } + + req.Header.Set("Authorization", md.Get("Authorization")[0]) +} + +func readJson(entity interface{}, resp *http.Response) error { + defer func(body io.ReadCloser) { + err := body.Close() + logging.OnError(err).Panic("Failed to close response body") + }(resp.Body) + + err := json.NewDecoder(resp.Body).Decode(entity) + logging.OnError(err).Panic("Failed decoding entity") + return err +} + +func readScimError(resp *http.Response) error { + scimErr := new(ScimError) + readErr := readJson(scimErr, resp) + logging.OnError(readErr).Panic("Failed reading scim error") + return scimErr +} + +func (c *ResourceClient) buildURL(orgID, segment string) string { + if segment == "" { + return c.baseUrl + "/" + path.Join(orgID, c.resourceName) + } + + return c.baseUrl + "/" + path.Join(orgID, c.resourceName, segment) +} + +func (err *ScimError) Error() string { + return "scim error: " + err.Detail +} diff --git a/internal/zerrors/zerror.go b/internal/zerrors/zerror.go index d7b85b84a7..996f67ce29 100644 --- a/internal/zerrors/zerror.go +++ b/internal/zerrors/zerror.go @@ -79,3 +79,8 @@ func (err *ZitadelError) As(target interface{}) bool { reflect.Indirect(reflect.ValueOf(target)).Set(reflect.ValueOf(err)) return true } + +func IsZitadelError(err error) bool { + zitadelErr := new(ZitadelError) + return errors.As(err, &zitadelErr) +} diff --git a/internal/zerrors/zerror_test.go b/internal/zerrors/zerror_test.go index 3a11a8e78e..517f938ee4 100644 --- a/internal/zerrors/zerror_test.go +++ b/internal/zerrors/zerror_test.go @@ -1,6 +1,7 @@ package zerrors_test import ( + "errors" "testing" "github.com/stretchr/testify/assert" @@ -17,3 +18,27 @@ func TestErrorMethod(t *testing.T) { subExptected := "ID=subID Message=subMsg Parent=(ID=id Message=msg)" assert.Equal(t, subExptected, err.Error()) } + +func TestIsZitadelError(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + { + name: "zitadel error", + err: zerrors.ThrowInvalidArgument(nil, "id", "msg"), + want: true, + }, + { + name: "other error", + err: errors.New("just a random error"), + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, zerrors.IsZitadelError(tt.err), "IsZitadelError(%v)", tt.err) + }) + } +}