fix: store auth methods instead of AMR in auth request linking and OIDC Session (#6192)

This PR changes the information stored on the SessionLinkedEvent and (OIDC Session) AddedEvent from OIDC AMR strings to domain.UserAuthMethodTypes, so no information is lost in the process (e.g. authentication with an IDP)
This commit is contained in:
Livio Spring
2023-07-12 14:24:01 +02:00
committed by GitHub
parent a3a1e245ad
commit ee26f99ebf
15 changed files with 156 additions and 174 deletions

View File

@@ -5,7 +5,6 @@ import (
"time"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/oidc/amr"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/repository/authrequest"
@@ -32,10 +31,10 @@ type AuthRequest struct {
type CurrentAuthRequest struct {
*AuthRequest
SessionID string
UserID string
AMR []string
AuthTime time.Time
SessionID string
UserID string
AuthMethods []domain.UserAuthMethodType
AuthTime time.Time
}
const IDPrefixV2 = "V2_"
@@ -108,7 +107,7 @@ func (c *Commands) LinkSessionToAuthRequest(ctx context.Context, id, sessionID,
sessionID,
sessionWriteModel.UserID,
sessionWriteModel.AuthenticationTime(),
amr.List(sessionWriteModel),
sessionWriteModel.AuthMethodTypes(),
)); err != nil {
return nil, nil, err
}
@@ -187,10 +186,10 @@ func authRequestWriteModelToCurrentAuthRequest(writeModel *AuthRequestWriteModel
LoginHint: writeModel.LoginHint,
HintUserID: writeModel.HintUserID,
},
SessionID: writeModel.SessionID,
UserID: writeModel.UserID,
AMR: writeModel.AMR,
AuthTime: writeModel.AuthTime,
SessionID: writeModel.SessionID,
UserID: writeModel.UserID,
AuthMethods: writeModel.AuthMethods,
AuthTime: writeModel.AuthTime,
}
}

View File

@@ -32,7 +32,7 @@ type AuthRequestWriteModel struct {
SessionID string
UserID string
AuthTime time.Time
AMR []string
AuthMethods []domain.UserAuthMethodType
AuthRequestState domain.AuthRequestState
}
@@ -68,7 +68,7 @@ func (m *AuthRequestWriteModel) Reduce() error {
m.SessionID = e.SessionID
m.UserID = e.UserID
m.AuthTime = e.AuthTime
m.AMR = e.AMR
m.AuthMethods = e.AuthMethods
case *authrequest.CodeAddedEvent:
m.AuthRequestState = domain.AuthRequestStateCodeAdded
case *authrequest.FailedEvent:

View File

@@ -10,7 +10,6 @@ import (
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/oidc/amr"
"github.com/zitadel/zitadel/internal/domain"
caos_errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
@@ -463,7 +462,7 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
"sessionID",
"userID",
testNow,
[]string{amr.PWD},
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
),
)}),
),
@@ -492,9 +491,9 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
Audience: []string{"audience"},
ResponseType: domain.OIDCResponseTypeCode,
},
SessionID: "sessionID",
UserID: "userID",
AMR: []string{amr.PWD},
SessionID: "sessionID",
UserID: "userID",
AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
},
},
},
@@ -542,7 +541,7 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
"sessionID",
"userID",
testNow,
[]string{amr.PWD},
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
),
)}),
),
@@ -572,9 +571,9 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
Audience: []string{"audience"},
ResponseType: domain.OIDCResponseTypeCode,
},
SessionID: "sessionID",
UserID: "userID",
AMR: []string{amr.PWD},
SessionID: "sessionID",
UserID: "userID",
AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
},
},
},
@@ -798,7 +797,7 @@ func TestCommands_AddAuthRequestCode(t *testing.T) {
"sessionID",
"userID",
testNow,
[]string{amr.PWD},
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
),
),
),
@@ -930,7 +929,7 @@ func TestCommands_ExchangeAuthCode(t *testing.T) {
"sessionID",
"userID",
testNow,
[]string{amr.PWD},
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
),
),
eventFromEventPusher(
@@ -972,9 +971,9 @@ func TestCommands_ExchangeAuthCode(t *testing.T) {
LoginHint: gu.Ptr("loginHint"),
HintUserID: gu.Ptr("hintUserID"),
},
SessionID: "sessionID",
UserID: "userID",
AMR: []string{"pwd"},
SessionID: "sessionID",
UserID: "userID",
AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
},
},
},

View File

@@ -7,7 +7,6 @@ import (
"time"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/oidc/amr"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
caos_errs "github.com/zitadel/zitadel/internal/errors"
@@ -195,7 +194,7 @@ func (c *OIDCSessionEvents) AddSession(ctx context.Context) {
c.authRequestWriteModel.ClientID,
c.authRequestWriteModel.Audience,
c.authRequestWriteModel.Scope,
amr.List(c.sessionWriteModel),
c.sessionWriteModel.AuthMethodTypes(),
c.sessionWriteModel.AuthenticationTime(),
))
}

View File

@@ -17,7 +17,7 @@ type OIDCSessionWriteModel struct {
ClientID string
Audience []string
Scope []string
AuthMethodsReferences []string
AuthMethods []domain.UserAuthMethodType
AuthTime time.Time
State domain.OIDCSessionState
AccessTokenExpiration time.Time
@@ -79,7 +79,7 @@ func (wm *OIDCSessionWriteModel) reduceAdded(e *oidcsession.AddedEvent) {
wm.ClientID = e.ClientID
wm.Audience = e.Audience
wm.Scope = e.Scope
wm.AuthMethodsReferences = e.AuthMethodsReferences
wm.AuthMethods = e.AuthMethods
wm.AuthTime = e.AuthTime
wm.State = domain.OIDCSessionStateActive
}

View File

@@ -11,7 +11,6 @@ import (
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/oidc/amr"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
caos_errs "github.com/zitadel/zitadel/internal/errors"
@@ -99,7 +98,7 @@ func TestCommands_AddOIDCSessionAccessToken(t *testing.T) {
"sessionID",
"userID",
testNow,
[]string{amr.PWD},
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
),
),
eventFromEventPusher(
@@ -151,7 +150,7 @@ func TestCommands_AddOIDCSessionAccessToken(t *testing.T) {
"sessionID",
"userID",
testNow,
[]string{amr.PWD},
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
),
),
eventFromEventPusher(
@@ -179,7 +178,7 @@ func TestCommands_AddOIDCSessionAccessToken(t *testing.T) {
[]*repository.Event{
eventFromEventPusherWithInstanceID("instanceID",
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid"}, []string{amr.PWD}, testNow),
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid"}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow),
),
eventFromEventPusherWithInstanceID("instanceID",
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
@@ -293,7 +292,7 @@ func TestCommands_AddOIDCSessionRefreshAndAccessToken(t *testing.T) {
"sessionID",
"userID",
testNow,
[]string{amr.PWD},
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
),
),
eventFromEventPusher(
@@ -345,7 +344,7 @@ func TestCommands_AddOIDCSessionRefreshAndAccessToken(t *testing.T) {
"sessionID",
"userID",
testNow,
[]string{amr.PWD},
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
),
),
eventFromEventPusher(
@@ -373,7 +372,7 @@ func TestCommands_AddOIDCSessionRefreshAndAccessToken(t *testing.T) {
[]*repository.Event{
eventFromEventPusherWithInstanceID("instanceID",
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "offline_access"}, []string{amr.PWD}, testNow),
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "offline_access"}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow),
),
eventFromEventPusherWithInstanceID("instanceID",
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
@@ -491,7 +490,7 @@ func TestCommands_ExchangeOIDCSessionRefreshAndAccessToken(t *testing.T) {
expectFilter(
eventFromEventPusher(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow),
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow),
),
eventFromEventPusher(
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
@@ -517,7 +516,7 @@ func TestCommands_ExchangeOIDCSessionRefreshAndAccessToken(t *testing.T) {
expectFilter(
eventFromEventPusher(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow),
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow),
),
eventFromEventPusher(
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
@@ -547,7 +546,7 @@ func TestCommands_ExchangeOIDCSessionRefreshAndAccessToken(t *testing.T) {
expectFilter(
eventFromEventPusherWithCreationDateNow(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow),
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow),
),
eventFromEventPusherWithCreationDateNow(
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
@@ -670,7 +669,7 @@ func TestCommands_OIDCSessionByRefreshToken(t *testing.T) {
expectFilter(
eventFromEventPusher(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow),
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow),
),
eventFromEventPusher(
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
@@ -695,7 +694,7 @@ func TestCommands_OIDCSessionByRefreshToken(t *testing.T) {
expectFilter(
eventFromEventPusher(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow),
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow),
),
eventFromEventPusher(
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
@@ -724,7 +723,7 @@ func TestCommands_OIDCSessionByRefreshToken(t *testing.T) {
expectFilter(
eventFromEventPusherWithCreationDateNow(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow),
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow),
),
eventFromEventPusherWithCreationDateNow(
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
@@ -753,7 +752,7 @@ func TestCommands_OIDCSessionByRefreshToken(t *testing.T) {
ClientID: "clientID",
Audience: []string{"audience"},
Scope: []string{"openid", "profile", "offline_access"},
AuthMethodsReferences: []string{amr.PWD},
AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
AuthTime: testNow,
State: domain.OIDCSessionStateActive,
RefreshTokenID: "refreshTokenID",
@@ -783,7 +782,7 @@ func TestCommands_OIDCSessionByRefreshToken(t *testing.T) {
assert.Equal(t, tt.res.model.ClientID, got.ClientID)
assert.Equal(t, tt.res.model.Audience, got.Audience)
assert.Equal(t, tt.res.model.Scope, got.Scope)
assert.Equal(t, tt.res.model.AuthMethodsReferences, got.AuthMethodsReferences)
assert.Equal(t, tt.res.model.AuthMethods, got.AuthMethods)
assert.WithinRange(t, got.AuthTime, tt.res.model.AuthTime.Add(-2*time.Second), tt.res.model.AuthTime.Add(2*time.Second))
assert.Equal(t, tt.res.model.State, got.State)
assert.Equal(t, tt.res.model.RefreshTokenID, got.RefreshTokenID)

View File

@@ -52,24 +52,6 @@ type SessionWriteModel struct {
aggregate *eventstore.Aggregate
}
func (wm *SessionWriteModel) IsPasswordChecked() bool {
return !wm.PasswordCheckedAt.IsZero()
}
func (wm *SessionWriteModel) IsPasskeyChecked() bool {
return !wm.PasskeyCheckedAt.IsZero()
}
func (wm *SessionWriteModel) IsU2FChecked() bool {
// TODO: implement with https://github.com/zitadel/zitadel/issues/5477
return false
}
func (wm *SessionWriteModel) IsOTPChecked() bool {
// TODO: implement with https://github.com/zitadel/zitadel/issues/5477
return false
}
func NewSessionWriteModel(sessionID string, resourceOwner string) *SessionWriteModel {
return &SessionWriteModel{
WriteModel: eventstore.WriteModel{
@@ -244,3 +226,27 @@ func (wm *SessionWriteModel) AuthenticationTime() time.Time {
}
return authTime
}
// AuthMethodTypes returns a list of UserAuthMethodTypes based on succeeded checks
func (wm *SessionWriteModel) AuthMethodTypes() []domain.UserAuthMethodType {
types := make([]domain.UserAuthMethodType, 0, domain.UserAuthMethodTypeIDP)
if !wm.PasswordCheckedAt.IsZero() {
types = append(types, domain.UserAuthMethodTypePassword)
}
if !wm.PasskeyCheckedAt.IsZero() {
types = append(types, domain.UserAuthMethodTypePasswordless)
}
if !wm.IntentCheckedAt.IsZero() {
types = append(types, domain.UserAuthMethodTypeIDP)
}
// TODO: add checks with https://github.com/zitadel/zitadel/issues/5477
/*
if !wm.TOTPCheckedAt.IsZero() {
types = append(types, domain.UserAuthMethodTypeOTP)
}
if !wm.U2FCheckedAt.IsZero() {
types = append(types, domain.UserAuthMethodTypeU2F)
}
*/
return types
}