mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-12 07:37:31 +00:00
feat(api): add OIDC session service (#6157)
This PR starts the OIDC implementation for the API V2 including the Implicit and Code Flow. Co-authored-by: Livio Spring <livio.a@gmail.com> Co-authored-by: Tim Möhlmann <tim+github@zitadel.com> Co-authored-by: Stefan Benz <46600784+stebenz@users.noreply.github.com>
This commit is contained in:
@@ -4,11 +4,11 @@ import "context"
|
||||
|
||||
func NewMockContext(instanceID, orgID, userID string) context.Context {
|
||||
ctx := context.WithValue(context.Background(), dataKey, CtxData{UserID: userID, OrgID: orgID})
|
||||
return context.WithValue(ctx, instanceKey, instanceID)
|
||||
return context.WithValue(ctx, instanceKey, &instance{id: instanceID})
|
||||
}
|
||||
|
||||
func NewMockContextWithPermissions(instanceID, orgID, userID string, permissions []string) context.Context {
|
||||
ctx := context.WithValue(context.Background(), dataKey, CtxData{UserID: userID, OrgID: orgID})
|
||||
ctx = context.WithValue(ctx, instanceKey, instanceID)
|
||||
ctx = context.WithValue(ctx, instanceKey, &instance{id: instanceID})
|
||||
return context.WithValue(ctx, requestPermissionsKey, permissions)
|
||||
}
|
||||
|
204
internal/api/grpc/oidc/v2/oidc.go
Normal file
204
internal/api/grpc/oidc/v2/oidc.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
"github.com/zitadel/oidc/v2/pkg/op"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/api/grpc/object/v2"
|
||||
"github.com/zitadel/zitadel/internal/api/http"
|
||||
"github.com/zitadel/zitadel/internal/api/oidc"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2alpha"
|
||||
)
|
||||
|
||||
func (s *Server) GetAuthRequest(ctx context.Context, req *oidc_pb.GetAuthRequestRequest) (*oidc_pb.GetAuthRequestResponse, error) {
|
||||
authRequest, err := s.query.AuthRequestByID(ctx, true, req.GetAuthRequestId(), true)
|
||||
if err != nil {
|
||||
logging.WithError(err).Error("query authRequest by ID")
|
||||
return nil, err
|
||||
}
|
||||
return &oidc_pb.GetAuthRequestResponse{
|
||||
AuthRequest: authRequestToPb(authRequest),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func authRequestToPb(a *query.AuthRequest) *oidc_pb.AuthRequest {
|
||||
pba := &oidc_pb.AuthRequest{
|
||||
Id: a.ID,
|
||||
CreationDate: timestamppb.New(a.CreationDate),
|
||||
ClientId: a.ClientID,
|
||||
Scope: a.Scope,
|
||||
RedirectUri: a.RedirectURI,
|
||||
Prompt: promptsToPb(a.Prompt),
|
||||
UiLocales: a.UiLocales,
|
||||
LoginHint: a.LoginHint,
|
||||
HintUserId: a.HintUserID,
|
||||
}
|
||||
if a.MaxAge != nil {
|
||||
pba.MaxAge = durationpb.New(*a.MaxAge)
|
||||
}
|
||||
return pba
|
||||
}
|
||||
|
||||
func promptsToPb(promps []domain.Prompt) []oidc_pb.Prompt {
|
||||
out := make([]oidc_pb.Prompt, len(promps))
|
||||
for i, p := range promps {
|
||||
out[i] = promptToPb(p)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func promptToPb(p domain.Prompt) oidc_pb.Prompt {
|
||||
switch p {
|
||||
case domain.PromptUnspecified:
|
||||
return oidc_pb.Prompt_PROMPT_UNSPECIFIED
|
||||
case domain.PromptNone:
|
||||
return oidc_pb.Prompt_PROMPT_NONE
|
||||
case domain.PromptLogin:
|
||||
return oidc_pb.Prompt_PROMPT_LOGIN
|
||||
case domain.PromptConsent:
|
||||
return oidc_pb.Prompt_PROMPT_CONSENT
|
||||
case domain.PromptSelectAccount:
|
||||
return oidc_pb.Prompt_PROMPT_SELECT_ACCOUNT
|
||||
case domain.PromptCreate:
|
||||
return oidc_pb.Prompt_PROMPT_CREATE
|
||||
default:
|
||||
return oidc_pb.Prompt_PROMPT_UNSPECIFIED
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) CreateCallback(ctx context.Context, req *oidc_pb.CreateCallbackRequest) (*oidc_pb.CreateCallbackResponse, error) {
|
||||
switch v := req.GetCallbackKind().(type) {
|
||||
case *oidc_pb.CreateCallbackRequest_Error:
|
||||
return s.failAuthRequest(ctx, req.GetAuthRequestId(), v.Error)
|
||||
case *oidc_pb.CreateCallbackRequest_Session:
|
||||
return s.linkSessionToAuthRequest(ctx, req.GetAuthRequestId(), v.Session)
|
||||
default:
|
||||
return nil, errors.ThrowUnimplementedf(nil, "OIDCv2-zee7A", "verification oneOf %T in method CreateCallback not implemented", v)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) failAuthRequest(ctx context.Context, authRequestID string, ae *oidc_pb.AuthorizationError) (*oidc_pb.CreateCallbackResponse, error) {
|
||||
details, aar, err := s.command.FailAuthRequest(ctx, authRequestID, errorReasonToDomain(ae.GetError()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authReq := &oidc.AuthRequestV2{CurrentAuthRequest: aar}
|
||||
callback, err := oidc.CreateErrorCallbackURL(authReq, errorReasonToOIDC(ae.GetError()), ae.GetErrorDescription(), ae.GetErrorUri(), s.op)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &oidc_pb.CreateCallbackResponse{
|
||||
Details: object.DomainToDetailsPb(details),
|
||||
CallbackUrl: callback,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) linkSessionToAuthRequest(ctx context.Context, authRequestID string, session *oidc_pb.Session) (*oidc_pb.CreateCallbackResponse, error) {
|
||||
details, aar, err := s.command.LinkSessionToAuthRequest(ctx, authRequestID, session.GetSessionId(), session.GetSessionToken(), true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authReq := &oidc.AuthRequestV2{CurrentAuthRequest: aar}
|
||||
ctx = op.ContextWithIssuer(ctx, http.BuildOrigin(authz.GetInstance(ctx).RequestedHost(), s.externalSecure))
|
||||
var callback string
|
||||
if aar.ResponseType == domain.OIDCResponseTypeCode {
|
||||
callback, err = oidc.CreateCodeCallbackURL(ctx, authReq, s.op)
|
||||
} else {
|
||||
callback, err = oidc.CreateTokenCallbackURL(ctx, authReq, s.op)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &oidc_pb.CreateCallbackResponse{
|
||||
Details: object.DomainToDetailsPb(details),
|
||||
CallbackUrl: callback,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func errorReasonToDomain(errorReason oidc_pb.ErrorReason) domain.OIDCErrorReason {
|
||||
switch errorReason {
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_UNSPECIFIED:
|
||||
return domain.OIDCErrorReasonUnspecified
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST:
|
||||
return domain.OIDCErrorReasonInvalidRequest
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_UNAUTHORIZED_CLIENT:
|
||||
return domain.OIDCErrorReasonUnauthorizedClient
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_ACCESS_DENIED:
|
||||
return domain.OIDCErrorReasonAccessDenied
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_UNSUPPORTED_RESPONSE_TYPE:
|
||||
return domain.OIDCErrorReasonUnsupportedResponseType
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_SCOPE:
|
||||
return domain.OIDCErrorReasonInvalidScope
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_SERVER_ERROR:
|
||||
return domain.OIDCErrorReasonServerError
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_TEMPORARY_UNAVAILABLE:
|
||||
return domain.OIDCErrorReasonTemporaryUnavailable
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_INTERACTION_REQUIRED:
|
||||
return domain.OIDCErrorReasonInteractionRequired
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_LOGIN_REQUIRED:
|
||||
return domain.OIDCErrorReasonLoginRequired
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_ACCOUNT_SELECTION_REQUIRED:
|
||||
return domain.OIDCErrorReasonAccountSelectionRequired
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_CONSENT_REQUIRED:
|
||||
return domain.OIDCErrorReasonConsentRequired
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_URI:
|
||||
return domain.OIDCErrorReasonInvalidRequestURI
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_OBJECT:
|
||||
return domain.OIDCErrorReasonInvalidRequestObject
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_REQUEST_NOT_SUPPORTED:
|
||||
return domain.OIDCErrorReasonRequestNotSupported
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_REQUEST_URI_NOT_SUPPORTED:
|
||||
return domain.OIDCErrorReasonRequestURINotSupported
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_REGISTRATION_NOT_SUPPORTED:
|
||||
return domain.OIDCErrorReasonRegistrationNotSupported
|
||||
default:
|
||||
return domain.OIDCErrorReasonUnspecified
|
||||
}
|
||||
}
|
||||
|
||||
func errorReasonToOIDC(reason oidc_pb.ErrorReason) string {
|
||||
switch reason {
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST:
|
||||
return "invalid_request"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_UNAUTHORIZED_CLIENT:
|
||||
return "unauthorized_client"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_ACCESS_DENIED:
|
||||
return "access_denied"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_UNSUPPORTED_RESPONSE_TYPE:
|
||||
return "unsupported_response_type"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_SCOPE:
|
||||
return "invalid_scope"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_TEMPORARY_UNAVAILABLE:
|
||||
return "temporarily_unavailable"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_INTERACTION_REQUIRED:
|
||||
return "interaction_required"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_LOGIN_REQUIRED:
|
||||
return "login_required"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_ACCOUNT_SELECTION_REQUIRED:
|
||||
return "account_selection_required"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_CONSENT_REQUIRED:
|
||||
return "consent_required"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_URI:
|
||||
return "invalid_request_uri"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_OBJECT:
|
||||
return "invalid_request_object"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_REQUEST_NOT_SUPPORTED:
|
||||
return "request_not_supported"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_REQUEST_URI_NOT_SUPPORTED:
|
||||
return "request_uri_not_supported"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_REGISTRATION_NOT_SUPPORTED:
|
||||
return "registration_not_supported"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_UNSPECIFIED, oidc_pb.ErrorReason_ERROR_REASON_SERVER_ERROR:
|
||||
fallthrough
|
||||
default:
|
||||
return "server_error"
|
||||
}
|
||||
}
|
249
internal/api/grpc/oidc/v2/oidc_integration_test.go
Normal file
249
internal/api/grpc/oidc/v2/oidc_integration_test.go
Normal file
@@ -0,0 +1,249 @@
|
||||
//go:build integration
|
||||
|
||||
package oidc_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
"os"
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/integration"
|
||||
object "github.com/zitadel/zitadel/pkg/grpc/object/v2alpha"
|
||||
oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2alpha"
|
||||
session "github.com/zitadel/zitadel/pkg/grpc/session/v2alpha"
|
||||
user "github.com/zitadel/zitadel/pkg/grpc/user/v2alpha"
|
||||
)
|
||||
|
||||
var (
|
||||
CTX context.Context
|
||||
Tester *integration.Tester
|
||||
Client oidc_pb.OIDCServiceClient
|
||||
User *user.AddHumanUserResponse
|
||||
)
|
||||
|
||||
const (
|
||||
redirectURI = "oidcIntegrationTest://callback"
|
||||
redirectURIImplicit = "http://localhost:9999/callback"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
os.Exit(func() int {
|
||||
ctx, _, cancel := integration.Contexts(5 * time.Minute)
|
||||
defer cancel()
|
||||
|
||||
Tester = integration.NewTester(ctx)
|
||||
defer Tester.Done()
|
||||
Client = Tester.Client.OIDCv2
|
||||
|
||||
CTX = Tester.WithAuthorization(ctx, integration.OrgOwner)
|
||||
User = Tester.CreateHumanUser(CTX)
|
||||
return m.Run()
|
||||
}())
|
||||
}
|
||||
|
||||
func TestServer_GetAuthRequest(t *testing.T) {
|
||||
client, err := Tester.CreateOIDCNativeClient(CTX, redirectURI)
|
||||
require.NoError(t, err)
|
||||
authRequestID, err := Tester.CreateOIDCAuthRequest(client.GetClientId(), Tester.Users[integration.FirstInstanceUsersKey][integration.OrgOwner].ID, redirectURI)
|
||||
require.NoError(t, err)
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
AuthRequestID string
|
||||
want *oidc_pb.GetAuthRequestResponse
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Not found",
|
||||
AuthRequestID: "123",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "success",
|
||||
AuthRequestID: authRequestID,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := Client.GetAuthRequest(CTX, &oidc_pb.GetAuthRequestRequest{
|
||||
AuthRequestId: tt.AuthRequestID,
|
||||
})
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
authRequest := got.GetAuthRequest()
|
||||
assert.NotNil(t, authRequest)
|
||||
assert.Equal(t, authRequestID, authRequest.GetId())
|
||||
assert.WithinRange(t, authRequest.GetCreationDate().AsTime(), now.Add(-time.Second), now.Add(time.Second))
|
||||
assert.Contains(t, authRequest.GetScope(), "openid")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_CreateCallback(t *testing.T) {
|
||||
client, err := Tester.CreateOIDCNativeClient(CTX, redirectURI)
|
||||
require.NoError(t, err)
|
||||
sessionResp, err := Tester.Client.SessionV2.CreateSession(CTX, &session.CreateSessionRequest{
|
||||
Checks: &session.Checks{
|
||||
User: &session.CheckUser{
|
||||
Search: &session.CheckUser_UserId{
|
||||
UserId: Tester.Users[integration.FirstInstanceUsersKey][integration.OrgOwner].ID,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
req *oidc_pb.CreateCallbackRequest
|
||||
AuthError string
|
||||
want *oidc_pb.CreateCallbackResponse
|
||||
wantURL *url.URL
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Not found",
|
||||
req: &oidc_pb.CreateCallbackRequest{
|
||||
AuthRequestId: "123",
|
||||
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
|
||||
Session: &oidc_pb.Session{
|
||||
SessionId: sessionResp.GetSessionId(),
|
||||
SessionToken: sessionResp.GetSessionToken(),
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "session not found",
|
||||
req: &oidc_pb.CreateCallbackRequest{
|
||||
AuthRequestId: func() string {
|
||||
authRequestID, err := Tester.CreateOIDCAuthRequest(client.GetClientId(), Tester.Users[integration.FirstInstanceUsersKey][integration.OrgOwner].ID, redirectURI)
|
||||
require.NoError(t, err)
|
||||
return authRequestID
|
||||
}(),
|
||||
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
|
||||
Session: &oidc_pb.Session{
|
||||
SessionId: "foo",
|
||||
SessionToken: "bar",
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "session token invalid",
|
||||
req: &oidc_pb.CreateCallbackRequest{
|
||||
AuthRequestId: func() string {
|
||||
authRequestID, err := Tester.CreateOIDCAuthRequest(client.GetClientId(), Tester.Users[integration.FirstInstanceUsersKey][integration.OrgOwner].ID, redirectURI)
|
||||
require.NoError(t, err)
|
||||
return authRequestID
|
||||
}(),
|
||||
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
|
||||
Session: &oidc_pb.Session{
|
||||
SessionId: sessionResp.GetSessionId(),
|
||||
SessionToken: "bar",
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "fail callback",
|
||||
req: &oidc_pb.CreateCallbackRequest{
|
||||
AuthRequestId: func() string {
|
||||
authRequestID, err := Tester.CreateOIDCAuthRequest(client.GetClientId(), Tester.Users[integration.FirstInstanceUsersKey][integration.OrgOwner].ID, redirectURI)
|
||||
require.NoError(t, err)
|
||||
return authRequestID
|
||||
}(),
|
||||
CallbackKind: &oidc_pb.CreateCallbackRequest_Error{
|
||||
Error: &oidc_pb.AuthorizationError{
|
||||
Error: oidc_pb.ErrorReason_ERROR_REASON_ACCESS_DENIED,
|
||||
ErrorDescription: gu.Ptr("nope"),
|
||||
ErrorUri: gu.Ptr("https://example.com/docs"),
|
||||
},
|
||||
},
|
||||
},
|
||||
want: &oidc_pb.CreateCallbackResponse{
|
||||
CallbackUrl: regexp.QuoteMeta(`oidcintegrationtest://callback?error=access_denied&error_description=nope&error_uri=https%3A%2F%2Fexample.com%2Fdocs&state=state`),
|
||||
Details: &object.Details{
|
||||
ResourceOwner: Tester.Instance.InstanceID(),
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "code callback",
|
||||
req: &oidc_pb.CreateCallbackRequest{
|
||||
AuthRequestId: func() string {
|
||||
authRequestID, err := Tester.CreateOIDCAuthRequest(client.GetClientId(), Tester.Users[integration.FirstInstanceUsersKey][integration.OrgOwner].ID, redirectURI)
|
||||
require.NoError(t, err)
|
||||
return authRequestID
|
||||
}(),
|
||||
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
|
||||
Session: &oidc_pb.Session{
|
||||
SessionId: sessionResp.GetSessionId(),
|
||||
SessionToken: sessionResp.GetSessionToken(),
|
||||
},
|
||||
},
|
||||
},
|
||||
want: &oidc_pb.CreateCallbackResponse{
|
||||
CallbackUrl: `oidcintegrationtest:\/\/callback\?code=(.*)&state=state`,
|
||||
Details: &object.Details{
|
||||
ResourceOwner: Tester.Instance.InstanceID(),
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "implicit",
|
||||
req: &oidc_pb.CreateCallbackRequest{
|
||||
AuthRequestId: func() string {
|
||||
client, err := Tester.CreateOIDCImplicitFlowClient(CTX, redirectURIImplicit)
|
||||
require.NoError(t, err)
|
||||
authRequestID, err := Tester.CreateOIDCAuthRequestImplicit(client.GetClientId(), Tester.Users[integration.FirstInstanceUsersKey][integration.OrgOwner].ID, redirectURIImplicit)
|
||||
require.NoError(t, err)
|
||||
return authRequestID
|
||||
}(),
|
||||
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
|
||||
Session: &oidc_pb.Session{
|
||||
SessionId: sessionResp.GetSessionId(),
|
||||
SessionToken: sessionResp.GetSessionToken(),
|
||||
},
|
||||
},
|
||||
},
|
||||
want: &oidc_pb.CreateCallbackResponse{
|
||||
CallbackUrl: `http:\/\/localhost:9999\/callback#access_token=(.*)&expires_in=(.*)&id_token=(.*)&state=state&token_type=Bearer`,
|
||||
Details: &object.Details{
|
||||
ResourceOwner: Tester.Instance.InstanceID(),
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := Client.CreateCallback(CTX, tt.req)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
integration.AssertDetails(t, tt.want, got)
|
||||
if tt.want != nil {
|
||||
assert.Regexp(t, regexp.MustCompile(tt.want.CallbackUrl), got.GetCallbackUrl())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
150
internal/api/grpc/oidc/v2/oidc_test.go
Normal file
150
internal/api/grpc/oidc/v2/oidc_test.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2alpha"
|
||||
)
|
||||
|
||||
func Test_authRequestToPb(t *testing.T) {
|
||||
now := time.Now()
|
||||
arg := &query.AuthRequest{
|
||||
ID: "authID",
|
||||
CreationDate: now,
|
||||
ClientID: "clientID",
|
||||
Scope: []string{"a", "b", "c"},
|
||||
RedirectURI: "callbackURI",
|
||||
Prompt: []domain.Prompt{
|
||||
domain.PromptUnspecified,
|
||||
domain.PromptNone,
|
||||
domain.PromptLogin,
|
||||
domain.PromptConsent,
|
||||
domain.PromptSelectAccount,
|
||||
domain.PromptCreate,
|
||||
999,
|
||||
},
|
||||
UiLocales: []string{"en", "fi"},
|
||||
LoginHint: gu.Ptr("foo@bar.com"),
|
||||
MaxAge: gu.Ptr(time.Minute),
|
||||
HintUserID: gu.Ptr("userID"),
|
||||
}
|
||||
want := &oidc_pb.AuthRequest{
|
||||
Id: "authID",
|
||||
CreationDate: timestamppb.New(now),
|
||||
ClientId: "clientID",
|
||||
RedirectUri: "callbackURI",
|
||||
Prompt: []oidc_pb.Prompt{
|
||||
oidc_pb.Prompt_PROMPT_UNSPECIFIED,
|
||||
oidc_pb.Prompt_PROMPT_NONE,
|
||||
oidc_pb.Prompt_PROMPT_LOGIN,
|
||||
oidc_pb.Prompt_PROMPT_CONSENT,
|
||||
oidc_pb.Prompt_PROMPT_SELECT_ACCOUNT,
|
||||
oidc_pb.Prompt_PROMPT_CREATE,
|
||||
oidc_pb.Prompt_PROMPT_UNSPECIFIED,
|
||||
},
|
||||
UiLocales: []string{"en", "fi"},
|
||||
Scope: []string{"a", "b", "c"},
|
||||
LoginHint: gu.Ptr("foo@bar.com"),
|
||||
MaxAge: durationpb.New(time.Minute),
|
||||
HintUserId: gu.Ptr("userID"),
|
||||
}
|
||||
got := authRequestToPb(arg)
|
||||
if !proto.Equal(want, got) {
|
||||
t.Errorf("authRequestToPb() =\n%v\nwant\n%v\n", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_errorReasonToOIDC(t *testing.T) {
|
||||
tests := []struct {
|
||||
reason oidc_pb.ErrorReason
|
||||
want string
|
||||
}{
|
||||
{
|
||||
reason: oidc_pb.ErrorReason_ERROR_REASON_UNSPECIFIED,
|
||||
want: "server_error",
|
||||
},
|
||||
{
|
||||
reason: oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST,
|
||||
want: "invalid_request",
|
||||
},
|
||||
{
|
||||
reason: oidc_pb.ErrorReason_ERROR_REASON_UNAUTHORIZED_CLIENT,
|
||||
want: "unauthorized_client",
|
||||
},
|
||||
{
|
||||
reason: oidc_pb.ErrorReason_ERROR_REASON_ACCESS_DENIED,
|
||||
want: "access_denied",
|
||||
},
|
||||
{
|
||||
reason: oidc_pb.ErrorReason_ERROR_REASON_UNSUPPORTED_RESPONSE_TYPE,
|
||||
want: "unsupported_response_type",
|
||||
},
|
||||
{
|
||||
reason: oidc_pb.ErrorReason_ERROR_REASON_INVALID_SCOPE,
|
||||
want: "invalid_scope",
|
||||
},
|
||||
{
|
||||
reason: oidc_pb.ErrorReason_ERROR_REASON_SERVER_ERROR,
|
||||
want: "server_error",
|
||||
},
|
||||
{
|
||||
reason: oidc_pb.ErrorReason_ERROR_REASON_TEMPORARY_UNAVAILABLE,
|
||||
want: "temporarily_unavailable",
|
||||
},
|
||||
{
|
||||
reason: oidc_pb.ErrorReason_ERROR_REASON_INTERACTION_REQUIRED,
|
||||
want: "interaction_required",
|
||||
},
|
||||
{
|
||||
reason: oidc_pb.ErrorReason_ERROR_REASON_LOGIN_REQUIRED,
|
||||
want: "login_required",
|
||||
},
|
||||
{
|
||||
reason: oidc_pb.ErrorReason_ERROR_REASON_ACCOUNT_SELECTION_REQUIRED,
|
||||
want: "account_selection_required",
|
||||
},
|
||||
{
|
||||
reason: oidc_pb.ErrorReason_ERROR_REASON_CONSENT_REQUIRED,
|
||||
want: "consent_required",
|
||||
},
|
||||
{
|
||||
reason: oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_URI,
|
||||
want: "invalid_request_uri",
|
||||
},
|
||||
{
|
||||
reason: oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_OBJECT,
|
||||
want: "invalid_request_object",
|
||||
},
|
||||
{
|
||||
reason: oidc_pb.ErrorReason_ERROR_REASON_REQUEST_NOT_SUPPORTED,
|
||||
want: "request_not_supported",
|
||||
},
|
||||
{
|
||||
reason: oidc_pb.ErrorReason_ERROR_REASON_REQUEST_URI_NOT_SUPPORTED,
|
||||
want: "request_uri_not_supported",
|
||||
},
|
||||
{
|
||||
reason: oidc_pb.ErrorReason_ERROR_REASON_REGISTRATION_NOT_SUPPORTED,
|
||||
want: "registration_not_supported",
|
||||
},
|
||||
{
|
||||
reason: 99999,
|
||||
want: "server_error",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.reason.String(), func(t *testing.T) {
|
||||
got := errorReasonToOIDC(tt.reason)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
59
internal/api/grpc/oidc/v2/server.go
Normal file
59
internal/api/grpc/oidc/v2/server.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"github.com/zitadel/oidc/v2/pkg/op"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/api/grpc/server"
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2alpha"
|
||||
)
|
||||
|
||||
var _ oidc_pb.OIDCServiceServer = (*Server)(nil)
|
||||
|
||||
type Server struct {
|
||||
oidc_pb.UnimplementedOIDCServiceServer
|
||||
command *command.Commands
|
||||
query *query.Queries
|
||||
|
||||
op op.OpenIDProvider
|
||||
externalSecure bool
|
||||
}
|
||||
|
||||
type Config struct{}
|
||||
|
||||
func CreateServer(
|
||||
command *command.Commands,
|
||||
query *query.Queries,
|
||||
op op.OpenIDProvider,
|
||||
externalSecure bool,
|
||||
) *Server {
|
||||
return &Server{
|
||||
command: command,
|
||||
query: query,
|
||||
op: op,
|
||||
externalSecure: externalSecure,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) RegisterServer(grpcServer *grpc.Server) {
|
||||
oidc_pb.RegisterOIDCServiceServer(grpcServer, s)
|
||||
}
|
||||
|
||||
func (s *Server) AppName() string {
|
||||
return oidc_pb.OIDCService_ServiceDesc.ServiceName
|
||||
}
|
||||
|
||||
func (s *Server) MethodPrefix() string {
|
||||
return oidc_pb.OIDCService_ServiceDesc.ServiceName
|
||||
}
|
||||
|
||||
func (s *Server) AuthMethods() authz.MethodMapping {
|
||||
return oidc_pb.OIDCService_AuthMethods
|
||||
}
|
||||
|
||||
func (s *Server) RegisterGateway() server.RegisterGatewayFunc {
|
||||
return oidc_pb.RegisterOIDCServiceHandler
|
||||
}
|
@@ -18,11 +18,10 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
CTX context.Context
|
||||
Tester *integration.Tester
|
||||
Client session.SessionServiceClient
|
||||
User *user.AddHumanUserResponse
|
||||
GenericOAuthIDPID string
|
||||
CTX context.Context
|
||||
Tester *integration.Tester
|
||||
Client session.SessionServiceClient
|
||||
User *user.AddHumanUserResponse
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
|
@@ -540,7 +540,7 @@ func TestServer_StartIdentityProviderFlow(t *testing.T) {
|
||||
ResourceOwner: Tester.Organisation.ID,
|
||||
},
|
||||
NextStep: &user.StartIdentityProviderFlowResponse_AuthUrl{
|
||||
AuthUrl: "https://example.com/oauth/v2/authorize?client_id=clientID&prompt=select_account&redirect_uri=https%3A%2F%2Flocalhost%3A8080%2Fidps%2Fcallback&response_type=code&scope=openid+profile+email&state=",
|
||||
AuthUrl: "https://example.com/oauth/v2/authorize?client_id=clientID&prompt=select_account&redirect_uri=http%3A%2F%2Flocalhost%3A8080%2Fidps%2Fcallback&response_type=code&scope=openid+profile+email&state=",
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
|
43
internal/api/oidc/amr/amr.go
Normal file
43
internal/api/oidc/amr/amr.go
Normal file
@@ -0,0 +1,43 @@
|
||||
// Package amr maps zitadel session factors to Authentication Method Reference Values
|
||||
// as defined in [RFC 8176, section 2].
|
||||
//
|
||||
// [RFC 8176, section 2]: https://datatracker.ietf.org/doc/html/rfc8176#section-2
|
||||
package amr
|
||||
|
||||
const (
|
||||
// Password states that the users password has been verified
|
||||
// Deprecated: use `PWD` instead
|
||||
Password = "password"
|
||||
// PWD states that the users password has been verified
|
||||
PWD = "pwd"
|
||||
// MFA states that multiple factors have been verified (e.g. pwd and otp or passkey)
|
||||
MFA = "mfa"
|
||||
// OTP states that a one time password has been verified (e.g. TOTP)
|
||||
OTP = "otp"
|
||||
// UserPresence states that the end users presence has been verified (e.g. passkey and u2f)
|
||||
UserPresence = "user"
|
||||
)
|
||||
|
||||
type AuthenticationMethodReference interface {
|
||||
IsPasswordChecked() bool
|
||||
IsPasskeyChecked() bool
|
||||
IsU2FChecked() bool
|
||||
IsOTPChecked() bool
|
||||
}
|
||||
|
||||
func List(model AuthenticationMethodReference) []string {
|
||||
amr := make([]string, 0)
|
||||
if model.IsPasswordChecked() {
|
||||
amr = append(amr, PWD)
|
||||
}
|
||||
if model.IsPasskeyChecked() || model.IsU2FChecked() {
|
||||
amr = append(amr, UserPresence)
|
||||
}
|
||||
if model.IsOTPChecked() {
|
||||
amr = append(amr, OTP)
|
||||
}
|
||||
if model.IsPasskeyChecked() || len(amr) >= 2 {
|
||||
amr = append(amr, MFA)
|
||||
}
|
||||
return amr
|
||||
}
|
93
internal/api/oidc/amr/amr_test.go
Normal file
93
internal/api/oidc/amr/amr_test.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package amr
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAMR(t *testing.T) {
|
||||
type args struct {
|
||||
model AuthenticationMethodReference
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
"no checks, empty",
|
||||
args{
|
||||
new(test),
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"pw checked",
|
||||
args{
|
||||
&test{pwChecked: true},
|
||||
},
|
||||
[]string{PWD},
|
||||
},
|
||||
{
|
||||
"passkey checked",
|
||||
args{
|
||||
&test{passkeyChecked: true},
|
||||
},
|
||||
[]string{UserPresence, MFA},
|
||||
},
|
||||
{
|
||||
"u2f checked",
|
||||
args{
|
||||
&test{u2fChecked: true},
|
||||
},
|
||||
[]string{UserPresence},
|
||||
},
|
||||
{
|
||||
"otp checked",
|
||||
args{
|
||||
&test{otpChecked: true},
|
||||
},
|
||||
[]string{OTP},
|
||||
},
|
||||
{
|
||||
"multiple checked",
|
||||
args{
|
||||
&test{
|
||||
pwChecked: true,
|
||||
u2fChecked: true,
|
||||
},
|
||||
},
|
||||
[]string{PWD, UserPresence, MFA},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := List(tt.args.model)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type test struct {
|
||||
pwChecked bool
|
||||
passkeyChecked bool
|
||||
u2fChecked bool
|
||||
otpChecked bool
|
||||
}
|
||||
|
||||
func (t test) IsPasswordChecked() bool {
|
||||
return t.pwChecked
|
||||
}
|
||||
|
||||
func (t test) IsPasskeyChecked() bool {
|
||||
return t.passkeyChecked
|
||||
}
|
||||
|
||||
func (t test) IsU2FChecked() bool {
|
||||
return t.u2fChecked
|
||||
}
|
||||
|
||||
func (t test) IsOTPChecked() bool {
|
||||
return t.otpChecked
|
||||
}
|
@@ -2,6 +2,7 @@ package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -10,16 +11,75 @@ import (
|
||||
"github.com/zitadel/oidc/v2/pkg/op"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
http_utils "github.com/zitadel/zitadel/internal/api/http"
|
||||
"github.com/zitadel/zitadel/internal/api/http/middleware"
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
"github.com/zitadel/zitadel/internal/user/model"
|
||||
)
|
||||
|
||||
const (
|
||||
LoginClientHeader = "x-zitadel-login-client"
|
||||
)
|
||||
|
||||
func (o *OPStorage) CreateAuthRequest(ctx context.Context, req *oidc.AuthRequest, userID string) (_ op.AuthRequest, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
headers, _ := http_utils.HeadersFromCtx(ctx)
|
||||
if loginClient := headers.Get(LoginClientHeader); loginClient != "" {
|
||||
return o.createAuthRequestLoginClient(ctx, req, userID, loginClient)
|
||||
}
|
||||
|
||||
return o.createAuthRequest(ctx, req, userID)
|
||||
}
|
||||
|
||||
func (o *OPStorage) createAuthRequestLoginClient(ctx context.Context, req *oidc.AuthRequest, hintUserID, loginClient string) (op.AuthRequest, error) {
|
||||
project, err := o.query.ProjectByClientID(ctx, req.ClientID, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
scope, err := o.assertProjectRoleScopesByProject(ctx, project, req.Scopes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
audience, err := o.audienceFromProjectID(ctx, project.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
audience = domain.AddAudScopeToAudience(ctx, audience, scope)
|
||||
authRequest := &command.AuthRequest{
|
||||
LoginClient: loginClient,
|
||||
ClientID: req.ClientID,
|
||||
RedirectURI: req.RedirectURI,
|
||||
State: req.State,
|
||||
Nonce: req.Nonce,
|
||||
Scope: scope,
|
||||
Audience: audience,
|
||||
ResponseType: ResponseTypeToBusiness(req.ResponseType),
|
||||
CodeChallenge: CodeChallengeToBusiness(req.CodeChallenge, req.CodeChallengeMethod),
|
||||
Prompt: PromptToBusiness(req.Prompt),
|
||||
UILocales: UILocalesToBusiness(req.UILocales),
|
||||
MaxAge: MaxAgeToBusiness(req.MaxAge),
|
||||
}
|
||||
if req.LoginHint != "" {
|
||||
authRequest.LoginHint = &req.LoginHint
|
||||
}
|
||||
if hintUserID != "" {
|
||||
authRequest.HintUserID = &hintUserID
|
||||
}
|
||||
|
||||
aar, err := o.command.AddAuthRequest(ctx, authRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &AuthRequestV2{aar}, nil
|
||||
}
|
||||
|
||||
func (o *OPStorage) createAuthRequest(ctx context.Context, req *oidc.AuthRequest, userID string) (_ op.AuthRequest, err error) {
|
||||
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
|
||||
if !ok {
|
||||
return nil, errors.ThrowPreconditionFailed(nil, "OIDC-sd436", "no user agent id")
|
||||
@@ -36,9 +96,31 @@ func (o *OPStorage) CreateAuthRequest(ctx context.Context, req *oidc.AuthRequest
|
||||
return AuthRequestFromBusiness(resp)
|
||||
}
|
||||
|
||||
func (o *OPStorage) audienceFromProjectID(ctx context.Context, projectID string) ([]string, error) {
|
||||
projectIDQuery, err := query.NewAppProjectIDSearchQuery(projectID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
appIDs, err := o.query.SearchClientIDs(ctx, &query.AppSearchQueries{Queries: []query.SearchQuery{projectIDQuery}}, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return append(appIDs, projectID), nil
|
||||
}
|
||||
|
||||
func (o *OPStorage) AuthRequestByID(ctx context.Context, id string) (_ op.AuthRequest, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
if strings.HasPrefix(id, command.IDPrefixV2) {
|
||||
req, err := o.command.GetCurrentAuthRequest(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &AuthRequestV2{req}, nil
|
||||
}
|
||||
|
||||
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
|
||||
if !ok {
|
||||
return nil, errors.ThrowPreconditionFailed(nil, "OIDC-D3g21", "no user agent id")
|
||||
@@ -54,6 +136,17 @@ func (o *OPStorage) AuthRequestByCode(ctx context.Context, code string) (_ op.Au
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
plainCode, err := o.decryptGrant(code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.HasPrefix(plainCode, command.IDPrefixV2) {
|
||||
authReq, err := o.command.ExchangeAuthCode(ctx, plainCode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &AuthRequestV2{authReq}, nil
|
||||
}
|
||||
resp, err := o.repo.AuthRequestByCode(ctx, code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -61,9 +154,23 @@ func (o *OPStorage) AuthRequestByCode(ctx context.Context, code string) (_ op.Au
|
||||
return AuthRequestFromBusiness(resp)
|
||||
}
|
||||
|
||||
// decryptGrant decrypts a code or refresh_token
|
||||
func (o *OPStorage) decryptGrant(grant string) (string, error) {
|
||||
decodedGrant, err := base64.RawURLEncoding.DecodeString(grant)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return o.encAlg.DecryptString(decodedGrant, o.encAlg.EncryptionKeyID())
|
||||
}
|
||||
|
||||
func (o *OPStorage) SaveAuthCode(ctx context.Context, id, code string) (err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
if strings.HasPrefix(id, command.IDPrefixV2) {
|
||||
return o.command.AddAuthRequestCode(ctx, id, code)
|
||||
}
|
||||
|
||||
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
|
||||
if !ok {
|
||||
return errors.ThrowPreconditionFailed(nil, "OIDC-Dgus2", "no user agent id")
|
||||
@@ -81,12 +188,15 @@ func (o *OPStorage) DeleteAuthRequest(ctx context.Context, id string) (err error
|
||||
func (o *OPStorage) CreateAccessToken(ctx context.Context, req op.TokenRequest) (_ string, _ time.Time, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
var userAgentID, applicationID, userOrgID string
|
||||
authReq, ok := req.(*AuthRequest)
|
||||
if ok {
|
||||
switch authReq := req.(type) {
|
||||
case *AuthRequest:
|
||||
userAgentID = authReq.AgentID
|
||||
applicationID = authReq.ApplicationID
|
||||
userOrgID = authReq.UserOrgID
|
||||
case *AuthRequestV2:
|
||||
return o.command.AddOIDCSessionAccessToken(setContextUserSystem(ctx), authReq.GetID())
|
||||
}
|
||||
|
||||
accessTokenLifetime, _, _, _, err := o.getOIDCSettings(ctx)
|
||||
@@ -104,6 +214,15 @@ func (o *OPStorage) CreateAccessToken(ctx context.Context, req op.TokenRequest)
|
||||
func (o *OPStorage) CreateAccessAndRefreshTokens(ctx context.Context, req op.TokenRequest, refreshToken string) (_, _ string, _ time.Time, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
// handle V2 request directly
|
||||
switch tokenReq := req.(type) {
|
||||
case *AuthRequestV2:
|
||||
return o.command.AddOIDCSessionRefreshAndAccessToken(setContextUserSystem(ctx), tokenReq.GetID())
|
||||
case *RefreshTokenRequestV2:
|
||||
return o.command.ExchangeOIDCSessionRefreshAndAccessToken(setContextUserSystem(ctx), tokenReq.OIDCSessionWriteModel.AggregateID, refreshToken, tokenReq.RequestedScopes)
|
||||
}
|
||||
|
||||
userAgentID, applicationID, userOrgID, authTime, authMethodsReferences := getInfoFromRequest(req)
|
||||
scopes, err := o.assertProjectRoleScopes(ctx, applicationID, req.GetScopes())
|
||||
if err != nil {
|
||||
@@ -142,7 +261,22 @@ func getInfoFromRequest(req op.TokenRequest) (string, string, string, time.Time,
|
||||
return "", "", "", time.Time{}, nil
|
||||
}
|
||||
|
||||
func (o *OPStorage) TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (op.RefreshTokenRequest, error) {
|
||||
func (o *OPStorage) TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (_ op.RefreshTokenRequest, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
plainCode, err := o.decryptGrant(refreshToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.HasPrefix(plainCode, command.IDPrefixV2) {
|
||||
oidcSession, err := o.command.OIDCSessionByRefreshToken(ctx, plainCode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &RefreshTokenRequestV2{OIDCSessionWriteModel: oidcSession}, nil
|
||||
}
|
||||
|
||||
tokenView, err := o.repo.RefreshTokenByToken(ctx, refreshToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -245,6 +379,29 @@ func (o *OPStorage) assertProjectRoleScopes(ctx context.Context, clientID string
|
||||
return scopes, nil
|
||||
}
|
||||
|
||||
func (o *OPStorage) assertProjectRoleScopesByProject(ctx context.Context, project *query.Project, scopes []string) ([]string, error) {
|
||||
for _, scope := range scopes {
|
||||
if strings.HasPrefix(scope, ScopeProjectRolePrefix) {
|
||||
return scopes, nil
|
||||
}
|
||||
}
|
||||
if !project.ProjectRoleAssertion {
|
||||
return scopes, nil
|
||||
}
|
||||
projectIDQuery, err := query.NewProjectRoleProjectIDSearchQuery(project.ID)
|
||||
if err != nil {
|
||||
return nil, errors.ThrowInternal(err, "OIDC-Cyc78", "Errors.Internal")
|
||||
}
|
||||
roles, err := o.query.SearchProjectRoles(ctx, true, &query.ProjectRoleSearchQueries{Queries: []query.SearchQuery{projectIDQuery}}, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, role := range roles.ProjectRoles {
|
||||
scopes = append(scopes, ScopeProjectRolePrefix+role.Key)
|
||||
}
|
||||
return scopes, nil
|
||||
}
|
||||
|
||||
func (o *OPStorage) assertClientScopesForPAT(ctx context.Context, token *model.TokenView, clientID, projectID string) error {
|
||||
token.Audience = append(token.Audience, clientID)
|
||||
projectIDQuery, err := query.NewProjectRoleProjectIDSearchQuery(projectID)
|
||||
@@ -279,3 +436,58 @@ func (o *OPStorage) getOIDCSettings(ctx context.Context) (accessTokenLifetime, i
|
||||
}
|
||||
return o.defaultAccessTokenLifetime, o.defaultIdTokenLifetime, o.defaultRefreshTokenIdleExpiration, o.defaultRefreshTokenExpiration, nil
|
||||
}
|
||||
|
||||
func CreateErrorCallbackURL(authReq op.AuthRequest, reason, description, uri string, authorizer op.Authorizer) (string, error) {
|
||||
e := struct {
|
||||
Error string `schema:"error"`
|
||||
Description string `schema:"error_description,omitempty"`
|
||||
URI string `schema:"error_uri,omitempty"`
|
||||
State string `schema:"state,omitempty"`
|
||||
}{
|
||||
Error: reason,
|
||||
Description: description,
|
||||
URI: uri,
|
||||
State: authReq.GetState(),
|
||||
}
|
||||
callback, err := op.AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), e, authorizer.Encoder())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return callback, nil
|
||||
}
|
||||
|
||||
func CreateCodeCallbackURL(ctx context.Context, authReq op.AuthRequest, authorizer op.Authorizer) (string, error) {
|
||||
code, err := op.CreateAuthRequestCode(ctx, authReq, authorizer.Storage(), authorizer.Crypto())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
codeResponse := struct {
|
||||
code string
|
||||
state string
|
||||
}{
|
||||
code: code,
|
||||
state: authReq.GetState(),
|
||||
}
|
||||
callback, err := op.AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), &codeResponse, authorizer.Encoder())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return callback, err
|
||||
}
|
||||
|
||||
func CreateTokenCallbackURL(ctx context.Context, req op.AuthRequest, authorizer op.Authorizer) (string, error) {
|
||||
client, err := authorizer.Storage().GetClientByClientID(ctx, req.GetClientID())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
createAccessToken := req.GetResponseType() != oidc.ResponseTypeIDTokenOnly
|
||||
resp, err := op.CreateTokenResponse(ctx, req, client, authorizer, createAccessToken, "", "")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
callback, err := op.AuthResponseURL(req.GetRedirectURI(), req.GetResponseType(), req.GetResponseMode(), resp, authorizer.Encoder())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return callback, err
|
||||
}
|
||||
|
@@ -12,20 +12,12 @@ import (
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
http_utils "github.com/zitadel/zitadel/internal/api/http"
|
||||
"github.com/zitadel/zitadel/internal/api/oidc/amr"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/user/model"
|
||||
)
|
||||
|
||||
const (
|
||||
// DEPRECATED: use `amrPWD` instead
|
||||
amrPassword = "password"
|
||||
amrPWD = "pwd"
|
||||
amrMFA = "mfa"
|
||||
amrOTP = "otp"
|
||||
amrUserPresence = "user"
|
||||
)
|
||||
|
||||
type AuthRequest struct {
|
||||
*domain.AuthRequest
|
||||
}
|
||||
@@ -40,19 +32,19 @@ func (a *AuthRequest) GetACR() string {
|
||||
}
|
||||
|
||||
func (a *AuthRequest) GetAMR() []string {
|
||||
amr := make([]string, 0)
|
||||
list := make([]string, 0)
|
||||
if a.PasswordVerified {
|
||||
amr = append(amr, amrPassword, amrPWD)
|
||||
list = append(list, amr.Password, amr.PWD)
|
||||
}
|
||||
if len(a.MFAsVerified) > 0 {
|
||||
amr = append(amr, amrMFA)
|
||||
list = append(list, amr.MFA)
|
||||
for _, mfa := range a.MFAsVerified {
|
||||
if amrMFA := AMRFromMFAType(mfa); amrMFA != "" {
|
||||
amr = append(amr, amrMFA)
|
||||
list = append(list, amrMFA)
|
||||
}
|
||||
}
|
||||
}
|
||||
return amr
|
||||
return list
|
||||
}
|
||||
|
||||
func (a *AuthRequest) GetAudience() []string {
|
||||
@@ -271,10 +263,10 @@ func CodeChallengeToOIDC(challenge *domain.OIDCCodeChallenge) *oidc.CodeChalleng
|
||||
func AMRFromMFAType(mfaType domain.MFAType) string {
|
||||
switch mfaType {
|
||||
case domain.MFATypeOTP:
|
||||
return amrOTP
|
||||
return amr.OTP
|
||||
case domain.MFATypeU2F,
|
||||
domain.MFATypeU2FUserVerification:
|
||||
return amrUserPresence
|
||||
return amr.UserPresence
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
|
106
internal/api/oidc/auth_request_converter_v2.go
Normal file
106
internal/api/oidc/auth_request_converter_v2.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
)
|
||||
|
||||
type AuthRequestV2 struct {
|
||||
*command.CurrentAuthRequest
|
||||
}
|
||||
|
||||
func (a *AuthRequestV2) GetID() string {
|
||||
return a.ID
|
||||
}
|
||||
|
||||
func (a *AuthRequestV2) GetACR() string {
|
||||
return "" //PLANNED: impl
|
||||
}
|
||||
|
||||
func (a *AuthRequestV2) GetAMR() []string {
|
||||
return a.AMR
|
||||
}
|
||||
|
||||
func (a *AuthRequestV2) GetAudience() []string {
|
||||
return a.Audience
|
||||
}
|
||||
|
||||
func (a *AuthRequestV2) GetAuthTime() time.Time {
|
||||
return a.AuthTime
|
||||
}
|
||||
|
||||
func (a *AuthRequestV2) GetClientID() string {
|
||||
return a.ClientID
|
||||
}
|
||||
|
||||
func (a *AuthRequestV2) GetCodeChallenge() *oidc.CodeChallenge {
|
||||
return CodeChallengeToOIDC(a.CodeChallenge)
|
||||
}
|
||||
|
||||
func (a *AuthRequestV2) GetNonce() string {
|
||||
return a.Nonce
|
||||
}
|
||||
|
||||
func (a *AuthRequestV2) GetRedirectURI() string {
|
||||
return a.RedirectURI
|
||||
}
|
||||
|
||||
func (a *AuthRequestV2) GetResponseType() oidc.ResponseType {
|
||||
return ResponseTypeToOIDC(a.ResponseType)
|
||||
}
|
||||
|
||||
func (a *AuthRequestV2) GetResponseMode() oidc.ResponseMode {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (a *AuthRequestV2) GetScopes() []string {
|
||||
return a.Scope
|
||||
}
|
||||
|
||||
func (a *AuthRequestV2) GetState() string {
|
||||
return a.State
|
||||
}
|
||||
|
||||
func (a *AuthRequestV2) GetSubject() string {
|
||||
return a.UserID
|
||||
}
|
||||
|
||||
func (a *AuthRequestV2) Done() bool {
|
||||
return a.UserID != "" && a.SessionID != ""
|
||||
}
|
||||
|
||||
type RefreshTokenRequestV2 struct {
|
||||
*command.OIDCSessionWriteModel
|
||||
RequestedScopes []string
|
||||
}
|
||||
|
||||
func (r *RefreshTokenRequestV2) GetAMR() []string {
|
||||
return r.AuthMethodsReferences
|
||||
}
|
||||
|
||||
func (r *RefreshTokenRequestV2) GetAudience() []string {
|
||||
return r.Audience
|
||||
}
|
||||
|
||||
func (r *RefreshTokenRequestV2) GetAuthTime() time.Time {
|
||||
return r.AuthTime
|
||||
}
|
||||
|
||||
func (r *RefreshTokenRequestV2) GetClientID() string {
|
||||
return r.ClientID
|
||||
}
|
||||
|
||||
func (r *RefreshTokenRequestV2) GetScopes() []string {
|
||||
return r.Scope
|
||||
}
|
||||
|
||||
func (r *RefreshTokenRequestV2) GetSubject() string {
|
||||
return r.UserID
|
||||
}
|
||||
|
||||
func (r *RefreshTokenRequestV2) SetCurrentScopes(scopes []string) {
|
||||
r.RequestedScopes = scopes
|
||||
}
|
275
internal/api/oidc/auth_request_integration_test.go
Normal file
275
internal/api/oidc/auth_request_integration_test.go
Normal file
@@ -0,0 +1,275 @@
|
||||
//go:build integration
|
||||
|
||||
package oidc_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/oidc/v2/pkg/client/rp"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/oidc/amr"
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
"github.com/zitadel/zitadel/internal/integration"
|
||||
oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2alpha"
|
||||
user "github.com/zitadel/zitadel/pkg/grpc/user/v2alpha"
|
||||
)
|
||||
|
||||
var (
|
||||
CTX context.Context
|
||||
CTXLOGIN context.Context
|
||||
Tester *integration.Tester
|
||||
User *user.AddHumanUserResponse
|
||||
)
|
||||
|
||||
const (
|
||||
redirectURI = "oidcIntegrationTest://callback"
|
||||
redirectURIImplicit = "http://localhost:9999/callback"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
os.Exit(func() int {
|
||||
ctx, errCtx, cancel := integration.Contexts(5 * time.Minute)
|
||||
defer cancel()
|
||||
|
||||
Tester = integration.NewTester(ctx)
|
||||
defer Tester.Done()
|
||||
|
||||
CTX, _ = Tester.WithAuthorization(ctx, integration.OrgOwner), errCtx
|
||||
User = Tester.CreateHumanUser(CTX)
|
||||
Tester.RegisterUserPasskey(CTX, User.GetUserId())
|
||||
CTXLOGIN, _ = Tester.WithAuthorization(ctx, integration.Login), errCtx
|
||||
return m.Run()
|
||||
}())
|
||||
}
|
||||
|
||||
func createClient(t testing.TB) string {
|
||||
app, err := Tester.CreateOIDCNativeClient(CTX, redirectURI)
|
||||
require.NoError(t, err)
|
||||
return app.GetClientId()
|
||||
}
|
||||
|
||||
func createImplicitClient(t testing.TB) string {
|
||||
app, err := Tester.CreateOIDCImplicitFlowClient(CTX, redirectURIImplicit)
|
||||
require.NoError(t, err)
|
||||
return app.GetClientId()
|
||||
}
|
||||
|
||||
func createAuthRequest(t testing.TB, clientID, redirectURI string, scope ...string) string {
|
||||
redURL, err := Tester.CreateOIDCAuthRequest(clientID, Tester.Users[integration.FirstInstanceUsersKey][integration.Login].ID, redirectURI, scope...)
|
||||
require.NoError(t, err)
|
||||
return redURL
|
||||
}
|
||||
|
||||
func createAuthRequestImplicit(t testing.TB, clientID, redirectURI string, scope ...string) string {
|
||||
redURL, err := Tester.CreateOIDCAuthRequestImplicit(clientID, Tester.Users[integration.FirstInstanceUsersKey][integration.Login].ID, redirectURI, scope...)
|
||||
require.NoError(t, err)
|
||||
return redURL
|
||||
}
|
||||
|
||||
func TestOPStorage_CreateAuthRequest(t *testing.T) {
|
||||
clientID := createClient(t)
|
||||
|
||||
id := createAuthRequest(t, clientID, redirectURI)
|
||||
require.Contains(t, id, command.IDPrefixV2)
|
||||
}
|
||||
|
||||
func TestOPStorage_CreateAccessToken_code(t *testing.T) {
|
||||
clientID := createClient(t)
|
||||
authRequestID := createAuthRequest(t, clientID, redirectURI)
|
||||
sessionID, sessionToken, startTime, changeTime := Tester.CreatePasskeySession(t, CTXLOGIN, User.GetUserId())
|
||||
linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
|
||||
AuthRequestId: authRequestID,
|
||||
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
|
||||
Session: &oidc_pb.Session{
|
||||
SessionId: sessionID,
|
||||
SessionToken: sessionToken,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// test code exchange
|
||||
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
|
||||
tokens, err := exchangeTokens(t, clientID, code)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, false)
|
||||
assertTokenClaims(t, tokens.IDTokenClaims, startTime, changeTime)
|
||||
|
||||
// callback on a succeeded request must fail
|
||||
linkResp, err = Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
|
||||
AuthRequestId: authRequestID,
|
||||
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
|
||||
Session: &oidc_pb.Session{
|
||||
SessionId: sessionID,
|
||||
SessionToken: sessionToken,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
|
||||
// exchange with a used code must fail
|
||||
_, err = exchangeTokens(t, clientID, code)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestOPStorage_CreateAccessToken_implicit(t *testing.T) {
|
||||
clientID := createImplicitClient(t)
|
||||
authRequestID := createAuthRequestImplicit(t, clientID, redirectURIImplicit)
|
||||
sessionID, sessionToken, startTime, changeTime := Tester.CreatePasskeySession(t, CTXLOGIN, User.GetUserId())
|
||||
linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
|
||||
AuthRequestId: authRequestID,
|
||||
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
|
||||
Session: &oidc_pb.Session{
|
||||
SessionId: sessionID,
|
||||
SessionToken: sessionToken,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// test implicit callback
|
||||
callback, err := url.Parse(linkResp.GetCallbackUrl())
|
||||
require.NoError(t, err)
|
||||
values, err := url.ParseQuery(callback.Fragment)
|
||||
require.NoError(t, err)
|
||||
accessToken := values.Get("access_token")
|
||||
idToken := values.Get("id_token")
|
||||
refreshToken := values.Get("refresh_token")
|
||||
assert.NotEmpty(t, accessToken)
|
||||
assert.NotEmpty(t, idToken)
|
||||
assert.Empty(t, refreshToken)
|
||||
assert.NotEmpty(t, values.Get("expires_in"))
|
||||
assert.Equal(t, oidc.BearerToken, values.Get("token_type"))
|
||||
assert.Equal(t, "state", values.Get("state"))
|
||||
|
||||
// check id_token / claims
|
||||
provider, err := Tester.CreateRelyingParty(clientID, redirectURIImplicit)
|
||||
require.NoError(t, err)
|
||||
claims, err := rp.VerifyTokens[*oidc.IDTokenClaims](context.Background(), accessToken, idToken, provider.IDTokenVerifier())
|
||||
require.NoError(t, err)
|
||||
assertTokenClaims(t, claims, startTime, changeTime)
|
||||
|
||||
// callback on a succeeded request must fail
|
||||
linkResp, err = Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
|
||||
AuthRequestId: authRequestID,
|
||||
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
|
||||
Session: &oidc_pb.Session{
|
||||
SessionId: sessionID,
|
||||
SessionToken: sessionToken,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestOPStorage_CreateAccessAndRefreshTokens_code(t *testing.T) {
|
||||
clientID := createClient(t)
|
||||
authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess)
|
||||
sessionID, sessionToken, startTime, changeTime := Tester.CreatePasskeySession(t, CTXLOGIN, User.GetUserId())
|
||||
linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
|
||||
AuthRequestId: authRequestID,
|
||||
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
|
||||
Session: &oidc_pb.Session{
|
||||
SessionId: sessionID,
|
||||
SessionToken: sessionToken,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// test code exchange (expect refresh token to be returned)
|
||||
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
|
||||
tokens, err := exchangeTokens(t, clientID, code)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, true)
|
||||
assertTokenClaims(t, tokens.IDTokenClaims, startTime, changeTime)
|
||||
}
|
||||
|
||||
func TestOPStorage_CreateAccessAndRefreshTokens_refresh(t *testing.T) {
|
||||
clientID := createClient(t)
|
||||
provider, err := Tester.CreateRelyingParty(clientID, redirectURI)
|
||||
require.NoError(t, err)
|
||||
authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess)
|
||||
sessionID, sessionToken, startTime, changeTime := Tester.CreatePasskeySession(t, CTXLOGIN, User.GetUserId())
|
||||
linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
|
||||
AuthRequestId: authRequestID,
|
||||
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
|
||||
Session: &oidc_pb.Session{
|
||||
SessionId: sessionID,
|
||||
SessionToken: sessionToken,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// code exchange
|
||||
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
|
||||
tokens, err := exchangeTokens(t, clientID, code)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, true)
|
||||
assertTokenClaims(t, tokens.IDTokenClaims, startTime, changeTime)
|
||||
|
||||
// test actual refresh grant
|
||||
newTokens, err := refreshTokens(t, clientID, tokens.RefreshToken)
|
||||
require.NoError(t, err)
|
||||
idToken, _ := newTokens.Extra("id_token").(string)
|
||||
assert.NotEmpty(t, idToken)
|
||||
assert.NotEmpty(t, newTokens.AccessToken)
|
||||
assert.NotEmpty(t, newTokens.RefreshToken)
|
||||
claims, err := rp.VerifyTokens[*oidc.IDTokenClaims](context.Background(), newTokens.AccessToken, idToken, provider.IDTokenVerifier())
|
||||
require.NoError(t, err)
|
||||
// auth time must still be the initial
|
||||
assertTokenClaims(t, claims, startTime, changeTime)
|
||||
|
||||
// refresh with an old refresh_token must fail
|
||||
_, err = rp.RefreshAccessToken(provider, tokens.RefreshToken, "", "")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func exchangeTokens(t testing.TB, clientID, code string) (*oidc.Tokens[*oidc.IDTokenClaims], error) {
|
||||
provider, err := Tester.CreateRelyingParty(clientID, redirectURI)
|
||||
require.NoError(t, err)
|
||||
|
||||
codeVerifier := "codeVerifier"
|
||||
return rp.CodeExchange[*oidc.IDTokenClaims](context.Background(), code, provider, rp.WithCodeVerifier(codeVerifier))
|
||||
}
|
||||
|
||||
func refreshTokens(t testing.TB, clientID, refreshToken string) (*oauth2.Token, error) {
|
||||
provider, err := Tester.CreateRelyingParty(clientID, redirectURI)
|
||||
require.NoError(t, err)
|
||||
|
||||
return rp.RefreshAccessToken(provider, refreshToken, "", "")
|
||||
}
|
||||
|
||||
func assertCodeResponse(t *testing.T, callback string) string {
|
||||
callbackURL, err := url.Parse(callback)
|
||||
require.NoError(t, err)
|
||||
code := callbackURL.Query().Get("code")
|
||||
require.NotEmpty(t, code)
|
||||
assert.Equal(t, "state", callbackURL.Query().Get("state"))
|
||||
return code
|
||||
}
|
||||
|
||||
func assertTokens(t *testing.T, tokens *oidc.Tokens[*oidc.IDTokenClaims], requireRefreshToken bool) {
|
||||
assert.NotEmpty(t, tokens.AccessToken)
|
||||
assert.NotEmpty(t, tokens.IDToken)
|
||||
if requireRefreshToken {
|
||||
assert.NotEmpty(t, tokens.RefreshToken)
|
||||
} else {
|
||||
assert.Empty(t, tokens.RefreshToken)
|
||||
}
|
||||
}
|
||||
|
||||
func assertTokenClaims(t *testing.T, claims *oidc.IDTokenClaims, sessionStart, sessionChange time.Time) {
|
||||
assert.Equal(t, User.GetUserId(), claims.Subject)
|
||||
assert.Equal(t, []string{amr.UserPresence, amr.MFA}, claims.AuthenticationMethodsReferences)
|
||||
assert.WithinRange(t, claims.AuthTime.AsTime().UTC(), sessionStart.Add(-1*time.Second), sessionChange.Add(1*time.Second))
|
||||
}
|
@@ -66,7 +66,7 @@ func (o *OPStorage) GetClientByClientID(ctx context.Context, id string) (_ op.Cl
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ClientFromBusiness(client, o.defaultLoginURL, accessTokenLifetime, idTokenLifetime, allowedScopes)
|
||||
return ClientFromBusiness(client, o.defaultLoginURL, o.defaultLoginURLV2, accessTokenLifetime, idTokenLifetime, allowedScopes)
|
||||
}
|
||||
|
||||
func (o *OPStorage) GetKeyByIDAndClientID(ctx context.Context, keyID, userID string) (_ *jose.JSONWebKey, err error) {
|
||||
@@ -153,7 +153,7 @@ func (o *OPStorage) SetUserinfoFromScopes(ctx context.Context, userInfo *oidc.Us
|
||||
return o.setUserinfo(ctx, userInfo, userID, applicationID, scopes, nil)
|
||||
}
|
||||
|
||||
func (o *OPStorage) SetIntrospectionFromToken(ctx context.Context, introspection *oidc.IntrospectionResponse, tokenID, subject, clientID string) error {
|
||||
func (o *OPStorage) SetIntrospectionFromToken(ctx context.Context, introspection *oidc.IntrospectionResponse, tokenID, subject, clientID string) (err error) {
|
||||
token, err := o.repo.TokenByIDs(ctx, subject, tokenID)
|
||||
if err != nil {
|
||||
return errors.ThrowPermissionDenied(nil, "OIDC-Dsfb2", "token is not valid or has expired")
|
||||
|
@@ -7,6 +7,7 @@ import (
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/op"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
@@ -15,18 +16,20 @@ import (
|
||||
type Client struct {
|
||||
app *query.App
|
||||
defaultLoginURL string
|
||||
defaultLoginURLV2 string
|
||||
defaultAccessTokenLifetime time.Duration
|
||||
defaultIdTokenLifetime time.Duration
|
||||
allowedScopes []string
|
||||
}
|
||||
|
||||
func ClientFromBusiness(app *query.App, defaultLoginURL string, defaultAccessTokenLifetime, defaultIdTokenLifetime time.Duration, allowedScopes []string) (op.Client, error) {
|
||||
func ClientFromBusiness(app *query.App, defaultLoginURL, defaultLoginURLV2 string, defaultAccessTokenLifetime, defaultIdTokenLifetime time.Duration, allowedScopes []string) (op.Client, error) {
|
||||
if app.OIDCConfig == nil {
|
||||
return nil, errors.ThrowInvalidArgument(nil, "OIDC-d5bhD", "client is not a proper oidc application")
|
||||
}
|
||||
return &Client{
|
||||
app: app,
|
||||
defaultLoginURL: defaultLoginURL,
|
||||
defaultLoginURLV2: defaultLoginURLV2,
|
||||
defaultAccessTokenLifetime: defaultAccessTokenLifetime,
|
||||
defaultIdTokenLifetime: defaultIdTokenLifetime,
|
||||
allowedScopes: allowedScopes},
|
||||
@@ -46,6 +49,9 @@ func (c *Client) GetID() string {
|
||||
}
|
||||
|
||||
func (c *Client) LoginURL(id string) string {
|
||||
if strings.HasPrefix(id, command.IDPrefixV2) {
|
||||
return c.defaultLoginURLV2 + id
|
||||
}
|
||||
return c.defaultLoginURL + id
|
||||
}
|
||||
|
||||
|
@@ -41,6 +41,7 @@ type Config struct {
|
||||
Cache *middleware.CacheConfig
|
||||
CustomEndpoints *EndpointConfig
|
||||
DeviceAuth *DeviceAuthorizationConfig
|
||||
DefaultLoginURLV2 string
|
||||
}
|
||||
|
||||
type EndpointConfig struct {
|
||||
@@ -65,6 +66,7 @@ type OPStorage struct {
|
||||
query *query.Queries
|
||||
eventstore *eventstore.Eventstore
|
||||
defaultLoginURL string
|
||||
defaultLoginURLV2 string
|
||||
defaultAccessTokenLifetime time.Duration
|
||||
defaultIdTokenLifetime time.Duration
|
||||
signingKeyAlgorithm string
|
||||
@@ -181,6 +183,7 @@ func newStorage(config Config, command *command.Commands, query *query.Queries,
|
||||
query: query,
|
||||
eventstore: es,
|
||||
defaultLoginURL: fmt.Sprintf("%s%s?%s=", login.HandlerPrefix, login.EndpointLogin, login.QueryAuthRequestID),
|
||||
defaultLoginURLV2: config.DefaultLoginURLV2,
|
||||
signingKeyAlgorithm: config.SigningKeyAlgorithm,
|
||||
defaultAccessTokenLifetime: config.DefaultAccessTokenLifetime,
|
||||
defaultIdTokenLifetime: config.DefaultIdTokenLifetime,
|
||||
|
Reference in New Issue
Block a user