Lars e15094cdea
feat: add scim v2 service provider configuration endpoints (#9258)
# Which Problems Are Solved
* Adds support for the service provider configuration SCIM v2 endpoints

# How the Problems Are Solved
* Adds support for the service provider configuration SCIM v2 endpoints
  * `GET /scim/v2/{orgId}/ServiceProviderConfig`
  * `GET /scim/v2/{orgId}/ResourceTypes`
  * `GET /scim/v2/{orgId}/ResourceTypes/{name}`
  * `GET /scim/v2/{orgId}/Schemas`
  * `GET /scim/v2/{orgId}/Schemas/{id}`

# Additional Context
Part of #8140

Co-authored-by: Stefan Benz <46600784+stebenz@users.noreply.github.com>
2025-01-29 18:11:12 +00:00

343 lines
9.3 KiB
Go

package scim
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/url"
"path"
"strconv"
"strings"
"github.com/zitadel/logging"
"google.golang.org/grpc/metadata"
zhttp "github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/api/scim/middleware"
"github.com/zitadel/zitadel/internal/api/scim/resources"
"github.com/zitadel/zitadel/internal/api/scim/schemas"
)
type Client struct {
client *http.Client
baseURL string
Users *ResourceClient[resources.ScimUser]
}
type ResourceClient[T any] struct {
client *http.Client
baseURL string
resourceName string
}
type ScimError struct {
Schemas []string `json:"schemas"`
ScimType string `json:"scimType"`
Detail string `json:"detail"`
Status string `json:"status"`
ZitadelDetail *ZitadelErrorDetail `json:"urn:ietf:params:scim:api:zitadel:messages:2.0:ErrorDetail,omitempty"`
}
type ZitadelErrorDetail struct {
ID string `json:"id"`
Message string `json:"message"`
}
type ListRequest struct {
Schemas []schemas.ScimSchemaType `json:"schemas"`
Count *int `json:"count,omitempty"`
// StartIndex An integer indicating the 1-based index of the first query result.
StartIndex *int `json:"startIndex,omitempty"`
// Filter a scim filter expression to filter the query result.
Filter *string `json:"filter,omitempty"`
SortBy *string `json:"sortBy,omitempty"`
SortOrder *ListRequestSortOrder `json:"sortOrder,omitempty"`
SendAsPost bool
}
type ListRequestSortOrder string
const (
ListRequestSortOrderAsc ListRequestSortOrder = "ascending"
ListRequestSortOrderDsc ListRequestSortOrder = "descending"
)
type ListResponse[T any] struct {
Schemas []schemas.ScimSchemaType `json:"schemas"`
ItemsPerPage int `json:"itemsPerPage"`
TotalResults int `json:"totalResults"`
StartIndex int `json:"startIndex"`
Resources []T `json:"Resources"`
}
type BulkRequest struct {
Schemas []schemas.ScimSchemaType `json:"schemas"`
FailOnErrors *int `json:"failOnErrors"`
Operations []*BulkRequestOperation `json:"Operations"`
}
type BulkRequestOperation struct {
Method string `json:"method"`
BulkID string `json:"bulkId"`
Path string `json:"path"`
Data json.RawMessage `json:"data"`
}
type BulkResponse struct {
Schemas []schemas.ScimSchemaType `json:"schemas"`
Operations []*BulkResponseOperation `json:"Operations"`
}
type BulkResponseOperation struct {
Method string `json:"method"`
BulkID string `json:"bulkId,omitempty"`
Location string `json:"location,omitempty"`
Response *ScimError `json:"response,omitempty"`
Status string `json:"status"`
}
const (
listQueryParamSortBy = "sortBy"
listQueryParamSortOrder = "sortOrder"
listQueryParamCount = "count"
listQueryParamStartIndex = "startIndex"
listQueryParamFilter = "filter"
)
func NewScimClient(target string) *Client {
target = "http://" + target + schemas.HandlerPrefix
client := &http.Client{}
return &Client{
client: client,
baseURL: target,
Users: &ResourceClient[resources.ScimUser]{
client: client,
baseURL: target,
resourceName: "Users",
},
}
}
func (c *Client) GetServiceProviderConfig(ctx context.Context, orgID string) ([]byte, error) {
return c.getWithRawResponse(ctx, orgID, "/ServiceProviderConfig")
}
func (c *Client) GetSchemas(ctx context.Context, orgID string) ([]byte, error) {
return c.getWithRawResponse(ctx, orgID, "/Schemas")
}
func (c *Client) GetSchema(ctx context.Context, orgID, schemaID string) ([]byte, error) {
return c.getWithRawResponse(ctx, orgID, "/Schemas/"+schemaID)
}
func (c *Client) GetResourceTypes(ctx context.Context, orgID string) ([]byte, error) {
return c.getWithRawResponse(ctx, orgID, "/ResourceTypes")
}
func (c *Client) GetResourceType(ctx context.Context, orgID, name string) ([]byte, error) {
return c.getWithRawResponse(ctx, orgID, "/ResourceTypes/"+name)
}
func (c *Client) Bulk(ctx context.Context, orgID string, body []byte) (*BulkResponse, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/"+orgID+"/Bulk", bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set(zhttp.ContentType, middleware.ContentTypeScim)
resp := new(BulkResponse)
return resp, doReq(c.client, req, resp)
}
func (c *ResourceClient[T]) Create(ctx context.Context, orgID string, body []byte) (*T, error) {
return c.doWithBody(ctx, http.MethodPost, orgID, "", bytes.NewReader(body))
}
func (c *ResourceClient[T]) Replace(ctx context.Context, orgID, id string, body []byte) (*T, error) {
return c.doWithBody(ctx, http.MethodPut, orgID, id, bytes.NewReader(body))
}
func (c *ResourceClient[T]) Update(ctx context.Context, orgID, id string, body []byte) error {
req, err := http.NewRequestWithContext(ctx, http.MethodPatch, c.buildResourceURL(orgID, id), bytes.NewReader(body))
if err != nil {
return err
}
return doReq(c.client, req, nil)
}
func (c *ResourceClient[T]) List(ctx context.Context, orgID string, req *ListRequest) (*ListResponse[*T], error) {
listResponse := new(ListResponse[*T])
if req.SendAsPost {
listReq, err := json.Marshal(req)
if err != nil {
return nil, err
}
err = c.doWithResponse(ctx, http.MethodPost, orgID, ".search", bytes.NewReader(listReq), listResponse)
return listResponse, err
}
query, err := url.ParseQuery("")
if err != nil {
return nil, err
}
if req.SortBy != nil {
query.Set(listQueryParamSortBy, *req.SortBy)
}
if req.SortOrder != nil {
query.Set(listQueryParamSortOrder, string(*req.SortOrder))
}
if req.Count != nil {
query.Set(listQueryParamCount, strconv.Itoa(*req.Count))
}
if req.StartIndex != nil {
query.Set(listQueryParamStartIndex, strconv.Itoa(*req.StartIndex))
}
if req.Filter != nil {
query.Set(listQueryParamFilter, *req.Filter)
}
err = c.doWithResponse(ctx, http.MethodGet, orgID, "?"+query.Encode(), nil, listResponse)
return listResponse, err
}
func (c *ResourceClient[T]) Get(ctx context.Context, orgID, resourceID string) (*T, error) {
return c.doWithBody(ctx, http.MethodGet, orgID, resourceID, nil)
}
func (c *ResourceClient[T]) Delete(ctx context.Context, orgID, id string) error {
return c.do(ctx, http.MethodDelete, orgID, id)
}
func (c *ResourceClient[T]) do(ctx context.Context, method, orgID, url string) error {
req, err := http.NewRequestWithContext(ctx, method, c.buildResourceURL(orgID, url), nil)
if err != nil {
return err
}
return doReq(c.client, req, nil)
}
func (c *ResourceClient[T]) doWithResponse(ctx context.Context, method, orgID, url string, body io.Reader, response interface{}) error {
req, err := http.NewRequestWithContext(ctx, method, c.buildResourceURL(orgID, url), body)
if err != nil {
return err
}
req.Header.Set(zhttp.ContentType, middleware.ContentTypeScim)
return doReq(c.client, req, response)
}
func (c *ResourceClient[T]) doWithBody(ctx context.Context, method, orgID, url string, body io.Reader) (*T, error) {
req, err := http.NewRequestWithContext(ctx, method, c.buildResourceURL(orgID, url), body)
if err != nil {
return nil, err
}
req.Header.Set(zhttp.ContentType, middleware.ContentTypeScim)
responseEntity := new(T)
return responseEntity, doReq(c.client, req, responseEntity)
}
func (c *Client) getWithRawResponse(ctx context.Context, orgID, url string) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.baseURL+"/"+orgID+url, nil)
if err != nil {
return nil, err
}
resp, err := c.client.Do(req)
if err != nil {
return nil, err
}
defer func() {
err := resp.Body.Close()
logging.OnError(err).Error("Failed to close response body")
}()
if (resp.StatusCode / 100) != 2 {
return nil, readScimError(resp)
}
return io.ReadAll(resp.Body)
}
func doReq(client *http.Client, req *http.Request, responseEntity interface{}) error {
addTokenAsHeader(req)
resp, err := client.Do(req)
if err != nil {
return err
}
defer func() {
err := resp.Body.Close()
logging.OnError(err).Error("Failed to close response body")
}()
if (resp.StatusCode / 100) != 2 {
return readScimError(resp)
}
if responseEntity == nil {
return nil
}
err = readJson(responseEntity, resp)
return err
}
func addTokenAsHeader(req *http.Request) {
md, ok := metadata.FromOutgoingContext(req.Context())
if !ok {
return
}
req.Header.Set("Authorization", md.Get("Authorization")[0])
}
func readJson(entity interface{}, resp *http.Response) error {
defer func(body io.ReadCloser) {
err := body.Close()
logging.OnError(err).Panic("Failed to close response body")
}(resp.Body)
err := json.NewDecoder(resp.Body).Decode(entity)
logging.OnError(err).Panic("Failed decoding entity")
return err
}
func readScimError(resp *http.Response) error {
scimErr := new(ScimError)
readErr := readJson(scimErr, resp)
logging.OnError(readErr).Panic("Failed reading scim error")
return scimErr
}
func (c *ResourceClient[T]) buildResourceURL(orgID, segment string) string {
if segment == "" || strings.HasPrefix(segment, "?") {
return c.baseURL + "/" + path.Join(orgID, c.resourceName) + segment
}
return c.baseURL + "/" + path.Join(orgID, c.resourceName, segment)
}
func (err *ScimError) Error() string {
return "scim error: " + err.Detail
}