feat(api): add OIDC session service (#6157)

This PR starts the OIDC implementation for the API V2 including the Implicit and Code Flow.


Co-authored-by: Livio Spring <livio.a@gmail.com>
Co-authored-by: Tim Möhlmann <tim+github@zitadel.com>
Co-authored-by: Stefan Benz <46600784+stebenz@users.noreply.github.com>
This commit is contained in:
Livio Spring
2023-07-10 15:27:00 +02:00
committed by GitHub
parent be1fe36776
commit 14b8cf4894
69 changed files with 5948 additions and 106 deletions

View File

@@ -2,6 +2,7 @@ package oidc
import (
"context"
"encoding/base64"
"strings"
"time"
@@ -10,16 +11,75 @@ import (
"github.com/zitadel/oidc/v2/pkg/op"
"github.com/zitadel/zitadel/internal/api/authz"
http_utils "github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/api/http/middleware"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/user/model"
)
const (
LoginClientHeader = "x-zitadel-login-client"
)
func (o *OPStorage) CreateAuthRequest(ctx context.Context, req *oidc.AuthRequest, userID string) (_ op.AuthRequest, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
headers, _ := http_utils.HeadersFromCtx(ctx)
if loginClient := headers.Get(LoginClientHeader); loginClient != "" {
return o.createAuthRequestLoginClient(ctx, req, userID, loginClient)
}
return o.createAuthRequest(ctx, req, userID)
}
func (o *OPStorage) createAuthRequestLoginClient(ctx context.Context, req *oidc.AuthRequest, hintUserID, loginClient string) (op.AuthRequest, error) {
project, err := o.query.ProjectByClientID(ctx, req.ClientID, false)
if err != nil {
return nil, err
}
scope, err := o.assertProjectRoleScopesByProject(ctx, project, req.Scopes)
if err != nil {
return nil, err
}
audience, err := o.audienceFromProjectID(ctx, project.ID)
if err != nil {
return nil, err
}
audience = domain.AddAudScopeToAudience(ctx, audience, scope)
authRequest := &command.AuthRequest{
LoginClient: loginClient,
ClientID: req.ClientID,
RedirectURI: req.RedirectURI,
State: req.State,
Nonce: req.Nonce,
Scope: scope,
Audience: audience,
ResponseType: ResponseTypeToBusiness(req.ResponseType),
CodeChallenge: CodeChallengeToBusiness(req.CodeChallenge, req.CodeChallengeMethod),
Prompt: PromptToBusiness(req.Prompt),
UILocales: UILocalesToBusiness(req.UILocales),
MaxAge: MaxAgeToBusiness(req.MaxAge),
}
if req.LoginHint != "" {
authRequest.LoginHint = &req.LoginHint
}
if hintUserID != "" {
authRequest.HintUserID = &hintUserID
}
aar, err := o.command.AddAuthRequest(ctx, authRequest)
if err != nil {
return nil, err
}
return &AuthRequestV2{aar}, nil
}
func (o *OPStorage) createAuthRequest(ctx context.Context, req *oidc.AuthRequest, userID string) (_ op.AuthRequest, err error) {
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
if !ok {
return nil, errors.ThrowPreconditionFailed(nil, "OIDC-sd436", "no user agent id")
@@ -36,9 +96,31 @@ func (o *OPStorage) CreateAuthRequest(ctx context.Context, req *oidc.AuthRequest
return AuthRequestFromBusiness(resp)
}
func (o *OPStorage) audienceFromProjectID(ctx context.Context, projectID string) ([]string, error) {
projectIDQuery, err := query.NewAppProjectIDSearchQuery(projectID)
if err != nil {
return nil, err
}
appIDs, err := o.query.SearchClientIDs(ctx, &query.AppSearchQueries{Queries: []query.SearchQuery{projectIDQuery}}, false)
if err != nil {
return nil, err
}
return append(appIDs, projectID), nil
}
func (o *OPStorage) AuthRequestByID(ctx context.Context, id string) (_ op.AuthRequest, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
if strings.HasPrefix(id, command.IDPrefixV2) {
req, err := o.command.GetCurrentAuthRequest(ctx, id)
if err != nil {
return nil, err
}
return &AuthRequestV2{req}, nil
}
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
if !ok {
return nil, errors.ThrowPreconditionFailed(nil, "OIDC-D3g21", "no user agent id")
@@ -54,6 +136,17 @@ func (o *OPStorage) AuthRequestByCode(ctx context.Context, code string) (_ op.Au
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
plainCode, err := o.decryptGrant(code)
if err != nil {
return nil, err
}
if strings.HasPrefix(plainCode, command.IDPrefixV2) {
authReq, err := o.command.ExchangeAuthCode(ctx, plainCode)
if err != nil {
return nil, err
}
return &AuthRequestV2{authReq}, nil
}
resp, err := o.repo.AuthRequestByCode(ctx, code)
if err != nil {
return nil, err
@@ -61,9 +154,23 @@ func (o *OPStorage) AuthRequestByCode(ctx context.Context, code string) (_ op.Au
return AuthRequestFromBusiness(resp)
}
// decryptGrant decrypts a code or refresh_token
func (o *OPStorage) decryptGrant(grant string) (string, error) {
decodedGrant, err := base64.RawURLEncoding.DecodeString(grant)
if err != nil {
return "", err
}
return o.encAlg.DecryptString(decodedGrant, o.encAlg.EncryptionKeyID())
}
func (o *OPStorage) SaveAuthCode(ctx context.Context, id, code string) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
if strings.HasPrefix(id, command.IDPrefixV2) {
return o.command.AddAuthRequestCode(ctx, id, code)
}
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
if !ok {
return errors.ThrowPreconditionFailed(nil, "OIDC-Dgus2", "no user agent id")
@@ -81,12 +188,15 @@ func (o *OPStorage) DeleteAuthRequest(ctx context.Context, id string) (err error
func (o *OPStorage) CreateAccessToken(ctx context.Context, req op.TokenRequest) (_ string, _ time.Time, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
var userAgentID, applicationID, userOrgID string
authReq, ok := req.(*AuthRequest)
if ok {
switch authReq := req.(type) {
case *AuthRequest:
userAgentID = authReq.AgentID
applicationID = authReq.ApplicationID
userOrgID = authReq.UserOrgID
case *AuthRequestV2:
return o.command.AddOIDCSessionAccessToken(setContextUserSystem(ctx), authReq.GetID())
}
accessTokenLifetime, _, _, _, err := o.getOIDCSettings(ctx)
@@ -104,6 +214,15 @@ func (o *OPStorage) CreateAccessToken(ctx context.Context, req op.TokenRequest)
func (o *OPStorage) CreateAccessAndRefreshTokens(ctx context.Context, req op.TokenRequest, refreshToken string) (_, _ string, _ time.Time, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
// handle V2 request directly
switch tokenReq := req.(type) {
case *AuthRequestV2:
return o.command.AddOIDCSessionRefreshAndAccessToken(setContextUserSystem(ctx), tokenReq.GetID())
case *RefreshTokenRequestV2:
return o.command.ExchangeOIDCSessionRefreshAndAccessToken(setContextUserSystem(ctx), tokenReq.OIDCSessionWriteModel.AggregateID, refreshToken, tokenReq.RequestedScopes)
}
userAgentID, applicationID, userOrgID, authTime, authMethodsReferences := getInfoFromRequest(req)
scopes, err := o.assertProjectRoleScopes(ctx, applicationID, req.GetScopes())
if err != nil {
@@ -142,7 +261,22 @@ func getInfoFromRequest(req op.TokenRequest) (string, string, string, time.Time,
return "", "", "", time.Time{}, nil
}
func (o *OPStorage) TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (op.RefreshTokenRequest, error) {
func (o *OPStorage) TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (_ op.RefreshTokenRequest, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
plainCode, err := o.decryptGrant(refreshToken)
if err != nil {
return nil, err
}
if strings.HasPrefix(plainCode, command.IDPrefixV2) {
oidcSession, err := o.command.OIDCSessionByRefreshToken(ctx, plainCode)
if err != nil {
return nil, err
}
return &RefreshTokenRequestV2{OIDCSessionWriteModel: oidcSession}, nil
}
tokenView, err := o.repo.RefreshTokenByToken(ctx, refreshToken)
if err != nil {
return nil, err
@@ -245,6 +379,29 @@ func (o *OPStorage) assertProjectRoleScopes(ctx context.Context, clientID string
return scopes, nil
}
func (o *OPStorage) assertProjectRoleScopesByProject(ctx context.Context, project *query.Project, scopes []string) ([]string, error) {
for _, scope := range scopes {
if strings.HasPrefix(scope, ScopeProjectRolePrefix) {
return scopes, nil
}
}
if !project.ProjectRoleAssertion {
return scopes, nil
}
projectIDQuery, err := query.NewProjectRoleProjectIDSearchQuery(project.ID)
if err != nil {
return nil, errors.ThrowInternal(err, "OIDC-Cyc78", "Errors.Internal")
}
roles, err := o.query.SearchProjectRoles(ctx, true, &query.ProjectRoleSearchQueries{Queries: []query.SearchQuery{projectIDQuery}}, false)
if err != nil {
return nil, err
}
for _, role := range roles.ProjectRoles {
scopes = append(scopes, ScopeProjectRolePrefix+role.Key)
}
return scopes, nil
}
func (o *OPStorage) assertClientScopesForPAT(ctx context.Context, token *model.TokenView, clientID, projectID string) error {
token.Audience = append(token.Audience, clientID)
projectIDQuery, err := query.NewProjectRoleProjectIDSearchQuery(projectID)
@@ -279,3 +436,58 @@ func (o *OPStorage) getOIDCSettings(ctx context.Context) (accessTokenLifetime, i
}
return o.defaultAccessTokenLifetime, o.defaultIdTokenLifetime, o.defaultRefreshTokenIdleExpiration, o.defaultRefreshTokenExpiration, nil
}
func CreateErrorCallbackURL(authReq op.AuthRequest, reason, description, uri string, authorizer op.Authorizer) (string, error) {
e := struct {
Error string `schema:"error"`
Description string `schema:"error_description,omitempty"`
URI string `schema:"error_uri,omitempty"`
State string `schema:"state,omitempty"`
}{
Error: reason,
Description: description,
URI: uri,
State: authReq.GetState(),
}
callback, err := op.AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), e, authorizer.Encoder())
if err != nil {
return "", err
}
return callback, nil
}
func CreateCodeCallbackURL(ctx context.Context, authReq op.AuthRequest, authorizer op.Authorizer) (string, error) {
code, err := op.CreateAuthRequestCode(ctx, authReq, authorizer.Storage(), authorizer.Crypto())
if err != nil {
return "", err
}
codeResponse := struct {
code string
state string
}{
code: code,
state: authReq.GetState(),
}
callback, err := op.AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), &codeResponse, authorizer.Encoder())
if err != nil {
return "", err
}
return callback, err
}
func CreateTokenCallbackURL(ctx context.Context, req op.AuthRequest, authorizer op.Authorizer) (string, error) {
client, err := authorizer.Storage().GetClientByClientID(ctx, req.GetClientID())
if err != nil {
return "", err
}
createAccessToken := req.GetResponseType() != oidc.ResponseTypeIDTokenOnly
resp, err := op.CreateTokenResponse(ctx, req, client, authorizer, createAccessToken, "", "")
if err != nil {
return "", err
}
callback, err := op.AuthResponseURL(req.GetRedirectURI(), req.GetResponseType(), req.GetResponseMode(), resp, authorizer.Encoder())
if err != nil {
return "", err
}
return callback, err
}