Merge branch 'rc' into next-rc

This commit is contained in:
adlerhurst
2023-06-01 13:12:00 +02:00
178 changed files with 10804 additions and 3303 deletions

View File

@@ -12,6 +12,7 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/health"
healthpb "google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/grpc/reflection"
internal_authz "github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/grpc/server"
@@ -19,24 +20,22 @@ import (
http_mw "github.com/zitadel/zitadel/internal/api/http/middleware"
"github.com/zitadel/zitadel/internal/api/ui/login"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/logstore"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/telemetry/metrics"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
type API struct {
port uint16
grpcServer *grpc.Server
verifier *internal_authz.TokenVerifier
health healthCheck
router *mux.Router
http1HostName string
grpcGateway *server.Gateway
healthServer *health.Server
cookieHandler *http_util.CookieHandler
cookieConfig *http_mw.AccessConfig
queries *query.Queries
port uint16
grpcServer *grpc.Server
verifier *internal_authz.TokenVerifier
health healthCheck
router *mux.Router
http1HostName string
grpcGateway *server.Gateway
healthServer *health.Server
accessInterceptor *http_mw.AccessInterceptor
queries *query.Queries
}
type healthCheck interface {
@@ -51,23 +50,20 @@ func New(
verifier *internal_authz.TokenVerifier,
authZ internal_authz.Config,
tlsConfig *tls.Config, http2HostName, http1HostName string,
accessSvc *logstore.Service,
cookieHandler *http_util.CookieHandler,
cookieConfig *http_mw.AccessConfig,
accessInterceptor *http_mw.AccessInterceptor,
) (_ *API, err error) {
api := &API{
port: port,
verifier: verifier,
health: queries,
router: router,
http1HostName: http1HostName,
cookieConfig: cookieConfig,
cookieHandler: cookieHandler,
queries: queries,
port: port,
verifier: verifier,
health: queries,
router: router,
http1HostName: http1HostName,
queries: queries,
accessInterceptor: accessInterceptor,
}
api.grpcServer = server.CreateServer(api.verifier, authZ, queries, http2HostName, tlsConfig, accessSvc)
api.grpcGateway, err = server.CreateGateway(ctx, port, http1HostName, cookieHandler, cookieConfig)
api.grpcServer = server.CreateServer(api.verifier, authZ, queries, http2HostName, tlsConfig, accessInterceptor.AccessService())
api.grpcGateway, err = server.CreateGateway(ctx, port, http1HostName, accessInterceptor)
if err != nil {
return nil, err
}
@@ -76,6 +72,7 @@ func New(
api.RegisterHandlerOnPrefix("/debug", api.healthHandler())
api.router.Handle("/", http.RedirectHandler(login.HandlerPrefix, http.StatusFound))
reflection.Register(api.grpcServer)
return api, nil
}
@@ -90,8 +87,7 @@ func (a *API) RegisterServer(ctx context.Context, grpcServer server.WithGatewayP
grpcServer,
a.port,
a.http1HostName,
a.cookieHandler,
a.cookieConfig,
a.accessInterceptor,
a.queries,
)
if err != nil {

View File

@@ -0,0 +1,16 @@
package authz
import (
"context"
"github.com/zitadel/zitadel/internal/errors"
)
// UserIDInCTX checks if the userID
// equals the authenticated user in the context.
func UserIDInCTX(ctx context.Context, userID string) error {
if GetCtxData(ctx).UserID != userID {
return errors.ThrowUnauthenticated(nil, "AUTH-Bohd2", "Errors.User.UserIDWrong")
}
return nil
}

View File

@@ -629,7 +629,7 @@ func (s *Server) importData(ctx context.Context, orgs []*admin_pb.DataOrg) (*adm
ExternalUserID: userLinks.ProvidedUserId,
DisplayName: userLinks.ProvidedUserName,
}
if err := s.command.AddUserIDPLink(ctx, userLinks.UserId, org.GetOrgId(), externalIDP); err != nil {
if _, err := s.command.AddUserIDPLink(ctx, userLinks.UserId, org.GetOrgId(), externalIDP); err != nil {
errors = append(errors, &admin_pb.ImportDataError{Type: "user_link", Id: userLinks.UserId + "_" + userLinks.IdpId, Message: err.Error()})
if isCtxTimeout(ctx) {
return &admin_pb.ImportDataResponse{Errors: errors, Success: success}, count, err

View File

@@ -0,0 +1,36 @@
package grpc
import (
"testing"
"google.golang.org/protobuf/reflect/protoreflect"
)
// AllFieldsSet recusively checks if all values in a message
// have a non-zero value.
func AllFieldsSet(t testing.TB, msg protoreflect.Message, ignoreTypes ...protoreflect.FullName) {
ignore := make(map[protoreflect.FullName]bool, len(ignoreTypes))
for _, name := range ignoreTypes {
ignore[name] = true
}
md := msg.Descriptor()
name := md.FullName()
if ignore[name] {
return
}
fields := md.Fields()
for i := 0; i < fields.Len(); i++ {
fd := fields.Get(i)
if !msg.Has(fd) {
t.Errorf("not all fields set in %q, missing %q", name, fd.Name())
continue
}
if fd.Kind() == protoreflect.MessageKind {
AllFieldsSet(t, msg.Get(fd).Message(), ignoreTypes...)
}
}
}

View File

@@ -241,7 +241,6 @@ func AddHumanUserRequestToAddHuman(req *mgmt_pb.AddHumanUserRequest) *command.Ad
PasswordChangeRequired: true,
Passwordless: false,
Register: false,
ExternalIDP: false,
}
if req.Phone != nil {
human.Phone = command.Phone{

View File

@@ -1,5 +1,12 @@
package grpc
import (
"path"
"google.golang.org/grpc"
"google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
)
const (
Healthz = "/Healthz"
Readiness = "/Ready"
@@ -9,3 +16,18 @@ const (
var (
Probes = []string{Healthz, Readiness, Validation}
)
func init() {
Probes = append(Probes, AllPaths(grpc_reflection_v1alpha.ServerReflection_ServiceDesc)...)
}
func AllPaths(sd grpc.ServiceDesc) []string {
paths := make([]string, 0, len(sd.Methods)+len(sd.Streams))
for _, method := range sd.Methods {
paths = append(paths, path.Join("/", sd.ServiceName, method.MethodName))
}
for _, stream := range sd.Streams {
paths = append(paths, path.Join("/", sd.ServiceName, stream.StreamName))
}
return paths
}

View File

@@ -0,0 +1,32 @@
package grpc
import (
"testing"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
"google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
)
func TestAllPaths(t *testing.T) {
type args struct {
sd grpc.ServiceDesc
}
tests := []struct {
name string
args args
want []string
}{
{
name: "server reflection",
args: args{grpc_reflection_v1alpha.ServerReflection_ServiceDesc},
want: []string{"/grpc.reflection.v1alpha.ServerReflection/ServerReflectionInfo"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := AllPaths(tt.args.sd)
assert.Equal(t, tt.want, got)
})
}
}

View File

@@ -16,7 +16,6 @@ import (
client_middleware "github.com/zitadel/zitadel/internal/api/grpc/client/middleware"
"github.com/zitadel/zitadel/internal/api/grpc/server/middleware"
http_utils "github.com/zitadel/zitadel/internal/api/http"
http_mw "github.com/zitadel/zitadel/internal/api/http/middleware"
"github.com/zitadel/zitadel/internal/query"
)
@@ -66,16 +65,15 @@ var (
)
type Gateway struct {
mux *runtime.ServeMux
http1HostName string
connection *grpc.ClientConn
cookieHandler *http_utils.CookieHandler
cookieConfig *http_mw.AccessConfig
queries *query.Queries
mux *runtime.ServeMux
http1HostName string
connection *grpc.ClientConn
accessInterceptor *http_mw.AccessInterceptor
queries *query.Queries
}
func (g *Gateway) Handler() http.Handler {
return addInterceptors(g.mux, g.http1HostName, g.cookieHandler, g.cookieConfig, g.queries)
return addInterceptors(g.mux, g.http1HostName, g.accessInterceptor, g.queries)
}
type CustomHTTPResponse interface {
@@ -89,8 +87,7 @@ func CreateGatewayWithPrefix(
g WithGatewayPrefix,
port uint16,
http1HostName string,
cookieHandler *http_utils.CookieHandler,
cookieConfig *http_mw.AccessConfig,
accessInterceptor *http_mw.AccessInterceptor,
queries *query.Queries,
) (http.Handler, string, error) {
runtimeMux := runtime.NewServeMux(serveMuxOptions...)
@@ -106,10 +103,10 @@ func CreateGatewayWithPrefix(
if err != nil {
return nil, "", fmt.Errorf("failed to register grpc gateway: %w", err)
}
return addInterceptors(runtimeMux, http1HostName, cookieHandler, cookieConfig, queries), g.GatewayPathPrefix(), nil
return addInterceptors(runtimeMux, http1HostName, accessInterceptor, queries), g.GatewayPathPrefix(), nil
}
func CreateGateway(ctx context.Context, port uint16, http1HostName string, cookieHandler *http_utils.CookieHandler, cookieConfig *http_mw.AccessConfig) (*Gateway, error) {
func CreateGateway(ctx context.Context, port uint16, http1HostName string, accessInterceptor *http_mw.AccessInterceptor) (*Gateway, error) {
connection, err := dial(ctx,
port,
[]grpc.DialOption{
@@ -121,11 +118,10 @@ func CreateGateway(ctx context.Context, port uint16, http1HostName string, cooki
}
runtimeMux := runtime.NewServeMux(append(serveMuxOptions, runtime.WithHealthzEndpoint(healthpb.NewHealthClient(connection)))...)
return &Gateway{
mux: runtimeMux,
http1HostName: http1HostName,
connection: connection,
cookieHandler: cookieHandler,
cookieConfig: cookieConfig,
mux: runtimeMux,
http1HostName: http1HostName,
connection: connection,
accessInterceptor: accessInterceptor,
}, nil
}
@@ -163,8 +159,7 @@ func dial(ctx context.Context, port uint16, opts []grpc.DialOption) (*grpc.Clien
func addInterceptors(
handler http.Handler,
http1HostName string,
cookieHandler *http_utils.CookieHandler,
cookieConfig *http_mw.AccessConfig,
accessInterceptor *http_mw.AccessInterceptor,
queries *query.Queries,
) http.Handler {
handler = http_mw.CallDurationHandler(handler)
@@ -174,7 +169,7 @@ func addInterceptors(
handler = http_mw.DefaultTelemetryHandler(handler)
// For some non-obvious reason, the exhaustedCookieInterceptor sends the SetCookie header
// only if it follows the http_mw.DefaultTelemetryHandler
handler = exhaustedCookieInterceptor(handler, cookieHandler, cookieConfig, queries)
handler = exhaustedCookieInterceptor(handler, accessInterceptor, queries)
handler = http_mw.DefaultMetricsHandler(handler)
return handler
}
@@ -193,35 +188,32 @@ func http1Host(next http.Handler, http1HostName string) http.Handler {
func exhaustedCookieInterceptor(
next http.Handler,
cookieHandler *http_utils.CookieHandler,
cookieConfig *http_mw.AccessConfig,
accessInterceptor *http_mw.AccessInterceptor,
queries *query.Queries,
) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
next.ServeHTTP(&cookieResponseWriter{
ResponseWriter: writer,
cookieHandler: cookieHandler,
cookieConfig: cookieConfig,
request: request,
queries: queries,
ResponseWriter: writer,
accessInterceptor: accessInterceptor,
request: request,
queries: queries,
}, request)
})
}
type cookieResponseWriter struct {
http.ResponseWriter
cookieHandler *http_utils.CookieHandler
cookieConfig *http_mw.AccessConfig
request *http.Request
queries *query.Queries
accessInterceptor *http_mw.AccessInterceptor
request *http.Request
queries *query.Queries
}
func (r *cookieResponseWriter) WriteHeader(status int) {
if status >= 200 && status < 300 {
http_mw.DeleteExhaustedCookie(r.cookieHandler, r.ResponseWriter, r.request, r.cookieConfig)
r.accessInterceptor.DeleteExhaustedCookie(r.ResponseWriter)
}
if status == http.StatusTooManyRequests {
http_mw.SetExhaustedCookie(r.cookieHandler, r.ResponseWriter, r.cookieConfig, r.request)
r.accessInterceptor.SetExhaustedCookie(r.ResponseWriter, r.request)
}
r.ResponseWriter.WriteHeader(status)
}

View File

@@ -11,38 +11,13 @@ import (
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/known/durationpb"
"github.com/zitadel/zitadel/internal/api/grpc"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/query"
settings "github.com/zitadel/zitadel/pkg/grpc/settings/v2alpha"
)
var ignoreMessageTypes = map[protoreflect.FullName]bool{
"google.protobuf.Duration": true,
}
// allFieldsSet recusively checks if all values in a message
// have a non-zero value.
func allFieldsSet(t testing.TB, msg protoreflect.Message) {
md := msg.Descriptor()
name := md.FullName()
if ignoreMessageTypes[name] {
return
}
fields := md.Fields()
for i := 0; i < fields.Len(); i++ {
fd := fields.Get(i)
if !msg.Has(fd) {
t.Errorf("not all fields set in %q, missing %q", name, fd.Name())
continue
}
if fd.Kind() == protoreflect.MessageKind {
allFieldsSet(t, msg.Get(fd).Message())
}
}
}
var ignoreTypes = []protoreflect.FullName{"google.protobuf.Duration"}
func Test_loginSettingsToPb(t *testing.T) {
arg := &query.LoginPolicy{
@@ -100,7 +75,7 @@ func Test_loginSettingsToPb(t *testing.T) {
}
got := loginSettingsToPb(arg)
allFieldsSet(t, got.ProtoReflect())
grpc.AllFieldsSet(t, got.ProtoReflect(), ignoreTypes...)
if !proto.Equal(got, want) {
t.Errorf("loginSettingsToPb() =\n%v\nwant\n%v", got, want)
}
@@ -241,7 +216,7 @@ func Test_passwordSettingsToPb(t *testing.T) {
}
got := passwordSettingsToPb(arg)
allFieldsSet(t, got.ProtoReflect())
grpc.AllFieldsSet(t, got.ProtoReflect(), ignoreTypes...)
if !proto.Equal(got, want) {
t.Errorf("passwordSettingsToPb() =\n%v\nwant\n%v", got, want)
}
@@ -295,7 +270,7 @@ func Test_brandingSettingsToPb(t *testing.T) {
}
got := brandingSettingsToPb(arg, "http://example.com")
allFieldsSet(t, got.ProtoReflect())
grpc.AllFieldsSet(t, got.ProtoReflect(), ignoreTypes...)
if !proto.Equal(got, want) {
t.Errorf("brandingSettingsToPb() =\n%v\nwant\n%v", got, want)
}
@@ -315,7 +290,7 @@ func Test_domainSettingsToPb(t *testing.T) {
ResourceOwnerType: settings.ResourceOwnerType_RESOURCE_OWNER_TYPE_INSTANCE,
}
got := domainSettingsToPb(arg)
allFieldsSet(t, got.ProtoReflect())
grpc.AllFieldsSet(t, got.ProtoReflect(), ignoreTypes...)
if !proto.Equal(got, want) {
t.Errorf("domainSettingsToPb() =\n%v\nwant\n%v", got, want)
}
@@ -337,7 +312,7 @@ func Test_legalSettingsToPb(t *testing.T) {
ResourceOwnerType: settings.ResourceOwnerType_RESOURCE_OWNER_TYPE_INSTANCE,
}
got := legalAndSupportSettingsToPb(arg)
allFieldsSet(t, got.ProtoReflect())
grpc.AllFieldsSet(t, got.ProtoReflect(), ignoreTypes...)
if !proto.Equal(got, want) {
t.Errorf("legalSettingsToPb() =\n%v\nwant\n%v", got, want)
}
@@ -353,7 +328,7 @@ func Test_lockoutSettingsToPb(t *testing.T) {
ResourceOwnerType: settings.ResourceOwnerType_RESOURCE_OWNER_TYPE_INSTANCE,
}
got := lockoutSettingsToPb(arg)
allFieldsSet(t, got.ProtoReflect())
grpc.AllFieldsSet(t, got.ProtoReflect(), ignoreTypes...)
if !proto.Equal(got, want) {
t.Errorf("lockoutSettingsToPb() =\n%v\nwant\n%v", got, want)
}
@@ -387,7 +362,7 @@ func Test_identityProvidersToPb(t *testing.T) {
got := identityProvidersToPb(arg)
require.Len(t, got, len(got))
for i, v := range got {
allFieldsSet(t, v.ProtoReflect())
grpc.AllFieldsSet(t, v.ProtoReflect(), ignoreTypes...)
if !proto.Equal(v, want[i]) {
t.Errorf("identityProvidersToPb() =\n%v\nwant\n%v", got, want)
}

View File

@@ -0,0 +1,104 @@
package user
import (
"context"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/grpc/object/v2"
"github.com/zitadel/zitadel/internal/domain"
caos_errs "github.com/zitadel/zitadel/internal/errors"
user "github.com/zitadel/zitadel/pkg/grpc/user/v2alpha"
)
func (s *Server) RegisterPasskey(ctx context.Context, req *user.RegisterPasskeyRequest) (resp *user.RegisterPasskeyResponse, err error) {
var (
resourceOwner = authz.GetCtxData(ctx).ResourceOwner
authenticator = passkeyAuthenticatorToDomain(req.GetAuthenticator())
)
if code := req.GetCode(); code != nil {
return passkeyRegistrationDetailsToPb(
s.command.RegisterUserPasskeyWithCode(ctx, req.GetUserId(), resourceOwner, authenticator, code.Id, code.Code, s.userCodeAlg),
)
}
return passkeyRegistrationDetailsToPb(
s.command.RegisterUserPasskey(ctx, req.GetUserId(), resourceOwner, authenticator),
)
}
func passkeyAuthenticatorToDomain(pa user.PasskeyAuthenticator) domain.AuthenticatorAttachment {
switch pa {
case user.PasskeyAuthenticator_PASSKEY_AUTHENTICATOR_UNSPECIFIED:
return domain.AuthenticatorAttachmentUnspecified
case user.PasskeyAuthenticator_PASSKEY_AUTHENTICATOR_PLATFORM:
return domain.AuthenticatorAttachmentPlattform
case user.PasskeyAuthenticator_PASSKEY_AUTHENTICATOR_CROSS_PLATFORM:
return domain.AuthenticatorAttachmentCrossPlattform
default:
return domain.AuthenticatorAttachmentUnspecified
}
}
func passkeyRegistrationDetailsToPb(details *domain.PasskeyRegistrationDetails, err error) (*user.RegisterPasskeyResponse, error) {
if err != nil {
return nil, err
}
return &user.RegisterPasskeyResponse{
Details: object.DomainToDetailsPb(details.ObjectDetails),
PasskeyId: details.PasskeyID,
PublicKeyCredentialCreationOptions: details.PublicKeyCredentialCreationOptions,
}, nil
}
func (s *Server) VerifyPasskeyRegistration(ctx context.Context, req *user.VerifyPasskeyRegistrationRequest) (*user.VerifyPasskeyRegistrationResponse, error) {
resourceOwner := authz.GetCtxData(ctx).ResourceOwner
objectDetails, err := s.command.HumanHumanPasswordlessSetup(ctx, req.GetUserId(), resourceOwner, req.GetPasskeyName(), "", req.GetPublicKeyCredential())
if err != nil {
return nil, err
}
return &user.VerifyPasskeyRegistrationResponse{
Details: object.DomainToDetailsPb(objectDetails),
}, nil
}
func (s *Server) CreatePasskeyRegistrationLink(ctx context.Context, req *user.CreatePasskeyRegistrationLinkRequest) (resp *user.CreatePasskeyRegistrationLinkResponse, err error) {
resourceOwner := authz.GetCtxData(ctx).ResourceOwner
switch medium := req.Medium.(type) {
case nil:
return passkeyDetailsToPb(
s.command.AddUserPasskeyCode(ctx, req.GetUserId(), resourceOwner, s.userCodeAlg),
)
case *user.CreatePasskeyRegistrationLinkRequest_SendLink:
return passkeyDetailsToPb(
s.command.AddUserPasskeyCodeURLTemplate(ctx, req.GetUserId(), resourceOwner, s.userCodeAlg, medium.SendLink.GetUrlTemplate()),
)
case *user.CreatePasskeyRegistrationLinkRequest_ReturnCode:
return passkeyCodeDetailsToPb(
s.command.AddUserPasskeyCodeReturn(ctx, req.GetUserId(), resourceOwner, s.userCodeAlg),
)
default:
return nil, caos_errs.ThrowUnimplementedf(nil, "USERv2-gaD8y", "verification oneOf %T in method CreatePasskeyRegistrationLink not implemented", medium)
}
}
func passkeyDetailsToPb(details *domain.ObjectDetails, err error) (*user.CreatePasskeyRegistrationLinkResponse, error) {
if err != nil {
return nil, err
}
return &user.CreatePasskeyRegistrationLinkResponse{
Details: object.DomainToDetailsPb(details),
}, nil
}
func passkeyCodeDetailsToPb(details *domain.PasskeyCodeDetails, err error) (*user.CreatePasskeyRegistrationLinkResponse, error) {
if err != nil {
return nil, err
}
return &user.CreatePasskeyRegistrationLinkResponse{
Details: object.DomainToDetailsPb(details.ObjectDetails),
Code: &user.PasskeyRegistrationCode{
Id: details.CodeID,
Code: details.Code,
},
}, nil
}

View File

@@ -0,0 +1,309 @@
//go:build integration
package user_test
import (
"context"
"testing"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/integration"
"github.com/zitadel/zitadel/internal/webauthn"
object "github.com/zitadel/zitadel/pkg/grpc/object/v2alpha"
user "github.com/zitadel/zitadel/pkg/grpc/user/v2alpha"
)
func TestServer_RegisterPasskey(t *testing.T) {
userID := createHumanUser(t).GetUserId()
reg, err := Client.CreatePasskeyRegistrationLink(CTX, &user.CreatePasskeyRegistrationLinkRequest{
UserId: userID,
Medium: &user.CreatePasskeyRegistrationLinkRequest_ReturnCode{},
})
require.NoError(t, err)
client := webauthn.NewClient(Tester.Config.WebAuthNName, Tester.Config.ExternalDomain, "https://"+Tester.Host())
type args struct {
ctx context.Context
req *user.RegisterPasskeyRequest
}
tests := []struct {
name string
args args
want *user.RegisterPasskeyResponse
wantErr bool
}{
{
name: "missing user id",
args: args{
ctx: CTX,
req: &user.RegisterPasskeyRequest{},
},
wantErr: true,
},
{
name: "register code",
args: args{
ctx: CTX,
req: &user.RegisterPasskeyRequest{
UserId: userID,
Code: reg.GetCode(),
Authenticator: user.PasskeyAuthenticator_PASSKEY_AUTHENTICATOR_PLATFORM,
},
},
want: &user.RegisterPasskeyResponse{
Details: &object.Details{
ResourceOwner: Tester.Organisation.ID,
},
},
},
{
name: "reuse code (not allowed)",
args: args{
ctx: CTX,
req: &user.RegisterPasskeyRequest{
UserId: userID,
Code: reg.GetCode(),
Authenticator: user.PasskeyAuthenticator_PASSKEY_AUTHENTICATOR_PLATFORM,
},
},
wantErr: true,
},
{
name: "wrong code",
args: args{
ctx: CTX,
req: &user.RegisterPasskeyRequest{
UserId: userID,
Code: &user.PasskeyRegistrationCode{
Id: reg.GetCode().GetId(),
Code: "foobar",
},
Authenticator: user.PasskeyAuthenticator_PASSKEY_AUTHENTICATOR_CROSS_PLATFORM,
},
},
wantErr: true,
},
{
name: "user mismatch",
args: args{
ctx: CTX,
req: &user.RegisterPasskeyRequest{
UserId: userID,
},
},
wantErr: true,
},
/* TODO after we are able to obtain a Bearer token for a human user
{
name: "human user",
args: args{
ctx: CTX,
req: &user.RegisterPasskeyRequest{
UserId: humanUserID,
},
},
want: &user.RegisterPasskeyResponse{
Details: &object.Details{
ResourceOwner: Tester.Organisation.ID,
},
},
},
*/
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := Client.RegisterPasskey(tt.args.ctx, tt.args.req)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
require.NotNil(t, got)
integration.AssertDetails(t, tt.want, got)
if tt.want != nil {
assert.NotEmpty(t, got.GetPasskeyId())
assert.NotEmpty(t, got.GetPublicKeyCredentialCreationOptions())
_, err := client.CreateAttestationResponse(got.GetPublicKeyCredentialCreationOptions())
require.NoError(t, err)
}
})
}
}
func TestServer_VerifyPasskeyRegistration(t *testing.T) {
userID := createHumanUser(t).GetUserId()
reg, err := Client.CreatePasskeyRegistrationLink(CTX, &user.CreatePasskeyRegistrationLinkRequest{
UserId: userID,
Medium: &user.CreatePasskeyRegistrationLinkRequest_ReturnCode{},
})
require.NoError(t, err)
pkr, err := Client.RegisterPasskey(CTX, &user.RegisterPasskeyRequest{
UserId: userID,
Code: reg.GetCode(),
})
require.NoError(t, err)
require.NotEmpty(t, pkr.GetPasskeyId())
require.NotEmpty(t, pkr.GetPublicKeyCredentialCreationOptions())
client := webauthn.NewClient(Tester.Config.WebAuthNName, Tester.Config.ExternalDomain, "https://"+Tester.Host())
attestationResponse, err := client.CreateAttestationResponse(pkr.GetPublicKeyCredentialCreationOptions())
require.NoError(t, err)
type args struct {
ctx context.Context
req *user.VerifyPasskeyRegistrationRequest
}
tests := []struct {
name string
args args
want *user.VerifyPasskeyRegistrationResponse
wantErr bool
}{
{
name: "missing user id",
args: args{
ctx: CTX,
req: &user.VerifyPasskeyRegistrationRequest{
PasskeyId: pkr.GetPasskeyId(),
PublicKeyCredential: []byte(attestationResponse),
PasskeyName: "nice name",
},
},
wantErr: true,
},
{
name: "success",
args: args{
ctx: CTX,
req: &user.VerifyPasskeyRegistrationRequest{
UserId: userID,
PasskeyId: pkr.GetPasskeyId(),
PublicKeyCredential: attestationResponse,
PasskeyName: "nice name",
},
},
want: &user.VerifyPasskeyRegistrationResponse{
Details: &object.Details{
ResourceOwner: Tester.Organisation.ID,
},
},
},
{
name: "wrong credential",
args: args{
ctx: CTX,
req: &user.VerifyPasskeyRegistrationRequest{
UserId: userID,
PasskeyId: pkr.GetPasskeyId(),
PublicKeyCredential: []byte("attestationResponseattestationResponseattestationResponse"),
PasskeyName: "nice name",
},
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := Client.VerifyPasskeyRegistration(tt.args.ctx, tt.args.req)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
require.NotNil(t, got)
integration.AssertDetails(t, tt.want, got)
})
}
}
func TestServer_CreatePasskeyRegistrationLink(t *testing.T) {
userID := createHumanUser(t).GetUserId()
type args struct {
ctx context.Context
req *user.CreatePasskeyRegistrationLinkRequest
}
tests := []struct {
name string
args args
want *user.CreatePasskeyRegistrationLinkResponse
wantCode bool
wantErr bool
}{
{
name: "missing user id",
args: args{
ctx: CTX,
req: &user.CreatePasskeyRegistrationLinkRequest{},
},
wantErr: true,
},
{
name: "send default mail",
args: args{
ctx: CTX,
req: &user.CreatePasskeyRegistrationLinkRequest{
UserId: userID,
},
},
want: &user.CreatePasskeyRegistrationLinkResponse{
Details: &object.Details{
ResourceOwner: Tester.Organisation.ID,
},
},
},
{
name: "send custom url",
args: args{
ctx: CTX,
req: &user.CreatePasskeyRegistrationLinkRequest{
UserId: userID,
Medium: &user.CreatePasskeyRegistrationLinkRequest_SendLink{
SendLink: &user.SendPasskeyRegistrationLink{
UrlTemplate: gu.Ptr("https://example.com/passkey/register?userID={{.UserID}}&orgID={{.OrgID}}&codeID={{.CodeID}}&code={{.Code}}"),
},
},
},
},
want: &user.CreatePasskeyRegistrationLinkResponse{
Details: &object.Details{
ResourceOwner: Tester.Organisation.ID,
},
},
},
{
name: "return code",
args: args{
ctx: CTX,
req: &user.CreatePasskeyRegistrationLinkRequest{
UserId: userID,
Medium: &user.CreatePasskeyRegistrationLinkRequest_ReturnCode{},
},
},
want: &user.CreatePasskeyRegistrationLinkResponse{
Details: &object.Details{
ResourceOwner: Tester.Organisation.ID,
},
},
wantCode: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := Client.CreatePasskeyRegistrationLink(tt.args.ctx, tt.args.req)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
require.NotNil(t, got)
integration.AssertDetails(t, tt.want, got)
if tt.wantCode {
assert.NotEmpty(t, got.GetCode().GetId())
assert.NotEmpty(t, got.GetCode().GetId())
}
})
}
}

View File

@@ -0,0 +1,210 @@
package user
import (
"io"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/zitadel/zitadel/internal/api/grpc"
"github.com/zitadel/zitadel/internal/domain"
object "github.com/zitadel/zitadel/pkg/grpc/object/v2alpha"
user "github.com/zitadel/zitadel/pkg/grpc/user/v2alpha"
)
func Test_passkeyAuthenticatorToDomain(t *testing.T) {
tests := []struct {
pa user.PasskeyAuthenticator
want domain.AuthenticatorAttachment
}{
{
pa: user.PasskeyAuthenticator_PASSKEY_AUTHENTICATOR_UNSPECIFIED,
want: domain.AuthenticatorAttachmentUnspecified,
},
{
pa: user.PasskeyAuthenticator_PASSKEY_AUTHENTICATOR_PLATFORM,
want: domain.AuthenticatorAttachmentPlattform,
},
{
pa: user.PasskeyAuthenticator_PASSKEY_AUTHENTICATOR_CROSS_PLATFORM,
want: domain.AuthenticatorAttachmentCrossPlattform,
},
{
pa: 999,
want: domain.AuthenticatorAttachmentUnspecified,
},
}
for _, tt := range tests {
t.Run(tt.pa.String(), func(t *testing.T) {
got := passkeyAuthenticatorToDomain(tt.pa)
assert.Equal(t, tt.want, got)
})
}
}
func Test_passkeyRegistrationDetailsToPb(t *testing.T) {
type args struct {
details *domain.PasskeyRegistrationDetails
err error
}
tests := []struct {
name string
args args
want *user.RegisterPasskeyResponse
}{
{
name: "an error",
args: args{
details: nil,
err: io.ErrClosedPipe,
},
},
{
name: "ok",
args: args{
details: &domain.PasskeyRegistrationDetails{
ObjectDetails: &domain.ObjectDetails{
Sequence: 22,
EventDate: time.Unix(3000, 22),
ResourceOwner: "me",
},
PasskeyID: "123",
PublicKeyCredentialCreationOptions: []byte{1, 2, 3},
},
err: nil,
},
want: &user.RegisterPasskeyResponse{
Details: &object.Details{
Sequence: 22,
ChangeDate: &timestamppb.Timestamp{
Seconds: 3000,
Nanos: 22,
},
ResourceOwner: "me",
},
PasskeyId: "123",
PublicKeyCredentialCreationOptions: []byte{1, 2, 3},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := passkeyRegistrationDetailsToPb(tt.args.details, tt.args.err)
require.ErrorIs(t, err, tt.args.err)
assert.Equal(t, tt.want, got)
if tt.want != nil {
grpc.AllFieldsSet(t, got.ProtoReflect())
}
})
}
}
func Test_passkeyDetailsToPb(t *testing.T) {
type args struct {
details *domain.ObjectDetails
err error
}
tests := []struct {
name string
args args
want *user.CreatePasskeyRegistrationLinkResponse
}{
{
name: "an error",
args: args{
details: nil,
err: io.ErrClosedPipe,
},
},
{
name: "ok",
args: args{
details: &domain.ObjectDetails{
Sequence: 22,
EventDate: time.Unix(3000, 22),
ResourceOwner: "me",
},
err: nil,
},
want: &user.CreatePasskeyRegistrationLinkResponse{
Details: &object.Details{
Sequence: 22,
ChangeDate: &timestamppb.Timestamp{
Seconds: 3000,
Nanos: 22,
},
ResourceOwner: "me",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := passkeyDetailsToPb(tt.args.details, tt.args.err)
require.ErrorIs(t, err, tt.args.err)
assert.Equal(t, tt.want, got)
})
}
}
func Test_passkeyCodeDetailsToPb(t *testing.T) {
type args struct {
details *domain.PasskeyCodeDetails
err error
}
tests := []struct {
name string
args args
want *user.CreatePasskeyRegistrationLinkResponse
}{
{
name: "an error",
args: args{
details: nil,
err: io.ErrClosedPipe,
},
},
{
name: "ok",
args: args{
details: &domain.PasskeyCodeDetails{
ObjectDetails: &domain.ObjectDetails{
Sequence: 22,
EventDate: time.Unix(3000, 22),
ResourceOwner: "me",
},
CodeID: "123",
Code: "456",
},
err: nil,
},
want: &user.CreatePasskeyRegistrationLinkResponse{
Details: &object.Details{
Sequence: 22,
ChangeDate: &timestamppb.Timestamp{
Seconds: 3000,
Nanos: 22,
},
ResourceOwner: "me",
},
Code: &user.PasskeyRegistrationCode{
Id: "123",
Code: "456",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := passkeyCodeDetailsToPb(tt.args.details, tt.args.err)
require.ErrorIs(t, err, tt.args.err)
assert.Equal(t, tt.want, got)
if tt.want != nil {
grpc.AllFieldsSet(t, got.ProtoReflect())
}
})
}
}

View File

@@ -1,6 +1,8 @@
package user
import (
"context"
"google.golang.org/grpc"
"github.com/zitadel/zitadel/internal/api/authz"
@@ -18,15 +20,25 @@ type Server struct {
command *command.Commands
query *query.Queries
userCodeAlg crypto.EncryptionAlgorithm
idpAlg crypto.EncryptionAlgorithm
idpCallback func(ctx context.Context) string
}
type Config struct{}
func CreateServer(command *command.Commands, query *query.Queries, userCodeAlg crypto.EncryptionAlgorithm) *Server {
func CreateServer(
command *command.Commands,
query *query.Queries,
userCodeAlg crypto.EncryptionAlgorithm,
idpAlg crypto.EncryptionAlgorithm,
idpCallback func(ctx context.Context) string,
) *Server {
return &Server{
command: command,
query: query,
userCodeAlg: userCodeAlg,
idpAlg: idpAlg,
idpCallback: idpCallback,
}
}

View File

@@ -2,15 +2,19 @@ package user
import (
"context"
"encoding/base64"
"io"
"golang.org/x/text/language"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/grpc/object/v2"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
object_pb "github.com/zitadel/zitadel/pkg/grpc/object/v2alpha"
user "github.com/zitadel/zitadel/pkg/grpc/user/v2alpha"
)
@@ -56,6 +60,14 @@ func addUserRequestToAddHuman(req *user.AddHumanUserRequest) (*command.AddHuman,
Value: metadataEntry.GetValue(),
}
}
links := make([]*command.AddLink, len(req.GetIdpLinks()))
for i, link := range req.GetIdpLinks() {
links[i] = &command.AddLink{
IDPID: link.GetIdpId(),
IDPExternalID: link.GetIdpExternalId(),
DisplayName: link.GetDisplayName(),
}
}
return &command.AddHuman{
ID: req.GetUserId(),
Username: username,
@@ -76,9 +88,9 @@ func addUserRequestToAddHuman(req *user.AddHumanUserRequest) (*command.AddHuman,
BcryptedPassword: bcryptedPassword,
PasswordChangeRequired: passwordChangeRequired,
Passwordless: false,
ExternalIDP: false,
Register: false,
Metadata: metadata,
Links: links,
}, nil
}
@@ -107,3 +119,95 @@ func hashedPasswordToCommand(hashed *user.HashedPassword) (string, error) {
}
return hashed.GetHash(), nil
}
func (s *Server) AddIDPLink(ctx context.Context, req *user.AddIDPLinkRequest) (_ *user.AddIDPLinkResponse, err error) {
orgID := authz.GetCtxData(ctx).OrgID
details, err := s.command.AddUserIDPLink(ctx, req.UserId, orgID, &domain.UserIDPLink{
IDPConfigID: req.GetIdpLink().GetIdpId(),
ExternalUserID: req.GetIdpLink().GetIdpExternalId(),
DisplayName: req.GetIdpLink().GetDisplayName(),
})
if err != nil {
return nil, err
}
return &user.AddIDPLinkResponse{
Details: object.DomainToDetailsPb(details),
}, nil
}
func (s *Server) StartIdentityProviderFlow(ctx context.Context, req *user.StartIdentityProviderFlowRequest) (_ *user.StartIdentityProviderFlowResponse, err error) {
id, details, err := s.command.CreateIntent(ctx, req.GetIdpId(), req.GetSuccessUrl(), req.GetFailureUrl(), authz.GetCtxData(ctx).OrgID)
if err != nil {
return nil, err
}
authURL, err := s.command.AuthURLFromProvider(ctx, req.GetIdpId(), id, s.idpCallback(ctx))
if err != nil {
return nil, err
}
return &user.StartIdentityProviderFlowResponse{
Details: object.DomainToDetailsPb(details),
NextStep: &user.StartIdentityProviderFlowResponse_AuthUrl{AuthUrl: authURL},
}, nil
}
func (s *Server) RetrieveIdentityProviderInformation(ctx context.Context, req *user.RetrieveIdentityProviderInformationRequest) (_ *user.RetrieveIdentityProviderInformationResponse, err error) {
intent, err := s.command.GetIntentWriteModel(ctx, req.GetIntentId(), authz.GetCtxData(ctx).OrgID)
if err != nil {
return nil, err
}
if err := s.checkIntentToken(req.GetToken(), intent.AggregateID); err != nil {
return nil, err
}
if intent.State != domain.IDPIntentStateSucceeded {
return nil, errors.ThrowPreconditionFailed(nil, "IDP-Hk38e", "Errors.Intent.NotSucceeded")
}
return intentToIDPInformationPb(intent, s.idpAlg)
}
func intentToIDPInformationPb(intent *command.IDPIntentWriteModel, alg crypto.EncryptionAlgorithm) (_ *user.RetrieveIdentityProviderInformationResponse, err error) {
var idToken *string
if intent.IDPIDToken != "" {
idToken = &intent.IDPIDToken
}
var accessToken string
if intent.IDPAccessToken != nil {
accessToken, err = crypto.DecryptString(intent.IDPAccessToken, alg)
if err != nil {
return nil, err
}
}
return &user.RetrieveIdentityProviderInformationResponse{
Details: &object_pb.Details{
Sequence: intent.ProcessedSequence,
ChangeDate: timestamppb.New(intent.ChangeDate),
ResourceOwner: intent.ResourceOwner,
},
IdpInformation: &user.IDPInformation{
Access: &user.IDPInformation_Oauth{
Oauth: &user.IDPOAuthAccessInformation{
AccessToken: accessToken,
IdToken: idToken,
},
},
IdpInformation: intent.IDPUser,
},
}, nil
}
func (s *Server) checkIntentToken(token string, intentID string) error {
if token == "" {
return errors.ThrowPermissionDenied(nil, "IDP-Sfefs", "Errors.Intent.InvalidToken")
}
data, err := base64.RawURLEncoding.DecodeString(token)
if err != nil {
return errors.ThrowPermissionDenied(err, "IDP-Swg31", "Errors.Intent.InvalidToken")
}
decryptedToken, err := s.idpAlg.Decrypt(data, s.idpAlg.EncryptionKeyID())
if err != nil {
return errors.ThrowPermissionDenied(err, "IDP-Sf4gt", "Errors.Intent.InvalidToken")
}
if string(decryptedToken) != intentID {
return errors.ThrowPermissionDenied(nil, "IDP-dkje3", "Errors.Intent.InvalidToken")
}
return nil
}

View File

@@ -6,16 +6,24 @@ import (
"context"
"fmt"
"os"
"strings"
"testing"
"time"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v2/pkg/oidc"
"golang.org/x/oauth2"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
"github.com/zitadel/zitadel/internal/integration"
"github.com/zitadel/zitadel/internal/repository/idp"
object "github.com/zitadel/zitadel/pkg/grpc/object/v2alpha"
user "github.com/zitadel/zitadel/pkg/grpc/user/v2alpha"
"google.golang.org/protobuf/types/known/timestamppb"
)
var (
@@ -39,7 +47,60 @@ func TestMain(m *testing.M) {
}())
}
func createProvider(t *testing.T) string {
ctx := authz.WithInstance(context.Background(), Tester.Instance)
id, _, err := Tester.Commands.AddOrgGenericOAuthProvider(ctx, Tester.Organisation.ID, command.GenericOAuthProvider{
"idp",
"clientID",
"clientSecret",
"https://example.com/oauth/v2/authorize",
"https://example.com/oauth/v2/token",
"https://api.example.com/user",
[]string{"openid", "profile", "email"},
"id",
idp.Options{
IsLinkingAllowed: true,
IsCreationAllowed: true,
IsAutoCreation: true,
IsAutoUpdate: true,
},
})
require.NoError(t, err)
return id
}
func createIntent(t *testing.T, idpID string) string {
ctx := authz.WithInstance(context.Background(), Tester.Instance)
id, _, err := Tester.Commands.CreateIntent(ctx, idpID, "https://example.com/success", "https://example.com/failure", Tester.Organisation.ID)
require.NoError(t, err)
return id
}
func createSuccessfulIntent(t *testing.T, idpID string) (string, string, time.Time, uint64) {
ctx := authz.WithInstance(context.Background(), Tester.Instance)
intentID := createIntent(t, idpID)
writeModel, err := Tester.Commands.GetIntentWriteModel(ctx, intentID, Tester.Organisation.ID)
require.NoError(t, err)
idpUser := &oauth.UserMapper{
RawInfo: map[string]interface{}{
"id": "id",
},
}
idpSession := &oauth.Session{
Tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
Token: &oauth2.Token{
AccessToken: "accessToken",
},
IDToken: "idToken",
},
}
token, err := Tester.Commands.SucceedIDPIntent(ctx, writeModel, idpUser, idpSession, "")
require.NoError(t, err)
return intentID, token, writeModel.ChangeDate, writeModel.ProcessedSequence
}
func TestServer_AddHumanUser(t *testing.T) {
idpID := createProvider(t)
type args struct {
ctx context.Context
req *user.AddHumanUserRequest
@@ -287,6 +348,105 @@ func TestServer_AddHumanUser(t *testing.T) {
},
wantErr: true,
},
{
name: "missing idp",
args: args{
CTX,
&user.AddHumanUserRequest{
Organisation: &object.Organisation{
Org: &object.Organisation_OrgId{
OrgId: Tester.Organisation.ID,
},
},
Profile: &user.SetHumanProfile{
FirstName: "Donald",
LastName: "Duck",
NickName: gu.Ptr("Dukkie"),
DisplayName: gu.Ptr("Donald Duck"),
PreferredLanguage: gu.Ptr("en"),
Gender: user.Gender_GENDER_DIVERSE.Enum(),
},
Email: &user.SetHumanEmail{
Email: "livio@zitadel.com",
Verification: &user.SetHumanEmail_IsVerified{
IsVerified: true,
},
},
Metadata: []*user.SetMetadataEntry{
{
Key: "somekey",
Value: []byte("somevalue"),
},
},
PasswordType: &user.AddHumanUserRequest_Password{
Password: &user.Password{
Password: "DifficultPW666!",
ChangeRequired: false,
},
},
IdpLinks: []*user.IDPLink{
{
IdpId: "idpID",
IdpExternalId: "externalID",
DisplayName: "displayName",
},
},
},
},
wantErr: true,
},
{
name: "with idp",
args: args{
CTX,
&user.AddHumanUserRequest{
Organisation: &object.Organisation{
Org: &object.Organisation_OrgId{
OrgId: Tester.Organisation.ID,
},
},
Profile: &user.SetHumanProfile{
FirstName: "Donald",
LastName: "Duck",
NickName: gu.Ptr("Dukkie"),
DisplayName: gu.Ptr("Donald Duck"),
PreferredLanguage: gu.Ptr("en"),
Gender: user.Gender_GENDER_DIVERSE.Enum(),
},
Email: &user.SetHumanEmail{
Email: "livio@zitadel.com",
Verification: &user.SetHumanEmail_IsVerified{
IsVerified: true,
},
},
Metadata: []*user.SetMetadataEntry{
{
Key: "somekey",
Value: []byte("somevalue"),
},
},
PasswordType: &user.AddHumanUserRequest_Password{
Password: &user.Password{
Password: "DifficultPW666!",
ChangeRequired: false,
},
},
IdpLinks: []*user.IDPLink{
{
IdpId: idpID,
IdpExternalId: "externalID",
DisplayName: "displayName",
},
},
},
},
want: &user.AddHumanUserResponse{
Details: &object.Details{
ChangeDate: timestamppb.Now(),
ResourceOwner: Tester.Organisation.ID,
},
},
},
}
for i, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@@ -315,3 +475,226 @@ func TestServer_AddHumanUser(t *testing.T) {
})
}
}
func TestServer_AddIDPLink(t *testing.T) {
idpID := createProvider(t)
type args struct {
ctx context.Context
req *user.AddIDPLinkRequest
}
tests := []struct {
name string
args args
want *user.AddIDPLinkResponse
wantErr bool
}{
{
name: "user does not exist",
args: args{
CTX,
&user.AddIDPLinkRequest{
UserId: "userID",
IdpLink: &user.IDPLink{
IdpId: idpID,
IdpExternalId: "externalID",
DisplayName: "displayName",
},
},
},
want: nil,
wantErr: true,
},
{
name: "idp does not exist",
args: args{
CTX,
&user.AddIDPLinkRequest{
UserId: Tester.Users[integration.OrgOwner].ID,
IdpLink: &user.IDPLink{
IdpId: "idpID",
IdpExternalId: "externalID",
DisplayName: "displayName",
},
},
},
want: nil,
wantErr: true,
},
{
name: "add link",
args: args{
CTX,
&user.AddIDPLinkRequest{
UserId: Tester.Users[integration.OrgOwner].ID,
IdpLink: &user.IDPLink{
IdpId: idpID,
IdpExternalId: "externalID",
DisplayName: "displayName",
},
},
},
want: &user.AddIDPLinkResponse{
Details: &object.Details{
ChangeDate: timestamppb.Now(),
ResourceOwner: Tester.Organisation.ID,
},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := Client.AddIDPLink(tt.args.ctx, tt.args.req)
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
integration.AssertDetails(t, tt.want, got)
})
}
}
func TestServer_StartIdentityProviderFlow(t *testing.T) {
idpID := createProvider(t)
type args struct {
ctx context.Context
req *user.StartIdentityProviderFlowRequest
}
tests := []struct {
name string
args args
want *user.StartIdentityProviderFlowResponse
wantErr bool
}{
{
name: "missing urls",
args: args{
CTX,
&user.StartIdentityProviderFlowRequest{
IdpId: idpID,
},
},
want: nil,
wantErr: true,
},
{
name: "next step auth url",
args: args{
CTX,
&user.StartIdentityProviderFlowRequest{
IdpId: idpID,
SuccessUrl: "https://example.com/success",
FailureUrl: "https://example.com/failure",
},
},
want: &user.StartIdentityProviderFlowResponse{
Details: &object.Details{
ChangeDate: timestamppb.Now(),
ResourceOwner: Tester.Organisation.ID,
},
NextStep: &user.StartIdentityProviderFlowResponse_AuthUrl{
AuthUrl: "https://example.com/oauth/v2/authorize?client_id=clientID&prompt=select_account&redirect_uri=https%3A%2F%2Flocalhost%3A8080%2Fidps%2Fcallback&response_type=code&scope=openid+profile+email&state=",
},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := Client.StartIdentityProviderFlow(tt.args.ctx, tt.args.req)
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
if nextStep := tt.want.GetNextStep(); nextStep != nil {
if !strings.HasPrefix(got.GetAuthUrl(), tt.want.GetAuthUrl()) {
assert.Failf(t, "auth url does not match", "expected: %s, but got: %s", tt.want.GetAuthUrl(), got.GetAuthUrl())
}
}
integration.AssertDetails(t, tt.want, got)
})
}
}
func TestServer_RetrieveIdentityProviderInformation(t *testing.T) {
idpID := createProvider(t)
intentID := createIntent(t, idpID)
successfulID, token, changeDate, sequence := createSuccessfulIntent(t, idpID)
type args struct {
ctx context.Context
req *user.RetrieveIdentityProviderInformationRequest
}
tests := []struct {
name string
args args
want *user.RetrieveIdentityProviderInformationResponse
wantErr bool
}{
{
name: "failed intent",
args: args{
CTX,
&user.RetrieveIdentityProviderInformationRequest{
IntentId: intentID,
Token: "",
},
},
wantErr: true,
},
{
name: "wrong token",
args: args{
CTX,
&user.RetrieveIdentityProviderInformationRequest{
IntentId: successfulID,
Token: "wrong token",
},
},
wantErr: true,
},
{
name: "retrieve successful intent",
args: args{
CTX,
&user.RetrieveIdentityProviderInformationRequest{
IntentId: successfulID,
Token: token,
},
},
want: &user.RetrieveIdentityProviderInformationResponse{
Details: &object.Details{
ChangeDate: timestamppb.New(changeDate),
ResourceOwner: Tester.Organisation.ID,
Sequence: sequence,
},
IdpInformation: &user.IDPInformation{
Access: &user.IDPInformation_Oauth{
Oauth: &user.IDPOAuthAccessInformation{
AccessToken: "accessToken",
IdToken: gu.Ptr("idToken"),
},
},
IdpInformation: []byte(`{"RawInfo":{"id":"id"}}`),
},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := Client.RetrieveIdentityProviderInformation(tt.args.ctx, tt.args.req)
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
require.Equal(t, tt.want.GetDetails(), got.GetDetails())
require.Equal(t, tt.want.GetIdpInformation(), got.GetIdpInformation())
})
}
}

View File

@@ -3,11 +3,21 @@ package user
import (
"errors"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/zitadel/zitadel/internal/api/grpc"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
caos_errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
object_pb "github.com/zitadel/zitadel/pkg/grpc/object/v2alpha"
user "github.com/zitadel/zitadel/pkg/grpc/user/v2alpha"
)
@@ -78,3 +88,118 @@ func Test_hashedPasswordToCommand(t *testing.T) {
})
}
}
func Test_intentToIDPInformationPb(t *testing.T) {
decryption := func(err error) crypto.EncryptionAlgorithm {
mCrypto := crypto.NewMockEncryptionAlgorithm(gomock.NewController(t))
mCrypto.EXPECT().Algorithm().Return("enc")
mCrypto.EXPECT().DecryptionKeyIDs().Return([]string{"id"})
mCrypto.EXPECT().DecryptString(gomock.Any(), gomock.Any()).DoAndReturn(
func(code []byte, keyID string) (string, error) {
if err != nil {
return "", err
}
return string(code), nil
})
return mCrypto
}
type args struct {
intent *command.IDPIntentWriteModel
alg crypto.EncryptionAlgorithm
}
type res struct {
resp *user.RetrieveIdentityProviderInformationResponse
err error
}
tests := []struct {
name string
args args
res res
}{
{
"decryption invalid key id error",
args{
intent: &command.IDPIntentWriteModel{
WriteModel: eventstore.WriteModel{
AggregateID: "intentID",
ProcessedSequence: 123,
ResourceOwner: "ro",
InstanceID: "instanceID",
ChangeDate: time.Date(2019, 4, 1, 1, 1, 1, 1, time.Local),
},
IDPID: "idpID",
IDPUser: []byte(`{"id": "id"}`),
IDPAccessToken: &crypto.CryptoValue{
CryptoType: crypto.TypeEncryption,
Algorithm: "enc",
KeyID: "id",
Crypted: []byte("accessToken"),
},
IDPIDToken: "idToken",
UserID: "userID",
State: domain.IDPIntentStateSucceeded,
},
alg: decryption(caos_errs.ThrowInternal(nil, "id", "invalid key id")),
},
res{
resp: nil,
err: caos_errs.ThrowInternal(nil, "id", "invalid key id"),
},
},
{
"successful",
args{
intent: &command.IDPIntentWriteModel{
WriteModel: eventstore.WriteModel{
AggregateID: "intentID",
ProcessedSequence: 123,
ResourceOwner: "ro",
InstanceID: "instanceID",
ChangeDate: time.Date(2019, 4, 1, 1, 1, 1, 1, time.Local),
},
IDPID: "idpID",
IDPUser: []byte(`{"id": "id"}`),
IDPAccessToken: &crypto.CryptoValue{
CryptoType: crypto.TypeEncryption,
Algorithm: "enc",
KeyID: "id",
Crypted: []byte("accessToken"),
},
IDPIDToken: "idToken",
UserID: "userID",
State: domain.IDPIntentStateSucceeded,
},
alg: decryption(nil),
},
res{
resp: &user.RetrieveIdentityProviderInformationResponse{
Details: &object_pb.Details{
Sequence: 123,
ChangeDate: timestamppb.New(time.Date(2019, 4, 1, 1, 1, 1, 1, time.Local)),
ResourceOwner: "ro",
},
IdpInformation: &user.IDPInformation{
Access: &user.IDPInformation_Oauth{
Oauth: &user.IDPOAuthAccessInformation{
AccessToken: "accessToken",
IdToken: gu.Ptr("idToken"),
}},
IdpInformation: []byte(`{"id": "id"}`),
},
},
err: nil,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := intentToIDPInformationPb(tt.args.intent, tt.args.alg)
require.ErrorIs(t, err, tt.res.err)
assert.Equal(t, tt.res.resp, got)
if tt.res.resp != nil {
grpc.AllFieldsSet(t, got.ProtoReflect())
}
})
}
}

View File

@@ -123,8 +123,8 @@ func (c *CookieHandler) SetEncryptedCookie(w http.ResponseWriter, name, domain s
return nil
}
func (c *CookieHandler) DeleteCookie(w http.ResponseWriter, r *http.Request, name string) {
c.httpSet(w, name, r.Host, "", -1)
func (c *CookieHandler) DeleteCookie(w http.ResponseWriter, name string) {
c.httpSet(w, name, "", "", -1)
}
func (c *CookieHandler) httpSet(w http.ResponseWriter, name, domain, value string, maxage int) {

View File

@@ -1,6 +1,7 @@
package middleware
import (
"context"
"net"
"net/http"
"net/url"
@@ -32,15 +33,54 @@ type AccessConfig struct {
// NewAccessInterceptor intercepts all requests and stores them to the logstore.
// If storeOnly is false, it also checks if requests are exhausted.
// If requests are exhausted, it also returns http.StatusTooManyRequests and sets a cookie
func NewAccessInterceptor(svc *logstore.Service, cookieHandler *http_utils.CookieHandler, cookieConfig *AccessConfig, storeOnly bool) *AccessInterceptor {
func NewAccessInterceptor(svc *logstore.Service, cookieHandler *http_utils.CookieHandler, cookieConfig *AccessConfig) *AccessInterceptor {
return &AccessInterceptor{
svc: svc,
cookieHandler: cookieHandler,
limitConfig: cookieConfig,
storeOnly: storeOnly,
}
}
func (a *AccessInterceptor) WithoutLimiting() *AccessInterceptor {
return &AccessInterceptor{
svc: a.svc,
cookieHandler: a.cookieHandler,
limitConfig: a.limitConfig,
storeOnly: true,
}
}
func (a *AccessInterceptor) AccessService() *logstore.Service {
return a.svc
}
func (a *AccessInterceptor) Limit(ctx context.Context) bool {
if !a.svc.Enabled() || a.storeOnly {
return false
}
instance := authz.GetInstance(ctx)
remaining := a.svc.Limit(ctx, instance.InstanceID())
return remaining != nil && *remaining <= 0
}
func (a *AccessInterceptor) SetExhaustedCookie(writer http.ResponseWriter, request *http.Request) {
cookieValue := "true"
host := request.Header.Get(middleware.HTTP1Host)
domain := host
if strings.ContainsAny(host, ":") {
var err error
domain, _, err = net.SplitHostPort(host)
if err != nil {
logging.WithError(err).WithField("host", host).Warning("failed to extract cookie domain from request host")
}
}
a.cookieHandler.SetCookie(writer, a.limitConfig.ExhaustedCookieKey, domain, cookieValue)
}
func (a *AccessInterceptor) DeleteExhaustedCookie(writer http.ResponseWriter) {
a.cookieHandler.DeleteCookie(writer, a.limitConfig.ExhaustedCookieKey)
}
func (a *AccessInterceptor) Handle(next http.Handler) http.Handler {
if !a.svc.Enabled() {
return next
@@ -49,23 +89,16 @@ func (a *AccessInterceptor) Handle(next http.Handler) http.Handler {
ctx := request.Context()
tracingCtx, checkSpan := tracing.NewNamedSpan(ctx, "checkAccess")
wrappedWriter := &statusRecorder{ResponseWriter: writer, status: 0}
instance := authz.GetInstance(ctx)
limit := false
if !a.storeOnly {
remaining := a.svc.Limit(tracingCtx, instance.InstanceID())
limit = remaining != nil && *remaining == 0
}
limited := a.Limit(tracingCtx)
checkSpan.End()
if limit {
// Limit can only be true when storeOnly is false, so set the cookie and the response code
SetExhaustedCookie(a.cookieHandler, wrappedWriter, a.limitConfig, request)
if limited {
a.SetExhaustedCookie(wrappedWriter, request)
http.Error(wrappedWriter, "quota for authenticated requests is exhausted", http.StatusTooManyRequests)
} else {
if !a.storeOnly {
// If not limited and not storeOnly, ensure the cookie is deleted
DeleteExhaustedCookie(a.cookieHandler, wrappedWriter, request, a.limitConfig)
}
// Always serve if not limited
}
if !limited && !a.storeOnly {
a.DeleteExhaustedCookie(wrappedWriter)
}
if !limited {
next.ServeHTTP(wrappedWriter, request)
}
tracingCtx, writeSpan := tracing.NewNamedSpan(tracingCtx, "writeAccess")
@@ -75,6 +108,7 @@ func (a *AccessInterceptor) Handle(next http.Handler) http.Handler {
if err != nil {
logging.WithError(err).WithField("url", requestURL).Warning("failed to unescape request url")
}
instance := authz.GetInstance(tracingCtx)
a.svc.Handle(tracingCtx, &access.Record{
LogDate: time.Now(),
Protocol: access.HTTP,
@@ -90,24 +124,6 @@ func (a *AccessInterceptor) Handle(next http.Handler) http.Handler {
})
}
func SetExhaustedCookie(cookieHandler *http_utils.CookieHandler, writer http.ResponseWriter, cookieConfig *AccessConfig, request *http.Request) {
cookieValue := "true"
host := request.Header.Get(middleware.HTTP1Host)
domain := host
if strings.ContainsAny(host, ":") {
var err error
domain, _, err = net.SplitHostPort(host)
if err != nil {
logging.WithError(err).WithField("host", host).Warning("failed to extract cookie domain from request host")
}
}
cookieHandler.SetCookie(writer, cookieConfig.ExhaustedCookieKey, domain, cookieValue)
}
func DeleteExhaustedCookie(cookieHandler *http_utils.CookieHandler, writer http.ResponseWriter, request *http.Request, cookieConfig *AccessConfig) {
cookieHandler.DeleteCookie(writer, request, cookieConfig.ExhaustedCookieKey)
}
type statusRecorder struct {
http.ResponseWriter
status int

246
internal/api/idp/idp.go Normal file
View File

@@ -0,0 +1,246 @@
package idp
import (
"context"
"errors"
"net/http"
"github.com/gorilla/mux"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
http_utils "github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
z_errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/form"
"github.com/zitadel/zitadel/internal/idp"
"github.com/zitadel/zitadel/internal/idp/providers/azuread"
"github.com/zitadel/zitadel/internal/idp/providers/github"
"github.com/zitadel/zitadel/internal/idp/providers/gitlab"
"github.com/zitadel/zitadel/internal/idp/providers/google"
"github.com/zitadel/zitadel/internal/idp/providers/jwt"
"github.com/zitadel/zitadel/internal/idp/providers/ldap"
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
openid "github.com/zitadel/zitadel/internal/idp/providers/oidc"
"github.com/zitadel/zitadel/internal/query"
)
const (
HandlerPrefix = "/idps"
callbackPath = "/callback"
paramIntentID = "id"
paramToken = "token"
paramUserID = "user"
paramError = "error"
paramErrorDescription = "error_description"
)
type Handler struct {
commands *command.Commands
queries *query.Queries
parser *form.Parser
encryptionAlgorithm crypto.EncryptionAlgorithm
callbackURL func(ctx context.Context) string
}
type externalIDPCallbackData struct {
State string `schema:"state"`
Code string `schema:"code"`
Error string `schema:"error"`
ErrorDescription string `schema:"error_description"`
}
// CallbackURL generates the instance specific URL to the IDP callback handler
func CallbackURL(externalSecure bool) func(ctx context.Context) string {
return func(ctx context.Context) string {
return http_utils.BuildOrigin(authz.GetInstance(ctx).RequestedHost(), externalSecure) + HandlerPrefix + callbackPath
}
}
func NewHandler(
commands *command.Commands,
queries *query.Queries,
encryptionAlgorithm crypto.EncryptionAlgorithm,
externalSecure bool,
instanceInterceptor func(next http.Handler) http.Handler,
) http.Handler {
h := &Handler{
commands: commands,
queries: queries,
parser: form.NewParser(),
encryptionAlgorithm: encryptionAlgorithm,
callbackURL: CallbackURL(externalSecure),
}
router := mux.NewRouter()
router.Use(instanceInterceptor)
router.HandleFunc(callbackPath, h.handleCallback)
return router
}
func (h *Handler) handleCallback(w http.ResponseWriter, r *http.Request) {
data, err := h.parseCallbackRequest(r)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
intent := h.getActiveIntent(w, r, data.State)
if intent == nil {
// if we didn't get an active intent the error was already handled (either redirected or display directly)
return
}
ctx := r.Context()
// the provider might have returned an error
if data.Error != "" {
cmdErr := h.commands.FailIDPIntent(ctx, intent, reason(data.Error, data.ErrorDescription))
logging.WithFields("intent", intent.AggregateID).OnError(cmdErr).Error("failed to push failed event on idp intent")
redirectToFailureURL(w, r, intent, data.Error, data.ErrorDescription)
return
}
provider, err := h.commands.GetProvider(ctx, intent.IDPID, h.callbackURL(ctx))
if err != nil {
cmdErr := h.commands.FailIDPIntent(ctx, intent, err.Error())
logging.WithFields("intent", intent.AggregateID).OnError(cmdErr).Error("failed to push failed event on idp intent")
redirectToFailureURLErr(w, r, intent, err)
return
}
idpUser, idpSession, err := h.fetchIDPUser(ctx, provider, data.Code)
if err != nil {
cmdErr := h.commands.FailIDPIntent(ctx, intent, err.Error())
logging.WithFields("intent", intent.AggregateID).OnError(cmdErr).Error("failed to push failed event on idp intent")
redirectToFailureURLErr(w, r, intent, err)
return
}
userID, err := h.checkExternalUser(ctx, intent.IDPID, idpUser.GetID())
logging.WithFields("intent", intent.AggregateID).OnError(err).Error("could not check if idp user already exists")
token, err := h.commands.SucceedIDPIntent(ctx, intent, idpUser, idpSession, userID)
if err != nil {
redirectToFailureURLErr(w, r, intent, z_errs.ThrowInternal(err, "IDP-JdD3g", "Errors.Intent.TokenCreationFailed"))
return
}
redirectToSuccessURL(w, r, intent, token, userID)
}
func (h *Handler) parseCallbackRequest(r *http.Request) (*externalIDPCallbackData, error) {
data := new(externalIDPCallbackData)
err := h.parser.Parse(r, data)
if err != nil {
return nil, err
}
if data.State == "" {
return nil, z_errs.ThrowInvalidArgument(nil, "IDP-Hk38e", "Errors.Intent.StateMissing")
}
return data, nil
}
func (h *Handler) getActiveIntent(w http.ResponseWriter, r *http.Request, state string) *command.IDPIntentWriteModel {
intent, err := h.commands.GetIntentWriteModel(r.Context(), state, "")
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return nil
}
if intent.State == domain.IDPIntentStateUnspecified {
http.Error(w, reason("IDP-Hk38e", "Errors.Intent.NotStarted"), http.StatusBadRequest)
return nil
}
if intent.State != domain.IDPIntentStateStarted {
redirectToFailureURL(w, r, intent, "IDP-Sfrgs", "Errors.Intent.NotStarted")
return nil
}
return intent
}
func redirectToSuccessURL(w http.ResponseWriter, r *http.Request, intent *command.IDPIntentWriteModel, token, userID string) {
queries := intent.SuccessURL.Query()
queries.Set(paramIntentID, intent.AggregateID)
queries.Set(paramToken, token)
if userID != "" {
queries.Set(paramUserID, userID)
}
intent.SuccessURL.RawQuery = queries.Encode()
http.Redirect(w, r, intent.SuccessURL.String(), http.StatusFound)
}
func redirectToFailureURLErr(w http.ResponseWriter, r *http.Request, i *command.IDPIntentWriteModel, err error) {
msg := err.Error()
var description string
zErr := new(z_errs.CaosError)
if errors.As(err, &zErr) {
msg = zErr.GetID()
description = zErr.GetMessage() // TODO: i18n?
}
redirectToFailureURL(w, r, i, msg, description)
}
func redirectToFailureURL(w http.ResponseWriter, r *http.Request, i *command.IDPIntentWriteModel, err, description string) {
queries := i.FailureURL.Query()
queries.Set(paramIntentID, i.AggregateID)
queries.Set(paramError, err)
queries.Set(paramErrorDescription, description)
i.FailureURL.RawQuery = queries.Encode()
http.Redirect(w, r, i.FailureURL.String(), http.StatusFound)
}
func (h *Handler) fetchIDPUser(ctx context.Context, identityProvider idp.Provider, code string) (user idp.User, idpTokens idp.Session, err error) {
var session idp.Session
switch provider := identityProvider.(type) {
case *oauth.Provider:
session = &oauth.Session{Provider: provider, Code: code}
case *openid.Provider:
session = &openid.Session{Provider: provider, Code: code}
case *azuread.Provider:
session = &oauth.Session{Provider: provider.Provider, Code: code}
case *github.Provider:
session = &oauth.Session{Provider: provider.Provider, Code: code}
case *gitlab.Provider:
session = &openid.Session{Provider: provider.Provider, Code: code}
case *google.Provider:
session = &openid.Session{Provider: provider.Provider, Code: code}
case *jwt.Provider, *ldap.Provider:
return nil, nil, z_errs.ThrowInvalidArgument(nil, "IDP-52jmn", "Errors.ExternalIDP.IDPTypeNotImplemented")
default:
return nil, nil, z_errs.ThrowUnimplemented(nil, "IDP-SSDg", "Errors.ExternalIDP.IDPTypeNotImplemented")
}
user, err = session.FetchUser(ctx)
if err != nil {
return nil, nil, err
}
return user, session, nil
}
func (h *Handler) checkExternalUser(ctx context.Context, idpID, externalUserID string) (userID string, err error) {
idQuery, err := query.NewIDPUserLinkIDPIDSearchQuery(idpID)
if err != nil {
return "", err
}
externalIDQuery, err := query.NewIDPUserLinksExternalIDSearchQuery(externalUserID)
if err != nil {
return "", err
}
queries := []query.SearchQuery{
idQuery, externalIDQuery,
}
links, err := h.queries.IDPUserLinks(ctx, &query.IDPUserLinksSearchQuery{Queries: queries}, false)
if err != nil {
return "", err
}
if len(links.Links) != 1 {
return "", nil
}
return links.Links[0].UserID, nil
}
func reason(err, description string) string {
if description == "" {
return err
}
return err + ": " + description
}

View File

@@ -0,0 +1,220 @@
package idp
import (
"net/http/httptest"
"net/url"
"testing"
"github.com/stretchr/testify/assert"
"github.com/zitadel/zitadel/internal/command"
z_errors "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/form"
)
func Test_redirectToSuccessURL(t *testing.T) {
type args struct {
id string
userID string
token string
failureURL string
successURL string
}
type res struct {
want string
}
tests := []struct {
name string
args args
res res
}{
{
"redirect",
args{
id: "id",
token: "token",
failureURL: "https://example.com/failure",
successURL: "https://example.com/success",
},
res{
"https://example.com/success?id=id&token=token",
},
},
{
"redirect with userID",
args{
id: "id",
userID: "user",
token: "token",
failureURL: "https://example.com/failure",
successURL: "https://example.com/success",
},
res{
"https://example.com/success?id=id&token=token&user=user",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com", nil)
resp := httptest.NewRecorder()
wm := command.NewIDPIntentWriteModel(tt.args.id, tt.args.id)
wm.FailureURL, _ = url.Parse(tt.args.failureURL)
wm.SuccessURL, _ = url.Parse(tt.args.successURL)
redirectToSuccessURL(resp, req, wm, tt.args.token, tt.args.userID)
assert.Equal(t, tt.res.want, resp.Header().Get("Location"))
})
}
}
func Test_redirectToFailureURL(t *testing.T) {
type args struct {
id string
failureURL string
successURL string
err string
desc string
}
type res struct {
want string
}
tests := []struct {
name string
args args
res res
}{
{
"redirect",
args{
id: "id",
failureURL: "https://example.com/failure",
successURL: "https://example.com/success",
},
res{
"https://example.com/failure?error=&error_description=&id=id",
},
},
{
"redirect with error",
args{
id: "id",
failureURL: "https://example.com/failure",
successURL: "https://example.com/success",
err: "test",
desc: "testdesc",
},
res{
"https://example.com/failure?error=test&error_description=testdesc&id=id",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com", nil)
resp := httptest.NewRecorder()
wm := command.NewIDPIntentWriteModel(tt.args.id, tt.args.id)
wm.FailureURL, _ = url.Parse(tt.args.failureURL)
wm.SuccessURL, _ = url.Parse(tt.args.successURL)
redirectToFailureURL(resp, req, wm, tt.args.err, tt.args.desc)
assert.Equal(t, tt.res.want, resp.Header().Get("Location"))
})
}
}
func Test_redirectToFailureURLErr(t *testing.T) {
type args struct {
id string
failureURL string
successURL string
err error
}
type res struct {
want string
}
tests := []struct {
name string
args args
res res
}{
{
"redirect with error",
args{
id: "id",
failureURL: "https://example.com/failure",
successURL: "https://example.com/success",
err: z_errors.ThrowError(nil, "test", "testdesc"),
},
res{
"https://example.com/failure?error=test&error_description=testdesc&id=id",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com", nil)
resp := httptest.NewRecorder()
wm := command.NewIDPIntentWriteModel(tt.args.id, tt.args.id)
wm.FailureURL, _ = url.Parse(tt.args.failureURL)
wm.SuccessURL, _ = url.Parse(tt.args.successURL)
redirectToFailureURLErr(resp, req, wm, tt.args.err)
assert.Equal(t, tt.res.want, resp.Header().Get("Location"))
})
}
}
func Test_parseCallbackRequest(t *testing.T) {
type args struct {
url string
}
type res struct {
want *externalIDPCallbackData
err bool
}
tests := []struct {
name string
args args
res res
}{
{
"no state",
args{
url: "https://example.com?state=&code=code&error=error&error_description=desc",
},
res{
err: true,
},
},
{
"parse",
args{
url: "https://example.com?state=state&code=code&error=error&error_description=desc",
},
res{
want: &externalIDPCallbackData{
State: "state",
Code: "code",
Error: "error",
ErrorDescription: "desc",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", tt.args.url, nil)
handler := Handler{parser: form.NewParser()}
data, err := handler.parseCallbackRequest(req)
if tt.res.err {
assert.Error(t, err)
}
assert.Equal(t, tt.res.want, data)
})
}
}

View File

@@ -91,7 +91,7 @@ func (f *file) Stat() (_ fs.FileInfo, err error) {
return f, nil
}
func Start(config Config, externalSecure bool, issuer op.IssuerFromRequest, callDurationInterceptor, instanceHandler, accessInterceptor func(http.Handler) http.Handler, customerPortal string) (http.Handler, error) {
func Start(config Config, externalSecure bool, issuer op.IssuerFromRequest, callDurationInterceptor, instanceHandler func(http.Handler) http.Handler, limitingAccessInterceptor *middleware.AccessInterceptor, customerPortal string) (http.Handler, error) {
fSys, err := fs.Sub(static, "static")
if err != nil {
return nil, err
@@ -106,20 +106,27 @@ func Start(config Config, externalSecure bool, issuer op.IssuerFromRequest, call
handler := mux.NewRouter()
handler.Use(callDurationInterceptor, instanceHandler, security, accessInterceptor)
handler.Use(callDurationInterceptor, instanceHandler, security, limitingAccessInterceptor.WithoutLimiting().Handle)
handler.Handle(envRequestPath, middleware.TelemetryHandler()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
url := http_util.BuildOrigin(r.Host, externalSecure)
instance := authz.GetInstance(r.Context())
ctx := r.Context()
instance := authz.GetInstance(ctx)
instanceMgmtURL, err := templateInstanceManagementURL(config.InstanceManagementURL, instance)
if err != nil {
http.Error(w, fmt.Sprintf("unable to template instance management url for console: %v", err), http.StatusInternalServerError)
return
}
environmentJSON, err := createEnvironmentJSON(url, issuer(r), instance.ConsoleClientID(), customerPortal, instanceMgmtURL)
exhausted := limitingAccessInterceptor.Limit(ctx)
environmentJSON, err := createEnvironmentJSON(url, issuer(r), instance.ConsoleClientID(), customerPortal, instanceMgmtURL, exhausted)
if err != nil {
http.Error(w, fmt.Sprintf("unable to marshal env for console: %v", err), http.StatusInternalServerError)
return
}
if exhausted {
limitingAccessInterceptor.SetExhaustedCookie(w, r)
} else {
limitingAccessInterceptor.DeleteExhaustedCookie(w)
}
_, err = w.Write(environmentJSON)
logging.OnError(err).Error("error serving environment.json")
})))
@@ -148,19 +155,21 @@ func csp() *middleware.CSP {
return &csp
}
func createEnvironmentJSON(api, issuer, clientID, customerPortal, instanceMgmtUrl string) ([]byte, error) {
func createEnvironmentJSON(api, issuer, clientID, customerPortal, instanceMgmtUrl string, exhausted bool) ([]byte, error) {
environment := struct {
API string `json:"api,omitempty"`
Issuer string `json:"issuer,omitempty"`
ClientID string `json:"clientid,omitempty"`
CustomerPortal string `json:"customer_portal,omitempty"`
InstanceManagementURL string `json:"instance_management_url,omitempty"`
Exhausted bool `json:"exhausted,omitempty"`
}{
API: api,
Issuer: issuer,
ClientID: clientID,
CustomerPortal: customerPortal,
InstanceManagementURL: instanceMgmtUrl,
Exhausted: exhausted,
}
return json.Marshal(environment)
}