zitadel/internal/api/oidc/auth_request.go
Livio Spring d705cb11b7
fix: error handling to prevent panics (#8248)
# Which Problems Are Solved

We found multiple cases where either the error was not properly handled,
which led to panics.

# How the Problems Are Solved

Handle the errors.

# Additional Changes

None.

# Additional Context

- noticed internally
2024-07-04 14:11:06 +00:00

594 lines
19 KiB
Go

package oidc
import (
"context"
"encoding/base64"
"fmt"
"net/http"
"slices"
"strings"
"time"
"github.com/zitadel/logging"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op"
"github.com/zitadel/zitadel/internal/api/authz"
http_utils "github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/api/http/middleware"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/user/model"
"github.com/zitadel/zitadel/internal/zerrors"
)
const (
LoginClientHeader = "x-zitadel-login-client"
)
func (o *OPStorage) CreateAuthRequest(ctx context.Context, req *oidc.AuthRequest, userID string) (_ op.AuthRequest, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() {
err = oidcError(err)
span.EndWithError(err)
}()
headers, _ := http_utils.HeadersFromCtx(ctx)
if loginClient := headers.Get(LoginClientHeader); loginClient != "" {
return o.createAuthRequestLoginClient(ctx, req, userID, loginClient)
}
return o.createAuthRequest(ctx, req, userID)
}
func (o *OPStorage) createAuthRequestScopeAndAudience(ctx context.Context, clientID string, reqScope []string) (scope, audience []string, err error) {
project, err := o.query.ProjectByClientID(ctx, clientID)
if err != nil {
return nil, nil, err
}
scope, err = o.assertProjectRoleScopesByProject(ctx, project, reqScope)
if err != nil {
return nil, nil, err
}
audience, err = o.audienceFromProjectID(ctx, project.ID)
audience = domain.AddAudScopeToAudience(ctx, audience, scope)
if err != nil {
return nil, nil, err
}
return scope, audience, nil
}
func (o *OPStorage) createAuthRequestLoginClient(ctx context.Context, req *oidc.AuthRequest, hintUserID, loginClient string) (op.AuthRequest, error) {
scope, audience, err := o.createAuthRequestScopeAndAudience(ctx, req.ClientID, req.Scopes)
if err != nil {
return nil, err
}
authRequest := &command.AuthRequest{
LoginClient: loginClient,
ClientID: req.ClientID,
RedirectURI: req.RedirectURI,
State: req.State,
Nonce: req.Nonce,
Scope: scope,
Audience: audience,
NeedRefreshToken: slices.Contains(scope, oidc.ScopeOfflineAccess),
ResponseType: ResponseTypeToBusiness(req.ResponseType),
ResponseMode: ResponseModeToBusiness(req.ResponseMode),
CodeChallenge: CodeChallengeToBusiness(req.CodeChallenge, req.CodeChallengeMethod),
Prompt: PromptToBusiness(req.Prompt),
UILocales: UILocalesToBusiness(req.UILocales),
MaxAge: MaxAgeToBusiness(req.MaxAge),
}
if req.LoginHint != "" {
authRequest.LoginHint = &req.LoginHint
}
if hintUserID != "" {
authRequest.HintUserID = &hintUserID
}
aar, err := o.command.AddAuthRequest(ctx, authRequest)
if err != nil {
return nil, err
}
return &AuthRequestV2{aar}, nil
}
func (o *OPStorage) createAuthRequest(ctx context.Context, req *oidc.AuthRequest, userID string) (_ op.AuthRequest, err error) {
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
if !ok {
return nil, zerrors.ThrowPreconditionFailed(nil, "OIDC-sd436", "no user agent id")
}
scope, audience, err := o.createAuthRequestScopeAndAudience(ctx, req.ClientID, req.Scopes)
if err != nil {
return nil, err
}
req.Scopes = scope
authRequest := CreateAuthRequestToBusiness(ctx, req, userAgentID, userID, audience)
resp, err := o.repo.CreateAuthRequest(ctx, authRequest)
if err != nil {
return nil, err
}
return AuthRequestFromBusiness(resp)
}
func (o *OPStorage) audienceFromProjectID(ctx context.Context, projectID string) ([]string, error) {
projectIDQuery, err := query.NewAppProjectIDSearchQuery(projectID)
if err != nil {
return nil, err
}
appIDs, err := o.query.SearchClientIDs(ctx, &query.AppSearchQueries{Queries: []query.SearchQuery{projectIDQuery}}, true)
if err != nil {
return nil, err
}
return append(appIDs, projectID), nil
}
func (o *OPStorage) AuthRequestByID(ctx context.Context, id string) (_ op.AuthRequest, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() {
err = oidcError(err)
span.EndWithError(err)
}()
if strings.HasPrefix(id, command.IDPrefixV2) {
req, err := o.command.GetCurrentAuthRequest(ctx, id)
if err != nil {
return nil, err
}
return &AuthRequestV2{req}, nil
}
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
if !ok {
return nil, zerrors.ThrowPreconditionFailed(nil, "OIDC-D3g21", "no user agent id")
}
resp, err := o.repo.AuthRequestByIDCheckLoggedIn(ctx, id, userAgentID)
if err != nil {
return nil, err
}
return AuthRequestFromBusiness(resp)
}
func (o *OPStorage) AuthRequestByCode(ctx context.Context, code string) (_ op.AuthRequest, err error) {
panic(o.panicErr("AuthRequestByCode"))
}
// decryptGrant decrypts a code or refresh_token
func (o *OPStorage) decryptGrant(grant string) (string, error) {
decodedGrant, err := base64.RawURLEncoding.DecodeString(grant)
if err != nil {
return "", err
}
return o.encAlg.DecryptString(decodedGrant, o.encAlg.EncryptionKeyID())
}
func (o *OPStorage) SaveAuthCode(ctx context.Context, id, code string) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() {
err = oidcError(err)
span.EndWithError(err)
}()
if strings.HasPrefix(id, command.IDPrefixV2) {
return o.command.AddAuthRequestCode(ctx, id, code)
}
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
if !ok {
return zerrors.ThrowPreconditionFailed(nil, "OIDC-Dgus2", "no user agent id")
}
return o.repo.SaveAuthCode(ctx, id, code, userAgentID)
}
func (o *OPStorage) DeleteAuthRequest(ctx context.Context, id string) (err error) {
panic(o.panicErr("DeleteAuthRequest"))
}
func (o *OPStorage) CreateAccessToken(ctx context.Context, req op.TokenRequest) (string, time.Time, error) {
panic(o.panicErr("CreateAccessToken"))
}
func (o *OPStorage) CreateAccessAndRefreshTokens(context.Context, op.TokenRequest, string) (string, string, time.Time, error) {
panic(o.panicErr("CreateAccessAndRefreshTokens"))
}
func (*OPStorage) panicErr(method string) error {
return fmt.Errorf("OPStorage.%s should not be called anymore. This is a bug. Please report https://github.com/zitadel/zitadel/issues", method)
}
func (o *OPStorage) TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (_ op.RefreshTokenRequest, err error) {
panic("TokenRequestByRefreshToken should not be called anymore. This is a bug. Please report https://github.com/zitadel/zitadel/issues")
}
func (o *OPStorage) TerminateSession(ctx context.Context, userID, clientID string) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() {
err = oidcError(err)
span.EndWithError(err)
}()
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
if !ok {
logging.Error("no user agent id")
return zerrors.ThrowPreconditionFailed(nil, "OIDC-fso7F", "no user agent id")
}
userIDs, err := o.repo.UserSessionUserIDsByAgentID(ctx, userAgentID)
if err != nil {
logging.WithError(err).Error("error retrieving user sessions")
return err
}
if len(userIDs) == 0 {
return nil
}
data := authz.CtxData{
UserID: userID,
}
err = o.command.HumansSignOut(authz.SetCtxData(ctx, data), userAgentID, userIDs)
logging.OnError(err).Error("error signing out")
return err
}
func (o *OPStorage) TerminateSessionFromRequest(ctx context.Context, endSessionRequest *op.EndSessionRequest) (redirectURI string, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() {
err = oidcError(err)
span.EndWithError(err)
}()
// check for the login client header
headers, _ := http_utils.HeadersFromCtx(ctx)
// in case there is no id_token_hint, redirect to the UI and let it decide which session to terminate
if headers.Get(LoginClientHeader) != "" && endSessionRequest.IDTokenHintClaims == nil {
return o.defaultLogoutURLV2 + endSessionRequest.RedirectURI, nil
}
// If there is no login client header and no id_token_hint or the id_token_hint does not have a session ID,
// do a v1 Terminate session.
if endSessionRequest.IDTokenHintClaims == nil || endSessionRequest.IDTokenHintClaims.SessionID == "" {
return endSessionRequest.RedirectURI, o.TerminateSession(ctx, endSessionRequest.UserID, endSessionRequest.ClientID)
}
// terminate the v2 session of the id_token_hint
_, err = o.command.TerminateSessionWithoutTokenCheck(ctx, endSessionRequest.IDTokenHintClaims.SessionID)
if err != nil {
return "", err
}
return endSessionRequest.RedirectURI, nil
}
func (o *OPStorage) RevokeToken(ctx context.Context, token, userID, clientID string) (err *oidc.Error) {
ctx, span := tracing.NewSpan(ctx)
defer func() {
// check for nil, because `err` is not an error and EndWithError would panic
if err == nil {
span.End()
return
}
span.EndWithError(err)
}()
if strings.HasPrefix(token, command.IDPrefixV2) {
err := o.command.RevokeOIDCSessionToken(ctx, token, clientID)
if err == nil {
return nil
}
if zerrors.IsPreconditionFailed(err) {
return oidc.ErrInvalidClient().WithDescription("token was not issued for this client")
}
return oidc.ErrServerError().WithParent(err)
}
return o.revokeTokenV1(ctx, token, userID, clientID)
}
func (o *OPStorage) revokeTokenV1(ctx context.Context, token, userID, clientID string) *oidc.Error {
refreshToken, err := o.repo.RefreshTokenByID(ctx, token, userID)
if err == nil {
if refreshToken.ClientID != clientID {
return oidc.ErrInvalidClient().WithDescription("token was not issued for this client")
}
_, err = o.command.RevokeRefreshToken(ctx, refreshToken.UserID, refreshToken.ResourceOwner, refreshToken.ID)
if err == nil || zerrors.IsNotFound(err) {
return nil
}
return oidc.ErrServerError().WithParent(err)
}
accessToken, err := o.repo.TokenByIDs(ctx, userID, token)
if err != nil {
if zerrors.IsNotFound(err) {
return nil
}
return oidc.ErrServerError().WithParent(err)
}
if accessToken.ApplicationID != clientID {
return oidc.ErrInvalidClient().WithDescription("token was not issued for this client")
}
_, err = o.command.RevokeAccessToken(ctx, userID, accessToken.ResourceOwner, accessToken.ID)
if err == nil || zerrors.IsNotFound(err) {
return nil
}
return oidc.ErrServerError().WithParent(err)
}
func (o *OPStorage) GetRefreshTokenInfo(ctx context.Context, clientID string, token string) (userID string, tokenID string, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() {
err = oidcError(err)
span.EndWithError(err)
}()
plainToken, err := o.decryptGrant(token)
if err != nil {
return "", "", op.ErrInvalidRefreshToken
}
if strings.HasPrefix(plainToken, command.IDPrefixV2) {
oidcSession, err := o.command.OIDCSessionByRefreshToken(ctx, plainToken)
if err != nil {
return "", "", op.ErrInvalidRefreshToken
}
return oidcSession.UserID, oidcSession.OIDCRefreshTokenID(oidcSession.RefreshTokenID), nil
}
refreshToken, err := o.repo.RefreshTokenByToken(ctx, token)
if err != nil {
return "", "", op.ErrInvalidRefreshToken
}
if refreshToken.ClientID != clientID {
return "", "", oidc.ErrInvalidClient().WithDescription("token was not issued for this client")
}
return refreshToken.UserID, refreshToken.ID, nil
}
func (o *OPStorage) assertProjectRoleScopes(ctx context.Context, clientID string, scopes []string) ([]string, error) {
for _, scope := range scopes {
if strings.HasPrefix(scope, ScopeProjectRolePrefix) {
return scopes, nil
}
}
project, err := o.query.ProjectByOIDCClientID(ctx, clientID)
if err != nil {
return nil, zerrors.ThrowPreconditionFailed(nil, "OIDC-w4wIn", "Errors.Internal")
}
if !project.ProjectRoleAssertion {
return scopes, nil
}
projectIDQuery, err := query.NewProjectRoleProjectIDSearchQuery(project.ID)
if err != nil {
return nil, zerrors.ThrowInternal(err, "OIDC-Cyc78", "Errors.Internal")
}
roles, err := o.query.SearchProjectRoles(ctx, true, &query.ProjectRoleSearchQueries{Queries: []query.SearchQuery{projectIDQuery}})
if err != nil {
return nil, err
}
for _, role := range roles.ProjectRoles {
scopes = append(scopes, ScopeProjectRolePrefix+role.Key)
}
return scopes, nil
}
func (o *OPStorage) assertProjectRoleScopesByProject(ctx context.Context, project *query.Project, scopes []string) ([]string, error) {
for _, scope := range scopes {
if strings.HasPrefix(scope, ScopeProjectRolePrefix) {
return scopes, nil
}
}
if !project.ProjectRoleAssertion {
return scopes, nil
}
projectIDQuery, err := query.NewProjectRoleProjectIDSearchQuery(project.ID)
if err != nil {
return nil, zerrors.ThrowInternal(err, "OIDC-Cyc78", "Errors.Internal")
}
roles, err := o.query.SearchProjectRoles(ctx, true, &query.ProjectRoleSearchQueries{Queries: []query.SearchQuery{projectIDQuery}})
if err != nil {
return nil, err
}
for _, role := range roles.ProjectRoles {
scopes = append(scopes, ScopeProjectRolePrefix+role.Key)
}
return scopes, nil
}
func (o *OPStorage) assertClientScopesForPAT(ctx context.Context, token *model.TokenView, clientID, projectID string) error {
token.Audience = append(token.Audience, clientID)
projectIDQuery, err := query.NewProjectRoleProjectIDSearchQuery(projectID)
if err != nil {
return zerrors.ThrowInternal(err, "OIDC-Cyc78", "Errors.Internal")
}
roles, err := o.query.SearchProjectRoles(ctx, true, &query.ProjectRoleSearchQueries{Queries: []query.SearchQuery{projectIDQuery}})
if err != nil {
return err
}
for _, role := range roles.ProjectRoles {
token.Scopes = append(token.Scopes, ScopeProjectRolePrefix+role.Key)
}
return nil
}
func setContextUserSystem(ctx context.Context) context.Context {
data := authz.CtxData{
UserID: "SYSTEM",
}
return authz.SetCtxData(ctx, data)
}
func CreateErrorCallbackURL(authReq op.AuthRequest, reason, description, uri string, authorizer op.Authorizer) (string, error) {
e := struct {
Error string `schema:"error"`
Description string `schema:"error_description,omitempty"`
URI string `schema:"error_uri,omitempty"`
State string `schema:"state,omitempty"`
}{
Error: reason,
Description: description,
URI: uri,
State: authReq.GetState(),
}
callback, err := op.AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), e, authorizer.Encoder())
if err != nil {
return "", err
}
return callback, nil
}
func CreateCodeCallbackURL(ctx context.Context, authReq op.AuthRequest, authorizer op.Authorizer) (string, error) {
code, err := op.CreateAuthRequestCode(ctx, authReq, authorizer.Storage(), authorizer.Crypto())
if err != nil {
return "", err
}
codeResponse := struct {
code string
state string
}{
code: code,
state: authReq.GetState(),
}
callback, err := op.AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), &codeResponse, authorizer.Encoder())
if err != nil {
return "", err
}
return callback, err
}
func (s *Server) CreateTokenCallbackURL(ctx context.Context, req op.AuthRequest) (string, error) {
provider := s.Provider()
opClient, err := provider.Storage().GetClientByClientID(ctx, req.GetClientID())
if err != nil {
return "", err
}
client, ok := opClient.(*Client)
if !ok {
return "", zerrors.ThrowInternal(nil, "OIDC-waeN6", "Error.Internal")
}
session, state, err := s.command.CreateOIDCSessionFromAuthRequest(
setContextUserSystem(ctx),
req.GetID(),
implicitFlowComplianceChecker(),
slices.Contains(client.GrantTypes(), oidc.GrantTypeRefreshToken),
)
if err != nil {
return "", err
}
resp, err := s.accessTokenResponseFromSession(ctx, client, session, state, client.client.ProjectID, client.client.ProjectRoleAssertion, client.client.AccessTokenRoleAssertion, client.client.IDTokenRoleAssertion, client.client.IDTokenUserinfoAssertion)
if err != nil {
return "", err
}
callback, err := op.AuthResponseURL(req.GetRedirectURI(), req.GetResponseType(), req.GetResponseMode(), resp, provider.Encoder())
if err != nil {
return "", err
}
return callback, err
}
func implicitFlowComplianceChecker() command.AuthRequestComplianceChecker {
return func(_ context.Context, authReq *command.AuthRequestWriteModel) error {
if err := authReq.CheckAuthenticated(); err != nil {
return err
}
return nil
}
}
func (s *Server) authorizeCallbackHandler(w http.ResponseWriter, r *http.Request) {
authorizer := s.Provider()
authReq, err := func(ctx context.Context) (authReq *AuthRequest, err error) {
ctx, span := tracing.NewSpan(ctx)
r = r.WithContext(ctx)
defer func() { span.EndWithError(err) }()
id, err := op.ParseAuthorizeCallbackRequest(r)
if err != nil {
return nil, err
}
authReq, err = s.getAuthRequestV1ByID(ctx, id)
if err != nil {
return nil, err
}
if !authReq.Done() {
return authReq, oidc.ErrInteractionRequired().WithDescription("Unfortunately, the user may be not logged in and/or additional interaction is required.")
}
return authReq, s.authResponse(authReq, authorizer, w, r)
}(r.Context())
if err != nil {
// we need to make sure there's no empty interface passed
if authReq == nil {
op.AuthRequestError(w, r, nil, err, authorizer)
return
}
op.AuthRequestError(w, r, authReq, err, authorizer)
}
}
func (s *Server) authResponse(authReq *AuthRequest, authorizer op.Authorizer, w http.ResponseWriter, r *http.Request) (err error) {
ctx, span := tracing.NewSpan(r.Context())
r = r.WithContext(ctx)
defer func() { span.EndWithError(err) }()
client, err := authorizer.Storage().GetClientByClientID(ctx, authReq.GetClientID())
if err != nil {
op.AuthRequestError(w, r, authReq, err, authorizer)
return err
}
if authReq.GetResponseType() == oidc.ResponseTypeCode {
op.AuthResponseCode(w, r, authReq, authorizer)
return nil
}
return s.authResponseToken(authReq, authorizer, client, w, r)
}
func (s *Server) authResponseToken(authReq *AuthRequest, authorizer op.Authorizer, opClient op.Client, w http.ResponseWriter, r *http.Request) (err error) {
ctx, span := tracing.NewSpan(r.Context())
r = r.WithContext(ctx)
defer func() { span.EndWithError(err) }()
client, ok := opClient.(*Client)
if !ok {
return zerrors.ThrowInternal(nil, "OIDC-waeN6", "Error.Internal")
}
scope := authReq.GetScopes()
session, err := s.command.CreateOIDCSession(ctx,
authReq.UserID,
authReq.UserOrgID,
client.client.ClientID,
scope,
authReq.Audience,
authReq.AuthMethods(),
authReq.AuthTime,
authReq.GetNonce(),
authReq.PreferredLanguage,
authReq.ToUserAgent(),
domain.TokenReasonAuthRequest,
nil,
slices.Contains(scope, oidc.ScopeOfflineAccess),
)
if err != nil {
op.AuthRequestError(w, r, authReq, err, authorizer)
return err
}
resp, err := s.accessTokenResponseFromSession(ctx, client, session, authReq.GetState(), client.client.ProjectID, client.client.ProjectRoleAssertion, client.client.AccessTokenRoleAssertion, client.client.IDTokenRoleAssertion, client.client.IDTokenUserinfoAssertion)
if err != nil {
op.AuthRequestError(w, r, authReq, err, authorizer)
return err
}
if authReq.GetResponseMode() == oidc.ResponseModeFormPost {
if err = op.AuthResponseFormPost(w, authReq.GetRedirectURI(), resp, authorizer.Encoder()); err != nil {
op.AuthRequestError(w, r, authReq, err, authorizer)
return err
}
return nil
}
callback, err := op.AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), resp, authorizer.Encoder())
if err != nil {
op.AuthRequestError(w, r, authReq, err, authorizer)
return err
}
http.Redirect(w, r, callback, http.StatusFound)
return nil
}