mirror of
https://github.com/zitadel/zitadel.git
synced 2025-04-16 10:11:29 +00:00

This change moves IDPLoginPolicyLinks out of the scan function Login Policy queries in order to prevent potential deadlocks.
475 lines
15 KiB
Go
475 lines
15 KiB
Go
package query
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"time"
|
|
|
|
sq "github.com/Masterminds/squirrel"
|
|
|
|
"github.com/zitadel/logging"
|
|
|
|
"github.com/zitadel/zitadel/internal/api/authz"
|
|
"github.com/zitadel/zitadel/internal/api/call"
|
|
"github.com/zitadel/zitadel/internal/database"
|
|
"github.com/zitadel/zitadel/internal/domain"
|
|
"github.com/zitadel/zitadel/internal/eventstore/handler/v2"
|
|
"github.com/zitadel/zitadel/internal/query/projection"
|
|
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
|
"github.com/zitadel/zitadel/internal/zerrors"
|
|
)
|
|
|
|
type LoginPolicy struct {
|
|
OrgID string
|
|
CreationDate time.Time
|
|
ChangeDate time.Time
|
|
Sequence uint64
|
|
AllowRegister bool
|
|
AllowUsernamePassword bool
|
|
AllowExternalIDPs bool
|
|
ForceMFA bool
|
|
ForceMFALocalOnly bool
|
|
SecondFactors database.Array[domain.SecondFactorType]
|
|
MultiFactors database.Array[domain.MultiFactorType]
|
|
PasswordlessType domain.PasswordlessType
|
|
IsDefault bool
|
|
HidePasswordReset bool
|
|
IgnoreUnknownUsernames bool
|
|
AllowDomainDiscovery bool
|
|
DisableLoginWithEmail bool
|
|
DisableLoginWithPhone bool
|
|
DefaultRedirectURI string
|
|
PasswordCheckLifetime time.Duration
|
|
ExternalLoginCheckLifetime time.Duration
|
|
MFAInitSkipLifetime time.Duration
|
|
SecondFactorCheckLifetime time.Duration
|
|
MultiFactorCheckLifetime time.Duration
|
|
IDPLinks []*IDPLoginPolicyLink
|
|
}
|
|
|
|
type SecondFactors struct {
|
|
SearchResponse
|
|
Factors database.Array[domain.SecondFactorType]
|
|
}
|
|
|
|
type MultiFactors struct {
|
|
SearchResponse
|
|
Factors database.Array[domain.MultiFactorType]
|
|
}
|
|
|
|
var (
|
|
loginPolicyTable = table{
|
|
name: projection.LoginPolicyTable,
|
|
instanceIDCol: projection.LoginPolicyInstanceIDCol,
|
|
}
|
|
LoginPolicyColumnOrgID = Column{
|
|
name: projection.LoginPolicyIDCol,
|
|
table: loginPolicyTable,
|
|
}
|
|
LoginPolicyColumnInstanceID = Column{
|
|
name: projection.LoginPolicyInstanceIDCol,
|
|
table: loginPolicyTable,
|
|
}
|
|
LoginPolicyColumnCreationDate = Column{
|
|
name: projection.LoginPolicyCreationDateCol,
|
|
table: loginPolicyTable,
|
|
}
|
|
LoginPolicyColumnChangeDate = Column{
|
|
name: projection.LoginPolicyChangeDateCol,
|
|
table: loginPolicyTable,
|
|
}
|
|
LoginPolicyColumnSequence = Column{
|
|
name: projection.LoginPolicySequenceCol,
|
|
table: loginPolicyTable,
|
|
}
|
|
LoginPolicyColumnAllowRegister = Column{
|
|
name: projection.LoginPolicyAllowRegisterCol,
|
|
table: loginPolicyTable,
|
|
}
|
|
LoginPolicyColumnAllowUsernamePassword = Column{
|
|
name: projection.LoginPolicyAllowUsernamePasswordCol,
|
|
table: loginPolicyTable,
|
|
}
|
|
LoginPolicyColumnAllowExternalIDPs = Column{
|
|
name: projection.LoginPolicyAllowExternalIDPsCol,
|
|
table: loginPolicyTable,
|
|
}
|
|
LoginPolicyColumnForceMFA = Column{
|
|
name: projection.LoginPolicyForceMFACol,
|
|
table: loginPolicyTable,
|
|
}
|
|
LoginPolicyColumnForceMFALocalOnly = Column{
|
|
name: projection.LoginPolicyForceMFALocalOnlyCol,
|
|
table: loginPolicyTable,
|
|
}
|
|
LoginPolicyColumnSecondFactors = Column{
|
|
name: projection.LoginPolicy2FAsCol,
|
|
table: loginPolicyTable,
|
|
}
|
|
LoginPolicyColumnMultiFactors = Column{
|
|
name: projection.LoginPolicyMFAsCol,
|
|
table: loginPolicyTable,
|
|
}
|
|
LoginPolicyColumnPasswordlessType = Column{
|
|
name: projection.LoginPolicyPasswordlessTypeCol,
|
|
table: loginPolicyTable,
|
|
}
|
|
LoginPolicyColumnIsDefault = Column{
|
|
name: projection.LoginPolicyIsDefaultCol,
|
|
table: loginPolicyTable,
|
|
}
|
|
LoginPolicyColumnHidePasswordReset = Column{
|
|
name: projection.LoginPolicyHidePWResetCol,
|
|
table: loginPolicyTable,
|
|
}
|
|
LoginPolicyColumnIgnoreUnknownUsernames = Column{
|
|
name: projection.IgnoreUnknownUsernames,
|
|
table: loginPolicyTable,
|
|
}
|
|
LoginPolicyColumnAllowDomainDiscovery = Column{
|
|
name: projection.AllowDomainDiscovery,
|
|
table: loginPolicyTable,
|
|
}
|
|
LoginPolicyColumnDisableLoginWithEmail = Column{
|
|
name: projection.DisableLoginWithEmail,
|
|
table: loginPolicyTable,
|
|
}
|
|
LoginPolicyColumnDisableLoginWithPhone = Column{
|
|
name: projection.DisableLoginWithPhone,
|
|
table: loginPolicyTable,
|
|
}
|
|
LoginPolicyColumnDefaultRedirectURI = Column{
|
|
name: projection.DefaultRedirectURI,
|
|
table: loginPolicyTable,
|
|
}
|
|
LoginPolicyColumnPasswordCheckLifetime = Column{
|
|
name: projection.PasswordCheckLifetimeCol,
|
|
table: loginPolicyTable,
|
|
}
|
|
LoginPolicyColumnExternalLoginCheckLifetime = Column{
|
|
name: projection.ExternalLoginCheckLifetimeCol,
|
|
table: loginPolicyTable,
|
|
}
|
|
LoginPolicyColumnMFAInitSkipLifetime = Column{
|
|
name: projection.MFAInitSkipLifetimeCol,
|
|
table: loginPolicyTable,
|
|
}
|
|
LoginPolicyColumnSecondFactorCheckLifetime = Column{
|
|
name: projection.SecondFactorCheckLifetimeCol,
|
|
table: loginPolicyTable,
|
|
}
|
|
LoginPolicyColumnMultiFactorCheckLifetime = Column{
|
|
name: projection.MultiFactorCheckLifetimeCol,
|
|
table: loginPolicyTable,
|
|
}
|
|
LoginPolicyColumnOwnerRemoved = Column{
|
|
name: projection.LoginPolicyOwnerRemovedCol,
|
|
table: loginPolicyTable,
|
|
}
|
|
)
|
|
|
|
func (q *Queries) LoginPolicyByID(ctx context.Context, shouldTriggerBulk bool, orgID string, withOwnerRemoved bool) (policy *LoginPolicy, err error) {
|
|
ctx, span := tracing.NewSpan(ctx)
|
|
defer func() { span.EndWithError(err) }()
|
|
|
|
if shouldTriggerBulk {
|
|
_, traceSpan := tracing.NewNamedSpan(ctx, "TriggerLoginPolicyProjection")
|
|
ctx, err = projection.LoginPolicyProjection.Trigger(ctx, handler.WithAwaitRunning())
|
|
logging.OnError(err).Debug("trigger failed")
|
|
traceSpan.EndWithError(err)
|
|
}
|
|
eq := sq.Eq{LoginPolicyColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID()}
|
|
if !withOwnerRemoved {
|
|
eq[LoginPolicyColumnOwnerRemoved.identifier()] = false
|
|
}
|
|
|
|
query, scan := prepareLoginPolicyQuery(ctx, q.client)
|
|
stmt, args, err := query.Where(
|
|
sq.And{
|
|
eq,
|
|
sq.Or{
|
|
sq.Eq{LoginPolicyColumnOrgID.identifier(): orgID},
|
|
sq.Eq{LoginPolicyColumnOrgID.identifier(): authz.GetInstance(ctx).InstanceID()},
|
|
},
|
|
}).Limit(1).OrderBy(LoginPolicyColumnIsDefault.identifier()).ToSql()
|
|
if err != nil {
|
|
return nil, zerrors.ThrowInternal(err, "QUERY-scVHo", "Errors.Query.SQLStatement")
|
|
}
|
|
|
|
err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
|
|
policy, err = scan(rows)
|
|
return err
|
|
}, stmt, args...)
|
|
if err != nil {
|
|
return nil, zerrors.ThrowInternal(err, "QUERY-SWgr3", "Errors.Internal")
|
|
}
|
|
return policy, q.addLinksToLoginPolicy(ctx, policy)
|
|
}
|
|
|
|
func (q *Queries) addLinksToLoginPolicy(ctx context.Context, policy *LoginPolicy) error {
|
|
links, err := q.IDPLoginPolicyLinks(ctx, policy.OrgID, &IDPLoginPolicyLinksSearchQuery{}, false)
|
|
if err != nil {
|
|
return zerrors.ThrowInternal(err, "QUERY-aa4Ve", "Errors.Internal")
|
|
}
|
|
policy.IDPLinks = append(policy.IDPLinks, links.Links...)
|
|
return nil
|
|
}
|
|
|
|
func (q *Queries) DefaultLoginPolicy(ctx context.Context) (policy *LoginPolicy, err error) {
|
|
ctx, span := tracing.NewSpan(ctx)
|
|
defer func() { span.EndWithError(err) }()
|
|
|
|
query, scan := prepareLoginPolicyQuery(ctx, q.client)
|
|
stmt, args, err := query.Where(sq.Eq{
|
|
LoginPolicyColumnOrgID.identifier(): authz.GetInstance(ctx).InstanceID(),
|
|
LoginPolicyColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
|
|
}).OrderBy(LoginPolicyColumnIsDefault.identifier()).ToSql()
|
|
if err != nil {
|
|
return nil, zerrors.ThrowInternal(err, "QUERY-t4TBK", "Errors.Query.SQLStatement")
|
|
}
|
|
|
|
err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
|
|
policy, err = scan(rows)
|
|
return err
|
|
}, stmt, args...)
|
|
if err != nil {
|
|
return nil, zerrors.ThrowInternal(err, "QUERY-SArt2", "Errors.Internal")
|
|
}
|
|
return policy, q.addLinksToLoginPolicy(ctx, policy)
|
|
}
|
|
|
|
func (q *Queries) SecondFactorsByOrg(ctx context.Context, orgID string) (factors *SecondFactors, err error) {
|
|
ctx, span := tracing.NewSpan(ctx)
|
|
defer func() { span.EndWithError(err) }()
|
|
|
|
query, scan := prepareLoginPolicy2FAsQuery(ctx, q.client)
|
|
stmt, args, err := query.Where(
|
|
sq.And{
|
|
sq.Eq{
|
|
LoginPolicyColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
|
|
},
|
|
sq.Or{
|
|
sq.Eq{
|
|
LoginPolicyColumnOrgID.identifier(): orgID,
|
|
},
|
|
sq.Eq{
|
|
LoginPolicyColumnOrgID.identifier(): authz.GetInstance(ctx).InstanceID(),
|
|
},
|
|
},
|
|
}).
|
|
OrderBy(LoginPolicyColumnIsDefault.identifier()).
|
|
Limit(1).ToSql()
|
|
if err != nil {
|
|
return nil, zerrors.ThrowInternal(err, "QUERY-scVHo", "Errors.Query.SQLStatement")
|
|
}
|
|
|
|
err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
|
|
factors, err = scan(row)
|
|
return err
|
|
}, stmt, args...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
factors.State, err = q.latestState(ctx, loginPolicyTable)
|
|
return factors, err
|
|
}
|
|
|
|
func (q *Queries) DefaultSecondFactors(ctx context.Context) (factors *SecondFactors, err error) {
|
|
ctx, span := tracing.NewSpan(ctx)
|
|
defer func() { span.EndWithError(err) }()
|
|
|
|
query, scan := prepareLoginPolicy2FAsQuery(ctx, q.client)
|
|
stmt, args, err := query.Where(sq.Eq{
|
|
LoginPolicyColumnOrgID.identifier(): authz.GetInstance(ctx).InstanceID(),
|
|
LoginPolicyColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
|
|
}).OrderBy(LoginPolicyColumnIsDefault.identifier()).ToSql()
|
|
if err != nil {
|
|
return nil, zerrors.ThrowInternal(err, "QUERY-CZ2Nv", "Errors.Query.SQLStatement")
|
|
}
|
|
|
|
err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
|
|
factors, err = scan(row)
|
|
return err
|
|
}, stmt, args...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
factors.State, err = q.latestState(ctx, loginPolicyTable)
|
|
return factors, err
|
|
}
|
|
|
|
func (q *Queries) MultiFactorsByOrg(ctx context.Context, orgID string) (factors *MultiFactors, err error) {
|
|
ctx, span := tracing.NewSpan(ctx)
|
|
defer func() { span.EndWithError(err) }()
|
|
|
|
query, scan := prepareLoginPolicyMFAsQuery(ctx, q.client)
|
|
stmt, args, err := query.Where(
|
|
sq.And{
|
|
sq.Eq{
|
|
LoginPolicyColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
|
|
},
|
|
sq.Or{
|
|
sq.Eq{
|
|
LoginPolicyColumnOrgID.identifier(): orgID,
|
|
},
|
|
sq.Eq{
|
|
LoginPolicyColumnOrgID.identifier(): authz.GetInstance(ctx).InstanceID(),
|
|
},
|
|
},
|
|
}).
|
|
OrderBy(LoginPolicyColumnIsDefault.identifier()).
|
|
Limit(1).ToSql()
|
|
if err != nil {
|
|
return nil, zerrors.ThrowInternal(err, "QUERY-B4o7h", "Errors.Query.SQLStatement")
|
|
}
|
|
|
|
err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
|
|
factors, err = scan(row)
|
|
return err
|
|
}, stmt, args...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
factors.State, err = q.latestState(ctx, loginPolicyTable)
|
|
return factors, err
|
|
}
|
|
|
|
func (q *Queries) DefaultMultiFactors(ctx context.Context) (factors *MultiFactors, err error) {
|
|
ctx, span := tracing.NewSpan(ctx)
|
|
defer func() { span.EndWithError(err) }()
|
|
|
|
query, scan := prepareLoginPolicyMFAsQuery(ctx, q.client)
|
|
stmt, args, err := query.Where(sq.Eq{
|
|
LoginPolicyColumnOrgID.identifier(): authz.GetInstance(ctx).InstanceID(),
|
|
LoginPolicyColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
|
|
}).OrderBy(LoginPolicyColumnIsDefault.identifier()).ToSql()
|
|
if err != nil {
|
|
return nil, zerrors.ThrowInternal(err, "QUERY-WxYjr", "Errors.Query.SQLStatement")
|
|
}
|
|
|
|
err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
|
|
factors, err = scan(row)
|
|
return err
|
|
}, stmt, args...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
factors.State, err = q.latestState(ctx, loginPolicyTable)
|
|
return factors, err
|
|
}
|
|
|
|
func prepareLoginPolicyQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*LoginPolicy, error)) {
|
|
return sq.Select(
|
|
LoginPolicyColumnOrgID.identifier(),
|
|
LoginPolicyColumnCreationDate.identifier(),
|
|
LoginPolicyColumnChangeDate.identifier(),
|
|
LoginPolicyColumnSequence.identifier(),
|
|
LoginPolicyColumnAllowRegister.identifier(),
|
|
LoginPolicyColumnAllowUsernamePassword.identifier(),
|
|
LoginPolicyColumnAllowExternalIDPs.identifier(),
|
|
LoginPolicyColumnForceMFA.identifier(),
|
|
LoginPolicyColumnForceMFALocalOnly.identifier(),
|
|
LoginPolicyColumnSecondFactors.identifier(),
|
|
LoginPolicyColumnMultiFactors.identifier(),
|
|
LoginPolicyColumnPasswordlessType.identifier(),
|
|
LoginPolicyColumnIsDefault.identifier(),
|
|
LoginPolicyColumnHidePasswordReset.identifier(),
|
|
LoginPolicyColumnIgnoreUnknownUsernames.identifier(),
|
|
LoginPolicyColumnAllowDomainDiscovery.identifier(),
|
|
LoginPolicyColumnDisableLoginWithEmail.identifier(),
|
|
LoginPolicyColumnDisableLoginWithPhone.identifier(),
|
|
LoginPolicyColumnDefaultRedirectURI.identifier(),
|
|
LoginPolicyColumnPasswordCheckLifetime.identifier(),
|
|
LoginPolicyColumnExternalLoginCheckLifetime.identifier(),
|
|
LoginPolicyColumnMFAInitSkipLifetime.identifier(),
|
|
LoginPolicyColumnSecondFactorCheckLifetime.identifier(),
|
|
LoginPolicyColumnMultiFactorCheckLifetime.identifier(),
|
|
).From(loginPolicyTable.identifier() + db.Timetravel(call.Took(ctx))).
|
|
PlaceholderFormat(sq.Dollar),
|
|
func(rows *sql.Rows) (*LoginPolicy, error) {
|
|
p := new(LoginPolicy)
|
|
defaultRedirectURI := sql.NullString{}
|
|
for rows.Next() {
|
|
err := rows.Scan(
|
|
&p.OrgID,
|
|
&p.CreationDate,
|
|
&p.ChangeDate,
|
|
&p.Sequence,
|
|
&p.AllowRegister,
|
|
&p.AllowUsernamePassword,
|
|
&p.AllowExternalIDPs,
|
|
&p.ForceMFA,
|
|
&p.ForceMFALocalOnly,
|
|
&p.SecondFactors,
|
|
&p.MultiFactors,
|
|
&p.PasswordlessType,
|
|
&p.IsDefault,
|
|
&p.HidePasswordReset,
|
|
&p.IgnoreUnknownUsernames,
|
|
&p.AllowDomainDiscovery,
|
|
&p.DisableLoginWithEmail,
|
|
&p.DisableLoginWithPhone,
|
|
&defaultRedirectURI,
|
|
&p.PasswordCheckLifetime,
|
|
&p.ExternalLoginCheckLifetime,
|
|
&p.MFAInitSkipLifetime,
|
|
&p.SecondFactorCheckLifetime,
|
|
&p.MultiFactorCheckLifetime,
|
|
)
|
|
if err != nil {
|
|
return nil, zerrors.ThrowInternal(err, "QUERY-YcC53", "Errors.Internal")
|
|
}
|
|
}
|
|
if p.OrgID == "" {
|
|
return nil, zerrors.ThrowNotFound(nil, "QUERY-QsUBJ", "Errors.LoginPolicy.NotFound")
|
|
}
|
|
p.DefaultRedirectURI = defaultRedirectURI.String
|
|
return p, nil
|
|
}
|
|
}
|
|
|
|
func prepareLoginPolicy2FAsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*SecondFactors, error)) {
|
|
return sq.Select(
|
|
LoginPolicyColumnSecondFactors.identifier(),
|
|
).From(loginPolicyTable.identifier() + db.Timetravel(call.Took(ctx))).
|
|
PlaceholderFormat(sq.Dollar),
|
|
func(row *sql.Row) (*SecondFactors, error) {
|
|
p := new(SecondFactors)
|
|
err := row.Scan(
|
|
&p.Factors,
|
|
)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, zerrors.ThrowNotFound(err, "QUERY-yPqIZ", "Errors.LoginPolicy.NotFound")
|
|
}
|
|
return nil, zerrors.ThrowInternal(err, "QUERY-Mr6H3", "Errors.Internal")
|
|
}
|
|
|
|
p.Count = uint64(len(p.Factors))
|
|
return p, nil
|
|
}
|
|
}
|
|
|
|
func prepareLoginPolicyMFAsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*MultiFactors, error)) {
|
|
return sq.Select(
|
|
LoginPolicyColumnMultiFactors.identifier(),
|
|
).From(loginPolicyTable.identifier() + db.Timetravel(call.Took(ctx))).
|
|
PlaceholderFormat(sq.Dollar),
|
|
func(row *sql.Row) (*MultiFactors, error) {
|
|
p := new(MultiFactors)
|
|
err := row.Scan(
|
|
&p.Factors,
|
|
)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, zerrors.ThrowNotFound(err, "QUERY-yPqIZ", "Errors.LoginPolicy.NotFound")
|
|
}
|
|
return nil, zerrors.ThrowInternal(err, "QUERY-Mr6H3", "Errors.Internal")
|
|
}
|
|
|
|
p.Count = uint64(len(p.Factors))
|
|
return p, nil
|
|
}
|
|
}
|