mirror of
https://github.com/zitadel/zitadel.git
synced 2025-03-01 05:17:23 +00:00
fix: added changed from review
This commit is contained in:
parent
a226de975b
commit
784b3afb4f
@ -3,23 +3,19 @@ package middleware
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/execution"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
exec_repo "github.com/zitadel/zitadel/internal/repository/execution"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
func ExecutionHandler(queries *query.Queries) grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||||
requestTargets, responseTargets := queryTargets(ctx, queries, info.FullMethod)
|
||||
requestTargets, responseTargets := execution.QueryExecutionTargetsForRequestAndResponse(ctx, queries, info.FullMethod)
|
||||
|
||||
// call targets otherwise return req
|
||||
handledReq, err := executeTargetsForRequest(ctx, requestTargets, info.FullMethod, req)
|
||||
@ -81,49 +77,6 @@ func executeTargetsForResponse(ctx context.Context, targets []execution.Target,
|
||||
return execution.CallTargets(ctx, targets, info)
|
||||
}
|
||||
|
||||
type ExecutionQueries interface {
|
||||
TargetsByExecutionIDs(ctx context.Context, ids1, ids2 []string) (execution []*query.ExecutionTarget, err error)
|
||||
}
|
||||
|
||||
func queryTargets(
|
||||
ctx context.Context,
|
||||
queries ExecutionQueries,
|
||||
fullMethod string,
|
||||
) ([]execution.Target, []execution.Target) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer span.End()
|
||||
|
||||
targets, err := queries.TargetsByExecutionIDs(ctx,
|
||||
idsForFullMethod(fullMethod, domain.ExecutionTypeRequest),
|
||||
idsForFullMethod(fullMethod, domain.ExecutionTypeResponse),
|
||||
)
|
||||
requestTargets := make([]execution.Target, 0, len(targets))
|
||||
responseTargets := make([]execution.Target, 0, len(targets))
|
||||
if err != nil {
|
||||
logging.WithFields("fullMethod", fullMethod).WithError(err).Info("unable to query targets")
|
||||
return requestTargets, responseTargets
|
||||
}
|
||||
|
||||
for _, target := range targets {
|
||||
if strings.HasPrefix(target.GetExecutionID(), exec_repo.IDAll(domain.ExecutionTypeRequest)) {
|
||||
requestTargets = append(requestTargets, target)
|
||||
} else if strings.HasPrefix(target.GetExecutionID(), exec_repo.IDAll(domain.ExecutionTypeResponse)) {
|
||||
responseTargets = append(responseTargets, target)
|
||||
}
|
||||
}
|
||||
|
||||
return requestTargets, responseTargets
|
||||
}
|
||||
|
||||
func idsForFullMethod(fullMethod string, executionType domain.ExecutionType) []string {
|
||||
return []string{exec_repo.ID(executionType, fullMethod), exec_repo.ID(executionType, serviceFromFullMethod(fullMethod)), exec_repo.IDAll(executionType)}
|
||||
}
|
||||
|
||||
func serviceFromFullMethod(s string) string {
|
||||
parts := strings.Split(s, "/")
|
||||
return parts[1]
|
||||
}
|
||||
|
||||
var _ execution.ContextInfo = &ContextInfoRequest{}
|
||||
|
||||
type ContextInfoRequest struct {
|
||||
|
@ -44,31 +44,31 @@ func (s *Server) StartIdentityProviderIntent(ctx context.Context, req *user.Star
|
||||
}
|
||||
|
||||
func (s *Server) startIDPIntent(ctx context.Context, idpID string, urls *user.RedirectURLs) (*user.StartIdentityProviderIntentResponse, error) {
|
||||
intentWriteModel, details, err := s.command.CreateIntent(ctx, idpID, urls.GetSuccessUrl(), urls.GetFailureUrl(), authz.GetInstance(ctx).InstanceID())
|
||||
state, session, err := s.command.AuthFromProvider(ctx, idpID, s.idpCallback(ctx), s.samlRootURL(ctx, idpID))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
content, redirect, err := s.command.AuthFromProvider(ctx, idpID, intentWriteModel.AggregateID, s.idpCallback(ctx), s.samlRootURL(ctx, idpID))
|
||||
_, details, err := s.command.CreateIntent(ctx, state, idpID, urls.GetSuccessUrl(), urls.GetFailureUrl(), authz.GetInstance(ctx).InstanceID(), session.PersistentParameters())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
content, redirect := session.GetAuth(ctx)
|
||||
if redirect {
|
||||
return &user.StartIdentityProviderIntentResponse{
|
||||
Details: object.DomainToDetailsPb(details),
|
||||
NextStep: &user.StartIdentityProviderIntentResponse_AuthUrl{AuthUrl: content},
|
||||
}, nil
|
||||
} else {
|
||||
return &user.StartIdentityProviderIntentResponse{
|
||||
Details: object.DomainToDetailsPb(details),
|
||||
NextStep: &user.StartIdentityProviderIntentResponse_PostForm{
|
||||
PostForm: []byte(content),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
return &user.StartIdentityProviderIntentResponse{
|
||||
Details: object.DomainToDetailsPb(details),
|
||||
NextStep: &user.StartIdentityProviderIntentResponse_PostForm{
|
||||
PostForm: []byte(content),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) startLDAPIntent(ctx context.Context, idpID string, ldapCredentials *user.LDAPCredentials) (*user.StartIdentityProviderIntentResponse, error) {
|
||||
intentWriteModel, details, err := s.command.CreateIntent(ctx, idpID, "", "", authz.GetInstance(ctx).InstanceID())
|
||||
intentWriteModel, details, err := s.command.CreateIntent(ctx, "", idpID, "", "", authz.GetInstance(ctx).InstanceID(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -412,62 +412,64 @@ func (s *Server) userinfoFlows(ctx context.Context, qu *query.OIDCUserInfo, user
|
||||
}
|
||||
}
|
||||
|
||||
function := ""
|
||||
var function string
|
||||
switch triggerType {
|
||||
case domain.TriggerTypePreUserinfoCreation:
|
||||
function = exec_repo.ID(domain.ExecutionTypeFunction, domain.ActionFunctionPreUserinfo.LocalizationKey())
|
||||
case domain.TriggerTypePreAccessTokenCreation:
|
||||
function = exec_repo.ID(domain.ExecutionTypeFunction, domain.ActionFunctionPreAccessToken.LocalizationKey())
|
||||
case domain.TriggerTypeUnspecified, domain.TriggerTypePostAuthentication, domain.TriggerTypePreCreation, domain.TriggerTypePostCreation, domain.TriggerTypePreSAMLResponseCreation:
|
||||
fallthrough
|
||||
default:
|
||||
function = ""
|
||||
// added for linting, there should never be any trigger type be used here besides PreUserinfo and PreAccessToken
|
||||
return
|
||||
}
|
||||
|
||||
if function != "" {
|
||||
executionTargets, err := queryExecutionTargets(ctx, s.query, function)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
info := &ContextInfo{
|
||||
Function: function,
|
||||
UserInfo: userInfo,
|
||||
User: qu.User,
|
||||
UserMetadata: qu.Metadata,
|
||||
Org: qu.Org,
|
||||
UserGrants: qu.UserGrants,
|
||||
}
|
||||
if function == "" {
|
||||
return nil
|
||||
}
|
||||
executionTargets, err := execution.QueryExecutionTargetsForFunction(ctx, s.query, function)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
info := &ContextInfo{
|
||||
Function: function,
|
||||
UserInfo: userInfo,
|
||||
User: qu.User,
|
||||
UserMetadata: qu.Metadata,
|
||||
Org: qu.Org,
|
||||
UserGrants: qu.UserGrants,
|
||||
}
|
||||
|
||||
resp, err := execution.CallTargets(ctx, executionTargets, info)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
contextInfoResponse, ok := resp.(*ContextInfoResponse)
|
||||
if ok && contextInfoResponse != nil {
|
||||
claimLogs := make([]string, 0)
|
||||
for _, metadata := range contextInfoResponse.SetUserMetadata {
|
||||
if _, err = s.command.SetUserMetadata(ctx, metadata, userInfo.Subject, qu.User.ResourceOwner); err != nil {
|
||||
claimLogs = append(claimLogs, fmt.Sprintf("failed to set user metadata key %q", metadata.Key))
|
||||
}
|
||||
}
|
||||
for _, claim := range contextInfoResponse.AppendClaims {
|
||||
if strings.HasPrefix(claim.Key, ClaimPrefix) {
|
||||
continue
|
||||
}
|
||||
if userInfo.Claims[claim.Key] == nil {
|
||||
userInfo.AppendClaims(claim.Key, claim.Value)
|
||||
continue
|
||||
}
|
||||
claimLogs = append(claimLogs, fmt.Sprintf("key %q already exists", claim.Key))
|
||||
}
|
||||
for _, log := range contextInfoResponse.AppendLogClaims {
|
||||
claimLogs = append(claimLogs, log)
|
||||
}
|
||||
if len(claimLogs) > 0 {
|
||||
userInfo.AppendClaims(fmt.Sprintf(ClaimActionLogFormat, function), claimLogs)
|
||||
}
|
||||
resp, err := execution.CallTargets(ctx, executionTargets, info)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
contextInfoResponse, ok := resp.(*ContextInfoResponse)
|
||||
if !ok || contextInfoResponse == nil {
|
||||
return nil
|
||||
}
|
||||
claimLogs := make([]string, 0)
|
||||
for _, metadata := range contextInfoResponse.SetUserMetadata {
|
||||
if _, err = s.command.SetUserMetadata(ctx, metadata, userInfo.Subject, qu.User.ResourceOwner); err != nil {
|
||||
claimLogs = append(claimLogs, fmt.Sprintf("failed to set user metadata key %q", metadata.Key))
|
||||
}
|
||||
}
|
||||
for _, claim := range contextInfoResponse.AppendClaims {
|
||||
if strings.HasPrefix(claim.Key, ClaimPrefix) {
|
||||
continue
|
||||
}
|
||||
if userInfo.Claims[claim.Key] == nil {
|
||||
userInfo.AppendClaims(claim.Key, claim.Value)
|
||||
continue
|
||||
}
|
||||
claimLogs = append(claimLogs, fmt.Sprintf("key %q already exists", claim.Key))
|
||||
}
|
||||
for _, log := range contextInfoResponse.AppendLogClaims {
|
||||
claimLogs = append(claimLogs, log)
|
||||
}
|
||||
if len(claimLogs) > 0 {
|
||||
userInfo.AppendClaims(fmt.Sprintf(ClaimActionLogFormat, function), claimLogs)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -510,32 +512,6 @@ func (c *ContextInfo) SetHTTPResponseBody(resp []byte) error {
|
||||
return json.Unmarshal(resp, c.Response)
|
||||
}
|
||||
|
||||
func (c *ContextInfo) GetContent() interface{} {
|
||||
func (c *ContextInfo) GetContent() any {
|
||||
return c.Response
|
||||
}
|
||||
|
||||
func queryExecutionTargets(ctx context.Context, query *query.Queries, function string) ([]execution.Target, error) {
|
||||
queriedActionsV2, err := query.TargetsByExecutionID(ctx, []string{function})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
executionTargets := make([]execution.Target, len(queriedActionsV2))
|
||||
for i, action := range queriedActionsV2 {
|
||||
executionTargets[i] = action
|
||||
}
|
||||
return executionTargets, nil
|
||||
}
|
||||
|
||||
func (s *Server) queryOrgMetadata(ctx context.Context, organizationID string) ([]*query.OrgMetadata, error) {
|
||||
metadata, err := s.query.SearchOrgMetadata(
|
||||
ctx,
|
||||
true,
|
||||
organizationID,
|
||||
&query.OrgMetadataSearchQueries{},
|
||||
false,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return metadata.Metadata, nil
|
||||
}
|
||||
|
@ -385,7 +385,7 @@ func (p *Storage) getCustomAttributes(ctx context.Context, user *query.User, use
|
||||
}
|
||||
|
||||
function := exec_repo.ID(domain.ExecutionTypeFunction, domain.ActionFunctionPreSAMLResponse.LocalizationKey())
|
||||
executionTargets, err := queryExecutionTargets(ctx, p.query, function)
|
||||
executionTargets, err := execution.QueryExecutionTargetsForFunction(ctx, p.query, function)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -461,18 +461,6 @@ func (c *ContextInfo) GetContent() interface{} {
|
||||
return c.Response
|
||||
}
|
||||
|
||||
func queryExecutionTargets(ctx context.Context, query *query.Queries, function string) ([]execution.Target, error) {
|
||||
queriedActionsV2, err := query.TargetsByExecutionID(ctx, []string{function})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
executionTargets := make([]execution.Target, len(queriedActionsV2))
|
||||
for i, action := range queriedActionsV2 {
|
||||
executionTargets[i] = action
|
||||
}
|
||||
return executionTargets, nil
|
||||
}
|
||||
|
||||
func (p *Storage) getGrants(ctx context.Context, userID, applicationID string) (*query.UserGrants, error) {
|
||||
projectID, err := p.query.ProjectIDFromClientID(ctx, applicationID)
|
||||
if err != nil {
|
||||
|
@ -22,7 +22,6 @@ type SAMLRequest struct {
|
||||
Binding string
|
||||
Issuer string
|
||||
Destination string
|
||||
EntityID string
|
||||
}
|
||||
|
||||
type CurrentSAMLRequest struct {
|
||||
|
@ -6,12 +6,15 @@ import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
zhttp "github.com/zitadel/zitadel/internal/api/http"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
"github.com/zitadel/zitadel/internal/repository/execution"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
"github.com/zitadel/zitadel/pkg/actions"
|
||||
@ -153,3 +156,59 @@ type ErrorBody struct {
|
||||
ForwardedStatusCode int `json:"forwardedStatusCode,omitempty"`
|
||||
ForwardedErrorMessage string `json:"forwardedErrorMessage,omitempty"`
|
||||
}
|
||||
|
||||
type ExecutionTargetsQueries interface {
|
||||
TargetsByExecutionID(ctx context.Context, ids []string) (execution []*query.ExecutionTarget, err error)
|
||||
TargetsByExecutionIDs(ctx context.Context, ids1, ids2 []string) (execution []*query.ExecutionTarget, err error)
|
||||
}
|
||||
|
||||
func QueryExecutionTargetsForRequestAndResponse(
|
||||
ctx context.Context,
|
||||
queries ExecutionTargetsQueries,
|
||||
fullMethod string,
|
||||
) ([]Target, []Target) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer span.End()
|
||||
|
||||
targets, err := queries.TargetsByExecutionIDs(ctx,
|
||||
idsForFullMethod(fullMethod, domain.ExecutionTypeRequest),
|
||||
idsForFullMethod(fullMethod, domain.ExecutionTypeResponse),
|
||||
)
|
||||
requestTargets := make([]Target, 0, len(targets))
|
||||
responseTargets := make([]Target, 0, len(targets))
|
||||
if err != nil {
|
||||
logging.WithFields("fullMethod", fullMethod).WithError(err).Info("unable to query targets")
|
||||
return requestTargets, responseTargets
|
||||
}
|
||||
|
||||
for _, target := range targets {
|
||||
if strings.HasPrefix(target.GetExecutionID(), execution.IDAll(domain.ExecutionTypeRequest)) {
|
||||
requestTargets = append(requestTargets, target)
|
||||
} else if strings.HasPrefix(target.GetExecutionID(), execution.IDAll(domain.ExecutionTypeResponse)) {
|
||||
responseTargets = append(responseTargets, target)
|
||||
}
|
||||
}
|
||||
|
||||
return requestTargets, responseTargets
|
||||
}
|
||||
|
||||
func idsForFullMethod(fullMethod string, executionType domain.ExecutionType) []string {
|
||||
return []string{execution.ID(executionType, fullMethod), execution.ID(executionType, serviceFromFullMethod(fullMethod)), execution.IDAll(executionType)}
|
||||
}
|
||||
|
||||
func serviceFromFullMethod(s string) string {
|
||||
parts := strings.Split(s, "/")
|
||||
return parts[1]
|
||||
}
|
||||
|
||||
func QueryExecutionTargetsForFunction(ctx context.Context, query ExecutionTargetsQueries, function string) ([]Target, error) {
|
||||
queriedActionsV2, err := query.TargetsByExecutionID(ctx, []string{function})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
executionTargets := make([]Target, len(queriedActionsV2))
|
||||
for i, action := range queriedActionsV2 {
|
||||
executionTargets[i] = action
|
||||
}
|
||||
return executionTargets, nil
|
||||
}
|
||||
|
@ -17,9 +17,9 @@ type Provider struct {
|
||||
rp.RelyingParty
|
||||
options []rp.Option
|
||||
name string
|
||||
userEndpoint string
|
||||
userMapper func() idp.User
|
||||
isLinkingAllowed bool
|
||||
userEndpoint string
|
||||
user func() idp.User
|
||||
isLinkingAllowed bool
|
||||
isCreationAllowed bool
|
||||
isAutoCreation bool
|
||||
isAutoUpdate bool
|
||||
@ -65,11 +65,11 @@ func WithRelyingPartyOption(option rp.Option) ProviderOpts {
|
||||
}
|
||||
|
||||
// New creates a generic OAuth 2.0 provider
|
||||
func New(config *oauth2.Config, name, userEndpoint string, userMapper func() idp.User, options ...ProviderOpts) (provider *Provider, err error) {
|
||||
func New(config *oauth2.Config, name, userEndpoint string, user func() idp.User, options ...ProviderOpts) (provider *Provider, err error) {
|
||||
provider = &Provider{
|
||||
name: name,
|
||||
userEndpoint: userEndpoint,
|
||||
userMapper: userMapper,
|
||||
user: user,
|
||||
generateVerifier: oauth2.GenerateVerifier,
|
||||
}
|
||||
for _, option := range options {
|
||||
@ -139,5 +139,5 @@ func (p *Provider) IsAutoUpdate() bool {
|
||||
}
|
||||
|
||||
func (p *Provider) User() idp.User {
|
||||
return p.userMapper()
|
||||
return p.user()
|
||||
}
|
||||
|
@ -51,7 +51,7 @@ func (s *Session) PersistentParameters() map[string]any {
|
||||
// FetchUser implements the [idp.Session] interface.
|
||||
// It will execute an OAuth 2.0 code exchange if needed to retrieve the access token,
|
||||
// call the specified userEndpoint and map the received information into an [idp.User].
|
||||
func (s *Session) FetchUser(ctx context.Context) (user idp.User, err error) {
|
||||
func (s *Session) FetchUser(ctx context.Context) (_ idp.User, err error) {
|
||||
if s.Tokens == nil {
|
||||
if err = s.authorize(ctx); err != nil {
|
||||
return nil, err
|
||||
@ -62,11 +62,11 @@ func (s *Session) FetchUser(ctx context.Context) (user idp.User, err error) {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("authorization", s.Tokens.TokenType+" "+s.Tokens.AccessToken)
|
||||
mapper := s.Provider.User()
|
||||
if err := httphelper.HttpRequest(s.Provider.RelyingParty.HttpClient(), req, &mapper); err != nil {
|
||||
user := s.Provider.User()
|
||||
if err := httphelper.HttpRequest(s.Provider.RelyingParty.HttpClient(), req, &user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return mapper, nil
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *Session) authorize(ctx context.Context) (err error) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user