feat(oidc): id token for device authorization (#7088)

* cleanup todo

* pass id token details to oidc

* feat(oidc): id token for device authorization

This changes updates to the newest oidc version,
so the Device Authorization grant can return ID tokens when
the scope `openid` is set.
There is also some refactoring done, so that the eventstore can be
queried directly when polling for state.
The projection is cleaned up to a minimum with only data required for the login UI.

* try to be explicit wit hthe timezone to fix github

* pin oidc v3.8.0

* remove TBD entry
This commit is contained in:
Tim Möhlmann 2023-12-20 14:21:08 +02:00 committed by GitHub
parent e15f6229cd
commit e22689c125
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 629 additions and 621 deletions

4
go.mod
View File

@ -60,7 +60,7 @@ require (
github.com/superseriousbusiness/exifremove v0.0.0-20210330092427-6acd27eac203
github.com/ttacon/libphonenumber v1.2.1
github.com/zitadel/logging v0.5.0
github.com/zitadel/oidc/v3 v3.5.0
github.com/zitadel/oidc/v3 v3.8.0
github.com/zitadel/passwap v0.4.0
github.com/zitadel/saml v0.1.3
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.46.0
@ -155,7 +155,7 @@ require (
github.com/golang/geo v0.0.0-20230421003525-6adc56603217 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/snappy v0.0.4 // indirect
github.com/google/uuid v1.4.0
github.com/google/uuid v1.5.0
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
github.com/googleapis/gax-go/v2 v2.12.0 // indirect
github.com/gorilla/handlers v1.5.2 // indirect

10
go.sum
View File

@ -410,8 +410,8 @@ github.com/google/s2a-go v0.1.7/go.mod h1:50CgR4k1jNlWBu4UfS4AcfhVe1r6pdZPygJ3R8
github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4=
github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.5.0 h1:1p67kYwdtXjb0gL0BPiP1Av9wiZPo5A8z2cWkTZ+eyU=
github.com/google/uuid v1.5.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfFxPRy3Bf7vr3h0cechB90XaQs=
github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0=
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
@ -867,12 +867,10 @@ github.com/zenazn/goji v1.0.1 h1:4lbD8Mx2h7IvloP7r2C0D6ltZP6Ufip8Hn0wmSK5LR8=
github.com/zenazn/goji v1.0.1/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q=
github.com/zitadel/logging v0.5.0 h1:Kunouvqse/efXy4UDvFw5s3vP+Z4AlHo3y8wF7stXHA=
github.com/zitadel/logging v0.5.0/go.mod h1:IzP5fzwFhzzyxHkSmfF8dsyqFsQRJLLcQmwhIBzlGsE=
github.com/zitadel/oidc/v3 v3.5.0 h1:z51AN6FPo5UuwYJ1r9nLvHlxpTGYd8QXg5MrtYm/dgM=
github.com/zitadel/oidc/v3 v3.5.0/go.mod h1:R8sF5DPR98QQnOoyySsaNqI4NcF/VFMkf/XoYiBUuXQ=
github.com/zitadel/oidc/v3 v3.8.0 h1:4Nvok+e6o3FDpqrf14JOg4EVBvwXNFOI1lFHPZU75iA=
github.com/zitadel/oidc/v3 v3.8.0/go.mod h1:v+aHyg4lBAUuuUHINwXqHtKunPJZo8kPvMpRRBYEKHY=
github.com/zitadel/passwap v0.4.0 h1:cMaISx+Ve7ilgG7Q8xOli4Z6IWr8Gndss+jeBk5A3O0=
github.com/zitadel/passwap v0.4.0/go.mod h1:yHaDM4A68yRkdic5BZ4iUNoc19hT+kYt8n1/Nz+I87g=
github.com/zitadel/saml v0.1.2 h1:RICwNTuP2upX4A1sZ8iq1rv4/x3DhZHzFx1e5bTKoTo=
github.com/zitadel/saml v0.1.2/go.mod h1:M+X+3vMUulpoLofKeH/W1/qjQQ3owitc2GuGDu3oYpM=
github.com/zitadel/saml v0.1.3 h1:LI4DOCVyyU1qKPkzs3vrGcA5J3H4pH3+CL9zr9ShkpM=
github.com/zitadel/saml v0.1.3/go.mod h1:MdkjyU3mwnTuh4lNnhPG+RyZL/VfzD72wUG/eWWBaXc=
github.com/zitadel/schema v1.3.0 h1:kQ9W9tvIwZICCKWcMvCEweXET1OcOyGEuFbHs4o5kg0=

View File

@ -5,11 +5,11 @@ import (
"time"
"github.com/zitadel/logging"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
"github.com/zitadel/zitadel/internal/api/ui/login"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/zerrors"
)
@ -78,47 +78,39 @@ func (o *OPStorage) StoreDeviceAuthorization(ctx context.Context, clientID, devi
span.EndWithError(err)
}()
// TODO(muhlemmer): Remove the following code block with oidc v3
// https://github.com/zitadel/oidc/issues/370
client, err := o.GetClientByClientID(ctx, clientID)
if err != nil {
return err
}
if !op.ValidateGrantType(client, oidc.GrantTypeDeviceCode) {
return zerrors.ThrowPermissionDeniedf(nil, "OIDC-et1Ae", "grant type %q not allowed for client", oidc.GrantTypeDeviceCode)
}
scopes, err = o.assertProjectRoleScopes(ctx, clientID, scopes)
if err != nil {
return zerrors.ThrowPreconditionFailed(err, "OIDC-She4t", "Errors.Internal")
}
aggrID, details, err := o.command.AddDeviceAuth(ctx, clientID, deviceCode, userCode, expires, scopes)
details, err := o.command.AddDeviceAuth(ctx, clientID, deviceCode, userCode, expires, scopes)
if err == nil {
logger.SetFields("aggregate_id", aggrID, "details", details).Debug(logMsg)
logger.SetFields("details", details).Debug(logMsg)
}
return err
}
func newDeviceAuthorizationState(d *domain.DeviceAuth) *op.DeviceAuthorizationState {
func newDeviceAuthorizationState(d *query.DeviceAuth) *op.DeviceAuthorizationState {
return &op.DeviceAuthorizationState{
ClientID: d.ClientID,
Scopes: d.Scopes,
Expires: d.Expires,
Done: d.State.Done(),
Subject: d.Subject,
Denied: d.State.Denied(),
Subject: d.Subject,
AMR: AuthMethodTypesToAMR(d.UserAuthMethods),
AuthTime: d.AuthTime,
}
}
// GetDeviceAuthorizatonState retieves the current state of the Device Authorization process.
// GetDeviceAuthorizatonState retrieves the current state of the Device Authorization process.
// It implements the [op.DeviceAuthorizationStorage] interface and is used by devices that
// are polling until they successfully receive a token or we indicate a denied or expired state.
// As generated user codes are of low entropy, this implementation also takes care or
// device authorization request cleanup, when it has been Approved, Denied or Expired.
func (o *OPStorage) GetDeviceAuthorizatonState(ctx context.Context, clientID, deviceCode string) (state *op.DeviceAuthorizationState, err error) {
const logMsg = "get device authorization state"
logger := logging.WithFields("client_id", clientID, "device_code", deviceCode)
logger := logging.WithFields("device_code", deviceCode)
ctx, span := tracing.NewSpan(ctx)
defer func() {
@ -128,7 +120,7 @@ func (o *OPStorage) GetDeviceAuthorizatonState(ctx context.Context, clientID, de
span.EndWithError(err)
}()
deviceAuth, err := o.query.DeviceAuthByDeviceCode(ctx, clientID, deviceCode)
deviceAuth, err := o.query.DeviceAuthByDeviceCode(ctx, deviceCode)
if err != nil {
return nil, err
}
@ -139,38 +131,12 @@ func (o *OPStorage) GetDeviceAuthorizatonState(ctx context.Context, clientID, de
// Cancel the request if it is expired, only if it wasn't Done meanwhile
if !deviceAuth.State.Done() && deviceAuth.Expires.Before(time.Now()) {
_, err = o.command.CancelDeviceAuth(ctx, deviceAuth.AggregateID, domain.DeviceAuthCanceledExpired)
_, err = o.command.CancelDeviceAuth(ctx, deviceAuth.DeviceCode, domain.DeviceAuthCanceledExpired)
if err != nil {
return nil, err
}
deviceAuth.State = domain.DeviceAuthStateExpired
}
// When the request is more then initiated, it has been either Approved, Denied or Expired.
// At this point we should remove it from the DB to avoid user code conflicts.
if deviceAuth.State > domain.DeviceAuthStateInitiated {
_, err = o.command.RemoveDeviceAuth(ctx, deviceAuth.AggregateID)
if err != nil {
return nil, err
}
}
return newDeviceAuthorizationState(deviceAuth), nil
}
// TODO(muhlemmer): remove the following methods with oidc v3.
// They are actually not used, but are required by the oidc device storage interface.
// https://github.com/zitadel/oidc/issues/371
func (o *OPStorage) GetDeviceAuthorizationByUserCode(ctx context.Context, userCode string) (*op.DeviceAuthorizationState, error) {
return nil, nil
}
func (o *OPStorage) CompleteDeviceAuthorization(ctx context.Context, userCode, subject string) (err error) {
return nil
}
func (o *OPStorage) DenyDeviceAuthorization(ctx context.Context, userCode string) (err error) {
return nil
}
// TODO end.

View File

@ -100,7 +100,7 @@ func (l *Login) handleDeviceAuthUserCode(w http.ResponseWriter, r *http.Request)
l.renderDeviceAuthUserCode(w, r, err)
return
}
deviceAuth, err := l.query.DeviceAuthByUserCode(ctx, userCode)
deviceAuthReq, err := l.query.DeviceAuthRequestByUserCode(ctx, userCode)
if err != nil {
l.renderDeviceAuthUserCode(w, r, err)
return
@ -113,14 +113,9 @@ func (l *Login) handleDeviceAuthUserCode(w http.ResponseWriter, r *http.Request)
authRequest, err := l.authRepo.CreateAuthRequest(ctx, &domain.AuthRequest{
CreationDate: time.Now(),
AgentID: userAgentID,
ApplicationID: deviceAuth.ClientID,
ApplicationID: deviceAuthReq.ClientID,
InstanceID: authz.GetInstance(ctx).InstanceID(),
Request: &domain.AuthRequestDevice{
ID: deviceAuth.AggregateID,
DeviceCode: deviceAuth.DeviceCode,
UserCode: deviceAuth.UserCode,
Scopes: deviceAuth.Scopes,
},
Request: deviceAuthReq,
})
if err != nil {
l.renderDeviceAuthUserCode(w, r, err)
@ -168,9 +163,9 @@ func (l *Login) handleDeviceAuthAction(w http.ResponseWriter, r *http.Request) {
action := mux.Vars(r)["action"]
switch action {
case deviceAuthAllowed:
_, err = l.command.ApproveDeviceAuth(r.Context(), authDev.ID, authReq.UserID)
_, err = l.command.ApproveDeviceAuth(r.Context(), authDev.DeviceCode, authReq.UserID, authReq.UserAuthMethodTypes(), authReq.AuthTime)
case deviceAuthDenied:
_, err = l.command.CancelDeviceAuth(r.Context(), authDev.ID, domain.DeviceAuthCanceledDenied)
_, err = l.command.CancelDeviceAuth(r.Context(), authDev.DeviceCode, domain.DeviceAuthCanceledDenied)
default:
l.renderDeviceAuthAction(w, r, authReq, authDev.Scopes)
return

View File

@ -24,6 +24,7 @@ import (
"github.com/zitadel/zitadel/internal/id"
"github.com/zitadel/zitadel/internal/repository/action"
"github.com/zitadel/zitadel/internal/repository/authrequest"
"github.com/zitadel/zitadel/internal/repository/deviceauth"
"github.com/zitadel/zitadel/internal/repository/feature"
"github.com/zitadel/zitadel/internal/repository/idpintent"
instance_repo "github.com/zitadel/zitadel/internal/repository/instance"
@ -166,6 +167,7 @@ func StartCommands(
oidcsession.RegisterEventMappers(repo.eventstore)
milestone.RegisterEventMappers(repo.eventstore)
feature.RegisterEventMappers(repo.eventstore)
deviceauth.RegisterEventMappers(repo.eventstore)
repo.codeAlg = crypto.NewBCrypt(defaults.SecretGenerators.PasswordSaltCost)
repo.userPasswordHasher, err = defaults.PasswordHasher.PasswordHasher()

View File

@ -11,14 +11,9 @@ import (
"github.com/zitadel/zitadel/internal/zerrors"
)
func (c *Commands) AddDeviceAuth(ctx context.Context, clientID, deviceCode, userCode string, expires time.Time, scopes []string) (string, *domain.ObjectDetails, error) {
aggrID, err := c.idGenerator.Next()
if err != nil {
return "", nil, err
}
aggr := deviceauth.NewAggregate(aggrID, authz.GetInstance(ctx).InstanceID())
model := NewDeviceAuthWriteModel(aggrID, aggr.ResourceOwner)
func (c *Commands) AddDeviceAuth(ctx context.Context, clientID, deviceCode, userCode string, expires time.Time, scopes []string) (*domain.ObjectDetails, error) {
aggr := deviceauth.NewAggregate(deviceCode, authz.GetInstance(ctx).InstanceID())
model := NewDeviceAuthWriteModel(deviceCode, aggr.ResourceOwner)
pushedEvents, err := c.eventstore.Push(ctx, deviceauth.NewAddedEvent(
ctx,
@ -30,18 +25,18 @@ func (c *Commands) AddDeviceAuth(ctx context.Context, clientID, deviceCode, user
scopes,
))
if err != nil {
return "", nil, err
return nil, err
}
err = AppendAndReduce(model, pushedEvents...)
if err != nil {
return "", nil, err
return nil, err
}
return model.AggregateID, writeModelToObjectDetails(&model.WriteModel), nil
return writeModelToObjectDetails(&model.WriteModel), nil
}
func (c *Commands) ApproveDeviceAuth(ctx context.Context, id, subject string) (*domain.ObjectDetails, error) {
model, err := c.getDeviceAuthWriteModelByID(ctx, id)
func (c *Commands) ApproveDeviceAuth(ctx context.Context, deviceCode, subject string, authMethods []domain.UserAuthMethodType, authTime time.Time) (*domain.ObjectDetails, error) {
model, err := c.getDeviceAuthWriteModelByDeviceCode(ctx, deviceCode)
if err != nil {
return nil, err
}
@ -50,7 +45,7 @@ func (c *Commands) ApproveDeviceAuth(ctx context.Context, id, subject string) (*
}
aggr := deviceauth.NewAggregate(model.AggregateID, model.InstanceID)
pushedEvents, err := c.eventstore.Push(ctx, deviceauth.NewApprovedEvent(ctx, aggr, subject))
pushedEvents, err := c.eventstore.Push(ctx, deviceauth.NewApprovedEvent(ctx, aggr, subject, authMethods, authTime))
if err != nil {
return nil, err
}
@ -63,7 +58,7 @@ func (c *Commands) ApproveDeviceAuth(ctx context.Context, id, subject string) (*
}
func (c *Commands) CancelDeviceAuth(ctx context.Context, id string, reason domain.DeviceAuthCanceled) (*domain.ObjectDetails, error) {
model, err := c.getDeviceAuthWriteModelByID(ctx, id)
model, err := c.getDeviceAuthWriteModelByDeviceCode(ctx, id)
if err != nil {
return nil, err
}
@ -84,27 +79,8 @@ func (c *Commands) CancelDeviceAuth(ctx context.Context, id string, reason domai
return writeModelToObjectDetails(&model.WriteModel), nil
}
func (c *Commands) RemoveDeviceAuth(ctx context.Context, id string) (*domain.ObjectDetails, error) {
model, err := c.getDeviceAuthWriteModelByID(ctx, id)
if err != nil {
return nil, err
}
aggr := deviceauth.NewAggregate(model.AggregateID, model.InstanceID)
pushedEvents, err := c.eventstore.Push(ctx, deviceauth.NewRemovedEvent(ctx, aggr, model.ClientID, model.DeviceCode, model.UserCode))
if err != nil {
return nil, err
}
err = AppendAndReduce(model, pushedEvents...)
if err != nil {
return nil, err
}
return writeModelToObjectDetails(&model.WriteModel), nil
}
func (c *Commands) getDeviceAuthWriteModelByID(ctx context.Context, id string) (*DeviceAuthWriteModel, error) {
model := &DeviceAuthWriteModel{WriteModel: eventstore.WriteModel{AggregateID: id}}
func (c *Commands) getDeviceAuthWriteModelByDeviceCode(ctx context.Context, deviceCode string) (*DeviceAuthWriteModel, error) {
model := &DeviceAuthWriteModel{WriteModel: eventstore.WriteModel{AggregateID: deviceCode}}
err := c.eventstore.FilterToQueryReducer(ctx, model)
if err != nil {
return nil, err

View File

@ -11,19 +11,21 @@ import (
type DeviceAuthWriteModel struct {
eventstore.WriteModel
ClientID string
DeviceCode string
UserCode string
Expires time.Time
Scopes []string
Subject string
State domain.DeviceAuthState
ClientID string
DeviceCode string
UserCode string
Expires time.Time
Scopes []string
State domain.DeviceAuthState
Subject string
UserAuthMethods []domain.UserAuthMethodType
AuthTime time.Time
}
func NewDeviceAuthWriteModel(aggrID, resourceOwner string) *DeviceAuthWriteModel {
func NewDeviceAuthWriteModel(deviceCode, resourceOwner string) *DeviceAuthWriteModel {
return &DeviceAuthWriteModel{
WriteModel: eventstore.WriteModel{
AggregateID: aggrID,
AggregateID: deviceCode,
ResourceOwner: resourceOwner,
},
}
@ -40,12 +42,12 @@ func (m *DeviceAuthWriteModel) Reduce() error {
m.Scopes = e.Scopes
m.State = e.State
case *deviceauth.ApprovedEvent:
m.Subject = e.Subject
m.State = domain.DeviceAuthStateApproved
m.Subject = e.Subject
m.UserAuthMethods = e.UserAuthMethods
m.AuthTime = e.AuthTime
case *deviceauth.CanceledEvent:
m.State = e.Reason.State()
case *deviceauth.RemovedEvent:
m.State = domain.DeviceAuthStateRemoved
}
}
@ -54,8 +56,14 @@ func (m *DeviceAuthWriteModel) Reduce() error {
func (m *DeviceAuthWriteModel) Query() *eventstore.SearchQueryBuilder {
return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
ResourceOwner(m.ResourceOwner).
AddQuery().
AggregateTypes(deviceauth.AggregateType).
AggregateIDs(m.AggregateID).
EventTypes(
deviceauth.AddedEventType,
deviceauth.ApprovedEventType,
deviceauth.CanceledEventType,
).
Builder()
}

View File

@ -8,29 +8,24 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/id"
id_mock "github.com/zitadel/zitadel/internal/id/mock"
"github.com/zitadel/zitadel/internal/repository/deviceauth"
"github.com/zitadel/zitadel/internal/zerrors"
)
func TestCommands_AddDeviceAuth(t *testing.T) {
ctx := authz.WithInstanceID(context.Background(), "instance1")
idErr := errors.New("idErr")
pushErr := errors.New("pushErr")
now := time.Now()
unique := deviceauth.NewAddUniqueConstraints("client_id", "123", "456")
unique := deviceauth.NewAddUniqueConstraints("123", "456")
require.Len(t, unique, 2)
type fields struct {
eventstore *eventstore.Eventstore
idGenerator id.Generator
eventstore *eventstore.Eventstore
}
type args struct {
ctx context.Context
@ -44,42 +39,20 @@ func TestCommands_AddDeviceAuth(t *testing.T) {
name string
fields fields
args args
wantID string
wantDetails *domain.ObjectDetails
wantErr error
}{
{
name: "idGenerator error",
fields: fields{
eventstore: eventstoreExpect(t),
idGenerator: func() id.Generator {
m := id_mock.NewMockGenerator(gomock.NewController(t))
m.EXPECT().Next().Return("", idErr)
return m
}(),
},
args: args{
ctx: ctx,
clientID: "client_id",
deviceCode: "123",
userCode: "456",
expires: now,
scopes: []string{"a", "b", "c"},
},
wantErr: idErr,
},
{
name: "success",
fields: fields{
eventstore: eventstoreExpect(t, expectPush(
deviceauth.NewAddedEvent(
ctx,
deviceauth.NewAggregate("1999", "instance1"),
deviceauth.NewAggregate("123", "instance1"),
"client_id", "123", "456", now,
[]string{"a", "b", "c"},
),
)),
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "1999"),
},
args: args{
ctx: authz.WithInstanceID(context.Background(), "instance1"),
@ -89,7 +62,6 @@ func TestCommands_AddDeviceAuth(t *testing.T) {
expires: now,
scopes: []string{"a", "b", "c"},
},
wantID: "1999",
wantDetails: &domain.ObjectDetails{
ResourceOwner: "instance1",
},
@ -100,12 +72,11 @@ func TestCommands_AddDeviceAuth(t *testing.T) {
eventstore: eventstoreExpect(t, expectPushFailed(pushErr,
deviceauth.NewAddedEvent(
ctx,
deviceauth.NewAggregate("1999", "instance1"),
deviceauth.NewAggregate("123", "instance1"),
"client_id", "123", "456", now,
[]string{"a", "b", "c"},
)),
),
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "1999"),
},
args: args{
ctx: authz.WithInstanceID(context.Background(), "instance1"),
@ -121,12 +92,10 @@ func TestCommands_AddDeviceAuth(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore,
idGenerator: tt.fields.idGenerator,
eventstore: tt.fields.eventstore,
}
gotID, gotDetails, err := c.AddDeviceAuth(tt.args.ctx, tt.args.clientID, tt.args.deviceCode, tt.args.userCode, tt.args.expires, tt.args.scopes)
gotDetails, err := c.AddDeviceAuth(tt.args.ctx, tt.args.clientID, tt.args.deviceCode, tt.args.userCode, tt.args.expires, tt.args.scopes)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.wantID, gotID)
assert.Equal(t, tt.wantDetails, gotDetails)
})
}
@ -141,9 +110,11 @@ func TestCommands_ApproveDeviceAuth(t *testing.T) {
eventstore *eventstore.Eventstore
}
type args struct {
ctx context.Context
id string
subject string
ctx context.Context
id string
subject string
authMethods []domain.UserAuthMethodType
authTime time.Time
}
tests := []struct {
name string
@ -156,26 +127,14 @@ func TestCommands_ApproveDeviceAuth(t *testing.T) {
name: "not found error",
fields: fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusherWithInstanceID("instance1",
deviceauth.NewAddedEvent(
ctx,
deviceauth.NewAggregate("1999", "instance1"),
"client_id", "123", "456", now,
[]string{"a", "b", "c"},
),
),
eventFromEventPusherWithInstanceID("instance1",
deviceauth.NewRemovedEvent(
ctx,
deviceauth.NewAggregate("1999", "instance1"),
"client_id", "123", "456",
),
),
),
expectFilter(),
),
},
args: args{ctx, "1999", "subj"},
args: args{
ctx, "123", "subj",
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
time.Unix(123, 456),
},
wantErr: zerrors.ThrowNotFound(nil, "COMMAND-Hief9", "Errors.DeviceAuth.NotFound"),
},
{
@ -186,19 +145,25 @@ func TestCommands_ApproveDeviceAuth(t *testing.T) {
"instance1",
deviceauth.NewAddedEvent(
ctx,
deviceauth.NewAggregate("1999", "instance1"),
deviceauth.NewAggregate("123", "instance1"),
"client_id", "123", "456", now,
[]string{"a", "b", "c"},
),
)),
expectPushFailed(pushErr,
deviceauth.NewApprovedEvent(
ctx, deviceauth.NewAggregate("1999", "instance1"), "subj",
ctx, deviceauth.NewAggregate("123", "instance1"), "subj",
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
time.Unix(123, 456),
),
),
),
},
args: args{ctx, "1999", "subj"},
args: args{
ctx, "123", "subj",
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
time.Unix(123, 456),
},
wantErr: pushErr,
},
{
@ -209,19 +174,25 @@ func TestCommands_ApproveDeviceAuth(t *testing.T) {
"instance1",
deviceauth.NewAddedEvent(
ctx,
deviceauth.NewAggregate("1999", "instance1"),
deviceauth.NewAggregate("123", "instance1"),
"client_id", "123", "456", now,
[]string{"a", "b", "c"},
),
)),
expectPush(
deviceauth.NewApprovedEvent(
ctx, deviceauth.NewAggregate("1999", "instance1"), "subj",
ctx, deviceauth.NewAggregate("123", "instance1"), "subj",
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
time.Unix(123, 456),
),
),
),
},
args: args{ctx, "1999", "subj"},
args: args{
ctx, "123", "subj",
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
time.Unix(123, 456),
},
wantDetails: &domain.ObjectDetails{
ResourceOwner: "instance1",
},
@ -232,7 +203,7 @@ func TestCommands_ApproveDeviceAuth(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore,
}
gotDetails, err := c.ApproveDeviceAuth(tt.args.ctx, tt.args.id, tt.args.subject)
gotDetails, err := c.ApproveDeviceAuth(tt.args.ctx, tt.args.id, tt.args.subject, tt.args.authMethods, tt.args.authTime)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, gotDetails, tt.wantDetails)
})
@ -263,26 +234,10 @@ func TestCommands_CancelDeviceAuth(t *testing.T) {
name: "not found error",
fields: fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusherWithInstanceID("instance1",
deviceauth.NewAddedEvent(
ctx,
deviceauth.NewAggregate("1999", "instance1"),
"client_id", "123", "456", now,
[]string{"a", "b", "c"},
),
),
eventFromEventPusherWithInstanceID("instance1",
deviceauth.NewRemovedEvent(
ctx,
deviceauth.NewAggregate("1999", "instance1"),
"client_id", "123", "456",
),
),
),
expectFilter(),
),
},
args: args{ctx, "1999", domain.DeviceAuthCanceledDenied},
args: args{ctx, "123", domain.DeviceAuthCanceledDenied},
wantErr: zerrors.ThrowNotFound(nil, "COMMAND-gee5A", "Errors.DeviceAuth.NotFound"),
},
{
@ -293,20 +248,20 @@ func TestCommands_CancelDeviceAuth(t *testing.T) {
"instance1",
deviceauth.NewAddedEvent(
ctx,
deviceauth.NewAggregate("1999", "instance1"),
deviceauth.NewAggregate("123", "instance1"),
"client_id", "123", "456", now,
[]string{"a", "b", "c"},
),
)),
expectPushFailed(pushErr,
deviceauth.NewCanceledEvent(
ctx, deviceauth.NewAggregate("1999", "instance1"),
ctx, deviceauth.NewAggregate("123", "instance1"),
domain.DeviceAuthCanceledDenied,
),
),
),
},
args: args{ctx, "1999", domain.DeviceAuthCanceledDenied},
args: args{ctx, "123", domain.DeviceAuthCanceledDenied},
wantErr: pushErr,
},
{
@ -317,20 +272,20 @@ func TestCommands_CancelDeviceAuth(t *testing.T) {
"instance1",
deviceauth.NewAddedEvent(
ctx,
deviceauth.NewAggregate("1999", "instance1"),
deviceauth.NewAggregate("123", "instance1"),
"client_id", "123", "456", now,
[]string{"a", "b", "c"},
),
)),
expectPush(
deviceauth.NewCanceledEvent(
ctx, deviceauth.NewAggregate("1999", "instance1"),
ctx, deviceauth.NewAggregate("123", "instance1"),
domain.DeviceAuthCanceledDenied,
),
),
),
},
args: args{ctx, "1999", domain.DeviceAuthCanceledDenied},
args: args{ctx, "123", domain.DeviceAuthCanceledDenied},
wantDetails: &domain.ObjectDetails{
ResourceOwner: "instance1",
},
@ -343,20 +298,20 @@ func TestCommands_CancelDeviceAuth(t *testing.T) {
"instance1",
deviceauth.NewAddedEvent(
ctx,
deviceauth.NewAggregate("1999", "instance1"),
deviceauth.NewAggregate("123", "instance1"),
"client_id", "123", "456", now,
[]string{"a", "b", "c"},
),
)),
expectPush(
deviceauth.NewCanceledEvent(
ctx, deviceauth.NewAggregate("1999", "instance1"),
ctx, deviceauth.NewAggregate("123", "instance1"),
domain.DeviceAuthCanceledExpired,
),
),
),
},
args: args{ctx, "1999", domain.DeviceAuthCanceledExpired},
args: args{ctx, "123", domain.DeviceAuthCanceledExpired},
wantDetails: &domain.ObjectDetails{
ResourceOwner: "instance1",
},
@ -373,88 +328,3 @@ func TestCommands_CancelDeviceAuth(t *testing.T) {
})
}
}
func TestCommands_RemoveDeviceAuth(t *testing.T) {
ctx := authz.WithInstanceID(context.Background(), "instance1")
now := time.Now()
pushErr := errors.New("pushErr")
unique := deviceauth.NewRemoveUniqueConstraints("client_id", "123", "456")
require.Len(t, unique, 2)
type fields struct {
eventstore *eventstore.Eventstore
}
type args struct {
ctx context.Context
id string
}
tests := []struct {
name string
fields fields
args args
wantDetails *domain.ObjectDetails
wantErr error
}{
{
name: "push error",
fields: fields{
eventstore: eventstoreExpect(t,
expectFilter(eventFromEventPusherWithInstanceID(
"instance1",
deviceauth.NewAddedEvent(
ctx,
deviceauth.NewAggregate("1999", "instance1"),
"client_id", "123", "456", now,
[]string{"a", "b", "c"},
),
)),
expectPushFailed(pushErr,
deviceauth.NewRemovedEvent(
ctx, deviceauth.NewAggregate("1999", "instance1"),
"client_id", "123", "456",
),
),
),
},
args: args{ctx, "1999"},
wantErr: pushErr,
},
{
name: "success",
fields: fields{
eventstore: eventstoreExpect(t,
expectFilter(eventFromEventPusherWithInstanceID(
"instance1",
deviceauth.NewAddedEvent(
ctx,
deviceauth.NewAggregate("1999", "instance1"),
"client_id", "123", "456", now,
[]string{"a", "b", "c"},
),
)),
expectPush(
deviceauth.NewRemovedEvent(
ctx, deviceauth.NewAggregate("1999", "instance1"),
"client_id", "123", "456",
),
),
),
},
args: args{ctx, "1999"},
wantDetails: &domain.ObjectDetails{
ResourceOwner: "instance1",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore,
}
gotDetails, err := c.RemoveDeviceAuth(tt.args.ctx, tt.args.id)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, gotDetails, tt.wantDetails)
})
}
}

View File

@ -19,6 +19,7 @@ import (
"github.com/zitadel/zitadel/internal/eventstore/repository/mock"
action_repo "github.com/zitadel/zitadel/internal/repository/action"
"github.com/zitadel/zitadel/internal/repository/authrequest"
"github.com/zitadel/zitadel/internal/repository/deviceauth"
"github.com/zitadel/zitadel/internal/repository/feature"
"github.com/zitadel/zitadel/internal/repository/idpintent"
iam_repo "github.com/zitadel/zitadel/internal/repository/instance"
@ -63,6 +64,7 @@ func eventstoreExpect(t *testing.T, expects ...expect) *eventstore.Eventstore {
limits.RegisterEventMappers(es)
restrictions.RegisterEventMappers(es)
feature.RegisterEventMappers(es)
deviceauth.RegisterEventMappers(es)
return es
}

View File

@ -110,6 +110,23 @@ const (
MFATypeOTPEmail
)
func (m MFAType) UserAuthMethodType() UserAuthMethodType {
switch m {
case MFATypeTOTP:
return UserAuthMethodTypeTOTP
case MFATypeU2F:
return UserAuthMethodTypeU2F
case MFATypeU2FUserVerification:
return UserAuthMethodTypePasswordless
case MFATypeOTPSMS:
return UserAuthMethodTypeOTPSMS
case MFATypeOTPEmail:
return UserAuthMethodTypeOTPEmail
default:
return UserAuthMethodTypeUnspecified
}
}
type MFALevel int
const (
@ -223,3 +240,14 @@ func (a *AuthRequest) PrivateLabelingOrgID(defaultID string) string {
}
return defaultID
}
func (a *AuthRequest) UserAuthMethodTypes() []UserAuthMethodType {
list := make([]UserAuthMethodType, 0, len(a.MFAsVerified)+1)
if a.PasswordVerified {
list = append(list, UserAuthMethodTypePassword)
}
for _, mfa := range a.MFAsVerified {
list = append(list, mfa.UserAuthMethodType())
}
return list
}

View File

@ -0,0 +1,108 @@
package domain
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestMFAType_UserAuthMethodType(t *testing.T) {
tests := []struct {
name string
m MFAType
want UserAuthMethodType
}{
{
name: "totp",
m: MFATypeTOTP,
want: UserAuthMethodTypeTOTP,
},
{
name: "u2f",
m: MFATypeU2F,
want: UserAuthMethodTypeU2F,
},
{
name: "passwordless",
m: MFATypeU2FUserVerification,
want: UserAuthMethodTypePasswordless,
},
{
name: "otp sms",
m: MFATypeOTPSMS,
want: UserAuthMethodTypeOTPSMS,
},
{
name: "otp email",
m: MFATypeOTPEmail,
want: UserAuthMethodTypeOTPEmail,
},
{
name: "unspecified",
m: 99,
want: UserAuthMethodTypeUnspecified,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.m.UserAuthMethodType()
assert.Equal(t, tt.want, got)
})
}
}
func TestAuthRequest_UserAuthMethodTypes(t *testing.T) {
type fields struct {
PasswordVerified bool
MFAsVerified []MFAType
}
tests := []struct {
name string
fields fields
want []UserAuthMethodType
}{
{
name: "no auth methods",
fields: fields{
PasswordVerified: false,
MFAsVerified: nil,
},
want: []UserAuthMethodType{},
},
{
name: "only password",
fields: fields{
PasswordVerified: true,
MFAsVerified: nil,
},
want: []UserAuthMethodType{
UserAuthMethodTypePassword,
},
},
{
name: "password, with mfa",
fields: fields{
PasswordVerified: true,
MFAsVerified: []MFAType{
MFATypeTOTP,
MFATypeU2F,
},
},
want: []UserAuthMethodType{
UserAuthMethodTypePassword,
UserAuthMethodTypeTOTP,
UserAuthMethodTypeU2F,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &AuthRequest{
PasswordVerified: tt.fields.PasswordVerified,
MFAsVerified: tt.fields.MFAsVerified,
}
got := a.UserAuthMethodTypes()
assert.Equal(t, tt.want, got)
})
}
}

View File

@ -2,28 +2,11 @@ package domain
import (
"strconv"
"time"
"github.com/zitadel/zitadel/internal/eventstore/v1/models"
)
// DeviceAuth describes a Device Authorization request.
// It is used as input and output model in the command and query packages.
type DeviceAuth struct {
models.ObjectRoot
ClientID string
DeviceCode string
UserCode string
Expires time.Time
Scopes []string
Subject string
State DeviceAuthState
}
// DeviceAuthState describes the step the
// the device authorization process is in.
// We generate the Stringer implemntation for pretier
// We generate the Stringer implementation for prettier
// log output.
//
//go:generate stringer -type=DeviceAuthState -linecomment
@ -35,13 +18,14 @@ const (
DeviceAuthStateApproved // approved
DeviceAuthStateDenied // denied
DeviceAuthStateExpired // expired
DeviceAuthStateRemoved // removed
deviceAuthStateCount // invalid
)
// Exists returns true when not Undefined and
// any status lower than Removed.
// any status lower than deviceAuthStateCount.
func (s DeviceAuthState) Exists() bool {
return s > DeviceAuthStateUndefined && s < DeviceAuthStateRemoved
return s > DeviceAuthStateUndefined && s < deviceAuthStateCount
}
// Done returns true when DeviceAuthState is Approved.

View File

@ -30,7 +30,7 @@ func TestDeviceAuthState_Exists(t *testing.T) {
want: true,
},
{
s: DeviceAuthStateRemoved,
s: deviceAuthStateCount,
want: false,
},
}
@ -68,10 +68,6 @@ func TestDeviceAuthState_Done(t *testing.T) {
s: DeviceAuthStateExpired,
want: false,
},
{
s: DeviceAuthStateRemoved,
want: false,
},
}
for _, tt := range tests {
t.Run(tt.s.String(), func(t *testing.T) {
@ -108,10 +104,6 @@ func TestDeviceAuthState_Denied(t *testing.T) {
s: DeviceAuthStateExpired,
want: true,
},
{
s: DeviceAuthStateRemoved,
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

View File

@ -13,10 +13,10 @@ func _() {
_ = x[DeviceAuthStateApproved-2]
_ = x[DeviceAuthStateDenied-3]
_ = x[DeviceAuthStateExpired-4]
_ = x[DeviceAuthStateRemoved-5]
_ = x[deviceAuthStateCount-5]
}
const _DeviceAuthState_name = "undefinedinitiatedapproveddeniedexpiredremoved"
const _DeviceAuthState_name = "undefinedinitiatedapproveddeniedexpiredinvalid"
var _DeviceAuthState_index = [...]uint8{0, 9, 18, 26, 32, 39, 46}

View File

@ -59,7 +59,7 @@ func (a *AuthRequestSAML) IsValid() bool {
}
type AuthRequestDevice struct {
ID string
ClientID string
DeviceCode string
UserCode string
Scopes []string
@ -70,5 +70,5 @@ func (*AuthRequestDevice) Type() AuthRequestType {
}
func (a *AuthRequestDevice) IsValid() bool {
return a.DeviceCode != "" && a.UserCode != "" && len(a.Scopes) > 0
return a.DeviceCode != "" && a.UserCode != ""
}

View File

@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"errors"
"time"
sq "github.com/Masterminds/squirrel"
@ -16,90 +17,80 @@ import (
)
var (
deviceAuthTable = table{
name: projection.DeviceAuthProjectionTable,
instanceIDCol: projection.DeviceAuthColumnInstanceID,
deviceAuthRequestTable = table{
name: projection.DeviceAuthRequestProjectionTable,
instanceIDCol: projection.DeviceAuthRequestColumnInstanceID,
}
DeviceAuthColumnID = Column{
name: projection.DeviceAuthColumnID,
table: deviceAuthTable,
DeviceAuthRequestColumnClientID = Column{
name: projection.DeviceAuthRequestColumnClientID,
table: deviceAuthRequestTable,
}
DeviceAuthColumnClientID = Column{
name: projection.DeviceAuthColumnClientID,
table: deviceAuthTable,
DeviceAuthRequestColumnDeviceCode = Column{
name: projection.DeviceAuthRequestColumnDeviceCode,
table: deviceAuthRequestTable,
}
DeviceAuthColumnDeviceCode = Column{
name: projection.DeviceAuthColumnDeviceCode,
table: deviceAuthTable,
DeviceAuthRequestColumnUserCode = Column{
name: projection.DeviceAuthRequestColumnUserCode,
table: deviceAuthRequestTable,
}
DeviceAuthColumnUserCode = Column{
name: projection.DeviceAuthColumnUserCode,
table: deviceAuthTable,
DeviceAuthRequestColumnScopes = Column{
name: projection.DeviceAuthRequestColumnScopes,
table: deviceAuthRequestTable,
}
DeviceAuthColumnExpires = Column{
name: projection.DeviceAuthColumnExpires,
table: deviceAuthTable,
DeviceAuthRequestColumnCreationDate = Column{
name: projection.DeviceAuthRequestColumnCreationDate,
table: deviceAuthRequestTable,
}
DeviceAuthColumnScopes = Column{
name: projection.DeviceAuthColumnScopes,
table: deviceAuthTable,
DeviceAuthRequestColumnChangeDate = Column{
name: projection.DeviceAuthRequestColumnChangeDate,
table: deviceAuthRequestTable,
}
DeviceAuthColumnState = Column{
name: projection.DeviceAuthColumnState,
table: deviceAuthTable,
DeviceAuthRequestColumnSequence = Column{
name: projection.DeviceAuthRequestColumnSequence,
table: deviceAuthRequestTable,
}
DeviceAuthColumnSubject = Column{
name: projection.DeviceAuthColumnSubject,
table: deviceAuthTable,
}
DeviceAuthColumnCreationDate = Column{
name: projection.DeviceAuthColumnCreationDate,
table: deviceAuthTable,
}
DeviceAuthColumnChangeDate = Column{
name: projection.DeviceAuthColumnChangeDate,
table: deviceAuthTable,
}
DeviceAuthColumnSequence = Column{
name: projection.DeviceAuthColumnSequence,
table: deviceAuthTable,
}
DeviceAuthColumnInstanceID = Column{
name: projection.DeviceAuthColumnInstanceID,
table: deviceAuthTable,
DeviceAuthRequestColumnInstanceID = Column{
name: projection.DeviceAuthRequestColumnInstanceID,
table: deviceAuthRequestTable,
}
)
func (q *Queries) DeviceAuthByDeviceCode(ctx context.Context, clientID, deviceCode string) (deviceAuth *domain.DeviceAuth, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
stmt, scan := prepareDeviceAuthQuery(ctx, q.client)
eq := sq.Eq{
DeviceAuthColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
DeviceAuthColumnClientID.identifier(): clientID,
DeviceAuthColumnDeviceCode.identifier(): deviceCode,
}
query, args, err := stmt.Where(eq).ToSql()
if err != nil {
return nil, zerrors.ThrowInternal(err, "QUERY-uk1Oh", "Errors.Query.SQLStatement")
}
err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
deviceAuth, err = scan(row)
return err
}, query, args...)
return deviceAuth, err
type DeviceAuth struct {
ClientID string
DeviceCode string
UserCode string
Expires time.Time
Scopes []string
State domain.DeviceAuthState
Subject string
UserAuthMethods []domain.UserAuthMethodType
AuthTime time.Time
}
func (q *Queries) DeviceAuthByUserCode(ctx context.Context, userCode string) (deviceAuth *domain.DeviceAuth, err error) {
// DeviceAuthByDeviceCode gets the current state of a Device Authorization directly from the eventstore.
func (q *Queries) DeviceAuthByDeviceCode(ctx context.Context, deviceCode string) (deviceAuth *DeviceAuth, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
model := NewDeviceAuthReadModel(deviceCode, authz.GetInstance(ctx).InstanceID())
if err := q.eventstore.FilterToQueryReducer(ctx, model); err != nil {
return nil, err
}
if !model.State.Exists() {
return nil, zerrors.ThrowNotFound(nil, "QUERY-eeR0e", "Errors.DeviceAuth.NotExisting")
}
return &model.DeviceAuth, nil
}
// DeviceAuthRequestByUserCode finds a Device Authorization request by User-Code from the `device_auth_requests` projection.
func (q *Queries) DeviceAuthRequestByUserCode(ctx context.Context, userCode string) (authReq *domain.AuthRequestDevice, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
stmt, scan := prepareDeviceAuthQuery(ctx, q.client)
eq := sq.Eq{
DeviceAuthColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
DeviceAuthColumnUserCode.identifier(): userCode,
DeviceAuthRequestColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
DeviceAuthRequestColumnUserCode.identifier(): userCode,
}
query, args, err := stmt.Where(eq).ToSql()
if err != nil {
@ -107,34 +98,32 @@ func (q *Queries) DeviceAuthByUserCode(ctx context.Context, userCode string) (de
}
err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
deviceAuth, err = scan(row)
authReq, err = scan(row)
return err
}, query, args...)
return deviceAuth, err
return authReq, err
}
var deviceAuthSelectColumns = []string{
DeviceAuthColumnID.identifier(),
DeviceAuthColumnClientID.identifier(),
DeviceAuthColumnScopes.identifier(),
DeviceAuthColumnExpires.identifier(),
DeviceAuthColumnState.identifier(),
DeviceAuthColumnSubject.identifier(),
DeviceAuthRequestColumnClientID.identifier(),
DeviceAuthRequestColumnDeviceCode.identifier(),
DeviceAuthRequestColumnUserCode.identifier(),
DeviceAuthRequestColumnScopes.identifier(),
}
func prepareDeviceAuthQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*domain.DeviceAuth, error)) {
return sq.Select(deviceAuthSelectColumns...).From(deviceAuthTable.identifier()).PlaceholderFormat(sq.Dollar),
func(row *sql.Row) (*domain.DeviceAuth, error) {
dst := new(domain.DeviceAuth)
var scopes database.TextArray[string]
func prepareDeviceAuthQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*domain.AuthRequestDevice, error)) {
return sq.Select(deviceAuthSelectColumns...).From(deviceAuthRequestTable.identifier()).PlaceholderFormat(sq.Dollar),
func(row *sql.Row) (*domain.AuthRequestDevice, error) {
dst := new(domain.AuthRequestDevice)
var (
scopes database.TextArray[string]
)
err := row.Scan(
&dst.AggregateID,
&dst.ClientID,
&dst.DeviceCode,
&dst.UserCode,
&scopes,
&dst.Expires,
&dst.State,
&dst.Subject,
)
if errors.Is(err, sql.ErrNoRows) {
return nil, zerrors.ThrowNotFound(err, "QUERY-Sah9a", "Errors.DeviceAuth.NotExisting")
@ -142,7 +131,6 @@ func prepareDeviceAuthQuery(ctx context.Context, db prepareDatabase) (sq.SelectB
if err != nil {
return nil, zerrors.ThrowInternal(err, "QUERY-Voo3o", "Errors.Internal")
}
dst.Scopes = scopes
return dst, nil
}

View File

@ -0,0 +1,58 @@
package query
import (
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/deviceauth"
)
type DeviceAuthReadModel struct {
eventstore.ReadModel
DeviceAuth
}
func NewDeviceAuthReadModel(deviceCode, resourceOwner string) *DeviceAuthReadModel {
return &DeviceAuthReadModel{
ReadModel: eventstore.ReadModel{
AggregateID: deviceCode,
ResourceOwner: resourceOwner,
},
}
}
func (m *DeviceAuthReadModel) Reduce() error {
for _, event := range m.Events {
switch e := event.(type) {
case *deviceauth.AddedEvent:
m.ClientID = e.ClientID
m.DeviceCode = e.DeviceCode
m.UserCode = e.UserCode
m.Expires = e.Expires
m.Scopes = e.Scopes
m.State = e.State
case *deviceauth.ApprovedEvent:
m.State = domain.DeviceAuthStateApproved
m.Subject = e.Subject
m.UserAuthMethods = e.UserAuthMethods
m.AuthTime = e.AuthTime
case *deviceauth.CanceledEvent:
m.State = e.Reason.State()
}
}
return m.ReadModel.Reduce()
}
func (m *DeviceAuthReadModel) Query() *eventstore.SearchQueryBuilder {
return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
ResourceOwner(m.ResourceOwner).
AddQuery().
AggregateTypes(deviceauth.AggregateType).
AggregateIDs(m.AggregateID).
EventTypes(
deviceauth.AddedEventType,
deviceauth.ApprovedEventType,
deviceauth.CanceledEventType,
).
Builder()
}

View File

@ -6,82 +6,188 @@ import (
"database/sql/driver"
"errors"
"fmt"
"io"
"regexp"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore/v1/models"
)
const (
expectedDeviceAuthQueryC = `SELECT` +
` projections.device_authorizations.id,` +
` projections.device_authorizations.client_id,` +
` projections.device_authorizations.scopes,` +
` projections.device_authorizations.expires,` +
` projections.device_authorizations.state,` +
` projections.device_authorizations.subject` +
` FROM projections.device_authorizations`
expectedDeviceAuthWhereDeviceCodeQueryC = expectedDeviceAuthQueryC +
` WHERE projections.device_authorizations.client_id = $1` +
` AND projections.device_authorizations.device_code = $2` +
` AND projections.device_authorizations.instance_id = $3`
expectedDeviceAuthWhereUserCodeQueryC = expectedDeviceAuthQueryC +
` WHERE projections.device_authorizations.instance_id = $1` +
` AND projections.device_authorizations.user_code = $2`
)
var (
expectedDeviceAuthQuery = regexp.QuoteMeta(expectedDeviceAuthQueryC)
expectedDeviceAuthWhereDeviceCodeQuery = regexp.QuoteMeta(expectedDeviceAuthWhereDeviceCodeQueryC)
expectedDeviceAuthWhereUserCodeQuery = regexp.QuoteMeta(expectedDeviceAuthWhereUserCodeQueryC)
expectedDeviceAuthValues = []driver.Value{
"primary-id",
"client-id",
database.TextArray[string]{"a", "b", "c"},
testNow,
domain.DeviceAuthStateApproved,
"subject",
}
expectedDeviceAuth = &domain.DeviceAuth{
ObjectRoot: models.ObjectRoot{
AggregateID: "primary-id",
},
ClientID: "client-id",
Scopes: []string{"a", "b", "c"},
Expires: testNow,
State: domain.DeviceAuthStateApproved,
Subject: "subject",
}
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/deviceauth"
"github.com/zitadel/zitadel/internal/zerrors"
)
func TestQueries_DeviceAuthByDeviceCode(t *testing.T) {
client, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("failed to build mock client: %v", err)
ctx := authz.NewMockContext("inst1", "org1", "user1")
timestamp := time.Date(2015, 12, 15, 22, 13, 45, 0, time.UTC)
tests := []struct {
name string
eventstore func(t *testing.T) *eventstore.Eventstore
want *DeviceAuth
wantErr error
}{
{
name: "filter error",
eventstore: expectEventstore(
expectFilterError(io.ErrClosedPipe),
),
wantErr: io.ErrClosedPipe,
},
{
name: "not found",
eventstore: expectEventstore(
expectFilter(),
),
wantErr: zerrors.ThrowNotFound(nil, "QUERY-eeR0e", "Errors.DeviceAuth.NotExisting"),
},
{
name: "ok, initiated",
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(deviceauth.NewAddedEvent(
ctx,
deviceauth.NewAggregate("device1", "instance1"),
"client1", "device1", "user-code", timestamp, []string{"foo", "bar"},
)),
),
),
want: &DeviceAuth{
ClientID: "client1",
DeviceCode: "device1",
UserCode: "user-code",
Expires: timestamp,
Scopes: []string{"foo", "bar"},
State: domain.DeviceAuthStateInitiated,
},
},
{
name: "ok, approved",
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(deviceauth.NewAddedEvent(
ctx,
deviceauth.NewAggregate("device1", "instance1"),
"client1", "device1", "user-code", timestamp, []string{"foo", "bar"},
)),
eventFromEventPusher(deviceauth.NewApprovedEvent(
ctx,
deviceauth.NewAggregate("device1", "instance1"),
"user1", []domain.UserAuthMethodType{domain.UserAuthMethodTypePasswordless},
timestamp,
)),
),
),
want: &DeviceAuth{
ClientID: "client1",
DeviceCode: "device1",
UserCode: "user-code",
Expires: timestamp,
Scopes: []string{"foo", "bar"},
State: domain.DeviceAuthStateApproved,
Subject: "user1",
UserAuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePasswordless},
AuthTime: timestamp,
},
},
{
name: "ok, denied",
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(deviceauth.NewAddedEvent(
ctx,
deviceauth.NewAggregate("device1", "instance1"),
"client1", "device1", "user-code", timestamp, []string{"foo", "bar"},
)),
eventFromEventPusher(deviceauth.NewCanceledEvent(
ctx,
deviceauth.NewAggregate("device1", "instance1"),
domain.DeviceAuthCanceledDenied,
)),
),
),
want: &DeviceAuth{
ClientID: "client1",
DeviceCode: "device1",
UserCode: "user-code",
Expires: timestamp,
Scopes: []string{"foo", "bar"},
State: domain.DeviceAuthStateDenied,
},
},
{
name: "ok, expired",
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(deviceauth.NewAddedEvent(
ctx,
deviceauth.NewAggregate("device1", "instance1"),
"client1", "device1", "user-code", timestamp, []string{"foo", "bar"},
)),
eventFromEventPusher(deviceauth.NewCanceledEvent(
ctx,
deviceauth.NewAggregate("device1", "instance1"),
domain.DeviceAuthCanceledExpired,
)),
),
),
want: &DeviceAuth{
ClientID: "client1",
DeviceCode: "device1",
UserCode: "user-code",
Expires: timestamp,
Scopes: []string{"foo", "bar"},
State: domain.DeviceAuthStateExpired,
},
},
}
defer client.Close()
mock.ExpectBegin()
mock.ExpectQuery(expectedDeviceAuthWhereDeviceCodeQuery).WillReturnRows(
sqlmock.NewRows(deviceAuthSelectColumns).AddRow(expectedDeviceAuthValues...),
)
mock.ExpectCommit()
q := Queries{
client: &database.DB{DB: client},
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
q := &Queries{
eventstore: tt.eventstore(t),
}
got, err := q.DeviceAuthByDeviceCode(ctx, "device1")
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.want, got)
})
}
got, err := q.DeviceAuthByDeviceCode(context.TODO(), "123", "456")
require.NoError(t, err)
assert.Equal(t, expectedDeviceAuth, got)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestQueries_DeviceAuthByUserCode(t *testing.T) {
const (
expectedDeviceAuthQueryC = `SELECT` +
` projections.device_auth_requests.client_id,` +
` projections.device_auth_requests.device_code,` +
` projections.device_auth_requests.user_code,` +
` projections.device_auth_requests.scopes` +
` FROM projections.device_auth_requests`
expectedDeviceAuthWhereUserCodeQueryC = expectedDeviceAuthQueryC +
` WHERE projections.device_auth_requests.instance_id = $1` +
` AND projections.device_auth_requests.user_code = $2`
)
var (
expectedDeviceAuthQuery = regexp.QuoteMeta(expectedDeviceAuthQueryC)
expectedDeviceAuthWhereUserCodeQuery = regexp.QuoteMeta(expectedDeviceAuthWhereUserCodeQueryC)
expectedDeviceAuthValues = []driver.Value{
"client-id",
"device1",
"user-code",
database.TextArray[string]{"a", "b", "c"},
}
expectedDeviceAuth = &domain.AuthRequestDevice{
ClientID: "client-id",
DeviceCode: "device1",
UserCode: "user-code",
Scopes: []string{"a", "b", "c"},
}
)
func TestQueries_DeviceAuthRequestByUserCode(t *testing.T) {
client, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("failed to build mock client: %v", err)
@ -96,7 +202,7 @@ func TestQueries_DeviceAuthByUserCode(t *testing.T) {
q := Queries{
client: &database.DB{DB: client},
}
got, err := q.DeviceAuthByUserCode(context.TODO(), "789")
got, err := q.DeviceAuthRequestByUserCode(context.TODO(), "789")
require.NoError(t, err)
assert.Equal(t, expectedDeviceAuth, got)
require.NoError(t, mock.ExpectationsWereMet())
@ -110,7 +216,7 @@ func Test_prepareDeviceAuthQuery(t *testing.T) {
tests := []struct {
name string
want want
object any
object *domain.AuthRequestDevice
}{
{
name: "success",
@ -137,7 +243,7 @@ func Test_prepareDeviceAuthQuery(t *testing.T) {
return nil, true
},
},
object: (*domain.DeviceAuth)(nil),
object: nil,
},
{
name: "other error",
@ -153,7 +259,7 @@ func Test_prepareDeviceAuthQuery(t *testing.T) {
return nil, true
},
},
object: (*domain.DeviceAuth)(nil),
object: nil,
},
}
for _, tt := range tests {

View File

@ -3,7 +3,6 @@ package projection
import (
"context"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
old_handler "github.com/zitadel/zitadel/internal/eventstore/handler"
"github.com/zitadel/zitadel/internal/eventstore/handler/v2"
@ -12,57 +11,51 @@ import (
)
const (
DeviceAuthProjectionTable = "projections.device_authorizations"
DeviceAuthRequestProjectionTable = "projections.device_auth_requests"
DeviceAuthColumnID = "id"
DeviceAuthColumnClientID = "client_id"
DeviceAuthColumnDeviceCode = "device_code"
DeviceAuthColumnUserCode = "user_code"
DeviceAuthColumnExpires = "expires"
DeviceAuthColumnScopes = "scopes"
DeviceAuthColumnState = "state"
DeviceAuthColumnSubject = "subject"
DeviceAuthColumnCreationDate = "creation_date"
DeviceAuthColumnChangeDate = "change_date"
DeviceAuthColumnSequence = "sequence"
DeviceAuthColumnInstanceID = "instance_id"
DeviceAuthRequestColumnClientID = "client_id"
DeviceAuthRequestColumnDeviceCode = "device_code"
DeviceAuthRequestColumnUserCode = "user_code"
DeviceAuthRequestColumnScopes = "scopes"
DeviceAuthRequestColumnCreationDate = "creation_date"
DeviceAuthRequestColumnChangeDate = "change_date"
DeviceAuthRequestColumnSequence = "sequence"
DeviceAuthRequestColumnInstanceID = "instance_id"
)
type deviceAuthProjection struct{}
// deviceAuthRequestProjection holds device authorization requests
// and makes them search-able by User Code.
// In principle the projected data is only needed during user login.
// Device Token logic uses the eventstore directly.
type deviceAuthRequestProjection struct{}
func newDeviceAuthProjection(ctx context.Context, config handler.Config) *handler.Handler {
return handler.NewHandler(ctx, &config, new(deviceAuthProjection))
return handler.NewHandler(ctx, &config, new(deviceAuthRequestProjection))
}
func (*deviceAuthProjection) Name() string {
return DeviceAuthProjectionTable
func (*deviceAuthRequestProjection) Name() string {
return DeviceAuthRequestProjectionTable
}
func (*deviceAuthProjection) Init() *old_handler.Check {
func (*deviceAuthRequestProjection) Init() *old_handler.Check {
return handler.NewTableCheck(
handler.NewTable([]*handler.InitColumn{
handler.NewColumn(DeviceAuthColumnID, handler.ColumnTypeText),
handler.NewColumn(DeviceAuthColumnClientID, handler.ColumnTypeText),
handler.NewColumn(DeviceAuthColumnDeviceCode, handler.ColumnTypeText),
handler.NewColumn(DeviceAuthColumnUserCode, handler.ColumnTypeText),
handler.NewColumn(DeviceAuthColumnExpires, handler.ColumnTypeTimestamp),
handler.NewColumn(DeviceAuthColumnScopes, handler.ColumnTypeTextArray),
handler.NewColumn(DeviceAuthColumnState, handler.ColumnTypeEnum, handler.Default(domain.DeviceAuthStateInitiated)),
handler.NewColumn(DeviceAuthColumnSubject, handler.ColumnTypeText, handler.Default("")),
handler.NewColumn(DeviceAuthColumnCreationDate, handler.ColumnTypeTimestamp),
handler.NewColumn(DeviceAuthColumnChangeDate, handler.ColumnTypeTimestamp),
handler.NewColumn(DeviceAuthColumnSequence, handler.ColumnTypeInt64),
handler.NewColumn(DeviceAuthColumnInstanceID, handler.ColumnTypeText),
handler.NewColumn(DeviceAuthRequestColumnClientID, handler.ColumnTypeText),
handler.NewColumn(DeviceAuthRequestColumnDeviceCode, handler.ColumnTypeText),
handler.NewColumn(DeviceAuthRequestColumnUserCode, handler.ColumnTypeText),
handler.NewColumn(DeviceAuthRequestColumnScopes, handler.ColumnTypeTextArray),
handler.NewColumn(DeviceAuthRequestColumnCreationDate, handler.ColumnTypeTimestamp),
handler.NewColumn(DeviceAuthRequestColumnChangeDate, handler.ColumnTypeTimestamp),
handler.NewColumn(DeviceAuthRequestColumnSequence, handler.ColumnTypeInt64),
handler.NewColumn(DeviceAuthRequestColumnInstanceID, handler.ColumnTypeText),
},
handler.NewPrimaryKey(DeviceAuthColumnInstanceID, DeviceAuthColumnID),
handler.WithIndex(handler.NewIndex("user_code", []string{DeviceAuthColumnInstanceID, DeviceAuthColumnUserCode})),
handler.WithIndex(handler.NewIndex("device_code", []string{DeviceAuthColumnInstanceID, DeviceAuthColumnClientID, DeviceAuthColumnDeviceCode})),
handler.NewPrimaryKey(DeviceAuthRequestColumnInstanceID, DeviceAuthRequestColumnDeviceCode),
handler.WithIndex(handler.NewIndex("user_code", []string{DeviceAuthRequestColumnInstanceID, DeviceAuthRequestColumnUserCode})),
),
)
}
func (p *deviceAuthProjection) Reducers() []handler.AggregateReducer {
func (p *deviceAuthRequestProjection) Reducers() []handler.AggregateReducer {
return []handler.AggregateReducer{
{
Aggregate: deviceauth.AggregateType,
@ -73,22 +66,18 @@ func (p *deviceAuthProjection) Reducers() []handler.AggregateReducer {
},
{
Event: deviceauth.ApprovedEventType,
Reduce: p.reduceAppoved,
Reduce: p.reduceDoneEvents,
},
{
Event: deviceauth.CanceledEventType,
Reduce: p.reduceCanceled,
},
{
Event: deviceauth.RemovedEventType,
Reduce: p.reduceRemoved,
Reduce: p.reduceDoneEvents,
},
},
},
}
}
func (p *deviceAuthProjection) reduceAdded(event eventstore.Event) (*handler.Statement, error) {
func (p *deviceAuthRequestProjection) reduceAdded(event eventstore.Event) (*handler.Statement, error) {
e, ok := event.(*deviceauth.AddedEvent)
if !ok {
return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-chu6O", "reduce.wrong.event.type %T != %s", event, deviceauth.AddedEventType)
@ -96,66 +85,30 @@ func (p *deviceAuthProjection) reduceAdded(event eventstore.Event) (*handler.Sta
return handler.NewCreateStatement(
e,
[]handler.Column{
handler.NewCol(DeviceAuthColumnID, e.Aggregate().ID),
handler.NewCol(DeviceAuthColumnClientID, e.ClientID),
handler.NewCol(DeviceAuthColumnDeviceCode, e.DeviceCode),
handler.NewCol(DeviceAuthColumnUserCode, e.UserCode),
handler.NewCol(DeviceAuthColumnExpires, e.Expires),
handler.NewCol(DeviceAuthColumnScopes, e.Scopes),
handler.NewCol(DeviceAuthColumnCreationDate, e.CreationDate()),
handler.NewCol(DeviceAuthColumnChangeDate, e.CreationDate()),
handler.NewCol(DeviceAuthColumnSequence, e.Sequence()),
handler.NewCol(DeviceAuthColumnInstanceID, e.Aggregate().InstanceID),
handler.NewCol(DeviceAuthRequestColumnClientID, e.ClientID),
handler.NewCol(DeviceAuthRequestColumnDeviceCode, e.DeviceCode),
handler.NewCol(DeviceAuthRequestColumnUserCode, e.UserCode),
handler.NewCol(DeviceAuthRequestColumnScopes, e.Scopes),
handler.NewCol(DeviceAuthRequestColumnCreationDate, e.CreationDate()),
handler.NewCol(DeviceAuthRequestColumnChangeDate, e.CreationDate()),
handler.NewCol(DeviceAuthRequestColumnSequence, e.Sequence()),
handler.NewCol(DeviceAuthRequestColumnInstanceID, e.Aggregate().InstanceID),
},
), nil
}
func (p *deviceAuthProjection) reduceAppoved(event eventstore.Event) (*handler.Statement, error) {
e, ok := event.(*deviceauth.ApprovedEvent)
if !ok {
return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-kei0A", "reduce.wrong.event.type %T != %s", event, deviceauth.ApprovedEventType)
}
return handler.NewUpdateStatement(e,
[]handler.Column{
handler.NewCol(DeviceAuthColumnState, domain.DeviceAuthStateApproved),
handler.NewCol(DeviceAuthColumnSubject, e.Subject),
handler.NewCol(DeviceAuthColumnChangeDate, e.CreationDate()),
handler.NewCol(DeviceAuthColumnSequence, e.Sequence()),
},
[]handler.Condition{
handler.NewCond(DeviceAuthColumnInstanceID, e.Aggregate().InstanceID),
handler.NewCond(DeviceAuthColumnID, e.Aggregate().ID),
},
), nil
}
// reduceDoneEvents removes the device auth request from the projection.
func (p *deviceAuthRequestProjection) reduceDoneEvents(event eventstore.Event) (*handler.Statement, error) {
switch event.(type) {
case *deviceauth.ApprovedEvent, *deviceauth.CanceledEvent:
return handler.NewDeleteStatement(event,
[]handler.Condition{
handler.NewCond(DeviceAuthRequestColumnInstanceID, event.Aggregate().InstanceID),
handler.NewCond(DeviceAuthRequestColumnDeviceCode, event.Aggregate().ID),
},
), nil
func (p *deviceAuthProjection) reduceCanceled(event eventstore.Event) (*handler.Statement, error) {
e, ok := event.(*deviceauth.CanceledEvent)
if !ok {
return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-eeS8d", "reduce.wrong.event.type %T != %s", event, deviceauth.CanceledEventType)
default:
return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-eeS8d", "reduce.wrong.event.type %T", event)
}
return handler.NewUpdateStatement(e,
[]handler.Column{
handler.NewCol(DeviceAuthColumnState, e.Reason.State()),
handler.NewCol(DeviceAuthColumnChangeDate, e.CreationDate()),
handler.NewCol(DeviceAuthColumnSequence, e.Sequence()),
},
[]handler.Condition{
handler.NewCond(DeviceAuthColumnInstanceID, e.Aggregate().InstanceID),
handler.NewCond(DeviceAuthColumnID, e.Aggregate().ID),
},
), nil
}
func (p *deviceAuthProjection) reduceRemoved(event eventstore.Event) (*handler.Statement, error) {
e, ok := event.(*deviceauth.RemovedEvent)
if !ok {
return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-AJi1u", "reduce.wrong.event.type %T != %s", event, deviceauth.RemovedEventType)
}
return handler.NewDeleteStatement(e,
[]handler.Condition{
handler.NewCond(DeviceAuthColumnInstanceID, e.Aggregate().InstanceID),
handler.NewCond(DeviceAuthColumnID, e.Aggregate().ID),
},
), nil
}

View File

@ -20,6 +20,7 @@ import (
"github.com/zitadel/zitadel/internal/query/projection"
"github.com/zitadel/zitadel/internal/repository/action"
"github.com/zitadel/zitadel/internal/repository/authrequest"
"github.com/zitadel/zitadel/internal/repository/deviceauth"
"github.com/zitadel/zitadel/internal/repository/idpintent"
iam_repo "github.com/zitadel/zitadel/internal/repository/instance"
"github.com/zitadel/zitadel/internal/repository/keypair"
@ -99,6 +100,7 @@ func StartQueries(
quota.RegisterEventMappers(repo.eventstore)
limits.RegisterEventMappers(repo.eventstore)
restrictions.RegisterEventMappers(repo.eventstore)
deviceauth.RegisterEventMappers(repo.eventstore)
repo.checkPermission = permissionCheck(repo)

View File

@ -12,6 +12,7 @@ import (
"github.com/zitadel/zitadel/internal/eventstore/repository/mock"
action_repo "github.com/zitadel/zitadel/internal/repository/action"
"github.com/zitadel/zitadel/internal/repository/authrequest"
"github.com/zitadel/zitadel/internal/repository/deviceauth"
"github.com/zitadel/zitadel/internal/repository/feature"
"github.com/zitadel/zitadel/internal/repository/idpintent"
iam_repo "github.com/zitadel/zitadel/internal/repository/instance"
@ -54,6 +55,7 @@ func expectEventstore(expects ...expect) func(*testing.T) *eventstore.Eventstore
quota_repo.RegisterEventMappers(es)
limits.RegisterEventMappers(es)
feature.RegisterEventMappers(es)
deviceauth.RegisterEventMappers(es)
return es
}
}

View File

@ -1,8 +1,6 @@
package deviceauth
import (
"strings"
"github.com/zitadel/zitadel/internal/eventstore"
)
@ -13,15 +11,11 @@ const (
DuplicateDeviceCode = "Errors.DeviceCode.AlreadyExists"
)
func deviceCodeUniqueField(clientID, deviceCode string) string {
return strings.Join([]string{clientID, deviceCode}, ":")
}
func NewAddUniqueConstraints(clientID, deviceCode, userCode string) []*eventstore.UniqueConstraint {
func NewAddUniqueConstraints(deviceCode, userCode string) []*eventstore.UniqueConstraint {
return []*eventstore.UniqueConstraint{
eventstore.NewAddEventUniqueConstraint(
UniqueDeviceCode,
deviceCodeUniqueField(clientID, deviceCode),
deviceCode,
DuplicateDeviceCode,
),
eventstore.NewAddEventUniqueConstraint(
@ -32,11 +26,11 @@ func NewAddUniqueConstraints(clientID, deviceCode, userCode string) []*eventstor
}
}
func NewRemoveUniqueConstraints(clientID, deviceCode, userCode string) []*eventstore.UniqueConstraint {
func NewRemoveUniqueConstraints(deviceCode, userCode string) []*eventstore.UniqueConstraint {
return []*eventstore.UniqueConstraint{
eventstore.NewRemoveUniqueConstraint(
UniqueDeviceCode,
deviceCodeUniqueField(clientID, deviceCode),
deviceCode,
),
eventstore.NewRemoveUniqueConstraint(
UniqueUserCode,

View File

@ -13,7 +13,6 @@ const (
AddedEventType = eventTypePrefix + "added"
ApprovedEventType = eventTypePrefix + "approved"
CanceledEventType = eventTypePrefix + "canceled"
RemovedEventType = eventTypePrefix + "removed"
)
type AddedEvent struct {
@ -36,7 +35,7 @@ func (e *AddedEvent) Payload() any {
}
func (e *AddedEvent) UniqueConstraints() []*eventstore.UniqueConstraint {
return NewAddUniqueConstraints(e.ClientID, e.DeviceCode, e.UserCode)
return NewAddUniqueConstraints(e.DeviceCode, e.UserCode)
}
func NewAddedEvent(
@ -58,7 +57,9 @@ func NewAddedEvent(
type ApprovedEvent struct {
*eventstore.BaseEvent `json:"-"`
Subject string
Subject string
UserAuthMethods []domain.UserAuthMethodType
AuthTime time.Time
}
func (e *ApprovedEvent) SetBaseEvent(b *eventstore.BaseEvent) {
@ -77,12 +78,16 @@ func NewApprovedEvent(
ctx context.Context,
aggregate *eventstore.Aggregate,
subject string,
userAuthMethods []domain.UserAuthMethodType,
authTime time.Time,
) *ApprovedEvent {
return &ApprovedEvent{
eventstore.NewBaseEventForPush(
ctx, aggregate, ApprovedEventType,
),
subject,
userAuthMethods,
authTime,
}
}
@ -107,36 +112,3 @@ func (e *CanceledEvent) UniqueConstraints() []*eventstore.UniqueConstraint {
func NewCanceledEvent(ctx context.Context, aggregate *eventstore.Aggregate, reason domain.DeviceAuthCanceled) *CanceledEvent {
return &CanceledEvent{eventstore.NewBaseEventForPush(ctx, aggregate, CanceledEventType), reason}
}
type RemovedEvent struct {
*eventstore.BaseEvent `json:"-"`
ClientID string
DeviceCode string
UserCode string
}
func (e *RemovedEvent) SetBaseEvent(b *eventstore.BaseEvent) {
e.BaseEvent = b
}
func (e *RemovedEvent) Payload() any {
return e
}
func (e *RemovedEvent) UniqueConstraints() []*eventstore.UniqueConstraint {
return NewRemoveUniqueConstraints(e.ClientID, e.DeviceCode, e.UserCode)
}
func NewRemovedEvent(
ctx context.Context,
aggregate *eventstore.Aggregate,
clientID, deviceCode, userCode string,
) *RemovedEvent {
return &RemovedEvent{
eventstore.NewBaseEventForPush(
ctx, aggregate, RemovedEventType,
),
clientID, deviceCode, userCode,
}
}

View File

@ -0,0 +1,9 @@
package deviceauth
import "github.com/zitadel/zitadel/internal/eventstore"
func RegisterEventMappers(es *eventstore.Eventstore) {
es.RegisterFilterEventMapper(AggregateType, AddedEventType, eventstore.GenericEventMapper[AddedEvent]).
RegisterFilterEventMapper(AggregateType, ApprovedEventType, eventstore.GenericEventMapper[ApprovedEvent]).
RegisterFilterEventMapper(AggregateType, CanceledEventType, eventstore.GenericEventMapper[CanceledEvent])
}

View File

@ -2,7 +2,6 @@ package org
import (
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/deviceauth"
)
func RegisterEventMappers(es *eventstore.Eventstore) {
@ -114,9 +113,5 @@ func RegisterEventMappers(es *eventstore.Eventstore) {
RegisterFilterEventMapper(AggregateType, MetadataRemovedAllType, MetadataRemovedAllEventMapper).
RegisterFilterEventMapper(AggregateType, NotificationPolicyAddedEventType, NotificationPolicyAddedEventMapper).
RegisterFilterEventMapper(AggregateType, NotificationPolicyChangedEventType, NotificationPolicyChangedEventMapper).
RegisterFilterEventMapper(AggregateType, NotificationPolicyRemovedEventType, NotificationPolicyRemovedEventMapper).
RegisterFilterEventMapper(AggregateType, deviceauth.AddedEventType, eventstore.GenericEventMapper[deviceauth.AddedEvent]).
RegisterFilterEventMapper(AggregateType, deviceauth.ApprovedEventType, eventstore.GenericEventMapper[deviceauth.ApprovedEvent]).
RegisterFilterEventMapper(AggregateType, deviceauth.CanceledEventType, eventstore.GenericEventMapper[deviceauth.CanceledEvent]).
RegisterFilterEventMapper(AggregateType, deviceauth.RemovedEventType, eventstore.GenericEventMapper[deviceauth.RemovedEvent])
RegisterFilterEventMapper(AggregateType, NotificationPolicyRemovedEventType, NotificationPolicyRemovedEventMapper)
}