feat: list users scim v2 endpoint (#9187)

# Which Problems Are Solved
- Adds support for the list users SCIM v2 endpoint

# How the Problems Are Solved
- Adds support for the list users SCIM v2 endpoints under `GET
/scim/v2/{orgID}/Users` and `POST /scim/v2/{orgID}/Users/.search`

# Additional Changes
- adds a new function `SearchUserMetadataForUsers` to the query layer to
query a metadata keyset for given user ids
- adds a new function `NewUserMetadataExistsQuery` to the query layer to
query a given metadata key value pair exists
- adds a new function `CountUsers` to the query layer to count users
without reading any rows
- handle `ErrorAlreadyExists` as scim errors `uniqueness`
- adds `NumberLessOrEqual` and `NumberGreaterOrEqual` query comparison
methods
- adds `BytesQuery` with `BytesEquals` and `BytesNotEquals` query
comparison methods

# Additional Context
Part of #8140
Supported fields for scim filters:
* `meta.created`
* `meta.lastModified`
* `id`
* `username`
* `name.familyName`
* `name.givenName`
* `emails` and `emails.value`
* `active` only eq and ne
* `externalId` only eq and ne
This commit is contained in:
Lars
2025-01-21 13:31:54 +01:00
committed by GitHub
parent 926e7169b2
commit 1915d35605
37 changed files with 4173 additions and 417 deletions

View File

@@ -1,7 +1,6 @@
package integration
import (
"reflect"
"testing"
"time"
@@ -170,99 +169,3 @@ func diffProto(expected, actual proto.Message) string {
}
return "\n\nDiff:\n" + diff
}
func AssertMapContains[M ~map[K]V, K comparable, V any](t assert.TestingT, m M, key K, expectedValue V) {
val, exists := m[key]
assert.True(t, exists, "Key '%s' should exist in the map", key)
if !exists {
return
}
assert.Equal(t, expectedValue, val, "Key '%s' should have value '%d'", key, expectedValue)
}
// PartiallyDeepEqual is similar to reflect.DeepEqual,
// but only compares exported non-zero fields of the expectedValue
func PartiallyDeepEqual(expected, actual interface{}) bool {
if expected == nil {
return actual == nil
}
if actual == nil {
return false
}
return partiallyDeepEqual(reflect.ValueOf(expected), reflect.ValueOf(actual))
}
func partiallyDeepEqual(expected, actual reflect.Value) bool {
// Dereference pointers if needed
if expected.Kind() == reflect.Ptr {
if expected.IsNil() {
return true
}
expected = expected.Elem()
}
if actual.Kind() == reflect.Ptr {
if actual.IsNil() {
return false
}
actual = actual.Elem()
}
if expected.Type() != actual.Type() {
return false
}
switch expected.Kind() { //nolint:exhaustive
case reflect.Struct:
for i := 0; i < expected.NumField(); i++ {
field := expected.Type().Field(i)
if field.PkgPath != "" { // Skip unexported fields
continue
}
expectedField := expected.Field(i)
actualField := actual.Field(i)
// Skip zero-value fields in expected
if reflect.DeepEqual(expectedField.Interface(), reflect.Zero(expectedField.Type()).Interface()) {
continue
}
// Compare fields recursively
if !partiallyDeepEqual(expectedField, actualField) {
return false
}
}
return true
case reflect.Slice, reflect.Array:
if expected.Len() > actual.Len() {
return false
}
for i := 0; i < expected.Len(); i++ {
if !partiallyDeepEqual(expected.Index(i), actual.Index(i)) {
return false
}
}
return true
default:
// Compare primitive types
return reflect.DeepEqual(expected.Interface(), actual.Interface())
}
}
func Must[T any](result T, error error) T {
if error != nil {
panic(error)
}
return result
}

View File

@@ -50,153 +50,3 @@ func TestAssertDetails(t *testing.T) {
})
}
}
func TestPartiallyDeepEqual(t *testing.T) {
type SecondaryNestedType struct {
Value int
}
type NestedType struct {
Value int
ValueSlice []int
Nested SecondaryNestedType
NestedPointer *SecondaryNestedType
}
type args struct {
expected interface{}
actual interface{}
}
tests := []struct {
name string
args args
want bool
}{
{
name: "nil",
args: args{
expected: nil,
actual: nil,
},
want: true,
},
{
name: "scalar value",
args: args{
expected: 10,
actual: 10,
},
want: true,
},
{
name: "different scalar value",
args: args{
expected: 11,
actual: 10,
},
want: false,
},
{
name: "string value",
args: args{
expected: "foo",
actual: "foo",
},
want: true,
},
{
name: "different string value",
args: args{
expected: "foo2",
actual: "foo",
},
want: false,
},
{
name: "scalar only set in actual",
args: args{
expected: &SecondaryNestedType{},
actual: &SecondaryNestedType{Value: 10},
},
want: true,
},
{
name: "scalar equal",
args: args{
expected: &SecondaryNestedType{Value: 10},
actual: &SecondaryNestedType{Value: 10},
},
want: true,
},
{
name: "scalar only set in expected",
args: args{
expected: &SecondaryNestedType{Value: 10},
actual: &SecondaryNestedType{},
},
want: false,
},
{
name: "ptr only set in expected",
args: args{
expected: &NestedType{NestedPointer: &SecondaryNestedType{Value: 10}},
actual: &NestedType{},
},
want: false,
},
{
name: "ptr only set in actual",
args: args{
expected: &NestedType{},
actual: &NestedType{NestedPointer: &SecondaryNestedType{Value: 10}},
},
want: true,
},
{
name: "ptr equal",
args: args{
expected: &NestedType{NestedPointer: &SecondaryNestedType{Value: 10}},
actual: &NestedType{NestedPointer: &SecondaryNestedType{Value: 10}},
},
want: true,
},
{
name: "nested equal",
args: args{
expected: &NestedType{Nested: SecondaryNestedType{Value: 10}},
actual: &NestedType{Nested: SecondaryNestedType{Value: 10}},
},
want: true,
},
{
name: "slice equal",
args: args{
expected: &NestedType{ValueSlice: []int{10, 20}},
actual: &NestedType{ValueSlice: []int{10, 20}},
},
want: true,
},
{
name: "slice additional in expected",
args: args{
expected: &NestedType{ValueSlice: []int{10, 20, 30}},
actual: &NestedType{ValueSlice: []int{10, 20}},
},
want: false,
},
{
name: "slice additional in actual",
args: args{
expected: &NestedType{ValueSlice: []int{10, 20}},
actual: &NestedType{ValueSlice: []int{10, 20, 30}},
},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := PartiallyDeepEqual(tt.args.expected, tt.args.actual); got != tt.want {
t.Errorf("PartiallyDeepEqual() = %v, want %v", got, tt.want)
}
})
}
}

View File

@@ -6,7 +6,10 @@ import (
"encoding/json"
"io"
"net/http"
"net/url"
"path"
"strconv"
"strings"
"github.com/zitadel/logging"
"google.golang.org/grpc/metadata"
@@ -40,6 +43,44 @@ type ZitadelErrorDetail struct {
Message string `json:"message"`
}
type ListRequest struct {
Count *int `json:"count,omitempty"`
// StartIndex An integer indicating the 1-based index of the first query result.
StartIndex *int `json:"startIndex,omitempty"`
// Filter a scim filter expression to filter the query result.
Filter *string `json:"filter,omitempty"`
SortBy *string `json:"sortBy,omitempty"`
SortOrder *ListRequestSortOrder `json:"sortOrder,omitempty"`
SendAsPost bool
}
type ListRequestSortOrder string
const (
ListRequestSortOrderAsc ListRequestSortOrder = "ascending"
ListRequestSortOrderDsc ListRequestSortOrder = "descending"
)
type ListResponse[T any] struct {
Schemas []schemas.ScimSchemaType `json:"schemas"`
ItemsPerPage int `json:"itemsPerPage"`
TotalResults int `json:"totalResults"`
StartIndex int `json:"startIndex"`
Resources []T `json:"Resources"`
}
const (
listQueryParamSortBy = "sortBy"
listQueryParamSortOrder = "sortOrder"
listQueryParamCount = "count"
listQueryParamStartIndex = "startIndex"
listQueryParamFilter = "filter"
)
func NewScimClient(target string) *Client {
target = "http://" + target + schemas.HandlerPrefix
client := &http.Client{}
@@ -60,6 +101,43 @@ func (c *ResourceClient[T]) Replace(ctx context.Context, orgID, id string, body
return c.doWithBody(ctx, http.MethodPut, orgID, id, bytes.NewReader(body))
}
func (c *ResourceClient[T]) List(ctx context.Context, orgID string, req *ListRequest) (*ListResponse[*T], error) {
if req.SendAsPost {
listReq, err := json.Marshal(req)
if err != nil {
return nil, err
}
return c.doWithListResponse(ctx, http.MethodPost, orgID, ".search", bytes.NewReader(listReq))
}
query, err := url.ParseQuery("")
if err != nil {
return nil, err
}
if req.SortBy != nil {
query.Set(listQueryParamSortBy, *req.SortBy)
}
if req.SortOrder != nil {
query.Set(listQueryParamSortOrder, string(*req.SortOrder))
}
if req.Count != nil {
query.Set(listQueryParamCount, strconv.Itoa(*req.Count))
}
if req.StartIndex != nil {
query.Set(listQueryParamStartIndex, strconv.Itoa(*req.StartIndex))
}
if req.Filter != nil {
query.Set(listQueryParamFilter, *req.Filter)
}
return c.doWithListResponse(ctx, http.MethodGet, orgID, "?"+query.Encode(), nil)
}
func (c *ResourceClient[T]) Get(ctx context.Context, orgID, resourceID string) (*T, error) {
return c.doWithBody(ctx, http.MethodGet, orgID, resourceID, nil)
}
@@ -77,6 +155,17 @@ func (c *ResourceClient[T]) do(ctx context.Context, method, orgID, url string) e
return c.doReq(req, nil)
}
func (c *ResourceClient[T]) doWithListResponse(ctx context.Context, method, orgID, url string, body io.Reader) (*ListResponse[*T], error) {
req, err := http.NewRequestWithContext(ctx, method, c.buildURL(orgID, url), body)
if err != nil {
return nil, err
}
req.Header.Set(zhttp.ContentType, middleware.ContentTypeScim)
response := new(ListResponse[*T])
return response, c.doReq(req, response)
}
func (c *ResourceClient[T]) doWithBody(ctx context.Context, method, orgID, url string, body io.Reader) (*T, error) {
req, err := http.NewRequestWithContext(ctx, method, c.buildURL(orgID, url), body)
if err != nil {
@@ -88,7 +177,7 @@ func (c *ResourceClient[T]) doWithBody(ctx context.Context, method, orgID, url s
return responseEntity, c.doReq(req, responseEntity)
}
func (c *ResourceClient[T]) doReq(req *http.Request, responseEntity *T) error {
func (c *ResourceClient[T]) doReq(req *http.Request, responseEntity interface{}) error {
addTokenAsHeader(req)
resp, err := c.client.Do(req)
@@ -141,8 +230,8 @@ func readScimError(resp *http.Response) error {
}
func (c *ResourceClient[T]) buildURL(orgID, segment string) string {
if segment == "" {
return c.baseUrl + "/" + path.Join(orgID, c.resourceName)
if segment == "" || strings.HasPrefix(segment, "?") {
return c.baseUrl + "/" + path.Join(orgID, c.resourceName) + segment
}
return c.baseUrl + "/" + path.Join(orgID, c.resourceName, segment)