fix: add domain as attribute to list user auth methods (#8718)

# Which Problems Are Solved

There is no option to only query auth methods related to specific
domains.

# How the Problems Are Solved

Add domain as attribute to the ListAuthenticationMethodTypes request.

# Additional Changes

OwnerRemoved column removed from the projection.

# Additional Context

Closes #8615

---------

Co-authored-by: Livio Spring <livio.a@gmail.com>
This commit is contained in:
Stefan Benz
2024-10-10 18:50:53 +02:00
committed by GitHub
parent df2033253d
commit 4d593dace2
29 changed files with 649 additions and 86 deletions

View File

@@ -63,8 +63,8 @@ var (
name: projection.UserAuthMethodTypeCol,
table: userAuthMethodTable,
}
UserAuthMethodColumnOwnerRemoved = Column{
name: projection.UserAuthMethodOwnerRemovedCol,
UserAuthMethodColumnDomain = Column{
name: projection.UserAuthMethodDomainCol,
table: userAuthMethodTable,
}
@@ -72,11 +72,8 @@ var (
authMethodTypeUserID = UserAuthMethodColumnUserID.setTable(authMethodTypeTable)
authMethodTypeInstanceID = UserAuthMethodColumnInstanceID.setTable(authMethodTypeTable)
authMethodTypeType = UserAuthMethodColumnMethodType.setTable(authMethodTypeTable)
authMethodTypeTypes = Column{
name: "method_types",
table: authMethodTypeTable,
}
authMethodTypeState = UserAuthMethodColumnState.setTable(authMethodTypeTable)
authMethodTypeState = UserAuthMethodColumnState.setTable(authMethodTypeTable)
authMethodTypeDomain = UserAuthMethodColumnDomain.setTable(authMethodTypeTable)
userIDPsCountTable = idpUserLinkTable.setAlias("user_idps_count")
userIDPsCountUserID = IDPUserLinkUserIDCol.setTable(userIDPsCountTable)
@@ -140,7 +137,7 @@ func (q *UserAuthMethodSearchQueries) hasUserID() bool {
}
func (q *Queries) SearchUserAuthMethods(ctx context.Context, queries *UserAuthMethodSearchQueries, permissionCheck domain.PermissionCheck) (userAuthMethods *AuthMethods, err error) {
methods, err := q.searchUserAuthMethods(ctx, queries, false)
methods, err := q.searchUserAuthMethods(ctx, queries)
if err != nil {
return nil, err
}
@@ -157,16 +154,12 @@ func (q *Queries) SearchUserAuthMethods(ctx context.Context, queries *UserAuthMe
return methods, nil
}
func (q *Queries) searchUserAuthMethods(ctx context.Context, queries *UserAuthMethodSearchQueries, withOwnerRemoved bool) (userAuthMethods *AuthMethods, err error) {
func (q *Queries) searchUserAuthMethods(ctx context.Context, queries *UserAuthMethodSearchQueries) (userAuthMethods *AuthMethods, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := prepareUserAuthMethodsQuery(ctx, q.client)
eq := sq.Eq{UserAuthMethodColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID()}
if !withOwnerRemoved {
eq[UserAuthMethodColumnOwnerRemoved.identifier()] = false
}
stmt, args, err := queries.toQuery(query).Where(eq).ToSql()
stmt, args, err := queries.toQuery(query).Where(sq.Eq{UserAuthMethodColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID()}).ToSql()
if err != nil {
return nil, zerrors.ThrowInvalidArgument(err, "QUERY-j9NJd", "Errors.Query.InvalidRequest")
}
@@ -182,7 +175,7 @@ func (q *Queries) searchUserAuthMethods(ctx context.Context, queries *UserAuthMe
return userAuthMethods, err
}
func (q *Queries) ListUserAuthMethodTypes(ctx context.Context, userID string, activeOnly bool) (userAuthMethodTypes *AuthMethodTypes, err error) {
func (q *Queries) ListUserAuthMethodTypes(ctx context.Context, userID string, activeOnly bool, includeWithoutDomain bool, queryDomain string) (userAuthMethodTypes *AuthMethodTypes, err error) {
ctxData := authz.GetCtxData(ctx)
if ctxData.UserID != userID {
if err := q.checkPermission(ctx, domain.PermissionUserRead, ctxData.OrgID, userID); err != nil {
@@ -192,7 +185,7 @@ func (q *Queries) ListUserAuthMethodTypes(ctx context.Context, userID string, ac
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := prepareUserAuthMethodTypesQuery(ctx, q.client, activeOnly)
query, scan := prepareUserAuthMethodTypesQuery(ctx, q.client, activeOnly, includeWithoutDomain, queryDomain)
eq := sq.Eq{
UserIDCol.identifier(): userID,
UserInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID(),
@@ -389,8 +382,8 @@ func prepareUserAuthMethodsQuery(ctx context.Context, db prepareDatabase) (sq.Se
}
}
func prepareUserAuthMethodTypesQuery(ctx context.Context, db prepareDatabase, activeOnly bool) (sq.SelectBuilder, func(*sql.Rows) (*AuthMethodTypes, error)) {
authMethodsQuery, authMethodsArgs, err := prepareAuthMethodQuery(activeOnly)
func prepareUserAuthMethodTypesQuery(ctx context.Context, db prepareDatabase, activeOnly bool, includeWithoutDomain bool, queryDomain string) (sq.SelectBuilder, func(*sql.Rows) (*AuthMethodTypes, error)) {
authMethodsQuery, authMethodsArgs, err := prepareAuthMethodQuery(activeOnly, includeWithoutDomain, queryDomain)
if err != nil {
return sq.SelectBuilder{}, nil
}
@@ -504,7 +497,7 @@ func prepareAuthMethodsIDPsQuery() (string, error) {
return idpsQuery, err
}
func prepareAuthMethodQuery(activeOnly bool) (string, []interface{}, error) {
func prepareAuthMethodQuery(activeOnly bool, includeWithoutDomain bool, queryDomain string) (string, []interface{}, error) {
q := sq.Select(
"DISTINCT("+authMethodTypeType.identifier()+")",
authMethodTypeUserID.identifier(),
@@ -513,6 +506,17 @@ func prepareAuthMethodQuery(activeOnly bool) (string, []interface{}, error) {
if activeOnly {
q = q.Where(sq.Eq{authMethodTypeState.identifier(): domain.MFAStateReady})
}
if queryDomain != "" {
conditions := sq.Or{
sq.Eq{authMethodTypeDomain.identifier(): nil},
sq.Eq{authMethodTypeDomain.identifier(): queryDomain},
}
if includeWithoutDomain {
conditions = append(conditions, sq.Eq{authMethodTypeDomain.identifier(): ""})
}
q = q.Where(conditions)
}
return q.ToSql()
}