feat: device authorization RFC 8628 (#5646)

* device auth: implement the write events

* add grant type device code

* fix(init): check if default value implements stringer

---------

Co-authored-by: adlerhurst <silvan.reusser@gmail.com>
This commit is contained in:
Tim Möhlmann
2023-04-19 11:46:02 +03:00
committed by GitHub
parent 3cd2cecfdf
commit 5819924275
49 changed files with 2313 additions and 38 deletions

View File

@@ -0,0 +1,141 @@
package query
import (
"context"
"database/sql"
errs "errors"
sq "github.com/Masterminds/squirrel"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query/projection"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
var (
deviceAuthTable = table{
name: projection.DeviceAuthProjectionTable,
instanceIDCol: projection.DeviceAuthColumnInstanceID,
}
DeviceAuthColumnID = Column{
name: projection.DeviceAuthColumnID,
table: deviceAuthTable,
}
DeviceAuthColumnClientID = Column{
name: projection.DeviceAuthColumnClientID,
table: deviceAuthTable,
}
DeviceAuthColumnDeviceCode = Column{
name: projection.DeviceAuthColumnDeviceCode,
table: deviceAuthTable,
}
DeviceAuthColumnUserCode = Column{
name: projection.DeviceAuthColumnUserCode,
table: deviceAuthTable,
}
DeviceAuthColumnExpires = Column{
name: projection.DeviceAuthColumnExpires,
table: deviceAuthTable,
}
DeviceAuthColumnScopes = Column{
name: projection.DeviceAuthColumnScopes,
table: deviceAuthTable,
}
DeviceAuthColumnState = Column{
name: projection.DeviceAuthColumnState,
table: deviceAuthTable,
}
DeviceAuthColumnSubject = Column{
name: projection.DeviceAuthColumnSubject,
table: deviceAuthTable,
}
DeviceAuthColumnCreationDate = Column{
name: projection.DeviceAuthColumnCreationDate,
table: deviceAuthTable,
}
DeviceAuthColumnChangeDate = Column{
name: projection.DeviceAuthColumnChangeDate,
table: deviceAuthTable,
}
DeviceAuthColumnSequence = Column{
name: projection.DeviceAuthColumnSequence,
table: deviceAuthTable,
}
DeviceAuthColumnInstanceID = Column{
name: projection.DeviceAuthColumnInstanceID,
table: deviceAuthTable,
}
)
func (q *Queries) DeviceAuthByDeviceCode(ctx context.Context, clientID, deviceCode string) (_ *domain.DeviceAuth, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
stmt, scan := prepareDeviceAuthQuery(ctx, q.client)
eq := sq.Eq{
DeviceAuthColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
DeviceAuthColumnClientID.identifier(): clientID,
DeviceAuthColumnDeviceCode.identifier(): deviceCode,
}
query, args, err := stmt.Where(eq).ToSql()
if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-uk1Oh", "Errors.Query.SQLStatement")
}
return scan(q.client.QueryRowContext(ctx, query, args...))
}
func (q *Queries) DeviceAuthByUserCode(ctx context.Context, userCode string) (_ *domain.DeviceAuth, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
stmt, scan := prepareDeviceAuthQuery(ctx, q.client)
eq := sq.Eq{
DeviceAuthColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
DeviceAuthColumnUserCode.identifier(): userCode,
}
query, args, err := stmt.Where(eq).ToSql()
if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-Axu7l", "Errors.Query.SQLStatement")
}
return scan(q.client.QueryRowContext(ctx, query, args...))
}
var deviceAuthSelectColumns = []string{
DeviceAuthColumnID.identifier(),
DeviceAuthColumnClientID.identifier(),
DeviceAuthColumnScopes.identifier(),
DeviceAuthColumnExpires.identifier(),
DeviceAuthColumnState.identifier(),
DeviceAuthColumnSubject.identifier(),
}
func prepareDeviceAuthQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*domain.DeviceAuth, error)) {
return sq.Select(deviceAuthSelectColumns...).From(deviceAuthTable.identifier()).PlaceholderFormat(sq.Dollar),
func(row *sql.Row) (*domain.DeviceAuth, error) {
dst := new(domain.DeviceAuth)
var scopes database.StringArray
err := row.Scan(
&dst.AggregateID,
&dst.ClientID,
&scopes,
&dst.Expires,
&dst.State,
&dst.Subject,
)
if errs.Is(err, sql.ErrNoRows) {
return nil, errors.ThrowNotFound(err, "QUERY-Sah9a", "Errors.DeviceAuth.NotExisting")
}
if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-Voo3o", "Errors.Internal")
}
dst.Scopes = scopes
return dst, nil
}
}

View File

@@ -0,0 +1,158 @@
package query
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"regexp"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore/v1/models"
)
const (
expectedDeviceAuthQueryC = `SELECT` +
` projections.device_authorizations.id,` +
` projections.device_authorizations.client_id,` +
` projections.device_authorizations.scopes,` +
` projections.device_authorizations.expires,` +
` projections.device_authorizations.state,` +
` projections.device_authorizations.subject` +
` FROM projections.device_authorizations`
expectedDeviceAuthWhereDeviceCodeQueryC = expectedDeviceAuthQueryC +
` WHERE projections.device_authorizations.client_id = $1` +
` AND projections.device_authorizations.device_code = $2` +
` AND projections.device_authorizations.instance_id = $3`
expectedDeviceAuthWhereUserCodeQueryC = expectedDeviceAuthQueryC +
` WHERE projections.device_authorizations.instance_id = $1` +
` AND projections.device_authorizations.user_code = $2`
)
var (
expectedDeviceAuthQuery = regexp.QuoteMeta(expectedDeviceAuthQueryC)
expectedDeviceAuthWhereDeviceCodeQuery = regexp.QuoteMeta(expectedDeviceAuthWhereDeviceCodeQueryC)
expectedDeviceAuthWhereUserCodeQuery = regexp.QuoteMeta(expectedDeviceAuthWhereUserCodeQueryC)
expectedDeviceAuthValues = []driver.Value{
"primary-id",
"client-id",
database.StringArray{"a", "b", "c"},
testNow,
domain.DeviceAuthStateApproved,
"subject",
}
expectedDeviceAuth = &domain.DeviceAuth{
ObjectRoot: models.ObjectRoot{
AggregateID: "primary-id",
},
ClientID: "client-id",
Scopes: []string{"a", "b", "c"},
Expires: testNow,
State: domain.DeviceAuthStateApproved,
Subject: "subject",
}
)
func TestQueries_DeviceAuthByDeviceCode(t *testing.T) {
client, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("failed to build mock client: %v", err)
}
defer client.Close()
mock.ExpectQuery(expectedDeviceAuthWhereDeviceCodeQuery).WillReturnRows(
sqlmock.NewRows(deviceAuthSelectColumns).AddRow(expectedDeviceAuthValues...),
)
q := Queries{
client: &database.DB{DB: client},
}
got, err := q.DeviceAuthByDeviceCode(context.TODO(), "123", "456")
require.NoError(t, err)
assert.Equal(t, expectedDeviceAuth, got)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestQueries_DeviceAuthByUserCode(t *testing.T) {
client, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("failed to build mock client: %v", err)
}
defer client.Close()
mock.ExpectQuery(expectedDeviceAuthWhereUserCodeQuery).WillReturnRows(
sqlmock.NewRows(deviceAuthSelectColumns).AddRow(expectedDeviceAuthValues...),
)
q := Queries{
client: &database.DB{DB: client},
}
got, err := q.DeviceAuthByUserCode(context.TODO(), "789")
require.NoError(t, err)
assert.Equal(t, expectedDeviceAuth, got)
require.NoError(t, mock.ExpectationsWereMet())
}
func Test_prepareDeviceAuthQuery(t *testing.T) {
type want struct {
sqlExpectations sqlExpectation
err checkErr
}
tests := []struct {
name string
want want
object any
}{
{
name: "success",
want: want{
sqlExpectations: mockQueries(
expectedDeviceAuthQuery,
deviceAuthSelectColumns,
[][]driver.Value{expectedDeviceAuthValues},
),
},
object: expectedDeviceAuth,
},
{
name: "not found error",
want: want{
sqlExpectations: mockQueryErr(
expectedDeviceAuthQuery,
sql.ErrNoRows,
),
err: func(err error) (error, bool) {
if !errors.Is(err, sql.ErrNoRows) {
return fmt.Errorf("err should be sql.ErrNoRows got: %w", err), false
}
return nil, true
},
},
},
{
name: "other error",
want: want{
sqlExpectations: mockQueryErr(
expectedDeviceAuthQuery,
sql.ErrConnDone,
),
err: func(err error) (error, bool) {
if !errors.Is(err, sql.ErrConnDone) {
return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false
}
return nil, true
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assertPrepare(t, prepareDeviceAuthQuery, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...)
})
}
}

View File

@@ -0,0 +1,161 @@
package projection
import (
"context"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/handler"
"github.com/zitadel/zitadel/internal/eventstore/handler/crdb"
"github.com/zitadel/zitadel/internal/repository/deviceauth"
)
const (
DeviceAuthProjectionTable = "projections.device_authorizations"
DeviceAuthColumnID = "id"
DeviceAuthColumnClientID = "client_id"
DeviceAuthColumnDeviceCode = "device_code"
DeviceAuthColumnUserCode = "user_code"
DeviceAuthColumnExpires = "expires"
DeviceAuthColumnScopes = "scopes"
DeviceAuthColumnState = "state"
DeviceAuthColumnSubject = "subject"
DeviceAuthColumnCreationDate = "creation_date"
DeviceAuthColumnChangeDate = "change_date"
DeviceAuthColumnSequence = "sequence"
DeviceAuthColumnInstanceID = "instance_id"
)
type deviceAuthProjection struct {
crdb.StatementHandler
}
func newDeviceAuthProjection(ctx context.Context, config crdb.StatementHandlerConfig) *deviceAuthProjection {
p := new(deviceAuthProjection)
config.ProjectionName = DeviceAuthProjectionTable
config.Reducers = p.reducers()
config.InitCheck = crdb.NewTableCheck(
crdb.NewTable([]*crdb.Column{
crdb.NewColumn(DeviceAuthColumnID, crdb.ColumnTypeText),
crdb.NewColumn(DeviceAuthColumnClientID, crdb.ColumnTypeText),
crdb.NewColumn(DeviceAuthColumnDeviceCode, crdb.ColumnTypeText),
crdb.NewColumn(DeviceAuthColumnUserCode, crdb.ColumnTypeText),
crdb.NewColumn(DeviceAuthColumnExpires, crdb.ColumnTypeTimestamp),
crdb.NewColumn(DeviceAuthColumnScopes, crdb.ColumnTypeTextArray),
crdb.NewColumn(DeviceAuthColumnState, crdb.ColumnTypeEnum, crdb.Default(domain.DeviceAuthStateInitiated)),
crdb.NewColumn(DeviceAuthColumnSubject, crdb.ColumnTypeText, crdb.Default("")),
crdb.NewColumn(DeviceAuthColumnCreationDate, crdb.ColumnTypeTimestamp),
crdb.NewColumn(DeviceAuthColumnChangeDate, crdb.ColumnTypeTimestamp),
crdb.NewColumn(DeviceAuthColumnSequence, crdb.ColumnTypeInt64),
crdb.NewColumn(DeviceAuthColumnInstanceID, crdb.ColumnTypeText),
},
crdb.NewPrimaryKey(DeviceAuthColumnInstanceID, DeviceAuthColumnID),
crdb.WithIndex(crdb.NewIndex("user_code", []string{DeviceAuthColumnInstanceID, DeviceAuthColumnUserCode})),
crdb.WithIndex(crdb.NewIndex("device_code", []string{DeviceAuthColumnInstanceID, DeviceAuthColumnClientID, DeviceAuthColumnDeviceCode})),
),
)
p.StatementHandler = crdb.NewStatementHandler(ctx, config)
return p
}
func (p *deviceAuthProjection) reducers() []handler.AggregateReducer {
return []handler.AggregateReducer{
{
Aggregate: deviceauth.AggregateType,
EventRedusers: []handler.EventReducer{
{
Event: deviceauth.AddedEventType,
Reduce: p.reduceAdded,
},
{
Event: deviceauth.ApprovedEventType,
Reduce: p.reduceAppoved,
},
{
Event: deviceauth.CanceledEventType,
Reduce: p.reduceCanceled,
},
{
Event: deviceauth.RemovedEventType,
Reduce: p.reduceRemoved,
},
},
},
}
}
func (p *deviceAuthProjection) reduceAdded(event eventstore.Event) (*handler.Statement, error) {
e, ok := event.(*deviceauth.AddedEvent)
if !ok {
return nil, errors.ThrowInvalidArgumentf(nil, "HANDL-chu6O", "reduce.wrong.event.type %T != %s", event, deviceauth.AddedEventType)
}
return crdb.NewCreateStatement(
e,
[]handler.Column{
handler.NewCol(DeviceAuthColumnID, e.Aggregate().ID),
handler.NewCol(DeviceAuthColumnClientID, e.ClientID),
handler.NewCol(DeviceAuthColumnDeviceCode, e.DeviceCode),
handler.NewCol(DeviceAuthColumnUserCode, e.UserCode),
handler.NewCol(DeviceAuthColumnExpires, e.Expires),
handler.NewCol(DeviceAuthColumnScopes, e.Scopes),
handler.NewCol(DeviceAuthColumnCreationDate, e.CreationDate()),
handler.NewCol(DeviceAuthColumnChangeDate, e.CreationDate()),
handler.NewCol(DeviceAuthColumnSequence, e.Sequence()),
handler.NewCol(DeviceAuthColumnInstanceID, e.Aggregate().InstanceID),
},
), nil
}
func (p *deviceAuthProjection) reduceAppoved(event eventstore.Event) (*handler.Statement, error) {
e, ok := event.(*deviceauth.ApprovedEvent)
if !ok {
return nil, errors.ThrowInvalidArgumentf(nil, "HANDL-kei0A", "reduce.wrong.event.type %T != %s", event, deviceauth.ApprovedEventType)
}
return crdb.NewUpdateStatement(e,
[]handler.Column{
handler.NewCol(DeviceAuthColumnState, domain.DeviceAuthStateApproved),
handler.NewCol(DeviceAuthColumnSubject, e.Subject),
handler.NewCol(DeviceAuthColumnChangeDate, e.CreationDate()),
handler.NewCol(DeviceAuthColumnSequence, e.Sequence()),
},
[]handler.Condition{
handler.NewCond(DeviceAuthColumnInstanceID, e.Aggregate().InstanceID),
handler.NewCond(DeviceAuthColumnID, e.Aggregate().ID),
},
), nil
}
func (p *deviceAuthProjection) reduceCanceled(event eventstore.Event) (*handler.Statement, error) {
e, ok := event.(*deviceauth.CanceledEvent)
if !ok {
return nil, errors.ThrowInvalidArgumentf(nil, "HANDL-eeS8d", "reduce.wrong.event.type %T != %s", event, deviceauth.CanceledEventType)
}
return crdb.NewUpdateStatement(e,
[]handler.Column{
handler.NewCol(DeviceAuthColumnState, e.Reason.State()),
handler.NewCol(DeviceAuthColumnChangeDate, e.CreationDate()),
handler.NewCol(DeviceAuthColumnSequence, e.Sequence()),
},
[]handler.Condition{
handler.NewCond(DeviceAuthColumnInstanceID, e.Aggregate().InstanceID),
handler.NewCond(DeviceAuthColumnID, e.Aggregate().ID),
},
), nil
}
func (p *deviceAuthProjection) reduceRemoved(event eventstore.Event) (*handler.Statement, error) {
e, ok := event.(*deviceauth.RemovedEvent)
if !ok {
return nil, errors.ThrowInvalidArgumentf(nil, "HANDL-AJi1u", "reduce.wrong.event.type %T != %s", event, deviceauth.RemovedEventType)
}
return crdb.NewDeleteStatement(e,
[]handler.Condition{
handler.NewCond(DeviceAuthColumnInstanceID, e.Aggregate().InstanceID),
handler.NewCond(DeviceAuthColumnID, e.Aggregate().ID),
},
), nil
}

View File

@@ -64,6 +64,7 @@ var (
NotificationPolicyProjection *notificationPolicyProjection
NotificationsProjection interface{}
NotificationsQuotaProjection interface{}
DeviceAuthProjection *deviceAuthProjection
)
type projection interface {
@@ -139,6 +140,7 @@ func Create(ctx context.Context, sqlClient *database.DB, es *eventstore.Eventsto
KeyProjection = newKeyProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["keys"]), keyEncryptionAlgorithm, certEncryptionAlgorithm)
SecurityPolicyProjection = newSecurityPolicyProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["security_policies"]))
NotificationPolicyProjection = newNotificationPolicyProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["notification_policies"]))
DeviceAuthProjection = newDeviceAuthProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["device_auth"]))
newProjectionsList()
return nil
}
@@ -234,5 +236,6 @@ func newProjectionsList() {
KeyProjection,
SecurityPolicyProjection,
NotificationPolicyProjection,
DeviceAuthProjection,
}
}