perf(oidc): optimize token creation (#7822)

* implement code exchange

* port tokenexchange to v2 tokens

* implement refresh token

* implement client credentials

* implement jwt profile

* implement device token

* cleanup unused code

* fix current unit tests

* add user agent unit test

* unit test domain package

* need refresh token as argument

* test commands create oidc session

* test commands device auth

* fix device auth build error

* implicit for oidc session API

* implement authorize callback handler for legacy implicit mode

* upgrade oidc module to working draft

* add missing auth methods and time

* handle all errors in defer

* do not fail auth request on error

the oauth2 Go client automagically retries on any error. If we fail the auth request on the first error, the next attempt will always fail with the Errors.AuthRequest.NoCode, because the auth request state is already set to failed.
The original error is then already lost and the oauth2 library does not return the original error.

Therefore we should not fail the auth request.

Might be worth discussing and perhaps send a bug report to Oauth2?

* fix code flow tests by explicitly setting code exchanged

* fix unit tests in command package

* return allowed scope from client credential client

* add device auth done reducer

* carry nonce thru session into ID token

* fix token exchange integration tests

* allow project role scope prefix in client credentials client

* gci formatting

* do not return refresh token in client credentials and jwt profile

* check org scope

* solve linting issue on authorize callback error

* end session based on v2 session ID

* use preferred language and user agent ID for v2 access tokens

* pin oidc v3.23.2

* add integration test for jwt profile and client credentials with org scopes

* refresh token v1 to v2

* add user token v2 audit event

* add activity trigger

* cleanup and set panics for unused methods

* use the encrypted code for v1 auth request get by code

* add missing event translation

* fix pipeline errors (hopefully)

* fix another test

* revert pointer usage of preferred language

* solve browser info panic in device auth

* remove duplicate entries in AMRToAuthMethodTypes to prevent future `mfa` claim

* revoke v1 refresh token to prevent reuse

* fix terminate oidc session

* always return a new refresh toke in refresh token grant

---------

Co-authored-by: Livio Spring <livio.a@gmail.com>
This commit is contained in:
Tim Möhlmann
2024-05-16 08:07:56 +03:00
committed by GitHub
parent 6cf9ca9f7e
commit 8e0c8393e9
84 changed files with 3429 additions and 2635 deletions

View File

@@ -12,21 +12,22 @@ import (
)
type AuthRequest struct {
ID string
LoginClient string
ClientID string
RedirectURI string
State string
Nonce string
Scope []string
Audience []string
ResponseType domain.OIDCResponseType
CodeChallenge *domain.OIDCCodeChallenge
Prompt []domain.Prompt
UILocales []string
MaxAge *time.Duration
LoginHint *string
HintUserID *string
ID string
LoginClient string
ClientID string
RedirectURI string
State string
Nonce string
Scope []string
Audience []string
ResponseType domain.OIDCResponseType
CodeChallenge *domain.OIDCCodeChallenge
Prompt []domain.Prompt
UILocales []string
MaxAge *time.Duration
LoginHint *string
HintUserID *string
NeedRefreshToken bool
}
type CurrentAuthRequest struct {
@@ -69,6 +70,7 @@ func (c *Commands) AddAuthRequest(ctx context.Context, authRequest *AuthRequest)
authRequest.MaxAge,
authRequest.LoginHint,
authRequest.HintUserID,
authRequest.NeedRefreshToken,
))
if err != nil {
return nil, err
@@ -148,25 +150,6 @@ func (c *Commands) AddAuthRequestCode(ctx context.Context, authRequestID, code s
&authrequest.NewAggregate(writeModel.AggregateID, authz.GetInstance(ctx).InstanceID()).Aggregate))
}
func (c *Commands) ExchangeAuthCode(ctx context.Context, code string) (authRequest *CurrentAuthRequest, err error) {
if code == "" {
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-Sf3g2", "Errors.AuthRequest.InvalidCode")
}
writeModel, err := c.getAuthRequestWriteModel(ctx, code)
if err != nil {
return nil, err
}
if writeModel.AuthRequestState != domain.AuthRequestStateCodeAdded {
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-SFwd2", "Errors.AuthRequest.NoCode")
}
err = c.pushAppendAndReduce(ctx, writeModel, authrequest.NewCodeExchangedEvent(ctx,
&authrequest.NewAggregate(writeModel.AggregateID, authz.GetInstance(ctx).InstanceID()).Aggregate))
if err != nil {
return nil, err
}
return authRequestWriteModelToCurrentAuthRequest(writeModel), nil
}
func authRequestWriteModelToCurrentAuthRequest(writeModel *AuthRequestWriteModel) (_ *CurrentAuthRequest) {
return &CurrentAuthRequest{
AuthRequest: &AuthRequest{

View File

@@ -34,6 +34,7 @@ type AuthRequestWriteModel struct {
AuthTime time.Time
AuthMethods []domain.UserAuthMethodType
AuthRequestState domain.AuthRequestState
NeedRefreshToken bool
}
func NewAuthRequestWriteModel(ctx context.Context, id string) *AuthRequestWriteModel {
@@ -64,6 +65,7 @@ func (m *AuthRequestWriteModel) Reduce() error {
m.LoginHint = e.LoginHint
m.HintUserID = e.HintUserID
m.AuthRequestState = domain.AuthRequestStateAdded
m.NeedRefreshToken = e.NeedRefreshToken
case *authrequest.SessionLinkedEvent:
m.SessionID = e.SessionID
m.UserID = e.UserID

View File

@@ -10,6 +10,7 @@ import (
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/domain"
@@ -59,6 +60,7 @@ func TestCommands_AddAuthRequest(t *testing.T) {
nil,
nil,
nil,
false,
),
),
),
@@ -96,6 +98,7 @@ func TestCommands_AddAuthRequest(t *testing.T) {
gu.Ptr(time.Duration(0)),
gu.Ptr("loginHint"),
gu.Ptr("hintUserID"),
false,
),
),
),
@@ -223,6 +226,7 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
nil,
nil,
nil,
true,
),
),
eventFromEventPusher(
@@ -263,6 +267,7 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
nil,
nil,
nil,
true,
),
),
),
@@ -301,6 +306,7 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
nil,
nil,
nil,
true,
),
),
),
@@ -338,6 +344,7 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
nil,
nil,
nil,
true,
),
),
),
@@ -354,7 +361,7 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
)),
eventFromEventPusher(
session.NewUserCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
"userID", "org1", testNow.Add(-5*time.Minute)),
"userID", "org1", testNow.Add(-5*time.Minute), &language.Afrikaans),
),
eventFromEventPusher(
session.NewPasswordCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
@@ -398,6 +405,7 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
nil,
nil,
nil,
true,
),
),
),
@@ -447,6 +455,7 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
nil,
nil,
nil,
true,
),
),
),
@@ -463,7 +472,7 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
)),
eventFromEventPusher(
session.NewUserCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
"userID", "org1", testNow),
"userID", "org1", testNow, &language.Afrikaans),
),
eventFromEventPusher(
session.NewPasswordCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
@@ -532,6 +541,7 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
nil,
nil,
nil,
true,
),
),
),
@@ -548,7 +558,7 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
)),
eventFromEventPusher(
session.NewUserCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
"userID", "org1", testNow),
"userID", "org1", testNow, &language.Afrikaans),
),
eventFromEventPusher(
session.NewPasswordCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
@@ -674,6 +684,7 @@ func TestCommands_FailAuthRequest(t *testing.T) {
nil,
nil,
nil,
true,
),
),
),
@@ -771,6 +782,7 @@ func TestCommands_AddAuthRequestCode(t *testing.T) {
gu.Ptr(time.Duration(0)),
gu.Ptr("loginHint"),
gu.Ptr("hintUserID"),
true,
),
),
),
@@ -807,6 +819,7 @@ func TestCommands_AddAuthRequestCode(t *testing.T) {
gu.Ptr(time.Duration(0)),
gu.Ptr("loginHint"),
gu.Ptr("hintUserID"),
true,
),
),
eventFromEventPusher(
@@ -841,166 +854,3 @@ func TestCommands_AddAuthRequestCode(t *testing.T) {
})
}
}
func TestCommands_ExchangeAuthCode(t *testing.T) {
mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient")
type fields struct {
eventstore *eventstore.Eventstore
}
type args struct {
ctx context.Context
code string
}
type res struct {
authRequest *CurrentAuthRequest
err error
}
tests := []struct {
name string
fields fields
args args
res res
}{
{
"empty code error",
fields{
eventstore: eventstoreExpect(t),
},
args{
ctx: mockCtx,
code: "",
},
res{
err: zerrors.ThrowPreconditionFailed(nil, "COMMAND-Sf3g2", "Errors.AuthRequest.InvalidCode"),
},
},
{
"no code added error",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
"loginClient",
"clientID",
"redirectURI",
"state",
"nonce",
[]string{"openid"},
[]string{"audience"},
domain.OIDCResponseTypeCode,
&domain.OIDCCodeChallenge{
Challenge: "challenge",
Method: domain.CodeChallengeMethodS256,
},
[]domain.Prompt{domain.PromptNone},
[]string{"en", "de"},
gu.Ptr(time.Duration(0)),
gu.Ptr("loginHint"),
gu.Ptr("hintUserID"),
),
),
),
),
},
args{
ctx: mockCtx,
code: "V2_authRequestID",
},
res{
err: zerrors.ThrowPreconditionFailed(nil, "COMMAND-SFwd2", "Errors.AuthRequest.NoCode"),
},
},
{
"code exchanged",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
"loginClient",
"clientID",
"redirectURI",
"state",
"nonce",
[]string{"openid"},
[]string{"audience"},
domain.OIDCResponseTypeCode,
&domain.OIDCCodeChallenge{
Challenge: "challenge",
Method: domain.CodeChallengeMethodS256,
},
[]domain.Prompt{domain.PromptNone},
[]string{"en", "de"},
gu.Ptr(time.Duration(0)),
gu.Ptr("loginHint"),
gu.Ptr("hintUserID"),
),
),
eventFromEventPusher(
authrequest.NewSessionLinkedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
"sessionID",
"userID",
testNow,
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
),
),
eventFromEventPusher(
authrequest.NewCodeAddedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
),
),
expectPush(
authrequest.NewCodeExchangedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
),
),
},
args{
ctx: mockCtx,
code: "V2_authRequestID",
},
res{
authRequest: &CurrentAuthRequest{
AuthRequest: &AuthRequest{
ID: "V2_authRequestID",
LoginClient: "loginClient",
ClientID: "clientID",
RedirectURI: "redirectURI",
State: "state",
Nonce: "nonce",
Scope: []string{"openid"},
Audience: []string{"audience"},
ResponseType: domain.OIDCResponseTypeCode,
CodeChallenge: &domain.OIDCCodeChallenge{
Challenge: "challenge",
Method: domain.CodeChallengeMethodS256,
},
Prompt: []domain.Prompt{domain.PromptNone},
UILocales: []string{"en", "de"},
MaxAge: gu.Ptr(time.Duration(0)),
LoginHint: gu.Ptr("loginHint"),
HintUserID: gu.Ptr("hintUserID"),
},
SessionID: "sessionID",
UserID: "userID",
AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore,
}
got, err := c.ExchangeAuthCode(tt.args.ctx, tt.args.code)
assert.ErrorIs(t, tt.res.err, err)
if err == nil {
// equal on time won't work -> test separately and clear it before comparing the rest
assert.WithinRange(t, got.AuthTime, testNow, testNow)
got.AuthTime = time.Time{}
}
assert.Equal(t, tt.res.authRequest, got)
})
}
}

View File

@@ -2,16 +2,20 @@ package command
import (
"context"
"fmt"
"time"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/deviceauth"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/zerrors"
)
func (c *Commands) AddDeviceAuth(ctx context.Context, clientID, deviceCode, userCode string, expires time.Time, scopes, audience []string) (*domain.ObjectDetails, error) {
func (c *Commands) AddDeviceAuth(ctx context.Context, clientID, deviceCode, userCode string, expires time.Time, scopes, audience []string, needRefreshToken bool) (*domain.ObjectDetails, error) {
aggr := deviceauth.NewAggregate(deviceCode, authz.GetInstance(ctx).InstanceID())
model := NewDeviceAuthWriteModel(deviceCode, aggr.ResourceOwner)
@@ -24,6 +28,7 @@ func (c *Commands) AddDeviceAuth(ctx context.Context, clientID, deviceCode, user
expires,
scopes,
audience,
needRefreshToken,
))
if err != nil {
return nil, err
@@ -36,7 +41,16 @@ func (c *Commands) AddDeviceAuth(ctx context.Context, clientID, deviceCode, user
return writeModelToObjectDetails(&model.WriteModel), nil
}
func (c *Commands) ApproveDeviceAuth(ctx context.Context, deviceCode, subject string, authMethods []domain.UserAuthMethodType, authTime time.Time) (*domain.ObjectDetails, error) {
func (c *Commands) ApproveDeviceAuth(
ctx context.Context,
deviceCode,
userID,
userOrgID string,
authMethods []domain.UserAuthMethodType,
authTime time.Time,
preferredLanguage *language.Tag,
userAgent *domain.UserAgent,
) (*domain.ObjectDetails, error) {
model, err := c.getDeviceAuthWriteModelByDeviceCode(ctx, deviceCode)
if err != nil {
return nil, err
@@ -44,9 +58,7 @@ func (c *Commands) ApproveDeviceAuth(ctx context.Context, deviceCode, subject st
if !model.State.Exists() {
return nil, zerrors.ThrowNotFound(nil, "COMMAND-Hief9", "Errors.DeviceAuth.NotFound")
}
aggr := deviceauth.NewAggregate(model.AggregateID, model.InstanceID)
pushedEvents, err := c.eventstore.Push(ctx, deviceauth.NewApprovedEvent(ctx, aggr, subject, authMethods, authTime))
pushedEvents, err := c.eventstore.Push(ctx, deviceauth.NewApprovedEvent(ctx, model.aggregate, userID, userOrgID, authMethods, authTime, preferredLanguage, userAgent))
if err != nil {
return nil, err
}
@@ -66,9 +78,7 @@ func (c *Commands) CancelDeviceAuth(ctx context.Context, id string, reason domai
if !model.State.Exists() {
return nil, zerrors.ThrowNotFound(nil, "COMMAND-gee5A", "Errors.DeviceAuth.NotFound")
}
aggr := deviceauth.NewAggregate(model.AggregateID, model.InstanceID)
pushedEvents, err := c.eventstore.Push(ctx, deviceauth.NewCanceledEvent(ctx, aggr, reason))
pushedEvents, err := c.eventstore.Push(ctx, deviceauth.NewCanceledEvent(ctx, model.aggregate, reason))
if err != nil {
return nil, err
}
@@ -81,10 +91,89 @@ func (c *Commands) CancelDeviceAuth(ctx context.Context, id string, reason domai
}
func (c *Commands) getDeviceAuthWriteModelByDeviceCode(ctx context.Context, deviceCode string) (*DeviceAuthWriteModel, error) {
model := &DeviceAuthWriteModel{WriteModel: eventstore.WriteModel{AggregateID: deviceCode}}
model := &DeviceAuthWriteModel{
WriteModel: eventstore.WriteModel{AggregateID: deviceCode},
}
err := c.eventstore.FilterToQueryReducer(ctx, model)
if err != nil {
return nil, err
}
model.aggregate = deviceauth.NewAggregate(model.AggregateID, model.InstanceID)
return model, nil
}
type DeviceAuthStateError domain.DeviceAuthState
func (e DeviceAuthStateError) Error() string {
return fmt.Sprintf("device auth state not approved: %s", domain.DeviceAuthState(e).String())
}
// CreateOIDCSessionFromDeviceAuth creates a new OIDC session if the device authorization
// flow is completed (user logged in).
// A [DeviceAuthStateError] is returned if the device authorization was not approved,
// containing a [domain.DeviceAuthState] which can be used to inform the client about the state.
//
// As devices can poll at various intervals, an explicit state takes precedence over expiry.
// This is to prevent cases where users might approve or deny the authorization on time, but the next poll
// happens after expiry.
func (c *Commands) CreateOIDCSessionFromDeviceAuth(ctx context.Context, deviceCode string) (_ *OIDCSession, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
deviceAuthModel, err := c.getDeviceAuthWriteModelByDeviceCode(ctx, deviceCode)
if err != nil {
return nil, err
}
switch deviceAuthModel.State {
case domain.DeviceAuthStateApproved:
break
case domain.DeviceAuthStateUndefined:
return nil, zerrors.ThrowNotFound(nil, "COMMAND-ua1Vo", "Errors.DeviceAuth.NotFound")
case domain.DeviceAuthStateInitiated:
if deviceAuthModel.Expires.Before(time.Now()) {
c.asyncPush(ctx, deviceauth.NewCanceledEvent(ctx, deviceAuthModel.aggregate, domain.DeviceAuthCanceledExpired))
return nil, DeviceAuthStateError(domain.DeviceAuthStateExpired)
}
fallthrough
case domain.DeviceAuthStateDenied, domain.DeviceAuthStateExpired, domain.DeviceAuthStateDone:
fallthrough
default:
return nil, DeviceAuthStateError(deviceAuthModel.State)
}
cmd, err := c.newOIDCSessionAddEvents(ctx, deviceAuthModel.UserOrgID)
if err != nil {
return nil, err
}
cmd.AddSession(ctx,
deviceAuthModel.UserID,
deviceAuthModel.UserOrgID,
"",
deviceAuthModel.ClientID,
deviceAuthModel.Audience,
deviceAuthModel.Scopes,
deviceAuthModel.UserAuthMethods,
deviceAuthModel.AuthTime,
"",
deviceAuthModel.PreferredLanguage,
deviceAuthModel.UserAgent,
)
if err = cmd.AddAccessToken(ctx, deviceAuthModel.Scopes, deviceAuthModel.UserID, deviceAuthModel.UserOrgID, domain.TokenReasonAuthRequest, nil); err != nil {
return nil, err
}
if deviceAuthModel.NeedRefreshToken {
if err = cmd.AddRefreshToken(ctx, deviceAuthModel.UserID); err != nil {
return nil, err
}
}
cmd.DeviceAuthRequestDone(ctx, deviceAuthModel.aggregate)
return cmd.PushEvents(ctx)
}
func (cmd *OIDCSessionEvents) DeviceAuthRequestDone(ctx context.Context, deviceAuthAggregate *eventstore.Aggregate) {
cmd.events = append(cmd.events, deviceauth.NewDoneEvent(ctx, deviceAuthAggregate))
}

View File

@@ -3,6 +3,8 @@ package command
import (
"time"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/deviceauth"
@@ -10,16 +12,22 @@ import (
type DeviceAuthWriteModel struct {
eventstore.WriteModel
aggregate *eventstore.Aggregate
ClientID string
DeviceCode string
UserCode string
Expires time.Time
Scopes []string
State domain.DeviceAuthState
Subject string
UserAuthMethods []domain.UserAuthMethodType
AuthTime time.Time
ClientID string
DeviceCode string
UserCode string
Expires time.Time
Scopes []string
Audience []string
State domain.DeviceAuthState
UserID string
UserOrgID string
UserAuthMethods []domain.UserAuthMethodType
AuthTime time.Time
PreferredLanguage *language.Tag
UserAgent *domain.UserAgent
NeedRefreshToken bool
}
func NewDeviceAuthWriteModel(deviceCode, resourceOwner string) *DeviceAuthWriteModel {
@@ -28,6 +36,7 @@ func NewDeviceAuthWriteModel(deviceCode, resourceOwner string) *DeviceAuthWriteM
AggregateID: deviceCode,
ResourceOwner: resourceOwner,
},
aggregate: deviceauth.NewAggregate(deviceCode, resourceOwner),
}
}
@@ -40,14 +49,21 @@ func (m *DeviceAuthWriteModel) Reduce() error {
m.UserCode = e.UserCode
m.Expires = e.Expires
m.Scopes = e.Scopes
m.Audience = e.Audience
m.State = e.State
m.NeedRefreshToken = e.NeedRefreshToken
case *deviceauth.ApprovedEvent:
m.State = domain.DeviceAuthStateApproved
m.Subject = e.Subject
m.UserID = e.UserID
m.UserOrgID = e.UserOrgID
m.UserAuthMethods = e.UserAuthMethods
m.AuthTime = e.AuthTime
m.PreferredLanguage = e.PreferredLanguage
m.UserAgent = e.UserAgent
case *deviceauth.CanceledEvent:
m.State = e.Reason.State()
case *deviceauth.DoneEvent:
m.State = domain.DeviceAuthStateDone
}
}

View File

@@ -3,16 +3,27 @@ package command
import (
"context"
"errors"
"io"
"net"
"net/http"
"testing"
"time"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/id"
"github.com/zitadel/zitadel/internal/id/mock"
"github.com/zitadel/zitadel/internal/repository/deviceauth"
"github.com/zitadel/zitadel/internal/repository/oidcsession"
"github.com/zitadel/zitadel/internal/repository/user"
"github.com/zitadel/zitadel/internal/zerrors"
)
@@ -25,16 +36,17 @@ func TestCommands_AddDeviceAuth(t *testing.T) {
require.Len(t, unique, 2)
type fields struct {
eventstore *eventstore.Eventstore
eventstore func(*testing.T) *eventstore.Eventstore
}
type args struct {
ctx context.Context
clientID string
deviceCode string
userCode string
expires time.Time
scopes []string
audience []string
ctx context.Context
clientID string
deviceCode string
userCode string
expires time.Time
scopes []string
audience []string
needRefreshToken bool
}
tests := []struct {
name string
@@ -46,24 +58,25 @@ func TestCommands_AddDeviceAuth(t *testing.T) {
{
name: "success",
fields: fields{
eventstore: eventstoreExpect(t, expectPush(
eventstore: expectEventstore(expectPush(
deviceauth.NewAddedEvent(
ctx,
deviceauth.NewAggregate("123", "instance1"),
"client_id", "123", "456", now,
[]string{"a", "b", "c"},
[]string{"projectID", "clientID"},
[]string{"projectID", "clientID"}, true,
),
)),
},
args: args{
ctx: authz.WithInstanceID(context.Background(), "instance1"),
clientID: "client_id",
deviceCode: "123",
userCode: "456",
expires: now,
scopes: []string{"a", "b", "c"},
audience: []string{"projectID", "clientID"},
ctx: authz.WithInstanceID(context.Background(), "instance1"),
clientID: "client_id",
deviceCode: "123",
userCode: "456",
expires: now,
scopes: []string{"a", "b", "c"},
audience: []string{"projectID", "clientID"},
needRefreshToken: true,
},
wantDetails: &domain.ObjectDetails{
ResourceOwner: "instance1",
@@ -72,24 +85,25 @@ func TestCommands_AddDeviceAuth(t *testing.T) {
{
name: "push error",
fields: fields{
eventstore: eventstoreExpect(t, expectPushFailed(pushErr,
eventstore: expectEventstore(expectPushFailed(pushErr,
deviceauth.NewAddedEvent(
ctx,
deviceauth.NewAggregate("123", "instance1"),
"client_id", "123", "456", now,
[]string{"a", "b", "c"},
[]string{"projectID", "clientID"},
[]string{"projectID", "clientID"}, false,
)),
),
},
args: args{
ctx: authz.WithInstanceID(context.Background(), "instance1"),
clientID: "client_id",
deviceCode: "123",
userCode: "456",
expires: now,
scopes: []string{"a", "b", "c"},
audience: []string{"projectID", "clientID"},
ctx: authz.WithInstanceID(context.Background(), "instance1"),
clientID: "client_id",
deviceCode: "123",
userCode: "456",
expires: now,
scopes: []string{"a", "b", "c"},
audience: []string{"projectID", "clientID"},
needRefreshToken: false,
},
wantErr: pushErr,
},
@@ -97,9 +111,9 @@ func TestCommands_AddDeviceAuth(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore,
eventstore: tt.fields.eventstore(t),
}
gotDetails, err := c.AddDeviceAuth(tt.args.ctx, tt.args.clientID, tt.args.deviceCode, tt.args.userCode, tt.args.expires, tt.args.scopes, tt.args.audience)
gotDetails, err := c.AddDeviceAuth(tt.args.ctx, tt.args.clientID, tt.args.deviceCode, tt.args.userCode, tt.args.expires, tt.args.scopes, tt.args.audience, tt.args.needRefreshToken)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.wantDetails, gotDetails)
})
@@ -115,11 +129,14 @@ func TestCommands_ApproveDeviceAuth(t *testing.T) {
eventstore *eventstore.Eventstore
}
type args struct {
ctx context.Context
id string
subject string
authMethods []domain.UserAuthMethodType
authTime time.Time
ctx context.Context
id string
userID string
userOrgID string
authMethods []domain.UserAuthMethodType
authTime time.Time
preferredLanguage *language.Tag
userAgent *domain.UserAgent
}
tests := []struct {
name string
@@ -136,9 +153,14 @@ func TestCommands_ApproveDeviceAuth(t *testing.T) {
),
},
args: args{
ctx, "123", "subj",
ctx, "123", "subj", "orgID",
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
time.Unix(123, 456),
time.Unix(123, 456), &language.Afrikaans, &domain.UserAgent{
FingerprintID: gu.Ptr("fp1"),
IP: net.ParseIP("1.2.3.4"),
Description: gu.Ptr("firefox"),
Header: http.Header{"foo": []string{"bar"}},
},
},
wantErr: zerrors.ThrowNotFound(nil, "COMMAND-Hief9", "Errors.DeviceAuth.NotFound"),
},
@@ -153,22 +175,32 @@ func TestCommands_ApproveDeviceAuth(t *testing.T) {
deviceauth.NewAggregate("123", "instance1"),
"client_id", "123", "456", now,
[]string{"a", "b", "c"},
[]string{"projectID", "clientID"},
[]string{"projectID", "clientID"}, true,
),
)),
expectPushFailed(pushErr,
deviceauth.NewApprovedEvent(
ctx, deviceauth.NewAggregate("123", "instance1"), "subj",
ctx, deviceauth.NewAggregate("123", "instance1"), "subj", "orgID",
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
time.Unix(123, 456),
time.Unix(123, 456), &language.Afrikaans, &domain.UserAgent{
FingerprintID: gu.Ptr("fp1"),
IP: net.ParseIP("1.2.3.4"),
Description: gu.Ptr("firefox"),
Header: http.Header{"foo": []string{"bar"}},
},
),
),
),
},
args: args{
ctx, "123", "subj",
ctx, "123", "subj", "orgID",
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
time.Unix(123, 456),
time.Unix(123, 456), &language.Afrikaans, &domain.UserAgent{
FingerprintID: gu.Ptr("fp1"),
IP: net.ParseIP("1.2.3.4"),
Description: gu.Ptr("firefox"),
Header: http.Header{"foo": []string{"bar"}},
},
},
wantErr: pushErr,
},
@@ -183,22 +215,32 @@ func TestCommands_ApproveDeviceAuth(t *testing.T) {
deviceauth.NewAggregate("123", "instance1"),
"client_id", "123", "456", now,
[]string{"a", "b", "c"},
[]string{"projectID", "clientID"},
[]string{"projectID", "clientID"}, true,
),
)),
expectPush(
deviceauth.NewApprovedEvent(
ctx, deviceauth.NewAggregate("123", "instance1"), "subj",
ctx, deviceauth.NewAggregate("123", "instance1"), "subj", "orgID",
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
time.Unix(123, 456),
time.Unix(123, 456), &language.Afrikaans, &domain.UserAgent{
FingerprintID: gu.Ptr("fp1"),
IP: net.ParseIP("1.2.3.4"),
Description: gu.Ptr("firefox"),
Header: http.Header{"foo": []string{"bar"}},
},
),
),
),
},
args: args{
ctx, "123", "subj",
ctx, "123", "subj", "orgID",
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
time.Unix(123, 456),
time.Unix(123, 456), &language.Afrikaans, &domain.UserAgent{
FingerprintID: gu.Ptr("fp1"),
IP: net.ParseIP("1.2.3.4"),
Description: gu.Ptr("firefox"),
Header: http.Header{"foo": []string{"bar"}},
},
},
wantDetails: &domain.ObjectDetails{
ResourceOwner: "instance1",
@@ -210,7 +252,7 @@ func TestCommands_ApproveDeviceAuth(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore,
}
gotDetails, err := c.ApproveDeviceAuth(tt.args.ctx, tt.args.id, tt.args.subject, tt.args.authMethods, tt.args.authTime)
gotDetails, err := c.ApproveDeviceAuth(tt.args.ctx, tt.args.id, tt.args.userID, tt.args.userOrgID, tt.args.authMethods, tt.args.authTime, tt.args.preferredLanguage, tt.args.userAgent)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, gotDetails, tt.wantDetails)
})
@@ -258,7 +300,7 @@ func TestCommands_CancelDeviceAuth(t *testing.T) {
deviceauth.NewAggregate("123", "instance1"),
"client_id", "123", "456", now,
[]string{"a", "b", "c"},
[]string{"projectID", "clientID"},
[]string{"projectID", "clientID"}, true,
),
)),
expectPushFailed(pushErr,
@@ -283,7 +325,7 @@ func TestCommands_CancelDeviceAuth(t *testing.T) {
deviceauth.NewAggregate("123", "instance1"),
"client_id", "123", "456", now,
[]string{"a", "b", "c"},
[]string{"projectID", "clientID"},
[]string{"projectID", "clientID"}, true,
),
)),
expectPush(
@@ -310,7 +352,7 @@ func TestCommands_CancelDeviceAuth(t *testing.T) {
deviceauth.NewAggregate("123", "instance1"),
"client_id", "123", "456", now,
[]string{"a", "b", "c"},
[]string{"projectID", "clientID"},
[]string{"projectID", "clientID"}, true,
),
)),
expectPush(
@@ -338,3 +380,392 @@ func TestCommands_CancelDeviceAuth(t *testing.T) {
})
}
}
func TestCommands_CreateOIDCSessionFromDeviceAuth(t *testing.T) {
ctx := authz.WithInstanceID(context.Background(), "instance1")
type fields struct {
eventstore func(*testing.T) *eventstore.Eventstore
idGenerator id.Generator
defaultAccessTokenLifetime time.Duration
defaultRefreshTokenLifetime time.Duration
defaultRefreshTokenIdleLifetime time.Duration
keyAlgorithm crypto.EncryptionAlgorithm
}
type args struct {
ctx context.Context
deviceCode string
}
tests := []struct {
name string
fields fields
args args
want *OIDCSession
wantErr error
}{
{
name: "device auth filter error",
fields: fields{
eventstore: expectEventstore(
expectFilterError(io.ErrClosedPipe),
),
},
args: args{
ctx,
"device1",
},
wantErr: io.ErrClosedPipe,
},
{
name: "not yet approved",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusherWithInstanceID(
"instance1",
deviceauth.NewAddedEvent(
ctx,
deviceauth.NewAggregate("123", "instance1"),
"clientID", "123", "456", time.Now().Add(time.Minute),
[]string{"openid", "offline_access"},
[]string{"audience"}, false,
),
),
),
),
},
args: args{
ctx,
"123",
},
wantErr: DeviceAuthStateError(domain.DeviceAuthStateInitiated),
},
{
name: "not found",
fields: fields{
eventstore: expectEventstore(
expectFilter(),
),
},
args: args{
ctx,
"123",
},
wantErr: zerrors.ThrowNotFound(nil, "COMMAND-ua1Vo", "Errors.DeviceAuth.NotFound"),
},
{
name: "expired",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusherWithInstanceID(
"instance1",
deviceauth.NewAddedEvent(
ctx,
deviceauth.NewAggregate("123", "instance1"),
"clientID", "123", "456", time.Now().Add(-time.Minute),
[]string{"openid", "offline_access"},
[]string{"audience"}, false,
),
),
),
expectPushSlow(time.Second, deviceauth.NewCanceledEvent(ctx,
deviceauth.NewAggregate("123", "instance1"),
domain.DeviceAuthCanceledExpired,
)),
),
},
args: args{
ctx,
"123",
},
wantErr: DeviceAuthStateError(domain.DeviceAuthStateExpired),
},
{
name: "already expired",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusherWithInstanceID(
"instance1",
deviceauth.NewAddedEvent(
ctx,
deviceauth.NewAggregate("123", "instance1"),
"clientID", "123", "456", time.Now().Add(-time.Minute),
[]string{"openid", "offline_access"},
[]string{"audience"}, false,
),
),
eventFromEventPusherWithInstanceID(
"instance1",
deviceauth.NewCanceledEvent(ctx,
deviceauth.NewAggregate("123", "instance1"),
domain.DeviceAuthCanceledExpired,
),
),
),
),
},
args: args{
ctx,
"123",
},
wantErr: DeviceAuthStateError(domain.DeviceAuthStateExpired),
},
{
name: "denied",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusherWithInstanceID(
"instance1",
deviceauth.NewAddedEvent(
ctx,
deviceauth.NewAggregate("123", "instance1"),
"clientID", "123", "456", time.Now().Add(-time.Minute),
[]string{"openid", "offline_access"},
[]string{"audience"}, false,
),
),
eventFromEventPusherWithInstanceID(
"instance1",
deviceauth.NewCanceledEvent(ctx,
deviceauth.NewAggregate("123", "instance1"),
domain.DeviceAuthCanceledDenied,
),
),
),
),
},
args: args{
ctx,
"123",
},
wantErr: DeviceAuthStateError(domain.DeviceAuthStateDenied),
},
{
name: "already done",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusherWithInstanceID(
"instance1",
deviceauth.NewAddedEvent(
ctx,
deviceauth.NewAggregate("123", "instance1"),
"clientID", "123", "456", time.Now().Add(-time.Minute),
[]string{"openid", "offline_access"},
[]string{"audience"}, false,
),
),
eventFromEventPusherWithInstanceID(
"instance1",
deviceauth.NewCanceledEvent(ctx,
deviceauth.NewAggregate("123", "instance1"),
domain.DeviceAuthCanceledDenied,
),
),
eventFromEventPusherWithInstanceID(
"instance1",
deviceauth.NewDoneEvent(ctx,
deviceauth.NewAggregate("123", "instance1"),
),
),
),
),
},
args: args{
ctx,
"123",
},
wantErr: DeviceAuthStateError(domain.DeviceAuthStateDone),
},
{
name: "approved, success",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusherWithInstanceID(
"instance1",
deviceauth.NewAddedEvent(
ctx,
deviceauth.NewAggregate("123", "instance1"),
"clientID", "123", "456", time.Now().Add(-time.Minute),
[]string{"openid", "offline_access"},
[]string{"audience"}, false,
),
),
eventFromEventPusherWithInstanceID(
"instance1",
deviceauth.NewApprovedEvent(ctx,
deviceauth.NewAggregate("123", "instance1"),
"userID", "org1",
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
testNow, &language.Afrikaans, &domain.UserAgent{
FingerprintID: gu.Ptr("fp1"),
IP: net.ParseIP("1.2.3.4"),
Description: gu.Ptr("firefox"),
Header: http.Header{"foo": []string{"bar"}},
},
),
),
),
expectFilter(), // token lifetime
expectPush(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
"userID", "org1", "", "clientID", []string{"audience"}, []string{"openid", "offline_access"},
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow, "", &language.Afrikaans, &domain.UserAgent{
FingerprintID: gu.Ptr("fp1"),
IP: net.ParseIP("1.2.3.4"),
Description: gu.Ptr("firefox"),
Header: http.Header{"foo": []string{"bar"}},
},
),
oidcsession.NewAccessTokenAddedEvent(context.Background(),
&oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
"at_accessTokenID", []string{"openid", "offline_access"}, time.Hour, domain.TokenReasonAuthRequest, nil,
),
user.NewUserTokenV2AddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate, "at_accessTokenID"),
deviceauth.NewDoneEvent(ctx,
deviceauth.NewAggregate("123", "instance1"),
),
),
),
idGenerator: mock.NewIDGeneratorExpectIDs(t, "oidcSessionID", "accessTokenID"),
defaultAccessTokenLifetime: time.Hour,
defaultRefreshTokenLifetime: 7 * 24 * time.Hour,
defaultRefreshTokenIdleLifetime: 24 * time.Hour,
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args: args{
ctx,
"123",
},
want: &OIDCSession{
TokenID: "V2_oidcSessionID-at_accessTokenID",
ClientID: "clientID",
UserID: "userID",
Audience: []string{"audience"},
Expiration: time.Time{}.Add(time.Hour),
Scope: []string{"openid", "offline_access"},
AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
AuthTime: testNow,
PreferredLanguage: &language.Afrikaans,
UserAgent: &domain.UserAgent{
FingerprintID: gu.Ptr("fp1"),
IP: net.ParseIP("1.2.3.4"),
Description: gu.Ptr("firefox"),
Header: http.Header{"foo": []string{"bar"}},
},
Reason: domain.TokenReasonAuthRequest,
},
},
{
name: "approved, with refresh token",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusherWithInstanceID(
"instance1",
deviceauth.NewAddedEvent(
ctx,
deviceauth.NewAggregate("123", "instance1"),
"clientID", "123", "456", time.Now().Add(-time.Minute),
[]string{"openid", "offline_access"},
[]string{"audience"}, true,
),
),
eventFromEventPusherWithInstanceID(
"instance1",
deviceauth.NewApprovedEvent(ctx,
deviceauth.NewAggregate("123", "instance1"),
"userID", "org1",
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
testNow, &language.Afrikaans, &domain.UserAgent{
FingerprintID: gu.Ptr("fp1"),
IP: net.ParseIP("1.2.3.4"),
Description: gu.Ptr("firefox"),
Header: http.Header{"foo": []string{"bar"}},
},
),
),
),
expectFilter(), // token lifetime
expectPush(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
"userID", "org1", "", "clientID", []string{"audience"}, []string{"openid", "offline_access"},
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow, "", &language.Afrikaans, &domain.UserAgent{
FingerprintID: gu.Ptr("fp1"),
IP: net.ParseIP("1.2.3.4"),
Description: gu.Ptr("firefox"),
Header: http.Header{"foo": []string{"bar"}},
},
),
oidcsession.NewAccessTokenAddedEvent(context.Background(),
&oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
"at_accessTokenID", []string{"openid", "offline_access"}, time.Hour, domain.TokenReasonAuthRequest, nil,
),
user.NewUserTokenV2AddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate, "at_accessTokenID"),
oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
"rt_refreshTokenID", 7*24*time.Hour, 24*time.Hour,
),
deviceauth.NewDoneEvent(ctx,
deviceauth.NewAggregate("123", "instance1"),
),
),
),
idGenerator: mock.NewIDGeneratorExpectIDs(t, "oidcSessionID", "accessTokenID", "refreshTokenID"),
defaultAccessTokenLifetime: time.Hour,
defaultRefreshTokenLifetime: 7 * 24 * time.Hour,
defaultRefreshTokenIdleLifetime: 24 * time.Hour,
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args: args{
ctx,
"123",
},
want: &OIDCSession{
TokenID: "V2_oidcSessionID-at_accessTokenID",
ClientID: "clientID",
UserID: "userID",
Audience: []string{"audience"},
Expiration: time.Time{}.Add(time.Hour),
Scope: []string{"openid", "offline_access"},
AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
AuthTime: testNow,
PreferredLanguage: &language.Afrikaans,
UserAgent: &domain.UserAgent{
FingerprintID: gu.Ptr("fp1"),
IP: net.ParseIP("1.2.3.4"),
Description: gu.Ptr("firefox"),
Header: http.Header{"foo": []string{"bar"}},
},
Reason: domain.TokenReasonAuthRequest,
RefreshToken: "VjJfb2lkY1Nlc3Npb25JRC1ydF9yZWZyZXNoVG9rZW5JRDp1c2VySUQ", //V2_oidcSessionID-rt_refreshTokenID:userID
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore(t),
idGenerator: tt.fields.idGenerator,
defaultAccessTokenLifetime: tt.fields.defaultAccessTokenLifetime,
defaultRefreshTokenLifetime: tt.fields.defaultRefreshTokenLifetime,
defaultRefreshTokenIdleLifetime: tt.fields.defaultRefreshTokenIdleLifetime,
keyAlgorithm: tt.fields.keyAlgorithm,
}
got, err := c.CreateOIDCSessionFromDeviceAuth(tt.args.ctx, tt.args.deviceCode)
c.jobs.Wait()
require.ErrorIs(t, err, tt.wantErr)
if got != nil {
assert.WithinRange(t, got.AuthTime, tt.want.AuthTime.Add(-time.Second), tt.want.AuthTime.Add(time.Second))
got.AuthTime = time.Time{}
tt.want.AuthTime = time.Time{}
}
assert.Equal(t, tt.want, got)
})
}
}

View File

@@ -8,7 +8,9 @@ import (
"time"
"github.com/zitadel/logging"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/activity"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
@@ -17,6 +19,7 @@ import (
"github.com/zitadel/zitadel/internal/repository/authrequest"
"github.com/zitadel/zitadel/internal/repository/oidcsession"
"github.com/zitadel/zitadel/internal/repository/user"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/zerrors"
)
@@ -28,60 +31,175 @@ const (
oidcTokenFormat = "%s" + oidcTokenSubjectDelimiter + "%s"
)
// AddOIDCSessionAccessToken creates a new OIDC Session, creates an access token and returns its id and expiration.
// If the underlying [AuthRequest] is a OIDC Auth Code Flow, it will set the code as exchanged.
func (c *Commands) AddOIDCSessionAccessToken(ctx context.Context, authRequestID string) (string, time.Time, error) {
cmd, err := c.newOIDCSessionAddEvents(ctx, authRequestID)
if err != nil {
return "", time.Time{}, err
}
cmd.AddSession(ctx)
if err = cmd.AddAccessToken(ctx, cmd.authRequestWriteModel.Scope, domain.TokenReasonAuthRequest, nil); err != nil {
return "", time.Time{}, err
}
cmd.SetAuthRequestSuccessful(ctx)
accessTokenID, _, accessTokenExpiration, err := cmd.PushEvents(ctx)
return accessTokenID, accessTokenExpiration, err
type OIDCSession struct {
SessionID string
TokenID string
ClientID string
UserID string
Audience []string
Expiration time.Time
Scope []string
AuthMethods []domain.UserAuthMethodType
AuthTime time.Time
Nonce string
PreferredLanguage *language.Tag
UserAgent *domain.UserAgent
Reason domain.TokenReason
Actor *domain.TokenActor
RefreshToken string
}
// AddOIDCSessionRefreshAndAccessToken creates a new OIDC Session, creates an access token and refresh token.
type AuthRequestComplianceChecker func(context.Context, *AuthRequestWriteModel) error
// CreateOIDCSessionFromAuthRequest creates a new OIDC Session, creates an access token and refresh token.
// It returns the access token id, expiration and the refresh token.
// If the underlying [AuthRequest] is a OIDC Auth Code Flow, it will set the code as exchanged.
func (c *Commands) AddOIDCSessionRefreshAndAccessToken(ctx context.Context, authRequestID string) (tokenID, refreshToken string, tokenExpiration time.Time, err error) {
cmd, err := c.newOIDCSessionAddEvents(ctx, authRequestID)
func (c *Commands) CreateOIDCSessionFromAuthRequest(ctx context.Context, authReqId string, complianceCheck AuthRequestComplianceChecker, needRefreshToken bool) (session *OIDCSession, state string, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
if authReqId == "" {
return nil, "", zerrors.ThrowPreconditionFailed(nil, "COMMAND-Sf3g2", "Errors.AuthRequest.InvalidCode")
}
authReqModel, err := c.getAuthRequestWriteModel(ctx, authReqId)
if err != nil {
return "", "", time.Time{}, err
return nil, "", err
}
cmd.AddSession(ctx)
if err = cmd.AddAccessToken(ctx, cmd.authRequestWriteModel.Scope, domain.TokenReasonAuthRequest, nil); err != nil {
return "", "", time.Time{}, err
if authReqModel.ResponseType == domain.OIDCResponseTypeCode && authReqModel.AuthRequestState != domain.AuthRequestStateCodeAdded {
return nil, "", zerrors.ThrowPreconditionFailed(nil, "COMMAND-Iung5", "Errors.AuthRequest.NoCode")
}
if err = cmd.AddRefreshToken(ctx); err != nil {
return "", "", time.Time{}, err
sessionModel := NewSessionWriteModel(authReqModel.SessionID, authz.GetInstance(ctx).InstanceID())
err = c.eventstore.FilterToQueryReducer(ctx, sessionModel)
if err != nil {
return nil, "", err
}
if err = sessionModel.CheckIsActive(); err != nil {
return nil, "", err
}
cmd, err := c.newOIDCSessionAddEvents(ctx, sessionModel.UserResourceOwner)
if err != nil {
return nil, "", err
}
if authReqModel.ResponseType == domain.OIDCResponseTypeCode {
if err = cmd.SetAuthRequestCodeExchanged(ctx, authReqModel); err != nil {
return nil, "", err
}
}
if err = complianceCheck(ctx, authReqModel); err != nil {
return nil, "", err
}
cmd.AddSession(ctx,
sessionModel.UserID,
sessionModel.UserResourceOwner,
sessionModel.AggregateID,
authReqModel.ClientID,
authReqModel.Audience,
authReqModel.Scope,
authReqModel.AuthMethods,
authReqModel.AuthTime,
authReqModel.Nonce,
sessionModel.PreferredLanguage,
sessionModel.UserAgent,
)
if authReqModel.ResponseType != domain.OIDCResponseTypeIDToken {
if err = cmd.AddAccessToken(ctx, authReqModel.Scope, sessionModel.UserID, sessionModel.UserResourceOwner, domain.TokenReasonAuthRequest, nil); err != nil {
return nil, "", err
}
}
if authReqModel.NeedRefreshToken && needRefreshToken {
if err = cmd.AddRefreshToken(ctx, sessionModel.UserID); err != nil {
return nil, "", err
}
}
cmd.SetAuthRequestSuccessful(ctx, authReqModel.aggregate)
session, err = cmd.PushEvents(ctx)
return session, authReqModel.State, err
}
func (c *Commands) CreateOIDCSession(ctx context.Context,
userID,
resourceOwner,
clientID string,
scope,
audience []string,
authMethods []domain.UserAuthMethodType,
authTime time.Time,
nonce string,
preferredLanguage *language.Tag,
userAgent *domain.UserAgent,
reason domain.TokenReason,
actor *domain.TokenActor,
needRefreshToken bool,
) (session *OIDCSession, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
cmd, err := c.newOIDCSessionAddEvents(ctx, resourceOwner)
if err != nil {
return nil, err
}
if reason == domain.TokenReasonImpersonation {
if err := c.checkPermission(ctx, "impersonation", resourceOwner, userID); err != nil {
return nil, err
}
cmd.UserImpersonated(ctx, userID, resourceOwner, clientID, actor)
}
cmd.AddSession(ctx, userID, resourceOwner, "", clientID, audience, scope, authMethods, authTime, nonce, preferredLanguage, userAgent)
if err = cmd.AddAccessToken(ctx, scope, userID, resourceOwner, reason, actor); err != nil {
return nil, err
}
if needRefreshToken {
if err = cmd.AddRefreshToken(ctx, userID); err != nil {
return nil, err
}
}
cmd.SetAuthRequestSuccessful(ctx)
return cmd.PushEvents(ctx)
}
type RefreshTokenComplianceChecker func(ctx context.Context, wm *OIDCSessionWriteModel, requestedScope []string) (scope []string, err error)
// ExchangeOIDCSessionRefreshAndAccessToken updates an existing OIDC Session, creates a new access and refresh token.
// It returns the access token id and expiration and the new refresh token.
func (c *Commands) ExchangeOIDCSessionRefreshAndAccessToken(ctx context.Context, oidcSessionID, refreshToken string, scope []string) (tokenID, newRefreshToken string, tokenExpiration time.Time, err error) {
cmd, err := c.newOIDCSessionUpdateEvents(ctx, oidcSessionID, refreshToken)
func (c *Commands) ExchangeOIDCSessionRefreshAndAccessToken(ctx context.Context, refreshToken string, scope []string, complianceCheck RefreshTokenComplianceChecker) (_ *OIDCSession, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
cmd, err := c.newOIDCSessionUpdateEvents(ctx, refreshToken)
if err != nil {
return "", "", time.Time{}, err
return nil, err
}
if err = cmd.AddAccessToken(ctx, scope, domain.TokenReasonRefresh, nil); err != nil {
return "", "", time.Time{}, err
scope, err = complianceCheck(ctx, cmd.oidcSessionWriteModel, scope)
if err != nil {
return nil, err
}
err = cmd.AddAccessToken(ctx, scope,
cmd.oidcSessionWriteModel.UserID,
cmd.oidcSessionWriteModel.UserResourceOwner,
domain.TokenReasonRefresh,
cmd.oidcSessionWriteModel.AccessTokenActor,
)
if err != nil {
return nil, err
}
if err = cmd.RenewRefreshToken(ctx); err != nil {
return "", "", time.Time{}, err
return nil, err
}
return cmd.PushEvents(ctx)
}
// OIDCSessionByRefreshToken computes the current state of an existing OIDCSession by a refresh_token (to start a Refresh Token Grant).
// If either the session is not active, the token is invalid or expired (incl. idle expiration) an invalid refresh token error will be returned.
func (c *Commands) OIDCSessionByRefreshToken(ctx context.Context, refreshToken string) (*OIDCSessionWriteModel, error) {
func (c *Commands) OIDCSessionByRefreshToken(ctx context.Context, refreshToken string) (_ *OIDCSessionWriteModel, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
oidcSessionID, refreshTokenID, err := parseRefreshToken(refreshToken)
if err != nil {
return nil, err
@@ -146,26 +264,7 @@ func (c *Commands) RevokeOIDCSessionToken(ctx context.Context, token, clientID s
return c.pushAppendAndReduce(ctx, writeModel, oidcsession.NewAccessTokenRevokedEvent(ctx, writeModel.aggregate))
}
func (c *Commands) newOIDCSessionAddEvents(ctx context.Context, authRequestID string) (*OIDCSessionEvents, error) {
authRequestWriteModel, err := c.getAuthRequestWriteModel(ctx, authRequestID)
if err != nil {
return nil, err
}
if err = authRequestWriteModel.CheckAuthenticated(); err != nil {
return nil, err
}
sessionWriteModel := NewSessionWriteModel(authRequestWriteModel.SessionID, authz.GetInstance(ctx).InstanceID())
err = c.eventstore.FilterToQueryReducer(ctx, sessionWriteModel)
if err != nil {
return nil, err
}
if err = sessionWriteModel.CheckIsActive(); err != nil {
return nil, err
}
resourceOwner, err := c.getResourceOwnerOfSessionUser(ctx, sessionWriteModel.UserID, sessionWriteModel.InstanceID)
if err != nil {
return nil, err
}
func (c *Commands) newOIDCSessionAddEvents(ctx context.Context, resourceOwner string, pending ...eventstore.Command) (*OIDCSessionEvents, error) {
accessTokenLifetime, refreshTokenLifeTime, refreshTokenIdleLifetime, err := c.tokenTokenLifetimes(ctx)
if err != nil {
return nil, err
@@ -179,42 +278,24 @@ func (c *Commands) newOIDCSessionAddEvents(ctx context.Context, authRequestID st
eventstore: c.eventstore,
idGenerator: c.idGenerator,
encryptionAlg: c.keyAlgorithm,
events: pending,
oidcSessionWriteModel: NewOIDCSessionWriteModel(sessionID, resourceOwner),
sessionWriteModel: sessionWriteModel,
authRequestWriteModel: authRequestWriteModel,
accessTokenLifetime: accessTokenLifetime,
refreshTokenLifeTime: refreshTokenLifeTime,
refreshTokenIdleLifetime: refreshTokenIdleLifetime,
}, nil
}
func (c *Commands) getResourceOwnerOfSessionUser(ctx context.Context, userID, instanceID string) (string, error) {
events, err := c.eventstore.Filter(ctx, eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
InstanceID(instanceID).
AllowTimeTravel().
OrderAsc().
Limit(1).
AddQuery().
AggregateTypes(user.AggregateType).
AggregateIDs(userID).
Builder())
if err != nil || len(events) != 1 {
return "", zerrors.ThrowInternal(err, "OIDCS-sferh", "Errors.Internal")
}
return events[0].Aggregate().ResourceOwner, nil
}
func (c *Commands) decryptRefreshToken(refreshToken string) (refreshTokenID string, err error) {
func (c *Commands) decryptRefreshToken(refreshToken string) (sessionID, refreshTokenID string, err error) {
decoded, err := base64.RawURLEncoding.DecodeString(refreshToken)
if err != nil {
return "", zerrors.ThrowInvalidArgument(err, "OIDCS-Cux9a", "Errors.User.RefreshToken.Invalid")
return "", "", zerrors.ThrowInvalidArgument(err, "OIDCS-Cux9a", "Errors.User.RefreshToken.Invalid")
}
decrypted, err := c.keyAlgorithm.DecryptString(decoded, c.keyAlgorithm.EncryptionKeyID())
if err != nil {
return "", err
return "", "", err
}
_, refreshTokenID, err = parseRefreshToken(decrypted)
return refreshTokenID, err
return parseRefreshToken(decrypted)
}
func parseRefreshToken(refreshToken string) (oidcSessionID, refreshTokenID string, err error) {
@@ -227,8 +308,8 @@ func parseRefreshToken(refreshToken string) (oidcSessionID, refreshTokenID strin
return split[0], strings.Split(split[1], oidcTokenSubjectDelimiter)[0], nil
}
func (c *Commands) newOIDCSessionUpdateEvents(ctx context.Context, oidcSessionID, refreshToken string) (*OIDCSessionEvents, error) {
refreshTokenID, err := c.decryptRefreshToken(refreshToken)
func (c *Commands) newOIDCSessionUpdateEvents(ctx context.Context, refreshToken string) (*OIDCSessionEvents, error) {
oidcSessionID, refreshTokenID, err := c.decryptRefreshToken(refreshToken)
if err != nil {
return nil, err
}
@@ -255,13 +336,12 @@ func (c *Commands) newOIDCSessionUpdateEvents(ctx context.Context, oidcSessionID
}
type OIDCSessionEvents struct {
eventstore *eventstore.Eventstore
idGenerator id.Generator
encryptionAlg crypto.EncryptionAlgorithm
events []eventstore.Command
oidcSessionWriteModel *OIDCSessionWriteModel
sessionWriteModel *SessionWriteModel
authRequestWriteModel *AuthRequestWriteModel
eventstore *eventstore.Eventstore
idGenerator id.Generator
encryptionAlg crypto.EncryptionAlgorithm
events []eventstore.Command
oidcSessionWriteModel *OIDCSessionWriteModel
accessTokenLifetime time.Duration
refreshTokenLifeTime time.Duration
refreshTokenIdleLifetime time.Duration
@@ -270,44 +350,75 @@ type OIDCSessionEvents struct {
accessTokenID string
// refreshToken is set by the command
refreshToken string
refreshTokenID string
refreshToken string
}
func (c *OIDCSessionEvents) AddSession(ctx context.Context) {
func (c *OIDCSessionEvents) AddSession(
ctx context.Context,
userID,
userResourceOwner,
sessionID,
clientID string,
audience,
scope []string,
authMethods []domain.UserAuthMethodType,
authTime time.Time,
nonce string,
preferredLanguage *language.Tag,
userAgent *domain.UserAgent,
) {
c.events = append(c.events, oidcsession.NewAddedEvent(
ctx,
c.oidcSessionWriteModel.aggregate,
c.sessionWriteModel.UserID,
c.sessionWriteModel.AggregateID,
c.authRequestWriteModel.ClientID,
c.authRequestWriteModel.Audience,
c.authRequestWriteModel.Scope,
c.sessionWriteModel.AuthMethodTypes(),
c.sessionWriteModel.AuthenticationTime(),
userID,
userResourceOwner,
sessionID,
clientID,
audience,
scope,
authMethods,
authTime,
nonce,
preferredLanguage,
userAgent,
))
}
func (c *OIDCSessionEvents) SetAuthRequestSuccessful(ctx context.Context) {
c.events = append(c.events, authrequest.NewSucceededEvent(ctx, c.authRequestWriteModel.aggregate))
func (c *OIDCSessionEvents) SetAuthRequestCodeExchanged(ctx context.Context, model *AuthRequestWriteModel) error {
event := authrequest.NewCodeExchangedEvent(ctx, model.aggregate)
model.AppendEvents(event)
c.events = append(c.events, event)
return model.Reduce()
}
func (c *OIDCSessionEvents) AddAccessToken(ctx context.Context, scope []string, reason domain.TokenReason, actor *domain.TokenActor) error {
func (c *OIDCSessionEvents) SetAuthRequestSuccessful(ctx context.Context, authRequestAggregate *eventstore.Aggregate) {
c.events = append(c.events, authrequest.NewSucceededEvent(ctx, authRequestAggregate))
}
func (c *OIDCSessionEvents) SetAuthRequestFailed(ctx context.Context, authRequestAggregate *eventstore.Aggregate, err error) {
c.events = append(c.events, authrequest.NewFailedEvent(ctx, authRequestAggregate, domain.OIDCErrorReasonFromError(err)))
}
func (c *OIDCSessionEvents) AddAccessToken(ctx context.Context, scope []string, userID, resourceOwner string, reason domain.TokenReason, actor *domain.TokenActor) error {
accessTokenID, err := c.idGenerator.Next()
if err != nil {
return err
}
c.accessTokenID = AccessTokenPrefix + accessTokenID
c.events = append(c.events, oidcsession.NewAccessTokenAddedEvent(ctx, c.oidcSessionWriteModel.aggregate, c.accessTokenID, scope, c.accessTokenLifetime, reason, actor))
c.events = append(c.events,
oidcsession.NewAccessTokenAddedEvent(ctx, c.oidcSessionWriteModel.aggregate, c.accessTokenID, scope, c.accessTokenLifetime, reason, actor),
user.NewUserTokenV2AddedEvent(ctx, &user.NewAggregate(userID, resourceOwner).Aggregate, c.accessTokenID), // for user audit log
)
return nil
}
func (c *OIDCSessionEvents) AddRefreshToken(ctx context.Context) (err error) {
var refreshTokenID string
refreshTokenID, c.refreshToken, err = c.generateRefreshToken(c.sessionWriteModel.UserID)
func (c *OIDCSessionEvents) AddRefreshToken(ctx context.Context, userID string) (err error) {
c.refreshTokenID, c.refreshToken, err = c.generateRefreshToken(userID)
if err != nil {
return err
}
c.events = append(c.events, oidcsession.NewRefreshTokenAddedEvent(ctx, c.oidcSessionWriteModel.aggregate, refreshTokenID, c.refreshTokenLifeTime, c.refreshTokenIdleLifetime))
c.events = append(c.events, oidcsession.NewRefreshTokenAddedEvent(ctx, c.oidcSessionWriteModel.aggregate, c.refreshTokenID, c.refreshTokenLifeTime, c.refreshTokenIdleLifetime))
return nil
}
@@ -321,6 +432,10 @@ func (c *OIDCSessionEvents) RenewRefreshToken(ctx context.Context) (err error) {
return nil
}
func (c *OIDCSessionEvents) UserImpersonated(ctx context.Context, userID, resourceOwner, clientID string, actor *domain.TokenActor) {
c.events = append(c.events, user.NewUserImpersonatedEvent(ctx, &user.NewAggregate(userID, resourceOwner).Aggregate, clientID, actor))
}
func (c *OIDCSessionEvents) generateRefreshToken(userID string) (refreshTokenID, refreshToken string, err error) {
refreshTokenID, err = c.idGenerator.Next()
if err != nil {
@@ -334,18 +449,38 @@ func (c *OIDCSessionEvents) generateRefreshToken(userID string) (refreshTokenID,
return refreshTokenID, base64.RawURLEncoding.EncodeToString(token), nil
}
func (c *OIDCSessionEvents) PushEvents(ctx context.Context) (accessTokenID string, refreshToken string, accessTokenExpiration time.Time, err error) {
func (c *OIDCSessionEvents) PushEvents(ctx context.Context) (*OIDCSession, error) {
pushedEvents, err := c.eventstore.Push(ctx, c.events...)
if err != nil {
return "", "", time.Time{}, err
return nil, err
}
err = AppendAndReduce(c.oidcSessionWriteModel, pushedEvents...)
if err != nil {
return "", "", time.Time{}, err
return nil, err
}
// prefix the returned id with the oidcSessionID so that we can retrieve it later on
// we need to use `-` as a delimiter because the OIDC library uses `:` and will check for a length of 2 parts
return c.oidcSessionWriteModel.AggregateID + TokenDelimiter + c.accessTokenID, c.refreshToken, c.oidcSessionWriteModel.AccessTokenExpiration, nil
session := &OIDCSession{
SessionID: c.oidcSessionWriteModel.SessionID,
ClientID: c.oidcSessionWriteModel.ClientID,
UserID: c.oidcSessionWriteModel.UserID,
Audience: c.oidcSessionWriteModel.Audience,
Expiration: c.oidcSessionWriteModel.AccessTokenExpiration,
Scope: c.oidcSessionWriteModel.Scope,
AuthMethods: c.oidcSessionWriteModel.AuthMethods,
AuthTime: c.oidcSessionWriteModel.AuthTime,
Nonce: c.oidcSessionWriteModel.Nonce,
PreferredLanguage: c.oidcSessionWriteModel.PreferredLanguage,
UserAgent: c.oidcSessionWriteModel.UserAgent,
Reason: c.oidcSessionWriteModel.AccessTokenReason,
Actor: c.oidcSessionWriteModel.AccessTokenActor,
RefreshToken: c.refreshToken,
}
if c.accessTokenID != "" {
// prefix the returned id with the oidcSessionID so that we can retrieve it later on
// we need to use `-` as a delimiter because the OIDC library uses `:` and will check for a length of 2 parts
session.TokenID = c.oidcSessionWriteModel.AggregateID + TokenDelimiter + c.accessTokenID
}
activity.Trigger(ctx, c.oidcSessionWriteModel.UserResourceOwner, c.oidcSessionWriteModel.UserID, tokenReasonToActivityMethodType(c.oidcSessionWriteModel.AccessTokenReason), c.eventstore.FilterToQueryReducer)
return session, nil
}
func (c *Commands) tokenTokenLifetimes(ctx context.Context) (accessTokenLifetime time.Duration, refreshTokenLifetime time.Duration, refreshTokenIdleLifetime time.Duration, err error) {
@@ -368,3 +503,14 @@ func (c *Commands) tokenTokenLifetimes(ctx context.Context) (accessTokenLifetime
}
return accessTokenLifetime, refreshTokenLifetime, refreshTokenIdleLifetime, nil
}
func tokenReasonToActivityMethodType(r domain.TokenReason) activity.TriggerMethod {
if r == domain.TokenReasonUnspecified {
return activity.Unspecified
}
if r == domain.TokenReasonRefresh {
return activity.OIDCRefreshToken
}
// all other reasons result in an access token
return activity.OIDCAccessToken
}

View File

@@ -3,6 +3,8 @@ package command
import (
"time"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/oidcsession"
@@ -13,12 +15,16 @@ type OIDCSessionWriteModel struct {
eventstore.WriteModel
UserID string
UserResourceOwner string
PreferredLanguage *language.Tag
SessionID string
ClientID string
Audience []string
Scope []string
AuthMethods []domain.UserAuthMethodType
AuthTime time.Time
Nonce string
UserAgent *domain.UserAgent
State domain.OIDCSessionState
AccessTokenID string
AccessTokenCreation time.Time
@@ -85,12 +91,16 @@ func (wm *OIDCSessionWriteModel) Query() *eventstore.SearchQueryBuilder {
func (wm *OIDCSessionWriteModel) reduceAdded(e *oidcsession.AddedEvent) {
wm.UserID = e.UserID
wm.UserResourceOwner = e.UserResourceOwner
wm.SessionID = e.SessionID
wm.ClientID = e.ClientID
wm.Audience = e.Audience
wm.Scope = e.Scope
wm.AuthMethods = e.AuthMethods
wm.AuthTime = e.AuthTime
wm.Nonce = e.Nonce
wm.PreferredLanguage = e.PreferredLanguage
wm.UserAgent = e.UserAgent
wm.State = domain.OIDCSessionStateActive
// the write model might be initialized without resource owner,
// so update the aggregate

File diff suppressed because it is too large Load Diff

View File

@@ -7,6 +7,8 @@ import (
"fmt"
"time"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/activity"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/crypto"
@@ -56,12 +58,12 @@ func (c *Commands) NewSessionCommands(cmds []SessionCommand, session *SessionWri
}
// CheckUser defines a user check to be executed for a session update
func CheckUser(id string, resourceOwner string) SessionCommand {
func CheckUser(id string, resourceOwner string, preferredLanguage *language.Tag) SessionCommand {
return func(ctx context.Context, cmd *SessionCommands) error {
if cmd.sessionWriteModel.UserID != "" && id != "" && cmd.sessionWriteModel.UserID != id {
return zerrors.ThrowInvalidArgument(nil, "", "user change not possible")
}
return cmd.UserChecked(ctx, id, resourceOwner, cmd.now())
return cmd.UserChecked(ctx, id, resourceOwner, cmd.now(), preferredLanguage)
}
}
@@ -171,8 +173,8 @@ func (s *SessionCommands) Start(ctx context.Context, userAgent *domain.UserAgent
s.eventCommands = append(s.eventCommands, session.NewAddedEvent(ctx, s.sessionWriteModel.aggregate, userAgent))
}
func (s *SessionCommands) UserChecked(ctx context.Context, userID, resourceOwner string, checkedAt time.Time) error {
s.eventCommands = append(s.eventCommands, session.NewUserCheckedEvent(ctx, s.sessionWriteModel.aggregate, userID, resourceOwner, checkedAt))
func (s *SessionCommands) UserChecked(ctx context.Context, userID, resourceOwner string, checkedAt time.Time, preferredLanguage *language.Tag) error {
s.eventCommands = append(s.eventCommands, session.NewUserCheckedEvent(ctx, s.sessionWriteModel.aggregate, userID, resourceOwner, checkedAt, preferredLanguage))
// set the userID so other checks can use it
s.sessionWriteModel.UserID = userID
s.sessionWriteModel.UserResourceOwner = resourceOwner

View File

@@ -3,6 +3,8 @@ package command
import (
"time"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
@@ -40,6 +42,7 @@ type SessionWriteModel struct {
TokenID string
UserID string
UserResourceOwner string
PreferredLanguage *language.Tag
UserCheckedAt time.Time
PasswordCheckedAt time.Time
IntentCheckedAt time.Time
@@ -50,6 +53,7 @@ type SessionWriteModel struct {
WebAuthNUserVerified bool
Metadata map[string][]byte
State domain.SessionState
UserAgent *domain.UserAgent
Expiration time.Time
WebAuthNChallenge *WebAuthNChallengeModel
@@ -137,12 +141,14 @@ func (wm *SessionWriteModel) Query() *eventstore.SearchQueryBuilder {
func (wm *SessionWriteModel) reduceAdded(e *session.AddedEvent) {
wm.State = domain.SessionStateActive
wm.UserAgent = e.UserAgent
}
func (wm *SessionWriteModel) reduceUserChecked(e *session.UserCheckedEvent) {
wm.UserID = e.UserID
wm.UserResourceOwner = e.UserResourceOwner
wm.UserCheckedAt = e.CheckedAt
wm.PreferredLanguage = e.PreferredLanguage
}
func (wm *SessionWriteModel) reducePasswordChecked(e *session.PasswordCheckedEvent) {

View File

@@ -566,7 +566,7 @@ func TestCommands_updateSession(t *testing.T) {
eventstore: eventstoreExpect(t,
expectPush(
session.NewUserCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instance1").Aggregate,
"userID", "org1", testNow,
"userID", "org1", testNow, &language.Afrikaans,
),
session.NewPasswordCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instance1").Aggregate,
testNow,
@@ -585,7 +585,7 @@ func TestCommands_updateSession(t *testing.T) {
checks: &SessionCommands{
sessionWriteModel: NewSessionWriteModel("sessionID", "instance1"),
sessionCommands: []SessionCommand{
CheckUser("userID", "org1"),
CheckUser("userID", "org1", &language.Afrikaans),
CheckPassword("password"),
},
eventstore: eventstoreExpect(t,
@@ -634,7 +634,7 @@ func TestCommands_updateSession(t *testing.T) {
checks: &SessionCommands{
sessionWriteModel: NewSessionWriteModel("sessionID", "instance1"),
sessionCommands: []SessionCommand{
CheckUser("userID", "org1"),
CheckUser("userID", "org1", &language.Afrikaans),
CheckIntent("intent", "aW50ZW50"),
},
eventstore: eventstoreExpect(t,
@@ -673,7 +673,7 @@ func TestCommands_updateSession(t *testing.T) {
checks: &SessionCommands{
sessionWriteModel: NewSessionWriteModel("sessionID", "instance1"),
sessionCommands: []SessionCommand{
CheckUser("userID", "org1"),
CheckUser("userID", "org1", &language.Afrikaans),
CheckIntent("intent", "aW50ZW50"),
},
eventstore: eventstoreExpect(t,
@@ -723,7 +723,7 @@ func TestCommands_updateSession(t *testing.T) {
checks: &SessionCommands{
sessionWriteModel: NewSessionWriteModel("sessionID", "instance1"),
sessionCommands: []SessionCommand{
CheckUser("userID", "org1"),
CheckUser("userID", "org1", &language.Afrikaans),
CheckIntent("intent2", "aW50ZW50"),
},
eventstore: eventstoreExpect(t),
@@ -751,7 +751,7 @@ func TestCommands_updateSession(t *testing.T) {
eventstore: eventstoreExpect(t,
expectPush(
session.NewUserCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instance1").Aggregate,
"userID", "org1", testNow),
"userID", "org1", testNow, &language.Afrikaans),
session.NewIntentCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instance1").Aggregate,
testNow),
session.NewMetadataSetEvent(context.Background(), &session.NewAggregate("sessionID", "instance1").Aggregate,
@@ -766,7 +766,7 @@ func TestCommands_updateSession(t *testing.T) {
checks: &SessionCommands{
sessionWriteModel: NewSessionWriteModel("sessionID", "instance1"),
sessionCommands: []SessionCommand{
CheckUser("userID", "org1"),
CheckUser("userID", "org1", &language.Afrikaans),
CheckIntent("intent", "aW50ZW50"),
},
eventstore: eventstoreExpect(t,
@@ -1188,7 +1188,7 @@ func TestCommands_TerminateSession(t *testing.T) {
),
eventFromEventPusher(
session.NewUserCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instance1").Aggregate,
"user1", "org1", testNow),
"user1", "org1", testNow, &language.Afrikaans),
),
eventFromEventPusher(
session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "instance1").Aggregate,
@@ -1229,7 +1229,7 @@ func TestCommands_TerminateSession(t *testing.T) {
),
eventFromEventPusher(
session.NewUserCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instance1").Aggregate,
"userID", "org1", testNow),
"userID", "org1", testNow, &language.Afrikaans),
),
eventFromEventPusher(
session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "instance1").Aggregate,
@@ -1271,7 +1271,7 @@ func TestCommands_TerminateSession(t *testing.T) {
),
eventFromEventPusher(
session.NewUserCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "org2").Aggregate,
"userID", "", testNow),
"userID", "", testNow, &language.Afrikaans),
),
eventFromEventPusher(
session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org2").Aggregate,

View File

@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"strings"
"time"
"github.com/zitadel/logging"
@@ -13,7 +12,6 @@ import (
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/v1/models"
"github.com/zitadel/zitadel/internal/repository/user"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/zerrors"
@@ -232,35 +230,6 @@ func (c *Commands) RemoveUser(ctx context.Context, userID, resourceOwner string,
return writeModelToObjectDetails(&existingUser.WriteModel), nil
}
func (c *Commands) AddUserToken(
ctx context.Context,
orgID,
agentID,
clientID,
userID string,
audience,
scopes,
authMethodsReferences []string,
lifetime time.Duration,
authTime time.Time,
reason domain.TokenReason,
actor *domain.TokenActor,
) (*domain.Token, error) {
if userID == "" { //do not check for empty orgID (JWT Profile requests won't provide it, so service user requests fail)
return nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-Dbge4", "Errors.IDMissing")
}
userWriteModel := NewUserWriteModel(userID, orgID)
cmds, accessToken, err := c.addUserToken(ctx, userWriteModel, agentID, clientID, "", audience, scopes, authMethodsReferences, lifetime, authTime, reason, actor)
if err != nil {
return nil, err
}
_, err = c.eventstore.Push(ctx, cmds...)
if err != nil {
return nil, err
}
return accessToken, nil
}
func (c *Commands) RevokeAccessToken(ctx context.Context, userID, orgID, tokenID string) (*domain.ObjectDetails, error) {
removeEvent, accessTokenWriteModel, err := c.removeAccessToken(ctx, userID, orgID, tokenID)
if err != nil {
@@ -277,61 +246,6 @@ func (c *Commands) RevokeAccessToken(ctx context.Context, userID, orgID, tokenID
return writeModelToObjectDetails(&accessTokenWriteModel.WriteModel), nil
}
func (c *Commands) addUserToken(ctx context.Context, userWriteModel *UserWriteModel, agentID, clientID, refreshTokenID string, audience, scopes, authMethodsReferences []string, lifetime time.Duration, authTime time.Time, reason domain.TokenReason, actor *domain.TokenActor) ([]eventstore.Command, *domain.Token, error) {
err := c.eventstore.FilterToQueryReducer(ctx, userWriteModel)
if err != nil {
return nil, nil, err
}
if userWriteModel.UserState != domain.UserStateActive {
return nil, nil, zerrors.ThrowNotFound(nil, "COMMAND-1d6Gg", "Errors.User.NotFound")
}
//nolint:contextcheck
userAgg := UserAggregateFromWriteModel(&userWriteModel.WriteModel)
var cmds []eventstore.Command
if reason == domain.TokenReasonImpersonation {
if err := c.checkPermission(ctx, "impersonation", userWriteModel.ResourceOwner, userWriteModel.AggregateID); err != nil {
return nil, nil, err
}
cmds = append(cmds, user.NewUserImpersonatedEvent(ctx, userAgg, clientID, actor))
}
preferredLanguage := ""
existingHuman, err := c.getHumanWriteModelByID(ctx, userWriteModel.AggregateID, userWriteModel.ResourceOwner)
if err != nil {
return nil, nil, err
}
if existingHuman != nil {
preferredLanguage = existingHuman.PreferredLanguage.String()
}
expiration := time.Now().UTC().Add(lifetime)
tokenID, err := c.idGenerator.Next()
if err != nil {
return nil, nil, err
}
cmds = append(cmds,
user.NewUserTokenAddedEvent(ctx, userAgg, tokenID, clientID, agentID, preferredLanguage, refreshTokenID, audience, scopes, authMethodsReferences, authTime, expiration, reason, actor),
)
return cmds, &domain.Token{
ObjectRoot: models.ObjectRoot{
AggregateID: userWriteModel.AggregateID,
},
TokenID: tokenID,
UserAgentID: agentID,
ApplicationID: clientID,
RefreshTokenID: refreshTokenID,
Audience: audience,
Scopes: scopes,
Expiration: expiration,
PreferredLanguage: preferredLanguage,
Reason: reason,
Actor: actor,
}, nil
}
func (c *Commands) removeAccessToken(ctx context.Context, userID, orgID, tokenID string) (*user.UserTokenRemovedEvent, *UserAccessTokenWriteModel, error) {
if userID == "" || orgID == "" || tokenID == "" {
return nil, nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-Dng42", "Errors.IDMissing")

View File

@@ -1,9 +1,6 @@
package command
import (
"encoding/base64"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/repository/user"
)
@@ -81,16 +78,6 @@ func writeModelToAddress(wm *HumanAddressWriteModel) *domain.Address {
}
}
func writeModelToMachine(wm *MachineWriteModel) *domain.Machine {
return &domain.Machine{
ObjectRoot: writeModelToObjectRoot(wm.WriteModel),
Username: wm.UserName,
Name: wm.Name,
Description: wm.Description,
State: wm.UserState,
}
}
func keyWriteModelToMachineKey(wm *MachineKeyWriteModel) *domain.MachineKey {
return &domain.MachineKey{
ObjectRoot: writeModelToObjectRoot(wm.WriteModel),
@@ -100,18 +87,6 @@ func keyWriteModelToMachineKey(wm *MachineKeyWriteModel) *domain.MachineKey {
}
}
func personalTokenWriteModelToToken(wm *PersonalAccessTokenWriteModel, algorithm crypto.EncryptionAlgorithm) (*domain.Token, string, error) {
encrypted, err := algorithm.Encrypt([]byte(wm.TokenID + ":" + wm.AggregateID))
if err != nil {
return nil, "", err
}
return &domain.Token{
ObjectRoot: writeModelToObjectRoot(wm.WriteModel),
TokenID: wm.TokenID,
Expiration: wm.ExpirationDate,
}, base64.RawURLEncoding.EncodeToString(encrypted), nil
}
func readModelToWebAuthNTokens(readModel HumanWebAuthNTokensReadModel) []*domain.WebAuthNToken {
tokens := make([]*domain.WebAuthNToken, len(readModel.GetWebAuthNTokens()))
for i, token := range readModel.GetWebAuthNTokens() {

View File

@@ -2,7 +2,6 @@ package command
import (
"context"
"time"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
@@ -10,98 +9,6 @@ import (
"github.com/zitadel/zitadel/internal/zerrors"
)
func (c *Commands) AddAccessAndRefreshToken(
ctx context.Context,
orgID,
agentID,
clientID,
userID,
refreshToken string,
audience,
scopes,
authMethodsReferences []string,
accessLifetime,
refreshIdleExpiration,
refreshExpiration time.Duration,
authTime time.Time,
reason domain.TokenReason,
actor *domain.TokenActor,
) (accessToken *domain.Token, newRefreshToken string, err error) {
if refreshToken == "" {
return c.AddNewRefreshTokenAndAccessToken(ctx, userID, orgID, agentID, clientID, audience, scopes, authMethodsReferences, refreshExpiration, accessLifetime, refreshIdleExpiration, authTime, reason, actor)
}
return c.RenewRefreshTokenAndAccessToken(ctx, userID, orgID, refreshToken, agentID, clientID, audience, scopes, refreshIdleExpiration, accessLifetime, actor)
}
func (c *Commands) AddNewRefreshTokenAndAccessToken(
ctx context.Context,
userID,
orgID,
agentID,
clientID string,
audience,
scopes,
authMethodsReferences []string,
refreshExpiration,
accessLifetime,
refreshIdleExpiration time.Duration,
authTime time.Time,
reason domain.TokenReason,
actor *domain.TokenActor,
) (accessToken *domain.Token, newRefreshToken string, err error) {
if userID == "" || clientID == "" {
return nil, "", zerrors.ThrowInvalidArgument(nil, "COMMAND-adg4r", "Errors.IDMissing")
}
userWriteModel := NewUserWriteModel(userID, orgID)
refreshTokenID, err := c.idGenerator.Next()
if err != nil {
return nil, "", err
}
cmds, accessToken, err := c.addUserToken(ctx, userWriteModel, agentID, clientID, refreshTokenID, audience, scopes, authMethodsReferences, accessLifetime, authTime, reason, actor)
if err != nil {
return nil, "", err
}
refreshTokenEvent, newRefreshToken, err := c.addRefreshToken(ctx, accessToken, authMethodsReferences, authTime, refreshIdleExpiration, refreshExpiration, actor)
if err != nil {
return nil, "", err
}
cmds = append(cmds, refreshTokenEvent)
_, err = c.eventstore.Push(ctx, cmds...)
if err != nil {
return nil, "", err
}
return accessToken, newRefreshToken, nil
}
func (c *Commands) RenewRefreshTokenAndAccessToken(
ctx context.Context,
userID,
orgID,
refreshToken,
agentID,
clientID string,
audience,
scopes []string,
idleExpiration,
accessLifetime time.Duration,
actor *domain.TokenActor,
) (accessToken *domain.Token, newRefreshToken string, err error) {
renewed, err := c.renewRefreshToken(ctx, userID, orgID, refreshToken, idleExpiration)
if err != nil {
return nil, "", err
}
userWriteModel := NewUserWriteModel(userID, orgID)
cmds, accessToken, err := c.addUserToken(ctx, userWriteModel, agentID, clientID, renewed.tokenID, audience, scopes, renewed.authMethodsReferences, accessLifetime, renewed.authTime, domain.TokenReasonRefresh, actor)
if err != nil {
return nil, "", err
}
_, err = c.eventstore.Push(ctx, append(cmds, renewed.event)...)
if err != nil {
return nil, "", err
}
return accessToken, renewed.token, nil
}
func (c *Commands) RevokeRefreshToken(ctx context.Context, userID, orgID, tokenID string) (*domain.ObjectDetails, error) {
removeEvent, refreshTokenWriteModel, err := c.removeRefreshToken(ctx, userID, orgID, tokenID)
if err != nil {
@@ -134,70 +41,6 @@ func (c *Commands) RevokeRefreshTokens(ctx context.Context, userID, orgID string
return err
}
func (c *Commands) addRefreshToken(ctx context.Context, accessToken *domain.Token, authMethodsReferences []string, authTime time.Time, idleExpiration, expiration time.Duration, actor *domain.TokenActor) (*user.HumanRefreshTokenAddedEvent, string, error) {
refreshToken, err := domain.NewRefreshToken(accessToken.AggregateID, accessToken.RefreshTokenID, c.keyAlgorithm)
if err != nil {
return nil, "", err
}
refreshTokenWriteModel := NewHumanRefreshTokenWriteModel(accessToken.AggregateID, accessToken.ResourceOwner, accessToken.RefreshTokenID)
userAgg := UserAggregateFromWriteModel(&refreshTokenWriteModel.WriteModel)
return user.NewHumanRefreshTokenAddedEvent(ctx, userAgg, accessToken.RefreshTokenID, accessToken.ApplicationID, accessToken.UserAgentID,
accessToken.PreferredLanguage, accessToken.Audience, accessToken.Scopes, authMethodsReferences, authTime, idleExpiration, expiration, actor),
refreshToken, nil
}
type renewedRefreshToken struct {
event *user.HumanRefreshTokenRenewedEvent
authTime time.Time
authMethodsReferences []string
tokenID string
token string
}
func (c *Commands) renewRefreshToken(ctx context.Context, userID, orgID, refreshToken string, idleExpiration time.Duration) (*renewedRefreshToken, error) {
if refreshToken == "" {
return nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-DHrr3", "Errors.IDMissing")
}
tokenUserID, tokenID, token, err := domain.FromRefreshToken(refreshToken, c.keyAlgorithm)
if err != nil {
return nil, err
}
if tokenUserID != userID {
return nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-Ht2g2", "Errors.User.RefreshToken.Invalid")
}
refreshTokenWriteModel := NewHumanRefreshTokenWriteModel(userID, orgID, tokenID)
err = c.eventstore.FilterToQueryReducer(ctx, refreshTokenWriteModel)
if err != nil {
return nil, err
}
if refreshTokenWriteModel.UserState != domain.UserStateActive {
return nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-BHnhs", "Errors.User.RefreshToken.Invalid")
}
if refreshTokenWriteModel.RefreshToken != token ||
refreshTokenWriteModel.IdleExpiration.Before(time.Now()) ||
refreshTokenWriteModel.Expiration.Before(time.Now()) {
return nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-Vr43e", "Errors.User.RefreshToken.Invalid")
}
newToken, err := c.idGenerator.Next()
if err != nil {
return nil, err
}
newRefreshToken, err := domain.RefreshToken(userID, tokenID, newToken, c.keyAlgorithm)
if err != nil {
return nil, err
}
userAgg := UserAggregateFromWriteModel(&refreshTokenWriteModel.WriteModel)
return &renewedRefreshToken{
event: user.NewHumanRefreshTokenRenewedEvent(ctx, userAgg, tokenID, newToken, idleExpiration),
authTime: refreshTokenWriteModel.AuthTime,
authMethodsReferences: refreshTokenWriteModel.AuthMethodsReferences,
tokenID: tokenID,
token: newRefreshToken,
}, nil
}
func (c *Commands) removeRefreshToken(ctx context.Context, userID, orgID, tokenID string) (*user.HumanRefreshTokenRemovedEvent, *HumanRefreshTokenWriteModel, error) {
if userID == "" || orgID == "" || tokenID == "" {
return nil, nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-GVDgf", "Errors.IDMissing")

View File

@@ -2,316 +2,18 @@ package command
import (
"context"
"encoding/base64"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v3/pkg/oidc"
"go.uber.org/mock/gomock"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/v1/models"
"github.com/zitadel/zitadel/internal/id"
id_mock "github.com/zitadel/zitadel/internal/id/mock"
"github.com/zitadel/zitadel/internal/repository/user"
"github.com/zitadel/zitadel/internal/zerrors"
)
func TestCommands_AddAccessAndRefreshToken(t *testing.T) {
type fields struct {
eventstore *eventstore.Eventstore
idGenerator id.Generator
keyAlgorithm crypto.EncryptionAlgorithm
}
type args struct {
ctx context.Context
orgID string
agentID string
clientID string
userID string
refreshToken string
audience []string
scopes []string
authMethodsReferences []string
lifetime time.Duration
authTime time.Time
refreshIdleExpiration time.Duration
refreshExpiration time.Duration
reason domain.TokenReason
actor *domain.TokenActor
}
type res struct {
token *domain.Token
refreshToken string
err func(error) bool
}
tests := []struct {
name string
fields fields
args args
res res
}{
{
name: "missing ID, error",
fields: fields{
eventstore: eventstoreExpect(t),
},
args: args{},
res: res{
err: zerrors.IsErrorInvalidArgument,
},
},
{
name: "add refresh token, user deactivated, error",
fields: fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
user.NewUserDeactivatedEvent(context.Background(),
&user.NewAggregate("userID", "orgID").Aggregate,
),
),
),
),
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "refreshTokenID1"),
},
args: args{
ctx: context.Background(),
orgID: "orgID",
agentID: "agentID",
userID: "userID",
clientID: "clientID",
},
res: res{
err: zerrors.IsNotFound,
},
},
{
name: "renew refresh token, invalid token, error",
fields: fields{
eventstore: eventstoreExpect(t),
keyAlgorithm: refreshTokenEncryptionAlgorithm(gomock.NewController(t)),
},
args: args{
ctx: context.Background(),
refreshToken: "invalid",
},
res: res{
err: zerrors.IsErrorInvalidArgument,
},
},
{
name: "renew refresh token, invalid token (invalid userID), error",
fields: fields{
eventstore: eventstoreExpect(t),
keyAlgorithm: refreshTokenEncryptionAlgorithm(gomock.NewController(t)),
},
args: args{
ctx: context.Background(),
userID: "userID",
orgID: "orgID",
refreshToken: base64.RawURLEncoding.EncodeToString([]byte("userID2:tokenID:token")),
},
res: res{
err: zerrors.IsErrorInvalidArgument,
},
},
{
name: "renew refresh token, token inactive, error",
fields: fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(user.NewHumanRefreshTokenAddedEvent(
context.Background(),
&user.NewAggregate("userID", "orgID").Aggregate,
"tokenID",
"applicationID",
"userAgentID",
"de",
[]string{"clientID1"},
[]string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
[]string{"password"},
time.Now(),
1*time.Hour,
24*time.Hour,
nil,
)),
eventFromEventPusher(user.NewHumanRefreshTokenRemovedEvent(
context.Background(),
&user.NewAggregate("userID", "orgID").Aggregate,
"tokenID",
)),
),
),
keyAlgorithm: refreshTokenEncryptionAlgorithm(gomock.NewController(t)),
},
args: args{
ctx: context.Background(),
userID: "userID",
orgID: "orgID",
refreshToken: base64.RawURLEncoding.EncodeToString([]byte("userID:tokenID:token")),
},
res: res{
err: zerrors.IsErrorInvalidArgument,
},
},
{
name: "renew refresh token, token expired, error",
fields: fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(user.NewHumanRefreshTokenAddedEvent(
context.Background(),
&user.NewAggregate("userID", "orgID").Aggregate,
"tokenID",
"applicationID",
"userAgentID",
"de",
[]string{"clientID1"},
[]string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
[]string{"password"},
time.Now(),
-1*time.Hour,
24*time.Hour,
nil,
)),
),
),
keyAlgorithm: refreshTokenEncryptionAlgorithm(gomock.NewController(t)),
},
args: args{
ctx: context.Background(),
userID: "userID",
orgID: "orgID",
refreshToken: base64.RawURLEncoding.EncodeToString([]byte("userID:tokenID:tokenID")),
},
res: res{
err: zerrors.IsErrorInvalidArgument,
},
},
//fails because of timestamp equality
//{
// name: "push failed, error",
// fields: fields{
// eventstore: eventstoreExpect(t,
// expectFilter(
// eventFromEventPusher(user.NewHumanAddedEvent(
// context.Background(),
// &user.NewAggregate("userID", "orgID").Aggregate,
// "username",
// "firstname",
// "lastname",
// "nickname",
// "displayname",
// language.German,
// domain.GenderUnspecified,
// "email",
// true,
// )),
// ),
// expectFilter(
// eventFromEventPusherWithCreationDateNow(user.NewHumanAddedEvent(
// context.Background(),
// &user.NewAggregate("userID", "orgID").Aggregate,
// "username",
// "firstname",
// "lastname",
// "nickname",
// "displayname",
// language.German,
// domain.GenderUnspecified,
// "email",
// true,
// )),
// ),
// expectFilter(
// eventFromEventPusherWithCreationDateNow(user.NewHumanRefreshTokenAddedEvent(
// context.Background(),
// &user.NewAggregate("userID", "orgID").Aggregate,
// "tokenID",
// "applicationID",
// "userAgentID",
// "de",
// []string{"clientID1"},
// []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
// []string{"password"},
// time.Now(),
// 1*time.Hour,
// 24*time.Hour,
// )),
// ),
// expectPushFailed(
// zerrors.ThrowInternal(nil, "ERROR", "internal"),
// []*repository.Event{
// eventFromEventPusher(user.NewUserTokenAddedEvent(
// context.Background(),
// &user.NewAggregate("userID", "orgID").Aggregate,
// "accessTokenID1",
// "clientID",
// "agentID",
// "de",
// []string{"clientID1"},
// []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
// time.Now().Add(5*time.Minute),
// )),
// eventFromEventPusher(user.NewHumanRefreshTokenRenewedEvent(
// context.Background(),
// &user.NewAggregate("userID", "orgID").Aggregate,
// "tokenID",
// "refreshToken1",
// 1*time.Hour,
// )),
// },
// ),
// ),
// idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "accessTokenID1", "refreshToken1"),
// keyAlgorithm: refreshTokenEncryptionAlgorithm(gomock.NewController(t)),
// },
// args: args{
// ctx: context.Background(),
// orgID: "orgID",
// agentID: "agentID",
// clientID: "clientID",
// userID: "userID",
// refreshToken: base64.RawURLEncoding.EncodeToString([]byte("userID:tokenID:tokenID")),
// audience: []string{"clientID1"},
// scopes: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
// authMethodsReferences: []string{"password"},
// lifetime: 5 * time.Minute,
// authTime: time.Now(),
// },
// res: res{
// err: zerrors.IsInternal,
// },
//},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore,
idGenerator: tt.fields.idGenerator,
keyAlgorithm: tt.fields.keyAlgorithm,
}
got, gotRefresh, err := c.AddAccessAndRefreshToken(tt.args.ctx, tt.args.orgID, tt.args.agentID, tt.args.clientID, tt.args.userID, tt.args.refreshToken,
tt.args.audience, tt.args.scopes, tt.args.authMethodsReferences, tt.args.lifetime, tt.args.refreshIdleExpiration, tt.args.refreshExpiration, tt.args.authTime, tt.args.reason, tt.args.actor)
if tt.res.err == nil {
assert.NoError(t, err)
}
if tt.res.err != nil && !tt.res.err(err) {
t.Errorf("got wrong err: %v ", err)
}
if tt.res.err == nil {
assert.Equal(t, tt.res.token, got)
assert.Equal(t, tt.res.refreshToken, gotRefresh)
}
})
}
}
func TestCommands_RevokeRefreshToken(t *testing.T) {
type fields struct {
eventstore *eventstore.Eventstore
@@ -669,395 +371,3 @@ func TestCommands_RevokeRefreshTokens(t *testing.T) {
})
}
}
func refreshTokenEncryptionAlgorithm(ctrl *gomock.Controller) crypto.EncryptionAlgorithm {
mCrypto := crypto.NewMockEncryptionAlgorithm(ctrl)
mCrypto.EXPECT().Algorithm().AnyTimes().Return("enc")
mCrypto.EXPECT().EncryptionKeyID().AnyTimes().Return("id")
mCrypto.EXPECT().Encrypt(gomock.Any()).AnyTimes().DoAndReturn(
func(refrehToken []byte) ([]byte, error) {
return refrehToken, nil
},
)
mCrypto.EXPECT().Decrypt(gomock.Any(), gomock.Any()).AnyTimes().DoAndReturn(
func(refrehToken []byte, keyID string) ([]byte, error) {
if keyID != "id" {
return nil, zerrors.ThrowInternal(nil, "id", "invalid key id")
}
return refrehToken, nil
},
)
return mCrypto
}
func TestCommands_addRefreshToken(t *testing.T) {
authTime := time.Now().Add(-1 * time.Hour)
type fields struct {
eventstore *eventstore.Eventstore
keyAlgorithm crypto.EncryptionAlgorithm
}
type args struct {
ctx context.Context
accessToken *domain.Token
authMethodsReferences []string
authTime time.Time
idleExpiration time.Duration
expiration time.Duration
actor *domain.TokenActor
}
type res struct {
event *user.HumanRefreshTokenAddedEvent
refreshToken string
err func(error) bool
}
tests := []struct {
name string
fields fields
args args
res res
}{
{
name: "add refresh Token",
fields: fields{
eventstore: eventstoreExpect(t),
keyAlgorithm: refreshTokenEncryptionAlgorithm(gomock.NewController(t)),
},
args: args{
ctx: context.Background(),
accessToken: &domain.Token{
ObjectRoot: models.ObjectRoot{
AggregateID: "userID",
ResourceOwner: "org1",
},
TokenID: "accessTokenID1",
ApplicationID: "clientID",
UserAgentID: "agentID",
RefreshTokenID: "refreshTokenID",
Audience: []string{"clientID1"},
Expiration: time.Now().Add(5 * time.Minute),
Scopes: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
PreferredLanguage: "de",
},
authMethodsReferences: []string{"password"},
authTime: authTime,
idleExpiration: 1 * time.Hour,
expiration: 10 * time.Hour,
},
res: res{
event: user.NewHumanRefreshTokenAddedEvent(
context.Background(),
&user.NewAggregate("userID", "org1").Aggregate,
"refreshTokenID",
"clientID",
"agentID",
"de",
[]string{"clientID1"},
[]string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
[]string{"password"},
authTime,
1*time.Hour,
10*time.Hour,
nil,
),
refreshToken: base64.RawURLEncoding.EncodeToString([]byte("userID:refreshTokenID:refreshTokenID")),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore,
keyAlgorithm: tt.fields.keyAlgorithm,
}
gotEvent, gotRefreshToken, err := c.addRefreshToken(tt.args.ctx, tt.args.accessToken, tt.args.authMethodsReferences, tt.args.authTime, tt.args.idleExpiration, tt.args.expiration, tt.args.actor)
if tt.res.err == nil {
assert.NoError(t, err)
}
if tt.res.err != nil && !tt.res.err(err) {
t.Errorf("got wrong err: %v ", err)
}
if tt.res.err == nil {
assert.Equal(t, tt.res.event, gotEvent)
assert.Equal(t, tt.res.refreshToken, gotRefreshToken)
}
})
}
}
func TestCommands_renewRefreshToken(t *testing.T) {
type fields struct {
eventstore *eventstore.Eventstore
idGenerator id.Generator
keyAlgorithm crypto.EncryptionAlgorithm
}
type args struct {
ctx context.Context
userID string
orgID string
refreshToken string
idleExpiration time.Duration
}
type res struct {
event *user.HumanRefreshTokenRenewedEvent
refreshTokenID string
newRefreshToken string
}
tests := []struct {
name string
fields fields
args args
want *renewedRefreshToken
wantErr func(error) bool
}{
{
name: "empty token, error",
fields: fields{
eventstore: eventstoreExpect(t),
},
args: args{
ctx: context.Background(),
},
wantErr: zerrors.IsErrorInvalidArgument,
},
{
name: "invalid token, error",
fields: fields{
eventstore: eventstoreExpect(t),
keyAlgorithm: refreshTokenEncryptionAlgorithm(gomock.NewController(t)),
},
args: args{
ctx: context.Background(),
refreshToken: "invalid",
},
wantErr: zerrors.IsErrorInvalidArgument,
},
{
name: "invalid token (invalid userID), error",
fields: fields{
eventstore: eventstoreExpect(t),
keyAlgorithm: refreshTokenEncryptionAlgorithm(gomock.NewController(t)),
},
args: args{
ctx: context.Background(),
userID: "userID",
orgID: "orgID",
refreshToken: base64.RawURLEncoding.EncodeToString([]byte("userID2:tokenID:token")),
},
wantErr: zerrors.IsErrorInvalidArgument,
},
{
name: "token inactive, error",
fields: fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(user.NewHumanRefreshTokenAddedEvent(
context.Background(),
&user.NewAggregate("userID", "orgID").Aggregate,
"tokenID",
"applicationID",
"userAgentID",
"de",
[]string{"clientID1"},
[]string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
[]string{"password"},
time.Now(),
1*time.Hour,
24*time.Hour,
nil,
)),
eventFromEventPusher(user.NewHumanRefreshTokenRemovedEvent(
context.Background(),
&user.NewAggregate("userID", "orgID").Aggregate,
"tokenID",
)),
),
),
keyAlgorithm: refreshTokenEncryptionAlgorithm(gomock.NewController(t)),
},
args: args{
ctx: context.Background(),
userID: "userID",
orgID: "orgID",
refreshToken: base64.RawURLEncoding.EncodeToString([]byte("userID:tokenID:token")),
},
wantErr: zerrors.IsErrorInvalidArgument,
},
{
name: "token expired, error",
fields: fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(user.NewHumanRefreshTokenAddedEvent(
context.Background(),
&user.NewAggregate("userID", "orgID").Aggregate,
"tokenID",
"applicationID",
"userAgentID",
"de",
[]string{"clientID1"},
[]string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
[]string{"password"},
time.Now(),
1*time.Hour,
24*time.Hour,
nil,
)),
),
),
keyAlgorithm: refreshTokenEncryptionAlgorithm(gomock.NewController(t)),
},
args: args{
ctx: context.Background(),
userID: "userID",
orgID: "orgID",
refreshToken: base64.RawURLEncoding.EncodeToString([]byte("userID:tokenID:tokenID")),
},
wantErr: zerrors.IsErrorInvalidArgument,
},
{
name: "user deactivated, error",
fields: fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusherWithCreationDateNow(user.NewHumanRefreshTokenAddedEvent(
context.Background(),
&user.NewAggregate("userID", "orgID").Aggregate,
"tokenID",
"applicationID",
"userAgentID",
"de",
[]string{"clientID1"},
[]string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
[]string{"password"},
time.Now(),
1*time.Hour,
24*time.Hour,
nil,
)),
eventFromEventPusher(
user.NewUserDeactivatedEvent(
context.Background(),
&user.NewAggregate("userID", "orgID").Aggregate,
),
),
),
),
keyAlgorithm: refreshTokenEncryptionAlgorithm(gomock.NewController(t)),
},
args: args{
ctx: context.Background(),
userID: "userID",
orgID: "orgID",
refreshToken: base64.RawURLEncoding.EncodeToString([]byte("userID:tokenID:tokenID")),
idleExpiration: 1 * time.Hour,
},
wantErr: zerrors.IsErrorInvalidArgument,
},
{
name: "user signedout, error",
fields: fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusherWithCreationDateNow(user.NewHumanRefreshTokenAddedEvent(
context.Background(),
&user.NewAggregate("userID", "orgID").Aggregate,
"tokenID",
"applicationID",
"userAgentID",
"de",
[]string{"clientID1"},
[]string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
[]string{"password"},
time.Now(),
1*time.Hour,
24*time.Hour,
nil,
)),
eventFromEventPusher(
user.NewHumanSignedOutEvent(
context.Background(),
&user.NewAggregate("userID", "orgID").Aggregate,
"userAgentID",
),
),
),
),
keyAlgorithm: refreshTokenEncryptionAlgorithm(gomock.NewController(t)),
},
args: args{
ctx: context.Background(),
userID: "userID",
orgID: "orgID",
refreshToken: base64.RawURLEncoding.EncodeToString([]byte("userID:tokenID:tokenID")),
idleExpiration: 1 * time.Hour,
},
wantErr: zerrors.IsErrorInvalidArgument,
},
{
name: "token renewed, ok",
fields: fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusherWithCreationDateNow(user.NewHumanRefreshTokenAddedEvent(
context.Background(),
&user.NewAggregate("userID", "orgID").Aggregate,
"tokenID",
"applicationID",
"userAgentID",
"de",
[]string{"clientID1"},
[]string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopeOfflineAccess},
[]string{"password"},
time.Now(),
1*time.Hour,
24*time.Hour,
nil,
)),
),
),
keyAlgorithm: refreshTokenEncryptionAlgorithm(gomock.NewController(t)),
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "refreshToken1"),
},
args: args{
ctx: context.Background(),
userID: "userID",
orgID: "orgID",
refreshToken: base64.RawURLEncoding.EncodeToString([]byte("userID:tokenID:tokenID")),
idleExpiration: 1 * time.Hour,
},
want: &renewedRefreshToken{
event: user.NewHumanRefreshTokenRenewedEvent(
context.Background(),
&user.NewAggregate("userID", "orgID").Aggregate,
"tokenID",
"refreshToken1",
1*time.Hour,
),
authMethodsReferences: []string{"password"},
tokenID: "tokenID",
token: base64.RawURLEncoding.EncodeToString([]byte("userID:tokenID:refreshToken1")),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore,
idGenerator: tt.fields.idGenerator,
keyAlgorithm: tt.fields.keyAlgorithm,
}
got, err := c.renewRefreshToken(tt.args.ctx, tt.args.userID, tt.args.orgID, tt.args.refreshToken, tt.args.idleExpiration)
if tt.wantErr != nil && !tt.wantErr(err) {
t.Errorf("got wrong err: %v ", err)
}
if tt.wantErr == nil {
require.NoError(t, err)
assert.Equal(t, tt.want.event, got.event)
assert.Equal(t, tt.want.authMethodsReferences, got.authMethodsReferences)
assert.Equal(t, tt.want.tokenID, got.tokenID)
assert.Equal(t, tt.want.token, got.token)
}
})
}
}

View File

@@ -11,7 +11,6 @@ import (
"github.com/zitadel/zitadel/internal/command/preparation"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/id"
"github.com/zitadel/zitadel/internal/repository/instance"
"github.com/zitadel/zitadel/internal/repository/org"
"github.com/zitadel/zitadel/internal/repository/project"
@@ -1433,91 +1432,6 @@ func TestCommandSide_RemoveUser(t *testing.T) {
}
}
func TestCommandSide_AddUserToken(t *testing.T) {
type fields struct {
eventstore *eventstore.Eventstore
idGenerator id.Generator
}
type (
args struct {
ctx context.Context
orgID string
agentID string
clientID string
userID string
audience []string
scopes []string
authMethodsReferences []string
lifetime time.Duration
authTime time.Time
reason domain.TokenReason
actor *domain.TokenActor
}
)
type res struct {
want *domain.Token
err func(error) bool
}
tests := []struct {
name string
fields fields
args args
res res
}{
{
name: "userid missing, invalid argument error",
fields: fields{
eventstore: eventstoreExpect(
t,
),
},
args: args{
ctx: context.Background(),
orgID: "org1",
userID: "",
},
res: res{
err: zerrors.IsErrorInvalidArgument,
},
},
{
name: "user not existing, not found error",
fields: fields{
eventstore: eventstoreExpect(
t,
expectFilter(),
),
},
args: args{
ctx: context.Background(),
orgID: "org1",
userID: "user1",
},
res: res{
err: zerrors.IsNotFound,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &Commands{
eventstore: tt.fields.eventstore,
idGenerator: tt.fields.idGenerator,
}
got, err := r.AddUserToken(tt.args.ctx, tt.args.orgID, tt.args.agentID, tt.args.clientID, tt.args.userID, tt.args.audience, tt.args.scopes, tt.args.authMethodsReferences, tt.args.lifetime, tt.args.authTime, tt.args.reason, tt.args.actor)
if tt.res.err == nil {
assert.NoError(t, err)
}
if tt.res.err != nil && !tt.res.err(err) {
t.Errorf("got wrong err: %v ", err)
}
if tt.res.err == nil {
assert.Equal(t, tt.res.want, got)
}
})
}
}
func TestCommands_RevokeAccessToken(t *testing.T) {
type fields struct {
eventstore *eventstore.Eventstore