mirror of
https://github.com/zitadel/zitadel.git
synced 2025-02-28 19:47:23 +00:00
feat: list users scim v2 endpoint (#9187)
# Which Problems Are Solved - Adds support for the list users SCIM v2 endpoint # How the Problems Are Solved - Adds support for the list users SCIM v2 endpoints under `GET /scim/v2/{orgID}/Users` and `POST /scim/v2/{orgID}/Users/.search` # Additional Changes - adds a new function `SearchUserMetadataForUsers` to the query layer to query a metadata keyset for given user ids - adds a new function `NewUserMetadataExistsQuery` to the query layer to query a given metadata key value pair exists - adds a new function `CountUsers` to the query layer to count users without reading any rows - handle `ErrorAlreadyExists` as scim errors `uniqueness` - adds `NumberLessOrEqual` and `NumberGreaterOrEqual` query comparison methods - adds `BytesQuery` with `BytesEquals` and `BytesNotEquals` query comparison methods # Additional Context Part of #8140 Supported fields for scim filters: * `meta.created` * `meta.lastModified` * `id` * `username` * `name.familyName` * `name.givenName` * `emails` and `emails.value` * `active` only eq and ne * `externalId` only eq and ne
This commit is contained in:
parent
926e7169b2
commit
1915d35605
1
go.mod
1
go.mod
@ -10,6 +10,7 @@ require (
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/trace v1.24.0
|
||||
github.com/Masterminds/squirrel v1.5.4
|
||||
github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b
|
||||
github.com/alecthomas/participle/v2 v2.1.1
|
||||
github.com/alicebob/miniredis/v2 v2.33.0
|
||||
github.com/benbjohnson/clock v1.3.5
|
||||
github.com/boombuler/barcode v1.0.2
|
||||
|
8
go.sum
8
go.sum
@ -49,6 +49,12 @@ github.com/ajstarks/deck v0.0.0-20200831202436-30c9fc6549a9/go.mod h1:JynElWSGnm
|
||||
github.com/ajstarks/deck/generate v0.0.0-20210309230005-c3f852c02e19/go.mod h1:T13YZdzov6OU0A1+RfKZiZN9ca6VeKdBdyDV+BY97Tk=
|
||||
github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b h1:slYM766cy2nI3BwyRiyQj/Ud48djTMtMebDqepE95rw=
|
||||
github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b/go.mod h1:1KcenG0jGWcpt8ov532z81sp/kMMUG485J2InIOyADM=
|
||||
github.com/alecthomas/assert/v2 v2.3.0 h1:mAsH2wmvjsuvyBvAmCtm7zFsBlb8mIHx5ySLVdDZXL0=
|
||||
github.com/alecthomas/assert/v2 v2.3.0/go.mod h1:pXcQ2Asjp247dahGEmsZ6ru0UVwnkhktn7S0bBDLxvQ=
|
||||
github.com/alecthomas/participle/v2 v2.1.1 h1:hrjKESvSqGHzRb4yW1ciisFJ4p3MGYih6icjJvbsmV8=
|
||||
github.com/alecthomas/participle/v2 v2.1.1/go.mod h1:Y1+hAs8DHPmc3YUFzqllV+eSQ9ljPTk0ZkPMtEdAx2c=
|
||||
github.com/alecthomas/repr v0.2.0 h1:HAzS41CIzNW5syS8Mf9UwXhNH1J9aix/BvDRf1Ml2Yk=
|
||||
github.com/alecthomas/repr v0.2.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
|
||||
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
||||
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
||||
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
|
||||
@ -400,6 +406,8 @@ github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO
|
||||
github.com/hashicorp/mdns v1.0.0/go.mod h1:tL+uN++7HEJ6SQLQ2/p+z2pH24WQKWjBPkE0mNTz8vQ=
|
||||
github.com/hashicorp/memberlist v0.1.3/go.mod h1:ajVTdAv/9Im8oMAAj5G31PhhMCZJV2pPBoIllUwCN7I=
|
||||
github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/JwenrHc=
|
||||
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
|
||||
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
|
||||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||
github.com/hudl/fargo v1.3.0/go.mod h1:y3CKSmjA+wD2gak7sUSXTAoopbhU08POFhmITJgmKTg=
|
||||
github.com/improbable-eng/grpc-web v0.15.0 h1:BN+7z6uNXZ1tQGcNAuaU1YjsLTApzkjt2tzCixLaUPQ=
|
||||
|
@ -1,6 +1,7 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/schema"
|
||||
@ -26,3 +27,24 @@ func (p *Parser) Parse(r *http.Request, data interface{}) error {
|
||||
|
||||
return p.decoder.Decode(data, r.Form)
|
||||
}
|
||||
|
||||
func (p *Parser) UnwrapParserError(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// try to unwrap the error
|
||||
var multiErr schema.MultiError
|
||||
if errors.As(err, &multiErr) && len(multiErr) == 1 {
|
||||
for _, v := range multiErr {
|
||||
var schemaErr schema.ConversionError
|
||||
if errors.As(v, &schemaErr) {
|
||||
return schemaErr.Err
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
81
internal/api/http/parser_test.go
Normal file
81
internal/api/http/parser_test.go
Normal file
@ -0,0 +1,81 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
gschema "github.com/gorilla/schema"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type SampleSchema struct {
|
||||
Value *SampleSchemaValue `schema:"value"`
|
||||
IntValue int `schema:"intvalue"`
|
||||
}
|
||||
|
||||
type SampleSchemaValue struct{}
|
||||
|
||||
func (s *SampleSchemaValue) UnmarshalText(text []byte) error {
|
||||
if string(text) == "foo" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return errors.New("this is a test error")
|
||||
}
|
||||
|
||||
func TestParser_UnwrapParserError(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
wantErr bool
|
||||
assertUnwrappedError func(err error, unwrappedErr error)
|
||||
}{
|
||||
{
|
||||
name: "unwrap ok",
|
||||
query: "value=test",
|
||||
wantErr: true,
|
||||
assertUnwrappedError: func(_, err error) {
|
||||
require.Equal(t, "this is a test error", err.Error())
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple errors",
|
||||
query: "value=test&intvalue=foo",
|
||||
wantErr: true,
|
||||
assertUnwrappedError: func(err error, unwrappedErr error) {
|
||||
require.Equal(t, err, unwrappedErr)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no error",
|
||||
query: "value=foo&intvalue=1",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := NewParser()
|
||||
encodedFormData := url.Values{}.Encode()
|
||||
r, err := http.NewRequest(http.MethodPost, "http://exmaple.com?"+tt.query, bytes.NewBufferString(encodedFormData))
|
||||
require.NoError(t, err)
|
||||
|
||||
data := new(SampleSchema)
|
||||
err = p.Parse(r, data)
|
||||
if !tt.wantErr {
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, p.UnwrapParserError(err))
|
||||
return
|
||||
}
|
||||
|
||||
require.Error(t, err)
|
||||
require.IsType(t, gschema.MultiError{}, err)
|
||||
|
||||
unwrappedErr := p.UnwrapParserError(err)
|
||||
require.Error(t, unwrappedErr)
|
||||
tt.assertUnwrappedError(err, unwrappedErr)
|
||||
})
|
||||
}
|
||||
}
|
@ -10,6 +10,12 @@ var AuthMapping = authz.MethodMapping{
|
||||
"POST:/scim/v2/" + http.OrgIdInPathVariable + "/Users": {
|
||||
Permission: domain.PermissionUserWrite,
|
||||
},
|
||||
"POST:/scim/v2/" + http.OrgIdInPathVariable + "/Users/.search": {
|
||||
Permission: domain.PermissionUserRead,
|
||||
},
|
||||
"GET:/scim/v2/" + http.OrgIdInPathVariable + "/Users": {
|
||||
Permission: domain.PermissionUserRead,
|
||||
},
|
||||
"GET:/scim/v2/" + http.OrgIdInPathVariable + "/Users/{id}": {
|
||||
Permission: domain.PermissionUserRead,
|
||||
},
|
||||
|
@ -21,6 +21,7 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/api/scim/schemas"
|
||||
"github.com/zitadel/zitadel/internal/integration"
|
||||
"github.com/zitadel/zitadel/internal/integration/scim"
|
||||
"github.com/zitadel/zitadel/internal/test"
|
||||
"github.com/zitadel/zitadel/pkg/grpc/management"
|
||||
"github.com/zitadel/zitadel/pkg/grpc/user/v2"
|
||||
)
|
||||
@ -55,6 +56,104 @@ var (
|
||||
|
||||
//go:embed testdata/users_create_test_invalid_timezone.json
|
||||
invalidTimeZoneUserJson []byte
|
||||
|
||||
fullUser = &resources.ScimUser{
|
||||
ExternalID: "701984",
|
||||
UserName: "bjensen@example.com",
|
||||
Name: &resources.ScimUserName{
|
||||
Formatted: "Babs Jensen", // DisplayName takes precedence in Zitadel
|
||||
FamilyName: "Jensen",
|
||||
GivenName: "Barbara",
|
||||
MiddleName: "Jane",
|
||||
HonorificPrefix: "Ms.",
|
||||
HonorificSuffix: "III",
|
||||
},
|
||||
DisplayName: "Babs Jensen",
|
||||
NickName: "Babs",
|
||||
ProfileUrl: test.Must(schemas.ParseHTTPURL("http://login.example.com/bjensen")),
|
||||
Emails: []*resources.ScimEmail{
|
||||
{
|
||||
Value: "bjensen@example.com",
|
||||
Primary: true,
|
||||
},
|
||||
},
|
||||
Addresses: []*resources.ScimAddress{
|
||||
{
|
||||
Type: "work",
|
||||
StreetAddress: "100 Universal City Plaza",
|
||||
Locality: "Hollywood",
|
||||
Region: "CA",
|
||||
PostalCode: "91608",
|
||||
Country: "USA",
|
||||
Formatted: "100 Universal City Plaza\nHollywood, CA 91608 USA",
|
||||
Primary: true,
|
||||
},
|
||||
{
|
||||
Type: "home",
|
||||
StreetAddress: "456 Hollywood Blvd",
|
||||
Locality: "Hollywood",
|
||||
Region: "CA",
|
||||
PostalCode: "91608",
|
||||
Country: "USA",
|
||||
Formatted: "456 Hollywood Blvd\nHollywood, CA 91608 USA",
|
||||
},
|
||||
},
|
||||
PhoneNumbers: []*resources.ScimPhoneNumber{
|
||||
{
|
||||
Value: "+415555555555",
|
||||
Primary: true,
|
||||
},
|
||||
},
|
||||
Ims: []*resources.ScimIms{
|
||||
{
|
||||
Value: "someaimhandle",
|
||||
Type: "aim",
|
||||
},
|
||||
{
|
||||
Value: "twitterhandle",
|
||||
Type: "X",
|
||||
},
|
||||
},
|
||||
Photos: []*resources.ScimPhoto{
|
||||
{
|
||||
Value: *test.Must(schemas.ParseHTTPURL("https://photos.example.com/profilephoto/72930000000Ccne/F")),
|
||||
Type: "photo",
|
||||
},
|
||||
},
|
||||
Roles: []*resources.ScimRole{
|
||||
{
|
||||
Value: "my-role-1",
|
||||
Display: "Rolle 1",
|
||||
Type: "main-role",
|
||||
Primary: true,
|
||||
},
|
||||
{
|
||||
Value: "my-role-2",
|
||||
Display: "Rolle 2",
|
||||
Type: "secondary-role",
|
||||
Primary: false,
|
||||
},
|
||||
},
|
||||
Entitlements: []*resources.ScimEntitlement{
|
||||
{
|
||||
Value: "my-entitlement-1",
|
||||
Display: "Entitlement 1",
|
||||
Type: "main-entitlement",
|
||||
Primary: true,
|
||||
},
|
||||
{
|
||||
Value: "my-entitlement-2",
|
||||
Display: "Entitlement 2",
|
||||
Type: "secondary-entitlement",
|
||||
Primary: false,
|
||||
},
|
||||
},
|
||||
Title: "Tour Guide",
|
||||
PreferredLanguage: language.MustParse("en-US"),
|
||||
Locale: "en-US",
|
||||
Timezone: "America/Los_Angeles",
|
||||
Active: gu.Ptr(true),
|
||||
}
|
||||
)
|
||||
|
||||
func TestCreateUser(t *testing.T) {
|
||||
@ -95,103 +194,7 @@ func TestCreateUser(t *testing.T) {
|
||||
{
|
||||
name: "full user",
|
||||
body: fullUserJson,
|
||||
want: &resources.ScimUser{
|
||||
ExternalID: "701984",
|
||||
UserName: "bjensen@example.com",
|
||||
Name: &resources.ScimUserName{
|
||||
Formatted: "Babs Jensen", // DisplayName takes precedence in Zitadel
|
||||
FamilyName: "Jensen",
|
||||
GivenName: "Barbara",
|
||||
MiddleName: "Jane",
|
||||
HonorificPrefix: "Ms.",
|
||||
HonorificSuffix: "III",
|
||||
},
|
||||
DisplayName: "Babs Jensen",
|
||||
NickName: "Babs",
|
||||
ProfileUrl: integration.Must(schemas.ParseHTTPURL("http://login.example.com/bjensen")),
|
||||
Emails: []*resources.ScimEmail{
|
||||
{
|
||||
Value: "bjensen@example.com",
|
||||
Primary: true,
|
||||
},
|
||||
},
|
||||
Addresses: []*resources.ScimAddress{
|
||||
{
|
||||
Type: "work",
|
||||
StreetAddress: "100 Universal City Plaza",
|
||||
Locality: "Hollywood",
|
||||
Region: "CA",
|
||||
PostalCode: "91608",
|
||||
Country: "USA",
|
||||
Formatted: "100 Universal City Plaza\nHollywood, CA 91608 USA",
|
||||
Primary: true,
|
||||
},
|
||||
{
|
||||
Type: "home",
|
||||
StreetAddress: "456 Hollywood Blvd",
|
||||
Locality: "Hollywood",
|
||||
Region: "CA",
|
||||
PostalCode: "91608",
|
||||
Country: "USA",
|
||||
Formatted: "456 Hollywood Blvd\nHollywood, CA 91608 USA",
|
||||
},
|
||||
},
|
||||
PhoneNumbers: []*resources.ScimPhoneNumber{
|
||||
{
|
||||
Value: "+415555555555",
|
||||
Primary: true,
|
||||
},
|
||||
},
|
||||
Ims: []*resources.ScimIms{
|
||||
{
|
||||
Value: "someaimhandle",
|
||||
Type: "aim",
|
||||
},
|
||||
{
|
||||
Value: "twitterhandle",
|
||||
Type: "X",
|
||||
},
|
||||
},
|
||||
Photos: []*resources.ScimPhoto{
|
||||
{
|
||||
Value: *integration.Must(schemas.ParseHTTPURL("https://photos.example.com/profilephoto/72930000000Ccne/F")),
|
||||
Type: "photo",
|
||||
},
|
||||
},
|
||||
Roles: []*resources.ScimRole{
|
||||
{
|
||||
Value: "my-role-1",
|
||||
Display: "Rolle 1",
|
||||
Type: "main-role",
|
||||
Primary: true,
|
||||
},
|
||||
{
|
||||
Value: "my-role-2",
|
||||
Display: "Rolle 2",
|
||||
Type: "secondary-role",
|
||||
Primary: false,
|
||||
},
|
||||
},
|
||||
Entitlements: []*resources.ScimEntitlement{
|
||||
{
|
||||
Value: "my-entitlement-1",
|
||||
Display: "Entitlement 1",
|
||||
Type: "main-entitlement",
|
||||
Primary: true,
|
||||
},
|
||||
{
|
||||
Value: "my-entitlement-2",
|
||||
Display: "Entitlement 2",
|
||||
Type: "secondary-entitlement",
|
||||
Primary: false,
|
||||
},
|
||||
},
|
||||
Title: "Tour Guide",
|
||||
PreferredLanguage: language.MustParse("en-US"),
|
||||
Locale: "en-US",
|
||||
Timezone: "America/Los_Angeles",
|
||||
Active: gu.Ptr(true),
|
||||
},
|
||||
want: fullUser,
|
||||
},
|
||||
{
|
||||
name: "missing userName",
|
||||
@ -290,7 +293,7 @@ func TestCreateUser(t *testing.T) {
|
||||
assert.Nil(t, createdUser.Password)
|
||||
|
||||
if tt.want != nil {
|
||||
if !integration.PartiallyDeepEqual(tt.want, createdUser) {
|
||||
if !test.PartiallyDeepEqual(tt.want, createdUser) {
|
||||
t.Errorf("CreateUser() got = %v, want %v", createdUser, tt.want)
|
||||
}
|
||||
|
||||
@ -299,7 +302,7 @@ func TestCreateUser(t *testing.T) {
|
||||
// ensure the user is really stored and not just returned to the caller
|
||||
fetchedUser, err := Instance.Client.SCIM.Users.Get(CTX, Instance.DefaultOrg.Id, createdUser.ID)
|
||||
require.NoError(ttt, err)
|
||||
if !integration.PartiallyDeepEqual(tt.want, fetchedUser) {
|
||||
if !test.PartiallyDeepEqual(tt.want, fetchedUser) {
|
||||
ttt.Errorf("GetUser() got = %v, want %v", fetchedUser, tt.want)
|
||||
}
|
||||
}, retryDuration, tick)
|
||||
@ -315,6 +318,7 @@ func TestCreateUser_duplicate(t *testing.T) {
|
||||
_, err = Instance.Client.SCIM.Users.Create(CTX, Instance.DefaultOrg.Id, minimalUserJson)
|
||||
scimErr := scim.RequireScimError(t, http.StatusConflict, err)
|
||||
assert.Equal(t, "User already exists", scimErr.Error.Detail)
|
||||
assert.Equal(t, "uniqueness", scimErr.Error.ScimType)
|
||||
|
||||
_, err = Instance.Client.UserV2.DeleteUser(CTX, &user.DeleteUserRequest{UserId: createdUser.ID})
|
||||
require.NoError(t, err)
|
||||
@ -341,19 +345,19 @@ func TestCreateUser_metadata(t *testing.T) {
|
||||
mdMap[md.Result[i].Key] = string(md.Result[i].Value)
|
||||
}
|
||||
|
||||
integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:name.honorificPrefix", "Ms.")
|
||||
integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:timezone", "America/Los_Angeles")
|
||||
integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:photos", `[{"value":"https://photos.example.com/profilephoto/72930000000Ccne/F","type":"photo"},{"value":"https://photos.example.com/profilephoto/72930000000Ccne/T","type":"thumbnail"}]`)
|
||||
integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:addresses", `[{"type":"work","streetAddress":"100 Universal City Plaza","locality":"Hollywood","region":"CA","postalCode":"91608","country":"USA","formatted":"100 Universal City Plaza\nHollywood, CA 91608 USA","primary":true},{"type":"home","streetAddress":"456 Hollywood Blvd","locality":"Hollywood","region":"CA","postalCode":"91608","country":"USA","formatted":"456 Hollywood Blvd\nHollywood, CA 91608 USA"}]`)
|
||||
integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:entitlements", `[{"value":"my-entitlement-1","display":"Entitlement 1","type":"main-entitlement","primary":true},{"value":"my-entitlement-2","display":"Entitlement 2","type":"secondary-entitlement"}]`)
|
||||
integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:externalId", "701984")
|
||||
integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:name.middleName", "Jane")
|
||||
integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:name.honorificSuffix", "III")
|
||||
integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:profileURL", "http://login.example.com/bjensen")
|
||||
integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:title", "Tour Guide")
|
||||
integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:locale", "en-US")
|
||||
integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:ims", `[{"value":"someaimhandle","type":"aim"},{"value":"twitterhandle","type":"X"}]`)
|
||||
integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:roles", `[{"value":"my-role-1","display":"Rolle 1","type":"main-role","primary":true},{"value":"my-role-2","display":"Rolle 2","type":"secondary-role"}]`)
|
||||
test.AssertMapContains(tt, mdMap, "urn:zitadel:scim:name.honorificPrefix", "Ms.")
|
||||
test.AssertMapContains(tt, mdMap, "urn:zitadel:scim:timezone", "America/Los_Angeles")
|
||||
test.AssertMapContains(tt, mdMap, "urn:zitadel:scim:photos", `[{"value":"https://photos.example.com/profilephoto/72930000000Ccne/F","type":"photo"},{"value":"https://photos.example.com/profilephoto/72930000000Ccne/T","type":"thumbnail"}]`)
|
||||
test.AssertMapContains(tt, mdMap, "urn:zitadel:scim:addresses", `[{"type":"work","streetAddress":"100 Universal City Plaza","locality":"Hollywood","region":"CA","postalCode":"91608","country":"USA","formatted":"100 Universal City Plaza\nHollywood, CA 91608 USA","primary":true},{"type":"home","streetAddress":"456 Hollywood Blvd","locality":"Hollywood","region":"CA","postalCode":"91608","country":"USA","formatted":"456 Hollywood Blvd\nHollywood, CA 91608 USA"}]`)
|
||||
test.AssertMapContains(tt, mdMap, "urn:zitadel:scim:entitlements", `[{"value":"my-entitlement-1","display":"Entitlement 1","type":"main-entitlement","primary":true},{"value":"my-entitlement-2","display":"Entitlement 2","type":"secondary-entitlement"}]`)
|
||||
test.AssertMapContains(tt, mdMap, "urn:zitadel:scim:externalId", "701984")
|
||||
test.AssertMapContains(tt, mdMap, "urn:zitadel:scim:name.middleName", "Jane")
|
||||
test.AssertMapContains(tt, mdMap, "urn:zitadel:scim:name.honorificSuffix", "III")
|
||||
test.AssertMapContains(tt, mdMap, "urn:zitadel:scim:profileURL", "http://login.example.com/bjensen")
|
||||
test.AssertMapContains(tt, mdMap, "urn:zitadel:scim:title", "Tour Guide")
|
||||
test.AssertMapContains(tt, mdMap, "urn:zitadel:scim:locale", "en-US")
|
||||
test.AssertMapContains(tt, mdMap, "urn:zitadel:scim:ims", `[{"value":"someaimhandle","type":"aim"},{"value":"twitterhandle","type":"X"}]`)
|
||||
test.AssertMapContains(tt, mdMap, "urn:zitadel:scim:roles", `[{"value":"my-role-1","display":"Rolle 1","type":"main-role","primary":true},{"value":"my-role-2","display":"Rolle 2","type":"secondary-role"}]`)
|
||||
}, retryDuration, tick)
|
||||
}
|
||||
|
||||
|
@ -19,6 +19,7 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/api/scim/schemas"
|
||||
"github.com/zitadel/zitadel/internal/integration"
|
||||
"github.com/zitadel/zitadel/internal/integration/scim"
|
||||
"github.com/zitadel/zitadel/internal/test"
|
||||
"github.com/zitadel/zitadel/pkg/grpc/management"
|
||||
"github.com/zitadel/zitadel/pkg/grpc/user/v2"
|
||||
)
|
||||
@ -93,7 +94,7 @@ func TestGetUser(t *testing.T) {
|
||||
},
|
||||
DisplayName: "Babs Jensen",
|
||||
NickName: "Babs",
|
||||
ProfileUrl: integration.Must(schemas.ParseHTTPURL("http://login.example.com/bjensen")),
|
||||
ProfileUrl: test.Must(schemas.ParseHTTPURL("http://login.example.com/bjensen")),
|
||||
Title: "Tour Guide",
|
||||
PreferredLanguage: language.Make("en-US"),
|
||||
Locale: "en-US",
|
||||
@ -144,11 +145,11 @@ func TestGetUser(t *testing.T) {
|
||||
},
|
||||
Photos: []*resources.ScimPhoto{
|
||||
{
|
||||
Value: *integration.Must(schemas.ParseHTTPURL("https://photos.example.com/profilephoto/72930000000Ccne/F")),
|
||||
Value: *test.Must(schemas.ParseHTTPURL("https://photos.example.com/profilephoto/72930000000Ccne/F")),
|
||||
Type: "photo",
|
||||
},
|
||||
{
|
||||
Value: *integration.Must(schemas.ParseHTTPURL("https://photos.example.com/profilephoto/72930000000Ccne/T")),
|
||||
Value: *test.Must(schemas.ParseHTTPURL("https://photos.example.com/profilephoto/72930000000Ccne/T")),
|
||||
Type: "thumbnail",
|
||||
},
|
||||
},
|
||||
@ -256,7 +257,7 @@ func TestGetUser(t *testing.T) {
|
||||
assert.Equal(ttt, schemas.ScimResourceTypeSingular("User"), fetchedUser.Resource.Meta.ResourceType)
|
||||
assert.Equal(ttt, "http://"+Instance.Host()+path.Join(schemas.HandlerPrefix, Instance.DefaultOrg.Id, "Users", fetchedUser.ID), fetchedUser.Resource.Meta.Location)
|
||||
assert.Nil(ttt, fetchedUser.Password)
|
||||
if !integration.PartiallyDeepEqual(tt.want, fetchedUser) {
|
||||
if !test.PartiallyDeepEqual(tt.want, fetchedUser) {
|
||||
ttt.Errorf("GetUser() got = %#v, want %#v", fetchedUser, tt.want)
|
||||
}
|
||||
}, retryDuration, tick)
|
||||
|
492
internal/api/scim/integration_test/users_list_test.go
Normal file
492
internal/api/scim/integration_test/users_list_test.go
Normal file
@ -0,0 +1,492 @@
|
||||
//go:build integration
|
||||
|
||||
package integration_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/brianvoe/gofakeit/v6"
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/scim/resources"
|
||||
"github.com/zitadel/zitadel/internal/api/scim/schemas"
|
||||
"github.com/zitadel/zitadel/internal/integration"
|
||||
"github.com/zitadel/zitadel/internal/integration/scim"
|
||||
"github.com/zitadel/zitadel/internal/test"
|
||||
"github.com/zitadel/zitadel/pkg/grpc/management"
|
||||
"github.com/zitadel/zitadel/pkg/grpc/object/v2"
|
||||
user_v2 "github.com/zitadel/zitadel/pkg/grpc/user/v2"
|
||||
)
|
||||
|
||||
var totalCountOfHumanUsers = 13
|
||||
|
||||
func TestListUser(t *testing.T) {
|
||||
createdUserIDs := createUsers(t, CTX, Instance.DefaultOrg.Id)
|
||||
defer func() {
|
||||
// only the full user needs to be deleted, all others have random identification data
|
||||
// fullUser is always the first one.
|
||||
_, err := Instance.Client.UserV2.DeleteUser(CTX, &user_v2.DeleteUserRequest{
|
||||
UserId: createdUserIDs[0],
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
// secondary organization with same set of users,
|
||||
// these should never be modified.
|
||||
// This allows testing list requests without filters.
|
||||
iamOwnerCtx := Instance.WithAuthorization(CTX, integration.UserTypeIAMOwner)
|
||||
secondaryOrg := Instance.CreateOrganization(iamOwnerCtx, gofakeit.Name(), gofakeit.Email())
|
||||
secondaryOrgCreatedUserIDs := createUsers(t, iamOwnerCtx, secondaryOrg.OrganizationId)
|
||||
|
||||
testsInitializedUtc := time.Now().UTC()
|
||||
|
||||
// Wait one second to ensure a change in the least significant value of the timestamp.
|
||||
time.Sleep(time.Second)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx context.Context
|
||||
orgID string
|
||||
req *scim.ListRequest
|
||||
prepare func(require.TestingT) *scim.ListRequest
|
||||
wantErr bool
|
||||
errorStatus int
|
||||
errorType string
|
||||
assert func(assert.TestingT, *scim.ListResponse[*resources.ScimUser])
|
||||
cleanup func(require.TestingT)
|
||||
}{
|
||||
{
|
||||
name: "not authenticated",
|
||||
ctx: context.Background(),
|
||||
req: new(scim.ListRequest),
|
||||
wantErr: true,
|
||||
errorStatus: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "no permissions",
|
||||
ctx: Instance.WithAuthorization(CTX, integration.UserTypeNoPermission),
|
||||
req: new(scim.ListRequest),
|
||||
wantErr: true,
|
||||
errorStatus: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "unknown sort order",
|
||||
req: &scim.ListRequest{
|
||||
SortBy: gu.Ptr("id"),
|
||||
SortOrder: gu.Ptr(scim.ListRequestSortOrder("fooBar")),
|
||||
},
|
||||
wantErr: true,
|
||||
errorType: "invalidValue",
|
||||
},
|
||||
{
|
||||
name: "unknown sort field",
|
||||
req: &scim.ListRequest{
|
||||
SortBy: gu.Ptr("fooBar"),
|
||||
},
|
||||
wantErr: true,
|
||||
errorType: "invalidValue",
|
||||
},
|
||||
{
|
||||
name: "unknown filter field",
|
||||
req: &scim.ListRequest{
|
||||
Filter: gu.Ptr(`fooBar eq "10"`),
|
||||
},
|
||||
wantErr: true,
|
||||
errorType: "invalidFilter",
|
||||
},
|
||||
{
|
||||
name: "invalid filter",
|
||||
req: &scim.ListRequest{
|
||||
Filter: gu.Ptr(`fooBarBaz`),
|
||||
},
|
||||
wantErr: true,
|
||||
errorType: "invalidFilter",
|
||||
},
|
||||
{
|
||||
name: "list users without filter",
|
||||
// use other org, modifications of users happens only on primary org
|
||||
orgID: secondaryOrg.OrganizationId,
|
||||
ctx: iamOwnerCtx,
|
||||
req: new(scim.ListRequest),
|
||||
assert: func(t assert.TestingT, resp *scim.ListResponse[*resources.ScimUser]) {
|
||||
assert.Equal(t, 100, resp.ItemsPerPage)
|
||||
assert.Equal(t, totalCountOfHumanUsers, resp.TotalResults)
|
||||
assert.Equal(t, 1, resp.StartIndex)
|
||||
assert.Len(t, resp.Resources, totalCountOfHumanUsers)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "list paged sorted users without filter",
|
||||
// use other org, modifications of users happens only on primary org
|
||||
orgID: secondaryOrg.OrganizationId,
|
||||
ctx: iamOwnerCtx,
|
||||
req: &scim.ListRequest{
|
||||
Count: gu.Ptr(2),
|
||||
StartIndex: gu.Ptr(5),
|
||||
SortOrder: gu.Ptr(scim.ListRequestSortOrderAsc),
|
||||
SortBy: gu.Ptr("username"),
|
||||
},
|
||||
assert: func(t assert.TestingT, resp *scim.ListResponse[*resources.ScimUser]) {
|
||||
assert.Equal(t, 2, resp.ItemsPerPage)
|
||||
assert.Equal(t, totalCountOfHumanUsers, resp.TotalResults)
|
||||
assert.Equal(t, 5, resp.StartIndex)
|
||||
assert.Len(t, resp.Resources, 2)
|
||||
assert.True(t, strings.HasPrefix(resp.Resources[0].UserName, "scim-username-1: "))
|
||||
assert.True(t, strings.HasPrefix(resp.Resources[1].UserName, "scim-username-2: "))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "list users with simple filter",
|
||||
req: &scim.ListRequest{
|
||||
Filter: gu.Ptr(`username sw "scim-username-1"`),
|
||||
},
|
||||
assert: func(t assert.TestingT, resp *scim.ListResponse[*resources.ScimUser]) {
|
||||
assert.Equal(t, 100, resp.ItemsPerPage)
|
||||
assert.Equal(t, 2, resp.TotalResults)
|
||||
assert.Equal(t, 1, resp.StartIndex)
|
||||
assert.Len(t, resp.Resources, 2)
|
||||
for _, resource := range resp.Resources {
|
||||
assert.True(t, strings.HasPrefix(resource.UserName, "scim-username-1"))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "list paged sorted users with filter",
|
||||
req: &scim.ListRequest{
|
||||
Count: gu.Ptr(5),
|
||||
StartIndex: gu.Ptr(1),
|
||||
SortOrder: gu.Ptr(scim.ListRequestSortOrderAsc),
|
||||
SortBy: gu.Ptr("username"),
|
||||
Filter: gu.Ptr(`emails sw "scim-email-1" and emails ew "@example.com"`),
|
||||
},
|
||||
assert: func(t assert.TestingT, resp *scim.ListResponse[*resources.ScimUser]) {
|
||||
assert.Equal(t, 5, resp.ItemsPerPage)
|
||||
assert.Equal(t, 2, resp.TotalResults)
|
||||
assert.Equal(t, 1, resp.StartIndex)
|
||||
assert.Len(t, resp.Resources, 2)
|
||||
for _, resource := range resp.Resources {
|
||||
assert.True(t, strings.HasPrefix(resource.UserName, "scim-username-1"))
|
||||
assert.Len(t, resource.Emails, 1)
|
||||
assert.True(t, strings.HasPrefix(resource.Emails[0].Value, "scim-email-1"))
|
||||
assert.True(t, strings.HasSuffix(resource.Emails[0].Value, "@example.com"))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "list paged sorted users with filter as post",
|
||||
req: &scim.ListRequest{
|
||||
Count: gu.Ptr(5),
|
||||
StartIndex: gu.Ptr(1),
|
||||
SortOrder: gu.Ptr(scim.ListRequestSortOrderAsc),
|
||||
SortBy: gu.Ptr("username"),
|
||||
Filter: gu.Ptr(`emails sw "scim-email-1" and emails ew "@example.com"`),
|
||||
SendAsPost: true,
|
||||
},
|
||||
assert: func(t assert.TestingT, resp *scim.ListResponse[*resources.ScimUser]) {
|
||||
assert.Equal(t, 5, resp.ItemsPerPage)
|
||||
assert.Equal(t, 2, resp.TotalResults)
|
||||
assert.Equal(t, 1, resp.StartIndex)
|
||||
assert.Len(t, resp.Resources, 2)
|
||||
for _, resource := range resp.Resources {
|
||||
assert.True(t, strings.HasPrefix(resource.UserName, "scim-username-1"))
|
||||
assert.Len(t, resource.Emails, 1)
|
||||
assert.True(t, strings.HasPrefix(resource.Emails[0].Value, "scim-email-1"))
|
||||
assert.True(t, strings.HasSuffix(resource.Emails[0].Value, "@example.com"))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "count users without filter",
|
||||
// use other org, modifications of users happens only on primary org
|
||||
orgID: secondaryOrg.OrganizationId,
|
||||
ctx: iamOwnerCtx,
|
||||
prepare: func(t require.TestingT) *scim.ListRequest {
|
||||
return &scim.ListRequest{
|
||||
Count: gu.Ptr(0),
|
||||
}
|
||||
},
|
||||
assert: func(t assert.TestingT, resp *scim.ListResponse[*resources.ScimUser]) {
|
||||
assert.Equal(t, 0, resp.ItemsPerPage)
|
||||
assert.Equal(t, totalCountOfHumanUsers, resp.TotalResults)
|
||||
assert.Equal(t, 1, resp.StartIndex)
|
||||
assert.Len(t, resp.Resources, 0)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "list users with active filter",
|
||||
req: &scim.ListRequest{
|
||||
Filter: gu.Ptr(`active eq false`),
|
||||
},
|
||||
assert: func(t assert.TestingT, resp *scim.ListResponse[*resources.ScimUser]) {
|
||||
assert.Equal(t, 100, resp.ItemsPerPage)
|
||||
assert.Equal(t, 1, resp.TotalResults)
|
||||
assert.Equal(t, 1, resp.StartIndex)
|
||||
assert.Len(t, resp.Resources, 1)
|
||||
assert.True(t, strings.HasPrefix(resp.Resources[0].UserName, "scim-username-0"))
|
||||
assert.False(t, *resp.Resources[0].Active)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "list users with externalid filter",
|
||||
req: &scim.ListRequest{
|
||||
Filter: gu.Ptr(`externalid eq "701984"`),
|
||||
},
|
||||
assert: func(t assert.TestingT, resp *scim.ListResponse[*resources.ScimUser]) {
|
||||
assert.Equal(t, 100, resp.ItemsPerPage)
|
||||
assert.Equal(t, 1, resp.TotalResults)
|
||||
assert.Equal(t, 1, resp.StartIndex)
|
||||
assert.Len(t, resp.Resources, 1)
|
||||
assert.Equal(t, resp.Resources[0].ExternalID, "701984")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "list users with externalid filter invalid operator",
|
||||
req: &scim.ListRequest{
|
||||
Filter: gu.Ptr(`externalid pr`),
|
||||
},
|
||||
wantErr: true,
|
||||
errorType: "invalidFilter",
|
||||
},
|
||||
{
|
||||
name: "list users with externalid complex filter",
|
||||
req: &scim.ListRequest{
|
||||
Filter: gu.Ptr(`externalid eq "701984" and username eq "bjensen@example.com"`),
|
||||
},
|
||||
assert: func(t assert.TestingT, resp *scim.ListResponse[*resources.ScimUser]) {
|
||||
assert.Equal(t, 100, resp.ItemsPerPage)
|
||||
assert.Equal(t, 1, resp.TotalResults)
|
||||
assert.Equal(t, 1, resp.StartIndex)
|
||||
assert.Len(t, resp.Resources, 1)
|
||||
assert.Equal(t, resp.Resources[0].UserName, "bjensen@example.com")
|
||||
assert.Equal(t, resp.Resources[0].ExternalID, "701984")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "count users with filter",
|
||||
req: &scim.ListRequest{
|
||||
Count: gu.Ptr(0),
|
||||
Filter: gu.Ptr(`emails sw "scim-email-1" and emails ew "@example.com"`),
|
||||
},
|
||||
assert: func(t assert.TestingT, resp *scim.ListResponse[*resources.ScimUser]) {
|
||||
assert.Equal(t, 0, resp.ItemsPerPage)
|
||||
assert.Equal(t, 2, resp.TotalResults)
|
||||
assert.Equal(t, 1, resp.StartIndex)
|
||||
assert.Len(t, resp.Resources, 0)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "list users with modification date filter",
|
||||
prepare: func(t require.TestingT) *scim.ListRequest {
|
||||
userID := createdUserIDs[len(createdUserIDs)-1] // use the last entry, as we use the others for other assertions
|
||||
_, err := Instance.Client.UserV2.UpdateHumanUser(CTX, &user_v2.UpdateHumanUserRequest{
|
||||
UserId: userID,
|
||||
|
||||
Profile: &user_v2.SetHumanProfile{
|
||||
GivenName: "scim-user-given-name-modified-0: " + gofakeit.FirstName(),
|
||||
FamilyName: "scim-user-family-name-modified-0: " + gofakeit.LastName(),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return &scim.ListRequest{
|
||||
// filter by id too to exclude other random users
|
||||
Filter: gu.Ptr(fmt.Sprintf(`id eq "%s" and meta.LASTMODIFIED gt "%s"`, userID, testsInitializedUtc.Format(time.RFC3339))),
|
||||
}
|
||||
},
|
||||
assert: func(t assert.TestingT, resp *scim.ListResponse[*resources.ScimUser]) {
|
||||
assert.Len(t, resp.Resources, 1)
|
||||
assert.Equal(t, resp.Resources[0].ID, createdUserIDs[len(createdUserIDs)-1])
|
||||
assert.True(t, strings.HasPrefix(resp.Resources[0].Name.FamilyName, "scim-user-family-name-modified-0:"))
|
||||
assert.True(t, strings.HasPrefix(resp.Resources[0].Name.GivenName, "scim-user-given-name-modified-0:"))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "list users with creation date filter",
|
||||
prepare: func(t require.TestingT) *scim.ListRequest {
|
||||
resp := createHumanUser(t, CTX, Instance.DefaultOrg.Id, 100)
|
||||
return &scim.ListRequest{
|
||||
Filter: gu.Ptr(fmt.Sprintf(`id eq "%s" and meta.created gt "%s"`, resp.UserId, testsInitializedUtc.Format(time.RFC3339))),
|
||||
}
|
||||
},
|
||||
assert: func(t assert.TestingT, resp *scim.ListResponse[*resources.ScimUser]) {
|
||||
assert.Len(t, resp.Resources, 1)
|
||||
assert.True(t, strings.HasPrefix(resp.Resources[0].UserName, "scim-username-100:"))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "validate returned objects",
|
||||
req: &scim.ListRequest{
|
||||
Filter: gu.Ptr(fmt.Sprintf(`id eq "%s"`, createdUserIDs[0])),
|
||||
},
|
||||
assert: func(t assert.TestingT, resp *scim.ListResponse[*resources.ScimUser]) {
|
||||
assert.Len(t, resp.Resources, 1)
|
||||
if !test.PartiallyDeepEqual(fullUser, resp.Resources[0]) {
|
||||
t.Errorf("got = %#v, want %#v", resp.Resources[0], fullUser)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "do not return user of other org",
|
||||
req: &scim.ListRequest{
|
||||
Filter: gu.Ptr(fmt.Sprintf(`id eq "%s"`, secondaryOrgCreatedUserIDs[0])),
|
||||
},
|
||||
assert: func(t assert.TestingT, resp *scim.ListResponse[*resources.ScimUser]) {
|
||||
assert.Len(t, resp.Resources, 0)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "do not count user of other org",
|
||||
prepare: func(t require.TestingT) *scim.ListRequest {
|
||||
iamOwnerCtx := Instance.WithAuthorization(CTX, integration.UserTypeIAMOwner)
|
||||
org := Instance.CreateOrganization(iamOwnerCtx, gofakeit.Name(), gofakeit.Email())
|
||||
resp := createHumanUser(t, iamOwnerCtx, org.OrganizationId, 102)
|
||||
|
||||
return &scim.ListRequest{
|
||||
Count: gu.Ptr(0),
|
||||
Filter: gu.Ptr(fmt.Sprintf(`id eq "%s"`, resp.UserId)),
|
||||
}
|
||||
},
|
||||
assert: func(t assert.TestingT, resp *scim.ListResponse[*resources.ScimUser]) {
|
||||
assert.Len(t, resp.Resources, 0)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "scoped externalID",
|
||||
prepare: func(t require.TestingT) *scim.ListRequest {
|
||||
resp := createHumanUser(t, CTX, Instance.DefaultOrg.Id, 102)
|
||||
|
||||
// set provisioning domain of service user
|
||||
_, err := Instance.Client.Mgmt.SetUserMetadata(CTX, &management.SetUserMetadataRequest{
|
||||
Id: Instance.Users.Get(integration.UserTypeOrgOwner).ID,
|
||||
Key: "urn:zitadel:scim:provisioning_domain",
|
||||
Value: []byte("fooBar"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// set externalID for provisioning domain
|
||||
_, err = Instance.Client.Mgmt.SetUserMetadata(CTX, &management.SetUserMetadataRequest{
|
||||
Id: resp.UserId,
|
||||
Key: "urn:zitadel:scim:fooBar:externalId",
|
||||
Value: []byte("100-scopedExternalId"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return &scim.ListRequest{
|
||||
Filter: gu.Ptr(fmt.Sprintf(`id eq "%s"`, resp.UserId)),
|
||||
}
|
||||
},
|
||||
assert: func(t assert.TestingT, resp *scim.ListResponse[*resources.ScimUser]) {
|
||||
assert.Len(t, resp.Resources, 1)
|
||||
assert.Equal(t, resp.Resources[0].ExternalID, "100-scopedExternalId")
|
||||
},
|
||||
cleanup: func(t require.TestingT) {
|
||||
// delete provisioning domain of service user
|
||||
_, err := Instance.Client.Mgmt.RemoveUserMetadata(CTX, &management.RemoveUserMetadataRequest{
|
||||
Id: Instance.Users.Get(integration.UserTypeOrgOwner).ID,
|
||||
Key: "urn:zitadel:scim:provisioning_domain",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.ctx == nil {
|
||||
tt.ctx = CTX
|
||||
}
|
||||
|
||||
if tt.prepare != nil {
|
||||
tt.req = tt.prepare(t)
|
||||
}
|
||||
|
||||
if tt.orgID == "" {
|
||||
tt.orgID = Instance.DefaultOrg.Id
|
||||
}
|
||||
|
||||
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(tt.ctx, time.Minute)
|
||||
require.EventuallyWithT(t, func(ttt *assert.CollectT) {
|
||||
listResp, err := Instance.Client.SCIM.Users.List(tt.ctx, tt.orgID, tt.req)
|
||||
if tt.wantErr {
|
||||
statusCode := tt.errorStatus
|
||||
if statusCode == 0 {
|
||||
statusCode = http.StatusBadRequest
|
||||
}
|
||||
|
||||
scimErr := scim.RequireScimError(ttt, statusCode, err)
|
||||
if tt.errorType != "" {
|
||||
assert.Equal(t, tt.errorType, scimErr.Error.ScimType)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.EqualValues(ttt, []schemas.ScimSchemaType{"urn:ietf:params:scim:api:messages:2.0:ListResponse"}, listResp.Schemas)
|
||||
if tt.assert != nil {
|
||||
tt.assert(ttt, listResp)
|
||||
}
|
||||
}, retryDuration, tick)
|
||||
|
||||
if tt.cleanup != nil {
|
||||
tt.cleanup(t)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createUsers(t *testing.T, ctx context.Context, orgID string) []string {
|
||||
count := totalCountOfHumanUsers - 1 // zitadel admin is always created by default
|
||||
createdUserIDs := make([]string, 0, count)
|
||||
|
||||
// create the full scim user if on primary org
|
||||
if orgID == Instance.DefaultOrg.Id {
|
||||
fullUserCreatedResp, err := Instance.Client.SCIM.Users.Create(ctx, orgID, fullUserJson)
|
||||
require.NoError(t, err)
|
||||
createdUserIDs = append(createdUserIDs, fullUserCreatedResp.ID)
|
||||
count--
|
||||
}
|
||||
|
||||
// set the first user inactive
|
||||
resp := createHumanUser(t, ctx, orgID, 0)
|
||||
_, err := Instance.Client.UserV2.DeactivateUser(ctx, &user_v2.DeactivateUserRequest{
|
||||
UserId: resp.UserId,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
createdUserIDs = append(createdUserIDs, resp.UserId)
|
||||
|
||||
for i := 1; i < count; i++ {
|
||||
resp = createHumanUser(t, ctx, orgID, i)
|
||||
createdUserIDs = append(createdUserIDs, resp.UserId)
|
||||
}
|
||||
|
||||
return createdUserIDs
|
||||
}
|
||||
|
||||
func createHumanUser(t require.TestingT, ctx context.Context, orgID string, i int) *user_v2.AddHumanUserResponse {
|
||||
// create remaining minimal users with faker data
|
||||
// no need to clean these up as identification attributes change each time
|
||||
resp, err := Instance.Client.UserV2.AddHumanUser(ctx, &user_v2.AddHumanUserRequest{
|
||||
Organization: &object.Organization{
|
||||
Org: &object.Organization_OrgId{
|
||||
OrgId: orgID,
|
||||
},
|
||||
},
|
||||
Username: gu.Ptr(fmt.Sprintf("scim-username-%d: %s", i, gofakeit.Username())),
|
||||
Profile: &user_v2.SetHumanProfile{
|
||||
GivenName: fmt.Sprintf("scim-givenname-%d: %s", i, gofakeit.FirstName()),
|
||||
FamilyName: fmt.Sprintf("scim-familyname-%d: %s", i, gofakeit.LastName()),
|
||||
PreferredLanguage: gu.Ptr("en-US"),
|
||||
Gender: gu.Ptr(user_v2.Gender_GENDER_MALE),
|
||||
},
|
||||
Email: &user_v2.SetHumanEmail{
|
||||
Email: fmt.Sprintf("scim-email-%d-%d@example.com", i, gofakeit.Number(0, 1_000_000)),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return resp
|
||||
}
|
@ -19,6 +19,7 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/api/scim/schemas"
|
||||
"github.com/zitadel/zitadel/internal/integration"
|
||||
"github.com/zitadel/zitadel/internal/integration/scim"
|
||||
"github.com/zitadel/zitadel/internal/test"
|
||||
"github.com/zitadel/zitadel/pkg/grpc/management"
|
||||
"github.com/zitadel/zitadel/pkg/grpc/user/v2"
|
||||
)
|
||||
@ -78,7 +79,7 @@ func TestReplaceUser(t *testing.T) {
|
||||
},
|
||||
DisplayName: "Babs Jensen-updated",
|
||||
NickName: "Babs-updated",
|
||||
ProfileUrl: integration.Must(schemas.ParseHTTPURL("http://login.example.com/bjensen-updated")),
|
||||
ProfileUrl: test.Must(schemas.ParseHTTPURL("http://login.example.com/bjensen-updated")),
|
||||
Emails: []*resources.ScimEmail{
|
||||
{
|
||||
Value: "bjensen-replaced-full@example.com",
|
||||
@ -124,11 +125,11 @@ func TestReplaceUser(t *testing.T) {
|
||||
},
|
||||
Photos: []*resources.ScimPhoto{
|
||||
{
|
||||
Value: *integration.Must(schemas.ParseHTTPURL("https://photos.example.com/profilephoto/72930000000Ccne/F-updated")),
|
||||
Value: *test.Must(schemas.ParseHTTPURL("https://photos.example.com/profilephoto/72930000000Ccne/F-updated")),
|
||||
Type: "photo-updated",
|
||||
},
|
||||
{
|
||||
Value: *integration.Must(schemas.ParseHTTPURL("https://photos.example.com/profilephoto/72930000000Ccne/T-updated")),
|
||||
Value: *test.Must(schemas.ParseHTTPURL("https://photos.example.com/profilephoto/72930000000Ccne/T-updated")),
|
||||
Type: "thumbnail-updated",
|
||||
},
|
||||
},
|
||||
@ -247,7 +248,7 @@ func TestReplaceUser(t *testing.T) {
|
||||
assert.Equal(t, "http://"+Instance.Host()+path.Join(schemas.HandlerPrefix, Instance.DefaultOrg.Id, "Users", createdUser.ID), replacedUser.Resource.Meta.Location)
|
||||
assert.Nil(t, createdUser.Password)
|
||||
|
||||
if !integration.PartiallyDeepEqual(tt.want, replacedUser) {
|
||||
if !test.PartiallyDeepEqual(tt.want, replacedUser) {
|
||||
t.Errorf("ReplaceUser() got = %#v, want %#v", replacedUser, tt.want)
|
||||
}
|
||||
|
||||
@ -256,7 +257,7 @@ func TestReplaceUser(t *testing.T) {
|
||||
// ensure the user is really stored and not just returned to the caller
|
||||
fetchedUser, err := Instance.Client.SCIM.Users.Get(CTX, Instance.DefaultOrg.Id, replacedUser.ID)
|
||||
require.NoError(ttt, err)
|
||||
if !integration.PartiallyDeepEqual(tt.want, fetchedUser) {
|
||||
if !test.PartiallyDeepEqual(tt.want, fetchedUser) {
|
||||
ttt.Errorf("GetUser() got = %#v, want %#v", fetchedUser, tt.want)
|
||||
}
|
||||
}, retryDuration, tick)
|
||||
@ -316,8 +317,8 @@ func TestReplaceUser_scopedExternalID(t *testing.T) {
|
||||
}
|
||||
|
||||
// both external IDs should be present on the user
|
||||
integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:externalId", "701984")
|
||||
integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:fooBazz:externalId", "replaced-external-id")
|
||||
test.AssertMapContains(tt, mdMap, "urn:zitadel:scim:externalId", "701984")
|
||||
test.AssertMapContains(tt, mdMap, "urn:zitadel:scim:fooBazz:externalId", "replaced-external-id")
|
||||
}, retryDuration, tick)
|
||||
|
||||
_, err = Instance.Client.UserV2.DeleteUser(CTX, &user.DeleteUserRequest{UserId: createdUser.ID})
|
||||
|
340
internal/api/scim/resources/filter/filter_parser.go
Normal file
340
internal/api/scim/resources/filter/filter_parser.go
Normal file
@ -0,0 +1,340 @@
|
||||
package filter
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/alecthomas/participle/v2"
|
||||
"github.com/alecthomas/participle/v2/lexer"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/scim/schemas"
|
||||
"github.com/zitadel/zitadel/internal/api/scim/serrors"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
// Filter The scim v2 filter
|
||||
// Separation between FilterSegment and Filter is required
|
||||
// due to the UnmarshalText method, which is used by the schema parser
|
||||
// as well as the participle parser but should do different things here.
|
||||
type Filter struct {
|
||||
Root Segment
|
||||
}
|
||||
|
||||
// Segment The root ast node for the filter grammar
|
||||
// according to the filter ABNF of https://datatracker.ietf.org/doc/html/rfc7644#section-3.4.2.2
|
||||
// FILTER = attrExp / logExp / valuePath / *1"not" "(" FILTER ")"
|
||||
// to reduce lookahead needs and reduce stack depth of the parser,
|
||||
// always match log expressions with optional operators
|
||||
type Segment struct {
|
||||
OrExp OrLogExp `parser:"@@"`
|
||||
}
|
||||
|
||||
// OrLogExp The logical expression according to the filter ABNF
|
||||
// separated in OrLogExp and AndLogExp to simplify parser stack depth and precedence
|
||||
// logExp = FILTER SP ("and" / "or") SP FILTER
|
||||
type OrLogExp struct {
|
||||
Left AndLogExp `parser:"@@"`
|
||||
Right *OrLogExp `parser:"(Whitespace 'or' Whitespace @@)?"`
|
||||
}
|
||||
|
||||
type AndLogExp struct {
|
||||
Left ValueAtom `parser:"@@"`
|
||||
Right *AndLogExp `parser:"(Whitespace 'and' Whitespace @@)?"`
|
||||
}
|
||||
|
||||
type ValueAtom struct {
|
||||
SubFilter *Segment `parser:"'(' @@ ')' |"`
|
||||
Negation *Segment `parser:"'not' '(' @@ ')' |"`
|
||||
ValuePath *ValuePath `parser:"@@ |"`
|
||||
AttrExp *AttrExp `parser:"@@"`
|
||||
}
|
||||
|
||||
// ValuePath The value path according to the filter ABNF
|
||||
// valuePath = attrPath "[" valFilter "]"
|
||||
// instead of a separate valFilter entity the LogExp
|
||||
// is used to simplify parsing.
|
||||
type ValuePath struct {
|
||||
AttrPath AttrPath `parser:"@@"`
|
||||
ValFilter OrLogExp `parser:"'[' @@ ']'"`
|
||||
}
|
||||
|
||||
// AttrExp The attribute expression according to the filter ABNF
|
||||
// attrExp = (attrPath SP "pr") / (attrPath SP compareOp SP compValue)
|
||||
type AttrExp struct {
|
||||
UnaryCondition *UnaryCondition `parser:"@@ |"`
|
||||
BinaryCondition *BinaryCondition `parser:"@@"`
|
||||
}
|
||||
|
||||
type UnaryCondition struct {
|
||||
Left AttrPath `parser:"@@ Whitespace"`
|
||||
Operator UnaryConditionOperator `parser:"@@"`
|
||||
}
|
||||
|
||||
type UnaryConditionOperator struct {
|
||||
Present bool `parser:"@'pr'"`
|
||||
}
|
||||
|
||||
type BinaryCondition struct {
|
||||
Left AttrPath `parser:"@@ Whitespace"`
|
||||
Operator CompareOp `parser:"@@ Whitespace"`
|
||||
Right CompValue `parser:"@@"`
|
||||
}
|
||||
|
||||
// CompareOp according to the scim filter ABNF
|
||||
// compareOp = "eq" / "ne" / "co" /
|
||||
// "sw" / "ew" /
|
||||
// "gt" / "lt" /
|
||||
// "ge" / "le"
|
||||
type CompareOp struct {
|
||||
Equal bool `parser:"@'eq' |"`
|
||||
NotEqual bool `parser:"@'ne' |"`
|
||||
Contains bool `parser:"@'co' |"`
|
||||
StartsWith bool `parser:"@'sw' |"`
|
||||
EndsWith bool `parser:"@'ew' |"`
|
||||
GreaterThan bool `parser:"@'gt' |"`
|
||||
GreaterThanOrEqual bool `parser:"@'ge' |"`
|
||||
LessThan bool `parser:"@'lt' |"`
|
||||
LessThanOrEqual bool `parser:"@'le'"`
|
||||
}
|
||||
|
||||
// CompValue the compare value according to the scim filter ABNF
|
||||
// compValue = false / null / true / number / string
|
||||
type CompValue struct {
|
||||
Null bool `parser:"@'null' |"`
|
||||
BooleanTrue bool `parser:"@'true' |"`
|
||||
BooleanFalse bool `parser:"@'false' |"`
|
||||
Int *int `parser:"@Int |"`
|
||||
Float *float64 `parser:"@Float |"`
|
||||
StringValue *string `parser:"@String"`
|
||||
}
|
||||
|
||||
// AttrPath the attribute path according to the scim filter ABNF
|
||||
// [URI ":"] AttrName *1subAttr
|
||||
type AttrPath struct {
|
||||
UrnAttributePrefix *string `parser:"(@UrnAttributePrefix)?"`
|
||||
AttrName string `parser:"@AttrName"`
|
||||
SubAttr *string `parser:"('.' @AttrName)?"`
|
||||
}
|
||||
|
||||
const (
|
||||
maxInputLength = 1000
|
||||
)
|
||||
|
||||
var (
|
||||
scimFilterLexer = lexer.MustSimple([]lexer.SimpleRule{
|
||||
// simplified version of RFC8141, last part isn't matched as in scim this is the attribute name
|
||||
// urn is additionally verified after parsing, use a more relaxed matching here
|
||||
{Name: "UrnAttributePrefix", Pattern: `urn:([\w()+,.=@;$_!*'%/?#-]+:)+`},
|
||||
{Name: "Float", Pattern: `[-+]?\d*\.\d+`},
|
||||
{Name: "Int", Pattern: `[-+]?\d+`},
|
||||
{Name: "Parenthesis", Pattern: `\(|\)|\[|\]`},
|
||||
{Name: "Punctuation", Pattern: `\.`},
|
||||
{Name: "String", Pattern: `"(\\"|[^"])*"`},
|
||||
{Name: "AttrName", Pattern: `[a-zA-Z][\w-]*`}, // AttrName according to the scim ABNF
|
||||
{Name: "Whitespace", Pattern: `[ \t\n\r]+`},
|
||||
})
|
||||
|
||||
scimFilterParser = buildParser[Segment]()
|
||||
)
|
||||
|
||||
func buildParser[T any]() *participle.Parser[T] {
|
||||
return participle.MustBuild[T](
|
||||
participle.Lexer(scimFilterLexer),
|
||||
participle.Unquote("String"),
|
||||
// Keyword literals are matched case-insensitive (according to https://datatracker.ietf.org/doc/html/rfc7644#section-3.4.2.2)
|
||||
// Keywords are a subset of AttrName
|
||||
participle.CaseInsensitive("AttrName"),
|
||||
participle.Elide("Whitespace"),
|
||||
participle.UseLookahead(participle.MaxLookahead),
|
||||
)
|
||||
}
|
||||
|
||||
func (f *Filter) UnmarshalText(text []byte) error {
|
||||
if len(text) == 0 {
|
||||
*f = Filter{}
|
||||
return nil
|
||||
}
|
||||
|
||||
parsedFilter, err := ParseFilter(string(text))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*f = *parsedFilter
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *Filter) UnmarshalJSON(data []byte) error {
|
||||
var rawFilter string
|
||||
if err := json.Unmarshal(data, &rawFilter); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return f.UnmarshalText([]byte(rawFilter))
|
||||
}
|
||||
|
||||
func (f *Filter) IsZero() bool {
|
||||
return f == nil || *f == Filter{}
|
||||
}
|
||||
|
||||
func ParseFilter(filter string) (*Filter, error) {
|
||||
if filter == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if len(filter) > maxInputLength {
|
||||
logging.WithFields("len", len(filter)).Infof("scim: filter exceeds maximum allowed length: %d", maxInputLength)
|
||||
return nil, serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgumentf(nil, "SCIM-filt13", "filter exceeds maximum allowed length: %d", maxInputLength))
|
||||
}
|
||||
|
||||
parsedFilter, err := scimFilterParser.ParseString("", filter)
|
||||
if err != nil {
|
||||
logging.WithError(err).Info("scim: failed to parse filter")
|
||||
return nil, serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgument(err, "SCIM-filt14", "failed to parse filter"))
|
||||
}
|
||||
|
||||
return &Filter{Root: *parsedFilter}, nil
|
||||
}
|
||||
|
||||
func (f *Filter) String() string {
|
||||
return f.Root.String()
|
||||
}
|
||||
|
||||
func (f *Segment) String() string {
|
||||
return f.OrExp.String()
|
||||
}
|
||||
|
||||
func (o *OrLogExp) String() string {
|
||||
if o.Right == nil {
|
||||
return o.Left.String()
|
||||
}
|
||||
|
||||
return "((" + o.Left.String() + ") or (" + o.Right.String() + "))"
|
||||
}
|
||||
|
||||
func (a *AndLogExp) String() string {
|
||||
if a.Right == nil {
|
||||
return a.Left.String()
|
||||
}
|
||||
|
||||
return "((" + a.Left.String() + ") and (" + a.Right.String() + "))"
|
||||
}
|
||||
|
||||
func (a *ValueAtom) String() string {
|
||||
switch {
|
||||
case a.SubFilter != nil:
|
||||
return "(" + a.SubFilter.String() + ")"
|
||||
case a.Negation != nil:
|
||||
return "not(" + a.Negation.String() + ")"
|
||||
case a.ValuePath != nil:
|
||||
return a.ValuePath.String()
|
||||
}
|
||||
|
||||
return a.AttrExp.String()
|
||||
}
|
||||
|
||||
func (v *ValuePath) String() string {
|
||||
return v.AttrPath.String() + "[" + v.ValFilter.String() + "]"
|
||||
}
|
||||
|
||||
func (a *AttrExp) String() string {
|
||||
if a.UnaryCondition != nil {
|
||||
return a.UnaryCondition.String()
|
||||
}
|
||||
|
||||
return a.BinaryCondition.String()
|
||||
}
|
||||
|
||||
func (u *UnaryCondition) String() string {
|
||||
return u.Left.String() + " " + u.Operator.String()
|
||||
}
|
||||
|
||||
func (u *UnaryConditionOperator) String() string {
|
||||
return "pr"
|
||||
}
|
||||
|
||||
func (b *BinaryCondition) String() string {
|
||||
return b.Left.String() + " " + b.Operator.String() + " " + b.Right.String()
|
||||
}
|
||||
|
||||
func (c *CompareOp) String() string {
|
||||
switch {
|
||||
case c.Equal:
|
||||
return "eq"
|
||||
case c.NotEqual:
|
||||
return "ne"
|
||||
case c.Contains:
|
||||
return "co"
|
||||
case c.StartsWith:
|
||||
return "sw"
|
||||
case c.EndsWith:
|
||||
return "ew"
|
||||
case c.GreaterThan:
|
||||
return "gt"
|
||||
case c.GreaterThanOrEqual:
|
||||
return "ge"
|
||||
case c.LessThan:
|
||||
return "lt"
|
||||
case c.LessThanOrEqual:
|
||||
return "le"
|
||||
}
|
||||
|
||||
return "<unknown CompareOp>"
|
||||
}
|
||||
|
||||
func (c *CompValue) String() string {
|
||||
switch {
|
||||
case c.Null:
|
||||
return "null"
|
||||
case c.BooleanTrue:
|
||||
return "true"
|
||||
case c.BooleanFalse:
|
||||
return "false"
|
||||
case c.Int != nil:
|
||||
return strconv.Itoa(*c.Int)
|
||||
case c.Float != nil:
|
||||
return strconv.FormatFloat(*c.Float, 'f', -1, 64)
|
||||
case c.StringValue != nil:
|
||||
return "\"" + *c.StringValue + "\""
|
||||
}
|
||||
return "<unknown CompValue>"
|
||||
}
|
||||
|
||||
func (a *AttrPath) String() string {
|
||||
var s = ""
|
||||
if a.UrnAttributePrefix != nil {
|
||||
s += *a.UrnAttributePrefix
|
||||
}
|
||||
|
||||
s += a.AttrName
|
||||
|
||||
if a.SubAttr != nil {
|
||||
s += "." + *a.SubAttr
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (a *AttrPath) validateSchema(expectedSchema schemas.ScimSchemaType) error {
|
||||
if a.UrnAttributePrefix == nil || *a.UrnAttributePrefix == string(expectedSchema)+":" {
|
||||
return nil
|
||||
}
|
||||
|
||||
logging.WithFields("urnPrefix", *a.UrnAttributePrefix).Info("scim filter: Invalid filter expression: unknown urn attribute prefix")
|
||||
return serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgument(nil, "SCIM-FF431", "Invalid filter expression: unknown urn attribute prefix"))
|
||||
}
|
||||
|
||||
func (a *AttrPath) Segments() []string {
|
||||
// user lower, since attribute names in scim are always case-insensitive
|
||||
if a.SubAttr != nil {
|
||||
return []string{strings.ToLower(a.AttrName), strings.ToLower(*a.SubAttr)}
|
||||
}
|
||||
|
||||
return []string{strings.ToLower(a.AttrName)}
|
||||
}
|
||||
|
||||
func (a *AttrPath) FieldPath() string {
|
||||
return strings.Join(a.Segments(), ".")
|
||||
}
|
868
internal/api/scim/resources/filter/filter_parser_test.go
Normal file
868
internal/api/scim/resources/filter/filter_parser_test.go
Normal file
@ -0,0 +1,868 @@
|
||||
package filter
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/muhlemmer/gu"
|
||||
)
|
||||
|
||||
var longString = ""
|
||||
|
||||
func init() {
|
||||
var sb strings.Builder
|
||||
for i := 0; i < maxInputLength+1; i++ {
|
||||
sb.WriteRune('x')
|
||||
}
|
||||
|
||||
longString = sb.String()
|
||||
}
|
||||
|
||||
func TestParseFilter(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
filter string
|
||||
want *Filter
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
},
|
||||
{
|
||||
name: "too long",
|
||||
filter: longString,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid syntax",
|
||||
filter: "fooBar[['baz']]",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "unknown binary operator",
|
||||
filter: `userName fu "bjensen"`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "unknown unary operator",
|
||||
filter: `userName ok`,
|
||||
wantErr: true,
|
||||
},
|
||||
|
||||
// test cases from https://datatracker.ietf.org/doc/html/rfc7644#section-3.4.2.2
|
||||
{
|
||||
name: "negation",
|
||||
filter: `not(username pr)`,
|
||||
want: &Filter{
|
||||
Root: Segment{
|
||||
OrExp: OrLogExp{
|
||||
Left: AndLogExp{
|
||||
Left: ValueAtom{
|
||||
Negation: &Segment{
|
||||
OrExp: OrLogExp{
|
||||
Left: AndLogExp{
|
||||
Left: ValueAtom{
|
||||
AttrExp: &AttrExp{
|
||||
UnaryCondition: &UnaryCondition{
|
||||
Left: AttrPath{
|
||||
AttrName: "username",
|
||||
},
|
||||
Operator: UnaryConditionOperator{
|
||||
Present: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "number",
|
||||
filter: `age gt 10`,
|
||||
want: &Filter{
|
||||
Root: Segment{
|
||||
OrExp: OrLogExp{
|
||||
Left: AndLogExp{
|
||||
Left: ValueAtom{
|
||||
AttrExp: &AttrExp{
|
||||
BinaryCondition: &BinaryCondition{
|
||||
Left: AttrPath{
|
||||
AttrName: "age",
|
||||
},
|
||||
Operator: CompareOp{
|
||||
GreaterThan: true,
|
||||
},
|
||||
Right: CompValue{
|
||||
Int: gu.Ptr(10),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "float",
|
||||
filter: `age gt 10.5`,
|
||||
want: &Filter{
|
||||
Root: Segment{
|
||||
OrExp: OrLogExp{
|
||||
Left: AndLogExp{
|
||||
Left: ValueAtom{
|
||||
AttrExp: &AttrExp{
|
||||
BinaryCondition: &BinaryCondition{
|
||||
Left: AttrPath{
|
||||
AttrName: "age",
|
||||
},
|
||||
Operator: CompareOp{
|
||||
GreaterThan: true,
|
||||
},
|
||||
Right: CompValue{
|
||||
Float: gu.Ptr(10.5),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "null",
|
||||
filter: `age eq null`,
|
||||
want: &Filter{
|
||||
Root: Segment{
|
||||
OrExp: OrLogExp{
|
||||
Left: AndLogExp{
|
||||
Left: ValueAtom{
|
||||
AttrExp: &AttrExp{
|
||||
BinaryCondition: &BinaryCondition{
|
||||
Left: AttrPath{
|
||||
AttrName: "age",
|
||||
},
|
||||
Operator: CompareOp{
|
||||
Equal: true,
|
||||
},
|
||||
Right: CompValue{
|
||||
Null: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "simple binary operator",
|
||||
filter: `userName eq "bjensen"`,
|
||||
want: &Filter{
|
||||
Root: Segment{
|
||||
OrExp: OrLogExp{
|
||||
Left: AndLogExp{
|
||||
Left: ValueAtom{
|
||||
AttrExp: &AttrExp{
|
||||
BinaryCondition: &BinaryCondition{
|
||||
Left: AttrPath{
|
||||
AttrName: "userName",
|
||||
},
|
||||
Operator: CompareOp{
|
||||
Equal: true,
|
||||
},
|
||||
Right: CompValue{
|
||||
StringValue: gu.Ptr("bjensen"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "uppercase binary operator",
|
||||
filter: `userName EQ "bjensen"`,
|
||||
want: &Filter{
|
||||
Root: Segment{
|
||||
OrExp: OrLogExp{
|
||||
Left: AndLogExp{
|
||||
Left: ValueAtom{
|
||||
AttrExp: &AttrExp{
|
||||
BinaryCondition: &BinaryCondition{
|
||||
Left: AttrPath{
|
||||
AttrName: "userName",
|
||||
},
|
||||
Operator: CompareOp{
|
||||
Equal: true,
|
||||
},
|
||||
Right: CompValue{
|
||||
StringValue: gu.Ptr("bjensen"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "case-insensitive literals and operators",
|
||||
filter: `active Eq TRue`,
|
||||
want: &Filter{
|
||||
Root: Segment{
|
||||
OrExp: OrLogExp{
|
||||
Left: AndLogExp{
|
||||
Left: ValueAtom{
|
||||
AttrExp: &AttrExp{
|
||||
BinaryCondition: &BinaryCondition{
|
||||
Left: AttrPath{
|
||||
AttrName: "active",
|
||||
},
|
||||
Operator: CompareOp{
|
||||
Equal: true,
|
||||
},
|
||||
Right: CompValue{
|
||||
BooleanTrue: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "extra whitespace",
|
||||
filter: `userName eq "bjensen"`,
|
||||
want: &Filter{
|
||||
Root: Segment{
|
||||
OrExp: OrLogExp{
|
||||
Left: AndLogExp{
|
||||
Left: ValueAtom{
|
||||
AttrExp: &AttrExp{
|
||||
BinaryCondition: &BinaryCondition{
|
||||
Left: AttrPath{
|
||||
AttrName: "userName",
|
||||
},
|
||||
Operator: CompareOp{
|
||||
Equal: true,
|
||||
},
|
||||
Right: CompValue{
|
||||
StringValue: gu.Ptr("bjensen"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nested attribute binary operator",
|
||||
filter: `name.familyName co "O'Malley"`,
|
||||
want: &Filter{
|
||||
Root: Segment{
|
||||
OrExp: OrLogExp{
|
||||
Left: AndLogExp{
|
||||
Left: ValueAtom{
|
||||
AttrExp: &AttrExp{
|
||||
BinaryCondition: &BinaryCondition{
|
||||
Left: AttrPath{
|
||||
AttrName: "name",
|
||||
SubAttr: gu.Ptr("familyName"),
|
||||
},
|
||||
Operator: CompareOp{
|
||||
Contains: true,
|
||||
},
|
||||
Right: CompValue{
|
||||
StringValue: gu.Ptr("O'Malley"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "urn prefixed",
|
||||
filter: `urn:ietf:params:scim:schemas:core:2.0:User:userName sw "J"`,
|
||||
want: &Filter{
|
||||
Root: Segment{
|
||||
OrExp: OrLogExp{
|
||||
Left: AndLogExp{
|
||||
Left: ValueAtom{
|
||||
AttrExp: &AttrExp{
|
||||
BinaryCondition: &BinaryCondition{
|
||||
Left: AttrPath{
|
||||
UrnAttributePrefix: gu.Ptr("urn:ietf:params:scim:schemas:core:2.0:User:"),
|
||||
AttrName: "userName",
|
||||
},
|
||||
Operator: CompareOp{
|
||||
StartsWith: true,
|
||||
},
|
||||
Right: CompValue{
|
||||
StringValue: gu.Ptr("J"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "urn prefixed nested",
|
||||
filter: `urn:ietf:params:scim:schemas:core:2.0:User:userName sw "J" and urn:ietf:params:scim:schemas:core:2.0:User:emails.value ew "@example.com"`,
|
||||
want: &Filter{
|
||||
Root: Segment{
|
||||
OrExp: OrLogExp{
|
||||
Left: AndLogExp{
|
||||
Left: ValueAtom{
|
||||
AttrExp: &AttrExp{
|
||||
BinaryCondition: &BinaryCondition{
|
||||
Left: AttrPath{
|
||||
UrnAttributePrefix: gu.Ptr("urn:ietf:params:scim:schemas:core:2.0:User:"),
|
||||
AttrName: "userName",
|
||||
},
|
||||
Operator: CompareOp{
|
||||
StartsWith: true,
|
||||
},
|
||||
Right: CompValue{
|
||||
StringValue: gu.Ptr("J"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Right: &AndLogExp{
|
||||
Left: ValueAtom{
|
||||
AttrExp: &AttrExp{
|
||||
BinaryCondition: &BinaryCondition{
|
||||
Left: AttrPath{
|
||||
UrnAttributePrefix: gu.Ptr("urn:ietf:params:scim:schemas:core:2.0:User:"),
|
||||
AttrName: "emails",
|
||||
SubAttr: gu.Ptr("value"),
|
||||
},
|
||||
Operator: CompareOp{
|
||||
EndsWith: true,
|
||||
},
|
||||
Right: CompValue{
|
||||
StringValue: gu.Ptr("@example.com"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unary operator",
|
||||
filter: `title pr`,
|
||||
want: &Filter{
|
||||
Root: Segment{
|
||||
OrExp: OrLogExp{
|
||||
Left: AndLogExp{
|
||||
Left: ValueAtom{
|
||||
AttrExp: &AttrExp{
|
||||
UnaryCondition: &UnaryCondition{
|
||||
Left: AttrPath{
|
||||
AttrName: "title",
|
||||
},
|
||||
Operator: UnaryConditionOperator{
|
||||
Present: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "binary nested date operator",
|
||||
filter: `meta.lastModified gt "2011-05-13T04:42:34Z"`,
|
||||
want: &Filter{
|
||||
Root: Segment{
|
||||
OrExp: OrLogExp{
|
||||
Left: AndLogExp{
|
||||
Left: ValueAtom{
|
||||
AttrExp: &AttrExp{
|
||||
BinaryCondition: &BinaryCondition{
|
||||
Left: AttrPath{
|
||||
AttrName: "meta",
|
||||
SubAttr: gu.Ptr("lastModified"),
|
||||
},
|
||||
Operator: CompareOp{
|
||||
GreaterThan: true,
|
||||
},
|
||||
Right: CompValue{
|
||||
StringValue: gu.Ptr("2011-05-13T04:42:34Z"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "and logical expression",
|
||||
filter: `title pr and userType eq "Employee"`,
|
||||
want: &Filter{
|
||||
Root: Segment{
|
||||
OrExp: OrLogExp{
|
||||
Left: AndLogExp{
|
||||
Left: ValueAtom{
|
||||
AttrExp: &AttrExp{
|
||||
UnaryCondition: &UnaryCondition{
|
||||
Left: AttrPath{
|
||||
AttrName: "title",
|
||||
},
|
||||
Operator: UnaryConditionOperator{
|
||||
Present: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Right: &AndLogExp{
|
||||
Left: ValueAtom{
|
||||
AttrExp: &AttrExp{
|
||||
BinaryCondition: &BinaryCondition{
|
||||
Left: AttrPath{
|
||||
AttrName: "userType",
|
||||
},
|
||||
Operator: CompareOp{
|
||||
Equal: true,
|
||||
},
|
||||
Right: CompValue{
|
||||
StringValue: gu.Ptr("Employee"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nested and / or with grouping",
|
||||
filter: `userType eq "Employee" and (emails co "example.com" or emails.value co "example.org")`,
|
||||
want: &Filter{
|
||||
Root: Segment{
|
||||
OrExp: OrLogExp{
|
||||
Left: AndLogExp{
|
||||
Left: ValueAtom{
|
||||
AttrExp: &AttrExp{
|
||||
BinaryCondition: &BinaryCondition{
|
||||
Left: AttrPath{
|
||||
AttrName: "userType",
|
||||
},
|
||||
Operator: CompareOp{
|
||||
Equal: true,
|
||||
},
|
||||
Right: CompValue{
|
||||
StringValue: gu.Ptr("Employee"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Right: &AndLogExp{
|
||||
Left: ValueAtom{
|
||||
SubFilter: &Segment{
|
||||
OrExp: OrLogExp{
|
||||
Left: AndLogExp{
|
||||
Left: ValueAtom{
|
||||
AttrExp: &AttrExp{
|
||||
BinaryCondition: &BinaryCondition{
|
||||
Left: AttrPath{
|
||||
AttrName: "emails",
|
||||
},
|
||||
Operator: CompareOp{
|
||||
Contains: true,
|
||||
},
|
||||
Right: CompValue{
|
||||
StringValue: gu.Ptr("example.com"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Right: &OrLogExp{
|
||||
Left: AndLogExp{
|
||||
Left: ValueAtom{
|
||||
AttrExp: &AttrExp{
|
||||
BinaryCondition: &BinaryCondition{
|
||||
Left: AttrPath{
|
||||
AttrName: "emails",
|
||||
SubAttr: gu.Ptr("value"),
|
||||
},
|
||||
Operator: CompareOp{
|
||||
Contains: true,
|
||||
},
|
||||
Right: CompValue{
|
||||
StringValue: gu.Ptr("example.org"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nested and / or without grouping",
|
||||
filter: `userType eq "Employee" and emails co "example.com" or emails.value co "example2.org"`,
|
||||
want: &Filter{
|
||||
Root: Segment{
|
||||
OrExp: OrLogExp{
|
||||
Left: AndLogExp{
|
||||
Left: ValueAtom{
|
||||
AttrExp: &AttrExp{
|
||||
BinaryCondition: &BinaryCondition{
|
||||
Left: AttrPath{
|
||||
AttrName: "userType",
|
||||
},
|
||||
Operator: CompareOp{
|
||||
Equal: true,
|
||||
},
|
||||
Right: CompValue{
|
||||
StringValue: gu.Ptr("Employee"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Right: &AndLogExp{
|
||||
Left: ValueAtom{
|
||||
AttrExp: &AttrExp{
|
||||
BinaryCondition: &BinaryCondition{
|
||||
Left: AttrPath{
|
||||
AttrName: "emails",
|
||||
},
|
||||
Operator: CompareOp{
|
||||
Contains: true,
|
||||
},
|
||||
Right: CompValue{
|
||||
StringValue: gu.Ptr("example.com"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Right: &OrLogExp{
|
||||
Left: AndLogExp{
|
||||
Left: ValueAtom{
|
||||
AttrExp: &AttrExp{
|
||||
BinaryCondition: &BinaryCondition{
|
||||
Left: AttrPath{
|
||||
AttrName: "emails",
|
||||
SubAttr: gu.Ptr("value"),
|
||||
},
|
||||
Operator: CompareOp{
|
||||
Contains: true,
|
||||
},
|
||||
Right: CompValue{
|
||||
StringValue: gu.Ptr("example2.org"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nested and / or with negated grouping",
|
||||
filter: `userType ne "Employee" and not (emails co "example.com" or emails.value co "example.org")`,
|
||||
want: &Filter{
|
||||
Root: Segment{
|
||||
OrExp: OrLogExp{
|
||||
Left: AndLogExp{
|
||||
Left: ValueAtom{
|
||||
AttrExp: &AttrExp{
|
||||
BinaryCondition: &BinaryCondition{
|
||||
Left: AttrPath{
|
||||
AttrName: "userType",
|
||||
},
|
||||
Operator: CompareOp{
|
||||
NotEqual: true,
|
||||
},
|
||||
Right: CompValue{
|
||||
StringValue: gu.Ptr("Employee"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Right: &AndLogExp{
|
||||
Left: ValueAtom{
|
||||
Negation: &Segment{
|
||||
OrExp: OrLogExp{
|
||||
Left: AndLogExp{
|
||||
Left: ValueAtom{
|
||||
AttrExp: &AttrExp{
|
||||
BinaryCondition: &BinaryCondition{
|
||||
Left: AttrPath{
|
||||
AttrName: "emails",
|
||||
},
|
||||
Operator: CompareOp{
|
||||
Contains: true,
|
||||
},
|
||||
Right: CompValue{
|
||||
StringValue: gu.Ptr("example.com"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Right: &OrLogExp{
|
||||
Left: AndLogExp{
|
||||
Left: ValueAtom{
|
||||
AttrExp: &AttrExp{
|
||||
BinaryCondition: &BinaryCondition{
|
||||
Left: AttrPath{
|
||||
AttrName: "emails",
|
||||
SubAttr: gu.Ptr("value"),
|
||||
},
|
||||
Operator: CompareOp{
|
||||
Contains: true,
|
||||
},
|
||||
Right: CompValue{
|
||||
StringValue: gu.Ptr("example.org"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nested value path path",
|
||||
filter: `userType eq "Employee" and emails[type eq "work" and value co "@example.com"]`,
|
||||
want: &Filter{
|
||||
Root: Segment{
|
||||
OrExp: OrLogExp{
|
||||
Left: AndLogExp{
|
||||
Left: ValueAtom{
|
||||
AttrExp: &AttrExp{
|
||||
BinaryCondition: &BinaryCondition{
|
||||
Left: AttrPath{
|
||||
AttrName: "userType",
|
||||
},
|
||||
Operator: CompareOp{
|
||||
Equal: true,
|
||||
},
|
||||
Right: CompValue{
|
||||
StringValue: gu.Ptr("Employee"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Right: &AndLogExp{
|
||||
Left: ValueAtom{
|
||||
ValuePath: &ValuePath{
|
||||
AttrPath: AttrPath{
|
||||
AttrName: "emails",
|
||||
},
|
||||
ValFilter: OrLogExp{
|
||||
Left: AndLogExp{
|
||||
Left: ValueAtom{
|
||||
AttrExp: &AttrExp{
|
||||
BinaryCondition: &BinaryCondition{
|
||||
Left: AttrPath{
|
||||
AttrName: "type",
|
||||
},
|
||||
Operator: CompareOp{
|
||||
Equal: true,
|
||||
},
|
||||
Right: CompValue{
|
||||
StringValue: gu.Ptr("work"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Right: &AndLogExp{
|
||||
Left: ValueAtom{
|
||||
AttrExp: &AttrExp{
|
||||
BinaryCondition: &BinaryCondition{
|
||||
Left: AttrPath{
|
||||
AttrName: "value",
|
||||
},
|
||||
Operator: CompareOp{
|
||||
Contains: true,
|
||||
},
|
||||
Right: CompValue{
|
||||
StringValue: gu.Ptr("@example.com"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "complex value path filter",
|
||||
filter: `emails[type eq "work" and value co "@example.com"] or ims[type eq "xmpp" and value co "@foo.com"]`,
|
||||
want: &Filter{
|
||||
Root: Segment{
|
||||
OrExp: OrLogExp{
|
||||
Left: AndLogExp{
|
||||
Left: ValueAtom{
|
||||
ValuePath: &ValuePath{
|
||||
AttrPath: AttrPath{
|
||||
AttrName: "emails",
|
||||
},
|
||||
ValFilter: OrLogExp{
|
||||
Left: AndLogExp{
|
||||
Left: ValueAtom{
|
||||
AttrExp: &AttrExp{
|
||||
BinaryCondition: &BinaryCondition{
|
||||
Left: AttrPath{
|
||||
AttrName: "type",
|
||||
},
|
||||
Operator: CompareOp{
|
||||
Equal: true,
|
||||
},
|
||||
Right: CompValue{
|
||||
StringValue: gu.Ptr("work"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Right: &AndLogExp{
|
||||
Left: ValueAtom{
|
||||
AttrExp: &AttrExp{
|
||||
BinaryCondition: &BinaryCondition{
|
||||
Left: AttrPath{
|
||||
AttrName: "value",
|
||||
},
|
||||
Operator: CompareOp{
|
||||
Contains: true,
|
||||
},
|
||||
Right: CompValue{
|
||||
StringValue: gu.Ptr("@example.com"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Right: &OrLogExp{
|
||||
Left: AndLogExp{
|
||||
Left: ValueAtom{
|
||||
ValuePath: &ValuePath{
|
||||
AttrPath: AttrPath{
|
||||
AttrName: "ims",
|
||||
},
|
||||
ValFilter: OrLogExp{
|
||||
Left: AndLogExp{
|
||||
Left: ValueAtom{
|
||||
AttrExp: &AttrExp{
|
||||
BinaryCondition: &BinaryCondition{
|
||||
Left: AttrPath{
|
||||
AttrName: "type",
|
||||
},
|
||||
Operator: CompareOp{
|
||||
Equal: true,
|
||||
},
|
||||
Right: CompValue{
|
||||
StringValue: gu.Ptr("xmpp"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Right: &AndLogExp{
|
||||
Left: ValueAtom{
|
||||
AttrExp: &AttrExp{
|
||||
BinaryCondition: &BinaryCondition{
|
||||
Left: AttrPath{
|
||||
AttrName: "value",
|
||||
},
|
||||
Operator: CompareOp{
|
||||
Contains: true,
|
||||
},
|
||||
Right: CompValue{
|
||||
StringValue: gu.Ptr("@foo.com"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := ParseFilter(tt.filter)
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ParseFilter() error = %#v, wantErr %#v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("ParseFilter() got = %s, want %s", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
347
internal/api/scim/resources/filter/filter_query_builder.go
Normal file
347
internal/api/scim/resources/filter/filter_query_builder.go
Normal file
@ -0,0 +1,347 @@
|
||||
package filter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/scim/schemas"
|
||||
"github.com/zitadel/zitadel/internal/api/scim/serrors"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
// FieldPathMapping maps lowercase json field names of the resource to the matching column in the projection
|
||||
type FieldPathMapping map[string]*QueryFieldInfo
|
||||
|
||||
// queryBuilder builds a query for a filter based on the visitor pattern
|
||||
type queryBuilder struct {
|
||||
ctx context.Context
|
||||
schema schemas.ScimSchemaType
|
||||
fieldPathMapping FieldPathMapping
|
||||
|
||||
// attrPathPrefixes prefixes of attributes that
|
||||
// should also take into account when resolving an attr path to a column.
|
||||
// This is used for "a[b eq 10]" expressions, when resolving b, a would be the prefix.
|
||||
attrPathPrefixStack []*AttrPath
|
||||
}
|
||||
|
||||
type MappedQueryBuilderFunc func(ctx context.Context, compareValue *CompValue, op *CompareOp) (query.SearchQuery, error)
|
||||
|
||||
type QueryFieldInfo struct {
|
||||
Column query.Column
|
||||
FieldType FieldType
|
||||
BuildMappedQuery MappedQueryBuilderFunc
|
||||
}
|
||||
|
||||
type FieldType int
|
||||
|
||||
const (
|
||||
FieldTypeCustom FieldType = iota
|
||||
FieldTypeString
|
||||
FieldTypeNumber
|
||||
FieldTypeBoolean
|
||||
FieldTypeTimestamp
|
||||
)
|
||||
|
||||
func (m FieldPathMapping) Resolve(path string) (*QueryFieldInfo, error) {
|
||||
info, ok := m[strings.ToLower(path)]
|
||||
if !ok {
|
||||
logging.WithFields("fieldPath", path).Info("scim filter: Invalid filter expression: unknown or unsupported field")
|
||||
return nil, serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgumentf(nil, "SCIM-FF433", "Invalid filter expression: unknown or unsupported field %s", path))
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
func (f *Filter) BuildQuery(ctx context.Context, schema schemas.ScimSchemaType, fieldPathColumnMapping FieldPathMapping) (query.SearchQuery, error) {
|
||||
builder := &queryBuilder{
|
||||
ctx: ctx,
|
||||
schema: schema,
|
||||
fieldPathMapping: fieldPathColumnMapping,
|
||||
}
|
||||
return builder.visitSegment(&f.Root)
|
||||
}
|
||||
|
||||
func (b *queryBuilder) pushAttrPath(path *AttrPath) {
|
||||
b.attrPathPrefixStack = append(b.attrPathPrefixStack, path)
|
||||
}
|
||||
|
||||
func (b *queryBuilder) popAttrPath() {
|
||||
b.attrPathPrefixStack = b.attrPathPrefixStack[:len(b.attrPathPrefixStack)-1]
|
||||
}
|
||||
|
||||
func (b *queryBuilder) visitSegment(s *Segment) (query.SearchQuery, error) {
|
||||
return b.visitOr(&s.OrExp)
|
||||
}
|
||||
|
||||
func (b *queryBuilder) visitOr(or *OrLogExp) (query.SearchQuery, error) {
|
||||
left, err := b.visitAnd(&or.Left)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if or.Right == nil {
|
||||
return left, nil
|
||||
}
|
||||
|
||||
right, err := b.visitOr(or.Right)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// flatten nested or queries
|
||||
if rightOr, ok := right.(*query.OrQuery); ok {
|
||||
rightOr.Prepend(left)
|
||||
return rightOr, nil
|
||||
}
|
||||
|
||||
return query.NewOrQuery(left, right)
|
||||
}
|
||||
|
||||
func (b *queryBuilder) visitAnd(and *AndLogExp) (query.SearchQuery, error) {
|
||||
left, err := b.visitAtom(&and.Left)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if and.Right == nil {
|
||||
return left, nil
|
||||
}
|
||||
|
||||
right, err := b.visitAnd(and.Right)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// flatten nested and queries
|
||||
if rightAnd, ok := right.(*query.AndQuery); ok {
|
||||
rightAnd.Prepend(left)
|
||||
return rightAnd, nil
|
||||
}
|
||||
|
||||
return query.NewAndQuery(left, right)
|
||||
}
|
||||
|
||||
func (b *queryBuilder) visitAtom(atom *ValueAtom) (query.SearchQuery, error) {
|
||||
switch {
|
||||
case atom.SubFilter != nil:
|
||||
return b.visitSegment(atom.SubFilter)
|
||||
case atom.Negation != nil:
|
||||
f, err := b.visitSegment(atom.Negation)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return query.NewNotQuery(f)
|
||||
case atom.ValuePath != nil:
|
||||
return b.visitValuePath(atom.ValuePath)
|
||||
case atom.AttrExp != nil:
|
||||
return b.visitAttrExp(atom.AttrExp)
|
||||
}
|
||||
|
||||
return nil, serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgument(nil, "SCIM-FF412", "Invalid filter expression"))
|
||||
}
|
||||
|
||||
func (b *queryBuilder) visitValuePath(path *ValuePath) (query.SearchQuery, error) {
|
||||
b.pushAttrPath(&path.AttrPath)
|
||||
defer b.popAttrPath()
|
||||
return b.visitOr(&path.ValFilter)
|
||||
}
|
||||
|
||||
func (b *queryBuilder) visitAttrExp(exp *AttrExp) (query.SearchQuery, error) {
|
||||
switch {
|
||||
case exp.UnaryCondition != nil:
|
||||
return b.visitUnaryCondition(exp.UnaryCondition)
|
||||
case exp.BinaryCondition != nil:
|
||||
return b.visitBinaryCondition(exp.BinaryCondition)
|
||||
}
|
||||
|
||||
return nil, serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgument(nil, "SCIM-FF413", "Invalid filter expression"))
|
||||
}
|
||||
|
||||
func (b *queryBuilder) visitUnaryCondition(condition *UnaryCondition) (query.SearchQuery, error) {
|
||||
// only supported unary operator is present
|
||||
if !condition.Operator.Present {
|
||||
return nil, serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgument(nil, "SCIM-FF419", "Unknown unary filter operator"))
|
||||
}
|
||||
|
||||
field, err := b.visitAttrPath(&condition.Left)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if field.FieldType == FieldTypeCustom {
|
||||
return nil, serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgument(nil, "SCIM-FXX49", "Unsupported attribute for unary filter operator"))
|
||||
}
|
||||
|
||||
return query.NewNotNullQuery(field.Column)
|
||||
}
|
||||
|
||||
func (b *queryBuilder) visitBinaryCondition(condition *BinaryCondition) (query.SearchQuery, error) {
|
||||
left, err := b.visitAttrPath(&condition.Left)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if condition.Operator.Equal && condition.Right.Null {
|
||||
return query.NewIsNullQuery(left.Column)
|
||||
}
|
||||
|
||||
if condition.Operator.NotEqual && condition.Right.Null {
|
||||
return query.NewNotNullQuery(left.Column)
|
||||
}
|
||||
|
||||
switch left.FieldType {
|
||||
case FieldTypeCustom:
|
||||
return left.BuildMappedQuery(b.ctx, &condition.Right, &condition.Operator)
|
||||
case FieldTypeTimestamp:
|
||||
return b.buildTimestampQuery(left, condition.Right, &condition.Operator)
|
||||
case FieldTypeString:
|
||||
return b.buildTextQuery(left, condition.Right, &condition.Operator)
|
||||
case FieldTypeNumber:
|
||||
return b.buildNumberQuery(left, condition.Right, &condition.Operator)
|
||||
case FieldTypeBoolean:
|
||||
return b.buildBooleanQuery(left, condition.Right, &condition.Operator)
|
||||
}
|
||||
|
||||
logging.WithFields("fieldType", left.FieldType).Error("Unknown field type")
|
||||
return nil, serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgument(nil, "SCIM-FF417", "Unknown filter expression field type"))
|
||||
}
|
||||
|
||||
func (b *queryBuilder) buildTimestampQuery(left *QueryFieldInfo, right CompValue, op *CompareOp) (query.SearchQuery, error) {
|
||||
if right.StringValue == nil {
|
||||
return nil, serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgument(nil, "SCIM-FF451", "Invalid filter expression: the compare value for a timestamp has to be a RFC3339 string"))
|
||||
}
|
||||
|
||||
timestamp, err := time.Parse(time.RFC3339, *right.StringValue)
|
||||
if err != nil {
|
||||
return nil, serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgument(err, "SCIM-FF421", "Invalid filter expression: the compare value for a timestamp has to be a RFC3339 string"))
|
||||
}
|
||||
|
||||
var comp query.TimestampComparison
|
||||
switch {
|
||||
case op.Equal:
|
||||
comp = query.TimestampEquals
|
||||
case op.GreaterThan:
|
||||
comp = query.TimestampGreater
|
||||
case op.GreaterThanOrEqual:
|
||||
comp = query.TimestampGreaterOrEquals
|
||||
case op.LessThan:
|
||||
comp = query.TimestampLess
|
||||
case op.LessThanOrEqual:
|
||||
comp = query.TimestampLessOrEquals
|
||||
default:
|
||||
return nil, serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgument(nil, "SCIM-FF422", "Invalid filter expression: unsupported comparison operator for timestamp fields"))
|
||||
}
|
||||
|
||||
return query.NewTimestampQuery(left.Column, timestamp, comp)
|
||||
}
|
||||
|
||||
func (b *queryBuilder) buildNumberQuery(left *QueryFieldInfo, right CompValue, op *CompareOp) (query.SearchQuery, error) {
|
||||
if right.Int == nil && right.Float == nil {
|
||||
return nil, serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgument(nil, "SCIM-FF423", "Invalid filter expression: unsupported comparison value for numeric fields"))
|
||||
}
|
||||
|
||||
var comp query.NumberComparison
|
||||
switch {
|
||||
case op.Equal:
|
||||
comp = query.NumberEquals
|
||||
case op.NotEqual:
|
||||
comp = query.NumberNotEquals
|
||||
case op.GreaterThan:
|
||||
comp = query.NumberGreater
|
||||
case op.GreaterThanOrEqual:
|
||||
comp = query.NumberGreaterOrEqual
|
||||
case op.LessThan:
|
||||
comp = query.NumberLess
|
||||
case op.LessThanOrEqual:
|
||||
comp = query.NumberLessOrEqual
|
||||
default:
|
||||
return nil, serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgument(nil, "SCIM-FF424", "Invalid filter expression: unsupported comparison operator for number fields"))
|
||||
}
|
||||
|
||||
var value interface{}
|
||||
if right.Int != nil {
|
||||
value = *right.Int
|
||||
} else {
|
||||
value = *right.Float
|
||||
}
|
||||
return query.NewNumberQuery(left.Column, value, comp)
|
||||
}
|
||||
|
||||
func (b *queryBuilder) buildBooleanQuery(field *QueryFieldInfo, right CompValue, op *CompareOp) (query.SearchQuery, error) {
|
||||
if !right.BooleanTrue && !right.BooleanFalse {
|
||||
return nil, serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgument(nil, "SCIM-FF428", "Invalid filter expression: unsupported comparison value for boolean field"))
|
||||
}
|
||||
|
||||
if !op.Equal && !op.NotEqual {
|
||||
return nil, serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgument(nil, "SCIM-FF427", "Invalid filter expression: unsupported comparison operator for boolean field"))
|
||||
}
|
||||
|
||||
return query.NewBoolQuery(field.Column, (op.Equal && right.BooleanTrue) || (op.NotEqual && right.BooleanFalse))
|
||||
}
|
||||
|
||||
func (b *queryBuilder) buildTextQuery(field *QueryFieldInfo, right CompValue, op *CompareOp) (query.SearchQuery, error) {
|
||||
if right.StringValue == nil {
|
||||
return nil, serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgument(nil, "SCIM-FF429", "Invalid filter expression: unsupported comparison value for text field"))
|
||||
}
|
||||
|
||||
var comp query.TextComparison
|
||||
switch {
|
||||
case op.Equal:
|
||||
comp = query.TextEquals
|
||||
case op.NotEqual:
|
||||
comp = query.TextNotEquals
|
||||
case op.Contains:
|
||||
comp = query.TextContains
|
||||
case op.StartsWith:
|
||||
comp = query.TextStartsWith
|
||||
case op.EndsWith:
|
||||
comp = query.TextEndsWith
|
||||
default:
|
||||
return nil, serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgument(nil, "SCIM-FF425", "Invalid filter expression: unsupported comparison operator for text fields"))
|
||||
}
|
||||
|
||||
return query.NewTextQuery(field.Column, *right.StringValue, comp)
|
||||
}
|
||||
|
||||
func (b *queryBuilder) visitAttrPath(attrPath *AttrPath) (*QueryFieldInfo, error) {
|
||||
b.pushAttrPath(attrPath)
|
||||
defer b.popAttrPath()
|
||||
|
||||
field, err := b.reduceAttrPaths(b.attrPathPrefixStack)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return b.fieldPathMapping.Resolve(field)
|
||||
}
|
||||
|
||||
// reduceAttrPaths reduces a slice of AttrPath
|
||||
// to a simple urn + fieldPath combination.
|
||||
// The urn is ensured to be unique across all segments and either to be empty or to match the schema of the builder.
|
||||
// The resulting fieldPath is in the form of a.b.c with a minimum of one path segment.
|
||||
func (b *queryBuilder) reduceAttrPaths(attrPaths []*AttrPath) (fieldPath string, err error) {
|
||||
if len(attrPaths) == 0 {
|
||||
err = serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgument(nil, "SCIM-FF431", "Invalid filter expression: unknown urn attribute prefix"))
|
||||
return fieldPath, err
|
||||
}
|
||||
|
||||
sb := strings.Builder{}
|
||||
|
||||
for _, p := range attrPaths {
|
||||
if err = p.validateSchema(b.schema); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
sb.WriteString(p.FieldPath())
|
||||
sb.WriteRune('.')
|
||||
}
|
||||
|
||||
fieldPath = sb.String()
|
||||
fieldPath = strings.TrimRight(fieldPath, ".") // trim very last '.'
|
||||
return fieldPath, err
|
||||
}
|
497
internal/api/scim/resources/filter/filter_query_builder_test.go
Normal file
497
internal/api/scim/resources/filter/filter_query_builder_test.go
Normal file
@ -0,0 +1,497 @@
|
||||
package filter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/scim/schemas"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
"github.com/zitadel/zitadel/internal/test"
|
||||
)
|
||||
|
||||
var fieldPathColumnMapping = FieldPathMapping{
|
||||
// a timestamp field
|
||||
"meta.lastmodified": {
|
||||
Column: query.UserChangeDateCol,
|
||||
FieldType: FieldTypeTimestamp,
|
||||
},
|
||||
// a string field
|
||||
"username": {
|
||||
Column: query.UserUsernameCol,
|
||||
FieldType: FieldTypeString,
|
||||
},
|
||||
// a nested string field
|
||||
"name.familyname": {
|
||||
Column: query.HumanLastNameCol,
|
||||
FieldType: FieldTypeString,
|
||||
},
|
||||
// a field which is a list in scim
|
||||
"emails": {
|
||||
Column: query.HumanEmailCol,
|
||||
FieldType: FieldTypeString,
|
||||
},
|
||||
// the default value field
|
||||
"emails.value": {
|
||||
Column: query.HumanEmailCol,
|
||||
FieldType: FieldTypeString,
|
||||
},
|
||||
// pseudo field to test number queries
|
||||
"age": {
|
||||
Column: query.HumanGenderCol,
|
||||
FieldType: FieldTypeNumber,
|
||||
},
|
||||
// pseudo field to test boolean queries
|
||||
"locked": {
|
||||
Column: query.HumanPasswordChangeRequiredCol,
|
||||
FieldType: FieldTypeBoolean,
|
||||
},
|
||||
// mapped field
|
||||
"active": {
|
||||
Column: query.UserStateCol,
|
||||
FieldType: FieldTypeCustom,
|
||||
BuildMappedQuery: func(ctx context.Context, compareValue *CompValue, op *CompareOp) (query.SearchQuery, error) {
|
||||
// very simple mock implementation
|
||||
return query.NewTextQuery(query.UserUsernameCol, "fooBar", query.TextContains)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
func TestFilter_BuildQuery(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
filter string
|
||||
want query.SearchQuery
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "unknown attribute",
|
||||
filter: `foobar eq "bjensen"`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "simple binary operator",
|
||||
filter: `userName eq "bjensen"`,
|
||||
want: test.Must(query.NewTextQuery(query.UserUsernameCol, "bjensen", query.TextEquals)),
|
||||
},
|
||||
{
|
||||
name: "binary operator equals null",
|
||||
filter: `userName eq null`,
|
||||
want: test.Must(query.NewIsNullQuery(query.UserUsernameCol)),
|
||||
},
|
||||
{
|
||||
name: "binary operator not equals null",
|
||||
filter: `userName ne null`,
|
||||
want: test.Must(query.NewNotNullQuery(query.UserUsernameCol)),
|
||||
},
|
||||
{
|
||||
name: "binary number operator on string field",
|
||||
filter: `userName gt 10`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "binary number operator greater",
|
||||
filter: `age gt 10`,
|
||||
want: test.Must(query.NewNumberQuery(query.HumanGenderCol, 10, query.NumberGreater)),
|
||||
},
|
||||
{
|
||||
name: "binary number operator greater equal",
|
||||
filter: `age ge 10`,
|
||||
want: test.Must(query.NewNumberQuery(query.HumanGenderCol, 10, query.NumberGreaterOrEqual)),
|
||||
},
|
||||
{
|
||||
name: "binary number operator less",
|
||||
filter: `age lt 10`,
|
||||
want: test.Must(query.NewNumberQuery(query.HumanGenderCol, 10, query.NumberLess)),
|
||||
},
|
||||
{
|
||||
name: "binary number operator less float",
|
||||
filter: `age lt 10.5`,
|
||||
want: test.Must(query.NewNumberQuery(query.HumanGenderCol, 10.5, query.NumberLess)),
|
||||
},
|
||||
{
|
||||
name: "binary number unsupported operator",
|
||||
filter: `age co 10.5`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "binary number unsupported comparison value",
|
||||
filter: `age gt "foo"`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "binary number operator less equal",
|
||||
filter: `age le 10`,
|
||||
want: test.Must(query.NewNumberQuery(query.HumanGenderCol, 10, query.NumberLessOrEqual)),
|
||||
},
|
||||
{
|
||||
name: "binary number operator equals",
|
||||
filter: `age eq 10`,
|
||||
want: test.Must(query.NewNumberQuery(query.HumanGenderCol, 10, query.NumberEquals)),
|
||||
},
|
||||
{
|
||||
name: "binary number operator not equals",
|
||||
filter: `age ne 10`,
|
||||
want: test.Must(query.NewNumberQuery(query.HumanGenderCol, 10, query.NumberNotEquals)),
|
||||
},
|
||||
{
|
||||
name: "binary bool operator equals string",
|
||||
filter: `locked eq "foo"`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "binary bool operator startswith bool",
|
||||
filter: `locked sw true`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "binary bool operator equals",
|
||||
filter: `locked eq true`,
|
||||
want: test.Must(query.NewBoolQuery(query.HumanPasswordChangeRequiredCol, true)),
|
||||
},
|
||||
{
|
||||
name: "binary bool operator not equals",
|
||||
filter: `locked ne true`,
|
||||
want: test.Must(query.NewBoolQuery(query.HumanPasswordChangeRequiredCol, false)),
|
||||
},
|
||||
{
|
||||
name: "binary bool operator not equals false",
|
||||
filter: `locked ne false`,
|
||||
want: test.Must(query.NewBoolQuery(query.HumanPasswordChangeRequiredCol, true)),
|
||||
},
|
||||
{
|
||||
name: "binary string invalid operator",
|
||||
filter: `username gt "test"`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "nested attribute binary operator",
|
||||
filter: `name.familyName co "O'Malley"`,
|
||||
want: test.Must(query.NewTextQuery(query.HumanLastNameCol, "O'Malley", query.TextContains)),
|
||||
},
|
||||
{
|
||||
name: "urn prefixed binary operator",
|
||||
filter: `urn:ietf:params:scim:schemas:core:2.0:User:userName sw "J"`,
|
||||
want: test.Must(query.NewTextQuery(query.UserUsernameCol, "J", query.TextStartsWith)),
|
||||
},
|
||||
{
|
||||
name: "urn prefixed nested binary operator",
|
||||
filter: `urn:ietf:params:scim:schemas:core:2.0:User:emails[value sw "hans.peter@"]`,
|
||||
want: test.Must(query.NewTextQuery(query.HumanEmailCol, "hans.peter@", query.TextStartsWith)),
|
||||
},
|
||||
{
|
||||
name: "invalid urn prefixed nested binary operator",
|
||||
filter: `urn:ietf:params:scim:schemas:core:2.0:UserFoo:emails[value sw "hans.peter@"]`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "unary operator",
|
||||
filter: `name.familyName pr`,
|
||||
want: test.Must(query.NewNotNullQuery(query.HumanLastNameCol)),
|
||||
},
|
||||
{
|
||||
name: "and logical expression",
|
||||
filter: `name.familyName pr and userName eq "bjensen"`,
|
||||
want: test.Must(query.NewAndQuery(test.Must(query.NewNotNullQuery(query.HumanLastNameCol)), test.Must(query.NewTextQuery(query.UserUsernameCol, "bjensen", query.TextEquals)))),
|
||||
},
|
||||
{
|
||||
name: "timestamp condition equal",
|
||||
filter: `meta.lastModified eq "2011-05-13T04:42:34Z"`,
|
||||
want: test.Must(query.NewTimestampQuery(query.UserChangeDateCol, time.Date(2011, time.May, 13, 4, 42, 34, 0, time.UTC), query.TimestampEquals)),
|
||||
},
|
||||
{
|
||||
name: "timestamp condition greater equals",
|
||||
filter: `meta.lastModified ge "2011-05-13T04:42:34Z"`,
|
||||
want: test.Must(query.NewTimestampQuery(query.UserChangeDateCol, time.Date(2011, time.May, 13, 4, 42, 34, 0, time.UTC), query.TimestampGreaterOrEquals)),
|
||||
},
|
||||
{
|
||||
name: "timestamp condition greater",
|
||||
filter: `meta.lastModified gt "2011-05-13T04:42:34Z"`,
|
||||
want: test.Must(query.NewTimestampQuery(query.UserChangeDateCol, time.Date(2011, time.May, 13, 4, 42, 34, 0, time.UTC), query.TimestampGreater)),
|
||||
},
|
||||
{
|
||||
name: "timestamp condition less equals",
|
||||
filter: `meta.lastModified le "2011-05-13T04:42:34Z"`,
|
||||
want: test.Must(query.NewTimestampQuery(query.UserChangeDateCol, time.Date(2011, time.May, 13, 4, 42, 34, 0, time.UTC), query.TimestampLessOrEquals)),
|
||||
},
|
||||
{
|
||||
name: "timestamp condition less",
|
||||
filter: `meta.lastModified lt "2011-05-13T04:42:34Z"`,
|
||||
want: test.Must(query.NewTimestampQuery(query.UserChangeDateCol, time.Date(2011, time.May, 13, 4, 42, 34, 0, time.UTC), query.TimestampLess)),
|
||||
},
|
||||
{
|
||||
name: "timestamp condition invalid operator",
|
||||
filter: `meta.lastModified ew "2011-05-13T04:42:34Z"`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "timestamp condition invalid format",
|
||||
filter: `meta.lastModified ge "2011-05-13T0:34Z"`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "timestamp condition invalid comparison value",
|
||||
filter: `meta.lastModified ge 15`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "nested and / or without grouping",
|
||||
filter: `userName eq "rudolpho" and emails co "example.com" or emails.value co "example2.org"`,
|
||||
want: test.Must(query.NewOrQuery(
|
||||
test.Must(query.NewAndQuery(
|
||||
test.Must(query.NewTextQuery(query.UserUsernameCol, "rudolpho", query.TextEquals)),
|
||||
test.Must(query.NewTextQuery(query.HumanEmailCol, "example.com", query.TextContains))),
|
||||
),
|
||||
test.Must(query.NewTextQuery(query.HumanEmailCol, "example2.org", query.TextContains)))),
|
||||
},
|
||||
{
|
||||
name: "nested and / or with grouping",
|
||||
filter: `userName ne "rudolpho" and (emails co "example.com" or emails.value co "example.org")`,
|
||||
want: test.Must(query.NewAndQuery(
|
||||
test.Must(query.NewTextQuery(query.UserUsernameCol, "rudolpho", query.TextNotEquals)),
|
||||
test.Must(query.NewOrQuery(
|
||||
test.Must(query.NewTextQuery(query.HumanEmailCol, "example.com", query.TextContains)),
|
||||
test.Must(query.NewTextQuery(query.HumanEmailCol, "example.org", query.TextContains)),
|
||||
)),
|
||||
)),
|
||||
},
|
||||
{
|
||||
name: "nested value path path",
|
||||
filter: `userName eq "Hans" and emails[value ew "@example.org" or value ew "@example.com"]`,
|
||||
want: test.Must(query.NewAndQuery(
|
||||
test.Must(query.NewTextQuery(query.UserUsernameCol, "Hans", query.TextEquals)),
|
||||
test.Must(query.NewOrQuery(
|
||||
test.Must(query.NewTextQuery(query.HumanEmailCol, "@example.org", query.TextEndsWith)),
|
||||
test.Must(query.NewTextQuery(query.HumanEmailCol, "@example.com", query.TextEndsWith)),
|
||||
)),
|
||||
)),
|
||||
},
|
||||
{
|
||||
name: "or value path filter",
|
||||
filter: `emails[value ew "@example.org" and value co "@example.com"] or emails[value sw "hans" or value sw "peter"]`,
|
||||
want: test.Must(query.NewOrQuery(
|
||||
test.Must(query.NewAndQuery(
|
||||
test.Must(query.NewTextQuery(query.HumanEmailCol, "@example.org", query.TextEndsWith)),
|
||||
test.Must(query.NewTextQuery(query.HumanEmailCol, "@example.com", query.TextContains)),
|
||||
)),
|
||||
test.Must(query.NewTextQuery(query.HumanEmailCol, "hans", query.TextStartsWith)),
|
||||
test.Must(query.NewTextQuery(query.HumanEmailCol, "peter", query.TextStartsWith)),
|
||||
)),
|
||||
},
|
||||
{
|
||||
name: "and value path filter",
|
||||
filter: `emails[value ew "@example.com"] and name.familyname co "hans" and username co "peter"`,
|
||||
want: test.Must(query.NewAndQuery(
|
||||
test.Must(query.NewTextQuery(query.HumanEmailCol, "@example.com", query.TextEndsWith)),
|
||||
test.Must(query.NewTextQuery(query.HumanLastNameCol, "hans", query.TextContains)),
|
||||
test.Must(query.NewTextQuery(query.UserUsernameCol, "peter", query.TextContains)),
|
||||
)),
|
||||
},
|
||||
{
|
||||
name: "negation",
|
||||
filter: `not(username eq "foo")`,
|
||||
want: test.Must(query.NewNotQuery(test.Must(query.NewTextQuery(query.UserUsernameCol, "foo", query.TextEquals)))),
|
||||
},
|
||||
{
|
||||
name: "negation with complex filter",
|
||||
filter: `not(emails[value ew "@example.com"])`,
|
||||
want: test.Must(query.NewNotQuery(test.Must(query.NewTextQuery(query.HumanEmailCol, "@example.com", query.TextEndsWith)))),
|
||||
},
|
||||
{
|
||||
name: "nested negation",
|
||||
filter: `emails[not(value ew "@example.org" or value ew "@example.com")]`,
|
||||
want: test.Must(query.NewNotQuery(
|
||||
test.Must(query.NewOrQuery(
|
||||
test.Must(query.NewTextQuery(query.HumanEmailCol, "@example.org", query.TextEndsWith)),
|
||||
test.Must(query.NewTextQuery(query.HumanEmailCol, "@example.com", query.TextEndsWith)),
|
||||
)),
|
||||
)),
|
||||
},
|
||||
{
|
||||
name: "mapped field",
|
||||
filter: `active eq true`,
|
||||
want: test.Must(query.NewTextQuery(query.UserUsernameCol, "fooBar", query.TextContains)),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
f, err := ParseFilter(tt.filter)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := f.BuildQuery(context.Background(), schemas.IdUser, fieldPathColumnMapping)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("BuildQuery() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("BuildQuery() got = %#v, want %#v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_queryBuilder_reduceAttrPaths(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
schema string
|
||||
attrPaths []*AttrPath
|
||||
wantFieldPath string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
attrPaths: []*AttrPath{},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "simple",
|
||||
attrPaths: []*AttrPath{
|
||||
{
|
||||
AttrName: "foo",
|
||||
},
|
||||
},
|
||||
wantFieldPath: "foo",
|
||||
},
|
||||
{
|
||||
name: "multiple simple",
|
||||
attrPaths: []*AttrPath{
|
||||
{
|
||||
AttrName: "foo",
|
||||
},
|
||||
{
|
||||
AttrName: "bar",
|
||||
},
|
||||
},
|
||||
wantFieldPath: "foo.bar",
|
||||
},
|
||||
{
|
||||
name: "with sub attr",
|
||||
attrPaths: []*AttrPath{
|
||||
{
|
||||
AttrName: "foo",
|
||||
SubAttr: gu.Ptr("bar"),
|
||||
},
|
||||
},
|
||||
wantFieldPath: "foo.bar",
|
||||
},
|
||||
{
|
||||
name: "multiple with sub attr",
|
||||
attrPaths: []*AttrPath{
|
||||
{
|
||||
AttrName: "foo",
|
||||
SubAttr: gu.Ptr("bar"),
|
||||
},
|
||||
{
|
||||
AttrName: "baz",
|
||||
SubAttr: gu.Ptr("woo"),
|
||||
},
|
||||
},
|
||||
wantFieldPath: "foo.bar.baz.woo",
|
||||
},
|
||||
{
|
||||
name: "with urn and sub attr",
|
||||
schema: "urn:foo:bar",
|
||||
attrPaths: []*AttrPath{
|
||||
{
|
||||
UrnAttributePrefix: gu.Ptr("urn:foo:bar:"),
|
||||
AttrName: "foo",
|
||||
SubAttr: gu.Ptr("bar"),
|
||||
},
|
||||
},
|
||||
wantFieldPath: "foo.bar",
|
||||
},
|
||||
{
|
||||
name: "multiple with urn and sub attr",
|
||||
schema: "urn:foo:bar",
|
||||
attrPaths: []*AttrPath{
|
||||
{
|
||||
UrnAttributePrefix: gu.Ptr("urn:foo:bar:"),
|
||||
AttrName: "foo",
|
||||
SubAttr: gu.Ptr("bar"),
|
||||
},
|
||||
{
|
||||
UrnAttributePrefix: gu.Ptr("urn:foo:bar:"),
|
||||
AttrName: "foo2",
|
||||
SubAttr: gu.Ptr("bar2"),
|
||||
},
|
||||
},
|
||||
wantFieldPath: "foo.bar.foo2.bar2",
|
||||
},
|
||||
{
|
||||
name: "secondary with urn and sub attr",
|
||||
schema: "urn:foo:bar",
|
||||
attrPaths: []*AttrPath{
|
||||
{
|
||||
AttrName: "foo",
|
||||
SubAttr: gu.Ptr("bar"),
|
||||
},
|
||||
{
|
||||
UrnAttributePrefix: gu.Ptr("urn:foo:bar:"),
|
||||
AttrName: "foo2",
|
||||
SubAttr: gu.Ptr("bar2"),
|
||||
},
|
||||
},
|
||||
wantFieldPath: "foo.bar.foo2.bar2",
|
||||
},
|
||||
{
|
||||
name: "urn mismatch",
|
||||
schema: "urn:foo:bar",
|
||||
attrPaths: []*AttrPath{
|
||||
{
|
||||
UrnAttributePrefix: gu.Ptr("urn:foo:baz"),
|
||||
AttrName: "foo",
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "nested urn mismatch",
|
||||
schema: "urn:foo:bar",
|
||||
attrPaths: []*AttrPath{
|
||||
{
|
||||
UrnAttributePrefix: gu.Ptr("urn:foo:bar:"),
|
||||
AttrName: "foo",
|
||||
},
|
||||
{
|
||||
UrnAttributePrefix: gu.Ptr("urn:foo:baz"),
|
||||
AttrName: "foo2",
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "secondary urn mismatch",
|
||||
schema: "urn:foo:bar",
|
||||
attrPaths: []*AttrPath{
|
||||
{
|
||||
AttrName: "foo",
|
||||
},
|
||||
{
|
||||
UrnAttributePrefix: gu.Ptr("urn:foo:baz"),
|
||||
AttrName: "foo2",
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
b := &queryBuilder{
|
||||
schema: schemas.ScimSchemaType(tt.schema),
|
||||
}
|
||||
gotFieldPath, err := b.reduceAttrPaths(tt.attrPaths)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("reduceAttrPaths() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if gotFieldPath != tt.wantFieldPath {
|
||||
t.Errorf("reduceAttrPaths() gotFieldPath = %v, want %v", gotFieldPath, tt.wantFieldPath)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -22,6 +22,7 @@ type ResourceHandler[T ResourceHolder] interface {
|
||||
Replace(ctx context.Context, id string, resource T) (T, error)
|
||||
Delete(ctx context.Context, id string) error
|
||||
Get(ctx context.Context, id string) (T, error)
|
||||
List(ctx context.Context, request *ListRequest) (*ListResponse[T], error)
|
||||
}
|
||||
|
||||
type Resource struct {
|
||||
|
@ -7,7 +7,6 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/scim/schemas"
|
||||
"github.com/zitadel/zitadel/internal/api/scim/serrors"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
@ -16,22 +15,6 @@ type ResourceHandlerAdapter[T ResourceHolder] struct {
|
||||
handler ResourceHandler[T]
|
||||
}
|
||||
|
||||
type ListRequest struct {
|
||||
// Count An integer indicating the desired maximum number of query results per page. OPTIONAL.
|
||||
Count uint64 `json:"count" schema:"count"`
|
||||
|
||||
// StartIndex An integer indicating the 1-based index of the first query result. Optional.
|
||||
StartIndex uint64 `json:"startIndex" schema:"startIndex"`
|
||||
}
|
||||
|
||||
type ListResponse[T any] struct {
|
||||
Schemas []schemas.ScimSchemaType `json:"schemas"`
|
||||
ItemsPerPage uint64 `json:"itemsPerPage"`
|
||||
TotalResults uint64 `json:"totalResults"`
|
||||
StartIndex uint64 `json:"startIndex"`
|
||||
Resources []T `json:"Resources"` // according to the rfc this is the only field in PascalCase...
|
||||
}
|
||||
|
||||
func NewResourceHandlerAdapter[T ResourceHolder](handler ResourceHandler[T]) *ResourceHandlerAdapter[T] {
|
||||
return &ResourceHandlerAdapter[T]{
|
||||
handler,
|
||||
@ -62,6 +45,15 @@ func (adapter *ResourceHandlerAdapter[T]) Delete(r *http.Request) error {
|
||||
return adapter.handler.Delete(r.Context(), id)
|
||||
}
|
||||
|
||||
func (adapter *ResourceHandlerAdapter[T]) List(r *http.Request) (*ListResponse[T], error) {
|
||||
request, err := readListRequest(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return adapter.handler.List(r.Context(), request)
|
||||
}
|
||||
|
||||
func (adapter *ResourceHandlerAdapter[T]) Get(r *http.Request) (T, error) {
|
||||
id := mux.Vars(r)["id"]
|
||||
return adapter.handler.Get(r.Context(), id)
|
||||
@ -71,7 +63,7 @@ func (adapter *ResourceHandlerAdapter[T]) readEntityFromBody(r *http.Request) (T
|
||||
entity := adapter.handler.NewResource()
|
||||
err := json.NewDecoder(r.Body).Decode(entity)
|
||||
if err != nil {
|
||||
if zerrors.IsZitadelError(err) {
|
||||
if serrors.IsScimOrZitadelError(err) {
|
||||
return entity, err
|
||||
}
|
||||
|
||||
|
148
internal/api/scim/resources/resource_list.go
Normal file
148
internal/api/scim/resources/resource_list.go
Normal file
@ -0,0 +1,148 @@
|
||||
package resources
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
zhttp "github.com/zitadel/zitadel/internal/api/http"
|
||||
"github.com/zitadel/zitadel/internal/api/scim/resources/filter"
|
||||
"github.com/zitadel/zitadel/internal/api/scim/schemas"
|
||||
"github.com/zitadel/zitadel/internal/api/scim/serrors"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
type ListRequest struct {
|
||||
// Count An integer indicating the desired maximum number of query results per page.
|
||||
Count int64 `json:"count" schema:"count"`
|
||||
|
||||
// StartIndex An integer indicating the 1-based index of the first query result.
|
||||
StartIndex int64 `json:"startIndex" schema:"startIndex"`
|
||||
|
||||
// Filter a scim filter expression to filter the query result.
|
||||
Filter *filter.Filter `json:"filter,omitempty" schema:"filter"`
|
||||
|
||||
// SortBy attribute path to the sort attribute
|
||||
SortBy string `json:"sortBy" schema:"sortBy"`
|
||||
SortOrder ListRequestSortOrder `json:"sortOrder" schema:"sortOrder"`
|
||||
}
|
||||
|
||||
type ListResponse[T ResourceHolder] struct {
|
||||
Schemas []schemas.ScimSchemaType `json:"schemas"`
|
||||
ItemsPerPage uint64 `json:"itemsPerPage"`
|
||||
TotalResults uint64 `json:"totalResults"`
|
||||
StartIndex uint64 `json:"startIndex"`
|
||||
Resources []T `json:"Resources"` // according to the rfc this is the only field in PascalCase...
|
||||
}
|
||||
|
||||
type ListRequestSortOrder string
|
||||
|
||||
const (
|
||||
ListRequestSortOrderAsc ListRequestSortOrder = "ascending"
|
||||
ListRequestSortOrderDsc ListRequestSortOrder = "descending"
|
||||
|
||||
defaultListCount = 100
|
||||
maxListCount = 100
|
||||
)
|
||||
|
||||
var parser = zhttp.NewParser()
|
||||
|
||||
func (o ListRequestSortOrder) isDefined() bool {
|
||||
switch o {
|
||||
case ListRequestSortOrderAsc, ListRequestSortOrderDsc:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (o ListRequestSortOrder) IsAscending() bool {
|
||||
return o == ListRequestSortOrderAsc
|
||||
}
|
||||
|
||||
func newListResponse[T ResourceHolder](totalResultCount uint64, q query.SearchRequest, resources []T) *ListResponse[T] {
|
||||
return &ListResponse[T]{
|
||||
Schemas: []schemas.ScimSchemaType{schemas.IdListResponse},
|
||||
ItemsPerPage: q.Limit,
|
||||
TotalResults: totalResultCount,
|
||||
StartIndex: q.Offset + 1, // start index is 1 based
|
||||
Resources: resources,
|
||||
}
|
||||
}
|
||||
|
||||
func readListRequest(r *http.Request) (*ListRequest, error) {
|
||||
request := &ListRequest{
|
||||
Count: defaultListCount,
|
||||
StartIndex: 1,
|
||||
SortOrder: ListRequestSortOrderAsc,
|
||||
}
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
if err := parser.Parse(r, request); err != nil {
|
||||
err = parser.UnwrapParserError(err)
|
||||
|
||||
if serrors.IsScimOrZitadelError(err) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, zerrors.ThrowInvalidArgument(nil, "SCIM-ullform", "Could not decode form: "+err.Error())
|
||||
}
|
||||
case http.MethodPost:
|
||||
if err := json.NewDecoder(r.Body).Decode(request); err != nil {
|
||||
if serrors.IsScimOrZitadelError(err) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, zerrors.ThrowInvalidArgument(nil, "SCIM-ulljson", "Could not decode json: "+err.Error())
|
||||
}
|
||||
|
||||
// json deserialization initializes this field if an empty string is provided
|
||||
// to not special case this in the resource implementation,
|
||||
// set it to nil here.
|
||||
if request.Filter.IsZero() {
|
||||
request.Filter = nil
|
||||
}
|
||||
}
|
||||
|
||||
return request, request.validate()
|
||||
}
|
||||
|
||||
func (r *ListRequest) toSearchRequest(defaultSortCol query.Column, fieldPathColumnMapping filter.FieldPathMapping) (query.SearchRequest, error) {
|
||||
sr := query.SearchRequest{
|
||||
Offset: uint64(r.StartIndex - 1), // start index is 1 based
|
||||
Limit: uint64(r.Count),
|
||||
Asc: r.SortOrder.IsAscending(),
|
||||
}
|
||||
|
||||
if r.SortBy == "" {
|
||||
// set a default sort to ensure consistent results
|
||||
sr.SortingColumn = defaultSortCol
|
||||
} else if sortCol, err := fieldPathColumnMapping.Resolve(r.SortBy); err != nil {
|
||||
return sr, serrors.ThrowInvalidValue(zerrors.ThrowInvalidArgument(err, "SCIM-SRT1", "SortBy field is unknown or not supported"))
|
||||
} else {
|
||||
sr.SortingColumn = sortCol.Column
|
||||
}
|
||||
|
||||
return sr, nil
|
||||
}
|
||||
|
||||
func (r *ListRequest) validate() error {
|
||||
// according to the spec values < 1 are treated as 1
|
||||
if r.StartIndex < 1 {
|
||||
r.StartIndex = 1
|
||||
}
|
||||
|
||||
// according to the spec values < 0 are treated as 0
|
||||
if r.Count < 0 {
|
||||
r.Count = 0
|
||||
} else if r.Count > maxListCount {
|
||||
return zerrors.ThrowInvalidArgumentf(nil, "SCIM-ucr", "Limit count exceeded, set a count <= %v", maxListCount)
|
||||
}
|
||||
|
||||
if !r.SortOrder.isDefined() {
|
||||
return zerrors.ThrowInvalidArgument(nil, "SCIM-ucx", "Invalid sort order")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
79
internal/api/scim/resources/resource_list_test.go
Normal file
79
internal/api/scim/resources/resource_list_test.go
Normal file
@ -0,0 +1,79 @@
|
||||
package resources
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestListRequest_validate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req *ListRequest
|
||||
want *ListRequest
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid",
|
||||
req: &ListRequest{
|
||||
SortOrder: ListRequestSortOrderAsc,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid sort order",
|
||||
req: &ListRequest{
|
||||
SortOrder: "fooBar",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "count too big",
|
||||
req: &ListRequest{
|
||||
Count: 99999999,
|
||||
SortOrder: ListRequestSortOrderAsc,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "negative start index",
|
||||
req: &ListRequest{
|
||||
StartIndex: -1,
|
||||
Count: 10,
|
||||
SortOrder: ListRequestSortOrderAsc,
|
||||
},
|
||||
want: &ListRequest{
|
||||
StartIndex: 1,
|
||||
Count: 10,
|
||||
SortOrder: ListRequestSortOrderAsc,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "negative count",
|
||||
req: &ListRequest{
|
||||
StartIndex: 10,
|
||||
Count: -1,
|
||||
SortOrder: ListRequestSortOrderAsc,
|
||||
},
|
||||
want: &ListRequest{
|
||||
StartIndex: 10,
|
||||
Count: 0,
|
||||
SortOrder: ListRequestSortOrderAsc,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.req.validate()
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
if tt.want != nil && !reflect.DeepEqual(tt.req, tt.want) {
|
||||
t.Errorf("got: %#v, want: %#v", tt.req, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -183,6 +183,35 @@ func (h *UsersHandler) Get(ctx context.Context, id string) (*ScimUser, error) {
|
||||
return h.mapToScimUser(ctx, user, metadata), nil
|
||||
}
|
||||
|
||||
func (h *UsersHandler) List(ctx context.Context, request *ListRequest) (*ListResponse[*ScimUser], error) {
|
||||
q, err := h.buildListQuery(ctx, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if request.Count == 0 {
|
||||
count, err := h.query.CountUsers(ctx, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newListResponse(count, q.SearchRequest, make([]*ScimUser, 0)), nil
|
||||
}
|
||||
|
||||
users, err := h.query.SearchUsers(ctx, q, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
metadata, err := h.queryMetadataForUsers(ctx, usersToIDs(users.Users))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
scimUsers := h.mapToScimUsers(ctx, users.Users, metadata)
|
||||
return newListResponse(users.SearchResponse.Count, q.SearchRequest, scimUsers), nil
|
||||
}
|
||||
|
||||
func (h *UsersHandler) queryUserDependencies(ctx context.Context, userID string) ([]*command.CascadingMembership, []string, error) {
|
||||
userGrantUserQuery, err := query.NewUserGrantUserIDSearchQuery(userID)
|
||||
if err != nil {
|
||||
|
@ -208,6 +208,20 @@ func (h *UsersHandler) mapChangeCommandToScimUser(ctx context.Context, user *Sci
|
||||
}
|
||||
}
|
||||
|
||||
func (h *UsersHandler) mapToScimUsers(ctx context.Context, users []*query.User, md map[string]map[metadata.ScopedKey][]byte) []*ScimUser {
|
||||
result := make([]*ScimUser, len(users))
|
||||
for i, user := range users {
|
||||
userMetadata, ok := md[user.ID]
|
||||
if !ok {
|
||||
userMetadata = make(map[metadata.ScopedKey][]byte)
|
||||
}
|
||||
|
||||
result[i] = h.mapToScimUser(ctx, user, userMetadata)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (h *UsersHandler) mapToScimUser(ctx context.Context, user *query.User, md map[metadata.ScopedKey][]byte) *ScimUser {
|
||||
scimUser := &ScimUser{
|
||||
Resource: h.buildResourceForQuery(ctx, user),
|
||||
@ -364,3 +378,11 @@ func userGrantsToIDs(userGrants []*query.UserGrant) []string {
|
||||
}
|
||||
return converted
|
||||
}
|
||||
|
||||
func usersToIDs(users []*query.User) []string {
|
||||
ids := make([]string, len(users))
|
||||
for i, user := range users {
|
||||
ids[i] = user.ID
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
@ -20,6 +20,28 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
func (h *UsersHandler) queryMetadataForUsers(ctx context.Context, userIds []string) (map[string]map[metadata.ScopedKey][]byte, error) {
|
||||
queries := h.buildMetadataQueries(ctx)
|
||||
|
||||
md, err := h.query.SearchUserMetadataForUsers(ctx, false, userIds, queries)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
metadataMap := make(map[string]map[metadata.ScopedKey][]byte, len(md.Metadata))
|
||||
for _, entry := range md.Metadata {
|
||||
userMetadata, ok := metadataMap[entry.UserID]
|
||||
if !ok {
|
||||
userMetadata = make(map[metadata.ScopedKey][]byte)
|
||||
metadataMap[entry.UserID] = userMetadata
|
||||
}
|
||||
|
||||
userMetadata[metadata.ScopedKey(entry.Key)] = entry.Value
|
||||
}
|
||||
|
||||
return metadataMap, nil
|
||||
}
|
||||
|
||||
func (h *UsersHandler) queryMetadataForUser(ctx context.Context, id string) (map[metadata.ScopedKey][]byte, error) {
|
||||
queries := h.buildMetadataQueries(ctx)
|
||||
|
||||
@ -108,15 +130,11 @@ func getValueForMetadataKey(user *ScimUser, key metadata.Key) ([]byte, error) {
|
||||
|
||||
switch key {
|
||||
// json values
|
||||
case metadata.KeyEntitlements:
|
||||
fallthrough
|
||||
case metadata.KeyIms:
|
||||
fallthrough
|
||||
case metadata.KeyPhotos:
|
||||
fallthrough
|
||||
case metadata.KeyAddresses:
|
||||
fallthrough
|
||||
case metadata.KeyRoles:
|
||||
case metadata.KeyRoles,
|
||||
metadata.KeyAddresses,
|
||||
metadata.KeyEntitlements,
|
||||
metadata.KeyIms,
|
||||
metadata.KeyPhotos:
|
||||
val, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -134,21 +152,14 @@ func getValueForMetadataKey(user *ScimUser, key metadata.Key) ([]byte, error) {
|
||||
return []byte(value.(*schemas.HttpURL).String()), nil
|
||||
|
||||
// raw values
|
||||
case metadata.KeyProvisioningDomain:
|
||||
fallthrough
|
||||
case metadata.KeyExternalId:
|
||||
fallthrough
|
||||
case metadata.KeyMiddleName:
|
||||
fallthrough
|
||||
case metadata.KeyHonorificSuffix:
|
||||
fallthrough
|
||||
case metadata.KeyHonorificPrefix:
|
||||
fallthrough
|
||||
case metadata.KeyTitle:
|
||||
fallthrough
|
||||
case metadata.KeyLocale:
|
||||
fallthrough
|
||||
case metadata.KeyTimezone:
|
||||
case metadata.KeyTimezone,
|
||||
metadata.KeyLocale,
|
||||
metadata.KeyTitle,
|
||||
metadata.KeyHonorificPrefix,
|
||||
metadata.KeyHonorificSuffix,
|
||||
metadata.KeyMiddleName,
|
||||
metadata.KeyExternalId,
|
||||
metadata.KeyProvisioningDomain:
|
||||
valueStr := value.(string)
|
||||
if valueStr == "" {
|
||||
return nil, nil
|
||||
|
160
internal/api/scim/resources/user_query_builder.go
Normal file
160
internal/api/scim/resources/user_query_builder.go
Normal file
@ -0,0 +1,160 @@
|
||||
package resources
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/api/scim/metadata"
|
||||
"github.com/zitadel/zitadel/internal/api/scim/resources/filter"
|
||||
"github.com/zitadel/zitadel/internal/api/scim/serrors"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
// fieldPathColumnMapping maps lowercase json field names of the scim user to the matching column in the projection
|
||||
// only a limited set of fields is supported
|
||||
// to ensure database performance.
|
||||
var fieldPathColumnMapping = filter.FieldPathMapping{
|
||||
"meta.created": {
|
||||
Column: query.UserCreationDateCol,
|
||||
FieldType: filter.FieldTypeTimestamp,
|
||||
},
|
||||
"meta.lastmodified": {
|
||||
Column: query.UserChangeDateCol,
|
||||
FieldType: filter.FieldTypeTimestamp,
|
||||
},
|
||||
"id": {
|
||||
Column: query.UserIDCol,
|
||||
FieldType: filter.FieldTypeString,
|
||||
},
|
||||
"username": {
|
||||
Column: query.UserUsernameCol,
|
||||
FieldType: filter.FieldTypeString,
|
||||
},
|
||||
"name.familyname": {
|
||||
Column: query.HumanLastNameCol,
|
||||
FieldType: filter.FieldTypeString,
|
||||
},
|
||||
"name.givenname": {
|
||||
Column: query.HumanFirstNameCol,
|
||||
FieldType: filter.FieldTypeString,
|
||||
},
|
||||
"emails": {
|
||||
Column: query.HumanEmailCol,
|
||||
FieldType: filter.FieldTypeString,
|
||||
},
|
||||
"emails.value": {
|
||||
Column: query.HumanEmailCol,
|
||||
FieldType: filter.FieldTypeString,
|
||||
},
|
||||
"active": {
|
||||
FieldType: filter.FieldTypeCustom,
|
||||
BuildMappedQuery: buildActiveUserStateQuery,
|
||||
},
|
||||
"externalid": {
|
||||
FieldType: filter.FieldTypeCustom,
|
||||
BuildMappedQuery: newMetadataQueryBuilder(metadata.KeyExternalId),
|
||||
},
|
||||
}
|
||||
|
||||
func (h *UsersHandler) buildListQuery(ctx context.Context, request *ListRequest) (*query.UserSearchQueries, error) {
|
||||
searchRequest, err := request.toSearchRequest(query.UserIDCol, fieldPathColumnMapping)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
q := &query.UserSearchQueries{
|
||||
SearchRequest: searchRequest,
|
||||
}
|
||||
|
||||
// the zitadel scim implementation only supports humans for now
|
||||
userTypeQuery, err := query.NewUserTypeSearchQuery(int32(domain.UserTypeHuman))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// the scim service is always limited to one organization
|
||||
// the organization is the resource owner
|
||||
orgIDQuery, err := query.NewUserResourceOwnerSearchQuery(authz.GetCtxData(ctx).OrgID, query.TextEquals)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
q.Queries = append(q.Queries, orgIDQuery, userTypeQuery)
|
||||
|
||||
if request.Filter == nil {
|
||||
return q, nil
|
||||
}
|
||||
|
||||
filterQuery, err := request.Filter.BuildQuery(ctx, h.SchemaType(), fieldPathColumnMapping)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
q.Queries = append(q.Queries, filterQuery)
|
||||
return q, nil
|
||||
}
|
||||
|
||||
func newMetadataQueryBuilder(key metadata.Key) filter.MappedQueryBuilderFunc {
|
||||
return func(ctx context.Context, compareValue *filter.CompValue, op *filter.CompareOp) (query.SearchQuery, error) {
|
||||
return buildMetadataQuery(ctx, key, compareValue, op)
|
||||
}
|
||||
}
|
||||
|
||||
func buildMetadataQuery(ctx context.Context, key metadata.Key, value *filter.CompValue, op *filter.CompareOp) (query.SearchQuery, error) {
|
||||
if value.StringValue == nil {
|
||||
return nil, serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgument(nil, "SCIM-EXid1", "invalid filter expression: unsupported comparison value"))
|
||||
}
|
||||
|
||||
var comparisonOperator query.BytesComparison
|
||||
|
||||
switch {
|
||||
case op.Equal:
|
||||
comparisonOperator = query.BytesEquals
|
||||
case op.NotEqual:
|
||||
comparisonOperator = query.BytesNotEquals
|
||||
default:
|
||||
return nil, serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgument(nil, "SCIM-EXid1", "invalid filter expression: unsupported comparison operator"))
|
||||
}
|
||||
|
||||
scopedKey := string(metadata.ScopeKey(ctx, key))
|
||||
return query.NewUserMetadataExistsQuery(scopedKey, []byte(*value.StringValue), query.TextEquals, comparisonOperator)
|
||||
}
|
||||
|
||||
func buildActiveUserStateQuery(_ context.Context, compareValue *filter.CompValue, op *filter.CompareOp) (query.SearchQuery, error) {
|
||||
if !op.Equal && !op.NotEqual {
|
||||
return nil, serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgument(nil, "SCIM-MGdg", "invalid filter expression: active unsupported comparison operator"))
|
||||
}
|
||||
|
||||
if !compareValue.BooleanTrue && !compareValue.BooleanFalse {
|
||||
return nil, serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgument(nil, "SCIM-MGdr", "invalid filter expression: active unsupported comparison value"))
|
||||
}
|
||||
|
||||
active := compareValue.BooleanTrue && op.Equal || compareValue.BooleanFalse && op.NotEqual
|
||||
if active {
|
||||
activeQuery, err := query.NewNumberQuery(query.UserStateCol, int32(domain.UserStateActive), query.NumberEquals)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
initialQuery, err := query.NewNumberQuery(query.UserStateCol, int32(domain.UserStateInitial), query.NumberEquals)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return query.NewOrQuery(initialQuery, activeQuery)
|
||||
}
|
||||
|
||||
activeQuery, err := query.NewNumberQuery(query.UserStateCol, int32(domain.UserStateActive), query.NumberNotEquals)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
initialQuery, err := query.NewNumberQuery(query.UserStateCol, int32(domain.UserStateInitial), query.NumberNotEquals)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return query.NewAndQuery(initialQuery, activeQuery)
|
||||
}
|
144
internal/api/scim/resources/user_query_builder_test.go
Normal file
144
internal/api/scim/resources/user_query_builder_test.go
Normal file
@ -0,0 +1,144 @@
|
||||
package resources
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/scim/metadata"
|
||||
"github.com/zitadel/zitadel/internal/api/scim/resources/filter"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
"github.com/zitadel/zitadel/internal/test"
|
||||
)
|
||||
|
||||
func Test_buildMetadataQuery(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key metadata.Key
|
||||
value *filter.CompValue
|
||||
op *filter.CompareOp
|
||||
want query.SearchQuery
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "equals",
|
||||
key: "foo",
|
||||
value: &filter.CompValue{StringValue: gu.Ptr("bar")},
|
||||
op: &filter.CompareOp{Equal: true},
|
||||
want: test.Must(query.NewUserMetadataExistsQuery("foo", []byte("bar"), query.TextEquals, query.BytesEquals)),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "not equals",
|
||||
key: "foo",
|
||||
value: &filter.CompValue{StringValue: gu.Ptr("bar")},
|
||||
op: &filter.CompareOp{NotEqual: true},
|
||||
want: test.Must(query.NewUserMetadataExistsQuery("foo", []byte("bar"), query.TextEquals, query.BytesNotEquals)),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "unsupported operator",
|
||||
key: "foo",
|
||||
value: &filter.CompValue{StringValue: gu.Ptr("bar")},
|
||||
op: &filter.CompareOp{StartsWith: true},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "unsupported comparison value",
|
||||
key: "foo",
|
||||
value: &filter.CompValue{Int: gu.Ptr(10)},
|
||||
op: &filter.CompareOp{Equal: true},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := buildMetadataQuery(context.Background(), tt.key, tt.value, tt.op)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("buildMetadataQuery() got = %#v, want %#v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_buildActiveUserStateQuery(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
compareValue *filter.CompValue
|
||||
compOp *filter.CompareOp
|
||||
want query.SearchQuery
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "eq true",
|
||||
compareValue: &filter.CompValue{BooleanTrue: true},
|
||||
compOp: &filter.CompareOp{Equal: true},
|
||||
want: test.Must(query.NewOrQuery(
|
||||
test.Must(query.NewNumberQuery(query.UserStateCol, int32(domain.UserStateInitial), query.NumberEquals)),
|
||||
test.Must(query.NewNumberQuery(query.UserStateCol, int32(domain.UserStateActive), query.NumberEquals)),
|
||||
)),
|
||||
},
|
||||
{
|
||||
name: "eq false",
|
||||
compareValue: &filter.CompValue{BooleanFalse: true},
|
||||
compOp: &filter.CompareOp{Equal: true},
|
||||
want: test.Must(query.NewAndQuery(
|
||||
test.Must(query.NewNumberQuery(query.UserStateCol, int32(domain.UserStateInitial), query.NumberNotEquals)),
|
||||
test.Must(query.NewNumberQuery(query.UserStateCol, int32(domain.UserStateActive), query.NumberNotEquals)),
|
||||
)),
|
||||
},
|
||||
{
|
||||
name: "ne true",
|
||||
compareValue: &filter.CompValue{BooleanTrue: true},
|
||||
compOp: &filter.CompareOp{NotEqual: true},
|
||||
want: test.Must(query.NewAndQuery(
|
||||
test.Must(query.NewNumberQuery(query.UserStateCol, int32(domain.UserStateInitial), query.NumberNotEquals)),
|
||||
test.Must(query.NewNumberQuery(query.UserStateCol, int32(domain.UserStateActive), query.NumberNotEquals)),
|
||||
)),
|
||||
},
|
||||
{
|
||||
name: "ne false",
|
||||
compareValue: &filter.CompValue{BooleanTrue: true},
|
||||
compOp: &filter.CompareOp{Equal: true},
|
||||
want: test.Must(query.NewOrQuery(
|
||||
test.Must(query.NewNumberQuery(query.UserStateCol, int32(domain.UserStateInitial), query.NumberEquals)),
|
||||
test.Must(query.NewNumberQuery(query.UserStateCol, int32(domain.UserStateActive), query.NumberEquals)),
|
||||
)),
|
||||
},
|
||||
{
|
||||
name: "invalid operator",
|
||||
compareValue: &filter.CompValue{BooleanTrue: true},
|
||||
compOp: &filter.CompareOp{StartsWith: true},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid comp value",
|
||||
compareValue: &filter.CompValue{StringValue: gu.Ptr("foo")},
|
||||
compOp: &filter.CompareOp{Equal: true},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := buildActiveUserStateQuery(context.Background(), tt.compareValue, tt.compOp)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equalf(t, tt.want, got, "buildActiveUserStateQuery(%#v, %#v)", tt.compareValue, tt.compOp)
|
||||
})
|
||||
}
|
||||
}
|
@ -10,6 +10,7 @@ const (
|
||||
idPrefixZitadelMessages = "urn:ietf:params:scim:api:zitadel:messages:2.0:"
|
||||
|
||||
IdUser ScimSchemaType = idPrefixCore + "User"
|
||||
IdListResponse ScimSchemaType = idPrefixMessages + "ListResponse"
|
||||
IdError ScimSchemaType = idPrefixMessages + "Error"
|
||||
IdZitadelErrorDetail ScimSchemaType = idPrefixZitadelMessages + "ErrorDetail"
|
||||
|
||||
|
@ -49,6 +49,13 @@ const (
|
||||
// ScimTypeInvalidSyntax The request body message structure was invalid or did
|
||||
// not conform to the request schema.
|
||||
ScimTypeInvalidSyntax scimErrorType = "invalidSyntax"
|
||||
|
||||
// ScimTypeInvalidFilter The specified filter syntax as invalid, or the
|
||||
// specified attribute and filter comparison combination is not supported.
|
||||
ScimTypeInvalidFilter scimErrorType = "invalidFilter"
|
||||
|
||||
// ScimTypeUniqueness One or more of the attribute values are already in use or are reserved.
|
||||
ScimTypeUniqueness scimErrorType = "uniqueness"
|
||||
)
|
||||
|
||||
var translator *i18n.Translator
|
||||
@ -85,6 +92,22 @@ func ThrowInvalidSyntax(parent error) error {
|
||||
}
|
||||
}
|
||||
|
||||
func ThrowInvalidFilter(parent error) error {
|
||||
return &wrappedScimError{
|
||||
Parent: parent,
|
||||
ScimType: ScimTypeInvalidFilter,
|
||||
}
|
||||
}
|
||||
|
||||
func IsScimOrZitadelError(err error) bool {
|
||||
return IsScimError(err) || zerrors.IsZitadelError(err)
|
||||
}
|
||||
|
||||
func IsScimError(err error) bool {
|
||||
var scimErr *wrappedScimError
|
||||
return errors.As(err, &scimErr)
|
||||
}
|
||||
|
||||
func (err *scimError) Error() string {
|
||||
return fmt.Sprintf("SCIM Error: %s: %s", err.ScimType, err.Detail)
|
||||
}
|
||||
@ -134,6 +157,8 @@ func mapErrorToScimErrorType(err error) scimErrorType {
|
||||
switch {
|
||||
case zerrors.IsErrorInvalidArgument(err):
|
||||
return ScimTypeInvalidValue
|
||||
case zerrors.IsErrorAlreadyExists(err):
|
||||
return ScimTypeUniqueness
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
|
@ -54,11 +54,26 @@ func mapResource[T sresources.ResourceHolder](router *mux.Router, mw zhttp_middl
|
||||
resourceRouter := router.PathPrefix("/" + path.Join(zhttp.OrgIdInPathVariable, string(handler.ResourceNamePlural()))).Subrouter()
|
||||
|
||||
resourceRouter.Handle("", mw(handleResourceCreatedResponse(adapter.Create))).Methods(http.MethodPost)
|
||||
resourceRouter.Handle("", mw(handleJsonResponse(adapter.List))).Methods(http.MethodGet)
|
||||
resourceRouter.Handle("/.search", mw(handleJsonResponse(adapter.List))).Methods(http.MethodPost)
|
||||
resourceRouter.Handle("/{id}", mw(handleResourceResponse(adapter.Get))).Methods(http.MethodGet)
|
||||
resourceRouter.Handle("/{id}", mw(handleResourceResponse(adapter.Replace))).Methods(http.MethodPut)
|
||||
resourceRouter.Handle("/{id}", mw(handleEmptyResponse(adapter.Delete))).Methods(http.MethodDelete)
|
||||
}
|
||||
|
||||
func handleJsonResponse[T any](next func(r *http.Request) (T, error)) zhttp_middlware.HandlerFuncWithError {
|
||||
return func(w http.ResponseWriter, r *http.Request) error {
|
||||
entity, err := next(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = json.NewEncoder(w).Encode(entity)
|
||||
logging.OnError(err).Warn("scim json response encoding failed")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func handleResourceCreatedResponse[T sresources.ResourceHolder](next func(*http.Request) (T, error)) zhttp_middlware.HandlerFuncWithError {
|
||||
return func(w http.ResponseWriter, r *http.Request) error {
|
||||
entity, err := next(r)
|
||||
|
@ -1,7 +1,6 @@
|
||||
package integration
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -170,99 +169,3 @@ func diffProto(expected, actual proto.Message) string {
|
||||
}
|
||||
return "\n\nDiff:\n" + diff
|
||||
}
|
||||
|
||||
func AssertMapContains[M ~map[K]V, K comparable, V any](t assert.TestingT, m M, key K, expectedValue V) {
|
||||
val, exists := m[key]
|
||||
assert.True(t, exists, "Key '%s' should exist in the map", key)
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, expectedValue, val, "Key '%s' should have value '%d'", key, expectedValue)
|
||||
}
|
||||
|
||||
// PartiallyDeepEqual is similar to reflect.DeepEqual,
|
||||
// but only compares exported non-zero fields of the expectedValue
|
||||
func PartiallyDeepEqual(expected, actual interface{}) bool {
|
||||
if expected == nil {
|
||||
return actual == nil
|
||||
}
|
||||
|
||||
if actual == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return partiallyDeepEqual(reflect.ValueOf(expected), reflect.ValueOf(actual))
|
||||
}
|
||||
|
||||
func partiallyDeepEqual(expected, actual reflect.Value) bool {
|
||||
// Dereference pointers if needed
|
||||
if expected.Kind() == reflect.Ptr {
|
||||
if expected.IsNil() {
|
||||
return true
|
||||
}
|
||||
|
||||
expected = expected.Elem()
|
||||
}
|
||||
|
||||
if actual.Kind() == reflect.Ptr {
|
||||
if actual.IsNil() {
|
||||
return false
|
||||
}
|
||||
|
||||
actual = actual.Elem()
|
||||
}
|
||||
|
||||
if expected.Type() != actual.Type() {
|
||||
return false
|
||||
}
|
||||
|
||||
switch expected.Kind() { //nolint:exhaustive
|
||||
case reflect.Struct:
|
||||
for i := 0; i < expected.NumField(); i++ {
|
||||
field := expected.Type().Field(i)
|
||||
if field.PkgPath != "" { // Skip unexported fields
|
||||
continue
|
||||
}
|
||||
|
||||
expectedField := expected.Field(i)
|
||||
actualField := actual.Field(i)
|
||||
|
||||
// Skip zero-value fields in expected
|
||||
if reflect.DeepEqual(expectedField.Interface(), reflect.Zero(expectedField.Type()).Interface()) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Compare fields recursively
|
||||
if !partiallyDeepEqual(expectedField, actualField) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
|
||||
case reflect.Slice, reflect.Array:
|
||||
if expected.Len() > actual.Len() {
|
||||
return false
|
||||
}
|
||||
|
||||
for i := 0; i < expected.Len(); i++ {
|
||||
if !partiallyDeepEqual(expected.Index(i), actual.Index(i)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
|
||||
default:
|
||||
// Compare primitive types
|
||||
return reflect.DeepEqual(expected.Interface(), actual.Interface())
|
||||
}
|
||||
}
|
||||
|
||||
func Must[T any](result T, error error) T {
|
||||
if error != nil {
|
||||
panic(error)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
@ -50,153 +50,3 @@ func TestAssertDetails(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPartiallyDeepEqual(t *testing.T) {
|
||||
type SecondaryNestedType struct {
|
||||
Value int
|
||||
}
|
||||
type NestedType struct {
|
||||
Value int
|
||||
ValueSlice []int
|
||||
Nested SecondaryNestedType
|
||||
NestedPointer *SecondaryNestedType
|
||||
}
|
||||
|
||||
type args struct {
|
||||
expected interface{}
|
||||
actual interface{}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "nil",
|
||||
args: args{
|
||||
expected: nil,
|
||||
actual: nil,
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "scalar value",
|
||||
args: args{
|
||||
expected: 10,
|
||||
actual: 10,
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "different scalar value",
|
||||
args: args{
|
||||
expected: 11,
|
||||
actual: 10,
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "string value",
|
||||
args: args{
|
||||
expected: "foo",
|
||||
actual: "foo",
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "different string value",
|
||||
args: args{
|
||||
expected: "foo2",
|
||||
actual: "foo",
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "scalar only set in actual",
|
||||
args: args{
|
||||
expected: &SecondaryNestedType{},
|
||||
actual: &SecondaryNestedType{Value: 10},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "scalar equal",
|
||||
args: args{
|
||||
expected: &SecondaryNestedType{Value: 10},
|
||||
actual: &SecondaryNestedType{Value: 10},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "scalar only set in expected",
|
||||
args: args{
|
||||
expected: &SecondaryNestedType{Value: 10},
|
||||
actual: &SecondaryNestedType{},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "ptr only set in expected",
|
||||
args: args{
|
||||
expected: &NestedType{NestedPointer: &SecondaryNestedType{Value: 10}},
|
||||
actual: &NestedType{},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "ptr only set in actual",
|
||||
args: args{
|
||||
expected: &NestedType{},
|
||||
actual: &NestedType{NestedPointer: &SecondaryNestedType{Value: 10}},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "ptr equal",
|
||||
args: args{
|
||||
expected: &NestedType{NestedPointer: &SecondaryNestedType{Value: 10}},
|
||||
actual: &NestedType{NestedPointer: &SecondaryNestedType{Value: 10}},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "nested equal",
|
||||
args: args{
|
||||
expected: &NestedType{Nested: SecondaryNestedType{Value: 10}},
|
||||
actual: &NestedType{Nested: SecondaryNestedType{Value: 10}},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "slice equal",
|
||||
args: args{
|
||||
expected: &NestedType{ValueSlice: []int{10, 20}},
|
||||
actual: &NestedType{ValueSlice: []int{10, 20}},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "slice additional in expected",
|
||||
args: args{
|
||||
expected: &NestedType{ValueSlice: []int{10, 20, 30}},
|
||||
actual: &NestedType{ValueSlice: []int{10, 20}},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "slice additional in actual",
|
||||
args: args{
|
||||
expected: &NestedType{ValueSlice: []int{10, 20}},
|
||||
actual: &NestedType{ValueSlice: []int{10, 20, 30}},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := PartiallyDeepEqual(tt.args.expected, tt.args.actual); got != tt.want {
|
||||
t.Errorf("PartiallyDeepEqual() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -6,7 +6,10 @@ import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
"google.golang.org/grpc/metadata"
|
||||
@ -40,6 +43,44 @@ type ZitadelErrorDetail struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type ListRequest struct {
|
||||
Count *int `json:"count,omitempty"`
|
||||
|
||||
// StartIndex An integer indicating the 1-based index of the first query result.
|
||||
StartIndex *int `json:"startIndex,omitempty"`
|
||||
|
||||
// Filter a scim filter expression to filter the query result.
|
||||
Filter *string `json:"filter,omitempty"`
|
||||
|
||||
SortBy *string `json:"sortBy,omitempty"`
|
||||
SortOrder *ListRequestSortOrder `json:"sortOrder,omitempty"`
|
||||
|
||||
SendAsPost bool
|
||||
}
|
||||
|
||||
type ListRequestSortOrder string
|
||||
|
||||
const (
|
||||
ListRequestSortOrderAsc ListRequestSortOrder = "ascending"
|
||||
ListRequestSortOrderDsc ListRequestSortOrder = "descending"
|
||||
)
|
||||
|
||||
type ListResponse[T any] struct {
|
||||
Schemas []schemas.ScimSchemaType `json:"schemas"`
|
||||
ItemsPerPage int `json:"itemsPerPage"`
|
||||
TotalResults int `json:"totalResults"`
|
||||
StartIndex int `json:"startIndex"`
|
||||
Resources []T `json:"Resources"`
|
||||
}
|
||||
|
||||
const (
|
||||
listQueryParamSortBy = "sortBy"
|
||||
listQueryParamSortOrder = "sortOrder"
|
||||
listQueryParamCount = "count"
|
||||
listQueryParamStartIndex = "startIndex"
|
||||
listQueryParamFilter = "filter"
|
||||
)
|
||||
|
||||
func NewScimClient(target string) *Client {
|
||||
target = "http://" + target + schemas.HandlerPrefix
|
||||
client := &http.Client{}
|
||||
@ -60,6 +101,43 @@ func (c *ResourceClient[T]) Replace(ctx context.Context, orgID, id string, body
|
||||
return c.doWithBody(ctx, http.MethodPut, orgID, id, bytes.NewReader(body))
|
||||
}
|
||||
|
||||
func (c *ResourceClient[T]) List(ctx context.Context, orgID string, req *ListRequest) (*ListResponse[*T], error) {
|
||||
if req.SendAsPost {
|
||||
listReq, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.doWithListResponse(ctx, http.MethodPost, orgID, ".search", bytes.NewReader(listReq))
|
||||
}
|
||||
|
||||
query, err := url.ParseQuery("")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if req.SortBy != nil {
|
||||
query.Set(listQueryParamSortBy, *req.SortBy)
|
||||
}
|
||||
|
||||
if req.SortOrder != nil {
|
||||
query.Set(listQueryParamSortOrder, string(*req.SortOrder))
|
||||
}
|
||||
|
||||
if req.Count != nil {
|
||||
query.Set(listQueryParamCount, strconv.Itoa(*req.Count))
|
||||
}
|
||||
|
||||
if req.StartIndex != nil {
|
||||
query.Set(listQueryParamStartIndex, strconv.Itoa(*req.StartIndex))
|
||||
}
|
||||
|
||||
if req.Filter != nil {
|
||||
query.Set(listQueryParamFilter, *req.Filter)
|
||||
}
|
||||
|
||||
return c.doWithListResponse(ctx, http.MethodGet, orgID, "?"+query.Encode(), nil)
|
||||
}
|
||||
|
||||
func (c *ResourceClient[T]) Get(ctx context.Context, orgID, resourceID string) (*T, error) {
|
||||
return c.doWithBody(ctx, http.MethodGet, orgID, resourceID, nil)
|
||||
}
|
||||
@ -77,6 +155,17 @@ func (c *ResourceClient[T]) do(ctx context.Context, method, orgID, url string) e
|
||||
return c.doReq(req, nil)
|
||||
}
|
||||
|
||||
func (c *ResourceClient[T]) doWithListResponse(ctx context.Context, method, orgID, url string, body io.Reader) (*ListResponse[*T], error) {
|
||||
req, err := http.NewRequestWithContext(ctx, method, c.buildURL(orgID, url), body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set(zhttp.ContentType, middleware.ContentTypeScim)
|
||||
response := new(ListResponse[*T])
|
||||
return response, c.doReq(req, response)
|
||||
}
|
||||
|
||||
func (c *ResourceClient[T]) doWithBody(ctx context.Context, method, orgID, url string, body io.Reader) (*T, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, method, c.buildURL(orgID, url), body)
|
||||
if err != nil {
|
||||
@ -88,7 +177,7 @@ func (c *ResourceClient[T]) doWithBody(ctx context.Context, method, orgID, url s
|
||||
return responseEntity, c.doReq(req, responseEntity)
|
||||
}
|
||||
|
||||
func (c *ResourceClient[T]) doReq(req *http.Request, responseEntity *T) error {
|
||||
func (c *ResourceClient[T]) doReq(req *http.Request, responseEntity interface{}) error {
|
||||
addTokenAsHeader(req)
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
@ -141,8 +230,8 @@ func readScimError(resp *http.Response) error {
|
||||
}
|
||||
|
||||
func (c *ResourceClient[T]) buildURL(orgID, segment string) string {
|
||||
if segment == "" {
|
||||
return c.baseUrl + "/" + path.Join(orgID, c.resourceName)
|
||||
if segment == "" || strings.HasPrefix(segment, "?") {
|
||||
return c.baseUrl + "/" + path.Join(orgID, c.resourceName) + segment
|
||||
}
|
||||
|
||||
return c.baseUrl + "/" + path.Join(orgID, c.resourceName, segment)
|
||||
|
@ -110,7 +110,7 @@ func mockQueries(stmt string, cols []string, rows [][]driver.Value, args ...driv
|
||||
result := m.NewRows(cols)
|
||||
count := uint64(len(rows))
|
||||
for _, row := range rows {
|
||||
if cols[len(cols)-1] == "count" {
|
||||
if cols[len(cols)-1] == "count" && len(row) == len(cols)-1 {
|
||||
row = append(row, count)
|
||||
}
|
||||
result.AddRow(row...)
|
||||
|
@ -109,6 +109,10 @@ func NewOrQuery(queries ...SearchQuery) (*OrQuery, error) {
|
||||
return &OrQuery{queries: queries}, nil
|
||||
}
|
||||
|
||||
func (q *OrQuery) Prepend(queries ...SearchQuery) {
|
||||
q.queries = append(queries, q.queries...)
|
||||
}
|
||||
|
||||
func (q *OrQuery) toQuery(query sq.SelectBuilder) sq.SelectBuilder {
|
||||
return query.Where(q.comp())
|
||||
}
|
||||
@ -147,6 +151,10 @@ func (q *AndQuery) comp() sq.Sqlizer {
|
||||
return and
|
||||
}
|
||||
|
||||
func (q *AndQuery) Prepend(queries ...SearchQuery) {
|
||||
q.queries = append(queries, q.queries...)
|
||||
}
|
||||
|
||||
type NotQuery struct {
|
||||
query SearchQuery
|
||||
}
|
||||
@ -406,8 +414,12 @@ func (q *NumberQuery) comp() sq.Sqlizer {
|
||||
return sq.NotEq{q.Column.identifier(): q.Number}
|
||||
case NumberLess:
|
||||
return sq.Lt{q.Column.identifier(): q.Number}
|
||||
case NumberLessOrEqual:
|
||||
return sq.LtOrEq{q.Column.identifier(): q.Number}
|
||||
case NumberGreater:
|
||||
return sq.Gt{q.Column.identifier(): q.Number}
|
||||
case NumberGreaterOrEqual:
|
||||
return sq.GtOrEq{q.Column.identifier(): q.Number}
|
||||
case NumberListContains:
|
||||
return &listContains{col: q.Column, args: []interface{}{q.Number}}
|
||||
case numberCompareMax:
|
||||
@ -423,7 +435,9 @@ const (
|
||||
NumberEquals NumberComparison = iota
|
||||
NumberNotEquals
|
||||
NumberLess
|
||||
NumberLessOrEqual
|
||||
NumberGreater
|
||||
NumberGreaterOrEqual
|
||||
NumberListContains
|
||||
|
||||
numberCompareMax
|
||||
@ -588,6 +602,57 @@ func (q *BoolQuery) comp() sq.Sqlizer {
|
||||
return sq.Eq{q.Column.identifier(): q.Value}
|
||||
}
|
||||
|
||||
type BytesComparison int
|
||||
|
||||
const (
|
||||
BytesEquals BytesComparison = iota
|
||||
BytesNotEquals
|
||||
bytesCompareMax
|
||||
)
|
||||
|
||||
type BytesQuery struct {
|
||||
Column Column
|
||||
Compare BytesComparison
|
||||
Value []byte
|
||||
}
|
||||
|
||||
func NewBytesQuery(col Column, values []byte, comparison BytesComparison) (*BytesQuery, error) {
|
||||
if col.isZero() {
|
||||
return nil, ErrMissingColumn
|
||||
}
|
||||
|
||||
if comparison < 0 || comparison >= bytesCompareMax {
|
||||
return nil, ErrInvalidCompare
|
||||
}
|
||||
|
||||
return &BytesQuery{
|
||||
Column: col,
|
||||
Value: values,
|
||||
Compare: comparison,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (q *BytesQuery) Col() Column {
|
||||
return q.Column
|
||||
}
|
||||
|
||||
func (q *BytesQuery) toQuery(query sq.SelectBuilder) sq.SelectBuilder {
|
||||
return query.Where(q.comp())
|
||||
}
|
||||
|
||||
func (q *BytesQuery) comp() sq.Sqlizer {
|
||||
switch q.Compare {
|
||||
case BytesEquals:
|
||||
return sq.Eq{q.Column.identifier(): q.Value}
|
||||
case BytesNotEquals:
|
||||
return sq.NotEq{q.Column.identifier(): q.Value}
|
||||
case bytesCompareMax:
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type TimestampComparison int
|
||||
|
||||
const (
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"testing"
|
||||
|
||||
sq "github.com/Masterminds/squirrel"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
)
|
||||
@ -1540,6 +1541,17 @@ func TestNumberQuery_comp(t *testing.T) {
|
||||
query: sq.Lt{"test_table.test_col": 42},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "less or equal",
|
||||
fields: fields{
|
||||
Column: testCol,
|
||||
Number: 42,
|
||||
Compare: NumberLessOrEqual,
|
||||
},
|
||||
want: want{
|
||||
query: sq.LtOrEq{"test_table.test_col": 42},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "greater",
|
||||
fields: fields{
|
||||
@ -1551,6 +1563,17 @@ func TestNumberQuery_comp(t *testing.T) {
|
||||
query: sq.Gt{"test_table.test_col": 42},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "greater or equal",
|
||||
fields: fields{
|
||||
Column: testCol,
|
||||
Number: 42,
|
||||
Compare: NumberGreaterOrEqual,
|
||||
},
|
||||
want: want{
|
||||
query: sq.GtOrEq{"test_table.test_col": 42},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "list containts",
|
||||
fields: fields{
|
||||
@ -2193,3 +2216,98 @@ func TestInTextQuery_comp(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBytesQuery_comp(t *testing.T) {
|
||||
type fields struct {
|
||||
Column Column
|
||||
Value []byte
|
||||
Compare BytesComparison
|
||||
}
|
||||
type want struct {
|
||||
query interface{}
|
||||
err bool
|
||||
isNil bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want want
|
||||
}{
|
||||
{
|
||||
name: "equals",
|
||||
fields: fields{
|
||||
Column: testCol,
|
||||
Value: []byte("foo"),
|
||||
Compare: BytesEquals,
|
||||
},
|
||||
want: want{
|
||||
query: sq.Eq{"test_table.test_col": []byte("foo")},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "not equals",
|
||||
fields: fields{
|
||||
Column: testCol,
|
||||
Value: []byte("foo"),
|
||||
Compare: BytesNotEquals,
|
||||
},
|
||||
want: want{
|
||||
query: sq.NotEq{"test_table.test_col": []byte("foo")},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unknown comparison",
|
||||
fields: fields{
|
||||
Column: testCol,
|
||||
Value: []byte("foo"),
|
||||
Compare: -1,
|
||||
},
|
||||
want: want{
|
||||
err: true,
|
||||
isNil: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "zero col",
|
||||
fields: fields{
|
||||
Column: Column{},
|
||||
Value: []byte("foo"),
|
||||
Compare: BytesEquals,
|
||||
},
|
||||
want: want{
|
||||
err: true,
|
||||
query: sq.Eq{"": []byte("foo")},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s, err := NewBytesQuery(tt.fields.Column, tt.fields.Value, tt.fields.Compare)
|
||||
if tt.want.err {
|
||||
require.Error(t, err)
|
||||
|
||||
// still test comp
|
||||
s = &BytesQuery{
|
||||
Column: tt.fields.Column,
|
||||
Value: tt.fields.Value,
|
||||
Compare: tt.fields.Compare,
|
||||
}
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
query := s.comp()
|
||||
|
||||
if tt.want.isNil {
|
||||
require.Nil(t, query)
|
||||
return
|
||||
}
|
||||
|
||||
require.NotNil(t, query)
|
||||
|
||||
if !reflect.DeepEqual(query, tt.want.query) {
|
||||
t.Errorf("wrong query: want: %v, (%T), got: %v, (%T)", tt.want.query, tt.want.query, query, query)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -604,6 +604,27 @@ func (q *Queries) GetNotifyUser(ctx context.Context, shouldTriggered bool, queri
|
||||
return user, err
|
||||
}
|
||||
|
||||
func (q *Queries) CountUsers(ctx context.Context, queries *UserSearchQueries) (count uint64, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
query, scan := prepareCountUsersQuery()
|
||||
eq := sq.Eq{UserInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID()}
|
||||
stmt, args, err := queries.toQuery(query).Where(eq).ToSql()
|
||||
if err != nil {
|
||||
return 0, zerrors.ThrowInternal(err, "QUERY-w3Dx", "Errors.Query.SQLStatment")
|
||||
}
|
||||
|
||||
err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
|
||||
count, err = scan(rows)
|
||||
return err
|
||||
}, stmt, args...)
|
||||
if err != nil {
|
||||
return 0, zerrors.ThrowInternal(err, "QUERY-AG4gs", "Errors.Internal")
|
||||
}
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (q *Queries) SearchUsers(ctx context.Context, queries *UserSearchQueries, permissionCheck domain.PermissionCheck) (*Users, error) {
|
||||
users, err := q.searchUsers(ctx, queries, permissionCheck != nil && authz.GetFeatures(ctx).PermissionCheckV2)
|
||||
if err != nil {
|
||||
@ -1278,6 +1299,24 @@ func scanNotifyUser(row *sql.Row) (*NotifyUser, error) {
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func prepareCountUsersQuery() (sq.SelectBuilder, func(*sql.Rows) (uint64, error)) {
|
||||
return sq.Select(countColumn.identifier()).
|
||||
From(userTable.identifier()).
|
||||
LeftJoin(join(HumanUserIDCol, UserIDCol)).
|
||||
LeftJoin(join(MachineUserIDCol, UserIDCol)).
|
||||
PlaceholderFormat(sq.Dollar),
|
||||
func(rows *sql.Rows) (count uint64, err error) {
|
||||
// the count is implemented as a windowing function,
|
||||
// if it is zero, no row is returned at all.
|
||||
if !rows.Next() {
|
||||
return
|
||||
}
|
||||
|
||||
err = rows.Scan(&count)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func prepareUserUniqueQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (bool, error)) {
|
||||
return sq.Select(
|
||||
UserIDCol.identifier(),
|
||||
|
@ -24,6 +24,7 @@ type UserMetadataList struct {
|
||||
|
||||
type UserMetadata struct {
|
||||
CreationDate time.Time `json:"creation_date,omitempty"`
|
||||
UserID string `json:"-"`
|
||||
ChangeDate time.Time `json:"change_date,omitempty"`
|
||||
ResourceOwner string `json:"resource_owner,omitempty"`
|
||||
Sequence uint64 `json:"sequence,omitempty"`
|
||||
@ -107,6 +108,38 @@ func (q *Queries) GetUserMetadataByKey(ctx context.Context, shouldTriggerBulk bo
|
||||
return metadata, err
|
||||
}
|
||||
|
||||
func (q *Queries) SearchUserMetadataForUsers(ctx context.Context, shouldTriggerBulk bool, userIDs []string, queries *UserMetadataSearchQueries) (metadata *UserMetadataList, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
if shouldTriggerBulk {
|
||||
_, traceSpan := tracing.NewNamedSpan(ctx, "TriggerUserMetadataProjection")
|
||||
ctx, err = projection.UserMetadataProjection.Trigger(ctx, handler.WithAwaitRunning())
|
||||
logging.OnError(err).Debug("trigger failed")
|
||||
traceSpan.EndWithError(err)
|
||||
}
|
||||
|
||||
query, scan := prepareUserMetadataListQuery(ctx, q.client)
|
||||
eq := sq.Eq{
|
||||
UserMetadataUserIDCol.identifier(): userIDs,
|
||||
UserMetadataInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID(),
|
||||
}
|
||||
stmt, args, err := queries.toQuery(query).Where(eq).ToSql()
|
||||
if err != nil {
|
||||
return nil, zerrors.ThrowInternal(err, "QUERY-Egbgd", "Errors.Query.SQLStatment")
|
||||
}
|
||||
|
||||
err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
|
||||
metadata, err = scan(rows)
|
||||
return err
|
||||
}, stmt, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
metadata.State, err = q.latestState(ctx, userMetadataTable)
|
||||
return metadata, err
|
||||
}
|
||||
|
||||
func (q *Queries) SearchUserMetadata(ctx context.Context, shouldTriggerBulk bool, userID string, queries *UserMetadataSearchQueries, withOwnerRemoved bool) (metadata *UserMetadataList, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
@ -164,6 +197,44 @@ func NewUserMetadataKeySearchQuery(value string, comparison TextComparison) (Sea
|
||||
return NewTextQuery(UserMetadataKeyCol, value, comparison)
|
||||
}
|
||||
|
||||
func NewUserMetadataExistsQuery(key string, value []byte, keyComparison TextComparison, valueComparison BytesComparison) (SearchQuery, error) {
|
||||
// linking queries for the subselect
|
||||
instanceQuery, err := NewColumnComparisonQuery(UserMetadataInstanceIDCol, UserInstanceIDCol, ColumnEquals)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userIDQuery, err := NewColumnComparisonQuery(UserMetadataUserIDCol, UserIDCol, ColumnEquals)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// text query to select data from the linked sub select
|
||||
metadataKeyQuery, err := NewTextQuery(UserMetadataKeyCol, key, keyComparison)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// text query to select data from the linked sub select
|
||||
metadataValueQuery, err := NewBytesQuery(UserMetadataValueCol, value, valueComparison)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// full definition of the sub select
|
||||
subSelect, err := NewSubSelect(UserMetadataUserIDCol, []SearchQuery{instanceQuery, userIDQuery, metadataKeyQuery, metadataValueQuery})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// "WHERE * IN (*)" query with subquery as list-data provider
|
||||
return NewListQuery(
|
||||
UserIDCol,
|
||||
subSelect,
|
||||
ListIn,
|
||||
)
|
||||
}
|
||||
|
||||
func prepareUserMetadataQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*UserMetadata, error)) {
|
||||
return sq.Select(
|
||||
UserMetadataCreationDateCol.identifier(),
|
||||
@ -200,6 +271,7 @@ func prepareUserMetadataListQuery(ctx context.Context, db prepareDatabase) (sq.S
|
||||
return sq.Select(
|
||||
UserMetadataCreationDateCol.identifier(),
|
||||
UserMetadataChangeDateCol.identifier(),
|
||||
UserMetadataUserIDCol.identifier(),
|
||||
UserMetadataResourceOwnerCol.identifier(),
|
||||
UserMetadataSequenceCol.identifier(),
|
||||
UserMetadataKeyCol.identifier(),
|
||||
@ -215,6 +287,7 @@ func prepareUserMetadataListQuery(ctx context.Context, db prepareDatabase) (sq.S
|
||||
err := rows.Scan(
|
||||
&m.CreationDate,
|
||||
&m.ChangeDate,
|
||||
&m.UserID,
|
||||
&m.ResourceOwner,
|
||||
&m.Sequence,
|
||||
&m.Key,
|
||||
|
@ -30,6 +30,7 @@ var (
|
||||
}
|
||||
userMetadataListQuery = `SELECT projections.user_metadata5.creation_date,` +
|
||||
` projections.user_metadata5.change_date,` +
|
||||
` projections.user_metadata5.user_id,` +
|
||||
` projections.user_metadata5.resource_owner,` +
|
||||
` projections.user_metadata5.sequence,` +
|
||||
` projections.user_metadata5.key,` +
|
||||
@ -39,6 +40,7 @@ var (
|
||||
userMetadataListCols = []string{
|
||||
"creation_date",
|
||||
"change_date",
|
||||
"user_id",
|
||||
"resource_owner",
|
||||
"sequence",
|
||||
"key",
|
||||
@ -148,6 +150,7 @@ func Test_UserMetadataPrepares(t *testing.T) {
|
||||
{
|
||||
testNow,
|
||||
testNow,
|
||||
"1",
|
||||
"resource_owner",
|
||||
uint64(20211108),
|
||||
"key",
|
||||
@ -164,6 +167,7 @@ func Test_UserMetadataPrepares(t *testing.T) {
|
||||
{
|
||||
CreationDate: testNow,
|
||||
ChangeDate: testNow,
|
||||
UserID: "1",
|
||||
ResourceOwner: "resource_owner",
|
||||
Sequence: 20211108,
|
||||
Key: "key",
|
||||
@ -183,6 +187,7 @@ func Test_UserMetadataPrepares(t *testing.T) {
|
||||
{
|
||||
testNow,
|
||||
testNow,
|
||||
"1",
|
||||
"resource_owner",
|
||||
uint64(20211108),
|
||||
"key",
|
||||
@ -191,6 +196,7 @@ func Test_UserMetadataPrepares(t *testing.T) {
|
||||
{
|
||||
testNow,
|
||||
testNow,
|
||||
"2",
|
||||
"resource_owner",
|
||||
uint64(20211108),
|
||||
"key2",
|
||||
@ -207,6 +213,7 @@ func Test_UserMetadataPrepares(t *testing.T) {
|
||||
{
|
||||
CreationDate: testNow,
|
||||
ChangeDate: testNow,
|
||||
UserID: "1",
|
||||
ResourceOwner: "resource_owner",
|
||||
Sequence: 20211108,
|
||||
Key: "key",
|
||||
@ -215,6 +222,7 @@ func Test_UserMetadataPrepares(t *testing.T) {
|
||||
{
|
||||
CreationDate: testNow,
|
||||
ChangeDate: testNow,
|
||||
UserID: "2",
|
||||
ResourceOwner: "resource_owner",
|
||||
Sequence: 20211108,
|
||||
Key: "key2",
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
@ -530,6 +531,8 @@ var (
|
||||
"access_token_type",
|
||||
"count",
|
||||
}
|
||||
countUsersQuery = "SELECT COUNT(*) OVER () FROM projections.users13"
|
||||
countUsersCols = []string{"count"}
|
||||
)
|
||||
|
||||
func Test_UserPrepares(t *testing.T) {
|
||||
@ -1508,10 +1511,67 @@ func Test_UserPrepares(t *testing.T) {
|
||||
},
|
||||
object: (*Users)(nil),
|
||||
},
|
||||
{
|
||||
name: "prepareCountUsersQuery no result",
|
||||
prepare: prepareCountUsersQuery,
|
||||
want: want{
|
||||
sqlExpectations: mockQuery(
|
||||
regexp.QuoteMeta(countUsersQuery),
|
||||
nil,
|
||||
nil,
|
||||
),
|
||||
},
|
||||
object: uint64(0),
|
||||
},
|
||||
{
|
||||
name: "prepareCountUsersQuery one result",
|
||||
prepare: prepareCountUsersQuery,
|
||||
want: want{
|
||||
sqlExpectations: mockQueries(
|
||||
regexp.QuoteMeta(countUsersQuery),
|
||||
countUsersCols,
|
||||
[][]driver.Value{{uint64(1)}},
|
||||
),
|
||||
},
|
||||
object: uint64(1),
|
||||
},
|
||||
{
|
||||
name: "prepareCountUsersQuery multiple results",
|
||||
prepare: prepareCountUsersQuery,
|
||||
want: want{
|
||||
sqlExpectations: mockQueries(
|
||||
regexp.QuoteMeta(countUsersQuery),
|
||||
countUsersCols,
|
||||
[][]driver.Value{{uint64(2)}},
|
||||
),
|
||||
},
|
||||
object: uint64(2),
|
||||
},
|
||||
{
|
||||
name: "prepareCountUsersQuery sql err",
|
||||
prepare: prepareCountUsersQuery,
|
||||
want: want{
|
||||
sqlExpectations: mockQueryErr(
|
||||
regexp.QuoteMeta(countUsersQuery),
|
||||
sql.ErrConnDone,
|
||||
),
|
||||
err: func(err error) (error, bool) {
|
||||
if !errors.Is(err, sql.ErrConnDone) {
|
||||
return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false
|
||||
}
|
||||
return nil, true
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...)
|
||||
params := defaultPrepareArgs
|
||||
if reflect.TypeOf(tt.prepare).NumIn() == 0 {
|
||||
params = []reflect.Value{}
|
||||
}
|
||||
|
||||
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, params...)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
103
internal/test/assert.go
Normal file
103
internal/test/assert.go
Normal file
@ -0,0 +1,103 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func AssertMapContains[M ~map[K]V, K comparable, V any](t assert.TestingT, m M, key K, expectedValue V) {
|
||||
val, exists := m[key]
|
||||
assert.True(t, exists, "Key '%s' should exist in the map", key)
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, expectedValue, val, "Key '%s' should have value '%d'", key, expectedValue)
|
||||
}
|
||||
|
||||
// PartiallyDeepEqual is similar to reflect.DeepEqual,
|
||||
// but only compares exported non-zero fields of the expectedValue
|
||||
func PartiallyDeepEqual(expected, actual interface{}) bool {
|
||||
if expected == nil {
|
||||
return actual == nil
|
||||
}
|
||||
|
||||
if actual == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return partiallyDeepEqual(reflect.ValueOf(expected), reflect.ValueOf(actual))
|
||||
}
|
||||
|
||||
func partiallyDeepEqual(expected, actual reflect.Value) bool {
|
||||
// Dereference pointers if needed
|
||||
if expected.Kind() == reflect.Ptr {
|
||||
if expected.IsNil() {
|
||||
return true
|
||||
}
|
||||
|
||||
expected = expected.Elem()
|
||||
}
|
||||
|
||||
if actual.Kind() == reflect.Ptr {
|
||||
if actual.IsNil() {
|
||||
return false
|
||||
}
|
||||
|
||||
actual = actual.Elem()
|
||||
}
|
||||
|
||||
if expected.Type() != actual.Type() {
|
||||
return false
|
||||
}
|
||||
|
||||
switch expected.Kind() { //nolint:exhaustive
|
||||
case reflect.Struct:
|
||||
for i := 0; i < expected.NumField(); i++ {
|
||||
field := expected.Type().Field(i)
|
||||
if field.PkgPath != "" { // Skip unexported fields
|
||||
continue
|
||||
}
|
||||
|
||||
expectedField := expected.Field(i)
|
||||
actualField := actual.Field(i)
|
||||
|
||||
// Skip zero-value fields in expected
|
||||
if reflect.DeepEqual(expectedField.Interface(), reflect.Zero(expectedField.Type()).Interface()) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Compare fields recursively
|
||||
if !partiallyDeepEqual(expectedField, actualField) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
|
||||
case reflect.Slice, reflect.Array:
|
||||
if expected.Len() > actual.Len() {
|
||||
return false
|
||||
}
|
||||
|
||||
for i := 0; i < expected.Len(); i++ {
|
||||
if !partiallyDeepEqual(expected.Index(i), actual.Index(i)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
|
||||
default:
|
||||
// Compare primitive types
|
||||
return reflect.DeepEqual(expected.Interface(), actual.Interface())
|
||||
}
|
||||
}
|
||||
|
||||
func Must[T any](result T, error error) T {
|
||||
if error != nil {
|
||||
panic(error)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
153
internal/test/assert_test.go
Normal file
153
internal/test/assert_test.go
Normal file
@ -0,0 +1,153 @@
|
||||
package test
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestPartiallyDeepEqual(t *testing.T) {
|
||||
type SecondaryNestedType struct {
|
||||
Value int
|
||||
}
|
||||
type NestedType struct {
|
||||
Value int
|
||||
ValueSlice []int
|
||||
Nested SecondaryNestedType
|
||||
NestedPointer *SecondaryNestedType
|
||||
}
|
||||
|
||||
type args struct {
|
||||
expected interface{}
|
||||
actual interface{}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "nil",
|
||||
args: args{
|
||||
expected: nil,
|
||||
actual: nil,
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "scalar value",
|
||||
args: args{
|
||||
expected: 10,
|
||||
actual: 10,
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "different scalar value",
|
||||
args: args{
|
||||
expected: 11,
|
||||
actual: 10,
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "string value",
|
||||
args: args{
|
||||
expected: "foo",
|
||||
actual: "foo",
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "different string value",
|
||||
args: args{
|
||||
expected: "foo2",
|
||||
actual: "foo",
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "scalar only set in actual",
|
||||
args: args{
|
||||
expected: &SecondaryNestedType{},
|
||||
actual: &SecondaryNestedType{Value: 10},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "scalar equal",
|
||||
args: args{
|
||||
expected: &SecondaryNestedType{Value: 10},
|
||||
actual: &SecondaryNestedType{Value: 10},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "scalar only set in expected",
|
||||
args: args{
|
||||
expected: &SecondaryNestedType{Value: 10},
|
||||
actual: &SecondaryNestedType{},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "ptr only set in expected",
|
||||
args: args{
|
||||
expected: &NestedType{NestedPointer: &SecondaryNestedType{Value: 10}},
|
||||
actual: &NestedType{},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "ptr only set in actual",
|
||||
args: args{
|
||||
expected: &NestedType{},
|
||||
actual: &NestedType{NestedPointer: &SecondaryNestedType{Value: 10}},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "ptr equal",
|
||||
args: args{
|
||||
expected: &NestedType{NestedPointer: &SecondaryNestedType{Value: 10}},
|
||||
actual: &NestedType{NestedPointer: &SecondaryNestedType{Value: 10}},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "nested equal",
|
||||
args: args{
|
||||
expected: &NestedType{Nested: SecondaryNestedType{Value: 10}},
|
||||
actual: &NestedType{Nested: SecondaryNestedType{Value: 10}},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "slice equal",
|
||||
args: args{
|
||||
expected: &NestedType{ValueSlice: []int{10, 20}},
|
||||
actual: &NestedType{ValueSlice: []int{10, 20}},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "slice additional in expected",
|
||||
args: args{
|
||||
expected: &NestedType{ValueSlice: []int{10, 20, 30}},
|
||||
actual: &NestedType{ValueSlice: []int{10, 20}},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "slice additional in actual",
|
||||
args: args{
|
||||
expected: &NestedType{ValueSlice: []int{10, 20}},
|
||||
actual: &NestedType{ValueSlice: []int{10, 20, 30}},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := PartiallyDeepEqual(tt.args.expected, tt.args.actual); got != tt.want {
|
||||
t.Errorf("PartiallyDeepEqual() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user