fix: added changed from review

This commit is contained in:
Stefan Benz 2025-02-27 10:54:32 +01:00
parent a226de975b
commit 784b3afb4f
No known key found for this signature in database
GPG Key ID: 071AA751ED4F9D31
8 changed files with 129 additions and 154 deletions

View File

@ -3,23 +3,19 @@ package middleware
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"strings"
"github.com/zitadel/logging"
"google.golang.org/grpc" "google.golang.org/grpc"
"github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/execution" "github.com/zitadel/zitadel/internal/execution"
"github.com/zitadel/zitadel/internal/query" "github.com/zitadel/zitadel/internal/query"
exec_repo "github.com/zitadel/zitadel/internal/repository/execution"
"github.com/zitadel/zitadel/internal/telemetry/tracing" "github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/zerrors" "github.com/zitadel/zitadel/internal/zerrors"
) )
func ExecutionHandler(queries *query.Queries) grpc.UnaryServerInterceptor { func ExecutionHandler(queries *query.Queries) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
requestTargets, responseTargets := queryTargets(ctx, queries, info.FullMethod) requestTargets, responseTargets := execution.QueryExecutionTargetsForRequestAndResponse(ctx, queries, info.FullMethod)
// call targets otherwise return req // call targets otherwise return req
handledReq, err := executeTargetsForRequest(ctx, requestTargets, info.FullMethod, req) handledReq, err := executeTargetsForRequest(ctx, requestTargets, info.FullMethod, req)
@ -81,49 +77,6 @@ func executeTargetsForResponse(ctx context.Context, targets []execution.Target,
return execution.CallTargets(ctx, targets, info) return execution.CallTargets(ctx, targets, info)
} }
type ExecutionQueries interface {
TargetsByExecutionIDs(ctx context.Context, ids1, ids2 []string) (execution []*query.ExecutionTarget, err error)
}
func queryTargets(
ctx context.Context,
queries ExecutionQueries,
fullMethod string,
) ([]execution.Target, []execution.Target) {
ctx, span := tracing.NewSpan(ctx)
defer span.End()
targets, err := queries.TargetsByExecutionIDs(ctx,
idsForFullMethod(fullMethod, domain.ExecutionTypeRequest),
idsForFullMethod(fullMethod, domain.ExecutionTypeResponse),
)
requestTargets := make([]execution.Target, 0, len(targets))
responseTargets := make([]execution.Target, 0, len(targets))
if err != nil {
logging.WithFields("fullMethod", fullMethod).WithError(err).Info("unable to query targets")
return requestTargets, responseTargets
}
for _, target := range targets {
if strings.HasPrefix(target.GetExecutionID(), exec_repo.IDAll(domain.ExecutionTypeRequest)) {
requestTargets = append(requestTargets, target)
} else if strings.HasPrefix(target.GetExecutionID(), exec_repo.IDAll(domain.ExecutionTypeResponse)) {
responseTargets = append(responseTargets, target)
}
}
return requestTargets, responseTargets
}
func idsForFullMethod(fullMethod string, executionType domain.ExecutionType) []string {
return []string{exec_repo.ID(executionType, fullMethod), exec_repo.ID(executionType, serviceFromFullMethod(fullMethod)), exec_repo.IDAll(executionType)}
}
func serviceFromFullMethod(s string) string {
parts := strings.Split(s, "/")
return parts[1]
}
var _ execution.ContextInfo = &ContextInfoRequest{} var _ execution.ContextInfo = &ContextInfoRequest{}
type ContextInfoRequest struct { type ContextInfoRequest struct {

View File

@ -44,31 +44,31 @@ func (s *Server) StartIdentityProviderIntent(ctx context.Context, req *user.Star
} }
func (s *Server) startIDPIntent(ctx context.Context, idpID string, urls *user.RedirectURLs) (*user.StartIdentityProviderIntentResponse, error) { func (s *Server) startIDPIntent(ctx context.Context, idpID string, urls *user.RedirectURLs) (*user.StartIdentityProviderIntentResponse, error) {
intentWriteModel, details, err := s.command.CreateIntent(ctx, idpID, urls.GetSuccessUrl(), urls.GetFailureUrl(), authz.GetInstance(ctx).InstanceID()) state, session, err := s.command.AuthFromProvider(ctx, idpID, s.idpCallback(ctx), s.samlRootURL(ctx, idpID))
if err != nil { if err != nil {
return nil, err return nil, err
} }
content, redirect, err := s.command.AuthFromProvider(ctx, idpID, intentWriteModel.AggregateID, s.idpCallback(ctx), s.samlRootURL(ctx, idpID)) _, details, err := s.command.CreateIntent(ctx, state, idpID, urls.GetSuccessUrl(), urls.GetFailureUrl(), authz.GetInstance(ctx).InstanceID(), session.PersistentParameters())
if err != nil { if err != nil {
return nil, err return nil, err
} }
content, redirect := session.GetAuth(ctx)
if redirect { if redirect {
return &user.StartIdentityProviderIntentResponse{ return &user.StartIdentityProviderIntentResponse{
Details: object.DomainToDetailsPb(details), Details: object.DomainToDetailsPb(details),
NextStep: &user.StartIdentityProviderIntentResponse_AuthUrl{AuthUrl: content}, NextStep: &user.StartIdentityProviderIntentResponse_AuthUrl{AuthUrl: content},
}, nil }, nil
} else { }
return &user.StartIdentityProviderIntentResponse{ return &user.StartIdentityProviderIntentResponse{
Details: object.DomainToDetailsPb(details), Details: object.DomainToDetailsPb(details),
NextStep: &user.StartIdentityProviderIntentResponse_PostForm{ NextStep: &user.StartIdentityProviderIntentResponse_PostForm{
PostForm: []byte(content), PostForm: []byte(content),
}, },
}, nil }, nil
}
} }
func (s *Server) startLDAPIntent(ctx context.Context, idpID string, ldapCredentials *user.LDAPCredentials) (*user.StartIdentityProviderIntentResponse, error) { func (s *Server) startLDAPIntent(ctx context.Context, idpID string, ldapCredentials *user.LDAPCredentials) (*user.StartIdentityProviderIntentResponse, error) {
intentWriteModel, details, err := s.command.CreateIntent(ctx, idpID, "", "", authz.GetInstance(ctx).InstanceID()) intentWriteModel, details, err := s.command.CreateIntent(ctx, "", idpID, "", "", authz.GetInstance(ctx).InstanceID(), nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -412,20 +412,21 @@ func (s *Server) userinfoFlows(ctx context.Context, qu *query.OIDCUserInfo, user
} }
} }
function := "" var function string
switch triggerType { switch triggerType {
case domain.TriggerTypePreUserinfoCreation: case domain.TriggerTypePreUserinfoCreation:
function = exec_repo.ID(domain.ExecutionTypeFunction, domain.ActionFunctionPreUserinfo.LocalizationKey()) function = exec_repo.ID(domain.ExecutionTypeFunction, domain.ActionFunctionPreUserinfo.LocalizationKey())
case domain.TriggerTypePreAccessTokenCreation: case domain.TriggerTypePreAccessTokenCreation:
function = exec_repo.ID(domain.ExecutionTypeFunction, domain.ActionFunctionPreAccessToken.LocalizationKey()) function = exec_repo.ID(domain.ExecutionTypeFunction, domain.ActionFunctionPreAccessToken.LocalizationKey())
case domain.TriggerTypeUnspecified, domain.TriggerTypePostAuthentication, domain.TriggerTypePreCreation, domain.TriggerTypePostCreation, domain.TriggerTypePreSAMLResponseCreation: case domain.TriggerTypeUnspecified, domain.TriggerTypePostAuthentication, domain.TriggerTypePreCreation, domain.TriggerTypePostCreation, domain.TriggerTypePreSAMLResponseCreation:
fallthrough // added for linting, there should never be any trigger type be used here besides PreUserinfo and PreAccessToken
default: return
function = ""
} }
if function != "" { if function == "" {
executionTargets, err := queryExecutionTargets(ctx, s.query, function) return nil
}
executionTargets, err := execution.QueryExecutionTargetsForFunction(ctx, s.query, function)
if err != nil { if err != nil {
return err return err
} }
@ -443,7 +444,9 @@ func (s *Server) userinfoFlows(ctx context.Context, qu *query.OIDCUserInfo, user
return err return err
} }
contextInfoResponse, ok := resp.(*ContextInfoResponse) contextInfoResponse, ok := resp.(*ContextInfoResponse)
if ok && contextInfoResponse != nil { if !ok || contextInfoResponse == nil {
return nil
}
claimLogs := make([]string, 0) claimLogs := make([]string, 0)
for _, metadata := range contextInfoResponse.SetUserMetadata { for _, metadata := range contextInfoResponse.SetUserMetadata {
if _, err = s.command.SetUserMetadata(ctx, metadata, userInfo.Subject, qu.User.ResourceOwner); err != nil { if _, err = s.command.SetUserMetadata(ctx, metadata, userInfo.Subject, qu.User.ResourceOwner); err != nil {
@ -466,8 +469,7 @@ func (s *Server) userinfoFlows(ctx context.Context, qu *query.OIDCUserInfo, user
if len(claimLogs) > 0 { if len(claimLogs) > 0 {
userInfo.AppendClaims(fmt.Sprintf(ClaimActionLogFormat, function), claimLogs) userInfo.AppendClaims(fmt.Sprintf(ClaimActionLogFormat, function), claimLogs)
} }
}
}
return nil return nil
} }
@ -510,32 +512,6 @@ func (c *ContextInfo) SetHTTPResponseBody(resp []byte) error {
return json.Unmarshal(resp, c.Response) return json.Unmarshal(resp, c.Response)
} }
func (c *ContextInfo) GetContent() interface{} { func (c *ContextInfo) GetContent() any {
return c.Response return c.Response
} }
func queryExecutionTargets(ctx context.Context, query *query.Queries, function string) ([]execution.Target, error) {
queriedActionsV2, err := query.TargetsByExecutionID(ctx, []string{function})
if err != nil {
return nil, err
}
executionTargets := make([]execution.Target, len(queriedActionsV2))
for i, action := range queriedActionsV2 {
executionTargets[i] = action
}
return executionTargets, nil
}
func (s *Server) queryOrgMetadata(ctx context.Context, organizationID string) ([]*query.OrgMetadata, error) {
metadata, err := s.query.SearchOrgMetadata(
ctx,
true,
organizationID,
&query.OrgMetadataSearchQueries{},
false,
)
if err != nil {
return nil, err
}
return metadata.Metadata, nil
}

View File

@ -385,7 +385,7 @@ func (p *Storage) getCustomAttributes(ctx context.Context, user *query.User, use
} }
function := exec_repo.ID(domain.ExecutionTypeFunction, domain.ActionFunctionPreSAMLResponse.LocalizationKey()) function := exec_repo.ID(domain.ExecutionTypeFunction, domain.ActionFunctionPreSAMLResponse.LocalizationKey())
executionTargets, err := queryExecutionTargets(ctx, p.query, function) executionTargets, err := execution.QueryExecutionTargetsForFunction(ctx, p.query, function)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -461,18 +461,6 @@ func (c *ContextInfo) GetContent() interface{} {
return c.Response return c.Response
} }
func queryExecutionTargets(ctx context.Context, query *query.Queries, function string) ([]execution.Target, error) {
queriedActionsV2, err := query.TargetsByExecutionID(ctx, []string{function})
if err != nil {
return nil, err
}
executionTargets := make([]execution.Target, len(queriedActionsV2))
for i, action := range queriedActionsV2 {
executionTargets[i] = action
}
return executionTargets, nil
}
func (p *Storage) getGrants(ctx context.Context, userID, applicationID string) (*query.UserGrants, error) { func (p *Storage) getGrants(ctx context.Context, userID, applicationID string) (*query.UserGrants, error) {
projectID, err := p.query.ProjectIDFromClientID(ctx, applicationID) projectID, err := p.query.ProjectIDFromClientID(ctx, applicationID)
if err != nil { if err != nil {

View File

@ -22,7 +22,6 @@ type SAMLRequest struct {
Binding string Binding string
Issuer string Issuer string
Destination string Destination string
EntityID string
} }
type CurrentSAMLRequest struct { type CurrentSAMLRequest struct {

View File

@ -6,12 +6,15 @@ import (
"encoding/json" "encoding/json"
"io" "io"
"net/http" "net/http"
"strings"
"time" "time"
"github.com/zitadel/logging" "github.com/zitadel/logging"
zhttp "github.com/zitadel/zitadel/internal/api/http" zhttp "github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/repository/execution"
"github.com/zitadel/zitadel/internal/telemetry/tracing" "github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/zerrors" "github.com/zitadel/zitadel/internal/zerrors"
"github.com/zitadel/zitadel/pkg/actions" "github.com/zitadel/zitadel/pkg/actions"
@ -153,3 +156,59 @@ type ErrorBody struct {
ForwardedStatusCode int `json:"forwardedStatusCode,omitempty"` ForwardedStatusCode int `json:"forwardedStatusCode,omitempty"`
ForwardedErrorMessage string `json:"forwardedErrorMessage,omitempty"` ForwardedErrorMessage string `json:"forwardedErrorMessage,omitempty"`
} }
type ExecutionTargetsQueries interface {
TargetsByExecutionID(ctx context.Context, ids []string) (execution []*query.ExecutionTarget, err error)
TargetsByExecutionIDs(ctx context.Context, ids1, ids2 []string) (execution []*query.ExecutionTarget, err error)
}
func QueryExecutionTargetsForRequestAndResponse(
ctx context.Context,
queries ExecutionTargetsQueries,
fullMethod string,
) ([]Target, []Target) {
ctx, span := tracing.NewSpan(ctx)
defer span.End()
targets, err := queries.TargetsByExecutionIDs(ctx,
idsForFullMethod(fullMethod, domain.ExecutionTypeRequest),
idsForFullMethod(fullMethod, domain.ExecutionTypeResponse),
)
requestTargets := make([]Target, 0, len(targets))
responseTargets := make([]Target, 0, len(targets))
if err != nil {
logging.WithFields("fullMethod", fullMethod).WithError(err).Info("unable to query targets")
return requestTargets, responseTargets
}
for _, target := range targets {
if strings.HasPrefix(target.GetExecutionID(), execution.IDAll(domain.ExecutionTypeRequest)) {
requestTargets = append(requestTargets, target)
} else if strings.HasPrefix(target.GetExecutionID(), execution.IDAll(domain.ExecutionTypeResponse)) {
responseTargets = append(responseTargets, target)
}
}
return requestTargets, responseTargets
}
func idsForFullMethod(fullMethod string, executionType domain.ExecutionType) []string {
return []string{execution.ID(executionType, fullMethod), execution.ID(executionType, serviceFromFullMethod(fullMethod)), execution.IDAll(executionType)}
}
func serviceFromFullMethod(s string) string {
parts := strings.Split(s, "/")
return parts[1]
}
func QueryExecutionTargetsForFunction(ctx context.Context, query ExecutionTargetsQueries, function string) ([]Target, error) {
queriedActionsV2, err := query.TargetsByExecutionID(ctx, []string{function})
if err != nil {
return nil, err
}
executionTargets := make([]Target, len(queriedActionsV2))
for i, action := range queriedActionsV2 {
executionTargets[i] = action
}
return executionTargets, nil
}

View File

@ -18,7 +18,7 @@ type Provider struct {
options []rp.Option options []rp.Option
name string name string
userEndpoint string userEndpoint string
userMapper func() idp.User user func() idp.User
isLinkingAllowed bool isLinkingAllowed bool
isCreationAllowed bool isCreationAllowed bool
isAutoCreation bool isAutoCreation bool
@ -65,11 +65,11 @@ func WithRelyingPartyOption(option rp.Option) ProviderOpts {
} }
// New creates a generic OAuth 2.0 provider // New creates a generic OAuth 2.0 provider
func New(config *oauth2.Config, name, userEndpoint string, userMapper func() idp.User, options ...ProviderOpts) (provider *Provider, err error) { func New(config *oauth2.Config, name, userEndpoint string, user func() idp.User, options ...ProviderOpts) (provider *Provider, err error) {
provider = &Provider{ provider = &Provider{
name: name, name: name,
userEndpoint: userEndpoint, userEndpoint: userEndpoint,
userMapper: userMapper, user: user,
generateVerifier: oauth2.GenerateVerifier, generateVerifier: oauth2.GenerateVerifier,
} }
for _, option := range options { for _, option := range options {
@ -139,5 +139,5 @@ func (p *Provider) IsAutoUpdate() bool {
} }
func (p *Provider) User() idp.User { func (p *Provider) User() idp.User {
return p.userMapper() return p.user()
} }

View File

@ -51,7 +51,7 @@ func (s *Session) PersistentParameters() map[string]any {
// FetchUser implements the [idp.Session] interface. // FetchUser implements the [idp.Session] interface.
// It will execute an OAuth 2.0 code exchange if needed to retrieve the access token, // It will execute an OAuth 2.0 code exchange if needed to retrieve the access token,
// call the specified userEndpoint and map the received information into an [idp.User]. // call the specified userEndpoint and map the received information into an [idp.User].
func (s *Session) FetchUser(ctx context.Context) (user idp.User, err error) { func (s *Session) FetchUser(ctx context.Context) (_ idp.User, err error) {
if s.Tokens == nil { if s.Tokens == nil {
if err = s.authorize(ctx); err != nil { if err = s.authorize(ctx); err != nil {
return nil, err return nil, err
@ -62,11 +62,11 @@ func (s *Session) FetchUser(ctx context.Context) (user idp.User, err error) {
return nil, err return nil, err
} }
req.Header.Set("authorization", s.Tokens.TokenType+" "+s.Tokens.AccessToken) req.Header.Set("authorization", s.Tokens.TokenType+" "+s.Tokens.AccessToken)
mapper := s.Provider.User() user := s.Provider.User()
if err := httphelper.HttpRequest(s.Provider.RelyingParty.HttpClient(), req, &mapper); err != nil { if err := httphelper.HttpRequest(s.Provider.RelyingParty.HttpClient(), req, &user); err != nil {
return nil, err return nil, err
} }
return mapper, nil return user, nil
} }
func (s *Session) authorize(ctx context.Context) (err error) { func (s *Session) authorize(ctx context.Context) (err error) {