fix: added changed from review

This commit is contained in:
Stefan Benz 2025-02-27 10:54:32 +01:00
parent a226de975b
commit 784b3afb4f
No known key found for this signature in database
GPG Key ID: 071AA751ED4F9D31
8 changed files with 129 additions and 154 deletions

View File

@ -3,23 +3,19 @@ package middleware
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"strings"
"github.com/zitadel/logging"
"google.golang.org/grpc" "google.golang.org/grpc"
"github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/execution" "github.com/zitadel/zitadel/internal/execution"
"github.com/zitadel/zitadel/internal/query" "github.com/zitadel/zitadel/internal/query"
exec_repo "github.com/zitadel/zitadel/internal/repository/execution"
"github.com/zitadel/zitadel/internal/telemetry/tracing" "github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/zerrors" "github.com/zitadel/zitadel/internal/zerrors"
) )
func ExecutionHandler(queries *query.Queries) grpc.UnaryServerInterceptor { func ExecutionHandler(queries *query.Queries) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
requestTargets, responseTargets := queryTargets(ctx, queries, info.FullMethod) requestTargets, responseTargets := execution.QueryExecutionTargetsForRequestAndResponse(ctx, queries, info.FullMethod)
// call targets otherwise return req // call targets otherwise return req
handledReq, err := executeTargetsForRequest(ctx, requestTargets, info.FullMethod, req) handledReq, err := executeTargetsForRequest(ctx, requestTargets, info.FullMethod, req)
@ -81,49 +77,6 @@ func executeTargetsForResponse(ctx context.Context, targets []execution.Target,
return execution.CallTargets(ctx, targets, info) return execution.CallTargets(ctx, targets, info)
} }
type ExecutionQueries interface {
TargetsByExecutionIDs(ctx context.Context, ids1, ids2 []string) (execution []*query.ExecutionTarget, err error)
}
func queryTargets(
ctx context.Context,
queries ExecutionQueries,
fullMethod string,
) ([]execution.Target, []execution.Target) {
ctx, span := tracing.NewSpan(ctx)
defer span.End()
targets, err := queries.TargetsByExecutionIDs(ctx,
idsForFullMethod(fullMethod, domain.ExecutionTypeRequest),
idsForFullMethod(fullMethod, domain.ExecutionTypeResponse),
)
requestTargets := make([]execution.Target, 0, len(targets))
responseTargets := make([]execution.Target, 0, len(targets))
if err != nil {
logging.WithFields("fullMethod", fullMethod).WithError(err).Info("unable to query targets")
return requestTargets, responseTargets
}
for _, target := range targets {
if strings.HasPrefix(target.GetExecutionID(), exec_repo.IDAll(domain.ExecutionTypeRequest)) {
requestTargets = append(requestTargets, target)
} else if strings.HasPrefix(target.GetExecutionID(), exec_repo.IDAll(domain.ExecutionTypeResponse)) {
responseTargets = append(responseTargets, target)
}
}
return requestTargets, responseTargets
}
func idsForFullMethod(fullMethod string, executionType domain.ExecutionType) []string {
return []string{exec_repo.ID(executionType, fullMethod), exec_repo.ID(executionType, serviceFromFullMethod(fullMethod)), exec_repo.IDAll(executionType)}
}
func serviceFromFullMethod(s string) string {
parts := strings.Split(s, "/")
return parts[1]
}
var _ execution.ContextInfo = &ContextInfoRequest{} var _ execution.ContextInfo = &ContextInfoRequest{}
type ContextInfoRequest struct { type ContextInfoRequest struct {

View File

@ -44,31 +44,31 @@ func (s *Server) StartIdentityProviderIntent(ctx context.Context, req *user.Star
} }
func (s *Server) startIDPIntent(ctx context.Context, idpID string, urls *user.RedirectURLs) (*user.StartIdentityProviderIntentResponse, error) { func (s *Server) startIDPIntent(ctx context.Context, idpID string, urls *user.RedirectURLs) (*user.StartIdentityProviderIntentResponse, error) {
intentWriteModel, details, err := s.command.CreateIntent(ctx, idpID, urls.GetSuccessUrl(), urls.GetFailureUrl(), authz.GetInstance(ctx).InstanceID()) state, session, err := s.command.AuthFromProvider(ctx, idpID, s.idpCallback(ctx), s.samlRootURL(ctx, idpID))
if err != nil { if err != nil {
return nil, err return nil, err
} }
content, redirect, err := s.command.AuthFromProvider(ctx, idpID, intentWriteModel.AggregateID, s.idpCallback(ctx), s.samlRootURL(ctx, idpID)) _, details, err := s.command.CreateIntent(ctx, state, idpID, urls.GetSuccessUrl(), urls.GetFailureUrl(), authz.GetInstance(ctx).InstanceID(), session.PersistentParameters())
if err != nil { if err != nil {
return nil, err return nil, err
} }
content, redirect := session.GetAuth(ctx)
if redirect { if redirect {
return &user.StartIdentityProviderIntentResponse{ return &user.StartIdentityProviderIntentResponse{
Details: object.DomainToDetailsPb(details), Details: object.DomainToDetailsPb(details),
NextStep: &user.StartIdentityProviderIntentResponse_AuthUrl{AuthUrl: content}, NextStep: &user.StartIdentityProviderIntentResponse_AuthUrl{AuthUrl: content},
}, nil }, nil
} else {
return &user.StartIdentityProviderIntentResponse{
Details: object.DomainToDetailsPb(details),
NextStep: &user.StartIdentityProviderIntentResponse_PostForm{
PostForm: []byte(content),
},
}, nil
} }
return &user.StartIdentityProviderIntentResponse{
Details: object.DomainToDetailsPb(details),
NextStep: &user.StartIdentityProviderIntentResponse_PostForm{
PostForm: []byte(content),
},
}, nil
} }
func (s *Server) startLDAPIntent(ctx context.Context, idpID string, ldapCredentials *user.LDAPCredentials) (*user.StartIdentityProviderIntentResponse, error) { func (s *Server) startLDAPIntent(ctx context.Context, idpID string, ldapCredentials *user.LDAPCredentials) (*user.StartIdentityProviderIntentResponse, error) {
intentWriteModel, details, err := s.command.CreateIntent(ctx, idpID, "", "", authz.GetInstance(ctx).InstanceID()) intentWriteModel, details, err := s.command.CreateIntent(ctx, "", idpID, "", "", authz.GetInstance(ctx).InstanceID(), nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -412,62 +412,64 @@ func (s *Server) userinfoFlows(ctx context.Context, qu *query.OIDCUserInfo, user
} }
} }
function := "" var function string
switch triggerType { switch triggerType {
case domain.TriggerTypePreUserinfoCreation: case domain.TriggerTypePreUserinfoCreation:
function = exec_repo.ID(domain.ExecutionTypeFunction, domain.ActionFunctionPreUserinfo.LocalizationKey()) function = exec_repo.ID(domain.ExecutionTypeFunction, domain.ActionFunctionPreUserinfo.LocalizationKey())
case domain.TriggerTypePreAccessTokenCreation: case domain.TriggerTypePreAccessTokenCreation:
function = exec_repo.ID(domain.ExecutionTypeFunction, domain.ActionFunctionPreAccessToken.LocalizationKey()) function = exec_repo.ID(domain.ExecutionTypeFunction, domain.ActionFunctionPreAccessToken.LocalizationKey())
case domain.TriggerTypeUnspecified, domain.TriggerTypePostAuthentication, domain.TriggerTypePreCreation, domain.TriggerTypePostCreation, domain.TriggerTypePreSAMLResponseCreation: case domain.TriggerTypeUnspecified, domain.TriggerTypePostAuthentication, domain.TriggerTypePreCreation, domain.TriggerTypePostCreation, domain.TriggerTypePreSAMLResponseCreation:
fallthrough // added for linting, there should never be any trigger type be used here besides PreUserinfo and PreAccessToken
default: return
function = ""
} }
if function != "" { if function == "" {
executionTargets, err := queryExecutionTargets(ctx, s.query, function) return nil
if err != nil { }
return err executionTargets, err := execution.QueryExecutionTargetsForFunction(ctx, s.query, function)
} if err != nil {
info := &ContextInfo{ return err
Function: function, }
UserInfo: userInfo, info := &ContextInfo{
User: qu.User, Function: function,
UserMetadata: qu.Metadata, UserInfo: userInfo,
Org: qu.Org, User: qu.User,
UserGrants: qu.UserGrants, UserMetadata: qu.Metadata,
} Org: qu.Org,
UserGrants: qu.UserGrants,
}
resp, err := execution.CallTargets(ctx, executionTargets, info) resp, err := execution.CallTargets(ctx, executionTargets, info)
if err != nil { if err != nil {
return err return err
} }
contextInfoResponse, ok := resp.(*ContextInfoResponse) contextInfoResponse, ok := resp.(*ContextInfoResponse)
if ok && contextInfoResponse != nil { if !ok || contextInfoResponse == nil {
claimLogs := make([]string, 0) return nil
for _, metadata := range contextInfoResponse.SetUserMetadata { }
if _, err = s.command.SetUserMetadata(ctx, metadata, userInfo.Subject, qu.User.ResourceOwner); err != nil { claimLogs := make([]string, 0)
claimLogs = append(claimLogs, fmt.Sprintf("failed to set user metadata key %q", metadata.Key)) for _, metadata := range contextInfoResponse.SetUserMetadata {
} if _, err = s.command.SetUserMetadata(ctx, metadata, userInfo.Subject, qu.User.ResourceOwner); err != nil {
} claimLogs = append(claimLogs, fmt.Sprintf("failed to set user metadata key %q", metadata.Key))
for _, claim := range contextInfoResponse.AppendClaims {
if strings.HasPrefix(claim.Key, ClaimPrefix) {
continue
}
if userInfo.Claims[claim.Key] == nil {
userInfo.AppendClaims(claim.Key, claim.Value)
continue
}
claimLogs = append(claimLogs, fmt.Sprintf("key %q already exists", claim.Key))
}
for _, log := range contextInfoResponse.AppendLogClaims {
claimLogs = append(claimLogs, log)
}
if len(claimLogs) > 0 {
userInfo.AppendClaims(fmt.Sprintf(ClaimActionLogFormat, function), claimLogs)
}
} }
} }
for _, claim := range contextInfoResponse.AppendClaims {
if strings.HasPrefix(claim.Key, ClaimPrefix) {
continue
}
if userInfo.Claims[claim.Key] == nil {
userInfo.AppendClaims(claim.Key, claim.Value)
continue
}
claimLogs = append(claimLogs, fmt.Sprintf("key %q already exists", claim.Key))
}
for _, log := range contextInfoResponse.AppendLogClaims {
claimLogs = append(claimLogs, log)
}
if len(claimLogs) > 0 {
userInfo.AppendClaims(fmt.Sprintf(ClaimActionLogFormat, function), claimLogs)
}
return nil return nil
} }
@ -510,32 +512,6 @@ func (c *ContextInfo) SetHTTPResponseBody(resp []byte) error {
return json.Unmarshal(resp, c.Response) return json.Unmarshal(resp, c.Response)
} }
func (c *ContextInfo) GetContent() interface{} { func (c *ContextInfo) GetContent() any {
return c.Response return c.Response
} }
func queryExecutionTargets(ctx context.Context, query *query.Queries, function string) ([]execution.Target, error) {
queriedActionsV2, err := query.TargetsByExecutionID(ctx, []string{function})
if err != nil {
return nil, err
}
executionTargets := make([]execution.Target, len(queriedActionsV2))
for i, action := range queriedActionsV2 {
executionTargets[i] = action
}
return executionTargets, nil
}
func (s *Server) queryOrgMetadata(ctx context.Context, organizationID string) ([]*query.OrgMetadata, error) {
metadata, err := s.query.SearchOrgMetadata(
ctx,
true,
organizationID,
&query.OrgMetadataSearchQueries{},
false,
)
if err != nil {
return nil, err
}
return metadata.Metadata, nil
}

View File

@ -385,7 +385,7 @@ func (p *Storage) getCustomAttributes(ctx context.Context, user *query.User, use
} }
function := exec_repo.ID(domain.ExecutionTypeFunction, domain.ActionFunctionPreSAMLResponse.LocalizationKey()) function := exec_repo.ID(domain.ExecutionTypeFunction, domain.ActionFunctionPreSAMLResponse.LocalizationKey())
executionTargets, err := queryExecutionTargets(ctx, p.query, function) executionTargets, err := execution.QueryExecutionTargetsForFunction(ctx, p.query, function)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -461,18 +461,6 @@ func (c *ContextInfo) GetContent() interface{} {
return c.Response return c.Response
} }
func queryExecutionTargets(ctx context.Context, query *query.Queries, function string) ([]execution.Target, error) {
queriedActionsV2, err := query.TargetsByExecutionID(ctx, []string{function})
if err != nil {
return nil, err
}
executionTargets := make([]execution.Target, len(queriedActionsV2))
for i, action := range queriedActionsV2 {
executionTargets[i] = action
}
return executionTargets, nil
}
func (p *Storage) getGrants(ctx context.Context, userID, applicationID string) (*query.UserGrants, error) { func (p *Storage) getGrants(ctx context.Context, userID, applicationID string) (*query.UserGrants, error) {
projectID, err := p.query.ProjectIDFromClientID(ctx, applicationID) projectID, err := p.query.ProjectIDFromClientID(ctx, applicationID)
if err != nil { if err != nil {

View File

@ -22,7 +22,6 @@ type SAMLRequest struct {
Binding string Binding string
Issuer string Issuer string
Destination string Destination string
EntityID string
} }
type CurrentSAMLRequest struct { type CurrentSAMLRequest struct {

View File

@ -6,12 +6,15 @@ import (
"encoding/json" "encoding/json"
"io" "io"
"net/http" "net/http"
"strings"
"time" "time"
"github.com/zitadel/logging" "github.com/zitadel/logging"
zhttp "github.com/zitadel/zitadel/internal/api/http" zhttp "github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/repository/execution"
"github.com/zitadel/zitadel/internal/telemetry/tracing" "github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/zerrors" "github.com/zitadel/zitadel/internal/zerrors"
"github.com/zitadel/zitadel/pkg/actions" "github.com/zitadel/zitadel/pkg/actions"
@ -153,3 +156,59 @@ type ErrorBody struct {
ForwardedStatusCode int `json:"forwardedStatusCode,omitempty"` ForwardedStatusCode int `json:"forwardedStatusCode,omitempty"`
ForwardedErrorMessage string `json:"forwardedErrorMessage,omitempty"` ForwardedErrorMessage string `json:"forwardedErrorMessage,omitempty"`
} }
type ExecutionTargetsQueries interface {
TargetsByExecutionID(ctx context.Context, ids []string) (execution []*query.ExecutionTarget, err error)
TargetsByExecutionIDs(ctx context.Context, ids1, ids2 []string) (execution []*query.ExecutionTarget, err error)
}
func QueryExecutionTargetsForRequestAndResponse(
ctx context.Context,
queries ExecutionTargetsQueries,
fullMethod string,
) ([]Target, []Target) {
ctx, span := tracing.NewSpan(ctx)
defer span.End()
targets, err := queries.TargetsByExecutionIDs(ctx,
idsForFullMethod(fullMethod, domain.ExecutionTypeRequest),
idsForFullMethod(fullMethod, domain.ExecutionTypeResponse),
)
requestTargets := make([]Target, 0, len(targets))
responseTargets := make([]Target, 0, len(targets))
if err != nil {
logging.WithFields("fullMethod", fullMethod).WithError(err).Info("unable to query targets")
return requestTargets, responseTargets
}
for _, target := range targets {
if strings.HasPrefix(target.GetExecutionID(), execution.IDAll(domain.ExecutionTypeRequest)) {
requestTargets = append(requestTargets, target)
} else if strings.HasPrefix(target.GetExecutionID(), execution.IDAll(domain.ExecutionTypeResponse)) {
responseTargets = append(responseTargets, target)
}
}
return requestTargets, responseTargets
}
func idsForFullMethod(fullMethod string, executionType domain.ExecutionType) []string {
return []string{execution.ID(executionType, fullMethod), execution.ID(executionType, serviceFromFullMethod(fullMethod)), execution.IDAll(executionType)}
}
func serviceFromFullMethod(s string) string {
parts := strings.Split(s, "/")
return parts[1]
}
func QueryExecutionTargetsForFunction(ctx context.Context, query ExecutionTargetsQueries, function string) ([]Target, error) {
queriedActionsV2, err := query.TargetsByExecutionID(ctx, []string{function})
if err != nil {
return nil, err
}
executionTargets := make([]Target, len(queriedActionsV2))
for i, action := range queriedActionsV2 {
executionTargets[i] = action
}
return executionTargets, nil
}

View File

@ -17,9 +17,9 @@ type Provider struct {
rp.RelyingParty rp.RelyingParty
options []rp.Option options []rp.Option
name string name string
userEndpoint string userEndpoint string
userMapper func() idp.User user func() idp.User
isLinkingAllowed bool isLinkingAllowed bool
isCreationAllowed bool isCreationAllowed bool
isAutoCreation bool isAutoCreation bool
isAutoUpdate bool isAutoUpdate bool
@ -65,11 +65,11 @@ func WithRelyingPartyOption(option rp.Option) ProviderOpts {
} }
// New creates a generic OAuth 2.0 provider // New creates a generic OAuth 2.0 provider
func New(config *oauth2.Config, name, userEndpoint string, userMapper func() idp.User, options ...ProviderOpts) (provider *Provider, err error) { func New(config *oauth2.Config, name, userEndpoint string, user func() idp.User, options ...ProviderOpts) (provider *Provider, err error) {
provider = &Provider{ provider = &Provider{
name: name, name: name,
userEndpoint: userEndpoint, userEndpoint: userEndpoint,
userMapper: userMapper, user: user,
generateVerifier: oauth2.GenerateVerifier, generateVerifier: oauth2.GenerateVerifier,
} }
for _, option := range options { for _, option := range options {
@ -139,5 +139,5 @@ func (p *Provider) IsAutoUpdate() bool {
} }
func (p *Provider) User() idp.User { func (p *Provider) User() idp.User {
return p.userMapper() return p.user()
} }

View File

@ -51,7 +51,7 @@ func (s *Session) PersistentParameters() map[string]any {
// FetchUser implements the [idp.Session] interface. // FetchUser implements the [idp.Session] interface.
// It will execute an OAuth 2.0 code exchange if needed to retrieve the access token, // It will execute an OAuth 2.0 code exchange if needed to retrieve the access token,
// call the specified userEndpoint and map the received information into an [idp.User]. // call the specified userEndpoint and map the received information into an [idp.User].
func (s *Session) FetchUser(ctx context.Context) (user idp.User, err error) { func (s *Session) FetchUser(ctx context.Context) (_ idp.User, err error) {
if s.Tokens == nil { if s.Tokens == nil {
if err = s.authorize(ctx); err != nil { if err = s.authorize(ctx); err != nil {
return nil, err return nil, err
@ -62,11 +62,11 @@ func (s *Session) FetchUser(ctx context.Context) (user idp.User, err error) {
return nil, err return nil, err
} }
req.Header.Set("authorization", s.Tokens.TokenType+" "+s.Tokens.AccessToken) req.Header.Set("authorization", s.Tokens.TokenType+" "+s.Tokens.AccessToken)
mapper := s.Provider.User() user := s.Provider.User()
if err := httphelper.HttpRequest(s.Provider.RelyingParty.HttpClient(), req, &mapper); err != nil { if err := httphelper.HttpRequest(s.Provider.RelyingParty.HttpClient(), req, &user); err != nil {
return nil, err return nil, err
} }
return mapper, nil return user, nil
} }
func (s *Session) authorize(ctx context.Context) (err error) { func (s *Session) authorize(ctx context.Context) (err error) {