mirror of
https://github.com/zitadel/zitadel.git
synced 2025-10-15 22:31:25 +00:00
Merge branch 'rc' into next-rc
This commit is contained in:
@@ -12,6 +12,7 @@ import (
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/health"
|
||||
healthpb "google.golang.org/grpc/health/grpc_health_v1"
|
||||
"google.golang.org/grpc/reflection"
|
||||
|
||||
internal_authz "github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/api/grpc/server"
|
||||
@@ -19,24 +20,22 @@ import (
|
||||
http_mw "github.com/zitadel/zitadel/internal/api/http/middleware"
|
||||
"github.com/zitadel/zitadel/internal/api/ui/login"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/logstore"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/metrics"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
|
||||
type API struct {
|
||||
port uint16
|
||||
grpcServer *grpc.Server
|
||||
verifier *internal_authz.TokenVerifier
|
||||
health healthCheck
|
||||
router *mux.Router
|
||||
http1HostName string
|
||||
grpcGateway *server.Gateway
|
||||
healthServer *health.Server
|
||||
cookieHandler *http_util.CookieHandler
|
||||
cookieConfig *http_mw.AccessConfig
|
||||
queries *query.Queries
|
||||
port uint16
|
||||
grpcServer *grpc.Server
|
||||
verifier *internal_authz.TokenVerifier
|
||||
health healthCheck
|
||||
router *mux.Router
|
||||
http1HostName string
|
||||
grpcGateway *server.Gateway
|
||||
healthServer *health.Server
|
||||
accessInterceptor *http_mw.AccessInterceptor
|
||||
queries *query.Queries
|
||||
}
|
||||
|
||||
type healthCheck interface {
|
||||
@@ -51,23 +50,20 @@ func New(
|
||||
verifier *internal_authz.TokenVerifier,
|
||||
authZ internal_authz.Config,
|
||||
tlsConfig *tls.Config, http2HostName, http1HostName string,
|
||||
accessSvc *logstore.Service,
|
||||
cookieHandler *http_util.CookieHandler,
|
||||
cookieConfig *http_mw.AccessConfig,
|
||||
accessInterceptor *http_mw.AccessInterceptor,
|
||||
) (_ *API, err error) {
|
||||
api := &API{
|
||||
port: port,
|
||||
verifier: verifier,
|
||||
health: queries,
|
||||
router: router,
|
||||
http1HostName: http1HostName,
|
||||
cookieConfig: cookieConfig,
|
||||
cookieHandler: cookieHandler,
|
||||
queries: queries,
|
||||
port: port,
|
||||
verifier: verifier,
|
||||
health: queries,
|
||||
router: router,
|
||||
http1HostName: http1HostName,
|
||||
queries: queries,
|
||||
accessInterceptor: accessInterceptor,
|
||||
}
|
||||
|
||||
api.grpcServer = server.CreateServer(api.verifier, authZ, queries, http2HostName, tlsConfig, accessSvc)
|
||||
api.grpcGateway, err = server.CreateGateway(ctx, port, http1HostName, cookieHandler, cookieConfig)
|
||||
api.grpcServer = server.CreateServer(api.verifier, authZ, queries, http2HostName, tlsConfig, accessInterceptor.AccessService())
|
||||
api.grpcGateway, err = server.CreateGateway(ctx, port, http1HostName, accessInterceptor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -76,6 +72,7 @@ func New(
|
||||
api.RegisterHandlerOnPrefix("/debug", api.healthHandler())
|
||||
api.router.Handle("/", http.RedirectHandler(login.HandlerPrefix, http.StatusFound))
|
||||
|
||||
reflection.Register(api.grpcServer)
|
||||
return api, nil
|
||||
}
|
||||
|
||||
@@ -90,8 +87,7 @@ func (a *API) RegisterServer(ctx context.Context, grpcServer server.WithGatewayP
|
||||
grpcServer,
|
||||
a.port,
|
||||
a.http1HostName,
|
||||
a.cookieHandler,
|
||||
a.cookieConfig,
|
||||
a.accessInterceptor,
|
||||
a.queries,
|
||||
)
|
||||
if err != nil {
|
||||
|
16
internal/api/authz/user.go
Normal file
16
internal/api/authz/user.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package authz
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
)
|
||||
|
||||
// UserIDInCTX checks if the userID
|
||||
// equals the authenticated user in the context.
|
||||
func UserIDInCTX(ctx context.Context, userID string) error {
|
||||
if GetCtxData(ctx).UserID != userID {
|
||||
return errors.ThrowUnauthenticated(nil, "AUTH-Bohd2", "Errors.User.UserIDWrong")
|
||||
}
|
||||
return nil
|
||||
}
|
@@ -629,7 +629,7 @@ func (s *Server) importData(ctx context.Context, orgs []*admin_pb.DataOrg) (*adm
|
||||
ExternalUserID: userLinks.ProvidedUserId,
|
||||
DisplayName: userLinks.ProvidedUserName,
|
||||
}
|
||||
if err := s.command.AddUserIDPLink(ctx, userLinks.UserId, org.GetOrgId(), externalIDP); err != nil {
|
||||
if _, err := s.command.AddUserIDPLink(ctx, userLinks.UserId, org.GetOrgId(), externalIDP); err != nil {
|
||||
errors = append(errors, &admin_pb.ImportDataError{Type: "user_link", Id: userLinks.UserId + "_" + userLinks.IdpId, Message: err.Error()})
|
||||
if isCtxTimeout(ctx) {
|
||||
return &admin_pb.ImportDataResponse{Errors: errors, Success: success}, count, err
|
||||
|
36
internal/api/grpc/fields.go
Normal file
36
internal/api/grpc/fields.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"google.golang.org/protobuf/reflect/protoreflect"
|
||||
)
|
||||
|
||||
// AllFieldsSet recusively checks if all values in a message
|
||||
// have a non-zero value.
|
||||
func AllFieldsSet(t testing.TB, msg protoreflect.Message, ignoreTypes ...protoreflect.FullName) {
|
||||
ignore := make(map[protoreflect.FullName]bool, len(ignoreTypes))
|
||||
for _, name := range ignoreTypes {
|
||||
ignore[name] = true
|
||||
}
|
||||
|
||||
md := msg.Descriptor()
|
||||
name := md.FullName()
|
||||
if ignore[name] {
|
||||
return
|
||||
}
|
||||
|
||||
fields := md.Fields()
|
||||
|
||||
for i := 0; i < fields.Len(); i++ {
|
||||
fd := fields.Get(i)
|
||||
if !msg.Has(fd) {
|
||||
t.Errorf("not all fields set in %q, missing %q", name, fd.Name())
|
||||
continue
|
||||
}
|
||||
|
||||
if fd.Kind() == protoreflect.MessageKind {
|
||||
AllFieldsSet(t, msg.Get(fd).Message(), ignoreTypes...)
|
||||
}
|
||||
}
|
||||
}
|
@@ -241,7 +241,6 @@ func AddHumanUserRequestToAddHuman(req *mgmt_pb.AddHumanUserRequest) *command.Ad
|
||||
PasswordChangeRequired: true,
|
||||
Passwordless: false,
|
||||
Register: false,
|
||||
ExternalIDP: false,
|
||||
}
|
||||
if req.Phone != nil {
|
||||
human.Phone = command.Phone{
|
||||
|
@@ -1,5 +1,12 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"path"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
|
||||
)
|
||||
|
||||
const (
|
||||
Healthz = "/Healthz"
|
||||
Readiness = "/Ready"
|
||||
@@ -9,3 +16,18 @@ const (
|
||||
var (
|
||||
Probes = []string{Healthz, Readiness, Validation}
|
||||
)
|
||||
|
||||
func init() {
|
||||
Probes = append(Probes, AllPaths(grpc_reflection_v1alpha.ServerReflection_ServiceDesc)...)
|
||||
}
|
||||
|
||||
func AllPaths(sd grpc.ServiceDesc) []string {
|
||||
paths := make([]string, 0, len(sd.Methods)+len(sd.Streams))
|
||||
for _, method := range sd.Methods {
|
||||
paths = append(paths, path.Join("/", sd.ServiceName, method.MethodName))
|
||||
}
|
||||
for _, stream := range sd.Streams {
|
||||
paths = append(paths, path.Join("/", sd.ServiceName, stream.StreamName))
|
||||
}
|
||||
return paths
|
||||
}
|
||||
|
32
internal/api/grpc/probes_test.go
Normal file
32
internal/api/grpc/probes_test.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
|
||||
)
|
||||
|
||||
func TestAllPaths(t *testing.T) {
|
||||
type args struct {
|
||||
sd grpc.ServiceDesc
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
name: "server reflection",
|
||||
args: args{grpc_reflection_v1alpha.ServerReflection_ServiceDesc},
|
||||
want: []string{"/grpc.reflection.v1alpha.ServerReflection/ServerReflectionInfo"},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := AllPaths(tt.args.sd)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
@@ -16,7 +16,6 @@ import (
|
||||
|
||||
client_middleware "github.com/zitadel/zitadel/internal/api/grpc/client/middleware"
|
||||
"github.com/zitadel/zitadel/internal/api/grpc/server/middleware"
|
||||
http_utils "github.com/zitadel/zitadel/internal/api/http"
|
||||
http_mw "github.com/zitadel/zitadel/internal/api/http/middleware"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
)
|
||||
@@ -66,16 +65,15 @@ var (
|
||||
)
|
||||
|
||||
type Gateway struct {
|
||||
mux *runtime.ServeMux
|
||||
http1HostName string
|
||||
connection *grpc.ClientConn
|
||||
cookieHandler *http_utils.CookieHandler
|
||||
cookieConfig *http_mw.AccessConfig
|
||||
queries *query.Queries
|
||||
mux *runtime.ServeMux
|
||||
http1HostName string
|
||||
connection *grpc.ClientConn
|
||||
accessInterceptor *http_mw.AccessInterceptor
|
||||
queries *query.Queries
|
||||
}
|
||||
|
||||
func (g *Gateway) Handler() http.Handler {
|
||||
return addInterceptors(g.mux, g.http1HostName, g.cookieHandler, g.cookieConfig, g.queries)
|
||||
return addInterceptors(g.mux, g.http1HostName, g.accessInterceptor, g.queries)
|
||||
}
|
||||
|
||||
type CustomHTTPResponse interface {
|
||||
@@ -89,8 +87,7 @@ func CreateGatewayWithPrefix(
|
||||
g WithGatewayPrefix,
|
||||
port uint16,
|
||||
http1HostName string,
|
||||
cookieHandler *http_utils.CookieHandler,
|
||||
cookieConfig *http_mw.AccessConfig,
|
||||
accessInterceptor *http_mw.AccessInterceptor,
|
||||
queries *query.Queries,
|
||||
) (http.Handler, string, error) {
|
||||
runtimeMux := runtime.NewServeMux(serveMuxOptions...)
|
||||
@@ -106,10 +103,10 @@ func CreateGatewayWithPrefix(
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to register grpc gateway: %w", err)
|
||||
}
|
||||
return addInterceptors(runtimeMux, http1HostName, cookieHandler, cookieConfig, queries), g.GatewayPathPrefix(), nil
|
||||
return addInterceptors(runtimeMux, http1HostName, accessInterceptor, queries), g.GatewayPathPrefix(), nil
|
||||
}
|
||||
|
||||
func CreateGateway(ctx context.Context, port uint16, http1HostName string, cookieHandler *http_utils.CookieHandler, cookieConfig *http_mw.AccessConfig) (*Gateway, error) {
|
||||
func CreateGateway(ctx context.Context, port uint16, http1HostName string, accessInterceptor *http_mw.AccessInterceptor) (*Gateway, error) {
|
||||
connection, err := dial(ctx,
|
||||
port,
|
||||
[]grpc.DialOption{
|
||||
@@ -121,11 +118,10 @@ func CreateGateway(ctx context.Context, port uint16, http1HostName string, cooki
|
||||
}
|
||||
runtimeMux := runtime.NewServeMux(append(serveMuxOptions, runtime.WithHealthzEndpoint(healthpb.NewHealthClient(connection)))...)
|
||||
return &Gateway{
|
||||
mux: runtimeMux,
|
||||
http1HostName: http1HostName,
|
||||
connection: connection,
|
||||
cookieHandler: cookieHandler,
|
||||
cookieConfig: cookieConfig,
|
||||
mux: runtimeMux,
|
||||
http1HostName: http1HostName,
|
||||
connection: connection,
|
||||
accessInterceptor: accessInterceptor,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -163,8 +159,7 @@ func dial(ctx context.Context, port uint16, opts []grpc.DialOption) (*grpc.Clien
|
||||
func addInterceptors(
|
||||
handler http.Handler,
|
||||
http1HostName string,
|
||||
cookieHandler *http_utils.CookieHandler,
|
||||
cookieConfig *http_mw.AccessConfig,
|
||||
accessInterceptor *http_mw.AccessInterceptor,
|
||||
queries *query.Queries,
|
||||
) http.Handler {
|
||||
handler = http_mw.CallDurationHandler(handler)
|
||||
@@ -174,7 +169,7 @@ func addInterceptors(
|
||||
handler = http_mw.DefaultTelemetryHandler(handler)
|
||||
// For some non-obvious reason, the exhaustedCookieInterceptor sends the SetCookie header
|
||||
// only if it follows the http_mw.DefaultTelemetryHandler
|
||||
handler = exhaustedCookieInterceptor(handler, cookieHandler, cookieConfig, queries)
|
||||
handler = exhaustedCookieInterceptor(handler, accessInterceptor, queries)
|
||||
handler = http_mw.DefaultMetricsHandler(handler)
|
||||
return handler
|
||||
}
|
||||
@@ -193,35 +188,32 @@ func http1Host(next http.Handler, http1HostName string) http.Handler {
|
||||
|
||||
func exhaustedCookieInterceptor(
|
||||
next http.Handler,
|
||||
cookieHandler *http_utils.CookieHandler,
|
||||
cookieConfig *http_mw.AccessConfig,
|
||||
accessInterceptor *http_mw.AccessInterceptor,
|
||||
queries *query.Queries,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
||||
next.ServeHTTP(&cookieResponseWriter{
|
||||
ResponseWriter: writer,
|
||||
cookieHandler: cookieHandler,
|
||||
cookieConfig: cookieConfig,
|
||||
request: request,
|
||||
queries: queries,
|
||||
ResponseWriter: writer,
|
||||
accessInterceptor: accessInterceptor,
|
||||
request: request,
|
||||
queries: queries,
|
||||
}, request)
|
||||
})
|
||||
}
|
||||
|
||||
type cookieResponseWriter struct {
|
||||
http.ResponseWriter
|
||||
cookieHandler *http_utils.CookieHandler
|
||||
cookieConfig *http_mw.AccessConfig
|
||||
request *http.Request
|
||||
queries *query.Queries
|
||||
accessInterceptor *http_mw.AccessInterceptor
|
||||
request *http.Request
|
||||
queries *query.Queries
|
||||
}
|
||||
|
||||
func (r *cookieResponseWriter) WriteHeader(status int) {
|
||||
if status >= 200 && status < 300 {
|
||||
http_mw.DeleteExhaustedCookie(r.cookieHandler, r.ResponseWriter, r.request, r.cookieConfig)
|
||||
r.accessInterceptor.DeleteExhaustedCookie(r.ResponseWriter)
|
||||
}
|
||||
if status == http.StatusTooManyRequests {
|
||||
http_mw.SetExhaustedCookie(r.cookieHandler, r.ResponseWriter, r.cookieConfig, r.request)
|
||||
r.accessInterceptor.SetExhaustedCookie(r.ResponseWriter, r.request)
|
||||
}
|
||||
r.ResponseWriter.WriteHeader(status)
|
||||
}
|
||||
|
@@ -11,38 +11,13 @@ import (
|
||||
"google.golang.org/protobuf/reflect/protoreflect"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/grpc"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
settings "github.com/zitadel/zitadel/pkg/grpc/settings/v2alpha"
|
||||
)
|
||||
|
||||
var ignoreMessageTypes = map[protoreflect.FullName]bool{
|
||||
"google.protobuf.Duration": true,
|
||||
}
|
||||
|
||||
// allFieldsSet recusively checks if all values in a message
|
||||
// have a non-zero value.
|
||||
func allFieldsSet(t testing.TB, msg protoreflect.Message) {
|
||||
md := msg.Descriptor()
|
||||
name := md.FullName()
|
||||
if ignoreMessageTypes[name] {
|
||||
return
|
||||
}
|
||||
|
||||
fields := md.Fields()
|
||||
|
||||
for i := 0; i < fields.Len(); i++ {
|
||||
fd := fields.Get(i)
|
||||
if !msg.Has(fd) {
|
||||
t.Errorf("not all fields set in %q, missing %q", name, fd.Name())
|
||||
continue
|
||||
}
|
||||
|
||||
if fd.Kind() == protoreflect.MessageKind {
|
||||
allFieldsSet(t, msg.Get(fd).Message())
|
||||
}
|
||||
}
|
||||
}
|
||||
var ignoreTypes = []protoreflect.FullName{"google.protobuf.Duration"}
|
||||
|
||||
func Test_loginSettingsToPb(t *testing.T) {
|
||||
arg := &query.LoginPolicy{
|
||||
@@ -100,7 +75,7 @@ func Test_loginSettingsToPb(t *testing.T) {
|
||||
}
|
||||
|
||||
got := loginSettingsToPb(arg)
|
||||
allFieldsSet(t, got.ProtoReflect())
|
||||
grpc.AllFieldsSet(t, got.ProtoReflect(), ignoreTypes...)
|
||||
if !proto.Equal(got, want) {
|
||||
t.Errorf("loginSettingsToPb() =\n%v\nwant\n%v", got, want)
|
||||
}
|
||||
@@ -241,7 +216,7 @@ func Test_passwordSettingsToPb(t *testing.T) {
|
||||
}
|
||||
|
||||
got := passwordSettingsToPb(arg)
|
||||
allFieldsSet(t, got.ProtoReflect())
|
||||
grpc.AllFieldsSet(t, got.ProtoReflect(), ignoreTypes...)
|
||||
if !proto.Equal(got, want) {
|
||||
t.Errorf("passwordSettingsToPb() =\n%v\nwant\n%v", got, want)
|
||||
}
|
||||
@@ -295,7 +270,7 @@ func Test_brandingSettingsToPb(t *testing.T) {
|
||||
}
|
||||
|
||||
got := brandingSettingsToPb(arg, "http://example.com")
|
||||
allFieldsSet(t, got.ProtoReflect())
|
||||
grpc.AllFieldsSet(t, got.ProtoReflect(), ignoreTypes...)
|
||||
if !proto.Equal(got, want) {
|
||||
t.Errorf("brandingSettingsToPb() =\n%v\nwant\n%v", got, want)
|
||||
}
|
||||
@@ -315,7 +290,7 @@ func Test_domainSettingsToPb(t *testing.T) {
|
||||
ResourceOwnerType: settings.ResourceOwnerType_RESOURCE_OWNER_TYPE_INSTANCE,
|
||||
}
|
||||
got := domainSettingsToPb(arg)
|
||||
allFieldsSet(t, got.ProtoReflect())
|
||||
grpc.AllFieldsSet(t, got.ProtoReflect(), ignoreTypes...)
|
||||
if !proto.Equal(got, want) {
|
||||
t.Errorf("domainSettingsToPb() =\n%v\nwant\n%v", got, want)
|
||||
}
|
||||
@@ -337,7 +312,7 @@ func Test_legalSettingsToPb(t *testing.T) {
|
||||
ResourceOwnerType: settings.ResourceOwnerType_RESOURCE_OWNER_TYPE_INSTANCE,
|
||||
}
|
||||
got := legalAndSupportSettingsToPb(arg)
|
||||
allFieldsSet(t, got.ProtoReflect())
|
||||
grpc.AllFieldsSet(t, got.ProtoReflect(), ignoreTypes...)
|
||||
if !proto.Equal(got, want) {
|
||||
t.Errorf("legalSettingsToPb() =\n%v\nwant\n%v", got, want)
|
||||
}
|
||||
@@ -353,7 +328,7 @@ func Test_lockoutSettingsToPb(t *testing.T) {
|
||||
ResourceOwnerType: settings.ResourceOwnerType_RESOURCE_OWNER_TYPE_INSTANCE,
|
||||
}
|
||||
got := lockoutSettingsToPb(arg)
|
||||
allFieldsSet(t, got.ProtoReflect())
|
||||
grpc.AllFieldsSet(t, got.ProtoReflect(), ignoreTypes...)
|
||||
if !proto.Equal(got, want) {
|
||||
t.Errorf("lockoutSettingsToPb() =\n%v\nwant\n%v", got, want)
|
||||
}
|
||||
@@ -387,7 +362,7 @@ func Test_identityProvidersToPb(t *testing.T) {
|
||||
got := identityProvidersToPb(arg)
|
||||
require.Len(t, got, len(got))
|
||||
for i, v := range got {
|
||||
allFieldsSet(t, v.ProtoReflect())
|
||||
grpc.AllFieldsSet(t, v.ProtoReflect(), ignoreTypes...)
|
||||
if !proto.Equal(v, want[i]) {
|
||||
t.Errorf("identityProvidersToPb() =\n%v\nwant\n%v", got, want)
|
||||
}
|
||||
|
104
internal/api/grpc/user/v2/passkey.go
Normal file
104
internal/api/grpc/user/v2/passkey.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/api/grpc/object/v2"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
caos_errs "github.com/zitadel/zitadel/internal/errors"
|
||||
user "github.com/zitadel/zitadel/pkg/grpc/user/v2alpha"
|
||||
)
|
||||
|
||||
func (s *Server) RegisterPasskey(ctx context.Context, req *user.RegisterPasskeyRequest) (resp *user.RegisterPasskeyResponse, err error) {
|
||||
var (
|
||||
resourceOwner = authz.GetCtxData(ctx).ResourceOwner
|
||||
authenticator = passkeyAuthenticatorToDomain(req.GetAuthenticator())
|
||||
)
|
||||
if code := req.GetCode(); code != nil {
|
||||
return passkeyRegistrationDetailsToPb(
|
||||
s.command.RegisterUserPasskeyWithCode(ctx, req.GetUserId(), resourceOwner, authenticator, code.Id, code.Code, s.userCodeAlg),
|
||||
)
|
||||
}
|
||||
return passkeyRegistrationDetailsToPb(
|
||||
s.command.RegisterUserPasskey(ctx, req.GetUserId(), resourceOwner, authenticator),
|
||||
)
|
||||
}
|
||||
|
||||
func passkeyAuthenticatorToDomain(pa user.PasskeyAuthenticator) domain.AuthenticatorAttachment {
|
||||
switch pa {
|
||||
case user.PasskeyAuthenticator_PASSKEY_AUTHENTICATOR_UNSPECIFIED:
|
||||
return domain.AuthenticatorAttachmentUnspecified
|
||||
case user.PasskeyAuthenticator_PASSKEY_AUTHENTICATOR_PLATFORM:
|
||||
return domain.AuthenticatorAttachmentPlattform
|
||||
case user.PasskeyAuthenticator_PASSKEY_AUTHENTICATOR_CROSS_PLATFORM:
|
||||
return domain.AuthenticatorAttachmentCrossPlattform
|
||||
default:
|
||||
return domain.AuthenticatorAttachmentUnspecified
|
||||
}
|
||||
}
|
||||
|
||||
func passkeyRegistrationDetailsToPb(details *domain.PasskeyRegistrationDetails, err error) (*user.RegisterPasskeyResponse, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user.RegisterPasskeyResponse{
|
||||
Details: object.DomainToDetailsPb(details.ObjectDetails),
|
||||
PasskeyId: details.PasskeyID,
|
||||
PublicKeyCredentialCreationOptions: details.PublicKeyCredentialCreationOptions,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) VerifyPasskeyRegistration(ctx context.Context, req *user.VerifyPasskeyRegistrationRequest) (*user.VerifyPasskeyRegistrationResponse, error) {
|
||||
resourceOwner := authz.GetCtxData(ctx).ResourceOwner
|
||||
objectDetails, err := s.command.HumanHumanPasswordlessSetup(ctx, req.GetUserId(), resourceOwner, req.GetPasskeyName(), "", req.GetPublicKeyCredential())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user.VerifyPasskeyRegistrationResponse{
|
||||
Details: object.DomainToDetailsPb(objectDetails),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) CreatePasskeyRegistrationLink(ctx context.Context, req *user.CreatePasskeyRegistrationLinkRequest) (resp *user.CreatePasskeyRegistrationLinkResponse, err error) {
|
||||
resourceOwner := authz.GetCtxData(ctx).ResourceOwner
|
||||
|
||||
switch medium := req.Medium.(type) {
|
||||
case nil:
|
||||
return passkeyDetailsToPb(
|
||||
s.command.AddUserPasskeyCode(ctx, req.GetUserId(), resourceOwner, s.userCodeAlg),
|
||||
)
|
||||
case *user.CreatePasskeyRegistrationLinkRequest_SendLink:
|
||||
return passkeyDetailsToPb(
|
||||
s.command.AddUserPasskeyCodeURLTemplate(ctx, req.GetUserId(), resourceOwner, s.userCodeAlg, medium.SendLink.GetUrlTemplate()),
|
||||
)
|
||||
case *user.CreatePasskeyRegistrationLinkRequest_ReturnCode:
|
||||
return passkeyCodeDetailsToPb(
|
||||
s.command.AddUserPasskeyCodeReturn(ctx, req.GetUserId(), resourceOwner, s.userCodeAlg),
|
||||
)
|
||||
default:
|
||||
return nil, caos_errs.ThrowUnimplementedf(nil, "USERv2-gaD8y", "verification oneOf %T in method CreatePasskeyRegistrationLink not implemented", medium)
|
||||
}
|
||||
}
|
||||
|
||||
func passkeyDetailsToPb(details *domain.ObjectDetails, err error) (*user.CreatePasskeyRegistrationLinkResponse, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user.CreatePasskeyRegistrationLinkResponse{
|
||||
Details: object.DomainToDetailsPb(details),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func passkeyCodeDetailsToPb(details *domain.PasskeyCodeDetails, err error) (*user.CreatePasskeyRegistrationLinkResponse, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user.CreatePasskeyRegistrationLinkResponse{
|
||||
Details: object.DomainToDetailsPb(details.ObjectDetails),
|
||||
Code: &user.PasskeyRegistrationCode{
|
||||
Id: details.CodeID,
|
||||
Code: details.Code,
|
||||
},
|
||||
}, nil
|
||||
}
|
309
internal/api/grpc/user/v2/passkey_integration_test.go
Normal file
309
internal/api/grpc/user/v2/passkey_integration_test.go
Normal file
@@ -0,0 +1,309 @@
|
||||
//go:build integration
|
||||
|
||||
package user_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/zitadel/internal/integration"
|
||||
"github.com/zitadel/zitadel/internal/webauthn"
|
||||
object "github.com/zitadel/zitadel/pkg/grpc/object/v2alpha"
|
||||
user "github.com/zitadel/zitadel/pkg/grpc/user/v2alpha"
|
||||
)
|
||||
|
||||
func TestServer_RegisterPasskey(t *testing.T) {
|
||||
userID := createHumanUser(t).GetUserId()
|
||||
reg, err := Client.CreatePasskeyRegistrationLink(CTX, &user.CreatePasskeyRegistrationLinkRequest{
|
||||
UserId: userID,
|
||||
Medium: &user.CreatePasskeyRegistrationLinkRequest_ReturnCode{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
client := webauthn.NewClient(Tester.Config.WebAuthNName, Tester.Config.ExternalDomain, "https://"+Tester.Host())
|
||||
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
req *user.RegisterPasskeyRequest
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *user.RegisterPasskeyResponse
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "missing user id",
|
||||
args: args{
|
||||
ctx: CTX,
|
||||
req: &user.RegisterPasskeyRequest{},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "register code",
|
||||
args: args{
|
||||
ctx: CTX,
|
||||
req: &user.RegisterPasskeyRequest{
|
||||
UserId: userID,
|
||||
Code: reg.GetCode(),
|
||||
Authenticator: user.PasskeyAuthenticator_PASSKEY_AUTHENTICATOR_PLATFORM,
|
||||
},
|
||||
},
|
||||
want: &user.RegisterPasskeyResponse{
|
||||
Details: &object.Details{
|
||||
ResourceOwner: Tester.Organisation.ID,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "reuse code (not allowed)",
|
||||
args: args{
|
||||
ctx: CTX,
|
||||
req: &user.RegisterPasskeyRequest{
|
||||
UserId: userID,
|
||||
Code: reg.GetCode(),
|
||||
Authenticator: user.PasskeyAuthenticator_PASSKEY_AUTHENTICATOR_PLATFORM,
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong code",
|
||||
args: args{
|
||||
ctx: CTX,
|
||||
req: &user.RegisterPasskeyRequest{
|
||||
UserId: userID,
|
||||
Code: &user.PasskeyRegistrationCode{
|
||||
Id: reg.GetCode().GetId(),
|
||||
Code: "foobar",
|
||||
},
|
||||
Authenticator: user.PasskeyAuthenticator_PASSKEY_AUTHENTICATOR_CROSS_PLATFORM,
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "user mismatch",
|
||||
args: args{
|
||||
ctx: CTX,
|
||||
req: &user.RegisterPasskeyRequest{
|
||||
UserId: userID,
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
/* TODO after we are able to obtain a Bearer token for a human user
|
||||
{
|
||||
name: "human user",
|
||||
args: args{
|
||||
ctx: CTX,
|
||||
req: &user.RegisterPasskeyRequest{
|
||||
UserId: humanUserID,
|
||||
},
|
||||
},
|
||||
want: &user.RegisterPasskeyResponse{
|
||||
Details: &object.Details{
|
||||
ResourceOwner: Tester.Organisation.ID,
|
||||
},
|
||||
},
|
||||
},
|
||||
*/
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := Client.RegisterPasskey(tt.args.ctx, tt.args.req)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
integration.AssertDetails(t, tt.want, got)
|
||||
if tt.want != nil {
|
||||
assert.NotEmpty(t, got.GetPasskeyId())
|
||||
assert.NotEmpty(t, got.GetPublicKeyCredentialCreationOptions())
|
||||
_, err := client.CreateAttestationResponse(got.GetPublicKeyCredentialCreationOptions())
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_VerifyPasskeyRegistration(t *testing.T) {
|
||||
userID := createHumanUser(t).GetUserId()
|
||||
reg, err := Client.CreatePasskeyRegistrationLink(CTX, &user.CreatePasskeyRegistrationLinkRequest{
|
||||
UserId: userID,
|
||||
Medium: &user.CreatePasskeyRegistrationLinkRequest_ReturnCode{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
pkr, err := Client.RegisterPasskey(CTX, &user.RegisterPasskeyRequest{
|
||||
UserId: userID,
|
||||
Code: reg.GetCode(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, pkr.GetPasskeyId())
|
||||
require.NotEmpty(t, pkr.GetPublicKeyCredentialCreationOptions())
|
||||
|
||||
client := webauthn.NewClient(Tester.Config.WebAuthNName, Tester.Config.ExternalDomain, "https://"+Tester.Host())
|
||||
attestationResponse, err := client.CreateAttestationResponse(pkr.GetPublicKeyCredentialCreationOptions())
|
||||
require.NoError(t, err)
|
||||
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
req *user.VerifyPasskeyRegistrationRequest
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *user.VerifyPasskeyRegistrationResponse
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "missing user id",
|
||||
args: args{
|
||||
ctx: CTX,
|
||||
req: &user.VerifyPasskeyRegistrationRequest{
|
||||
PasskeyId: pkr.GetPasskeyId(),
|
||||
PublicKeyCredential: []byte(attestationResponse),
|
||||
PasskeyName: "nice name",
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "success",
|
||||
args: args{
|
||||
ctx: CTX,
|
||||
req: &user.VerifyPasskeyRegistrationRequest{
|
||||
UserId: userID,
|
||||
PasskeyId: pkr.GetPasskeyId(),
|
||||
PublicKeyCredential: attestationResponse,
|
||||
PasskeyName: "nice name",
|
||||
},
|
||||
},
|
||||
want: &user.VerifyPasskeyRegistrationResponse{
|
||||
Details: &object.Details{
|
||||
ResourceOwner: Tester.Organisation.ID,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "wrong credential",
|
||||
args: args{
|
||||
ctx: CTX,
|
||||
req: &user.VerifyPasskeyRegistrationRequest{
|
||||
UserId: userID,
|
||||
PasskeyId: pkr.GetPasskeyId(),
|
||||
PublicKeyCredential: []byte("attestationResponseattestationResponseattestationResponse"),
|
||||
PasskeyName: "nice name",
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := Client.VerifyPasskeyRegistration(tt.args.ctx, tt.args.req)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
integration.AssertDetails(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_CreatePasskeyRegistrationLink(t *testing.T) {
|
||||
userID := createHumanUser(t).GetUserId()
|
||||
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
req *user.CreatePasskeyRegistrationLinkRequest
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *user.CreatePasskeyRegistrationLinkResponse
|
||||
wantCode bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "missing user id",
|
||||
args: args{
|
||||
ctx: CTX,
|
||||
req: &user.CreatePasskeyRegistrationLinkRequest{},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "send default mail",
|
||||
args: args{
|
||||
ctx: CTX,
|
||||
req: &user.CreatePasskeyRegistrationLinkRequest{
|
||||
UserId: userID,
|
||||
},
|
||||
},
|
||||
want: &user.CreatePasskeyRegistrationLinkResponse{
|
||||
Details: &object.Details{
|
||||
ResourceOwner: Tester.Organisation.ID,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "send custom url",
|
||||
args: args{
|
||||
ctx: CTX,
|
||||
req: &user.CreatePasskeyRegistrationLinkRequest{
|
||||
UserId: userID,
|
||||
Medium: &user.CreatePasskeyRegistrationLinkRequest_SendLink{
|
||||
SendLink: &user.SendPasskeyRegistrationLink{
|
||||
UrlTemplate: gu.Ptr("https://example.com/passkey/register?userID={{.UserID}}&orgID={{.OrgID}}&codeID={{.CodeID}}&code={{.Code}}"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: &user.CreatePasskeyRegistrationLinkResponse{
|
||||
Details: &object.Details{
|
||||
ResourceOwner: Tester.Organisation.ID,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "return code",
|
||||
args: args{
|
||||
ctx: CTX,
|
||||
req: &user.CreatePasskeyRegistrationLinkRequest{
|
||||
UserId: userID,
|
||||
Medium: &user.CreatePasskeyRegistrationLinkRequest_ReturnCode{},
|
||||
},
|
||||
},
|
||||
want: &user.CreatePasskeyRegistrationLinkResponse{
|
||||
Details: &object.Details{
|
||||
ResourceOwner: Tester.Organisation.ID,
|
||||
},
|
||||
},
|
||||
wantCode: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := Client.CreatePasskeyRegistrationLink(tt.args.ctx, tt.args.req)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
integration.AssertDetails(t, tt.want, got)
|
||||
if tt.wantCode {
|
||||
assert.NotEmpty(t, got.GetCode().GetId())
|
||||
assert.NotEmpty(t, got.GetCode().GetId())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
210
internal/api/grpc/user/v2/passkey_test.go
Normal file
210
internal/api/grpc/user/v2/passkey_test.go
Normal file
@@ -0,0 +1,210 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/grpc"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
object "github.com/zitadel/zitadel/pkg/grpc/object/v2alpha"
|
||||
user "github.com/zitadel/zitadel/pkg/grpc/user/v2alpha"
|
||||
)
|
||||
|
||||
func Test_passkeyAuthenticatorToDomain(t *testing.T) {
|
||||
tests := []struct {
|
||||
pa user.PasskeyAuthenticator
|
||||
want domain.AuthenticatorAttachment
|
||||
}{
|
||||
{
|
||||
pa: user.PasskeyAuthenticator_PASSKEY_AUTHENTICATOR_UNSPECIFIED,
|
||||
want: domain.AuthenticatorAttachmentUnspecified,
|
||||
},
|
||||
{
|
||||
pa: user.PasskeyAuthenticator_PASSKEY_AUTHENTICATOR_PLATFORM,
|
||||
want: domain.AuthenticatorAttachmentPlattform,
|
||||
},
|
||||
{
|
||||
pa: user.PasskeyAuthenticator_PASSKEY_AUTHENTICATOR_CROSS_PLATFORM,
|
||||
want: domain.AuthenticatorAttachmentCrossPlattform,
|
||||
},
|
||||
{
|
||||
pa: 999,
|
||||
want: domain.AuthenticatorAttachmentUnspecified,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.pa.String(), func(t *testing.T) {
|
||||
got := passkeyAuthenticatorToDomain(tt.pa)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_passkeyRegistrationDetailsToPb(t *testing.T) {
|
||||
type args struct {
|
||||
details *domain.PasskeyRegistrationDetails
|
||||
err error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *user.RegisterPasskeyResponse
|
||||
}{
|
||||
{
|
||||
name: "an error",
|
||||
args: args{
|
||||
details: nil,
|
||||
err: io.ErrClosedPipe,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ok",
|
||||
args: args{
|
||||
details: &domain.PasskeyRegistrationDetails{
|
||||
ObjectDetails: &domain.ObjectDetails{
|
||||
Sequence: 22,
|
||||
EventDate: time.Unix(3000, 22),
|
||||
ResourceOwner: "me",
|
||||
},
|
||||
PasskeyID: "123",
|
||||
PublicKeyCredentialCreationOptions: []byte{1, 2, 3},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
want: &user.RegisterPasskeyResponse{
|
||||
Details: &object.Details{
|
||||
Sequence: 22,
|
||||
ChangeDate: ×tamppb.Timestamp{
|
||||
Seconds: 3000,
|
||||
Nanos: 22,
|
||||
},
|
||||
ResourceOwner: "me",
|
||||
},
|
||||
PasskeyId: "123",
|
||||
PublicKeyCredentialCreationOptions: []byte{1, 2, 3},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := passkeyRegistrationDetailsToPb(tt.args.details, tt.args.err)
|
||||
require.ErrorIs(t, err, tt.args.err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
if tt.want != nil {
|
||||
grpc.AllFieldsSet(t, got.ProtoReflect())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_passkeyDetailsToPb(t *testing.T) {
|
||||
type args struct {
|
||||
details *domain.ObjectDetails
|
||||
err error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *user.CreatePasskeyRegistrationLinkResponse
|
||||
}{
|
||||
{
|
||||
name: "an error",
|
||||
args: args{
|
||||
details: nil,
|
||||
err: io.ErrClosedPipe,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ok",
|
||||
args: args{
|
||||
details: &domain.ObjectDetails{
|
||||
Sequence: 22,
|
||||
EventDate: time.Unix(3000, 22),
|
||||
ResourceOwner: "me",
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
want: &user.CreatePasskeyRegistrationLinkResponse{
|
||||
Details: &object.Details{
|
||||
Sequence: 22,
|
||||
ChangeDate: ×tamppb.Timestamp{
|
||||
Seconds: 3000,
|
||||
Nanos: 22,
|
||||
},
|
||||
ResourceOwner: "me",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := passkeyDetailsToPb(tt.args.details, tt.args.err)
|
||||
require.ErrorIs(t, err, tt.args.err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_passkeyCodeDetailsToPb(t *testing.T) {
|
||||
type args struct {
|
||||
details *domain.PasskeyCodeDetails
|
||||
err error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *user.CreatePasskeyRegistrationLinkResponse
|
||||
}{
|
||||
{
|
||||
name: "an error",
|
||||
args: args{
|
||||
details: nil,
|
||||
err: io.ErrClosedPipe,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ok",
|
||||
args: args{
|
||||
details: &domain.PasskeyCodeDetails{
|
||||
ObjectDetails: &domain.ObjectDetails{
|
||||
Sequence: 22,
|
||||
EventDate: time.Unix(3000, 22),
|
||||
ResourceOwner: "me",
|
||||
},
|
||||
CodeID: "123",
|
||||
Code: "456",
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
want: &user.CreatePasskeyRegistrationLinkResponse{
|
||||
Details: &object.Details{
|
||||
Sequence: 22,
|
||||
ChangeDate: ×tamppb.Timestamp{
|
||||
Seconds: 3000,
|
||||
Nanos: 22,
|
||||
},
|
||||
ResourceOwner: "me",
|
||||
},
|
||||
Code: &user.PasskeyRegistrationCode{
|
||||
Id: "123",
|
||||
Code: "456",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := passkeyCodeDetailsToPb(tt.args.details, tt.args.err)
|
||||
require.ErrorIs(t, err, tt.args.err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
if tt.want != nil {
|
||||
grpc.AllFieldsSet(t, got.ProtoReflect())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@@ -1,6 +1,8 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
@@ -18,15 +20,25 @@ type Server struct {
|
||||
command *command.Commands
|
||||
query *query.Queries
|
||||
userCodeAlg crypto.EncryptionAlgorithm
|
||||
idpAlg crypto.EncryptionAlgorithm
|
||||
idpCallback func(ctx context.Context) string
|
||||
}
|
||||
|
||||
type Config struct{}
|
||||
|
||||
func CreateServer(command *command.Commands, query *query.Queries, userCodeAlg crypto.EncryptionAlgorithm) *Server {
|
||||
func CreateServer(
|
||||
command *command.Commands,
|
||||
query *query.Queries,
|
||||
userCodeAlg crypto.EncryptionAlgorithm,
|
||||
idpAlg crypto.EncryptionAlgorithm,
|
||||
idpCallback func(ctx context.Context) string,
|
||||
) *Server {
|
||||
return &Server{
|
||||
command: command,
|
||||
query: query,
|
||||
userCodeAlg: userCodeAlg,
|
||||
idpAlg: idpAlg,
|
||||
idpCallback: idpCallback,
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -2,15 +2,19 @@ package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
|
||||
"golang.org/x/text/language"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/api/grpc/object/v2"
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
object_pb "github.com/zitadel/zitadel/pkg/grpc/object/v2alpha"
|
||||
user "github.com/zitadel/zitadel/pkg/grpc/user/v2alpha"
|
||||
)
|
||||
|
||||
@@ -56,6 +60,14 @@ func addUserRequestToAddHuman(req *user.AddHumanUserRequest) (*command.AddHuman,
|
||||
Value: metadataEntry.GetValue(),
|
||||
}
|
||||
}
|
||||
links := make([]*command.AddLink, len(req.GetIdpLinks()))
|
||||
for i, link := range req.GetIdpLinks() {
|
||||
links[i] = &command.AddLink{
|
||||
IDPID: link.GetIdpId(),
|
||||
IDPExternalID: link.GetIdpExternalId(),
|
||||
DisplayName: link.GetDisplayName(),
|
||||
}
|
||||
}
|
||||
return &command.AddHuman{
|
||||
ID: req.GetUserId(),
|
||||
Username: username,
|
||||
@@ -76,9 +88,9 @@ func addUserRequestToAddHuman(req *user.AddHumanUserRequest) (*command.AddHuman,
|
||||
BcryptedPassword: bcryptedPassword,
|
||||
PasswordChangeRequired: passwordChangeRequired,
|
||||
Passwordless: false,
|
||||
ExternalIDP: false,
|
||||
Register: false,
|
||||
Metadata: metadata,
|
||||
Links: links,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -107,3 +119,95 @@ func hashedPasswordToCommand(hashed *user.HashedPassword) (string, error) {
|
||||
}
|
||||
return hashed.GetHash(), nil
|
||||
}
|
||||
|
||||
func (s *Server) AddIDPLink(ctx context.Context, req *user.AddIDPLinkRequest) (_ *user.AddIDPLinkResponse, err error) {
|
||||
orgID := authz.GetCtxData(ctx).OrgID
|
||||
details, err := s.command.AddUserIDPLink(ctx, req.UserId, orgID, &domain.UserIDPLink{
|
||||
IDPConfigID: req.GetIdpLink().GetIdpId(),
|
||||
ExternalUserID: req.GetIdpLink().GetIdpExternalId(),
|
||||
DisplayName: req.GetIdpLink().GetDisplayName(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user.AddIDPLinkResponse{
|
||||
Details: object.DomainToDetailsPb(details),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) StartIdentityProviderFlow(ctx context.Context, req *user.StartIdentityProviderFlowRequest) (_ *user.StartIdentityProviderFlowResponse, err error) {
|
||||
id, details, err := s.command.CreateIntent(ctx, req.GetIdpId(), req.GetSuccessUrl(), req.GetFailureUrl(), authz.GetCtxData(ctx).OrgID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authURL, err := s.command.AuthURLFromProvider(ctx, req.GetIdpId(), id, s.idpCallback(ctx))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user.StartIdentityProviderFlowResponse{
|
||||
Details: object.DomainToDetailsPb(details),
|
||||
NextStep: &user.StartIdentityProviderFlowResponse_AuthUrl{AuthUrl: authURL},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) RetrieveIdentityProviderInformation(ctx context.Context, req *user.RetrieveIdentityProviderInformationRequest) (_ *user.RetrieveIdentityProviderInformationResponse, err error) {
|
||||
intent, err := s.command.GetIntentWriteModel(ctx, req.GetIntentId(), authz.GetCtxData(ctx).OrgID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.checkIntentToken(req.GetToken(), intent.AggregateID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if intent.State != domain.IDPIntentStateSucceeded {
|
||||
return nil, errors.ThrowPreconditionFailed(nil, "IDP-Hk38e", "Errors.Intent.NotSucceeded")
|
||||
}
|
||||
return intentToIDPInformationPb(intent, s.idpAlg)
|
||||
}
|
||||
|
||||
func intentToIDPInformationPb(intent *command.IDPIntentWriteModel, alg crypto.EncryptionAlgorithm) (_ *user.RetrieveIdentityProviderInformationResponse, err error) {
|
||||
var idToken *string
|
||||
if intent.IDPIDToken != "" {
|
||||
idToken = &intent.IDPIDToken
|
||||
}
|
||||
var accessToken string
|
||||
if intent.IDPAccessToken != nil {
|
||||
accessToken, err = crypto.DecryptString(intent.IDPAccessToken, alg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &user.RetrieveIdentityProviderInformationResponse{
|
||||
Details: &object_pb.Details{
|
||||
Sequence: intent.ProcessedSequence,
|
||||
ChangeDate: timestamppb.New(intent.ChangeDate),
|
||||
ResourceOwner: intent.ResourceOwner,
|
||||
},
|
||||
IdpInformation: &user.IDPInformation{
|
||||
Access: &user.IDPInformation_Oauth{
|
||||
Oauth: &user.IDPOAuthAccessInformation{
|
||||
AccessToken: accessToken,
|
||||
IdToken: idToken,
|
||||
},
|
||||
},
|
||||
IdpInformation: intent.IDPUser,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) checkIntentToken(token string, intentID string) error {
|
||||
if token == "" {
|
||||
return errors.ThrowPermissionDenied(nil, "IDP-Sfefs", "Errors.Intent.InvalidToken")
|
||||
}
|
||||
data, err := base64.RawURLEncoding.DecodeString(token)
|
||||
if err != nil {
|
||||
return errors.ThrowPermissionDenied(err, "IDP-Swg31", "Errors.Intent.InvalidToken")
|
||||
}
|
||||
decryptedToken, err := s.idpAlg.Decrypt(data, s.idpAlg.EncryptionKeyID())
|
||||
if err != nil {
|
||||
return errors.ThrowPermissionDenied(err, "IDP-Sf4gt", "Errors.Intent.InvalidToken")
|
||||
}
|
||||
if string(decryptedToken) != intentID {
|
||||
return errors.ThrowPermissionDenied(nil, "IDP-dkje3", "Errors.Intent.InvalidToken")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@@ -6,16 +6,24 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
|
||||
"github.com/zitadel/zitadel/internal/integration"
|
||||
"github.com/zitadel/zitadel/internal/repository/idp"
|
||||
object "github.com/zitadel/zitadel/pkg/grpc/object/v2alpha"
|
||||
user "github.com/zitadel/zitadel/pkg/grpc/user/v2alpha"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -39,7 +47,60 @@ func TestMain(m *testing.M) {
|
||||
}())
|
||||
}
|
||||
|
||||
func createProvider(t *testing.T) string {
|
||||
ctx := authz.WithInstance(context.Background(), Tester.Instance)
|
||||
id, _, err := Tester.Commands.AddOrgGenericOAuthProvider(ctx, Tester.Organisation.ID, command.GenericOAuthProvider{
|
||||
"idp",
|
||||
"clientID",
|
||||
"clientSecret",
|
||||
"https://example.com/oauth/v2/authorize",
|
||||
"https://example.com/oauth/v2/token",
|
||||
"https://api.example.com/user",
|
||||
[]string{"openid", "profile", "email"},
|
||||
"id",
|
||||
idp.Options{
|
||||
IsLinkingAllowed: true,
|
||||
IsCreationAllowed: true,
|
||||
IsAutoCreation: true,
|
||||
IsAutoUpdate: true,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return id
|
||||
}
|
||||
|
||||
func createIntent(t *testing.T, idpID string) string {
|
||||
ctx := authz.WithInstance(context.Background(), Tester.Instance)
|
||||
id, _, err := Tester.Commands.CreateIntent(ctx, idpID, "https://example.com/success", "https://example.com/failure", Tester.Organisation.ID)
|
||||
require.NoError(t, err)
|
||||
return id
|
||||
}
|
||||
|
||||
func createSuccessfulIntent(t *testing.T, idpID string) (string, string, time.Time, uint64) {
|
||||
ctx := authz.WithInstance(context.Background(), Tester.Instance)
|
||||
intentID := createIntent(t, idpID)
|
||||
writeModel, err := Tester.Commands.GetIntentWriteModel(ctx, intentID, Tester.Organisation.ID)
|
||||
require.NoError(t, err)
|
||||
idpUser := &oauth.UserMapper{
|
||||
RawInfo: map[string]interface{}{
|
||||
"id": "id",
|
||||
},
|
||||
}
|
||||
idpSession := &oauth.Session{
|
||||
Tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
|
||||
Token: &oauth2.Token{
|
||||
AccessToken: "accessToken",
|
||||
},
|
||||
IDToken: "idToken",
|
||||
},
|
||||
}
|
||||
token, err := Tester.Commands.SucceedIDPIntent(ctx, writeModel, idpUser, idpSession, "")
|
||||
require.NoError(t, err)
|
||||
return intentID, token, writeModel.ChangeDate, writeModel.ProcessedSequence
|
||||
}
|
||||
|
||||
func TestServer_AddHumanUser(t *testing.T) {
|
||||
idpID := createProvider(t)
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
req *user.AddHumanUserRequest
|
||||
@@ -287,6 +348,105 @@ func TestServer_AddHumanUser(t *testing.T) {
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing idp",
|
||||
args: args{
|
||||
CTX,
|
||||
&user.AddHumanUserRequest{
|
||||
Organisation: &object.Organisation{
|
||||
Org: &object.Organisation_OrgId{
|
||||
OrgId: Tester.Organisation.ID,
|
||||
},
|
||||
},
|
||||
Profile: &user.SetHumanProfile{
|
||||
FirstName: "Donald",
|
||||
LastName: "Duck",
|
||||
NickName: gu.Ptr("Dukkie"),
|
||||
DisplayName: gu.Ptr("Donald Duck"),
|
||||
PreferredLanguage: gu.Ptr("en"),
|
||||
Gender: user.Gender_GENDER_DIVERSE.Enum(),
|
||||
},
|
||||
Email: &user.SetHumanEmail{
|
||||
Email: "livio@zitadel.com",
|
||||
Verification: &user.SetHumanEmail_IsVerified{
|
||||
IsVerified: true,
|
||||
},
|
||||
},
|
||||
Metadata: []*user.SetMetadataEntry{
|
||||
{
|
||||
Key: "somekey",
|
||||
Value: []byte("somevalue"),
|
||||
},
|
||||
},
|
||||
PasswordType: &user.AddHumanUserRequest_Password{
|
||||
Password: &user.Password{
|
||||
Password: "DifficultPW666!",
|
||||
ChangeRequired: false,
|
||||
},
|
||||
},
|
||||
IdpLinks: []*user.IDPLink{
|
||||
{
|
||||
IdpId: "idpID",
|
||||
IdpExternalId: "externalID",
|
||||
DisplayName: "displayName",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "with idp",
|
||||
args: args{
|
||||
CTX,
|
||||
&user.AddHumanUserRequest{
|
||||
Organisation: &object.Organisation{
|
||||
Org: &object.Organisation_OrgId{
|
||||
OrgId: Tester.Organisation.ID,
|
||||
},
|
||||
},
|
||||
Profile: &user.SetHumanProfile{
|
||||
FirstName: "Donald",
|
||||
LastName: "Duck",
|
||||
NickName: gu.Ptr("Dukkie"),
|
||||
DisplayName: gu.Ptr("Donald Duck"),
|
||||
PreferredLanguage: gu.Ptr("en"),
|
||||
Gender: user.Gender_GENDER_DIVERSE.Enum(),
|
||||
},
|
||||
Email: &user.SetHumanEmail{
|
||||
Email: "livio@zitadel.com",
|
||||
Verification: &user.SetHumanEmail_IsVerified{
|
||||
IsVerified: true,
|
||||
},
|
||||
},
|
||||
Metadata: []*user.SetMetadataEntry{
|
||||
{
|
||||
Key: "somekey",
|
||||
Value: []byte("somevalue"),
|
||||
},
|
||||
},
|
||||
PasswordType: &user.AddHumanUserRequest_Password{
|
||||
Password: &user.Password{
|
||||
Password: "DifficultPW666!",
|
||||
ChangeRequired: false,
|
||||
},
|
||||
},
|
||||
IdpLinks: []*user.IDPLink{
|
||||
{
|
||||
IdpId: idpID,
|
||||
IdpExternalId: "externalID",
|
||||
DisplayName: "displayName",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: &user.AddHumanUserResponse{
|
||||
Details: &object.Details{
|
||||
ChangeDate: timestamppb.Now(),
|
||||
ResourceOwner: Tester.Organisation.ID,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
@@ -315,3 +475,226 @@ func TestServer_AddHumanUser(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_AddIDPLink(t *testing.T) {
|
||||
idpID := createProvider(t)
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
req *user.AddIDPLinkRequest
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *user.AddIDPLinkResponse
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "user does not exist",
|
||||
args: args{
|
||||
CTX,
|
||||
&user.AddIDPLinkRequest{
|
||||
UserId: "userID",
|
||||
IdpLink: &user.IDPLink{
|
||||
IdpId: idpID,
|
||||
IdpExternalId: "externalID",
|
||||
DisplayName: "displayName",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "idp does not exist",
|
||||
args: args{
|
||||
CTX,
|
||||
&user.AddIDPLinkRequest{
|
||||
UserId: Tester.Users[integration.OrgOwner].ID,
|
||||
IdpLink: &user.IDPLink{
|
||||
IdpId: "idpID",
|
||||
IdpExternalId: "externalID",
|
||||
DisplayName: "displayName",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "add link",
|
||||
args: args{
|
||||
CTX,
|
||||
&user.AddIDPLinkRequest{
|
||||
UserId: Tester.Users[integration.OrgOwner].ID,
|
||||
IdpLink: &user.IDPLink{
|
||||
IdpId: idpID,
|
||||
IdpExternalId: "externalID",
|
||||
DisplayName: "displayName",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: &user.AddIDPLinkResponse{
|
||||
Details: &object.Details{
|
||||
ChangeDate: timestamppb.Now(),
|
||||
ResourceOwner: Tester.Organisation.ID,
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := Client.AddIDPLink(tt.args.ctx, tt.args.req)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
integration.AssertDetails(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_StartIdentityProviderFlow(t *testing.T) {
|
||||
idpID := createProvider(t)
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
req *user.StartIdentityProviderFlowRequest
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *user.StartIdentityProviderFlowResponse
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "missing urls",
|
||||
args: args{
|
||||
CTX,
|
||||
&user.StartIdentityProviderFlowRequest{
|
||||
IdpId: idpID,
|
||||
},
|
||||
},
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "next step auth url",
|
||||
args: args{
|
||||
CTX,
|
||||
&user.StartIdentityProviderFlowRequest{
|
||||
IdpId: idpID,
|
||||
SuccessUrl: "https://example.com/success",
|
||||
FailureUrl: "https://example.com/failure",
|
||||
},
|
||||
},
|
||||
want: &user.StartIdentityProviderFlowResponse{
|
||||
Details: &object.Details{
|
||||
ChangeDate: timestamppb.Now(),
|
||||
ResourceOwner: Tester.Organisation.ID,
|
||||
},
|
||||
NextStep: &user.StartIdentityProviderFlowResponse_AuthUrl{
|
||||
AuthUrl: "https://example.com/oauth/v2/authorize?client_id=clientID&prompt=select_account&redirect_uri=https%3A%2F%2Flocalhost%3A8080%2Fidps%2Fcallback&response_type=code&scope=openid+profile+email&state=",
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := Client.StartIdentityProviderFlow(tt.args.ctx, tt.args.req)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
if nextStep := tt.want.GetNextStep(); nextStep != nil {
|
||||
if !strings.HasPrefix(got.GetAuthUrl(), tt.want.GetAuthUrl()) {
|
||||
assert.Failf(t, "auth url does not match", "expected: %s, but got: %s", tt.want.GetAuthUrl(), got.GetAuthUrl())
|
||||
}
|
||||
}
|
||||
integration.AssertDetails(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_RetrieveIdentityProviderInformation(t *testing.T) {
|
||||
idpID := createProvider(t)
|
||||
intentID := createIntent(t, idpID)
|
||||
successfulID, token, changeDate, sequence := createSuccessfulIntent(t, idpID)
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
req *user.RetrieveIdentityProviderInformationRequest
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *user.RetrieveIdentityProviderInformationResponse
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "failed intent",
|
||||
args: args{
|
||||
CTX,
|
||||
&user.RetrieveIdentityProviderInformationRequest{
|
||||
IntentId: intentID,
|
||||
Token: "",
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "wrong token",
|
||||
args: args{
|
||||
CTX,
|
||||
&user.RetrieveIdentityProviderInformationRequest{
|
||||
IntentId: successfulID,
|
||||
Token: "wrong token",
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "retrieve successful intent",
|
||||
args: args{
|
||||
CTX,
|
||||
&user.RetrieveIdentityProviderInformationRequest{
|
||||
IntentId: successfulID,
|
||||
Token: token,
|
||||
},
|
||||
},
|
||||
want: &user.RetrieveIdentityProviderInformationResponse{
|
||||
Details: &object.Details{
|
||||
ChangeDate: timestamppb.New(changeDate),
|
||||
ResourceOwner: Tester.Organisation.ID,
|
||||
Sequence: sequence,
|
||||
},
|
||||
IdpInformation: &user.IDPInformation{
|
||||
Access: &user.IDPInformation_Oauth{
|
||||
Oauth: &user.IDPOAuthAccessInformation{
|
||||
AccessToken: "accessToken",
|
||||
IdToken: gu.Ptr("idToken"),
|
||||
},
|
||||
},
|
||||
IdpInformation: []byte(`{"RawInfo":{"id":"id"}}`),
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := Client.RetrieveIdentityProviderInformation(tt.args.ctx, tt.args.req)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
require.Equal(t, tt.want.GetDetails(), got.GetDetails())
|
||||
require.Equal(t, tt.want.GetIdpInformation(), got.GetIdpInformation())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@@ -3,11 +3,21 @@ package user
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/grpc"
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
caos_errs "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
object_pb "github.com/zitadel/zitadel/pkg/grpc/object/v2alpha"
|
||||
user "github.com/zitadel/zitadel/pkg/grpc/user/v2alpha"
|
||||
)
|
||||
|
||||
@@ -78,3 +88,118 @@ func Test_hashedPasswordToCommand(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_intentToIDPInformationPb(t *testing.T) {
|
||||
decryption := func(err error) crypto.EncryptionAlgorithm {
|
||||
mCrypto := crypto.NewMockEncryptionAlgorithm(gomock.NewController(t))
|
||||
mCrypto.EXPECT().Algorithm().Return("enc")
|
||||
mCrypto.EXPECT().DecryptionKeyIDs().Return([]string{"id"})
|
||||
mCrypto.EXPECT().DecryptString(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(code []byte, keyID string) (string, error) {
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(code), nil
|
||||
})
|
||||
return mCrypto
|
||||
}
|
||||
|
||||
type args struct {
|
||||
intent *command.IDPIntentWriteModel
|
||||
alg crypto.EncryptionAlgorithm
|
||||
}
|
||||
type res struct {
|
||||
resp *user.RetrieveIdentityProviderInformationResponse
|
||||
err error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"decryption invalid key id error",
|
||||
args{
|
||||
intent: &command.IDPIntentWriteModel{
|
||||
WriteModel: eventstore.WriteModel{
|
||||
AggregateID: "intentID",
|
||||
ProcessedSequence: 123,
|
||||
ResourceOwner: "ro",
|
||||
InstanceID: "instanceID",
|
||||
ChangeDate: time.Date(2019, 4, 1, 1, 1, 1, 1, time.Local),
|
||||
},
|
||||
IDPID: "idpID",
|
||||
IDPUser: []byte(`{"id": "id"}`),
|
||||
IDPAccessToken: &crypto.CryptoValue{
|
||||
CryptoType: crypto.TypeEncryption,
|
||||
Algorithm: "enc",
|
||||
KeyID: "id",
|
||||
Crypted: []byte("accessToken"),
|
||||
},
|
||||
IDPIDToken: "idToken",
|
||||
UserID: "userID",
|
||||
State: domain.IDPIntentStateSucceeded,
|
||||
},
|
||||
alg: decryption(caos_errs.ThrowInternal(nil, "id", "invalid key id")),
|
||||
},
|
||||
res{
|
||||
resp: nil,
|
||||
err: caos_errs.ThrowInternal(nil, "id", "invalid key id"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"successful",
|
||||
args{
|
||||
intent: &command.IDPIntentWriteModel{
|
||||
WriteModel: eventstore.WriteModel{
|
||||
AggregateID: "intentID",
|
||||
ProcessedSequence: 123,
|
||||
ResourceOwner: "ro",
|
||||
InstanceID: "instanceID",
|
||||
ChangeDate: time.Date(2019, 4, 1, 1, 1, 1, 1, time.Local),
|
||||
},
|
||||
IDPID: "idpID",
|
||||
IDPUser: []byte(`{"id": "id"}`),
|
||||
IDPAccessToken: &crypto.CryptoValue{
|
||||
CryptoType: crypto.TypeEncryption,
|
||||
Algorithm: "enc",
|
||||
KeyID: "id",
|
||||
Crypted: []byte("accessToken"),
|
||||
},
|
||||
IDPIDToken: "idToken",
|
||||
UserID: "userID",
|
||||
State: domain.IDPIntentStateSucceeded,
|
||||
},
|
||||
alg: decryption(nil),
|
||||
},
|
||||
res{
|
||||
resp: &user.RetrieveIdentityProviderInformationResponse{
|
||||
Details: &object_pb.Details{
|
||||
Sequence: 123,
|
||||
ChangeDate: timestamppb.New(time.Date(2019, 4, 1, 1, 1, 1, 1, time.Local)),
|
||||
ResourceOwner: "ro",
|
||||
},
|
||||
IdpInformation: &user.IDPInformation{
|
||||
Access: &user.IDPInformation_Oauth{
|
||||
Oauth: &user.IDPOAuthAccessInformation{
|
||||
AccessToken: "accessToken",
|
||||
IdToken: gu.Ptr("idToken"),
|
||||
}},
|
||||
IdpInformation: []byte(`{"id": "id"}`),
|
||||
},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := intentToIDPInformationPb(tt.args.intent, tt.args.alg)
|
||||
require.ErrorIs(t, err, tt.res.err)
|
||||
assert.Equal(t, tt.res.resp, got)
|
||||
if tt.res.resp != nil {
|
||||
grpc.AllFieldsSet(t, got.ProtoReflect())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@@ -123,8 +123,8 @@ func (c *CookieHandler) SetEncryptedCookie(w http.ResponseWriter, name, domain s
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CookieHandler) DeleteCookie(w http.ResponseWriter, r *http.Request, name string) {
|
||||
c.httpSet(w, name, r.Host, "", -1)
|
||||
func (c *CookieHandler) DeleteCookie(w http.ResponseWriter, name string) {
|
||||
c.httpSet(w, name, "", "", -1)
|
||||
}
|
||||
|
||||
func (c *CookieHandler) httpSet(w http.ResponseWriter, name, domain, value string, maxage int) {
|
||||
|
@@ -1,6 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -32,15 +33,54 @@ type AccessConfig struct {
|
||||
// NewAccessInterceptor intercepts all requests and stores them to the logstore.
|
||||
// If storeOnly is false, it also checks if requests are exhausted.
|
||||
// If requests are exhausted, it also returns http.StatusTooManyRequests and sets a cookie
|
||||
func NewAccessInterceptor(svc *logstore.Service, cookieHandler *http_utils.CookieHandler, cookieConfig *AccessConfig, storeOnly bool) *AccessInterceptor {
|
||||
func NewAccessInterceptor(svc *logstore.Service, cookieHandler *http_utils.CookieHandler, cookieConfig *AccessConfig) *AccessInterceptor {
|
||||
return &AccessInterceptor{
|
||||
svc: svc,
|
||||
cookieHandler: cookieHandler,
|
||||
limitConfig: cookieConfig,
|
||||
storeOnly: storeOnly,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *AccessInterceptor) WithoutLimiting() *AccessInterceptor {
|
||||
return &AccessInterceptor{
|
||||
svc: a.svc,
|
||||
cookieHandler: a.cookieHandler,
|
||||
limitConfig: a.limitConfig,
|
||||
storeOnly: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *AccessInterceptor) AccessService() *logstore.Service {
|
||||
return a.svc
|
||||
}
|
||||
|
||||
func (a *AccessInterceptor) Limit(ctx context.Context) bool {
|
||||
if !a.svc.Enabled() || a.storeOnly {
|
||||
return false
|
||||
}
|
||||
instance := authz.GetInstance(ctx)
|
||||
remaining := a.svc.Limit(ctx, instance.InstanceID())
|
||||
return remaining != nil && *remaining <= 0
|
||||
}
|
||||
|
||||
func (a *AccessInterceptor) SetExhaustedCookie(writer http.ResponseWriter, request *http.Request) {
|
||||
cookieValue := "true"
|
||||
host := request.Header.Get(middleware.HTTP1Host)
|
||||
domain := host
|
||||
if strings.ContainsAny(host, ":") {
|
||||
var err error
|
||||
domain, _, err = net.SplitHostPort(host)
|
||||
if err != nil {
|
||||
logging.WithError(err).WithField("host", host).Warning("failed to extract cookie domain from request host")
|
||||
}
|
||||
}
|
||||
a.cookieHandler.SetCookie(writer, a.limitConfig.ExhaustedCookieKey, domain, cookieValue)
|
||||
}
|
||||
|
||||
func (a *AccessInterceptor) DeleteExhaustedCookie(writer http.ResponseWriter) {
|
||||
a.cookieHandler.DeleteCookie(writer, a.limitConfig.ExhaustedCookieKey)
|
||||
}
|
||||
|
||||
func (a *AccessInterceptor) Handle(next http.Handler) http.Handler {
|
||||
if !a.svc.Enabled() {
|
||||
return next
|
||||
@@ -49,23 +89,16 @@ func (a *AccessInterceptor) Handle(next http.Handler) http.Handler {
|
||||
ctx := request.Context()
|
||||
tracingCtx, checkSpan := tracing.NewNamedSpan(ctx, "checkAccess")
|
||||
wrappedWriter := &statusRecorder{ResponseWriter: writer, status: 0}
|
||||
instance := authz.GetInstance(ctx)
|
||||
limit := false
|
||||
if !a.storeOnly {
|
||||
remaining := a.svc.Limit(tracingCtx, instance.InstanceID())
|
||||
limit = remaining != nil && *remaining == 0
|
||||
}
|
||||
limited := a.Limit(tracingCtx)
|
||||
checkSpan.End()
|
||||
if limit {
|
||||
// Limit can only be true when storeOnly is false, so set the cookie and the response code
|
||||
SetExhaustedCookie(a.cookieHandler, wrappedWriter, a.limitConfig, request)
|
||||
if limited {
|
||||
a.SetExhaustedCookie(wrappedWriter, request)
|
||||
http.Error(wrappedWriter, "quota for authenticated requests is exhausted", http.StatusTooManyRequests)
|
||||
} else {
|
||||
if !a.storeOnly {
|
||||
// If not limited and not storeOnly, ensure the cookie is deleted
|
||||
DeleteExhaustedCookie(a.cookieHandler, wrappedWriter, request, a.limitConfig)
|
||||
}
|
||||
// Always serve if not limited
|
||||
}
|
||||
if !limited && !a.storeOnly {
|
||||
a.DeleteExhaustedCookie(wrappedWriter)
|
||||
}
|
||||
if !limited {
|
||||
next.ServeHTTP(wrappedWriter, request)
|
||||
}
|
||||
tracingCtx, writeSpan := tracing.NewNamedSpan(tracingCtx, "writeAccess")
|
||||
@@ -75,6 +108,7 @@ func (a *AccessInterceptor) Handle(next http.Handler) http.Handler {
|
||||
if err != nil {
|
||||
logging.WithError(err).WithField("url", requestURL).Warning("failed to unescape request url")
|
||||
}
|
||||
instance := authz.GetInstance(tracingCtx)
|
||||
a.svc.Handle(tracingCtx, &access.Record{
|
||||
LogDate: time.Now(),
|
||||
Protocol: access.HTTP,
|
||||
@@ -90,24 +124,6 @@ func (a *AccessInterceptor) Handle(next http.Handler) http.Handler {
|
||||
})
|
||||
}
|
||||
|
||||
func SetExhaustedCookie(cookieHandler *http_utils.CookieHandler, writer http.ResponseWriter, cookieConfig *AccessConfig, request *http.Request) {
|
||||
cookieValue := "true"
|
||||
host := request.Header.Get(middleware.HTTP1Host)
|
||||
domain := host
|
||||
if strings.ContainsAny(host, ":") {
|
||||
var err error
|
||||
domain, _, err = net.SplitHostPort(host)
|
||||
if err != nil {
|
||||
logging.WithError(err).WithField("host", host).Warning("failed to extract cookie domain from request host")
|
||||
}
|
||||
}
|
||||
cookieHandler.SetCookie(writer, cookieConfig.ExhaustedCookieKey, domain, cookieValue)
|
||||
}
|
||||
|
||||
func DeleteExhaustedCookie(cookieHandler *http_utils.CookieHandler, writer http.ResponseWriter, request *http.Request, cookieConfig *AccessConfig) {
|
||||
cookieHandler.DeleteCookie(writer, request, cookieConfig.ExhaustedCookieKey)
|
||||
}
|
||||
|
||||
type statusRecorder struct {
|
||||
http.ResponseWriter
|
||||
status int
|
||||
|
246
internal/api/idp/idp.go
Normal file
246
internal/api/idp/idp.go
Normal file
@@ -0,0 +1,246 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
http_utils "github.com/zitadel/zitadel/internal/api/http"
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
z_errs "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/form"
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/azuread"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/github"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/gitlab"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/google"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/jwt"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/ldap"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
|
||||
openid "github.com/zitadel/zitadel/internal/idp/providers/oidc"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
)
|
||||
|
||||
const (
|
||||
HandlerPrefix = "/idps"
|
||||
callbackPath = "/callback"
|
||||
|
||||
paramIntentID = "id"
|
||||
paramToken = "token"
|
||||
paramUserID = "user"
|
||||
paramError = "error"
|
||||
paramErrorDescription = "error_description"
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
commands *command.Commands
|
||||
queries *query.Queries
|
||||
parser *form.Parser
|
||||
encryptionAlgorithm crypto.EncryptionAlgorithm
|
||||
callbackURL func(ctx context.Context) string
|
||||
}
|
||||
|
||||
type externalIDPCallbackData struct {
|
||||
State string `schema:"state"`
|
||||
Code string `schema:"code"`
|
||||
Error string `schema:"error"`
|
||||
ErrorDescription string `schema:"error_description"`
|
||||
}
|
||||
|
||||
// CallbackURL generates the instance specific URL to the IDP callback handler
|
||||
func CallbackURL(externalSecure bool) func(ctx context.Context) string {
|
||||
return func(ctx context.Context) string {
|
||||
return http_utils.BuildOrigin(authz.GetInstance(ctx).RequestedHost(), externalSecure) + HandlerPrefix + callbackPath
|
||||
}
|
||||
}
|
||||
|
||||
func NewHandler(
|
||||
commands *command.Commands,
|
||||
queries *query.Queries,
|
||||
encryptionAlgorithm crypto.EncryptionAlgorithm,
|
||||
externalSecure bool,
|
||||
instanceInterceptor func(next http.Handler) http.Handler,
|
||||
) http.Handler {
|
||||
h := &Handler{
|
||||
commands: commands,
|
||||
queries: queries,
|
||||
parser: form.NewParser(),
|
||||
encryptionAlgorithm: encryptionAlgorithm,
|
||||
callbackURL: CallbackURL(externalSecure),
|
||||
}
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.Use(instanceInterceptor)
|
||||
router.HandleFunc(callbackPath, h.handleCallback)
|
||||
return router
|
||||
}
|
||||
|
||||
func (h *Handler) handleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
data, err := h.parseCallbackRequest(r)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
intent := h.getActiveIntent(w, r, data.State)
|
||||
if intent == nil {
|
||||
// if we didn't get an active intent the error was already handled (either redirected or display directly)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
// the provider might have returned an error
|
||||
if data.Error != "" {
|
||||
cmdErr := h.commands.FailIDPIntent(ctx, intent, reason(data.Error, data.ErrorDescription))
|
||||
logging.WithFields("intent", intent.AggregateID).OnError(cmdErr).Error("failed to push failed event on idp intent")
|
||||
redirectToFailureURL(w, r, intent, data.Error, data.ErrorDescription)
|
||||
return
|
||||
}
|
||||
|
||||
provider, err := h.commands.GetProvider(ctx, intent.IDPID, h.callbackURL(ctx))
|
||||
if err != nil {
|
||||
cmdErr := h.commands.FailIDPIntent(ctx, intent, err.Error())
|
||||
logging.WithFields("intent", intent.AggregateID).OnError(cmdErr).Error("failed to push failed event on idp intent")
|
||||
redirectToFailureURLErr(w, r, intent, err)
|
||||
return
|
||||
}
|
||||
|
||||
idpUser, idpSession, err := h.fetchIDPUser(ctx, provider, data.Code)
|
||||
if err != nil {
|
||||
cmdErr := h.commands.FailIDPIntent(ctx, intent, err.Error())
|
||||
logging.WithFields("intent", intent.AggregateID).OnError(cmdErr).Error("failed to push failed event on idp intent")
|
||||
redirectToFailureURLErr(w, r, intent, err)
|
||||
return
|
||||
}
|
||||
userID, err := h.checkExternalUser(ctx, intent.IDPID, idpUser.GetID())
|
||||
logging.WithFields("intent", intent.AggregateID).OnError(err).Error("could not check if idp user already exists")
|
||||
|
||||
token, err := h.commands.SucceedIDPIntent(ctx, intent, idpUser, idpSession, userID)
|
||||
if err != nil {
|
||||
redirectToFailureURLErr(w, r, intent, z_errs.ThrowInternal(err, "IDP-JdD3g", "Errors.Intent.TokenCreationFailed"))
|
||||
return
|
||||
}
|
||||
redirectToSuccessURL(w, r, intent, token, userID)
|
||||
}
|
||||
|
||||
func (h *Handler) parseCallbackRequest(r *http.Request) (*externalIDPCallbackData, error) {
|
||||
data := new(externalIDPCallbackData)
|
||||
err := h.parser.Parse(r, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if data.State == "" {
|
||||
return nil, z_errs.ThrowInvalidArgument(nil, "IDP-Hk38e", "Errors.Intent.StateMissing")
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (h *Handler) getActiveIntent(w http.ResponseWriter, r *http.Request, state string) *command.IDPIntentWriteModel {
|
||||
intent, err := h.commands.GetIntentWriteModel(r.Context(), state, "")
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return nil
|
||||
}
|
||||
if intent.State == domain.IDPIntentStateUnspecified {
|
||||
http.Error(w, reason("IDP-Hk38e", "Errors.Intent.NotStarted"), http.StatusBadRequest)
|
||||
return nil
|
||||
}
|
||||
if intent.State != domain.IDPIntentStateStarted {
|
||||
redirectToFailureURL(w, r, intent, "IDP-Sfrgs", "Errors.Intent.NotStarted")
|
||||
return nil
|
||||
}
|
||||
return intent
|
||||
}
|
||||
|
||||
func redirectToSuccessURL(w http.ResponseWriter, r *http.Request, intent *command.IDPIntentWriteModel, token, userID string) {
|
||||
queries := intent.SuccessURL.Query()
|
||||
queries.Set(paramIntentID, intent.AggregateID)
|
||||
queries.Set(paramToken, token)
|
||||
if userID != "" {
|
||||
queries.Set(paramUserID, userID)
|
||||
}
|
||||
intent.SuccessURL.RawQuery = queries.Encode()
|
||||
http.Redirect(w, r, intent.SuccessURL.String(), http.StatusFound)
|
||||
}
|
||||
|
||||
func redirectToFailureURLErr(w http.ResponseWriter, r *http.Request, i *command.IDPIntentWriteModel, err error) {
|
||||
msg := err.Error()
|
||||
var description string
|
||||
zErr := new(z_errs.CaosError)
|
||||
if errors.As(err, &zErr) {
|
||||
msg = zErr.GetID()
|
||||
description = zErr.GetMessage() // TODO: i18n?
|
||||
}
|
||||
redirectToFailureURL(w, r, i, msg, description)
|
||||
}
|
||||
|
||||
func redirectToFailureURL(w http.ResponseWriter, r *http.Request, i *command.IDPIntentWriteModel, err, description string) {
|
||||
queries := i.FailureURL.Query()
|
||||
queries.Set(paramIntentID, i.AggregateID)
|
||||
queries.Set(paramError, err)
|
||||
queries.Set(paramErrorDescription, description)
|
||||
i.FailureURL.RawQuery = queries.Encode()
|
||||
http.Redirect(w, r, i.FailureURL.String(), http.StatusFound)
|
||||
}
|
||||
|
||||
func (h *Handler) fetchIDPUser(ctx context.Context, identityProvider idp.Provider, code string) (user idp.User, idpTokens idp.Session, err error) {
|
||||
var session idp.Session
|
||||
switch provider := identityProvider.(type) {
|
||||
case *oauth.Provider:
|
||||
session = &oauth.Session{Provider: provider, Code: code}
|
||||
case *openid.Provider:
|
||||
session = &openid.Session{Provider: provider, Code: code}
|
||||
case *azuread.Provider:
|
||||
session = &oauth.Session{Provider: provider.Provider, Code: code}
|
||||
case *github.Provider:
|
||||
session = &oauth.Session{Provider: provider.Provider, Code: code}
|
||||
case *gitlab.Provider:
|
||||
session = &openid.Session{Provider: provider.Provider, Code: code}
|
||||
case *google.Provider:
|
||||
session = &openid.Session{Provider: provider.Provider, Code: code}
|
||||
case *jwt.Provider, *ldap.Provider:
|
||||
return nil, nil, z_errs.ThrowInvalidArgument(nil, "IDP-52jmn", "Errors.ExternalIDP.IDPTypeNotImplemented")
|
||||
default:
|
||||
return nil, nil, z_errs.ThrowUnimplemented(nil, "IDP-SSDg", "Errors.ExternalIDP.IDPTypeNotImplemented")
|
||||
}
|
||||
|
||||
user, err = session.FetchUser(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return user, session, nil
|
||||
}
|
||||
|
||||
func (h *Handler) checkExternalUser(ctx context.Context, idpID, externalUserID string) (userID string, err error) {
|
||||
idQuery, err := query.NewIDPUserLinkIDPIDSearchQuery(idpID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
externalIDQuery, err := query.NewIDPUserLinksExternalIDSearchQuery(externalUserID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
queries := []query.SearchQuery{
|
||||
idQuery, externalIDQuery,
|
||||
}
|
||||
links, err := h.queries.IDPUserLinks(ctx, &query.IDPUserLinksSearchQuery{Queries: queries}, false)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(links.Links) != 1 {
|
||||
return "", nil
|
||||
}
|
||||
return links.Links[0].UserID, nil
|
||||
}
|
||||
|
||||
func reason(err, description string) string {
|
||||
if description == "" {
|
||||
return err
|
||||
}
|
||||
return err + ": " + description
|
||||
}
|
220
internal/api/idp/idp_test.go
Normal file
220
internal/api/idp/idp_test.go
Normal file
@@ -0,0 +1,220 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
z_errors "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/form"
|
||||
)
|
||||
|
||||
func Test_redirectToSuccessURL(t *testing.T) {
|
||||
type args struct {
|
||||
id string
|
||||
userID string
|
||||
token string
|
||||
failureURL string
|
||||
successURL string
|
||||
}
|
||||
type res struct {
|
||||
want string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"redirect",
|
||||
args{
|
||||
id: "id",
|
||||
token: "token",
|
||||
failureURL: "https://example.com/failure",
|
||||
successURL: "https://example.com/success",
|
||||
},
|
||||
res{
|
||||
"https://example.com/success?id=id&token=token",
|
||||
},
|
||||
},
|
||||
{
|
||||
"redirect with userID",
|
||||
args{
|
||||
id: "id",
|
||||
userID: "user",
|
||||
token: "token",
|
||||
failureURL: "https://example.com/failure",
|
||||
successURL: "https://example.com/success",
|
||||
},
|
||||
res{
|
||||
"https://example.com/success?id=id&token=token&user=user",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
|
||||
wm := command.NewIDPIntentWriteModel(tt.args.id, tt.args.id)
|
||||
wm.FailureURL, _ = url.Parse(tt.args.failureURL)
|
||||
wm.SuccessURL, _ = url.Parse(tt.args.successURL)
|
||||
|
||||
redirectToSuccessURL(resp, req, wm, tt.args.token, tt.args.userID)
|
||||
assert.Equal(t, tt.res.want, resp.Header().Get("Location"))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_redirectToFailureURL(t *testing.T) {
|
||||
type args struct {
|
||||
id string
|
||||
failureURL string
|
||||
successURL string
|
||||
err string
|
||||
desc string
|
||||
}
|
||||
type res struct {
|
||||
want string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"redirect",
|
||||
args{
|
||||
id: "id",
|
||||
failureURL: "https://example.com/failure",
|
||||
successURL: "https://example.com/success",
|
||||
},
|
||||
res{
|
||||
"https://example.com/failure?error=&error_description=&id=id",
|
||||
},
|
||||
},
|
||||
{
|
||||
"redirect with error",
|
||||
args{
|
||||
id: "id",
|
||||
failureURL: "https://example.com/failure",
|
||||
successURL: "https://example.com/success",
|
||||
err: "test",
|
||||
desc: "testdesc",
|
||||
},
|
||||
res{
|
||||
"https://example.com/failure?error=test&error_description=testdesc&id=id",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
|
||||
wm := command.NewIDPIntentWriteModel(tt.args.id, tt.args.id)
|
||||
wm.FailureURL, _ = url.Parse(tt.args.failureURL)
|
||||
wm.SuccessURL, _ = url.Parse(tt.args.successURL)
|
||||
|
||||
redirectToFailureURL(resp, req, wm, tt.args.err, tt.args.desc)
|
||||
assert.Equal(t, tt.res.want, resp.Header().Get("Location"))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_redirectToFailureURLErr(t *testing.T) {
|
||||
type args struct {
|
||||
id string
|
||||
failureURL string
|
||||
successURL string
|
||||
err error
|
||||
}
|
||||
type res struct {
|
||||
want string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"redirect with error",
|
||||
args{
|
||||
id: "id",
|
||||
failureURL: "https://example.com/failure",
|
||||
successURL: "https://example.com/success",
|
||||
err: z_errors.ThrowError(nil, "test", "testdesc"),
|
||||
},
|
||||
res{
|
||||
"https://example.com/failure?error=test&error_description=testdesc&id=id",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
|
||||
wm := command.NewIDPIntentWriteModel(tt.args.id, tt.args.id)
|
||||
wm.FailureURL, _ = url.Parse(tt.args.failureURL)
|
||||
wm.SuccessURL, _ = url.Parse(tt.args.successURL)
|
||||
|
||||
redirectToFailureURLErr(resp, req, wm, tt.args.err)
|
||||
assert.Equal(t, tt.res.want, resp.Header().Get("Location"))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_parseCallbackRequest(t *testing.T) {
|
||||
type args struct {
|
||||
url string
|
||||
}
|
||||
type res struct {
|
||||
want *externalIDPCallbackData
|
||||
err bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"no state",
|
||||
args{
|
||||
url: "https://example.com?state=&code=code&error=error&error_description=desc",
|
||||
},
|
||||
res{
|
||||
err: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
"parse",
|
||||
args{
|
||||
url: "https://example.com?state=state&code=code&error=error&error_description=desc",
|
||||
},
|
||||
res{
|
||||
want: &externalIDPCallbackData{
|
||||
State: "state",
|
||||
Code: "code",
|
||||
Error: "error",
|
||||
ErrorDescription: "desc",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", tt.args.url, nil)
|
||||
handler := Handler{parser: form.NewParser()}
|
||||
|
||||
data, err := handler.parseCallbackRequest(req)
|
||||
if tt.res.err {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
assert.Equal(t, tt.res.want, data)
|
||||
})
|
||||
}
|
||||
}
|
@@ -91,7 +91,7 @@ func (f *file) Stat() (_ fs.FileInfo, err error) {
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func Start(config Config, externalSecure bool, issuer op.IssuerFromRequest, callDurationInterceptor, instanceHandler, accessInterceptor func(http.Handler) http.Handler, customerPortal string) (http.Handler, error) {
|
||||
func Start(config Config, externalSecure bool, issuer op.IssuerFromRequest, callDurationInterceptor, instanceHandler func(http.Handler) http.Handler, limitingAccessInterceptor *middleware.AccessInterceptor, customerPortal string) (http.Handler, error) {
|
||||
fSys, err := fs.Sub(static, "static")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -106,20 +106,27 @@ func Start(config Config, externalSecure bool, issuer op.IssuerFromRequest, call
|
||||
|
||||
handler := mux.NewRouter()
|
||||
|
||||
handler.Use(callDurationInterceptor, instanceHandler, security, accessInterceptor)
|
||||
handler.Use(callDurationInterceptor, instanceHandler, security, limitingAccessInterceptor.WithoutLimiting().Handle)
|
||||
handler.Handle(envRequestPath, middleware.TelemetryHandler()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
url := http_util.BuildOrigin(r.Host, externalSecure)
|
||||
instance := authz.GetInstance(r.Context())
|
||||
ctx := r.Context()
|
||||
instance := authz.GetInstance(ctx)
|
||||
instanceMgmtURL, err := templateInstanceManagementURL(config.InstanceManagementURL, instance)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("unable to template instance management url for console: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
environmentJSON, err := createEnvironmentJSON(url, issuer(r), instance.ConsoleClientID(), customerPortal, instanceMgmtURL)
|
||||
exhausted := limitingAccessInterceptor.Limit(ctx)
|
||||
environmentJSON, err := createEnvironmentJSON(url, issuer(r), instance.ConsoleClientID(), customerPortal, instanceMgmtURL, exhausted)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("unable to marshal env for console: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if exhausted {
|
||||
limitingAccessInterceptor.SetExhaustedCookie(w, r)
|
||||
} else {
|
||||
limitingAccessInterceptor.DeleteExhaustedCookie(w)
|
||||
}
|
||||
_, err = w.Write(environmentJSON)
|
||||
logging.OnError(err).Error("error serving environment.json")
|
||||
})))
|
||||
@@ -148,19 +155,21 @@ func csp() *middleware.CSP {
|
||||
return &csp
|
||||
}
|
||||
|
||||
func createEnvironmentJSON(api, issuer, clientID, customerPortal, instanceMgmtUrl string) ([]byte, error) {
|
||||
func createEnvironmentJSON(api, issuer, clientID, customerPortal, instanceMgmtUrl string, exhausted bool) ([]byte, error) {
|
||||
environment := struct {
|
||||
API string `json:"api,omitempty"`
|
||||
Issuer string `json:"issuer,omitempty"`
|
||||
ClientID string `json:"clientid,omitempty"`
|
||||
CustomerPortal string `json:"customer_portal,omitempty"`
|
||||
InstanceManagementURL string `json:"instance_management_url,omitempty"`
|
||||
Exhausted bool `json:"exhausted,omitempty"`
|
||||
}{
|
||||
API: api,
|
||||
Issuer: issuer,
|
||||
ClientID: clientID,
|
||||
CustomerPortal: customerPortal,
|
||||
InstanceManagementURL: instanceMgmtUrl,
|
||||
Exhausted: exhausted,
|
||||
}
|
||||
return json.Marshal(environment)
|
||||
}
|
||||
|
Reference in New Issue
Block a user