fix: add spans in auth requests (#6368)

Co-authored-by: Livio Spring <livio.a@gmail.com>
This commit is contained in:
Silvan 2023-08-18 09:21:31 +02:00 committed by GitHub
parent 52f68f8db8
commit 6672dcd87d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 37 additions and 14 deletions

View File

@ -151,7 +151,7 @@ func (repo *AuthRequestRepo) CreateAuthRequest(ctx context.Context, request *dom
logging.WithFields("login name", request.LoginHint, "id", request.ID, "applicationID", request.ApplicationID, "traceID", tracing.TraceIDFromCtx(ctx)).OnError(err).Info("login hint invalid")
}
if request.UserID == "" && request.LoginHint == "" && domain.IsPrompt(request.Prompt, domain.PromptNone) {
err = repo.tryUsingOnlyUserSession(request)
err = repo.tryUsingOnlyUserSession(ctx, request)
logging.WithFields("id", request.ID, "applicationID", request.ApplicationID, "traceID", tracing.TraceIDFromCtx(ctx)).OnError(err).Debug("unable to select only user session")
}
@ -592,7 +592,7 @@ func (repo *AuthRequestRepo) getAuthRequestEnsureUser(ctx context.Context, authR
// If there's no user, checks if the user could be reused (from the session).
// (the nextStepsUser will update the userID in the request in that case)
if request.UserID == "" {
if _, err = repo.nextStepsUser(request); err != nil {
if _, err = repo.nextStepsUser(ctx, request); err != nil {
return nil, err
}
}
@ -606,8 +606,11 @@ func (repo *AuthRequestRepo) getAuthRequestEnsureUser(ctx context.Context, authR
return request, nil
}
func (repo *AuthRequestRepo) getAuthRequest(ctx context.Context, id, userAgentID string) (*domain.AuthRequest, error) {
request, err := repo.AuthRequests.GetAuthRequestByID(ctx, id)
func (repo *AuthRequestRepo) getAuthRequest(ctx context.Context, id, userAgentID string) (request *domain.AuthRequest, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
request, err = repo.AuthRequests.GetAuthRequestByID(ctx, id)
if err != nil {
return nil, err
}
@ -693,8 +696,11 @@ func (repo *AuthRequestRepo) fillPolicies(ctx context.Context, request *domain.A
return nil
}
func (repo *AuthRequestRepo) tryUsingOnlyUserSession(request *domain.AuthRequest) error {
userSessions, err := userSessionsByUserAgentID(repo.UserSessionViewProvider, request.AgentID, request.InstanceID)
func (repo *AuthRequestRepo) tryUsingOnlyUserSession(ctx context.Context, request *domain.AuthRequest) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
userSessions, err := userSessionsByUserAgentID(ctx, repo.UserSessionViewProvider, request.AgentID, request.InstanceID)
if err != nil {
return err
}
@ -964,6 +970,9 @@ func (repo *AuthRequestRepo) checkExternalUserLogin(ctx context.Context, request
//nolint:gocognit
func (repo *AuthRequestRepo) nextSteps(ctx context.Context, request *domain.AuthRequest, checkLoggedIn bool) (steps []domain.NextStep, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
if request == nil {
return nil, errors.ThrowInvalidArgument(nil, "EVENT-ds27a", "Errors.Internal")
}
@ -972,7 +981,7 @@ func (repo *AuthRequestRepo) nextSteps(ctx context.Context, request *domain.Auth
return append(steps, &domain.RedirectToCallbackStep{}), nil
}
if request.UserID == "" {
steps, err = repo.nextStepsUser(request)
steps, err = repo.nextStepsUser(ctx, request)
if err != nil || len(steps) > 0 {
return steps, err
}
@ -1066,7 +1075,10 @@ func (repo *AuthRequestRepo) nextSteps(ctx context.Context, request *domain.Auth
return append(steps, &domain.RedirectToCallbackStep{}), nil
}
func (repo *AuthRequestRepo) nextStepsUser(request *domain.AuthRequest) ([]domain.NextStep, error) {
func (repo *AuthRequestRepo) nextStepsUser(ctx context.Context, request *domain.AuthRequest) (_ []domain.NextStep, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
steps := make([]domain.NextStep, 0)
if request.LinkingUsers != nil && len(request.LinkingUsers) > 0 {
steps = append(steps, new(domain.ExternalNotFoundOptionStep))
@ -1081,7 +1093,7 @@ func (repo *AuthRequestRepo) nextStepsUser(request *domain.AuthRequest) ([]domai
} else {
// if no user was specified, no prompt or select_account was provided,
// then check the active user sessions (of the user agent)
users, err := repo.usersForUserSelection(request)
users, err := repo.usersForUserSelection(ctx, request)
if err != nil {
return nil, err
}
@ -1115,8 +1127,8 @@ func checkExternalIDPsOfUser(ctx context.Context, idpUserLinksProvider idpUserLi
return idpUserLinksProvider.IDPUserLinks(ctx, &query.IDPUserLinksSearchQuery{Queries: []query.SearchQuery{userIDQuery}}, false)
}
func (repo *AuthRequestRepo) usersForUserSelection(request *domain.AuthRequest) ([]domain.UserSelection, error) {
userSessions, err := userSessionsByUserAgentID(repo.UserSessionViewProvider, request.AgentID, request.InstanceID)
func (repo *AuthRequestRepo) usersForUserSelection(ctx context.Context, request *domain.AuthRequest) ([]domain.UserSelection, error) {
userSessions, err := userSessionsByUserAgentID(ctx, repo.UserSessionViewProvider, request.AgentID, request.InstanceID)
if err != nil {
return nil, err
}
@ -1384,7 +1396,11 @@ func checkVerificationTime(verificationTime time.Time, lifetime time.Duration) b
return verificationTime.Add(lifetime).After(time.Now().UTC())
}
func userSessionsByUserAgentID(provider userSessionViewProvider, agentID, instanceID string) ([]*user_model.UserSessionView, error) {
func userSessionsByUserAgentID(ctx context.Context, provider userSessionViewProvider, agentID, instanceID string) (_ []*user_model.UserSessionView, err error) {
//nolint
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
session, err := provider.UserSessionsByAgentID(agentID, instanceID)
if err != nil {
return nil, err
@ -1505,7 +1521,10 @@ func activeUserByID(ctx context.Context, userViewProvider userViewProvider, user
return user, nil
}
func userByID(ctx context.Context, viewProvider userViewProvider, eventProvider userEventProvider, userID string) (*user_model.UserView, error) {
func userByID(ctx context.Context, viewProvider userViewProvider, eventProvider userEventProvider, userID string) (_ *user_model.UserView, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
user, viewErr := viewProvider.UserByID(userID, authz.GetInstance(ctx).InstanceID())
if viewErr != nil && !errors.IsNotFound(viewErr) {
return nil, viewErr

View File

@ -7,6 +7,7 @@ import (
"github.com/zitadel/zitadel/internal/eventstore/v1/internal/repository"
z_sql "github.com/zitadel/zitadel/internal/eventstore/v1/internal/repository/sql"
"github.com/zitadel/zitadel/internal/eventstore/v1/models"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
type Eventstore interface {
@ -28,7 +29,10 @@ func Start(db *database.DB, allowOrderByCreationDate bool) (Eventstore, error) {
}, nil
}
func (es *eventstore) FilterEvents(ctx context.Context, searchQuery *models.SearchQuery) ([]*models.Event, error) {
func (es *eventstore) FilterEvents(ctx context.Context, searchQuery *models.SearchQuery) (_ []*models.Event, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
if err := searchQuery.Validate(); err != nil {
return nil, err
}