mirror of
https://github.com/zitadel/zitadel.git
synced 2025-12-01 15:53:42 +00:00
perf(oidc): optimize token creation (#7822)
* implement code exchange * port tokenexchange to v2 tokens * implement refresh token * implement client credentials * implement jwt profile * implement device token * cleanup unused code * fix current unit tests * add user agent unit test * unit test domain package * need refresh token as argument * test commands create oidc session * test commands device auth * fix device auth build error * implicit for oidc session API * implement authorize callback handler for legacy implicit mode * upgrade oidc module to working draft * add missing auth methods and time * handle all errors in defer * do not fail auth request on error the oauth2 Go client automagically retries on any error. If we fail the auth request on the first error, the next attempt will always fail with the Errors.AuthRequest.NoCode, because the auth request state is already set to failed. The original error is then already lost and the oauth2 library does not return the original error. Therefore we should not fail the auth request. Might be worth discussing and perhaps send a bug report to Oauth2? * fix code flow tests by explicitly setting code exchanged * fix unit tests in command package * return allowed scope from client credential client * add device auth done reducer * carry nonce thru session into ID token * fix token exchange integration tests * allow project role scope prefix in client credentials client * gci formatting * do not return refresh token in client credentials and jwt profile * check org scope * solve linting issue on authorize callback error * end session based on v2 session ID * use preferred language and user agent ID for v2 access tokens * pin oidc v3.23.2 * add integration test for jwt profile and client credentials with org scopes * refresh token v1 to v2 * add user token v2 audit event * add activity trigger * cleanup and set panics for unused methods * use the encrypted code for v1 auth request get by code * add missing event translation * fix pipeline errors (hopefully) * fix another test * revert pointer usage of preferred language * solve browser info panic in device auth * remove duplicate entries in AMRToAuthMethodTypes to prevent future `mfa` claim * revoke v1 refresh token to prevent reuse * fix terminate oidc session * always return a new refresh toke in refresh token grant --------- Co-authored-by: Livio Spring <livio.a@gmail.com>
This commit is contained in:
@@ -12,21 +12,22 @@ import (
|
||||
)
|
||||
|
||||
type AuthRequest struct {
|
||||
ID string
|
||||
LoginClient string
|
||||
ClientID string
|
||||
RedirectURI string
|
||||
State string
|
||||
Nonce string
|
||||
Scope []string
|
||||
Audience []string
|
||||
ResponseType domain.OIDCResponseType
|
||||
CodeChallenge *domain.OIDCCodeChallenge
|
||||
Prompt []domain.Prompt
|
||||
UILocales []string
|
||||
MaxAge *time.Duration
|
||||
LoginHint *string
|
||||
HintUserID *string
|
||||
ID string
|
||||
LoginClient string
|
||||
ClientID string
|
||||
RedirectURI string
|
||||
State string
|
||||
Nonce string
|
||||
Scope []string
|
||||
Audience []string
|
||||
ResponseType domain.OIDCResponseType
|
||||
CodeChallenge *domain.OIDCCodeChallenge
|
||||
Prompt []domain.Prompt
|
||||
UILocales []string
|
||||
MaxAge *time.Duration
|
||||
LoginHint *string
|
||||
HintUserID *string
|
||||
NeedRefreshToken bool
|
||||
}
|
||||
|
||||
type CurrentAuthRequest struct {
|
||||
@@ -69,6 +70,7 @@ func (c *Commands) AddAuthRequest(ctx context.Context, authRequest *AuthRequest)
|
||||
authRequest.MaxAge,
|
||||
authRequest.LoginHint,
|
||||
authRequest.HintUserID,
|
||||
authRequest.NeedRefreshToken,
|
||||
))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -148,25 +150,6 @@ func (c *Commands) AddAuthRequestCode(ctx context.Context, authRequestID, code s
|
||||
&authrequest.NewAggregate(writeModel.AggregateID, authz.GetInstance(ctx).InstanceID()).Aggregate))
|
||||
}
|
||||
|
||||
func (c *Commands) ExchangeAuthCode(ctx context.Context, code string) (authRequest *CurrentAuthRequest, err error) {
|
||||
if code == "" {
|
||||
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-Sf3g2", "Errors.AuthRequest.InvalidCode")
|
||||
}
|
||||
writeModel, err := c.getAuthRequestWriteModel(ctx, code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if writeModel.AuthRequestState != domain.AuthRequestStateCodeAdded {
|
||||
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-SFwd2", "Errors.AuthRequest.NoCode")
|
||||
}
|
||||
err = c.pushAppendAndReduce(ctx, writeModel, authrequest.NewCodeExchangedEvent(ctx,
|
||||
&authrequest.NewAggregate(writeModel.AggregateID, authz.GetInstance(ctx).InstanceID()).Aggregate))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return authRequestWriteModelToCurrentAuthRequest(writeModel), nil
|
||||
}
|
||||
|
||||
func authRequestWriteModelToCurrentAuthRequest(writeModel *AuthRequestWriteModel) (_ *CurrentAuthRequest) {
|
||||
return &CurrentAuthRequest{
|
||||
AuthRequest: &AuthRequest{
|
||||
|
||||
@@ -34,6 +34,7 @@ type AuthRequestWriteModel struct {
|
||||
AuthTime time.Time
|
||||
AuthMethods []domain.UserAuthMethodType
|
||||
AuthRequestState domain.AuthRequestState
|
||||
NeedRefreshToken bool
|
||||
}
|
||||
|
||||
func NewAuthRequestWriteModel(ctx context.Context, id string) *AuthRequestWriteModel {
|
||||
@@ -64,6 +65,7 @@ func (m *AuthRequestWriteModel) Reduce() error {
|
||||
m.LoginHint = e.LoginHint
|
||||
m.HintUserID = e.HintUserID
|
||||
m.AuthRequestState = domain.AuthRequestStateAdded
|
||||
m.NeedRefreshToken = e.NeedRefreshToken
|
||||
case *authrequest.SessionLinkedEvent:
|
||||
m.SessionID = e.SessionID
|
||||
m.UserID = e.UserID
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
@@ -59,6 +60,7 @@ func TestCommands_AddAuthRequest(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
false,
|
||||
),
|
||||
),
|
||||
),
|
||||
@@ -96,6 +98,7 @@ func TestCommands_AddAuthRequest(t *testing.T) {
|
||||
gu.Ptr(time.Duration(0)),
|
||||
gu.Ptr("loginHint"),
|
||||
gu.Ptr("hintUserID"),
|
||||
false,
|
||||
),
|
||||
),
|
||||
),
|
||||
@@ -223,6 +226,7 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
true,
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
@@ -263,6 +267,7 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
true,
|
||||
),
|
||||
),
|
||||
),
|
||||
@@ -301,6 +306,7 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
true,
|
||||
),
|
||||
),
|
||||
),
|
||||
@@ -338,6 +344,7 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
true,
|
||||
),
|
||||
),
|
||||
),
|
||||
@@ -354,7 +361,7 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
|
||||
)),
|
||||
eventFromEventPusher(
|
||||
session.NewUserCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
"userID", "org1", testNow.Add(-5*time.Minute)),
|
||||
"userID", "org1", testNow.Add(-5*time.Minute), &language.Afrikaans),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
session.NewPasswordCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
@@ -398,6 +405,7 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
true,
|
||||
),
|
||||
),
|
||||
),
|
||||
@@ -447,6 +455,7 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
true,
|
||||
),
|
||||
),
|
||||
),
|
||||
@@ -463,7 +472,7 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
|
||||
)),
|
||||
eventFromEventPusher(
|
||||
session.NewUserCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
"userID", "org1", testNow),
|
||||
"userID", "org1", testNow, &language.Afrikaans),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
session.NewPasswordCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
@@ -532,6 +541,7 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
true,
|
||||
),
|
||||
),
|
||||
),
|
||||
@@ -548,7 +558,7 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
|
||||
)),
|
||||
eventFromEventPusher(
|
||||
session.NewUserCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
"userID", "org1", testNow),
|
||||
"userID", "org1", testNow, &language.Afrikaans),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
session.NewPasswordCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
@@ -674,6 +684,7 @@ func TestCommands_FailAuthRequest(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
true,
|
||||
),
|
||||
),
|
||||
),
|
||||
@@ -771,6 +782,7 @@ func TestCommands_AddAuthRequestCode(t *testing.T) {
|
||||
gu.Ptr(time.Duration(0)),
|
||||
gu.Ptr("loginHint"),
|
||||
gu.Ptr("hintUserID"),
|
||||
true,
|
||||
),
|
||||
),
|
||||
),
|
||||
@@ -807,6 +819,7 @@ func TestCommands_AddAuthRequestCode(t *testing.T) {
|
||||
gu.Ptr(time.Duration(0)),
|
||||
gu.Ptr("loginHint"),
|
||||
gu.Ptr("hintUserID"),
|
||||
true,
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
@@ -841,166 +854,3 @@ func TestCommands_AddAuthRequestCode(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommands_ExchangeAuthCode(t *testing.T) {
|
||||
mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient")
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
code string
|
||||
}
|
||||
type res struct {
|
||||
authRequest *CurrentAuthRequest
|
||||
err error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"empty code error",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
code: "",
|
||||
},
|
||||
res{
|
||||
err: zerrors.ThrowPreconditionFailed(nil, "COMMAND-Sf3g2", "Errors.AuthRequest.InvalidCode"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"no code added error",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"clientID",
|
||||
"redirectURI",
|
||||
"state",
|
||||
"nonce",
|
||||
[]string{"openid"},
|
||||
[]string{"audience"},
|
||||
domain.OIDCResponseTypeCode,
|
||||
&domain.OIDCCodeChallenge{
|
||||
Challenge: "challenge",
|
||||
Method: domain.CodeChallengeMethodS256,
|
||||
},
|
||||
[]domain.Prompt{domain.PromptNone},
|
||||
[]string{"en", "de"},
|
||||
gu.Ptr(time.Duration(0)),
|
||||
gu.Ptr("loginHint"),
|
||||
gu.Ptr("hintUserID"),
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
code: "V2_authRequestID",
|
||||
},
|
||||
res{
|
||||
err: zerrors.ThrowPreconditionFailed(nil, "COMMAND-SFwd2", "Errors.AuthRequest.NoCode"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"code exchanged",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"clientID",
|
||||
"redirectURI",
|
||||
"state",
|
||||
"nonce",
|
||||
[]string{"openid"},
|
||||
[]string{"audience"},
|
||||
domain.OIDCResponseTypeCode,
|
||||
&domain.OIDCCodeChallenge{
|
||||
Challenge: "challenge",
|
||||
Method: domain.CodeChallengeMethodS256,
|
||||
},
|
||||
[]domain.Prompt{domain.PromptNone},
|
||||
[]string{"en", "de"},
|
||||
gu.Ptr(time.Duration(0)),
|
||||
gu.Ptr("loginHint"),
|
||||
gu.Ptr("hintUserID"),
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
authrequest.NewSessionLinkedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
|
||||
"sessionID",
|
||||
"userID",
|
||||
testNow,
|
||||
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
authrequest.NewCodeAddedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
|
||||
),
|
||||
),
|
||||
expectPush(
|
||||
authrequest.NewCodeExchangedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
|
||||
),
|
||||
),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
code: "V2_authRequestID",
|
||||
},
|
||||
res{
|
||||
authRequest: &CurrentAuthRequest{
|
||||
AuthRequest: &AuthRequest{
|
||||
ID: "V2_authRequestID",
|
||||
LoginClient: "loginClient",
|
||||
ClientID: "clientID",
|
||||
RedirectURI: "redirectURI",
|
||||
State: "state",
|
||||
Nonce: "nonce",
|
||||
Scope: []string{"openid"},
|
||||
Audience: []string{"audience"},
|
||||
ResponseType: domain.OIDCResponseTypeCode,
|
||||
CodeChallenge: &domain.OIDCCodeChallenge{
|
||||
Challenge: "challenge",
|
||||
Method: domain.CodeChallengeMethodS256,
|
||||
},
|
||||
Prompt: []domain.Prompt{domain.PromptNone},
|
||||
UILocales: []string{"en", "de"},
|
||||
MaxAge: gu.Ptr(time.Duration(0)),
|
||||
LoginHint: gu.Ptr("loginHint"),
|
||||
HintUserID: gu.Ptr("hintUserID"),
|
||||
},
|
||||
SessionID: "sessionID",
|
||||
UserID: "userID",
|
||||
AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore,
|
||||
}
|
||||
got, err := c.ExchangeAuthCode(tt.args.ctx, tt.args.code)
|
||||
assert.ErrorIs(t, tt.res.err, err)
|
||||
|
||||
if err == nil {
|
||||
// equal on time won't work -> test separately and clear it before comparing the rest
|
||||
assert.WithinRange(t, got.AuthTime, testNow, testNow)
|
||||
got.AuthTime = time.Time{}
|
||||
}
|
||||
assert.Equal(t, tt.res.authRequest, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,16 +2,20 @@ package command
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/repository/deviceauth"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
func (c *Commands) AddDeviceAuth(ctx context.Context, clientID, deviceCode, userCode string, expires time.Time, scopes, audience []string) (*domain.ObjectDetails, error) {
|
||||
func (c *Commands) AddDeviceAuth(ctx context.Context, clientID, deviceCode, userCode string, expires time.Time, scopes, audience []string, needRefreshToken bool) (*domain.ObjectDetails, error) {
|
||||
aggr := deviceauth.NewAggregate(deviceCode, authz.GetInstance(ctx).InstanceID())
|
||||
model := NewDeviceAuthWriteModel(deviceCode, aggr.ResourceOwner)
|
||||
|
||||
@@ -24,6 +28,7 @@ func (c *Commands) AddDeviceAuth(ctx context.Context, clientID, deviceCode, user
|
||||
expires,
|
||||
scopes,
|
||||
audience,
|
||||
needRefreshToken,
|
||||
))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -36,7 +41,16 @@ func (c *Commands) AddDeviceAuth(ctx context.Context, clientID, deviceCode, user
|
||||
return writeModelToObjectDetails(&model.WriteModel), nil
|
||||
}
|
||||
|
||||
func (c *Commands) ApproveDeviceAuth(ctx context.Context, deviceCode, subject string, authMethods []domain.UserAuthMethodType, authTime time.Time) (*domain.ObjectDetails, error) {
|
||||
func (c *Commands) ApproveDeviceAuth(
|
||||
ctx context.Context,
|
||||
deviceCode,
|
||||
userID,
|
||||
userOrgID string,
|
||||
authMethods []domain.UserAuthMethodType,
|
||||
authTime time.Time,
|
||||
preferredLanguage *language.Tag,
|
||||
userAgent *domain.UserAgent,
|
||||
) (*domain.ObjectDetails, error) {
|
||||
model, err := c.getDeviceAuthWriteModelByDeviceCode(ctx, deviceCode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -44,9 +58,7 @@ func (c *Commands) ApproveDeviceAuth(ctx context.Context, deviceCode, subject st
|
||||
if !model.State.Exists() {
|
||||
return nil, zerrors.ThrowNotFound(nil, "COMMAND-Hief9", "Errors.DeviceAuth.NotFound")
|
||||
}
|
||||
aggr := deviceauth.NewAggregate(model.AggregateID, model.InstanceID)
|
||||
|
||||
pushedEvents, err := c.eventstore.Push(ctx, deviceauth.NewApprovedEvent(ctx, aggr, subject, authMethods, authTime))
|
||||
pushedEvents, err := c.eventstore.Push(ctx, deviceauth.NewApprovedEvent(ctx, model.aggregate, userID, userOrgID, authMethods, authTime, preferredLanguage, userAgent))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -66,9 +78,7 @@ func (c *Commands) CancelDeviceAuth(ctx context.Context, id string, reason domai
|
||||
if !model.State.Exists() {
|
||||
return nil, zerrors.ThrowNotFound(nil, "COMMAND-gee5A", "Errors.DeviceAuth.NotFound")
|
||||
}
|
||||
aggr := deviceauth.NewAggregate(model.AggregateID, model.InstanceID)
|
||||
|
||||
pushedEvents, err := c.eventstore.Push(ctx, deviceauth.NewCanceledEvent(ctx, aggr, reason))
|
||||
pushedEvents, err := c.eventstore.Push(ctx, deviceauth.NewCanceledEvent(ctx, model.aggregate, reason))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -81,10 +91,89 @@ func (c *Commands) CancelDeviceAuth(ctx context.Context, id string, reason domai
|
||||
}
|
||||
|
||||
func (c *Commands) getDeviceAuthWriteModelByDeviceCode(ctx context.Context, deviceCode string) (*DeviceAuthWriteModel, error) {
|
||||
model := &DeviceAuthWriteModel{WriteModel: eventstore.WriteModel{AggregateID: deviceCode}}
|
||||
model := &DeviceAuthWriteModel{
|
||||
WriteModel: eventstore.WriteModel{AggregateID: deviceCode},
|
||||
}
|
||||
err := c.eventstore.FilterToQueryReducer(ctx, model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
model.aggregate = deviceauth.NewAggregate(model.AggregateID, model.InstanceID)
|
||||
return model, nil
|
||||
}
|
||||
|
||||
type DeviceAuthStateError domain.DeviceAuthState
|
||||
|
||||
func (e DeviceAuthStateError) Error() string {
|
||||
return fmt.Sprintf("device auth state not approved: %s", domain.DeviceAuthState(e).String())
|
||||
}
|
||||
|
||||
// CreateOIDCSessionFromDeviceAuth creates a new OIDC session if the device authorization
|
||||
// flow is completed (user logged in).
|
||||
// A [DeviceAuthStateError] is returned if the device authorization was not approved,
|
||||
// containing a [domain.DeviceAuthState] which can be used to inform the client about the state.
|
||||
//
|
||||
// As devices can poll at various intervals, an explicit state takes precedence over expiry.
|
||||
// This is to prevent cases where users might approve or deny the authorization on time, but the next poll
|
||||
// happens after expiry.
|
||||
func (c *Commands) CreateOIDCSessionFromDeviceAuth(ctx context.Context, deviceCode string) (_ *OIDCSession, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
deviceAuthModel, err := c.getDeviceAuthWriteModelByDeviceCode(ctx, deviceCode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch deviceAuthModel.State {
|
||||
case domain.DeviceAuthStateApproved:
|
||||
break
|
||||
case domain.DeviceAuthStateUndefined:
|
||||
return nil, zerrors.ThrowNotFound(nil, "COMMAND-ua1Vo", "Errors.DeviceAuth.NotFound")
|
||||
|
||||
case domain.DeviceAuthStateInitiated:
|
||||
if deviceAuthModel.Expires.Before(time.Now()) {
|
||||
c.asyncPush(ctx, deviceauth.NewCanceledEvent(ctx, deviceAuthModel.aggregate, domain.DeviceAuthCanceledExpired))
|
||||
return nil, DeviceAuthStateError(domain.DeviceAuthStateExpired)
|
||||
}
|
||||
fallthrough
|
||||
case domain.DeviceAuthStateDenied, domain.DeviceAuthStateExpired, domain.DeviceAuthStateDone:
|
||||
fallthrough
|
||||
default:
|
||||
return nil, DeviceAuthStateError(deviceAuthModel.State)
|
||||
}
|
||||
|
||||
cmd, err := c.newOIDCSessionAddEvents(ctx, deviceAuthModel.UserOrgID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cmd.AddSession(ctx,
|
||||
deviceAuthModel.UserID,
|
||||
deviceAuthModel.UserOrgID,
|
||||
"",
|
||||
deviceAuthModel.ClientID,
|
||||
deviceAuthModel.Audience,
|
||||
deviceAuthModel.Scopes,
|
||||
deviceAuthModel.UserAuthMethods,
|
||||
deviceAuthModel.AuthTime,
|
||||
"",
|
||||
deviceAuthModel.PreferredLanguage,
|
||||
deviceAuthModel.UserAgent,
|
||||
)
|
||||
if err = cmd.AddAccessToken(ctx, deviceAuthModel.Scopes, deviceAuthModel.UserID, deviceAuthModel.UserOrgID, domain.TokenReasonAuthRequest, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if deviceAuthModel.NeedRefreshToken {
|
||||
if err = cmd.AddRefreshToken(ctx, deviceAuthModel.UserID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
cmd.DeviceAuthRequestDone(ctx, deviceAuthModel.aggregate)
|
||||
return cmd.PushEvents(ctx)
|
||||
}
|
||||
|
||||
func (cmd *OIDCSessionEvents) DeviceAuthRequestDone(ctx context.Context, deviceAuthAggregate *eventstore.Aggregate) {
|
||||
cmd.events = append(cmd.events, deviceauth.NewDoneEvent(ctx, deviceAuthAggregate))
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ package command
|
||||
import (
|
||||
"time"
|
||||
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/repository/deviceauth"
|
||||
@@ -10,16 +12,22 @@ import (
|
||||
|
||||
type DeviceAuthWriteModel struct {
|
||||
eventstore.WriteModel
|
||||
aggregate *eventstore.Aggregate
|
||||
|
||||
ClientID string
|
||||
DeviceCode string
|
||||
UserCode string
|
||||
Expires time.Time
|
||||
Scopes []string
|
||||
State domain.DeviceAuthState
|
||||
Subject string
|
||||
UserAuthMethods []domain.UserAuthMethodType
|
||||
AuthTime time.Time
|
||||
ClientID string
|
||||
DeviceCode string
|
||||
UserCode string
|
||||
Expires time.Time
|
||||
Scopes []string
|
||||
Audience []string
|
||||
State domain.DeviceAuthState
|
||||
UserID string
|
||||
UserOrgID string
|
||||
UserAuthMethods []domain.UserAuthMethodType
|
||||
AuthTime time.Time
|
||||
PreferredLanguage *language.Tag
|
||||
UserAgent *domain.UserAgent
|
||||
NeedRefreshToken bool
|
||||
}
|
||||
|
||||
func NewDeviceAuthWriteModel(deviceCode, resourceOwner string) *DeviceAuthWriteModel {
|
||||
@@ -28,6 +36,7 @@ func NewDeviceAuthWriteModel(deviceCode, resourceOwner string) *DeviceAuthWriteM
|
||||
AggregateID: deviceCode,
|
||||
ResourceOwner: resourceOwner,
|
||||
},
|
||||
aggregate: deviceauth.NewAggregate(deviceCode, resourceOwner),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,14 +49,21 @@ func (m *DeviceAuthWriteModel) Reduce() error {
|
||||
m.UserCode = e.UserCode
|
||||
m.Expires = e.Expires
|
||||
m.Scopes = e.Scopes
|
||||
m.Audience = e.Audience
|
||||
m.State = e.State
|
||||
m.NeedRefreshToken = e.NeedRefreshToken
|
||||
case *deviceauth.ApprovedEvent:
|
||||
m.State = domain.DeviceAuthStateApproved
|
||||
m.Subject = e.Subject
|
||||
m.UserID = e.UserID
|
||||
m.UserOrgID = e.UserOrgID
|
||||
m.UserAuthMethods = e.UserAuthMethods
|
||||
m.AuthTime = e.AuthTime
|
||||
m.PreferredLanguage = e.PreferredLanguage
|
||||
m.UserAgent = e.UserAgent
|
||||
case *deviceauth.CanceledEvent:
|
||||
m.State = e.Reason.State()
|
||||
case *deviceauth.DoneEvent:
|
||||
m.State = domain.DeviceAuthStateDone
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,16 +3,27 @@ package command
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/id"
|
||||
"github.com/zitadel/zitadel/internal/id/mock"
|
||||
"github.com/zitadel/zitadel/internal/repository/deviceauth"
|
||||
"github.com/zitadel/zitadel/internal/repository/oidcsession"
|
||||
"github.com/zitadel/zitadel/internal/repository/user"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
@@ -25,16 +36,17 @@ func TestCommands_AddDeviceAuth(t *testing.T) {
|
||||
require.Len(t, unique, 2)
|
||||
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
eventstore func(*testing.T) *eventstore.Eventstore
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
clientID string
|
||||
deviceCode string
|
||||
userCode string
|
||||
expires time.Time
|
||||
scopes []string
|
||||
audience []string
|
||||
ctx context.Context
|
||||
clientID string
|
||||
deviceCode string
|
||||
userCode string
|
||||
expires time.Time
|
||||
scopes []string
|
||||
audience []string
|
||||
needRefreshToken bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -46,24 +58,25 @@ func TestCommands_AddDeviceAuth(t *testing.T) {
|
||||
{
|
||||
name: "success",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(t, expectPush(
|
||||
eventstore: expectEventstore(expectPush(
|
||||
deviceauth.NewAddedEvent(
|
||||
ctx,
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
"client_id", "123", "456", now,
|
||||
[]string{"a", "b", "c"},
|
||||
[]string{"projectID", "clientID"},
|
||||
[]string{"projectID", "clientID"}, true,
|
||||
),
|
||||
)),
|
||||
},
|
||||
args: args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instance1"),
|
||||
clientID: "client_id",
|
||||
deviceCode: "123",
|
||||
userCode: "456",
|
||||
expires: now,
|
||||
scopes: []string{"a", "b", "c"},
|
||||
audience: []string{"projectID", "clientID"},
|
||||
ctx: authz.WithInstanceID(context.Background(), "instance1"),
|
||||
clientID: "client_id",
|
||||
deviceCode: "123",
|
||||
userCode: "456",
|
||||
expires: now,
|
||||
scopes: []string{"a", "b", "c"},
|
||||
audience: []string{"projectID", "clientID"},
|
||||
needRefreshToken: true,
|
||||
},
|
||||
wantDetails: &domain.ObjectDetails{
|
||||
ResourceOwner: "instance1",
|
||||
@@ -72,24 +85,25 @@ func TestCommands_AddDeviceAuth(t *testing.T) {
|
||||
{
|
||||
name: "push error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(t, expectPushFailed(pushErr,
|
||||
eventstore: expectEventstore(expectPushFailed(pushErr,
|
||||
deviceauth.NewAddedEvent(
|
||||
ctx,
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
"client_id", "123", "456", now,
|
||||
[]string{"a", "b", "c"},
|
||||
[]string{"projectID", "clientID"},
|
||||
[]string{"projectID", "clientID"}, false,
|
||||
)),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instance1"),
|
||||
clientID: "client_id",
|
||||
deviceCode: "123",
|
||||
userCode: "456",
|
||||
expires: now,
|
||||
scopes: []string{"a", "b", "c"},
|
||||
audience: []string{"projectID", "clientID"},
|
||||
ctx: authz.WithInstanceID(context.Background(), "instance1"),
|
||||
clientID: "client_id",
|
||||
deviceCode: "123",
|
||||
userCode: "456",
|
||||
expires: now,
|
||||
scopes: []string{"a", "b", "c"},
|
||||
audience: []string{"projectID", "clientID"},
|
||||
needRefreshToken: false,
|
||||
},
|
||||
wantErr: pushErr,
|
||||
},
|
||||
@@ -97,9 +111,9 @@ func TestCommands_AddDeviceAuth(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore,
|
||||
eventstore: tt.fields.eventstore(t),
|
||||
}
|
||||
gotDetails, err := c.AddDeviceAuth(tt.args.ctx, tt.args.clientID, tt.args.deviceCode, tt.args.userCode, tt.args.expires, tt.args.scopes, tt.args.audience)
|
||||
gotDetails, err := c.AddDeviceAuth(tt.args.ctx, tt.args.clientID, tt.args.deviceCode, tt.args.userCode, tt.args.expires, tt.args.scopes, tt.args.audience, tt.args.needRefreshToken)
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
assert.Equal(t, tt.wantDetails, gotDetails)
|
||||
})
|
||||
@@ -115,11 +129,14 @@ func TestCommands_ApproveDeviceAuth(t *testing.T) {
|
||||
eventstore *eventstore.Eventstore
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
id string
|
||||
subject string
|
||||
authMethods []domain.UserAuthMethodType
|
||||
authTime time.Time
|
||||
ctx context.Context
|
||||
id string
|
||||
userID string
|
||||
userOrgID string
|
||||
authMethods []domain.UserAuthMethodType
|
||||
authTime time.Time
|
||||
preferredLanguage *language.Tag
|
||||
userAgent *domain.UserAgent
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -136,9 +153,14 @@ func TestCommands_ApproveDeviceAuth(t *testing.T) {
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx, "123", "subj",
|
||||
ctx, "123", "subj", "orgID",
|
||||
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
|
||||
time.Unix(123, 456),
|
||||
time.Unix(123, 456), &language.Afrikaans, &domain.UserAgent{
|
||||
FingerprintID: gu.Ptr("fp1"),
|
||||
IP: net.ParseIP("1.2.3.4"),
|
||||
Description: gu.Ptr("firefox"),
|
||||
Header: http.Header{"foo": []string{"bar"}},
|
||||
},
|
||||
},
|
||||
wantErr: zerrors.ThrowNotFound(nil, "COMMAND-Hief9", "Errors.DeviceAuth.NotFound"),
|
||||
},
|
||||
@@ -153,22 +175,32 @@ func TestCommands_ApproveDeviceAuth(t *testing.T) {
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
"client_id", "123", "456", now,
|
||||
[]string{"a", "b", "c"},
|
||||
[]string{"projectID", "clientID"},
|
||||
[]string{"projectID", "clientID"}, true,
|
||||
),
|
||||
)),
|
||||
expectPushFailed(pushErr,
|
||||
deviceauth.NewApprovedEvent(
|
||||
ctx, deviceauth.NewAggregate("123", "instance1"), "subj",
|
||||
ctx, deviceauth.NewAggregate("123", "instance1"), "subj", "orgID",
|
||||
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
|
||||
time.Unix(123, 456),
|
||||
time.Unix(123, 456), &language.Afrikaans, &domain.UserAgent{
|
||||
FingerprintID: gu.Ptr("fp1"),
|
||||
IP: net.ParseIP("1.2.3.4"),
|
||||
Description: gu.Ptr("firefox"),
|
||||
Header: http.Header{"foo": []string{"bar"}},
|
||||
},
|
||||
),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx, "123", "subj",
|
||||
ctx, "123", "subj", "orgID",
|
||||
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
|
||||
time.Unix(123, 456),
|
||||
time.Unix(123, 456), &language.Afrikaans, &domain.UserAgent{
|
||||
FingerprintID: gu.Ptr("fp1"),
|
||||
IP: net.ParseIP("1.2.3.4"),
|
||||
Description: gu.Ptr("firefox"),
|
||||
Header: http.Header{"foo": []string{"bar"}},
|
||||
},
|
||||
},
|
||||
wantErr: pushErr,
|
||||
},
|
||||
@@ -183,22 +215,32 @@ func TestCommands_ApproveDeviceAuth(t *testing.T) {
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
"client_id", "123", "456", now,
|
||||
[]string{"a", "b", "c"},
|
||||
[]string{"projectID", "clientID"},
|
||||
[]string{"projectID", "clientID"}, true,
|
||||
),
|
||||
)),
|
||||
expectPush(
|
||||
deviceauth.NewApprovedEvent(
|
||||
ctx, deviceauth.NewAggregate("123", "instance1"), "subj",
|
||||
ctx, deviceauth.NewAggregate("123", "instance1"), "subj", "orgID",
|
||||
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
|
||||
time.Unix(123, 456),
|
||||
time.Unix(123, 456), &language.Afrikaans, &domain.UserAgent{
|
||||
FingerprintID: gu.Ptr("fp1"),
|
||||
IP: net.ParseIP("1.2.3.4"),
|
||||
Description: gu.Ptr("firefox"),
|
||||
Header: http.Header{"foo": []string{"bar"}},
|
||||
},
|
||||
),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx, "123", "subj",
|
||||
ctx, "123", "subj", "orgID",
|
||||
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
|
||||
time.Unix(123, 456),
|
||||
time.Unix(123, 456), &language.Afrikaans, &domain.UserAgent{
|
||||
FingerprintID: gu.Ptr("fp1"),
|
||||
IP: net.ParseIP("1.2.3.4"),
|
||||
Description: gu.Ptr("firefox"),
|
||||
Header: http.Header{"foo": []string{"bar"}},
|
||||
},
|
||||
},
|
||||
wantDetails: &domain.ObjectDetails{
|
||||
ResourceOwner: "instance1",
|
||||
@@ -210,7 +252,7 @@ func TestCommands_ApproveDeviceAuth(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore,
|
||||
}
|
||||
gotDetails, err := c.ApproveDeviceAuth(tt.args.ctx, tt.args.id, tt.args.subject, tt.args.authMethods, tt.args.authTime)
|
||||
gotDetails, err := c.ApproveDeviceAuth(tt.args.ctx, tt.args.id, tt.args.userID, tt.args.userOrgID, tt.args.authMethods, tt.args.authTime, tt.args.preferredLanguage, tt.args.userAgent)
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
assert.Equal(t, gotDetails, tt.wantDetails)
|
||||
})
|
||||
@@ -258,7 +300,7 @@ func TestCommands_CancelDeviceAuth(t *testing.T) {
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
"client_id", "123", "456", now,
|
||||
[]string{"a", "b", "c"},
|
||||
[]string{"projectID", "clientID"},
|
||||
[]string{"projectID", "clientID"}, true,
|
||||
),
|
||||
)),
|
||||
expectPushFailed(pushErr,
|
||||
@@ -283,7 +325,7 @@ func TestCommands_CancelDeviceAuth(t *testing.T) {
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
"client_id", "123", "456", now,
|
||||
[]string{"a", "b", "c"},
|
||||
[]string{"projectID", "clientID"},
|
||||
[]string{"projectID", "clientID"}, true,
|
||||
),
|
||||
)),
|
||||
expectPush(
|
||||
@@ -310,7 +352,7 @@ func TestCommands_CancelDeviceAuth(t *testing.T) {
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
"client_id", "123", "456", now,
|
||||
[]string{"a", "b", "c"},
|
||||
[]string{"projectID", "clientID"},
|
||||
[]string{"projectID", "clientID"}, true,
|
||||
),
|
||||
)),
|
||||
expectPush(
|
||||
@@ -338,3 +380,392 @@ func TestCommands_CancelDeviceAuth(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommands_CreateOIDCSessionFromDeviceAuth(t *testing.T) {
|
||||
ctx := authz.WithInstanceID(context.Background(), "instance1")
|
||||
|
||||
type fields struct {
|
||||
eventstore func(*testing.T) *eventstore.Eventstore
|
||||
idGenerator id.Generator
|
||||
defaultAccessTokenLifetime time.Duration
|
||||
defaultRefreshTokenLifetime time.Duration
|
||||
defaultRefreshTokenIdleLifetime time.Duration
|
||||
keyAlgorithm crypto.EncryptionAlgorithm
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
deviceCode string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want *OIDCSession
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "device auth filter error",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilterError(io.ErrClosedPipe),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx,
|
||||
"device1",
|
||||
},
|
||||
wantErr: io.ErrClosedPipe,
|
||||
},
|
||||
{
|
||||
name: "not yet approved",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusherWithInstanceID(
|
||||
"instance1",
|
||||
deviceauth.NewAddedEvent(
|
||||
ctx,
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
"clientID", "123", "456", time.Now().Add(time.Minute),
|
||||
[]string{"openid", "offline_access"},
|
||||
[]string{"audience"}, false,
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx,
|
||||
"123",
|
||||
},
|
||||
wantErr: DeviceAuthStateError(domain.DeviceAuthStateInitiated),
|
||||
},
|
||||
{
|
||||
name: "not found",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx,
|
||||
"123",
|
||||
},
|
||||
wantErr: zerrors.ThrowNotFound(nil, "COMMAND-ua1Vo", "Errors.DeviceAuth.NotFound"),
|
||||
},
|
||||
{
|
||||
name: "expired",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusherWithInstanceID(
|
||||
"instance1",
|
||||
deviceauth.NewAddedEvent(
|
||||
ctx,
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
"clientID", "123", "456", time.Now().Add(-time.Minute),
|
||||
[]string{"openid", "offline_access"},
|
||||
[]string{"audience"}, false,
|
||||
),
|
||||
),
|
||||
),
|
||||
expectPushSlow(time.Second, deviceauth.NewCanceledEvent(ctx,
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
domain.DeviceAuthCanceledExpired,
|
||||
)),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx,
|
||||
"123",
|
||||
},
|
||||
wantErr: DeviceAuthStateError(domain.DeviceAuthStateExpired),
|
||||
},
|
||||
{
|
||||
name: "already expired",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusherWithInstanceID(
|
||||
"instance1",
|
||||
deviceauth.NewAddedEvent(
|
||||
ctx,
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
"clientID", "123", "456", time.Now().Add(-time.Minute),
|
||||
[]string{"openid", "offline_access"},
|
||||
[]string{"audience"}, false,
|
||||
),
|
||||
),
|
||||
eventFromEventPusherWithInstanceID(
|
||||
"instance1",
|
||||
deviceauth.NewCanceledEvent(ctx,
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
domain.DeviceAuthCanceledExpired,
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx,
|
||||
"123",
|
||||
},
|
||||
wantErr: DeviceAuthStateError(domain.DeviceAuthStateExpired),
|
||||
},
|
||||
{
|
||||
name: "denied",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusherWithInstanceID(
|
||||
"instance1",
|
||||
deviceauth.NewAddedEvent(
|
||||
ctx,
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
"clientID", "123", "456", time.Now().Add(-time.Minute),
|
||||
[]string{"openid", "offline_access"},
|
||||
[]string{"audience"}, false,
|
||||
),
|
||||
),
|
||||
eventFromEventPusherWithInstanceID(
|
||||
"instance1",
|
||||
deviceauth.NewCanceledEvent(ctx,
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
domain.DeviceAuthCanceledDenied,
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx,
|
||||
"123",
|
||||
},
|
||||
wantErr: DeviceAuthStateError(domain.DeviceAuthStateDenied),
|
||||
},
|
||||
{
|
||||
name: "already done",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusherWithInstanceID(
|
||||
"instance1",
|
||||
deviceauth.NewAddedEvent(
|
||||
ctx,
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
"clientID", "123", "456", time.Now().Add(-time.Minute),
|
||||
[]string{"openid", "offline_access"},
|
||||
[]string{"audience"}, false,
|
||||
),
|
||||
),
|
||||
eventFromEventPusherWithInstanceID(
|
||||
"instance1",
|
||||
deviceauth.NewCanceledEvent(ctx,
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
domain.DeviceAuthCanceledDenied,
|
||||
),
|
||||
),
|
||||
eventFromEventPusherWithInstanceID(
|
||||
"instance1",
|
||||
deviceauth.NewDoneEvent(ctx,
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx,
|
||||
"123",
|
||||
},
|
||||
wantErr: DeviceAuthStateError(domain.DeviceAuthStateDone),
|
||||
},
|
||||
{
|
||||
name: "approved, success",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusherWithInstanceID(
|
||||
"instance1",
|
||||
deviceauth.NewAddedEvent(
|
||||
ctx,
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
"clientID", "123", "456", time.Now().Add(-time.Minute),
|
||||
[]string{"openid", "offline_access"},
|
||||
[]string{"audience"}, false,
|
||||
),
|
||||
),
|
||||
eventFromEventPusherWithInstanceID(
|
||||
"instance1",
|
||||
deviceauth.NewApprovedEvent(ctx,
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
"userID", "org1",
|
||||
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
|
||||
testNow, &language.Afrikaans, &domain.UserAgent{
|
||||
FingerprintID: gu.Ptr("fp1"),
|
||||
IP: net.ParseIP("1.2.3.4"),
|
||||
Description: gu.Ptr("firefox"),
|
||||
Header: http.Header{"foo": []string{"bar"}},
|
||||
},
|
||||
),
|
||||
),
|
||||
),
|
||||
expectFilter(), // token lifetime
|
||||
expectPush(
|
||||
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
|
||||
"userID", "org1", "", "clientID", []string{"audience"}, []string{"openid", "offline_access"},
|
||||
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow, "", &language.Afrikaans, &domain.UserAgent{
|
||||
FingerprintID: gu.Ptr("fp1"),
|
||||
IP: net.ParseIP("1.2.3.4"),
|
||||
Description: gu.Ptr("firefox"),
|
||||
Header: http.Header{"foo": []string{"bar"}},
|
||||
},
|
||||
),
|
||||
oidcsession.NewAccessTokenAddedEvent(context.Background(),
|
||||
&oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
|
||||
"at_accessTokenID", []string{"openid", "offline_access"}, time.Hour, domain.TokenReasonAuthRequest, nil,
|
||||
),
|
||||
user.NewUserTokenV2AddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate, "at_accessTokenID"),
|
||||
deviceauth.NewDoneEvent(ctx,
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
),
|
||||
),
|
||||
),
|
||||
idGenerator: mock.NewIDGeneratorExpectIDs(t, "oidcSessionID", "accessTokenID"),
|
||||
defaultAccessTokenLifetime: time.Hour,
|
||||
defaultRefreshTokenLifetime: 7 * 24 * time.Hour,
|
||||
defaultRefreshTokenIdleLifetime: 24 * time.Hour,
|
||||
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
args: args{
|
||||
ctx,
|
||||
"123",
|
||||
},
|
||||
want: &OIDCSession{
|
||||
TokenID: "V2_oidcSessionID-at_accessTokenID",
|
||||
ClientID: "clientID",
|
||||
UserID: "userID",
|
||||
Audience: []string{"audience"},
|
||||
Expiration: time.Time{}.Add(time.Hour),
|
||||
Scope: []string{"openid", "offline_access"},
|
||||
AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
|
||||
AuthTime: testNow,
|
||||
PreferredLanguage: &language.Afrikaans,
|
||||
UserAgent: &domain.UserAgent{
|
||||
FingerprintID: gu.Ptr("fp1"),
|
||||
IP: net.ParseIP("1.2.3.4"),
|
||||
Description: gu.Ptr("firefox"),
|
||||
Header: http.Header{"foo": []string{"bar"}},
|
||||
},
|
||||
Reason: domain.TokenReasonAuthRequest,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "approved, with refresh token",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusherWithInstanceID(
|
||||
"instance1",
|
||||
deviceauth.NewAddedEvent(
|
||||
ctx,
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
"clientID", "123", "456", time.Now().Add(-time.Minute),
|
||||
[]string{"openid", "offline_access"},
|
||||
[]string{"audience"}, true,
|
||||
),
|
||||
),
|
||||
eventFromEventPusherWithInstanceID(
|
||||
"instance1",
|
||||
deviceauth.NewApprovedEvent(ctx,
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
"userID", "org1",
|
||||
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
|
||||
testNow, &language.Afrikaans, &domain.UserAgent{
|
||||
FingerprintID: gu.Ptr("fp1"),
|
||||
IP: net.ParseIP("1.2.3.4"),
|
||||
Description: gu.Ptr("firefox"),
|
||||
Header: http.Header{"foo": []string{"bar"}},
|
||||
},
|
||||
),
|
||||
),
|
||||
),
|
||||
expectFilter(), // token lifetime
|
||||
expectPush(
|
||||
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
|
||||
"userID", "org1", "", "clientID", []string{"audience"}, []string{"openid", "offline_access"},
|
||||
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow, "", &language.Afrikaans, &domain.UserAgent{
|
||||
FingerprintID: gu.Ptr("fp1"),
|
||||
IP: net.ParseIP("1.2.3.4"),
|
||||
Description: gu.Ptr("firefox"),
|
||||
Header: http.Header{"foo": []string{"bar"}},
|
||||
},
|
||||
),
|
||||
oidcsession.NewAccessTokenAddedEvent(context.Background(),
|
||||
&oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
|
||||
"at_accessTokenID", []string{"openid", "offline_access"}, time.Hour, domain.TokenReasonAuthRequest, nil,
|
||||
),
|
||||
user.NewUserTokenV2AddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate, "at_accessTokenID"),
|
||||
oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
|
||||
"rt_refreshTokenID", 7*24*time.Hour, 24*time.Hour,
|
||||
),
|
||||
deviceauth.NewDoneEvent(ctx,
|
||||
deviceauth.NewAggregate("123", "instance1"),
|
||||
),
|
||||
),
|
||||
),
|
||||
idGenerator: mock.NewIDGeneratorExpectIDs(t, "oidcSessionID", "accessTokenID", "refreshTokenID"),
|
||||
defaultAccessTokenLifetime: time.Hour,
|
||||
defaultRefreshTokenLifetime: 7 * 24 * time.Hour,
|
||||
defaultRefreshTokenIdleLifetime: 24 * time.Hour,
|
||||
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
args: args{
|
||||
ctx,
|
||||
"123",
|
||||
},
|
||||
want: &OIDCSession{
|
||||
TokenID: "V2_oidcSessionID-at_accessTokenID",
|
||||
ClientID: "clientID",
|
||||
UserID: "userID",
|
||||
Audience: []string{"audience"},
|
||||
Expiration: time.Time{}.Add(time.Hour),
|
||||
Scope: []string{"openid", "offline_access"},
|
||||
AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
|
||||
AuthTime: testNow,
|
||||
PreferredLanguage: &language.Afrikaans,
|
||||
UserAgent: &domain.UserAgent{
|
||||
FingerprintID: gu.Ptr("fp1"),
|
||||
IP: net.ParseIP("1.2.3.4"),
|
||||
Description: gu.Ptr("firefox"),
|
||||
Header: http.Header{"foo": []string{"bar"}},
|
||||
},
|
||||
Reason: domain.TokenReasonAuthRequest,
|
||||
RefreshToken: "VjJfb2lkY1Nlc3Npb25JRC1ydF9yZWZyZXNoVG9rZW5JRDp1c2VySUQ", //V2_oidcSessionID-rt_refreshTokenID:userID
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore(t),
|
||||
idGenerator: tt.fields.idGenerator,
|
||||
defaultAccessTokenLifetime: tt.fields.defaultAccessTokenLifetime,
|
||||
defaultRefreshTokenLifetime: tt.fields.defaultRefreshTokenLifetime,
|
||||
defaultRefreshTokenIdleLifetime: tt.fields.defaultRefreshTokenIdleLifetime,
|
||||
keyAlgorithm: tt.fields.keyAlgorithm,
|
||||
}
|
||||
got, err := c.CreateOIDCSessionFromDeviceAuth(tt.args.ctx, tt.args.deviceCode)
|
||||
c.jobs.Wait()
|
||||
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
|
||||
if got != nil {
|
||||
assert.WithinRange(t, got.AuthTime, tt.want.AuthTime.Add(-time.Second), tt.want.AuthTime.Add(time.Second))
|
||||
got.AuthTime = time.Time{}
|
||||
tt.want.AuthTime = time.Time{}
|
||||
}
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,9 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/activity"
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
@@ -17,6 +19,7 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/repository/authrequest"
|
||||
"github.com/zitadel/zitadel/internal/repository/oidcsession"
|
||||
"github.com/zitadel/zitadel/internal/repository/user"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
@@ -28,60 +31,175 @@ const (
|
||||
oidcTokenFormat = "%s" + oidcTokenSubjectDelimiter + "%s"
|
||||
)
|
||||
|
||||
// AddOIDCSessionAccessToken creates a new OIDC Session, creates an access token and returns its id and expiration.
|
||||
// If the underlying [AuthRequest] is a OIDC Auth Code Flow, it will set the code as exchanged.
|
||||
func (c *Commands) AddOIDCSessionAccessToken(ctx context.Context, authRequestID string) (string, time.Time, error) {
|
||||
cmd, err := c.newOIDCSessionAddEvents(ctx, authRequestID)
|
||||
if err != nil {
|
||||
return "", time.Time{}, err
|
||||
}
|
||||
cmd.AddSession(ctx)
|
||||
if err = cmd.AddAccessToken(ctx, cmd.authRequestWriteModel.Scope, domain.TokenReasonAuthRequest, nil); err != nil {
|
||||
return "", time.Time{}, err
|
||||
}
|
||||
cmd.SetAuthRequestSuccessful(ctx)
|
||||
accessTokenID, _, accessTokenExpiration, err := cmd.PushEvents(ctx)
|
||||
return accessTokenID, accessTokenExpiration, err
|
||||
type OIDCSession struct {
|
||||
SessionID string
|
||||
TokenID string
|
||||
ClientID string
|
||||
UserID string
|
||||
Audience []string
|
||||
Expiration time.Time
|
||||
Scope []string
|
||||
AuthMethods []domain.UserAuthMethodType
|
||||
AuthTime time.Time
|
||||
Nonce string
|
||||
PreferredLanguage *language.Tag
|
||||
UserAgent *domain.UserAgent
|
||||
Reason domain.TokenReason
|
||||
Actor *domain.TokenActor
|
||||
RefreshToken string
|
||||
}
|
||||
|
||||
// AddOIDCSessionRefreshAndAccessToken creates a new OIDC Session, creates an access token and refresh token.
|
||||
type AuthRequestComplianceChecker func(context.Context, *AuthRequestWriteModel) error
|
||||
|
||||
// CreateOIDCSessionFromAuthRequest creates a new OIDC Session, creates an access token and refresh token.
|
||||
// It returns the access token id, expiration and the refresh token.
|
||||
// If the underlying [AuthRequest] is a OIDC Auth Code Flow, it will set the code as exchanged.
|
||||
func (c *Commands) AddOIDCSessionRefreshAndAccessToken(ctx context.Context, authRequestID string) (tokenID, refreshToken string, tokenExpiration time.Time, err error) {
|
||||
cmd, err := c.newOIDCSessionAddEvents(ctx, authRequestID)
|
||||
func (c *Commands) CreateOIDCSessionFromAuthRequest(ctx context.Context, authReqId string, complianceCheck AuthRequestComplianceChecker, needRefreshToken bool) (session *OIDCSession, state string, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
if authReqId == "" {
|
||||
return nil, "", zerrors.ThrowPreconditionFailed(nil, "COMMAND-Sf3g2", "Errors.AuthRequest.InvalidCode")
|
||||
}
|
||||
|
||||
authReqModel, err := c.getAuthRequestWriteModel(ctx, authReqId)
|
||||
if err != nil {
|
||||
return "", "", time.Time{}, err
|
||||
return nil, "", err
|
||||
}
|
||||
cmd.AddSession(ctx)
|
||||
if err = cmd.AddAccessToken(ctx, cmd.authRequestWriteModel.Scope, domain.TokenReasonAuthRequest, nil); err != nil {
|
||||
return "", "", time.Time{}, err
|
||||
|
||||
if authReqModel.ResponseType == domain.OIDCResponseTypeCode && authReqModel.AuthRequestState != domain.AuthRequestStateCodeAdded {
|
||||
return nil, "", zerrors.ThrowPreconditionFailed(nil, "COMMAND-Iung5", "Errors.AuthRequest.NoCode")
|
||||
}
|
||||
if err = cmd.AddRefreshToken(ctx); err != nil {
|
||||
return "", "", time.Time{}, err
|
||||
|
||||
sessionModel := NewSessionWriteModel(authReqModel.SessionID, authz.GetInstance(ctx).InstanceID())
|
||||
err = c.eventstore.FilterToQueryReducer(ctx, sessionModel)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if err = sessionModel.CheckIsActive(); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
cmd, err := c.newOIDCSessionAddEvents(ctx, sessionModel.UserResourceOwner)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if authReqModel.ResponseType == domain.OIDCResponseTypeCode {
|
||||
if err = cmd.SetAuthRequestCodeExchanged(ctx, authReqModel); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
}
|
||||
if err = complianceCheck(ctx, authReqModel); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
cmd.AddSession(ctx,
|
||||
sessionModel.UserID,
|
||||
sessionModel.UserResourceOwner,
|
||||
sessionModel.AggregateID,
|
||||
authReqModel.ClientID,
|
||||
authReqModel.Audience,
|
||||
authReqModel.Scope,
|
||||
authReqModel.AuthMethods,
|
||||
authReqModel.AuthTime,
|
||||
authReqModel.Nonce,
|
||||
sessionModel.PreferredLanguage,
|
||||
sessionModel.UserAgent,
|
||||
)
|
||||
|
||||
if authReqModel.ResponseType != domain.OIDCResponseTypeIDToken {
|
||||
if err = cmd.AddAccessToken(ctx, authReqModel.Scope, sessionModel.UserID, sessionModel.UserResourceOwner, domain.TokenReasonAuthRequest, nil); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
}
|
||||
if authReqModel.NeedRefreshToken && needRefreshToken {
|
||||
if err = cmd.AddRefreshToken(ctx, sessionModel.UserID); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
}
|
||||
cmd.SetAuthRequestSuccessful(ctx, authReqModel.aggregate)
|
||||
session, err = cmd.PushEvents(ctx)
|
||||
return session, authReqModel.State, err
|
||||
}
|
||||
|
||||
func (c *Commands) CreateOIDCSession(ctx context.Context,
|
||||
userID,
|
||||
resourceOwner,
|
||||
clientID string,
|
||||
scope,
|
||||
audience []string,
|
||||
authMethods []domain.UserAuthMethodType,
|
||||
authTime time.Time,
|
||||
nonce string,
|
||||
preferredLanguage *language.Tag,
|
||||
userAgent *domain.UserAgent,
|
||||
reason domain.TokenReason,
|
||||
actor *domain.TokenActor,
|
||||
needRefreshToken bool,
|
||||
) (session *OIDCSession, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
cmd, err := c.newOIDCSessionAddEvents(ctx, resourceOwner)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if reason == domain.TokenReasonImpersonation {
|
||||
if err := c.checkPermission(ctx, "impersonation", resourceOwner, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cmd.UserImpersonated(ctx, userID, resourceOwner, clientID, actor)
|
||||
}
|
||||
|
||||
cmd.AddSession(ctx, userID, resourceOwner, "", clientID, audience, scope, authMethods, authTime, nonce, preferredLanguage, userAgent)
|
||||
if err = cmd.AddAccessToken(ctx, scope, userID, resourceOwner, reason, actor); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if needRefreshToken {
|
||||
if err = cmd.AddRefreshToken(ctx, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
cmd.SetAuthRequestSuccessful(ctx)
|
||||
return cmd.PushEvents(ctx)
|
||||
}
|
||||
|
||||
type RefreshTokenComplianceChecker func(ctx context.Context, wm *OIDCSessionWriteModel, requestedScope []string) (scope []string, err error)
|
||||
|
||||
// ExchangeOIDCSessionRefreshAndAccessToken updates an existing OIDC Session, creates a new access and refresh token.
|
||||
// It returns the access token id and expiration and the new refresh token.
|
||||
func (c *Commands) ExchangeOIDCSessionRefreshAndAccessToken(ctx context.Context, oidcSessionID, refreshToken string, scope []string) (tokenID, newRefreshToken string, tokenExpiration time.Time, err error) {
|
||||
cmd, err := c.newOIDCSessionUpdateEvents(ctx, oidcSessionID, refreshToken)
|
||||
func (c *Commands) ExchangeOIDCSessionRefreshAndAccessToken(ctx context.Context, refreshToken string, scope []string, complianceCheck RefreshTokenComplianceChecker) (_ *OIDCSession, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
cmd, err := c.newOIDCSessionUpdateEvents(ctx, refreshToken)
|
||||
if err != nil {
|
||||
return "", "", time.Time{}, err
|
||||
return nil, err
|
||||
}
|
||||
if err = cmd.AddAccessToken(ctx, scope, domain.TokenReasonRefresh, nil); err != nil {
|
||||
return "", "", time.Time{}, err
|
||||
scope, err = complianceCheck(ctx, cmd.oidcSessionWriteModel, scope)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = cmd.AddAccessToken(ctx, scope,
|
||||
cmd.oidcSessionWriteModel.UserID,
|
||||
cmd.oidcSessionWriteModel.UserResourceOwner,
|
||||
domain.TokenReasonRefresh,
|
||||
cmd.oidcSessionWriteModel.AccessTokenActor,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = cmd.RenewRefreshToken(ctx); err != nil {
|
||||
return "", "", time.Time{}, err
|
||||
return nil, err
|
||||
}
|
||||
return cmd.PushEvents(ctx)
|
||||
}
|
||||
|
||||
// OIDCSessionByRefreshToken computes the current state of an existing OIDCSession by a refresh_token (to start a Refresh Token Grant).
|
||||
// If either the session is not active, the token is invalid or expired (incl. idle expiration) an invalid refresh token error will be returned.
|
||||
func (c *Commands) OIDCSessionByRefreshToken(ctx context.Context, refreshToken string) (*OIDCSessionWriteModel, error) {
|
||||
func (c *Commands) OIDCSessionByRefreshToken(ctx context.Context, refreshToken string) (_ *OIDCSessionWriteModel, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
oidcSessionID, refreshTokenID, err := parseRefreshToken(refreshToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -146,26 +264,7 @@ func (c *Commands) RevokeOIDCSessionToken(ctx context.Context, token, clientID s
|
||||
return c.pushAppendAndReduce(ctx, writeModel, oidcsession.NewAccessTokenRevokedEvent(ctx, writeModel.aggregate))
|
||||
}
|
||||
|
||||
func (c *Commands) newOIDCSessionAddEvents(ctx context.Context, authRequestID string) (*OIDCSessionEvents, error) {
|
||||
authRequestWriteModel, err := c.getAuthRequestWriteModel(ctx, authRequestID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = authRequestWriteModel.CheckAuthenticated(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sessionWriteModel := NewSessionWriteModel(authRequestWriteModel.SessionID, authz.GetInstance(ctx).InstanceID())
|
||||
err = c.eventstore.FilterToQueryReducer(ctx, sessionWriteModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = sessionWriteModel.CheckIsActive(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resourceOwner, err := c.getResourceOwnerOfSessionUser(ctx, sessionWriteModel.UserID, sessionWriteModel.InstanceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
func (c *Commands) newOIDCSessionAddEvents(ctx context.Context, resourceOwner string, pending ...eventstore.Command) (*OIDCSessionEvents, error) {
|
||||
accessTokenLifetime, refreshTokenLifeTime, refreshTokenIdleLifetime, err := c.tokenTokenLifetimes(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -179,42 +278,24 @@ func (c *Commands) newOIDCSessionAddEvents(ctx context.Context, authRequestID st
|
||||
eventstore: c.eventstore,
|
||||
idGenerator: c.idGenerator,
|
||||
encryptionAlg: c.keyAlgorithm,
|
||||
events: pending,
|
||||
oidcSessionWriteModel: NewOIDCSessionWriteModel(sessionID, resourceOwner),
|
||||
sessionWriteModel: sessionWriteModel,
|
||||
authRequestWriteModel: authRequestWriteModel,
|
||||
accessTokenLifetime: accessTokenLifetime,
|
||||
refreshTokenLifeTime: refreshTokenLifeTime,
|
||||
refreshTokenIdleLifetime: refreshTokenIdleLifetime,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Commands) getResourceOwnerOfSessionUser(ctx context.Context, userID, instanceID string) (string, error) {
|
||||
events, err := c.eventstore.Filter(ctx, eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
|
||||
InstanceID(instanceID).
|
||||
AllowTimeTravel().
|
||||
OrderAsc().
|
||||
Limit(1).
|
||||
AddQuery().
|
||||
AggregateTypes(user.AggregateType).
|
||||
AggregateIDs(userID).
|
||||
Builder())
|
||||
if err != nil || len(events) != 1 {
|
||||
return "", zerrors.ThrowInternal(err, "OIDCS-sferh", "Errors.Internal")
|
||||
}
|
||||
return events[0].Aggregate().ResourceOwner, nil
|
||||
}
|
||||
|
||||
func (c *Commands) decryptRefreshToken(refreshToken string) (refreshTokenID string, err error) {
|
||||
func (c *Commands) decryptRefreshToken(refreshToken string) (sessionID, refreshTokenID string, err error) {
|
||||
decoded, err := base64.RawURLEncoding.DecodeString(refreshToken)
|
||||
if err != nil {
|
||||
return "", zerrors.ThrowInvalidArgument(err, "OIDCS-Cux9a", "Errors.User.RefreshToken.Invalid")
|
||||
return "", "", zerrors.ThrowInvalidArgument(err, "OIDCS-Cux9a", "Errors.User.RefreshToken.Invalid")
|
||||
}
|
||||
decrypted, err := c.keyAlgorithm.DecryptString(decoded, c.keyAlgorithm.EncryptionKeyID())
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", "", err
|
||||
}
|
||||
_, refreshTokenID, err = parseRefreshToken(decrypted)
|
||||
return refreshTokenID, err
|
||||
return parseRefreshToken(decrypted)
|
||||
}
|
||||
|
||||
func parseRefreshToken(refreshToken string) (oidcSessionID, refreshTokenID string, err error) {
|
||||
@@ -227,8 +308,8 @@ func parseRefreshToken(refreshToken string) (oidcSessionID, refreshTokenID strin
|
||||
return split[0], strings.Split(split[1], oidcTokenSubjectDelimiter)[0], nil
|
||||
}
|
||||
|
||||
func (c *Commands) newOIDCSessionUpdateEvents(ctx context.Context, oidcSessionID, refreshToken string) (*OIDCSessionEvents, error) {
|
||||
refreshTokenID, err := c.decryptRefreshToken(refreshToken)
|
||||
func (c *Commands) newOIDCSessionUpdateEvents(ctx context.Context, refreshToken string) (*OIDCSessionEvents, error) {
|
||||
oidcSessionID, refreshTokenID, err := c.decryptRefreshToken(refreshToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -255,13 +336,12 @@ func (c *Commands) newOIDCSessionUpdateEvents(ctx context.Context, oidcSessionID
|
||||
}
|
||||
|
||||
type OIDCSessionEvents struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
idGenerator id.Generator
|
||||
encryptionAlg crypto.EncryptionAlgorithm
|
||||
events []eventstore.Command
|
||||
oidcSessionWriteModel *OIDCSessionWriteModel
|
||||
sessionWriteModel *SessionWriteModel
|
||||
authRequestWriteModel *AuthRequestWriteModel
|
||||
eventstore *eventstore.Eventstore
|
||||
idGenerator id.Generator
|
||||
encryptionAlg crypto.EncryptionAlgorithm
|
||||
events []eventstore.Command
|
||||
oidcSessionWriteModel *OIDCSessionWriteModel
|
||||
|
||||
accessTokenLifetime time.Duration
|
||||
refreshTokenLifeTime time.Duration
|
||||
refreshTokenIdleLifetime time.Duration
|
||||
@@ -270,44 +350,75 @@ type OIDCSessionEvents struct {
|
||||
accessTokenID string
|
||||
|
||||
// refreshToken is set by the command
|
||||
refreshToken string
|
||||
refreshTokenID string
|
||||
refreshToken string
|
||||
}
|
||||
|
||||
func (c *OIDCSessionEvents) AddSession(ctx context.Context) {
|
||||
func (c *OIDCSessionEvents) AddSession(
|
||||
ctx context.Context,
|
||||
userID,
|
||||
userResourceOwner,
|
||||
sessionID,
|
||||
clientID string,
|
||||
audience,
|
||||
scope []string,
|
||||
authMethods []domain.UserAuthMethodType,
|
||||
authTime time.Time,
|
||||
nonce string,
|
||||
preferredLanguage *language.Tag,
|
||||
userAgent *domain.UserAgent,
|
||||
) {
|
||||
c.events = append(c.events, oidcsession.NewAddedEvent(
|
||||
ctx,
|
||||
c.oidcSessionWriteModel.aggregate,
|
||||
c.sessionWriteModel.UserID,
|
||||
c.sessionWriteModel.AggregateID,
|
||||
c.authRequestWriteModel.ClientID,
|
||||
c.authRequestWriteModel.Audience,
|
||||
c.authRequestWriteModel.Scope,
|
||||
c.sessionWriteModel.AuthMethodTypes(),
|
||||
c.sessionWriteModel.AuthenticationTime(),
|
||||
userID,
|
||||
userResourceOwner,
|
||||
sessionID,
|
||||
clientID,
|
||||
audience,
|
||||
scope,
|
||||
authMethods,
|
||||
authTime,
|
||||
nonce,
|
||||
preferredLanguage,
|
||||
userAgent,
|
||||
))
|
||||
}
|
||||
|
||||
func (c *OIDCSessionEvents) SetAuthRequestSuccessful(ctx context.Context) {
|
||||
c.events = append(c.events, authrequest.NewSucceededEvent(ctx, c.authRequestWriteModel.aggregate))
|
||||
func (c *OIDCSessionEvents) SetAuthRequestCodeExchanged(ctx context.Context, model *AuthRequestWriteModel) error {
|
||||
event := authrequest.NewCodeExchangedEvent(ctx, model.aggregate)
|
||||
model.AppendEvents(event)
|
||||
c.events = append(c.events, event)
|
||||
return model.Reduce()
|
||||
}
|
||||
|
||||
func (c *OIDCSessionEvents) AddAccessToken(ctx context.Context, scope []string, reason domain.TokenReason, actor *domain.TokenActor) error {
|
||||
func (c *OIDCSessionEvents) SetAuthRequestSuccessful(ctx context.Context, authRequestAggregate *eventstore.Aggregate) {
|
||||
c.events = append(c.events, authrequest.NewSucceededEvent(ctx, authRequestAggregate))
|
||||
}
|
||||
|
||||
func (c *OIDCSessionEvents) SetAuthRequestFailed(ctx context.Context, authRequestAggregate *eventstore.Aggregate, err error) {
|
||||
c.events = append(c.events, authrequest.NewFailedEvent(ctx, authRequestAggregate, domain.OIDCErrorReasonFromError(err)))
|
||||
}
|
||||
|
||||
func (c *OIDCSessionEvents) AddAccessToken(ctx context.Context, scope []string, userID, resourceOwner string, reason domain.TokenReason, actor *domain.TokenActor) error {
|
||||
accessTokenID, err := c.idGenerator.Next()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.accessTokenID = AccessTokenPrefix + accessTokenID
|
||||
c.events = append(c.events, oidcsession.NewAccessTokenAddedEvent(ctx, c.oidcSessionWriteModel.aggregate, c.accessTokenID, scope, c.accessTokenLifetime, reason, actor))
|
||||
c.events = append(c.events,
|
||||
oidcsession.NewAccessTokenAddedEvent(ctx, c.oidcSessionWriteModel.aggregate, c.accessTokenID, scope, c.accessTokenLifetime, reason, actor),
|
||||
user.NewUserTokenV2AddedEvent(ctx, &user.NewAggregate(userID, resourceOwner).Aggregate, c.accessTokenID), // for user audit log
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OIDCSessionEvents) AddRefreshToken(ctx context.Context) (err error) {
|
||||
var refreshTokenID string
|
||||
refreshTokenID, c.refreshToken, err = c.generateRefreshToken(c.sessionWriteModel.UserID)
|
||||
func (c *OIDCSessionEvents) AddRefreshToken(ctx context.Context, userID string) (err error) {
|
||||
c.refreshTokenID, c.refreshToken, err = c.generateRefreshToken(userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.events = append(c.events, oidcsession.NewRefreshTokenAddedEvent(ctx, c.oidcSessionWriteModel.aggregate, refreshTokenID, c.refreshTokenLifeTime, c.refreshTokenIdleLifetime))
|
||||
c.events = append(c.events, oidcsession.NewRefreshTokenAddedEvent(ctx, c.oidcSessionWriteModel.aggregate, c.refreshTokenID, c.refreshTokenLifeTime, c.refreshTokenIdleLifetime))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -321,6 +432,10 @@ func (c *OIDCSessionEvents) RenewRefreshToken(ctx context.Context) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OIDCSessionEvents) UserImpersonated(ctx context.Context, userID, resourceOwner, clientID string, actor *domain.TokenActor) {
|
||||
c.events = append(c.events, user.NewUserImpersonatedEvent(ctx, &user.NewAggregate(userID, resourceOwner).Aggregate, clientID, actor))
|
||||
}
|
||||
|
||||
func (c *OIDCSessionEvents) generateRefreshToken(userID string) (refreshTokenID, refreshToken string, err error) {
|
||||
refreshTokenID, err = c.idGenerator.Next()
|
||||
if err != nil {
|
||||
@@ -334,18 +449,38 @@ func (c *OIDCSessionEvents) generateRefreshToken(userID string) (refreshTokenID,
|
||||
return refreshTokenID, base64.RawURLEncoding.EncodeToString(token), nil
|
||||
}
|
||||
|
||||
func (c *OIDCSessionEvents) PushEvents(ctx context.Context) (accessTokenID string, refreshToken string, accessTokenExpiration time.Time, err error) {
|
||||
func (c *OIDCSessionEvents) PushEvents(ctx context.Context) (*OIDCSession, error) {
|
||||
pushedEvents, err := c.eventstore.Push(ctx, c.events...)
|
||||
if err != nil {
|
||||
return "", "", time.Time{}, err
|
||||
return nil, err
|
||||
}
|
||||
err = AppendAndReduce(c.oidcSessionWriteModel, pushedEvents...)
|
||||
if err != nil {
|
||||
return "", "", time.Time{}, err
|
||||
return nil, err
|
||||
}
|
||||
// prefix the returned id with the oidcSessionID so that we can retrieve it later on
|
||||
// we need to use `-` as a delimiter because the OIDC library uses `:` and will check for a length of 2 parts
|
||||
return c.oidcSessionWriteModel.AggregateID + TokenDelimiter + c.accessTokenID, c.refreshToken, c.oidcSessionWriteModel.AccessTokenExpiration, nil
|
||||
session := &OIDCSession{
|
||||
SessionID: c.oidcSessionWriteModel.SessionID,
|
||||
ClientID: c.oidcSessionWriteModel.ClientID,
|
||||
UserID: c.oidcSessionWriteModel.UserID,
|
||||
Audience: c.oidcSessionWriteModel.Audience,
|
||||
Expiration: c.oidcSessionWriteModel.AccessTokenExpiration,
|
||||
Scope: c.oidcSessionWriteModel.Scope,
|
||||
AuthMethods: c.oidcSessionWriteModel.AuthMethods,
|
||||
AuthTime: c.oidcSessionWriteModel.AuthTime,
|
||||
Nonce: c.oidcSessionWriteModel.Nonce,
|
||||
PreferredLanguage: c.oidcSessionWriteModel.PreferredLanguage,
|
||||
UserAgent: c.oidcSessionWriteModel.UserAgent,
|
||||
Reason: c.oidcSessionWriteModel.AccessTokenReason,
|
||||
Actor: c.oidcSessionWriteModel.AccessTokenActor,
|
||||
RefreshToken: c.refreshToken,
|
||||
}
|
||||
if c.accessTokenID != "" {
|
||||
// prefix the returned id with the oidcSessionID so that we can retrieve it later on
|
||||
// we need to use `-` as a delimiter because the OIDC library uses `:` and will check for a length of 2 parts
|
||||
session.TokenID = c.oidcSessionWriteModel.AggregateID + TokenDelimiter + c.accessTokenID
|
||||
}
|
||||
activity.Trigger(ctx, c.oidcSessionWriteModel.UserResourceOwner, c.oidcSessionWriteModel.UserID, tokenReasonToActivityMethodType(c.oidcSessionWriteModel.AccessTokenReason), c.eventstore.FilterToQueryReducer)
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (c *Commands) tokenTokenLifetimes(ctx context.Context) (accessTokenLifetime time.Duration, refreshTokenLifetime time.Duration, refreshTokenIdleLifetime time.Duration, err error) {
|
||||
@@ -368,3 +503,14 @@ func (c *Commands) tokenTokenLifetimes(ctx context.Context) (accessTokenLifetime
|
||||
}
|
||||
return accessTokenLifetime, refreshTokenLifetime, refreshTokenIdleLifetime, nil
|
||||
}
|
||||
|
||||
func tokenReasonToActivityMethodType(r domain.TokenReason) activity.TriggerMethod {
|
||||
if r == domain.TokenReasonUnspecified {
|
||||
return activity.Unspecified
|
||||
}
|
||||
if r == domain.TokenReasonRefresh {
|
||||
return activity.OIDCRefreshToken
|
||||
}
|
||||
// all other reasons result in an access token
|
||||
return activity.OIDCAccessToken
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ package command
|
||||
import (
|
||||
"time"
|
||||
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/repository/oidcsession"
|
||||
@@ -13,12 +15,16 @@ type OIDCSessionWriteModel struct {
|
||||
eventstore.WriteModel
|
||||
|
||||
UserID string
|
||||
UserResourceOwner string
|
||||
PreferredLanguage *language.Tag
|
||||
SessionID string
|
||||
ClientID string
|
||||
Audience []string
|
||||
Scope []string
|
||||
AuthMethods []domain.UserAuthMethodType
|
||||
AuthTime time.Time
|
||||
Nonce string
|
||||
UserAgent *domain.UserAgent
|
||||
State domain.OIDCSessionState
|
||||
AccessTokenID string
|
||||
AccessTokenCreation time.Time
|
||||
@@ -85,12 +91,16 @@ func (wm *OIDCSessionWriteModel) Query() *eventstore.SearchQueryBuilder {
|
||||
|
||||
func (wm *OIDCSessionWriteModel) reduceAdded(e *oidcsession.AddedEvent) {
|
||||
wm.UserID = e.UserID
|
||||
wm.UserResourceOwner = e.UserResourceOwner
|
||||
wm.SessionID = e.SessionID
|
||||
wm.ClientID = e.ClientID
|
||||
wm.Audience = e.Audience
|
||||
wm.Scope = e.Scope
|
||||
wm.AuthMethods = e.AuthMethods
|
||||
wm.AuthTime = e.AuthTime
|
||||
wm.Nonce = e.Nonce
|
||||
wm.PreferredLanguage = e.PreferredLanguage
|
||||
wm.UserAgent = e.UserAgent
|
||||
wm.State = domain.OIDCSessionStateActive
|
||||
// the write model might be initialized without resource owner,
|
||||
// so update the aggregate
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -7,6 +7,8 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/activity"
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
@@ -56,12 +58,12 @@ func (c *Commands) NewSessionCommands(cmds []SessionCommand, session *SessionWri
|
||||
}
|
||||
|
||||
// CheckUser defines a user check to be executed for a session update
|
||||
func CheckUser(id string, resourceOwner string) SessionCommand {
|
||||
func CheckUser(id string, resourceOwner string, preferredLanguage *language.Tag) SessionCommand {
|
||||
return func(ctx context.Context, cmd *SessionCommands) error {
|
||||
if cmd.sessionWriteModel.UserID != "" && id != "" && cmd.sessionWriteModel.UserID != id {
|
||||
return zerrors.ThrowInvalidArgument(nil, "", "user change not possible")
|
||||
}
|
||||
return cmd.UserChecked(ctx, id, resourceOwner, cmd.now())
|
||||
return cmd.UserChecked(ctx, id, resourceOwner, cmd.now(), preferredLanguage)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -171,8 +173,8 @@ func (s *SessionCommands) Start(ctx context.Context, userAgent *domain.UserAgent
|
||||
s.eventCommands = append(s.eventCommands, session.NewAddedEvent(ctx, s.sessionWriteModel.aggregate, userAgent))
|
||||
}
|
||||
|
||||
func (s *SessionCommands) UserChecked(ctx context.Context, userID, resourceOwner string, checkedAt time.Time) error {
|
||||
s.eventCommands = append(s.eventCommands, session.NewUserCheckedEvent(ctx, s.sessionWriteModel.aggregate, userID, resourceOwner, checkedAt))
|
||||
func (s *SessionCommands) UserChecked(ctx context.Context, userID, resourceOwner string, checkedAt time.Time, preferredLanguage *language.Tag) error {
|
||||
s.eventCommands = append(s.eventCommands, session.NewUserCheckedEvent(ctx, s.sessionWriteModel.aggregate, userID, resourceOwner, checkedAt, preferredLanguage))
|
||||
// set the userID so other checks can use it
|
||||
s.sessionWriteModel.UserID = userID
|
||||
s.sessionWriteModel.UserResourceOwner = resourceOwner
|
||||
|
||||
@@ -3,6 +3,8 @@ package command
|
||||
import (
|
||||
"time"
|
||||
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
@@ -40,6 +42,7 @@ type SessionWriteModel struct {
|
||||
TokenID string
|
||||
UserID string
|
||||
UserResourceOwner string
|
||||
PreferredLanguage *language.Tag
|
||||
UserCheckedAt time.Time
|
||||
PasswordCheckedAt time.Time
|
||||
IntentCheckedAt time.Time
|
||||
@@ -50,6 +53,7 @@ type SessionWriteModel struct {
|
||||
WebAuthNUserVerified bool
|
||||
Metadata map[string][]byte
|
||||
State domain.SessionState
|
||||
UserAgent *domain.UserAgent
|
||||
Expiration time.Time
|
||||
|
||||
WebAuthNChallenge *WebAuthNChallengeModel
|
||||
@@ -137,12 +141,14 @@ func (wm *SessionWriteModel) Query() *eventstore.SearchQueryBuilder {
|
||||
|
||||
func (wm *SessionWriteModel) reduceAdded(e *session.AddedEvent) {
|
||||
wm.State = domain.SessionStateActive
|
||||
wm.UserAgent = e.UserAgent
|
||||
}
|
||||
|
||||
func (wm *SessionWriteModel) reduceUserChecked(e *session.UserCheckedEvent) {
|
||||
wm.UserID = e.UserID
|
||||
wm.UserResourceOwner = e.UserResourceOwner
|
||||
wm.UserCheckedAt = e.CheckedAt
|
||||
wm.PreferredLanguage = e.PreferredLanguage
|
||||
}
|
||||
|
||||
func (wm *SessionWriteModel) reducePasswordChecked(e *session.PasswordCheckedEvent) {
|
||||
|
||||
@@ -566,7 +566,7 @@ func TestCommands_updateSession(t *testing.T) {
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectPush(
|
||||
session.NewUserCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
"userID", "org1", testNow,
|
||||
"userID", "org1", testNow, &language.Afrikaans,
|
||||
),
|
||||
session.NewPasswordCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
testNow,
|
||||
@@ -585,7 +585,7 @@ func TestCommands_updateSession(t *testing.T) {
|
||||
checks: &SessionCommands{
|
||||
sessionWriteModel: NewSessionWriteModel("sessionID", "instance1"),
|
||||
sessionCommands: []SessionCommand{
|
||||
CheckUser("userID", "org1"),
|
||||
CheckUser("userID", "org1", &language.Afrikaans),
|
||||
CheckPassword("password"),
|
||||
},
|
||||
eventstore: eventstoreExpect(t,
|
||||
@@ -634,7 +634,7 @@ func TestCommands_updateSession(t *testing.T) {
|
||||
checks: &SessionCommands{
|
||||
sessionWriteModel: NewSessionWriteModel("sessionID", "instance1"),
|
||||
sessionCommands: []SessionCommand{
|
||||
CheckUser("userID", "org1"),
|
||||
CheckUser("userID", "org1", &language.Afrikaans),
|
||||
CheckIntent("intent", "aW50ZW50"),
|
||||
},
|
||||
eventstore: eventstoreExpect(t,
|
||||
@@ -673,7 +673,7 @@ func TestCommands_updateSession(t *testing.T) {
|
||||
checks: &SessionCommands{
|
||||
sessionWriteModel: NewSessionWriteModel("sessionID", "instance1"),
|
||||
sessionCommands: []SessionCommand{
|
||||
CheckUser("userID", "org1"),
|
||||
CheckUser("userID", "org1", &language.Afrikaans),
|
||||
CheckIntent("intent", "aW50ZW50"),
|
||||
},
|
||||
eventstore: eventstoreExpect(t,
|
||||
@@ -723,7 +723,7 @@ func TestCommands_updateSession(t *testing.T) {
|
||||
checks: &SessionCommands{
|
||||
sessionWriteModel: NewSessionWriteModel("sessionID", "instance1"),
|
||||
sessionCommands: []SessionCommand{
|
||||
CheckUser("userID", "org1"),
|
||||
CheckUser("userID", "org1", &language.Afrikaans),
|
||||
CheckIntent("intent2", "aW50ZW50"),
|
||||
},
|
||||
eventstore: eventstoreExpect(t),
|
||||
@@ -751,7 +751,7 @@ func TestCommands_updateSession(t *testing.T) {
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectPush(
|
||||
session.NewUserCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
"userID", "org1", testNow),
|
||||
"userID", "org1", testNow, &language.Afrikaans),
|
||||
session.NewIntentCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
testNow),
|
||||
session.NewMetadataSetEvent(context.Background(), &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
@@ -766,7 +766,7 @@ func TestCommands_updateSession(t *testing.T) {
|
||||
checks: &SessionCommands{
|
||||
sessionWriteModel: NewSessionWriteModel("sessionID", "instance1"),
|
||||
sessionCommands: []SessionCommand{
|
||||
CheckUser("userID", "org1"),
|
||||
CheckUser("userID", "org1", &language.Afrikaans),
|
||||
CheckIntent("intent", "aW50ZW50"),
|
||||
},
|
||||
eventstore: eventstoreExpect(t,
|
||||
@@ -1188,7 +1188,7 @@ func TestCommands_TerminateSession(t *testing.T) {
|
||||
),
|
||||
eventFromEventPusher(
|
||||
session.NewUserCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
"user1", "org1", testNow),
|
||||
"user1", "org1", testNow, &language.Afrikaans),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
@@ -1229,7 +1229,7 @@ func TestCommands_TerminateSession(t *testing.T) {
|
||||
),
|
||||
eventFromEventPusher(
|
||||
session.NewUserCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
"userID", "org1", testNow),
|
||||
"userID", "org1", testNow, &language.Afrikaans),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
@@ -1271,7 +1271,7 @@ func TestCommands_TerminateSession(t *testing.T) {
|
||||
),
|
||||
eventFromEventPusher(
|
||||
session.NewUserCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "org2").Aggregate,
|
||||
"userID", "", testNow),
|
||||
"userID", "", testNow, &language.Afrikaans),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org2").Aggregate,
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
@@ -13,7 +12,6 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/v1/models"
|
||||
"github.com/zitadel/zitadel/internal/repository/user"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
@@ -232,35 +230,6 @@ func (c *Commands) RemoveUser(ctx context.Context, userID, resourceOwner string,
|
||||
return writeModelToObjectDetails(&existingUser.WriteModel), nil
|
||||
}
|
||||
|
||||
func (c *Commands) AddUserToken(
|
||||
ctx context.Context,
|
||||
orgID,
|
||||
agentID,
|
||||
clientID,
|
||||
userID string,
|
||||
audience,
|
||||
scopes,
|
||||
authMethodsReferences []string,
|
||||
lifetime time.Duration,
|
||||
authTime time.Time,
|
||||
reason domain.TokenReason,
|
||||
actor *domain.TokenActor,
|
||||
) (*domain.Token, error) {
|
||||
if userID == "" { //do not check for empty orgID (JWT Profile requests won't provide it, so service user requests fail)
|
||||
return nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-Dbge4", "Errors.IDMissing")
|
||||
}
|
||||
userWriteModel := NewUserWriteModel(userID, orgID)
|
||||
cmds, accessToken, err := c.addUserToken(ctx, userWriteModel, agentID, clientID, "", audience, scopes, authMethodsReferences, lifetime, authTime, reason, actor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, err = c.eventstore.Push(ctx, cmds...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func (c *Commands) RevokeAccessToken(ctx context.Context, userID, orgID, tokenID string) (*domain.ObjectDetails, error) {
|
||||
removeEvent, accessTokenWriteModel, err := c.removeAccessToken(ctx, userID, orgID, tokenID)
|
||||
if err != nil {
|
||||
@@ -277,61 +246,6 @@ func (c *Commands) RevokeAccessToken(ctx context.Context, userID, orgID, tokenID
|
||||
return writeModelToObjectDetails(&accessTokenWriteModel.WriteModel), nil
|
||||
}
|
||||
|
||||
func (c *Commands) addUserToken(ctx context.Context, userWriteModel *UserWriteModel, agentID, clientID, refreshTokenID string, audience, scopes, authMethodsReferences []string, lifetime time.Duration, authTime time.Time, reason domain.TokenReason, actor *domain.TokenActor) ([]eventstore.Command, *domain.Token, error) {
|
||||
err := c.eventstore.FilterToQueryReducer(ctx, userWriteModel)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if userWriteModel.UserState != domain.UserStateActive {
|
||||
return nil, nil, zerrors.ThrowNotFound(nil, "COMMAND-1d6Gg", "Errors.User.NotFound")
|
||||
}
|
||||
|
||||
//nolint:contextcheck
|
||||
userAgg := UserAggregateFromWriteModel(&userWriteModel.WriteModel)
|
||||
|
||||
var cmds []eventstore.Command
|
||||
if reason == domain.TokenReasonImpersonation {
|
||||
if err := c.checkPermission(ctx, "impersonation", userWriteModel.ResourceOwner, userWriteModel.AggregateID); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
cmds = append(cmds, user.NewUserImpersonatedEvent(ctx, userAgg, clientID, actor))
|
||||
}
|
||||
|
||||
preferredLanguage := ""
|
||||
existingHuman, err := c.getHumanWriteModelByID(ctx, userWriteModel.AggregateID, userWriteModel.ResourceOwner)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if existingHuman != nil {
|
||||
preferredLanguage = existingHuman.PreferredLanguage.String()
|
||||
}
|
||||
expiration := time.Now().UTC().Add(lifetime)
|
||||
tokenID, err := c.idGenerator.Next()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
cmds = append(cmds,
|
||||
user.NewUserTokenAddedEvent(ctx, userAgg, tokenID, clientID, agentID, preferredLanguage, refreshTokenID, audience, scopes, authMethodsReferences, authTime, expiration, reason, actor),
|
||||
)
|
||||
|
||||
return cmds, &domain.Token{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
AggregateID: userWriteModel.AggregateID,
|
||||
},
|
||||
TokenID: tokenID,
|
||||
UserAgentID: agentID,
|
||||
ApplicationID: clientID,
|
||||
RefreshTokenID: refreshTokenID,
|
||||
Audience: audience,
|
||||
Scopes: scopes,
|
||||
Expiration: expiration,
|
||||
PreferredLanguage: preferredLanguage,
|
||||
Reason: reason,
|
||||
Actor: actor,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Commands) removeAccessToken(ctx context.Context, userID, orgID, tokenID string) (*user.UserTokenRemovedEvent, *UserAccessTokenWriteModel, error) {
|
||||
if userID == "" || orgID == "" || tokenID == "" {
|
||||
return nil, nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-Dng42", "Errors.IDMissing")
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/repository/user"
|
||||
)
|
||||
@@ -81,16 +78,6 @@ func writeModelToAddress(wm *HumanAddressWriteModel) *domain.Address {
|
||||
}
|
||||
}
|
||||
|
||||
func writeModelToMachine(wm *MachineWriteModel) *domain.Machine {
|
||||
return &domain.Machine{
|
||||
ObjectRoot: writeModelToObjectRoot(wm.WriteModel),
|
||||
Username: wm.UserName,
|
||||
Name: wm.Name,
|
||||
Description: wm.Description,
|
||||
State: wm.UserState,
|
||||
}
|
||||
}
|
||||
|
||||
func keyWriteModelToMachineKey(wm *MachineKeyWriteModel) *domain.MachineKey {
|
||||
return &domain.MachineKey{
|
||||
ObjectRoot: writeModelToObjectRoot(wm.WriteModel),
|
||||
@@ -100,18 +87,6 @@ func keyWriteModelToMachineKey(wm *MachineKeyWriteModel) *domain.MachineKey {
|
||||
}
|
||||
}
|
||||
|
||||
func personalTokenWriteModelToToken(wm *PersonalAccessTokenWriteModel, algorithm crypto.EncryptionAlgorithm) (*domain.Token, string, error) {
|
||||
encrypted, err := algorithm.Encrypt([]byte(wm.TokenID + ":" + wm.AggregateID))
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return &domain.Token{
|
||||
ObjectRoot: writeModelToObjectRoot(wm.WriteModel),
|
||||
TokenID: wm.TokenID,
|
||||
Expiration: wm.ExpirationDate,
|
||||
}, base64.RawURLEncoding.EncodeToString(encrypted), nil
|
||||
}
|
||||
|
||||
func readModelToWebAuthNTokens(readModel HumanWebAuthNTokensReadModel) []*domain.WebAuthNToken {
|
||||
tokens := make([]*domain.WebAuthNToken, len(readModel.GetWebAuthNTokens()))
|
||||
for i, token := range readModel.GetWebAuthNTokens() {
|
||||
|
||||
@@ -2,7 +2,6 @@ package command
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
@@ -10,98 +9,6 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
func (c *Commands) AddAccessAndRefreshToken(
|
||||
ctx context.Context,
|
||||
orgID,
|
||||
agentID,
|
||||
clientID,
|
||||
userID,
|
||||
refreshToken string,
|
||||
audience,
|
||||
scopes,
|
||||
authMethodsReferences []string,
|
||||
accessLifetime,
|
||||
refreshIdleExpiration,
|
||||
refreshExpiration time.Duration,
|
||||
authTime time.Time,
|
||||
reason domain.TokenReason,
|
||||
actor *domain.TokenActor,
|
||||
) (accessToken *domain.Token, newRefreshToken string, err error) {
|
||||
if refreshToken == "" {
|
||||
return c.AddNewRefreshTokenAndAccessToken(ctx, userID, orgID, agentID, clientID, audience, scopes, authMethodsReferences, refreshExpiration, accessLifetime, refreshIdleExpiration, authTime, reason, actor)
|
||||
}
|
||||
return c.RenewRefreshTokenAndAccessToken(ctx, userID, orgID, refreshToken, agentID, clientID, audience, scopes, refreshIdleExpiration, accessLifetime, actor)
|
||||
}
|
||||
|
||||
func (c *Commands) AddNewRefreshTokenAndAccessToken(
|
||||
ctx context.Context,
|
||||
userID,
|
||||
orgID,
|
||||
agentID,
|
||||
clientID string,
|
||||
audience,
|
||||
scopes,
|
||||
authMethodsReferences []string,
|
||||
refreshExpiration,
|
||||
accessLifetime,
|
||||
refreshIdleExpiration time.Duration,
|
||||
authTime time.Time,
|
||||
reason domain.TokenReason,
|
||||
actor *domain.TokenActor,
|
||||
) (accessToken *domain.Token, newRefreshToken string, err error) {
|
||||
if userID == "" || clientID == "" {
|
||||
return nil, "", zerrors.ThrowInvalidArgument(nil, "COMMAND-adg4r", "Errors.IDMissing")
|
||||
}
|
||||
userWriteModel := NewUserWriteModel(userID, orgID)
|
||||
refreshTokenID, err := c.idGenerator.Next()
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
cmds, accessToken, err := c.addUserToken(ctx, userWriteModel, agentID, clientID, refreshTokenID, audience, scopes, authMethodsReferences, accessLifetime, authTime, reason, actor)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
refreshTokenEvent, newRefreshToken, err := c.addRefreshToken(ctx, accessToken, authMethodsReferences, authTime, refreshIdleExpiration, refreshExpiration, actor)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
cmds = append(cmds, refreshTokenEvent)
|
||||
_, err = c.eventstore.Push(ctx, cmds...)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return accessToken, newRefreshToken, nil
|
||||
}
|
||||
|
||||
func (c *Commands) RenewRefreshTokenAndAccessToken(
|
||||
ctx context.Context,
|
||||
userID,
|
||||
orgID,
|
||||
refreshToken,
|
||||
agentID,
|
||||
clientID string,
|
||||
audience,
|
||||
scopes []string,
|
||||
idleExpiration,
|
||||
accessLifetime time.Duration,
|
||||
actor *domain.TokenActor,
|
||||
) (accessToken *domain.Token, newRefreshToken string, err error) {
|
||||
renewed, err := c.renewRefreshToken(ctx, userID, orgID, refreshToken, idleExpiration)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
userWriteModel := NewUserWriteModel(userID, orgID)
|
||||
cmds, accessToken, err := c.addUserToken(ctx, userWriteModel, agentID, clientID, renewed.tokenID, audience, scopes, renewed.authMethodsReferences, accessLifetime, renewed.authTime, domain.TokenReasonRefresh, actor)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
_, err = c.eventstore.Push(ctx, append(cmds, renewed.event)...)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return accessToken, renewed.token, nil
|
||||
}
|
||||
|
||||
func (c *Commands) RevokeRefreshToken(ctx context.Context, userID, orgID, tokenID string) (*domain.ObjectDetails, error) {
|
||||
removeEvent, refreshTokenWriteModel, err := c.removeRefreshToken(ctx, userID, orgID, tokenID)
|
||||
if err != nil {
|
||||
@@ -134,70 +41,6 @@ func (c *Commands) RevokeRefreshTokens(ctx context.Context, userID, orgID string
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Commands) addRefreshToken(ctx context.Context, accessToken *domain.Token, authMethodsReferences []string, authTime time.Time, idleExpiration, expiration time.Duration, actor *domain.TokenActor) (*user.HumanRefreshTokenAddedEvent, string, error) {
|
||||
refreshToken, err := domain.NewRefreshToken(accessToken.AggregateID, accessToken.RefreshTokenID, c.keyAlgorithm)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
refreshTokenWriteModel := NewHumanRefreshTokenWriteModel(accessToken.AggregateID, accessToken.ResourceOwner, accessToken.RefreshTokenID)
|
||||
userAgg := UserAggregateFromWriteModel(&refreshTokenWriteModel.WriteModel)
|
||||
return user.NewHumanRefreshTokenAddedEvent(ctx, userAgg, accessToken.RefreshTokenID, accessToken.ApplicationID, accessToken.UserAgentID,
|
||||
accessToken.PreferredLanguage, accessToken.Audience, accessToken.Scopes, authMethodsReferences, authTime, idleExpiration, expiration, actor),
|
||||
refreshToken, nil
|
||||
}
|
||||
|
||||
type renewedRefreshToken struct {
|
||||
event *user.HumanRefreshTokenRenewedEvent
|
||||
authTime time.Time
|
||||
authMethodsReferences []string
|
||||
tokenID string
|
||||
token string
|
||||
}
|
||||
|
||||
func (c *Commands) renewRefreshToken(ctx context.Context, userID, orgID, refreshToken string, idleExpiration time.Duration) (*renewedRefreshToken, error) {
|
||||
if refreshToken == "" {
|
||||
return nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-DHrr3", "Errors.IDMissing")
|
||||
}
|
||||
|
||||
tokenUserID, tokenID, token, err := domain.FromRefreshToken(refreshToken, c.keyAlgorithm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tokenUserID != userID {
|
||||
return nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-Ht2g2", "Errors.User.RefreshToken.Invalid")
|
||||
}
|
||||
refreshTokenWriteModel := NewHumanRefreshTokenWriteModel(userID, orgID, tokenID)
|
||||
err = c.eventstore.FilterToQueryReducer(ctx, refreshTokenWriteModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if refreshTokenWriteModel.UserState != domain.UserStateActive {
|
||||
return nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-BHnhs", "Errors.User.RefreshToken.Invalid")
|
||||
}
|
||||
if refreshTokenWriteModel.RefreshToken != token ||
|
||||
refreshTokenWriteModel.IdleExpiration.Before(time.Now()) ||
|
||||
refreshTokenWriteModel.Expiration.Before(time.Now()) {
|
||||
return nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-Vr43e", "Errors.User.RefreshToken.Invalid")
|
||||
}
|
||||
|
||||
newToken, err := c.idGenerator.Next()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newRefreshToken, err := domain.RefreshToken(userID, tokenID, newToken, c.keyAlgorithm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userAgg := UserAggregateFromWriteModel(&refreshTokenWriteModel.WriteModel)
|
||||
return &renewedRefreshToken{
|
||||
event: user.NewHumanRefreshTokenRenewedEvent(ctx, userAgg, tokenID, newToken, idleExpiration),
|
||||
authTime: refreshTokenWriteModel.AuthTime,
|
||||
authMethodsReferences: refreshTokenWriteModel.AuthMethodsReferences,
|
||||
tokenID: tokenID,
|
||||
token: newRefreshToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Commands) removeRefreshToken(ctx context.Context, userID, orgID, tokenID string) (*user.HumanRefreshTokenRemovedEvent, *HumanRefreshTokenWriteModel, error) {
|
||||
if userID == "" || orgID == "" || tokenID == "" {
|
||||
return nil, nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-GVDgf", "Errors.IDMissing")
|
||||
|
||||
@@ -2,316 +2,18 @@ package command
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/v1/models"
|
||||
"github.com/zitadel/zitadel/internal/id"
|
||||
id_mock "github.com/zitadel/zitadel/internal/id/mock"
|
||||
"github.com/zitadel/zitadel/internal/repository/user"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
func TestCommands_AddAccessAndRefreshToken(t *testing.T) {
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
idGenerator id.Generator
|
||||
keyAlgorithm crypto.EncryptionAlgorithm
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
orgID string
|
||||
agentID string
|
||||
clientID string
|
||||
userID string
|
||||
refreshToken string
|
||||
audience []string
|
||||
scopes []string
|
||||
authMethodsReferences []string
|
||||
lifetime time.Duration
|
||||
authTime time.Time
|
||||
refreshIdleExpiration time.Duration
|
||||
refreshExpiration time.Duration
|
||||
reason domain.TokenReason
|
||||
actor *domain.TokenActor
|
||||
}
|
||||
type res struct {
|
||||
token *domain.Token
|
||||
refreshToken string
|
||||
err func(error) bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "missing ID, error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(t),
|
||||
},
|
||||
args: args{},
|
||||
res: res{
|
||||
err: zerrors.IsErrorInvalidArgument,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "add refresh token, user deactivated, error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
user.NewUserDeactivatedEvent(context.Background(),
|
||||
&user.NewAggregate("userID", "orgID").Aggregate,
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "refreshTokenID1"),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
orgID: "orgID",
|
||||
agentID: "agentID",
|
||||
userID: "userID",
|
||||
clientID: "clientID",
|
||||
},
|
||||
res: res{
|
||||
err: zerrors.IsNotFound,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "renew refresh token, invalid token, error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(t),
|
||||
keyAlgorithm: refreshTokenEncryptionAlgorithm(gomock.NewController(t)),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
refreshToken: "invalid",
|
||||
},
|
||||
res: res{
|
||||
err: zerrors.IsErrorInvalidArgument,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "renew refresh token, invalid token (invalid userID), error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(t),
|
||||
keyAlgorithm: refreshTokenEncryptionAlgorithm(gomock.NewController(t)),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
userID: "userID",
|
||||
orgID: "orgID",
|
||||
refreshToken: base64.RawURLEncoding.EncodeToString([]byte("userID2:tokenID:token")),
|
||||
},
|
||||
res: res{
|
||||
err: zerrors.IsErrorInvalidArgument,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "renew refresh token, token inactive, error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(user.NewHumanRefreshTokenAddedEvent(
|
||||
context.Background(),
|
||||
&user.NewAggregate("userID", "orgID").Aggregate,
|
||||
"tokenID",
|
||||
"applicationID",
|
||||
"userAgentID",
|
||||
"de",
|
||||
[]string{"clientID1"},
|
||||
[]string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
|
||||
[]string{"password"},
|
||||
time.Now(),
|
||||
1*time.Hour,
|
||||
24*time.Hour,
|
||||
nil,
|
||||
)),
|
||||
eventFromEventPusher(user.NewHumanRefreshTokenRemovedEvent(
|
||||
context.Background(),
|
||||
&user.NewAggregate("userID", "orgID").Aggregate,
|
||||
"tokenID",
|
||||
)),
|
||||
),
|
||||
),
|
||||
keyAlgorithm: refreshTokenEncryptionAlgorithm(gomock.NewController(t)),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
userID: "userID",
|
||||
orgID: "orgID",
|
||||
refreshToken: base64.RawURLEncoding.EncodeToString([]byte("userID:tokenID:token")),
|
||||
},
|
||||
res: res{
|
||||
err: zerrors.IsErrorInvalidArgument,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "renew refresh token, token expired, error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(user.NewHumanRefreshTokenAddedEvent(
|
||||
context.Background(),
|
||||
&user.NewAggregate("userID", "orgID").Aggregate,
|
||||
"tokenID",
|
||||
"applicationID",
|
||||
"userAgentID",
|
||||
"de",
|
||||
[]string{"clientID1"},
|
||||
[]string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
|
||||
[]string{"password"},
|
||||
time.Now(),
|
||||
-1*time.Hour,
|
||||
24*time.Hour,
|
||||
nil,
|
||||
)),
|
||||
),
|
||||
),
|
||||
keyAlgorithm: refreshTokenEncryptionAlgorithm(gomock.NewController(t)),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
userID: "userID",
|
||||
orgID: "orgID",
|
||||
refreshToken: base64.RawURLEncoding.EncodeToString([]byte("userID:tokenID:tokenID")),
|
||||
},
|
||||
res: res{
|
||||
err: zerrors.IsErrorInvalidArgument,
|
||||
},
|
||||
},
|
||||
//fails because of timestamp equality
|
||||
//{
|
||||
// name: "push failed, error",
|
||||
// fields: fields{
|
||||
// eventstore: eventstoreExpect(t,
|
||||
// expectFilter(
|
||||
// eventFromEventPusher(user.NewHumanAddedEvent(
|
||||
// context.Background(),
|
||||
// &user.NewAggregate("userID", "orgID").Aggregate,
|
||||
// "username",
|
||||
// "firstname",
|
||||
// "lastname",
|
||||
// "nickname",
|
||||
// "displayname",
|
||||
// language.German,
|
||||
// domain.GenderUnspecified,
|
||||
// "email",
|
||||
// true,
|
||||
// )),
|
||||
// ),
|
||||
// expectFilter(
|
||||
// eventFromEventPusherWithCreationDateNow(user.NewHumanAddedEvent(
|
||||
// context.Background(),
|
||||
// &user.NewAggregate("userID", "orgID").Aggregate,
|
||||
// "username",
|
||||
// "firstname",
|
||||
// "lastname",
|
||||
// "nickname",
|
||||
// "displayname",
|
||||
// language.German,
|
||||
// domain.GenderUnspecified,
|
||||
// "email",
|
||||
// true,
|
||||
// )),
|
||||
// ),
|
||||
// expectFilter(
|
||||
// eventFromEventPusherWithCreationDateNow(user.NewHumanRefreshTokenAddedEvent(
|
||||
// context.Background(),
|
||||
// &user.NewAggregate("userID", "orgID").Aggregate,
|
||||
// "tokenID",
|
||||
// "applicationID",
|
||||
// "userAgentID",
|
||||
// "de",
|
||||
// []string{"clientID1"},
|
||||
// []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
|
||||
// []string{"password"},
|
||||
// time.Now(),
|
||||
// 1*time.Hour,
|
||||
// 24*time.Hour,
|
||||
// )),
|
||||
// ),
|
||||
// expectPushFailed(
|
||||
// zerrors.ThrowInternal(nil, "ERROR", "internal"),
|
||||
// []*repository.Event{
|
||||
// eventFromEventPusher(user.NewUserTokenAddedEvent(
|
||||
// context.Background(),
|
||||
// &user.NewAggregate("userID", "orgID").Aggregate,
|
||||
// "accessTokenID1",
|
||||
// "clientID",
|
||||
// "agentID",
|
||||
// "de",
|
||||
// []string{"clientID1"},
|
||||
// []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
|
||||
// time.Now().Add(5*time.Minute),
|
||||
// )),
|
||||
// eventFromEventPusher(user.NewHumanRefreshTokenRenewedEvent(
|
||||
// context.Background(),
|
||||
// &user.NewAggregate("userID", "orgID").Aggregate,
|
||||
// "tokenID",
|
||||
// "refreshToken1",
|
||||
// 1*time.Hour,
|
||||
// )),
|
||||
// },
|
||||
// ),
|
||||
// ),
|
||||
// idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "accessTokenID1", "refreshToken1"),
|
||||
// keyAlgorithm: refreshTokenEncryptionAlgorithm(gomock.NewController(t)),
|
||||
// },
|
||||
// args: args{
|
||||
// ctx: context.Background(),
|
||||
// orgID: "orgID",
|
||||
// agentID: "agentID",
|
||||
// clientID: "clientID",
|
||||
// userID: "userID",
|
||||
// refreshToken: base64.RawURLEncoding.EncodeToString([]byte("userID:tokenID:tokenID")),
|
||||
// audience: []string{"clientID1"},
|
||||
// scopes: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
|
||||
// authMethodsReferences: []string{"password"},
|
||||
// lifetime: 5 * time.Minute,
|
||||
// authTime: time.Now(),
|
||||
// },
|
||||
// res: res{
|
||||
// err: zerrors.IsInternal,
|
||||
// },
|
||||
//},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore,
|
||||
idGenerator: tt.fields.idGenerator,
|
||||
keyAlgorithm: tt.fields.keyAlgorithm,
|
||||
}
|
||||
got, gotRefresh, err := c.AddAccessAndRefreshToken(tt.args.ctx, tt.args.orgID, tt.args.agentID, tt.args.clientID, tt.args.userID, tt.args.refreshToken,
|
||||
tt.args.audience, tt.args.scopes, tt.args.authMethodsReferences, tt.args.lifetime, tt.args.refreshIdleExpiration, tt.args.refreshExpiration, tt.args.authTime, tt.args.reason, tt.args.actor)
|
||||
if tt.res.err == nil {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
if tt.res.err != nil && !tt.res.err(err) {
|
||||
t.Errorf("got wrong err: %v ", err)
|
||||
}
|
||||
if tt.res.err == nil {
|
||||
assert.Equal(t, tt.res.token, got)
|
||||
assert.Equal(t, tt.res.refreshToken, gotRefresh)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommands_RevokeRefreshToken(t *testing.T) {
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
@@ -669,395 +371,3 @@ func TestCommands_RevokeRefreshTokens(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func refreshTokenEncryptionAlgorithm(ctrl *gomock.Controller) crypto.EncryptionAlgorithm {
|
||||
mCrypto := crypto.NewMockEncryptionAlgorithm(ctrl)
|
||||
mCrypto.EXPECT().Algorithm().AnyTimes().Return("enc")
|
||||
mCrypto.EXPECT().EncryptionKeyID().AnyTimes().Return("id")
|
||||
mCrypto.EXPECT().Encrypt(gomock.Any()).AnyTimes().DoAndReturn(
|
||||
func(refrehToken []byte) ([]byte, error) {
|
||||
return refrehToken, nil
|
||||
},
|
||||
)
|
||||
mCrypto.EXPECT().Decrypt(gomock.Any(), gomock.Any()).AnyTimes().DoAndReturn(
|
||||
func(refrehToken []byte, keyID string) ([]byte, error) {
|
||||
if keyID != "id" {
|
||||
return nil, zerrors.ThrowInternal(nil, "id", "invalid key id")
|
||||
}
|
||||
return refrehToken, nil
|
||||
},
|
||||
)
|
||||
return mCrypto
|
||||
}
|
||||
|
||||
func TestCommands_addRefreshToken(t *testing.T) {
|
||||
authTime := time.Now().Add(-1 * time.Hour)
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
keyAlgorithm crypto.EncryptionAlgorithm
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
accessToken *domain.Token
|
||||
authMethodsReferences []string
|
||||
authTime time.Time
|
||||
idleExpiration time.Duration
|
||||
expiration time.Duration
|
||||
actor *domain.TokenActor
|
||||
}
|
||||
type res struct {
|
||||
event *user.HumanRefreshTokenAddedEvent
|
||||
refreshToken string
|
||||
err func(error) bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
|
||||
{
|
||||
name: "add refresh Token",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(t),
|
||||
keyAlgorithm: refreshTokenEncryptionAlgorithm(gomock.NewController(t)),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
accessToken: &domain.Token{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
AggregateID: "userID",
|
||||
ResourceOwner: "org1",
|
||||
},
|
||||
TokenID: "accessTokenID1",
|
||||
ApplicationID: "clientID",
|
||||
UserAgentID: "agentID",
|
||||
RefreshTokenID: "refreshTokenID",
|
||||
Audience: []string{"clientID1"},
|
||||
Expiration: time.Now().Add(5 * time.Minute),
|
||||
Scopes: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
|
||||
PreferredLanguage: "de",
|
||||
},
|
||||
authMethodsReferences: []string{"password"},
|
||||
authTime: authTime,
|
||||
idleExpiration: 1 * time.Hour,
|
||||
expiration: 10 * time.Hour,
|
||||
},
|
||||
res: res{
|
||||
event: user.NewHumanRefreshTokenAddedEvent(
|
||||
context.Background(),
|
||||
&user.NewAggregate("userID", "org1").Aggregate,
|
||||
"refreshTokenID",
|
||||
"clientID",
|
||||
"agentID",
|
||||
"de",
|
||||
[]string{"clientID1"},
|
||||
[]string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
|
||||
[]string{"password"},
|
||||
authTime,
|
||||
1*time.Hour,
|
||||
10*time.Hour,
|
||||
nil,
|
||||
),
|
||||
refreshToken: base64.RawURLEncoding.EncodeToString([]byte("userID:refreshTokenID:refreshTokenID")),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore,
|
||||
keyAlgorithm: tt.fields.keyAlgorithm,
|
||||
}
|
||||
gotEvent, gotRefreshToken, err := c.addRefreshToken(tt.args.ctx, tt.args.accessToken, tt.args.authMethodsReferences, tt.args.authTime, tt.args.idleExpiration, tt.args.expiration, tt.args.actor)
|
||||
if tt.res.err == nil {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
if tt.res.err != nil && !tt.res.err(err) {
|
||||
t.Errorf("got wrong err: %v ", err)
|
||||
}
|
||||
if tt.res.err == nil {
|
||||
assert.Equal(t, tt.res.event, gotEvent)
|
||||
assert.Equal(t, tt.res.refreshToken, gotRefreshToken)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommands_renewRefreshToken(t *testing.T) {
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
idGenerator id.Generator
|
||||
keyAlgorithm crypto.EncryptionAlgorithm
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
userID string
|
||||
orgID string
|
||||
refreshToken string
|
||||
idleExpiration time.Duration
|
||||
}
|
||||
type res struct {
|
||||
event *user.HumanRefreshTokenRenewedEvent
|
||||
refreshTokenID string
|
||||
newRefreshToken string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want *renewedRefreshToken
|
||||
wantErr func(error) bool
|
||||
}{
|
||||
{
|
||||
name: "empty token, error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(t),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
},
|
||||
wantErr: zerrors.IsErrorInvalidArgument,
|
||||
},
|
||||
{
|
||||
name: "invalid token, error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(t),
|
||||
keyAlgorithm: refreshTokenEncryptionAlgorithm(gomock.NewController(t)),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
refreshToken: "invalid",
|
||||
},
|
||||
wantErr: zerrors.IsErrorInvalidArgument,
|
||||
},
|
||||
{
|
||||
name: "invalid token (invalid userID), error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(t),
|
||||
keyAlgorithm: refreshTokenEncryptionAlgorithm(gomock.NewController(t)),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
userID: "userID",
|
||||
orgID: "orgID",
|
||||
refreshToken: base64.RawURLEncoding.EncodeToString([]byte("userID2:tokenID:token")),
|
||||
},
|
||||
wantErr: zerrors.IsErrorInvalidArgument,
|
||||
},
|
||||
{
|
||||
name: "token inactive, error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(user.NewHumanRefreshTokenAddedEvent(
|
||||
context.Background(),
|
||||
&user.NewAggregate("userID", "orgID").Aggregate,
|
||||
"tokenID",
|
||||
"applicationID",
|
||||
"userAgentID",
|
||||
"de",
|
||||
[]string{"clientID1"},
|
||||
[]string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
|
||||
[]string{"password"},
|
||||
time.Now(),
|
||||
1*time.Hour,
|
||||
24*time.Hour,
|
||||
nil,
|
||||
)),
|
||||
eventFromEventPusher(user.NewHumanRefreshTokenRemovedEvent(
|
||||
context.Background(),
|
||||
&user.NewAggregate("userID", "orgID").Aggregate,
|
||||
"tokenID",
|
||||
)),
|
||||
),
|
||||
),
|
||||
keyAlgorithm: refreshTokenEncryptionAlgorithm(gomock.NewController(t)),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
userID: "userID",
|
||||
orgID: "orgID",
|
||||
refreshToken: base64.RawURLEncoding.EncodeToString([]byte("userID:tokenID:token")),
|
||||
},
|
||||
wantErr: zerrors.IsErrorInvalidArgument,
|
||||
},
|
||||
{
|
||||
name: "token expired, error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(user.NewHumanRefreshTokenAddedEvent(
|
||||
context.Background(),
|
||||
&user.NewAggregate("userID", "orgID").Aggregate,
|
||||
"tokenID",
|
||||
"applicationID",
|
||||
"userAgentID",
|
||||
"de",
|
||||
[]string{"clientID1"},
|
||||
[]string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
|
||||
[]string{"password"},
|
||||
time.Now(),
|
||||
1*time.Hour,
|
||||
24*time.Hour,
|
||||
nil,
|
||||
)),
|
||||
),
|
||||
),
|
||||
keyAlgorithm: refreshTokenEncryptionAlgorithm(gomock.NewController(t)),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
userID: "userID",
|
||||
orgID: "orgID",
|
||||
refreshToken: base64.RawURLEncoding.EncodeToString([]byte("userID:tokenID:tokenID")),
|
||||
},
|
||||
wantErr: zerrors.IsErrorInvalidArgument,
|
||||
},
|
||||
{
|
||||
name: "user deactivated, error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusherWithCreationDateNow(user.NewHumanRefreshTokenAddedEvent(
|
||||
context.Background(),
|
||||
&user.NewAggregate("userID", "orgID").Aggregate,
|
||||
"tokenID",
|
||||
"applicationID",
|
||||
"userAgentID",
|
||||
"de",
|
||||
[]string{"clientID1"},
|
||||
[]string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
|
||||
[]string{"password"},
|
||||
time.Now(),
|
||||
1*time.Hour,
|
||||
24*time.Hour,
|
||||
nil,
|
||||
)),
|
||||
eventFromEventPusher(
|
||||
user.NewUserDeactivatedEvent(
|
||||
context.Background(),
|
||||
&user.NewAggregate("userID", "orgID").Aggregate,
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
keyAlgorithm: refreshTokenEncryptionAlgorithm(gomock.NewController(t)),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
userID: "userID",
|
||||
orgID: "orgID",
|
||||
refreshToken: base64.RawURLEncoding.EncodeToString([]byte("userID:tokenID:tokenID")),
|
||||
idleExpiration: 1 * time.Hour,
|
||||
},
|
||||
wantErr: zerrors.IsErrorInvalidArgument,
|
||||
},
|
||||
{
|
||||
name: "user signedout, error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusherWithCreationDateNow(user.NewHumanRefreshTokenAddedEvent(
|
||||
context.Background(),
|
||||
&user.NewAggregate("userID", "orgID").Aggregate,
|
||||
"tokenID",
|
||||
"applicationID",
|
||||
"userAgentID",
|
||||
"de",
|
||||
[]string{"clientID1"},
|
||||
[]string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
|
||||
[]string{"password"},
|
||||
time.Now(),
|
||||
1*time.Hour,
|
||||
24*time.Hour,
|
||||
nil,
|
||||
)),
|
||||
eventFromEventPusher(
|
||||
user.NewHumanSignedOutEvent(
|
||||
context.Background(),
|
||||
&user.NewAggregate("userID", "orgID").Aggregate,
|
||||
"userAgentID",
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
keyAlgorithm: refreshTokenEncryptionAlgorithm(gomock.NewController(t)),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
userID: "userID",
|
||||
orgID: "orgID",
|
||||
refreshToken: base64.RawURLEncoding.EncodeToString([]byte("userID:tokenID:tokenID")),
|
||||
idleExpiration: 1 * time.Hour,
|
||||
},
|
||||
wantErr: zerrors.IsErrorInvalidArgument,
|
||||
},
|
||||
{
|
||||
name: "token renewed, ok",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusherWithCreationDateNow(user.NewHumanRefreshTokenAddedEvent(
|
||||
context.Background(),
|
||||
&user.NewAggregate("userID", "orgID").Aggregate,
|
||||
"tokenID",
|
||||
"applicationID",
|
||||
"userAgentID",
|
||||
"de",
|
||||
[]string{"clientID1"},
|
||||
[]string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
|
||||
[]string{"password"},
|
||||
time.Now(),
|
||||
1*time.Hour,
|
||||
24*time.Hour,
|
||||
nil,
|
||||
)),
|
||||
),
|
||||
),
|
||||
keyAlgorithm: refreshTokenEncryptionAlgorithm(gomock.NewController(t)),
|
||||
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "refreshToken1"),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
userID: "userID",
|
||||
orgID: "orgID",
|
||||
refreshToken: base64.RawURLEncoding.EncodeToString([]byte("userID:tokenID:tokenID")),
|
||||
idleExpiration: 1 * time.Hour,
|
||||
},
|
||||
want: &renewedRefreshToken{
|
||||
event: user.NewHumanRefreshTokenRenewedEvent(
|
||||
context.Background(),
|
||||
&user.NewAggregate("userID", "orgID").Aggregate,
|
||||
"tokenID",
|
||||
"refreshToken1",
|
||||
1*time.Hour,
|
||||
),
|
||||
authMethodsReferences: []string{"password"},
|
||||
tokenID: "tokenID",
|
||||
token: base64.RawURLEncoding.EncodeToString([]byte("userID:tokenID:refreshToken1")),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore,
|
||||
idGenerator: tt.fields.idGenerator,
|
||||
keyAlgorithm: tt.fields.keyAlgorithm,
|
||||
}
|
||||
got, err := c.renewRefreshToken(tt.args.ctx, tt.args.userID, tt.args.orgID, tt.args.refreshToken, tt.args.idleExpiration)
|
||||
if tt.wantErr != nil && !tt.wantErr(err) {
|
||||
t.Errorf("got wrong err: %v ", err)
|
||||
}
|
||||
if tt.wantErr == nil {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.want.event, got.event)
|
||||
assert.Equal(t, tt.want.authMethodsReferences, got.authMethodsReferences)
|
||||
assert.Equal(t, tt.want.tokenID, got.tokenID)
|
||||
assert.Equal(t, tt.want.token, got.token)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/command/preparation"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/id"
|
||||
"github.com/zitadel/zitadel/internal/repository/instance"
|
||||
"github.com/zitadel/zitadel/internal/repository/org"
|
||||
"github.com/zitadel/zitadel/internal/repository/project"
|
||||
@@ -1433,91 +1432,6 @@ func TestCommandSide_RemoveUser(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommandSide_AddUserToken(t *testing.T) {
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
idGenerator id.Generator
|
||||
}
|
||||
type (
|
||||
args struct {
|
||||
ctx context.Context
|
||||
orgID string
|
||||
agentID string
|
||||
clientID string
|
||||
userID string
|
||||
audience []string
|
||||
scopes []string
|
||||
authMethodsReferences []string
|
||||
lifetime time.Duration
|
||||
authTime time.Time
|
||||
reason domain.TokenReason
|
||||
actor *domain.TokenActor
|
||||
}
|
||||
)
|
||||
type res struct {
|
||||
want *domain.Token
|
||||
err func(error) bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "userid missing, invalid argument error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
orgID: "org1",
|
||||
userID: "",
|
||||
},
|
||||
res: res{
|
||||
err: zerrors.IsErrorInvalidArgument,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "user not existing, not found error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
expectFilter(),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
orgID: "org1",
|
||||
userID: "user1",
|
||||
},
|
||||
res: res{
|
||||
err: zerrors.IsNotFound,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := &Commands{
|
||||
eventstore: tt.fields.eventstore,
|
||||
idGenerator: tt.fields.idGenerator,
|
||||
}
|
||||
got, err := r.AddUserToken(tt.args.ctx, tt.args.orgID, tt.args.agentID, tt.args.clientID, tt.args.userID, tt.args.audience, tt.args.scopes, tt.args.authMethodsReferences, tt.args.lifetime, tt.args.authTime, tt.args.reason, tt.args.actor)
|
||||
if tt.res.err == nil {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
if tt.res.err != nil && !tt.res.err(err) {
|
||||
t.Errorf("got wrong err: %v ", err)
|
||||
}
|
||||
if tt.res.err == nil {
|
||||
assert.Equal(t, tt.res.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommands_RevokeAccessToken(t *testing.T) {
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
|
||||
Reference in New Issue
Block a user