feat(oidc): allow additional audience based on scope in device auth (#7685)

feat(oidc): allow additional audience based on scope
This commit is contained in:
Tim Möhlmann 2024-04-03 09:06:21 +03:00 committed by GitHub
parent 2d25244c77
commit 5b3946b67e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 61 additions and 24 deletions

View File

@ -41,12 +41,12 @@ func (o *OPStorage) CreateAuthRequest(ctx context.Context, req *oidc.AuthRequest
return o.createAuthRequest(ctx, req, userID)
}
func (o *OPStorage) createAuthRequestScopeAndAudience(ctx context.Context, req *oidc.AuthRequest) (scope, audience []string, err error) {
project, err := o.query.ProjectByClientID(ctx, req.ClientID)
func (o *OPStorage) createAuthRequestScopeAndAudience(ctx context.Context, clientID string, reqScope []string) (scope, audience []string, err error) {
project, err := o.query.ProjectByClientID(ctx, clientID)
if err != nil {
return nil, nil, err
}
scope, err = o.assertProjectRoleScopesByProject(ctx, project, req.Scopes)
scope, err = o.assertProjectRoleScopesByProject(ctx, project, reqScope)
if err != nil {
return nil, nil, err
}
@ -59,7 +59,7 @@ func (o *OPStorage) createAuthRequestScopeAndAudience(ctx context.Context, req *
}
func (o *OPStorage) createAuthRequestLoginClient(ctx context.Context, req *oidc.AuthRequest, hintUserID, loginClient string) (op.AuthRequest, error) {
scope, audience, err := o.createAuthRequestScopeAndAudience(ctx, req)
scope, audience, err := o.createAuthRequestScopeAndAudience(ctx, req.ClientID, req.Scopes)
if err != nil {
return nil, err
}
@ -96,7 +96,7 @@ func (o *OPStorage) createAuthRequest(ctx context.Context, req *oidc.AuthRequest
if !ok {
return nil, zerrors.ThrowPreconditionFailed(nil, "OIDC-sd436", "no user agent id")
}
scope, audience, err := o.createAuthRequestScopeAndAudience(ctx, req)
scope, audience, err := o.createAuthRequestScopeAndAudience(ctx, req.ClientID, req.Scopes)
if err != nil {
return nil, err
}

View File

@ -11,7 +11,6 @@ import (
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/zerrors"
)
const (
@ -68,21 +67,20 @@ func (c *DeviceAuthorizationConfig) toOPConfig() op.DeviceAuthorizationConfig {
// StoreDeviceAuthorization creates a new Device Authorization request.
// Implements the op.DeviceAuthorizationStorage interface.
func (o *OPStorage) StoreDeviceAuthorization(ctx context.Context, clientID, deviceCode, userCode string, expires time.Time, scopes []string) (err error) {
func (o *OPStorage) StoreDeviceAuthorization(ctx context.Context, clientID, deviceCode, userCode string, expires time.Time, scope []string) (err error) {
const logMsg = "store device authorization"
logger := logging.WithFields("client_id", clientID, "device_code", deviceCode, "user_code", userCode, "expires", expires, "scopes", scopes)
logger := logging.WithFields("client_id", clientID, "device_code", deviceCode, "user_code", userCode, "expires", expires, "scope", scope)
ctx, span := tracing.NewSpan(ctx)
defer func() {
logger.OnError(err).Error(logMsg)
span.EndWithError(err)
}()
scopes, err = o.assertProjectRoleScopes(ctx, clientID, scopes)
scope, audience, err := o.createAuthRequestScopeAndAudience(ctx, clientID, scope)
if err != nil {
return zerrors.ThrowPreconditionFailed(err, "OIDC-She4t", "Errors.Internal")
return err
}
details, err := o.command.AddDeviceAuth(ctx, clientID, deviceCode, userCode, expires, scopes)
details, err := o.command.AddDeviceAuth(ctx, clientID, deviceCode, userCode, expires, scope, audience)
if err == nil {
logger.SetFields("details", details).Debug(logMsg)
}
@ -94,6 +92,7 @@ func newDeviceAuthorizationState(d *query.DeviceAuth) *op.DeviceAuthorizationSta
return &op.DeviceAuthorizationState{
ClientID: d.ClientID,
Scopes: d.Scopes,
Audience: d.Audience,
Expires: d.Expires,
Done: d.State.Done(),
Denied: d.State.Denied(),

View File

@ -11,7 +11,7 @@ import (
"github.com/zitadel/zitadel/internal/zerrors"
)
func (c *Commands) AddDeviceAuth(ctx context.Context, clientID, deviceCode, userCode string, expires time.Time, scopes []string) (*domain.ObjectDetails, error) {
func (c *Commands) AddDeviceAuth(ctx context.Context, clientID, deviceCode, userCode string, expires time.Time, scopes, audience []string) (*domain.ObjectDetails, error) {
aggr := deviceauth.NewAggregate(deviceCode, authz.GetInstance(ctx).InstanceID())
model := NewDeviceAuthWriteModel(deviceCode, aggr.ResourceOwner)
@ -23,6 +23,7 @@ func (c *Commands) AddDeviceAuth(ctx context.Context, clientID, deviceCode, user
userCode,
expires,
scopes,
audience,
))
if err != nil {
return nil, err

View File

@ -34,6 +34,7 @@ func TestCommands_AddDeviceAuth(t *testing.T) {
userCode string
expires time.Time
scopes []string
audience []string
}
tests := []struct {
name string
@ -51,6 +52,7 @@ func TestCommands_AddDeviceAuth(t *testing.T) {
deviceauth.NewAggregate("123", "instance1"),
"client_id", "123", "456", now,
[]string{"a", "b", "c"},
[]string{"projectID", "clientID"},
),
)),
},
@ -61,6 +63,7 @@ func TestCommands_AddDeviceAuth(t *testing.T) {
userCode: "456",
expires: now,
scopes: []string{"a", "b", "c"},
audience: []string{"projectID", "clientID"},
},
wantDetails: &domain.ObjectDetails{
ResourceOwner: "instance1",
@ -75,6 +78,7 @@ func TestCommands_AddDeviceAuth(t *testing.T) {
deviceauth.NewAggregate("123", "instance1"),
"client_id", "123", "456", now,
[]string{"a", "b", "c"},
[]string{"projectID", "clientID"},
)),
),
},
@ -85,6 +89,7 @@ func TestCommands_AddDeviceAuth(t *testing.T) {
userCode: "456",
expires: now,
scopes: []string{"a", "b", "c"},
audience: []string{"projectID", "clientID"},
},
wantErr: pushErr,
},
@ -94,7 +99,7 @@ func TestCommands_AddDeviceAuth(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore,
}
gotDetails, err := c.AddDeviceAuth(tt.args.ctx, tt.args.clientID, tt.args.deviceCode, tt.args.userCode, tt.args.expires, tt.args.scopes)
gotDetails, err := c.AddDeviceAuth(tt.args.ctx, tt.args.clientID, tt.args.deviceCode, tt.args.userCode, tt.args.expires, tt.args.scopes, tt.args.audience)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.wantDetails, gotDetails)
})
@ -148,6 +153,7 @@ func TestCommands_ApproveDeviceAuth(t *testing.T) {
deviceauth.NewAggregate("123", "instance1"),
"client_id", "123", "456", now,
[]string{"a", "b", "c"},
[]string{"projectID", "clientID"},
),
)),
expectPushFailed(pushErr,
@ -177,6 +183,7 @@ func TestCommands_ApproveDeviceAuth(t *testing.T) {
deviceauth.NewAggregate("123", "instance1"),
"client_id", "123", "456", now,
[]string{"a", "b", "c"},
[]string{"projectID", "clientID"},
),
)),
expectPush(
@ -251,6 +258,7 @@ func TestCommands_CancelDeviceAuth(t *testing.T) {
deviceauth.NewAggregate("123", "instance1"),
"client_id", "123", "456", now,
[]string{"a", "b", "c"},
[]string{"projectID", "clientID"},
),
)),
expectPushFailed(pushErr,
@ -275,6 +283,7 @@ func TestCommands_CancelDeviceAuth(t *testing.T) {
deviceauth.NewAggregate("123", "instance1"),
"client_id", "123", "456", now,
[]string{"a", "b", "c"},
[]string{"projectID", "clientID"},
),
)),
expectPush(
@ -301,6 +310,7 @@ func TestCommands_CancelDeviceAuth(t *testing.T) {
deviceauth.NewAggregate("123", "instance1"),
"client_id", "123", "456", now,
[]string{"a", "b", "c"},
[]string{"projectID", "clientID"},
),
)),
expectPush(

View File

@ -63,6 +63,7 @@ type AuthRequestDevice struct {
DeviceCode string
UserCode string
Scopes []string
Audience []string
}
func (*AuthRequestDevice) Type() AuthRequestType {

View File

@ -37,6 +37,10 @@ var (
name: projection.DeviceAuthRequestColumnScopes,
table: deviceAuthRequestTable,
}
DeviceAuthRequestColumnAudience = Column{
name: projection.DeviceAuthRequestColumnAudience,
table: deviceAuthRequestTable,
}
DeviceAuthRequestColumnCreationDate = Column{
name: projection.DeviceAuthRequestColumnCreationDate,
table: deviceAuthRequestTable,
@ -61,6 +65,7 @@ type DeviceAuth struct {
UserCode string
Expires time.Time
Scopes []string
Audience []string
State domain.DeviceAuthState
Subject string
UserAuthMethods []domain.UserAuthMethodType
@ -109,6 +114,7 @@ var deviceAuthSelectColumns = []string{
DeviceAuthRequestColumnDeviceCode.identifier(),
DeviceAuthRequestColumnUserCode.identifier(),
DeviceAuthRequestColumnScopes.identifier(),
DeviceAuthRequestColumnAudience.identifier(),
}
func prepareDeviceAuthQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*domain.AuthRequestDevice, error)) {
@ -116,7 +122,8 @@ func prepareDeviceAuthQuery(ctx context.Context, db prepareDatabase) (sq.SelectB
func(row *sql.Row) (*domain.AuthRequestDevice, error) {
dst := new(domain.AuthRequestDevice)
var (
scopes database.TextArray[string]
scopes database.TextArray[string]
audience database.TextArray[string]
)
err := row.Scan(
@ -124,6 +131,7 @@ func prepareDeviceAuthQuery(ctx context.Context, db prepareDatabase) (sq.SelectB
&dst.DeviceCode,
&dst.UserCode,
&scopes,
&audience,
)
if errors.Is(err, sql.ErrNoRows) {
return nil, zerrors.ThrowNotFound(err, "QUERY-Sah9a", "Errors.DeviceAuth.NotExisting")
@ -132,6 +140,7 @@ func prepareDeviceAuthQuery(ctx context.Context, db prepareDatabase) (sq.SelectB
return nil, zerrors.ThrowInternal(err, "QUERY-Voo3o", "Errors.Internal")
}
dst.Scopes = scopes
dst.Audience = audience
return dst, nil
}
}

View File

@ -29,6 +29,7 @@ func (m *DeviceAuthReadModel) Reduce() error {
m.UserCode = e.UserCode
m.Expires = e.Expires
m.Scopes = e.Scopes
m.Audience = e.Audience
m.State = e.State
case *deviceauth.ApprovedEvent:
m.State = domain.DeviceAuthStateApproved

View File

@ -55,6 +55,7 @@ func TestQueries_DeviceAuthByDeviceCode(t *testing.T) {
ctx,
deviceauth.NewAggregate("device1", "instance1"),
"client1", "device1", "user-code", timestamp, []string{"foo", "bar"},
[]string{"projectID", "clientID"},
)),
),
),
@ -64,6 +65,7 @@ func TestQueries_DeviceAuthByDeviceCode(t *testing.T) {
UserCode: "user-code",
Expires: timestamp,
Scopes: []string{"foo", "bar"},
Audience: []string{"projectID", "clientID"},
State: domain.DeviceAuthStateInitiated,
},
},
@ -75,6 +77,7 @@ func TestQueries_DeviceAuthByDeviceCode(t *testing.T) {
ctx,
deviceauth.NewAggregate("device1", "instance1"),
"client1", "device1", "user-code", timestamp, []string{"foo", "bar"},
[]string{"projectID", "clientID"},
)),
eventFromEventPusher(deviceauth.NewApprovedEvent(
ctx,
@ -90,6 +93,7 @@ func TestQueries_DeviceAuthByDeviceCode(t *testing.T) {
UserCode: "user-code",
Expires: timestamp,
Scopes: []string{"foo", "bar"},
Audience: []string{"projectID", "clientID"},
State: domain.DeviceAuthStateApproved,
Subject: "user1",
UserAuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePasswordless},
@ -104,6 +108,7 @@ func TestQueries_DeviceAuthByDeviceCode(t *testing.T) {
ctx,
deviceauth.NewAggregate("device1", "instance1"),
"client1", "device1", "user-code", timestamp, []string{"foo", "bar"},
[]string{"projectID", "clientID"},
)),
eventFromEventPusher(deviceauth.NewCanceledEvent(
ctx,
@ -118,6 +123,7 @@ func TestQueries_DeviceAuthByDeviceCode(t *testing.T) {
UserCode: "user-code",
Expires: timestamp,
Scopes: []string{"foo", "bar"},
Audience: []string{"projectID", "clientID"},
State: domain.DeviceAuthStateDenied,
},
},
@ -129,6 +135,7 @@ func TestQueries_DeviceAuthByDeviceCode(t *testing.T) {
ctx,
deviceauth.NewAggregate("device1", "instance1"),
"client1", "device1", "user-code", timestamp, []string{"foo", "bar"},
[]string{"projectID", "clientID"},
)),
eventFromEventPusher(deviceauth.NewCanceledEvent(
ctx,
@ -143,6 +150,7 @@ func TestQueries_DeviceAuthByDeviceCode(t *testing.T) {
UserCode: "user-code",
Expires: timestamp,
Scopes: []string{"foo", "bar"},
Audience: []string{"projectID", "clientID"},
State: domain.DeviceAuthStateExpired,
},
},
@ -161,14 +169,15 @@ func TestQueries_DeviceAuthByDeviceCode(t *testing.T) {
const (
expectedDeviceAuthQueryC = `SELECT` +
` projections.device_auth_requests.client_id,` +
` projections.device_auth_requests.device_code,` +
` projections.device_auth_requests.user_code,` +
` projections.device_auth_requests.scopes` +
` FROM projections.device_auth_requests`
` projections.device_auth_requests1.client_id,` +
` projections.device_auth_requests1.device_code,` +
` projections.device_auth_requests1.user_code,` +
` projections.device_auth_requests1.scopes,` +
` projections.device_auth_requests1.audience` +
` FROM projections.device_auth_requests1`
expectedDeviceAuthWhereUserCodeQueryC = expectedDeviceAuthQueryC +
` WHERE projections.device_auth_requests.instance_id = $1` +
` AND projections.device_auth_requests.user_code = $2`
` WHERE projections.device_auth_requests1.instance_id = $1` +
` AND projections.device_auth_requests1.user_code = $2`
)
var (
@ -179,12 +188,14 @@ var (
"device1",
"user-code",
database.TextArray[string]{"a", "b", "c"},
[]string{"projectID", "clientID"},
}
expectedDeviceAuth = &domain.AuthRequestDevice{
ClientID: "client-id",
DeviceCode: "device1",
UserCode: "user-code",
Scopes: []string{"a", "b", "c"},
Audience: []string{"projectID", "clientID"},
}
)

View File

@ -11,12 +11,13 @@ import (
)
const (
DeviceAuthRequestProjectionTable = "projections.device_auth_requests"
DeviceAuthRequestProjectionTable = "projections.device_auth_requests1"
DeviceAuthRequestColumnClientID = "client_id"
DeviceAuthRequestColumnDeviceCode = "device_code"
DeviceAuthRequestColumnUserCode = "user_code"
DeviceAuthRequestColumnScopes = "scopes"
DeviceAuthRequestColumnAudience = "audience"
DeviceAuthRequestColumnCreationDate = "creation_date"
DeviceAuthRequestColumnChangeDate = "change_date"
DeviceAuthRequestColumnSequence = "sequence"
@ -44,6 +45,7 @@ func (*deviceAuthRequestProjection) Init() *old_handler.Check {
handler.NewColumn(DeviceAuthRequestColumnDeviceCode, handler.ColumnTypeText),
handler.NewColumn(DeviceAuthRequestColumnUserCode, handler.ColumnTypeText),
handler.NewColumn(DeviceAuthRequestColumnScopes, handler.ColumnTypeTextArray),
handler.NewColumn(DeviceAuthRequestColumnAudience, handler.ColumnTypeTextArray),
handler.NewColumn(DeviceAuthRequestColumnCreationDate, handler.ColumnTypeTimestamp),
handler.NewColumn(DeviceAuthRequestColumnChangeDate, handler.ColumnTypeTimestamp),
handler.NewColumn(DeviceAuthRequestColumnSequence, handler.ColumnTypeInt64),
@ -89,6 +91,7 @@ func (p *deviceAuthRequestProjection) reduceAdded(event eventstore.Event) (*hand
handler.NewCol(DeviceAuthRequestColumnDeviceCode, e.DeviceCode),
handler.NewCol(DeviceAuthRequestColumnUserCode, e.UserCode),
handler.NewCol(DeviceAuthRequestColumnScopes, e.Scopes),
handler.NewCol(DeviceAuthRequestColumnAudience, e.Audience),
handler.NewCol(DeviceAuthRequestColumnCreationDate, e.CreationDate()),
handler.NewCol(DeviceAuthRequestColumnChangeDate, e.CreationDate()),
handler.NewCol(DeviceAuthRequestColumnSequence, e.Sequence()),

View File

@ -23,6 +23,7 @@ type AddedEvent struct {
UserCode string
Expires time.Time
Scopes []string
Audience []string
State domain.DeviceAuthState
}
@ -46,12 +47,13 @@ func NewAddedEvent(
userCode string,
expires time.Time,
scopes []string,
audience []string,
) *AddedEvent {
return &AddedEvent{
eventstore.NewBaseEventForPush(
ctx, aggregate, AddedEventType,
),
clientID, deviceCode, userCode, expires, scopes, domain.DeviceAuthStateInitiated}
clientID, deviceCode, userCode, expires, scopes, audience, domain.DeviceAuthStateInitiated}
}
type ApprovedEvent struct {