Merge branch 'next' into next-rc

This commit is contained in:
Livio Spring 2025-02-13 11:26:00 +01:00
commit b2fbf9ace3
No known key found for this signature in database
39 changed files with 586 additions and 224 deletions

View File

@ -23,5 +23,5 @@ func (mig *User11AddLowerFieldsToVerifiedEmail) Execute(ctx context.Context, _ e
}
func (mig *User11AddLowerFieldsToVerifiedEmail) String() string {
return "25_user13_add_lower_fields_to_verified_email"
return "25_user14_add_lower_fields_to_verified_email"
}

View File

@ -37,14 +37,14 @@ BEGIN
END;
-- Return the organizations where permission were granted thru org-level roles
SELECT array_agg(org_id) INTO org_ids
SELECT array_agg(sub.org_id) INTO org_ids
FROM (
SELECT DISTINCT om.org_id
FROM eventstore.org_members om
WHERE om.role = ANY(matched_roles)
AND om.instance_id = instanceID
AND om.user_id = userId
);
) AS sub;
RETURN;
END;
$$;

View File

@ -1,36 +1,29 @@
package middleware
import (
"context"
"strings"
grpc_trace "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"google.golang.org/grpc"
"google.golang.org/grpc/stats"
grpc_utils "github.com/zitadel/zitadel/internal/api/grpc"
)
type GRPCMethod string
func DefaultTracingClient() grpc.UnaryClientInterceptor {
return TracingServer(grpc_utils.Healthz, grpc_utils.Readiness, grpc_utils.Validation)
func DefaultTracingClient() stats.Handler {
return TracingClient(grpc_utils.Healthz, grpc_utils.Readiness, grpc_utils.Validation)
}
func TracingServer(ignoredMethods ...GRPCMethod) grpc.UnaryClientInterceptor {
return func(
ctx context.Context,
method string,
req, reply interface{},
cc *grpc.ClientConn,
invoker grpc.UnaryInvoker,
opts ...grpc.CallOption,
) error {
for _, ignoredMethod := range ignoredMethods {
if strings.HasSuffix(method, string(ignoredMethod)) {
return invoker(ctx, method, req, reply, cc, opts...)
func TracingClient(ignoredMethods ...GRPCMethod) stats.Handler {
return grpc_trace.NewClientHandler(grpc_trace.WithFilter(
func(info *stats.RPCTagInfo) bool {
for _, ignoredMethod := range ignoredMethods {
if strings.HasSuffix(info.FullMethodName, string(ignoredMethod)) {
return false
}
}
}
return grpc_trace.UnaryClientInterceptor()(ctx, method, req, reply, cc, invoker, opts...)
}
return true
},
))
}

View File

@ -10,6 +10,7 @@ import (
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"github.com/zitadel/logging"
"go.opentelemetry.io/otel/trace"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
@ -56,6 +57,13 @@ var (
},
)
// we need the errorHandler to set the request URI pattern in case of an error
errorHandler = runtime.ErrorHandlerFunc(
func(ctx context.Context, mux *runtime.ServeMux, marshaler runtime.Marshaler, w http.ResponseWriter, r *http.Request, err error) {
setRequestURIPattern(ctx)
runtime.DefaultHTTPErrorHandler(ctx, mux, marshaler, w, r, err)
})
serveMuxOptions = func(hostHeaders []string) []runtime.ServeMuxOption {
return []runtime.ServeMuxOption{
runtime.WithMarshalerOption(jsonMarshaler.ContentType(nil), jsonMarshaler),
@ -65,6 +73,7 @@ var (
runtime.WithOutgoingHeaderMatcher(runtime.DefaultHeaderMatcher),
runtime.WithForwardResponseOption(responseForwarder),
runtime.WithRoutingErrorHandler(httpErrorHandler),
runtime.WithErrorHandler(errorHandler),
}
}
@ -81,6 +90,7 @@ var (
}
responseForwarder = func(ctx context.Context, w http.ResponseWriter, resp proto.Message) error {
setRequestURIPattern(ctx)
t, ok := resp.(CustomHTTPResponse)
if ok {
// TODO: find a way to return a location header if needed w.Header().Set("location", t.Location())
@ -118,9 +128,9 @@ func CreateGatewayWithPrefix(
opts := []grpc.DialOption{
grpc.WithTransportCredentials(grpcCredentials(tlsConfig)),
grpc.WithChainUnaryInterceptor(
client_middleware.DefaultTracingClient(),
client_middleware.UnaryActivityClientInterceptor(),
),
grpc.WithStatsHandler(client_middleware.DefaultTracingClient()),
}
connection, err := dial(ctx, port, opts)
if err != nil {
@ -145,9 +155,9 @@ func CreateGateway(
[]grpc.DialOption{
grpc.WithTransportCredentials(grpcCredentials(tlsConfig)),
grpc.WithChainUnaryInterceptor(
client_middleware.DefaultTracingClient(),
client_middleware.UnaryActivityClientInterceptor(),
),
grpc.WithStatsHandler(client_middleware.DefaultTracingClient()),
})
if err != nil {
return nil, err
@ -260,3 +270,13 @@ func grpcCredentials(tlsConfig *tls.Config) credentials.TransportCredentials {
}
return creds
}
func setRequestURIPattern(ctx context.Context) {
pattern, ok := runtime.HTTPPathPattern(ctx)
if !ok {
return
}
span := trace.SpanFromContext(ctx)
span.SetName(pattern)
metrics.SetRequestURIPattern(ctx, pattern)
}

View File

@ -1,34 +1,29 @@
package middleware
import (
"context"
"strings"
grpc_trace "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"google.golang.org/grpc"
"google.golang.org/grpc/stats"
grpc_utils "github.com/zitadel/zitadel/internal/api/grpc"
)
type GRPCMethod string
func DefaultTracingServer() grpc.UnaryServerInterceptor {
func DefaultTracingServer() stats.Handler {
return TracingServer(grpc_utils.Healthz, grpc_utils.Readiness, grpc_utils.Validation)
}
func TracingServer(ignoredMethods ...GRPCMethod) grpc.UnaryServerInterceptor {
return func(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
for _, ignoredMethod := range ignoredMethods {
if strings.HasSuffix(info.FullMethod, string(ignoredMethod)) {
return handler(ctx, req)
func TracingServer(ignoredMethods ...GRPCMethod) stats.Handler {
return grpc_trace.NewServerHandler(grpc_trace.WithFilter(
func(info *stats.RPCTagInfo) bool {
for _, ignoredMethod := range ignoredMethods {
if strings.HasSuffix(info.FullMethodName, string(ignoredMethod)) {
return false
}
}
}
return grpc_trace.UnaryServerInterceptor()(ctx, req, info, handler)
}
return true
},
))
}

View File

@ -47,7 +47,6 @@ func CreateServer(
grpc.UnaryInterceptor(
grpc_middleware.ChainUnaryServer(
middleware.CallDurationHandler(),
middleware.DefaultTracingServer(),
middleware.MetricsHandler(metricTypes, grpc_api.Probes...),
middleware.NoCacheInterceptor(),
middleware.InstanceInterceptor(queries, externalDomain, system_pb.SystemService_ServiceDesc.ServiceName, healthpb.Health_ServiceDesc.ServiceName),
@ -63,6 +62,7 @@ func CreateServer(
middleware.ActivityInterceptor(),
),
),
grpc.StatsHandler(middleware.DefaultTracingServer()),
}
if tlsConfig != nil {
serverOptions = append(serverOptions, grpc.Creds(credentials.NewTLS(tlsConfig)))

View File

@ -275,7 +275,7 @@ func (s *Server) DeleteUser(ctx context.Context, req *user.DeleteUserRequest) (_
if err != nil {
return nil, err
}
details, err := s.command.RemoveUserV2(ctx, req.UserId, memberships, grants...)
details, err := s.command.RemoveUserV2(ctx, req.UserId, "", memberships, grants...)
if err != nil {
return nil, err
}

View File

@ -278,7 +278,7 @@ func (s *Server) DeleteUser(ctx context.Context, req *user.DeleteUserRequest) (_
if err != nil {
return nil, err
}
details, err := s.command.RemoveUserV2(ctx, req.UserId, memberships, grants...)
details, err := s.command.RemoveUserV2(ctx, req.UserId, "", memberships, grants...)
if err != nil {
return nil, err
}

View File

@ -14,7 +14,6 @@ import (
"time"
"github.com/brianvoe/gofakeit/v6"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/text/language"
@ -24,6 +23,7 @@ import (
"github.com/zitadel/zitadel/internal/integration"
"github.com/zitadel/zitadel/internal/integration/scim"
"github.com/zitadel/zitadel/internal/test"
"github.com/zitadel/zitadel/pkg/grpc/management"
)
var (
@ -289,7 +289,7 @@ func TestBulk(t *testing.T) {
},
DisplayName: "scim-bulk-created-user-0-given-name scim-bulk-created-user-0-family-name",
PreferredLanguage: test.Must(language.Parse("en")),
Active: gu.Ptr(true),
Active: schemas.NewRelaxedBool(true),
Emails: []*resources.ScimEmail{
{
Value: "scim-bulk-created-user-0@example.com",
@ -308,7 +308,7 @@ func TestBulk(t *testing.T) {
DisplayName: "scim-bulk-created-user-1-given-name scim-bulk-created-user-1-family-name",
NickName: "scim-bulk-created-user-1-nickname-patched",
PreferredLanguage: test.Must(language.Parse("en")),
Active: gu.Ptr(true),
Active: schemas.NewRelaxedBool(true),
Emails: []*resources.ScimEmail{
{
Value: "scim-bulk-created-user-1@example.com",
@ -333,7 +333,7 @@ func TestBulk(t *testing.T) {
DisplayName: "scim-bulk-created-user-2-given-name scim-bulk-created-user-2-family-name",
NickName: "scim-bulk-created-user-2-nickname-patched",
PreferredLanguage: test.Must(language.Parse("en")),
Active: gu.Ptr(true),
Active: schemas.NewRelaxedBool(true),
Emails: []*resources.ScimEmail{
{
Value: "scim-bulk-created-user-2@example.com",
@ -696,3 +696,39 @@ func buildTooManyOperationsRequest() *scim.BulkRequest {
return req
}
func setProvisioningDomain(t require.TestingT, userID, provisioningDomain string) {
setAndEnsureMetadata(t, userID, "urn:zitadel:scim:provisioningDomain", provisioningDomain)
}
func setAndEnsureMetadata(t require.TestingT, userID, key, value string) {
_, err := Instance.Client.Mgmt.SetUserMetadata(CTX, &management.SetUserMetadataRequest{
Id: userID,
Key: key,
Value: []byte(value),
})
require.NoError(t, err)
// ensure metadata is projected
ensureMetadataProjected(t, userID, key, value)
}
func ensureMetadataProjected(t require.TestingT, userID, key, value string) {
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
require.EventuallyWithT(t, func(tt *assert.CollectT) {
md, err := Instance.Client.Mgmt.GetUserMetadata(CTX, &management.GetUserMetadataRequest{
Id: userID,
Key: key,
})
require.NoError(tt, err)
require.Equal(tt, value, string(md.Metadata.Value))
}, retryDuration, tick)
}
func removeProvisioningDomain(t require.TestingT, userID string) {
_, err := Instance.Client.Mgmt.RemoveUserMetadata(CTX, &management.RemoveUserMetadataRequest{
Id: userID,
Key: "urn:zitadel:scim:provisioningDomain",
})
require.NoError(t, err)
}

View File

@ -9,12 +9,16 @@ import (
"testing"
"time"
"github.com/brianvoe/gofakeit/v6"
"github.com/zitadel/zitadel/internal/integration"
"github.com/zitadel/zitadel/pkg/grpc/org/v2"
)
var (
Instance *integration.Instance
CTX context.Context
Instance *integration.Instance
SecondaryOrganization *org.AddOrganizationResponse
CTX context.Context
// remove comments in the json, as the default golang json unmarshaler cannot handle them
// some test files (e.g. bulk, patch) are much easier to maintain with comments
@ -29,6 +33,10 @@ func TestMain(m *testing.M) {
Instance = integration.NewInstance(ctx)
CTX = Instance.WithAuthorization(ctx, integration.UserTypeOrgOwner)
iamOwnerCtx := Instance.WithAuthorization(CTX, integration.UserTypeIAMOwner)
SecondaryOrganization = Instance.CreateOrganization(iamOwnerCtx, gofakeit.Name(), gofakeit.Email())
return m.Run()
}())
}

View File

@ -3,7 +3,7 @@
"Operations": [
// add without path
{
"op": "add",
"op": "Add", // with PascalCase operation type
"value": {
"emails":[
{
@ -17,7 +17,7 @@
},
// add complex attribute with path
{
"op": "add",
"op": "add", // with camelCase operation type
"path": "name",
"value": {
"formatted": "added-formatted",
@ -30,7 +30,7 @@
},
// add complex attribute value
{
"op": "add",
"op": "ADD", // with UPPERCASE operation type
"path": "name.middlename",
"value": "added-middle-name-2"
},
@ -134,6 +134,14 @@
"op": "replace",
"path": "password",
"value": "Password2!"
},
// replace active state
{
"op": "replace",
"path": "active",
// quoted uppercase bool
// (ensure compatibility with Microsoft Entra)
"value": "True"
}
]
}

View File

@ -10,8 +10,6 @@ import (
"testing"
"time"
"github.com/brianvoe/gofakeit/v6"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/text/language"
@ -155,7 +153,7 @@ var (
PreferredLanguage: language.MustParse("en-US"),
Locale: "en-US",
Timezone: "America/Los_Angeles",
Active: gu.Ptr(true),
Active: schemas.NewRelaxedBool(true),
}
)
@ -164,6 +162,7 @@ func TestCreateUser(t *testing.T) {
name string
body []byte
ctx context.Context
orgID string
want *resources.ScimUser
wantErr bool
scimErrorType string
@ -191,7 +190,7 @@ func TestCreateUser(t *testing.T) {
name: "minimal inactive user",
body: minimalInactiveUserJson,
want: &resources.ScimUser{
Active: gu.Ptr(false),
Active: schemas.NewRelaxedBool(false),
},
},
{
@ -275,6 +274,13 @@ func TestCreateUser(t *testing.T) {
wantErr: true,
errorStatus: http.StatusNotFound,
},
{
name: "another org",
body: minimalUserJson,
orgID: SecondaryOrganization.OrganizationId,
wantErr: true,
errorStatus: http.StatusNotFound,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -283,7 +289,12 @@ func TestCreateUser(t *testing.T) {
ctx = CTX
}
createdUser, err := Instance.Client.SCIM.Users.Create(ctx, Instance.DefaultOrg.Id, tt.body)
orgID := tt.orgID
if orgID == "" {
orgID = Instance.DefaultOrg.Id
}
createdUser, err := Instance.Client.SCIM.Users.Create(ctx, orgID, tt.body)
if (err != nil) != tt.wantErr {
t.Errorf("CreateUser() error = %v, wantErr %v", err, tt.wantErr)
return
@ -311,7 +322,7 @@ func TestCreateUser(t *testing.T) {
assert.EqualValues(t, []schemas.ScimSchemaType{"urn:ietf:params:scim:schemas:core:2.0:User"}, createdUser.Resource.Schemas)
assert.Equal(t, schemas.ScimResourceTypeSingular("User"), createdUser.Resource.Meta.ResourceType)
assert.Equal(t, "http://"+Instance.Host()+path.Join(schemas.HandlerPrefix, Instance.DefaultOrg.Id, "Users", createdUser.ID), createdUser.Resource.Meta.Location)
assert.Equal(t, "http://"+Instance.Host()+path.Join(schemas.HandlerPrefix, orgID, "Users", createdUser.ID), createdUser.Resource.Meta.Location)
assert.Nil(t, createdUser.Password)
if tt.want != nil {
@ -384,12 +395,7 @@ func TestCreateUser_metadata(t *testing.T) {
}
func TestCreateUser_scopedExternalID(t *testing.T) {
_, err := Instance.Client.Mgmt.SetUserMetadata(CTX, &management.SetUserMetadataRequest{
Id: Instance.Users.Get(integration.UserTypeOrgOwner).ID,
Key: "urn:zitadel:scim:provisioningDomain",
Value: []byte("fooBar"),
})
require.NoError(t, err)
setProvisioningDomain(t, Instance.Users.Get(integration.UserTypeOrgOwner).ID, "fooBar")
createdUser, err := Instance.Client.SCIM.Users.Create(CTX, Instance.DefaultOrg.Id, fullUserJson)
require.NoError(t, err)
@ -398,11 +404,7 @@ func TestCreateUser_scopedExternalID(t *testing.T) {
_, err = Instance.Client.UserV2.DeleteUser(CTX, &user.DeleteUserRequest{UserId: createdUser.ID})
require.NoError(t, err)
_, err = Instance.Client.Mgmt.RemoveUserMetadata(CTX, &management.RemoveUserMetadataRequest{
Id: Instance.Users.Get(integration.UserTypeOrgOwner).ID,
Key: "urn:zitadel:scim:provisioningDomain",
})
require.NoError(t, err)
removeProvisioningDomain(t, Instance.Users.Get(integration.UserTypeOrgOwner).ID)
}()
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
@ -423,9 +425,3 @@ func TestCreateUser_scopedExternalID(t *testing.T) {
assert.Equal(tt, "701984", string(md.Metadata.Value))
}, retryDuration, tick)
}
func TestCreateUser_anotherOrg(t *testing.T) {
org := Instance.CreateOrganization(Instance.WithAuthorization(CTX, integration.UserTypeIAMOwner), gofakeit.Name(), gofakeit.Email())
_, err := Instance.Client.SCIM.Users.Create(CTX, org.OrganizationId, fullUserJson)
scim.RequireScimError(t, http.StatusNotFound, err)
}

View File

@ -8,7 +8,6 @@ import (
"testing"
"time"
"github.com/brianvoe/gofakeit/v6"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/codes"
@ -22,6 +21,7 @@ func TestDeleteUser_errors(t *testing.T) {
tests := []struct {
name string
ctx context.Context
orgID string
errorStatus int
}{
{
@ -38,6 +38,17 @@ func TestDeleteUser_errors(t *testing.T) {
name: "unknown user id",
errorStatus: http.StatusNotFound,
},
{
name: "another org",
orgID: SecondaryOrganization.OrganizationId,
errorStatus: http.StatusNotFound,
},
{
name: "another org with permissions",
orgID: SecondaryOrganization.OrganizationId,
ctx: Instance.WithAuthorization(CTX, integration.UserTypeIAMOwner),
errorStatus: http.StatusNotFound,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -46,7 +57,11 @@ func TestDeleteUser_errors(t *testing.T) {
ctx = CTX
}
err := Instance.Client.SCIM.Users.Delete(ctx, Instance.DefaultOrg.Id, "1")
orgID := tt.orgID
if orgID == "" {
orgID = Instance.DefaultOrg.Id
}
err := Instance.Client.SCIM.Users.Delete(ctx, orgID, "1")
statusCode := tt.errorStatus
if statusCode == 0 {
@ -81,10 +96,3 @@ func TestDeleteUser_ensureReallyDeleted(t *testing.T) {
integration.AssertGrpcStatus(tt, codes.NotFound, err)
}, retryDuration, tick)
}
func TestDeleteUser_anotherOrg(t *testing.T) {
createUserResp := Instance.CreateHumanUser(CTX)
org := Instance.CreateOrganization(Instance.WithAuthorization(CTX, integration.UserTypeIAMOwner), gofakeit.Name(), gofakeit.Email())
err := Instance.Client.SCIM.Users.Delete(CTX, org.OrganizationId, createUserResp.UserId)
scim.RequireScimError(t, http.StatusNotFound, err)
}

View File

@ -9,8 +9,6 @@ import (
"testing"
"time"
"github.com/brianvoe/gofakeit/v6"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/text/language"
@ -20,13 +18,13 @@ import (
"github.com/zitadel/zitadel/internal/integration"
"github.com/zitadel/zitadel/internal/integration/scim"
"github.com/zitadel/zitadel/internal/test"
"github.com/zitadel/zitadel/pkg/grpc/management"
"github.com/zitadel/zitadel/pkg/grpc/user/v2"
)
func TestGetUser(t *testing.T) {
tests := []struct {
name string
orgID string
buildUserID func() string
cleanup func(userID string)
ctx context.Context
@ -46,6 +44,19 @@ func TestGetUser(t *testing.T) {
errorStatus: http.StatusNotFound,
wantErr: true,
},
{
name: "another org",
orgID: SecondaryOrganization.OrganizationId,
errorStatus: http.StatusNotFound,
wantErr: true,
},
{
name: "another org with permissions",
orgID: SecondaryOrganization.OrganizationId,
ctx: Instance.WithAuthorization(CTX, integration.UserTypeNoPermission),
errorStatus: http.StatusNotFound,
wantErr: true,
},
{
name: "unknown user id",
buildUserID: func() string {
@ -99,7 +110,7 @@ func TestGetUser(t *testing.T) {
PreferredLanguage: language.Make("en-US"),
Locale: "en-US",
Timezone: "America/Los_Angeles",
Active: gu.Ptr(true),
Active: schemas.NewRelaxedBool(true),
Emails: []*resources.ScimEmail{
{
Value: "bjensen@example.com",
@ -191,31 +202,17 @@ func TestGetUser(t *testing.T) {
require.NoError(t, err)
// set provisioning domain of service user
_, err = Instance.Client.Mgmt.SetUserMetadata(CTX, &management.SetUserMetadataRequest{
Id: Instance.Users.Get(integration.UserTypeOrgOwner).ID,
Key: "urn:zitadel:scim:provisioningDomain",
Value: []byte("fooBar"),
})
require.NoError(t, err)
setProvisioningDomain(t, Instance.Users.Get(integration.UserTypeOrgOwner).ID, "fooBar")
// set externalID for provisioning domain
_, err = Instance.Client.Mgmt.SetUserMetadata(CTX, &management.SetUserMetadataRequest{
Id: createdUser.ID,
Key: "urn:zitadel:scim:fooBar:externalId",
Value: []byte("100-scopedExternalId"),
})
require.NoError(t, err)
setAndEnsureMetadata(t, createdUser.ID, "urn:zitadel:scim:fooBar:externalId", "100-scopedExternalId")
return createdUser.ID
},
cleanup: func(userID string) {
_, err := Instance.Client.UserV2.DeleteUser(CTX, &user.DeleteUserRequest{UserId: userID})
require.NoError(t, err)
_, err = Instance.Client.Mgmt.RemoveUserMetadata(CTX, &management.RemoveUserMetadataRequest{
Id: Instance.Users.Get(integration.UserTypeOrgOwner).ID,
Key: "urn:zitadel:scim:provisioningDomain",
})
require.NoError(t, err)
removeProvisioningDomain(t, Instance.Users.Get(integration.UserTypeOrgOwner).ID)
},
want: &resources.ScimUser{
ExternalID: "100-scopedExternalId",
@ -237,11 +234,16 @@ func TestGetUser(t *testing.T) {
userID = createUserResp.UserId
}
orgID := tt.orgID
if orgID == "" {
orgID = Instance.DefaultOrg.Id
}
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
var fetchedUser *resources.ScimUser
var err error
require.EventuallyWithT(t, func(ttt *assert.CollectT) {
fetchedUser, err = Instance.Client.SCIM.Users.Get(ctx, Instance.DefaultOrg.Id, userID)
fetchedUser, err = Instance.Client.SCIM.Users.Get(ctx, orgID, userID)
if tt.wantErr {
statusCode := tt.errorStatus
if statusCode == 0 {
@ -255,7 +257,7 @@ func TestGetUser(t *testing.T) {
assert.Equal(ttt, userID, fetchedUser.ID)
assert.EqualValues(ttt, []schemas.ScimSchemaType{"urn:ietf:params:scim:schemas:core:2.0:User"}, fetchedUser.Schemas)
assert.Equal(ttt, schemas.ScimResourceTypeSingular("User"), fetchedUser.Resource.Meta.ResourceType)
assert.Equal(ttt, "http://"+Instance.Host()+path.Join(schemas.HandlerPrefix, Instance.DefaultOrg.Id, "Users", fetchedUser.ID), fetchedUser.Resource.Meta.Location)
assert.Equal(ttt, "http://"+Instance.Host()+path.Join(schemas.HandlerPrefix, orgID, "Users", fetchedUser.ID), fetchedUser.Resource.Meta.Location)
assert.Nil(ttt, fetchedUser.Password)
if !test.PartiallyDeepEqual(tt.want, fetchedUser) {
ttt.Errorf("GetUser() got = %#v, want %#v", fetchedUser, tt.want)
@ -268,10 +270,3 @@ func TestGetUser(t *testing.T) {
})
}
}
func TestGetUser_anotherOrg(t *testing.T) {
createUserResp := Instance.CreateHumanUser(CTX)
org := Instance.CreateOrganization(Instance.WithAuthorization(CTX, integration.UserTypeIAMOwner), gofakeit.Name(), gofakeit.Email())
_, err := Instance.Client.SCIM.Users.Get(CTX, org.OrganizationId, createUserResp.UserId)
scim.RequireScimError(t, http.StatusNotFound, err)
}

View File

@ -20,7 +20,6 @@ import (
"github.com/zitadel/zitadel/internal/integration"
"github.com/zitadel/zitadel/internal/integration/scim"
"github.com/zitadel/zitadel/internal/test"
"github.com/zitadel/zitadel/pkg/grpc/management"
"github.com/zitadel/zitadel/pkg/grpc/object/v2"
user_v2 "github.com/zitadel/zitadel/pkg/grpc/user/v2"
)
@ -239,7 +238,7 @@ func TestListUser(t *testing.T) {
assert.Equal(t, 1, resp.StartIndex)
assert.Len(t, resp.Resources, 1)
assert.True(t, strings.HasPrefix(resp.Resources[0].UserName, "scim-username-0"))
assert.False(t, *resp.Resources[0].Active)
assert.False(t, resp.Resources[0].Active.Bool())
},
},
{
@ -372,20 +371,10 @@ func TestListUser(t *testing.T) {
resp := createHumanUser(t, CTX, Instance.DefaultOrg.Id, 102)
// set provisioning domain of service user
_, err := Instance.Client.Mgmt.SetUserMetadata(CTX, &management.SetUserMetadataRequest{
Id: Instance.Users.Get(integration.UserTypeOrgOwner).ID,
Key: "urn:zitadel:scim:provisioningDomain",
Value: []byte("fooBar"),
})
require.NoError(t, err)
setProvisioningDomain(t, Instance.Users.Get(integration.UserTypeOrgOwner).ID, "fooBar")
// set externalID for provisioning domain
_, err = Instance.Client.Mgmt.SetUserMetadata(CTX, &management.SetUserMetadataRequest{
Id: resp.UserId,
Key: "urn:zitadel:scim:fooBar:externalId",
Value: []byte("100-scopedExternalId"),
})
require.NoError(t, err)
setAndEnsureMetadata(t, resp.UserId, "urn:zitadel:scim:fooBar:externalId", "100-scopedExternalId")
return &scim.ListRequest{
Filter: gu.Ptr(fmt.Sprintf(`id eq "%s"`, resp.UserId)),
}
@ -396,11 +385,7 @@ func TestListUser(t *testing.T) {
},
cleanup: func(t require.TestingT) {
// delete provisioning domain of service user
_, err := Instance.Client.Mgmt.RemoveUserMetadata(CTX, &management.RemoveUserMetadataRequest{
Id: Instance.Users.Get(integration.UserTypeOrgOwner).ID,
Key: "urn:zitadel:scim:provisioningDomain",
})
require.NoError(t, err)
removeProvisioningDomain(t, Instance.Users.Get(integration.UserTypeOrgOwner).ID)
},
},
}

View File

@ -10,7 +10,6 @@ import (
"testing"
"time"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/text/language"
@ -37,14 +36,16 @@ var (
func TestReplaceUser(t *testing.T) {
tests := []struct {
name string
body []byte
ctx context.Context
want *resources.ScimUser
wantErr bool
scimErrorType string
errorStatus int
zitadelErrID string
name string
body []byte
ctx context.Context
createUserOrgID string
replaceUserOrgID string
want *resources.ScimUser
wantErr bool
scimErrorType string
errorStatus int
zitadelErrID string
}{
{
name: "minimal user",
@ -165,7 +166,7 @@ func TestReplaceUser(t *testing.T) {
PreferredLanguage: language.MustParse("en-CH"),
Locale: "en-CH",
Timezone: "Europe/Zurich",
Active: gu.Ptr(false),
Active: schemas.NewRelaxedBool(false),
},
},
{
@ -207,10 +208,26 @@ func TestReplaceUser(t *testing.T) {
wantErr: true,
errorStatus: http.StatusNotFound,
},
{
name: "another org",
body: minimalUserJson,
replaceUserOrgID: SecondaryOrganization.OrganizationId,
wantErr: true,
errorStatus: http.StatusNotFound,
},
{
name: "another org with permissions",
body: minimalUserJson,
replaceUserOrgID: SecondaryOrganization.OrganizationId,
ctx: Instance.WithAuthorization(CTX, integration.UserTypeIAMOwner),
wantErr: true,
errorStatus: http.StatusNotFound,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
createdUser, err := Instance.Client.SCIM.Users.Create(CTX, Instance.DefaultOrg.Id, fullUserJson)
// use iam owner => we don't want to test permissions of the create endpoint.
createdUser, err := Instance.Client.SCIM.Users.Create(Instance.WithAuthorization(CTX, integration.UserTypeIAMOwner), Instance.DefaultOrg.Id, fullUserJson)
require.NoError(t, err)
defer func() {
@ -223,7 +240,12 @@ func TestReplaceUser(t *testing.T) {
ctx = CTX
}
replacedUser, err := Instance.Client.SCIM.Users.Replace(ctx, Instance.DefaultOrg.Id, createdUser.ID, tt.body)
replaceUserOrgID := tt.replaceUserOrgID
if replaceUserOrgID == "" {
replaceUserOrgID = Instance.DefaultOrg.Id
}
replacedUser, err := Instance.Client.SCIM.Users.Replace(ctx, replaceUserOrgID, createdUser.ID, tt.body)
if (err != nil) != tt.wantErr {
t.Errorf("ReplaceUser() error = %v, wantErr %v", err, tt.wantErr)
}
@ -294,12 +316,7 @@ func TestReplaceUser_scopedExternalID(t *testing.T) {
require.NoError(t, err)
// set provisioning domain of service user
_, err = Instance.Client.Mgmt.SetUserMetadata(CTX, &management.SetUserMetadataRequest{
Id: Instance.Users.Get(integration.UserTypeOrgOwner).ID,
Key: "urn:zitadel:scim:provisioningDomain",
Value: []byte("fooBazz"),
})
require.NoError(t, err)
setProvisioningDomain(t, Instance.Users.Get(integration.UserTypeOrgOwner).ID, "fooBazz")
// replace the user with provisioning domain set
_, err = Instance.Client.SCIM.Users.Replace(CTX, Instance.DefaultOrg.Id, createdUser.ID, minimalUserWithExternalIDJson)
@ -325,9 +342,5 @@ func TestReplaceUser_scopedExternalID(t *testing.T) {
_, err = Instance.Client.UserV2.DeleteUser(CTX, &user.DeleteUserRequest{UserId: createdUser.ID})
require.NoError(t, err)
_, err = Instance.Client.Mgmt.RemoveUserMetadata(CTX, &management.RemoveUserMetadataRequest{
Id: Instance.Users.Get(integration.UserTypeOrgOwner).ID,
Key: "urn:zitadel:scim:provisioningDomain",
})
require.NoError(t, err)
removeProvisioningDomain(t, Instance.Users.Get(integration.UserTypeOrgOwner).ID)
}

View File

@ -10,8 +10,6 @@ import (
"testing"
"time"
"github.com/brianvoe/gofakeit/v6"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/text/language"
@ -28,7 +26,7 @@ var (
//go:embed testdata/users_update_test_full.json
fullUserUpdateJson []byte
minimalUserUpdateJson = simpleReplacePatchBody("nickname", "foo")
minimalUserUpdateJson = simpleReplacePatchBody("nickname", "\"foo\"")
)
func init() {
@ -44,9 +42,6 @@ func TestUpdateUser(t *testing.T) {
require.NoError(t, err)
}()
iamOwnerCtx := Instance.WithAuthorization(CTX, integration.UserTypeIAMOwner)
secondaryOrg := Instance.CreateOrganization(iamOwnerCtx, gofakeit.Name(), gofakeit.Email())
tests := []struct {
name string
body []byte
@ -74,7 +69,15 @@ func TestUpdateUser(t *testing.T) {
},
{
name: "other org",
orgID: secondaryOrg.OrganizationId,
orgID: SecondaryOrganization.OrganizationId,
body: minimalUserUpdateJson,
wantErr: true,
errorStatus: http.StatusNotFound,
},
{
name: "other org with permissions",
ctx: Instance.WithAuthorization(CTX, integration.UserTypeIAMOwner),
orgID: SecondaryOrganization.OrganizationId,
body: minimalUserUpdateJson,
wantErr: true,
errorStatus: http.StatusNotFound,
@ -204,7 +207,7 @@ func TestUpdateUser(t *testing.T) {
PreferredLanguage: language.MustParse("en-US"),
Locale: "en-US",
Timezone: "America/Los_Angeles",
Active: gu.Ptr(true),
Active: schemas.NewRelaxedBool(true),
},
},
}

View File

@ -60,6 +60,9 @@ func (req *OperationRequest) Validate() error {
}
func (op *Operation) validate() error {
// ignore the casing, as some scim clients send these capitalized
op.Operation = OperationType(strings.ToLower(string(op.Operation)))
if !op.Operation.isValid() {
return serrors.ThrowInvalidValue(zerrors.ThrowInvalidArgumentf(nil, "SCIM-opty1", "Patch op %s not supported", op.Operation))
}

View File

@ -39,7 +39,7 @@ type ScimUser struct {
PreferredLanguage language.Tag `json:"preferredLanguage,omitempty"`
Locale string `json:"locale,omitempty"`
Timezone string `json:"timezone,omitempty"`
Active *bool `json:"active,omitempty"`
Active *scim_schemas.RelaxedBool `json:"active,omitempty"`
Emails []*ScimEmail `json:"emails,omitempty" scim:"required"`
PhoneNumbers []*ScimPhoneNumber `json:"phoneNumbers,omitempty"`
Password *scim_schemas.WriteOnlyString `json:"password,omitempty"`
@ -154,7 +154,7 @@ func (h *UsersHandler) Create(ctx context.Context, user *ScimUser) (*ScimUser, e
return nil, err
}
err = h.command.AddUserHuman(ctx, orgID, addHuman, true, h.userCodeAlg)
err = h.command.AddUserHuman(ctx, orgID, addHuman, false, h.userCodeAlg)
if err != nil {
return nil, err
}
@ -180,7 +180,8 @@ func (h *UsersHandler) Replace(ctx context.Context, id string, user *ScimUser) (
}
func (h *UsersHandler) Update(ctx context.Context, id string, operations patch.OperationCollection) error {
userWM, err := h.command.UserHumanWriteModel(ctx, id, true, true, true, true, false, false, true)
orgID := authz.GetCtxData(ctx).OrgID
userWM, err := h.command.UserHumanWriteModel(ctx, id, orgID, true, true, true, true, false, false, true)
if err != nil {
return err
}
@ -191,6 +192,9 @@ func (h *UsersHandler) Update(ctx context.Context, id string, operations patch.O
return err
}
// ensure the identity of the user is not modified
changeHuman.ID = id
changeHuman.ResourceOwner = orgID
return h.command.ChangeUserHuman(ctx, changeHuman, h.userCodeAlg)
}
@ -200,12 +204,12 @@ func (h *UsersHandler) Delete(ctx context.Context, id string) error {
return err
}
_, err = h.command.RemoveUserV2(ctx, id, memberships, grants...)
_, err = h.command.RemoveUserV2(ctx, id, authz.GetCtxData(ctx).OrgID, memberships, grants...)
return err
}
func (h *UsersHandler) Get(ctx context.Context, id string) (*ScimUser, error) {
user, err := h.query.GetUserByID(ctx, false, id)
user, err := h.query.GetUserByIDWithResourceOwner(ctx, false, id, authz.GetCtxData(ctx).OrgID)
if err != nil {
return nil, err
}

View File

@ -9,6 +9,7 @@ import (
"github.com/zitadel/logging"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/scim/metadata"
"github.com/zitadel/zitadel/internal/api/scim/schemas"
"github.com/zitadel/zitadel/internal/command"
@ -73,8 +74,9 @@ func (h *UsersHandler) mapToAddHuman(ctx context.Context, scimUser *ScimUser) (*
func (h *UsersHandler) mapToChangeHuman(ctx context.Context, scimUser *ScimUser) (*command.ChangeHuman, error) {
human := &command.ChangeHuman{
ID: scimUser.ID,
Username: &scimUser.UserName,
ID: scimUser.ID,
ResourceOwner: authz.GetCtxData(ctx).OrgID,
Username: &scimUser.UserName,
Profile: &command.Profile{
NickName: &scimUser.NickName,
DisplayName: &scimUser.DisplayName,
@ -271,7 +273,7 @@ func (h *UsersHandler) mapToScimUser(ctx context.Context, user *query.User, md m
FamilyName: user.Human.LastName,
GivenName: user.Human.FirstName,
},
Active: gu.Ptr(user.State.IsEnabled()),
Active: schemas.NewRelaxedBool(user.State.IsEnabled()),
}
if string(user.Human.Email) != "" {
@ -309,7 +311,7 @@ func (h *UsersHandler) mapWriteModelToScimUser(ctx context.Context, user *comman
FamilyName: user.LastName,
GivenName: user.FirstName,
},
Active: gu.Ptr(user.UserState.IsEnabled()),
Active: schemas.NewRelaxedBool(user.UserState.IsEnabled()),
}
if string(user.Email) != "" {

View File

@ -7,7 +7,6 @@ import (
"strings"
"testing"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/text/language"
@ -707,7 +706,7 @@ func TestOperationCollection_Apply(t *testing.T) {
PreferredLanguage: language.MustParse("en-US"),
Locale: "en-US",
Timezone: "America/New_York",
Active: gu.Ptr(true),
Active: schemas.NewRelaxedBool(true),
Emails: []*ScimEmail{
{
Value: "jeanie.pendleton@example.com",

View File

@ -0,0 +1,41 @@
package schemas
import (
"encoding/json"
"strings"
"github.com/muhlemmer/gu"
"github.com/zitadel/zitadel/internal/zerrors"
)
// RelaxedBool a bool which is more relaxed when it comes to json (un)marshaling.
// This ensures compatibility with some bugged scim providers,
// such as Microsoft Entry, which sends booleans as "True" or "False".
// See also https://learn.microsoft.com/en-us/entra/identity/app-provisioning/application-provisioning-config-problem-scim-compatibility.
type RelaxedBool bool
func NewRelaxedBool(value bool) *RelaxedBool {
return gu.Ptr(RelaxedBool(value))
}
func (b *RelaxedBool) MarshalJSON() ([]byte, error) {
return json.Marshal(bool(*b))
}
func (b *RelaxedBool) UnmarshalJSON(bytes []byte) error {
str := strings.ToLower(string(bytes))
switch {
case str == "true" || str == "\"true\"":
*b = true
case str == "false" || str == "\"false\"":
*b = false
default:
return zerrors.ThrowInvalidArgumentf(nil, "SCIM-BOO1", "bool expected, got %v", str)
}
return nil
}
func (b *RelaxedBool) Bool() bool {
return bool(*b)
}

View File

@ -0,0 +1,63 @@
package schemas
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRelaxedBool_MarshalJSON(t *testing.T) {
tests := []struct {
name string
input bool
expected string
}{
{name: "true", input: true, expected: "true"},
{name: "false", expected: "false"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
value := NewRelaxedBool(tt.input)
bytes, err := json.Marshal(value)
require.NoError(t, err)
assert.Equal(t, tt.expected, string(bytes))
})
}
}
func TestRelaxedBool_UnmarshalJSON(t *testing.T) {
tests := []struct {
name string
input string
expected bool
wantErr bool
}{
{name: "valid true", input: "true", expected: true},
{name: "valid false", input: "false"},
{name: "quoted true", input: `"true"`, expected: true},
{name: "quoted pascal case true", input: `"True"`, expected: true},
{name: "quoted upper case true", input: `"TRUE"`, expected: true},
{name: "quoted false", input: `"false"`},
{name: "quoted pascal case false", input: `"False"`},
{name: "quoted upper case false", input: `"FALSE"`},
{name: "invalid value", input: "invalid", wantErr: true},
{name: "number value", input: "1", wantErr: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
value := new(RelaxedBool)
err := json.Unmarshal([]byte(tt.input), &value)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expected, value.Bool())
}
})
}
}

View File

@ -45,6 +45,8 @@ var (
const (
envRequestPath = "/assets/environment.json"
// https://posthog.com/docs/advanced/content-security-policy
posthogCSPHost = "https://*.i.posthog.com"
)
var (
@ -106,12 +108,14 @@ func Start(config Config, externalSecure bool, issuer op.IssuerFromRequest, call
config.LongCache.MaxAge,
config.LongCache.SharedMaxAge,
)
security := middleware.SecurityHeaders(csp(), nil)
security := middleware.SecurityHeaders(csp(config.PostHog.URL), nil)
handler := mux.NewRouter()
handler.Use(security, limitingAccessInterceptor.WithoutLimiting().Handle)
handler.Use(callDurationInterceptor, instanceHandler, security, limitingAccessInterceptor.WithoutLimiting().Handle)
handler.Handle(envRequestPath, middleware.TelemetryHandler()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
env := handler.NewRoute().Path(envRequestPath).Subrouter()
env.Use(callDurationInterceptor, middleware.TelemetryHandler(), instanceHandler)
env.HandleFunc("", func(w http.ResponseWriter, r *http.Request) {
url := http_util.BuildOrigin(r.Host, externalSecure)
ctx := r.Context()
instance := authz.GetInstance(ctx)
@ -128,7 +132,7 @@ func Start(config Config, externalSecure bool, issuer op.IssuerFromRequest, call
}
_, err = w.Write(environmentJSON)
logging.OnError(err).Error("error serving environment.json")
})))
})
handler.SkipClean(true).PathPrefix("").Handler(cache(http.FileServer(&spaHandler{http.FS(fSys)})))
return handler, nil
}
@ -145,12 +149,22 @@ func templateInstanceManagementURL(templateableCookieValue string, instance auth
return cookieValue.String(), nil
}
func csp() *middleware.CSP {
func csp(posthogURL string) *middleware.CSP {
csp := middleware.DefaultSCP
csp.StyleSrc = csp.StyleSrc.AddInline()
csp.ScriptSrc = csp.ScriptSrc.AddEval()
csp.ConnectSrc = csp.ConnectSrc.AddOwnHost()
csp.ImgSrc = csp.ImgSrc.AddOwnHost().AddScheme("blob")
if posthogURL != "" {
// https://posthog.com/docs/advanced/content-security-policy#enabling-the-toolbar
csp.ScriptSrc = csp.ScriptSrc.AddHost(posthogCSPHost)
csp.ConnectSrc = csp.ConnectSrc.AddHost(posthogCSPHost)
csp.ImgSrc = csp.ImgSrc.AddHost(posthogCSPHost)
csp.StyleSrc = csp.StyleSrc.AddHost(posthogCSPHost)
csp.FontSrc = csp.FontSrc.AddHost(posthogCSPHost)
csp.MediaSrc = middleware.CSPSourceOpts().AddHost(posthogCSPHost)
}
return &csp
}

View File

@ -159,12 +159,12 @@ func (c *Commands) userStateWriteModel(ctx context.Context, userID string) (writ
return writeModel, nil
}
func (c *Commands) RemoveUserV2(ctx context.Context, userID string, cascadingUserMemberships []*CascadingMembership, cascadingGrantIDs ...string) (*domain.ObjectDetails, error) {
func (c *Commands) RemoveUserV2(ctx context.Context, userID, resourceOwner string, cascadingUserMemberships []*CascadingMembership, cascadingGrantIDs ...string) (*domain.ObjectDetails, error) {
if userID == "" {
return nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-vaipl7s13l", "Errors.User.UserIDMissing")
}
existingUser, err := c.userRemoveWriteModel(ctx, userID)
existingUser, err := c.userRemoveWriteModel(ctx, userID, resourceOwner)
if err != nil {
return nil, err
}
@ -210,11 +210,11 @@ func (c *Commands) RemoveUserV2(ctx context.Context, userID string, cascadingUse
return writeModelToObjectDetails(&existingUser.WriteModel), nil
}
func (c *Commands) userRemoveWriteModel(ctx context.Context, userID string) (writeModel *UserV2WriteModel, err error) {
func (c *Commands) userRemoveWriteModel(ctx context.Context, userID, resourceOwner string) (writeModel *UserV2WriteModel, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
writeModel = NewUserRemoveWriteModel(userID, "")
writeModel = NewUserRemoveWriteModel(userID, resourceOwner)
err = c.eventstore.FilterToQueryReducer(ctx, writeModel)
if err != nil {
return nil, err

View File

@ -14,12 +14,13 @@ import (
)
type ChangeHuman struct {
ID string
State *domain.UserState
Username *string
Profile *Profile
Email *Email
Phone *Phone
ID string
ResourceOwner string
State *domain.UserState
Username *string
Profile *Profile
Email *Email
Phone *Phone
Metadata []*domain.Metadata
MetadataKeysToRemove []string
@ -267,6 +268,7 @@ func (c *Commands) ChangeUserHuman(ctx context.Context, human *ChangeHuman, alg
existingHuman, err := c.UserHumanWriteModel(
ctx,
human.ID,
human.ResourceOwner,
human.Profile != nil,
human.Email != nil,
human.Phone != nil,
@ -525,11 +527,11 @@ func (c *Commands) userExistsWriteModel(ctx context.Context, userID string) (wri
return writeModel, nil
}
func (c *Commands) UserHumanWriteModel(ctx context.Context, userID string, profileWM, emailWM, phoneWM, passwordWM, avatarWM, idpLinksWM, metadataWM bool) (writeModel *UserV2WriteModel, err error) {
func (c *Commands) UserHumanWriteModel(ctx context.Context, userID, resourceOwner string, profileWM, emailWM, phoneWM, passwordWM, avatarWM, idpLinksWM, metadataWM bool) (writeModel *UserV2WriteModel, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
writeModel = NewUserHumanWriteModel(userID, "", profileWM, emailWM, phoneWM, passwordWM, avatarWM, idpLinksWM, metadataWM)
writeModel = NewUserHumanWriteModel(userID, resourceOwner, profileWM, emailWM, phoneWM, passwordWM, avatarWM, idpLinksWM, metadataWM)
err = c.eventstore.FilterToQueryReducer(ctx, writeModel)
if err != nil {
return nil, err

View File

@ -567,7 +567,7 @@ func TestCommandSide_userHumanWriteModel_profile(t *testing.T) {
r := &Commands{
eventstore: tt.fields.eventstore(t),
}
wm, err := r.UserHumanWriteModel(tt.args.ctx, tt.args.userID, true, false, false, false, false, false, false)
wm, err := r.UserHumanWriteModel(tt.args.ctx, tt.args.userID, "", true, false, false, false, false, false, false)
if tt.res.err == nil {
if !assert.NoError(t, err) {
t.FailNow()
@ -912,7 +912,7 @@ func TestCommandSide_userHumanWriteModel_email(t *testing.T) {
r := &Commands{
eventstore: tt.fields.eventstore(t),
}
wm, err := r.UserHumanWriteModel(tt.args.ctx, tt.args.userID, false, true, false, false, false, false, false)
wm, err := r.UserHumanWriteModel(tt.args.ctx, tt.args.userID, "", false, true, false, false, false, false, false)
if tt.res.err == nil {
if !assert.NoError(t, err) {
t.FailNow()
@ -1344,7 +1344,7 @@ func TestCommandSide_userHumanWriteModel_phone(t *testing.T) {
r := &Commands{
eventstore: tt.fields.eventstore(t),
}
wm, err := r.UserHumanWriteModel(tt.args.ctx, tt.args.userID, false, false, true, false, false, false, false)
wm, err := r.UserHumanWriteModel(tt.args.ctx, tt.args.userID, "", false, false, true, false, false, false, false)
if tt.res.err == nil {
if !assert.NoError(t, err) {
t.FailNow()
@ -1605,7 +1605,7 @@ func TestCommandSide_userHumanWriteModel_password(t *testing.T) {
r := &Commands{
eventstore: tt.fields.eventstore(t),
}
wm, err := r.UserHumanWriteModel(tt.args.ctx, tt.args.userID, false, false, false, true, false, false, false)
wm, err := r.UserHumanWriteModel(tt.args.ctx, tt.args.userID, "", false, false, false, true, false, false, false)
if tt.res.err == nil {
if !assert.NoError(t, err) {
t.FailNow()
@ -2132,7 +2132,7 @@ func TestCommandSide_userHumanWriteModel_avatar(t *testing.T) {
r := &Commands{
eventstore: tt.fields.eventstore(t),
}
wm, err := r.UserHumanWriteModel(tt.args.ctx, tt.args.userID, false, false, false, false, true, false, false)
wm, err := r.UserHumanWriteModel(tt.args.ctx, tt.args.userID, "", false, false, false, false, true, false, false)
if tt.res.err == nil {
if !assert.NoError(t, err) {
t.FailNow()
@ -2441,7 +2441,7 @@ func TestCommandSide_userHumanWriteModel_idpLinks(t *testing.T) {
r := &Commands{
eventstore: tt.fields.eventstore(t),
}
wm, err := r.userRemoveWriteModel(tt.args.ctx, tt.args.userID)
wm, err := r.userRemoveWriteModel(tt.args.ctx, tt.args.userID, "")
if tt.res.err == nil {
if !assert.NoError(t, err) {
t.FailNow()
@ -2744,7 +2744,7 @@ func TestCommandSide_userHumanWriteModel_metadata(t *testing.T) {
r := &Commands{
eventstore: tt.fields.eventstore(t),
}
wm, err := r.UserHumanWriteModel(tt.args.ctx, tt.args.userID, false, false, false, false, false, false, true)
wm, err := r.UserHumanWriteModel(tt.args.ctx, tt.args.userID, "", false, false, false, false, false, false, true)
if tt.res.err == nil {
if !assert.NoError(t, err) {
t.FailNow()

View File

@ -1358,7 +1358,7 @@ func TestCommandSide_RemoveUserV2(t *testing.T) {
eventstore: tt.fields.eventstore(t),
checkPermission: tt.fields.checkPermission,
}
got, err := r.RemoveUserV2(tt.args.ctx, tt.args.userID, tt.args.cascadingMemberships, tt.args.grantIDs...)
got, err := r.RemoveUserV2(tt.args.ctx, tt.args.userID, "", tt.args.cascadingMemberships, tt.args.grantIDs...)
if tt.res.err == nil {
assert.NoError(t, err)
}

View File

@ -6,7 +6,6 @@ import (
"github.com/zitadel/logging"
"go.opentelemetry.io/otel/attribute"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/notification/channels"
"github.com/zitadel/zitadel/internal/telemetry/metrics"
)
@ -18,18 +17,14 @@ func countMessages(ctx context.Context, channel channels.NotificationChannel, su
if err != nil {
metricName = errorMetricName
}
addCount(ctx, metricName, message, err)
addCount(ctx, metricName, message)
return err
})
}
func addCount(ctx context.Context, metricName string, message channels.Message, err error) {
func addCount(ctx context.Context, metricName string, message channels.Message) {
labels := map[string]attribute.Value{
"triggering_event_typey": attribute.StringValue(string(message.GetTriggeringEvent().Type())),
"instance": attribute.StringValue(authz.GetInstance(ctx).InstanceID()),
}
if err != nil {
labels["error"] = attribute.StringValue(err.Error())
"triggering_event_type": attribute.StringValue(string(message.GetTriggeringEvent().Type())),
}
addCountErr := metrics.AddCount(ctx, metricName, 1, labels)
logging.WithFields("name", metricName, "labels", labels).OnError(addCountErr).Error("incrementing counter metric failed")

View File

@ -368,6 +368,10 @@ func (q *Queries) GetUserByIDWithPermission(ctx context.Context, shouldTriggerBu
}
func (q *Queries) GetUserByID(ctx context.Context, shouldTriggerBulk bool, userID string) (user *User, err error) {
return q.GetUserByIDWithResourceOwner(ctx, shouldTriggerBulk, userID, "")
}
func (q *Queries) GetUserByIDWithResourceOwner(ctx context.Context, shouldTriggerBulk bool, userID, resourceOwner string) (user *User, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
@ -382,6 +386,7 @@ func (q *Queries) GetUserByID(ctx context.Context, shouldTriggerBulk bool, userI
},
userByIDQuery,
userID,
resourceOwner,
authz.GetInstance(ctx).InstanceID(),
)
return user, err

View File

@ -20,8 +20,8 @@ WITH login_names AS (SELECT
WHERE
u.instance_id = p.instance_id
AND (
(p.is_default IS TRUE AND p.instance_id = $2)
OR (p.instance_id = $2 AND p.resource_owner = u.resource_owner)
(p.is_default IS TRUE AND p.instance_id = $3)
OR (p.instance_id = $3 AND p.resource_owner = u.resource_owner)
)
ORDER BY is_default
LIMIT 1
@ -32,8 +32,9 @@ WITH login_names AS (SELECT
u.instance_id = d.instance_id
AND u.resource_owner = d.resource_owner
WHERE
u.instance_id = $2
AND u.id = $1
u.id = $1
AND (u.resource_owner = $2 OR $2 = '')
AND u.instance_id = $3
)
SELECT
u.id
@ -80,6 +81,7 @@ LEFT JOIN
AND u.instance_id = m.instance_id
WHERE
u.id = $1
AND u.instance_id = $2
AND (u.resource_owner = $2 OR $2 = '')
AND u.instance_id = $3
LIMIT 1
;

View File

@ -30,5 +30,5 @@ func TelemetryHandler(handler http.Handler, ignoredEndpoints ...string) http.Han
}
func spanNameFormatter(_ string, r *http.Request) string {
return r.Host + r.URL.EscapedPath()
return strings.Split(r.RequestURI, "?")[0]
}

View File

@ -1,6 +1,7 @@
package metrics
import (
"context"
"net/http"
"strings"
@ -35,7 +36,8 @@ const (
type StatusRecorder struct {
http.ResponseWriter
Status int
RequestURI *string
Status int
}
func (r *StatusRecorder) WriteHeader(status int) {
@ -56,6 +58,18 @@ func NewMetricsHandler(handler http.Handler, metricMethods []MetricType, ignored
return &h
}
type key int
const requestURI key = iota
func SetRequestURIPattern(ctx context.Context, pattern string) {
uri, ok := ctx.Value(requestURI).(*string)
if !ok {
return
}
*uri = pattern
}
// ServeHTTP serves HTTP requests (http.Handler)
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if len(h.methods) == 0 {
@ -69,13 +83,16 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
}
uri := strings.Split(r.RequestURI, "?")[0]
recorder := &StatusRecorder{
ResponseWriter: w,
RequestURI: &uri,
Status: 200,
}
r = r.WithContext(context.WithValue(r.Context(), requestURI, &uri))
h.handler.ServeHTTP(recorder, r)
if h.containsMetricsMethod(MetricTypeRequestCount) {
RegisterRequestCounter(r)
RegisterRequestCounter(recorder, r)
}
if h.containsMetricsMethod(MetricTypeTotalCount) {
RegisterTotalRequestCounter(r)
@ -94,9 +111,9 @@ func (h *Handler) containsMetricsMethod(method MetricType) bool {
return false
}
func RegisterRequestCounter(r *http.Request) {
func RegisterRequestCounter(recorder *StatusRecorder, r *http.Request) {
var labels = map[string]attribute.Value{
URI: attribute.StringValue(strings.Split(r.RequestURI, "?")[0]),
URI: attribute.StringValue(*recorder.RequestURI),
Method: attribute.StringValue(r.Method),
}
RegisterCounter(RequestCounter, RequestCountDescription)
@ -110,7 +127,7 @@ func RegisterTotalRequestCounter(r *http.Request) {
func RegisterRequestCodeCounter(recorder *StatusRecorder, r *http.Request) {
var labels = map[string]attribute.Value{
URI: attribute.StringValue(strings.Split(r.RequestURI, "?")[0]),
URI: attribute.StringValue(*recorder.RequestURI),
Method: attribute.StringValue(r.Method),
ReturnCode: attribute.IntValue(recorder.Status),
}

View File

@ -6,9 +6,11 @@ import (
"sync"
"github.com/prometheus/client_golang/prometheus/promhttp"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/exporters/prometheus"
"go.opentelemetry.io/otel/metric"
"go.opentelemetry.io/otel/sdk/instrumentation"
sdk_metric "go.opentelemetry.io/otel/sdk/metric"
"github.com/zitadel/zitadel/internal/telemetry/metrics"
@ -33,9 +35,19 @@ func NewMetrics(meterName string) (metrics.Metrics, error) {
if err != nil {
return &Metrics{}, err
}
// create a view to filter out unwanted attributes
view := sdk_metric.NewView(
sdk_metric.Instrument{
Scope: instrumentation.Scope{Name: otelhttp.ScopeName},
},
sdk_metric.Stream{
AttributeFilter: attribute.NewAllowKeysFilter("http.method", "http.status_code", "http.target"),
},
)
meterProvider := sdk_metric.NewMeterProvider(
sdk_metric.WithReader(exporter),
sdk_metric.WithResource(resource),
sdk_metric.WithView(view),
)
return &Metrics{
Provider: meterProvider,

View File

@ -28,7 +28,7 @@ type Tracer struct {
}
func (c *Config) NewTracer() error {
sampler := sdk_trace.ParentBased(sdk_trace.TraceIDRatioBased(c.Fraction))
sampler := otel.NewSampler(sdk_trace.TraceIDRatioBased(c.Fraction))
exporter, err := texporter.New(texporter.WithProjectID(c.ProjectID))
if err != nil {
return err

View File

@ -26,7 +26,7 @@ type Tracer struct {
}
func (c *Config) NewTracer() error {
sampler := sdk_trace.ParentBased(sdk_trace.TraceIDRatioBased(c.Fraction))
sampler := otel.NewSampler(sdk_trace.TraceIDRatioBased(c.Fraction))
exporter, err := stdout.New(stdout.WithPrettyPrint())
if err != nil {
return err

View File

@ -6,6 +6,7 @@ import (
otlpgrpc "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
sdk_trace "go.opentelemetry.io/otel/sdk/trace"
api_trace "go.opentelemetry.io/otel/trace"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/zerrors"
@ -47,7 +48,7 @@ func FractionFromConfig(i interface{}) (float64, error) {
}
func (c *Config) NewTracer() error {
sampler := sdk_trace.ParentBased(sdk_trace.TraceIDRatioBased(c.Fraction))
sampler := NewSampler(sdk_trace.TraceIDRatioBased(c.Fraction))
exporter, err := otlpgrpc.New(context.Background(), otlpgrpc.WithEndpoint(c.Endpoint), otlpgrpc.WithInsecure())
if err != nil {
return err
@ -56,3 +57,19 @@ func (c *Config) NewTracer() error {
tracing.T, err = NewTracer(sampler, exporter)
return err
}
// NewSampler returns a sampler decorator which behaves differently,
// based on the parent of the span. If the span has no parent and is of kind server,
// the decorated sampler is used to make sampling decision.
// If the span has a parent, depending on whether the parent is remote and whether it
// is sampled, one of the following samplers will apply:
// - remote parent sampled -> always sample
// - remote parent not sampled -> sample based on the decorated sampler (fraction based)
// - local parent sampled -> always sample
// - local parent not sampled -> never sample
func NewSampler(sampler sdk_trace.Sampler) sdk_trace.Sampler {
return sdk_trace.ParentBased(
tracing.SpanKindBased(sampler, api_trace.SpanKindServer),
sdk_trace.WithRemoteParentNotSampled(sampler),
)
}

View File

@ -0,0 +1,46 @@
package tracing
import (
"fmt"
"slices"
sdk_trace "go.opentelemetry.io/otel/sdk/trace"
"go.opentelemetry.io/otel/trace"
)
type spanKindSampler struct {
sampler sdk_trace.Sampler
kinds []trace.SpanKind
}
// ShouldSample implements the [sdk_trace.Sampler] interface.
// It will not sample any spans which do not match the configured span kinds.
// For spans which do match, the decorated sampler is used to make the sampling decision.
func (sk spanKindSampler) ShouldSample(p sdk_trace.SamplingParameters) sdk_trace.SamplingResult {
psc := trace.SpanContextFromContext(p.ParentContext)
if !slices.Contains(sk.kinds, p.Kind) {
return sdk_trace.SamplingResult{
Decision: sdk_trace.Drop,
Tracestate: psc.TraceState(),
}
}
s := sk.sampler.ShouldSample(p)
return s
}
func (sk spanKindSampler) Description() string {
return fmt.Sprintf("SpanKindBased{sampler:%s,kinds:%v}",
sk.sampler.Description(),
sk.kinds,
)
}
// SpanKindBased returns a sampler decorator which behaves differently, based on the kind of the span.
// If the span kind does not match one of the configured kinds, it will not be sampled.
// If the span kind matches, the decorated sampler is used to make sampling decision.
func SpanKindBased(sampler sdk_trace.Sampler, kinds ...trace.SpanKind) sdk_trace.Sampler {
return spanKindSampler{
sampler: sampler,
kinds: kinds,
}
}

View File

@ -0,0 +1,80 @@
package tracing
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
sdk_trace "go.opentelemetry.io/otel/sdk/trace"
"go.opentelemetry.io/otel/trace"
)
func TestSpanKindBased(t *testing.T) {
type args struct {
sampler sdk_trace.Sampler
kinds []trace.SpanKind
}
type want struct {
description string
sampled int
}
tests := []struct {
name string
args args
want want
}{
{
"never sample, no sample",
args{
sampler: sdk_trace.NeverSample(),
kinds: []trace.SpanKind{trace.SpanKindServer},
},
want{
description: "SpanKindBased{sampler:AlwaysOffSampler,kinds:[server]}",
sampled: 0,
},
},
{
"always sample, no kind, no sample",
args{
sampler: sdk_trace.AlwaysSample(),
kinds: nil,
},
want{
description: "SpanKindBased{sampler:AlwaysOnSampler,kinds:[]}",
sampled: 0,
},
},
{
"always sample, 2 kinds, 2 samples",
args{
sampler: sdk_trace.AlwaysSample(),
kinds: []trace.SpanKind{trace.SpanKindServer, trace.SpanKindClient},
},
want{
description: "SpanKindBased{sampler:AlwaysOnSampler,kinds:[server client]}",
sampled: 2,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sampler := SpanKindBased(tt.args.sampler, tt.args.kinds...)
assert.Equal(t, tt.want.description, sampler.Description())
p := sdk_trace.NewTracerProvider(sdk_trace.WithSampler(sampler))
tr := p.Tracer("test")
var sampled int
for i := trace.SpanKindUnspecified; i <= trace.SpanKindConsumer; i++ {
ctx := context.Background()
_, span := tr.Start(ctx, "test", trace.WithSpanKind(i))
if span.SpanContext().IsSampled() {
sampled++
}
}
assert.Equal(t, tt.want.sampled, sampled)
})
}
}