diff --git a/cmd/start/start.go b/cmd/start/start.go index 12062951a9..b12dde2182 100644 --- a/cmd/start/start.go +++ b/cmd/start/start.go @@ -473,7 +473,7 @@ func startAPIs( if err := apis.RegisterService(ctx, idp_v2.CreateServer(commands, queries, permissionCheck)); err != nil { return nil, err } - if err := apis.RegisterService(ctx, action_v3_alpha.CreateServer(config.SystemDefaults, commands, queries, domain.AllFunctions, apis.ListGrpcMethods, apis.ListGrpcServices)); err != nil { + if err := apis.RegisterService(ctx, action_v3_alpha.CreateServer(config.SystemDefaults, commands, queries, domain.AllActionFunctions, apis.ListGrpcMethods, apis.ListGrpcServices)); err != nil { return nil, err } if err := apis.RegisterService(ctx, userschema_v3_alpha.CreateServer(config.SystemDefaults, commands, queries)); err != nil { diff --git a/docs/docs/concepts/features/actions_v2.md b/docs/docs/concepts/features/actions_v2.md index a06384639d..85f05b7c29 100644 --- a/docs/docs/concepts/features/actions_v2.md +++ b/docs/docs/concepts/features/actions_v2.md @@ -10,6 +10,12 @@ This is useful when you have special business requirements that ZITADEL doesn't We're working on Actions continuously. In the [roadmap](https://zitadel.com/roadmap), you see how we are planning to expand and improve it. Please tell us about your needs and help us prioritize further fixes and features. ::: +:::warning +To use Actions v2 activate the feature flag "Actions" [feature flag](/docs/apis/resources/feature_service_v2/feature-service-set-instance-features), to be able to manage the related resources. + +The Actions v2 will always be executed if available, even if the feature flag is switched off, to remove any Actions v2 the related Execution has to be removed. +::: + ## Why actions? ZITADEL can't anticipate and solve every possible business rule and integration requirements from all ZITADEL users. Here are some examples: - A business requires domain specific data validation before a user can be created or authenticated. @@ -31,9 +37,13 @@ so that everybody can implement their custom behaviour for as many processes as Possible conditions for the Execution: - Request, to react to or manipulate requests to ZITADEL, for example add information to newly created users - Response, to react to or manipulate responses to ZITADEL, for example to provision newly created users to other systems -- Function, to react to different functionality in ZITADEL, replaces [Actions](/concepts/features/actions) +- Function, to react to different functionality in ZITADEL, replaces [Actions](/concepts/features/actions). - Event, to create to different events which get created in ZITADEL, for example to inform somebody if a user gets locked +:::info +Currently, the defined Actions v2 will be executed additionally to the defined [Actions](/concepts/features/actions). +::: + ## Further reading - [Actions v2 reference](/apis/actions/v3/usage) diff --git a/go.mod b/go.mod index f316c3e866..3f81b16ac5 100644 --- a/go.mod +++ b/go.mod @@ -57,6 +57,9 @@ require ( github.com/pquerna/otp v1.4.0 github.com/rakyll/statik v0.1.7 github.com/redis/go-redis/v9 v9.7.0 + github.com/riverqueue/river v0.16.0 + github.com/riverqueue/river/riverdriver v0.16.0 + github.com/riverqueue/river/rivertype v0.16.0 github.com/rs/cors v1.11.1 github.com/santhosh-tekuri/jsonschema/v5 v5.3.1 github.com/sony/gobreaker/v2 v2.0.0 @@ -70,7 +73,7 @@ require ( github.com/zitadel/logging v0.6.1 github.com/zitadel/oidc/v3 v3.32.0 github.com/zitadel/passwap v0.6.0 - github.com/zitadel/saml v0.3.3 + github.com/zitadel/saml v0.3.4 github.com/zitadel/schema v1.3.0 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.53.0 go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.53.0 @@ -124,10 +127,7 @@ require ( github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/pkg/errors v0.9.1 // indirect - github.com/riverqueue/river v0.16.0 // indirect - github.com/riverqueue/river/riverdriver v0.16.0 // indirect github.com/riverqueue/river/rivershared v0.16.0 // indirect - github.com/riverqueue/river/rivertype v0.16.0 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/sourcegraph/conc v0.3.0 // indirect diff --git a/go.sum b/go.sum index 709d6b7a14..9c992f3662 100644 --- a/go.sum +++ b/go.sum @@ -416,16 +416,14 @@ github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANyt github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/influxdata/influxdb1-client v0.0.0-20191209144304-8bf82d3c094d/go.mod h1:qj24IKcXYK6Iy9ceXlo3Tc+vtHo9lIhSX5JddghvEPo= +github.com/jackc/pgerrcode v0.0.0-20220416144525-469b46aa5efa h1:s+4MhCQ6YrzisK6hFJUX53drDT4UsSW3DEhKn0ifuHw= +github.com/jackc/pgerrcode v0.0.0-20220416144525-469b46aa5efa/go.mod h1:a/s9Lp5W7n/DD0VrVoyJ00FbP2ytTPDVOivvn2bMlds= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgx/v5 v5.7.0 h1:FG6VLIdzvAPhnYqP14sQ2xhFLkiUQHCs6ySqO91kF4g= -github.com/jackc/pgx/v5 v5.7.0/go.mod h1:awP1KNnjylvpxHuHP63gzjhnGkI1iw+PMoIwvoleN/8= github.com/jackc/pgx/v5 v5.7.2 h1:mLoDLV6sonKlvjIEsV56SkWNCnuNv531l94GaIzO+XI= github.com/jackc/pgx/v5 v5.7.2/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ= -github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= -github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jarcoal/jpath v0.0.0-20140328210829-f76b8b2dbf52 h1:jny9eqYPwkG8IVy7foUoRjQmFLcArCSz+uPsL6KS0HQ= @@ -648,12 +646,16 @@ github.com/riverqueue/river v0.16.0 h1:YyQrs0kGgjuABwgat02DPUYS0TMyG2ZFlzvf6+fSF github.com/riverqueue/river v0.16.0/go.mod h1:pEZ8Gc15XyFjVY89nJeL256ub5z18XF7ukYn8ktqQrs= github.com/riverqueue/river/riverdriver v0.16.0 h1:y4Df4e1Xk3Id0nnu1VxHJn9118OzmRHcmvOxM/i1Q30= github.com/riverqueue/river/riverdriver v0.16.0/go.mod h1:7Kdf5HQDrLyLUUqPqXobaK+7zbcMctWeAl7yhg4nHes= +github.com/riverqueue/river/riverdriver/riverdatabasesql v0.16.0 h1:T/DcMmZXiJAyLN3CSyAoNcf3U4oAD9Ht/8Vd5SXv5YU= +github.com/riverqueue/river/riverdriver/riverdatabasesql v0.16.0/go.mod h1:a9EUhD2yGsAeM9eWo+QrGGbL8LVWoGj2m8KEzm0xUxE= github.com/riverqueue/river/riverdriver/riverpgxv5 v0.16.0 h1:6HP296OPN+3ORL9qG1f561pldB5eovkLzfkNIQmaTXI= github.com/riverqueue/river/riverdriver/riverpgxv5 v0.16.0/go.mod h1:MAeBNoTQ+CD3nRvV9mF6iCBfsGJTxYHZeZSP4MYoeUE= github.com/riverqueue/river/rivershared v0.16.0 h1:L1lQ3gMwdIsxA6yF0/PwAdsFP0T82yBD1V03q2GuJDU= github.com/riverqueue/river/rivershared v0.16.0/go.mod h1:y5Xu8Shcp44DUNnEQV4c6oWH4m2OTkSMCe6nRrgzT34= github.com/riverqueue/river/rivertype v0.16.0 h1:iDjNtCiUbXwLraqNEyQdH/OD80f1wTo8Ai6WHYCwRxs= github.com/riverqueue/river/rivertype v0.16.0/go.mod h1:DETcejveWlq6bAb8tHkbgJqmXWVLiFhTiEm8j7co1bE= +github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= +github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= @@ -727,7 +729,6 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= @@ -777,8 +778,8 @@ github.com/zitadel/oidc/v3 v3.32.0 h1:Mw0EPZRC6h+OXAuT0Uk2BZIjJQNHLqUpaJCm6c3IBy github.com/zitadel/oidc/v3 v3.32.0/go.mod h1:DyE/XClysRK/ozFaZSqlYamKVnTh4l6Ln25ihSNI03w= github.com/zitadel/passwap v0.6.0 h1:m9F3epFC0VkBXu25rihSLGyHvWiNlCzU5kk8RoI+SXQ= github.com/zitadel/passwap v0.6.0/go.mod h1:kqAiJ4I4eZvm3Y6oAk6hlEqlZZOkjMHraGXF90GG7LI= -github.com/zitadel/saml v0.3.3 h1:Cn+1ZNeWlzMM7wxUxJfgNjXSW+Yu6UD4zWbpySA5GQM= -github.com/zitadel/saml v0.3.3/go.mod h1:QqKcguOt7mMVI6tkEfpkyzwnYRdlmn3kYQj3VTPUw1g= +github.com/zitadel/saml v0.3.4 h1:L2pybnx2Hs+kqebZmUbnZUd9L/CY2sNw5psMWw2D/6Q= +github.com/zitadel/saml v0.3.4/go.mod h1:M0losAULJpLtAmXrYqBnf375ia2rMgJ75b1mpaU/GlA= github.com/zitadel/schema v1.3.0 h1:kQ9W9tvIwZICCKWcMvCEweXET1OcOyGEuFbHs4o5kg0= github.com/zitadel/schema v1.3.0/go.mod h1:NptN6mkBDFvERUCvZHlvWmmME+gmZ44xzwRXwhzsbtc= go.etcd.io/bbolt v1.3.3/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= @@ -924,8 +925,6 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= -golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= diff --git a/internal/api/grpc/oidc/v2/integration_test/oidc_test.go b/internal/api/grpc/oidc/v2/integration_test/oidc_test.go index 1eb031bd6d..43eaf1c6eb 100644 --- a/internal/api/grpc/oidc/v2/integration_test/oidc_test.go +++ b/internal/api/grpc/oidc/v2/integration_test/oidc_test.go @@ -104,13 +104,11 @@ func TestServer_CreateCallback(t *testing.T) { sessionResp := createSession(t, CTX, Instance.Users[integration.UserTypeOrgOwner].ID) tests := []struct { - name string - ctx context.Context - req *oidc_pb.CreateCallbackRequest - AuthError string - want *oidc_pb.CreateCallbackResponse - wantURL *url.URL - wantErr bool + name string + ctx context.Context + req *oidc_pb.CreateCallbackRequest + want *oidc_pb.CreateCallbackResponse + wantErr bool }{ { name: "Not found", diff --git a/internal/api/grpc/resources/action/v3alpha/integration_test/execution_target_test.go b/internal/api/grpc/resources/action/v3alpha/integration_test/execution_target_test.go index 7aff6afb3f..9e8bfac3eb 100644 --- a/internal/api/grpc/resources/action/v3alpha/integration_test/execution_target_test.go +++ b/internal/api/grpc/resources/action/v3alpha/integration_test/execution_target_test.go @@ -4,25 +4,54 @@ package action_test import ( "context" + "encoding/base64" "encoding/json" "io" "net/http" "net/http/httptest" + "net/url" "reflect" + "strings" "testing" "time" "github.com/brianvoe/gofakeit/v6" + "github.com/crewjam/saml" + "github.com/crewjam/saml/samlsp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/zitadel/oidc/v3/pkg/client/rp" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" + "golang.org/x/text/language" "google.golang.org/protobuf/types/known/durationpb" "github.com/zitadel/zitadel/internal/api/grpc/server/middleware" + oidc_api "github.com/zitadel/zitadel/internal/api/oidc" + saml_api "github.com/zitadel/zitadel/internal/api/saml" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/integration" + "github.com/zitadel/zitadel/internal/query" + "github.com/zitadel/zitadel/pkg/grpc/app" + "github.com/zitadel/zitadel/pkg/grpc/management" + "github.com/zitadel/zitadel/pkg/grpc/metadata" object "github.com/zitadel/zitadel/pkg/grpc/object/v3alpha" + oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2" action "github.com/zitadel/zitadel/pkg/grpc/resources/action/v3alpha" resource_object "github.com/zitadel/zitadel/pkg/grpc/resources/object/v3alpha" + saml_pb "github.com/zitadel/zitadel/pkg/grpc/saml/v2" + "github.com/zitadel/zitadel/pkg/grpc/session/v2" + "github.com/zitadel/zitadel/pkg/grpc/user/v2" +) + +const ( + redirectURI = "https://callback" + logoutRedirectURI = "https://logged-out" + redirectURIImplicit = "http://localhost:9999/callback" +) + +var ( + loginV2 = &app.LoginVersion{Version: &app.LoginVersion_LoginV2{LoginV2: &app.LoginV2{BaseUri: nil}}} ) func TestServer_ExecutionTarget(t *testing.T) { @@ -408,3 +437,749 @@ func testServerCall( return server.URL, server.Close } + +func conditionFunction(function string) *action.Condition { + return &action.Condition{ + ConditionType: &action.Condition_Function{ + Function: &action.FunctionExecution{ + Name: function, + }, + }, + } +} + +func TestServer_ExecutionTargetPreUserinfo(t *testing.T) { + instance := integration.NewInstance(CTX) + ensureFeatureEnabled(t, instance) + isolatedIAMCtx := instance.WithAuthorization(CTX, integration.UserTypeIAMOwner) + ctxLoginClient := instance.WithAuthorization(CTX, integration.UserTypeLogin) + + client, err := instance.CreateOIDCImplicitFlowClient(isolatedIAMCtx, redirectURIImplicit, loginV2) + require.NoError(t, err) + + type want struct { + addedClaims map[string]any + addedLogClaims map[string][]string + setUserMetadata []*metadata.Metadata + } + tests := []struct { + name string + ctx context.Context + dep func(ctx context.Context, t *testing.T, req *oidc_pb.CreateCallbackRequest) (string, func()) + req *oidc_pb.CreateCallbackRequest + want want + wantErr bool + }{ + { + name: "append claim", + ctx: ctxLoginClient, + dep: func(ctx context.Context, t *testing.T, req *oidc_pb.CreateCallbackRequest) (string, func()) { + response := &oidc_api.ContextInfoResponse{ + AppendClaims: []*oidc_api.AppendClaim{ + {Key: "added", Value: "value"}, + }, + } + return expectPreUserinfoExecution(ctx, t, instance, req, response) + }, + req: &oidc_pb.CreateCallbackRequest{ + AuthRequestId: func() string { + authRequestID, err := instance.CreateOIDCAuthRequestImplicitWithoutLoginClientHeader(isolatedIAMCtx, client.GetClientId(), redirectURIImplicit) + require.NoError(t, err) + return authRequestID + }(), + }, + want: want{ + addedClaims: map[string]any{ + "added": "value", + }, + }, + wantErr: false, + }, + { + name: "append log claim", + ctx: ctxLoginClient, + dep: func(ctx context.Context, t *testing.T, req *oidc_pb.CreateCallbackRequest) (string, func()) { + response := &oidc_api.ContextInfoResponse{ + AppendLogClaims: []string{ + "addedLog", + }, + } + return expectPreUserinfoExecution(ctx, t, instance, req, response) + }, + req: &oidc_pb.CreateCallbackRequest{ + AuthRequestId: func() string { + authRequestID, err := instance.CreateOIDCAuthRequestImplicitWithoutLoginClientHeader(isolatedIAMCtx, client.GetClientId(), redirectURIImplicit) + require.NoError(t, err) + return authRequestID + }(), + }, + want: want{ + addedLogClaims: map[string][]string{ + "urn:zitadel:iam:action:function/preuserinfo:log": {"addedLog"}, + }, + }, + wantErr: false, + }, + { + name: "set user metadata", + ctx: ctxLoginClient, + dep: func(ctx context.Context, t *testing.T, req *oidc_pb.CreateCallbackRequest) (string, func()) { + response := &oidc_api.ContextInfoResponse{ + SetUserMetadata: []*domain.Metadata{ + {Key: "key", Value: []byte("value")}, + }, + } + return expectPreUserinfoExecution(ctx, t, instance, req, response) + }, + req: &oidc_pb.CreateCallbackRequest{ + AuthRequestId: func() string { + authRequestID, err := instance.CreateOIDCAuthRequestImplicitWithoutLoginClientHeader(isolatedIAMCtx, client.GetClientId(), redirectURIImplicit) + require.NoError(t, err) + return authRequestID + }(), + }, + want: want{ + setUserMetadata: []*metadata.Metadata{ + {Key: "key", Value: []byte("value")}, + }, + }, + wantErr: false, + }, + { + name: "full usage", + ctx: ctxLoginClient, + dep: func(ctx context.Context, t *testing.T, req *oidc_pb.CreateCallbackRequest) (string, func()) { + response := &oidc_api.ContextInfoResponse{ + SetUserMetadata: []*domain.Metadata{ + {Key: "key1", Value: []byte("value1")}, + {Key: "key2", Value: []byte("value2")}, + {Key: "key3", Value: []byte("value3")}, + }, + AppendLogClaims: []string{ + "addedLog1", + "addedLog2", + "addedLog3", + }, + AppendClaims: []*oidc_api.AppendClaim{ + {Key: "added1", Value: "value1"}, + {Key: "added2", Value: "value2"}, + {Key: "added3", Value: "value3"}, + }, + } + return expectPreUserinfoExecution(ctx, t, instance, req, response) + }, + req: &oidc_pb.CreateCallbackRequest{ + AuthRequestId: func() string { + authRequestID, err := instance.CreateOIDCAuthRequestImplicitWithoutLoginClientHeader(isolatedIAMCtx, client.GetClientId(), redirectURIImplicit) + require.NoError(t, err) + return authRequestID + }(), + }, + want: want{ + addedClaims: map[string]any{ + "added1": "value1", + "added2": "value2", + "added3": "value3", + }, + setUserMetadata: []*metadata.Metadata{ + {Key: "key1", Value: []byte("value1")}, + {Key: "key2", Value: []byte("value2")}, + {Key: "key3", Value: []byte("value3")}, + }, + addedLogClaims: map[string][]string{ + "urn:zitadel:iam:action:function/preuserinfo:log": {"addedLog1", "addedLog2", "addedLog3"}, + }, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + userID, closeF := tt.dep(isolatedIAMCtx, t, tt.req) + defer closeF() + + got, err := instance.Client.OIDCv2.CreateCallback(tt.ctx, tt.req) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + callbackUrl, err := url.Parse(strings.Replace(got.GetCallbackUrl(), "#", "?", 1)) + require.NoError(t, err) + claims := getIDTokenClaimsFromCallbackURL(tt.ctx, t, instance, client.GetClientId(), callbackUrl) + + for k, v := range tt.want.addedClaims { + value, ok := claims[k] + if !assert.True(t, ok) { + return + } + assert.Equal(t, v, value) + } + for k, v := range tt.want.addedLogClaims { + value, ok := claims[k] + if !assert.True(t, ok) { + return + } + assert.ElementsMatch(t, v, value) + } + if len(tt.want.setUserMetadata) > 0 { + checkForSetMetadata(isolatedIAMCtx, t, instance, userID, tt.want.setUserMetadata) + } + }) + } +} + +func expectPreUserinfoExecution(ctx context.Context, t *testing.T, instance *integration.Instance, req *oidc_pb.CreateCallbackRequest, response *oidc_api.ContextInfoResponse) (string, func()) { + userEmail := gofakeit.Email() + userPhone := "+41" + gofakeit.Phone() + userResp := instance.CreateHumanUserVerified(ctx, instance.DefaultOrg.Id, userEmail, userPhone) + + sessionResp := createSession(ctx, t, instance, userResp.GetUserId()) + req.CallbackKind = &oidc_pb.CreateCallbackRequest_Session{ + Session: &oidc_pb.Session{ + SessionId: sessionResp.GetSessionId(), + SessionToken: sessionResp.GetSessionToken(), + }, + } + expectedContextInfo := contextInfoForUserOIDC(instance, "function/preuserinfo", userResp, userEmail, userPhone) + + targetURL, closeF := testServerCall(expectedContextInfo, 0, http.StatusOK, response) + + targetResp := waitForTarget(ctx, t, instance, targetURL, domain.TargetTypeCall, true) + waitForExecutionOnCondition(ctx, t, instance, conditionFunction("preuserinfo"), executionTargetsSingleTarget(targetResp.GetDetails().GetId())) + return userResp.GetUserId(), closeF +} + +func createSession(ctx context.Context, t *testing.T, instance *integration.Instance, userID string) *session.CreateSessionResponse { + sessionResp, err := instance.Client.SessionV2.CreateSession(ctx, &session.CreateSessionRequest{ + Checks: &session.Checks{ + User: &session.CheckUser{ + Search: &session.CheckUser_UserId{ + UserId: userID, + }, + }, + }, + }) + require.NoError(t, err) + return sessionResp +} + +func checkForSetMetadata(ctx context.Context, t *testing.T, instance *integration.Instance, userID string, metadataExpected []*metadata.Metadata) { + integration.WaitForAndTickWithMaxDuration(ctx, time.Minute) + + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(ctx, time.Minute) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + metadataResp, err := instance.Client.Mgmt.ListUserMetadata(ctx, &management.ListUserMetadataRequest{Id: userID}) + if !assert.NoError(ct, err) { + return + } + for _, dataExpected := range metadataExpected { + found := false + for _, dataCheck := range metadataResp.GetResult() { + if dataExpected.Key == dataCheck.Key { + found = true + if !assert.Equal(ct, dataExpected.Value, dataCheck.Value) { + return + } + } + } + if !assert.True(ct, found) { + return + } + } + }, retryDuration, tick) +} + +func getIDTokenClaimsFromCallbackURL(ctx context.Context, t *testing.T, instance *integration.Instance, clientID string, callbackURL *url.URL) map[string]any { + accessToken := callbackURL.Query().Get("access_token") + idToken := callbackURL.Query().Get("id_token") + + provider, err := instance.CreateRelyingParty(ctx, clientID, redirectURIImplicit, oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopePhone) + require.NoError(t, err) + claims, err := rp.VerifyTokens[*oidc.IDTokenClaims](context.Background(), accessToken, idToken, provider.IDTokenVerifier()) + require.NoError(t, err) + return claims.Claims +} + +type CustomAccessTokenClaims struct { + oidc.TokenClaims + Added1 string `json:"added1,omitempty"` + Added2 string `json:"added2,omitempty"` + Added3 string `json:"added3,omitempty"` + Log []string `json:"urn:zitadel:iam:action:function/preaccesstoken:log,omitempty"` +} + +func getAccessTokenClaims(ctx context.Context, t *testing.T, instance *integration.Instance, callbackURL *url.URL) *CustomAccessTokenClaims { + accessToken := callbackURL.Query().Get("access_token") + + verifier := op.NewAccessTokenVerifier(instance.OIDCIssuer(), rp.NewRemoteKeySet(http.DefaultClient, instance.OIDCIssuer()+"/oauth/v2/keys")) + + claims, err := op.VerifyAccessToken[*CustomAccessTokenClaims](ctx, accessToken, verifier) + require.NoError(t, err) + return claims +} + +func contextInfoForUserOIDC(instance *integration.Instance, function string, userResp *user.AddHumanUserResponse, email, phone string) *oidc_api.ContextInfo { + return &oidc_api.ContextInfo{ + Function: function, + UserInfo: &oidc.UserInfo{ + Subject: userResp.GetUserId(), + }, + User: &query.User{ + ID: userResp.GetUserId(), + CreationDate: userResp.Details.ChangeDate.AsTime(), + ChangeDate: userResp.Details.ChangeDate.AsTime(), + ResourceOwner: instance.DefaultOrg.GetId(), + Sequence: userResp.Details.Sequence, + State: 1, + Username: email, + PreferredLoginName: email, + Human: &query.Human{ + FirstName: "Mickey", + LastName: "Mouse", + NickName: "Mickey", + DisplayName: "Mickey Mouse", + AvatarKey: "", + PreferredLanguage: language.Dutch, + Gender: 2, + Email: domain.EmailAddress(email), + IsEmailVerified: true, + Phone: domain.PhoneNumber(phone), + IsPhoneVerified: true, + PasswordChangeRequired: false, + PasswordChanged: time.Time{}, + MFAInitSkipped: time.Time{}, + }, + }, + UserMetadata: nil, + Org: &query.UserInfoOrg{ + ID: instance.DefaultOrg.GetId(), + Name: instance.DefaultOrg.GetName(), + PrimaryDomain: instance.DefaultOrg.GetPrimaryDomain(), + }, + UserGrants: nil, + Response: nil, + } +} + +func TestServer_ExecutionTargetPreAccessToken(t *testing.T) { + instance := integration.NewInstance(CTX) + ensureFeatureEnabled(t, instance) + isolatedIAMCtx := instance.WithAuthorization(CTX, integration.UserTypeIAMOwner) + ctxLoginClient := instance.WithAuthorization(CTX, integration.UserTypeLogin) + + client, err := instance.CreateOIDCImplicitFlowClient(isolatedIAMCtx, redirectURIImplicit, loginV2) + require.NoError(t, err) + + type want struct { + addedClaims *CustomAccessTokenClaims + addedLogClaims map[string][]string + setUserMetadata []*metadata.Metadata + } + tests := []struct { + name string + ctx context.Context + dep func(ctx context.Context, t *testing.T, req *oidc_pb.CreateCallbackRequest) (string, func()) + req *oidc_pb.CreateCallbackRequest + want want + wantErr bool + }{ + { + name: "append claim", + ctx: ctxLoginClient, + dep: func(ctx context.Context, t *testing.T, req *oidc_pb.CreateCallbackRequest) (string, func()) { + response := &oidc_api.ContextInfoResponse{ + AppendClaims: []*oidc_api.AppendClaim{ + {Key: "added1", Value: "value"}, + }, + } + return expectPreAccessTokenExecution(ctx, t, instance, req, response) + }, + req: &oidc_pb.CreateCallbackRequest{ + AuthRequestId: func() string { + authRequestID, err := instance.CreateOIDCAuthRequestImplicitWithoutLoginClientHeader(isolatedIAMCtx, client.GetClientId(), redirectURIImplicit) + require.NoError(t, err) + return authRequestID + }(), + }, + want: want{ + addedClaims: &CustomAccessTokenClaims{ + Added1: "value", + }, + }, + wantErr: false, + }, + { + name: "append log claim", + ctx: ctxLoginClient, + dep: func(ctx context.Context, t *testing.T, req *oidc_pb.CreateCallbackRequest) (string, func()) { + response := &oidc_api.ContextInfoResponse{ + AppendLogClaims: []string{ + "addedLog", + }, + } + return expectPreAccessTokenExecution(ctx, t, instance, req, response) + }, + req: &oidc_pb.CreateCallbackRequest{ + AuthRequestId: func() string { + authRequestID, err := instance.CreateOIDCAuthRequestImplicitWithoutLoginClientHeader(isolatedIAMCtx, client.GetClientId(), redirectURIImplicit) + require.NoError(t, err) + return authRequestID + }(), + }, + want: want{ + addedClaims: &CustomAccessTokenClaims{ + Log: []string{"addedLog"}, + }, + }, + wantErr: false, + }, + { + name: "set user metadata", + ctx: ctxLoginClient, + dep: func(ctx context.Context, t *testing.T, req *oidc_pb.CreateCallbackRequest) (string, func()) { + response := &oidc_api.ContextInfoResponse{ + SetUserMetadata: []*domain.Metadata{ + {Key: "key", Value: []byte("value")}, + }, + } + return expectPreAccessTokenExecution(ctx, t, instance, req, response) + }, + req: &oidc_pb.CreateCallbackRequest{ + AuthRequestId: func() string { + authRequestID, err := instance.CreateOIDCAuthRequestImplicitWithoutLoginClientHeader(isolatedIAMCtx, client.GetClientId(), redirectURIImplicit) + require.NoError(t, err) + return authRequestID + }(), + }, + want: want{ + setUserMetadata: []*metadata.Metadata{ + {Key: "key", Value: []byte("value")}, + }, + }, + wantErr: false, + }, + { + name: "full usage", + ctx: ctxLoginClient, + dep: func(ctx context.Context, t *testing.T, req *oidc_pb.CreateCallbackRequest) (string, func()) { + response := &oidc_api.ContextInfoResponse{ + SetUserMetadata: []*domain.Metadata{ + {Key: "key1", Value: []byte("value1")}, + {Key: "key2", Value: []byte("value2")}, + {Key: "key3", Value: []byte("value3")}, + }, + AppendLogClaims: []string{ + "addedLog1", + "addedLog2", + "addedLog3", + }, + AppendClaims: []*oidc_api.AppendClaim{ + {Key: "added1", Value: "value1"}, + {Key: "added2", Value: "value2"}, + {Key: "added3", Value: "value3"}, + }, + } + return expectPreAccessTokenExecution(ctx, t, instance, req, response) + }, + req: &oidc_pb.CreateCallbackRequest{ + AuthRequestId: func() string { + authRequestID, err := instance.CreateOIDCAuthRequestImplicitWithoutLoginClientHeader(isolatedIAMCtx, client.GetClientId(), redirectURIImplicit) + require.NoError(t, err) + return authRequestID + }(), + }, + want: want{ + addedClaims: &CustomAccessTokenClaims{ + Added1: "value1", + Added2: "value2", + Added3: "value3", + Log: []string{"addedLog1", "addedLog2", "addedLog3"}, + }, + setUserMetadata: []*metadata.Metadata{ + {Key: "key1", Value: []byte("value1")}, + {Key: "key2", Value: []byte("value2")}, + {Key: "key3", Value: []byte("value3")}, + }, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + userID, closeF := tt.dep(isolatedIAMCtx, t, tt.req) + defer closeF() + + got, err := instance.Client.OIDCv2.CreateCallback(tt.ctx, tt.req) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + + callbackUrl, err := url.Parse(strings.Replace(got.GetCallbackUrl(), "#", "?", 1)) + require.NoError(t, err) + claims := getAccessTokenClaims(tt.ctx, t, instance, callbackUrl) + + if tt.want.addedClaims != nil { + assert.Equal(t, tt.want.addedClaims.Added1, claims.Added1) + assert.Equal(t, tt.want.addedClaims.Added2, claims.Added2) + assert.Equal(t, tt.want.addedClaims.Added3, claims.Added3) + assert.Equal(t, tt.want.addedClaims.Log, claims.Log) + } + if len(tt.want.setUserMetadata) > 0 { + checkForSetMetadata(isolatedIAMCtx, t, instance, userID, tt.want.setUserMetadata) + } + + }) + } +} + +func expectPreAccessTokenExecution(ctx context.Context, t *testing.T, instance *integration.Instance, req *oidc_pb.CreateCallbackRequest, response *oidc_api.ContextInfoResponse) (string, func()) { + userEmail := gofakeit.Email() + userPhone := "+41" + gofakeit.Phone() + userResp := instance.CreateHumanUserVerified(ctx, instance.DefaultOrg.Id, userEmail, userPhone) + + sessionResp := createSession(ctx, t, instance, userResp.GetUserId()) + req.CallbackKind = &oidc_pb.CreateCallbackRequest_Session{ + Session: &oidc_pb.Session{ + SessionId: sessionResp.GetSessionId(), + SessionToken: sessionResp.GetSessionToken(), + }, + } + expectedContextInfo := contextInfoForUserOIDC(instance, "function/preaccesstoken", userResp, userEmail, userPhone) + + targetURL, closeF := testServerCall(expectedContextInfo, 0, http.StatusOK, response) + + targetResp := waitForTarget(ctx, t, instance, targetURL, domain.TargetTypeCall, true) + waitForExecutionOnCondition(ctx, t, instance, conditionFunction("preaccesstoken"), executionTargetsSingleTarget(targetResp.GetDetails().GetId())) + return userResp.GetUserId(), closeF +} + +func TestServer_ExecutionTargetPreSAMLResponse(t *testing.T) { + instance := integration.NewInstance(CTX) + ensureFeatureEnabled(t, instance) + isolatedIAMCtx := instance.WithAuthorization(CTX, integration.UserTypeIAMOwner) + ctxLoginClient := instance.WithAuthorization(CTX, integration.UserTypeLogin) + + idpMetadata, err := instance.GetSAMLIDPMetadata() + require.NoError(t, err) + + acsPost := idpMetadata.IDPSSODescriptors[0].SingleSignOnServices[1] + _, _, spMiddlewarePost := createSAMLApplication(isolatedIAMCtx, t, instance, idpMetadata, saml.HTTPPostBinding, false, false) + + type want struct { + addedAttributes map[string][]saml.AttributeValue + setUserMetadata []*metadata.Metadata + } + tests := []struct { + name string + ctx context.Context + dep func(ctx context.Context, t *testing.T, req *saml_pb.CreateResponseRequest) (string, func()) + req *saml_pb.CreateResponseRequest + want want + wantErr bool + }{ + { + name: "append attribute", + ctx: ctxLoginClient, + dep: func(ctx context.Context, t *testing.T, req *saml_pb.CreateResponseRequest) (string, func()) { + response := &saml_api.ContextInfoResponse{ + AppendAttribute: []*saml_api.AppendAttribute{ + {Name: "added", NameFormat: "format", Value: []string{"value"}}, + }, + } + return expectPreSAMLResponseExecution(ctx, t, instance, req, response) + }, + req: &saml_pb.CreateResponseRequest{ + SamlRequestId: func() string { + _, samlRequestID, err := instance.CreateSAMLAuthRequest(spMiddlewarePost, instance.Users[integration.UserTypeOrgOwner].ID, acsPost, gofakeit.BitcoinAddress(), saml.HTTPPostBinding) + require.NoError(t, err) + return samlRequestID + }(), + }, + want: want{ + addedAttributes: map[string][]saml.AttributeValue{ + "added": {saml.AttributeValue{Value: "value"}}, + }, + }, + wantErr: false, + }, + { + name: "set user metadata", + ctx: ctxLoginClient, + dep: func(ctx context.Context, t *testing.T, req *saml_pb.CreateResponseRequest) (string, func()) { + response := &saml_api.ContextInfoResponse{ + SetUserMetadata: []*domain.Metadata{ + {Key: "key", Value: []byte("value")}, + }, + } + return expectPreSAMLResponseExecution(ctx, t, instance, req, response) + }, + req: &saml_pb.CreateResponseRequest{ + SamlRequestId: func() string { + _, samlRequestID, err := instance.CreateSAMLAuthRequest(spMiddlewarePost, instance.Users[integration.UserTypeOrgOwner].ID, acsPost, gofakeit.BitcoinAddress(), saml.HTTPPostBinding) + require.NoError(t, err) + return samlRequestID + }(), + }, + want: want{ + setUserMetadata: []*metadata.Metadata{ + {Key: "key", Value: []byte("value")}, + }, + }, + wantErr: false, + }, + { + name: "set user metadata", + ctx: ctxLoginClient, + dep: func(ctx context.Context, t *testing.T, req *saml_pb.CreateResponseRequest) (string, func()) { + response := &saml_api.ContextInfoResponse{ + AppendAttribute: []*saml_api.AppendAttribute{ + {Name: "added1", NameFormat: "format", Value: []string{"value1"}}, + {Name: "added2", NameFormat: "format", Value: []string{"value2"}}, + {Name: "added3", NameFormat: "format", Value: []string{"value3"}}, + }, + SetUserMetadata: []*domain.Metadata{ + {Key: "key1", Value: []byte("value1")}, + {Key: "key2", Value: []byte("value2")}, + {Key: "key3", Value: []byte("value3")}, + }, + } + return expectPreSAMLResponseExecution(ctx, t, instance, req, response) + }, + req: &saml_pb.CreateResponseRequest{ + SamlRequestId: func() string { + _, samlRequestID, err := instance.CreateSAMLAuthRequest(spMiddlewarePost, instance.Users[integration.UserTypeOrgOwner].ID, acsPost, gofakeit.BitcoinAddress(), saml.HTTPPostBinding) + require.NoError(t, err) + return samlRequestID + }(), + }, + want: want{ + addedAttributes: map[string][]saml.AttributeValue{ + "added1": {saml.AttributeValue{Value: "value1"}}, + "added2": {saml.AttributeValue{Value: "value2"}}, + "added3": {saml.AttributeValue{Value: "value3"}}, + }, + setUserMetadata: []*metadata.Metadata{ + {Key: "key1", Value: []byte("value1")}, + {Key: "key2", Value: []byte("value2")}, + {Key: "key3", Value: []byte("value3")}, + }, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + userID, closeF := tt.dep(isolatedIAMCtx, t, tt.req) + defer closeF() + + got, err := instance.Client.SAMLv2.CreateResponse(tt.ctx, tt.req) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + attributes := getSAMLResponseAttributes(t, got.GetPost().GetSamlResponse(), spMiddlewarePost) + for k, v := range tt.want.addedAttributes { + found := false + for _, attribute := range attributes { + if attribute.Name == k { + found = true + assert.Equal(t, v, attribute.Values) + } + } + if !assert.True(t, found) { + return + } + } + if len(tt.want.setUserMetadata) > 0 { + checkForSetMetadata(isolatedIAMCtx, t, instance, userID, tt.want.setUserMetadata) + } + }) + } +} + +func expectPreSAMLResponseExecution(ctx context.Context, t *testing.T, instance *integration.Instance, req *saml_pb.CreateResponseRequest, response *saml_api.ContextInfoResponse) (string, func()) { + userEmail := gofakeit.Email() + userPhone := "+41" + gofakeit.Phone() + userResp := instance.CreateHumanUserVerified(ctx, instance.DefaultOrg.Id, userEmail, userPhone) + + sessionResp := createSession(ctx, t, instance, userResp.GetUserId()) + req.ResponseKind = &saml_pb.CreateResponseRequest_Session{ + Session: &saml_pb.Session{ + SessionId: sessionResp.GetSessionId(), + SessionToken: sessionResp.GetSessionToken(), + }, + } + expectedContextInfo := contextInfoForUserSAML(instance, "function/presamlresponse", userResp, userEmail, userPhone) + + targetURL, closeF := testServerCall(expectedContextInfo, 0, http.StatusOK, response) + + targetResp := waitForTarget(ctx, t, instance, targetURL, domain.TargetTypeCall, true) + waitForExecutionOnCondition(ctx, t, instance, conditionFunction("presamlresponse"), executionTargetsSingleTarget(targetResp.GetDetails().GetId())) + + return userResp.GetUserId(), closeF +} + +func createSAMLSP(t *testing.T, idpMetadata *saml.EntityDescriptor, binding string) (string, *samlsp.Middleware) { + rootURL := "example." + gofakeit.DomainName() + spMiddleware, err := integration.CreateSAMLSP("https://"+rootURL, idpMetadata, binding) + require.NoError(t, err) + return rootURL, spMiddleware +} + +func createSAMLApplication(ctx context.Context, t *testing.T, instance *integration.Instance, idpMetadata *saml.EntityDescriptor, binding string, projectRoleCheck, hasProjectCheck bool) (string, string, *samlsp.Middleware) { + project, err := instance.CreateProjectWithPermissionCheck(ctx, projectRoleCheck, hasProjectCheck) + require.NoError(t, err) + rootURL, sp := createSAMLSP(t, idpMetadata, binding) + _, err = instance.CreateSAMLClient(ctx, project.GetId(), sp) + require.NoError(t, err) + return project.GetId(), rootURL, sp +} + +func getSAMLResponseAttributes(t *testing.T, samlResponse string, sp *samlsp.Middleware) []saml.Attribute { + data, err := base64.StdEncoding.DecodeString(samlResponse) + require.NoError(t, err) + sp.ServiceProvider.AllowIDPInitiated = true + assertion, err := sp.ServiceProvider.ParseXMLResponse(data, []string{}) + require.NoError(t, err) + return assertion.AttributeStatements[0].Attributes +} + +func contextInfoForUserSAML(instance *integration.Instance, function string, userResp *user.AddHumanUserResponse, email, phone string) *saml_api.ContextInfo { + return &saml_api.ContextInfo{ + Function: function, + User: &query.User{ + ID: userResp.GetUserId(), + CreationDate: userResp.Details.ChangeDate.AsTime(), + ChangeDate: userResp.Details.ChangeDate.AsTime(), + ResourceOwner: instance.DefaultOrg.GetId(), + Sequence: userResp.Details.Sequence, + State: 1, + Type: domain.UserTypeHuman, + Username: email, + PreferredLoginName: email, + LoginNames: []string{email}, + Human: &query.Human{ + FirstName: "Mickey", + LastName: "Mouse", + NickName: "Mickey", + DisplayName: "Mickey Mouse", + AvatarKey: "", + PreferredLanguage: language.Dutch, + Gender: 2, + Email: domain.EmailAddress(email), + IsEmailVerified: true, + Phone: domain.PhoneNumber(phone), + IsPhoneVerified: true, + PasswordChangeRequired: false, + PasswordChanged: time.Time{}, + MFAInitSkipped: time.Time{}, + }, + }, + UserGrants: nil, + Response: nil, + } +} diff --git a/internal/api/grpc/resources/action/v3alpha/integration_test/execution_test.go b/internal/api/grpc/resources/action/v3alpha/integration_test/execution_test.go index b56efd6b99..a4d4fe24f8 100644 --- a/internal/api/grpc/resources/action/v3alpha/integration_test/execution_test.go +++ b/internal/api/grpc/resources/action/v3alpha/integration_test/execution_test.go @@ -774,7 +774,7 @@ func TestServer_SetExecution_Function(t *testing.T) { req: &action.SetExecutionRequest{ Condition: &action.Condition{ ConditionType: &action.Condition_Function{ - Function: &action.FunctionExecution{Name: "Action.Flow.Type.ExternalAuthentication.Action.TriggerType.PostAuthentication"}, + Function: &action.FunctionExecution{Name: "presamlresponse"}, }, }, Execution: &action.Execution{ diff --git a/internal/api/grpc/resources/action/v3alpha/integration_test/query_test.go b/internal/api/grpc/resources/action/v3alpha/integration_test/query_test.go index aa748ac4d8..b46870d98c 100644 --- a/internal/api/grpc/resources/action/v3alpha/integration_test/query_test.go +++ b/internal/api/grpc/resources/action/v3alpha/integration_test/query_test.go @@ -835,7 +835,7 @@ func TestServer_SearchExecutions(t *testing.T) { {ConditionType: &action.Condition_Event{Event: &action.EventExecution{Condition: &action.EventExecution_Event{Event: "user.added"}}}}, {ConditionType: &action.Condition_Event{Event: &action.EventExecution{Condition: &action.EventExecution_Group{Group: "user"}}}}, {ConditionType: &action.Condition_Event{Event: &action.EventExecution{Condition: &action.EventExecution_All{All: true}}}}, - {ConditionType: &action.Condition_Function{Function: &action.FunctionExecution{Name: "Action.Flow.Type.ExternalAuthentication.Action.TriggerType.PostAuthentication"}}}, + {ConditionType: &action.Condition_Function{Function: &action.FunctionExecution{Name: "presamlresponse"}}}, }, }, }, diff --git a/internal/api/grpc/server/middleware/execution_interceptor.go b/internal/api/grpc/server/middleware/execution_interceptor.go index c309827d94..3288f28ad8 100644 --- a/internal/api/grpc/server/middleware/execution_interceptor.go +++ b/internal/api/grpc/server/middleware/execution_interceptor.go @@ -3,23 +3,19 @@ package middleware import ( "context" "encoding/json" - "strings" - "github.com/zitadel/logging" "google.golang.org/grpc" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/execution" "github.com/zitadel/zitadel/internal/query" - exec_repo "github.com/zitadel/zitadel/internal/repository/execution" "github.com/zitadel/zitadel/internal/telemetry/tracing" "github.com/zitadel/zitadel/internal/zerrors" ) func ExecutionHandler(queries *query.Queries) grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { - requestTargets, responseTargets := queryTargets(ctx, queries, info.FullMethod) + requestTargets, responseTargets := execution.QueryExecutionTargetsForRequestAndResponse(ctx, queries, info.FullMethod) // call targets otherwise return req handledReq, err := executeTargetsForRequest(ctx, requestTargets, info.FullMethod, req) @@ -81,49 +77,6 @@ func executeTargetsForResponse(ctx context.Context, targets []execution.Target, return execution.CallTargets(ctx, targets, info) } -type ExecutionQueries interface { - TargetsByExecutionIDs(ctx context.Context, ids1, ids2 []string) (execution []*query.ExecutionTarget, err error) -} - -func queryTargets( - ctx context.Context, - queries ExecutionQueries, - fullMethod string, -) ([]execution.Target, []execution.Target) { - ctx, span := tracing.NewSpan(ctx) - defer span.End() - - targets, err := queries.TargetsByExecutionIDs(ctx, - idsForFullMethod(fullMethod, domain.ExecutionTypeRequest), - idsForFullMethod(fullMethod, domain.ExecutionTypeResponse), - ) - requestTargets := make([]execution.Target, 0, len(targets)) - responseTargets := make([]execution.Target, 0, len(targets)) - if err != nil { - logging.WithFields("fullMethod", fullMethod).WithError(err).Info("unable to query targets") - return requestTargets, responseTargets - } - - for _, target := range targets { - if strings.HasPrefix(target.GetExecutionID(), exec_repo.IDAll(domain.ExecutionTypeRequest)) { - requestTargets = append(requestTargets, target) - } else if strings.HasPrefix(target.GetExecutionID(), exec_repo.IDAll(domain.ExecutionTypeResponse)) { - responseTargets = append(responseTargets, target) - } - } - - return requestTargets, responseTargets -} - -func idsForFullMethod(fullMethod string, executionType domain.ExecutionType) []string { - return []string{exec_repo.ID(executionType, fullMethod), exec_repo.ID(executionType, serviceFromFullMethod(fullMethod)), exec_repo.IDAll(executionType)} -} - -func serviceFromFullMethod(s string) string { - parts := strings.Split(s, "/") - return parts[1] -} - var _ execution.ContextInfo = &ContextInfoRequest{} type ContextInfoRequest struct { diff --git a/internal/api/grpc/user/v2/integration_test/user_test.go b/internal/api/grpc/user/v2/integration_test/user_test.go index 0293fd925d..f39212f7e3 100644 --- a/internal/api/grpc/user/v2/integration_test/user_test.go +++ b/internal/api/grpc/user/v2/integration_test/user_test.go @@ -20,7 +20,6 @@ import ( "google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/timestamppb" - "github.com/zitadel/zitadel/internal/api/grpc" "github.com/zitadel/zitadel/internal/integration" "github.com/zitadel/zitadel/internal/integration/sink" "github.com/zitadel/zitadel/pkg/grpc/auth" @@ -2114,18 +2113,29 @@ func TestServer_StartIdentityProviderIntent(t *testing.T) { } func TestServer_RetrieveIdentityProviderIntent(t *testing.T) { - idpID := Instance.AddGenericOAuthProvider(IamCTX, gofakeit.AppName()).GetId() - intentID := Instance.CreateIntent(CTX, idpID).GetIdpIntent().GetIdpIntentId() + oauthIdpID := Instance.AddGenericOAuthProvider(IamCTX, gofakeit.AppName()).GetId() + oidcIdpID := Instance.AddGenericOIDCProvider(IamCTX, gofakeit.AppName()).GetId() + samlIdpID := Instance.AddSAMLPostProvider(IamCTX) + ldapIdpID := Instance.AddLDAPProvider(IamCTX) + authURL, err := url.Parse(Instance.CreateIntent(CTX, oauthIdpID).GetAuthUrl()) + require.NoError(t, err) + intentID := authURL.Query().Get("state") - successfulID, token, changeDate, sequence, err := sink.SuccessfulOAuthIntent(Instance.ID(), idpID, "id", "") + successfulID, token, changeDate, sequence, err := sink.SuccessfulOAuthIntent(Instance.ID(), oauthIdpID, "id", "") require.NoError(t, err) - successfulWithUserID, withUsertoken, withUserchangeDate, withUsersequence, err := sink.SuccessfulOAuthIntent(Instance.ID(), idpID, "id", "user") + successfulWithUserID, withUsertoken, withUserchangeDate, withUsersequence, err := sink.SuccessfulOAuthIntent(Instance.ID(), oauthIdpID, "id", "user") require.NoError(t, err) - ldapSuccessfulID, ldapToken, ldapChangeDate, ldapSequence, err := sink.SuccessfulLDAPIntent(Instance.ID(), idpID, "id", "") + oidcSuccessful, oidcToken, oidcChangeDate, oidcSequence, err := sink.SuccessfulOIDCIntent(Instance.ID(), oidcIdpID, "id", "") require.NoError(t, err) - ldapSuccessfulWithUserID, ldapWithUserToken, ldapWithUserChangeDate, ldapWithUserSequence, err := sink.SuccessfulLDAPIntent(Instance.ID(), idpID, "id", "user") + oidcSuccessfulWithUserID, oidcWithUserIDToken, oidcWithUserIDChangeDate, oidcWithUserIDSequence, err := sink.SuccessfulOIDCIntent(Instance.ID(), oidcIdpID, "id", "user") require.NoError(t, err) - samlSuccessfulID, samlToken, samlChangeDate, samlSequence, err := sink.SuccessfulSAMLIntent(Instance.ID(), idpID, "id", "") + ldapSuccessfulID, ldapToken, ldapChangeDate, ldapSequence, err := sink.SuccessfulLDAPIntent(Instance.ID(), ldapIdpID, "id", "") + require.NoError(t, err) + ldapSuccessfulWithUserID, ldapWithUserToken, ldapWithUserChangeDate, ldapWithUserSequence, err := sink.SuccessfulLDAPIntent(Instance.ID(), ldapIdpID, "id", "user") + require.NoError(t, err) + samlSuccessfulID, samlToken, samlChangeDate, samlSequence, err := sink.SuccessfulSAMLIntent(Instance.ID(), samlIdpID, "id", "") + require.NoError(t, err) + samlSuccessfulWithUserID, samlWithUserToken, samlWithUserChangeDate, samlWithUserSequence, err := sink.SuccessfulSAMLIntent(Instance.ID(), samlIdpID, "id", "user") require.NoError(t, err) type args struct { ctx context.Context @@ -2160,7 +2170,7 @@ func TestServer_RetrieveIdentityProviderIntent(t *testing.T) { wantErr: true, }, { - name: "retrieve successful intent", + name: "retrieve successful oauth intent", args: args{ CTX, &user.RetrieveIdentityProviderIntentRequest{ @@ -2181,18 +2191,31 @@ func TestServer_RetrieveIdentityProviderIntent(t *testing.T) { IdToken: gu.Ptr("idToken"), }, }, - IdpId: idpID, + IdpId: oauthIdpID, UserId: "id", - UserName: "username", + UserName: "", RawInformation: func() *structpb.Struct { s, err := structpb.NewStruct(map[string]interface{}{ - "sub": "id", - "preferred_username": "username", + "RawInfo": map[string]interface{}{ + "id": "id", + "preferred_username": "username", + }, }) require.NoError(t, err) return s }(), }, + AddHumanUser: &user.AddHumanUserRequest{ + Profile: &user.SetHumanProfile{ + PreferredLanguage: gu.Ptr("und"), + }, + IdpLinks: []*user.IDPLink{ + {IdpId: oauthIdpID, UserId: "id"}, + }, + Email: &user.SetHumanEmail{ + Verification: &user.SetHumanEmail_SendCode{SendCode: &user.SendEmailVerificationCode{}}, + }, + }, }, wantErr: false, }, @@ -2219,7 +2242,97 @@ func TestServer_RetrieveIdentityProviderIntent(t *testing.T) { IdToken: gu.Ptr("idToken"), }, }, - IdpId: idpID, + IdpId: oauthIdpID, + UserId: "id", + UserName: "", + RawInformation: func() *structpb.Struct { + s, err := structpb.NewStruct(map[string]interface{}{ + "RawInfo": map[string]interface{}{ + "id": "id", + "preferred_username": "username", + }, + }) + require.NoError(t, err) + return s + }(), + }, + }, + wantErr: false, + }, + { + name: "retrieve successful oidc intent", + args: args{ + CTX, + &user.RetrieveIdentityProviderIntentRequest{ + IdpIntentId: oidcSuccessful, + IdpIntentToken: oidcToken, + }, + }, + want: &user.RetrieveIdentityProviderIntentResponse{ + Details: &object.Details{ + ChangeDate: timestamppb.New(oidcChangeDate), + ResourceOwner: Instance.ID(), + Sequence: oidcSequence, + }, + UserId: "", + IdpInformation: &user.IDPInformation{ + Access: &user.IDPInformation_Oauth{ + Oauth: &user.IDPOAuthAccessInformation{ + AccessToken: "accessToken", + IdToken: gu.Ptr("idToken"), + }, + }, + IdpId: oidcIdpID, + UserId: "id", + UserName: "username", + RawInformation: func() *structpb.Struct { + s, err := structpb.NewStruct(map[string]interface{}{ + "sub": "id", + "preferred_username": "username", + }) + require.NoError(t, err) + return s + }(), + }, + AddHumanUser: &user.AddHumanUserRequest{ + Username: gu.Ptr("username"), + Profile: &user.SetHumanProfile{ + PreferredLanguage: gu.Ptr("und"), + }, + IdpLinks: []*user.IDPLink{ + {IdpId: oidcIdpID, UserId: "id", UserName: "username"}, + }, + Email: &user.SetHumanEmail{ + Verification: &user.SetHumanEmail_SendCode{SendCode: &user.SendEmailVerificationCode{}}, + }, + }, + }, + wantErr: false, + }, + { + name: "retrieve successful oidc intent with linked user", + args: args{ + CTX, + &user.RetrieveIdentityProviderIntentRequest{ + IdpIntentId: oidcSuccessfulWithUserID, + IdpIntentToken: oidcWithUserIDToken, + }, + }, + want: &user.RetrieveIdentityProviderIntentResponse{ + Details: &object.Details{ + ChangeDate: timestamppb.New(oidcWithUserIDChangeDate), + ResourceOwner: Instance.ID(), + Sequence: oidcWithUserIDSequence, + }, + UserId: "user", + IdpInformation: &user.IDPInformation{ + Access: &user.IDPInformation_Oauth{ + Oauth: &user.IDPOAuthAccessInformation{ + AccessToken: "accessToken", + IdToken: gu.Ptr("idToken"), + }, + }, + IdpId: oidcIdpID, UserId: "id", UserName: "username", RawInformation: func() *structpb.Struct { @@ -2263,7 +2376,7 @@ func TestServer_RetrieveIdentityProviderIntent(t *testing.T) { }(), }, }, - IdpId: idpID, + IdpId: ldapIdpID, UserId: "id", UserName: "username", RawInformation: func() *structpb.Struct { @@ -2276,6 +2389,18 @@ func TestServer_RetrieveIdentityProviderIntent(t *testing.T) { return s }(), }, + AddHumanUser: &user.AddHumanUserRequest{ + Username: gu.Ptr("username"), + Profile: &user.SetHumanProfile{ + PreferredLanguage: gu.Ptr("en"), + }, + IdpLinks: []*user.IDPLink{ + {IdpId: ldapIdpID, UserId: "id", UserName: "username"}, + }, + Email: &user.SetHumanEmail{ + Verification: &user.SetHumanEmail_SendCode{SendCode: &user.SendEmailVerificationCode{}}, + }, + }, }, wantErr: false, }, @@ -2309,7 +2434,7 @@ func TestServer_RetrieveIdentityProviderIntent(t *testing.T) { }(), }, }, - IdpId: idpID, + IdpId: ldapIdpID, UserId: "id", UserName: "username", RawInformation: func() *structpb.Struct { @@ -2346,7 +2471,7 @@ func TestServer_RetrieveIdentityProviderIntent(t *testing.T) { Assertion: []byte(""), }, }, - IdpId: idpID, + IdpId: samlIdpID, UserId: "id", UserName: "", RawInformation: func() *structpb.Struct { @@ -2360,6 +2485,56 @@ func TestServer_RetrieveIdentityProviderIntent(t *testing.T) { return s }(), }, + AddHumanUser: &user.AddHumanUserRequest{ + Profile: &user.SetHumanProfile{ + PreferredLanguage: gu.Ptr("und"), + }, + IdpLinks: []*user.IDPLink{ + {IdpId: samlIdpID, UserId: "id"}, + }, + Email: &user.SetHumanEmail{ + Verification: &user.SetHumanEmail_SendCode{SendCode: &user.SendEmailVerificationCode{}}, + }, + }, + }, + wantErr: false, + }, + { + name: "retrieve successful saml intent with linked user", + args: args{ + CTX, + &user.RetrieveIdentityProviderIntentRequest{ + IdpIntentId: samlSuccessfulWithUserID, + IdpIntentToken: samlWithUserToken, + }, + }, + want: &user.RetrieveIdentityProviderIntentResponse{ + Details: &object.Details{ + ChangeDate: timestamppb.New(samlWithUserChangeDate), + ResourceOwner: Instance.ID(), + Sequence: samlWithUserSequence, + }, + IdpInformation: &user.IDPInformation{ + Access: &user.IDPInformation_Saml{ + Saml: &user.IDPSAMLAccessInformation{ + Assertion: []byte(""), + }, + }, + IdpId: samlIdpID, + UserId: "id", + UserName: "", + RawInformation: func() *structpb.Struct { + s, err := structpb.NewStruct(map[string]interface{}{ + "id": "id", + "attributes": map[string]interface{}{ + "attribute1": []interface{}{"value1"}, + }, + }) + require.NoError(t, err) + return s + }(), + }, + UserId: "user", }, wantErr: false, }, @@ -2369,11 +2544,11 @@ func TestServer_RetrieveIdentityProviderIntent(t *testing.T) { got, err := Client.RetrieveIdentityProviderIntent(tt.args.ctx, tt.args.req) if tt.wantErr { require.Error(t, err) - } else { - require.NoError(t, err) + return } + require.NoError(t, err) - grpc.AllFieldsEqual(t, tt.want.ProtoReflect(), got.ProtoReflect(), grpc.CustomMappers) + assert.EqualExportedValues(t, tt.want, got) }) } } diff --git a/internal/api/grpc/user/v2/intent.go b/internal/api/grpc/user/v2/intent.go new file mode 100644 index 0000000000..6e46dfd5c3 --- /dev/null +++ b/internal/api/grpc/user/v2/intent.go @@ -0,0 +1,370 @@ +package user + +import ( + "context" + "encoding/json" + "errors" + + oidc_pkg "github.com/zitadel/oidc/v3/pkg/oidc" + "google.golang.org/protobuf/types/known/structpb" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/api/grpc/object/v2" + "github.com/zitadel/zitadel/internal/command" + "github.com/zitadel/zitadel/internal/crypto" + "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/idp" + "github.com/zitadel/zitadel/internal/idp/providers/apple" + "github.com/zitadel/zitadel/internal/idp/providers/azuread" + "github.com/zitadel/zitadel/internal/idp/providers/github" + "github.com/zitadel/zitadel/internal/idp/providers/gitlab" + "github.com/zitadel/zitadel/internal/idp/providers/google" + "github.com/zitadel/zitadel/internal/idp/providers/jwt" + "github.com/zitadel/zitadel/internal/idp/providers/ldap" + "github.com/zitadel/zitadel/internal/idp/providers/oauth" + "github.com/zitadel/zitadel/internal/idp/providers/oidc" + "github.com/zitadel/zitadel/internal/idp/providers/saml" + "github.com/zitadel/zitadel/internal/query" + "github.com/zitadel/zitadel/internal/zerrors" + object_pb "github.com/zitadel/zitadel/pkg/grpc/object/v2" + "github.com/zitadel/zitadel/pkg/grpc/user/v2" +) + +func (s *Server) StartIdentityProviderIntent(ctx context.Context, req *user.StartIdentityProviderIntentRequest) (_ *user.StartIdentityProviderIntentResponse, err error) { + switch t := req.GetContent().(type) { + case *user.StartIdentityProviderIntentRequest_Urls: + return s.startIDPIntent(ctx, req.GetIdpId(), t.Urls) + case *user.StartIdentityProviderIntentRequest_Ldap: + return s.startLDAPIntent(ctx, req.GetIdpId(), t.Ldap) + default: + return nil, zerrors.ThrowUnimplementedf(nil, "USERv2-S2g21", "type oneOf %T in method StartIdentityProviderIntent not implemented", t) + } +} + +func (s *Server) startIDPIntent(ctx context.Context, idpID string, urls *user.RedirectURLs) (*user.StartIdentityProviderIntentResponse, error) { + state, session, err := s.command.AuthFromProvider(ctx, idpID, s.idpCallback(ctx), s.samlRootURL(ctx, idpID)) + if err != nil { + return nil, err + } + _, details, err := s.command.CreateIntent(ctx, state, idpID, urls.GetSuccessUrl(), urls.GetFailureUrl(), authz.GetInstance(ctx).InstanceID(), session.PersistentParameters()) + if err != nil { + return nil, err + } + content, redirect := session.GetAuth(ctx) + if redirect { + return &user.StartIdentityProviderIntentResponse{ + Details: object.DomainToDetailsPb(details), + NextStep: &user.StartIdentityProviderIntentResponse_AuthUrl{AuthUrl: content}, + }, nil + } + return &user.StartIdentityProviderIntentResponse{ + Details: object.DomainToDetailsPb(details), + NextStep: &user.StartIdentityProviderIntentResponse_PostForm{ + PostForm: []byte(content), + }, + }, nil +} + +func (s *Server) startLDAPIntent(ctx context.Context, idpID string, ldapCredentials *user.LDAPCredentials) (*user.StartIdentityProviderIntentResponse, error) { + intentWriteModel, details, err := s.command.CreateIntent(ctx, "", idpID, "", "", authz.GetInstance(ctx).InstanceID(), nil) + if err != nil { + return nil, err + } + externalUser, userID, attributes, err := s.ldapLogin(ctx, intentWriteModel.IDPID, ldapCredentials.GetUsername(), ldapCredentials.GetPassword()) + if err != nil { + if err := s.command.FailIDPIntent(ctx, intentWriteModel, err.Error()); err != nil { + return nil, err + } + return nil, err + } + token, err := s.command.SucceedLDAPIDPIntent(ctx, intentWriteModel, externalUser, userID, attributes) + if err != nil { + return nil, err + } + return &user.StartIdentityProviderIntentResponse{ + Details: object.DomainToDetailsPb(details), + NextStep: &user.StartIdentityProviderIntentResponse_IdpIntent{ + IdpIntent: &user.IDPIntent{ + IdpIntentId: intentWriteModel.AggregateID, + IdpIntentToken: token, + UserId: userID, + }, + }, + }, nil +} + +func (s *Server) checkLinkedExternalUser(ctx context.Context, idpID, externalUserID string) (string, error) { + idQuery, err := query.NewIDPUserLinkIDPIDSearchQuery(idpID) + if err != nil { + return "", err + } + externalIDQuery, err := query.NewIDPUserLinksExternalIDSearchQuery(externalUserID) + if err != nil { + return "", err + } + queries := []query.SearchQuery{ + idQuery, externalIDQuery, + } + links, err := s.query.IDPUserLinks(ctx, &query.IDPUserLinksSearchQuery{Queries: queries}, nil) + if err != nil { + return "", err + } + if len(links.Links) == 1 { + return links.Links[0].UserID, nil + } + return "", nil +} + +func (s *Server) ldapLogin(ctx context.Context, idpID, username, password string) (idp.User, string, map[string][]string, error) { + provider, err := s.command.GetProvider(ctx, idpID, "", "") + if err != nil { + return nil, "", nil, err + } + ldapProvider, ok := provider.(*ldap.Provider) + if !ok { + return nil, "", nil, zerrors.ThrowInvalidArgument(nil, "IDP-9a02j2n2bh", "Errors.ExternalIDP.IDPTypeNotImplemented") + } + session := ldapProvider.GetSession(username, password) + externalUser, err := session.FetchUser(ctx) + if errors.Is(err, ldap.ErrFailedLogin) || errors.Is(err, ldap.ErrNoSingleUser) { + return nil, "", nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-nzun2i", "Errors.User.ExternalIDP.LoginFailed") + } + if err != nil { + return nil, "", nil, err + } + userID, err := s.checkLinkedExternalUser(ctx, idpID, externalUser.GetID()) + if err != nil { + return nil, "", nil, err + } + + attributes := make(map[string][]string, 0) + for _, item := range session.Entry.Attributes { + attributes[item.Name] = item.Values + } + return externalUser, userID, attributes, nil +} + +func (s *Server) RetrieveIdentityProviderIntent(ctx context.Context, req *user.RetrieveIdentityProviderIntentRequest) (_ *user.RetrieveIdentityProviderIntentResponse, err error) { + intent, err := s.command.GetIntentWriteModel(ctx, req.GetIdpIntentId(), "") + if err != nil { + return nil, err + } + if err := s.checkIntentToken(req.GetIdpIntentToken(), intent.AggregateID); err != nil { + return nil, err + } + if intent.State != domain.IDPIntentStateSucceeded { + return nil, zerrors.ThrowPreconditionFailed(nil, "IDP-nme4gszsvx", "Errors.Intent.NotSucceeded") + } + idpIntent, err := idpIntentToIDPIntentPb(intent, s.idpAlg) + if err != nil { + return nil, err + } + if idpIntent.UserId == "" { + provider, err := s.command.GetProvider(ctx, idpIntent.IdpInformation.IdpId, "", "") + if err != nil && !errors.Is(err, oidc_pkg.ErrDiscoveryFailed) { + return nil, err + } + var idpUser idp.User + switch p := provider.(type) { + case *apple.Provider: + idpUser, err = unmarshalIdpUser(intent.IDPUser, &apple.User{}) + case *oauth.Provider: + idpUser, err = unmarshalRawIdpUser(intent.IDPUser, p.User()) + case *oidc.Provider: + idpUser, err = unmarshalIdpUser(intent.IDPUser, &oidc.User{UserInfo: &oidc_pkg.UserInfo{}}) + case *jwt.Provider: + idpUser, err = unmarshalIdpUser(intent.IDPUser, &jwt.User{}) + case *azuread.Provider: + idpUser, err = unmarshalRawIdpUser(intent.IDPUser, p.User()) + case *github.Provider: + idpUser, err = unmarshalIdpUser(intent.IDPUser, &github.User{}) + case *gitlab.Provider: + idpUser, err = unmarshalIdpUser(intent.IDPUser, &oidc.User{UserInfo: &oidc_pkg.UserInfo{}}) + case *google.Provider: + idpUser, err = unmarshalIdpUser(intent.IDPUser, &oidc.User{UserInfo: &oidc_pkg.UserInfo{}}) + case *saml.Provider: + idpUser, err = unmarshalIdpUser(intent.IDPUser, &saml.UserMapper{}) + case *ldap.Provider: + idpUser, err = unmarshalIdpUser(intent.IDPUser, &ldap.User{}) + default: + return nil, zerrors.ThrowInvalidArgument(nil, "IDP-7rPBbls4Zn", "Errors.ExternalIDP.IDPTypeNotImplemented") + } + if err != nil { + return nil, err + } + idpIntent.AddHumanUser = idpUserToAddHumanUser(idpUser, idpIntent.IdpInformation.IdpId) + } + return idpIntent, nil +} + +type rawUserMapper struct { + RawInfo map[string]interface{} +} + +func unmarshalRawIdpUser(idpUserData []byte, idpUser idp.User) (idp.User, error) { + userMapper := &rawUserMapper{} + if err := json.Unmarshal(idpUserData, userMapper); err != nil { + return nil, err + } + idpUserData, err := json.Marshal(userMapper.RawInfo) + if err != nil { + return nil, err + } + return unmarshalIdpUser(idpUserData, idpUser) +} + +func unmarshalIdpUser(idpUserData []byte, idpUser idp.User) (idp.User, error) { + if err := json.Unmarshal(idpUserData, idpUser); err != nil { + return nil, err + } + return idpUser, nil +} + +func idpIntentToIDPIntentPb(intent *command.IDPIntentWriteModel, alg crypto.EncryptionAlgorithm) (_ *user.RetrieveIdentityProviderIntentResponse, err error) { + rawInformation := new(structpb.Struct) + err = rawInformation.UnmarshalJSON(intent.IDPUser) + if err != nil { + return nil, err + } + information := &user.RetrieveIdentityProviderIntentResponse{ + IdpInformation: &user.IDPInformation{ + IdpId: intent.IDPID, + UserId: intent.IDPUserID, + UserName: intent.IDPUserName, + RawInformation: rawInformation, + }, + UserId: intent.UserID, + } + information.Details = intentToDetailsPb(intent) + // OAuth / OIDC + if intent.IDPIDToken != "" || intent.IDPAccessToken != nil { + information.IdpInformation.Access, err = idpOAuthTokensToPb(intent.IDPIDToken, intent.IDPAccessToken, alg) + if err != nil { + return nil, err + } + } + // LDAP + if intent.IDPEntryAttributes != nil { + access, err := IDPEntryAttributesToPb(intent.IDPEntryAttributes) + if err != nil { + return nil, err + } + information.IdpInformation.Access = access + } + // SAML + if intent.Assertion != nil { + assertion, err := crypto.Decrypt(intent.Assertion, alg) + if err != nil { + return nil, err + } + information.IdpInformation.Access = IDPSAMLResponseToPb(assertion) + } + return information, nil +} + +func idpOAuthTokensToPb(idpIDToken string, idpAccessToken *crypto.CryptoValue, alg crypto.EncryptionAlgorithm) (_ *user.IDPInformation_Oauth, err error) { + var idToken *string + if idpIDToken != "" { + idToken = &idpIDToken + } + var accessToken string + if idpAccessToken != nil { + accessToken, err = crypto.DecryptString(idpAccessToken, alg) + if err != nil { + return nil, err + } + } + return &user.IDPInformation_Oauth{ + Oauth: &user.IDPOAuthAccessInformation{ + AccessToken: accessToken, + IdToken: idToken, + }, + }, nil +} + +func intentToDetailsPb(intent *command.IDPIntentWriteModel) *object_pb.Details { + return &object_pb.Details{ + Sequence: intent.ProcessedSequence, + ChangeDate: timestamppb.New(intent.ChangeDate), + ResourceOwner: intent.ResourceOwner, + } +} + +func IDPEntryAttributesToPb(entryAttributes map[string][]string) (*user.IDPInformation_Ldap, error) { + values := make(map[string]interface{}, 0) + for k, v := range entryAttributes { + intValues := make([]interface{}, len(v)) + for i, value := range v { + intValues[i] = value + } + values[k] = intValues + } + attributes, err := structpb.NewStruct(values) + if err != nil { + return nil, err + } + return &user.IDPInformation_Ldap{ + Ldap: &user.IDPLDAPAccessInformation{ + Attributes: attributes, + }, + }, nil +} + +func IDPSAMLResponseToPb(assertion []byte) *user.IDPInformation_Saml { + return &user.IDPInformation_Saml{ + Saml: &user.IDPSAMLAccessInformation{ + Assertion: assertion, + }, + } +} + +func (s *Server) checkIntentToken(token string, intentID string) error { + return crypto.CheckToken(s.idpAlg, token, intentID) +} + +func idpUserToAddHumanUser(idpUser idp.User, idpID string) *user.AddHumanUserRequest { + addHumanUser := &user.AddHumanUserRequest{ + Profile: &user.SetHumanProfile{ + GivenName: idpUser.GetFirstName(), + FamilyName: idpUser.GetLastName(), + }, + Email: &user.SetHumanEmail{ + Email: string(idpUser.GetEmail()), + Verification: &user.SetHumanEmail_SendCode{}, + }, + Metadata: make([]*user.SetMetadataEntry, 0), + IdpLinks: []*user.IDPLink{ + { + IdpId: idpID, + UserId: idpUser.GetID(), + UserName: idpUser.GetPreferredUsername(), + }, + }, + } + if username := idpUser.GetPreferredUsername(); username != "" { + addHumanUser.Username = &username + } + if nickName := idpUser.GetNickname(); nickName != "" { + addHumanUser.Profile.NickName = &nickName + } + if displayName := idpUser.GetDisplayName(); displayName != "" { + addHumanUser.Profile.DisplayName = &displayName + } + if lang := idpUser.GetPreferredLanguage().String(); lang != "" { + addHumanUser.Profile.PreferredLanguage = &lang + } + if isEmailVerified := idpUser.IsEmailVerified(); isEmailVerified { + addHumanUser.Email.Verification = &user.SetHumanEmail_IsVerified{IsVerified: isEmailVerified} + } + if phone := idpUser.GetPhone(); phone != "" { + addHumanUser.Phone = &user.SetHumanPhone{ + Phone: string(phone), + Verification: &user.SetHumanPhone_SendCode{}, + } + if isPhoneVerified := idpUser.IsPhoneVerified(); isPhoneVerified { + addHumanUser.Phone.Verification = &user.SetHumanPhone_IsVerified{IsVerified: isPhoneVerified} + } + } + return addHumanUser +} diff --git a/internal/api/grpc/user/v2/user.go b/internal/api/grpc/user/v2/user.go index a743206cf0..0f958f0d40 100644 --- a/internal/api/grpc/user/v2/user.go +++ b/internal/api/grpc/user/v2/user.go @@ -2,28 +2,19 @@ package user import ( "context" - "errors" "io" "golang.org/x/text/language" - "google.golang.org/protobuf/types/known/structpb" - "google.golang.org/protobuf/types/known/timestamppb" "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/grpc/object/v2" "github.com/zitadel/zitadel/internal/command" - "github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/domain" - "github.com/zitadel/zitadel/internal/idp" - "github.com/zitadel/zitadel/internal/idp/providers/ldap" "github.com/zitadel/zitadel/internal/query" - "github.com/zitadel/zitadel/internal/zerrors" - object_pb "github.com/zitadel/zitadel/pkg/grpc/object/v2" "github.com/zitadel/zitadel/pkg/grpc/user/v2" ) func (s *Server) AddHumanUser(ctx context.Context, req *user.AddHumanUserRequest) (_ *user.AddHumanUserResponse, err error) { - human, err := AddUserRequestToAddHuman(req) if err != nil { return nil, err @@ -356,236 +347,6 @@ func userGrantsToIDs(userGrants []*query.UserGrant) []string { return converted } -func (s *Server) StartIdentityProviderIntent(ctx context.Context, req *user.StartIdentityProviderIntentRequest) (_ *user.StartIdentityProviderIntentResponse, err error) { - switch t := req.GetContent().(type) { - case *user.StartIdentityProviderIntentRequest_Urls: - return s.startIDPIntent(ctx, req.GetIdpId(), t.Urls) - case *user.StartIdentityProviderIntentRequest_Ldap: - return s.startLDAPIntent(ctx, req.GetIdpId(), t.Ldap) - default: - return nil, zerrors.ThrowUnimplementedf(nil, "USERv2-S2g21", "type oneOf %T in method StartIdentityProviderIntent not implemented", t) - } -} - -func (s *Server) startIDPIntent(ctx context.Context, idpID string, urls *user.RedirectURLs) (*user.StartIdentityProviderIntentResponse, error) { - state, session, err := s.command.AuthFromProvider(ctx, idpID, s.idpCallback(ctx), s.samlRootURL(ctx, idpID)) - if err != nil { - return nil, err - } - _, details, err := s.command.CreateIntent(ctx, state, idpID, urls.GetSuccessUrl(), urls.GetFailureUrl(), authz.GetInstance(ctx).InstanceID(), session.PersistentParameters()) - if err != nil { - return nil, err - } - content, redirect := session.GetAuth(ctx) - if redirect { - return &user.StartIdentityProviderIntentResponse{ - Details: object.DomainToDetailsPb(details), - NextStep: &user.StartIdentityProviderIntentResponse_AuthUrl{AuthUrl: content}, - }, nil - } - return &user.StartIdentityProviderIntentResponse{ - Details: object.DomainToDetailsPb(details), - NextStep: &user.StartIdentityProviderIntentResponse_PostForm{ - PostForm: []byte(content), - }, - }, nil -} - -func (s *Server) startLDAPIntent(ctx context.Context, idpID string, ldapCredentials *user.LDAPCredentials) (*user.StartIdentityProviderIntentResponse, error) { - intentWriteModel, details, err := s.command.CreateIntent(ctx, "", idpID, "", "", authz.GetInstance(ctx).InstanceID(), nil) - if err != nil { - return nil, err - } - externalUser, userID, attributes, err := s.ldapLogin(ctx, intentWriteModel.IDPID, ldapCredentials.GetUsername(), ldapCredentials.GetPassword()) - if err != nil { - if err := s.command.FailIDPIntent(ctx, intentWriteModel, err.Error()); err != nil { - return nil, err - } - return nil, err - } - token, err := s.command.SucceedLDAPIDPIntent(ctx, intentWriteModel, externalUser, userID, attributes) - if err != nil { - return nil, err - } - return &user.StartIdentityProviderIntentResponse{ - Details: object.DomainToDetailsPb(details), - NextStep: &user.StartIdentityProviderIntentResponse_IdpIntent{ - IdpIntent: &user.IDPIntent{ - IdpIntentId: intentWriteModel.AggregateID, - IdpIntentToken: token, - UserId: userID, - }, - }, - }, nil -} - -func (s *Server) checkLinkedExternalUser(ctx context.Context, idpID, externalUserID string) (string, error) { - idQuery, err := query.NewIDPUserLinkIDPIDSearchQuery(idpID) - if err != nil { - return "", err - } - externalIDQuery, err := query.NewIDPUserLinksExternalIDSearchQuery(externalUserID) - if err != nil { - return "", err - } - queries := []query.SearchQuery{ - idQuery, externalIDQuery, - } - links, err := s.query.IDPUserLinks(ctx, &query.IDPUserLinksSearchQuery{Queries: queries}, nil) - if err != nil { - return "", err - } - if len(links.Links) == 1 { - return links.Links[0].UserID, nil - } - return "", nil -} - -func (s *Server) ldapLogin(ctx context.Context, idpID, username, password string) (idp.User, string, map[string][]string, error) { - provider, err := s.command.GetProvider(ctx, idpID, "", "") - if err != nil { - return nil, "", nil, err - } - ldapProvider, ok := provider.(*ldap.Provider) - if !ok { - return nil, "", nil, zerrors.ThrowInvalidArgument(nil, "IDP-9a02j2n2bh", "Errors.ExternalIDP.IDPTypeNotImplemented") - } - session := ldapProvider.GetSession(username, password) - externalUser, err := session.FetchUser(ctx) - if errors.Is(err, ldap.ErrFailedLogin) || errors.Is(err, ldap.ErrNoSingleUser) { - return nil, "", nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-nzun2i", "Errors.User.ExternalIDP.LoginFailed") - } - if err != nil { - return nil, "", nil, err - } - userID, err := s.checkLinkedExternalUser(ctx, idpID, externalUser.GetID()) - if err != nil { - return nil, "", nil, err - } - - attributes := make(map[string][]string, 0) - for _, item := range session.Entry.Attributes { - attributes[item.Name] = item.Values - } - return externalUser, userID, attributes, nil -} - -func (s *Server) RetrieveIdentityProviderIntent(ctx context.Context, req *user.RetrieveIdentityProviderIntentRequest) (_ *user.RetrieveIdentityProviderIntentResponse, err error) { - intent, err := s.command.GetIntentWriteModel(ctx, req.GetIdpIntentId(), "") - if err != nil { - return nil, err - } - if err := s.checkIntentToken(req.GetIdpIntentToken(), intent.AggregateID); err != nil { - return nil, err - } - if intent.State != domain.IDPIntentStateSucceeded { - return nil, zerrors.ThrowPreconditionFailed(nil, "IDP-nme4gszsvx", "Errors.Intent.NotSucceeded") - } - return idpIntentToIDPIntentPb(intent, s.idpAlg) -} - -func idpIntentToIDPIntentPb(intent *command.IDPIntentWriteModel, alg crypto.EncryptionAlgorithm) (_ *user.RetrieveIdentityProviderIntentResponse, err error) { - rawInformation := new(structpb.Struct) - err = rawInformation.UnmarshalJSON(intent.IDPUser) - if err != nil { - return nil, err - } - information := &user.RetrieveIdentityProviderIntentResponse{ - Details: intentToDetailsPb(intent), - IdpInformation: &user.IDPInformation{ - IdpId: intent.IDPID, - UserId: intent.IDPUserID, - UserName: intent.IDPUserName, - RawInformation: rawInformation, - }, - UserId: intent.UserID, - } - if intent.IDPIDToken != "" || intent.IDPAccessToken != nil { - information.IdpInformation.Access, err = idpOAuthTokensToPb(intent.IDPIDToken, intent.IDPAccessToken, alg) - if err != nil { - return nil, err - } - } - - if intent.IDPEntryAttributes != nil { - access, err := IDPEntryAttributesToPb(intent.IDPEntryAttributes) - if err != nil { - return nil, err - } - information.IdpInformation.Access = access - } - - if intent.Assertion != nil { - assertion, err := crypto.Decrypt(intent.Assertion, alg) - if err != nil { - return nil, err - } - information.IdpInformation.Access = IDPSAMLResponseToPb(assertion) - } - - return information, nil -} - -func idpOAuthTokensToPb(idpIDToken string, idpAccessToken *crypto.CryptoValue, alg crypto.EncryptionAlgorithm) (_ *user.IDPInformation_Oauth, err error) { - var idToken *string - if idpIDToken != "" { - idToken = &idpIDToken - } - var accessToken string - if idpAccessToken != nil { - accessToken, err = crypto.DecryptString(idpAccessToken, alg) - if err != nil { - return nil, err - } - } - return &user.IDPInformation_Oauth{ - Oauth: &user.IDPOAuthAccessInformation{ - AccessToken: accessToken, - IdToken: idToken, - }, - }, nil -} - -func intentToDetailsPb(intent *command.IDPIntentWriteModel) *object_pb.Details { - return &object_pb.Details{ - Sequence: intent.ProcessedSequence, - ChangeDate: timestamppb.New(intent.ChangeDate), - ResourceOwner: intent.ResourceOwner, - } -} - -func IDPEntryAttributesToPb(entryAttributes map[string][]string) (*user.IDPInformation_Ldap, error) { - values := make(map[string]interface{}, 0) - for k, v := range entryAttributes { - intValues := make([]interface{}, len(v)) - for i, value := range v { - intValues[i] = value - } - values[k] = intValues - } - attributes, err := structpb.NewStruct(values) - if err != nil { - return nil, err - } - return &user.IDPInformation_Ldap{ - Ldap: &user.IDPLDAPAccessInformation{ - Attributes: attributes, - }, - }, nil -} - -func IDPSAMLResponseToPb(assertion []byte) *user.IDPInformation_Saml { - return &user.IDPInformation_Saml{ - Saml: &user.IDPSAMLAccessInformation{ - Assertion: assertion, - }, - } -} - -func (s *Server) checkIntentToken(token string, intentID string) error { - return crypto.CheckToken(s.idpAlg, token, intentID) -} - func (s *Server) ListAuthenticationMethodTypes(ctx context.Context, req *user.ListAuthenticationMethodTypesRequest) (*user.ListAuthenticationMethodTypesResponse, error) { authMethods, err := s.query.ListUserAuthMethodTypes(ctx, req.GetUserId(), true, req.GetDomainQuery().GetIncludeWithoutDomain(), req.GetDomainQuery().GetDomain()) if err != nil { diff --git a/internal/api/grpc/user/v2/user_test.go b/internal/api/grpc/user/v2/user_test.go index 9e7a5a5ab0..9408b3acf9 100644 --- a/internal/api/grpc/user/v2/user_test.go +++ b/internal/api/grpc/user/v2/user_test.go @@ -11,7 +11,6 @@ import ( "google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/timestamppb" - "github.com/zitadel/zitadel/internal/api/grpc" "github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/domain" @@ -322,7 +321,7 @@ func Test_idpIntentToIDPIntentPb(t *testing.T) { t.Run(tt.name, func(t *testing.T) { got, err := idpIntentToIDPIntentPb(tt.args.intent, tt.args.alg) require.ErrorIs(t, err, tt.res.err) - grpc.AllFieldsEqual(t, tt.res.resp.ProtoReflect(), got.ProtoReflect(), grpc.CustomMappers) + assert.EqualExportedValues(t, tt.res.resp, got) }) } } diff --git a/internal/api/grpc/user/v2beta/integration_test/user_test.go b/internal/api/grpc/user/v2beta/integration_test/user_test.go index ab2e3215ee..a81de58761 100644 --- a/internal/api/grpc/user/v2beta/integration_test/user_test.go +++ b/internal/api/grpc/user/v2beta/integration_test/user_test.go @@ -2146,17 +2146,29 @@ func TestServer_StartIdentityProviderIntent(t *testing.T) { } func TestServer_RetrieveIdentityProviderIntent(t *testing.T) { - idpID := Instance.AddGenericOAuthProvider(IamCTX, gofakeit.AppName()).GetId() - intentID := Instance.CreateIntent(CTX, idpID).GetIdpIntent().GetIdpIntentId() - successfulID, token, changeDate, sequence, err := sink.SuccessfulOAuthIntent(Instance.ID(), idpID, "id", "") + oauthIdpID := Instance.AddGenericOAuthProvider(IamCTX, gofakeit.AppName()).GetId() + oidcIdpID := Instance.AddGenericOIDCProvider(IamCTX, gofakeit.AppName()).GetId() + samlIdpID := Instance.AddSAMLPostProvider(IamCTX) + ldapIdpID := Instance.AddLDAPProvider(IamCTX) + authURL, err := url.Parse(Instance.CreateIntent(CTX, oauthIdpID).GetAuthUrl()) require.NoError(t, err) - successfulWithUserID, withUsertoken, withUserchangeDate, withUsersequence, err := sink.SuccessfulOAuthIntent(Instance.ID(), idpID, "id", "user") + intentID := authURL.Query().Get("state") + + successfulID, token, changeDate, sequence, err := sink.SuccessfulOAuthIntent(Instance.ID(), oauthIdpID, "id", "") require.NoError(t, err) - ldapSuccessfulID, ldapToken, ldapChangeDate, ldapSequence, err := sink.SuccessfulLDAPIntent(Instance.ID(), idpID, "id", "") + successfulWithUserID, withUsertoken, withUserchangeDate, withUsersequence, err := sink.SuccessfulOAuthIntent(Instance.ID(), oauthIdpID, "id", "user") require.NoError(t, err) - ldapSuccessfulWithUserID, ldapWithUserToken, ldapWithUserChangeDate, ldapWithUserSequence, err := sink.SuccessfulLDAPIntent(Instance.ID(), idpID, "id", "user") + oidcSuccessful, oidcToken, oidcChangeDate, oidcSequence, err := sink.SuccessfulOIDCIntent(Instance.ID(), oidcIdpID, "id", "") require.NoError(t, err) - samlSuccessfulID, samlToken, samlChangeDate, samlSequence, err := sink.SuccessfulSAMLIntent(Instance.ID(), idpID, "id", "") + oidcSuccessfulWithUserID, oidcWithUserIDToken, oidcWithUserIDChangeDate, oidcWithUserIDSequence, err := sink.SuccessfulOIDCIntent(Instance.ID(), oidcIdpID, "id", "user") + require.NoError(t, err) + ldapSuccessfulID, ldapToken, ldapChangeDate, ldapSequence, err := sink.SuccessfulLDAPIntent(Instance.ID(), ldapIdpID, "id", "") + require.NoError(t, err) + ldapSuccessfulWithUserID, ldapWithUserToken, ldapWithUserChangeDate, ldapWithUserSequence, err := sink.SuccessfulLDAPIntent(Instance.ID(), ldapIdpID, "id", "user") + require.NoError(t, err) + samlSuccessfulID, samlToken, samlChangeDate, samlSequence, err := sink.SuccessfulSAMLIntent(Instance.ID(), samlIdpID, "id", "") + require.NoError(t, err) + samlSuccessfulWithUserID, samlWithUserToken, samlWithUserChangeDate, samlWithUserSequence, err := sink.SuccessfulSAMLIntent(Instance.ID(), samlIdpID, "id", "user") require.NoError(t, err) type args struct { ctx context.Context @@ -2191,7 +2203,7 @@ func TestServer_RetrieveIdentityProviderIntent(t *testing.T) { wantErr: true, }, { - name: "retrieve successful intent", + name: "retrieve successful oauth intent", args: args{ CTX, &user.RetrieveIdentityProviderIntentRequest{ @@ -2212,13 +2224,15 @@ func TestServer_RetrieveIdentityProviderIntent(t *testing.T) { IdToken: gu.Ptr("idToken"), }, }, - IdpId: idpID, + IdpId: oauthIdpID, UserId: "id", - UserName: "username", + UserName: "", RawInformation: func() *structpb.Struct { s, err := structpb.NewStruct(map[string]interface{}{ - "sub": "id", - "preferred_username": "username", + "RawInfo": map[string]interface{}{ + "id": "id", + "preferred_username": "username", + }, }) require.NoError(t, err) return s @@ -2250,7 +2264,85 @@ func TestServer_RetrieveIdentityProviderIntent(t *testing.T) { IdToken: gu.Ptr("idToken"), }, }, - IdpId: idpID, + IdpId: oauthIdpID, + UserId: "id", + UserName: "", + RawInformation: func() *structpb.Struct { + s, err := structpb.NewStruct(map[string]interface{}{ + "RawInfo": map[string]interface{}{ + "id": "id", + "preferred_username": "username", + }, + }) + require.NoError(t, err) + return s + }(), + }, + }, + wantErr: false, + }, + { + name: "retrieve successful oidc intent", + args: args{ + CTX, + &user.RetrieveIdentityProviderIntentRequest{ + IdpIntentId: oidcSuccessful, + IdpIntentToken: oidcToken, + }, + }, + want: &user.RetrieveIdentityProviderIntentResponse{ + Details: &object.Details{ + ChangeDate: timestamppb.New(oidcChangeDate), + ResourceOwner: Instance.ID(), + Sequence: oidcSequence, + }, + UserId: "", + IdpInformation: &user.IDPInformation{ + Access: &user.IDPInformation_Oauth{ + Oauth: &user.IDPOAuthAccessInformation{ + AccessToken: "accessToken", + IdToken: gu.Ptr("idToken"), + }, + }, + IdpId: oidcIdpID, + UserId: "id", + UserName: "username", + RawInformation: func() *structpb.Struct { + s, err := structpb.NewStruct(map[string]interface{}{ + "sub": "id", + "preferred_username": "username", + }) + require.NoError(t, err) + return s + }(), + }, + }, + wantErr: false, + }, + { + name: "retrieve successful oidc intent with linked user", + args: args{ + CTX, + &user.RetrieveIdentityProviderIntentRequest{ + IdpIntentId: oidcSuccessfulWithUserID, + IdpIntentToken: oidcWithUserIDToken, + }, + }, + want: &user.RetrieveIdentityProviderIntentResponse{ + Details: &object.Details{ + ChangeDate: timestamppb.New(oidcWithUserIDChangeDate), + ResourceOwner: Instance.ID(), + Sequence: oidcWithUserIDSequence, + }, + UserId: "user", + IdpInformation: &user.IDPInformation{ + Access: &user.IDPInformation_Oauth{ + Oauth: &user.IDPOAuthAccessInformation{ + AccessToken: "accessToken", + IdToken: gu.Ptr("idToken"), + }, + }, + IdpId: oidcIdpID, UserId: "id", UserName: "username", RawInformation: func() *structpb.Struct { @@ -2294,7 +2386,7 @@ func TestServer_RetrieveIdentityProviderIntent(t *testing.T) { }(), }, }, - IdpId: idpID, + IdpId: ldapIdpID, UserId: "id", UserName: "username", RawInformation: func() *structpb.Struct { @@ -2340,7 +2432,7 @@ func TestServer_RetrieveIdentityProviderIntent(t *testing.T) { }(), }, }, - IdpId: idpID, + IdpId: ldapIdpID, UserId: "id", UserName: "username", RawInformation: func() *structpb.Struct { @@ -2377,7 +2469,7 @@ func TestServer_RetrieveIdentityProviderIntent(t *testing.T) { Assertion: []byte(""), }, }, - IdpId: idpID, + IdpId: samlIdpID, UserId: "id", UserName: "", RawInformation: func() *structpb.Struct { @@ -2394,6 +2486,45 @@ func TestServer_RetrieveIdentityProviderIntent(t *testing.T) { }, wantErr: false, }, + { + name: "retrieve successful saml intent with linked user", + args: args{ + CTX, + &user.RetrieveIdentityProviderIntentRequest{ + IdpIntentId: samlSuccessfulWithUserID, + IdpIntentToken: samlWithUserToken, + }, + }, + want: &user.RetrieveIdentityProviderIntentResponse{ + Details: &object.Details{ + ChangeDate: timestamppb.New(samlWithUserChangeDate), + ResourceOwner: Instance.ID(), + Sequence: samlWithUserSequence, + }, + IdpInformation: &user.IDPInformation{ + Access: &user.IDPInformation_Saml{ + Saml: &user.IDPSAMLAccessInformation{ + Assertion: []byte(""), + }, + }, + IdpId: samlIdpID, + UserId: "id", + UserName: "", + RawInformation: func() *structpb.Struct { + s, err := structpb.NewStruct(map[string]interface{}{ + "id": "id", + "attributes": map[string]interface{}{ + "attribute1": []interface{}{"value1"}, + }, + }) + require.NoError(t, err) + return s + }(), + }, + UserId: "user", + }, + wantErr: false, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/internal/api/oidc/userinfo.go b/internal/api/oidc/userinfo.go index b2121a73a2..61f03b6d0f 100644 --- a/internal/api/oidc/userinfo.go +++ b/internal/api/oidc/userinfo.go @@ -20,7 +20,9 @@ import ( "github.com/zitadel/zitadel/internal/actions/object" "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/execution" "github.com/zitadel/zitadel/internal/query" + exec_repo "github.com/zitadel/zitadel/internal/repository/execution" "github.com/zitadel/zitadel/internal/telemetry/tracing" "github.com/zitadel/zitadel/internal/zerrors" ) @@ -410,5 +412,104 @@ func (s *Server) userinfoFlows(ctx context.Context, qu *query.OIDCUserInfo, user } } + var function string + switch triggerType { + case domain.TriggerTypePreUserinfoCreation: + function = exec_repo.ID(domain.ExecutionTypeFunction, domain.ActionFunctionPreUserinfo.LocalizationKey()) + case domain.TriggerTypePreAccessTokenCreation: + function = exec_repo.ID(domain.ExecutionTypeFunction, domain.ActionFunctionPreAccessToken.LocalizationKey()) + case domain.TriggerTypeUnspecified, domain.TriggerTypePostAuthentication, domain.TriggerTypePreCreation, domain.TriggerTypePostCreation, domain.TriggerTypePreSAMLResponseCreation: + // added for linting, there should never be any trigger type be used here besides PreUserinfo and PreAccessToken + return err + } + + if function == "" { + return nil + } + executionTargets, err := execution.QueryExecutionTargetsForFunction(ctx, s.query, function) + if err != nil { + return err + } + info := &ContextInfo{ + Function: function, + UserInfo: userInfo, + User: qu.User, + UserMetadata: qu.Metadata, + Org: qu.Org, + UserGrants: qu.UserGrants, + } + + resp, err := execution.CallTargets(ctx, executionTargets, info) + if err != nil { + return err + } + contextInfoResponse, ok := resp.(*ContextInfoResponse) + if !ok || contextInfoResponse == nil { + return nil + } + claimLogs := make([]string, 0) + for _, metadata := range contextInfoResponse.SetUserMetadata { + if _, err = s.command.SetUserMetadata(ctx, metadata, userInfo.Subject, qu.User.ResourceOwner); err != nil { + claimLogs = append(claimLogs, fmt.Sprintf("failed to set user metadata key %q", metadata.Key)) + } + } + for _, claim := range contextInfoResponse.AppendClaims { + if strings.HasPrefix(claim.Key, ClaimPrefix) { + continue + } + if userInfo.Claims[claim.Key] == nil { + userInfo.AppendClaims(claim.Key, claim.Value) + continue + } + claimLogs = append(claimLogs, fmt.Sprintf("key %q already exists", claim.Key)) + } + claimLogs = append(claimLogs, contextInfoResponse.AppendLogClaims...) + if len(claimLogs) > 0 { + userInfo.AppendClaims(fmt.Sprintf(ClaimActionLogFormat, function), claimLogs) + } + return nil } + +type ContextInfo struct { + Function string `json:"function,omitempty"` + UserInfo *oidc.UserInfo `json:"userinfo,omitempty"` + User *query.User `json:"user,omitempty"` + UserMetadata []query.UserMetadata `json:"user_metadata,omitempty"` + Org *query.UserInfoOrg `json:"org,omitempty"` + UserGrants []query.UserGrant `json:"user_grants,omitempty"` + Response *ContextInfoResponse `json:"response,omitempty"` +} + +type ContextInfoResponse struct { + SetUserMetadata []*domain.Metadata `json:"set_user_metadata,omitempty"` + AppendClaims []*AppendClaim `json:"append_claims,omitempty"` + AppendLogClaims []string `json:"append_log_claims,omitempty"` +} + +type AppendClaim struct { + Key string `json:"key"` + Value any `json:"value"` +} + +func (c *ContextInfo) GetHTTPRequestBody() []byte { + data, err := json.Marshal(c) + if err != nil { + return nil + } + return data +} + +func (c *ContextInfo) SetHTTPResponseBody(resp []byte) error { + if !json.Valid(resp) { + return zerrors.ThrowPreconditionFailed(nil, "ACTION-4m9s2", "Errors.Execution.ResponseIsNotValidJSON") + } + if c.Response == nil { + c.Response = &ContextInfoResponse{} + } + return json.Unmarshal(resp, c.Response) +} + +func (c *ContextInfo) GetContent() any { + return c.Response +} diff --git a/internal/api/saml/auth_request.go b/internal/api/saml/auth_request.go index a846cd090b..f31647f705 100644 --- a/internal/api/saml/auth_request.go +++ b/internal/api/saml/auth_request.go @@ -3,6 +3,7 @@ package saml import ( "context" "encoding/base64" + "net/http" "net/url" "github.com/zitadel/saml/pkg/provider" @@ -32,9 +33,16 @@ func (p *Provider) CreateResponse(ctx context.Context, authReq models.AuthReques RelayState: authReq.GetRelayState(), AcsUrl: authReq.GetAccessConsumerServiceURL(), RequestID: authReq.GetAuthRequestID(), - Issuer: authReq.GetDestination(), Audience: authReq.GetIssuer(), } + + issuer := ContextToIssuer(ctx) + req, err := http.NewRequestWithContext(provider.ContextWithIssuer(ctx, issuer), http.MethodGet, issuer, nil) + if err != nil { + return "", "", err + } + resp.Issuer = p.GetEntityID(req) + samlResponse, err := p.AuthCallbackResponse(ctx, authReq, resp) if err != nil { return "", "", err diff --git a/internal/api/saml/provider.go b/internal/api/saml/provider.go index edf713456c..0b056797d5 100644 --- a/internal/api/saml/provider.go +++ b/internal/api/saml/provider.go @@ -1,6 +1,7 @@ package saml import ( + "context" "fmt" "net/http" @@ -83,7 +84,7 @@ func NewProvider( p, err := provider.NewProvider( provStorage, - HandlerPrefix, + IssuerFromContext, conf.ProviderConfig, options..., ) @@ -96,6 +97,16 @@ func NewProvider( }, nil } +func ContextToIssuer(ctx context.Context) string { + return http_utils.DomainContext(ctx).Origin() + HandlerPrefix +} + +func IssuerFromContext(_ bool) (provider.IssuerFromRequest, error) { + return func(r *http.Request) string { + return ContextToIssuer(r.Context()) + }, nil +} + func newStorage( command *command.Commands, query *query.Queries, diff --git a/internal/api/saml/storage.go b/internal/api/saml/storage.go index 5a02619d93..834cc7b392 100644 --- a/internal/api/saml/storage.go +++ b/internal/api/saml/storage.go @@ -3,6 +3,7 @@ package saml import ( "context" "encoding/json" + "fmt" "strings" "time" @@ -26,7 +27,9 @@ import ( "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/eventstore/handler/crdb" + "github.com/zitadel/zitadel/internal/execution" "github.com/zitadel/zitadel/internal/query" + exec_repo "github.com/zitadel/zitadel/internal/repository/execution" "github.com/zitadel/zitadel/internal/telemetry/tracing" "github.com/zitadel/zitadel/internal/zerrors" ) @@ -37,7 +40,8 @@ var _ provider.AuthStorage = &Storage{} var _ provider.UserStorage = &Storage{} const ( - LoginClientHeader = "x-zitadel-login-client" + LoginClientHeader = "x-zitadel-login-client" + AttributeActionLogFormat = "urn:zitadel:iam:action:%s:log" ) type Storage struct { @@ -380,9 +384,86 @@ func (p *Storage) getCustomAttributes(ctx context.Context, user *query.User, use return nil, err } } + + function := exec_repo.ID(domain.ExecutionTypeFunction, domain.ActionFunctionPreSAMLResponse.LocalizationKey()) + executionTargets, err := execution.QueryExecutionTargetsForFunction(ctx, p.query, function) + if err != nil { + return nil, err + } + + // correct time for utc + user.CreationDate = user.CreationDate.UTC() + user.ChangeDate = user.ChangeDate.UTC() + + info := &ContextInfo{ + Function: function, + User: user, + UserGrants: userGrants.UserGrants, + } + + resp, err := execution.CallTargets(ctx, executionTargets, info) + if err != nil { + return nil, err + } + contextInfoResponse, ok := resp.(*ContextInfoResponse) + if !ok || contextInfoResponse == nil { + return customAttributes, nil + } + attributeLogs := make([]string, 0) + for _, metadata := range contextInfoResponse.SetUserMetadata { + if _, err = p.command.SetUserMetadata(ctx, metadata, user.ID, user.ResourceOwner); err != nil { + attributeLogs = append(attributeLogs, fmt.Sprintf("failed to set user metadata key %q", metadata.Key)) + } + } + for _, attribute := range contextInfoResponse.AppendAttribute { + customAttributes = appendCustomAttribute(customAttributes, attribute.Name, attribute.NameFormat, attribute.Value) + } + if len(attributeLogs) > 0 { + customAttributes = appendCustomAttribute(customAttributes, fmt.Sprintf(AttributeActionLogFormat, function), "", attributeLogs) + } return customAttributes, nil } +type ContextInfo struct { + Function string `json:"function,omitempty"` + User *query.User `json:"user,omitempty"` + UserGrants []*query.UserGrant `json:"user_grants,omitempty"` + Response *ContextInfoResponse `json:"response,omitempty"` +} + +type ContextInfoResponse struct { + SetUserMetadata []*domain.Metadata `json:"set_user_metadata,omitempty"` + AppendAttribute []*AppendAttribute `json:"append_attribute,omitempty"` +} + +type AppendAttribute struct { + Name string `json:"name"` + NameFormat string `json:"name_format"` + Value []string `json:"value"` +} + +func (c *ContextInfo) GetHTTPRequestBody() []byte { + data, err := json.Marshal(c) + if err != nil { + return nil + } + return data +} + +func (c *ContextInfo) SetHTTPResponseBody(resp []byte) error { + if !json.Valid(resp) { + return zerrors.ThrowPreconditionFailed(nil, "ACTION-4m9s2", "Errors.Execution.ResponseIsNotValidJSON") + } + if c.Response == nil { + c.Response = &ContextInfoResponse{} + } + return json.Unmarshal(resp, c.Response) +} + +func (c *ContextInfo) GetContent() interface{} { + return c.Response +} + func (p *Storage) getGrants(ctx context.Context, userID, applicationID string) (*query.UserGrants, error) { projectID, err := p.query.ProjectIDFromClientID(ctx, applicationID) if err != nil { diff --git a/internal/command/command.go b/internal/command/command.go index 17f6641caf..f9c78fbaab 100644 --- a/internal/command/command.go +++ b/internal/command/command.go @@ -183,7 +183,7 @@ func StartCommands( EventGroupExisting: func(group string) bool { return true }, GrpcServiceExisting: func(service string) bool { return false }, GrpcMethodExisting: func(method string) bool { return false }, - ActionFunctionExisting: domain.FunctionExists(), + ActionFunctionExisting: domain.ActionFunctionExists(), multifactors: domain.MultifactorConfigs{ OTP: domain.OTPConfig{ CryptoMFA: otpEncryption, diff --git a/internal/domain/action.go b/internal/domain/action.go index 18dd23e8c5..b57dde6289 100644 --- a/internal/domain/action.go +++ b/internal/domain/action.go @@ -1,6 +1,7 @@ package domain import ( + "slices" "time" "github.com/zitadel/zitadel/internal/eventstore/v1/models" @@ -45,3 +46,51 @@ const ( ActionsMaxAllowed ActionsAllowedUnlimited ) + +type ActionFunction int32 + +const ( + ActionFunctionUnspecified ActionFunction = iota + ActionFunctionPreUserinfo + ActionFunctionPreAccessToken + ActionFunctionPreSAMLResponse + actionFunctionCount +) + +func (s ActionFunction) Valid() bool { + return s >= 0 && s < actionFunctionCount +} + +func (s ActionFunction) LocalizationKey() string { + if !s.Valid() { + return ActionFunctionUnspecified.LocalizationKey() + } + + switch s { + case ActionFunctionPreUserinfo: + return "preuserinfo" + case ActionFunctionPreAccessToken: + return "preaccesstoken" + case ActionFunctionPreSAMLResponse: + return "presamlresponse" + case ActionFunctionUnspecified, actionFunctionCount: + fallthrough + default: + return "unspecified" + } +} + +func AllActionFunctions() []string { + return []string{ + ActionFunctionPreUserinfo.LocalizationKey(), + ActionFunctionPreAccessToken.LocalizationKey(), + ActionFunctionPreSAMLResponse.LocalizationKey(), + } +} + +func ActionFunctionExists() func(string) bool { + functions := AllActionFunctions() + return func(s string) bool { + return slices.Contains(functions, s) + } +} diff --git a/internal/domain/flow.go b/internal/domain/flow.go index 143ce6bd0b..39cb13fc1e 100644 --- a/internal/domain/flow.go +++ b/internal/domain/flow.go @@ -1,7 +1,6 @@ package domain import ( - "slices" "strconv" ) @@ -150,20 +149,3 @@ func (s TriggerType) LocalizationKey() string { return "Action.TriggerType.Unspecified" } } - -func AllFunctions() []string { - functions := make([]string, 0) - for _, flowType := range AllFlowTypes() { - for _, triggerType := range flowType.TriggerTypes() { - functions = append(functions, flowType.LocalizationKey()+"."+triggerType.LocalizationKey()) - } - } - return functions -} - -func FunctionExists() func(string) bool { - functions := AllFunctions() - return func(s string) bool { - return slices.Contains(functions, s) - } -} diff --git a/internal/execution/execution.go b/internal/execution/execution.go index 99d7f6182f..116f377e17 100644 --- a/internal/execution/execution.go +++ b/internal/execution/execution.go @@ -6,12 +6,15 @@ import ( "encoding/json" "io" "net/http" + "strings" "time" "github.com/zitadel/logging" zhttp "github.com/zitadel/zitadel/internal/api/http" "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/query" + "github.com/zitadel/zitadel/internal/repository/execution" "github.com/zitadel/zitadel/internal/telemetry/tracing" "github.com/zitadel/zitadel/internal/zerrors" "github.com/zitadel/zitadel/pkg/actions" @@ -153,3 +156,59 @@ type ErrorBody struct { ForwardedStatusCode int `json:"forwardedStatusCode,omitempty"` ForwardedErrorMessage string `json:"forwardedErrorMessage,omitempty"` } + +type ExecutionTargetsQueries interface { + TargetsByExecutionID(ctx context.Context, ids []string) (execution []*query.ExecutionTarget, err error) + TargetsByExecutionIDs(ctx context.Context, ids1, ids2 []string) (execution []*query.ExecutionTarget, err error) +} + +func QueryExecutionTargetsForRequestAndResponse( + ctx context.Context, + queries ExecutionTargetsQueries, + fullMethod string, +) ([]Target, []Target) { + ctx, span := tracing.NewSpan(ctx) + defer span.End() + + targets, err := queries.TargetsByExecutionIDs(ctx, + idsForFullMethod(fullMethod, domain.ExecutionTypeRequest), + idsForFullMethod(fullMethod, domain.ExecutionTypeResponse), + ) + requestTargets := make([]Target, 0, len(targets)) + responseTargets := make([]Target, 0, len(targets)) + if err != nil { + logging.WithFields("fullMethod", fullMethod).WithError(err).Info("unable to query targets") + return requestTargets, responseTargets + } + + for _, target := range targets { + if strings.HasPrefix(target.GetExecutionID(), execution.IDAll(domain.ExecutionTypeRequest)) { + requestTargets = append(requestTargets, target) + } else if strings.HasPrefix(target.GetExecutionID(), execution.IDAll(domain.ExecutionTypeResponse)) { + responseTargets = append(responseTargets, target) + } + } + + return requestTargets, responseTargets +} + +func idsForFullMethod(fullMethod string, executionType domain.ExecutionType) []string { + return []string{execution.ID(executionType, fullMethod), execution.ID(executionType, serviceFromFullMethod(fullMethod)), execution.IDAll(executionType)} +} + +func serviceFromFullMethod(s string) string { + parts := strings.Split(s, "/") + return parts[1] +} + +func QueryExecutionTargetsForFunction(ctx context.Context, query ExecutionTargetsQueries, function string) ([]Target, error) { + queriedActionsV2, err := query.TargetsByExecutionID(ctx, []string{function}) + if err != nil { + return nil, err + } + executionTargets := make([]Target, len(queriedActionsV2)) + for i, action := range queriedActionsV2 { + executionTargets[i] = action + } + return executionTargets, nil +} diff --git a/internal/idp/providers/azuread/azuread.go b/internal/idp/providers/azuread/azuread.go index 65f38ede5b..a15f793e37 100644 --- a/internal/idp/providers/azuread/azuread.go +++ b/internal/idp/providers/azuread/azuread.go @@ -152,6 +152,10 @@ func ensureMinimalScope(scopes []string) []string { return scopes } +func (p *Provider) User() idp.User { + return p.Provider.User() +} + // User represents the structure return on the userinfo endpoint and implements the [idp.User] interface // // AzureAD does not return an `email_verified` claim. diff --git a/internal/idp/providers/oauth/oauth2.go b/internal/idp/providers/oauth/oauth2.go index e9c627509a..a790c550f5 100644 --- a/internal/idp/providers/oauth/oauth2.go +++ b/internal/idp/providers/oauth/oauth2.go @@ -18,7 +18,7 @@ type Provider struct { options []rp.Option name string userEndpoint string - userMapper func() idp.User + user func() idp.User isLinkingAllowed bool isCreationAllowed bool isAutoCreation bool @@ -65,11 +65,11 @@ func WithRelyingPartyOption(option rp.Option) ProviderOpts { } // New creates a generic OAuth 2.0 provider -func New(config *oauth2.Config, name, userEndpoint string, userMapper func() idp.User, options ...ProviderOpts) (provider *Provider, err error) { +func New(config *oauth2.Config, name, userEndpoint string, user func() idp.User, options ...ProviderOpts) (provider *Provider, err error) { provider = &Provider{ name: name, userEndpoint: userEndpoint, - userMapper: userMapper, + user: user, generateVerifier: oauth2.GenerateVerifier, } for _, option := range options { @@ -137,3 +137,7 @@ func (p *Provider) IsAutoCreation() bool { func (p *Provider) IsAutoUpdate() bool { return p.isAutoUpdate } + +func (p *Provider) User() idp.User { + return p.user() +} diff --git a/internal/idp/providers/oauth/session.go b/internal/idp/providers/oauth/session.go index aca22234a2..247a7f8710 100644 --- a/internal/idp/providers/oauth/session.go +++ b/internal/idp/providers/oauth/session.go @@ -51,7 +51,7 @@ func (s *Session) PersistentParameters() map[string]any { // FetchUser implements the [idp.Session] interface. // It will execute an OAuth 2.0 code exchange if needed to retrieve the access token, // call the specified userEndpoint and map the received information into an [idp.User]. -func (s *Session) FetchUser(ctx context.Context) (user idp.User, err error) { +func (s *Session) FetchUser(ctx context.Context) (_ idp.User, err error) { if s.Tokens == nil { if err = s.authorize(ctx); err != nil { return nil, err @@ -62,11 +62,11 @@ func (s *Session) FetchUser(ctx context.Context) (user idp.User, err error) { return nil, err } req.Header.Set("authorization", s.Tokens.TokenType+" "+s.Tokens.AccessToken) - mapper := s.Provider.userMapper() - if err := httphelper.HttpRequest(s.Provider.RelyingParty.HttpClient(), req, &mapper); err != nil { + user := s.Provider.User() + if err := httphelper.HttpRequest(s.Provider.RelyingParty.HttpClient(), req, &user); err != nil { return nil, err } - return mapper, nil + return user, nil } func (s *Session) authorize(ctx context.Context) (err error) { diff --git a/internal/integration/client.go b/internal/integration/client.go index a480a86ce0..abc774f452 100644 --- a/internal/integration/client.go +++ b/internal/integration/client.go @@ -472,6 +472,26 @@ func (i *Instance) AddOrgGenericOAuthProvider(ctx context.Context, name string) return resp } +func (i *Instance) AddGenericOIDCProvider(ctx context.Context, name string) *admin.AddGenericOIDCProviderResponse { + resp, err := i.Client.Admin.AddGenericOIDCProvider(ctx, &admin.AddGenericOIDCProviderRequest{ + Name: name, + Issuer: "https://example.com", + ClientId: "clientID", + ClientSecret: "clientSecret", + Scopes: []string{"openid", "profile", "email"}, + ProviderOptions: &idp.Options{ + IsLinkingAllowed: true, + IsCreationAllowed: true, + IsAutoCreation: true, + IsAutoUpdate: true, + AutoLinking: idp.AutoLinkingOption_AUTO_LINKING_OPTION_USERNAME, + }, + IsIdTokenMapping: false, + }) + logging.OnError(err).Panic("create generic oidc idp") + return resp +} + func (i *Instance) AddSAMLProvider(ctx context.Context) string { resp, err := i.Client.Admin.AddSAMLProvider(ctx, &admin.AddSAMLProviderRequest{ Name: "saml-idp", @@ -526,6 +546,32 @@ func (i *Instance) AddSAMLPostProvider(ctx context.Context) string { return resp.GetId() } +func (i *Instance) AddLDAPProvider(ctx context.Context) string { + resp, err := i.Client.Admin.AddLDAPProvider(ctx, &admin.AddLDAPProviderRequest{ + Name: "ldap-idp-post", + Servers: []string{"https://localhost:8000"}, + StartTls: false, + BaseDn: "baseDn", + BindDn: "admin", + BindPassword: "admin", + UserBase: "dn", + UserObjectClasses: []string{"user"}, + UserFilters: []string{"(objectclass=*)"}, + Timeout: durationpb.New(10 * time.Second), + Attributes: &idp.LDAPAttributes{ + IdAttribute: "id", + }, + ProviderOptions: &idp.Options{ + IsLinkingAllowed: true, + IsCreationAllowed: true, + IsAutoCreation: true, + IsAutoUpdate: true, + }, + }) + logging.OnError(err).Panic("create ldap idp") + return resp.GetId() +} + func (i *Instance) CreateIntent(ctx context.Context, idpID string) *user_v2.StartIdentityProviderIntentResponse { resp, err := i.Client.UserV2.StartIdentityProviderIntent(ctx, &user_v2.StartIdentityProviderIntentRequest{ IdpId: idpID, diff --git a/internal/integration/saml.go b/internal/integration/saml.go index 533b0ee515..28934ac421 100644 --- a/internal/integration/saml.go +++ b/internal/integration/saml.go @@ -17,6 +17,7 @@ import ( "github.com/crewjam/saml" "github.com/crewjam/saml/samlsp" "github.com/zitadel/logging" + "github.com/zitadel/saml/pkg/provider" http_util "github.com/zitadel/zitadel/internal/api/http" oidc_internal "github.com/zitadel/zitadel/internal/api/oidc" @@ -220,8 +221,15 @@ func (i *Instance) SuccessfulSAMLAuthRequest(ctx context.Context, userId, id str } func (i *Instance) GetSAMLIDPMetadata() (*saml.EntityDescriptor, error) { - idpEntityID := http_util.BuildHTTP(i.Domain, i.Config.Port, i.Config.Secure) + "/saml/v2/metadata" - resp, err := http.Get(idpEntityID) + issuer := i.Issuer() + "/saml/v2" + idpEntityID := issuer + "/metadata" + + req, err := http.NewRequestWithContext(provider.ContextWithIssuer(context.Background(), issuer), http.MethodGet, idpEntityID, nil) + if err != nil { + return nil, err + } + + resp, err := http.DefaultClient.Do(req) if err != nil { return nil, err } diff --git a/internal/integration/sink/server.go b/internal/integration/sink/server.go index 2c79081e98..633ebf424f 100644 --- a/internal/integration/sink/server.go +++ b/internal/integration/sink/server.go @@ -27,6 +27,7 @@ import ( "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/idp/providers/ldap" + "github.com/zitadel/zitadel/internal/idp/providers/oauth" openid "github.com/zitadel/zitadel/internal/idp/providers/oidc" "github.com/zitadel/zitadel/internal/idp/providers/saml" ) @@ -65,6 +66,24 @@ func SuccessfulOAuthIntent(instanceID, idpID, idpUserID, userID string) (string, return resp.IntentID, resp.Token, resp.ChangeDate, resp.Sequence, nil } +func SuccessfulOIDCIntent(instanceID, idpID, idpUserID, userID string) (string, string, time.Time, uint64, error) { + u := url.URL{ + Scheme: "http", + Host: host, + Path: successfulIntentOIDCPath(), + } + resp, err := callIntent(u.String(), &SuccessfulIntentRequest{ + InstanceID: instanceID, + IDPID: idpID, + IDPUserID: idpUserID, + UserID: userID, + }) + if err != nil { + return "", "", time.Time{}, uint64(0), err + } + return resp.IntentID, resp.Token, resp.ChangeDate, resp.Sequence, nil +} + func SuccessfulSAMLIntent(instanceID, idpID, idpUserID, userID string) (string, string, time.Time, uint64, error) { u := url.URL{ Scheme: "http", @@ -119,6 +138,7 @@ func StartServer(commands *command.Commands) (close func()) { router.HandleFunc(rootPath(ch), fwd.receiveHandler) router.HandleFunc(subscribePath(ch), fwd.subscriptionHandler) router.HandleFunc(successfulIntentOAuthPath(), successfulIntentHandler(commands, createSuccessfulOAuthIntent)) + router.HandleFunc(successfulIntentOIDCPath(), successfulIntentHandler(commands, createSuccessfulOIDCIntent)) router.HandleFunc(successfulIntentSAMLPath(), successfulIntentHandler(commands, createSuccessfulSAMLIntent)) router.HandleFunc(successfulIntentLDAPPath(), successfulIntentHandler(commands, createSuccessfulLDAPIntent)) } @@ -159,6 +179,10 @@ func successfulIntentOAuthPath() string { return path.Join(successfulIntentPath(), "/", "oauth") } +func successfulIntentOIDCPath() string { + return path.Join(successfulIntentPath(), "/", "oidc") +} + func successfulIntentSAMLPath() string { return path.Join(successfulIntentPath(), "/", "saml") } @@ -334,6 +358,41 @@ func createIntent(ctx context.Context, cmd *command.Commands, instanceID, idpID } func createSuccessfulOAuthIntent(ctx context.Context, cmd *command.Commands, req *SuccessfulIntentRequest) (*SuccessfulIntentResponse, error) { + intentID, err := createIntent(ctx, cmd, req.InstanceID, req.IDPID) + if err != nil { + return nil, err + } + writeModel, err := cmd.GetIntentWriteModel(ctx, intentID, req.InstanceID) + if err != nil { + return nil, err + } + idAttribute := "id" + idpUser := oauth.NewUserMapper(idAttribute) + idpUser.RawInfo = map[string]interface{}{ + idAttribute: req.IDPUserID, + "preferred_username": "username", + } + idpSession := &oauth.Session{ + Tokens: &oidc.Tokens[*oidc.IDTokenClaims]{ + Token: &oauth2.Token{ + AccessToken: "accessToken", + }, + IDToken: "idToken", + }, + } + token, err := cmd.SucceedIDPIntent(ctx, writeModel, idpUser, idpSession, req.UserID) + if err != nil { + return nil, err + } + return &SuccessfulIntentResponse{ + intentID, + token, + writeModel.ChangeDate, + writeModel.ProcessedSequence, + }, nil +} + +func createSuccessfulOIDCIntent(ctx context.Context, cmd *command.Commands, req *SuccessfulIntentRequest) (*SuccessfulIntentResponse, error) { intentID, err := createIntent(ctx, cmd, req.InstanceID, req.IDPID) writeModel, err := cmd.GetIntentWriteModel(ctx, intentID, req.InstanceID) idpUser := openid.NewUser( diff --git a/proto/zitadel/user/v2/user_service.proto b/proto/zitadel/user/v2/user_service.proto index 5457efd64e..d59d6e67ec 100644 --- a/proto/zitadel/user/v2/user_service.proto +++ b/proto/zitadel/user/v2/user_service.proto @@ -2094,6 +2094,7 @@ message RetrieveIdentityProviderIntentResponse{ example: "\"163840776835432345\""; } ]; + AddHumanUserRequest add_human_user = 4; } message AddIDPLinkRequest{