Lars 189f9770c6
feat: patch user scim v2 endpoint (#9219)
# Which Problems Are Solved
* Adds support for the patch user SCIM v2 endpoint

# How the Problems Are Solved
* Adds support for the patch user SCIM v2 endpoint under `PATCH
/scim/v2/{orgID}/Users/{id}`

# Additional Context
Part of #8140
2025-01-27 13:36:07 +01:00

274 lines
7.6 KiB
Go

package patch
import (
"encoding/json"
"reflect"
"slices"
"strings"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/scim/resources/filter"
"github.com/zitadel/zitadel/internal/api/scim/schemas"
"github.com/zitadel/zitadel/internal/api/scim/serrors"
"github.com/zitadel/zitadel/internal/zerrors"
)
type OperationRequest struct {
Schemas []schemas.ScimSchemaType `json:"Schemas"`
Operations []*Operation `json:"Operations"`
}
type Operation struct {
Operation OperationType `json:"op"`
Path *filter.Path `json:"path"`
Value json.RawMessage `json:"value"`
valueIsArray bool
}
type OperationCollection []*Operation
type OperationType string
const (
OperationTypeAdd OperationType = "add"
OperationTypeRemove OperationType = "remove"
OperationTypeReplace OperationType = "replace"
fieldNamePrimary = "Primary"
fieldNameValue = "Value"
)
type ResourcePatcher interface {
FilterEvaluator() *filter.Evaluator
Added(attributePath []string) error
Replaced(attributePath []string) error
Removed(attributePath []string) error
}
func (req *OperationRequest) Validate() error {
if !slices.Contains(req.Schemas, schemas.IdPatchOperation) {
return serrors.ThrowInvalidSyntax(zerrors.ThrowInvalidArgumentf(nil, "SCIM-xy1schema", "Expected schema %v is not provided", schemas.IdPatchOperation))
}
for _, op := range req.Operations {
if err := op.validate(); err != nil {
return err
}
}
return nil
}
func (op *Operation) validate() error {
if !op.Operation.isValid() {
return serrors.ThrowInvalidValue(zerrors.ThrowInvalidArgumentf(nil, "SCIM-opty1", "Patch op %s not supported", op.Operation))
}
// json deserialization initializes this field if an empty string is provided
// to not special case this in the further code,
// set it to nil here.
if op.Path.IsZero() {
op.Path = nil
}
op.valueIsArray = strings.HasPrefix(strings.TrimPrefix(string(op.Value), " "), "[")
return nil
}
func (ops OperationCollection) Apply(patcher ResourcePatcher, value interface{}) error {
for _, op := range ops {
if err := op.validate(); err != nil {
return err
}
if err := op.apply(patcher, value); err != nil {
return err
}
}
return nil
}
func (op *Operation) apply(patcher ResourcePatcher, value interface{}) error {
switch op.Operation {
case OperationTypeRemove:
return applyRemovePatch(patcher, op, value)
case OperationTypeReplace:
return applyReplacePatch(patcher, op, value)
case OperationTypeAdd:
return applyAddPatch(patcher, op, value)
}
return zerrors.ThrowInvalidArgumentf(nil, "SCIM-opty3", "SCIM patch: Invalid operation %v", op.Operation)
}
func (o OperationType) isValid() bool {
switch o {
case OperationTypeAdd, OperationTypeRemove, OperationTypeReplace:
return true
default:
return false
}
}
func flattenAndApplyPatchOperations(patcher ResourcePatcher, op *Operation, value interface{}) error {
ops, err := flattenPatchOperations(op)
if err != nil {
return err
}
for _, flattenedOperation := range ops {
if err = flattenedOperation.apply(patcher, value); err != nil {
return err
}
}
return nil
}
// flattenPatchOperations flattens patch operations without a path
// it converts an op { "operation": "add", "value": { "path1": "value1", "path2": "value2" } }
// into [ { "operation": "add", "path": "path1", "value": "value1" }, { "operation": "add", "path": "path2", "value": "value2" } ]
func flattenPatchOperations(op *Operation) ([]*Operation, error) {
if op.Path != nil {
panic("Only operations without a path can be flattened")
}
patches := map[string]json.RawMessage{}
if err := json.Unmarshal(op.Value, &patches); err != nil {
logging.WithError(err).Error("SCIM: Invalid patch value while flattening")
return nil, zerrors.ThrowInvalidArgument(err, "SCIM-ioyl1", "Invalid patch value")
}
result := make([]*Operation, 0, len(patches))
for path, value := range patches {
result = append(result, &Operation{
Operation: op.Operation,
Path: &filter.Path{
AttrPath: &filter.AttrPath{
AttrName: path,
},
},
Value: value,
valueIsArray: strings.HasPrefix(string(value), "["),
})
}
return result, nil
}
// unmarshalPatchValuesSlice unmarshal the raw json value (a scalar value, object or array) into a new slice
func unmarshalPatchValuesSlice(elementTypePtr reflect.Type, value json.RawMessage, valueIsArray bool) (reflect.Value, error) {
if elementTypePtr.Kind() != reflect.Ptr {
logging.Panicf("elementType must be a pointer to a struct, but is %s", elementTypePtr.Name())
return reflect.Value{}, nil
}
if !valueIsArray {
newElement := reflect.New(elementTypePtr.Elem())
if err := unmarshalPatchValue(value, newElement); err != nil {
return reflect.Value{}, err
}
newSlice := reflect.MakeSlice(reflect.SliceOf(elementTypePtr), 1, 1)
newSlice.Index(0).Set(newElement)
return newSlice, nil
}
newSlicePtr := reflect.New(reflect.SliceOf(elementTypePtr))
newSlice := newSlicePtr.Elem()
if err := json.Unmarshal(value, newSlicePtr.Interface()); err != nil {
logging.WithError(err).Error("SCIM: Invalid patch values")
return reflect.Value{}, zerrors.ThrowInvalidArgument(err, "SCIM-opxx8", "Invalid patch values")
}
return newSlice, nil
}
func unmarshalPatchValue(newValue json.RawMessage, targetElement reflect.Value) error {
if targetElement.Kind() != reflect.Ptr {
targetElement = targetElement.Addr()
}
if targetElement.IsNil() {
targetElement.Set(reflect.New(targetElement.Type().Elem()))
}
if err := json.Unmarshal(newValue, targetElement.Interface()); err != nil {
logging.WithError(err).Error("SCIM: Invalid patch value")
return zerrors.ThrowInvalidArgument(err, "SCIM-opty9", "Invalid patch value")
}
return nil
}
// ensureSinglePrimary ensures the modification on a slice results in max one primary element.
// modifiedSlice contains the patched slice.
// modifiedElementsSlice contains only the modified elements.
// if a new element has Primary set to true and an existing is also Primary, the existing Primary flag is set to false.
// returns an error if multiple modifiedElements have a primary value of true.
func ensureSinglePrimary(modifiedSlice reflect.Value, modifiedElementsSlice []reflect.Value, modifiedElementIndexes map[int]bool) error {
if len(modifiedElementsSlice) == 0 {
return nil
}
hasPrimary, err := isAnyPrimary(modifiedElementsSlice)
if err != nil || !hasPrimary {
return err
}
for i := 0; i < modifiedSlice.Len(); i++ {
if mod, ok := modifiedElementIndexes[i]; ok && mod {
continue
}
sliceElement := modifiedSlice.Index(i)
if sliceElement.Kind() == reflect.Ptr {
sliceElement = sliceElement.Elem()
}
sliceElementPrimaryField := sliceElement.FieldByName(fieldNamePrimary)
if !sliceElementPrimaryField.IsValid() || !sliceElementPrimaryField.Bool() {
continue
}
sliceElementPrimaryField.SetBool(false)
// we can stop at the first primary,
// since there can only be one primary in a slice.
return nil
}
return nil
}
func isAnyPrimary(elements []reflect.Value) (bool, error) {
foundPrimary := false
for _, element := range elements {
if !isPrimary(element) {
continue
}
if foundPrimary {
return true, zerrors.ThrowInvalidArgument(nil, "SCIM-1d23", "Cannot add multiple primary values in one patch operation")
}
foundPrimary = true
}
return foundPrimary, nil
}
func isPrimary(element reflect.Value) bool {
if element.Kind() == reflect.Ptr {
element = element.Elem()
}
if element.Kind() != reflect.Struct {
return false
}
primaryField := element.FieldByName(fieldNamePrimary)
return primaryField.IsValid() && primaryField.Bool()
}