diff --git a/cmd/defaults.yaml b/cmd/defaults.yaml index 8996e384c6..1afc617c5a 100644 --- a/cmd/defaults.yaml +++ b/cmd/defaults.yaml @@ -233,6 +233,8 @@ OIDC: Path: /oidc/v1/end_session Keys: Path: /oauth/v2/keys + DeviceAuth: + Path: /oauth/v2/device_authorization SAML: ProviderConfig: diff --git a/cmd/start/start.go b/cmd/start/start.go index f80fc6fef1..c3ba766fea 100644 --- a/cmd/start/start.go +++ b/cmd/start/start.go @@ -12,14 +12,13 @@ import ( "syscall" "time" - "github.com/zitadel/saml/pkg/provider" - clockpkg "github.com/benbjohnson/clock" "github.com/gorilla/mux" "github.com/spf13/cobra" "github.com/spf13/viper" "github.com/zitadel/logging" "github.com/zitadel/oidc/v2/pkg/op" + "github.com/zitadel/saml/pkg/provider" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" @@ -294,6 +293,7 @@ func startAPIs( return fmt.Errorf("unable to start login: %w", err) } apis.RegisterHandlerOnPrefix(login.HandlerPrefix, l.Handler()) + apis.HandleFunc(login.EndpointDeviceAuth, login.RedirectDeviceAuthToPrefix) // handle grpc at last to be able to handle the root, because grpc and gateway require a lot of different prefixes apis.RouteGRPC() diff --git a/go.mod b/go.mod index eb61ce26e2..668aa02305 100644 --- a/go.mod +++ b/go.mod @@ -45,6 +45,7 @@ require ( github.com/minio/minio-go/v7 v7.0.50 github.com/mitchellh/mapstructure v1.5.0 github.com/muesli/gamut v0.3.1 + github.com/muhlemmer/gu v0.3.1 github.com/nicksnyder/go-i18n/v2 v2.2.1 github.com/pkg/errors v0.9.1 github.com/pquerna/otp v1.4.0 @@ -57,7 +58,7 @@ require ( github.com/superseriousbusiness/exifremove v0.0.0-20210330092427-6acd27eac203 github.com/ttacon/libphonenumber v1.2.1 github.com/zitadel/logging v0.3.4 - github.com/zitadel/oidc/v2 v2.2.6 + github.com/zitadel/oidc/v2 v2.4.0 github.com/zitadel/saml v0.0.11 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.40.0 go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.40.0 @@ -70,10 +71,10 @@ require ( go.opentelemetry.io/otel/sdk/metric v0.37.0 go.opentelemetry.io/otel/trace v1.14.0 golang.org/x/crypto v0.7.0 - golang.org/x/net v0.8.0 - golang.org/x/oauth2 v0.6.0 + golang.org/x/net v0.9.0 + golang.org/x/oauth2 v0.7.0 golang.org/x/sync v0.1.0 - golang.org/x/text v0.8.0 + golang.org/x/text v0.9.0 golang.org/x/tools v0.7.0 google.golang.org/api v0.115.0 google.golang.org/genproto v0.0.0-20230403163135-c38d8f061ccd @@ -90,7 +91,6 @@ require ( github.com/go-logr/stdr v1.2.2 // indirect github.com/google/pprof v0.0.0-20230323073829-e72429f035bd // indirect github.com/klauspost/cpuid/v2 v2.2.4 // indirect - github.com/muhlemmer/gu v0.3.1 // indirect github.com/pelletier/go-toml/v2 v2.0.7 // indirect go.opentelemetry.io/otel/exporters/otlp/internal/retry v1.14.0 // indirect go.uber.org/multierr v1.11.0 // indirect diff --git a/go.sum b/go.sum index 4efcc17bb2..382d482406 100644 --- a/go.sum +++ b/go.sum @@ -1130,8 +1130,8 @@ github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5t github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= github.com/zitadel/logging v0.3.4 h1:9hZsTjMMTE3X2LUi0xcF9Q9EdLo+FAezeu52ireBbHM= github.com/zitadel/logging v0.3.4/go.mod h1:aPpLQhE+v6ocNK0TWrBrd363hZ95KcI17Q1ixAQwZF0= -github.com/zitadel/oidc/v2 v2.2.6 h1:L2k5q1X8Rucax5Ynp3B3lz7JQDJxUwfWCOmgc9Bh0BM= -github.com/zitadel/oidc/v2 v2.2.6/go.mod h1:tGkj9lQk6KVj5hsM89XPadvi6I06666sMy3KtykvSFM= +github.com/zitadel/oidc/v2 v2.4.0 h1:BKx61qOxDf+GjrY8T6lFxPjea0aMfkFvHD9pqyJGpFk= +github.com/zitadel/oidc/v2 v2.4.0/go.mod h1:wBOrfB0m/tGXo6isym1F5k3VeXSUinGsAt2H8V/+Uks= github.com/zitadel/saml v0.0.11 h1:kObucnBrcu1PHCO7RGT0iVeuJL/5I50gUgr40S41nMs= github.com/zitadel/saml v0.0.11/go.mod h1:YGWAvPZRv4DbEZ78Ht/2P0AWzGn+6WGhFf90PMXl0Po= github.com/ziutek/mymysql v1.5.4/go.mod h1:LMSpPZ6DbqWFxNCHW77HeMg9I646SAhApZ/wKdgO/C0= @@ -1342,8 +1342,8 @@ golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.4.0/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ= -golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= +golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM= +golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181106182150-f42d05182288/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -1360,8 +1360,8 @@ golang.org/x/oauth2 v0.0.0-20210313182246-cd4f82c27b84/go.mod h1:KelEdhl1UZF7XfJ golang.org/x/oauth2 v0.0.0-20210413134643-5e61552d6c78/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20210427180440-81ed05c6b58c/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= -golang.org/x/oauth2 v0.6.0 h1:Lh8GPgSKBfWSwFvtuWOfeI3aAAnbXTSutYxJiOJFgIw= -golang.org/x/oauth2 v0.6.0/go.mod h1:ycmewcwgD4Rpr3eZJLSB4Kyyljb3qDh40vJ8STE5HKw= +golang.org/x/oauth2 v0.7.0 h1:qe6s0zUXlPX80/dITx3440hWZ7GwMwgDDyrSGTPJG/g= +golang.org/x/oauth2 v0.7.0/go.mod h1:hPLQkd9LyjfXTiRohC/41GhcFqxisoUQ99sCUOHO9x4= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -1477,8 +1477,8 @@ golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.5.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68= -golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/internal/api/api.go b/internal/api/api.go index 3cb851e3d0..d634e50109 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -101,6 +101,12 @@ func (a *API) RegisterService(ctx context.Context, grpcServer server.Server) err return nil } +// HandleFunc allows registering a [http.HandlerFunc] on an exact +// path, instead of prefix like RegisterHandlerOnPrefix. +func (a *API) HandleFunc(path string, f http.HandlerFunc) { + a.router.HandleFunc(path, f) +} + // RegisterHandlerOnPrefix registers a http handler on a path prefix // the prefix will not be passed to the actual handler func (a *API) RegisterHandlerOnPrefix(prefix string, handler http.Handler) { diff --git a/internal/api/grpc/project/application.go b/internal/api/grpc/project/application.go index 5565ba359e..e555e93d27 100644 --- a/internal/api/grpc/project/application.go +++ b/internal/api/grpc/project/application.go @@ -136,6 +136,8 @@ func OIDCGrantTypesFromModel(grantTypes []domain.OIDCGrantType) []app_pb.OIDCGra oidcGrantTypes[i] = app_pb.OIDCGrantType_OIDC_GRANT_TYPE_IMPLICIT case domain.OIDCGrantTypeRefreshToken: oidcGrantTypes[i] = app_pb.OIDCGrantType_OIDC_GRANT_TYPE_REFRESH_TOKEN + case domain.OIDCGrantTypeDeviceCode: + oidcGrantTypes[i] = app_pb.OIDCGrantType_OIDC_GRANT_TYPE_DEVICE_CODE } } return oidcGrantTypes @@ -154,6 +156,8 @@ func OIDCGrantTypesToDomain(grantTypes []app_pb.OIDCGrantType) []domain.OIDCGran oidcGrantTypes[i] = domain.OIDCGrantTypeImplicit case app_pb.OIDCGrantType_OIDC_GRANT_TYPE_REFRESH_TOKEN: oidcGrantTypes[i] = domain.OIDCGrantTypeRefreshToken + case app_pb.OIDCGrantType_OIDC_GRANT_TYPE_DEVICE_CODE: + oidcGrantTypes[i] = domain.OIDCGrantTypeDeviceCode } } return oidcGrantTypes diff --git a/internal/api/oidc/auth_request_converter.go b/internal/api/oidc/auth_request_converter.go index 48729705e7..6473460843 100644 --- a/internal/api/oidc/auth_request_converter.go +++ b/internal/api/oidc/auth_request_converter.go @@ -99,15 +99,6 @@ func (a *AuthRequest) GetSubject() string { return a.UserID } -func (a *AuthRequest) Done() bool { - for _, step := range a.PossibleSteps { - if step.Type() == domain.NextStepRedirectToCallback { - return true - } - } - return false -} - func (a *AuthRequest) oidc() *domain.AuthRequestOIDC { return a.Request.(*domain.AuthRequestOIDC) } diff --git a/internal/api/oidc/client_converter.go b/internal/api/oidc/client_converter.go index 749a5c3dff..6b32f38927 100644 --- a/internal/api/oidc/client_converter.go +++ b/internal/api/oidc/client_converter.go @@ -200,6 +200,8 @@ func grantTypeToOIDC(grantType domain.OIDCGrantType) oidc.GrantType { return oidc.GrantTypeImplicit case domain.OIDCGrantTypeRefreshToken: return oidc.GrantTypeRefreshToken + case domain.OIDCGrantTypeDeviceCode: + return oidc.GrantTypeDeviceCode default: return oidc.GrantTypeCode } diff --git a/internal/api/oidc/device_auth.go b/internal/api/oidc/device_auth.go new file mode 100644 index 0000000000..7eee06096a --- /dev/null +++ b/internal/api/oidc/device_auth.go @@ -0,0 +1,176 @@ +package oidc + +import ( + "context" + "time" + + "github.com/zitadel/logging" + "github.com/zitadel/oidc/v2/pkg/oidc" + "github.com/zitadel/oidc/v2/pkg/op" + + "github.com/zitadel/zitadel/internal/api/ui/login" + "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/errors" + "github.com/zitadel/zitadel/internal/telemetry/tracing" +) + +const ( + DeviceAuthDefaultLifetime = 5 * time.Minute + DeviceAuthDefaultPollInterval = 5 * time.Second +) + +type DeviceAuthorizationConfig struct { + Lifetime time.Duration + PollInterval time.Duration + UserCode *UserCodeConfig +} + +type UserCodeConfig struct { + CharSet string + CharAmount int + DashInterval int +} + +// toOPConfig converts DeviceAuthorizationConfig to a [op.DeviceAuthorizationConfig], +// setting sane defaults for empty values. +// Safe to call when c is nil. +func (c *DeviceAuthorizationConfig) toOPConfig() op.DeviceAuthorizationConfig { + out := op.DeviceAuthorizationConfig{ + Lifetime: DeviceAuthDefaultLifetime, + PollInterval: DeviceAuthDefaultPollInterval, + UserFormPath: login.EndpointDeviceAuth, + UserCode: op.UserCodeBase20, + } + if c == nil { + return out + } + if c.Lifetime != 0 { + out.Lifetime = c.Lifetime + } + if c.PollInterval != 0 { + out.PollInterval = c.PollInterval + } + + if c.UserCode == nil { + return out + } + if c.UserCode.CharSet != "" { + out.UserCode.CharSet = c.UserCode.CharSet + } + if c.UserCode.CharAmount != 0 { + out.UserCode.CharAmount = c.UserCode.CharAmount + } + if c.UserCode.DashInterval != 0 { + out.UserCode.DashInterval = c.UserCode.CharAmount + } + return out +} + +// StoreDeviceAuthorization creates a new Device Authorization request. +// Implements the op.DeviceAuthorizationStorage interface. +func (o *OPStorage) StoreDeviceAuthorization(ctx context.Context, clientID, deviceCode, userCode string, expires time.Time, scopes []string) (err error) { + const logMsg = "store device authorization" + logger := logging.WithFields("client_id", clientID, "device_code", deviceCode, "user_code", userCode, "expires", expires, "scopes", scopes) + + ctx, span := tracing.NewSpan(ctx) + defer func() { + logger.OnError(err).Error(logMsg) + span.EndWithError(err) + }() + + // TODO(muhlemmer): Remove the following code block with oidc v3 + // https://github.com/zitadel/oidc/issues/370 + client, err := o.GetClientByClientID(ctx, clientID) + if err != nil { + return err + } + if !op.ValidateGrantType(client, oidc.GrantTypeDeviceCode) { + return errors.ThrowPermissionDeniedf(nil, "OIDC-et1Ae", "grant type %q not allowed for client", oidc.GrantTypeDeviceCode) + } + + scopes, err = o.assertProjectRoleScopes(ctx, clientID, scopes) + if err != nil { + return errors.ThrowPreconditionFailed(err, "OIDC-She4t", "Errors.Internal") + } + aggrID, details, err := o.command.AddDeviceAuth(ctx, clientID, deviceCode, userCode, expires, scopes) + if err == nil { + logger.SetFields("aggregate_id", aggrID, "details", details).Debug(logMsg) + } + + return err +} + +func newDeviceAuthorizationState(d *domain.DeviceAuth) *op.DeviceAuthorizationState { + return &op.DeviceAuthorizationState{ + ClientID: d.ClientID, + Scopes: d.Scopes, + Expires: d.Expires, + Done: d.State.Done(), + Subject: d.Subject, + Denied: d.State.Denied(), + } +} + +// GetDeviceAuthorizatonState retieves the current state of the Device Authorization process. +// It implements the [op.DeviceAuthorizationStorage] interface and is used by devices that +// are polling until they successfully receive a token or we indicate a denied or expired state. +// As generated user codes are of low entropy, this implementation also takes care or +// device authorization request cleanup, when it has been Approved, Denied or Expired. +func (o *OPStorage) GetDeviceAuthorizatonState(ctx context.Context, clientID, deviceCode string) (state *op.DeviceAuthorizationState, err error) { + const logMsg = "get device authorization state" + logger := logging.WithFields("client_id", clientID, "device_code", deviceCode) + + ctx, span := tracing.NewSpan(ctx) + defer func() { + if err != nil { + logger.WithError(err).Error(logMsg) + } + span.EndWithError(err) + }() + + deviceAuth, err := o.query.DeviceAuthByDeviceCode(ctx, clientID, deviceCode) + if err != nil { + return nil, err + } + logger.SetFields( + "expires", deviceAuth.Expires, "scopes", deviceAuth.Scopes, + "subject", deviceAuth.Subject, "state", deviceAuth.State, + ).Debug("device authorization state") + + // Cancel the request if it is expired, only if it wasn't Done meanwhile + if !deviceAuth.State.Done() && deviceAuth.Expires.Before(time.Now()) { + _, err = o.command.CancelDeviceAuth(ctx, deviceAuth.AggregateID, domain.DeviceAuthCanceledExpired) + if err != nil { + return nil, err + } + deviceAuth.State = domain.DeviceAuthStateExpired + } + + // When the request is more then initiated, it has been either Approved, Denied or Expired. + // At this point we should remove it from the DB to avoid user code conflicts. + if deviceAuth.State > domain.DeviceAuthStateInitiated { + _, err = o.command.RemoveDeviceAuth(ctx, deviceAuth.AggregateID) + if err != nil { + return nil, err + } + } + + return newDeviceAuthorizationState(deviceAuth), nil +} + +// TODO(muhlemmer): remove the following methods with oidc v3. +// They are actually not used, but are required by the oidc device storage interface. +// https://github.com/zitadel/oidc/issues/371 +func (o *OPStorage) GetDeviceAuthorizationByUserCode(ctx context.Context, userCode string) (*op.DeviceAuthorizationState, error) { + return nil, nil +} + +func (o *OPStorage) CompleteDeviceAuthorization(ctx context.Context, userCode, subject string) (err error) { + return nil +} + +func (o *OPStorage) DenyDeviceAuthorization(ctx context.Context, userCode string) (err error) { + return nil +} + +// TODO end. diff --git a/internal/api/oidc/op.go b/internal/api/oidc/op.go index 48167f402c..9574f561b6 100644 --- a/internal/api/oidc/op.go +++ b/internal/api/oidc/op.go @@ -40,6 +40,7 @@ type Config struct { UserAgentCookieConfig *middleware.UserAgentCookieConfig Cache *middleware.CacheConfig CustomEndpoints *EndpointConfig + DeviceAuth *DeviceAuthorizationConfig } type EndpointConfig struct { @@ -50,6 +51,7 @@ type EndpointConfig struct { Revocation *Endpoint EndSession *Endpoint Keys *Endpoint + DeviceAuth *Endpoint } type Endpoint struct { @@ -108,6 +110,7 @@ func createOPConfig(config Config, defaultLogoutRedirectURI string, cryptoKey [] GrantTypeRefreshToken: config.GrantTypeRefreshToken, RequestObjectSupported: config.RequestObjectSupported, SupportedUILocales: supportedLanguages, + DeviceAuthorization: config.DeviceAuth.toOPConfig(), } if cryptoLength := len(cryptoKey); cryptoLength != 32 { return nil, caos_errs.ThrowInternalf(nil, "OIDC-D43gf", "crypto key must be 32 bytes, but is %d", cryptoLength) @@ -165,6 +168,9 @@ func customEndpoints(endpointConfig *EndpointConfig) []op.Option { if endpointConfig.Keys != nil { options = append(options, op.WithCustomKeysEndpoint(op.NewEndpointWithURL(endpointConfig.Keys.Path, endpointConfig.Keys.URL))) } + if endpointConfig.DeviceAuth != nil { + options = append(options, op.WithCustomDeviceAuthorizationEndpoint(op.NewEndpointWithURL(endpointConfig.DeviceAuth.Path, endpointConfig.DeviceAuth.URL))) + } return options } diff --git a/internal/api/saml/auth_request_converter.go b/internal/api/saml/auth_request_converter.go index a19fed0920..28f2a3c548 100644 --- a/internal/api/saml/auth_request_converter.go +++ b/internal/api/saml/auth_request_converter.go @@ -63,14 +63,6 @@ func (a *AuthRequest) GetUserID() string { func (a *AuthRequest) GetUserName() string { return a.UserName } -func (a *AuthRequest) Done() bool { - for _, step := range a.PossibleSteps { - if step.Type() == domain.NextStepRedirectToCallback { - return true - } - } - return false -} func AuthRequestFromBusiness(authReq *domain.AuthRequest) (_ models.AuthRequestInt, err error) { if _, ok := authReq.Request.(*domain.AuthRequestSAML); !ok { diff --git a/internal/api/ui/login/device_auth.go b/internal/api/ui/login/device_auth.go new file mode 100644 index 0000000000..e2322ee04f --- /dev/null +++ b/internal/api/ui/login/device_auth.go @@ -0,0 +1,201 @@ +package login + +import ( + errs "errors" + "fmt" + "net/http" + "net/url" + "time" + + "github.com/gorilla/mux" + "github.com/muhlemmer/gu" + "github.com/zitadel/logging" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/api/http/middleware" + "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/errors" +) + +const ( + tmplDeviceAuthUserCode = "device-usercode" + tmplDeviceAuthAction = "device-action" +) + +func (l *Login) renderDeviceAuthUserCode(w http.ResponseWriter, r *http.Request, err error) { + var errID, errMessage string + if err != nil { + logging.WithError(err).Error() + errID, errMessage = l.getErrorMessage(r, err) + } + + data := l.getBaseData(r, nil, "DeviceAuth.Title", "DeviceAuth.UserCode.Description", errID, errMessage) + translator := l.getTranslator(r.Context(), nil) + l.renderer.RenderTemplate(w, r, translator, l.renderer.Templates[tmplDeviceAuthUserCode], data, nil) +} + +func (l *Login) renderDeviceAuthAction(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, scopes []string) { + data := &struct { + baseData + AuthRequestID string + Username string + ClientID string + Scopes []string + }{ + baseData: l.getBaseData(r, authReq, "DeviceAuth.Title", "DeviceAuth.Action.Description", "", ""), + AuthRequestID: authReq.ID, + Username: authReq.UserName, + ClientID: authReq.ApplicationID, + Scopes: scopes, + } + + translator := l.getTranslator(r.Context(), authReq) + l.renderer.RenderTemplate(w, r, translator, l.renderer.Templates[tmplDeviceAuthAction], data, nil) +} + +const ( + deviceAuthAllowed = "allowed" + deviceAuthDenied = "denied" +) + +// renderDeviceAuthDone renders success.html when the action was allowed and error.html when it was denied. +func (l *Login) renderDeviceAuthDone(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, action string) { + data := &struct { + baseData + Message string + }{ + baseData: l.getBaseData(r, authReq, "DeviceAuth.Title", "DeviceAuth.Done.Description", "", ""), + } + + translator := l.getTranslator(r.Context(), authReq) + switch action { + case deviceAuthAllowed: + data.Message = translator.LocalizeFromRequest(r, "DeviceAuth.Done.Approved", nil) + l.renderer.RenderTemplate(w, r, translator, l.renderer.Templates[tmplSuccess], data, nil) + case deviceAuthDenied: + data.ErrMessage = translator.LocalizeFromRequest(r, "DeviceAuth.Done.Denied", nil) + l.renderer.RenderTemplate(w, r, translator, l.renderer.Templates[tmplError], data, nil) + } +} + +// handleDeviceUserCode serves the Device Authorization user code submission form. +// The "user_code" may be submitted by URL (GET) or form (POST). +// When a "user_code" is received and found through query, +// handleDeviceAuthUserCode will create a new AuthRequest in the repository. +// The user is then redirected to the /login endpoint to complete authentication. +// +// The agent ID from the context is set to the authentication request +// to ensure the complete login flow is completed from the same browser. +func (l *Login) handleDeviceAuthUserCode(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + err := r.ParseForm() + if err != nil { + w.WriteHeader(http.StatusBadRequest) + l.renderDeviceAuthUserCode(w, r, err) + return + } + userCode := r.Form.Get("user_code") + if userCode == "" { + if prompt, _ := url.QueryUnescape(r.Form.Get("prompt")); prompt != "" { + err = errs.New(prompt) + } + l.renderDeviceAuthUserCode(w, r, err) + return + } + deviceAuth, err := l.query.DeviceAuthByUserCode(ctx, userCode) + if err != nil { + l.renderDeviceAuthUserCode(w, r, err) + return + } + userAgentID, ok := middleware.UserAgentIDFromCtx(ctx) + if !ok { + l.renderDeviceAuthUserCode(w, r, errs.New("internal error: agent ID missing")) + return + } + authRequest, err := l.authRepo.CreateAuthRequest(ctx, &domain.AuthRequest{ + CreationDate: time.Now(), + AgentID: userAgentID, + ApplicationID: deviceAuth.ClientID, + InstanceID: authz.GetInstance(ctx).InstanceID(), + Request: &domain.AuthRequestDevice{ + ID: deviceAuth.AggregateID, + DeviceCode: deviceAuth.DeviceCode, + UserCode: deviceAuth.UserCode, + Scopes: deviceAuth.Scopes, + }, + }) + if err != nil { + l.renderDeviceAuthUserCode(w, r, err) + return + } + + http.Redirect(w, r, l.renderer.pathPrefix+EndpointLogin+"?authRequestID="+authRequest.ID, http.StatusFound) +} + +// redirectDeviceAuthStart redirects the user to the start point of +// the device authorization flow. A prompt can be set to inform the user +// of the reason why they are redirected back. +func (l *Login) redirectDeviceAuthStart(w http.ResponseWriter, r *http.Request, prompt string) { + values := make(url.Values) + values.Set("prompt", url.QueryEscape(prompt)) + + url := url.URL{ + Path: l.renderer.pathPrefix + EndpointDeviceAuth, + RawQuery: values.Encode(), + } + http.Redirect(w, r, url.String(), http.StatusSeeOther) +} + +// handleDeviceAuthAction is the handler where the user is redirected after login. +// The authRequest is checked if the login was indeed completed. +// When the action of "allowed" or "denied", the device authorization is updated accordingly. +// Else the user is presented with a page where they can choose / submit either action. +func (l *Login) handleDeviceAuthAction(w http.ResponseWriter, r *http.Request) { + authReq, err := l.getAuthRequest(r) + if authReq == nil { + err = errors.ThrowInvalidArgument(err, "LOGIN-OLah8", "invalid or missing auth request") + l.redirectDeviceAuthStart(w, r, err.Error()) + return + } + if !authReq.Done() { + l.redirectDeviceAuthStart(w, r, "authentication not completed") + return + } + authDev, ok := authReq.Request.(*domain.AuthRequestDevice) + if !ok { + l.redirectDeviceAuthStart(w, r, fmt.Sprintf("wrong auth request type: %T", authReq.Request)) + return + } + + action := mux.Vars(r)["action"] + switch action { + case deviceAuthAllowed: + _, err = l.command.ApproveDeviceAuth(r.Context(), authDev.ID, authReq.UserID) + case deviceAuthDenied: + _, err = l.command.CancelDeviceAuth(r.Context(), authDev.ID, domain.DeviceAuthCanceledDenied) + default: + l.renderDeviceAuthAction(w, r, authReq, authDev.Scopes) + return + } + if err != nil { + l.redirectDeviceAuthStart(w, r, err.Error()) + return + } + + l.renderDeviceAuthDone(w, r, authReq, action) +} + +// deviceAuthCallbackURL creates the callback URL with which the user +// is redirected back to the device authorization flow. +func (l *Login) deviceAuthCallbackURL(authRequestID string) string { + return l.renderer.pathPrefix + EndpointDeviceAuthAction + "?authRequestID=" + authRequestID +} + +// RedirectDeviceAuthToPrefix allows users to use https://domain.com/device without the /ui/login prefix +// and redirects them to the prefixed endpoint. +// [rfc 8628](https://www.rfc-editor.org/rfc/rfc8628#section-3.2) recommends the URL to be as short as possible. +func RedirectDeviceAuthToPrefix(w http.ResponseWriter, r *http.Request) { + target := gu.PtrCopy(r.URL) + target.Path = HandlerPrefix + EndpointDeviceAuth + http.Redirect(w, r, target.String(), http.StatusFound) +} diff --git a/internal/api/ui/login/login_success_handler.go b/internal/api/ui/login/login_success_handler.go index aad9a67393..f05ee48185 100644 --- a/internal/api/ui/login/login_success_handler.go +++ b/internal/api/ui/login/login_success_handler.go @@ -69,6 +69,8 @@ func (l *Login) authRequestCallback(ctx context.Context, authReq *domain.AuthReq return l.oidcAuthCallbackURL(ctx, authReq.ID), nil case *domain.AuthRequestSAML: return l.samlAuthCallbackURL(ctx, authReq.ID), nil + case *domain.AuthRequestDevice: + return l.deviceAuthCallbackURL(authReq.ID), nil default: return "", caos_errs.ThrowInternal(nil, "LOGIN-rhjQF", "Errors.AuthRequest.RequestTypeNotSupported") } diff --git a/internal/api/ui/login/renderer.go b/internal/api/ui/login/renderer.go index 8266b74d07..a9b12f19dc 100644 --- a/internal/api/ui/login/renderer.go +++ b/internal/api/ui/login/renderer.go @@ -25,7 +25,8 @@ import ( ) const ( - tmplError = "error" + tmplError = "error" + tmplSuccess = "success" ) type Renderer struct { @@ -45,6 +46,7 @@ func CreateRenderer(pathPrefix string, staticDir http.FileSystem, staticStorage } tmplMapping := map[string]string{ tmplError: "error.html", + tmplSuccess: "success.html", tmplLogin: "login.html", tmplUserSelection: "select_user.html", tmplPassword: "password.html", @@ -77,6 +79,8 @@ func CreateRenderer(pathPrefix string, staticDir http.FileSystem, staticStorage tmplExternalNotFoundOption: "external_not_found_option.html", tmplLoginSuccess: "login_success.html", tmplLDAPLogin: "ldap_login.html", + tmplDeviceAuthUserCode: "device_usercode.html", + tmplDeviceAuthAction: "device_action.html", } funcs := map[string]interface{}{ "resourceUrl": func(file string) string { @@ -323,6 +327,7 @@ func (l *Login) chooseNextStep(w http.ResponseWriter, r *http.Request, authReq * func (l *Login) renderInternalError(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, err error) { var msg string if err != nil { + logging.WithError(err).WithField("auth_req_id", authReq.ID).Error() _, msg = l.getErrorMessage(r, err) } data := l.getBaseData(r, authReq, "Errors.Internal", "", "Internal", msg) diff --git a/internal/api/ui/login/router.go b/internal/api/ui/login/router.go index e723cad1ac..8ad27d7573 100644 --- a/internal/api/ui/login/router.go +++ b/internal/api/ui/login/router.go @@ -46,6 +46,9 @@ const ( EndpointResources = "/resources" EndpointDynamicResources = "/resources/dynamic" + + EndpointDeviceAuth = "/device" + EndpointDeviceAuthAction = "/device/{action}" ) var ( @@ -107,5 +110,7 @@ func CreateRouter(login *Login, staticDir http.FileSystem, interceptors ...mux.M router.HandleFunc(EndpointLDAPLogin, login.handleLDAP).Methods(http.MethodGet) router.HandleFunc(EndpointLDAPCallback, login.handleLDAPCallback).Methods(http.MethodPost) router.SkipClean(true).Handle("", http.RedirectHandler(HandlerPrefix+"/", http.StatusMovedPermanently)) + router.HandleFunc(EndpointDeviceAuth, login.handleDeviceAuthUserCode).Methods(http.MethodGet, http.MethodPost) + router.HandleFunc(EndpointDeviceAuthAction, login.handleDeviceAuthAction).Methods(http.MethodGet, http.MethodPost) return router } diff --git a/internal/api/ui/login/static/i18n/de.yaml b/internal/api/ui/login/static/i18n/de.yaml index 5b12002119..96a5d461fa 100644 --- a/internal/api/ui/login/static/i18n/de.yaml +++ b/internal/api/ui/login/static/i18n/de.yaml @@ -317,6 +317,24 @@ ExternalNotFound: Japanese: 日本語 Spanish: Español +DeviceAuth: + Title: Geräteautorisierung + UserCode: + Label: Benutzercode + Description: Geben Sie den auf dem Gerät angezeigten Benutzercode ein + ButtonNext: weiter + Action: + Description: Gerätezugriff erlauben + GrantDevice: Sie sind dabei, das Gerät zu erlauben + AccessToScopes: Zugriff auf die folgenden Daten + Button: + Allow: erlauben + Deny: verweigern + Done: + Description: Abgeschlossen + Approved: Gerätezulassung genehmigt. Sie können jetzt zum Gerät zurückkehren. + Denied: Geräteautorisierung verweigert. Sie können jetzt zum Gerät zurückkehren. + Footer: PoweredBy: Powered By Tos: AGB @@ -425,5 +443,7 @@ Errors: Org: LoginPolicy: RegistrationNotAllowed: Registrierung ist nicht erlaubt + DeviceAuth: + NotExisting: Benutzercode existiert nicht optional: (optional) diff --git a/internal/api/ui/login/static/i18n/en.yaml b/internal/api/ui/login/static/i18n/en.yaml index 14b0d90011..14ac6a5808 100644 --- a/internal/api/ui/login/static/i18n/en.yaml +++ b/internal/api/ui/login/static/i18n/en.yaml @@ -317,6 +317,24 @@ ExternalNotFound: Japanese: 日本語 Spanish: Español +DeviceAuth: + Title: Device Authorization + UserCode: + Label: User Code + Description: Enter the user code presented on the device. + ButtonNext: next + Action: + Description: Grant device access. + GrantDevice: you are about to grant device + AccessToScopes: access to the following scopes + Button: + Allow: allow + Deny: deny + Done: + Description: Done. + Approved: Device authorization approved. You can now return to the device. + Denied: Device authorization denied. You can now return to the device. + Footer: PoweredBy: Powered By Tos: TOS @@ -425,5 +443,7 @@ Errors: Org: LoginPolicy: RegistrationNotAllowed: Registration is not allowed + DeviceAuth: + NotExisting: User Code doesn't exist optional: (optional) diff --git a/internal/api/ui/login/static/i18n/fr.yaml b/internal/api/ui/login/static/i18n/fr.yaml index be50b4327e..3678b555a0 100644 --- a/internal/api/ui/login/static/i18n/fr.yaml +++ b/internal/api/ui/login/static/i18n/fr.yaml @@ -317,6 +317,24 @@ ExternalNotFound: Japanese: 日本語 Spanish: Español +DeviceAuth: + Title: Autorisation de l'appareil + UserCode: + Label: Code d'utilisateur + Description: Saisissez le code utilisateur présenté sur l'appareil. + ButtonNext: suivant + Action: + Description: Accordez l'accès à l'appareil. + GrantDevice: vous êtes sur le point d'accorder un appareil + AccessToScopes: accès aux périmètres suivants + Button: + Allow: permettre + Deny: refuser + Done: + Description: Fait. + Approved: Autorisation de l'appareil approuvée. Vous pouvez maintenant retourner à l'appareil. + Denied: Autorisation de l'appareil refusée. Vous pouvez maintenant retourner à l'appareil. + Footer: PoweredBy: Promulgué par Tos: TOS @@ -425,5 +443,7 @@ Errors: Org: LoginPolicy: RegistrationNotAllowed: L'enregistrement n'est pas autorisé + DeviceAuth: + NotExisting: Le code utilisateur n'existe pas optional: (facultatif) diff --git a/internal/api/ui/login/static/i18n/it.yaml b/internal/api/ui/login/static/i18n/it.yaml index b27e27b072..61bf0e3190 100644 --- a/internal/api/ui/login/static/i18n/it.yaml +++ b/internal/api/ui/login/static/i18n/it.yaml @@ -317,6 +317,24 @@ ExternalNotFound: Japanese: 日本語 Spanish: Español +DeviceAuth: + Title: Autorizzazione del dispositivo + UserCode: + Label: Codice utente + Description: Inserire il codice utente presentato sul dispositivo. + ButtonNext: prossimo + Action: + Description: Concedi l'accesso al dispositivo. + GrantDevice: stai per concedere il dispositivo + AccessToScopes: accesso ai seguenti ambiti + Button: + Allow: permettere + Deny: negare + Done: + Description: Fatto. + Approved: Autorizzazione del dispositivo approvata. Ora puoi tornare al dispositivo. + Denied: Autorizzazione dispositivo negata. Ora puoi tornare al dispositivo. + Footer: PoweredBy: Alimentato da Tos: Termini di servizio @@ -425,5 +443,7 @@ Errors: Org: LoginPolicy: RegistrationNotAllowed: la registrazione non è consentita. + DeviceAuth: + NotExisting: Il codice utente non esiste optional: (opzionale) diff --git a/internal/api/ui/login/static/i18n/ja.yaml b/internal/api/ui/login/static/i18n/ja.yaml index bdb9ec09d7..3a7b964b2b 100644 --- a/internal/api/ui/login/static/i18n/ja.yaml +++ b/internal/api/ui/login/static/i18n/ja.yaml @@ -309,6 +309,24 @@ ExternalNotFound: Japanese: 日本語 Spanish: Español +DeviceAuth: + Title: デバイス認証 + UserCode: + Label: ユーザーコード + Description: デバイスに表示されたユーザー コードを入力します。 + ButtonNext: 次 + Action: + Description: デバイスへのアクセスを許可します。 + GrantDevice: デバイスを許可しようとしています + AccessToScopes: 次のスコープへのアクセス + Button: + Allow: 許可する + Deny: 拒否 + Done: + Description: 終わり。 + Approved: デバイス認証が承認されました。 これで、デバイスに戻ることができます。 + Denied: デバイス認証が拒否されました。 これで、デバイスに戻ることができます。 + Footer: PoweredBy: Powered By Tos: TOS @@ -385,5 +403,7 @@ Errors: IAM: LockoutPolicy: NotExisting: ロックアウトポリシーが存在しません + DeviceAuth: + NotExisting: ユーザーコードが存在しません optional: "(オプション)" diff --git a/internal/api/ui/login/static/i18n/pl.yaml b/internal/api/ui/login/static/i18n/pl.yaml index ca038b2561..894de884bd 100644 --- a/internal/api/ui/login/static/i18n/pl.yaml +++ b/internal/api/ui/login/static/i18n/pl.yaml @@ -317,6 +317,24 @@ ExternalNotFound: Japanese: 日本語 Spanish: Español +DeviceAuth: + Title: Autoryzacja urządzenia + UserCode: + Label: Kod użytkownika + Description: Wprowadź kod użytkownika prezentowany na urządzeniu. + ButtonNext: Następny + Action: + Description: Przyznaj dostęp do urządzenia. + GrantDevice: zamierzasz przyznać urządzenie + AccessToScopes: dostęp do następujących zakresów + Button: + Allow: umożliwić + Deny: zaprzeczyć + Done: + Description: Zrobione. + Approved: Zatwierdzono autoryzację urządzenia. Możesz teraz wrócić do urządzenia. + Denied: Odmowa autoryzacji urządzenia. Możesz teraz wrócić do urządzenia. + Footer: PoweredBy: Obsługiwane przez Tos: TOS @@ -425,5 +443,7 @@ Errors: Org: LoginPolicy: RegistrationNotAllowed: Rejestracja nie jest dozwolona + DeviceAuth: + NotExisting: Kod użytkownika nie istnieje optional: (opcjonalny) diff --git a/internal/api/ui/login/static/i18n/zh.yaml b/internal/api/ui/login/static/i18n/zh.yaml index 830476299f..e2947e6d17 100644 --- a/internal/api/ui/login/static/i18n/zh.yaml +++ b/internal/api/ui/login/static/i18n/zh.yaml @@ -317,6 +317,24 @@ ExternalNotFound: Japanese: 日本語 Spanish: Español +DeviceAuth: + Title: 设备授权 + UserCode: + Label: 用户代码 + Description: 输入设备上显示的用户代码。 + ButtonNext: 下一个 + Action: + Description: 授予设备访问权限。 + GrantDevice: 您即将授予设备 + AccessToScopes: 访问以下范围 + Button: + Allow: 允许 + Deny: 否定 + Done: + Description: 完毕。 + Approved: 设备授权已批准。 您现在可以返回设备。 + Denied: 设备授权被拒绝。 您现在可以返回设备。 + Footer: PoweredBy: Powered By Tos: 服务条款 @@ -425,5 +443,7 @@ Errors: Org: LoginPolicy: RegistrationNotAllowed: 不允许注册 + DeviceAuth: + NotExisting: 用户代码不存在 optional: (可选) diff --git a/internal/api/ui/login/static/templates/device_action.html b/internal/api/ui/login/static/templates/device_action.html new file mode 100644 index 0000000000..4e0cc2801d --- /dev/null +++ b/internal/api/ui/login/static/templates/device_action.html @@ -0,0 +1,18 @@ +{{template "main-top" .}} + +

{{.Title}}

+

+ {{.Username}}, {{t "DeviceAuth.Action.GrantDevice"}} {{.ClientID}} {{t "DeviceAuth.Action.AccessToScopes"}}: {{.Scopes}}. +

+
+ {{ .CSRF }} + + + +
+ +{{template "main-bottom" .}} diff --git a/internal/api/ui/login/static/templates/device_usercode.html b/internal/api/ui/login/static/templates/device_usercode.html new file mode 100644 index 0000000000..5d053cabac --- /dev/null +++ b/internal/api/ui/login/static/templates/device_usercode.html @@ -0,0 +1,21 @@ +{{template "main-top" .}} + +

{{.Title}}

+
+ + {{ .CSRF }} + +
+ + +
+ + {{template "error-message" .}} + +
+ + +
+
+ +{{template "main-bottom" .}} diff --git a/internal/api/ui/login/static/templates/success.html b/internal/api/ui/login/static/templates/success.html new file mode 100644 index 0000000000..bc5042f13b --- /dev/null +++ b/internal/api/ui/login/static/templates/success.html @@ -0,0 +1,12 @@ +{{template "main-top" .}} + +
+
+ +

+ {{ .Message }} +

+
+
+ +{{template "main-bottom" .}} diff --git a/internal/api/ui/login/statik/generate.go b/internal/api/ui/login/statik/generate.go index 75330afad9..5388980631 100644 --- a/internal/api/ui/login/statik/generate.go +++ b/internal/api/ui/login/statik/generate.go @@ -1,3 +1,3 @@ package statik -//go:generate statik -src=../static -dest=.. -ns=login +//go:generate statik -f -src=../static -dest=.. -ns=login diff --git a/internal/auth/repository/eventsourcing/eventstore/auth_request.go b/internal/auth/repository/eventsourcing/eventstore/auth_request.go index 34a30a6abf..84f884620b 100644 --- a/internal/auth/repository/eventsourcing/eventstore/auth_request.go +++ b/internal/auth/repository/eventsourcing/eventstore/auth_request.go @@ -1446,7 +1446,7 @@ func linkingIDPConfigExistingInAllowedIDPs(linkingUsers []*domain.ExternalUser, func userGrantRequired(ctx context.Context, request *domain.AuthRequest, user *user_model.UserView, userGrantProvider userGrantProvider) (_ bool, err error) { var project *query.Project switch request.Request.Type() { - case domain.AuthRequestTypeOIDC, domain.AuthRequestTypeSAML: + case domain.AuthRequestTypeOIDC, domain.AuthRequestTypeSAML, domain.AuthRequestTypeDevice: project, err = userGrantProvider.ProjectByClientID(ctx, request.ApplicationID, false) if err != nil { return false, err @@ -1467,13 +1467,13 @@ func userGrantRequired(ctx context.Context, request *domain.AuthRequest, user *u func projectRequired(ctx context.Context, request *domain.AuthRequest, projectProvider projectProvider) (missingGrant bool, err error) { var project *query.Project switch request.Request.Type() { - case domain.AuthRequestTypeOIDC, domain.AuthRequestTypeSAML: + case domain.AuthRequestTypeOIDC, domain.AuthRequestTypeSAML, domain.AuthRequestTypeDevice: project, err = projectProvider.ProjectByClientID(ctx, request.ApplicationID, false) if err != nil { return false, err } default: - return false, errors.ThrowPreconditionFailed(nil, "EVENT-dfrw2", "Errors.AuthRequest.RequestTypeNotSupported") + return false, errors.ThrowPreconditionFailed(nil, "EVENT-ku4He", "Errors.AuthRequest.RequestTypeNotSupported") } // if the user and project are part of the same organisation we do not need to check if the project exists on that org if !project.HasProjectCheck || project.ResourceOwner == request.UserOrgID { diff --git a/internal/command/device_auth.go b/internal/command/device_auth.go new file mode 100644 index 0000000000..6c3e1a3cfa --- /dev/null +++ b/internal/command/device_auth.go @@ -0,0 +1,113 @@ +package command + +import ( + "context" + "time" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/domain" + caos_errs "github.com/zitadel/zitadel/internal/errors" + "github.com/zitadel/zitadel/internal/eventstore" + "github.com/zitadel/zitadel/internal/repository/deviceauth" +) + +func (c *Commands) AddDeviceAuth(ctx context.Context, clientID, deviceCode, userCode string, expires time.Time, scopes []string) (string, *domain.ObjectDetails, error) { + aggrID, err := c.idGenerator.Next() + if err != nil { + return "", nil, err + } + + aggr := deviceauth.NewAggregate(aggrID, authz.GetInstance(ctx).InstanceID()) + model := NewDeviceAuthWriteModel(aggrID, aggr.ResourceOwner) + + pushedEvents, err := c.eventstore.Push(ctx, deviceauth.NewAddedEvent( + ctx, + aggr, + clientID, + deviceCode, + userCode, + expires, + scopes, + )) + if err != nil { + return "", nil, err + } + err = AppendAndReduce(model, pushedEvents...) + if err != nil { + return "", nil, err + } + + return model.AggregateID, writeModelToObjectDetails(&model.WriteModel), nil +} + +func (c *Commands) ApproveDeviceAuth(ctx context.Context, id, subject string) (*domain.ObjectDetails, error) { + model, err := c.getDeviceAuthWriteModelByID(ctx, id) + if err != nil { + return nil, err + } + if !model.State.Exists() { + return nil, caos_errs.ThrowNotFound(nil, "COMMAND-Hief9", "Errors.DeviceAuth.NotFound") + } + aggr := deviceauth.NewAggregate(model.AggregateID, model.InstanceID) + + pushedEvents, err := c.eventstore.Push(ctx, deviceauth.NewApprovedEvent(ctx, aggr, subject)) + if err != nil { + return nil, err + } + err = AppendAndReduce(model, pushedEvents...) + if err != nil { + return nil, err + } + + return writeModelToObjectDetails(&model.WriteModel), nil +} + +func (c *Commands) CancelDeviceAuth(ctx context.Context, id string, reason domain.DeviceAuthCanceled) (*domain.ObjectDetails, error) { + model, err := c.getDeviceAuthWriteModelByID(ctx, id) + if err != nil { + return nil, err + } + if !model.State.Exists() { + return nil, caos_errs.ThrowNotFound(nil, "COMMAND-gee5A", "Errors.DeviceAuth.NotFound") + } + aggr := deviceauth.NewAggregate(model.AggregateID, model.InstanceID) + + pushedEvents, err := c.eventstore.Push(ctx, deviceauth.NewCanceledEvent(ctx, aggr, reason)) + if err != nil { + return nil, err + } + err = AppendAndReduce(model, pushedEvents...) + if err != nil { + return nil, err + } + + return writeModelToObjectDetails(&model.WriteModel), nil +} + +func (c *Commands) RemoveDeviceAuth(ctx context.Context, id string) (*domain.ObjectDetails, error) { + model, err := c.getDeviceAuthWriteModelByID(ctx, id) + if err != nil { + return nil, err + } + aggr := deviceauth.NewAggregate(model.AggregateID, model.InstanceID) + + pushedEvents, err := c.eventstore.Push(ctx, deviceauth.NewRemovedEvent(ctx, aggr, model.ClientID, model.DeviceCode, model.UserCode)) + if err != nil { + return nil, err + } + err = AppendAndReduce(model, pushedEvents...) + if err != nil { + return nil, err + } + + return writeModelToObjectDetails(&model.WriteModel), nil +} + +func (c *Commands) getDeviceAuthWriteModelByID(ctx context.Context, id string) (*DeviceAuthWriteModel, error) { + model := &DeviceAuthWriteModel{WriteModel: eventstore.WriteModel{AggregateID: id}} + err := c.eventstore.FilterToQueryReducer(ctx, model) + if err != nil { + return nil, err + } + return model, nil +} diff --git a/internal/command/device_auth_model.go b/internal/command/device_auth_model.go new file mode 100644 index 0000000000..2ea52a39ab --- /dev/null +++ b/internal/command/device_auth_model.go @@ -0,0 +1,61 @@ +package command + +import ( + "time" + + "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/eventstore" + "github.com/zitadel/zitadel/internal/repository/deviceauth" +) + +type DeviceAuthWriteModel struct { + eventstore.WriteModel + + ClientID string + DeviceCode string + UserCode string + Expires time.Time + Scopes []string + Subject string + State domain.DeviceAuthState +} + +func NewDeviceAuthWriteModel(aggrID, resourceOwner string) *DeviceAuthWriteModel { + return &DeviceAuthWriteModel{ + WriteModel: eventstore.WriteModel{ + AggregateID: aggrID, + ResourceOwner: resourceOwner, + }, + } +} + +func (m *DeviceAuthWriteModel) Reduce() error { + for _, event := range m.Events { + switch e := event.(type) { + case *deviceauth.AddedEvent: + m.ClientID = e.ClientID + m.DeviceCode = e.DeviceCode + m.UserCode = e.UserCode + m.Expires = e.Expires + m.Scopes = e.Scopes + m.State = e.State + case *deviceauth.ApprovedEvent: + m.Subject = e.Subject + m.State = domain.DeviceAuthStateApproved + case *deviceauth.CanceledEvent: + m.State = e.Reason.State() + case *deviceauth.RemovedEvent: + m.State = domain.DeviceAuthStateRemoved + } + } + + return m.WriteModel.Reduce() +} + +func (m *DeviceAuthWriteModel) Query() *eventstore.SearchQueryBuilder { + return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent). + AddQuery(). + AggregateTypes(deviceauth.AggregateType). + AggregateIDs(m.AggregateID). + Builder() +} diff --git a/internal/command/device_auth_test.go b/internal/command/device_auth_test.go new file mode 100644 index 0000000000..d0d3dd8281 --- /dev/null +++ b/internal/command/device_auth_test.go @@ -0,0 +1,481 @@ +package command + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/domain" + caos_errs "github.com/zitadel/zitadel/internal/errors" + "github.com/zitadel/zitadel/internal/eventstore" + "github.com/zitadel/zitadel/internal/eventstore/repository" + "github.com/zitadel/zitadel/internal/id" + id_mock "github.com/zitadel/zitadel/internal/id/mock" + "github.com/zitadel/zitadel/internal/repository/deviceauth" +) + +func TestCommands_AddDeviceAuth(t *testing.T) { + ctx := authz.WithInstanceID(context.Background(), "instance1") + idErr := errors.New("idErr") + pushErr := errors.New("pushErr") + now := time.Now() + + unique := deviceauth.NewAddUniqueConstraints("client_id", "123", "456") + require.Len(t, unique, 2) + + type fields struct { + eventstore *eventstore.Eventstore + idGenerator id.Generator + } + type args struct { + ctx context.Context + clientID string + deviceCode string + userCode string + expires time.Time + scopes []string + } + tests := []struct { + name string + fields fields + args args + wantID string + wantDetails *domain.ObjectDetails + wantErr error + }{ + { + name: "idGenerator error", + fields: fields{ + eventstore: eventstoreExpect(t), + idGenerator: func() id.Generator { + m := id_mock.NewMockGenerator(gomock.NewController(t)) + m.EXPECT().Next().Return("", idErr) + return m + }(), + }, + args: args{ + ctx: ctx, + clientID: "client_id", + deviceCode: "123", + userCode: "456", + expires: now, + scopes: []string{"a", "b", "c"}, + }, + wantErr: idErr, + }, + { + name: "success", + fields: fields{ + eventstore: eventstoreExpect(t, expectPush( + []*repository.Event{ + eventFromEventPusherWithInstanceID("instance1", deviceauth.NewAddedEvent( + ctx, + deviceauth.NewAggregate("1999", "instance1"), + "client_id", "123", "456", now, + []string{"a", "b", "c"}, + )), + }, + uniqueConstraintsFromEventConstraintWithInstanceID("instance1", unique[0]), + uniqueConstraintsFromEventConstraintWithInstanceID("instance1", unique[1]), + )), + idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "1999"), + }, + args: args{ + ctx: authz.WithInstanceID(context.Background(), "instance1"), + clientID: "client_id", + deviceCode: "123", + userCode: "456", + expires: now, + scopes: []string{"a", "b", "c"}, + }, + wantID: "1999", + wantDetails: &domain.ObjectDetails{ + ResourceOwner: "instance1", + }, + }, + { + name: "push error", + fields: fields{ + eventstore: eventstoreExpect(t, expectPushFailed(pushErr, + []*repository.Event{ + eventFromEventPusherWithInstanceID("instance1", deviceauth.NewAddedEvent( + ctx, + deviceauth.NewAggregate("1999", "instance1"), + "client_id", "123", "456", now, + []string{"a", "b", "c"}, + )), + }, + uniqueConstraintsFromEventConstraintWithInstanceID("instance1", unique[0]), + uniqueConstraintsFromEventConstraintWithInstanceID("instance1", unique[1]), + )), + idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "1999"), + }, + args: args{ + ctx: authz.WithInstanceID(context.Background(), "instance1"), + clientID: "client_id", + deviceCode: "123", + userCode: "456", + expires: now, + scopes: []string{"a", "b", "c"}, + }, + wantErr: pushErr, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Commands{ + eventstore: tt.fields.eventstore, + idGenerator: tt.fields.idGenerator, + } + gotID, gotDetails, err := c.AddDeviceAuth(tt.args.ctx, tt.args.clientID, tt.args.deviceCode, tt.args.userCode, tt.args.expires, tt.args.scopes) + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, gotID, tt.wantID) + assert.Equal(t, gotDetails, tt.wantDetails) + }) + } +} + +func TestCommands_ApproveDeviceAuth(t *testing.T) { + ctx := authz.WithInstanceID(context.Background(), "instance1") + now := time.Now() + pushErr := errors.New("pushErr") + + type fields struct { + eventstore *eventstore.Eventstore + } + type args struct { + ctx context.Context + id string + subject string + } + tests := []struct { + name string + fields fields + args args + wantDetails *domain.ObjectDetails + wantErr error + }{ + { + name: "not found error", + fields: fields{ + eventstore: eventstoreExpect(t, + expectFilter( + eventFromEventPusherWithInstanceID("instance1", + deviceauth.NewAddedEvent( + ctx, + deviceauth.NewAggregate("1999", "instance1"), + "client_id", "123", "456", now, + []string{"a", "b", "c"}, + ), + ), + eventFromEventPusherWithInstanceID("instance1", + deviceauth.NewRemovedEvent( + ctx, + deviceauth.NewAggregate("1999", "instance1"), + "client_id", "123", "456", + ), + ), + ), + ), + }, + args: args{ctx, "1999", "subj"}, + wantErr: caos_errs.ThrowNotFound(nil, "COMMAND-Hief9", "Errors.DeviceAuth.NotFound"), + }, + { + name: "push error", + fields: fields{ + eventstore: eventstoreExpect(t, + expectFilter(eventFromEventPusherWithInstanceID( + "instance1", + deviceauth.NewAddedEvent( + ctx, + deviceauth.NewAggregate("1999", "instance1"), + "client_id", "123", "456", now, + []string{"a", "b", "c"}, + ), + )), + expectPushFailed(pushErr, + []*repository.Event{eventFromEventPusherWithInstanceID( + "instance1", deviceauth.NewApprovedEvent( + ctx, deviceauth.NewAggregate("1999", "instance1"), "subj", + ), + )}, + ), + ), + }, + args: args{ctx, "1999", "subj"}, + wantErr: pushErr, + }, + { + name: "success", + fields: fields{ + eventstore: eventstoreExpect(t, + expectFilter(eventFromEventPusherWithInstanceID( + "instance1", + deviceauth.NewAddedEvent( + ctx, + deviceauth.NewAggregate("1999", "instance1"), + "client_id", "123", "456", now, + []string{"a", "b", "c"}, + ), + )), + expectPush([]*repository.Event{eventFromEventPusherWithInstanceID( + "instance1", deviceauth.NewApprovedEvent( + ctx, deviceauth.NewAggregate("1999", "instance1"), "subj", + ), + )}), + ), + }, + args: args{ctx, "1999", "subj"}, + wantDetails: &domain.ObjectDetails{ + ResourceOwner: "instance1", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Commands{ + eventstore: tt.fields.eventstore, + } + gotDetails, err := c.ApproveDeviceAuth(tt.args.ctx, tt.args.id, tt.args.subject) + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, gotDetails, tt.wantDetails) + }) + } +} + +func TestCommands_CancelDeviceAuth(t *testing.T) { + ctx := authz.WithInstanceID(context.Background(), "instance1") + now := time.Now() + pushErr := errors.New("pushErr") + + type fields struct { + eventstore *eventstore.Eventstore + } + type args struct { + ctx context.Context + id string + reason domain.DeviceAuthCanceled + } + tests := []struct { + name string + fields fields + args args + wantDetails *domain.ObjectDetails + wantErr error + }{ + { + name: "not found error", + fields: fields{ + eventstore: eventstoreExpect(t, + expectFilter( + eventFromEventPusherWithInstanceID("instance1", + deviceauth.NewAddedEvent( + ctx, + deviceauth.NewAggregate("1999", "instance1"), + "client_id", "123", "456", now, + []string{"a", "b", "c"}, + ), + ), + eventFromEventPusherWithInstanceID("instance1", + deviceauth.NewRemovedEvent( + ctx, + deviceauth.NewAggregate("1999", "instance1"), + "client_id", "123", "456", + ), + ), + ), + ), + }, + args: args{ctx, "1999", domain.DeviceAuthCanceledDenied}, + wantErr: caos_errs.ThrowNotFound(nil, "COMMAND-gee5A", "Errors.DeviceAuth.NotFound"), + }, + { + name: "push error", + fields: fields{ + eventstore: eventstoreExpect(t, + expectFilter(eventFromEventPusherWithInstanceID( + "instance1", + deviceauth.NewAddedEvent( + ctx, + deviceauth.NewAggregate("1999", "instance1"), + "client_id", "123", "456", now, + []string{"a", "b", "c"}, + ), + )), + expectPushFailed(pushErr, + []*repository.Event{eventFromEventPusherWithInstanceID( + "instance1", deviceauth.NewCanceledEvent( + ctx, deviceauth.NewAggregate("1999", "instance1"), + domain.DeviceAuthCanceledDenied, + ), + )}, + ), + ), + }, + args: args{ctx, "1999", domain.DeviceAuthCanceledDenied}, + wantErr: pushErr, + }, + { + name: "success/denied", + fields: fields{ + eventstore: eventstoreExpect(t, + expectFilter(eventFromEventPusherWithInstanceID( + "instance1", + deviceauth.NewAddedEvent( + ctx, + deviceauth.NewAggregate("1999", "instance1"), + "client_id", "123", "456", now, + []string{"a", "b", "c"}, + ), + )), + expectPush([]*repository.Event{eventFromEventPusherWithInstanceID( + "instance1", deviceauth.NewCanceledEvent( + ctx, deviceauth.NewAggregate("1999", "instance1"), + domain.DeviceAuthCanceledDenied, + ), + )}), + ), + }, + args: args{ctx, "1999", domain.DeviceAuthCanceledDenied}, + wantDetails: &domain.ObjectDetails{ + ResourceOwner: "instance1", + }, + }, + { + name: "success/expired", + fields: fields{ + eventstore: eventstoreExpect(t, + expectFilter(eventFromEventPusherWithInstanceID( + "instance1", + deviceauth.NewAddedEvent( + ctx, + deviceauth.NewAggregate("1999", "instance1"), + "client_id", "123", "456", now, + []string{"a", "b", "c"}, + ), + )), + expectPush([]*repository.Event{eventFromEventPusherWithInstanceID( + "instance1", deviceauth.NewCanceledEvent( + ctx, deviceauth.NewAggregate("1999", "instance1"), + domain.DeviceAuthCanceledExpired, + ), + )}), + ), + }, + args: args{ctx, "1999", domain.DeviceAuthCanceledExpired}, + wantDetails: &domain.ObjectDetails{ + ResourceOwner: "instance1", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Commands{ + eventstore: tt.fields.eventstore, + } + gotDetails, err := c.CancelDeviceAuth(tt.args.ctx, tt.args.id, tt.args.reason) + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, gotDetails, tt.wantDetails) + }) + } +} + +func TestCommands_RemoveDeviceAuth(t *testing.T) { + ctx := authz.WithInstanceID(context.Background(), "instance1") + now := time.Now() + pushErr := errors.New("pushErr") + + unique := deviceauth.NewRemoveUniqueConstraints("client_id", "123", "456") + require.Len(t, unique, 2) + + type fields struct { + eventstore *eventstore.Eventstore + } + type args struct { + ctx context.Context + id string + } + tests := []struct { + name string + fields fields + args args + wantDetails *domain.ObjectDetails + wantErr error + }{ + { + name: "push error", + fields: fields{ + eventstore: eventstoreExpect(t, + expectFilter(eventFromEventPusherWithInstanceID( + "instance1", + deviceauth.NewAddedEvent( + ctx, + deviceauth.NewAggregate("1999", "instance1"), + "client_id", "123", "456", now, + []string{"a", "b", "c"}, + ), + )), + expectPushFailed(pushErr, + []*repository.Event{eventFromEventPusherWithInstanceID( + "instance1", deviceauth.NewRemovedEvent( + ctx, deviceauth.NewAggregate("1999", "instance1"), + "client_id", "123", "456", + ), + )}, + uniqueConstraintsFromEventConstraintWithInstanceID("instance1", unique[0]), + uniqueConstraintsFromEventConstraintWithInstanceID("instance1", unique[1]), + ), + ), + }, + args: args{ctx, "1999"}, + wantErr: pushErr, + }, + { + name: "success", + fields: fields{ + eventstore: eventstoreExpect(t, + expectFilter(eventFromEventPusherWithInstanceID( + "instance1", + deviceauth.NewAddedEvent( + ctx, + deviceauth.NewAggregate("1999", "instance1"), + "client_id", "123", "456", now, + []string{"a", "b", "c"}, + ), + )), + expectPush( + []*repository.Event{eventFromEventPusherWithInstanceID( + "instance1", deviceauth.NewRemovedEvent( + ctx, deviceauth.NewAggregate("1999", "instance1"), + "client_id", "123", "456", + ), + )}, + uniqueConstraintsFromEventConstraintWithInstanceID("instance1", unique[0]), + uniqueConstraintsFromEventConstraintWithInstanceID("instance1", unique[1]), + ), + ), + }, + args: args{ctx, "1999"}, + wantDetails: &domain.ObjectDetails{ + ResourceOwner: "instance1", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Commands{ + eventstore: tt.fields.eventstore, + } + gotDetails, err := c.RemoveDeviceAuth(tt.args.ctx, tt.args.id) + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, gotDetails, tt.wantDetails) + }) + } +} diff --git a/internal/domain/application_oidc.go b/internal/domain/application_oidc.go index ad08b18d7a..e0ae783e8e 100644 --- a/internal/domain/application_oidc.go +++ b/internal/domain/application_oidc.go @@ -90,6 +90,7 @@ const ( OIDCGrantTypeAuthorizationCode OIDCGrantType = iota OIDCGrantTypeImplicit OIDCGrantTypeRefreshToken + OIDCGrantTypeDeviceCode ) type OIDCApplicationType int32 diff --git a/internal/domain/auth_request.go b/internal/domain/auth_request.go index 84ca2b9ab8..86cd0575f6 100644 --- a/internal/domain/auth_request.go +++ b/internal/domain/auth_request.go @@ -122,6 +122,8 @@ func NewAuthRequestFromType(requestType AuthRequestType) (*AuthRequest, error) { return &AuthRequest{Request: &AuthRequestOIDC{}}, nil case AuthRequestTypeSAML: return &AuthRequest{Request: &AuthRequestSAML{}}, nil + case AuthRequestTypeDevice: + return &AuthRequest{Request: &AuthRequestDevice{}}, nil } return nil, errors.ThrowInvalidArgument(nil, "DOMAIN-ds2kl", "invalid request type") } @@ -184,3 +186,12 @@ func (a *AuthRequest) GetScopeOrgID() string { } return "" } + +func (a *AuthRequest) Done() bool { + for _, step := range a.PossibleSteps { + if step.Type() == NextStepRedirectToCallback { + return true + } + } + return false +} diff --git a/internal/domain/device_auth.go b/internal/domain/device_auth.go new file mode 100644 index 0000000000..79f30250f0 --- /dev/null +++ b/internal/domain/device_auth.go @@ -0,0 +1,78 @@ +package domain + +import ( + "time" + + "github.com/zitadel/zitadel/internal/eventstore/v1/models" +) + +// DeviceAuth describes a Device Authorization request. +// It is used as input and output model in the command and query packages. +type DeviceAuth struct { + models.ObjectRoot + + ClientID string + DeviceCode string + UserCode string + Expires time.Time + Scopes []string + Subject string + State DeviceAuthState +} + +// DeviceAuthState describes the step the +// the device authorization process is in. +// We generate the Stringer implemntation for pretier +// log output. +// +//go:generate stringer -type=DeviceAuthState -linecomment +type DeviceAuthState uint + +const ( + DeviceAuthStateUndefined DeviceAuthState = iota // undefined + DeviceAuthStateInitiated // initiated + DeviceAuthStateApproved // approved + DeviceAuthStateDenied // denied + DeviceAuthStateExpired // expired + DeviceAuthStateRemoved // removed +) + +// Exists returns true when not Undefined and +// any status lower than Removed. +func (s DeviceAuthState) Exists() bool { + return s > DeviceAuthStateUndefined && s < DeviceAuthStateRemoved +} + +// Done returns true when DeviceAuthState is Approved. +// This implements the OIDC interface requirement of "Done" +func (s DeviceAuthState) Done() bool { + return s == DeviceAuthStateApproved +} + +// Denied returns true when DeviceAuthState is Denied, Expired or Removed. +// This implements the OIDC interface requirement of "Denied". +func (s DeviceAuthState) Denied() bool { + return s >= DeviceAuthStateDenied +} + +// DeviceAuthCanceled is a subset of DeviceAuthState, allowed to +// be used in the deviceauth.CanceledEvent. +// The string type is used to make the eventstore more readable +// on the reason of cancelation. +type DeviceAuthCanceled string + +const ( + DeviceAuthCanceledDenied = "denied" + DeviceAuthCanceledExpired = "expired" +) + +func (c DeviceAuthCanceled) State() DeviceAuthState { + switch c { + case DeviceAuthCanceledDenied: + return DeviceAuthStateDenied + case DeviceAuthCanceledExpired: + return DeviceAuthStateExpired + default: + return DeviceAuthStateUndefined + } +} diff --git a/internal/domain/device_auth_test.go b/internal/domain/device_auth_test.go new file mode 100644 index 0000000000..c3fcf359da --- /dev/null +++ b/internal/domain/device_auth_test.go @@ -0,0 +1,158 @@ +package domain + +import ( + "testing" +) + +func TestDeviceAuthState_Exists(t *testing.T) { + tests := []struct { + s DeviceAuthState + want bool + }{ + { + s: DeviceAuthStateUndefined, + want: false, + }, + { + s: DeviceAuthStateInitiated, + want: true, + }, + { + s: DeviceAuthStateApproved, + want: true, + }, + { + s: DeviceAuthStateDenied, + want: true, + }, + { + s: DeviceAuthStateExpired, + want: true, + }, + { + s: DeviceAuthStateRemoved, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.s.String(), func(t *testing.T) { + if got := tt.s.Exists(); got != tt.want { + t.Errorf("DeviceAuthState.Exists() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestDeviceAuthState_Done(t *testing.T) { + tests := []struct { + s DeviceAuthState + want bool + }{ + { + s: DeviceAuthStateUndefined, + want: false, + }, + { + s: DeviceAuthStateInitiated, + want: false, + }, + { + s: DeviceAuthStateApproved, + want: true, + }, + { + s: DeviceAuthStateDenied, + want: false, + }, + { + s: DeviceAuthStateExpired, + want: false, + }, + { + s: DeviceAuthStateRemoved, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.s.String(), func(t *testing.T) { + if got := tt.s.Done(); got != tt.want { + t.Errorf("DeviceAuthState.Done() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestDeviceAuthState_Denied(t *testing.T) { + tests := []struct { + name string + s DeviceAuthState + want bool + }{ + { + s: DeviceAuthStateUndefined, + want: false, + }, + { + s: DeviceAuthStateInitiated, + want: false, + }, + { + s: DeviceAuthStateApproved, + want: false, + }, + { + s: DeviceAuthStateDenied, + want: true, + }, + { + s: DeviceAuthStateExpired, + want: true, + }, + { + s: DeviceAuthStateRemoved, + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.s.Denied(); got != tt.want { + t.Errorf("DeviceAuthState.Denied() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestDeviceAuthCanceled_State(t *testing.T) { + tests := []struct { + name string + c DeviceAuthCanceled + want DeviceAuthState + }{ + { + name: "empty", + want: DeviceAuthStateUndefined, + }, + { + name: "invalid", + c: "foo", + want: DeviceAuthStateUndefined, + }, + { + name: "denied", + c: DeviceAuthCanceledDenied, + want: DeviceAuthStateDenied, + }, + { + name: "expired", + c: DeviceAuthCanceledExpired, + want: DeviceAuthStateExpired, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.c.State(); got != tt.want { + t.Errorf("DeviceAuthCanceled.State() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/domain/deviceauthstate_string.go b/internal/domain/deviceauthstate_string.go new file mode 100644 index 0000000000..b47a6bc7e8 --- /dev/null +++ b/internal/domain/deviceauthstate_string.go @@ -0,0 +1,28 @@ +// Code generated by "stringer -type=DeviceAuthState -linecomment"; DO NOT EDIT. + +package domain + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[DeviceAuthStateUndefined-0] + _ = x[DeviceAuthStateInitiated-1] + _ = x[DeviceAuthStateApproved-2] + _ = x[DeviceAuthStateDenied-3] + _ = x[DeviceAuthStateExpired-4] + _ = x[DeviceAuthStateRemoved-5] +} + +const _DeviceAuthState_name = "undefinedinitiatedapproveddeniedexpiredremoved" + +var _DeviceAuthState_index = [...]uint8{0, 9, 18, 26, 32, 39, 46} + +func (i DeviceAuthState) String() string { + if i >= DeviceAuthState(len(_DeviceAuthState_index)-1) { + return "DeviceAuthState(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _DeviceAuthState_name[_DeviceAuthState_index[i]:_DeviceAuthState_index[i+1]] +} diff --git a/internal/domain/request.go b/internal/domain/request.go index d76ab2cf4f..7f91463921 100644 --- a/internal/domain/request.go +++ b/internal/domain/request.go @@ -22,6 +22,7 @@ type AuthRequestType int32 const ( AuthRequestTypeOIDC AuthRequestType = iota AuthRequestTypeSAML + AuthRequestTypeDevice ) type AuthRequestOIDC struct { @@ -56,3 +57,18 @@ func (a *AuthRequestSAML) Type() AuthRequestType { func (a *AuthRequestSAML) IsValid() bool { return true } + +type AuthRequestDevice struct { + ID string + DeviceCode string + UserCode string + Scopes []string +} + +func (*AuthRequestDevice) Type() AuthRequestType { + return AuthRequestTypeDevice +} + +func (a *AuthRequestDevice) IsValid() bool { + return a.DeviceCode != "" && a.UserCode != "" && len(a.Scopes) > 0 +} diff --git a/internal/eventstore/eventstore.go b/internal/eventstore/eventstore.go index 6e9e47f8e4..216a29bdee 100644 --- a/internal/eventstore/eventstore.go +++ b/internal/eventstore/eventstore.go @@ -304,3 +304,21 @@ func uniqueConstraintActionToRepository(action UniqueConstraintAction) repositor return repository.UniqueConstraintAdd } } + +type BaseEventSetter[T any] interface { + Event + SetBaseEvent(*BaseEvent) + *T +} + +func GenericEventMapper[T any, PT BaseEventSetter[T]](event *repository.Event) (Event, error) { + e := PT(new(T)) + e.SetBaseEvent(BaseEventFromRepo(event)) + + err := json.Unmarshal(event.Data, e) + if err != nil { + return nil, errors.ThrowInternal(err, "V2-Thai6", "unable to unmarshal event") + } + + return e, nil +} diff --git a/internal/eventstore/handler/crdb/handler_stmt.go b/internal/eventstore/handler/crdb/handler_stmt.go index 65eb99426a..8f6d9481f0 100644 --- a/internal/eventstore/handler/crdb/handler_stmt.go +++ b/internal/eventstore/handler/crdb/handler_stmt.go @@ -267,6 +267,7 @@ func (h *StatementHandler) executeStmt(tx *sql.Tx, stmt *handler.Statement) erro } err = stmt.Execute(tx, h.ProjectionName) if err != nil { + logging.WithError(err).Error() _, rollbackErr := tx.Exec("ROLLBACK TO SAVEPOINT push_stmt") if rollbackErr != nil { return errors.ThrowInternal(rollbackErr, "CRDB-zzp3P", "rollback to savepoint failed") diff --git a/internal/eventstore/handler/crdb/init.go b/internal/eventstore/handler/crdb/init.go index 9ca9c34a9d..420b11731a 100644 --- a/internal/eventstore/handler/crdb/init.go +++ b/internal/eventstore/handler/crdb/init.go @@ -377,6 +377,8 @@ func defaultValue(value interface{}) string { switch v := value.(type) { case string: return "'" + v + "'" + case fmt.Stringer: + return fmt.Sprintf("%#v", v) default: return fmt.Sprintf("%v", v) } diff --git a/internal/eventstore/handler/crdb/init_test.go b/internal/eventstore/handler/crdb/init_test.go new file mode 100644 index 0000000000..1e7e6bd823 --- /dev/null +++ b/internal/eventstore/handler/crdb/init_test.go @@ -0,0 +1,49 @@ +package crdb + +import "testing" + +func Test_defaultValue(t *testing.T) { + type args struct { + value interface{} + } + tests := []struct { + name string + args args + want string + }{ + { + name: "string", + args: args{ + value: "asdf", + }, + want: "'asdf'", + }, + { + name: "primitive non string", + args: args{ + value: 1, + }, + want: "1", + }, + { + name: "stringer", + args: args{ + value: testStringer(0), + }, + want: "0", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := defaultValue(tt.args.value); got != tt.want { + t.Errorf("defaultValue() = %v, want %v", got, tt.want) + } + }) + } +} + +type testStringer int + +func (t testStringer) String() string { + return "0529958243" +} diff --git a/internal/query/device_auth.go b/internal/query/device_auth.go new file mode 100644 index 0000000000..98faff200b --- /dev/null +++ b/internal/query/device_auth.go @@ -0,0 +1,141 @@ +package query + +import ( + "context" + "database/sql" + errs "errors" + + sq "github.com/Masterminds/squirrel" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/database" + "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/errors" + "github.com/zitadel/zitadel/internal/query/projection" + "github.com/zitadel/zitadel/internal/telemetry/tracing" +) + +var ( + deviceAuthTable = table{ + name: projection.DeviceAuthProjectionTable, + instanceIDCol: projection.DeviceAuthColumnInstanceID, + } + DeviceAuthColumnID = Column{ + name: projection.DeviceAuthColumnID, + table: deviceAuthTable, + } + DeviceAuthColumnClientID = Column{ + name: projection.DeviceAuthColumnClientID, + table: deviceAuthTable, + } + DeviceAuthColumnDeviceCode = Column{ + name: projection.DeviceAuthColumnDeviceCode, + table: deviceAuthTable, + } + DeviceAuthColumnUserCode = Column{ + name: projection.DeviceAuthColumnUserCode, + table: deviceAuthTable, + } + DeviceAuthColumnExpires = Column{ + name: projection.DeviceAuthColumnExpires, + table: deviceAuthTable, + } + DeviceAuthColumnScopes = Column{ + name: projection.DeviceAuthColumnScopes, + table: deviceAuthTable, + } + DeviceAuthColumnState = Column{ + name: projection.DeviceAuthColumnState, + table: deviceAuthTable, + } + DeviceAuthColumnSubject = Column{ + name: projection.DeviceAuthColumnSubject, + table: deviceAuthTable, + } + DeviceAuthColumnCreationDate = Column{ + name: projection.DeviceAuthColumnCreationDate, + table: deviceAuthTable, + } + DeviceAuthColumnChangeDate = Column{ + name: projection.DeviceAuthColumnChangeDate, + table: deviceAuthTable, + } + DeviceAuthColumnSequence = Column{ + name: projection.DeviceAuthColumnSequence, + table: deviceAuthTable, + } + DeviceAuthColumnInstanceID = Column{ + name: projection.DeviceAuthColumnInstanceID, + table: deviceAuthTable, + } +) + +func (q *Queries) DeviceAuthByDeviceCode(ctx context.Context, clientID, deviceCode string) (_ *domain.DeviceAuth, err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() + + stmt, scan := prepareDeviceAuthQuery(ctx, q.client) + eq := sq.Eq{ + DeviceAuthColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), + DeviceAuthColumnClientID.identifier(): clientID, + DeviceAuthColumnDeviceCode.identifier(): deviceCode, + } + query, args, err := stmt.Where(eq).ToSql() + if err != nil { + return nil, errors.ThrowInternal(err, "QUERY-uk1Oh", "Errors.Query.SQLStatement") + } + + return scan(q.client.QueryRowContext(ctx, query, args...)) +} + +func (q *Queries) DeviceAuthByUserCode(ctx context.Context, userCode string) (_ *domain.DeviceAuth, err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() + + stmt, scan := prepareDeviceAuthQuery(ctx, q.client) + eq := sq.Eq{ + DeviceAuthColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), + DeviceAuthColumnUserCode.identifier(): userCode, + } + query, args, err := stmt.Where(eq).ToSql() + if err != nil { + return nil, errors.ThrowInternal(err, "QUERY-Axu7l", "Errors.Query.SQLStatement") + } + + return scan(q.client.QueryRowContext(ctx, query, args...)) +} + +var deviceAuthSelectColumns = []string{ + DeviceAuthColumnID.identifier(), + DeviceAuthColumnClientID.identifier(), + DeviceAuthColumnScopes.identifier(), + DeviceAuthColumnExpires.identifier(), + DeviceAuthColumnState.identifier(), + DeviceAuthColumnSubject.identifier(), +} + +func prepareDeviceAuthQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*domain.DeviceAuth, error)) { + return sq.Select(deviceAuthSelectColumns...).From(deviceAuthTable.identifier()).PlaceholderFormat(sq.Dollar), + func(row *sql.Row) (*domain.DeviceAuth, error) { + dst := new(domain.DeviceAuth) + var scopes database.StringArray + + err := row.Scan( + &dst.AggregateID, + &dst.ClientID, + &scopes, + &dst.Expires, + &dst.State, + &dst.Subject, + ) + if errs.Is(err, sql.ErrNoRows) { + return nil, errors.ThrowNotFound(err, "QUERY-Sah9a", "Errors.DeviceAuth.NotExisting") + } + if err != nil { + return nil, errors.ThrowInternal(err, "QUERY-Voo3o", "Errors.Internal") + } + + dst.Scopes = scopes + return dst, nil + } +} diff --git a/internal/query/device_auth_test.go b/internal/query/device_auth_test.go new file mode 100644 index 0000000000..938cb9f844 --- /dev/null +++ b/internal/query/device_auth_test.go @@ -0,0 +1,158 @@ +package query + +import ( + "context" + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "regexp" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/zitadel/zitadel/internal/database" + "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/eventstore/v1/models" +) + +const ( + expectedDeviceAuthQueryC = `SELECT` + + ` projections.device_authorizations.id,` + + ` projections.device_authorizations.client_id,` + + ` projections.device_authorizations.scopes,` + + ` projections.device_authorizations.expires,` + + ` projections.device_authorizations.state,` + + ` projections.device_authorizations.subject` + + ` FROM projections.device_authorizations` + expectedDeviceAuthWhereDeviceCodeQueryC = expectedDeviceAuthQueryC + + ` WHERE projections.device_authorizations.client_id = $1` + + ` AND projections.device_authorizations.device_code = $2` + + ` AND projections.device_authorizations.instance_id = $3` + expectedDeviceAuthWhereUserCodeQueryC = expectedDeviceAuthQueryC + + ` WHERE projections.device_authorizations.instance_id = $1` + + ` AND projections.device_authorizations.user_code = $2` +) + +var ( + expectedDeviceAuthQuery = regexp.QuoteMeta(expectedDeviceAuthQueryC) + expectedDeviceAuthWhereDeviceCodeQuery = regexp.QuoteMeta(expectedDeviceAuthWhereDeviceCodeQueryC) + expectedDeviceAuthWhereUserCodeQuery = regexp.QuoteMeta(expectedDeviceAuthWhereUserCodeQueryC) + expectedDeviceAuthValues = []driver.Value{ + "primary-id", + "client-id", + database.StringArray{"a", "b", "c"}, + testNow, + domain.DeviceAuthStateApproved, + "subject", + } + expectedDeviceAuth = &domain.DeviceAuth{ + ObjectRoot: models.ObjectRoot{ + AggregateID: "primary-id", + }, + ClientID: "client-id", + Scopes: []string{"a", "b", "c"}, + Expires: testNow, + State: domain.DeviceAuthStateApproved, + Subject: "subject", + } +) + +func TestQueries_DeviceAuthByDeviceCode(t *testing.T) { + client, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("failed to build mock client: %v", err) + } + defer client.Close() + + mock.ExpectQuery(expectedDeviceAuthWhereDeviceCodeQuery).WillReturnRows( + sqlmock.NewRows(deviceAuthSelectColumns).AddRow(expectedDeviceAuthValues...), + ) + q := Queries{ + client: &database.DB{DB: client}, + } + got, err := q.DeviceAuthByDeviceCode(context.TODO(), "123", "456") + require.NoError(t, err) + assert.Equal(t, expectedDeviceAuth, got) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestQueries_DeviceAuthByUserCode(t *testing.T) { + client, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("failed to build mock client: %v", err) + } + defer client.Close() + + mock.ExpectQuery(expectedDeviceAuthWhereUserCodeQuery).WillReturnRows( + sqlmock.NewRows(deviceAuthSelectColumns).AddRow(expectedDeviceAuthValues...), + ) + q := Queries{ + client: &database.DB{DB: client}, + } + got, err := q.DeviceAuthByUserCode(context.TODO(), "789") + require.NoError(t, err) + assert.Equal(t, expectedDeviceAuth, got) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func Test_prepareDeviceAuthQuery(t *testing.T) { + type want struct { + sqlExpectations sqlExpectation + err checkErr + } + tests := []struct { + name string + want want + object any + }{ + { + name: "success", + want: want{ + sqlExpectations: mockQueries( + expectedDeviceAuthQuery, + deviceAuthSelectColumns, + [][]driver.Value{expectedDeviceAuthValues}, + ), + }, + object: expectedDeviceAuth, + }, + { + name: "not found error", + want: want{ + sqlExpectations: mockQueryErr( + expectedDeviceAuthQuery, + sql.ErrNoRows, + ), + err: func(err error) (error, bool) { + if !errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("err should be sql.ErrNoRows got: %w", err), false + } + return nil, true + }, + }, + }, + { + name: "other error", + want: want{ + sqlExpectations: mockQueryErr( + expectedDeviceAuthQuery, + sql.ErrConnDone, + ), + err: func(err error) (error, bool) { + if !errors.Is(err, sql.ErrConnDone) { + return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false + } + return nil, true + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assertPrepare(t, prepareDeviceAuthQuery, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + }) + } +} diff --git a/internal/query/projection/device_auth.go b/internal/query/projection/device_auth.go new file mode 100644 index 0000000000..c678dbd301 --- /dev/null +++ b/internal/query/projection/device_auth.go @@ -0,0 +1,161 @@ +package projection + +import ( + "context" + + "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/errors" + "github.com/zitadel/zitadel/internal/eventstore" + "github.com/zitadel/zitadel/internal/eventstore/handler" + "github.com/zitadel/zitadel/internal/eventstore/handler/crdb" + "github.com/zitadel/zitadel/internal/repository/deviceauth" +) + +const ( + DeviceAuthProjectionTable = "projections.device_authorizations" + + DeviceAuthColumnID = "id" + DeviceAuthColumnClientID = "client_id" + DeviceAuthColumnDeviceCode = "device_code" + DeviceAuthColumnUserCode = "user_code" + DeviceAuthColumnExpires = "expires" + DeviceAuthColumnScopes = "scopes" + DeviceAuthColumnState = "state" + DeviceAuthColumnSubject = "subject" + + DeviceAuthColumnCreationDate = "creation_date" + DeviceAuthColumnChangeDate = "change_date" + DeviceAuthColumnSequence = "sequence" + DeviceAuthColumnInstanceID = "instance_id" +) + +type deviceAuthProjection struct { + crdb.StatementHandler +} + +func newDeviceAuthProjection(ctx context.Context, config crdb.StatementHandlerConfig) *deviceAuthProjection { + p := new(deviceAuthProjection) + config.ProjectionName = DeviceAuthProjectionTable + config.Reducers = p.reducers() + config.InitCheck = crdb.NewTableCheck( + crdb.NewTable([]*crdb.Column{ + crdb.NewColumn(DeviceAuthColumnID, crdb.ColumnTypeText), + crdb.NewColumn(DeviceAuthColumnClientID, crdb.ColumnTypeText), + crdb.NewColumn(DeviceAuthColumnDeviceCode, crdb.ColumnTypeText), + crdb.NewColumn(DeviceAuthColumnUserCode, crdb.ColumnTypeText), + crdb.NewColumn(DeviceAuthColumnExpires, crdb.ColumnTypeTimestamp), + crdb.NewColumn(DeviceAuthColumnScopes, crdb.ColumnTypeTextArray), + crdb.NewColumn(DeviceAuthColumnState, crdb.ColumnTypeEnum, crdb.Default(domain.DeviceAuthStateInitiated)), + crdb.NewColumn(DeviceAuthColumnSubject, crdb.ColumnTypeText, crdb.Default("")), + crdb.NewColumn(DeviceAuthColumnCreationDate, crdb.ColumnTypeTimestamp), + crdb.NewColumn(DeviceAuthColumnChangeDate, crdb.ColumnTypeTimestamp), + crdb.NewColumn(DeviceAuthColumnSequence, crdb.ColumnTypeInt64), + crdb.NewColumn(DeviceAuthColumnInstanceID, crdb.ColumnTypeText), + }, + crdb.NewPrimaryKey(DeviceAuthColumnInstanceID, DeviceAuthColumnID), + crdb.WithIndex(crdb.NewIndex("user_code", []string{DeviceAuthColumnInstanceID, DeviceAuthColumnUserCode})), + crdb.WithIndex(crdb.NewIndex("device_code", []string{DeviceAuthColumnInstanceID, DeviceAuthColumnClientID, DeviceAuthColumnDeviceCode})), + ), + ) + + p.StatementHandler = crdb.NewStatementHandler(ctx, config) + return p +} + +func (p *deviceAuthProjection) reducers() []handler.AggregateReducer { + return []handler.AggregateReducer{ + { + Aggregate: deviceauth.AggregateType, + EventRedusers: []handler.EventReducer{ + { + Event: deviceauth.AddedEventType, + Reduce: p.reduceAdded, + }, + { + Event: deviceauth.ApprovedEventType, + Reduce: p.reduceAppoved, + }, + { + Event: deviceauth.CanceledEventType, + Reduce: p.reduceCanceled, + }, + { + Event: deviceauth.RemovedEventType, + Reduce: p.reduceRemoved, + }, + }, + }, + } +} + +func (p *deviceAuthProjection) reduceAdded(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*deviceauth.AddedEvent) + if !ok { + return nil, errors.ThrowInvalidArgumentf(nil, "HANDL-chu6O", "reduce.wrong.event.type %T != %s", event, deviceauth.AddedEventType) + } + return crdb.NewCreateStatement( + e, + []handler.Column{ + handler.NewCol(DeviceAuthColumnID, e.Aggregate().ID), + handler.NewCol(DeviceAuthColumnClientID, e.ClientID), + handler.NewCol(DeviceAuthColumnDeviceCode, e.DeviceCode), + handler.NewCol(DeviceAuthColumnUserCode, e.UserCode), + handler.NewCol(DeviceAuthColumnExpires, e.Expires), + handler.NewCol(DeviceAuthColumnScopes, e.Scopes), + handler.NewCol(DeviceAuthColumnCreationDate, e.CreationDate()), + handler.NewCol(DeviceAuthColumnChangeDate, e.CreationDate()), + handler.NewCol(DeviceAuthColumnSequence, e.Sequence()), + handler.NewCol(DeviceAuthColumnInstanceID, e.Aggregate().InstanceID), + }, + ), nil +} + +func (p *deviceAuthProjection) reduceAppoved(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*deviceauth.ApprovedEvent) + if !ok { + return nil, errors.ThrowInvalidArgumentf(nil, "HANDL-kei0A", "reduce.wrong.event.type %T != %s", event, deviceauth.ApprovedEventType) + } + return crdb.NewUpdateStatement(e, + []handler.Column{ + handler.NewCol(DeviceAuthColumnState, domain.DeviceAuthStateApproved), + handler.NewCol(DeviceAuthColumnSubject, e.Subject), + handler.NewCol(DeviceAuthColumnChangeDate, e.CreationDate()), + handler.NewCol(DeviceAuthColumnSequence, e.Sequence()), + }, + []handler.Condition{ + handler.NewCond(DeviceAuthColumnInstanceID, e.Aggregate().InstanceID), + handler.NewCond(DeviceAuthColumnID, e.Aggregate().ID), + }, + ), nil +} + +func (p *deviceAuthProjection) reduceCanceled(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*deviceauth.CanceledEvent) + if !ok { + return nil, errors.ThrowInvalidArgumentf(nil, "HANDL-eeS8d", "reduce.wrong.event.type %T != %s", event, deviceauth.CanceledEventType) + } + return crdb.NewUpdateStatement(e, + []handler.Column{ + handler.NewCol(DeviceAuthColumnState, e.Reason.State()), + handler.NewCol(DeviceAuthColumnChangeDate, e.CreationDate()), + handler.NewCol(DeviceAuthColumnSequence, e.Sequence()), + }, + []handler.Condition{ + handler.NewCond(DeviceAuthColumnInstanceID, e.Aggregate().InstanceID), + handler.NewCond(DeviceAuthColumnID, e.Aggregate().ID), + }, + ), nil +} + +func (p *deviceAuthProjection) reduceRemoved(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*deviceauth.RemovedEvent) + if !ok { + return nil, errors.ThrowInvalidArgumentf(nil, "HANDL-AJi1u", "reduce.wrong.event.type %T != %s", event, deviceauth.RemovedEventType) + } + return crdb.NewDeleteStatement(e, + []handler.Condition{ + handler.NewCond(DeviceAuthColumnInstanceID, e.Aggregate().InstanceID), + handler.NewCond(DeviceAuthColumnID, e.Aggregate().ID), + }, + ), nil +} diff --git a/internal/query/projection/projection.go b/internal/query/projection/projection.go index e3c5cd4f71..fb461927c2 100644 --- a/internal/query/projection/projection.go +++ b/internal/query/projection/projection.go @@ -64,6 +64,7 @@ var ( NotificationPolicyProjection *notificationPolicyProjection NotificationsProjection interface{} NotificationsQuotaProjection interface{} + DeviceAuthProjection *deviceAuthProjection ) type projection interface { @@ -139,6 +140,7 @@ func Create(ctx context.Context, sqlClient *database.DB, es *eventstore.Eventsto KeyProjection = newKeyProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["keys"]), keyEncryptionAlgorithm, certEncryptionAlgorithm) SecurityPolicyProjection = newSecurityPolicyProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["security_policies"])) NotificationPolicyProjection = newNotificationPolicyProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["notification_policies"])) + DeviceAuthProjection = newDeviceAuthProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["device_auth"])) newProjectionsList() return nil } @@ -234,5 +236,6 @@ func newProjectionsList() { KeyProjection, SecurityPolicyProjection, NotificationPolicyProjection, + DeviceAuthProjection, } } diff --git a/internal/repository/deviceauth/aggregate.go b/internal/repository/deviceauth/aggregate.go new file mode 100644 index 0000000000..da3645d112 --- /dev/null +++ b/internal/repository/deviceauth/aggregate.go @@ -0,0 +1,19 @@ +package deviceauth + +import "github.com/zitadel/zitadel/internal/eventstore" + +const ( + AggregateType = "device_auth" + AggregateVersion = "v1" +) + +func NewAggregate(aggrID, instanceID string) *eventstore.Aggregate { + return &eventstore.Aggregate{ + ID: aggrID, + Type: AggregateType, + // we use the id because we don't know the resource owner yet + ResourceOwner: instanceID, + InstanceID: instanceID, + Version: AggregateVersion, + } +} diff --git a/internal/repository/deviceauth/constraints.go b/internal/repository/deviceauth/constraints.go new file mode 100644 index 0000000000..679220524c --- /dev/null +++ b/internal/repository/deviceauth/constraints.go @@ -0,0 +1,46 @@ +package deviceauth + +import ( + "strings" + + "github.com/zitadel/zitadel/internal/eventstore" +) + +const ( + UniqueUserCode = "user_code" + UniqueDeviceCode = "device_code" + DuplicateUserCode = "Errors.DeviceUserCode.AlreadyExists" + DuplicateDeviceCode = "Errors.DeviceCode.AlreadyExists" +) + +func deviceCodeUniqueField(clientID, deviceCode string) string { + return strings.Join([]string{clientID, deviceCode}, ":") +} + +func NewAddUniqueConstraints(clientID, deviceCode, userCode string) []*eventstore.EventUniqueConstraint { + return []*eventstore.EventUniqueConstraint{ + eventstore.NewAddEventUniqueConstraint( + UniqueDeviceCode, + deviceCodeUniqueField(clientID, deviceCode), + DuplicateDeviceCode, + ), + eventstore.NewAddEventUniqueConstraint( + UniqueUserCode, + userCode, + DuplicateUserCode, + ), + } +} + +func NewRemoveUniqueConstraints(clientID, deviceCode, userCode string) []*eventstore.EventUniqueConstraint { + return []*eventstore.EventUniqueConstraint{ + eventstore.NewRemoveEventUniqueConstraint( + UniqueDeviceCode, + deviceCodeUniqueField(clientID, deviceCode), + ), + eventstore.NewRemoveEventUniqueConstraint( + UniqueUserCode, + userCode, + ), + } +} diff --git a/internal/repository/deviceauth/device_auth.go b/internal/repository/deviceauth/device_auth.go new file mode 100644 index 0000000000..0ece3e78f9 --- /dev/null +++ b/internal/repository/deviceauth/device_auth.go @@ -0,0 +1,141 @@ +package deviceauth + +import ( + "context" + "time" + + "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/eventstore" +) + +const ( + eventTypePrefix eventstore.EventType = "device.authorization." + AddedEventType = eventTypePrefix + "added" + ApprovedEventType = eventTypePrefix + "approved" + CanceledEventType = eventTypePrefix + "canceled" + RemovedEventType = eventTypePrefix + "removed" +) + +type AddedEvent struct { + *eventstore.BaseEvent + + ClientID string + DeviceCode string + UserCode string + Expires time.Time + Scopes []string + State domain.DeviceAuthState +} + +func (e *AddedEvent) SetBaseEvent(b *eventstore.BaseEvent) { + e.BaseEvent = b +} + +func (e *AddedEvent) Data() any { + return e +} + +func (e *AddedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint { + return NewAddUniqueConstraints(e.ClientID, e.DeviceCode, e.UserCode) +} + +func NewAddedEvent( + ctx context.Context, + aggregate *eventstore.Aggregate, + clientID string, + deviceCode string, + userCode string, + expires time.Time, + scopes []string, +) *AddedEvent { + return &AddedEvent{ + eventstore.NewBaseEventForPush( + ctx, aggregate, AddedEventType, + ), + clientID, deviceCode, userCode, expires, scopes, domain.DeviceAuthStateInitiated} +} + +type ApprovedEvent struct { + *eventstore.BaseEvent + + Subject string +} + +func (e *ApprovedEvent) SetBaseEvent(b *eventstore.BaseEvent) { + e.BaseEvent = b +} + +func (e *ApprovedEvent) Data() any { + return e +} + +func (e *ApprovedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint { + return nil +} + +func NewApprovedEvent( + ctx context.Context, + aggregate *eventstore.Aggregate, + subject string, +) *ApprovedEvent { + return &ApprovedEvent{ + eventstore.NewBaseEventForPush( + ctx, aggregate, ApprovedEventType, + ), + subject, + } +} + +type CanceledEvent struct { + *eventstore.BaseEvent + Reason domain.DeviceAuthCanceled +} + +func (e *CanceledEvent) SetBaseEvent(b *eventstore.BaseEvent) { + e.BaseEvent = b +} + +func (e *CanceledEvent) Data() any { + return e +} + +func (e *CanceledEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint { + return nil +} + +func NewCanceledEvent(ctx context.Context, aggregate *eventstore.Aggregate, reason domain.DeviceAuthCanceled) *CanceledEvent { + return &CanceledEvent{eventstore.NewBaseEventForPush(ctx, aggregate, CanceledEventType), reason} +} + +type RemovedEvent struct { + *eventstore.BaseEvent + + ClientID string + DeviceCode string + UserCode string +} + +func (e *RemovedEvent) SetBaseEvent(b *eventstore.BaseEvent) { + e.BaseEvent = b +} + +func (e *RemovedEvent) Data() any { + return e +} + +func (e *RemovedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint { + return NewRemoveUniqueConstraints(e.ClientID, e.DeviceCode, e.UserCode) +} + +func NewRemovedEvent( + ctx context.Context, + aggregate *eventstore.Aggregate, + clientID, deviceCode, userCode string, +) *RemovedEvent { + return &RemovedEvent{ + eventstore.NewBaseEventForPush( + ctx, aggregate, RemovedEventType, + ), + clientID, deviceCode, userCode, + } +} diff --git a/internal/repository/org/eventstore.go b/internal/repository/org/eventstore.go index fb85ca86ee..662bf77b4b 100644 --- a/internal/repository/org/eventstore.go +++ b/internal/repository/org/eventstore.go @@ -2,6 +2,7 @@ package org import ( "github.com/zitadel/zitadel/internal/eventstore" + "github.com/zitadel/zitadel/internal/repository/deviceauth" ) func RegisterEventMappers(es *eventstore.Eventstore) { @@ -107,5 +108,9 @@ func RegisterEventMappers(es *eventstore.Eventstore) { RegisterFilterEventMapper(AggregateType, MetadataRemovedAllType, MetadataRemovedAllEventMapper). RegisterFilterEventMapper(AggregateType, NotificationPolicyAddedEventType, NotificationPolicyAddedEventMapper). RegisterFilterEventMapper(AggregateType, NotificationPolicyChangedEventType, NotificationPolicyChangedEventMapper). - RegisterFilterEventMapper(AggregateType, NotificationPolicyRemovedEventType, NotificationPolicyRemovedEventMapper) + RegisterFilterEventMapper(AggregateType, NotificationPolicyRemovedEventType, NotificationPolicyRemovedEventMapper). + RegisterFilterEventMapper(AggregateType, deviceauth.AddedEventType, eventstore.GenericEventMapper[deviceauth.AddedEvent]). + RegisterFilterEventMapper(AggregateType, deviceauth.ApprovedEventType, eventstore.GenericEventMapper[deviceauth.ApprovedEvent]). + RegisterFilterEventMapper(AggregateType, deviceauth.CanceledEventType, eventstore.GenericEventMapper[deviceauth.CanceledEvent]). + RegisterFilterEventMapper(AggregateType, deviceauth.RemovedEventType, eventstore.GenericEventMapper[deviceauth.RemovedEvent]) } diff --git a/proto/zitadel/app.proto b/proto/zitadel/app.proto index d889135bc4..dcd5da5c25 100644 --- a/proto/zitadel/app.proto +++ b/proto/zitadel/app.proto @@ -180,6 +180,7 @@ enum OIDCGrantType{ OIDC_GRANT_TYPE_AUTHORIZATION_CODE = 0; OIDC_GRANT_TYPE_IMPLICIT = 1; OIDC_GRANT_TYPE_REFRESH_TOKEN = 2; + OIDC_GRANT_TYPE_DEVICE_CODE = 3; } enum OIDCAppType {