feat: add possibility to set an expiration to a session (#6851)

* add lifetime to session api

* extend session with lifetime

* check session token expiration

* fix typo

* integration test to check session token expiration

* integration test to check session token expiration

* i18n

* cleanup

* improve tests

* prevent negative lifetime

* fix error message

* fix lifetime check
This commit is contained in:
Livio Spring
2023-11-06 11:48:28 +02:00
committed by GitHub
parent ce322323aa
commit f3b8a3aece
35 changed files with 608 additions and 151 deletions

View File

@@ -95,8 +95,8 @@ func (c *Commands) LinkSessionToAuthRequest(ctx context.Context, id, sessionID,
if err != nil {
return nil, nil, err
}
if sessionWriteModel.State == domain.SessionStateUnspecified {
return nil, nil, errors.ThrowNotFound(nil, "COMMAND-x0099887", "Errors.Session.NotExisting")
if err = sessionWriteModel.CheckIsActive(); err != nil {
return nil, nil, err
}
if err := c.sessionPermission(ctx, sessionWriteModel, sessionToken, domain.PermissionSessionWrite); err != nil {
return nil, nil, err

View File

@@ -327,7 +327,67 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
sessionID: "sessionID",
},
res{
wantErr: caos_errs.ThrowNotFound(nil, "COMMAND-x0099887", "Errors.Session.NotExisting"),
wantErr: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-Flk38", "Errors.Session.NotExisting"),
},
},
{
"session expired",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"loginClient",
"clientID",
"redirectURI",
"state",
"nonce",
[]string{"openid"},
[]string{"audience"},
domain.OIDCResponseTypeCode,
nil,
nil,
nil,
nil,
nil,
nil,
),
),
),
expectFilter(
eventFromEventPusher(
session.NewAddedEvent(mockCtx,
&session.NewAggregate("sessionID", "org1").Aggregate,
&domain.UserAgent{
FingerprintID: gu.Ptr("fp1"),
IP: net.ParseIP("1.2.3.4"),
Description: gu.Ptr("firefox"),
Header: http.Header{"foo": []string{"bar"}},
},
)),
eventFromEventPusher(
session.NewUserCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate,
"userID", testNow.Add(-5*time.Minute)),
),
eventFromEventPusher(
session.NewPasswordCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate,
testNow.Add(-5*time.Minute)),
),
eventFromEventPusher(
session.NewLifetimeSetEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate,
2*time.Minute),
),
),
),
},
args{
ctx: mockCtx,
id: "V2_id",
sessionID: "sessionID",
sessionToken: "token",
},
res{
wantErr: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-Hkl3d", "Errors.Session.Expired"),
},
},
{
@@ -475,6 +535,10 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
session.NewPasswordCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate,
testNow),
),
eventFromEventPusherWithCreationDateNow(
session.NewLifetimeSetEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate,
2*time.Minute),
),
),
expectPush(
authrequest.NewSessionLinkedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
@@ -559,6 +623,10 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
session.NewPasswordCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate,
testNow),
),
eventFromEventPusherWithCreationDateNow(
session.NewLifetimeSetEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate,
2*time.Minute),
),
),
expectPush(
authrequest.NewSessionLinkedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,

View File

@@ -11,7 +11,6 @@ import (
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
caos_errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/id"
@@ -159,8 +158,8 @@ func (c *Commands) newOIDCSessionAddEvents(ctx context.Context, authRequestID st
if err != nil {
return nil, err
}
if sessionWriteModel.State != domain.SessionStateActive {
return nil, caos_errs.ThrowPreconditionFailed(nil, "OIDCS-sjkl3", "Errors.Session.Terminated")
if err = sessionWriteModel.CheckIsActive(); err != nil {
return nil, err
}
resourceOwner, err := c.getResourceOwnerOfSessionUser(ctx, sessionWriteModel.UserID, sessionWriteModel.InstanceID)
if err != nil {

View File

@@ -119,7 +119,7 @@ func TestCommands_AddOIDCSessionAccessToken(t *testing.T) {
authRequestID: "V2_authRequestID",
},
res{
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-sjkl3", "Errors.Session.Terminated"),
err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-Flk38", "Errors.Session.NotExisting"),
},
},
{
@@ -320,7 +320,7 @@ func TestCommands_AddOIDCSessionRefreshAndAccessToken(t *testing.T) {
authRequestID: "V2_authRequestID",
},
res{
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-sjkl3", "Errors.Session.Terminated"),
err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-Flk38", "Errors.Session.NotExisting"),
},
},
{

View File

@@ -252,6 +252,17 @@ func (s *SessionCommands) ChangeMetadata(ctx context.Context, metadata map[strin
}
}
func (s *SessionCommands) SetLifetime(ctx context.Context, lifetime time.Duration) error {
if lifetime < 0 {
return caos_errs.ThrowInvalidArgument(nil, "COMMAND-asEG4", "Errors.Session.PositiveLifetime")
}
if lifetime == 0 {
return nil
}
s.eventCommands = append(s.eventCommands, session.NewLifetimeSetEvent(ctx, s.sessionWriteModel.aggregate, lifetime))
return nil
}
func (s *SessionCommands) gethumanWriteModel(ctx context.Context) (*HumanWriteModel, error) {
if s.sessionWriteModel.UserID == "" {
return nil, caos_errs.ThrowPreconditionFailed(nil, "COMMAND-eeR2e", "Errors.User.UserIDMissing")
@@ -280,7 +291,7 @@ func (s *SessionCommands) commands(ctx context.Context) (string, []eventstore.Co
return token, s.eventCommands, nil
}
func (c *Commands) CreateSession(ctx context.Context, cmds []SessionCommand, metadata map[string][]byte, userAgent *domain.UserAgent) (set *SessionChanged, err error) {
func (c *Commands) CreateSession(ctx context.Context, cmds []SessionCommand, metadata map[string][]byte, userAgent *domain.UserAgent, lifetime time.Duration) (set *SessionChanged, err error) {
sessionID, err := c.idGenerator.Next()
if err != nil {
return nil, err
@@ -292,10 +303,10 @@ func (c *Commands) CreateSession(ctx context.Context, cmds []SessionCommand, met
}
cmd := c.NewSessionCommands(cmds, sessionWriteModel)
cmd.Start(ctx, userAgent)
return c.updateSession(ctx, cmd, metadata)
return c.updateSession(ctx, cmd, metadata, lifetime)
}
func (c *Commands) UpdateSession(ctx context.Context, sessionID, sessionToken string, cmds []SessionCommand, metadata map[string][]byte) (set *SessionChanged, err error) {
func (c *Commands) UpdateSession(ctx context.Context, sessionID, sessionToken string, cmds []SessionCommand, metadata map[string][]byte, lifetime time.Duration) (set *SessionChanged, err error) {
sessionWriteModel := NewSessionWriteModel(sessionID, authz.GetCtxData(ctx).OrgID)
err = c.eventstore.FilterToQueryReducer(ctx, sessionWriteModel)
if err != nil {
@@ -305,7 +316,7 @@ func (c *Commands) UpdateSession(ctx context.Context, sessionID, sessionToken st
return nil, err
}
cmd := c.NewSessionCommands(cmds, sessionWriteModel)
return c.updateSession(ctx, cmd, metadata)
return c.updateSession(ctx, cmd, metadata, lifetime)
}
func (c *Commands) TerminateSession(ctx context.Context, sessionID string, sessionToken string) (*domain.ObjectDetails, error) {
@@ -326,7 +337,7 @@ func (c *Commands) terminateSession(ctx context.Context, sessionID, sessionToken
return nil, err
}
}
if sessionWriteModel.State != domain.SessionStateActive {
if sessionWriteModel.CheckIsActive() != nil {
return writeModelToObjectDetails(&sessionWriteModel.WriteModel), nil
}
terminate := session.NewTerminateEvent(ctx, &session.NewAggregate(sessionWriteModel.AggregateID, sessionWriteModel.ResourceOwner).Aggregate)
@@ -342,15 +353,19 @@ func (c *Commands) terminateSession(ctx context.Context, sessionID, sessionToken
}
// updateSession execute the [SessionCommands] where new events will be created and as well as for metadata (changes)
func (c *Commands) updateSession(ctx context.Context, checks *SessionCommands, metadata map[string][]byte) (set *SessionChanged, err error) {
if checks.sessionWriteModel.State == domain.SessionStateTerminated {
return nil, caos_errs.ThrowPreconditionFailed(nil, "COMAND-SAjeh", "Errors.Session.Terminated")
func (c *Commands) updateSession(ctx context.Context, checks *SessionCommands, metadata map[string][]byte, lifetime time.Duration) (set *SessionChanged, err error) {
if err = checks.sessionWriteModel.CheckNotInvalidated(); err != nil {
return nil, err
}
if err := checks.Exec(ctx); err != nil {
// TODO: how to handle failed checks (e.g. pw wrong) https://github.com/zitadel/zitadel/issues/5807
return nil, err
}
checks.ChangeMetadata(ctx, metadata)
err = checks.SetLifetime(ctx, lifetime)
if err != nil {
return nil, err
}
sessionToken, cmds, err := checks.commands(ctx)
if err != nil {
return nil, err

View File

@@ -5,6 +5,7 @@ import (
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/session"
)
@@ -48,6 +49,7 @@ type SessionWriteModel struct {
WebAuthNUserVerified bool
Metadata map[string][]byte
State domain.SessionState
Expiration time.Time
WebAuthNChallenge *WebAuthNChallengeModel
OTPSMSCodeChallenge *OTPCode
@@ -94,6 +96,8 @@ func (wm *SessionWriteModel) Reduce() error {
wm.reduceOTPEmailChecked(e)
case *session.TokenSetEvent:
wm.reduceTokenSet(e)
case *session.LifetimeSetEvent:
wm.reduceLifetimeSet(e)
case *session.TerminateEvent:
wm.reduceTerminate()
}
@@ -120,6 +124,7 @@ func (wm *SessionWriteModel) Query() *eventstore.SearchQueryBuilder {
session.OTPEmailCheckedType,
session.TokenSetType,
session.MetadataSetType,
session.LifetimeSetType,
session.TerminateType,
).
Builder()
@@ -196,6 +201,10 @@ func (wm *SessionWriteModel) reduceTokenSet(e *session.TokenSetEvent) {
wm.TokenID = e.TokenID
}
func (wm *SessionWriteModel) reduceLifetimeSet(e *session.LifetimeSetEvent) {
wm.Expiration = e.CreationDate().Add(e.Lifetime)
}
func (wm *SessionWriteModel) reduceTerminate() {
wm.State = domain.SessionStateTerminated
}
@@ -245,3 +254,23 @@ func (wm *SessionWriteModel) AuthMethodTypes() []domain.UserAuthMethodType {
}
return types
}
// CheckNotInvalidated checks that the session was not invalidated either manually ([session.TerminateType])
// or automatically (expired).
func (wm *SessionWriteModel) CheckNotInvalidated() error {
if wm.State == domain.SessionStateTerminated {
return errors.ThrowPreconditionFailed(nil, "COMMAND-Hewfq", "Errors.Session.Terminated")
}
if !wm.Expiration.IsZero() && wm.Expiration.Before(time.Now()) {
return errors.ThrowPreconditionFailed(nil, "COMMAND-Hkl3d", "Errors.Session.Expired")
}
return nil
}
// CheckIsActive checks that the session was not invalidated ([CheckNotInvalidated]) and actually already exists.
func (wm *SessionWriteModel) CheckIsActive() error {
if wm.State == domain.SessionStateUnspecified {
return errors.ThrowPreconditionFailed(nil, "COMMAND-Flk38", "Errors.Session.NotExisting")
}
return wm.CheckNotInvalidated()
}

View File

@@ -152,6 +152,7 @@ func TestCommands_CreateSession(t *testing.T) {
checks []SessionCommand
metadata map[string][]byte
userAgent *domain.UserAgent
lifetime time.Duration
}
type res struct {
want *SessionChanged
@@ -192,6 +193,33 @@ func TestCommands_CreateSession(t *testing.T) {
err: caos_errs.ThrowInternal(nil, "id", "filter failed"),
},
},
{
"negative lifetime",
fields{
idGenerator: mock.NewIDGeneratorExpectIDs(t, "sessionID"),
tokenCreator: func(sessionID string) (string, string, error) {
return "tokenID",
"token",
nil
},
},
args{
ctx: authz.NewMockContext("", "org1", ""),
userAgent: &domain.UserAgent{
FingerprintID: gu.Ptr("fp1"),
IP: net.ParseIP("1.2.3.4"),
Description: gu.Ptr("firefox"),
Header: http.Header{"foo": []string{"bar"}},
},
lifetime: -10 * time.Minute,
},
[]expect{
expectFilter(),
},
res{
err: caos_errs.ThrowInvalidArgument(nil, "COMMAND-asEG4", "Errors.Session.PositiveLifetime"),
},
},
{
"empty session",
fields{
@@ -210,6 +238,7 @@ func TestCommands_CreateSession(t *testing.T) {
Description: gu.Ptr("firefox"),
Header: http.Header{"foo": []string{"bar"}},
},
lifetime: 10 * time.Minute,
},
[]expect{
expectFilter(),
@@ -223,6 +252,7 @@ func TestCommands_CreateSession(t *testing.T) {
Header: http.Header{"foo": []string{"bar"}},
},
),
session.NewLifetimeSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, 10*time.Minute),
session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate,
"tokenID",
),
@@ -245,7 +275,7 @@ func TestCommands_CreateSession(t *testing.T) {
idGenerator: tt.fields.idGenerator,
sessionTokenCreator: tt.fields.tokenCreator,
}
got, err := c.CreateSession(tt.args.ctx, tt.args.checks, tt.args.metadata, tt.args.userAgent)
got, err := c.CreateSession(tt.args.ctx, tt.args.checks, tt.args.metadata, tt.args.userAgent, tt.args.lifetime)
require.ErrorIs(t, err, tt.res.err)
assert.Equal(t, tt.res.want, got)
})
@@ -263,6 +293,7 @@ func TestCommands_UpdateSession(t *testing.T) {
sessionToken string
checks []SessionCommand
metadata map[string][]byte
lifetime time.Duration
}
type res struct {
want *SessionChanged
@@ -368,7 +399,7 @@ func TestCommands_UpdateSession(t *testing.T) {
eventstore: tt.fields.eventstore,
sessionTokenVerifier: tt.fields.tokenVerifier,
}
got, err := c.UpdateSession(tt.args.ctx, tt.args.sessionID, tt.args.sessionToken, tt.args.checks, tt.args.metadata)
got, err := c.UpdateSession(tt.args.ctx, tt.args.sessionID, tt.args.sessionToken, tt.args.checks, tt.args.metadata, tt.args.lifetime)
require.ErrorIs(t, err, tt.res.err)
assert.Equal(t, tt.res.want, got)
})
@@ -397,6 +428,7 @@ func TestCommands_updateSession(t *testing.T) {
ctx context.Context
checks *SessionCommands
metadata map[string][]byte
lifetime time.Duration
}
type res struct {
want *SessionChanged
@@ -420,7 +452,7 @@ func TestCommands_updateSession(t *testing.T) {
},
},
res{
err: caos_errs.ThrowPreconditionFailed(nil, "COMAND-SAjeh", "Errors.Session.Terminated"),
err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-Hewfq", "Errors.Session.Terminated"),
},
},
{
@@ -465,6 +497,73 @@ func TestCommands_updateSession(t *testing.T) {
},
},
},
{
"negative lifetime",
fields{
eventstore: eventstoreExpect(t),
},
args{
ctx: context.Background(),
checks: &SessionCommands{
sessionWriteModel: NewSessionWriteModel("sessionID", "org1"),
sessionCommands: []SessionCommand{},
eventstore: eventstoreExpect(t),
createToken: func(sessionID string) (string, string, error) {
return "tokenID",
"token",
nil
},
now: func() time.Time {
return testNow
},
},
lifetime: -10 * time.Minute,
},
res{
err: caos_errs.ThrowInvalidArgument(nil, "COMMAND-asEG4", "Errors.Session.PositiveLifetime"),
},
},
{
"lifetime set",
fields{
eventstore: eventstoreExpect(t,
expectPush(
session.NewLifetimeSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate,
10*time.Minute,
),
session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate,
"tokenID",
),
),
),
},
args{
ctx: context.Background(),
checks: &SessionCommands{
sessionWriteModel: NewSessionWriteModel("sessionID", "org1"),
sessionCommands: []SessionCommand{},
eventstore: eventstoreExpect(t),
createToken: func(sessionID string) (string, string, error) {
return "tokenID",
"token",
nil
},
now: func() time.Time {
return testNow
},
},
lifetime: 10 * time.Minute,
},
res{
want: &SessionChanged{
ObjectDetails: &domain.ObjectDetails{
ResourceOwner: "org1",
},
ID: "sessionID",
NewToken: "token",
},
},
},
{
"set user, password, metadata and token",
fields{
@@ -721,7 +820,7 @@ func TestCommands_updateSession(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore,
}
got, err := c.updateSession(tt.args.ctx, tt.args.checks, tt.args.metadata)
got, err := c.updateSession(tt.args.ctx, tt.args.checks, tt.args.metadata, tt.args.lifetime)
require.ErrorIs(t, err, tt.res.err)
assert.Equal(t, tt.res.want, got)
})