perf(oidc): optimize token creation (#7822)

* implement code exchange

* port tokenexchange to v2 tokens

* implement refresh token

* implement client credentials

* implement jwt profile

* implement device token

* cleanup unused code

* fix current unit tests

* add user agent unit test

* unit test domain package

* need refresh token as argument

* test commands create oidc session

* test commands device auth

* fix device auth build error

* implicit for oidc session API

* implement authorize callback handler for legacy implicit mode

* upgrade oidc module to working draft

* add missing auth methods and time

* handle all errors in defer

* do not fail auth request on error

the oauth2 Go client automagically retries on any error. If we fail the auth request on the first error, the next attempt will always fail with the Errors.AuthRequest.NoCode, because the auth request state is already set to failed.
The original error is then already lost and the oauth2 library does not return the original error.

Therefore we should not fail the auth request.

Might be worth discussing and perhaps send a bug report to Oauth2?

* fix code flow tests by explicitly setting code exchanged

* fix unit tests in command package

* return allowed scope from client credential client

* add device auth done reducer

* carry nonce thru session into ID token

* fix token exchange integration tests

* allow project role scope prefix in client credentials client

* gci formatting

* do not return refresh token in client credentials and jwt profile

* check org scope

* solve linting issue on authorize callback error

* end session based on v2 session ID

* use preferred language and user agent ID for v2 access tokens

* pin oidc v3.23.2

* add integration test for jwt profile and client credentials with org scopes

* refresh token v1 to v2

* add user token v2 audit event

* add activity trigger

* cleanup and set panics for unused methods

* use the encrypted code for v1 auth request get by code

* add missing event translation

* fix pipeline errors (hopefully)

* fix another test

* revert pointer usage of preferred language

* solve browser info panic in device auth

* remove duplicate entries in AMRToAuthMethodTypes to prevent future `mfa` claim

* revoke v1 refresh token to prevent reuse

* fix terminate oidc session

* always return a new refresh toke in refresh token grant

---------

Co-authored-by: Livio Spring <livio.a@gmail.com>
This commit is contained in:
Tim Möhlmann
2024-05-16 08:07:56 +03:00
committed by GitHub
parent 6cf9ca9f7e
commit 8e0c8393e9
84 changed files with 3429 additions and 2635 deletions

View File

@@ -5,6 +5,8 @@ import (
"strings"
"time"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/oidcsession"
@@ -24,10 +26,13 @@ type OIDCSessionAccessTokenReadModel struct {
Scope []string
AuthMethods []domain.UserAuthMethodType
AuthTime time.Time
Nonce string
State domain.OIDCSessionState
AccessTokenID string
AccessTokenCreation time.Time
AccessTokenExpiration time.Time
PreferredLanguage *language.Tag
UserAgent *domain.UserAgent
Reason domain.TokenReason
Actor *domain.TokenActor
}
@@ -79,6 +84,9 @@ func (wm *OIDCSessionAccessTokenReadModel) reduceAdded(e *oidcsession.AddedEvent
wm.Scope = e.Scope
wm.AuthMethods = e.AuthMethods
wm.AuthTime = e.AuthTime
wm.Nonce = e.Nonce
wm.PreferredLanguage = e.PreferredLanguage
wm.UserAgent = e.UserAgent
wm.State = domain.OIDCSessionStateActive
}
@@ -112,7 +120,7 @@ func (q *Queries) ActiveAccessTokenByToken(ctx context.Context, token string) (m
if !model.AccessTokenExpiration.After(time.Now()) {
return nil, zerrors.ThrowPermissionDenied(nil, "QUERY-SAF3rf", "Errors.OIDCSession.Token.Expired")
}
if err = q.checkSessionNotTerminatedAfter(ctx, model.SessionID, model.UserID, model.AccessTokenCreation); err != nil {
if err = q.checkSessionNotTerminatedAfter(ctx, model.SessionID, model.UserID, model.AccessTokenCreation, model.UserAgent.GetFingerprintID()); err != nil {
return nil, err
}
return model, nil
@@ -132,16 +140,17 @@ func (q *Queries) accessTokenByOIDCSessionAndTokenID(ctx context.Context, oidcSe
return model, nil
}
// checkSessionNotTerminatedAfter checks if a [session.TerminateType] event occurred after a certain time
// and will return an error if so.
func (q *Queries) checkSessionNotTerminatedAfter(ctx context.Context, sessionID, userID string, creation time.Time) (err error) {
// checkSessionNotTerminatedAfter checks if a [session.TerminateType] event (or user events leading to a session termination)
// occurred after a certain time and will return an error if so.
func (q *Queries) checkSessionNotTerminatedAfter(ctx context.Context, sessionID, userID string, creation time.Time, fingerprintID string) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
model := &sessionTerminatedModel{
sessionID: sessionID,
creation: creation,
userID: userID,
sessionID: sessionID,
creation: creation,
userID: userID,
fingerPrintID: fingerprintID,
}
err = q.eventstore.FilterToQueryReducer(ctx, model)
if err != nil {
@@ -155,9 +164,10 @@ func (q *Queries) checkSessionNotTerminatedAfter(ctx context.Context, sessionID,
}
type sessionTerminatedModel struct {
creation time.Time
sessionID string
userID string
creation time.Time
sessionID string
userID string
fingerPrintID string
events int
terminated bool
@@ -195,5 +205,12 @@ func (s *sessionTerminatedModel) Query() *eventstore.SearchQueryBuilder {
user.UserLockedType,
user.UserRemovedType,
).
Or(). // for specific logout on v1 sessions from the same user agent
AggregateTypes(user.AggregateType).
AggregateIDs(s.userID).
EventTypes(
user.HumanSignedOutType,
).
EventData(map[string]interface{}{"userAgentID": s.fingerPrintID}).
Builder()
}

View File

@@ -209,7 +209,7 @@ func (q *Queries) GetAuthNKeyByID(ctx context.Context, shouldTriggerBulk bool, i
return key, err
}
func (q *Queries) GetAuthNKeyPublicKeyByIDAndIdentifier(ctx context.Context, id string, identifier string, withOwnerRemoved bool) (key []byte, err error) {
func (q *Queries) GetAuthNKeyPublicKeyByIDAndIdentifier(ctx context.Context, id string, identifier string) (key []byte, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()

View File

@@ -4,7 +4,6 @@ import (
"context"
"database/sql"
"errors"
"time"
sq "github.com/Masterminds/squirrel"
@@ -59,34 +58,6 @@ var (
}
)
type DeviceAuth struct {
ClientID string
DeviceCode string
UserCode string
Expires time.Time
Scopes []string
Audience []string
State domain.DeviceAuthState
Subject string
UserAuthMethods []domain.UserAuthMethodType
AuthTime time.Time
}
// DeviceAuthByDeviceCode gets the current state of a Device Authorization directly from the eventstore.
func (q *Queries) DeviceAuthByDeviceCode(ctx context.Context, deviceCode string) (deviceAuth *DeviceAuth, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
model := NewDeviceAuthReadModel(deviceCode, authz.GetInstance(ctx).InstanceID())
if err := q.eventstore.FilterToQueryReducer(ctx, model); err != nil {
return nil, err
}
if !model.State.Exists() {
return nil, zerrors.ThrowNotFound(nil, "QUERY-eeR0e", "Errors.DeviceAuth.NotExisting")
}
return &model.DeviceAuth, nil
}
// DeviceAuthRequestByUserCode finds a Device Authorization request by User-Code from the `device_auth_requests` projection.
func (q *Queries) DeviceAuthRequestByUserCode(ctx context.Context, userCode string) (authReq *domain.AuthRequestDevice, err error) {
ctx, span := tracing.NewSpan(ctx)

View File

@@ -1,59 +0,0 @@
package query
import (
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/deviceauth"
)
type DeviceAuthReadModel struct {
eventstore.ReadModel
DeviceAuth
}
func NewDeviceAuthReadModel(deviceCode, resourceOwner string) *DeviceAuthReadModel {
return &DeviceAuthReadModel{
ReadModel: eventstore.ReadModel{
AggregateID: deviceCode,
ResourceOwner: resourceOwner,
},
}
}
func (m *DeviceAuthReadModel) Reduce() error {
for _, event := range m.Events {
switch e := event.(type) {
case *deviceauth.AddedEvent:
m.ClientID = e.ClientID
m.DeviceCode = e.DeviceCode
m.UserCode = e.UserCode
m.Expires = e.Expires
m.Scopes = e.Scopes
m.Audience = e.Audience
m.State = e.State
case *deviceauth.ApprovedEvent:
m.State = domain.DeviceAuthStateApproved
m.Subject = e.Subject
m.UserAuthMethods = e.UserAuthMethods
m.AuthTime = e.AuthTime
case *deviceauth.CanceledEvent:
m.State = e.Reason.State()
}
}
return m.ReadModel.Reduce()
}
func (m *DeviceAuthReadModel) Query() *eventstore.SearchQueryBuilder {
return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
ResourceOwner(m.ResourceOwner).
AddQuery().
AggregateTypes(deviceauth.AggregateType).
AggregateIDs(m.AggregateID).
EventTypes(
deviceauth.AddedEventType,
deviceauth.ApprovedEventType,
deviceauth.CanceledEventType,
).
Builder()
}

View File

@@ -6,167 +6,18 @@ import (
"database/sql/driver"
"errors"
"fmt"
"io"
"regexp"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/database"
db_mock "github.com/zitadel/zitadel/internal/database/mock"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/deviceauth"
"github.com/zitadel/zitadel/internal/zerrors"
)
func TestQueries_DeviceAuthByDeviceCode(t *testing.T) {
ctx := authz.NewMockContext("inst1", "org1", "user1")
timestamp := time.Date(2015, 12, 15, 22, 13, 45, 0, time.UTC)
tests := []struct {
name string
eventstore func(t *testing.T) *eventstore.Eventstore
want *DeviceAuth
wantErr error
}{
{
name: "filter error",
eventstore: expectEventstore(
expectFilterError(io.ErrClosedPipe),
),
wantErr: io.ErrClosedPipe,
},
{
name: "not found",
eventstore: expectEventstore(
expectFilter(),
),
wantErr: zerrors.ThrowNotFound(nil, "QUERY-eeR0e", "Errors.DeviceAuth.NotExisting"),
},
{
name: "ok, initiated",
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(deviceauth.NewAddedEvent(
ctx,
deviceauth.NewAggregate("device1", "instance1"),
"client1", "device1", "user-code", timestamp, []string{"foo", "bar"},
[]string{"projectID", "clientID"},
)),
),
),
want: &DeviceAuth{
ClientID: "client1",
DeviceCode: "device1",
UserCode: "user-code",
Expires: timestamp,
Scopes: []string{"foo", "bar"},
Audience: []string{"projectID", "clientID"},
State: domain.DeviceAuthStateInitiated,
},
},
{
name: "ok, approved",
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(deviceauth.NewAddedEvent(
ctx,
deviceauth.NewAggregate("device1", "instance1"),
"client1", "device1", "user-code", timestamp, []string{"foo", "bar"},
[]string{"projectID", "clientID"},
)),
eventFromEventPusher(deviceauth.NewApprovedEvent(
ctx,
deviceauth.NewAggregate("device1", "instance1"),
"user1", []domain.UserAuthMethodType{domain.UserAuthMethodTypePasswordless},
timestamp,
)),
),
),
want: &DeviceAuth{
ClientID: "client1",
DeviceCode: "device1",
UserCode: "user-code",
Expires: timestamp,
Scopes: []string{"foo", "bar"},
Audience: []string{"projectID", "clientID"},
State: domain.DeviceAuthStateApproved,
Subject: "user1",
UserAuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePasswordless},
AuthTime: timestamp,
},
},
{
name: "ok, denied",
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(deviceauth.NewAddedEvent(
ctx,
deviceauth.NewAggregate("device1", "instance1"),
"client1", "device1", "user-code", timestamp, []string{"foo", "bar"},
[]string{"projectID", "clientID"},
)),
eventFromEventPusher(deviceauth.NewCanceledEvent(
ctx,
deviceauth.NewAggregate("device1", "instance1"),
domain.DeviceAuthCanceledDenied,
)),
),
),
want: &DeviceAuth{
ClientID: "client1",
DeviceCode: "device1",
UserCode: "user-code",
Expires: timestamp,
Scopes: []string{"foo", "bar"},
Audience: []string{"projectID", "clientID"},
State: domain.DeviceAuthStateDenied,
},
},
{
name: "ok, expired",
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(deviceauth.NewAddedEvent(
ctx,
deviceauth.NewAggregate("device1", "instance1"),
"client1", "device1", "user-code", timestamp, []string{"foo", "bar"},
[]string{"projectID", "clientID"},
)),
eventFromEventPusher(deviceauth.NewCanceledEvent(
ctx,
deviceauth.NewAggregate("device1", "instance1"),
domain.DeviceAuthCanceledExpired,
)),
),
),
want: &DeviceAuth{
ClientID: "client1",
DeviceCode: "device1",
UserCode: "user-code",
Expires: timestamp,
Scopes: []string{"foo", "bar"},
Audience: []string{"projectID", "clientID"},
State: domain.DeviceAuthStateExpired,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
q := &Queries{
eventstore: tt.eventstore(t),
}
got, err := q.DeviceAuthByDeviceCode(ctx, "device1")
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.want, got)
})
}
}
const (
expectedDeviceAuthQueryC = `SELECT` +
` projections.device_auth_requests2.client_id,` +

View File

@@ -74,6 +74,10 @@ func (p *deviceAuthRequestProjection) Reducers() []handler.AggregateReducer {
Event: deviceauth.CanceledEventType,
Reduce: p.reduceDoneEvents,
},
{
Event: deviceauth.DoneEventType,
Reduce: p.reduceDoneEvents,
},
},
},
}
@@ -103,7 +107,7 @@ func (p *deviceAuthRequestProjection) reduceAdded(event eventstore.Event) (*hand
// reduceDoneEvents removes the device auth request from the projection.
func (p *deviceAuthRequestProjection) reduceDoneEvents(event eventstore.Event) (*handler.Statement, error) {
switch event.(type) {
case *deviceauth.ApprovedEvent, *deviceauth.CanceledEvent:
case *deviceauth.ApprovedEvent, *deviceauth.CanceledEvent, *deviceauth.DoneEvent:
return handler.NewDeleteStatement(event,
[]handler.Condition{
handler.NewCond(DeviceAuthRequestColumnInstanceID, event.Aggregate().InstanceID),