mirror of
https://github.com/zitadel/zitadel.git
synced 2025-03-01 08:17:23 +00:00
fix: added changed from review
This commit is contained in:
parent
a226de975b
commit
784b3afb4f
@ -3,23 +3,19 @@ package middleware
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/zitadel/logging"
|
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
"github.com/zitadel/zitadel/internal/api/authz"
|
"github.com/zitadel/zitadel/internal/api/authz"
|
||||||
"github.com/zitadel/zitadel/internal/domain"
|
|
||||||
"github.com/zitadel/zitadel/internal/execution"
|
"github.com/zitadel/zitadel/internal/execution"
|
||||||
"github.com/zitadel/zitadel/internal/query"
|
"github.com/zitadel/zitadel/internal/query"
|
||||||
exec_repo "github.com/zitadel/zitadel/internal/repository/execution"
|
|
||||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||||
"github.com/zitadel/zitadel/internal/zerrors"
|
"github.com/zitadel/zitadel/internal/zerrors"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ExecutionHandler(queries *query.Queries) grpc.UnaryServerInterceptor {
|
func ExecutionHandler(queries *query.Queries) grpc.UnaryServerInterceptor {
|
||||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||||||
requestTargets, responseTargets := queryTargets(ctx, queries, info.FullMethod)
|
requestTargets, responseTargets := execution.QueryExecutionTargetsForRequestAndResponse(ctx, queries, info.FullMethod)
|
||||||
|
|
||||||
// call targets otherwise return req
|
// call targets otherwise return req
|
||||||
handledReq, err := executeTargetsForRequest(ctx, requestTargets, info.FullMethod, req)
|
handledReq, err := executeTargetsForRequest(ctx, requestTargets, info.FullMethod, req)
|
||||||
@ -81,49 +77,6 @@ func executeTargetsForResponse(ctx context.Context, targets []execution.Target,
|
|||||||
return execution.CallTargets(ctx, targets, info)
|
return execution.CallTargets(ctx, targets, info)
|
||||||
}
|
}
|
||||||
|
|
||||||
type ExecutionQueries interface {
|
|
||||||
TargetsByExecutionIDs(ctx context.Context, ids1, ids2 []string) (execution []*query.ExecutionTarget, err error)
|
|
||||||
}
|
|
||||||
|
|
||||||
func queryTargets(
|
|
||||||
ctx context.Context,
|
|
||||||
queries ExecutionQueries,
|
|
||||||
fullMethod string,
|
|
||||||
) ([]execution.Target, []execution.Target) {
|
|
||||||
ctx, span := tracing.NewSpan(ctx)
|
|
||||||
defer span.End()
|
|
||||||
|
|
||||||
targets, err := queries.TargetsByExecutionIDs(ctx,
|
|
||||||
idsForFullMethod(fullMethod, domain.ExecutionTypeRequest),
|
|
||||||
idsForFullMethod(fullMethod, domain.ExecutionTypeResponse),
|
|
||||||
)
|
|
||||||
requestTargets := make([]execution.Target, 0, len(targets))
|
|
||||||
responseTargets := make([]execution.Target, 0, len(targets))
|
|
||||||
if err != nil {
|
|
||||||
logging.WithFields("fullMethod", fullMethod).WithError(err).Info("unable to query targets")
|
|
||||||
return requestTargets, responseTargets
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, target := range targets {
|
|
||||||
if strings.HasPrefix(target.GetExecutionID(), exec_repo.IDAll(domain.ExecutionTypeRequest)) {
|
|
||||||
requestTargets = append(requestTargets, target)
|
|
||||||
} else if strings.HasPrefix(target.GetExecutionID(), exec_repo.IDAll(domain.ExecutionTypeResponse)) {
|
|
||||||
responseTargets = append(responseTargets, target)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return requestTargets, responseTargets
|
|
||||||
}
|
|
||||||
|
|
||||||
func idsForFullMethod(fullMethod string, executionType domain.ExecutionType) []string {
|
|
||||||
return []string{exec_repo.ID(executionType, fullMethod), exec_repo.ID(executionType, serviceFromFullMethod(fullMethod)), exec_repo.IDAll(executionType)}
|
|
||||||
}
|
|
||||||
|
|
||||||
func serviceFromFullMethod(s string) string {
|
|
||||||
parts := strings.Split(s, "/")
|
|
||||||
return parts[1]
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ execution.ContextInfo = &ContextInfoRequest{}
|
var _ execution.ContextInfo = &ContextInfoRequest{}
|
||||||
|
|
||||||
type ContextInfoRequest struct {
|
type ContextInfoRequest struct {
|
||||||
|
@ -44,31 +44,31 @@ func (s *Server) StartIdentityProviderIntent(ctx context.Context, req *user.Star
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) startIDPIntent(ctx context.Context, idpID string, urls *user.RedirectURLs) (*user.StartIdentityProviderIntentResponse, error) {
|
func (s *Server) startIDPIntent(ctx context.Context, idpID string, urls *user.RedirectURLs) (*user.StartIdentityProviderIntentResponse, error) {
|
||||||
intentWriteModel, details, err := s.command.CreateIntent(ctx, idpID, urls.GetSuccessUrl(), urls.GetFailureUrl(), authz.GetInstance(ctx).InstanceID())
|
state, session, err := s.command.AuthFromProvider(ctx, idpID, s.idpCallback(ctx), s.samlRootURL(ctx, idpID))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
content, redirect, err := s.command.AuthFromProvider(ctx, idpID, intentWriteModel.AggregateID, s.idpCallback(ctx), s.samlRootURL(ctx, idpID))
|
_, details, err := s.command.CreateIntent(ctx, state, idpID, urls.GetSuccessUrl(), urls.GetFailureUrl(), authz.GetInstance(ctx).InstanceID(), session.PersistentParameters())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
content, redirect := session.GetAuth(ctx)
|
||||||
if redirect {
|
if redirect {
|
||||||
return &user.StartIdentityProviderIntentResponse{
|
return &user.StartIdentityProviderIntentResponse{
|
||||||
Details: object.DomainToDetailsPb(details),
|
Details: object.DomainToDetailsPb(details),
|
||||||
NextStep: &user.StartIdentityProviderIntentResponse_AuthUrl{AuthUrl: content},
|
NextStep: &user.StartIdentityProviderIntentResponse_AuthUrl{AuthUrl: content},
|
||||||
}, nil
|
}, nil
|
||||||
} else {
|
|
||||||
return &user.StartIdentityProviderIntentResponse{
|
|
||||||
Details: object.DomainToDetailsPb(details),
|
|
||||||
NextStep: &user.StartIdentityProviderIntentResponse_PostForm{
|
|
||||||
PostForm: []byte(content),
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
return &user.StartIdentityProviderIntentResponse{
|
||||||
|
Details: object.DomainToDetailsPb(details),
|
||||||
|
NextStep: &user.StartIdentityProviderIntentResponse_PostForm{
|
||||||
|
PostForm: []byte(content),
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) startLDAPIntent(ctx context.Context, idpID string, ldapCredentials *user.LDAPCredentials) (*user.StartIdentityProviderIntentResponse, error) {
|
func (s *Server) startLDAPIntent(ctx context.Context, idpID string, ldapCredentials *user.LDAPCredentials) (*user.StartIdentityProviderIntentResponse, error) {
|
||||||
intentWriteModel, details, err := s.command.CreateIntent(ctx, idpID, "", "", authz.GetInstance(ctx).InstanceID())
|
intentWriteModel, details, err := s.command.CreateIntent(ctx, "", idpID, "", "", authz.GetInstance(ctx).InstanceID(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -412,62 +412,64 @@ func (s *Server) userinfoFlows(ctx context.Context, qu *query.OIDCUserInfo, user
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function := ""
|
var function string
|
||||||
switch triggerType {
|
switch triggerType {
|
||||||
case domain.TriggerTypePreUserinfoCreation:
|
case domain.TriggerTypePreUserinfoCreation:
|
||||||
function = exec_repo.ID(domain.ExecutionTypeFunction, domain.ActionFunctionPreUserinfo.LocalizationKey())
|
function = exec_repo.ID(domain.ExecutionTypeFunction, domain.ActionFunctionPreUserinfo.LocalizationKey())
|
||||||
case domain.TriggerTypePreAccessTokenCreation:
|
case domain.TriggerTypePreAccessTokenCreation:
|
||||||
function = exec_repo.ID(domain.ExecutionTypeFunction, domain.ActionFunctionPreAccessToken.LocalizationKey())
|
function = exec_repo.ID(domain.ExecutionTypeFunction, domain.ActionFunctionPreAccessToken.LocalizationKey())
|
||||||
case domain.TriggerTypeUnspecified, domain.TriggerTypePostAuthentication, domain.TriggerTypePreCreation, domain.TriggerTypePostCreation, domain.TriggerTypePreSAMLResponseCreation:
|
case domain.TriggerTypeUnspecified, domain.TriggerTypePostAuthentication, domain.TriggerTypePreCreation, domain.TriggerTypePostCreation, domain.TriggerTypePreSAMLResponseCreation:
|
||||||
fallthrough
|
// added for linting, there should never be any trigger type be used here besides PreUserinfo and PreAccessToken
|
||||||
default:
|
return
|
||||||
function = ""
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if function != "" {
|
if function == "" {
|
||||||
executionTargets, err := queryExecutionTargets(ctx, s.query, function)
|
return nil
|
||||||
if err != nil {
|
}
|
||||||
return err
|
executionTargets, err := execution.QueryExecutionTargetsForFunction(ctx, s.query, function)
|
||||||
}
|
if err != nil {
|
||||||
info := &ContextInfo{
|
return err
|
||||||
Function: function,
|
}
|
||||||
UserInfo: userInfo,
|
info := &ContextInfo{
|
||||||
User: qu.User,
|
Function: function,
|
||||||
UserMetadata: qu.Metadata,
|
UserInfo: userInfo,
|
||||||
Org: qu.Org,
|
User: qu.User,
|
||||||
UserGrants: qu.UserGrants,
|
UserMetadata: qu.Metadata,
|
||||||
}
|
Org: qu.Org,
|
||||||
|
UserGrants: qu.UserGrants,
|
||||||
|
}
|
||||||
|
|
||||||
resp, err := execution.CallTargets(ctx, executionTargets, info)
|
resp, err := execution.CallTargets(ctx, executionTargets, info)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
contextInfoResponse, ok := resp.(*ContextInfoResponse)
|
contextInfoResponse, ok := resp.(*ContextInfoResponse)
|
||||||
if ok && contextInfoResponse != nil {
|
if !ok || contextInfoResponse == nil {
|
||||||
claimLogs := make([]string, 0)
|
return nil
|
||||||
for _, metadata := range contextInfoResponse.SetUserMetadata {
|
}
|
||||||
if _, err = s.command.SetUserMetadata(ctx, metadata, userInfo.Subject, qu.User.ResourceOwner); err != nil {
|
claimLogs := make([]string, 0)
|
||||||
claimLogs = append(claimLogs, fmt.Sprintf("failed to set user metadata key %q", metadata.Key))
|
for _, metadata := range contextInfoResponse.SetUserMetadata {
|
||||||
}
|
if _, err = s.command.SetUserMetadata(ctx, metadata, userInfo.Subject, qu.User.ResourceOwner); err != nil {
|
||||||
}
|
claimLogs = append(claimLogs, fmt.Sprintf("failed to set user metadata key %q", metadata.Key))
|
||||||
for _, claim := range contextInfoResponse.AppendClaims {
|
|
||||||
if strings.HasPrefix(claim.Key, ClaimPrefix) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if userInfo.Claims[claim.Key] == nil {
|
|
||||||
userInfo.AppendClaims(claim.Key, claim.Value)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
claimLogs = append(claimLogs, fmt.Sprintf("key %q already exists", claim.Key))
|
|
||||||
}
|
|
||||||
for _, log := range contextInfoResponse.AppendLogClaims {
|
|
||||||
claimLogs = append(claimLogs, log)
|
|
||||||
}
|
|
||||||
if len(claimLogs) > 0 {
|
|
||||||
userInfo.AppendClaims(fmt.Sprintf(ClaimActionLogFormat, function), claimLogs)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
for _, claim := range contextInfoResponse.AppendClaims {
|
||||||
|
if strings.HasPrefix(claim.Key, ClaimPrefix) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if userInfo.Claims[claim.Key] == nil {
|
||||||
|
userInfo.AppendClaims(claim.Key, claim.Value)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
claimLogs = append(claimLogs, fmt.Sprintf("key %q already exists", claim.Key))
|
||||||
|
}
|
||||||
|
for _, log := range contextInfoResponse.AppendLogClaims {
|
||||||
|
claimLogs = append(claimLogs, log)
|
||||||
|
}
|
||||||
|
if len(claimLogs) > 0 {
|
||||||
|
userInfo.AppendClaims(fmt.Sprintf(ClaimActionLogFormat, function), claimLogs)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -510,32 +512,6 @@ func (c *ContextInfo) SetHTTPResponseBody(resp []byte) error {
|
|||||||
return json.Unmarshal(resp, c.Response)
|
return json.Unmarshal(resp, c.Response)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ContextInfo) GetContent() interface{} {
|
func (c *ContextInfo) GetContent() any {
|
||||||
return c.Response
|
return c.Response
|
||||||
}
|
}
|
||||||
|
|
||||||
func queryExecutionTargets(ctx context.Context, query *query.Queries, function string) ([]execution.Target, error) {
|
|
||||||
queriedActionsV2, err := query.TargetsByExecutionID(ctx, []string{function})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
executionTargets := make([]execution.Target, len(queriedActionsV2))
|
|
||||||
for i, action := range queriedActionsV2 {
|
|
||||||
executionTargets[i] = action
|
|
||||||
}
|
|
||||||
return executionTargets, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) queryOrgMetadata(ctx context.Context, organizationID string) ([]*query.OrgMetadata, error) {
|
|
||||||
metadata, err := s.query.SearchOrgMetadata(
|
|
||||||
ctx,
|
|
||||||
true,
|
|
||||||
organizationID,
|
|
||||||
&query.OrgMetadataSearchQueries{},
|
|
||||||
false,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return metadata.Metadata, nil
|
|
||||||
}
|
|
||||||
|
@ -385,7 +385,7 @@ func (p *Storage) getCustomAttributes(ctx context.Context, user *query.User, use
|
|||||||
}
|
}
|
||||||
|
|
||||||
function := exec_repo.ID(domain.ExecutionTypeFunction, domain.ActionFunctionPreSAMLResponse.LocalizationKey())
|
function := exec_repo.ID(domain.ExecutionTypeFunction, domain.ActionFunctionPreSAMLResponse.LocalizationKey())
|
||||||
executionTargets, err := queryExecutionTargets(ctx, p.query, function)
|
executionTargets, err := execution.QueryExecutionTargetsForFunction(ctx, p.query, function)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -461,18 +461,6 @@ func (c *ContextInfo) GetContent() interface{} {
|
|||||||
return c.Response
|
return c.Response
|
||||||
}
|
}
|
||||||
|
|
||||||
func queryExecutionTargets(ctx context.Context, query *query.Queries, function string) ([]execution.Target, error) {
|
|
||||||
queriedActionsV2, err := query.TargetsByExecutionID(ctx, []string{function})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
executionTargets := make([]execution.Target, len(queriedActionsV2))
|
|
||||||
for i, action := range queriedActionsV2 {
|
|
||||||
executionTargets[i] = action
|
|
||||||
}
|
|
||||||
return executionTargets, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Storage) getGrants(ctx context.Context, userID, applicationID string) (*query.UserGrants, error) {
|
func (p *Storage) getGrants(ctx context.Context, userID, applicationID string) (*query.UserGrants, error) {
|
||||||
projectID, err := p.query.ProjectIDFromClientID(ctx, applicationID)
|
projectID, err := p.query.ProjectIDFromClientID(ctx, applicationID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -22,7 +22,6 @@ type SAMLRequest struct {
|
|||||||
Binding string
|
Binding string
|
||||||
Issuer string
|
Issuer string
|
||||||
Destination string
|
Destination string
|
||||||
EntityID string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type CurrentSAMLRequest struct {
|
type CurrentSAMLRequest struct {
|
||||||
|
@ -6,12 +6,15 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/zitadel/logging"
|
"github.com/zitadel/logging"
|
||||||
|
|
||||||
zhttp "github.com/zitadel/zitadel/internal/api/http"
|
zhttp "github.com/zitadel/zitadel/internal/api/http"
|
||||||
"github.com/zitadel/zitadel/internal/domain"
|
"github.com/zitadel/zitadel/internal/domain"
|
||||||
|
"github.com/zitadel/zitadel/internal/query"
|
||||||
|
"github.com/zitadel/zitadel/internal/repository/execution"
|
||||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||||
"github.com/zitadel/zitadel/internal/zerrors"
|
"github.com/zitadel/zitadel/internal/zerrors"
|
||||||
"github.com/zitadel/zitadel/pkg/actions"
|
"github.com/zitadel/zitadel/pkg/actions"
|
||||||
@ -153,3 +156,59 @@ type ErrorBody struct {
|
|||||||
ForwardedStatusCode int `json:"forwardedStatusCode,omitempty"`
|
ForwardedStatusCode int `json:"forwardedStatusCode,omitempty"`
|
||||||
ForwardedErrorMessage string `json:"forwardedErrorMessage,omitempty"`
|
ForwardedErrorMessage string `json:"forwardedErrorMessage,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ExecutionTargetsQueries interface {
|
||||||
|
TargetsByExecutionID(ctx context.Context, ids []string) (execution []*query.ExecutionTarget, err error)
|
||||||
|
TargetsByExecutionIDs(ctx context.Context, ids1, ids2 []string) (execution []*query.ExecutionTarget, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func QueryExecutionTargetsForRequestAndResponse(
|
||||||
|
ctx context.Context,
|
||||||
|
queries ExecutionTargetsQueries,
|
||||||
|
fullMethod string,
|
||||||
|
) ([]Target, []Target) {
|
||||||
|
ctx, span := tracing.NewSpan(ctx)
|
||||||
|
defer span.End()
|
||||||
|
|
||||||
|
targets, err := queries.TargetsByExecutionIDs(ctx,
|
||||||
|
idsForFullMethod(fullMethod, domain.ExecutionTypeRequest),
|
||||||
|
idsForFullMethod(fullMethod, domain.ExecutionTypeResponse),
|
||||||
|
)
|
||||||
|
requestTargets := make([]Target, 0, len(targets))
|
||||||
|
responseTargets := make([]Target, 0, len(targets))
|
||||||
|
if err != nil {
|
||||||
|
logging.WithFields("fullMethod", fullMethod).WithError(err).Info("unable to query targets")
|
||||||
|
return requestTargets, responseTargets
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, target := range targets {
|
||||||
|
if strings.HasPrefix(target.GetExecutionID(), execution.IDAll(domain.ExecutionTypeRequest)) {
|
||||||
|
requestTargets = append(requestTargets, target)
|
||||||
|
} else if strings.HasPrefix(target.GetExecutionID(), execution.IDAll(domain.ExecutionTypeResponse)) {
|
||||||
|
responseTargets = append(responseTargets, target)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return requestTargets, responseTargets
|
||||||
|
}
|
||||||
|
|
||||||
|
func idsForFullMethod(fullMethod string, executionType domain.ExecutionType) []string {
|
||||||
|
return []string{execution.ID(executionType, fullMethod), execution.ID(executionType, serviceFromFullMethod(fullMethod)), execution.IDAll(executionType)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func serviceFromFullMethod(s string) string {
|
||||||
|
parts := strings.Split(s, "/")
|
||||||
|
return parts[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
func QueryExecutionTargetsForFunction(ctx context.Context, query ExecutionTargetsQueries, function string) ([]Target, error) {
|
||||||
|
queriedActionsV2, err := query.TargetsByExecutionID(ctx, []string{function})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
executionTargets := make([]Target, len(queriedActionsV2))
|
||||||
|
for i, action := range queriedActionsV2 {
|
||||||
|
executionTargets[i] = action
|
||||||
|
}
|
||||||
|
return executionTargets, nil
|
||||||
|
}
|
||||||
|
@ -17,9 +17,9 @@ type Provider struct {
|
|||||||
rp.RelyingParty
|
rp.RelyingParty
|
||||||
options []rp.Option
|
options []rp.Option
|
||||||
name string
|
name string
|
||||||
userEndpoint string
|
userEndpoint string
|
||||||
userMapper func() idp.User
|
user func() idp.User
|
||||||
isLinkingAllowed bool
|
isLinkingAllowed bool
|
||||||
isCreationAllowed bool
|
isCreationAllowed bool
|
||||||
isAutoCreation bool
|
isAutoCreation bool
|
||||||
isAutoUpdate bool
|
isAutoUpdate bool
|
||||||
@ -65,11 +65,11 @@ func WithRelyingPartyOption(option rp.Option) ProviderOpts {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// New creates a generic OAuth 2.0 provider
|
// New creates a generic OAuth 2.0 provider
|
||||||
func New(config *oauth2.Config, name, userEndpoint string, userMapper func() idp.User, options ...ProviderOpts) (provider *Provider, err error) {
|
func New(config *oauth2.Config, name, userEndpoint string, user func() idp.User, options ...ProviderOpts) (provider *Provider, err error) {
|
||||||
provider = &Provider{
|
provider = &Provider{
|
||||||
name: name,
|
name: name,
|
||||||
userEndpoint: userEndpoint,
|
userEndpoint: userEndpoint,
|
||||||
userMapper: userMapper,
|
user: user,
|
||||||
generateVerifier: oauth2.GenerateVerifier,
|
generateVerifier: oauth2.GenerateVerifier,
|
||||||
}
|
}
|
||||||
for _, option := range options {
|
for _, option := range options {
|
||||||
@ -139,5 +139,5 @@ func (p *Provider) IsAutoUpdate() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *Provider) User() idp.User {
|
func (p *Provider) User() idp.User {
|
||||||
return p.userMapper()
|
return p.user()
|
||||||
}
|
}
|
||||||
|
@ -51,7 +51,7 @@ func (s *Session) PersistentParameters() map[string]any {
|
|||||||
// FetchUser implements the [idp.Session] interface.
|
// FetchUser implements the [idp.Session] interface.
|
||||||
// It will execute an OAuth 2.0 code exchange if needed to retrieve the access token,
|
// It will execute an OAuth 2.0 code exchange if needed to retrieve the access token,
|
||||||
// call the specified userEndpoint and map the received information into an [idp.User].
|
// call the specified userEndpoint and map the received information into an [idp.User].
|
||||||
func (s *Session) FetchUser(ctx context.Context) (user idp.User, err error) {
|
func (s *Session) FetchUser(ctx context.Context) (_ idp.User, err error) {
|
||||||
if s.Tokens == nil {
|
if s.Tokens == nil {
|
||||||
if err = s.authorize(ctx); err != nil {
|
if err = s.authorize(ctx); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -62,11 +62,11 @@ func (s *Session) FetchUser(ctx context.Context) (user idp.User, err error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
req.Header.Set("authorization", s.Tokens.TokenType+" "+s.Tokens.AccessToken)
|
req.Header.Set("authorization", s.Tokens.TokenType+" "+s.Tokens.AccessToken)
|
||||||
mapper := s.Provider.User()
|
user := s.Provider.User()
|
||||||
if err := httphelper.HttpRequest(s.Provider.RelyingParty.HttpClient(), req, &mapper); err != nil {
|
if err := httphelper.HttpRequest(s.Provider.RelyingParty.HttpClient(), req, &user); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return mapper, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) authorize(ctx context.Context) (err error) {
|
func (s *Session) authorize(ctx context.Context) (err error) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user