Merge branch 'main' into fix-project-grant-owners

This commit is contained in:
Silvan 2025-01-10 13:15:51 +01:00 committed by GitHub
commit 1ed82d76c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 749 additions and 20 deletions

View File

@ -152,7 +152,7 @@ curl --request POST \
If you didn't get a user ID in the parameters of your success page, you know that there is no existing user in ZITADEL with that provider, and you can register a new user or link it to an existing account (read the next section). If you didn't get a user ID in the parameters of your success page, you know that there is no existing user in ZITADEL with that provider, and you can register a new user or link it to an existing account (read the next section).
Fill the IdP links in the create user request to add a user with an external login provider. Fill the IdP links in the create user request to add a user with an external login provider.
The idpId is the ID of the provider in ZITADEL, the idpExternalId is the ID of the user in the external identity provider; usually, this is sent in the “sub”. The idpId is the ID of the provider in ZITADEL, the userId is the ID of the user in the external identity provider; usually, this is sent in the “sub”.
The display name is used to list the linkings on the users. The display name is used to list the linkings on the users.
[Create User API Documentation](/docs/apis/resources/user_service_v2/user-service-add-human-user) [Create User API Documentation](/docs/apis/resources/user_service_v2/user-service-add-human-user)
@ -181,8 +181,8 @@ curl --request POST \
"idpLinks": [ "idpLinks": [
{ {
"idpId": "218528353504723201", "idpId": "218528353504723201",
"idpExternalId": "111392805975715856637", "userId": "111392805975715856637",
"displayName": "Minnie Mouse" "userName": "Minnie Mouse"
} }
] ]
}' }'
@ -205,8 +205,8 @@ curl --request POST \
--data '{ --data '{
"idpLink": { "idpLink": {
"idpId": "218528353504723201", "idpId": "218528353504723201",
"idpExternalId": "1113928059757158566371", "userId": "1113928059757158566371",
"displayName": "Minnie Mouse" "userName": "Minnie Mouse"
} }
}' }'
``` ```

View File

@ -10,6 +10,9 @@ var AuthMapping = authz.MethodMapping{
"POST:/scim/v2/" + http.OrgIdInPathVariable + "/Users": { "POST:/scim/v2/" + http.OrgIdInPathVariable + "/Users": {
Permission: domain.PermissionUserWrite, Permission: domain.PermissionUserWrite,
}, },
"GET:/scim/v2/" + http.OrgIdInPathVariable + "/Users/{id}": {
Permission: domain.PermissionUserRead,
},
"DELETE:/scim/v2/" + http.OrgIdInPathVariable + "/Users/{id}": { "DELETE:/scim/v2/" + http.OrgIdInPathVariable + "/Users/{id}": {
Permission: domain.PermissionUserDelete, Permission: domain.PermissionUserDelete,
}, },

View File

@ -0,0 +1,255 @@
//go:build integration
package integration_test
import (
"context"
"github.com/brianvoe/gofakeit/v6"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/api/scim/resources"
"github.com/zitadel/zitadel/internal/api/scim/schemas"
"github.com/zitadel/zitadel/internal/integration"
"github.com/zitadel/zitadel/internal/integration/scim"
"github.com/zitadel/zitadel/pkg/grpc/management"
guser "github.com/zitadel/zitadel/pkg/grpc/user/v2"
"golang.org/x/text/language"
"net/http"
"path"
"testing"
)
func TestGetUser(t *testing.T) {
tests := []struct {
name string
buildUserID func() (userID string, deleteUser bool)
ctx context.Context
want *resources.ScimUser
wantErr bool
errorStatus int
}{
{
name: "not authenticated",
ctx: context.Background(),
errorStatus: http.StatusUnauthorized,
wantErr: true,
},
{
name: "no permissions",
ctx: Instance.WithAuthorization(CTX, integration.UserTypeNoPermission),
errorStatus: http.StatusNotFound,
wantErr: true,
},
{
name: "unknown user id",
buildUserID: func() (string, bool) {
return "unknown", false
},
errorStatus: http.StatusNotFound,
wantErr: true,
},
{
name: "created via grpc",
want: &resources.ScimUser{
Name: &resources.ScimUserName{
FamilyName: "Mouse",
GivenName: "Mickey",
},
PreferredLanguage: language.MustParse("nl"),
PhoneNumbers: []*resources.ScimPhoneNumber{
{
Value: "+41791234567",
Primary: true,
},
},
},
},
{
name: "created via scim",
buildUserID: func() (string, bool) {
user, err := Instance.Client.SCIM.Users.Create(CTX, Instance.DefaultOrg.Id, fullUserJson)
require.NoError(t, err)
return user.ID, true
},
want: &resources.ScimUser{
ExternalID: "701984",
UserName: "bjensen@example.com",
Name: &resources.ScimUserName{
Formatted: "Babs Jensen", // DisplayName takes precedence
FamilyName: "Jensen",
GivenName: "Barbara",
MiddleName: "Jane",
HonorificPrefix: "Ms.",
HonorificSuffix: "III",
},
DisplayName: "Babs Jensen",
NickName: "Babs",
ProfileUrl: integration.Must(schemas.ParseHTTPURL("http://login.example.com/bjensen")),
Title: "Tour Guide",
PreferredLanguage: language.Make("en-US"),
Locale: "en-US",
Timezone: "America/Los_Angeles",
Active: gu.Ptr(true),
Emails: []*resources.ScimEmail{
{
Value: "bjensen@example.com",
Primary: true,
},
},
PhoneNumbers: []*resources.ScimPhoneNumber{
{
Value: "+415555555555",
Primary: true,
},
},
Ims: []*resources.ScimIms{
{
Value: "someaimhandle",
Type: "aim",
},
{
Value: "twitterhandle",
Type: "X",
},
},
Addresses: []*resources.ScimAddress{
{
Type: "work",
StreetAddress: "100 Universal City Plaza",
Locality: "Hollywood",
Region: "CA",
PostalCode: "91608",
Country: "USA",
Formatted: "100 Universal City Plaza\nHollywood, CA 91608 USA",
Primary: true,
},
{
Type: "home",
StreetAddress: "456 Hollywood Blvd",
Locality: "Hollywood",
Region: "CA",
PostalCode: "91608",
Country: "USA",
Formatted: "456 Hollywood Blvd\nHollywood, CA 91608 USA",
},
},
Photos: []*resources.ScimPhoto{
{
Value: *integration.Must(schemas.ParseHTTPURL("https://photos.example.com/profilephoto/72930000000Ccne/F")),
Type: "photo",
},
{
Value: *integration.Must(schemas.ParseHTTPURL("https://photos.example.com/profilephoto/72930000000Ccne/T")),
Type: "thumbnail",
},
},
Roles: []*resources.ScimRole{
{
Value: "my-role-1",
Display: "Rolle 1",
Type: "main-role",
Primary: true,
},
{
Value: "my-role-2",
Display: "Rolle 2",
Type: "secondary-role",
Primary: false,
},
},
Entitlements: []*resources.ScimEntitlement{
{
Value: "my-entitlement-1",
Display: "Entitlement 1",
Type: "main-entitlement",
Primary: true,
},
{
Value: "my-entitlement-2",
Display: "Entitlement 2",
Type: "secondary-entitlement",
Primary: false,
},
},
},
},
{
name: "scoped externalID",
buildUserID: func() (string, bool) {
// create user without provisioning domain
user, err := Instance.Client.SCIM.Users.Create(CTX, Instance.DefaultOrg.Id, fullUserJson)
require.NoError(t, err)
// set provisioning domain of service user
_, err = Instance.Client.Mgmt.SetUserMetadata(CTX, &management.SetUserMetadataRequest{
Id: Instance.Users.Get(integration.UserTypeOrgOwner).ID,
Key: "urn:zitadel:scim:provisioning_domain",
Value: []byte("fooBar"),
})
require.NoError(t, err)
// set externalID for provisioning domain
_, err = Instance.Client.Mgmt.SetUserMetadata(CTX, &management.SetUserMetadataRequest{
Id: user.ID,
Key: "urn:zitadel:scim:fooBar:externalId",
Value: []byte("100-scopedExternalId"),
})
require.NoError(t, err)
return user.ID, true
},
want: &resources.ScimUser{
ExternalID: "100-scopedExternalId",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := tt.ctx
if ctx == nil {
ctx = CTX
}
var userID string
var deleteUserAfterTest bool
if tt.buildUserID != nil {
userID, deleteUserAfterTest = tt.buildUserID()
} else {
createUserResp := Instance.CreateHumanUser(CTX)
userID = createUserResp.UserId
}
user, err := Instance.Client.SCIM.Users.Get(ctx, Instance.DefaultOrg.Id, userID)
if tt.wantErr {
statusCode := tt.errorStatus
if statusCode == 0 {
statusCode = http.StatusBadRequest
}
scim.RequireScimError(t, statusCode, err)
return
}
assert.Equal(t, userID, user.ID)
assert.EqualValues(t, []schemas.ScimSchemaType{"urn:ietf:params:scim:schemas:core:2.0:User"}, user.Schemas)
assert.Equal(t, schemas.ScimResourceTypeSingular("User"), user.Resource.Meta.ResourceType)
assert.Equal(t, "http://"+Instance.Host()+path.Join(schemas.HandlerPrefix, Instance.DefaultOrg.Id, "Users", user.ID), user.Resource.Meta.Location)
assert.Nil(t, user.Password)
if !integration.PartiallyDeepEqual(tt.want, user) {
t.Errorf("keysFromArgs() got = %v, want %v", user, tt.want)
}
if deleteUserAfterTest {
_, err = Instance.Client.UserV2.DeleteUser(CTX, &guser.DeleteUserRequest{UserId: user.ID})
require.NoError(t, err)
}
})
}
}
func TestGetUser_anotherOrg(t *testing.T) {
createUserResp := Instance.CreateHumanUser(CTX)
org := Instance.CreateOrganization(Instance.WithAuthorization(CTX, integration.UserTypeIAMOwner), gofakeit.Name(), gofakeit.Email())
_, err := Instance.Client.SCIM.Users.Get(CTX, org.OrganizationId, createUserResp.UserId)
scim.RequireScimError(t, http.StatusNotFound, err)
}

View File

@ -20,6 +20,7 @@ type ResourceHandler[T ResourceHolder] interface {
Create(ctx context.Context, resource T) (T, error) Create(ctx context.Context, resource T) (T, error)
Delete(ctx context.Context, id string) error Delete(ctx context.Context, id string) error
Get(ctx context.Context, id string) (T, error)
} }
type Resource struct { type Resource struct {

View File

@ -52,6 +52,11 @@ func (adapter *ResourceHandlerAdapter[T]) Delete(r *http.Request) error {
return adapter.handler.Delete(r.Context(), id) return adapter.handler.Delete(r.Context(), id)
} }
func (adapter *ResourceHandlerAdapter[T]) Get(r *http.Request) (T, error) {
id := mux.Vars(r)["id"]
return adapter.handler.Get(r.Context(), id)
}
func (adapter *ResourceHandlerAdapter[T]) readEntityFromBody(r *http.Request) (T, error) { func (adapter *ResourceHandlerAdapter[T]) readEntityFromBody(r *http.Request) (T, error) {
entity := adapter.handler.NewResource() entity := adapter.handler.NewResource()
err := json.NewDecoder(r.Body).Decode(entity) err := json.NewDecoder(r.Body).Decode(entity)

View File

@ -155,6 +155,19 @@ func (h *UsersHandler) Delete(ctx context.Context, id string) error {
return err return err
} }
func (h *UsersHandler) Get(ctx context.Context, id string) (*ScimUser, error) {
user, err := h.query.GetUserByID(ctx, false, id)
if err != nil {
return nil, err
}
metadata, err := h.queryMetadataForUser(ctx, id)
if err != nil {
return nil, err
}
return h.mapToScimUser(ctx, user, metadata), nil
}
func (h *UsersHandler) queryUserDependencies(ctx context.Context, userID string) ([]*command.CascadingMembership, []string, error) { func (h *UsersHandler) queryUserDependencies(ctx context.Context, userID string) ([]*command.CascadingMembership, []string, error) {
userGrantUserQuery, err := query.NewUserGrantUserIDSearchQuery(userID) userGrantUserQuery, err := query.NewUserGrantUserIDSearchQuery(userID)
if err != nil { if err != nil {

View File

@ -2,9 +2,15 @@ package resources
import ( import (
"context" "context"
"strconv"
"time"
"github.com/muhlemmer/gu"
"github.com/zitadel/logging"
"golang.org/x/text/language" "golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/api/scim/metadata"
"github.com/zitadel/zitadel/internal/api/scim/schemas"
"github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/query" "github.com/zitadel/zitadel/internal/query"
@ -81,6 +87,112 @@ func (h *UsersHandler) mapPrimaryPhone(scimUser *ScimUser) command.Phone {
return command.Phone{} return command.Phone{}
} }
func (h *UsersHandler) mapToScimUser(ctx context.Context, user *query.User, md map[metadata.ScopedKey][]byte) *ScimUser {
scimUser := &ScimUser{
Resource: h.buildResourceForQuery(ctx, user),
ID: user.ID,
ExternalID: extractScalarMetadata(ctx, md, metadata.KeyExternalId),
UserName: user.Username,
ProfileUrl: extractHttpURLMetadata(ctx, md, metadata.KeyProfileUrl),
Title: extractScalarMetadata(ctx, md, metadata.KeyTitle),
Locale: extractScalarMetadata(ctx, md, metadata.KeyLocale),
Timezone: extractScalarMetadata(ctx, md, metadata.KeyTimezone),
Active: gu.Ptr(user.State.IsEnabled()),
Ims: make([]*ScimIms, 0),
Addresses: make([]*ScimAddress, 0),
Photos: make([]*ScimPhoto, 0),
Entitlements: make([]*ScimEntitlement, 0),
Roles: make([]*ScimRole, 0),
}
if scimUser.Locale != "" {
_, err := language.Parse(scimUser.Locale)
if err != nil {
logging.OnError(err).Warn("Failed to load locale of scim user")
scimUser.Locale = ""
}
}
if scimUser.Timezone != "" {
_, err := time.LoadLocation(scimUser.Timezone)
if err != nil {
logging.OnError(err).Warn("Failed to load timezone of scim user")
scimUser.Timezone = ""
}
}
if err := extractJsonMetadata(ctx, md, metadata.KeyIms, &scimUser.Ims); err != nil {
logging.OnError(err).Warn("Could not deserialize scim ims metadata")
}
if err := extractJsonMetadata(ctx, md, metadata.KeyAddresses, &scimUser.Addresses); err != nil {
logging.OnError(err).Warn("Could not deserialize scim addresses metadata")
}
if err := extractJsonMetadata(ctx, md, metadata.KeyPhotos, &scimUser.Photos); err != nil {
logging.OnError(err).Warn("Could not deserialize scim photos metadata")
}
if err := extractJsonMetadata(ctx, md, metadata.KeyEntitlements, &scimUser.Entitlements); err != nil {
logging.OnError(err).Warn("Could not deserialize scim entitlements metadata")
}
if err := extractJsonMetadata(ctx, md, metadata.KeyRoles, &scimUser.Roles); err != nil {
logging.OnError(err).Warn("Could not deserialize scim roles metadata")
}
if user.Human != nil {
mapHumanToScimUser(ctx, user.Human, scimUser, md)
}
return scimUser
}
func mapHumanToScimUser(ctx context.Context, human *query.Human, user *ScimUser, md map[metadata.ScopedKey][]byte) {
user.DisplayName = human.DisplayName
user.NickName = human.NickName
user.PreferredLanguage = human.PreferredLanguage
user.Name = &ScimUserName{
Formatted: human.DisplayName,
FamilyName: human.LastName,
GivenName: human.FirstName,
MiddleName: extractScalarMetadata(ctx, md, metadata.KeyMiddleName),
HonorificPrefix: extractScalarMetadata(ctx, md, metadata.KeyHonorificPrefix),
HonorificSuffix: extractScalarMetadata(ctx, md, metadata.KeyHonorificSuffix),
}
if string(human.Email) != "" {
user.Emails = []*ScimEmail{
{
Value: string(human.Email),
Primary: true,
},
}
}
if string(human.Phone) != "" {
user.PhoneNumbers = []*ScimPhoneNumber{
{
Value: string(human.Phone),
Primary: true,
},
}
}
}
func (h *UsersHandler) buildResourceForQuery(ctx context.Context, user *query.User) *Resource {
return &Resource{
Schemas: []schemas.ScimSchemaType{schemas.IdUser},
Meta: &ResourceMeta{
ResourceType: schemas.UserResourceType,
Created: user.CreationDate.UTC(),
LastModified: user.ChangeDate.UTC(),
Version: strconv.FormatUint(user.Sequence, 10),
Location: buildLocation(ctx, h, user.ID),
},
}
}
func cascadingMemberships(memberships []*query.Membership) []*command.CascadingMembership { func cascadingMemberships(memberships []*query.Membership) []*command.CascadingMembership {
cascades := make([]*command.CascadingMembership, len(memberships)) cascades := make([]*command.CascadingMembership, len(memberships))
for i, membership := range memberships { for i, membership := range memberships {

View File

@ -12,9 +12,49 @@ import (
"github.com/zitadel/zitadel/internal/api/scim/schemas" "github.com/zitadel/zitadel/internal/api/scim/schemas"
"github.com/zitadel/zitadel/internal/api/scim/serrors" "github.com/zitadel/zitadel/internal/api/scim/serrors"
"github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/zerrors" "github.com/zitadel/zitadel/internal/zerrors"
) )
func (h *UsersHandler) queryMetadataForUser(ctx context.Context, id string) (map[metadata.ScopedKey][]byte, error) {
queries := h.buildMetadataQueries(ctx)
md, err := h.query.SearchUserMetadata(ctx, false, id, queries, false)
if err != nil {
return nil, err
}
metadataMap := make(map[metadata.ScopedKey][]byte, len(md.Metadata))
for _, entry := range md.Metadata {
metadataMap[metadata.ScopedKey(entry.Key)] = entry.Value
}
return metadataMap, nil
}
func (h *UsersHandler) buildMetadataQueries(ctx context.Context) *query.UserMetadataSearchQueries {
keyQueries := make([]query.SearchQuery, len(metadata.ScimUserRelevantMetadataKeys))
for i, key := range metadata.ScimUserRelevantMetadataKeys {
keyQueries[i] = buildMetadataKeyQuery(ctx, key)
}
queries := &query.UserMetadataSearchQueries{
SearchRequest: query.SearchRequest{},
Queries: []query.SearchQuery{query.Or(keyQueries...)},
}
return queries
}
func buildMetadataKeyQuery(ctx context.Context, key metadata.Key) query.SearchQuery {
scopedKey := metadata.ScopeKey(ctx, key)
q, err := query.NewUserMetadataKeySearchQuery(string(scopedKey), query.TextEquals)
if err != nil {
logging.Panic("Error build user metadata query for key " + key)
}
return q
}
func (h *UsersHandler) mapMetadataToCommands(ctx context.Context, user *ScimUser) ([]*command.AddMetadataEntry, error) { func (h *UsersHandler) mapMetadataToCommands(ctx context.Context, user *ScimUser) ([]*command.AddMetadataEntry, error) {
md := make([]*command.AddMetadataEntry, 0, len(metadata.ScimUserRelevantMetadataKeys)) md := make([]*command.AddMetadataEntry, 0, len(metadata.ScimUserRelevantMetadataKeys))
for _, key := range metadata.ScimUserRelevantMetadataKeys { for _, key := range metadata.ScimUserRelevantMetadataKeys {
@ -51,7 +91,17 @@ func getValueForMetadataKey(user *ScimUser, key metadata.Key) ([]byte, error) {
case metadata.KeyAddresses: case metadata.KeyAddresses:
fallthrough fallthrough
case metadata.KeyRoles: case metadata.KeyRoles:
return json.Marshal(value) val, err := json.Marshal(value)
if err != nil {
return nil, err
}
// null is considered no value
if len(val) == 4 && string(val) == "null" {
return nil, nil
}
return val, nil
// http url values // http url values
case metadata.KeyProfileUrl: case metadata.KeyProfileUrl:
@ -148,3 +198,36 @@ func getRawValueForMetadataKey(user *ScimUser, key metadata.Key) interface{} {
logging.Panicf("Unknown or unsupported metadata key %s", key) logging.Panicf("Unknown or unsupported metadata key %s", key)
return nil return nil
} }
func extractScalarMetadata(ctx context.Context, md map[metadata.ScopedKey][]byte, key metadata.Key) string {
val, ok := md[metadata.ScopeKey(ctx, key)]
if !ok {
return ""
}
return string(val)
}
func extractHttpURLMetadata(ctx context.Context, md map[metadata.ScopedKey][]byte, key metadata.Key) *schemas.HttpURL {
val, ok := md[metadata.ScopeKey(ctx, key)]
if !ok {
return nil
}
url, err := schemas.ParseHTTPURL(string(val))
if err != nil {
logging.OnError(err).Warn("Failed to parse scim url metadata for " + key)
return nil
}
return url
}
func extractJsonMetadata(ctx context.Context, md map[metadata.ScopedKey][]byte, key metadata.Key, v interface{}) error {
val, ok := md[metadata.ScopeKey(ctx, key)]
if !ok {
return nil
}
return json.Unmarshal(val, v)
}

View File

@ -54,6 +54,7 @@ func mapResource[T sresources.ResourceHolder](router *mux.Router, mw zhttp_middl
resourceRouter := router.PathPrefix("/" + path.Join(zhttp.OrgIdInPathVariable, string(handler.ResourceNamePlural()))).Subrouter() resourceRouter := router.PathPrefix("/" + path.Join(zhttp.OrgIdInPathVariable, string(handler.ResourceNamePlural()))).Subrouter()
resourceRouter.Handle("", mw(handleResourceCreatedResponse(adapter.Create))).Methods(http.MethodPost) resourceRouter.Handle("", mw(handleResourceCreatedResponse(adapter.Create))).Methods(http.MethodPost)
resourceRouter.Handle("/{id}", mw(handleResourceResponse(adapter.Get))).Methods(http.MethodGet)
resourceRouter.Handle("/{id}", mw(handleEmptyResponse(adapter.Delete))).Methods(http.MethodDelete) resourceRouter.Handle("/{id}", mw(handleEmptyResponse(adapter.Delete))).Methods(http.MethodDelete)
} }
@ -74,6 +75,22 @@ func handleResourceCreatedResponse[T sresources.ResourceHolder](next func(*http.
} }
} }
func handleResourceResponse[T sresources.ResourceHolder](next func(*http.Request) (T, error)) zhttp_middlware.HandlerFuncWithError {
return func(w http.ResponseWriter, r *http.Request) error {
entity, err := next(r)
if err != nil {
return err
}
resource := entity.GetResource()
w.Header().Set(zhttp.ContentLocation, resource.Meta.Location)
err = json.NewEncoder(w).Encode(entity)
logging.OnError(err).Warn("scim json response encoding failed")
return nil
}
}
func handleEmptyResponse(next func(*http.Request) error) zhttp_middlware.HandlerFuncWithError { func handleEmptyResponse(next func(*http.Request) error) zhttp_middlware.HandlerFuncWithError {
return func(w http.ResponseWriter, r *http.Request) error { return func(w http.ResponseWriter, r *http.Request) error {
err := next(r) err := next(r)

View File

@ -1,6 +1,7 @@
package integration package integration
import ( import (
"reflect"
"testing" "testing"
"time" "time"
@ -175,3 +176,89 @@ func AssertMapContains[M ~map[K]V, K comparable, V any](t *testing.T, m M, key K
assert.True(t, exists, "Key '%s' should exist in the map", key) assert.True(t, exists, "Key '%s' should exist in the map", key)
assert.Equal(t, expectedValue, val, "Key '%s' should have value '%d'", key, expectedValue) assert.Equal(t, expectedValue, val, "Key '%s' should have value '%d'", key, expectedValue)
} }
// PartiallyDeepEqual is similar to reflect.DeepEqual,
// but only compares exported non-zero fields of the expectedValue
func PartiallyDeepEqual(expected, actual interface{}) bool {
if expected == nil {
return actual == nil
}
if actual == nil {
return false
}
return partiallyDeepEqual(reflect.ValueOf(expected), reflect.ValueOf(actual))
}
func partiallyDeepEqual(expected, actual reflect.Value) bool {
// Dereference pointers if needed
if expected.Kind() == reflect.Ptr {
if expected.IsNil() {
return actual.IsNil()
}
expected = expected.Elem()
}
if actual.Kind() == reflect.Ptr {
if actual.IsNil() {
return false
}
actual = actual.Elem()
}
if expected.Type() != actual.Type() {
return false
}
switch expected.Kind() { //nolint:exhaustive
case reflect.Struct:
for i := 0; i < expected.NumField(); i++ {
field := expected.Type().Field(i)
if field.PkgPath != "" { // Skip unexported fields
continue
}
expectedField := expected.Field(i)
actualField := actual.Field(i)
// Skip zero-value fields in expected
if reflect.DeepEqual(expectedField.Interface(), reflect.Zero(expectedField.Type()).Interface()) {
continue
}
// Compare fields recursively
if !partiallyDeepEqual(expectedField, actualField) {
return false
}
}
return true
case reflect.Slice, reflect.Array:
if expected.Len() > actual.Len() {
return false
}
for i := 0; i < expected.Len(); i++ {
if !partiallyDeepEqual(expected.Index(i), actual.Index(i)) {
return false
}
}
return true
default:
// Compare primitive types
return reflect.DeepEqual(expected.Interface(), actual.Interface())
}
}
func Must[T any](result T, error error) T {
if error != nil {
panic(error)
}
return result
}

View File

@ -50,3 +50,153 @@ func TestAssertDetails(t *testing.T) {
}) })
} }
} }
func TestPartiallyDeepEqual(t *testing.T) {
type SecondaryNestedType struct {
Value int
}
type NestedType struct {
Value int
ValueSlice []int
Nested SecondaryNestedType
NestedPointer *SecondaryNestedType
}
type args struct {
expected interface{}
actual interface{}
}
tests := []struct {
name string
args args
want bool
}{
{
name: "nil",
args: args{
expected: nil,
actual: nil,
},
want: true,
},
{
name: "scalar value",
args: args{
expected: 10,
actual: 10,
},
want: true,
},
{
name: "different scalar value",
args: args{
expected: 11,
actual: 10,
},
want: false,
},
{
name: "string value",
args: args{
expected: "foo",
actual: "foo",
},
want: true,
},
{
name: "different string value",
args: args{
expected: "foo2",
actual: "foo",
},
want: false,
},
{
name: "scalar only set in actual",
args: args{
expected: &SecondaryNestedType{},
actual: &SecondaryNestedType{Value: 10},
},
want: true,
},
{
name: "scalar equal",
args: args{
expected: &SecondaryNestedType{Value: 10},
actual: &SecondaryNestedType{Value: 10},
},
want: true,
},
{
name: "scalar only set in expected",
args: args{
expected: &SecondaryNestedType{Value: 10},
actual: &SecondaryNestedType{},
},
want: false,
},
{
name: "ptr only set in expected",
args: args{
expected: &NestedType{NestedPointer: &SecondaryNestedType{Value: 10}},
actual: &NestedType{},
},
want: false,
},
{
name: "ptr only set in actual",
args: args{
expected: &NestedType{},
actual: &NestedType{NestedPointer: &SecondaryNestedType{Value: 10}},
},
want: true,
},
{
name: "ptr equal",
args: args{
expected: &NestedType{NestedPointer: &SecondaryNestedType{Value: 10}},
actual: &NestedType{NestedPointer: &SecondaryNestedType{Value: 10}},
},
want: true,
},
{
name: "nested equal",
args: args{
expected: &NestedType{Nested: SecondaryNestedType{Value: 10}},
actual: &NestedType{Nested: SecondaryNestedType{Value: 10}},
},
want: true,
},
{
name: "slice equal",
args: args{
expected: &NestedType{ValueSlice: []int{10, 20}},
actual: &NestedType{ValueSlice: []int{10, 20}},
},
want: true,
},
{
name: "slice additional in expected",
args: args{
expected: &NestedType{ValueSlice: []int{10, 20, 30}},
actual: &NestedType{ValueSlice: []int{10, 20}},
},
want: false,
},
{
name: "slice additional in actual",
args: args{
expected: &NestedType{ValueSlice: []int{10, 20}},
actual: &NestedType{ValueSlice: []int{10, 20, 30}},
},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := PartiallyDeepEqual(tt.args.expected, tt.args.actual); got != tt.want {
t.Errorf("PartiallyDeepEqual() = %v, want %v", got, tt.want)
}
})
}
}

View File

@ -18,10 +18,10 @@ import (
) )
type Client struct { type Client struct {
Users *ResourceClient Users *ResourceClient[resources.ScimUser]
} }
type ResourceClient struct { type ResourceClient[T any] struct {
client *http.Client client *http.Client
baseUrl string baseUrl string
resourceName string resourceName string
@ -44,7 +44,7 @@ func NewScimClient(target string) *Client {
target = "http://" + target + schemas.HandlerPrefix target = "http://" + target + schemas.HandlerPrefix
client := &http.Client{} client := &http.Client{}
return &Client{ return &Client{
Users: &ResourceClient{ Users: &ResourceClient[resources.ScimUser]{
client: client, client: client,
baseUrl: target, baseUrl: target,
resourceName: "Users", resourceName: "Users",
@ -52,17 +52,19 @@ func NewScimClient(target string) *Client {
} }
} }
func (c *ResourceClient) Create(ctx context.Context, orgID string, body []byte) (*resources.ScimUser, error) { func (c *ResourceClient[T]) Create(ctx context.Context, orgID string, body []byte) (*T, error) {
user := new(resources.ScimUser) return c.doWithBody(ctx, http.MethodPost, orgID, "", bytes.NewReader(body))
err := c.doWithBody(ctx, http.MethodPost, orgID, "", bytes.NewReader(body), user)
return user, err
} }
func (c *ResourceClient) Delete(ctx context.Context, orgID, id string) error { func (c *ResourceClient[T]) Get(ctx context.Context, orgID, resourceID string) (*T, error) {
return c.doWithBody(ctx, http.MethodGet, orgID, resourceID, nil)
}
func (c *ResourceClient[T]) Delete(ctx context.Context, orgID, id string) error {
return c.do(ctx, http.MethodDelete, orgID, id) return c.do(ctx, http.MethodDelete, orgID, id)
} }
func (c *ResourceClient) do(ctx context.Context, method, orgID, url string) error { func (c *ResourceClient[T]) do(ctx context.Context, method, orgID, url string) error {
req, err := http.NewRequestWithContext(ctx, method, c.buildURL(orgID, url), nil) req, err := http.NewRequestWithContext(ctx, method, c.buildURL(orgID, url), nil)
if err != nil { if err != nil {
return err return err
@ -71,17 +73,18 @@ func (c *ResourceClient) do(ctx context.Context, method, orgID, url string) erro
return c.doReq(req, nil) return c.doReq(req, nil)
} }
func (c *ResourceClient) doWithBody(ctx context.Context, method, orgID, url string, body io.Reader, responseEntity interface{}) error { func (c *ResourceClient[T]) doWithBody(ctx context.Context, method, orgID, url string, body io.Reader) (*T, error) {
req, err := http.NewRequestWithContext(ctx, method, c.buildURL(orgID, url), body) req, err := http.NewRequestWithContext(ctx, method, c.buildURL(orgID, url), body)
if err != nil { if err != nil {
return err return nil, err
} }
req.Header.Set(zhttp.ContentType, middleware.ContentTypeScim) req.Header.Set(zhttp.ContentType, middleware.ContentTypeScim)
return c.doReq(req, responseEntity) responseEntity := new(T)
return responseEntity, c.doReq(req, responseEntity)
} }
func (c *ResourceClient) doReq(req *http.Request, responseEntity interface{}) error { func (c *ResourceClient[T]) doReq(req *http.Request, responseEntity *T) error {
addTokenAsHeader(req) addTokenAsHeader(req)
resp, err := c.client.Do(req) resp, err := c.client.Do(req)
@ -133,7 +136,7 @@ func readScimError(resp *http.Response) error {
return scimErr return scimErr
} }
func (c *ResourceClient) buildURL(orgID, segment string) string { func (c *ResourceClient[T]) buildURL(orgID, segment string) string {
if segment == "" { if segment == "" {
return c.baseUrl + "/" + path.Join(orgID, c.resourceName) return c.baseUrl + "/" + path.Join(orgID, c.resourceName)
} }