feat: list users scim v2 endpoint (#9187)

# Which Problems Are Solved
- Adds support for the list users SCIM v2 endpoint

# How the Problems Are Solved
- Adds support for the list users SCIM v2 endpoints under `GET
/scim/v2/{orgID}/Users` and `POST /scim/v2/{orgID}/Users/.search`

# Additional Changes
- adds a new function `SearchUserMetadataForUsers` to the query layer to
query a metadata keyset for given user ids
- adds a new function `NewUserMetadataExistsQuery` to the query layer to
query a given metadata key value pair exists
- adds a new function `CountUsers` to the query layer to count users
without reading any rows
- handle `ErrorAlreadyExists` as scim errors `uniqueness`
- adds `NumberLessOrEqual` and `NumberGreaterOrEqual` query comparison
methods
- adds `BytesQuery` with `BytesEquals` and `BytesNotEquals` query
comparison methods

# Additional Context
Part of #8140
Supported fields for scim filters:
* `meta.created`
* `meta.lastModified`
* `id`
* `username`
* `name.familyName`
* `name.givenName`
* `emails` and `emails.value`
* `active` only eq and ne
* `externalId` only eq and ne
This commit is contained in:
Lars 2025-01-21 13:31:54 +01:00 committed by GitHub
parent 926e7169b2
commit 1915d35605
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
37 changed files with 4173 additions and 417 deletions

1
go.mod
View File

@ -10,6 +10,7 @@ require (
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/trace v1.24.0
github.com/Masterminds/squirrel v1.5.4
github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b
github.com/alecthomas/participle/v2 v2.1.1
github.com/alicebob/miniredis/v2 v2.33.0
github.com/benbjohnson/clock v1.3.5
github.com/boombuler/barcode v1.0.2

8
go.sum
View File

@ -49,6 +49,12 @@ github.com/ajstarks/deck v0.0.0-20200831202436-30c9fc6549a9/go.mod h1:JynElWSGnm
github.com/ajstarks/deck/generate v0.0.0-20210309230005-c3f852c02e19/go.mod h1:T13YZdzov6OU0A1+RfKZiZN9ca6VeKdBdyDV+BY97Tk=
github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b h1:slYM766cy2nI3BwyRiyQj/Ud48djTMtMebDqepE95rw=
github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b/go.mod h1:1KcenG0jGWcpt8ov532z81sp/kMMUG485J2InIOyADM=
github.com/alecthomas/assert/v2 v2.3.0 h1:mAsH2wmvjsuvyBvAmCtm7zFsBlb8mIHx5ySLVdDZXL0=
github.com/alecthomas/assert/v2 v2.3.0/go.mod h1:pXcQ2Asjp247dahGEmsZ6ru0UVwnkhktn7S0bBDLxvQ=
github.com/alecthomas/participle/v2 v2.1.1 h1:hrjKESvSqGHzRb4yW1ciisFJ4p3MGYih6icjJvbsmV8=
github.com/alecthomas/participle/v2 v2.1.1/go.mod h1:Y1+hAs8DHPmc3YUFzqllV+eSQ9ljPTk0ZkPMtEdAx2c=
github.com/alecthomas/repr v0.2.0 h1:HAzS41CIzNW5syS8Mf9UwXhNH1J9aix/BvDRf1Ml2Yk=
github.com/alecthomas/repr v0.2.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
@ -400,6 +406,8 @@ github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO
github.com/hashicorp/mdns v1.0.0/go.mod h1:tL+uN++7HEJ6SQLQ2/p+z2pH24WQKWjBPkE0mNTz8vQ=
github.com/hashicorp/memberlist v0.1.3/go.mod h1:ajVTdAv/9Im8oMAAj5G31PhhMCZJV2pPBoIllUwCN7I=
github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/JwenrHc=
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/hudl/fargo v1.3.0/go.mod h1:y3CKSmjA+wD2gak7sUSXTAoopbhU08POFhmITJgmKTg=
github.com/improbable-eng/grpc-web v0.15.0 h1:BN+7z6uNXZ1tQGcNAuaU1YjsLTApzkjt2tzCixLaUPQ=

View File

@ -1,6 +1,7 @@
package http
import (
"errors"
"net/http"
"github.com/gorilla/schema"
@ -26,3 +27,24 @@ func (p *Parser) Parse(r *http.Request, data interface{}) error {
return p.decoder.Decode(data, r.Form)
}
func (p *Parser) UnwrapParserError(err error) error {
if err == nil {
return nil
}
// try to unwrap the error
var multiErr schema.MultiError
if errors.As(err, &multiErr) && len(multiErr) == 1 {
for _, v := range multiErr {
var schemaErr schema.ConversionError
if errors.As(v, &schemaErr) {
return schemaErr.Err
}
return v
}
}
return err
}

View File

@ -0,0 +1,81 @@
package http
import (
"bytes"
"errors"
"net/http"
"net/url"
"testing"
gschema "github.com/gorilla/schema"
"github.com/stretchr/testify/require"
)
type SampleSchema struct {
Value *SampleSchemaValue `schema:"value"`
IntValue int `schema:"intvalue"`
}
type SampleSchemaValue struct{}
func (s *SampleSchemaValue) UnmarshalText(text []byte) error {
if string(text) == "foo" {
return nil
}
return errors.New("this is a test error")
}
func TestParser_UnwrapParserError(t *testing.T) {
tests := []struct {
name string
query string
wantErr bool
assertUnwrappedError func(err error, unwrappedErr error)
}{
{
name: "unwrap ok",
query: "value=test",
wantErr: true,
assertUnwrappedError: func(_, err error) {
require.Equal(t, "this is a test error", err.Error())
},
},
{
name: "multiple errors",
query: "value=test&intvalue=foo",
wantErr: true,
assertUnwrappedError: func(err error, unwrappedErr error) {
require.Equal(t, err, unwrappedErr)
},
},
{
name: "no error",
query: "value=foo&intvalue=1",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := NewParser()
encodedFormData := url.Values{}.Encode()
r, err := http.NewRequest(http.MethodPost, "http://exmaple.com?"+tt.query, bytes.NewBufferString(encodedFormData))
require.NoError(t, err)
data := new(SampleSchema)
err = p.Parse(r, data)
if !tt.wantErr {
require.NoError(t, err)
require.Nil(t, p.UnwrapParserError(err))
return
}
require.Error(t, err)
require.IsType(t, gschema.MultiError{}, err)
unwrappedErr := p.UnwrapParserError(err)
require.Error(t, unwrappedErr)
tt.assertUnwrappedError(err, unwrappedErr)
})
}
}

View File

@ -10,6 +10,12 @@ var AuthMapping = authz.MethodMapping{
"POST:/scim/v2/" + http.OrgIdInPathVariable + "/Users": {
Permission: domain.PermissionUserWrite,
},
"POST:/scim/v2/" + http.OrgIdInPathVariable + "/Users/.search": {
Permission: domain.PermissionUserRead,
},
"GET:/scim/v2/" + http.OrgIdInPathVariable + "/Users": {
Permission: domain.PermissionUserRead,
},
"GET:/scim/v2/" + http.OrgIdInPathVariable + "/Users/{id}": {
Permission: domain.PermissionUserRead,
},

View File

@ -21,6 +21,7 @@ import (
"github.com/zitadel/zitadel/internal/api/scim/schemas"
"github.com/zitadel/zitadel/internal/integration"
"github.com/zitadel/zitadel/internal/integration/scim"
"github.com/zitadel/zitadel/internal/test"
"github.com/zitadel/zitadel/pkg/grpc/management"
"github.com/zitadel/zitadel/pkg/grpc/user/v2"
)
@ -55,6 +56,104 @@ var (
//go:embed testdata/users_create_test_invalid_timezone.json
invalidTimeZoneUserJson []byte
fullUser = &resources.ScimUser{
ExternalID: "701984",
UserName: "bjensen@example.com",
Name: &resources.ScimUserName{
Formatted: "Babs Jensen", // DisplayName takes precedence in Zitadel
FamilyName: "Jensen",
GivenName: "Barbara",
MiddleName: "Jane",
HonorificPrefix: "Ms.",
HonorificSuffix: "III",
},
DisplayName: "Babs Jensen",
NickName: "Babs",
ProfileUrl: test.Must(schemas.ParseHTTPURL("http://login.example.com/bjensen")),
Emails: []*resources.ScimEmail{
{
Value: "bjensen@example.com",
Primary: true,
},
},
Addresses: []*resources.ScimAddress{
{
Type: "work",
StreetAddress: "100 Universal City Plaza",
Locality: "Hollywood",
Region: "CA",
PostalCode: "91608",
Country: "USA",
Formatted: "100 Universal City Plaza\nHollywood, CA 91608 USA",
Primary: true,
},
{
Type: "home",
StreetAddress: "456 Hollywood Blvd",
Locality: "Hollywood",
Region: "CA",
PostalCode: "91608",
Country: "USA",
Formatted: "456 Hollywood Blvd\nHollywood, CA 91608 USA",
},
},
PhoneNumbers: []*resources.ScimPhoneNumber{
{
Value: "+415555555555",
Primary: true,
},
},
Ims: []*resources.ScimIms{
{
Value: "someaimhandle",
Type: "aim",
},
{
Value: "twitterhandle",
Type: "X",
},
},
Photos: []*resources.ScimPhoto{
{
Value: *test.Must(schemas.ParseHTTPURL("https://photos.example.com/profilephoto/72930000000Ccne/F")),
Type: "photo",
},
},
Roles: []*resources.ScimRole{
{
Value: "my-role-1",
Display: "Rolle 1",
Type: "main-role",
Primary: true,
},
{
Value: "my-role-2",
Display: "Rolle 2",
Type: "secondary-role",
Primary: false,
},
},
Entitlements: []*resources.ScimEntitlement{
{
Value: "my-entitlement-1",
Display: "Entitlement 1",
Type: "main-entitlement",
Primary: true,
},
{
Value: "my-entitlement-2",
Display: "Entitlement 2",
Type: "secondary-entitlement",
Primary: false,
},
},
Title: "Tour Guide",
PreferredLanguage: language.MustParse("en-US"),
Locale: "en-US",
Timezone: "America/Los_Angeles",
Active: gu.Ptr(true),
}
)
func TestCreateUser(t *testing.T) {
@ -95,103 +194,7 @@ func TestCreateUser(t *testing.T) {
{
name: "full user",
body: fullUserJson,
want: &resources.ScimUser{
ExternalID: "701984",
UserName: "bjensen@example.com",
Name: &resources.ScimUserName{
Formatted: "Babs Jensen", // DisplayName takes precedence in Zitadel
FamilyName: "Jensen",
GivenName: "Barbara",
MiddleName: "Jane",
HonorificPrefix: "Ms.",
HonorificSuffix: "III",
},
DisplayName: "Babs Jensen",
NickName: "Babs",
ProfileUrl: integration.Must(schemas.ParseHTTPURL("http://login.example.com/bjensen")),
Emails: []*resources.ScimEmail{
{
Value: "bjensen@example.com",
Primary: true,
},
},
Addresses: []*resources.ScimAddress{
{
Type: "work",
StreetAddress: "100 Universal City Plaza",
Locality: "Hollywood",
Region: "CA",
PostalCode: "91608",
Country: "USA",
Formatted: "100 Universal City Plaza\nHollywood, CA 91608 USA",
Primary: true,
},
{
Type: "home",
StreetAddress: "456 Hollywood Blvd",
Locality: "Hollywood",
Region: "CA",
PostalCode: "91608",
Country: "USA",
Formatted: "456 Hollywood Blvd\nHollywood, CA 91608 USA",
},
},
PhoneNumbers: []*resources.ScimPhoneNumber{
{
Value: "+415555555555",
Primary: true,
},
},
Ims: []*resources.ScimIms{
{
Value: "someaimhandle",
Type: "aim",
},
{
Value: "twitterhandle",
Type: "X",
},
},
Photos: []*resources.ScimPhoto{
{
Value: *integration.Must(schemas.ParseHTTPURL("https://photos.example.com/profilephoto/72930000000Ccne/F")),
Type: "photo",
},
},
Roles: []*resources.ScimRole{
{
Value: "my-role-1",
Display: "Rolle 1",
Type: "main-role",
Primary: true,
},
{
Value: "my-role-2",
Display: "Rolle 2",
Type: "secondary-role",
Primary: false,
},
},
Entitlements: []*resources.ScimEntitlement{
{
Value: "my-entitlement-1",
Display: "Entitlement 1",
Type: "main-entitlement",
Primary: true,
},
{
Value: "my-entitlement-2",
Display: "Entitlement 2",
Type: "secondary-entitlement",
Primary: false,
},
},
Title: "Tour Guide",
PreferredLanguage: language.MustParse("en-US"),
Locale: "en-US",
Timezone: "America/Los_Angeles",
Active: gu.Ptr(true),
},
want: fullUser,
},
{
name: "missing userName",
@ -290,7 +293,7 @@ func TestCreateUser(t *testing.T) {
assert.Nil(t, createdUser.Password)
if tt.want != nil {
if !integration.PartiallyDeepEqual(tt.want, createdUser) {
if !test.PartiallyDeepEqual(tt.want, createdUser) {
t.Errorf("CreateUser() got = %v, want %v", createdUser, tt.want)
}
@ -299,7 +302,7 @@ func TestCreateUser(t *testing.T) {
// ensure the user is really stored and not just returned to the caller
fetchedUser, err := Instance.Client.SCIM.Users.Get(CTX, Instance.DefaultOrg.Id, createdUser.ID)
require.NoError(ttt, err)
if !integration.PartiallyDeepEqual(tt.want, fetchedUser) {
if !test.PartiallyDeepEqual(tt.want, fetchedUser) {
ttt.Errorf("GetUser() got = %v, want %v", fetchedUser, tt.want)
}
}, retryDuration, tick)
@ -315,6 +318,7 @@ func TestCreateUser_duplicate(t *testing.T) {
_, err = Instance.Client.SCIM.Users.Create(CTX, Instance.DefaultOrg.Id, minimalUserJson)
scimErr := scim.RequireScimError(t, http.StatusConflict, err)
assert.Equal(t, "User already exists", scimErr.Error.Detail)
assert.Equal(t, "uniqueness", scimErr.Error.ScimType)
_, err = Instance.Client.UserV2.DeleteUser(CTX, &user.DeleteUserRequest{UserId: createdUser.ID})
require.NoError(t, err)
@ -341,19 +345,19 @@ func TestCreateUser_metadata(t *testing.T) {
mdMap[md.Result[i].Key] = string(md.Result[i].Value)
}
integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:name.honorificPrefix", "Ms.")
integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:timezone", "America/Los_Angeles")
integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:photos", `[{"value":"https://photos.example.com/profilephoto/72930000000Ccne/F","type":"photo"},{"value":"https://photos.example.com/profilephoto/72930000000Ccne/T","type":"thumbnail"}]`)
integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:addresses", `[{"type":"work","streetAddress":"100 Universal City Plaza","locality":"Hollywood","region":"CA","postalCode":"91608","country":"USA","formatted":"100 Universal City Plaza\nHollywood, CA 91608 USA","primary":true},{"type":"home","streetAddress":"456 Hollywood Blvd","locality":"Hollywood","region":"CA","postalCode":"91608","country":"USA","formatted":"456 Hollywood Blvd\nHollywood, CA 91608 USA"}]`)
integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:entitlements", `[{"value":"my-entitlement-1","display":"Entitlement 1","type":"main-entitlement","primary":true},{"value":"my-entitlement-2","display":"Entitlement 2","type":"secondary-entitlement"}]`)
integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:externalId", "701984")
integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:name.middleName", "Jane")
integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:name.honorificSuffix", "III")
integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:profileURL", "http://login.example.com/bjensen")
integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:title", "Tour Guide")
integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:locale", "en-US")
integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:ims", `[{"value":"someaimhandle","type":"aim"},{"value":"twitterhandle","type":"X"}]`)
integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:roles", `[{"value":"my-role-1","display":"Rolle 1","type":"main-role","primary":true},{"value":"my-role-2","display":"Rolle 2","type":"secondary-role"}]`)
test.AssertMapContains(tt, mdMap, "urn:zitadel:scim:name.honorificPrefix", "Ms.")
test.AssertMapContains(tt, mdMap, "urn:zitadel:scim:timezone", "America/Los_Angeles")
test.AssertMapContains(tt, mdMap, "urn:zitadel:scim:photos", `[{"value":"https://photos.example.com/profilephoto/72930000000Ccne/F","type":"photo"},{"value":"https://photos.example.com/profilephoto/72930000000Ccne/T","type":"thumbnail"}]`)
test.AssertMapContains(tt, mdMap, "urn:zitadel:scim:addresses", `[{"type":"work","streetAddress":"100 Universal City Plaza","locality":"Hollywood","region":"CA","postalCode":"91608","country":"USA","formatted":"100 Universal City Plaza\nHollywood, CA 91608 USA","primary":true},{"type":"home","streetAddress":"456 Hollywood Blvd","locality":"Hollywood","region":"CA","postalCode":"91608","country":"USA","formatted":"456 Hollywood Blvd\nHollywood, CA 91608 USA"}]`)
test.AssertMapContains(tt, mdMap, "urn:zitadel:scim:entitlements", `[{"value":"my-entitlement-1","display":"Entitlement 1","type":"main-entitlement","primary":true},{"value":"my-entitlement-2","display":"Entitlement 2","type":"secondary-entitlement"}]`)
test.AssertMapContains(tt, mdMap, "urn:zitadel:scim:externalId", "701984")
test.AssertMapContains(tt, mdMap, "urn:zitadel:scim:name.middleName", "Jane")
test.AssertMapContains(tt, mdMap, "urn:zitadel:scim:name.honorificSuffix", "III")
test.AssertMapContains(tt, mdMap, "urn:zitadel:scim:profileURL", "http://login.example.com/bjensen")
test.AssertMapContains(tt, mdMap, "urn:zitadel:scim:title", "Tour Guide")
test.AssertMapContains(tt, mdMap, "urn:zitadel:scim:locale", "en-US")
test.AssertMapContains(tt, mdMap, "urn:zitadel:scim:ims", `[{"value":"someaimhandle","type":"aim"},{"value":"twitterhandle","type":"X"}]`)
test.AssertMapContains(tt, mdMap, "urn:zitadel:scim:roles", `[{"value":"my-role-1","display":"Rolle 1","type":"main-role","primary":true},{"value":"my-role-2","display":"Rolle 2","type":"secondary-role"}]`)
}, retryDuration, tick)
}

View File

@ -19,6 +19,7 @@ import (
"github.com/zitadel/zitadel/internal/api/scim/schemas"
"github.com/zitadel/zitadel/internal/integration"
"github.com/zitadel/zitadel/internal/integration/scim"
"github.com/zitadel/zitadel/internal/test"
"github.com/zitadel/zitadel/pkg/grpc/management"
"github.com/zitadel/zitadel/pkg/grpc/user/v2"
)
@ -93,7 +94,7 @@ func TestGetUser(t *testing.T) {
},
DisplayName: "Babs Jensen",
NickName: "Babs",
ProfileUrl: integration.Must(schemas.ParseHTTPURL("http://login.example.com/bjensen")),
ProfileUrl: test.Must(schemas.ParseHTTPURL("http://login.example.com/bjensen")),
Title: "Tour Guide",
PreferredLanguage: language.Make("en-US"),
Locale: "en-US",
@ -144,11 +145,11 @@ func TestGetUser(t *testing.T) {
},
Photos: []*resources.ScimPhoto{
{
Value: *integration.Must(schemas.ParseHTTPURL("https://photos.example.com/profilephoto/72930000000Ccne/F")),
Value: *test.Must(schemas.ParseHTTPURL("https://photos.example.com/profilephoto/72930000000Ccne/F")),
Type: "photo",
},
{
Value: *integration.Must(schemas.ParseHTTPURL("https://photos.example.com/profilephoto/72930000000Ccne/T")),
Value: *test.Must(schemas.ParseHTTPURL("https://photos.example.com/profilephoto/72930000000Ccne/T")),
Type: "thumbnail",
},
},
@ -256,7 +257,7 @@ func TestGetUser(t *testing.T) {
assert.Equal(ttt, schemas.ScimResourceTypeSingular("User"), fetchedUser.Resource.Meta.ResourceType)
assert.Equal(ttt, "http://"+Instance.Host()+path.Join(schemas.HandlerPrefix, Instance.DefaultOrg.Id, "Users", fetchedUser.ID), fetchedUser.Resource.Meta.Location)
assert.Nil(ttt, fetchedUser.Password)
if !integration.PartiallyDeepEqual(tt.want, fetchedUser) {
if !test.PartiallyDeepEqual(tt.want, fetchedUser) {
ttt.Errorf("GetUser() got = %#v, want %#v", fetchedUser, tt.want)
}
}, retryDuration, tick)

View File

@ -0,0 +1,492 @@
//go:build integration
package integration_test
import (
"context"
"fmt"
"net/http"
"strings"
"testing"
"time"
"github.com/brianvoe/gofakeit/v6"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/api/scim/resources"
"github.com/zitadel/zitadel/internal/api/scim/schemas"
"github.com/zitadel/zitadel/internal/integration"
"github.com/zitadel/zitadel/internal/integration/scim"
"github.com/zitadel/zitadel/internal/test"
"github.com/zitadel/zitadel/pkg/grpc/management"
"github.com/zitadel/zitadel/pkg/grpc/object/v2"
user_v2 "github.com/zitadel/zitadel/pkg/grpc/user/v2"
)
var totalCountOfHumanUsers = 13
func TestListUser(t *testing.T) {
createdUserIDs := createUsers(t, CTX, Instance.DefaultOrg.Id)
defer func() {
// only the full user needs to be deleted, all others have random identification data
// fullUser is always the first one.
_, err := Instance.Client.UserV2.DeleteUser(CTX, &user_v2.DeleteUserRequest{
UserId: createdUserIDs[0],
})
require.NoError(t, err)
}()
// secondary organization with same set of users,
// these should never be modified.
// This allows testing list requests without filters.
iamOwnerCtx := Instance.WithAuthorization(CTX, integration.UserTypeIAMOwner)
secondaryOrg := Instance.CreateOrganization(iamOwnerCtx, gofakeit.Name(), gofakeit.Email())
secondaryOrgCreatedUserIDs := createUsers(t, iamOwnerCtx, secondaryOrg.OrganizationId)
testsInitializedUtc := time.Now().UTC()
// Wait one second to ensure a change in the least significant value of the timestamp.
time.Sleep(time.Second)
tests := []struct {
name string
ctx context.Context
orgID string
req *scim.ListRequest
prepare func(require.TestingT) *scim.ListRequest
wantErr bool
errorStatus int
errorType string
assert func(assert.TestingT, *scim.ListResponse[*resources.ScimUser])
cleanup func(require.TestingT)
}{
{
name: "not authenticated",
ctx: context.Background(),
req: new(scim.ListRequest),
wantErr: true,
errorStatus: http.StatusUnauthorized,
},
{
name: "no permissions",
ctx: Instance.WithAuthorization(CTX, integration.UserTypeNoPermission),
req: new(scim.ListRequest),
wantErr: true,
errorStatus: http.StatusNotFound,
},
{
name: "unknown sort order",
req: &scim.ListRequest{
SortBy: gu.Ptr("id"),
SortOrder: gu.Ptr(scim.ListRequestSortOrder("fooBar")),
},
wantErr: true,
errorType: "invalidValue",
},
{
name: "unknown sort field",
req: &scim.ListRequest{
SortBy: gu.Ptr("fooBar"),
},
wantErr: true,
errorType: "invalidValue",
},
{
name: "unknown filter field",
req: &scim.ListRequest{
Filter: gu.Ptr(`fooBar eq "10"`),
},
wantErr: true,
errorType: "invalidFilter",
},
{
name: "invalid filter",
req: &scim.ListRequest{
Filter: gu.Ptr(`fooBarBaz`),
},
wantErr: true,
errorType: "invalidFilter",
},
{
name: "list users without filter",
// use other org, modifications of users happens only on primary org
orgID: secondaryOrg.OrganizationId,
ctx: iamOwnerCtx,
req: new(scim.ListRequest),
assert: func(t assert.TestingT, resp *scim.ListResponse[*resources.ScimUser]) {
assert.Equal(t, 100, resp.ItemsPerPage)
assert.Equal(t, totalCountOfHumanUsers, resp.TotalResults)
assert.Equal(t, 1, resp.StartIndex)
assert.Len(t, resp.Resources, totalCountOfHumanUsers)
},
},
{
name: "list paged sorted users without filter",
// use other org, modifications of users happens only on primary org
orgID: secondaryOrg.OrganizationId,
ctx: iamOwnerCtx,
req: &scim.ListRequest{
Count: gu.Ptr(2),
StartIndex: gu.Ptr(5),
SortOrder: gu.Ptr(scim.ListRequestSortOrderAsc),
SortBy: gu.Ptr("username"),
},
assert: func(t assert.TestingT, resp *scim.ListResponse[*resources.ScimUser]) {
assert.Equal(t, 2, resp.ItemsPerPage)
assert.Equal(t, totalCountOfHumanUsers, resp.TotalResults)
assert.Equal(t, 5, resp.StartIndex)
assert.Len(t, resp.Resources, 2)
assert.True(t, strings.HasPrefix(resp.Resources[0].UserName, "scim-username-1: "))
assert.True(t, strings.HasPrefix(resp.Resources[1].UserName, "scim-username-2: "))
},
},
{
name: "list users with simple filter",
req: &scim.ListRequest{
Filter: gu.Ptr(`username sw "scim-username-1"`),
},
assert: func(t assert.TestingT, resp *scim.ListResponse[*resources.ScimUser]) {
assert.Equal(t, 100, resp.ItemsPerPage)
assert.Equal(t, 2, resp.TotalResults)
assert.Equal(t, 1, resp.StartIndex)
assert.Len(t, resp.Resources, 2)
for _, resource := range resp.Resources {
assert.True(t, strings.HasPrefix(resource.UserName, "scim-username-1"))
}
},
},
{
name: "list paged sorted users with filter",
req: &scim.ListRequest{
Count: gu.Ptr(5),
StartIndex: gu.Ptr(1),
SortOrder: gu.Ptr(scim.ListRequestSortOrderAsc),
SortBy: gu.Ptr("username"),
Filter: gu.Ptr(`emails sw "scim-email-1" and emails ew "@example.com"`),
},
assert: func(t assert.TestingT, resp *scim.ListResponse[*resources.ScimUser]) {
assert.Equal(t, 5, resp.ItemsPerPage)
assert.Equal(t, 2, resp.TotalResults)
assert.Equal(t, 1, resp.StartIndex)
assert.Len(t, resp.Resources, 2)
for _, resource := range resp.Resources {
assert.True(t, strings.HasPrefix(resource.UserName, "scim-username-1"))
assert.Len(t, resource.Emails, 1)
assert.True(t, strings.HasPrefix(resource.Emails[0].Value, "scim-email-1"))
assert.True(t, strings.HasSuffix(resource.Emails[0].Value, "@example.com"))
}
},
},
{
name: "list paged sorted users with filter as post",
req: &scim.ListRequest{
Count: gu.Ptr(5),
StartIndex: gu.Ptr(1),
SortOrder: gu.Ptr(scim.ListRequestSortOrderAsc),
SortBy: gu.Ptr("username"),
Filter: gu.Ptr(`emails sw "scim-email-1" and emails ew "@example.com"`),
SendAsPost: true,
},
assert: func(t assert.TestingT, resp *scim.ListResponse[*resources.ScimUser]) {
assert.Equal(t, 5, resp.ItemsPerPage)
assert.Equal(t, 2, resp.TotalResults)
assert.Equal(t, 1, resp.StartIndex)
assert.Len(t, resp.Resources, 2)
for _, resource := range resp.Resources {
assert.True(t, strings.HasPrefix(resource.UserName, "scim-username-1"))
assert.Len(t, resource.Emails, 1)
assert.True(t, strings.HasPrefix(resource.Emails[0].Value, "scim-email-1"))
assert.True(t, strings.HasSuffix(resource.Emails[0].Value, "@example.com"))
}
},
},
{
name: "count users without filter",
// use other org, modifications of users happens only on primary org
orgID: secondaryOrg.OrganizationId,
ctx: iamOwnerCtx,
prepare: func(t require.TestingT) *scim.ListRequest {
return &scim.ListRequest{
Count: gu.Ptr(0),
}
},
assert: func(t assert.TestingT, resp *scim.ListResponse[*resources.ScimUser]) {
assert.Equal(t, 0, resp.ItemsPerPage)
assert.Equal(t, totalCountOfHumanUsers, resp.TotalResults)
assert.Equal(t, 1, resp.StartIndex)
assert.Len(t, resp.Resources, 0)
},
},
{
name: "list users with active filter",
req: &scim.ListRequest{
Filter: gu.Ptr(`active eq false`),
},
assert: func(t assert.TestingT, resp *scim.ListResponse[*resources.ScimUser]) {
assert.Equal(t, 100, resp.ItemsPerPage)
assert.Equal(t, 1, resp.TotalResults)
assert.Equal(t, 1, resp.StartIndex)
assert.Len(t, resp.Resources, 1)
assert.True(t, strings.HasPrefix(resp.Resources[0].UserName, "scim-username-0"))
assert.False(t, *resp.Resources[0].Active)
},
},
{
name: "list users with externalid filter",
req: &scim.ListRequest{
Filter: gu.Ptr(`externalid eq "701984"`),
},
assert: func(t assert.TestingT, resp *scim.ListResponse[*resources.ScimUser]) {
assert.Equal(t, 100, resp.ItemsPerPage)
assert.Equal(t, 1, resp.TotalResults)
assert.Equal(t, 1, resp.StartIndex)
assert.Len(t, resp.Resources, 1)
assert.Equal(t, resp.Resources[0].ExternalID, "701984")
},
},
{
name: "list users with externalid filter invalid operator",
req: &scim.ListRequest{
Filter: gu.Ptr(`externalid pr`),
},
wantErr: true,
errorType: "invalidFilter",
},
{
name: "list users with externalid complex filter",
req: &scim.ListRequest{
Filter: gu.Ptr(`externalid eq "701984" and username eq "bjensen@example.com"`),
},
assert: func(t assert.TestingT, resp *scim.ListResponse[*resources.ScimUser]) {
assert.Equal(t, 100, resp.ItemsPerPage)
assert.Equal(t, 1, resp.TotalResults)
assert.Equal(t, 1, resp.StartIndex)
assert.Len(t, resp.Resources, 1)
assert.Equal(t, resp.Resources[0].UserName, "bjensen@example.com")
assert.Equal(t, resp.Resources[0].ExternalID, "701984")
},
},
{
name: "count users with filter",
req: &scim.ListRequest{
Count: gu.Ptr(0),
Filter: gu.Ptr(`emails sw "scim-email-1" and emails ew "@example.com"`),
},
assert: func(t assert.TestingT, resp *scim.ListResponse[*resources.ScimUser]) {
assert.Equal(t, 0, resp.ItemsPerPage)
assert.Equal(t, 2, resp.TotalResults)
assert.Equal(t, 1, resp.StartIndex)
assert.Len(t, resp.Resources, 0)
},
},
{
name: "list users with modification date filter",
prepare: func(t require.TestingT) *scim.ListRequest {
userID := createdUserIDs[len(createdUserIDs)-1] // use the last entry, as we use the others for other assertions
_, err := Instance.Client.UserV2.UpdateHumanUser(CTX, &user_v2.UpdateHumanUserRequest{
UserId: userID,
Profile: &user_v2.SetHumanProfile{
GivenName: "scim-user-given-name-modified-0: " + gofakeit.FirstName(),
FamilyName: "scim-user-family-name-modified-0: " + gofakeit.LastName(),
},
})
require.NoError(t, err)
return &scim.ListRequest{
// filter by id too to exclude other random users
Filter: gu.Ptr(fmt.Sprintf(`id eq "%s" and meta.LASTMODIFIED gt "%s"`, userID, testsInitializedUtc.Format(time.RFC3339))),
}
},
assert: func(t assert.TestingT, resp *scim.ListResponse[*resources.ScimUser]) {
assert.Len(t, resp.Resources, 1)
assert.Equal(t, resp.Resources[0].ID, createdUserIDs[len(createdUserIDs)-1])
assert.True(t, strings.HasPrefix(resp.Resources[0].Name.FamilyName, "scim-user-family-name-modified-0:"))
assert.True(t, strings.HasPrefix(resp.Resources[0].Name.GivenName, "scim-user-given-name-modified-0:"))
},
},
{
name: "list users with creation date filter",
prepare: func(t require.TestingT) *scim.ListRequest {
resp := createHumanUser(t, CTX, Instance.DefaultOrg.Id, 100)
return &scim.ListRequest{
Filter: gu.Ptr(fmt.Sprintf(`id eq "%s" and meta.created gt "%s"`, resp.UserId, testsInitializedUtc.Format(time.RFC3339))),
}
},
assert: func(t assert.TestingT, resp *scim.ListResponse[*resources.ScimUser]) {
assert.Len(t, resp.Resources, 1)
assert.True(t, strings.HasPrefix(resp.Resources[0].UserName, "scim-username-100:"))
},
},
{
name: "validate returned objects",
req: &scim.ListRequest{
Filter: gu.Ptr(fmt.Sprintf(`id eq "%s"`, createdUserIDs[0])),
},
assert: func(t assert.TestingT, resp *scim.ListResponse[*resources.ScimUser]) {
assert.Len(t, resp.Resources, 1)
if !test.PartiallyDeepEqual(fullUser, resp.Resources[0]) {
t.Errorf("got = %#v, want %#v", resp.Resources[0], fullUser)
}
},
},
{
name: "do not return user of other org",
req: &scim.ListRequest{
Filter: gu.Ptr(fmt.Sprintf(`id eq "%s"`, secondaryOrgCreatedUserIDs[0])),
},
assert: func(t assert.TestingT, resp *scim.ListResponse[*resources.ScimUser]) {
assert.Len(t, resp.Resources, 0)
},
},
{
name: "do not count user of other org",
prepare: func(t require.TestingT) *scim.ListRequest {
iamOwnerCtx := Instance.WithAuthorization(CTX, integration.UserTypeIAMOwner)
org := Instance.CreateOrganization(iamOwnerCtx, gofakeit.Name(), gofakeit.Email())
resp := createHumanUser(t, iamOwnerCtx, org.OrganizationId, 102)
return &scim.ListRequest{
Count: gu.Ptr(0),
Filter: gu.Ptr(fmt.Sprintf(`id eq "%s"`, resp.UserId)),
}
},
assert: func(t assert.TestingT, resp *scim.ListResponse[*resources.ScimUser]) {
assert.Len(t, resp.Resources, 0)
},
},
{
name: "scoped externalID",
prepare: func(t require.TestingT) *scim.ListRequest {
resp := createHumanUser(t, CTX, Instance.DefaultOrg.Id, 102)
// set provisioning domain of service user
_, err := Instance.Client.Mgmt.SetUserMetadata(CTX, &management.SetUserMetadataRequest{
Id: Instance.Users.Get(integration.UserTypeOrgOwner).ID,
Key: "urn:zitadel:scim:provisioning_domain",
Value: []byte("fooBar"),
})
require.NoError(t, err)
// set externalID for provisioning domain
_, err = Instance.Client.Mgmt.SetUserMetadata(CTX, &management.SetUserMetadataRequest{
Id: resp.UserId,
Key: "urn:zitadel:scim:fooBar:externalId",
Value: []byte("100-scopedExternalId"),
})
require.NoError(t, err)
return &scim.ListRequest{
Filter: gu.Ptr(fmt.Sprintf(`id eq "%s"`, resp.UserId)),
}
},
assert: func(t assert.TestingT, resp *scim.ListResponse[*resources.ScimUser]) {
assert.Len(t, resp.Resources, 1)
assert.Equal(t, resp.Resources[0].ExternalID, "100-scopedExternalId")
},
cleanup: func(t require.TestingT) {
// delete provisioning domain of service user
_, err := Instance.Client.Mgmt.RemoveUserMetadata(CTX, &management.RemoveUserMetadataRequest{
Id: Instance.Users.Get(integration.UserTypeOrgOwner).ID,
Key: "urn:zitadel:scim:provisioning_domain",
})
require.NoError(t, err)
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.ctx == nil {
tt.ctx = CTX
}
if tt.prepare != nil {
tt.req = tt.prepare(t)
}
if tt.orgID == "" {
tt.orgID = Instance.DefaultOrg.Id
}
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(tt.ctx, time.Minute)
require.EventuallyWithT(t, func(ttt *assert.CollectT) {
listResp, err := Instance.Client.SCIM.Users.List(tt.ctx, tt.orgID, tt.req)
if tt.wantErr {
statusCode := tt.errorStatus
if statusCode == 0 {
statusCode = http.StatusBadRequest
}
scimErr := scim.RequireScimError(ttt, statusCode, err)
if tt.errorType != "" {
assert.Equal(t, tt.errorType, scimErr.Error.ScimType)
}
return
}
require.NoError(t, err)
assert.EqualValues(ttt, []schemas.ScimSchemaType{"urn:ietf:params:scim:api:messages:2.0:ListResponse"}, listResp.Schemas)
if tt.assert != nil {
tt.assert(ttt, listResp)
}
}, retryDuration, tick)
if tt.cleanup != nil {
tt.cleanup(t)
}
})
}
}
func createUsers(t *testing.T, ctx context.Context, orgID string) []string {
count := totalCountOfHumanUsers - 1 // zitadel admin is always created by default
createdUserIDs := make([]string, 0, count)
// create the full scim user if on primary org
if orgID == Instance.DefaultOrg.Id {
fullUserCreatedResp, err := Instance.Client.SCIM.Users.Create(ctx, orgID, fullUserJson)
require.NoError(t, err)
createdUserIDs = append(createdUserIDs, fullUserCreatedResp.ID)
count--
}
// set the first user inactive
resp := createHumanUser(t, ctx, orgID, 0)
_, err := Instance.Client.UserV2.DeactivateUser(ctx, &user_v2.DeactivateUserRequest{
UserId: resp.UserId,
})
require.NoError(t, err)
createdUserIDs = append(createdUserIDs, resp.UserId)
for i := 1; i < count; i++ {
resp = createHumanUser(t, ctx, orgID, i)
createdUserIDs = append(createdUserIDs, resp.UserId)
}
return createdUserIDs
}
func createHumanUser(t require.TestingT, ctx context.Context, orgID string, i int) *user_v2.AddHumanUserResponse {
// create remaining minimal users with faker data
// no need to clean these up as identification attributes change each time
resp, err := Instance.Client.UserV2.AddHumanUser(ctx, &user_v2.AddHumanUserRequest{
Organization: &object.Organization{
Org: &object.Organization_OrgId{
OrgId: orgID,
},
},
Username: gu.Ptr(fmt.Sprintf("scim-username-%d: %s", i, gofakeit.Username())),
Profile: &user_v2.SetHumanProfile{
GivenName: fmt.Sprintf("scim-givenname-%d: %s", i, gofakeit.FirstName()),
FamilyName: fmt.Sprintf("scim-familyname-%d: %s", i, gofakeit.LastName()),
PreferredLanguage: gu.Ptr("en-US"),
Gender: gu.Ptr(user_v2.Gender_GENDER_MALE),
},
Email: &user_v2.SetHumanEmail{
Email: fmt.Sprintf("scim-email-%d-%d@example.com", i, gofakeit.Number(0, 1_000_000)),
},
})
require.NoError(t, err)
return resp
}

View File

@ -19,6 +19,7 @@ import (
"github.com/zitadel/zitadel/internal/api/scim/schemas"
"github.com/zitadel/zitadel/internal/integration"
"github.com/zitadel/zitadel/internal/integration/scim"
"github.com/zitadel/zitadel/internal/test"
"github.com/zitadel/zitadel/pkg/grpc/management"
"github.com/zitadel/zitadel/pkg/grpc/user/v2"
)
@ -78,7 +79,7 @@ func TestReplaceUser(t *testing.T) {
},
DisplayName: "Babs Jensen-updated",
NickName: "Babs-updated",
ProfileUrl: integration.Must(schemas.ParseHTTPURL("http://login.example.com/bjensen-updated")),
ProfileUrl: test.Must(schemas.ParseHTTPURL("http://login.example.com/bjensen-updated")),
Emails: []*resources.ScimEmail{
{
Value: "bjensen-replaced-full@example.com",
@ -124,11 +125,11 @@ func TestReplaceUser(t *testing.T) {
},
Photos: []*resources.ScimPhoto{
{
Value: *integration.Must(schemas.ParseHTTPURL("https://photos.example.com/profilephoto/72930000000Ccne/F-updated")),
Value: *test.Must(schemas.ParseHTTPURL("https://photos.example.com/profilephoto/72930000000Ccne/F-updated")),
Type: "photo-updated",
},
{
Value: *integration.Must(schemas.ParseHTTPURL("https://photos.example.com/profilephoto/72930000000Ccne/T-updated")),
Value: *test.Must(schemas.ParseHTTPURL("https://photos.example.com/profilephoto/72930000000Ccne/T-updated")),
Type: "thumbnail-updated",
},
},
@ -247,7 +248,7 @@ func TestReplaceUser(t *testing.T) {
assert.Equal(t, "http://"+Instance.Host()+path.Join(schemas.HandlerPrefix, Instance.DefaultOrg.Id, "Users", createdUser.ID), replacedUser.Resource.Meta.Location)
assert.Nil(t, createdUser.Password)
if !integration.PartiallyDeepEqual(tt.want, replacedUser) {
if !test.PartiallyDeepEqual(tt.want, replacedUser) {
t.Errorf("ReplaceUser() got = %#v, want %#v", replacedUser, tt.want)
}
@ -256,7 +257,7 @@ func TestReplaceUser(t *testing.T) {
// ensure the user is really stored and not just returned to the caller
fetchedUser, err := Instance.Client.SCIM.Users.Get(CTX, Instance.DefaultOrg.Id, replacedUser.ID)
require.NoError(ttt, err)
if !integration.PartiallyDeepEqual(tt.want, fetchedUser) {
if !test.PartiallyDeepEqual(tt.want, fetchedUser) {
ttt.Errorf("GetUser() got = %#v, want %#v", fetchedUser, tt.want)
}
}, retryDuration, tick)
@ -316,8 +317,8 @@ func TestReplaceUser_scopedExternalID(t *testing.T) {
}
// both external IDs should be present on the user
integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:externalId", "701984")
integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:fooBazz:externalId", "replaced-external-id")
test.AssertMapContains(tt, mdMap, "urn:zitadel:scim:externalId", "701984")
test.AssertMapContains(tt, mdMap, "urn:zitadel:scim:fooBazz:externalId", "replaced-external-id")
}, retryDuration, tick)
_, err = Instance.Client.UserV2.DeleteUser(CTX, &user.DeleteUserRequest{UserId: createdUser.ID})

View File

@ -0,0 +1,340 @@
package filter
import (
"encoding/json"
"strconv"
"strings"
"github.com/alecthomas/participle/v2"
"github.com/alecthomas/participle/v2/lexer"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/scim/schemas"
"github.com/zitadel/zitadel/internal/api/scim/serrors"
"github.com/zitadel/zitadel/internal/zerrors"
)
// Filter The scim v2 filter
// Separation between FilterSegment and Filter is required
// due to the UnmarshalText method, which is used by the schema parser
// as well as the participle parser but should do different things here.
type Filter struct {
Root Segment
}
// Segment The root ast node for the filter grammar
// according to the filter ABNF of https://datatracker.ietf.org/doc/html/rfc7644#section-3.4.2.2
// FILTER = attrExp / logExp / valuePath / *1"not" "(" FILTER ")"
// to reduce lookahead needs and reduce stack depth of the parser,
// always match log expressions with optional operators
type Segment struct {
OrExp OrLogExp `parser:"@@"`
}
// OrLogExp The logical expression according to the filter ABNF
// separated in OrLogExp and AndLogExp to simplify parser stack depth and precedence
// logExp = FILTER SP ("and" / "or") SP FILTER
type OrLogExp struct {
Left AndLogExp `parser:"@@"`
Right *OrLogExp `parser:"(Whitespace 'or' Whitespace @@)?"`
}
type AndLogExp struct {
Left ValueAtom `parser:"@@"`
Right *AndLogExp `parser:"(Whitespace 'and' Whitespace @@)?"`
}
type ValueAtom struct {
SubFilter *Segment `parser:"'(' @@ ')' |"`
Negation *Segment `parser:"'not' '(' @@ ')' |"`
ValuePath *ValuePath `parser:"@@ |"`
AttrExp *AttrExp `parser:"@@"`
}
// ValuePath The value path according to the filter ABNF
// valuePath = attrPath "[" valFilter "]"
// instead of a separate valFilter entity the LogExp
// is used to simplify parsing.
type ValuePath struct {
AttrPath AttrPath `parser:"@@"`
ValFilter OrLogExp `parser:"'[' @@ ']'"`
}
// AttrExp The attribute expression according to the filter ABNF
// attrExp = (attrPath SP "pr") / (attrPath SP compareOp SP compValue)
type AttrExp struct {
UnaryCondition *UnaryCondition `parser:"@@ |"`
BinaryCondition *BinaryCondition `parser:"@@"`
}
type UnaryCondition struct {
Left AttrPath `parser:"@@ Whitespace"`
Operator UnaryConditionOperator `parser:"@@"`
}
type UnaryConditionOperator struct {
Present bool `parser:"@'pr'"`
}
type BinaryCondition struct {
Left AttrPath `parser:"@@ Whitespace"`
Operator CompareOp `parser:"@@ Whitespace"`
Right CompValue `parser:"@@"`
}
// CompareOp according to the scim filter ABNF
// compareOp = "eq" / "ne" / "co" /
// "sw" / "ew" /
// "gt" / "lt" /
// "ge" / "le"
type CompareOp struct {
Equal bool `parser:"@'eq' |"`
NotEqual bool `parser:"@'ne' |"`
Contains bool `parser:"@'co' |"`
StartsWith bool `parser:"@'sw' |"`
EndsWith bool `parser:"@'ew' |"`
GreaterThan bool `parser:"@'gt' |"`
GreaterThanOrEqual bool `parser:"@'ge' |"`
LessThan bool `parser:"@'lt' |"`
LessThanOrEqual bool `parser:"@'le'"`
}
// CompValue the compare value according to the scim filter ABNF
// compValue = false / null / true / number / string
type CompValue struct {
Null bool `parser:"@'null' |"`
BooleanTrue bool `parser:"@'true' |"`
BooleanFalse bool `parser:"@'false' |"`
Int *int `parser:"@Int |"`
Float *float64 `parser:"@Float |"`
StringValue *string `parser:"@String"`
}
// AttrPath the attribute path according to the scim filter ABNF
// [URI ":"] AttrName *1subAttr
type AttrPath struct {
UrnAttributePrefix *string `parser:"(@UrnAttributePrefix)?"`
AttrName string `parser:"@AttrName"`
SubAttr *string `parser:"('.' @AttrName)?"`
}
const (
maxInputLength = 1000
)
var (
scimFilterLexer = lexer.MustSimple([]lexer.SimpleRule{
// simplified version of RFC8141, last part isn't matched as in scim this is the attribute name
// urn is additionally verified after parsing, use a more relaxed matching here
{Name: "UrnAttributePrefix", Pattern: `urn:([\w()+,.=@;$_!*'%/?#-]+:)+`},
{Name: "Float", Pattern: `[-+]?\d*\.\d+`},
{Name: "Int", Pattern: `[-+]?\d+`},
{Name: "Parenthesis", Pattern: `\(|\)|\[|\]`},
{Name: "Punctuation", Pattern: `\.`},
{Name: "String", Pattern: `"(\\"|[^"])*"`},
{Name: "AttrName", Pattern: `[a-zA-Z][\w-]*`}, // AttrName according to the scim ABNF
{Name: "Whitespace", Pattern: `[ \t\n\r]+`},
})
scimFilterParser = buildParser[Segment]()
)
func buildParser[T any]() *participle.Parser[T] {
return participle.MustBuild[T](
participle.Lexer(scimFilterLexer),
participle.Unquote("String"),
// Keyword literals are matched case-insensitive (according to https://datatracker.ietf.org/doc/html/rfc7644#section-3.4.2.2)
// Keywords are a subset of AttrName
participle.CaseInsensitive("AttrName"),
participle.Elide("Whitespace"),
participle.UseLookahead(participle.MaxLookahead),
)
}
func (f *Filter) UnmarshalText(text []byte) error {
if len(text) == 0 {
*f = Filter{}
return nil
}
parsedFilter, err := ParseFilter(string(text))
if err != nil {
return err
}
*f = *parsedFilter
return nil
}
func (f *Filter) UnmarshalJSON(data []byte) error {
var rawFilter string
if err := json.Unmarshal(data, &rawFilter); err != nil {
return err
}
return f.UnmarshalText([]byte(rawFilter))
}
func (f *Filter) IsZero() bool {
return f == nil || *f == Filter{}
}
func ParseFilter(filter string) (*Filter, error) {
if filter == "" {
return nil, nil
}
if len(filter) > maxInputLength {
logging.WithFields("len", len(filter)).Infof("scim: filter exceeds maximum allowed length: %d", maxInputLength)
return nil, serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgumentf(nil, "SCIM-filt13", "filter exceeds maximum allowed length: %d", maxInputLength))
}
parsedFilter, err := scimFilterParser.ParseString("", filter)
if err != nil {
logging.WithError(err).Info("scim: failed to parse filter")
return nil, serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgument(err, "SCIM-filt14", "failed to parse filter"))
}
return &Filter{Root: *parsedFilter}, nil
}
func (f *Filter) String() string {
return f.Root.String()
}
func (f *Segment) String() string {
return f.OrExp.String()
}
func (o *OrLogExp) String() string {
if o.Right == nil {
return o.Left.String()
}
return "((" + o.Left.String() + ") or (" + o.Right.String() + "))"
}
func (a *AndLogExp) String() string {
if a.Right == nil {
return a.Left.String()
}
return "((" + a.Left.String() + ") and (" + a.Right.String() + "))"
}
func (a *ValueAtom) String() string {
switch {
case a.SubFilter != nil:
return "(" + a.SubFilter.String() + ")"
case a.Negation != nil:
return "not(" + a.Negation.String() + ")"
case a.ValuePath != nil:
return a.ValuePath.String()
}
return a.AttrExp.String()
}
func (v *ValuePath) String() string {
return v.AttrPath.String() + "[" + v.ValFilter.String() + "]"
}
func (a *AttrExp) String() string {
if a.UnaryCondition != nil {
return a.UnaryCondition.String()
}
return a.BinaryCondition.String()
}
func (u *UnaryCondition) String() string {
return u.Left.String() + " " + u.Operator.String()
}
func (u *UnaryConditionOperator) String() string {
return "pr"
}
func (b *BinaryCondition) String() string {
return b.Left.String() + " " + b.Operator.String() + " " + b.Right.String()
}
func (c *CompareOp) String() string {
switch {
case c.Equal:
return "eq"
case c.NotEqual:
return "ne"
case c.Contains:
return "co"
case c.StartsWith:
return "sw"
case c.EndsWith:
return "ew"
case c.GreaterThan:
return "gt"
case c.GreaterThanOrEqual:
return "ge"
case c.LessThan:
return "lt"
case c.LessThanOrEqual:
return "le"
}
return "<unknown CompareOp>"
}
func (c *CompValue) String() string {
switch {
case c.Null:
return "null"
case c.BooleanTrue:
return "true"
case c.BooleanFalse:
return "false"
case c.Int != nil:
return strconv.Itoa(*c.Int)
case c.Float != nil:
return strconv.FormatFloat(*c.Float, 'f', -1, 64)
case c.StringValue != nil:
return "\"" + *c.StringValue + "\""
}
return "<unknown CompValue>"
}
func (a *AttrPath) String() string {
var s = ""
if a.UrnAttributePrefix != nil {
s += *a.UrnAttributePrefix
}
s += a.AttrName
if a.SubAttr != nil {
s += "." + *a.SubAttr
}
return s
}
func (a *AttrPath) validateSchema(expectedSchema schemas.ScimSchemaType) error {
if a.UrnAttributePrefix == nil || *a.UrnAttributePrefix == string(expectedSchema)+":" {
return nil
}
logging.WithFields("urnPrefix", *a.UrnAttributePrefix).Info("scim filter: Invalid filter expression: unknown urn attribute prefix")
return serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgument(nil, "SCIM-FF431", "Invalid filter expression: unknown urn attribute prefix"))
}
func (a *AttrPath) Segments() []string {
// user lower, since attribute names in scim are always case-insensitive
if a.SubAttr != nil {
return []string{strings.ToLower(a.AttrName), strings.ToLower(*a.SubAttr)}
}
return []string{strings.ToLower(a.AttrName)}
}
func (a *AttrPath) FieldPath() string {
return strings.Join(a.Segments(), ".")
}

View File

@ -0,0 +1,868 @@
package filter
import (
"reflect"
"strings"
"testing"
"github.com/muhlemmer/gu"
)
var longString = ""
func init() {
var sb strings.Builder
for i := 0; i < maxInputLength+1; i++ {
sb.WriteRune('x')
}
longString = sb.String()
}
func TestParseFilter(t *testing.T) {
tests := []struct {
name string
filter string
want *Filter
wantErr bool
}{
{
name: "empty",
},
{
name: "too long",
filter: longString,
wantErr: true,
},
{
name: "invalid syntax",
filter: "fooBar[['baz']]",
wantErr: true,
},
{
name: "unknown binary operator",
filter: `userName fu "bjensen"`,
wantErr: true,
},
{
name: "unknown unary operator",
filter: `userName ok`,
wantErr: true,
},
// test cases from https://datatracker.ietf.org/doc/html/rfc7644#section-3.4.2.2
{
name: "negation",
filter: `not(username pr)`,
want: &Filter{
Root: Segment{
OrExp: OrLogExp{
Left: AndLogExp{
Left: ValueAtom{
Negation: &Segment{
OrExp: OrLogExp{
Left: AndLogExp{
Left: ValueAtom{
AttrExp: &AttrExp{
UnaryCondition: &UnaryCondition{
Left: AttrPath{
AttrName: "username",
},
Operator: UnaryConditionOperator{
Present: true,
},
},
},
},
},
},
},
},
},
},
},
},
},
{
name: "number",
filter: `age gt 10`,
want: &Filter{
Root: Segment{
OrExp: OrLogExp{
Left: AndLogExp{
Left: ValueAtom{
AttrExp: &AttrExp{
BinaryCondition: &BinaryCondition{
Left: AttrPath{
AttrName: "age",
},
Operator: CompareOp{
GreaterThan: true,
},
Right: CompValue{
Int: gu.Ptr(10),
},
},
},
},
},
},
},
},
},
{
name: "float",
filter: `age gt 10.5`,
want: &Filter{
Root: Segment{
OrExp: OrLogExp{
Left: AndLogExp{
Left: ValueAtom{
AttrExp: &AttrExp{
BinaryCondition: &BinaryCondition{
Left: AttrPath{
AttrName: "age",
},
Operator: CompareOp{
GreaterThan: true,
},
Right: CompValue{
Float: gu.Ptr(10.5),
},
},
},
},
},
},
},
},
},
{
name: "null",
filter: `age eq null`,
want: &Filter{
Root: Segment{
OrExp: OrLogExp{
Left: AndLogExp{
Left: ValueAtom{
AttrExp: &AttrExp{
BinaryCondition: &BinaryCondition{
Left: AttrPath{
AttrName: "age",
},
Operator: CompareOp{
Equal: true,
},
Right: CompValue{
Null: true,
},
},
},
},
},
},
},
},
},
{
name: "simple binary operator",
filter: `userName eq "bjensen"`,
want: &Filter{
Root: Segment{
OrExp: OrLogExp{
Left: AndLogExp{
Left: ValueAtom{
AttrExp: &AttrExp{
BinaryCondition: &BinaryCondition{
Left: AttrPath{
AttrName: "userName",
},
Operator: CompareOp{
Equal: true,
},
Right: CompValue{
StringValue: gu.Ptr("bjensen"),
},
},
},
},
},
},
},
},
},
{
name: "uppercase binary operator",
filter: `userName EQ "bjensen"`,
want: &Filter{
Root: Segment{
OrExp: OrLogExp{
Left: AndLogExp{
Left: ValueAtom{
AttrExp: &AttrExp{
BinaryCondition: &BinaryCondition{
Left: AttrPath{
AttrName: "userName",
},
Operator: CompareOp{
Equal: true,
},
Right: CompValue{
StringValue: gu.Ptr("bjensen"),
},
},
},
},
},
},
},
},
},
{
name: "case-insensitive literals and operators",
filter: `active Eq TRue`,
want: &Filter{
Root: Segment{
OrExp: OrLogExp{
Left: AndLogExp{
Left: ValueAtom{
AttrExp: &AttrExp{
BinaryCondition: &BinaryCondition{
Left: AttrPath{
AttrName: "active",
},
Operator: CompareOp{
Equal: true,
},
Right: CompValue{
BooleanTrue: true,
},
},
},
},
},
},
},
},
},
{
name: "extra whitespace",
filter: `userName eq "bjensen"`,
want: &Filter{
Root: Segment{
OrExp: OrLogExp{
Left: AndLogExp{
Left: ValueAtom{
AttrExp: &AttrExp{
BinaryCondition: &BinaryCondition{
Left: AttrPath{
AttrName: "userName",
},
Operator: CompareOp{
Equal: true,
},
Right: CompValue{
StringValue: gu.Ptr("bjensen"),
},
},
},
},
},
},
},
},
},
{
name: "nested attribute binary operator",
filter: `name.familyName co "O'Malley"`,
want: &Filter{
Root: Segment{
OrExp: OrLogExp{
Left: AndLogExp{
Left: ValueAtom{
AttrExp: &AttrExp{
BinaryCondition: &BinaryCondition{
Left: AttrPath{
AttrName: "name",
SubAttr: gu.Ptr("familyName"),
},
Operator: CompareOp{
Contains: true,
},
Right: CompValue{
StringValue: gu.Ptr("O'Malley"),
},
},
},
},
},
},
},
},
},
{
name: "urn prefixed",
filter: `urn:ietf:params:scim:schemas:core:2.0:User:userName sw "J"`,
want: &Filter{
Root: Segment{
OrExp: OrLogExp{
Left: AndLogExp{
Left: ValueAtom{
AttrExp: &AttrExp{
BinaryCondition: &BinaryCondition{
Left: AttrPath{
UrnAttributePrefix: gu.Ptr("urn:ietf:params:scim:schemas:core:2.0:User:"),
AttrName: "userName",
},
Operator: CompareOp{
StartsWith: true,
},
Right: CompValue{
StringValue: gu.Ptr("J"),
},
},
},
},
},
},
},
},
},
{
name: "urn prefixed nested",
filter: `urn:ietf:params:scim:schemas:core:2.0:User:userName sw "J" and urn:ietf:params:scim:schemas:core:2.0:User:emails.value ew "@example.com"`,
want: &Filter{
Root: Segment{
OrExp: OrLogExp{
Left: AndLogExp{
Left: ValueAtom{
AttrExp: &AttrExp{
BinaryCondition: &BinaryCondition{
Left: AttrPath{
UrnAttributePrefix: gu.Ptr("urn:ietf:params:scim:schemas:core:2.0:User:"),
AttrName: "userName",
},
Operator: CompareOp{
StartsWith: true,
},
Right: CompValue{
StringValue: gu.Ptr("J"),
},
},
},
},
Right: &AndLogExp{
Left: ValueAtom{
AttrExp: &AttrExp{
BinaryCondition: &BinaryCondition{
Left: AttrPath{
UrnAttributePrefix: gu.Ptr("urn:ietf:params:scim:schemas:core:2.0:User:"),
AttrName: "emails",
SubAttr: gu.Ptr("value"),
},
Operator: CompareOp{
EndsWith: true,
},
Right: CompValue{
StringValue: gu.Ptr("@example.com"),
},
},
},
},
},
},
},
},
},
},
{
name: "unary operator",
filter: `title pr`,
want: &Filter{
Root: Segment{
OrExp: OrLogExp{
Left: AndLogExp{
Left: ValueAtom{
AttrExp: &AttrExp{
UnaryCondition: &UnaryCondition{
Left: AttrPath{
AttrName: "title",
},
Operator: UnaryConditionOperator{
Present: true,
},
},
},
},
},
},
},
},
},
{
name: "binary nested date operator",
filter: `meta.lastModified gt "2011-05-13T04:42:34Z"`,
want: &Filter{
Root: Segment{
OrExp: OrLogExp{
Left: AndLogExp{
Left: ValueAtom{
AttrExp: &AttrExp{
BinaryCondition: &BinaryCondition{
Left: AttrPath{
AttrName: "meta",
SubAttr: gu.Ptr("lastModified"),
},
Operator: CompareOp{
GreaterThan: true,
},
Right: CompValue{
StringValue: gu.Ptr("2011-05-13T04:42:34Z"),
},
},
},
},
},
},
},
},
},
{
name: "and logical expression",
filter: `title pr and userType eq "Employee"`,
want: &Filter{
Root: Segment{
OrExp: OrLogExp{
Left: AndLogExp{
Left: ValueAtom{
AttrExp: &AttrExp{
UnaryCondition: &UnaryCondition{
Left: AttrPath{
AttrName: "title",
},
Operator: UnaryConditionOperator{
Present: true,
},
},
},
},
Right: &AndLogExp{
Left: ValueAtom{
AttrExp: &AttrExp{
BinaryCondition: &BinaryCondition{
Left: AttrPath{
AttrName: "userType",
},
Operator: CompareOp{
Equal: true,
},
Right: CompValue{
StringValue: gu.Ptr("Employee"),
},
},
},
},
},
},
},
},
},
},
{
name: "nested and / or with grouping",
filter: `userType eq "Employee" and (emails co "example.com" or emails.value co "example.org")`,
want: &Filter{
Root: Segment{
OrExp: OrLogExp{
Left: AndLogExp{
Left: ValueAtom{
AttrExp: &AttrExp{
BinaryCondition: &BinaryCondition{
Left: AttrPath{
AttrName: "userType",
},
Operator: CompareOp{
Equal: true,
},
Right: CompValue{
StringValue: gu.Ptr("Employee"),
},
},
},
},
Right: &AndLogExp{
Left: ValueAtom{
SubFilter: &Segment{
OrExp: OrLogExp{
Left: AndLogExp{
Left: ValueAtom{
AttrExp: &AttrExp{
BinaryCondition: &BinaryCondition{
Left: AttrPath{
AttrName: "emails",
},
Operator: CompareOp{
Contains: true,
},
Right: CompValue{
StringValue: gu.Ptr("example.com"),
},
},
},
},
},
Right: &OrLogExp{
Left: AndLogExp{
Left: ValueAtom{
AttrExp: &AttrExp{
BinaryCondition: &BinaryCondition{
Left: AttrPath{
AttrName: "emails",
SubAttr: gu.Ptr("value"),
},
Operator: CompareOp{
Contains: true,
},
Right: CompValue{
StringValue: gu.Ptr("example.org"),
},
},
},
},
},
},
},
},
},
},
},
},
},
},
},
{
name: "nested and / or without grouping",
filter: `userType eq "Employee" and emails co "example.com" or emails.value co "example2.org"`,
want: &Filter{
Root: Segment{
OrExp: OrLogExp{
Left: AndLogExp{
Left: ValueAtom{
AttrExp: &AttrExp{
BinaryCondition: &BinaryCondition{
Left: AttrPath{
AttrName: "userType",
},
Operator: CompareOp{
Equal: true,
},
Right: CompValue{
StringValue: gu.Ptr("Employee"),
},
},
},
},
Right: &AndLogExp{
Left: ValueAtom{
AttrExp: &AttrExp{
BinaryCondition: &BinaryCondition{
Left: AttrPath{
AttrName: "emails",
},
Operator: CompareOp{
Contains: true,
},
Right: CompValue{
StringValue: gu.Ptr("example.com"),
},
},
},
},
},
},
Right: &OrLogExp{
Left: AndLogExp{
Left: ValueAtom{
AttrExp: &AttrExp{
BinaryCondition: &BinaryCondition{
Left: AttrPath{
AttrName: "emails",
SubAttr: gu.Ptr("value"),
},
Operator: CompareOp{
Contains: true,
},
Right: CompValue{
StringValue: gu.Ptr("example2.org"),
},
},
},
},
},
},
},
},
},
},
{
name: "nested and / or with negated grouping",
filter: `userType ne "Employee" and not (emails co "example.com" or emails.value co "example.org")`,
want: &Filter{
Root: Segment{
OrExp: OrLogExp{
Left: AndLogExp{
Left: ValueAtom{
AttrExp: &AttrExp{
BinaryCondition: &BinaryCondition{
Left: AttrPath{
AttrName: "userType",
},
Operator: CompareOp{
NotEqual: true,
},
Right: CompValue{
StringValue: gu.Ptr("Employee"),
},
},
},
},
Right: &AndLogExp{
Left: ValueAtom{
Negation: &Segment{
OrExp: OrLogExp{
Left: AndLogExp{
Left: ValueAtom{
AttrExp: &AttrExp{
BinaryCondition: &BinaryCondition{
Left: AttrPath{
AttrName: "emails",
},
Operator: CompareOp{
Contains: true,
},
Right: CompValue{
StringValue: gu.Ptr("example.com"),
},
},
},
},
},
Right: &OrLogExp{
Left: AndLogExp{
Left: ValueAtom{
AttrExp: &AttrExp{
BinaryCondition: &BinaryCondition{
Left: AttrPath{
AttrName: "emails",
SubAttr: gu.Ptr("value"),
},
Operator: CompareOp{
Contains: true,
},
Right: CompValue{
StringValue: gu.Ptr("example.org"),
},
},
},
},
},
},
},
},
},
},
},
},
},
},
},
{
name: "nested value path path",
filter: `userType eq "Employee" and emails[type eq "work" and value co "@example.com"]`,
want: &Filter{
Root: Segment{
OrExp: OrLogExp{
Left: AndLogExp{
Left: ValueAtom{
AttrExp: &AttrExp{
BinaryCondition: &BinaryCondition{
Left: AttrPath{
AttrName: "userType",
},
Operator: CompareOp{
Equal: true,
},
Right: CompValue{
StringValue: gu.Ptr("Employee"),
},
},
},
},
Right: &AndLogExp{
Left: ValueAtom{
ValuePath: &ValuePath{
AttrPath: AttrPath{
AttrName: "emails",
},
ValFilter: OrLogExp{
Left: AndLogExp{
Left: ValueAtom{
AttrExp: &AttrExp{
BinaryCondition: &BinaryCondition{
Left: AttrPath{
AttrName: "type",
},
Operator: CompareOp{
Equal: true,
},
Right: CompValue{
StringValue: gu.Ptr("work"),
},
},
},
},
Right: &AndLogExp{
Left: ValueAtom{
AttrExp: &AttrExp{
BinaryCondition: &BinaryCondition{
Left: AttrPath{
AttrName: "value",
},
Operator: CompareOp{
Contains: true,
},
Right: CompValue{
StringValue: gu.Ptr("@example.com"),
},
},
},
},
},
},
},
},
},
},
},
},
},
},
},
{
name: "complex value path filter",
filter: `emails[type eq "work" and value co "@example.com"] or ims[type eq "xmpp" and value co "@foo.com"]`,
want: &Filter{
Root: Segment{
OrExp: OrLogExp{
Left: AndLogExp{
Left: ValueAtom{
ValuePath: &ValuePath{
AttrPath: AttrPath{
AttrName: "emails",
},
ValFilter: OrLogExp{
Left: AndLogExp{
Left: ValueAtom{
AttrExp: &AttrExp{
BinaryCondition: &BinaryCondition{
Left: AttrPath{
AttrName: "type",
},
Operator: CompareOp{
Equal: true,
},
Right: CompValue{
StringValue: gu.Ptr("work"),
},
},
},
},
Right: &AndLogExp{
Left: ValueAtom{
AttrExp: &AttrExp{
BinaryCondition: &BinaryCondition{
Left: AttrPath{
AttrName: "value",
},
Operator: CompareOp{
Contains: true,
},
Right: CompValue{
StringValue: gu.Ptr("@example.com"),
},
},
},
},
},
},
},
},
},
},
Right: &OrLogExp{
Left: AndLogExp{
Left: ValueAtom{
ValuePath: &ValuePath{
AttrPath: AttrPath{
AttrName: "ims",
},
ValFilter: OrLogExp{
Left: AndLogExp{
Left: ValueAtom{
AttrExp: &AttrExp{
BinaryCondition: &BinaryCondition{
Left: AttrPath{
AttrName: "type",
},
Operator: CompareOp{
Equal: true,
},
Right: CompValue{
StringValue: gu.Ptr("xmpp"),
},
},
},
},
Right: &AndLogExp{
Left: ValueAtom{
AttrExp: &AttrExp{
BinaryCondition: &BinaryCondition{
Left: AttrPath{
AttrName: "value",
},
Operator: CompareOp{
Contains: true,
},
Right: CompValue{
StringValue: gu.Ptr("@foo.com"),
},
},
},
},
},
},
},
},
},
},
},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ParseFilter(tt.filter)
if (err != nil) != tt.wantErr {
t.Errorf("ParseFilter() error = %#v, wantErr %#v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("ParseFilter() got = %s, want %s", got, tt.want)
}
})
}
}

View File

@ -0,0 +1,347 @@
package filter
import (
"context"
"strings"
"time"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/scim/schemas"
"github.com/zitadel/zitadel/internal/api/scim/serrors"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/zerrors"
)
// FieldPathMapping maps lowercase json field names of the resource to the matching column in the projection
type FieldPathMapping map[string]*QueryFieldInfo
// queryBuilder builds a query for a filter based on the visitor pattern
type queryBuilder struct {
ctx context.Context
schema schemas.ScimSchemaType
fieldPathMapping FieldPathMapping
// attrPathPrefixes prefixes of attributes that
// should also take into account when resolving an attr path to a column.
// This is used for "a[b eq 10]" expressions, when resolving b, a would be the prefix.
attrPathPrefixStack []*AttrPath
}
type MappedQueryBuilderFunc func(ctx context.Context, compareValue *CompValue, op *CompareOp) (query.SearchQuery, error)
type QueryFieldInfo struct {
Column query.Column
FieldType FieldType
BuildMappedQuery MappedQueryBuilderFunc
}
type FieldType int
const (
FieldTypeCustom FieldType = iota
FieldTypeString
FieldTypeNumber
FieldTypeBoolean
FieldTypeTimestamp
)
func (m FieldPathMapping) Resolve(path string) (*QueryFieldInfo, error) {
info, ok := m[strings.ToLower(path)]
if !ok {
logging.WithFields("fieldPath", path).Info("scim filter: Invalid filter expression: unknown or unsupported field")
return nil, serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgumentf(nil, "SCIM-FF433", "Invalid filter expression: unknown or unsupported field %s", path))
}
return info, nil
}
func (f *Filter) BuildQuery(ctx context.Context, schema schemas.ScimSchemaType, fieldPathColumnMapping FieldPathMapping) (query.SearchQuery, error) {
builder := &queryBuilder{
ctx: ctx,
schema: schema,
fieldPathMapping: fieldPathColumnMapping,
}
return builder.visitSegment(&f.Root)
}
func (b *queryBuilder) pushAttrPath(path *AttrPath) {
b.attrPathPrefixStack = append(b.attrPathPrefixStack, path)
}
func (b *queryBuilder) popAttrPath() {
b.attrPathPrefixStack = b.attrPathPrefixStack[:len(b.attrPathPrefixStack)-1]
}
func (b *queryBuilder) visitSegment(s *Segment) (query.SearchQuery, error) {
return b.visitOr(&s.OrExp)
}
func (b *queryBuilder) visitOr(or *OrLogExp) (query.SearchQuery, error) {
left, err := b.visitAnd(&or.Left)
if err != nil {
return nil, err
}
if or.Right == nil {
return left, nil
}
right, err := b.visitOr(or.Right)
if err != nil {
return nil, err
}
// flatten nested or queries
if rightOr, ok := right.(*query.OrQuery); ok {
rightOr.Prepend(left)
return rightOr, nil
}
return query.NewOrQuery(left, right)
}
func (b *queryBuilder) visitAnd(and *AndLogExp) (query.SearchQuery, error) {
left, err := b.visitAtom(&and.Left)
if err != nil {
return nil, err
}
if and.Right == nil {
return left, nil
}
right, err := b.visitAnd(and.Right)
if err != nil {
return nil, err
}
// flatten nested and queries
if rightAnd, ok := right.(*query.AndQuery); ok {
rightAnd.Prepend(left)
return rightAnd, nil
}
return query.NewAndQuery(left, right)
}
func (b *queryBuilder) visitAtom(atom *ValueAtom) (query.SearchQuery, error) {
switch {
case atom.SubFilter != nil:
return b.visitSegment(atom.SubFilter)
case atom.Negation != nil:
f, err := b.visitSegment(atom.Negation)
if err != nil {
return nil, err
}
return query.NewNotQuery(f)
case atom.ValuePath != nil:
return b.visitValuePath(atom.ValuePath)
case atom.AttrExp != nil:
return b.visitAttrExp(atom.AttrExp)
}
return nil, serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgument(nil, "SCIM-FF412", "Invalid filter expression"))
}
func (b *queryBuilder) visitValuePath(path *ValuePath) (query.SearchQuery, error) {
b.pushAttrPath(&path.AttrPath)
defer b.popAttrPath()
return b.visitOr(&path.ValFilter)
}
func (b *queryBuilder) visitAttrExp(exp *AttrExp) (query.SearchQuery, error) {
switch {
case exp.UnaryCondition != nil:
return b.visitUnaryCondition(exp.UnaryCondition)
case exp.BinaryCondition != nil:
return b.visitBinaryCondition(exp.BinaryCondition)
}
return nil, serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgument(nil, "SCIM-FF413", "Invalid filter expression"))
}
func (b *queryBuilder) visitUnaryCondition(condition *UnaryCondition) (query.SearchQuery, error) {
// only supported unary operator is present
if !condition.Operator.Present {
return nil, serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgument(nil, "SCIM-FF419", "Unknown unary filter operator"))
}
field, err := b.visitAttrPath(&condition.Left)
if err != nil {
return nil, err
}
if field.FieldType == FieldTypeCustom {
return nil, serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgument(nil, "SCIM-FXX49", "Unsupported attribute for unary filter operator"))
}
return query.NewNotNullQuery(field.Column)
}
func (b *queryBuilder) visitBinaryCondition(condition *BinaryCondition) (query.SearchQuery, error) {
left, err := b.visitAttrPath(&condition.Left)
if err != nil {
return nil, err
}
if condition.Operator.Equal && condition.Right.Null {
return query.NewIsNullQuery(left.Column)
}
if condition.Operator.NotEqual && condition.Right.Null {
return query.NewNotNullQuery(left.Column)
}
switch left.FieldType {
case FieldTypeCustom:
return left.BuildMappedQuery(b.ctx, &condition.Right, &condition.Operator)
case FieldTypeTimestamp:
return b.buildTimestampQuery(left, condition.Right, &condition.Operator)
case FieldTypeString:
return b.buildTextQuery(left, condition.Right, &condition.Operator)
case FieldTypeNumber:
return b.buildNumberQuery(left, condition.Right, &condition.Operator)
case FieldTypeBoolean:
return b.buildBooleanQuery(left, condition.Right, &condition.Operator)
}
logging.WithFields("fieldType", left.FieldType).Error("Unknown field type")
return nil, serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgument(nil, "SCIM-FF417", "Unknown filter expression field type"))
}
func (b *queryBuilder) buildTimestampQuery(left *QueryFieldInfo, right CompValue, op *CompareOp) (query.SearchQuery, error) {
if right.StringValue == nil {
return nil, serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgument(nil, "SCIM-FF451", "Invalid filter expression: the compare value for a timestamp has to be a RFC3339 string"))
}
timestamp, err := time.Parse(time.RFC3339, *right.StringValue)
if err != nil {
return nil, serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgument(err, "SCIM-FF421", "Invalid filter expression: the compare value for a timestamp has to be a RFC3339 string"))
}
var comp query.TimestampComparison
switch {
case op.Equal:
comp = query.TimestampEquals
case op.GreaterThan:
comp = query.TimestampGreater
case op.GreaterThanOrEqual:
comp = query.TimestampGreaterOrEquals
case op.LessThan:
comp = query.TimestampLess
case op.LessThanOrEqual:
comp = query.TimestampLessOrEquals
default:
return nil, serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgument(nil, "SCIM-FF422", "Invalid filter expression: unsupported comparison operator for timestamp fields"))
}
return query.NewTimestampQuery(left.Column, timestamp, comp)
}
func (b *queryBuilder) buildNumberQuery(left *QueryFieldInfo, right CompValue, op *CompareOp) (query.SearchQuery, error) {
if right.Int == nil && right.Float == nil {
return nil, serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgument(nil, "SCIM-FF423", "Invalid filter expression: unsupported comparison value for numeric fields"))
}
var comp query.NumberComparison
switch {
case op.Equal:
comp = query.NumberEquals
case op.NotEqual:
comp = query.NumberNotEquals
case op.GreaterThan:
comp = query.NumberGreater
case op.GreaterThanOrEqual:
comp = query.NumberGreaterOrEqual
case op.LessThan:
comp = query.NumberLess
case op.LessThanOrEqual:
comp = query.NumberLessOrEqual
default:
return nil, serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgument(nil, "SCIM-FF424", "Invalid filter expression: unsupported comparison operator for number fields"))
}
var value interface{}
if right.Int != nil {
value = *right.Int
} else {
value = *right.Float
}
return query.NewNumberQuery(left.Column, value, comp)
}
func (b *queryBuilder) buildBooleanQuery(field *QueryFieldInfo, right CompValue, op *CompareOp) (query.SearchQuery, error) {
if !right.BooleanTrue && !right.BooleanFalse {
return nil, serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgument(nil, "SCIM-FF428", "Invalid filter expression: unsupported comparison value for boolean field"))
}
if !op.Equal && !op.NotEqual {
return nil, serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgument(nil, "SCIM-FF427", "Invalid filter expression: unsupported comparison operator for boolean field"))
}
return query.NewBoolQuery(field.Column, (op.Equal && right.BooleanTrue) || (op.NotEqual && right.BooleanFalse))
}
func (b *queryBuilder) buildTextQuery(field *QueryFieldInfo, right CompValue, op *CompareOp) (query.SearchQuery, error) {
if right.StringValue == nil {
return nil, serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgument(nil, "SCIM-FF429", "Invalid filter expression: unsupported comparison value for text field"))
}
var comp query.TextComparison
switch {
case op.Equal:
comp = query.TextEquals
case op.NotEqual:
comp = query.TextNotEquals
case op.Contains:
comp = query.TextContains
case op.StartsWith:
comp = query.TextStartsWith
case op.EndsWith:
comp = query.TextEndsWith
default:
return nil, serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgument(nil, "SCIM-FF425", "Invalid filter expression: unsupported comparison operator for text fields"))
}
return query.NewTextQuery(field.Column, *right.StringValue, comp)
}
func (b *queryBuilder) visitAttrPath(attrPath *AttrPath) (*QueryFieldInfo, error) {
b.pushAttrPath(attrPath)
defer b.popAttrPath()
field, err := b.reduceAttrPaths(b.attrPathPrefixStack)
if err != nil {
return nil, err
}
return b.fieldPathMapping.Resolve(field)
}
// reduceAttrPaths reduces a slice of AttrPath
// to a simple urn + fieldPath combination.
// The urn is ensured to be unique across all segments and either to be empty or to match the schema of the builder.
// The resulting fieldPath is in the form of a.b.c with a minimum of one path segment.
func (b *queryBuilder) reduceAttrPaths(attrPaths []*AttrPath) (fieldPath string, err error) {
if len(attrPaths) == 0 {
err = serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgument(nil, "SCIM-FF431", "Invalid filter expression: unknown urn attribute prefix"))
return fieldPath, err
}
sb := strings.Builder{}
for _, p := range attrPaths {
if err = p.validateSchema(b.schema); err != nil {
return
}
sb.WriteString(p.FieldPath())
sb.WriteRune('.')
}
fieldPath = sb.String()
fieldPath = strings.TrimRight(fieldPath, ".") // trim very last '.'
return fieldPath, err
}

View File

@ -0,0 +1,497 @@
package filter
import (
"context"
"reflect"
"testing"
"time"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/api/scim/schemas"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/test"
)
var fieldPathColumnMapping = FieldPathMapping{
// a timestamp field
"meta.lastmodified": {
Column: query.UserChangeDateCol,
FieldType: FieldTypeTimestamp,
},
// a string field
"username": {
Column: query.UserUsernameCol,
FieldType: FieldTypeString,
},
// a nested string field
"name.familyname": {
Column: query.HumanLastNameCol,
FieldType: FieldTypeString,
},
// a field which is a list in scim
"emails": {
Column: query.HumanEmailCol,
FieldType: FieldTypeString,
},
// the default value field
"emails.value": {
Column: query.HumanEmailCol,
FieldType: FieldTypeString,
},
// pseudo field to test number queries
"age": {
Column: query.HumanGenderCol,
FieldType: FieldTypeNumber,
},
// pseudo field to test boolean queries
"locked": {
Column: query.HumanPasswordChangeRequiredCol,
FieldType: FieldTypeBoolean,
},
// mapped field
"active": {
Column: query.UserStateCol,
FieldType: FieldTypeCustom,
BuildMappedQuery: func(ctx context.Context, compareValue *CompValue, op *CompareOp) (query.SearchQuery, error) {
// very simple mock implementation
return query.NewTextQuery(query.UserUsernameCol, "fooBar", query.TextContains)
},
},
}
func TestFilter_BuildQuery(t *testing.T) {
tests := []struct {
name string
filter string
want query.SearchQuery
wantErr bool
}{
{
name: "unknown attribute",
filter: `foobar eq "bjensen"`,
wantErr: true,
},
{
name: "simple binary operator",
filter: `userName eq "bjensen"`,
want: test.Must(query.NewTextQuery(query.UserUsernameCol, "bjensen", query.TextEquals)),
},
{
name: "binary operator equals null",
filter: `userName eq null`,
want: test.Must(query.NewIsNullQuery(query.UserUsernameCol)),
},
{
name: "binary operator not equals null",
filter: `userName ne null`,
want: test.Must(query.NewNotNullQuery(query.UserUsernameCol)),
},
{
name: "binary number operator on string field",
filter: `userName gt 10`,
wantErr: true,
},
{
name: "binary number operator greater",
filter: `age gt 10`,
want: test.Must(query.NewNumberQuery(query.HumanGenderCol, 10, query.NumberGreater)),
},
{
name: "binary number operator greater equal",
filter: `age ge 10`,
want: test.Must(query.NewNumberQuery(query.HumanGenderCol, 10, query.NumberGreaterOrEqual)),
},
{
name: "binary number operator less",
filter: `age lt 10`,
want: test.Must(query.NewNumberQuery(query.HumanGenderCol, 10, query.NumberLess)),
},
{
name: "binary number operator less float",
filter: `age lt 10.5`,
want: test.Must(query.NewNumberQuery(query.HumanGenderCol, 10.5, query.NumberLess)),
},
{
name: "binary number unsupported operator",
filter: `age co 10.5`,
wantErr: true,
},
{
name: "binary number unsupported comparison value",
filter: `age gt "foo"`,
wantErr: true,
},
{
name: "binary number operator less equal",
filter: `age le 10`,
want: test.Must(query.NewNumberQuery(query.HumanGenderCol, 10, query.NumberLessOrEqual)),
},
{
name: "binary number operator equals",
filter: `age eq 10`,
want: test.Must(query.NewNumberQuery(query.HumanGenderCol, 10, query.NumberEquals)),
},
{
name: "binary number operator not equals",
filter: `age ne 10`,
want: test.Must(query.NewNumberQuery(query.HumanGenderCol, 10, query.NumberNotEquals)),
},
{
name: "binary bool operator equals string",
filter: `locked eq "foo"`,
wantErr: true,
},
{
name: "binary bool operator startswith bool",
filter: `locked sw true`,
wantErr: true,
},
{
name: "binary bool operator equals",
filter: `locked eq true`,
want: test.Must(query.NewBoolQuery(query.HumanPasswordChangeRequiredCol, true)),
},
{
name: "binary bool operator not equals",
filter: `locked ne true`,
want: test.Must(query.NewBoolQuery(query.HumanPasswordChangeRequiredCol, false)),
},
{
name: "binary bool operator not equals false",
filter: `locked ne false`,
want: test.Must(query.NewBoolQuery(query.HumanPasswordChangeRequiredCol, true)),
},
{
name: "binary string invalid operator",
filter: `username gt "test"`,
wantErr: true,
},
{
name: "nested attribute binary operator",
filter: `name.familyName co "O'Malley"`,
want: test.Must(query.NewTextQuery(query.HumanLastNameCol, "O'Malley", query.TextContains)),
},
{
name: "urn prefixed binary operator",
filter: `urn:ietf:params:scim:schemas:core:2.0:User:userName sw "J"`,
want: test.Must(query.NewTextQuery(query.UserUsernameCol, "J", query.TextStartsWith)),
},
{
name: "urn prefixed nested binary operator",
filter: `urn:ietf:params:scim:schemas:core:2.0:User:emails[value sw "hans.peter@"]`,
want: test.Must(query.NewTextQuery(query.HumanEmailCol, "hans.peter@", query.TextStartsWith)),
},
{
name: "invalid urn prefixed nested binary operator",
filter: `urn:ietf:params:scim:schemas:core:2.0:UserFoo:emails[value sw "hans.peter@"]`,
wantErr: true,
},
{
name: "unary operator",
filter: `name.familyName pr`,
want: test.Must(query.NewNotNullQuery(query.HumanLastNameCol)),
},
{
name: "and logical expression",
filter: `name.familyName pr and userName eq "bjensen"`,
want: test.Must(query.NewAndQuery(test.Must(query.NewNotNullQuery(query.HumanLastNameCol)), test.Must(query.NewTextQuery(query.UserUsernameCol, "bjensen", query.TextEquals)))),
},
{
name: "timestamp condition equal",
filter: `meta.lastModified eq "2011-05-13T04:42:34Z"`,
want: test.Must(query.NewTimestampQuery(query.UserChangeDateCol, time.Date(2011, time.May, 13, 4, 42, 34, 0, time.UTC), query.TimestampEquals)),
},
{
name: "timestamp condition greater equals",
filter: `meta.lastModified ge "2011-05-13T04:42:34Z"`,
want: test.Must(query.NewTimestampQuery(query.UserChangeDateCol, time.Date(2011, time.May, 13, 4, 42, 34, 0, time.UTC), query.TimestampGreaterOrEquals)),
},
{
name: "timestamp condition greater",
filter: `meta.lastModified gt "2011-05-13T04:42:34Z"`,
want: test.Must(query.NewTimestampQuery(query.UserChangeDateCol, time.Date(2011, time.May, 13, 4, 42, 34, 0, time.UTC), query.TimestampGreater)),
},
{
name: "timestamp condition less equals",
filter: `meta.lastModified le "2011-05-13T04:42:34Z"`,
want: test.Must(query.NewTimestampQuery(query.UserChangeDateCol, time.Date(2011, time.May, 13, 4, 42, 34, 0, time.UTC), query.TimestampLessOrEquals)),
},
{
name: "timestamp condition less",
filter: `meta.lastModified lt "2011-05-13T04:42:34Z"`,
want: test.Must(query.NewTimestampQuery(query.UserChangeDateCol, time.Date(2011, time.May, 13, 4, 42, 34, 0, time.UTC), query.TimestampLess)),
},
{
name: "timestamp condition invalid operator",
filter: `meta.lastModified ew "2011-05-13T04:42:34Z"`,
wantErr: true,
},
{
name: "timestamp condition invalid format",
filter: `meta.lastModified ge "2011-05-13T0:34Z"`,
wantErr: true,
},
{
name: "timestamp condition invalid comparison value",
filter: `meta.lastModified ge 15`,
wantErr: true,
},
{
name: "nested and / or without grouping",
filter: `userName eq "rudolpho" and emails co "example.com" or emails.value co "example2.org"`,
want: test.Must(query.NewOrQuery(
test.Must(query.NewAndQuery(
test.Must(query.NewTextQuery(query.UserUsernameCol, "rudolpho", query.TextEquals)),
test.Must(query.NewTextQuery(query.HumanEmailCol, "example.com", query.TextContains))),
),
test.Must(query.NewTextQuery(query.HumanEmailCol, "example2.org", query.TextContains)))),
},
{
name: "nested and / or with grouping",
filter: `userName ne "rudolpho" and (emails co "example.com" or emails.value co "example.org")`,
want: test.Must(query.NewAndQuery(
test.Must(query.NewTextQuery(query.UserUsernameCol, "rudolpho", query.TextNotEquals)),
test.Must(query.NewOrQuery(
test.Must(query.NewTextQuery(query.HumanEmailCol, "example.com", query.TextContains)),
test.Must(query.NewTextQuery(query.HumanEmailCol, "example.org", query.TextContains)),
)),
)),
},
{
name: "nested value path path",
filter: `userName eq "Hans" and emails[value ew "@example.org" or value ew "@example.com"]`,
want: test.Must(query.NewAndQuery(
test.Must(query.NewTextQuery(query.UserUsernameCol, "Hans", query.TextEquals)),
test.Must(query.NewOrQuery(
test.Must(query.NewTextQuery(query.HumanEmailCol, "@example.org", query.TextEndsWith)),
test.Must(query.NewTextQuery(query.HumanEmailCol, "@example.com", query.TextEndsWith)),
)),
)),
},
{
name: "or value path filter",
filter: `emails[value ew "@example.org" and value co "@example.com"] or emails[value sw "hans" or value sw "peter"]`,
want: test.Must(query.NewOrQuery(
test.Must(query.NewAndQuery(
test.Must(query.NewTextQuery(query.HumanEmailCol, "@example.org", query.TextEndsWith)),
test.Must(query.NewTextQuery(query.HumanEmailCol, "@example.com", query.TextContains)),
)),
test.Must(query.NewTextQuery(query.HumanEmailCol, "hans", query.TextStartsWith)),
test.Must(query.NewTextQuery(query.HumanEmailCol, "peter", query.TextStartsWith)),
)),
},
{
name: "and value path filter",
filter: `emails[value ew "@example.com"] and name.familyname co "hans" and username co "peter"`,
want: test.Must(query.NewAndQuery(
test.Must(query.NewTextQuery(query.HumanEmailCol, "@example.com", query.TextEndsWith)),
test.Must(query.NewTextQuery(query.HumanLastNameCol, "hans", query.TextContains)),
test.Must(query.NewTextQuery(query.UserUsernameCol, "peter", query.TextContains)),
)),
},
{
name: "negation",
filter: `not(username eq "foo")`,
want: test.Must(query.NewNotQuery(test.Must(query.NewTextQuery(query.UserUsernameCol, "foo", query.TextEquals)))),
},
{
name: "negation with complex filter",
filter: `not(emails[value ew "@example.com"])`,
want: test.Must(query.NewNotQuery(test.Must(query.NewTextQuery(query.HumanEmailCol, "@example.com", query.TextEndsWith)))),
},
{
name: "nested negation",
filter: `emails[not(value ew "@example.org" or value ew "@example.com")]`,
want: test.Must(query.NewNotQuery(
test.Must(query.NewOrQuery(
test.Must(query.NewTextQuery(query.HumanEmailCol, "@example.org", query.TextEndsWith)),
test.Must(query.NewTextQuery(query.HumanEmailCol, "@example.com", query.TextEndsWith)),
)),
)),
},
{
name: "mapped field",
filter: `active eq true`,
want: test.Must(query.NewTextQuery(query.UserUsernameCol, "fooBar", query.TextContains)),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
f, err := ParseFilter(tt.filter)
require.NoError(t, err)
got, err := f.BuildQuery(context.Background(), schemas.IdUser, fieldPathColumnMapping)
if (err != nil) != tt.wantErr {
t.Errorf("BuildQuery() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("BuildQuery() got = %#v, want %#v", got, tt.want)
}
})
}
}
func Test_queryBuilder_reduceAttrPaths(t *testing.T) {
tests := []struct {
name string
schema string
attrPaths []*AttrPath
wantFieldPath string
wantErr bool
}{
{
name: "empty",
attrPaths: []*AttrPath{},
wantErr: true,
},
{
name: "simple",
attrPaths: []*AttrPath{
{
AttrName: "foo",
},
},
wantFieldPath: "foo",
},
{
name: "multiple simple",
attrPaths: []*AttrPath{
{
AttrName: "foo",
},
{
AttrName: "bar",
},
},
wantFieldPath: "foo.bar",
},
{
name: "with sub attr",
attrPaths: []*AttrPath{
{
AttrName: "foo",
SubAttr: gu.Ptr("bar"),
},
},
wantFieldPath: "foo.bar",
},
{
name: "multiple with sub attr",
attrPaths: []*AttrPath{
{
AttrName: "foo",
SubAttr: gu.Ptr("bar"),
},
{
AttrName: "baz",
SubAttr: gu.Ptr("woo"),
},
},
wantFieldPath: "foo.bar.baz.woo",
},
{
name: "with urn and sub attr",
schema: "urn:foo:bar",
attrPaths: []*AttrPath{
{
UrnAttributePrefix: gu.Ptr("urn:foo:bar:"),
AttrName: "foo",
SubAttr: gu.Ptr("bar"),
},
},
wantFieldPath: "foo.bar",
},
{
name: "multiple with urn and sub attr",
schema: "urn:foo:bar",
attrPaths: []*AttrPath{
{
UrnAttributePrefix: gu.Ptr("urn:foo:bar:"),
AttrName: "foo",
SubAttr: gu.Ptr("bar"),
},
{
UrnAttributePrefix: gu.Ptr("urn:foo:bar:"),
AttrName: "foo2",
SubAttr: gu.Ptr("bar2"),
},
},
wantFieldPath: "foo.bar.foo2.bar2",
},
{
name: "secondary with urn and sub attr",
schema: "urn:foo:bar",
attrPaths: []*AttrPath{
{
AttrName: "foo",
SubAttr: gu.Ptr("bar"),
},
{
UrnAttributePrefix: gu.Ptr("urn:foo:bar:"),
AttrName: "foo2",
SubAttr: gu.Ptr("bar2"),
},
},
wantFieldPath: "foo.bar.foo2.bar2",
},
{
name: "urn mismatch",
schema: "urn:foo:bar",
attrPaths: []*AttrPath{
{
UrnAttributePrefix: gu.Ptr("urn:foo:baz"),
AttrName: "foo",
},
},
wantErr: true,
},
{
name: "nested urn mismatch",
schema: "urn:foo:bar",
attrPaths: []*AttrPath{
{
UrnAttributePrefix: gu.Ptr("urn:foo:bar:"),
AttrName: "foo",
},
{
UrnAttributePrefix: gu.Ptr("urn:foo:baz"),
AttrName: "foo2",
},
},
wantErr: true,
},
{
name: "secondary urn mismatch",
schema: "urn:foo:bar",
attrPaths: []*AttrPath{
{
AttrName: "foo",
},
{
UrnAttributePrefix: gu.Ptr("urn:foo:baz"),
AttrName: "foo2",
},
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
b := &queryBuilder{
schema: schemas.ScimSchemaType(tt.schema),
}
gotFieldPath, err := b.reduceAttrPaths(tt.attrPaths)
if (err != nil) != tt.wantErr {
t.Errorf("reduceAttrPaths() error = %v, wantErr %v", err, tt.wantErr)
return
}
if gotFieldPath != tt.wantFieldPath {
t.Errorf("reduceAttrPaths() gotFieldPath = %v, want %v", gotFieldPath, tt.wantFieldPath)
}
})
}
}

View File

@ -22,6 +22,7 @@ type ResourceHandler[T ResourceHolder] interface {
Replace(ctx context.Context, id string, resource T) (T, error)
Delete(ctx context.Context, id string) error
Get(ctx context.Context, id string) (T, error)
List(ctx context.Context, request *ListRequest) (*ListResponse[T], error)
}
type Resource struct {

View File

@ -7,7 +7,6 @@ import (
"github.com/gorilla/mux"
"github.com/zitadel/zitadel/internal/api/scim/schemas"
"github.com/zitadel/zitadel/internal/api/scim/serrors"
"github.com/zitadel/zitadel/internal/zerrors"
)
@ -16,22 +15,6 @@ type ResourceHandlerAdapter[T ResourceHolder] struct {
handler ResourceHandler[T]
}
type ListRequest struct {
// Count An integer indicating the desired maximum number of query results per page. OPTIONAL.
Count uint64 `json:"count" schema:"count"`
// StartIndex An integer indicating the 1-based index of the first query result. Optional.
StartIndex uint64 `json:"startIndex" schema:"startIndex"`
}
type ListResponse[T any] struct {
Schemas []schemas.ScimSchemaType `json:"schemas"`
ItemsPerPage uint64 `json:"itemsPerPage"`
TotalResults uint64 `json:"totalResults"`
StartIndex uint64 `json:"startIndex"`
Resources []T `json:"Resources"` // according to the rfc this is the only field in PascalCase...
}
func NewResourceHandlerAdapter[T ResourceHolder](handler ResourceHandler[T]) *ResourceHandlerAdapter[T] {
return &ResourceHandlerAdapter[T]{
handler,
@ -62,6 +45,15 @@ func (adapter *ResourceHandlerAdapter[T]) Delete(r *http.Request) error {
return adapter.handler.Delete(r.Context(), id)
}
func (adapter *ResourceHandlerAdapter[T]) List(r *http.Request) (*ListResponse[T], error) {
request, err := readListRequest(r)
if err != nil {
return nil, err
}
return adapter.handler.List(r.Context(), request)
}
func (adapter *ResourceHandlerAdapter[T]) Get(r *http.Request) (T, error) {
id := mux.Vars(r)["id"]
return adapter.handler.Get(r.Context(), id)
@ -71,7 +63,7 @@ func (adapter *ResourceHandlerAdapter[T]) readEntityFromBody(r *http.Request) (T
entity := adapter.handler.NewResource()
err := json.NewDecoder(r.Body).Decode(entity)
if err != nil {
if zerrors.IsZitadelError(err) {
if serrors.IsScimOrZitadelError(err) {
return entity, err
}

View File

@ -0,0 +1,148 @@
package resources
import (
"encoding/json"
"net/http"
zhttp "github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/api/scim/resources/filter"
"github.com/zitadel/zitadel/internal/api/scim/schemas"
"github.com/zitadel/zitadel/internal/api/scim/serrors"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/zerrors"
)
type ListRequest struct {
// Count An integer indicating the desired maximum number of query results per page.
Count int64 `json:"count" schema:"count"`
// StartIndex An integer indicating the 1-based index of the first query result.
StartIndex int64 `json:"startIndex" schema:"startIndex"`
// Filter a scim filter expression to filter the query result.
Filter *filter.Filter `json:"filter,omitempty" schema:"filter"`
// SortBy attribute path to the sort attribute
SortBy string `json:"sortBy" schema:"sortBy"`
SortOrder ListRequestSortOrder `json:"sortOrder" schema:"sortOrder"`
}
type ListResponse[T ResourceHolder] struct {
Schemas []schemas.ScimSchemaType `json:"schemas"`
ItemsPerPage uint64 `json:"itemsPerPage"`
TotalResults uint64 `json:"totalResults"`
StartIndex uint64 `json:"startIndex"`
Resources []T `json:"Resources"` // according to the rfc this is the only field in PascalCase...
}
type ListRequestSortOrder string
const (
ListRequestSortOrderAsc ListRequestSortOrder = "ascending"
ListRequestSortOrderDsc ListRequestSortOrder = "descending"
defaultListCount = 100
maxListCount = 100
)
var parser = zhttp.NewParser()
func (o ListRequestSortOrder) isDefined() bool {
switch o {
case ListRequestSortOrderAsc, ListRequestSortOrderDsc:
return true
default:
return false
}
}
func (o ListRequestSortOrder) IsAscending() bool {
return o == ListRequestSortOrderAsc
}
func newListResponse[T ResourceHolder](totalResultCount uint64, q query.SearchRequest, resources []T) *ListResponse[T] {
return &ListResponse[T]{
Schemas: []schemas.ScimSchemaType{schemas.IdListResponse},
ItemsPerPage: q.Limit,
TotalResults: totalResultCount,
StartIndex: q.Offset + 1, // start index is 1 based
Resources: resources,
}
}
func readListRequest(r *http.Request) (*ListRequest, error) {
request := &ListRequest{
Count: defaultListCount,
StartIndex: 1,
SortOrder: ListRequestSortOrderAsc,
}
switch r.Method {
case http.MethodGet:
if err := parser.Parse(r, request); err != nil {
err = parser.UnwrapParserError(err)
if serrors.IsScimOrZitadelError(err) {
return nil, err
}
return nil, zerrors.ThrowInvalidArgument(nil, "SCIM-ullform", "Could not decode form: "+err.Error())
}
case http.MethodPost:
if err := json.NewDecoder(r.Body).Decode(request); err != nil {
if serrors.IsScimOrZitadelError(err) {
return nil, err
}
return nil, zerrors.ThrowInvalidArgument(nil, "SCIM-ulljson", "Could not decode json: "+err.Error())
}
// json deserialization initializes this field if an empty string is provided
// to not special case this in the resource implementation,
// set it to nil here.
if request.Filter.IsZero() {
request.Filter = nil
}
}
return request, request.validate()
}
func (r *ListRequest) toSearchRequest(defaultSortCol query.Column, fieldPathColumnMapping filter.FieldPathMapping) (query.SearchRequest, error) {
sr := query.SearchRequest{
Offset: uint64(r.StartIndex - 1), // start index is 1 based
Limit: uint64(r.Count),
Asc: r.SortOrder.IsAscending(),
}
if r.SortBy == "" {
// set a default sort to ensure consistent results
sr.SortingColumn = defaultSortCol
} else if sortCol, err := fieldPathColumnMapping.Resolve(r.SortBy); err != nil {
return sr, serrors.ThrowInvalidValue(zerrors.ThrowInvalidArgument(err, "SCIM-SRT1", "SortBy field is unknown or not supported"))
} else {
sr.SortingColumn = sortCol.Column
}
return sr, nil
}
func (r *ListRequest) validate() error {
// according to the spec values < 1 are treated as 1
if r.StartIndex < 1 {
r.StartIndex = 1
}
// according to the spec values < 0 are treated as 0
if r.Count < 0 {
r.Count = 0
} else if r.Count > maxListCount {
return zerrors.ThrowInvalidArgumentf(nil, "SCIM-ucr", "Limit count exceeded, set a count <= %v", maxListCount)
}
if !r.SortOrder.isDefined() {
return zerrors.ThrowInvalidArgument(nil, "SCIM-ucx", "Invalid sort order")
}
return nil
}

View File

@ -0,0 +1,79 @@
package resources
import (
"reflect"
"testing"
"github.com/stretchr/testify/assert"
)
func TestListRequest_validate(t *testing.T) {
tests := []struct {
name string
req *ListRequest
want *ListRequest
wantErr bool
}{
{
name: "valid",
req: &ListRequest{
SortOrder: ListRequestSortOrderAsc,
},
},
{
name: "invalid sort order",
req: &ListRequest{
SortOrder: "fooBar",
},
wantErr: true,
},
{
name: "count too big",
req: &ListRequest{
Count: 99999999,
SortOrder: ListRequestSortOrderAsc,
},
wantErr: true,
},
{
name: "negative start index",
req: &ListRequest{
StartIndex: -1,
Count: 10,
SortOrder: ListRequestSortOrderAsc,
},
want: &ListRequest{
StartIndex: 1,
Count: 10,
SortOrder: ListRequestSortOrderAsc,
},
},
{
name: "negative count",
req: &ListRequest{
StartIndex: 10,
Count: -1,
SortOrder: ListRequestSortOrderAsc,
},
want: &ListRequest{
StartIndex: 10,
Count: 0,
SortOrder: ListRequestSortOrderAsc,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.req.validate()
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
if tt.want != nil && !reflect.DeepEqual(tt.req, tt.want) {
t.Errorf("got: %#v, want: %#v", tt.req, tt.want)
}
})
}
}

View File

@ -183,6 +183,35 @@ func (h *UsersHandler) Get(ctx context.Context, id string) (*ScimUser, error) {
return h.mapToScimUser(ctx, user, metadata), nil
}
func (h *UsersHandler) List(ctx context.Context, request *ListRequest) (*ListResponse[*ScimUser], error) {
q, err := h.buildListQuery(ctx, request)
if err != nil {
return nil, err
}
if request.Count == 0 {
count, err := h.query.CountUsers(ctx, q)
if err != nil {
return nil, err
}
return newListResponse(count, q.SearchRequest, make([]*ScimUser, 0)), nil
}
users, err := h.query.SearchUsers(ctx, q, nil)
if err != nil {
return nil, err
}
metadata, err := h.queryMetadataForUsers(ctx, usersToIDs(users.Users))
if err != nil {
return nil, err
}
scimUsers := h.mapToScimUsers(ctx, users.Users, metadata)
return newListResponse(users.SearchResponse.Count, q.SearchRequest, scimUsers), nil
}
func (h *UsersHandler) queryUserDependencies(ctx context.Context, userID string) ([]*command.CascadingMembership, []string, error) {
userGrantUserQuery, err := query.NewUserGrantUserIDSearchQuery(userID)
if err != nil {

View File

@ -208,6 +208,20 @@ func (h *UsersHandler) mapChangeCommandToScimUser(ctx context.Context, user *Sci
}
}
func (h *UsersHandler) mapToScimUsers(ctx context.Context, users []*query.User, md map[string]map[metadata.ScopedKey][]byte) []*ScimUser {
result := make([]*ScimUser, len(users))
for i, user := range users {
userMetadata, ok := md[user.ID]
if !ok {
userMetadata = make(map[metadata.ScopedKey][]byte)
}
result[i] = h.mapToScimUser(ctx, user, userMetadata)
}
return result
}
func (h *UsersHandler) mapToScimUser(ctx context.Context, user *query.User, md map[metadata.ScopedKey][]byte) *ScimUser {
scimUser := &ScimUser{
Resource: h.buildResourceForQuery(ctx, user),
@ -364,3 +378,11 @@ func userGrantsToIDs(userGrants []*query.UserGrant) []string {
}
return converted
}
func usersToIDs(users []*query.User) []string {
ids := make([]string, len(users))
for i, user := range users {
ids[i] = user.ID
}
return ids
}

View File

@ -20,6 +20,28 @@ import (
"github.com/zitadel/zitadel/internal/zerrors"
)
func (h *UsersHandler) queryMetadataForUsers(ctx context.Context, userIds []string) (map[string]map[metadata.ScopedKey][]byte, error) {
queries := h.buildMetadataQueries(ctx)
md, err := h.query.SearchUserMetadataForUsers(ctx, false, userIds, queries)
if err != nil {
return nil, err
}
metadataMap := make(map[string]map[metadata.ScopedKey][]byte, len(md.Metadata))
for _, entry := range md.Metadata {
userMetadata, ok := metadataMap[entry.UserID]
if !ok {
userMetadata = make(map[metadata.ScopedKey][]byte)
metadataMap[entry.UserID] = userMetadata
}
userMetadata[metadata.ScopedKey(entry.Key)] = entry.Value
}
return metadataMap, nil
}
func (h *UsersHandler) queryMetadataForUser(ctx context.Context, id string) (map[metadata.ScopedKey][]byte, error) {
queries := h.buildMetadataQueries(ctx)
@ -108,15 +130,11 @@ func getValueForMetadataKey(user *ScimUser, key metadata.Key) ([]byte, error) {
switch key {
// json values
case metadata.KeyEntitlements:
fallthrough
case metadata.KeyIms:
fallthrough
case metadata.KeyPhotos:
fallthrough
case metadata.KeyAddresses:
fallthrough
case metadata.KeyRoles:
case metadata.KeyRoles,
metadata.KeyAddresses,
metadata.KeyEntitlements,
metadata.KeyIms,
metadata.KeyPhotos:
val, err := json.Marshal(value)
if err != nil {
return nil, err
@ -134,21 +152,14 @@ func getValueForMetadataKey(user *ScimUser, key metadata.Key) ([]byte, error) {
return []byte(value.(*schemas.HttpURL).String()), nil
// raw values
case metadata.KeyProvisioningDomain:
fallthrough
case metadata.KeyExternalId:
fallthrough
case metadata.KeyMiddleName:
fallthrough
case metadata.KeyHonorificSuffix:
fallthrough
case metadata.KeyHonorificPrefix:
fallthrough
case metadata.KeyTitle:
fallthrough
case metadata.KeyLocale:
fallthrough
case metadata.KeyTimezone:
case metadata.KeyTimezone,
metadata.KeyLocale,
metadata.KeyTitle,
metadata.KeyHonorificPrefix,
metadata.KeyHonorificSuffix,
metadata.KeyMiddleName,
metadata.KeyExternalId,
metadata.KeyProvisioningDomain:
valueStr := value.(string)
if valueStr == "" {
return nil, nil

View File

@ -0,0 +1,160 @@
package resources
import (
"context"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/scim/metadata"
"github.com/zitadel/zitadel/internal/api/scim/resources/filter"
"github.com/zitadel/zitadel/internal/api/scim/serrors"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/zerrors"
)
// fieldPathColumnMapping maps lowercase json field names of the scim user to the matching column in the projection
// only a limited set of fields is supported
// to ensure database performance.
var fieldPathColumnMapping = filter.FieldPathMapping{
"meta.created": {
Column: query.UserCreationDateCol,
FieldType: filter.FieldTypeTimestamp,
},
"meta.lastmodified": {
Column: query.UserChangeDateCol,
FieldType: filter.FieldTypeTimestamp,
},
"id": {
Column: query.UserIDCol,
FieldType: filter.FieldTypeString,
},
"username": {
Column: query.UserUsernameCol,
FieldType: filter.FieldTypeString,
},
"name.familyname": {
Column: query.HumanLastNameCol,
FieldType: filter.FieldTypeString,
},
"name.givenname": {
Column: query.HumanFirstNameCol,
FieldType: filter.FieldTypeString,
},
"emails": {
Column: query.HumanEmailCol,
FieldType: filter.FieldTypeString,
},
"emails.value": {
Column: query.HumanEmailCol,
FieldType: filter.FieldTypeString,
},
"active": {
FieldType: filter.FieldTypeCustom,
BuildMappedQuery: buildActiveUserStateQuery,
},
"externalid": {
FieldType: filter.FieldTypeCustom,
BuildMappedQuery: newMetadataQueryBuilder(metadata.KeyExternalId),
},
}
func (h *UsersHandler) buildListQuery(ctx context.Context, request *ListRequest) (*query.UserSearchQueries, error) {
searchRequest, err := request.toSearchRequest(query.UserIDCol, fieldPathColumnMapping)
if err != nil {
return nil, err
}
q := &query.UserSearchQueries{
SearchRequest: searchRequest,
}
// the zitadel scim implementation only supports humans for now
userTypeQuery, err := query.NewUserTypeSearchQuery(int32(domain.UserTypeHuman))
if err != nil {
return nil, err
}
// the scim service is always limited to one organization
// the organization is the resource owner
orgIDQuery, err := query.NewUserResourceOwnerSearchQuery(authz.GetCtxData(ctx).OrgID, query.TextEquals)
if err != nil {
return nil, err
}
q.Queries = append(q.Queries, orgIDQuery, userTypeQuery)
if request.Filter == nil {
return q, nil
}
filterQuery, err := request.Filter.BuildQuery(ctx, h.SchemaType(), fieldPathColumnMapping)
if err != nil {
return nil, err
}
q.Queries = append(q.Queries, filterQuery)
return q, nil
}
func newMetadataQueryBuilder(key metadata.Key) filter.MappedQueryBuilderFunc {
return func(ctx context.Context, compareValue *filter.CompValue, op *filter.CompareOp) (query.SearchQuery, error) {
return buildMetadataQuery(ctx, key, compareValue, op)
}
}
func buildMetadataQuery(ctx context.Context, key metadata.Key, value *filter.CompValue, op *filter.CompareOp) (query.SearchQuery, error) {
if value.StringValue == nil {
return nil, serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgument(nil, "SCIM-EXid1", "invalid filter expression: unsupported comparison value"))
}
var comparisonOperator query.BytesComparison
switch {
case op.Equal:
comparisonOperator = query.BytesEquals
case op.NotEqual:
comparisonOperator = query.BytesNotEquals
default:
return nil, serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgument(nil, "SCIM-EXid1", "invalid filter expression: unsupported comparison operator"))
}
scopedKey := string(metadata.ScopeKey(ctx, key))
return query.NewUserMetadataExistsQuery(scopedKey, []byte(*value.StringValue), query.TextEquals, comparisonOperator)
}
func buildActiveUserStateQuery(_ context.Context, compareValue *filter.CompValue, op *filter.CompareOp) (query.SearchQuery, error) {
if !op.Equal && !op.NotEqual {
return nil, serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgument(nil, "SCIM-MGdg", "invalid filter expression: active unsupported comparison operator"))
}
if !compareValue.BooleanTrue && !compareValue.BooleanFalse {
return nil, serrors.ThrowInvalidFilter(zerrors.ThrowInvalidArgument(nil, "SCIM-MGdr", "invalid filter expression: active unsupported comparison value"))
}
active := compareValue.BooleanTrue && op.Equal || compareValue.BooleanFalse && op.NotEqual
if active {
activeQuery, err := query.NewNumberQuery(query.UserStateCol, int32(domain.UserStateActive), query.NumberEquals)
if err != nil {
return nil, err
}
initialQuery, err := query.NewNumberQuery(query.UserStateCol, int32(domain.UserStateInitial), query.NumberEquals)
if err != nil {
return nil, err
}
return query.NewOrQuery(initialQuery, activeQuery)
}
activeQuery, err := query.NewNumberQuery(query.UserStateCol, int32(domain.UserStateActive), query.NumberNotEquals)
if err != nil {
return nil, err
}
initialQuery, err := query.NewNumberQuery(query.UserStateCol, int32(domain.UserStateInitial), query.NumberNotEquals)
if err != nil {
return nil, err
}
return query.NewAndQuery(initialQuery, activeQuery)
}

View File

@ -0,0 +1,144 @@
package resources
import (
"context"
"reflect"
"testing"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/api/scim/metadata"
"github.com/zitadel/zitadel/internal/api/scim/resources/filter"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/test"
)
func Test_buildMetadataQuery(t *testing.T) {
tests := []struct {
name string
key metadata.Key
value *filter.CompValue
op *filter.CompareOp
want query.SearchQuery
wantErr bool
}{
{
name: "equals",
key: "foo",
value: &filter.CompValue{StringValue: gu.Ptr("bar")},
op: &filter.CompareOp{Equal: true},
want: test.Must(query.NewUserMetadataExistsQuery("foo", []byte("bar"), query.TextEquals, query.BytesEquals)),
wantErr: false,
},
{
name: "not equals",
key: "foo",
value: &filter.CompValue{StringValue: gu.Ptr("bar")},
op: &filter.CompareOp{NotEqual: true},
want: test.Must(query.NewUserMetadataExistsQuery("foo", []byte("bar"), query.TextEquals, query.BytesNotEquals)),
wantErr: false,
},
{
name: "unsupported operator",
key: "foo",
value: &filter.CompValue{StringValue: gu.Ptr("bar")},
op: &filter.CompareOp{StartsWith: true},
wantErr: true,
},
{
name: "unsupported comparison value",
key: "foo",
value: &filter.CompValue{Int: gu.Ptr(10)},
op: &filter.CompareOp{Equal: true},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := buildMetadataQuery(context.Background(), tt.key, tt.value, tt.op)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("buildMetadataQuery() got = %#v, want %#v", got, tt.want)
}
})
}
}
func Test_buildActiveUserStateQuery(t *testing.T) {
tests := []struct {
name string
compareValue *filter.CompValue
compOp *filter.CompareOp
want query.SearchQuery
wantErr bool
}{
{
name: "eq true",
compareValue: &filter.CompValue{BooleanTrue: true},
compOp: &filter.CompareOp{Equal: true},
want: test.Must(query.NewOrQuery(
test.Must(query.NewNumberQuery(query.UserStateCol, int32(domain.UserStateInitial), query.NumberEquals)),
test.Must(query.NewNumberQuery(query.UserStateCol, int32(domain.UserStateActive), query.NumberEquals)),
)),
},
{
name: "eq false",
compareValue: &filter.CompValue{BooleanFalse: true},
compOp: &filter.CompareOp{Equal: true},
want: test.Must(query.NewAndQuery(
test.Must(query.NewNumberQuery(query.UserStateCol, int32(domain.UserStateInitial), query.NumberNotEquals)),
test.Must(query.NewNumberQuery(query.UserStateCol, int32(domain.UserStateActive), query.NumberNotEquals)),
)),
},
{
name: "ne true",
compareValue: &filter.CompValue{BooleanTrue: true},
compOp: &filter.CompareOp{NotEqual: true},
want: test.Must(query.NewAndQuery(
test.Must(query.NewNumberQuery(query.UserStateCol, int32(domain.UserStateInitial), query.NumberNotEquals)),
test.Must(query.NewNumberQuery(query.UserStateCol, int32(domain.UserStateActive), query.NumberNotEquals)),
)),
},
{
name: "ne false",
compareValue: &filter.CompValue{BooleanTrue: true},
compOp: &filter.CompareOp{Equal: true},
want: test.Must(query.NewOrQuery(
test.Must(query.NewNumberQuery(query.UserStateCol, int32(domain.UserStateInitial), query.NumberEquals)),
test.Must(query.NewNumberQuery(query.UserStateCol, int32(domain.UserStateActive), query.NumberEquals)),
)),
},
{
name: "invalid operator",
compareValue: &filter.CompValue{BooleanTrue: true},
compOp: &filter.CompareOp{StartsWith: true},
wantErr: true,
},
{
name: "invalid comp value",
compareValue: &filter.CompValue{StringValue: gu.Ptr("foo")},
compOp: &filter.CompareOp{Equal: true},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := buildActiveUserStateQuery(context.Background(), tt.compareValue, tt.compOp)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.Equalf(t, tt.want, got, "buildActiveUserStateQuery(%#v, %#v)", tt.compareValue, tt.compOp)
})
}
}

View File

@ -10,6 +10,7 @@ const (
idPrefixZitadelMessages = "urn:ietf:params:scim:api:zitadel:messages:2.0:"
IdUser ScimSchemaType = idPrefixCore + "User"
IdListResponse ScimSchemaType = idPrefixMessages + "ListResponse"
IdError ScimSchemaType = idPrefixMessages + "Error"
IdZitadelErrorDetail ScimSchemaType = idPrefixZitadelMessages + "ErrorDetail"

View File

@ -49,6 +49,13 @@ const (
// ScimTypeInvalidSyntax The request body message structure was invalid or did
// not conform to the request schema.
ScimTypeInvalidSyntax scimErrorType = "invalidSyntax"
// ScimTypeInvalidFilter The specified filter syntax as invalid, or the
// specified attribute and filter comparison combination is not supported.
ScimTypeInvalidFilter scimErrorType = "invalidFilter"
// ScimTypeUniqueness One or more of the attribute values are already in use or are reserved.
ScimTypeUniqueness scimErrorType = "uniqueness"
)
var translator *i18n.Translator
@ -85,6 +92,22 @@ func ThrowInvalidSyntax(parent error) error {
}
}
func ThrowInvalidFilter(parent error) error {
return &wrappedScimError{
Parent: parent,
ScimType: ScimTypeInvalidFilter,
}
}
func IsScimOrZitadelError(err error) bool {
return IsScimError(err) || zerrors.IsZitadelError(err)
}
func IsScimError(err error) bool {
var scimErr *wrappedScimError
return errors.As(err, &scimErr)
}
func (err *scimError) Error() string {
return fmt.Sprintf("SCIM Error: %s: %s", err.ScimType, err.Detail)
}
@ -134,6 +157,8 @@ func mapErrorToScimErrorType(err error) scimErrorType {
switch {
case zerrors.IsErrorInvalidArgument(err):
return ScimTypeInvalidValue
case zerrors.IsErrorAlreadyExists(err):
return ScimTypeUniqueness
default:
return ""
}

View File

@ -54,11 +54,26 @@ func mapResource[T sresources.ResourceHolder](router *mux.Router, mw zhttp_middl
resourceRouter := router.PathPrefix("/" + path.Join(zhttp.OrgIdInPathVariable, string(handler.ResourceNamePlural()))).Subrouter()
resourceRouter.Handle("", mw(handleResourceCreatedResponse(adapter.Create))).Methods(http.MethodPost)
resourceRouter.Handle("", mw(handleJsonResponse(adapter.List))).Methods(http.MethodGet)
resourceRouter.Handle("/.search", mw(handleJsonResponse(adapter.List))).Methods(http.MethodPost)
resourceRouter.Handle("/{id}", mw(handleResourceResponse(adapter.Get))).Methods(http.MethodGet)
resourceRouter.Handle("/{id}", mw(handleResourceResponse(adapter.Replace))).Methods(http.MethodPut)
resourceRouter.Handle("/{id}", mw(handleEmptyResponse(adapter.Delete))).Methods(http.MethodDelete)
}
func handleJsonResponse[T any](next func(r *http.Request) (T, error)) zhttp_middlware.HandlerFuncWithError {
return func(w http.ResponseWriter, r *http.Request) error {
entity, err := next(r)
if err != nil {
return err
}
err = json.NewEncoder(w).Encode(entity)
logging.OnError(err).Warn("scim json response encoding failed")
return nil
}
}
func handleResourceCreatedResponse[T sresources.ResourceHolder](next func(*http.Request) (T, error)) zhttp_middlware.HandlerFuncWithError {
return func(w http.ResponseWriter, r *http.Request) error {
entity, err := next(r)

View File

@ -1,7 +1,6 @@
package integration
import (
"reflect"
"testing"
"time"
@ -170,99 +169,3 @@ func diffProto(expected, actual proto.Message) string {
}
return "\n\nDiff:\n" + diff
}
func AssertMapContains[M ~map[K]V, K comparable, V any](t assert.TestingT, m M, key K, expectedValue V) {
val, exists := m[key]
assert.True(t, exists, "Key '%s' should exist in the map", key)
if !exists {
return
}
assert.Equal(t, expectedValue, val, "Key '%s' should have value '%d'", key, expectedValue)
}
// PartiallyDeepEqual is similar to reflect.DeepEqual,
// but only compares exported non-zero fields of the expectedValue
func PartiallyDeepEqual(expected, actual interface{}) bool {
if expected == nil {
return actual == nil
}
if actual == nil {
return false
}
return partiallyDeepEqual(reflect.ValueOf(expected), reflect.ValueOf(actual))
}
func partiallyDeepEqual(expected, actual reflect.Value) bool {
// Dereference pointers if needed
if expected.Kind() == reflect.Ptr {
if expected.IsNil() {
return true
}
expected = expected.Elem()
}
if actual.Kind() == reflect.Ptr {
if actual.IsNil() {
return false
}
actual = actual.Elem()
}
if expected.Type() != actual.Type() {
return false
}
switch expected.Kind() { //nolint:exhaustive
case reflect.Struct:
for i := 0; i < expected.NumField(); i++ {
field := expected.Type().Field(i)
if field.PkgPath != "" { // Skip unexported fields
continue
}
expectedField := expected.Field(i)
actualField := actual.Field(i)
// Skip zero-value fields in expected
if reflect.DeepEqual(expectedField.Interface(), reflect.Zero(expectedField.Type()).Interface()) {
continue
}
// Compare fields recursively
if !partiallyDeepEqual(expectedField, actualField) {
return false
}
}
return true
case reflect.Slice, reflect.Array:
if expected.Len() > actual.Len() {
return false
}
for i := 0; i < expected.Len(); i++ {
if !partiallyDeepEqual(expected.Index(i), actual.Index(i)) {
return false
}
}
return true
default:
// Compare primitive types
return reflect.DeepEqual(expected.Interface(), actual.Interface())
}
}
func Must[T any](result T, error error) T {
if error != nil {
panic(error)
}
return result
}

View File

@ -50,153 +50,3 @@ func TestAssertDetails(t *testing.T) {
})
}
}
func TestPartiallyDeepEqual(t *testing.T) {
type SecondaryNestedType struct {
Value int
}
type NestedType struct {
Value int
ValueSlice []int
Nested SecondaryNestedType
NestedPointer *SecondaryNestedType
}
type args struct {
expected interface{}
actual interface{}
}
tests := []struct {
name string
args args
want bool
}{
{
name: "nil",
args: args{
expected: nil,
actual: nil,
},
want: true,
},
{
name: "scalar value",
args: args{
expected: 10,
actual: 10,
},
want: true,
},
{
name: "different scalar value",
args: args{
expected: 11,
actual: 10,
},
want: false,
},
{
name: "string value",
args: args{
expected: "foo",
actual: "foo",
},
want: true,
},
{
name: "different string value",
args: args{
expected: "foo2",
actual: "foo",
},
want: false,
},
{
name: "scalar only set in actual",
args: args{
expected: &SecondaryNestedType{},
actual: &SecondaryNestedType{Value: 10},
},
want: true,
},
{
name: "scalar equal",
args: args{
expected: &SecondaryNestedType{Value: 10},
actual: &SecondaryNestedType{Value: 10},
},
want: true,
},
{
name: "scalar only set in expected",
args: args{
expected: &SecondaryNestedType{Value: 10},
actual: &SecondaryNestedType{},
},
want: false,
},
{
name: "ptr only set in expected",
args: args{
expected: &NestedType{NestedPointer: &SecondaryNestedType{Value: 10}},
actual: &NestedType{},
},
want: false,
},
{
name: "ptr only set in actual",
args: args{
expected: &NestedType{},
actual: &NestedType{NestedPointer: &SecondaryNestedType{Value: 10}},
},
want: true,
},
{
name: "ptr equal",
args: args{
expected: &NestedType{NestedPointer: &SecondaryNestedType{Value: 10}},
actual: &NestedType{NestedPointer: &SecondaryNestedType{Value: 10}},
},
want: true,
},
{
name: "nested equal",
args: args{
expected: &NestedType{Nested: SecondaryNestedType{Value: 10}},
actual: &NestedType{Nested: SecondaryNestedType{Value: 10}},
},
want: true,
},
{
name: "slice equal",
args: args{
expected: &NestedType{ValueSlice: []int{10, 20}},
actual: &NestedType{ValueSlice: []int{10, 20}},
},
want: true,
},
{
name: "slice additional in expected",
args: args{
expected: &NestedType{ValueSlice: []int{10, 20, 30}},
actual: &NestedType{ValueSlice: []int{10, 20}},
},
want: false,
},
{
name: "slice additional in actual",
args: args{
expected: &NestedType{ValueSlice: []int{10, 20}},
actual: &NestedType{ValueSlice: []int{10, 20, 30}},
},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := PartiallyDeepEqual(tt.args.expected, tt.args.actual); got != tt.want {
t.Errorf("PartiallyDeepEqual() = %v, want %v", got, tt.want)
}
})
}
}

View File

@ -6,7 +6,10 @@ import (
"encoding/json"
"io"
"net/http"
"net/url"
"path"
"strconv"
"strings"
"github.com/zitadel/logging"
"google.golang.org/grpc/metadata"
@ -40,6 +43,44 @@ type ZitadelErrorDetail struct {
Message string `json:"message"`
}
type ListRequest struct {
Count *int `json:"count,omitempty"`
// StartIndex An integer indicating the 1-based index of the first query result.
StartIndex *int `json:"startIndex,omitempty"`
// Filter a scim filter expression to filter the query result.
Filter *string `json:"filter,omitempty"`
SortBy *string `json:"sortBy,omitempty"`
SortOrder *ListRequestSortOrder `json:"sortOrder,omitempty"`
SendAsPost bool
}
type ListRequestSortOrder string
const (
ListRequestSortOrderAsc ListRequestSortOrder = "ascending"
ListRequestSortOrderDsc ListRequestSortOrder = "descending"
)
type ListResponse[T any] struct {
Schemas []schemas.ScimSchemaType `json:"schemas"`
ItemsPerPage int `json:"itemsPerPage"`
TotalResults int `json:"totalResults"`
StartIndex int `json:"startIndex"`
Resources []T `json:"Resources"`
}
const (
listQueryParamSortBy = "sortBy"
listQueryParamSortOrder = "sortOrder"
listQueryParamCount = "count"
listQueryParamStartIndex = "startIndex"
listQueryParamFilter = "filter"
)
func NewScimClient(target string) *Client {
target = "http://" + target + schemas.HandlerPrefix
client := &http.Client{}
@ -60,6 +101,43 @@ func (c *ResourceClient[T]) Replace(ctx context.Context, orgID, id string, body
return c.doWithBody(ctx, http.MethodPut, orgID, id, bytes.NewReader(body))
}
func (c *ResourceClient[T]) List(ctx context.Context, orgID string, req *ListRequest) (*ListResponse[*T], error) {
if req.SendAsPost {
listReq, err := json.Marshal(req)
if err != nil {
return nil, err
}
return c.doWithListResponse(ctx, http.MethodPost, orgID, ".search", bytes.NewReader(listReq))
}
query, err := url.ParseQuery("")
if err != nil {
return nil, err
}
if req.SortBy != nil {
query.Set(listQueryParamSortBy, *req.SortBy)
}
if req.SortOrder != nil {
query.Set(listQueryParamSortOrder, string(*req.SortOrder))
}
if req.Count != nil {
query.Set(listQueryParamCount, strconv.Itoa(*req.Count))
}
if req.StartIndex != nil {
query.Set(listQueryParamStartIndex, strconv.Itoa(*req.StartIndex))
}
if req.Filter != nil {
query.Set(listQueryParamFilter, *req.Filter)
}
return c.doWithListResponse(ctx, http.MethodGet, orgID, "?"+query.Encode(), nil)
}
func (c *ResourceClient[T]) Get(ctx context.Context, orgID, resourceID string) (*T, error) {
return c.doWithBody(ctx, http.MethodGet, orgID, resourceID, nil)
}
@ -77,6 +155,17 @@ func (c *ResourceClient[T]) do(ctx context.Context, method, orgID, url string) e
return c.doReq(req, nil)
}
func (c *ResourceClient[T]) doWithListResponse(ctx context.Context, method, orgID, url string, body io.Reader) (*ListResponse[*T], error) {
req, err := http.NewRequestWithContext(ctx, method, c.buildURL(orgID, url), body)
if err != nil {
return nil, err
}
req.Header.Set(zhttp.ContentType, middleware.ContentTypeScim)
response := new(ListResponse[*T])
return response, c.doReq(req, response)
}
func (c *ResourceClient[T]) doWithBody(ctx context.Context, method, orgID, url string, body io.Reader) (*T, error) {
req, err := http.NewRequestWithContext(ctx, method, c.buildURL(orgID, url), body)
if err != nil {
@ -88,7 +177,7 @@ func (c *ResourceClient[T]) doWithBody(ctx context.Context, method, orgID, url s
return responseEntity, c.doReq(req, responseEntity)
}
func (c *ResourceClient[T]) doReq(req *http.Request, responseEntity *T) error {
func (c *ResourceClient[T]) doReq(req *http.Request, responseEntity interface{}) error {
addTokenAsHeader(req)
resp, err := c.client.Do(req)
@ -141,8 +230,8 @@ func readScimError(resp *http.Response) error {
}
func (c *ResourceClient[T]) buildURL(orgID, segment string) string {
if segment == "" {
return c.baseUrl + "/" + path.Join(orgID, c.resourceName)
if segment == "" || strings.HasPrefix(segment, "?") {
return c.baseUrl + "/" + path.Join(orgID, c.resourceName) + segment
}
return c.baseUrl + "/" + path.Join(orgID, c.resourceName, segment)

View File

@ -110,7 +110,7 @@ func mockQueries(stmt string, cols []string, rows [][]driver.Value, args ...driv
result := m.NewRows(cols)
count := uint64(len(rows))
for _, row := range rows {
if cols[len(cols)-1] == "count" {
if cols[len(cols)-1] == "count" && len(row) == len(cols)-1 {
row = append(row, count)
}
result.AddRow(row...)

View File

@ -109,6 +109,10 @@ func NewOrQuery(queries ...SearchQuery) (*OrQuery, error) {
return &OrQuery{queries: queries}, nil
}
func (q *OrQuery) Prepend(queries ...SearchQuery) {
q.queries = append(queries, q.queries...)
}
func (q *OrQuery) toQuery(query sq.SelectBuilder) sq.SelectBuilder {
return query.Where(q.comp())
}
@ -147,6 +151,10 @@ func (q *AndQuery) comp() sq.Sqlizer {
return and
}
func (q *AndQuery) Prepend(queries ...SearchQuery) {
q.queries = append(queries, q.queries...)
}
type NotQuery struct {
query SearchQuery
}
@ -406,8 +414,12 @@ func (q *NumberQuery) comp() sq.Sqlizer {
return sq.NotEq{q.Column.identifier(): q.Number}
case NumberLess:
return sq.Lt{q.Column.identifier(): q.Number}
case NumberLessOrEqual:
return sq.LtOrEq{q.Column.identifier(): q.Number}
case NumberGreater:
return sq.Gt{q.Column.identifier(): q.Number}
case NumberGreaterOrEqual:
return sq.GtOrEq{q.Column.identifier(): q.Number}
case NumberListContains:
return &listContains{col: q.Column, args: []interface{}{q.Number}}
case numberCompareMax:
@ -423,7 +435,9 @@ const (
NumberEquals NumberComparison = iota
NumberNotEquals
NumberLess
NumberLessOrEqual
NumberGreater
NumberGreaterOrEqual
NumberListContains
numberCompareMax
@ -588,6 +602,57 @@ func (q *BoolQuery) comp() sq.Sqlizer {
return sq.Eq{q.Column.identifier(): q.Value}
}
type BytesComparison int
const (
BytesEquals BytesComparison = iota
BytesNotEquals
bytesCompareMax
)
type BytesQuery struct {
Column Column
Compare BytesComparison
Value []byte
}
func NewBytesQuery(col Column, values []byte, comparison BytesComparison) (*BytesQuery, error) {
if col.isZero() {
return nil, ErrMissingColumn
}
if comparison < 0 || comparison >= bytesCompareMax {
return nil, ErrInvalidCompare
}
return &BytesQuery{
Column: col,
Value: values,
Compare: comparison,
}, nil
}
func (q *BytesQuery) Col() Column {
return q.Column
}
func (q *BytesQuery) toQuery(query sq.SelectBuilder) sq.SelectBuilder {
return query.Where(q.comp())
}
func (q *BytesQuery) comp() sq.Sqlizer {
switch q.Compare {
case BytesEquals:
return sq.Eq{q.Column.identifier(): q.Value}
case BytesNotEquals:
return sq.NotEq{q.Column.identifier(): q.Value}
case bytesCompareMax:
return nil
}
return nil
}
type TimestampComparison int
const (

View File

@ -6,6 +6,7 @@ import (
"testing"
sq "github.com/Masterminds/squirrel"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/domain"
)
@ -1540,6 +1541,17 @@ func TestNumberQuery_comp(t *testing.T) {
query: sq.Lt{"test_table.test_col": 42},
},
},
{
name: "less or equal",
fields: fields{
Column: testCol,
Number: 42,
Compare: NumberLessOrEqual,
},
want: want{
query: sq.LtOrEq{"test_table.test_col": 42},
},
},
{
name: "greater",
fields: fields{
@ -1551,6 +1563,17 @@ func TestNumberQuery_comp(t *testing.T) {
query: sq.Gt{"test_table.test_col": 42},
},
},
{
name: "greater or equal",
fields: fields{
Column: testCol,
Number: 42,
Compare: NumberGreaterOrEqual,
},
want: want{
query: sq.GtOrEq{"test_table.test_col": 42},
},
},
{
name: "list containts",
fields: fields{
@ -2193,3 +2216,98 @@ func TestInTextQuery_comp(t *testing.T) {
})
}
}
func TestBytesQuery_comp(t *testing.T) {
type fields struct {
Column Column
Value []byte
Compare BytesComparison
}
type want struct {
query interface{}
err bool
isNil bool
}
tests := []struct {
name string
fields fields
want want
}{
{
name: "equals",
fields: fields{
Column: testCol,
Value: []byte("foo"),
Compare: BytesEquals,
},
want: want{
query: sq.Eq{"test_table.test_col": []byte("foo")},
},
},
{
name: "not equals",
fields: fields{
Column: testCol,
Value: []byte("foo"),
Compare: BytesNotEquals,
},
want: want{
query: sq.NotEq{"test_table.test_col": []byte("foo")},
},
},
{
name: "unknown comparison",
fields: fields{
Column: testCol,
Value: []byte("foo"),
Compare: -1,
},
want: want{
err: true,
isNil: true,
},
},
{
name: "zero col",
fields: fields{
Column: Column{},
Value: []byte("foo"),
Compare: BytesEquals,
},
want: want{
err: true,
query: sq.Eq{"": []byte("foo")},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s, err := NewBytesQuery(tt.fields.Column, tt.fields.Value, tt.fields.Compare)
if tt.want.err {
require.Error(t, err)
// still test comp
s = &BytesQuery{
Column: tt.fields.Column,
Value: tt.fields.Value,
Compare: tt.fields.Compare,
}
} else {
require.NoError(t, err)
}
query := s.comp()
if tt.want.isNil {
require.Nil(t, query)
return
}
require.NotNil(t, query)
if !reflect.DeepEqual(query, tt.want.query) {
t.Errorf("wrong query: want: %v, (%T), got: %v, (%T)", tt.want.query, tt.want.query, query, query)
}
})
}
}

View File

@ -604,6 +604,27 @@ func (q *Queries) GetNotifyUser(ctx context.Context, shouldTriggered bool, queri
return user, err
}
func (q *Queries) CountUsers(ctx context.Context, queries *UserSearchQueries) (count uint64, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := prepareCountUsersQuery()
eq := sq.Eq{UserInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID()}
stmt, args, err := queries.toQuery(query).Where(eq).ToSql()
if err != nil {
return 0, zerrors.ThrowInternal(err, "QUERY-w3Dx", "Errors.Query.SQLStatment")
}
err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
count, err = scan(rows)
return err
}, stmt, args...)
if err != nil {
return 0, zerrors.ThrowInternal(err, "QUERY-AG4gs", "Errors.Internal")
}
return count, err
}
func (q *Queries) SearchUsers(ctx context.Context, queries *UserSearchQueries, permissionCheck domain.PermissionCheck) (*Users, error) {
users, err := q.searchUsers(ctx, queries, permissionCheck != nil && authz.GetFeatures(ctx).PermissionCheckV2)
if err != nil {
@ -1278,6 +1299,24 @@ func scanNotifyUser(row *sql.Row) (*NotifyUser, error) {
return u, nil
}
func prepareCountUsersQuery() (sq.SelectBuilder, func(*sql.Rows) (uint64, error)) {
return sq.Select(countColumn.identifier()).
From(userTable.identifier()).
LeftJoin(join(HumanUserIDCol, UserIDCol)).
LeftJoin(join(MachineUserIDCol, UserIDCol)).
PlaceholderFormat(sq.Dollar),
func(rows *sql.Rows) (count uint64, err error) {
// the count is implemented as a windowing function,
// if it is zero, no row is returned at all.
if !rows.Next() {
return
}
err = rows.Scan(&count)
return
}
}
func prepareUserUniqueQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (bool, error)) {
return sq.Select(
UserIDCol.identifier(),

View File

@ -24,6 +24,7 @@ type UserMetadataList struct {
type UserMetadata struct {
CreationDate time.Time `json:"creation_date,omitempty"`
UserID string `json:"-"`
ChangeDate time.Time `json:"change_date,omitempty"`
ResourceOwner string `json:"resource_owner,omitempty"`
Sequence uint64 `json:"sequence,omitempty"`
@ -107,6 +108,38 @@ func (q *Queries) GetUserMetadataByKey(ctx context.Context, shouldTriggerBulk bo
return metadata, err
}
func (q *Queries) SearchUserMetadataForUsers(ctx context.Context, shouldTriggerBulk bool, userIDs []string, queries *UserMetadataSearchQueries) (metadata *UserMetadataList, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
if shouldTriggerBulk {
_, traceSpan := tracing.NewNamedSpan(ctx, "TriggerUserMetadataProjection")
ctx, err = projection.UserMetadataProjection.Trigger(ctx, handler.WithAwaitRunning())
logging.OnError(err).Debug("trigger failed")
traceSpan.EndWithError(err)
}
query, scan := prepareUserMetadataListQuery(ctx, q.client)
eq := sq.Eq{
UserMetadataUserIDCol.identifier(): userIDs,
UserMetadataInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID(),
}
stmt, args, err := queries.toQuery(query).Where(eq).ToSql()
if err != nil {
return nil, zerrors.ThrowInternal(err, "QUERY-Egbgd", "Errors.Query.SQLStatment")
}
err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
metadata, err = scan(rows)
return err
}, stmt, args...)
if err != nil {
return nil, err
}
metadata.State, err = q.latestState(ctx, userMetadataTable)
return metadata, err
}
func (q *Queries) SearchUserMetadata(ctx context.Context, shouldTriggerBulk bool, userID string, queries *UserMetadataSearchQueries, withOwnerRemoved bool) (metadata *UserMetadataList, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
@ -164,6 +197,44 @@ func NewUserMetadataKeySearchQuery(value string, comparison TextComparison) (Sea
return NewTextQuery(UserMetadataKeyCol, value, comparison)
}
func NewUserMetadataExistsQuery(key string, value []byte, keyComparison TextComparison, valueComparison BytesComparison) (SearchQuery, error) {
// linking queries for the subselect
instanceQuery, err := NewColumnComparisonQuery(UserMetadataInstanceIDCol, UserInstanceIDCol, ColumnEquals)
if err != nil {
return nil, err
}
userIDQuery, err := NewColumnComparisonQuery(UserMetadataUserIDCol, UserIDCol, ColumnEquals)
if err != nil {
return nil, err
}
// text query to select data from the linked sub select
metadataKeyQuery, err := NewTextQuery(UserMetadataKeyCol, key, keyComparison)
if err != nil {
return nil, err
}
// text query to select data from the linked sub select
metadataValueQuery, err := NewBytesQuery(UserMetadataValueCol, value, valueComparison)
if err != nil {
return nil, err
}
// full definition of the sub select
subSelect, err := NewSubSelect(UserMetadataUserIDCol, []SearchQuery{instanceQuery, userIDQuery, metadataKeyQuery, metadataValueQuery})
if err != nil {
return nil, err
}
// "WHERE * IN (*)" query with subquery as list-data provider
return NewListQuery(
UserIDCol,
subSelect,
ListIn,
)
}
func prepareUserMetadataQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*UserMetadata, error)) {
return sq.Select(
UserMetadataCreationDateCol.identifier(),
@ -200,6 +271,7 @@ func prepareUserMetadataListQuery(ctx context.Context, db prepareDatabase) (sq.S
return sq.Select(
UserMetadataCreationDateCol.identifier(),
UserMetadataChangeDateCol.identifier(),
UserMetadataUserIDCol.identifier(),
UserMetadataResourceOwnerCol.identifier(),
UserMetadataSequenceCol.identifier(),
UserMetadataKeyCol.identifier(),
@ -215,6 +287,7 @@ func prepareUserMetadataListQuery(ctx context.Context, db prepareDatabase) (sq.S
err := rows.Scan(
&m.CreationDate,
&m.ChangeDate,
&m.UserID,
&m.ResourceOwner,
&m.Sequence,
&m.Key,

View File

@ -30,6 +30,7 @@ var (
}
userMetadataListQuery = `SELECT projections.user_metadata5.creation_date,` +
` projections.user_metadata5.change_date,` +
` projections.user_metadata5.user_id,` +
` projections.user_metadata5.resource_owner,` +
` projections.user_metadata5.sequence,` +
` projections.user_metadata5.key,` +
@ -39,6 +40,7 @@ var (
userMetadataListCols = []string{
"creation_date",
"change_date",
"user_id",
"resource_owner",
"sequence",
"key",
@ -148,6 +150,7 @@ func Test_UserMetadataPrepares(t *testing.T) {
{
testNow,
testNow,
"1",
"resource_owner",
uint64(20211108),
"key",
@ -164,6 +167,7 @@ func Test_UserMetadataPrepares(t *testing.T) {
{
CreationDate: testNow,
ChangeDate: testNow,
UserID: "1",
ResourceOwner: "resource_owner",
Sequence: 20211108,
Key: "key",
@ -183,6 +187,7 @@ func Test_UserMetadataPrepares(t *testing.T) {
{
testNow,
testNow,
"1",
"resource_owner",
uint64(20211108),
"key",
@ -191,6 +196,7 @@ func Test_UserMetadataPrepares(t *testing.T) {
{
testNow,
testNow,
"2",
"resource_owner",
uint64(20211108),
"key2",
@ -207,6 +213,7 @@ func Test_UserMetadataPrepares(t *testing.T) {
{
CreationDate: testNow,
ChangeDate: testNow,
UserID: "1",
ResourceOwner: "resource_owner",
Sequence: 20211108,
Key: "key",
@ -215,6 +222,7 @@ func Test_UserMetadataPrepares(t *testing.T) {
{
CreationDate: testNow,
ChangeDate: testNow,
UserID: "2",
ResourceOwner: "resource_owner",
Sequence: 20211108,
Key: "key2",

View File

@ -6,6 +6,7 @@ import (
"database/sql/driver"
"errors"
"fmt"
"reflect"
"regexp"
"testing"
@ -530,6 +531,8 @@ var (
"access_token_type",
"count",
}
countUsersQuery = "SELECT COUNT(*) OVER () FROM projections.users13"
countUsersCols = []string{"count"}
)
func Test_UserPrepares(t *testing.T) {
@ -1508,10 +1511,67 @@ func Test_UserPrepares(t *testing.T) {
},
object: (*Users)(nil),
},
{
name: "prepareCountUsersQuery no result",
prepare: prepareCountUsersQuery,
want: want{
sqlExpectations: mockQuery(
regexp.QuoteMeta(countUsersQuery),
nil,
nil,
),
},
object: uint64(0),
},
{
name: "prepareCountUsersQuery one result",
prepare: prepareCountUsersQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(countUsersQuery),
countUsersCols,
[][]driver.Value{{uint64(1)}},
),
},
object: uint64(1),
},
{
name: "prepareCountUsersQuery multiple results",
prepare: prepareCountUsersQuery,
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(countUsersQuery),
countUsersCols,
[][]driver.Value{{uint64(2)}},
),
},
object: uint64(2),
},
{
name: "prepareCountUsersQuery sql err",
prepare: prepareCountUsersQuery,
want: want{
sqlExpectations: mockQueryErr(
regexp.QuoteMeta(countUsersQuery),
sql.ErrConnDone,
),
err: func(err error) (error, bool) {
if !errors.Is(err, sql.ErrConnDone) {
return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false
}
return nil, true
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...)
params := defaultPrepareArgs
if reflect.TypeOf(tt.prepare).NumIn() == 0 {
params = []reflect.Value{}
}
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, params...)
})
}
}

103
internal/test/assert.go Normal file
View File

@ -0,0 +1,103 @@
package test
import (
"reflect"
"github.com/stretchr/testify/assert"
)
func AssertMapContains[M ~map[K]V, K comparable, V any](t assert.TestingT, m M, key K, expectedValue V) {
val, exists := m[key]
assert.True(t, exists, "Key '%s' should exist in the map", key)
if !exists {
return
}
assert.Equal(t, expectedValue, val, "Key '%s' should have value '%d'", key, expectedValue)
}
// PartiallyDeepEqual is similar to reflect.DeepEqual,
// but only compares exported non-zero fields of the expectedValue
func PartiallyDeepEqual(expected, actual interface{}) bool {
if expected == nil {
return actual == nil
}
if actual == nil {
return false
}
return partiallyDeepEqual(reflect.ValueOf(expected), reflect.ValueOf(actual))
}
func partiallyDeepEqual(expected, actual reflect.Value) bool {
// Dereference pointers if needed
if expected.Kind() == reflect.Ptr {
if expected.IsNil() {
return true
}
expected = expected.Elem()
}
if actual.Kind() == reflect.Ptr {
if actual.IsNil() {
return false
}
actual = actual.Elem()
}
if expected.Type() != actual.Type() {
return false
}
switch expected.Kind() { //nolint:exhaustive
case reflect.Struct:
for i := 0; i < expected.NumField(); i++ {
field := expected.Type().Field(i)
if field.PkgPath != "" { // Skip unexported fields
continue
}
expectedField := expected.Field(i)
actualField := actual.Field(i)
// Skip zero-value fields in expected
if reflect.DeepEqual(expectedField.Interface(), reflect.Zero(expectedField.Type()).Interface()) {
continue
}
// Compare fields recursively
if !partiallyDeepEqual(expectedField, actualField) {
return false
}
}
return true
case reflect.Slice, reflect.Array:
if expected.Len() > actual.Len() {
return false
}
for i := 0; i < expected.Len(); i++ {
if !partiallyDeepEqual(expected.Index(i), actual.Index(i)) {
return false
}
}
return true
default:
// Compare primitive types
return reflect.DeepEqual(expected.Interface(), actual.Interface())
}
}
func Must[T any](result T, error error) T {
if error != nil {
panic(error)
}
return result
}

View File

@ -0,0 +1,153 @@
package test
import "testing"
func TestPartiallyDeepEqual(t *testing.T) {
type SecondaryNestedType struct {
Value int
}
type NestedType struct {
Value int
ValueSlice []int
Nested SecondaryNestedType
NestedPointer *SecondaryNestedType
}
type args struct {
expected interface{}
actual interface{}
}
tests := []struct {
name string
args args
want bool
}{
{
name: "nil",
args: args{
expected: nil,
actual: nil,
},
want: true,
},
{
name: "scalar value",
args: args{
expected: 10,
actual: 10,
},
want: true,
},
{
name: "different scalar value",
args: args{
expected: 11,
actual: 10,
},
want: false,
},
{
name: "string value",
args: args{
expected: "foo",
actual: "foo",
},
want: true,
},
{
name: "different string value",
args: args{
expected: "foo2",
actual: "foo",
},
want: false,
},
{
name: "scalar only set in actual",
args: args{
expected: &SecondaryNestedType{},
actual: &SecondaryNestedType{Value: 10},
},
want: true,
},
{
name: "scalar equal",
args: args{
expected: &SecondaryNestedType{Value: 10},
actual: &SecondaryNestedType{Value: 10},
},
want: true,
},
{
name: "scalar only set in expected",
args: args{
expected: &SecondaryNestedType{Value: 10},
actual: &SecondaryNestedType{},
},
want: false,
},
{
name: "ptr only set in expected",
args: args{
expected: &NestedType{NestedPointer: &SecondaryNestedType{Value: 10}},
actual: &NestedType{},
},
want: false,
},
{
name: "ptr only set in actual",
args: args{
expected: &NestedType{},
actual: &NestedType{NestedPointer: &SecondaryNestedType{Value: 10}},
},
want: true,
},
{
name: "ptr equal",
args: args{
expected: &NestedType{NestedPointer: &SecondaryNestedType{Value: 10}},
actual: &NestedType{NestedPointer: &SecondaryNestedType{Value: 10}},
},
want: true,
},
{
name: "nested equal",
args: args{
expected: &NestedType{Nested: SecondaryNestedType{Value: 10}},
actual: &NestedType{Nested: SecondaryNestedType{Value: 10}},
},
want: true,
},
{
name: "slice equal",
args: args{
expected: &NestedType{ValueSlice: []int{10, 20}},
actual: &NestedType{ValueSlice: []int{10, 20}},
},
want: true,
},
{
name: "slice additional in expected",
args: args{
expected: &NestedType{ValueSlice: []int{10, 20, 30}},
actual: &NestedType{ValueSlice: []int{10, 20}},
},
want: false,
},
{
name: "slice additional in actual",
args: args{
expected: &NestedType{ValueSlice: []int{10, 20}},
actual: &NestedType{ValueSlice: []int{10, 20, 30}},
},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := PartiallyDeepEqual(tt.args.expected, tt.args.actual); got != tt.want {
t.Errorf("PartiallyDeepEqual() = %v, want %v", got, tt.want)
}
})
}
}