mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-11 19:07:30 +00:00
feat(oidc): sid claim for id_tokens issued through login V1 (#8525)
# Which Problems Are Solved id_tokens issued for auth requests created through the login UI currently do not provide a sid claim. This is due to the fact that (SSO) sessions for the login UI do not have one and are only computed by the userAgent(ID), the user(ID) and the authentication checks of the latter. This prevents client to track sessions and terminate specific session on the end_session_endpoint. # How the Problems Are Solved - An `id` column is added to the `auth.user_sessions` table. - The `id` (prefixed with `V1_`) is set whenever a session is added or updated to active (from terminated) - The id is passed to the `oidc session` (as v2 sessionIDs), to expose it as `sid` claim # Additional Changes - refactored `getUpdateCols` to handle different column value types and add arguments for query # Additional Context - closes #8499 - relates to #8501
This commit is contained in:
27
cmd/setup/32.go
Normal file
27
cmd/setup/32.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
package setup
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
_ "embed"
|
||||||
|
|
||||||
|
"github.com/zitadel/zitadel/internal/database"
|
||||||
|
"github.com/zitadel/zitadel/internal/eventstore"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
//go:embed 32.sql
|
||||||
|
addAuthSessionID string
|
||||||
|
)
|
||||||
|
|
||||||
|
type AddAuthSessionID struct {
|
||||||
|
dbClient *database.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mig *AddAuthSessionID) Execute(ctx context.Context, _ eventstore.Event) error {
|
||||||
|
_, err := mig.dbClient.ExecContext(ctx, addAuthSessionID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mig *AddAuthSessionID) String() string {
|
||||||
|
return "32_add_auth_sessionID"
|
||||||
|
}
|
3
cmd/setup/32.sql
Normal file
3
cmd/setup/32.sql
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
ALTER TABLE IF EXISTS auth.user_sessions ADD COLUMN IF NOT EXISTS id TEXT;
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS user_session_id ON auth.user_sessions (id, instance_id);
|
@@ -116,6 +116,7 @@ type Steps struct {
|
|||||||
s29FillFieldsForProjectGrant *FillFieldsForProjectGrant
|
s29FillFieldsForProjectGrant *FillFieldsForProjectGrant
|
||||||
s30FillFieldsForOrgDomainVerified *FillFieldsForOrgDomainVerified
|
s30FillFieldsForOrgDomainVerified *FillFieldsForOrgDomainVerified
|
||||||
s31AddAggregateIndexToFields *AddAggregateIndexToFields
|
s31AddAggregateIndexToFields *AddAggregateIndexToFields
|
||||||
|
s32AddAuthSessionID *AddAuthSessionID
|
||||||
}
|
}
|
||||||
|
|
||||||
func MustNewSteps(v *viper.Viper) *Steps {
|
func MustNewSteps(v *viper.Viper) *Steps {
|
||||||
|
@@ -160,6 +160,7 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string)
|
|||||||
steps.s29FillFieldsForProjectGrant = &FillFieldsForProjectGrant{eventstore: eventstoreClient}
|
steps.s29FillFieldsForProjectGrant = &FillFieldsForProjectGrant{eventstore: eventstoreClient}
|
||||||
steps.s30FillFieldsForOrgDomainVerified = &FillFieldsForOrgDomainVerified{eventstore: eventstoreClient}
|
steps.s30FillFieldsForOrgDomainVerified = &FillFieldsForOrgDomainVerified{eventstore: eventstoreClient}
|
||||||
steps.s31AddAggregateIndexToFields = &AddAggregateIndexToFields{dbClient: esPusherDBClient}
|
steps.s31AddAggregateIndexToFields = &AddAggregateIndexToFields{dbClient: esPusherDBClient}
|
||||||
|
steps.s32AddAuthSessionID = &AddAuthSessionID{dbClient: esPusherDBClient}
|
||||||
|
|
||||||
err = projection.Create(ctx, projectionDBClient, eventstoreClient, config.Projections, nil, nil, nil)
|
err = projection.Create(ctx, projectionDBClient, eventstoreClient, config.Projections, nil, nil, nil)
|
||||||
logging.OnError(err).Fatal("unable to start projections")
|
logging.OnError(err).Fatal("unable to start projections")
|
||||||
@@ -216,6 +217,7 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string)
|
|||||||
steps.s21AddBlockFieldToLimits,
|
steps.s21AddBlockFieldToLimits,
|
||||||
steps.s25User11AddLowerFieldsToVerifiedEmail,
|
steps.s25User11AddLowerFieldsToVerifiedEmail,
|
||||||
steps.s27IDPTemplate6SAMLNameIDFormat,
|
steps.s27IDPTemplate6SAMLNameIDFormat,
|
||||||
|
steps.s32AddAuthSessionID,
|
||||||
} {
|
} {
|
||||||
mustExecuteMigration(ctx, eventstoreClient, step, "migration failed")
|
mustExecuteMigration(ctx, eventstoreClient, step, "migration failed")
|
||||||
}
|
}
|
||||||
|
@@ -564,6 +564,7 @@ func (s *Server) authResponseToken(authReq *AuthRequest, authorizer op.Authorize
|
|||||||
domain.TokenReasonAuthRequest,
|
domain.TokenReasonAuthRequest,
|
||||||
nil,
|
nil,
|
||||||
slices.Contains(scope, oidc.ScopeOfflineAccess),
|
slices.Contains(scope, oidc.ScopeOfflineAccess),
|
||||||
|
authReq.SessionID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
op.AuthRequestError(w, r, authReq, err, authorizer)
|
op.AuthRequestError(w, r, authReq, err, authorizer)
|
||||||
|
@@ -45,6 +45,7 @@ func (s *Server) ClientCredentialsExchange(ctx context.Context, r *op.ClientRequ
|
|||||||
domain.TokenReasonClientCredentials,
|
domain.TokenReasonClientCredentials,
|
||||||
nil,
|
nil,
|
||||||
false,
|
false,
|
||||||
|
"",
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@@ -85,6 +85,7 @@ func (s *Server) codeExchangeV1(ctx context.Context, client *Client, req *oidc.A
|
|||||||
domain.TokenReasonAuthRequest,
|
domain.TokenReasonAuthRequest,
|
||||||
nil,
|
nil,
|
||||||
slices.Contains(scope, oidc.ScopeOfflineAccess),
|
slices.Contains(scope, oidc.ScopeOfflineAccess),
|
||||||
|
authReq.SessionID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@@ -298,6 +298,7 @@ func (s *Server) createExchangeAccessToken(
|
|||||||
reason,
|
reason,
|
||||||
actor,
|
actor,
|
||||||
slices.Contains(scope, oidc.ScopeOfflineAccess),
|
slices.Contains(scope, oidc.ScopeOfflineAccess),
|
||||||
|
"",
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", "", 0, err
|
return "", "", "", 0, err
|
||||||
@@ -342,6 +343,7 @@ func (s *Server) createExchangeJWT(
|
|||||||
reason,
|
reason,
|
||||||
actor,
|
actor,
|
||||||
slices.Contains(scope, oidc.ScopeOfflineAccess),
|
slices.Contains(scope, oidc.ScopeOfflineAccess),
|
||||||
|
"",
|
||||||
)
|
)
|
||||||
accessToken, err = s.createJWT(ctx, client, session, getUserInfo, roleAssertion, getSigner)
|
accessToken, err = s.createJWT(ctx, client, session, getUserInfo, roleAssertion, getSigner)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@@ -53,6 +53,7 @@ func (s *Server) JWTProfile(ctx context.Context, r *op.Request[oidc.JWTProfileGr
|
|||||||
domain.TokenReasonJWTProfile,
|
domain.TokenReasonJWTProfile,
|
||||||
nil,
|
nil,
|
||||||
false,
|
false,
|
||||||
|
"",
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@@ -67,6 +67,7 @@ func (s *Server) refreshTokenV1(ctx context.Context, client *Client, r *op.Clien
|
|||||||
domain.TokenReasonRefresh,
|
domain.TokenReasonRefresh,
|
||||||
refreshToken.Actor,
|
refreshToken.Actor,
|
||||||
true,
|
true,
|
||||||
|
"",
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@@ -162,7 +162,7 @@ func (l *Login) handleDeviceAuthAction(w http.ResponseWriter, r *http.Request) {
|
|||||||
action := mux.Vars(r)["action"]
|
action := mux.Vars(r)["action"]
|
||||||
switch action {
|
switch action {
|
||||||
case deviceAuthAllowed:
|
case deviceAuthAllowed:
|
||||||
_, err = l.command.ApproveDeviceAuth(r.Context(), authDev.DeviceCode, authReq.UserID, authReq.UserOrgID, authReq.UserAuthMethodTypes(), authReq.AuthTime, authReq.PreferredLanguage, authReq.ToUserAgent())
|
_, err = l.command.ApproveDeviceAuth(r.Context(), authDev.DeviceCode, authReq.UserID, authReq.UserOrgID, authReq.UserAuthMethodTypes(), authReq.AuthTime, authReq.PreferredLanguage, authReq.ToUserAgent(), authReq.SessionID)
|
||||||
case deviceAuthDenied:
|
case deviceAuthDenied:
|
||||||
_, err = l.command.CancelDeviceAuth(r.Context(), authDev.DeviceCode, domain.DeviceAuthCanceledDenied)
|
_, err = l.command.CancelDeviceAuth(r.Context(), authDev.DeviceCode, domain.DeviceAuthCanceledDenied)
|
||||||
default:
|
default:
|
||||||
|
@@ -1048,6 +1048,7 @@ func (repo *AuthRequestRepo) nextSteps(ctx context.Context, request *domain.Auth
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
request.SessionID = userSession.ID
|
||||||
request.DisplayName = userSession.DisplayName
|
request.DisplayName = userSession.DisplayName
|
||||||
request.AvatarKey = userSession.AvatarKey
|
request.AvatarKey = userSession.AvatarKey
|
||||||
if user.HumanView != nil && user.HumanView.PreferredLanguage != "" {
|
if user.HumanView != nil && user.HumanView.PreferredLanguage != "" {
|
||||||
|
@@ -9,6 +9,7 @@ import (
|
|||||||
"github.com/zitadel/zitadel/internal/eventstore"
|
"github.com/zitadel/zitadel/internal/eventstore"
|
||||||
"github.com/zitadel/zitadel/internal/eventstore/handler/v2"
|
"github.com/zitadel/zitadel/internal/eventstore/handler/v2"
|
||||||
handler2 "github.com/zitadel/zitadel/internal/eventstore/handler/v2"
|
handler2 "github.com/zitadel/zitadel/internal/eventstore/handler/v2"
|
||||||
|
"github.com/zitadel/zitadel/internal/id"
|
||||||
query2 "github.com/zitadel/zitadel/internal/query"
|
query2 "github.com/zitadel/zitadel/internal/query"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -40,6 +41,7 @@ func Register(ctx context.Context, configs Config, view *view.View, queries *que
|
|||||||
configs.overwrite("UserSession"),
|
configs.overwrite("UserSession"),
|
||||||
view,
|
view,
|
||||||
queries,
|
queries,
|
||||||
|
id.SonyFlakeGenerator(),
|
||||||
))
|
))
|
||||||
|
|
||||||
projections = append(projections, newToken(ctx,
|
projections = append(projections, newToken(ctx,
|
||||||
|
@@ -2,12 +2,14 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"slices"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
auth_view "github.com/zitadel/zitadel/internal/auth/repository/eventsourcing/view"
|
auth_view "github.com/zitadel/zitadel/internal/auth/repository/eventsourcing/view"
|
||||||
"github.com/zitadel/zitadel/internal/domain"
|
"github.com/zitadel/zitadel/internal/domain"
|
||||||
"github.com/zitadel/zitadel/internal/eventstore"
|
"github.com/zitadel/zitadel/internal/eventstore"
|
||||||
"github.com/zitadel/zitadel/internal/eventstore/handler/v2"
|
"github.com/zitadel/zitadel/internal/eventstore/handler/v2"
|
||||||
|
"github.com/zitadel/zitadel/internal/id"
|
||||||
query2 "github.com/zitadel/zitadel/internal/query"
|
query2 "github.com/zitadel/zitadel/internal/query"
|
||||||
"github.com/zitadel/zitadel/internal/repository/instance"
|
"github.com/zitadel/zitadel/internal/repository/instance"
|
||||||
"github.com/zitadel/zitadel/internal/repository/org"
|
"github.com/zitadel/zitadel/internal/repository/org"
|
||||||
@@ -18,12 +20,15 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
userSessionTable = "auth.user_sessions"
|
userSessionTable = "auth.user_sessions"
|
||||||
|
|
||||||
|
IDPrefixV1 = "V1_"
|
||||||
)
|
)
|
||||||
|
|
||||||
type UserSession struct {
|
type UserSession struct {
|
||||||
queries *query2.Queries
|
queries *query2.Queries
|
||||||
view *auth_view.View
|
view *auth_view.View
|
||||||
es handler.EventStore
|
es handler.EventStore
|
||||||
|
idGenerator id.Generator
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ handler.Projection = (*UserSession)(nil)
|
var _ handler.Projection = (*UserSession)(nil)
|
||||||
@@ -33,14 +38,16 @@ func newUserSession(
|
|||||||
config handler.Config,
|
config handler.Config,
|
||||||
view *auth_view.View,
|
view *auth_view.View,
|
||||||
queries *query2.Queries,
|
queries *query2.Queries,
|
||||||
|
idGenerator id.Generator,
|
||||||
) *handler.Handler {
|
) *handler.Handler {
|
||||||
return handler.NewHandler(
|
return handler.NewHandler(
|
||||||
ctx,
|
ctx,
|
||||||
&config,
|
&config,
|
||||||
&UserSession{
|
&UserSession{
|
||||||
queries: queries,
|
queries: queries,
|
||||||
view: view,
|
view: view,
|
||||||
es: config.Eventstore,
|
es: config.Eventstore,
|
||||||
|
idGenerator: idGenerator,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -187,7 +194,7 @@ func (s *UserSession) Reducers() []handler.AggregateReducer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func sessionColumns(event eventstore.Event, columns ...handler.Column) ([]handler.Column, error) {
|
func (u *UserSession) sessionColumns(event eventstore.Event, columns ...handler.Column) ([]handler.Column, error) {
|
||||||
userAgent, err := agentIDFromSession(event)
|
userAgent, err := agentIDFromSession(event)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -203,14 +210,34 @@ func sessionColumns(event eventstore.Event, columns ...handler.Column) ([]handle
|
|||||||
}, columns...), nil
|
}, columns...), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *UserSession) sessionColumnsActivate(event eventstore.Event, columns ...handler.Column) ([]handler.Column, error) {
|
||||||
|
sessionID, err := u.idGenerator.Next()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
sessionID = IDPrefixV1 + sessionID
|
||||||
|
columns = slices.Grow(columns, 2)
|
||||||
|
columns = append(columns,
|
||||||
|
handler.NewCol(view_model.UserSessionKeyState, domain.UserSessionStateActive),
|
||||||
|
handler.NewCol(view_model.UserSessionKeyID,
|
||||||
|
handler.OnlySetValueInCase(userSessionTable, sessionID,
|
||||||
|
handler.ConditionOr(
|
||||||
|
handler.ColumnChangedCondition(userSessionTable, view_model.UserSessionKeyState, domain.UserSessionStateTerminated, domain.UserSessionStateActive),
|
||||||
|
handler.ColumnIsNullCondition(userSessionTable, view_model.UserSessionKeyID),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return u.sessionColumns(event, columns...)
|
||||||
|
}
|
||||||
|
|
||||||
func (u *UserSession) Reduce(event eventstore.Event) (_ *handler.Statement, err error) {
|
func (u *UserSession) Reduce(event eventstore.Event) (_ *handler.Statement, err error) {
|
||||||
// in case anything needs to be change here check if appendEvent function needs the change as well
|
// in case anything needs to be change here check if appendEvent function needs the change as well
|
||||||
switch event.Type() {
|
switch event.Type() {
|
||||||
case user.UserV1PasswordCheckSucceededType,
|
case user.UserV1PasswordCheckSucceededType,
|
||||||
user.HumanPasswordCheckSucceededType:
|
user.HumanPasswordCheckSucceededType:
|
||||||
columns, err := sessionColumns(event,
|
columns, err := u.sessionColumnsActivate(event,
|
||||||
handler.NewCol(view_model.UserSessionKeyPasswordVerification, event.CreatedAt()),
|
handler.NewCol(view_model.UserSessionKeyPasswordVerification, event.CreatedAt()),
|
||||||
handler.NewCol(view_model.UserSessionKeyState, domain.UserSessionStateActive),
|
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -218,9 +245,8 @@ func (u *UserSession) Reduce(event eventstore.Event) (_ *handler.Statement, err
|
|||||||
return handler.NewUpsertStatement(event, columns[0:3], columns), nil
|
return handler.NewUpsertStatement(event, columns[0:3], columns), nil
|
||||||
case user.UserV1PasswordCheckFailedType,
|
case user.UserV1PasswordCheckFailedType,
|
||||||
user.HumanPasswordCheckFailedType:
|
user.HumanPasswordCheckFailedType:
|
||||||
columns, err := sessionColumns(event,
|
columns, err := u.sessionColumnsActivate(event,
|
||||||
handler.NewCol(view_model.UserSessionKeyPasswordVerification, time.Time{}),
|
handler.NewCol(view_model.UserSessionKeyPasswordVerification, time.Time{}),
|
||||||
handler.NewCol(view_model.UserSessionKeyState, domain.UserSessionStateActive),
|
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -228,10 +254,9 @@ func (u *UserSession) Reduce(event eventstore.Event) (_ *handler.Statement, err
|
|||||||
return handler.NewUpsertStatement(event, columns[0:3], columns), nil
|
return handler.NewUpsertStatement(event, columns[0:3], columns), nil
|
||||||
case user.UserV1MFAOTPCheckSucceededType,
|
case user.UserV1MFAOTPCheckSucceededType,
|
||||||
user.HumanMFAOTPCheckSucceededType:
|
user.HumanMFAOTPCheckSucceededType:
|
||||||
columns, err := sessionColumns(event,
|
columns, err := u.sessionColumnsActivate(event,
|
||||||
handler.NewCol(view_model.UserSessionKeySecondFactorVerification, event.CreatedAt()),
|
handler.NewCol(view_model.UserSessionKeySecondFactorVerification, event.CreatedAt()),
|
||||||
handler.NewCol(view_model.UserSessionKeySecondFactorVerificationType, domain.MFATypeTOTP),
|
handler.NewCol(view_model.UserSessionKeySecondFactorVerificationType, domain.MFATypeTOTP),
|
||||||
handler.NewCol(view_model.UserSessionKeyState, domain.UserSessionStateActive),
|
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -240,9 +265,8 @@ func (u *UserSession) Reduce(event eventstore.Event) (_ *handler.Statement, err
|
|||||||
case user.UserV1MFAOTPCheckFailedType,
|
case user.UserV1MFAOTPCheckFailedType,
|
||||||
user.HumanMFAOTPCheckFailedType,
|
user.HumanMFAOTPCheckFailedType,
|
||||||
user.HumanU2FTokenCheckFailedType:
|
user.HumanU2FTokenCheckFailedType:
|
||||||
columns, err := sessionColumns(event,
|
columns, err := u.sessionColumnsActivate(event,
|
||||||
handler.NewCol(view_model.UserSessionKeySecondFactorVerification, time.Time{}),
|
handler.NewCol(view_model.UserSessionKeySecondFactorVerification, time.Time{}),
|
||||||
handler.NewCol(view_model.UserSessionKeyState, domain.UserSessionStateActive),
|
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -250,7 +274,7 @@ func (u *UserSession) Reduce(event eventstore.Event) (_ *handler.Statement, err
|
|||||||
return handler.NewUpsertStatement(event, columns[0:3], columns), nil
|
return handler.NewUpsertStatement(event, columns[0:3], columns), nil
|
||||||
case user.UserV1SignedOutType,
|
case user.UserV1SignedOutType,
|
||||||
user.HumanSignedOutType:
|
user.HumanSignedOutType:
|
||||||
columns, err := sessionColumns(event,
|
columns, err := u.sessionColumns(event,
|
||||||
handler.NewCol(view_model.UserSessionKeyPasswordlessVerification, time.Time{}),
|
handler.NewCol(view_model.UserSessionKeyPasswordlessVerification, time.Time{}),
|
||||||
handler.NewCol(view_model.UserSessionKeyPasswordVerification, time.Time{}),
|
handler.NewCol(view_model.UserSessionKeyPasswordVerification, time.Time{}),
|
||||||
handler.NewCol(view_model.UserSessionKeySecondFactorVerification, time.Time{}),
|
handler.NewCol(view_model.UserSessionKeySecondFactorVerification, time.Time{}),
|
||||||
@@ -270,10 +294,9 @@ func (u *UserSession) Reduce(event eventstore.Event) (_ *handler.Statement, err
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
columns, err := sessionColumns(event,
|
columns, err := u.sessionColumnsActivate(event,
|
||||||
handler.NewCol(view_model.UserSessionKeyExternalLoginVerification, event.CreatedAt()),
|
handler.NewCol(view_model.UserSessionKeyExternalLoginVerification, event.CreatedAt()),
|
||||||
handler.NewCol(view_model.UserSessionKeySelectedIDPConfigID, data.SelectedIDPConfigID),
|
handler.NewCol(view_model.UserSessionKeySelectedIDPConfigID, data.SelectedIDPConfigID),
|
||||||
handler.NewCol(view_model.UserSessionKeyState, domain.UserSessionStateActive),
|
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -285,10 +308,9 @@ func (u *UserSession) Reduce(event eventstore.Event) (_ *handler.Statement, err
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
columns, err := sessionColumns(event,
|
columns, err := u.sessionColumnsActivate(event,
|
||||||
handler.NewCol(view_model.UserSessionKeySecondFactorVerification, event.CreatedAt()),
|
handler.NewCol(view_model.UserSessionKeySecondFactorVerification, event.CreatedAt()),
|
||||||
handler.NewCol(view_model.UserSessionKeySecondFactorVerificationType, domain.MFATypeU2F),
|
handler.NewCol(view_model.UserSessionKeySecondFactorVerificationType, domain.MFATypeU2F),
|
||||||
handler.NewCol(view_model.UserSessionKeyState, domain.UserSessionStateActive),
|
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -300,11 +322,10 @@ func (u *UserSession) Reduce(event eventstore.Event) (_ *handler.Statement, err
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
columns, err := sessionColumns(event,
|
columns, err := u.sessionColumnsActivate(event,
|
||||||
handler.NewCol(view_model.UserSessionKeyPasswordlessVerification, event.CreatedAt()),
|
handler.NewCol(view_model.UserSessionKeyPasswordlessVerification, event.CreatedAt()),
|
||||||
handler.NewCol(view_model.UserSessionKeyMultiFactorVerification, event.CreatedAt()),
|
handler.NewCol(view_model.UserSessionKeyMultiFactorVerification, event.CreatedAt()),
|
||||||
handler.NewCol(view_model.UserSessionKeyMultiFactorVerificationType, domain.MFATypeU2FUserVerification),
|
handler.NewCol(view_model.UserSessionKeyMultiFactorVerificationType, domain.MFATypeU2FUserVerification),
|
||||||
handler.NewCol(view_model.UserSessionKeyState, domain.UserSessionStateActive),
|
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -316,10 +337,9 @@ func (u *UserSession) Reduce(event eventstore.Event) (_ *handler.Statement, err
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
columns, err := sessionColumns(event,
|
columns, err := u.sessionColumnsActivate(event,
|
||||||
handler.NewCol(view_model.UserSessionKeyPasswordlessVerification, time.Time{}),
|
handler.NewCol(view_model.UserSessionKeyPasswordlessVerification, time.Time{}),
|
||||||
handler.NewCol(view_model.UserSessionKeyMultiFactorVerification, time.Time{}),
|
handler.NewCol(view_model.UserSessionKeyMultiFactorVerification, time.Time{}),
|
||||||
handler.NewCol(view_model.UserSessionKeyState, domain.UserSessionStateActive),
|
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -429,8 +449,7 @@ func (u *UserSession) Reduce(event eventstore.Event) (_ *handler.Statement, err
|
|||||||
},
|
},
|
||||||
), nil
|
), nil
|
||||||
case user.HumanRegisteredType:
|
case user.HumanRegisteredType:
|
||||||
columns, err := sessionColumns(event,
|
columns, err := u.sessionColumnsActivate(event,
|
||||||
handler.NewCol(view_model.UserSessionKeyState, domain.UserSessionStateActive),
|
|
||||||
handler.NewCol(view_model.UserSessionKeyPasswordVerification, event.CreatedAt()),
|
handler.NewCol(view_model.UserSessionKeyPasswordVerification, event.CreatedAt()),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@@ -50,6 +50,7 @@ func (c *Commands) ApproveDeviceAuth(
|
|||||||
authTime time.Time,
|
authTime time.Time,
|
||||||
preferredLanguage *language.Tag,
|
preferredLanguage *language.Tag,
|
||||||
userAgent *domain.UserAgent,
|
userAgent *domain.UserAgent,
|
||||||
|
sessionID string,
|
||||||
) (*domain.ObjectDetails, error) {
|
) (*domain.ObjectDetails, error) {
|
||||||
model, err := c.getDeviceAuthWriteModelByDeviceCode(ctx, deviceCode)
|
model, err := c.getDeviceAuthWriteModelByDeviceCode(ctx, deviceCode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -58,7 +59,7 @@ func (c *Commands) ApproveDeviceAuth(
|
|||||||
if !model.State.Exists() {
|
if !model.State.Exists() {
|
||||||
return nil, zerrors.ThrowNotFound(nil, "COMMAND-Hief9", "Errors.DeviceAuth.NotFound")
|
return nil, zerrors.ThrowNotFound(nil, "COMMAND-Hief9", "Errors.DeviceAuth.NotFound")
|
||||||
}
|
}
|
||||||
pushedEvents, err := c.eventstore.Push(ctx, deviceauth.NewApprovedEvent(ctx, model.aggregate, userID, userOrgID, authMethods, authTime, preferredLanguage, userAgent))
|
pushedEvents, err := c.eventstore.Push(ctx, deviceauth.NewApprovedEvent(ctx, model.aggregate, userID, userOrgID, authMethods, authTime, preferredLanguage, userAgent, sessionID))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -151,7 +152,7 @@ func (c *Commands) CreateOIDCSessionFromDeviceAuth(ctx context.Context, deviceCo
|
|||||||
cmd.AddSession(ctx,
|
cmd.AddSession(ctx,
|
||||||
deviceAuthModel.UserID,
|
deviceAuthModel.UserID,
|
||||||
deviceAuthModel.UserOrgID,
|
deviceAuthModel.UserOrgID,
|
||||||
"",
|
deviceAuthModel.SessionID,
|
||||||
deviceAuthModel.ClientID,
|
deviceAuthModel.ClientID,
|
||||||
deviceAuthModel.Audience,
|
deviceAuthModel.Audience,
|
||||||
deviceAuthModel.Scopes,
|
deviceAuthModel.Scopes,
|
||||||
|
@@ -28,6 +28,7 @@ type DeviceAuthWriteModel struct {
|
|||||||
PreferredLanguage *language.Tag
|
PreferredLanguage *language.Tag
|
||||||
UserAgent *domain.UserAgent
|
UserAgent *domain.UserAgent
|
||||||
NeedRefreshToken bool
|
NeedRefreshToken bool
|
||||||
|
SessionID string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDeviceAuthWriteModel(deviceCode, resourceOwner string) *DeviceAuthWriteModel {
|
func NewDeviceAuthWriteModel(deviceCode, resourceOwner string) *DeviceAuthWriteModel {
|
||||||
@@ -60,6 +61,7 @@ func (m *DeviceAuthWriteModel) Reduce() error {
|
|||||||
m.AuthTime = e.AuthTime
|
m.AuthTime = e.AuthTime
|
||||||
m.PreferredLanguage = e.PreferredLanguage
|
m.PreferredLanguage = e.PreferredLanguage
|
||||||
m.UserAgent = e.UserAgent
|
m.UserAgent = e.UserAgent
|
||||||
|
m.SessionID = e.SessionID
|
||||||
case *deviceauth.CanceledEvent:
|
case *deviceauth.CanceledEvent:
|
||||||
m.State = e.Reason.State()
|
m.State = e.Reason.State()
|
||||||
case *deviceauth.DoneEvent:
|
case *deviceauth.DoneEvent:
|
||||||
|
@@ -137,6 +137,7 @@ func TestCommands_ApproveDeviceAuth(t *testing.T) {
|
|||||||
authTime time.Time
|
authTime time.Time
|
||||||
preferredLanguage *language.Tag
|
preferredLanguage *language.Tag
|
||||||
userAgent *domain.UserAgent
|
userAgent *domain.UserAgent
|
||||||
|
sessionID string
|
||||||
}
|
}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -161,6 +162,7 @@ func TestCommands_ApproveDeviceAuth(t *testing.T) {
|
|||||||
Description: gu.Ptr("firefox"),
|
Description: gu.Ptr("firefox"),
|
||||||
Header: http.Header{"foo": []string{"bar"}},
|
Header: http.Header{"foo": []string{"bar"}},
|
||||||
},
|
},
|
||||||
|
"sessionID",
|
||||||
},
|
},
|
||||||
wantErr: zerrors.ThrowNotFound(nil, "COMMAND-Hief9", "Errors.DeviceAuth.NotFound"),
|
wantErr: zerrors.ThrowNotFound(nil, "COMMAND-Hief9", "Errors.DeviceAuth.NotFound"),
|
||||||
},
|
},
|
||||||
@@ -188,6 +190,7 @@ func TestCommands_ApproveDeviceAuth(t *testing.T) {
|
|||||||
Description: gu.Ptr("firefox"),
|
Description: gu.Ptr("firefox"),
|
||||||
Header: http.Header{"foo": []string{"bar"}},
|
Header: http.Header{"foo": []string{"bar"}},
|
||||||
},
|
},
|
||||||
|
"sessionID",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
@@ -201,6 +204,7 @@ func TestCommands_ApproveDeviceAuth(t *testing.T) {
|
|||||||
Description: gu.Ptr("firefox"),
|
Description: gu.Ptr("firefox"),
|
||||||
Header: http.Header{"foo": []string{"bar"}},
|
Header: http.Header{"foo": []string{"bar"}},
|
||||||
},
|
},
|
||||||
|
"sessionID",
|
||||||
},
|
},
|
||||||
wantErr: pushErr,
|
wantErr: pushErr,
|
||||||
},
|
},
|
||||||
@@ -228,6 +232,7 @@ func TestCommands_ApproveDeviceAuth(t *testing.T) {
|
|||||||
Description: gu.Ptr("firefox"),
|
Description: gu.Ptr("firefox"),
|
||||||
Header: http.Header{"foo": []string{"bar"}},
|
Header: http.Header{"foo": []string{"bar"}},
|
||||||
},
|
},
|
||||||
|
"sessionID",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
@@ -241,6 +246,7 @@ func TestCommands_ApproveDeviceAuth(t *testing.T) {
|
|||||||
Description: gu.Ptr("firefox"),
|
Description: gu.Ptr("firefox"),
|
||||||
Header: http.Header{"foo": []string{"bar"}},
|
Header: http.Header{"foo": []string{"bar"}},
|
||||||
},
|
},
|
||||||
|
"sessionID",
|
||||||
},
|
},
|
||||||
wantDetails: &domain.ObjectDetails{
|
wantDetails: &domain.ObjectDetails{
|
||||||
ResourceOwner: "instance1",
|
ResourceOwner: "instance1",
|
||||||
@@ -252,7 +258,7 @@ func TestCommands_ApproveDeviceAuth(t *testing.T) {
|
|||||||
c := &Commands{
|
c := &Commands{
|
||||||
eventstore: tt.fields.eventstore,
|
eventstore: tt.fields.eventstore,
|
||||||
}
|
}
|
||||||
gotDetails, err := c.ApproveDeviceAuth(tt.args.ctx, tt.args.id, tt.args.userID, tt.args.userOrgID, tt.args.authMethods, tt.args.authTime, tt.args.preferredLanguage, tt.args.userAgent)
|
gotDetails, err := c.ApproveDeviceAuth(tt.args.ctx, tt.args.id, tt.args.userID, tt.args.userOrgID, tt.args.authMethods, tt.args.authTime, tt.args.preferredLanguage, tt.args.userAgent, tt.args.sessionID)
|
||||||
require.ErrorIs(t, err, tt.wantErr)
|
require.ErrorIs(t, err, tt.wantErr)
|
||||||
assertObjectDetails(t, tt.wantDetails, gotDetails)
|
assertObjectDetails(t, tt.wantDetails, gotDetails)
|
||||||
})
|
})
|
||||||
@@ -607,13 +613,14 @@ func TestCommands_CreateOIDCSessionFromDeviceAuth(t *testing.T) {
|
|||||||
Description: gu.Ptr("firefox"),
|
Description: gu.Ptr("firefox"),
|
||||||
Header: http.Header{"foo": []string{"bar"}},
|
Header: http.Header{"foo": []string{"bar"}},
|
||||||
},
|
},
|
||||||
|
"sessionID",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
expectFilter(), // token lifetime
|
expectFilter(), // token lifetime
|
||||||
expectPush(
|
expectPush(
|
||||||
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
|
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
|
||||||
"userID", "org1", "", "clientID", []string{"audience"}, []string{"openid", "offline_access"},
|
"userID", "org1", "sessionID", "clientID", []string{"audience"}, []string{"openid", "offline_access"},
|
||||||
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow, "", &language.Afrikaans, &domain.UserAgent{
|
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow, "", &language.Afrikaans, &domain.UserAgent{
|
||||||
FingerprintID: gu.Ptr("fp1"),
|
FingerprintID: gu.Ptr("fp1"),
|
||||||
IP: net.ParseIP("1.2.3.4"),
|
IP: net.ParseIP("1.2.3.4"),
|
||||||
@@ -657,7 +664,8 @@ func TestCommands_CreateOIDCSessionFromDeviceAuth(t *testing.T) {
|
|||||||
Description: gu.Ptr("firefox"),
|
Description: gu.Ptr("firefox"),
|
||||||
Header: http.Header{"foo": []string{"bar"}},
|
Header: http.Header{"foo": []string{"bar"}},
|
||||||
},
|
},
|
||||||
Reason: domain.TokenReasonAuthRequest,
|
Reason: domain.TokenReasonAuthRequest,
|
||||||
|
SessionID: "sessionID",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -687,13 +695,14 @@ func TestCommands_CreateOIDCSessionFromDeviceAuth(t *testing.T) {
|
|||||||
Description: gu.Ptr("firefox"),
|
Description: gu.Ptr("firefox"),
|
||||||
Header: http.Header{"foo": []string{"bar"}},
|
Header: http.Header{"foo": []string{"bar"}},
|
||||||
},
|
},
|
||||||
|
"sessionID",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
expectFilter(), // token lifetime
|
expectFilter(), // token lifetime
|
||||||
expectPush(
|
expectPush(
|
||||||
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
|
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
|
||||||
"userID", "org1", "", "clientID", []string{"audience"}, []string{"openid", "offline_access"},
|
"userID", "org1", "sessionID", "clientID", []string{"audience"}, []string{"openid", "offline_access"},
|
||||||
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow, "", &language.Afrikaans, &domain.UserAgent{
|
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow, "", &language.Afrikaans, &domain.UserAgent{
|
||||||
FingerprintID: gu.Ptr("fp1"),
|
FingerprintID: gu.Ptr("fp1"),
|
||||||
IP: net.ParseIP("1.2.3.4"),
|
IP: net.ParseIP("1.2.3.4"),
|
||||||
@@ -742,6 +751,7 @@ func TestCommands_CreateOIDCSessionFromDeviceAuth(t *testing.T) {
|
|||||||
},
|
},
|
||||||
Reason: domain.TokenReasonAuthRequest,
|
Reason: domain.TokenReasonAuthRequest,
|
||||||
RefreshToken: "VjJfb2lkY1Nlc3Npb25JRC1ydF9yZWZyZXNoVG9rZW5JRDp1c2VySUQ", //V2_oidcSessionID-rt_refreshTokenID:userID
|
RefreshToken: "VjJfb2lkY1Nlc3Npb25JRC1ydF9yZWZyZXNoVG9rZW5JRDp1c2VySUQ", //V2_oidcSessionID-rt_refreshTokenID:userID
|
||||||
|
SessionID: "sessionID",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@@ -136,6 +136,7 @@ func (c *Commands) CreateOIDCSession(ctx context.Context,
|
|||||||
reason domain.TokenReason,
|
reason domain.TokenReason,
|
||||||
actor *domain.TokenActor,
|
actor *domain.TokenActor,
|
||||||
needRefreshToken bool,
|
needRefreshToken bool,
|
||||||
|
sessionID string,
|
||||||
) (session *OIDCSession, err error) {
|
) (session *OIDCSession, err error) {
|
||||||
ctx, span := tracing.NewSpan(ctx)
|
ctx, span := tracing.NewSpan(ctx)
|
||||||
defer func() { span.EndWithError(err) }()
|
defer func() { span.EndWithError(err) }()
|
||||||
@@ -151,7 +152,7 @@ func (c *Commands) CreateOIDCSession(ctx context.Context,
|
|||||||
cmd.UserImpersonated(ctx, userID, resourceOwner, clientID, actor)
|
cmd.UserImpersonated(ctx, userID, resourceOwner, clientID, actor)
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.AddSession(ctx, userID, resourceOwner, "", clientID, audience, scope, authMethods, authTime, nonce, preferredLanguage, userAgent)
|
cmd.AddSession(ctx, userID, resourceOwner, sessionID, clientID, audience, scope, authMethods, authTime, nonce, preferredLanguage, userAgent)
|
||||||
if err = cmd.AddAccessToken(ctx, scope, userID, resourceOwner, reason, actor); err != nil {
|
if err = cmd.AddAccessToken(ctx, scope, userID, resourceOwner, reason, actor); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@@ -479,6 +479,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
|
|||||||
reason domain.TokenReason
|
reason domain.TokenReason
|
||||||
actor *domain.TokenActor
|
actor *domain.TokenActor
|
||||||
needRefreshToken bool
|
needRefreshToken bool
|
||||||
|
sessionID string
|
||||||
}
|
}
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -684,6 +685,89 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
|
|||||||
RefreshToken: "VjJfb2lkY1Nlc3Npb25JRC1ydF9yZWZyZXNoVG9rZW5JRDp1c2VySUQ", //V2_oidcSessionID-rt_refreshTokenID:userID
|
RefreshToken: "VjJfb2lkY1Nlc3Npb25JRC1ydF9yZWZyZXNoVG9rZW5JRDp1c2VySUQ", //V2_oidcSessionID-rt_refreshTokenID:userID
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "with sessionID",
|
||||||
|
fields: fields{
|
||||||
|
eventstore: expectEventstore(
|
||||||
|
expectFilter(), // token lifetime
|
||||||
|
expectPush(
|
||||||
|
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
|
||||||
|
"userID", "org1", "sessionID", "clientID", []string{"audience"}, []string{"openid", "offline_access"},
|
||||||
|
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow, "nonce", &language.Afrikaans,
|
||||||
|
&domain.UserAgent{
|
||||||
|
FingerprintID: gu.Ptr("fp1"),
|
||||||
|
IP: net.ParseIP("1.2.3.4"),
|
||||||
|
Description: gu.Ptr("firefox"),
|
||||||
|
Header: http.Header{"foo": []string{"bar"}},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
oidcsession.NewAccessTokenAddedEvent(context.Background(),
|
||||||
|
&oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
|
||||||
|
"at_accessTokenID", []string{"openid", "offline_access"}, time.Hour, domain.TokenReasonAuthRequest,
|
||||||
|
&domain.TokenActor{
|
||||||
|
UserID: "user2",
|
||||||
|
Issuer: "foo.com",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
user.NewUserTokenV2AddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate, "at_accessTokenID"),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
idGenerator: mock.NewIDGeneratorExpectIDs(t, "oidcSessionID", "accessTokenID"),
|
||||||
|
defaultAccessTokenLifetime: time.Hour,
|
||||||
|
defaultRefreshTokenLifetime: 7 * 24 * time.Hour,
|
||||||
|
defaultRefreshTokenIdleLifetime: 24 * time.Hour,
|
||||||
|
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||||
|
},
|
||||||
|
args: args{
|
||||||
|
ctx: context.Background(),
|
||||||
|
userID: "userID",
|
||||||
|
resourceOwner: "org1",
|
||||||
|
clientID: "clientID",
|
||||||
|
audience: []string{"audience"},
|
||||||
|
scope: []string{"openid", "offline_access"},
|
||||||
|
authMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
|
||||||
|
authTime: testNow,
|
||||||
|
nonce: "nonce",
|
||||||
|
preferredLanguage: &language.Afrikaans,
|
||||||
|
userAgent: &domain.UserAgent{
|
||||||
|
FingerprintID: gu.Ptr("fp1"),
|
||||||
|
IP: net.ParseIP("1.2.3.4"),
|
||||||
|
Description: gu.Ptr("firefox"),
|
||||||
|
Header: http.Header{"foo": []string{"bar"}},
|
||||||
|
},
|
||||||
|
reason: domain.TokenReasonAuthRequest,
|
||||||
|
actor: &domain.TokenActor{
|
||||||
|
UserID: "user2",
|
||||||
|
Issuer: "foo.com",
|
||||||
|
},
|
||||||
|
needRefreshToken: false,
|
||||||
|
sessionID: "sessionID",
|
||||||
|
},
|
||||||
|
want: &OIDCSession{
|
||||||
|
TokenID: "V2_oidcSessionID-at_accessTokenID",
|
||||||
|
ClientID: "clientID",
|
||||||
|
UserID: "userID",
|
||||||
|
Audience: []string{"audience"},
|
||||||
|
Expiration: time.Time{}.Add(time.Hour),
|
||||||
|
Scope: []string{"openid", "offline_access"},
|
||||||
|
AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
|
||||||
|
AuthTime: testNow,
|
||||||
|
Nonce: "nonce",
|
||||||
|
PreferredLanguage: &language.Afrikaans,
|
||||||
|
UserAgent: &domain.UserAgent{
|
||||||
|
FingerprintID: gu.Ptr("fp1"),
|
||||||
|
IP: net.ParseIP("1.2.3.4"),
|
||||||
|
Description: gu.Ptr("firefox"),
|
||||||
|
Header: http.Header{"foo": []string{"bar"}},
|
||||||
|
},
|
||||||
|
Reason: domain.TokenReasonAuthRequest,
|
||||||
|
Actor: &domain.TokenActor{
|
||||||
|
UserID: "user2",
|
||||||
|
Issuer: "foo.com",
|
||||||
|
},
|
||||||
|
SessionID: "sessionID",
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "impersonation not allowed",
|
name: "impersonation not allowed",
|
||||||
fields: fields{
|
fields: fields{
|
||||||
@@ -839,6 +923,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
|
|||||||
tt.args.reason,
|
tt.args.reason,
|
||||||
tt.args.actor,
|
tt.args.actor,
|
||||||
tt.args.needRefreshToken,
|
tt.args.needRefreshToken,
|
||||||
|
tt.args.sessionID,
|
||||||
)
|
)
|
||||||
require.ErrorIs(t, err, tt.wantErr)
|
require.ErrorIs(t, err, tt.wantErr)
|
||||||
if got != nil {
|
if got != nil {
|
||||||
|
@@ -62,6 +62,8 @@ type AuthRequest struct {
|
|||||||
SAMLRequestID string
|
SAMLRequestID string
|
||||||
// orgID the policies were last loaded with
|
// orgID the policies were last loaded with
|
||||||
policyOrgID string
|
policyOrgID string
|
||||||
|
// SessionID is set to the computed sessionID of the login session table
|
||||||
|
SessionID string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AuthRequest) SetPolicyOrgID(id string) {
|
func (a *AuthRequest) SetPolicyOrgID(id string) {
|
||||||
|
@@ -146,18 +146,17 @@ func NewUpsertStatement(event eventstore.Event, conflictCols []Column, values []
|
|||||||
conflictTarget[i] = col.Name
|
conflictTarget[i] = col.Name
|
||||||
}
|
}
|
||||||
|
|
||||||
config := execConfig{
|
config := execConfig{}
|
||||||
args: args,
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(values) == 0 {
|
if len(values) == 0 {
|
||||||
config.err = ErrNoValues
|
config.err = ErrNoValues
|
||||||
}
|
}
|
||||||
|
|
||||||
updateCols, updateVals := getUpdateCols(values, conflictTarget)
|
updateCols, updateVals, args := getUpdateCols(values, conflictTarget, params, args)
|
||||||
if len(updateCols) == 0 || len(updateVals) == 0 {
|
if len(updateCols) == 0 || len(updateVals) == 0 {
|
||||||
config.err = ErrNoValues
|
config.err = ErrNoValues
|
||||||
}
|
}
|
||||||
|
config.args = args
|
||||||
|
|
||||||
q := func(config execConfig) string {
|
q := func(config execConfig) string {
|
||||||
var updateStmt string
|
var updateStmt string
|
||||||
@@ -194,18 +193,78 @@ func OnlySetValueOnInsert(table string, value interface{}) *onlySetValueOnInsert
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getUpdateCols(cols []Column, conflictTarget []string) (updateCols, updateVals []string) {
|
type onlySetValueInCase struct {
|
||||||
|
Table string
|
||||||
|
Value interface{}
|
||||||
|
Condition Condition
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *onlySetValueInCase) GetValue() interface{} {
|
||||||
|
return c.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
// ColumnChangedCondition checks the current value and if it changed to a specific new value
|
||||||
|
func ColumnChangedCondition(table, column string, currentValue, newValue interface{}) Condition {
|
||||||
|
return func(param string) (string, []any) {
|
||||||
|
index, _ := strconv.Atoi(param)
|
||||||
|
return fmt.Sprintf("%[1]s.%[2]s = $%[3]d AND EXCLUDED.%[2]s = $%[4]d", table, column, index, index+1), []any{currentValue, newValue}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ColumnIsNullCondition checks if the current value is null
|
||||||
|
func ColumnIsNullCondition(table, column string) Condition {
|
||||||
|
return func(param string) (string, []any) {
|
||||||
|
return fmt.Sprintf("%[1]s.%[2]s IS NULL", table, column), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConditionOr links multiple Conditions by OR
|
||||||
|
func ConditionOr(conditions ...Condition) Condition {
|
||||||
|
return func(param string) (_ string, args []any) {
|
||||||
|
if len(conditions) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
b := strings.Builder{}
|
||||||
|
s, arg := conditions[0](param)
|
||||||
|
b.WriteString(s)
|
||||||
|
args = append(args, arg...)
|
||||||
|
for i := 1; i < len(conditions); i++ {
|
||||||
|
b.WriteString(" OR ")
|
||||||
|
s, condArgs := conditions[i](param)
|
||||||
|
b.WriteString(s)
|
||||||
|
args = append(args, condArgs...)
|
||||||
|
}
|
||||||
|
return b.String(), args
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnlySetValueInCase will only update to the desired value if the condition applies
|
||||||
|
func OnlySetValueInCase(table string, value interface{}, condition Condition) *onlySetValueInCase {
|
||||||
|
return &onlySetValueInCase{
|
||||||
|
Table: table,
|
||||||
|
Value: value,
|
||||||
|
Condition: condition,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getUpdateCols(cols []Column, conflictTarget, params []string, args []interface{}) (updateCols, updateVals []string, updatedArgs []interface{}) {
|
||||||
updateCols = make([]string, len(cols))
|
updateCols = make([]string, len(cols))
|
||||||
updateVals = make([]string, len(cols))
|
updateVals = make([]string, len(cols))
|
||||||
|
updatedArgs = args
|
||||||
|
|
||||||
for i := len(cols) - 1; i >= 0; i-- {
|
for i := len(cols) - 1; i >= 0; i-- {
|
||||||
col := cols[i]
|
col := cols[i]
|
||||||
table := "EXCLUDED"
|
|
||||||
if onlyOnInsert, ok := col.Value.(*onlySetValueOnInsert); ok {
|
|
||||||
table = onlyOnInsert.Table
|
|
||||||
}
|
|
||||||
updateCols[i] = col.Name
|
updateCols[i] = col.Name
|
||||||
updateVals[i] = table + "." + col.Name
|
switch v := col.Value.(type) {
|
||||||
|
case *onlySetValueOnInsert:
|
||||||
|
updateVals[i] = v.Table + "." + col.Name
|
||||||
|
case *onlySetValueInCase:
|
||||||
|
s, condArgs := v.Condition(strconv.Itoa(len(params) + 1))
|
||||||
|
updatedArgs = append(updatedArgs, condArgs...)
|
||||||
|
updateVals[i] = fmt.Sprintf("CASE WHEN %[1]s THEN EXCLUDED.%[2]s ELSE %[3]s.%[2]s END", s, col.Name, v.Table)
|
||||||
|
default:
|
||||||
|
updateVals[i] = "EXCLUDED" + "." + col.Name
|
||||||
|
}
|
||||||
for _, conflict := range conflictTarget {
|
for _, conflict := range conflictTarget {
|
||||||
if conflict == col.Name {
|
if conflict == col.Name {
|
||||||
copy(updateCols[i:], updateCols[i+1:])
|
copy(updateCols[i:], updateCols[i+1:])
|
||||||
@@ -221,7 +280,7 @@ func getUpdateCols(cols []Column, conflictTarget []string) (updateCols, updateVa
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return updateCols, updateVals
|
return updateCols, updateVals, updatedArgs
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUpdateStatement(event eventstore.Event, values []Column, conditions []Condition, opts ...execOption) *Statement {
|
func NewUpdateStatement(event eventstore.Event, values []Column, conditions []Condition, opts ...execOption) *Statement {
|
||||||
|
@@ -451,6 +451,55 @@ func TestNewUpsertStatement(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "correct all *onlySetValueInCase",
|
||||||
|
args: args{
|
||||||
|
table: "my_table",
|
||||||
|
event: &testEvent{
|
||||||
|
aggregateType: "agg",
|
||||||
|
sequence: 1,
|
||||||
|
previousSequence: 0,
|
||||||
|
},
|
||||||
|
conflictCols: []Column{
|
||||||
|
NewCol("col1", nil),
|
||||||
|
},
|
||||||
|
values: []Column{
|
||||||
|
{
|
||||||
|
Name: "col1",
|
||||||
|
Value: "val1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "col2",
|
||||||
|
Value: &onlySetValueInCase{
|
||||||
|
Table: "some.table",
|
||||||
|
Value: "val2",
|
||||||
|
Condition: ConditionOr(
|
||||||
|
ColumnChangedCondition("some.table", "val3", 0, 1),
|
||||||
|
ColumnIsNullCondition("some.table", "val3"),
|
||||||
|
),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: want{
|
||||||
|
table: "my_table",
|
||||||
|
aggregateType: "agg",
|
||||||
|
sequence: 1,
|
||||||
|
previousSequence: 1,
|
||||||
|
executer: &wantExecuter{
|
||||||
|
params: []params{
|
||||||
|
{
|
||||||
|
query: "INSERT INTO my_table (col1, col2) VALUES ($1, $2) ON CONFLICT (col1) DO UPDATE SET col2 = CASE WHEN some.table.val3 = $3 AND EXCLUDED.val3 = $4 OR some.table.val3 IS NULL THEN EXCLUDED.col2 ELSE some.table.col2 END",
|
||||||
|
args: []interface{}{"val1", "val2", 0, 1},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
shouldExecute: true,
|
||||||
|
},
|
||||||
|
isErr: func(err error) bool {
|
||||||
|
return err == nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@@ -72,6 +72,7 @@ type ApprovedEvent struct {
|
|||||||
AuthTime time.Time
|
AuthTime time.Time
|
||||||
PreferredLanguage *language.Tag
|
PreferredLanguage *language.Tag
|
||||||
UserAgent *domain.UserAgent
|
UserAgent *domain.UserAgent
|
||||||
|
SessionID string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *ApprovedEvent) SetBaseEvent(b *eventstore.BaseEvent) {
|
func (e *ApprovedEvent) SetBaseEvent(b *eventstore.BaseEvent) {
|
||||||
@@ -95,17 +96,19 @@ func NewApprovedEvent(
|
|||||||
authTime time.Time,
|
authTime time.Time,
|
||||||
preferredLanguage *language.Tag,
|
preferredLanguage *language.Tag,
|
||||||
userAgent *domain.UserAgent,
|
userAgent *domain.UserAgent,
|
||||||
|
sessionID string,
|
||||||
) *ApprovedEvent {
|
) *ApprovedEvent {
|
||||||
return &ApprovedEvent{
|
return &ApprovedEvent{
|
||||||
eventstore.NewBaseEventForPush(
|
BaseEvent: eventstore.NewBaseEventForPush(
|
||||||
ctx, aggregate, ApprovedEventType,
|
ctx, aggregate, ApprovedEventType,
|
||||||
),
|
),
|
||||||
userID,
|
UserID: userID,
|
||||||
userOrgID,
|
UserOrgID: userOrgID,
|
||||||
userAuthMethods,
|
UserAuthMethods: userAuthMethods,
|
||||||
authTime,
|
AuthTime: authTime,
|
||||||
preferredLanguage,
|
PreferredLanguage: preferredLanguage,
|
||||||
userAgent,
|
UserAgent: userAgent,
|
||||||
|
SessionID: sessionID,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -27,6 +27,7 @@ type UserSessionView struct {
|
|||||||
MultiFactorVerification time.Time
|
MultiFactorVerification time.Time
|
||||||
MultiFactorVerificationType domain.MFAType
|
MultiFactorVerificationType domain.MFAType
|
||||||
Sequence uint64
|
Sequence uint64
|
||||||
|
ID string
|
||||||
}
|
}
|
||||||
|
|
||||||
type UserSessionSearchRequest struct {
|
type UserSessionSearchRequest struct {
|
||||||
|
@@ -32,6 +32,7 @@ const (
|
|||||||
UserSessionKeyPasswordlessVerification = "passwordless_verification"
|
UserSessionKeyPasswordlessVerification = "passwordless_verification"
|
||||||
UserSessionKeyExternalLoginVerification = "external_login_verification"
|
UserSessionKeyExternalLoginVerification = "external_login_verification"
|
||||||
UserSessionKeySelectedIDPConfigID = "selected_idp_config_id"
|
UserSessionKeySelectedIDPConfigID = "selected_idp_config_id"
|
||||||
|
UserSessionKeyID = "id"
|
||||||
)
|
)
|
||||||
|
|
||||||
type UserSessionView struct {
|
type UserSessionView struct {
|
||||||
@@ -59,6 +60,7 @@ type UserSessionView struct {
|
|||||||
MultiFactorVerificationType sql.NullInt32 `json:"-" gorm:"column:multi_factor_verification_type"`
|
MultiFactorVerificationType sql.NullInt32 `json:"-" gorm:"column:multi_factor_verification_type"`
|
||||||
Sequence uint64 `json:"-" gorm:"column:sequence"`
|
Sequence uint64 `json:"-" gorm:"column:sequence"`
|
||||||
InstanceID string `json:"instanceID" gorm:"column:instance_id;primary_key"`
|
InstanceID string `json:"instanceID" gorm:"column:instance_id;primary_key"`
|
||||||
|
ID sql.NullString `json:"id" gorm:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type userAgentIDPayload struct {
|
type userAgentIDPayload struct {
|
||||||
@@ -95,6 +97,7 @@ func UserSessionToModel(userSession *UserSessionView) *model.UserSessionView {
|
|||||||
MultiFactorVerification: userSession.MultiFactorVerification.Time,
|
MultiFactorVerification: userSession.MultiFactorVerification.Time,
|
||||||
MultiFactorVerificationType: domain.MFAType(userSession.MultiFactorVerificationType.Int32),
|
MultiFactorVerificationType: domain.MFAType(userSession.MultiFactorVerificationType.Int32),
|
||||||
Sequence: userSession.Sequence,
|
Sequence: userSession.Sequence,
|
||||||
|
ID: userSession.ID.String,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -17,7 +17,8 @@ SELECT s.creation_date,
|
|||||||
s.multi_factor_verification,
|
s.multi_factor_verification,
|
||||||
s.multi_factor_verification_type,
|
s.multi_factor_verification_type,
|
||||||
s.sequence,
|
s.sequence,
|
||||||
s.instance_id
|
s.instance_id,
|
||||||
|
s.id
|
||||||
FROM auth.user_sessions s
|
FROM auth.user_sessions s
|
||||||
LEFT JOIN projections.users13 u ON s.user_id = u.id AND s.instance_id = u.instance_id
|
LEFT JOIN projections.users13 u ON s.user_id = u.id AND s.instance_id = u.instance_id
|
||||||
LEFT JOIN projections.users13_humans h ON s.user_id = h.user_id AND s.instance_id = h.instance_id
|
LEFT JOIN projections.users13_humans h ON s.user_id = h.user_id AND s.instance_id = h.instance_id
|
||||||
|
@@ -65,6 +65,7 @@ func scanUserSession(row *sql.Row) (*model.UserSessionView, error) {
|
|||||||
&session.MultiFactorVerificationType,
|
&session.MultiFactorVerificationType,
|
||||||
&session.Sequence,
|
&session.Sequence,
|
||||||
&session.InstanceID,
|
&session.InstanceID,
|
||||||
|
&session.ID,
|
||||||
)
|
)
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
return nil, zerrors.ThrowNotFound(nil, "VIEW-NGBs1", "Errors.UserSession.NotFound")
|
return nil, zerrors.ThrowNotFound(nil, "VIEW-NGBs1", "Errors.UserSession.NotFound")
|
||||||
@@ -97,6 +98,7 @@ func scanUserSessions(rows *sql.Rows) ([]*model.UserSessionView, error) {
|
|||||||
&session.MultiFactorVerificationType,
|
&session.MultiFactorVerificationType,
|
||||||
&session.Sequence,
|
&session.Sequence,
|
||||||
&session.InstanceID,
|
&session.InstanceID,
|
||||||
|
&session.ID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@@ -17,7 +17,8 @@ SELECT s.creation_date,
|
|||||||
s.multi_factor_verification,
|
s.multi_factor_verification,
|
||||||
s.multi_factor_verification_type,
|
s.multi_factor_verification_type,
|
||||||
s.sequence,
|
s.sequence,
|
||||||
s.instance_id
|
s.instance_id,
|
||||||
|
s.id
|
||||||
FROM auth.user_sessions s
|
FROM auth.user_sessions s
|
||||||
LEFT JOIN projections.users13 u ON s.user_id = u.id AND s.instance_id = u.instance_id
|
LEFT JOIN projections.users13 u ON s.user_id = u.id AND s.instance_id = u.instance_id
|
||||||
LEFT JOIN projections.users13_humans h ON s.user_id = h.user_id AND s.instance_id = h.instance_id
|
LEFT JOIN projections.users13_humans h ON s.user_id = h.user_id AND s.instance_id = h.instance_id
|
||||||
|
Reference in New Issue
Block a user