mirror of
https://github.com/zitadel/zitadel.git
synced 2025-02-28 23:27:23 +00:00
Merge branch 'next' into next-rc
This commit is contained in:
commit
b2fbf9ace3
@ -23,5 +23,5 @@ func (mig *User11AddLowerFieldsToVerifiedEmail) Execute(ctx context.Context, _ e
|
||||
}
|
||||
|
||||
func (mig *User11AddLowerFieldsToVerifiedEmail) String() string {
|
||||
return "25_user13_add_lower_fields_to_verified_email"
|
||||
return "25_user14_add_lower_fields_to_verified_email"
|
||||
}
|
||||
|
@ -37,14 +37,14 @@ BEGIN
|
||||
END;
|
||||
|
||||
-- Return the organizations where permission were granted thru org-level roles
|
||||
SELECT array_agg(org_id) INTO org_ids
|
||||
SELECT array_agg(sub.org_id) INTO org_ids
|
||||
FROM (
|
||||
SELECT DISTINCT om.org_id
|
||||
FROM eventstore.org_members om
|
||||
WHERE om.role = ANY(matched_roles)
|
||||
AND om.instance_id = instanceID
|
||||
AND om.user_id = userId
|
||||
);
|
||||
) AS sub;
|
||||
RETURN;
|
||||
END;
|
||||
$$;
|
||||
|
@ -1,36 +1,29 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
grpc_trace "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/stats"
|
||||
|
||||
grpc_utils "github.com/zitadel/zitadel/internal/api/grpc"
|
||||
)
|
||||
|
||||
type GRPCMethod string
|
||||
|
||||
func DefaultTracingClient() grpc.UnaryClientInterceptor {
|
||||
return TracingServer(grpc_utils.Healthz, grpc_utils.Readiness, grpc_utils.Validation)
|
||||
func DefaultTracingClient() stats.Handler {
|
||||
return TracingClient(grpc_utils.Healthz, grpc_utils.Readiness, grpc_utils.Validation)
|
||||
}
|
||||
|
||||
func TracingServer(ignoredMethods ...GRPCMethod) grpc.UnaryClientInterceptor {
|
||||
return func(
|
||||
ctx context.Context,
|
||||
method string,
|
||||
req, reply interface{},
|
||||
cc *grpc.ClientConn,
|
||||
invoker grpc.UnaryInvoker,
|
||||
opts ...grpc.CallOption,
|
||||
) error {
|
||||
|
||||
for _, ignoredMethod := range ignoredMethods {
|
||||
if strings.HasSuffix(method, string(ignoredMethod)) {
|
||||
return invoker(ctx, method, req, reply, cc, opts...)
|
||||
func TracingClient(ignoredMethods ...GRPCMethod) stats.Handler {
|
||||
return grpc_trace.NewClientHandler(grpc_trace.WithFilter(
|
||||
func(info *stats.RPCTagInfo) bool {
|
||||
for _, ignoredMethod := range ignoredMethods {
|
||||
if strings.HasSuffix(info.FullMethodName, string(ignoredMethod)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return grpc_trace.UnaryClientInterceptor()(ctx, method, req, reply, cc, invoker, opts...)
|
||||
}
|
||||
return true
|
||||
},
|
||||
))
|
||||
}
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
|
||||
"github.com/zitadel/logging"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
@ -56,6 +57,13 @@ var (
|
||||
},
|
||||
)
|
||||
|
||||
// we need the errorHandler to set the request URI pattern in case of an error
|
||||
errorHandler = runtime.ErrorHandlerFunc(
|
||||
func(ctx context.Context, mux *runtime.ServeMux, marshaler runtime.Marshaler, w http.ResponseWriter, r *http.Request, err error) {
|
||||
setRequestURIPattern(ctx)
|
||||
runtime.DefaultHTTPErrorHandler(ctx, mux, marshaler, w, r, err)
|
||||
})
|
||||
|
||||
serveMuxOptions = func(hostHeaders []string) []runtime.ServeMuxOption {
|
||||
return []runtime.ServeMuxOption{
|
||||
runtime.WithMarshalerOption(jsonMarshaler.ContentType(nil), jsonMarshaler),
|
||||
@ -65,6 +73,7 @@ var (
|
||||
runtime.WithOutgoingHeaderMatcher(runtime.DefaultHeaderMatcher),
|
||||
runtime.WithForwardResponseOption(responseForwarder),
|
||||
runtime.WithRoutingErrorHandler(httpErrorHandler),
|
||||
runtime.WithErrorHandler(errorHandler),
|
||||
}
|
||||
}
|
||||
|
||||
@ -81,6 +90,7 @@ var (
|
||||
}
|
||||
|
||||
responseForwarder = func(ctx context.Context, w http.ResponseWriter, resp proto.Message) error {
|
||||
setRequestURIPattern(ctx)
|
||||
t, ok := resp.(CustomHTTPResponse)
|
||||
if ok {
|
||||
// TODO: find a way to return a location header if needed w.Header().Set("location", t.Location())
|
||||
@ -118,9 +128,9 @@ func CreateGatewayWithPrefix(
|
||||
opts := []grpc.DialOption{
|
||||
grpc.WithTransportCredentials(grpcCredentials(tlsConfig)),
|
||||
grpc.WithChainUnaryInterceptor(
|
||||
client_middleware.DefaultTracingClient(),
|
||||
client_middleware.UnaryActivityClientInterceptor(),
|
||||
),
|
||||
grpc.WithStatsHandler(client_middleware.DefaultTracingClient()),
|
||||
}
|
||||
connection, err := dial(ctx, port, opts)
|
||||
if err != nil {
|
||||
@ -145,9 +155,9 @@ func CreateGateway(
|
||||
[]grpc.DialOption{
|
||||
grpc.WithTransportCredentials(grpcCredentials(tlsConfig)),
|
||||
grpc.WithChainUnaryInterceptor(
|
||||
client_middleware.DefaultTracingClient(),
|
||||
client_middleware.UnaryActivityClientInterceptor(),
|
||||
),
|
||||
grpc.WithStatsHandler(client_middleware.DefaultTracingClient()),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -260,3 +270,13 @@ func grpcCredentials(tlsConfig *tls.Config) credentials.TransportCredentials {
|
||||
}
|
||||
return creds
|
||||
}
|
||||
|
||||
func setRequestURIPattern(ctx context.Context) {
|
||||
pattern, ok := runtime.HTTPPathPattern(ctx)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
span := trace.SpanFromContext(ctx)
|
||||
span.SetName(pattern)
|
||||
metrics.SetRequestURIPattern(ctx, pattern)
|
||||
}
|
||||
|
@ -1,34 +1,29 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
grpc_trace "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/stats"
|
||||
|
||||
grpc_utils "github.com/zitadel/zitadel/internal/api/grpc"
|
||||
)
|
||||
|
||||
type GRPCMethod string
|
||||
|
||||
func DefaultTracingServer() grpc.UnaryServerInterceptor {
|
||||
func DefaultTracingServer() stats.Handler {
|
||||
return TracingServer(grpc_utils.Healthz, grpc_utils.Readiness, grpc_utils.Validation)
|
||||
}
|
||||
|
||||
func TracingServer(ignoredMethods ...GRPCMethod) grpc.UnaryServerInterceptor {
|
||||
return func(
|
||||
ctx context.Context,
|
||||
req interface{},
|
||||
info *grpc.UnaryServerInfo,
|
||||
handler grpc.UnaryHandler,
|
||||
) (interface{}, error) {
|
||||
|
||||
for _, ignoredMethod := range ignoredMethods {
|
||||
if strings.HasSuffix(info.FullMethod, string(ignoredMethod)) {
|
||||
return handler(ctx, req)
|
||||
func TracingServer(ignoredMethods ...GRPCMethod) stats.Handler {
|
||||
return grpc_trace.NewServerHandler(grpc_trace.WithFilter(
|
||||
func(info *stats.RPCTagInfo) bool {
|
||||
for _, ignoredMethod := range ignoredMethods {
|
||||
if strings.HasSuffix(info.FullMethodName, string(ignoredMethod)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return grpc_trace.UnaryServerInterceptor()(ctx, req, info, handler)
|
||||
}
|
||||
return true
|
||||
},
|
||||
))
|
||||
}
|
||||
|
@ -47,7 +47,6 @@ func CreateServer(
|
||||
grpc.UnaryInterceptor(
|
||||
grpc_middleware.ChainUnaryServer(
|
||||
middleware.CallDurationHandler(),
|
||||
middleware.DefaultTracingServer(),
|
||||
middleware.MetricsHandler(metricTypes, grpc_api.Probes...),
|
||||
middleware.NoCacheInterceptor(),
|
||||
middleware.InstanceInterceptor(queries, externalDomain, system_pb.SystemService_ServiceDesc.ServiceName, healthpb.Health_ServiceDesc.ServiceName),
|
||||
@ -63,6 +62,7 @@ func CreateServer(
|
||||
middleware.ActivityInterceptor(),
|
||||
),
|
||||
),
|
||||
grpc.StatsHandler(middleware.DefaultTracingServer()),
|
||||
}
|
||||
if tlsConfig != nil {
|
||||
serverOptions = append(serverOptions, grpc.Creds(credentials.NewTLS(tlsConfig)))
|
||||
|
@ -275,7 +275,7 @@ func (s *Server) DeleteUser(ctx context.Context, req *user.DeleteUserRequest) (_
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
details, err := s.command.RemoveUserV2(ctx, req.UserId, memberships, grants...)
|
||||
details, err := s.command.RemoveUserV2(ctx, req.UserId, "", memberships, grants...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -278,7 +278,7 @@ func (s *Server) DeleteUser(ctx context.Context, req *user.DeleteUserRequest) (_
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
details, err := s.command.RemoveUserV2(ctx, req.UserId, memberships, grants...)
|
||||
details, err := s.command.RemoveUserV2(ctx, req.UserId, "", memberships, grants...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -14,7 +14,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/brianvoe/gofakeit/v6"
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/text/language"
|
||||
@ -24,6 +23,7 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/integration"
|
||||
"github.com/zitadel/zitadel/internal/integration/scim"
|
||||
"github.com/zitadel/zitadel/internal/test"
|
||||
"github.com/zitadel/zitadel/pkg/grpc/management"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -289,7 +289,7 @@ func TestBulk(t *testing.T) {
|
||||
},
|
||||
DisplayName: "scim-bulk-created-user-0-given-name scim-bulk-created-user-0-family-name",
|
||||
PreferredLanguage: test.Must(language.Parse("en")),
|
||||
Active: gu.Ptr(true),
|
||||
Active: schemas.NewRelaxedBool(true),
|
||||
Emails: []*resources.ScimEmail{
|
||||
{
|
||||
Value: "scim-bulk-created-user-0@example.com",
|
||||
@ -308,7 +308,7 @@ func TestBulk(t *testing.T) {
|
||||
DisplayName: "scim-bulk-created-user-1-given-name scim-bulk-created-user-1-family-name",
|
||||
NickName: "scim-bulk-created-user-1-nickname-patched",
|
||||
PreferredLanguage: test.Must(language.Parse("en")),
|
||||
Active: gu.Ptr(true),
|
||||
Active: schemas.NewRelaxedBool(true),
|
||||
Emails: []*resources.ScimEmail{
|
||||
{
|
||||
Value: "scim-bulk-created-user-1@example.com",
|
||||
@ -333,7 +333,7 @@ func TestBulk(t *testing.T) {
|
||||
DisplayName: "scim-bulk-created-user-2-given-name scim-bulk-created-user-2-family-name",
|
||||
NickName: "scim-bulk-created-user-2-nickname-patched",
|
||||
PreferredLanguage: test.Must(language.Parse("en")),
|
||||
Active: gu.Ptr(true),
|
||||
Active: schemas.NewRelaxedBool(true),
|
||||
Emails: []*resources.ScimEmail{
|
||||
{
|
||||
Value: "scim-bulk-created-user-2@example.com",
|
||||
@ -696,3 +696,39 @@ func buildTooManyOperationsRequest() *scim.BulkRequest {
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
func setProvisioningDomain(t require.TestingT, userID, provisioningDomain string) {
|
||||
setAndEnsureMetadata(t, userID, "urn:zitadel:scim:provisioningDomain", provisioningDomain)
|
||||
}
|
||||
|
||||
func setAndEnsureMetadata(t require.TestingT, userID, key, value string) {
|
||||
_, err := Instance.Client.Mgmt.SetUserMetadata(CTX, &management.SetUserMetadataRequest{
|
||||
Id: userID,
|
||||
Key: key,
|
||||
Value: []byte(value),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// ensure metadata is projected
|
||||
ensureMetadataProjected(t, userID, key, value)
|
||||
}
|
||||
|
||||
func ensureMetadataProjected(t require.TestingT, userID, key, value string) {
|
||||
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
require.EventuallyWithT(t, func(tt *assert.CollectT) {
|
||||
md, err := Instance.Client.Mgmt.GetUserMetadata(CTX, &management.GetUserMetadataRequest{
|
||||
Id: userID,
|
||||
Key: key,
|
||||
})
|
||||
require.NoError(tt, err)
|
||||
require.Equal(tt, value, string(md.Metadata.Value))
|
||||
}, retryDuration, tick)
|
||||
}
|
||||
|
||||
func removeProvisioningDomain(t require.TestingT, userID string) {
|
||||
_, err := Instance.Client.Mgmt.RemoveUserMetadata(CTX, &management.RemoveUserMetadataRequest{
|
||||
Id: userID,
|
||||
Key: "urn:zitadel:scim:provisioningDomain",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
@ -9,12 +9,16 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/brianvoe/gofakeit/v6"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/integration"
|
||||
"github.com/zitadel/zitadel/pkg/grpc/org/v2"
|
||||
)
|
||||
|
||||
var (
|
||||
Instance *integration.Instance
|
||||
CTX context.Context
|
||||
Instance *integration.Instance
|
||||
SecondaryOrganization *org.AddOrganizationResponse
|
||||
CTX context.Context
|
||||
|
||||
// remove comments in the json, as the default golang json unmarshaler cannot handle them
|
||||
// some test files (e.g. bulk, patch) are much easier to maintain with comments
|
||||
@ -29,6 +33,10 @@ func TestMain(m *testing.M) {
|
||||
Instance = integration.NewInstance(ctx)
|
||||
|
||||
CTX = Instance.WithAuthorization(ctx, integration.UserTypeOrgOwner)
|
||||
|
||||
iamOwnerCtx := Instance.WithAuthorization(CTX, integration.UserTypeIAMOwner)
|
||||
SecondaryOrganization = Instance.CreateOrganization(iamOwnerCtx, gofakeit.Name(), gofakeit.Email())
|
||||
|
||||
return m.Run()
|
||||
}())
|
||||
}
|
||||
|
@ -3,7 +3,7 @@
|
||||
"Operations": [
|
||||
// add without path
|
||||
{
|
||||
"op": "add",
|
||||
"op": "Add", // with PascalCase operation type
|
||||
"value": {
|
||||
"emails":[
|
||||
{
|
||||
@ -17,7 +17,7 @@
|
||||
},
|
||||
// add complex attribute with path
|
||||
{
|
||||
"op": "add",
|
||||
"op": "add", // with camelCase operation type
|
||||
"path": "name",
|
||||
"value": {
|
||||
"formatted": "added-formatted",
|
||||
@ -30,7 +30,7 @@
|
||||
},
|
||||
// add complex attribute value
|
||||
{
|
||||
"op": "add",
|
||||
"op": "ADD", // with UPPERCASE operation type
|
||||
"path": "name.middlename",
|
||||
"value": "added-middle-name-2"
|
||||
},
|
||||
@ -134,6 +134,14 @@
|
||||
"op": "replace",
|
||||
"path": "password",
|
||||
"value": "Password2!"
|
||||
},
|
||||
// replace active state
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "active",
|
||||
// quoted uppercase bool
|
||||
// (ensure compatibility with Microsoft Entra)
|
||||
"value": "True"
|
||||
}
|
||||
]
|
||||
}
|
@ -10,8 +10,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/brianvoe/gofakeit/v6"
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/text/language"
|
||||
@ -155,7 +153,7 @@ var (
|
||||
PreferredLanguage: language.MustParse("en-US"),
|
||||
Locale: "en-US",
|
||||
Timezone: "America/Los_Angeles",
|
||||
Active: gu.Ptr(true),
|
||||
Active: schemas.NewRelaxedBool(true),
|
||||
}
|
||||
)
|
||||
|
||||
@ -164,6 +162,7 @@ func TestCreateUser(t *testing.T) {
|
||||
name string
|
||||
body []byte
|
||||
ctx context.Context
|
||||
orgID string
|
||||
want *resources.ScimUser
|
||||
wantErr bool
|
||||
scimErrorType string
|
||||
@ -191,7 +190,7 @@ func TestCreateUser(t *testing.T) {
|
||||
name: "minimal inactive user",
|
||||
body: minimalInactiveUserJson,
|
||||
want: &resources.ScimUser{
|
||||
Active: gu.Ptr(false),
|
||||
Active: schemas.NewRelaxedBool(false),
|
||||
},
|
||||
},
|
||||
{
|
||||
@ -275,6 +274,13 @@ func TestCreateUser(t *testing.T) {
|
||||
wantErr: true,
|
||||
errorStatus: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "another org",
|
||||
body: minimalUserJson,
|
||||
orgID: SecondaryOrganization.OrganizationId,
|
||||
wantErr: true,
|
||||
errorStatus: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
@ -283,7 +289,12 @@ func TestCreateUser(t *testing.T) {
|
||||
ctx = CTX
|
||||
}
|
||||
|
||||
createdUser, err := Instance.Client.SCIM.Users.Create(ctx, Instance.DefaultOrg.Id, tt.body)
|
||||
orgID := tt.orgID
|
||||
if orgID == "" {
|
||||
orgID = Instance.DefaultOrg.Id
|
||||
}
|
||||
|
||||
createdUser, err := Instance.Client.SCIM.Users.Create(ctx, orgID, tt.body)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("CreateUser() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
@ -311,7 +322,7 @@ func TestCreateUser(t *testing.T) {
|
||||
|
||||
assert.EqualValues(t, []schemas.ScimSchemaType{"urn:ietf:params:scim:schemas:core:2.0:User"}, createdUser.Resource.Schemas)
|
||||
assert.Equal(t, schemas.ScimResourceTypeSingular("User"), createdUser.Resource.Meta.ResourceType)
|
||||
assert.Equal(t, "http://"+Instance.Host()+path.Join(schemas.HandlerPrefix, Instance.DefaultOrg.Id, "Users", createdUser.ID), createdUser.Resource.Meta.Location)
|
||||
assert.Equal(t, "http://"+Instance.Host()+path.Join(schemas.HandlerPrefix, orgID, "Users", createdUser.ID), createdUser.Resource.Meta.Location)
|
||||
assert.Nil(t, createdUser.Password)
|
||||
|
||||
if tt.want != nil {
|
||||
@ -384,12 +395,7 @@ func TestCreateUser_metadata(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCreateUser_scopedExternalID(t *testing.T) {
|
||||
_, err := Instance.Client.Mgmt.SetUserMetadata(CTX, &management.SetUserMetadataRequest{
|
||||
Id: Instance.Users.Get(integration.UserTypeOrgOwner).ID,
|
||||
Key: "urn:zitadel:scim:provisioningDomain",
|
||||
Value: []byte("fooBar"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
setProvisioningDomain(t, Instance.Users.Get(integration.UserTypeOrgOwner).ID, "fooBar")
|
||||
|
||||
createdUser, err := Instance.Client.SCIM.Users.Create(CTX, Instance.DefaultOrg.Id, fullUserJson)
|
||||
require.NoError(t, err)
|
||||
@ -398,11 +404,7 @@ func TestCreateUser_scopedExternalID(t *testing.T) {
|
||||
_, err = Instance.Client.UserV2.DeleteUser(CTX, &user.DeleteUserRequest{UserId: createdUser.ID})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = Instance.Client.Mgmt.RemoveUserMetadata(CTX, &management.RemoveUserMetadataRequest{
|
||||
Id: Instance.Users.Get(integration.UserTypeOrgOwner).ID,
|
||||
Key: "urn:zitadel:scim:provisioningDomain",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
removeProvisioningDomain(t, Instance.Users.Get(integration.UserTypeOrgOwner).ID)
|
||||
}()
|
||||
|
||||
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
@ -423,9 +425,3 @@ func TestCreateUser_scopedExternalID(t *testing.T) {
|
||||
assert.Equal(tt, "701984", string(md.Metadata.Value))
|
||||
}, retryDuration, tick)
|
||||
}
|
||||
|
||||
func TestCreateUser_anotherOrg(t *testing.T) {
|
||||
org := Instance.CreateOrganization(Instance.WithAuthorization(CTX, integration.UserTypeIAMOwner), gofakeit.Name(), gofakeit.Email())
|
||||
_, err := Instance.Client.SCIM.Users.Create(CTX, org.OrganizationId, fullUserJson)
|
||||
scim.RequireScimError(t, http.StatusNotFound, err)
|
||||
}
|
||||
|
@ -8,7 +8,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/brianvoe/gofakeit/v6"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc/codes"
|
||||
@ -22,6 +21,7 @@ func TestDeleteUser_errors(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx context.Context
|
||||
orgID string
|
||||
errorStatus int
|
||||
}{
|
||||
{
|
||||
@ -38,6 +38,17 @@ func TestDeleteUser_errors(t *testing.T) {
|
||||
name: "unknown user id",
|
||||
errorStatus: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "another org",
|
||||
orgID: SecondaryOrganization.OrganizationId,
|
||||
errorStatus: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "another org with permissions",
|
||||
orgID: SecondaryOrganization.OrganizationId,
|
||||
ctx: Instance.WithAuthorization(CTX, integration.UserTypeIAMOwner),
|
||||
errorStatus: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
@ -46,7 +57,11 @@ func TestDeleteUser_errors(t *testing.T) {
|
||||
ctx = CTX
|
||||
}
|
||||
|
||||
err := Instance.Client.SCIM.Users.Delete(ctx, Instance.DefaultOrg.Id, "1")
|
||||
orgID := tt.orgID
|
||||
if orgID == "" {
|
||||
orgID = Instance.DefaultOrg.Id
|
||||
}
|
||||
err := Instance.Client.SCIM.Users.Delete(ctx, orgID, "1")
|
||||
|
||||
statusCode := tt.errorStatus
|
||||
if statusCode == 0 {
|
||||
@ -81,10 +96,3 @@ func TestDeleteUser_ensureReallyDeleted(t *testing.T) {
|
||||
integration.AssertGrpcStatus(tt, codes.NotFound, err)
|
||||
}, retryDuration, tick)
|
||||
}
|
||||
|
||||
func TestDeleteUser_anotherOrg(t *testing.T) {
|
||||
createUserResp := Instance.CreateHumanUser(CTX)
|
||||
org := Instance.CreateOrganization(Instance.WithAuthorization(CTX, integration.UserTypeIAMOwner), gofakeit.Name(), gofakeit.Email())
|
||||
err := Instance.Client.SCIM.Users.Delete(CTX, org.OrganizationId, createUserResp.UserId)
|
||||
scim.RequireScimError(t, http.StatusNotFound, err)
|
||||
}
|
||||
|
@ -9,8 +9,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/brianvoe/gofakeit/v6"
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/text/language"
|
||||
@ -20,13 +18,13 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/integration"
|
||||
"github.com/zitadel/zitadel/internal/integration/scim"
|
||||
"github.com/zitadel/zitadel/internal/test"
|
||||
"github.com/zitadel/zitadel/pkg/grpc/management"
|
||||
"github.com/zitadel/zitadel/pkg/grpc/user/v2"
|
||||
)
|
||||
|
||||
func TestGetUser(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
orgID string
|
||||
buildUserID func() string
|
||||
cleanup func(userID string)
|
||||
ctx context.Context
|
||||
@ -46,6 +44,19 @@ func TestGetUser(t *testing.T) {
|
||||
errorStatus: http.StatusNotFound,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "another org",
|
||||
orgID: SecondaryOrganization.OrganizationId,
|
||||
errorStatus: http.StatusNotFound,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "another org with permissions",
|
||||
orgID: SecondaryOrganization.OrganizationId,
|
||||
ctx: Instance.WithAuthorization(CTX, integration.UserTypeNoPermission),
|
||||
errorStatus: http.StatusNotFound,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "unknown user id",
|
||||
buildUserID: func() string {
|
||||
@ -99,7 +110,7 @@ func TestGetUser(t *testing.T) {
|
||||
PreferredLanguage: language.Make("en-US"),
|
||||
Locale: "en-US",
|
||||
Timezone: "America/Los_Angeles",
|
||||
Active: gu.Ptr(true),
|
||||
Active: schemas.NewRelaxedBool(true),
|
||||
Emails: []*resources.ScimEmail{
|
||||
{
|
||||
Value: "bjensen@example.com",
|
||||
@ -191,31 +202,17 @@ func TestGetUser(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// set provisioning domain of service user
|
||||
_, err = Instance.Client.Mgmt.SetUserMetadata(CTX, &management.SetUserMetadataRequest{
|
||||
Id: Instance.Users.Get(integration.UserTypeOrgOwner).ID,
|
||||
Key: "urn:zitadel:scim:provisioningDomain",
|
||||
Value: []byte("fooBar"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
setProvisioningDomain(t, Instance.Users.Get(integration.UserTypeOrgOwner).ID, "fooBar")
|
||||
|
||||
// set externalID for provisioning domain
|
||||
_, err = Instance.Client.Mgmt.SetUserMetadata(CTX, &management.SetUserMetadataRequest{
|
||||
Id: createdUser.ID,
|
||||
Key: "urn:zitadel:scim:fooBar:externalId",
|
||||
Value: []byte("100-scopedExternalId"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
setAndEnsureMetadata(t, createdUser.ID, "urn:zitadel:scim:fooBar:externalId", "100-scopedExternalId")
|
||||
return createdUser.ID
|
||||
},
|
||||
cleanup: func(userID string) {
|
||||
_, err := Instance.Client.UserV2.DeleteUser(CTX, &user.DeleteUserRequest{UserId: userID})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = Instance.Client.Mgmt.RemoveUserMetadata(CTX, &management.RemoveUserMetadataRequest{
|
||||
Id: Instance.Users.Get(integration.UserTypeOrgOwner).ID,
|
||||
Key: "urn:zitadel:scim:provisioningDomain",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
removeProvisioningDomain(t, Instance.Users.Get(integration.UserTypeOrgOwner).ID)
|
||||
},
|
||||
want: &resources.ScimUser{
|
||||
ExternalID: "100-scopedExternalId",
|
||||
@ -237,11 +234,16 @@ func TestGetUser(t *testing.T) {
|
||||
userID = createUserResp.UserId
|
||||
}
|
||||
|
||||
orgID := tt.orgID
|
||||
if orgID == "" {
|
||||
orgID = Instance.DefaultOrg.Id
|
||||
}
|
||||
|
||||
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
var fetchedUser *resources.ScimUser
|
||||
var err error
|
||||
require.EventuallyWithT(t, func(ttt *assert.CollectT) {
|
||||
fetchedUser, err = Instance.Client.SCIM.Users.Get(ctx, Instance.DefaultOrg.Id, userID)
|
||||
fetchedUser, err = Instance.Client.SCIM.Users.Get(ctx, orgID, userID)
|
||||
if tt.wantErr {
|
||||
statusCode := tt.errorStatus
|
||||
if statusCode == 0 {
|
||||
@ -255,7 +257,7 @@ func TestGetUser(t *testing.T) {
|
||||
assert.Equal(ttt, userID, fetchedUser.ID)
|
||||
assert.EqualValues(ttt, []schemas.ScimSchemaType{"urn:ietf:params:scim:schemas:core:2.0:User"}, fetchedUser.Schemas)
|
||||
assert.Equal(ttt, schemas.ScimResourceTypeSingular("User"), fetchedUser.Resource.Meta.ResourceType)
|
||||
assert.Equal(ttt, "http://"+Instance.Host()+path.Join(schemas.HandlerPrefix, Instance.DefaultOrg.Id, "Users", fetchedUser.ID), fetchedUser.Resource.Meta.Location)
|
||||
assert.Equal(ttt, "http://"+Instance.Host()+path.Join(schemas.HandlerPrefix, orgID, "Users", fetchedUser.ID), fetchedUser.Resource.Meta.Location)
|
||||
assert.Nil(ttt, fetchedUser.Password)
|
||||
if !test.PartiallyDeepEqual(tt.want, fetchedUser) {
|
||||
ttt.Errorf("GetUser() got = %#v, want %#v", fetchedUser, tt.want)
|
||||
@ -268,10 +270,3 @@ func TestGetUser(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUser_anotherOrg(t *testing.T) {
|
||||
createUserResp := Instance.CreateHumanUser(CTX)
|
||||
org := Instance.CreateOrganization(Instance.WithAuthorization(CTX, integration.UserTypeIAMOwner), gofakeit.Name(), gofakeit.Email())
|
||||
_, err := Instance.Client.SCIM.Users.Get(CTX, org.OrganizationId, createUserResp.UserId)
|
||||
scim.RequireScimError(t, http.StatusNotFound, err)
|
||||
}
|
||||
|
@ -20,7 +20,6 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/integration"
|
||||
"github.com/zitadel/zitadel/internal/integration/scim"
|
||||
"github.com/zitadel/zitadel/internal/test"
|
||||
"github.com/zitadel/zitadel/pkg/grpc/management"
|
||||
"github.com/zitadel/zitadel/pkg/grpc/object/v2"
|
||||
user_v2 "github.com/zitadel/zitadel/pkg/grpc/user/v2"
|
||||
)
|
||||
@ -239,7 +238,7 @@ func TestListUser(t *testing.T) {
|
||||
assert.Equal(t, 1, resp.StartIndex)
|
||||
assert.Len(t, resp.Resources, 1)
|
||||
assert.True(t, strings.HasPrefix(resp.Resources[0].UserName, "scim-username-0"))
|
||||
assert.False(t, *resp.Resources[0].Active)
|
||||
assert.False(t, resp.Resources[0].Active.Bool())
|
||||
},
|
||||
},
|
||||
{
|
||||
@ -372,20 +371,10 @@ func TestListUser(t *testing.T) {
|
||||
resp := createHumanUser(t, CTX, Instance.DefaultOrg.Id, 102)
|
||||
|
||||
// set provisioning domain of service user
|
||||
_, err := Instance.Client.Mgmt.SetUserMetadata(CTX, &management.SetUserMetadataRequest{
|
||||
Id: Instance.Users.Get(integration.UserTypeOrgOwner).ID,
|
||||
Key: "urn:zitadel:scim:provisioningDomain",
|
||||
Value: []byte("fooBar"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
setProvisioningDomain(t, Instance.Users.Get(integration.UserTypeOrgOwner).ID, "fooBar")
|
||||
|
||||
// set externalID for provisioning domain
|
||||
_, err = Instance.Client.Mgmt.SetUserMetadata(CTX, &management.SetUserMetadataRequest{
|
||||
Id: resp.UserId,
|
||||
Key: "urn:zitadel:scim:fooBar:externalId",
|
||||
Value: []byte("100-scopedExternalId"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
setAndEnsureMetadata(t, resp.UserId, "urn:zitadel:scim:fooBar:externalId", "100-scopedExternalId")
|
||||
return &scim.ListRequest{
|
||||
Filter: gu.Ptr(fmt.Sprintf(`id eq "%s"`, resp.UserId)),
|
||||
}
|
||||
@ -396,11 +385,7 @@ func TestListUser(t *testing.T) {
|
||||
},
|
||||
cleanup: func(t require.TestingT) {
|
||||
// delete provisioning domain of service user
|
||||
_, err := Instance.Client.Mgmt.RemoveUserMetadata(CTX, &management.RemoveUserMetadataRequest{
|
||||
Id: Instance.Users.Get(integration.UserTypeOrgOwner).ID,
|
||||
Key: "urn:zitadel:scim:provisioningDomain",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
removeProvisioningDomain(t, Instance.Users.Get(integration.UserTypeOrgOwner).ID)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
@ -10,7 +10,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/text/language"
|
||||
@ -37,14 +36,16 @@ var (
|
||||
|
||||
func TestReplaceUser(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body []byte
|
||||
ctx context.Context
|
||||
want *resources.ScimUser
|
||||
wantErr bool
|
||||
scimErrorType string
|
||||
errorStatus int
|
||||
zitadelErrID string
|
||||
name string
|
||||
body []byte
|
||||
ctx context.Context
|
||||
createUserOrgID string
|
||||
replaceUserOrgID string
|
||||
want *resources.ScimUser
|
||||
wantErr bool
|
||||
scimErrorType string
|
||||
errorStatus int
|
||||
zitadelErrID string
|
||||
}{
|
||||
{
|
||||
name: "minimal user",
|
||||
@ -165,7 +166,7 @@ func TestReplaceUser(t *testing.T) {
|
||||
PreferredLanguage: language.MustParse("en-CH"),
|
||||
Locale: "en-CH",
|
||||
Timezone: "Europe/Zurich",
|
||||
Active: gu.Ptr(false),
|
||||
Active: schemas.NewRelaxedBool(false),
|
||||
},
|
||||
},
|
||||
{
|
||||
@ -207,10 +208,26 @@ func TestReplaceUser(t *testing.T) {
|
||||
wantErr: true,
|
||||
errorStatus: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "another org",
|
||||
body: minimalUserJson,
|
||||
replaceUserOrgID: SecondaryOrganization.OrganizationId,
|
||||
wantErr: true,
|
||||
errorStatus: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "another org with permissions",
|
||||
body: minimalUserJson,
|
||||
replaceUserOrgID: SecondaryOrganization.OrganizationId,
|
||||
ctx: Instance.WithAuthorization(CTX, integration.UserTypeIAMOwner),
|
||||
wantErr: true,
|
||||
errorStatus: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
createdUser, err := Instance.Client.SCIM.Users.Create(CTX, Instance.DefaultOrg.Id, fullUserJson)
|
||||
// use iam owner => we don't want to test permissions of the create endpoint.
|
||||
createdUser, err := Instance.Client.SCIM.Users.Create(Instance.WithAuthorization(CTX, integration.UserTypeIAMOwner), Instance.DefaultOrg.Id, fullUserJson)
|
||||
require.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
@ -223,7 +240,12 @@ func TestReplaceUser(t *testing.T) {
|
||||
ctx = CTX
|
||||
}
|
||||
|
||||
replacedUser, err := Instance.Client.SCIM.Users.Replace(ctx, Instance.DefaultOrg.Id, createdUser.ID, tt.body)
|
||||
replaceUserOrgID := tt.replaceUserOrgID
|
||||
if replaceUserOrgID == "" {
|
||||
replaceUserOrgID = Instance.DefaultOrg.Id
|
||||
}
|
||||
|
||||
replacedUser, err := Instance.Client.SCIM.Users.Replace(ctx, replaceUserOrgID, createdUser.ID, tt.body)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ReplaceUser() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
@ -294,12 +316,7 @@ func TestReplaceUser_scopedExternalID(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// set provisioning domain of service user
|
||||
_, err = Instance.Client.Mgmt.SetUserMetadata(CTX, &management.SetUserMetadataRequest{
|
||||
Id: Instance.Users.Get(integration.UserTypeOrgOwner).ID,
|
||||
Key: "urn:zitadel:scim:provisioningDomain",
|
||||
Value: []byte("fooBazz"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
setProvisioningDomain(t, Instance.Users.Get(integration.UserTypeOrgOwner).ID, "fooBazz")
|
||||
|
||||
// replace the user with provisioning domain set
|
||||
_, err = Instance.Client.SCIM.Users.Replace(CTX, Instance.DefaultOrg.Id, createdUser.ID, minimalUserWithExternalIDJson)
|
||||
@ -325,9 +342,5 @@ func TestReplaceUser_scopedExternalID(t *testing.T) {
|
||||
_, err = Instance.Client.UserV2.DeleteUser(CTX, &user.DeleteUserRequest{UserId: createdUser.ID})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = Instance.Client.Mgmt.RemoveUserMetadata(CTX, &management.RemoveUserMetadataRequest{
|
||||
Id: Instance.Users.Get(integration.UserTypeOrgOwner).ID,
|
||||
Key: "urn:zitadel:scim:provisioningDomain",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
removeProvisioningDomain(t, Instance.Users.Get(integration.UserTypeOrgOwner).ID)
|
||||
}
|
||||
|
@ -10,8 +10,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/brianvoe/gofakeit/v6"
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/text/language"
|
||||
@ -28,7 +26,7 @@ var (
|
||||
//go:embed testdata/users_update_test_full.json
|
||||
fullUserUpdateJson []byte
|
||||
|
||||
minimalUserUpdateJson = simpleReplacePatchBody("nickname", "foo")
|
||||
minimalUserUpdateJson = simpleReplacePatchBody("nickname", "\"foo\"")
|
||||
)
|
||||
|
||||
func init() {
|
||||
@ -44,9 +42,6 @@ func TestUpdateUser(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
iamOwnerCtx := Instance.WithAuthorization(CTX, integration.UserTypeIAMOwner)
|
||||
secondaryOrg := Instance.CreateOrganization(iamOwnerCtx, gofakeit.Name(), gofakeit.Email())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body []byte
|
||||
@ -74,7 +69,15 @@ func TestUpdateUser(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "other org",
|
||||
orgID: secondaryOrg.OrganizationId,
|
||||
orgID: SecondaryOrganization.OrganizationId,
|
||||
body: minimalUserUpdateJson,
|
||||
wantErr: true,
|
||||
errorStatus: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "other org with permissions",
|
||||
ctx: Instance.WithAuthorization(CTX, integration.UserTypeIAMOwner),
|
||||
orgID: SecondaryOrganization.OrganizationId,
|
||||
body: minimalUserUpdateJson,
|
||||
wantErr: true,
|
||||
errorStatus: http.StatusNotFound,
|
||||
@ -204,7 +207,7 @@ func TestUpdateUser(t *testing.T) {
|
||||
PreferredLanguage: language.MustParse("en-US"),
|
||||
Locale: "en-US",
|
||||
Timezone: "America/Los_Angeles",
|
||||
Active: gu.Ptr(true),
|
||||
Active: schemas.NewRelaxedBool(true),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
@ -60,6 +60,9 @@ func (req *OperationRequest) Validate() error {
|
||||
}
|
||||
|
||||
func (op *Operation) validate() error {
|
||||
// ignore the casing, as some scim clients send these capitalized
|
||||
op.Operation = OperationType(strings.ToLower(string(op.Operation)))
|
||||
|
||||
if !op.Operation.isValid() {
|
||||
return serrors.ThrowInvalidValue(zerrors.ThrowInvalidArgumentf(nil, "SCIM-opty1", "Patch op %s not supported", op.Operation))
|
||||
}
|
||||
|
@ -39,7 +39,7 @@ type ScimUser struct {
|
||||
PreferredLanguage language.Tag `json:"preferredLanguage,omitempty"`
|
||||
Locale string `json:"locale,omitempty"`
|
||||
Timezone string `json:"timezone,omitempty"`
|
||||
Active *bool `json:"active,omitempty"`
|
||||
Active *scim_schemas.RelaxedBool `json:"active,omitempty"`
|
||||
Emails []*ScimEmail `json:"emails,omitempty" scim:"required"`
|
||||
PhoneNumbers []*ScimPhoneNumber `json:"phoneNumbers,omitempty"`
|
||||
Password *scim_schemas.WriteOnlyString `json:"password,omitempty"`
|
||||
@ -154,7 +154,7 @@ func (h *UsersHandler) Create(ctx context.Context, user *ScimUser) (*ScimUser, e
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = h.command.AddUserHuman(ctx, orgID, addHuman, true, h.userCodeAlg)
|
||||
err = h.command.AddUserHuman(ctx, orgID, addHuman, false, h.userCodeAlg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -180,7 +180,8 @@ func (h *UsersHandler) Replace(ctx context.Context, id string, user *ScimUser) (
|
||||
}
|
||||
|
||||
func (h *UsersHandler) Update(ctx context.Context, id string, operations patch.OperationCollection) error {
|
||||
userWM, err := h.command.UserHumanWriteModel(ctx, id, true, true, true, true, false, false, true)
|
||||
orgID := authz.GetCtxData(ctx).OrgID
|
||||
userWM, err := h.command.UserHumanWriteModel(ctx, id, orgID, true, true, true, true, false, false, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -191,6 +192,9 @@ func (h *UsersHandler) Update(ctx context.Context, id string, operations patch.O
|
||||
return err
|
||||
}
|
||||
|
||||
// ensure the identity of the user is not modified
|
||||
changeHuman.ID = id
|
||||
changeHuman.ResourceOwner = orgID
|
||||
return h.command.ChangeUserHuman(ctx, changeHuman, h.userCodeAlg)
|
||||
}
|
||||
|
||||
@ -200,12 +204,12 @@ func (h *UsersHandler) Delete(ctx context.Context, id string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = h.command.RemoveUserV2(ctx, id, memberships, grants...)
|
||||
_, err = h.command.RemoveUserV2(ctx, id, authz.GetCtxData(ctx).OrgID, memberships, grants...)
|
||||
return err
|
||||
}
|
||||
|
||||
func (h *UsersHandler) Get(ctx context.Context, id string) (*ScimUser, error) {
|
||||
user, err := h.query.GetUserByID(ctx, false, id)
|
||||
user, err := h.query.GetUserByIDWithResourceOwner(ctx, false, id, authz.GetCtxData(ctx).OrgID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
"github.com/zitadel/logging"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/api/scim/metadata"
|
||||
"github.com/zitadel/zitadel/internal/api/scim/schemas"
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
@ -73,8 +74,9 @@ func (h *UsersHandler) mapToAddHuman(ctx context.Context, scimUser *ScimUser) (*
|
||||
|
||||
func (h *UsersHandler) mapToChangeHuman(ctx context.Context, scimUser *ScimUser) (*command.ChangeHuman, error) {
|
||||
human := &command.ChangeHuman{
|
||||
ID: scimUser.ID,
|
||||
Username: &scimUser.UserName,
|
||||
ID: scimUser.ID,
|
||||
ResourceOwner: authz.GetCtxData(ctx).OrgID,
|
||||
Username: &scimUser.UserName,
|
||||
Profile: &command.Profile{
|
||||
NickName: &scimUser.NickName,
|
||||
DisplayName: &scimUser.DisplayName,
|
||||
@ -271,7 +273,7 @@ func (h *UsersHandler) mapToScimUser(ctx context.Context, user *query.User, md m
|
||||
FamilyName: user.Human.LastName,
|
||||
GivenName: user.Human.FirstName,
|
||||
},
|
||||
Active: gu.Ptr(user.State.IsEnabled()),
|
||||
Active: schemas.NewRelaxedBool(user.State.IsEnabled()),
|
||||
}
|
||||
|
||||
if string(user.Human.Email) != "" {
|
||||
@ -309,7 +311,7 @@ func (h *UsersHandler) mapWriteModelToScimUser(ctx context.Context, user *comman
|
||||
FamilyName: user.LastName,
|
||||
GivenName: user.FirstName,
|
||||
},
|
||||
Active: gu.Ptr(user.UserState.IsEnabled()),
|
||||
Active: schemas.NewRelaxedBool(user.UserState.IsEnabled()),
|
||||
}
|
||||
|
||||
if string(user.Email) != "" {
|
||||
|
@ -7,7 +7,6 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/text/language"
|
||||
@ -707,7 +706,7 @@ func TestOperationCollection_Apply(t *testing.T) {
|
||||
PreferredLanguage: language.MustParse("en-US"),
|
||||
Locale: "en-US",
|
||||
Timezone: "America/New_York",
|
||||
Active: gu.Ptr(true),
|
||||
Active: schemas.NewRelaxedBool(true),
|
||||
Emails: []*ScimEmail{
|
||||
{
|
||||
Value: "jeanie.pendleton@example.com",
|
||||
|
41
internal/api/scim/schemas/bool.go
Normal file
41
internal/api/scim/schemas/bool.go
Normal file
@ -0,0 +1,41 @@
|
||||
package schemas
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/muhlemmer/gu"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
// RelaxedBool a bool which is more relaxed when it comes to json (un)marshaling.
|
||||
// This ensures compatibility with some bugged scim providers,
|
||||
// such as Microsoft Entry, which sends booleans as "True" or "False".
|
||||
// See also https://learn.microsoft.com/en-us/entra/identity/app-provisioning/application-provisioning-config-problem-scim-compatibility.
|
||||
type RelaxedBool bool
|
||||
|
||||
func NewRelaxedBool(value bool) *RelaxedBool {
|
||||
return gu.Ptr(RelaxedBool(value))
|
||||
}
|
||||
|
||||
func (b *RelaxedBool) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(bool(*b))
|
||||
}
|
||||
|
||||
func (b *RelaxedBool) UnmarshalJSON(bytes []byte) error {
|
||||
str := strings.ToLower(string(bytes))
|
||||
switch {
|
||||
case str == "true" || str == "\"true\"":
|
||||
*b = true
|
||||
case str == "false" || str == "\"false\"":
|
||||
*b = false
|
||||
default:
|
||||
return zerrors.ThrowInvalidArgumentf(nil, "SCIM-BOO1", "bool expected, got %v", str)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *RelaxedBool) Bool() bool {
|
||||
return bool(*b)
|
||||
}
|
63
internal/api/scim/schemas/bool_test.go
Normal file
63
internal/api/scim/schemas/bool_test.go
Normal file
@ -0,0 +1,63 @@
|
||||
package schemas
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRelaxedBool_MarshalJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input bool
|
||||
expected string
|
||||
}{
|
||||
{name: "true", input: true, expected: "true"},
|
||||
{name: "false", expected: "false"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
value := NewRelaxedBool(tt.input)
|
||||
bytes, err := json.Marshal(value)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, string(bytes))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRelaxedBool_UnmarshalJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected bool
|
||||
wantErr bool
|
||||
}{
|
||||
{name: "valid true", input: "true", expected: true},
|
||||
{name: "valid false", input: "false"},
|
||||
{name: "quoted true", input: `"true"`, expected: true},
|
||||
{name: "quoted pascal case true", input: `"True"`, expected: true},
|
||||
{name: "quoted upper case true", input: `"TRUE"`, expected: true},
|
||||
{name: "quoted false", input: `"false"`},
|
||||
{name: "quoted pascal case false", input: `"False"`},
|
||||
{name: "quoted upper case false", input: `"FALSE"`},
|
||||
{name: "invalid value", input: "invalid", wantErr: true},
|
||||
{name: "number value", input: "1", wantErr: true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
value := new(RelaxedBool)
|
||||
err := json.Unmarshal([]byte(tt.input), &value)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, value.Bool())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -45,6 +45,8 @@ var (
|
||||
|
||||
const (
|
||||
envRequestPath = "/assets/environment.json"
|
||||
// https://posthog.com/docs/advanced/content-security-policy
|
||||
posthogCSPHost = "https://*.i.posthog.com"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -106,12 +108,14 @@ func Start(config Config, externalSecure bool, issuer op.IssuerFromRequest, call
|
||||
config.LongCache.MaxAge,
|
||||
config.LongCache.SharedMaxAge,
|
||||
)
|
||||
security := middleware.SecurityHeaders(csp(), nil)
|
||||
security := middleware.SecurityHeaders(csp(config.PostHog.URL), nil)
|
||||
|
||||
handler := mux.NewRouter()
|
||||
handler.Use(security, limitingAccessInterceptor.WithoutLimiting().Handle)
|
||||
|
||||
handler.Use(callDurationInterceptor, instanceHandler, security, limitingAccessInterceptor.WithoutLimiting().Handle)
|
||||
handler.Handle(envRequestPath, middleware.TelemetryHandler()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
env := handler.NewRoute().Path(envRequestPath).Subrouter()
|
||||
env.Use(callDurationInterceptor, middleware.TelemetryHandler(), instanceHandler)
|
||||
env.HandleFunc("", func(w http.ResponseWriter, r *http.Request) {
|
||||
url := http_util.BuildOrigin(r.Host, externalSecure)
|
||||
ctx := r.Context()
|
||||
instance := authz.GetInstance(ctx)
|
||||
@ -128,7 +132,7 @@ func Start(config Config, externalSecure bool, issuer op.IssuerFromRequest, call
|
||||
}
|
||||
_, err = w.Write(environmentJSON)
|
||||
logging.OnError(err).Error("error serving environment.json")
|
||||
})))
|
||||
})
|
||||
handler.SkipClean(true).PathPrefix("").Handler(cache(http.FileServer(&spaHandler{http.FS(fSys)})))
|
||||
return handler, nil
|
||||
}
|
||||
@ -145,12 +149,22 @@ func templateInstanceManagementURL(templateableCookieValue string, instance auth
|
||||
return cookieValue.String(), nil
|
||||
}
|
||||
|
||||
func csp() *middleware.CSP {
|
||||
func csp(posthogURL string) *middleware.CSP {
|
||||
csp := middleware.DefaultSCP
|
||||
csp.StyleSrc = csp.StyleSrc.AddInline()
|
||||
csp.ScriptSrc = csp.ScriptSrc.AddEval()
|
||||
csp.ConnectSrc = csp.ConnectSrc.AddOwnHost()
|
||||
csp.ImgSrc = csp.ImgSrc.AddOwnHost().AddScheme("blob")
|
||||
if posthogURL != "" {
|
||||
// https://posthog.com/docs/advanced/content-security-policy#enabling-the-toolbar
|
||||
csp.ScriptSrc = csp.ScriptSrc.AddHost(posthogCSPHost)
|
||||
csp.ConnectSrc = csp.ConnectSrc.AddHost(posthogCSPHost)
|
||||
csp.ImgSrc = csp.ImgSrc.AddHost(posthogCSPHost)
|
||||
csp.StyleSrc = csp.StyleSrc.AddHost(posthogCSPHost)
|
||||
csp.FontSrc = csp.FontSrc.AddHost(posthogCSPHost)
|
||||
csp.MediaSrc = middleware.CSPSourceOpts().AddHost(posthogCSPHost)
|
||||
}
|
||||
|
||||
return &csp
|
||||
}
|
||||
|
||||
|
@ -159,12 +159,12 @@ func (c *Commands) userStateWriteModel(ctx context.Context, userID string) (writ
|
||||
return writeModel, nil
|
||||
}
|
||||
|
||||
func (c *Commands) RemoveUserV2(ctx context.Context, userID string, cascadingUserMemberships []*CascadingMembership, cascadingGrantIDs ...string) (*domain.ObjectDetails, error) {
|
||||
func (c *Commands) RemoveUserV2(ctx context.Context, userID, resourceOwner string, cascadingUserMemberships []*CascadingMembership, cascadingGrantIDs ...string) (*domain.ObjectDetails, error) {
|
||||
if userID == "" {
|
||||
return nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-vaipl7s13l", "Errors.User.UserIDMissing")
|
||||
}
|
||||
|
||||
existingUser, err := c.userRemoveWriteModel(ctx, userID)
|
||||
existingUser, err := c.userRemoveWriteModel(ctx, userID, resourceOwner)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -210,11 +210,11 @@ func (c *Commands) RemoveUserV2(ctx context.Context, userID string, cascadingUse
|
||||
return writeModelToObjectDetails(&existingUser.WriteModel), nil
|
||||
}
|
||||
|
||||
func (c *Commands) userRemoveWriteModel(ctx context.Context, userID string) (writeModel *UserV2WriteModel, err error) {
|
||||
func (c *Commands) userRemoveWriteModel(ctx context.Context, userID, resourceOwner string) (writeModel *UserV2WriteModel, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
writeModel = NewUserRemoveWriteModel(userID, "")
|
||||
writeModel = NewUserRemoveWriteModel(userID, resourceOwner)
|
||||
err = c.eventstore.FilterToQueryReducer(ctx, writeModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -14,12 +14,13 @@ import (
|
||||
)
|
||||
|
||||
type ChangeHuman struct {
|
||||
ID string
|
||||
State *domain.UserState
|
||||
Username *string
|
||||
Profile *Profile
|
||||
Email *Email
|
||||
Phone *Phone
|
||||
ID string
|
||||
ResourceOwner string
|
||||
State *domain.UserState
|
||||
Username *string
|
||||
Profile *Profile
|
||||
Email *Email
|
||||
Phone *Phone
|
||||
|
||||
Metadata []*domain.Metadata
|
||||
MetadataKeysToRemove []string
|
||||
@ -267,6 +268,7 @@ func (c *Commands) ChangeUserHuman(ctx context.Context, human *ChangeHuman, alg
|
||||
existingHuman, err := c.UserHumanWriteModel(
|
||||
ctx,
|
||||
human.ID,
|
||||
human.ResourceOwner,
|
||||
human.Profile != nil,
|
||||
human.Email != nil,
|
||||
human.Phone != nil,
|
||||
@ -525,11 +527,11 @@ func (c *Commands) userExistsWriteModel(ctx context.Context, userID string) (wri
|
||||
return writeModel, nil
|
||||
}
|
||||
|
||||
func (c *Commands) UserHumanWriteModel(ctx context.Context, userID string, profileWM, emailWM, phoneWM, passwordWM, avatarWM, idpLinksWM, metadataWM bool) (writeModel *UserV2WriteModel, err error) {
|
||||
func (c *Commands) UserHumanWriteModel(ctx context.Context, userID, resourceOwner string, profileWM, emailWM, phoneWM, passwordWM, avatarWM, idpLinksWM, metadataWM bool) (writeModel *UserV2WriteModel, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
writeModel = NewUserHumanWriteModel(userID, "", profileWM, emailWM, phoneWM, passwordWM, avatarWM, idpLinksWM, metadataWM)
|
||||
writeModel = NewUserHumanWriteModel(userID, resourceOwner, profileWM, emailWM, phoneWM, passwordWM, avatarWM, idpLinksWM, metadataWM)
|
||||
err = c.eventstore.FilterToQueryReducer(ctx, writeModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -567,7 +567,7 @@ func TestCommandSide_userHumanWriteModel_profile(t *testing.T) {
|
||||
r := &Commands{
|
||||
eventstore: tt.fields.eventstore(t),
|
||||
}
|
||||
wm, err := r.UserHumanWriteModel(tt.args.ctx, tt.args.userID, true, false, false, false, false, false, false)
|
||||
wm, err := r.UserHumanWriteModel(tt.args.ctx, tt.args.userID, "", true, false, false, false, false, false, false)
|
||||
if tt.res.err == nil {
|
||||
if !assert.NoError(t, err) {
|
||||
t.FailNow()
|
||||
@ -912,7 +912,7 @@ func TestCommandSide_userHumanWriteModel_email(t *testing.T) {
|
||||
r := &Commands{
|
||||
eventstore: tt.fields.eventstore(t),
|
||||
}
|
||||
wm, err := r.UserHumanWriteModel(tt.args.ctx, tt.args.userID, false, true, false, false, false, false, false)
|
||||
wm, err := r.UserHumanWriteModel(tt.args.ctx, tt.args.userID, "", false, true, false, false, false, false, false)
|
||||
if tt.res.err == nil {
|
||||
if !assert.NoError(t, err) {
|
||||
t.FailNow()
|
||||
@ -1344,7 +1344,7 @@ func TestCommandSide_userHumanWriteModel_phone(t *testing.T) {
|
||||
r := &Commands{
|
||||
eventstore: tt.fields.eventstore(t),
|
||||
}
|
||||
wm, err := r.UserHumanWriteModel(tt.args.ctx, tt.args.userID, false, false, true, false, false, false, false)
|
||||
wm, err := r.UserHumanWriteModel(tt.args.ctx, tt.args.userID, "", false, false, true, false, false, false, false)
|
||||
if tt.res.err == nil {
|
||||
if !assert.NoError(t, err) {
|
||||
t.FailNow()
|
||||
@ -1605,7 +1605,7 @@ func TestCommandSide_userHumanWriteModel_password(t *testing.T) {
|
||||
r := &Commands{
|
||||
eventstore: tt.fields.eventstore(t),
|
||||
}
|
||||
wm, err := r.UserHumanWriteModel(tt.args.ctx, tt.args.userID, false, false, false, true, false, false, false)
|
||||
wm, err := r.UserHumanWriteModel(tt.args.ctx, tt.args.userID, "", false, false, false, true, false, false, false)
|
||||
if tt.res.err == nil {
|
||||
if !assert.NoError(t, err) {
|
||||
t.FailNow()
|
||||
@ -2132,7 +2132,7 @@ func TestCommandSide_userHumanWriteModel_avatar(t *testing.T) {
|
||||
r := &Commands{
|
||||
eventstore: tt.fields.eventstore(t),
|
||||
}
|
||||
wm, err := r.UserHumanWriteModel(tt.args.ctx, tt.args.userID, false, false, false, false, true, false, false)
|
||||
wm, err := r.UserHumanWriteModel(tt.args.ctx, tt.args.userID, "", false, false, false, false, true, false, false)
|
||||
if tt.res.err == nil {
|
||||
if !assert.NoError(t, err) {
|
||||
t.FailNow()
|
||||
@ -2441,7 +2441,7 @@ func TestCommandSide_userHumanWriteModel_idpLinks(t *testing.T) {
|
||||
r := &Commands{
|
||||
eventstore: tt.fields.eventstore(t),
|
||||
}
|
||||
wm, err := r.userRemoveWriteModel(tt.args.ctx, tt.args.userID)
|
||||
wm, err := r.userRemoveWriteModel(tt.args.ctx, tt.args.userID, "")
|
||||
if tt.res.err == nil {
|
||||
if !assert.NoError(t, err) {
|
||||
t.FailNow()
|
||||
@ -2744,7 +2744,7 @@ func TestCommandSide_userHumanWriteModel_metadata(t *testing.T) {
|
||||
r := &Commands{
|
||||
eventstore: tt.fields.eventstore(t),
|
||||
}
|
||||
wm, err := r.UserHumanWriteModel(tt.args.ctx, tt.args.userID, false, false, false, false, false, false, true)
|
||||
wm, err := r.UserHumanWriteModel(tt.args.ctx, tt.args.userID, "", false, false, false, false, false, false, true)
|
||||
if tt.res.err == nil {
|
||||
if !assert.NoError(t, err) {
|
||||
t.FailNow()
|
||||
|
@ -1358,7 +1358,7 @@ func TestCommandSide_RemoveUserV2(t *testing.T) {
|
||||
eventstore: tt.fields.eventstore(t),
|
||||
checkPermission: tt.fields.checkPermission,
|
||||
}
|
||||
got, err := r.RemoveUserV2(tt.args.ctx, tt.args.userID, tt.args.cascadingMemberships, tt.args.grantIDs...)
|
||||
got, err := r.RemoveUserV2(tt.args.ctx, tt.args.userID, "", tt.args.cascadingMemberships, tt.args.grantIDs...)
|
||||
if tt.res.err == nil {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
@ -6,7 +6,6 @@ import (
|
||||
"github.com/zitadel/logging"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/notification/channels"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/metrics"
|
||||
)
|
||||
@ -18,18 +17,14 @@ func countMessages(ctx context.Context, channel channels.NotificationChannel, su
|
||||
if err != nil {
|
||||
metricName = errorMetricName
|
||||
}
|
||||
addCount(ctx, metricName, message, err)
|
||||
addCount(ctx, metricName, message)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func addCount(ctx context.Context, metricName string, message channels.Message, err error) {
|
||||
func addCount(ctx context.Context, metricName string, message channels.Message) {
|
||||
labels := map[string]attribute.Value{
|
||||
"triggering_event_typey": attribute.StringValue(string(message.GetTriggeringEvent().Type())),
|
||||
"instance": attribute.StringValue(authz.GetInstance(ctx).InstanceID()),
|
||||
}
|
||||
if err != nil {
|
||||
labels["error"] = attribute.StringValue(err.Error())
|
||||
"triggering_event_type": attribute.StringValue(string(message.GetTriggeringEvent().Type())),
|
||||
}
|
||||
addCountErr := metrics.AddCount(ctx, metricName, 1, labels)
|
||||
logging.WithFields("name", metricName, "labels", labels).OnError(addCountErr).Error("incrementing counter metric failed")
|
||||
|
@ -368,6 +368,10 @@ func (q *Queries) GetUserByIDWithPermission(ctx context.Context, shouldTriggerBu
|
||||
}
|
||||
|
||||
func (q *Queries) GetUserByID(ctx context.Context, shouldTriggerBulk bool, userID string) (user *User, err error) {
|
||||
return q.GetUserByIDWithResourceOwner(ctx, shouldTriggerBulk, userID, "")
|
||||
}
|
||||
|
||||
func (q *Queries) GetUserByIDWithResourceOwner(ctx context.Context, shouldTriggerBulk bool, userID, resourceOwner string) (user *User, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
@ -382,6 +386,7 @@ func (q *Queries) GetUserByID(ctx context.Context, shouldTriggerBulk bool, userI
|
||||
},
|
||||
userByIDQuery,
|
||||
userID,
|
||||
resourceOwner,
|
||||
authz.GetInstance(ctx).InstanceID(),
|
||||
)
|
||||
return user, err
|
||||
|
@ -20,8 +20,8 @@ WITH login_names AS (SELECT
|
||||
WHERE
|
||||
u.instance_id = p.instance_id
|
||||
AND (
|
||||
(p.is_default IS TRUE AND p.instance_id = $2)
|
||||
OR (p.instance_id = $2 AND p.resource_owner = u.resource_owner)
|
||||
(p.is_default IS TRUE AND p.instance_id = $3)
|
||||
OR (p.instance_id = $3 AND p.resource_owner = u.resource_owner)
|
||||
)
|
||||
ORDER BY is_default
|
||||
LIMIT 1
|
||||
@ -32,8 +32,9 @@ WITH login_names AS (SELECT
|
||||
u.instance_id = d.instance_id
|
||||
AND u.resource_owner = d.resource_owner
|
||||
WHERE
|
||||
u.instance_id = $2
|
||||
AND u.id = $1
|
||||
u.id = $1
|
||||
AND (u.resource_owner = $2 OR $2 = '')
|
||||
AND u.instance_id = $3
|
||||
)
|
||||
SELECT
|
||||
u.id
|
||||
@ -80,6 +81,7 @@ LEFT JOIN
|
||||
AND u.instance_id = m.instance_id
|
||||
WHERE
|
||||
u.id = $1
|
||||
AND u.instance_id = $2
|
||||
AND (u.resource_owner = $2 OR $2 = '')
|
||||
AND u.instance_id = $3
|
||||
LIMIT 1
|
||||
;
|
@ -30,5 +30,5 @@ func TelemetryHandler(handler http.Handler, ignoredEndpoints ...string) http.Han
|
||||
}
|
||||
|
||||
func spanNameFormatter(_ string, r *http.Request) string {
|
||||
return r.Host + r.URL.EscapedPath()
|
||||
return strings.Split(r.RequestURI, "?")[0]
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
@ -35,7 +36,8 @@ const (
|
||||
|
||||
type StatusRecorder struct {
|
||||
http.ResponseWriter
|
||||
Status int
|
||||
RequestURI *string
|
||||
Status int
|
||||
}
|
||||
|
||||
func (r *StatusRecorder) WriteHeader(status int) {
|
||||
@ -56,6 +58,18 @@ func NewMetricsHandler(handler http.Handler, metricMethods []MetricType, ignored
|
||||
return &h
|
||||
}
|
||||
|
||||
type key int
|
||||
|
||||
const requestURI key = iota
|
||||
|
||||
func SetRequestURIPattern(ctx context.Context, pattern string) {
|
||||
uri, ok := ctx.Value(requestURI).(*string)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
*uri = pattern
|
||||
}
|
||||
|
||||
// ServeHTTP serves HTTP requests (http.Handler)
|
||||
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if len(h.methods) == 0 {
|
||||
@ -69,13 +83,16 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
}
|
||||
uri := strings.Split(r.RequestURI, "?")[0]
|
||||
recorder := &StatusRecorder{
|
||||
ResponseWriter: w,
|
||||
RequestURI: &uri,
|
||||
Status: 200,
|
||||
}
|
||||
r = r.WithContext(context.WithValue(r.Context(), requestURI, &uri))
|
||||
h.handler.ServeHTTP(recorder, r)
|
||||
if h.containsMetricsMethod(MetricTypeRequestCount) {
|
||||
RegisterRequestCounter(r)
|
||||
RegisterRequestCounter(recorder, r)
|
||||
}
|
||||
if h.containsMetricsMethod(MetricTypeTotalCount) {
|
||||
RegisterTotalRequestCounter(r)
|
||||
@ -94,9 +111,9 @@ func (h *Handler) containsMetricsMethod(method MetricType) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func RegisterRequestCounter(r *http.Request) {
|
||||
func RegisterRequestCounter(recorder *StatusRecorder, r *http.Request) {
|
||||
var labels = map[string]attribute.Value{
|
||||
URI: attribute.StringValue(strings.Split(r.RequestURI, "?")[0]),
|
||||
URI: attribute.StringValue(*recorder.RequestURI),
|
||||
Method: attribute.StringValue(r.Method),
|
||||
}
|
||||
RegisterCounter(RequestCounter, RequestCountDescription)
|
||||
@ -110,7 +127,7 @@ func RegisterTotalRequestCounter(r *http.Request) {
|
||||
|
||||
func RegisterRequestCodeCounter(recorder *StatusRecorder, r *http.Request) {
|
||||
var labels = map[string]attribute.Value{
|
||||
URI: attribute.StringValue(strings.Split(r.RequestURI, "?")[0]),
|
||||
URI: attribute.StringValue(*recorder.RequestURI),
|
||||
Method: attribute.StringValue(r.Method),
|
||||
ReturnCode: attribute.IntValue(recorder.Status),
|
||||
}
|
||||
|
@ -6,9 +6,11 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/exporters/prometheus"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
"go.opentelemetry.io/otel/sdk/instrumentation"
|
||||
sdk_metric "go.opentelemetry.io/otel/sdk/metric"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/telemetry/metrics"
|
||||
@ -33,9 +35,19 @@ func NewMetrics(meterName string) (metrics.Metrics, error) {
|
||||
if err != nil {
|
||||
return &Metrics{}, err
|
||||
}
|
||||
// create a view to filter out unwanted attributes
|
||||
view := sdk_metric.NewView(
|
||||
sdk_metric.Instrument{
|
||||
Scope: instrumentation.Scope{Name: otelhttp.ScopeName},
|
||||
},
|
||||
sdk_metric.Stream{
|
||||
AttributeFilter: attribute.NewAllowKeysFilter("http.method", "http.status_code", "http.target"),
|
||||
},
|
||||
)
|
||||
meterProvider := sdk_metric.NewMeterProvider(
|
||||
sdk_metric.WithReader(exporter),
|
||||
sdk_metric.WithResource(resource),
|
||||
sdk_metric.WithView(view),
|
||||
)
|
||||
return &Metrics{
|
||||
Provider: meterProvider,
|
||||
|
@ -28,7 +28,7 @@ type Tracer struct {
|
||||
}
|
||||
|
||||
func (c *Config) NewTracer() error {
|
||||
sampler := sdk_trace.ParentBased(sdk_trace.TraceIDRatioBased(c.Fraction))
|
||||
sampler := otel.NewSampler(sdk_trace.TraceIDRatioBased(c.Fraction))
|
||||
exporter, err := texporter.New(texporter.WithProjectID(c.ProjectID))
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -26,7 +26,7 @@ type Tracer struct {
|
||||
}
|
||||
|
||||
func (c *Config) NewTracer() error {
|
||||
sampler := sdk_trace.ParentBased(sdk_trace.TraceIDRatioBased(c.Fraction))
|
||||
sampler := otel.NewSampler(sdk_trace.TraceIDRatioBased(c.Fraction))
|
||||
exporter, err := stdout.New(stdout.WithPrettyPrint())
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
|
||||
otlpgrpc "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
|
||||
sdk_trace "go.opentelemetry.io/otel/sdk/trace"
|
||||
api_trace "go.opentelemetry.io/otel/trace"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
@ -47,7 +48,7 @@ func FractionFromConfig(i interface{}) (float64, error) {
|
||||
}
|
||||
|
||||
func (c *Config) NewTracer() error {
|
||||
sampler := sdk_trace.ParentBased(sdk_trace.TraceIDRatioBased(c.Fraction))
|
||||
sampler := NewSampler(sdk_trace.TraceIDRatioBased(c.Fraction))
|
||||
exporter, err := otlpgrpc.New(context.Background(), otlpgrpc.WithEndpoint(c.Endpoint), otlpgrpc.WithInsecure())
|
||||
if err != nil {
|
||||
return err
|
||||
@ -56,3 +57,19 @@ func (c *Config) NewTracer() error {
|
||||
tracing.T, err = NewTracer(sampler, exporter)
|
||||
return err
|
||||
}
|
||||
|
||||
// NewSampler returns a sampler decorator which behaves differently,
|
||||
// based on the parent of the span. If the span has no parent and is of kind server,
|
||||
// the decorated sampler is used to make sampling decision.
|
||||
// If the span has a parent, depending on whether the parent is remote and whether it
|
||||
// is sampled, one of the following samplers will apply:
|
||||
// - remote parent sampled -> always sample
|
||||
// - remote parent not sampled -> sample based on the decorated sampler (fraction based)
|
||||
// - local parent sampled -> always sample
|
||||
// - local parent not sampled -> never sample
|
||||
func NewSampler(sampler sdk_trace.Sampler) sdk_trace.Sampler {
|
||||
return sdk_trace.ParentBased(
|
||||
tracing.SpanKindBased(sampler, api_trace.SpanKindServer),
|
||||
sdk_trace.WithRemoteParentNotSampled(sampler),
|
||||
)
|
||||
}
|
||||
|
46
internal/telemetry/tracing/sampler.go
Normal file
46
internal/telemetry/tracing/sampler.go
Normal file
@ -0,0 +1,46 @@
|
||||
package tracing
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
sdk_trace "go.opentelemetry.io/otel/sdk/trace"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
type spanKindSampler struct {
|
||||
sampler sdk_trace.Sampler
|
||||
kinds []trace.SpanKind
|
||||
}
|
||||
|
||||
// ShouldSample implements the [sdk_trace.Sampler] interface.
|
||||
// It will not sample any spans which do not match the configured span kinds.
|
||||
// For spans which do match, the decorated sampler is used to make the sampling decision.
|
||||
func (sk spanKindSampler) ShouldSample(p sdk_trace.SamplingParameters) sdk_trace.SamplingResult {
|
||||
psc := trace.SpanContextFromContext(p.ParentContext)
|
||||
if !slices.Contains(sk.kinds, p.Kind) {
|
||||
return sdk_trace.SamplingResult{
|
||||
Decision: sdk_trace.Drop,
|
||||
Tracestate: psc.TraceState(),
|
||||
}
|
||||
}
|
||||
s := sk.sampler.ShouldSample(p)
|
||||
return s
|
||||
}
|
||||
|
||||
func (sk spanKindSampler) Description() string {
|
||||
return fmt.Sprintf("SpanKindBased{sampler:%s,kinds:%v}",
|
||||
sk.sampler.Description(),
|
||||
sk.kinds,
|
||||
)
|
||||
}
|
||||
|
||||
// SpanKindBased returns a sampler decorator which behaves differently, based on the kind of the span.
|
||||
// If the span kind does not match one of the configured kinds, it will not be sampled.
|
||||
// If the span kind matches, the decorated sampler is used to make sampling decision.
|
||||
func SpanKindBased(sampler sdk_trace.Sampler, kinds ...trace.SpanKind) sdk_trace.Sampler {
|
||||
return spanKindSampler{
|
||||
sampler: sampler,
|
||||
kinds: kinds,
|
||||
}
|
||||
}
|
80
internal/telemetry/tracing/sampler_test.go
Normal file
80
internal/telemetry/tracing/sampler_test.go
Normal file
@ -0,0 +1,80 @@
|
||||
package tracing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
sdk_trace "go.opentelemetry.io/otel/sdk/trace"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
func TestSpanKindBased(t *testing.T) {
|
||||
type args struct {
|
||||
sampler sdk_trace.Sampler
|
||||
kinds []trace.SpanKind
|
||||
}
|
||||
type want struct {
|
||||
description string
|
||||
sampled int
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want want
|
||||
}{
|
||||
{
|
||||
"never sample, no sample",
|
||||
args{
|
||||
sampler: sdk_trace.NeverSample(),
|
||||
kinds: []trace.SpanKind{trace.SpanKindServer},
|
||||
},
|
||||
want{
|
||||
description: "SpanKindBased{sampler:AlwaysOffSampler,kinds:[server]}",
|
||||
sampled: 0,
|
||||
},
|
||||
},
|
||||
{
|
||||
"always sample, no kind, no sample",
|
||||
args{
|
||||
sampler: sdk_trace.AlwaysSample(),
|
||||
kinds: nil,
|
||||
},
|
||||
want{
|
||||
description: "SpanKindBased{sampler:AlwaysOnSampler,kinds:[]}",
|
||||
sampled: 0,
|
||||
},
|
||||
},
|
||||
{
|
||||
"always sample, 2 kinds, 2 samples",
|
||||
args{
|
||||
sampler: sdk_trace.AlwaysSample(),
|
||||
kinds: []trace.SpanKind{trace.SpanKindServer, trace.SpanKindClient},
|
||||
},
|
||||
want{
|
||||
description: "SpanKindBased{sampler:AlwaysOnSampler,kinds:[server client]}",
|
||||
sampled: 2,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sampler := SpanKindBased(tt.args.sampler, tt.args.kinds...)
|
||||
assert.Equal(t, tt.want.description, sampler.Description())
|
||||
|
||||
p := sdk_trace.NewTracerProvider(sdk_trace.WithSampler(sampler))
|
||||
tr := p.Tracer("test")
|
||||
|
||||
var sampled int
|
||||
for i := trace.SpanKindUnspecified; i <= trace.SpanKindConsumer; i++ {
|
||||
ctx := context.Background()
|
||||
_, span := tr.Start(ctx, "test", trace.WithSpanKind(i))
|
||||
if span.SpanContext().IsSampled() {
|
||||
sampled++
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, tt.want.sampled, sampled)
|
||||
})
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user