feat(api): add OIDC session service (#6157)

This PR starts the OIDC implementation for the API V2 including the Implicit and Code Flow.


Co-authored-by: Livio Spring <livio.a@gmail.com>
Co-authored-by: Tim Möhlmann <tim+github@zitadel.com>
Co-authored-by: Stefan Benz <46600784+stebenz@users.noreply.github.com>
This commit is contained in:
Livio Spring 2023-07-10 15:27:00 +02:00 committed by GitHub
parent be1fe36776
commit 14b8cf4894
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
69 changed files with 5948 additions and 106 deletions

View File

@ -43,7 +43,7 @@ jobs:
go run main.go init --config internal/integration/config/zitadel.yaml --config internal/integration/config/${INTEGRATION_DB_FLAVOR}.yaml
go run main.go setup --masterkeyFromEnv --config internal/integration/config/zitadel.yaml --config internal/integration/config/${INTEGRATION_DB_FLAVOR}.yaml
- name: Run integration tests
run: go test -tags=integration -race -p 1 -v -coverprofile=profile.cov -coverpkg=./internal/...,./cmd/... ./internal/integration ./internal/api/grpc/... ./internal/notification/handlers/...
run: go test -tags=integration -race -p 1 -v -coverprofile=profile.cov -coverpkg=./internal/...,./cmd/... ./internal/integration ./internal/api/grpc/... ./internal/notification/handlers/... ./internal/api/oidc
- name: Publish go coverage
uses: codecov/codecov-action@v3.1.0
with:

View File

@ -224,7 +224,7 @@ export INTEGRATION_DB_FLAVOR="cockroach" ZITADEL_MASTERKEY="MasterkeyNeedsToHave
docker compose -f internal/integration/config/docker-compose.yaml up --wait ${INTEGRATION_DB_FLAVOR}
go run main.go init --config internal/integration/config/zitadel.yaml --config internal/integration/config/${INTEGRATION_DB_FLAVOR}.yaml
go run main.go setup --masterkeyFromEnv --config internal/integration/config/zitadel.yaml --config internal/integration/config/${INTEGRATION_DB_FLAVOR}.yaml
go test -count 1 -tags=integration -race -p 1 ./internal/integration ./internal/api/grpc/...
go test -count 1 -tags=integration -race -p 1 ./internal/integration ./internal/api/grpc/... ./internal/api/oidc
docker compose -f internal/integration/config/docker-compose.yaml down
```

View File

@ -108,4 +108,16 @@ protoc \
--validate_out=lang=go:${GOPATH}/src \
${PROTO_PATH}/settings/v2alpha/settings_service.proto
protoc \
-I=/proto/include \
--grpc-gateway_out ${GOPATH}/src \
--grpc-gateway_opt logtostderr=true \
--grpc-gateway_opt allow_delete_body=true \
--openapiv2_out ${OPENAPI_PATH} \
--openapiv2_opt logtostderr=true \
--openapiv2_opt allow_delete_body=true \
--zitadel_out=${GOPATH}/src \
--validate_out=lang=go:${GOPATH}/src \
${PROTO_PATH}/oidc/v2alpha/oidc_service.proto
echo "done generating grpc"

View File

@ -270,6 +270,7 @@ OIDC:
Path: /oauth/v2/keys
DeviceAuth:
Path: /oauth/v2/device_authorization
DefaultLoginURLV2: "/login?authRequest="
SAML:
ProviderConfig:

View File

@ -88,6 +88,9 @@ func (mig *FirstInstance) Execute(ctx context.Context) error {
nil,
nil,
nil,
0,
0,
0,
)
if err != nil {
return err

View File

@ -53,6 +53,9 @@ func (mig *externalConfigChange) Execute(ctx context.Context) error {
nil,
nil,
nil,
0,
0,
0,
)
if err != nil {

View File

@ -32,6 +32,7 @@ import (
"github.com/zitadel/zitadel/internal/api/grpc/admin"
"github.com/zitadel/zitadel/internal/api/grpc/auth"
"github.com/zitadel/zitadel/internal/api/grpc/management"
oidc_v2 "github.com/zitadel/zitadel/internal/api/grpc/oidc/v2"
"github.com/zitadel/zitadel/internal/api/grpc/session/v2"
"github.com/zitadel/zitadel/internal/api/grpc/settings/v2"
"github.com/zitadel/zitadel/internal/api/grpc/system"
@ -192,6 +193,9 @@ func startZitadel(config *Config, masterKey string, server chan<- *Server) error
&http.Client{},
permissionCheck,
sessionTokenVerifier,
config.OIDC.DefaultAccessTokenLifetime,
config.OIDC.DefaultRefreshTokenExpiration,
config.OIDC.DefaultRefreshTokenIdleExpiration,
)
if err != nil {
return fmt.Errorf("cannot start commands: %w", err)
@ -344,6 +348,7 @@ func startAPIs(
if err := apis.RegisterService(ctx, session.CreateServer(commands, queries, permissionCheck)); err != nil {
return err
}
if err := apis.RegisterService(ctx, settings.CreateServer(commands, queries, config.ExternalSecure)); err != nil {
return err
}
@ -397,6 +402,11 @@ func startAPIs(
apis.RegisterHandlerOnPrefix(login.HandlerPrefix, l.Handler())
apis.HandleFunc(login.EndpointDeviceAuth, login.RedirectDeviceAuthToPrefix)
// After OIDC provider so that the callback endpoint can be used
if err := apis.RegisterService(ctx, oidc_v2.CreateServer(commands, queries, oidcProvider, config.ExternalSecure)); err != nil {
return err
}
// handle grpc at last to be able to handle the root, because grpc and gateway require a lot of different prefixes
apis.RouteGRPC()
return nil

View File

@ -274,6 +274,13 @@ module.exports = {
groupPathsBy: "tag",
},
},
oidc: {
specPath: ".artifacts/openapi/zitadel/oidc/v2alpha/oidc_service.swagger.json",
outputDir: "docs/apis/resources/oidc_service",
sidebarOptions: {
groupPathsBy: "tag",
},
},
settings: {
specPath: ".artifacts/openapi/zitadel/settings/v2alpha/settings_service.swagger.json",
outputDir: "docs/apis/resources/settings_service",

View File

@ -484,6 +484,20 @@ module.exports = {
},
items: require("./docs/apis/resources/session_service/sidebar.js"),
},
{
type: "category",
label: "OIDC lifecycle (Alpha)",
link: {
type: "generated-index",
title: "OIDC service API (Alpha)",
slug: "/apis/resources/oidc_service",
description:
"Get OIDC Auth Request details and create callback URLs.\n"+
"\n"+
"This project is in alpha state. It can AND will continue breaking until the services provide the same functionality as the current login.",
},
items: require("./docs/apis/resources/oidc_service/sidebar.js"),
},
{
type: "category",
label: "Settings lifecycle (alpha)",

View File

@ -4,11 +4,11 @@ import "context"
func NewMockContext(instanceID, orgID, userID string) context.Context {
ctx := context.WithValue(context.Background(), dataKey, CtxData{UserID: userID, OrgID: orgID})
return context.WithValue(ctx, instanceKey, instanceID)
return context.WithValue(ctx, instanceKey, &instance{id: instanceID})
}
func NewMockContextWithPermissions(instanceID, orgID, userID string, permissions []string) context.Context {
ctx := context.WithValue(context.Background(), dataKey, CtxData{UserID: userID, OrgID: orgID})
ctx = context.WithValue(ctx, instanceKey, instanceID)
ctx = context.WithValue(ctx, instanceKey, &instance{id: instanceID})
return context.WithValue(ctx, requestPermissionsKey, permissions)
}

View File

@ -0,0 +1,204 @@
package oidc
import (
"context"
"github.com/zitadel/logging"
"github.com/zitadel/oidc/v2/pkg/op"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/grpc/object/v2"
"github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/api/oidc"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query"
oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2alpha"
)
func (s *Server) GetAuthRequest(ctx context.Context, req *oidc_pb.GetAuthRequestRequest) (*oidc_pb.GetAuthRequestResponse, error) {
authRequest, err := s.query.AuthRequestByID(ctx, true, req.GetAuthRequestId(), true)
if err != nil {
logging.WithError(err).Error("query authRequest by ID")
return nil, err
}
return &oidc_pb.GetAuthRequestResponse{
AuthRequest: authRequestToPb(authRequest),
}, nil
}
func authRequestToPb(a *query.AuthRequest) *oidc_pb.AuthRequest {
pba := &oidc_pb.AuthRequest{
Id: a.ID,
CreationDate: timestamppb.New(a.CreationDate),
ClientId: a.ClientID,
Scope: a.Scope,
RedirectUri: a.RedirectURI,
Prompt: promptsToPb(a.Prompt),
UiLocales: a.UiLocales,
LoginHint: a.LoginHint,
HintUserId: a.HintUserID,
}
if a.MaxAge != nil {
pba.MaxAge = durationpb.New(*a.MaxAge)
}
return pba
}
func promptsToPb(promps []domain.Prompt) []oidc_pb.Prompt {
out := make([]oidc_pb.Prompt, len(promps))
for i, p := range promps {
out[i] = promptToPb(p)
}
return out
}
func promptToPb(p domain.Prompt) oidc_pb.Prompt {
switch p {
case domain.PromptUnspecified:
return oidc_pb.Prompt_PROMPT_UNSPECIFIED
case domain.PromptNone:
return oidc_pb.Prompt_PROMPT_NONE
case domain.PromptLogin:
return oidc_pb.Prompt_PROMPT_LOGIN
case domain.PromptConsent:
return oidc_pb.Prompt_PROMPT_CONSENT
case domain.PromptSelectAccount:
return oidc_pb.Prompt_PROMPT_SELECT_ACCOUNT
case domain.PromptCreate:
return oidc_pb.Prompt_PROMPT_CREATE
default:
return oidc_pb.Prompt_PROMPT_UNSPECIFIED
}
}
func (s *Server) CreateCallback(ctx context.Context, req *oidc_pb.CreateCallbackRequest) (*oidc_pb.CreateCallbackResponse, error) {
switch v := req.GetCallbackKind().(type) {
case *oidc_pb.CreateCallbackRequest_Error:
return s.failAuthRequest(ctx, req.GetAuthRequestId(), v.Error)
case *oidc_pb.CreateCallbackRequest_Session:
return s.linkSessionToAuthRequest(ctx, req.GetAuthRequestId(), v.Session)
default:
return nil, errors.ThrowUnimplementedf(nil, "OIDCv2-zee7A", "verification oneOf %T in method CreateCallback not implemented", v)
}
}
func (s *Server) failAuthRequest(ctx context.Context, authRequestID string, ae *oidc_pb.AuthorizationError) (*oidc_pb.CreateCallbackResponse, error) {
details, aar, err := s.command.FailAuthRequest(ctx, authRequestID, errorReasonToDomain(ae.GetError()))
if err != nil {
return nil, err
}
authReq := &oidc.AuthRequestV2{CurrentAuthRequest: aar}
callback, err := oidc.CreateErrorCallbackURL(authReq, errorReasonToOIDC(ae.GetError()), ae.GetErrorDescription(), ae.GetErrorUri(), s.op)
if err != nil {
return nil, err
}
return &oidc_pb.CreateCallbackResponse{
Details: object.DomainToDetailsPb(details),
CallbackUrl: callback,
}, nil
}
func (s *Server) linkSessionToAuthRequest(ctx context.Context, authRequestID string, session *oidc_pb.Session) (*oidc_pb.CreateCallbackResponse, error) {
details, aar, err := s.command.LinkSessionToAuthRequest(ctx, authRequestID, session.GetSessionId(), session.GetSessionToken(), true)
if err != nil {
return nil, err
}
authReq := &oidc.AuthRequestV2{CurrentAuthRequest: aar}
ctx = op.ContextWithIssuer(ctx, http.BuildOrigin(authz.GetInstance(ctx).RequestedHost(), s.externalSecure))
var callback string
if aar.ResponseType == domain.OIDCResponseTypeCode {
callback, err = oidc.CreateCodeCallbackURL(ctx, authReq, s.op)
} else {
callback, err = oidc.CreateTokenCallbackURL(ctx, authReq, s.op)
}
if err != nil {
return nil, err
}
return &oidc_pb.CreateCallbackResponse{
Details: object.DomainToDetailsPb(details),
CallbackUrl: callback,
}, nil
}
func errorReasonToDomain(errorReason oidc_pb.ErrorReason) domain.OIDCErrorReason {
switch errorReason {
case oidc_pb.ErrorReason_ERROR_REASON_UNSPECIFIED:
return domain.OIDCErrorReasonUnspecified
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST:
return domain.OIDCErrorReasonInvalidRequest
case oidc_pb.ErrorReason_ERROR_REASON_UNAUTHORIZED_CLIENT:
return domain.OIDCErrorReasonUnauthorizedClient
case oidc_pb.ErrorReason_ERROR_REASON_ACCESS_DENIED:
return domain.OIDCErrorReasonAccessDenied
case oidc_pb.ErrorReason_ERROR_REASON_UNSUPPORTED_RESPONSE_TYPE:
return domain.OIDCErrorReasonUnsupportedResponseType
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_SCOPE:
return domain.OIDCErrorReasonInvalidScope
case oidc_pb.ErrorReason_ERROR_REASON_SERVER_ERROR:
return domain.OIDCErrorReasonServerError
case oidc_pb.ErrorReason_ERROR_REASON_TEMPORARY_UNAVAILABLE:
return domain.OIDCErrorReasonTemporaryUnavailable
case oidc_pb.ErrorReason_ERROR_REASON_INTERACTION_REQUIRED:
return domain.OIDCErrorReasonInteractionRequired
case oidc_pb.ErrorReason_ERROR_REASON_LOGIN_REQUIRED:
return domain.OIDCErrorReasonLoginRequired
case oidc_pb.ErrorReason_ERROR_REASON_ACCOUNT_SELECTION_REQUIRED:
return domain.OIDCErrorReasonAccountSelectionRequired
case oidc_pb.ErrorReason_ERROR_REASON_CONSENT_REQUIRED:
return domain.OIDCErrorReasonConsentRequired
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_URI:
return domain.OIDCErrorReasonInvalidRequestURI
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_OBJECT:
return domain.OIDCErrorReasonInvalidRequestObject
case oidc_pb.ErrorReason_ERROR_REASON_REQUEST_NOT_SUPPORTED:
return domain.OIDCErrorReasonRequestNotSupported
case oidc_pb.ErrorReason_ERROR_REASON_REQUEST_URI_NOT_SUPPORTED:
return domain.OIDCErrorReasonRequestURINotSupported
case oidc_pb.ErrorReason_ERROR_REASON_REGISTRATION_NOT_SUPPORTED:
return domain.OIDCErrorReasonRegistrationNotSupported
default:
return domain.OIDCErrorReasonUnspecified
}
}
func errorReasonToOIDC(reason oidc_pb.ErrorReason) string {
switch reason {
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST:
return "invalid_request"
case oidc_pb.ErrorReason_ERROR_REASON_UNAUTHORIZED_CLIENT:
return "unauthorized_client"
case oidc_pb.ErrorReason_ERROR_REASON_ACCESS_DENIED:
return "access_denied"
case oidc_pb.ErrorReason_ERROR_REASON_UNSUPPORTED_RESPONSE_TYPE:
return "unsupported_response_type"
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_SCOPE:
return "invalid_scope"
case oidc_pb.ErrorReason_ERROR_REASON_TEMPORARY_UNAVAILABLE:
return "temporarily_unavailable"
case oidc_pb.ErrorReason_ERROR_REASON_INTERACTION_REQUIRED:
return "interaction_required"
case oidc_pb.ErrorReason_ERROR_REASON_LOGIN_REQUIRED:
return "login_required"
case oidc_pb.ErrorReason_ERROR_REASON_ACCOUNT_SELECTION_REQUIRED:
return "account_selection_required"
case oidc_pb.ErrorReason_ERROR_REASON_CONSENT_REQUIRED:
return "consent_required"
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_URI:
return "invalid_request_uri"
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_OBJECT:
return "invalid_request_object"
case oidc_pb.ErrorReason_ERROR_REASON_REQUEST_NOT_SUPPORTED:
return "request_not_supported"
case oidc_pb.ErrorReason_ERROR_REASON_REQUEST_URI_NOT_SUPPORTED:
return "request_uri_not_supported"
case oidc_pb.ErrorReason_ERROR_REASON_REGISTRATION_NOT_SUPPORTED:
return "registration_not_supported"
case oidc_pb.ErrorReason_ERROR_REASON_UNSPECIFIED, oidc_pb.ErrorReason_ERROR_REASON_SERVER_ERROR:
fallthrough
default:
return "server_error"
}
}

View File

@ -0,0 +1,249 @@
//go:build integration
package oidc_test
import (
"context"
"net/url"
"os"
"regexp"
"testing"
"time"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/integration"
object "github.com/zitadel/zitadel/pkg/grpc/object/v2alpha"
oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2alpha"
session "github.com/zitadel/zitadel/pkg/grpc/session/v2alpha"
user "github.com/zitadel/zitadel/pkg/grpc/user/v2alpha"
)
var (
CTX context.Context
Tester *integration.Tester
Client oidc_pb.OIDCServiceClient
User *user.AddHumanUserResponse
)
const (
redirectURI = "oidcIntegrationTest://callback"
redirectURIImplicit = "http://localhost:9999/callback"
)
func TestMain(m *testing.M) {
os.Exit(func() int {
ctx, _, cancel := integration.Contexts(5 * time.Minute)
defer cancel()
Tester = integration.NewTester(ctx)
defer Tester.Done()
Client = Tester.Client.OIDCv2
CTX = Tester.WithAuthorization(ctx, integration.OrgOwner)
User = Tester.CreateHumanUser(CTX)
return m.Run()
}())
}
func TestServer_GetAuthRequest(t *testing.T) {
client, err := Tester.CreateOIDCNativeClient(CTX, redirectURI)
require.NoError(t, err)
authRequestID, err := Tester.CreateOIDCAuthRequest(client.GetClientId(), Tester.Users[integration.FirstInstanceUsersKey][integration.OrgOwner].ID, redirectURI)
require.NoError(t, err)
now := time.Now()
tests := []struct {
name string
AuthRequestID string
want *oidc_pb.GetAuthRequestResponse
wantErr bool
}{
{
name: "Not found",
AuthRequestID: "123",
wantErr: true,
},
{
name: "success",
AuthRequestID: authRequestID,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := Client.GetAuthRequest(CTX, &oidc_pb.GetAuthRequestRequest{
AuthRequestId: tt.AuthRequestID,
})
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
authRequest := got.GetAuthRequest()
assert.NotNil(t, authRequest)
assert.Equal(t, authRequestID, authRequest.GetId())
assert.WithinRange(t, authRequest.GetCreationDate().AsTime(), now.Add(-time.Second), now.Add(time.Second))
assert.Contains(t, authRequest.GetScope(), "openid")
})
}
}
func TestServer_CreateCallback(t *testing.T) {
client, err := Tester.CreateOIDCNativeClient(CTX, redirectURI)
require.NoError(t, err)
sessionResp, err := Tester.Client.SessionV2.CreateSession(CTX, &session.CreateSessionRequest{
Checks: &session.Checks{
User: &session.CheckUser{
Search: &session.CheckUser_UserId{
UserId: Tester.Users[integration.FirstInstanceUsersKey][integration.OrgOwner].ID,
},
},
},
})
require.NoError(t, err)
tests := []struct {
name string
req *oidc_pb.CreateCallbackRequest
AuthError string
want *oidc_pb.CreateCallbackResponse
wantURL *url.URL
wantErr bool
}{
{
name: "Not found",
req: &oidc_pb.CreateCallbackRequest{
AuthRequestId: "123",
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
Session: &oidc_pb.Session{
SessionId: sessionResp.GetSessionId(),
SessionToken: sessionResp.GetSessionToken(),
},
},
},
wantErr: true,
},
{
name: "session not found",
req: &oidc_pb.CreateCallbackRequest{
AuthRequestId: func() string {
authRequestID, err := Tester.CreateOIDCAuthRequest(client.GetClientId(), Tester.Users[integration.FirstInstanceUsersKey][integration.OrgOwner].ID, redirectURI)
require.NoError(t, err)
return authRequestID
}(),
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
Session: &oidc_pb.Session{
SessionId: "foo",
SessionToken: "bar",
},
},
},
wantErr: true,
},
{
name: "session token invalid",
req: &oidc_pb.CreateCallbackRequest{
AuthRequestId: func() string {
authRequestID, err := Tester.CreateOIDCAuthRequest(client.GetClientId(), Tester.Users[integration.FirstInstanceUsersKey][integration.OrgOwner].ID, redirectURI)
require.NoError(t, err)
return authRequestID
}(),
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
Session: &oidc_pb.Session{
SessionId: sessionResp.GetSessionId(),
SessionToken: "bar",
},
},
},
wantErr: true,
},
{
name: "fail callback",
req: &oidc_pb.CreateCallbackRequest{
AuthRequestId: func() string {
authRequestID, err := Tester.CreateOIDCAuthRequest(client.GetClientId(), Tester.Users[integration.FirstInstanceUsersKey][integration.OrgOwner].ID, redirectURI)
require.NoError(t, err)
return authRequestID
}(),
CallbackKind: &oidc_pb.CreateCallbackRequest_Error{
Error: &oidc_pb.AuthorizationError{
Error: oidc_pb.ErrorReason_ERROR_REASON_ACCESS_DENIED,
ErrorDescription: gu.Ptr("nope"),
ErrorUri: gu.Ptr("https://example.com/docs"),
},
},
},
want: &oidc_pb.CreateCallbackResponse{
CallbackUrl: regexp.QuoteMeta(`oidcintegrationtest://callback?error=access_denied&error_description=nope&error_uri=https%3A%2F%2Fexample.com%2Fdocs&state=state`),
Details: &object.Details{
ResourceOwner: Tester.Instance.InstanceID(),
},
},
wantErr: false,
},
{
name: "code callback",
req: &oidc_pb.CreateCallbackRequest{
AuthRequestId: func() string {
authRequestID, err := Tester.CreateOIDCAuthRequest(client.GetClientId(), Tester.Users[integration.FirstInstanceUsersKey][integration.OrgOwner].ID, redirectURI)
require.NoError(t, err)
return authRequestID
}(),
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
Session: &oidc_pb.Session{
SessionId: sessionResp.GetSessionId(),
SessionToken: sessionResp.GetSessionToken(),
},
},
},
want: &oidc_pb.CreateCallbackResponse{
CallbackUrl: `oidcintegrationtest:\/\/callback\?code=(.*)&state=state`,
Details: &object.Details{
ResourceOwner: Tester.Instance.InstanceID(),
},
},
wantErr: false,
},
{
name: "implicit",
req: &oidc_pb.CreateCallbackRequest{
AuthRequestId: func() string {
client, err := Tester.CreateOIDCImplicitFlowClient(CTX, redirectURIImplicit)
require.NoError(t, err)
authRequestID, err := Tester.CreateOIDCAuthRequestImplicit(client.GetClientId(), Tester.Users[integration.FirstInstanceUsersKey][integration.OrgOwner].ID, redirectURIImplicit)
require.NoError(t, err)
return authRequestID
}(),
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
Session: &oidc_pb.Session{
SessionId: sessionResp.GetSessionId(),
SessionToken: sessionResp.GetSessionToken(),
},
},
},
want: &oidc_pb.CreateCallbackResponse{
CallbackUrl: `http:\/\/localhost:9999\/callback#access_token=(.*)&expires_in=(.*)&id_token=(.*)&state=state&token_type=Bearer`,
Details: &object.Details{
ResourceOwner: Tester.Instance.InstanceID(),
},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := Client.CreateCallback(CTX, tt.req)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
integration.AssertDetails(t, tt.want, got)
if tt.want != nil {
assert.Regexp(t, regexp.MustCompile(tt.want.CallbackUrl), got.GetCallbackUrl())
}
})
}
}

View File

@ -0,0 +1,150 @@
package oidc
import (
"testing"
"time"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/query"
oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2alpha"
)
func Test_authRequestToPb(t *testing.T) {
now := time.Now()
arg := &query.AuthRequest{
ID: "authID",
CreationDate: now,
ClientID: "clientID",
Scope: []string{"a", "b", "c"},
RedirectURI: "callbackURI",
Prompt: []domain.Prompt{
domain.PromptUnspecified,
domain.PromptNone,
domain.PromptLogin,
domain.PromptConsent,
domain.PromptSelectAccount,
domain.PromptCreate,
999,
},
UiLocales: []string{"en", "fi"},
LoginHint: gu.Ptr("foo@bar.com"),
MaxAge: gu.Ptr(time.Minute),
HintUserID: gu.Ptr("userID"),
}
want := &oidc_pb.AuthRequest{
Id: "authID",
CreationDate: timestamppb.New(now),
ClientId: "clientID",
RedirectUri: "callbackURI",
Prompt: []oidc_pb.Prompt{
oidc_pb.Prompt_PROMPT_UNSPECIFIED,
oidc_pb.Prompt_PROMPT_NONE,
oidc_pb.Prompt_PROMPT_LOGIN,
oidc_pb.Prompt_PROMPT_CONSENT,
oidc_pb.Prompt_PROMPT_SELECT_ACCOUNT,
oidc_pb.Prompt_PROMPT_CREATE,
oidc_pb.Prompt_PROMPT_UNSPECIFIED,
},
UiLocales: []string{"en", "fi"},
Scope: []string{"a", "b", "c"},
LoginHint: gu.Ptr("foo@bar.com"),
MaxAge: durationpb.New(time.Minute),
HintUserId: gu.Ptr("userID"),
}
got := authRequestToPb(arg)
if !proto.Equal(want, got) {
t.Errorf("authRequestToPb() =\n%v\nwant\n%v\n", got, want)
}
}
func Test_errorReasonToOIDC(t *testing.T) {
tests := []struct {
reason oidc_pb.ErrorReason
want string
}{
{
reason: oidc_pb.ErrorReason_ERROR_REASON_UNSPECIFIED,
want: "server_error",
},
{
reason: oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST,
want: "invalid_request",
},
{
reason: oidc_pb.ErrorReason_ERROR_REASON_UNAUTHORIZED_CLIENT,
want: "unauthorized_client",
},
{
reason: oidc_pb.ErrorReason_ERROR_REASON_ACCESS_DENIED,
want: "access_denied",
},
{
reason: oidc_pb.ErrorReason_ERROR_REASON_UNSUPPORTED_RESPONSE_TYPE,
want: "unsupported_response_type",
},
{
reason: oidc_pb.ErrorReason_ERROR_REASON_INVALID_SCOPE,
want: "invalid_scope",
},
{
reason: oidc_pb.ErrorReason_ERROR_REASON_SERVER_ERROR,
want: "server_error",
},
{
reason: oidc_pb.ErrorReason_ERROR_REASON_TEMPORARY_UNAVAILABLE,
want: "temporarily_unavailable",
},
{
reason: oidc_pb.ErrorReason_ERROR_REASON_INTERACTION_REQUIRED,
want: "interaction_required",
},
{
reason: oidc_pb.ErrorReason_ERROR_REASON_LOGIN_REQUIRED,
want: "login_required",
},
{
reason: oidc_pb.ErrorReason_ERROR_REASON_ACCOUNT_SELECTION_REQUIRED,
want: "account_selection_required",
},
{
reason: oidc_pb.ErrorReason_ERROR_REASON_CONSENT_REQUIRED,
want: "consent_required",
},
{
reason: oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_URI,
want: "invalid_request_uri",
},
{
reason: oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_OBJECT,
want: "invalid_request_object",
},
{
reason: oidc_pb.ErrorReason_ERROR_REASON_REQUEST_NOT_SUPPORTED,
want: "request_not_supported",
},
{
reason: oidc_pb.ErrorReason_ERROR_REASON_REQUEST_URI_NOT_SUPPORTED,
want: "request_uri_not_supported",
},
{
reason: oidc_pb.ErrorReason_ERROR_REASON_REGISTRATION_NOT_SUPPORTED,
want: "registration_not_supported",
},
{
reason: 99999,
want: "server_error",
},
}
for _, tt := range tests {
t.Run(tt.reason.String(), func(t *testing.T) {
got := errorReasonToOIDC(tt.reason)
assert.Equal(t, tt.want, got)
})
}
}

View File

@ -0,0 +1,59 @@
package oidc
import (
"github.com/zitadel/oidc/v2/pkg/op"
"google.golang.org/grpc"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/grpc/server"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/query"
oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2alpha"
)
var _ oidc_pb.OIDCServiceServer = (*Server)(nil)
type Server struct {
oidc_pb.UnimplementedOIDCServiceServer
command *command.Commands
query *query.Queries
op op.OpenIDProvider
externalSecure bool
}
type Config struct{}
func CreateServer(
command *command.Commands,
query *query.Queries,
op op.OpenIDProvider,
externalSecure bool,
) *Server {
return &Server{
command: command,
query: query,
op: op,
externalSecure: externalSecure,
}
}
func (s *Server) RegisterServer(grpcServer *grpc.Server) {
oidc_pb.RegisterOIDCServiceServer(grpcServer, s)
}
func (s *Server) AppName() string {
return oidc_pb.OIDCService_ServiceDesc.ServiceName
}
func (s *Server) MethodPrefix() string {
return oidc_pb.OIDCService_ServiceDesc.ServiceName
}
func (s *Server) AuthMethods() authz.MethodMapping {
return oidc_pb.OIDCService_AuthMethods
}
func (s *Server) RegisterGateway() server.RegisterGatewayFunc {
return oidc_pb.RegisterOIDCServiceHandler
}

View File

@ -18,11 +18,10 @@ import (
)
var (
CTX context.Context
Tester *integration.Tester
Client session.SessionServiceClient
User *user.AddHumanUserResponse
GenericOAuthIDPID string
CTX context.Context
Tester *integration.Tester
Client session.SessionServiceClient
User *user.AddHumanUserResponse
)
func TestMain(m *testing.M) {

View File

@ -540,7 +540,7 @@ func TestServer_StartIdentityProviderFlow(t *testing.T) {
ResourceOwner: Tester.Organisation.ID,
},
NextStep: &user.StartIdentityProviderFlowResponse_AuthUrl{
AuthUrl: "https://example.com/oauth/v2/authorize?client_id=clientID&prompt=select_account&redirect_uri=https%3A%2F%2Flocalhost%3A8080%2Fidps%2Fcallback&response_type=code&scope=openid+profile+email&state=",
AuthUrl: "https://example.com/oauth/v2/authorize?client_id=clientID&prompt=select_account&redirect_uri=http%3A%2F%2Flocalhost%3A8080%2Fidps%2Fcallback&response_type=code&scope=openid+profile+email&state=",
},
},
wantErr: false,

View File

@ -0,0 +1,43 @@
// Package amr maps zitadel session factors to Authentication Method Reference Values
// as defined in [RFC 8176, section 2].
//
// [RFC 8176, section 2]: https://datatracker.ietf.org/doc/html/rfc8176#section-2
package amr
const (
// Password states that the users password has been verified
// Deprecated: use `PWD` instead
Password = "password"
// PWD states that the users password has been verified
PWD = "pwd"
// MFA states that multiple factors have been verified (e.g. pwd and otp or passkey)
MFA = "mfa"
// OTP states that a one time password has been verified (e.g. TOTP)
OTP = "otp"
// UserPresence states that the end users presence has been verified (e.g. passkey and u2f)
UserPresence = "user"
)
type AuthenticationMethodReference interface {
IsPasswordChecked() bool
IsPasskeyChecked() bool
IsU2FChecked() bool
IsOTPChecked() bool
}
func List(model AuthenticationMethodReference) []string {
amr := make([]string, 0)
if model.IsPasswordChecked() {
amr = append(amr, PWD)
}
if model.IsPasskeyChecked() || model.IsU2FChecked() {
amr = append(amr, UserPresence)
}
if model.IsOTPChecked() {
amr = append(amr, OTP)
}
if model.IsPasskeyChecked() || len(amr) >= 2 {
amr = append(amr, MFA)
}
return amr
}

View File

@ -0,0 +1,93 @@
package amr
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestAMR(t *testing.T) {
type args struct {
model AuthenticationMethodReference
}
tests := []struct {
name string
args args
want []string
}{
{
"no checks, empty",
args{
new(test),
},
[]string{},
},
{
"pw checked",
args{
&test{pwChecked: true},
},
[]string{PWD},
},
{
"passkey checked",
args{
&test{passkeyChecked: true},
},
[]string{UserPresence, MFA},
},
{
"u2f checked",
args{
&test{u2fChecked: true},
},
[]string{UserPresence},
},
{
"otp checked",
args{
&test{otpChecked: true},
},
[]string{OTP},
},
{
"multiple checked",
args{
&test{
pwChecked: true,
u2fChecked: true,
},
},
[]string{PWD, UserPresence, MFA},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := List(tt.args.model)
assert.Equal(t, tt.want, got)
})
}
}
type test struct {
pwChecked bool
passkeyChecked bool
u2fChecked bool
otpChecked bool
}
func (t test) IsPasswordChecked() bool {
return t.pwChecked
}
func (t test) IsPasskeyChecked() bool {
return t.passkeyChecked
}
func (t test) IsU2FChecked() bool {
return t.u2fChecked
}
func (t test) IsOTPChecked() bool {
return t.otpChecked
}

View File

@ -2,6 +2,7 @@ package oidc
import (
"context"
"encoding/base64"
"strings"
"time"
@ -10,16 +11,75 @@ import (
"github.com/zitadel/oidc/v2/pkg/op"
"github.com/zitadel/zitadel/internal/api/authz"
http_utils "github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/api/http/middleware"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/user/model"
)
const (
LoginClientHeader = "x-zitadel-login-client"
)
func (o *OPStorage) CreateAuthRequest(ctx context.Context, req *oidc.AuthRequest, userID string) (_ op.AuthRequest, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
headers, _ := http_utils.HeadersFromCtx(ctx)
if loginClient := headers.Get(LoginClientHeader); loginClient != "" {
return o.createAuthRequestLoginClient(ctx, req, userID, loginClient)
}
return o.createAuthRequest(ctx, req, userID)
}
func (o *OPStorage) createAuthRequestLoginClient(ctx context.Context, req *oidc.AuthRequest, hintUserID, loginClient string) (op.AuthRequest, error) {
project, err := o.query.ProjectByClientID(ctx, req.ClientID, false)
if err != nil {
return nil, err
}
scope, err := o.assertProjectRoleScopesByProject(ctx, project, req.Scopes)
if err != nil {
return nil, err
}
audience, err := o.audienceFromProjectID(ctx, project.ID)
if err != nil {
return nil, err
}
audience = domain.AddAudScopeToAudience(ctx, audience, scope)
authRequest := &command.AuthRequest{
LoginClient: loginClient,
ClientID: req.ClientID,
RedirectURI: req.RedirectURI,
State: req.State,
Nonce: req.Nonce,
Scope: scope,
Audience: audience,
ResponseType: ResponseTypeToBusiness(req.ResponseType),
CodeChallenge: CodeChallengeToBusiness(req.CodeChallenge, req.CodeChallengeMethod),
Prompt: PromptToBusiness(req.Prompt),
UILocales: UILocalesToBusiness(req.UILocales),
MaxAge: MaxAgeToBusiness(req.MaxAge),
}
if req.LoginHint != "" {
authRequest.LoginHint = &req.LoginHint
}
if hintUserID != "" {
authRequest.HintUserID = &hintUserID
}
aar, err := o.command.AddAuthRequest(ctx, authRequest)
if err != nil {
return nil, err
}
return &AuthRequestV2{aar}, nil
}
func (o *OPStorage) createAuthRequest(ctx context.Context, req *oidc.AuthRequest, userID string) (_ op.AuthRequest, err error) {
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
if !ok {
return nil, errors.ThrowPreconditionFailed(nil, "OIDC-sd436", "no user agent id")
@ -36,9 +96,31 @@ func (o *OPStorage) CreateAuthRequest(ctx context.Context, req *oidc.AuthRequest
return AuthRequestFromBusiness(resp)
}
func (o *OPStorage) audienceFromProjectID(ctx context.Context, projectID string) ([]string, error) {
projectIDQuery, err := query.NewAppProjectIDSearchQuery(projectID)
if err != nil {
return nil, err
}
appIDs, err := o.query.SearchClientIDs(ctx, &query.AppSearchQueries{Queries: []query.SearchQuery{projectIDQuery}}, false)
if err != nil {
return nil, err
}
return append(appIDs, projectID), nil
}
func (o *OPStorage) AuthRequestByID(ctx context.Context, id string) (_ op.AuthRequest, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
if strings.HasPrefix(id, command.IDPrefixV2) {
req, err := o.command.GetCurrentAuthRequest(ctx, id)
if err != nil {
return nil, err
}
return &AuthRequestV2{req}, nil
}
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
if !ok {
return nil, errors.ThrowPreconditionFailed(nil, "OIDC-D3g21", "no user agent id")
@ -54,6 +136,17 @@ func (o *OPStorage) AuthRequestByCode(ctx context.Context, code string) (_ op.Au
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
plainCode, err := o.decryptGrant(code)
if err != nil {
return nil, err
}
if strings.HasPrefix(plainCode, command.IDPrefixV2) {
authReq, err := o.command.ExchangeAuthCode(ctx, plainCode)
if err != nil {
return nil, err
}
return &AuthRequestV2{authReq}, nil
}
resp, err := o.repo.AuthRequestByCode(ctx, code)
if err != nil {
return nil, err
@ -61,9 +154,23 @@ func (o *OPStorage) AuthRequestByCode(ctx context.Context, code string) (_ op.Au
return AuthRequestFromBusiness(resp)
}
// decryptGrant decrypts a code or refresh_token
func (o *OPStorage) decryptGrant(grant string) (string, error) {
decodedGrant, err := base64.RawURLEncoding.DecodeString(grant)
if err != nil {
return "", err
}
return o.encAlg.DecryptString(decodedGrant, o.encAlg.EncryptionKeyID())
}
func (o *OPStorage) SaveAuthCode(ctx context.Context, id, code string) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
if strings.HasPrefix(id, command.IDPrefixV2) {
return o.command.AddAuthRequestCode(ctx, id, code)
}
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
if !ok {
return errors.ThrowPreconditionFailed(nil, "OIDC-Dgus2", "no user agent id")
@ -81,12 +188,15 @@ func (o *OPStorage) DeleteAuthRequest(ctx context.Context, id string) (err error
func (o *OPStorage) CreateAccessToken(ctx context.Context, req op.TokenRequest) (_ string, _ time.Time, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
var userAgentID, applicationID, userOrgID string
authReq, ok := req.(*AuthRequest)
if ok {
switch authReq := req.(type) {
case *AuthRequest:
userAgentID = authReq.AgentID
applicationID = authReq.ApplicationID
userOrgID = authReq.UserOrgID
case *AuthRequestV2:
return o.command.AddOIDCSessionAccessToken(setContextUserSystem(ctx), authReq.GetID())
}
accessTokenLifetime, _, _, _, err := o.getOIDCSettings(ctx)
@ -104,6 +214,15 @@ func (o *OPStorage) CreateAccessToken(ctx context.Context, req op.TokenRequest)
func (o *OPStorage) CreateAccessAndRefreshTokens(ctx context.Context, req op.TokenRequest, refreshToken string) (_, _ string, _ time.Time, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
// handle V2 request directly
switch tokenReq := req.(type) {
case *AuthRequestV2:
return o.command.AddOIDCSessionRefreshAndAccessToken(setContextUserSystem(ctx), tokenReq.GetID())
case *RefreshTokenRequestV2:
return o.command.ExchangeOIDCSessionRefreshAndAccessToken(setContextUserSystem(ctx), tokenReq.OIDCSessionWriteModel.AggregateID, refreshToken, tokenReq.RequestedScopes)
}
userAgentID, applicationID, userOrgID, authTime, authMethodsReferences := getInfoFromRequest(req)
scopes, err := o.assertProjectRoleScopes(ctx, applicationID, req.GetScopes())
if err != nil {
@ -142,7 +261,22 @@ func getInfoFromRequest(req op.TokenRequest) (string, string, string, time.Time,
return "", "", "", time.Time{}, nil
}
func (o *OPStorage) TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (op.RefreshTokenRequest, error) {
func (o *OPStorage) TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (_ op.RefreshTokenRequest, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
plainCode, err := o.decryptGrant(refreshToken)
if err != nil {
return nil, err
}
if strings.HasPrefix(plainCode, command.IDPrefixV2) {
oidcSession, err := o.command.OIDCSessionByRefreshToken(ctx, plainCode)
if err != nil {
return nil, err
}
return &RefreshTokenRequestV2{OIDCSessionWriteModel: oidcSession}, nil
}
tokenView, err := o.repo.RefreshTokenByToken(ctx, refreshToken)
if err != nil {
return nil, err
@ -245,6 +379,29 @@ func (o *OPStorage) assertProjectRoleScopes(ctx context.Context, clientID string
return scopes, nil
}
func (o *OPStorage) assertProjectRoleScopesByProject(ctx context.Context, project *query.Project, scopes []string) ([]string, error) {
for _, scope := range scopes {
if strings.HasPrefix(scope, ScopeProjectRolePrefix) {
return scopes, nil
}
}
if !project.ProjectRoleAssertion {
return scopes, nil
}
projectIDQuery, err := query.NewProjectRoleProjectIDSearchQuery(project.ID)
if err != nil {
return nil, errors.ThrowInternal(err, "OIDC-Cyc78", "Errors.Internal")
}
roles, err := o.query.SearchProjectRoles(ctx, true, &query.ProjectRoleSearchQueries{Queries: []query.SearchQuery{projectIDQuery}}, false)
if err != nil {
return nil, err
}
for _, role := range roles.ProjectRoles {
scopes = append(scopes, ScopeProjectRolePrefix+role.Key)
}
return scopes, nil
}
func (o *OPStorage) assertClientScopesForPAT(ctx context.Context, token *model.TokenView, clientID, projectID string) error {
token.Audience = append(token.Audience, clientID)
projectIDQuery, err := query.NewProjectRoleProjectIDSearchQuery(projectID)
@ -279,3 +436,58 @@ func (o *OPStorage) getOIDCSettings(ctx context.Context) (accessTokenLifetime, i
}
return o.defaultAccessTokenLifetime, o.defaultIdTokenLifetime, o.defaultRefreshTokenIdleExpiration, o.defaultRefreshTokenExpiration, nil
}
func CreateErrorCallbackURL(authReq op.AuthRequest, reason, description, uri string, authorizer op.Authorizer) (string, error) {
e := struct {
Error string `schema:"error"`
Description string `schema:"error_description,omitempty"`
URI string `schema:"error_uri,omitempty"`
State string `schema:"state,omitempty"`
}{
Error: reason,
Description: description,
URI: uri,
State: authReq.GetState(),
}
callback, err := op.AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), e, authorizer.Encoder())
if err != nil {
return "", err
}
return callback, nil
}
func CreateCodeCallbackURL(ctx context.Context, authReq op.AuthRequest, authorizer op.Authorizer) (string, error) {
code, err := op.CreateAuthRequestCode(ctx, authReq, authorizer.Storage(), authorizer.Crypto())
if err != nil {
return "", err
}
codeResponse := struct {
code string
state string
}{
code: code,
state: authReq.GetState(),
}
callback, err := op.AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), &codeResponse, authorizer.Encoder())
if err != nil {
return "", err
}
return callback, err
}
func CreateTokenCallbackURL(ctx context.Context, req op.AuthRequest, authorizer op.Authorizer) (string, error) {
client, err := authorizer.Storage().GetClientByClientID(ctx, req.GetClientID())
if err != nil {
return "", err
}
createAccessToken := req.GetResponseType() != oidc.ResponseTypeIDTokenOnly
resp, err := op.CreateTokenResponse(ctx, req, client, authorizer, createAccessToken, "", "")
if err != nil {
return "", err
}
callback, err := op.AuthResponseURL(req.GetRedirectURI(), req.GetResponseType(), req.GetResponseMode(), resp, authorizer.Encoder())
if err != nil {
return "", err
}
return callback, err
}

View File

@ -12,20 +12,12 @@ import (
"github.com/zitadel/zitadel/internal/api/authz"
http_utils "github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/api/oidc/amr"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/user/model"
)
const (
// DEPRECATED: use `amrPWD` instead
amrPassword = "password"
amrPWD = "pwd"
amrMFA = "mfa"
amrOTP = "otp"
amrUserPresence = "user"
)
type AuthRequest struct {
*domain.AuthRequest
}
@ -40,19 +32,19 @@ func (a *AuthRequest) GetACR() string {
}
func (a *AuthRequest) GetAMR() []string {
amr := make([]string, 0)
list := make([]string, 0)
if a.PasswordVerified {
amr = append(amr, amrPassword, amrPWD)
list = append(list, amr.Password, amr.PWD)
}
if len(a.MFAsVerified) > 0 {
amr = append(amr, amrMFA)
list = append(list, amr.MFA)
for _, mfa := range a.MFAsVerified {
if amrMFA := AMRFromMFAType(mfa); amrMFA != "" {
amr = append(amr, amrMFA)
list = append(list, amrMFA)
}
}
}
return amr
return list
}
func (a *AuthRequest) GetAudience() []string {
@ -271,10 +263,10 @@ func CodeChallengeToOIDC(challenge *domain.OIDCCodeChallenge) *oidc.CodeChalleng
func AMRFromMFAType(mfaType domain.MFAType) string {
switch mfaType {
case domain.MFATypeOTP:
return amrOTP
return amr.OTP
case domain.MFATypeU2F,
domain.MFATypeU2FUserVerification:
return amrUserPresence
return amr.UserPresence
default:
return ""
}

View File

@ -0,0 +1,106 @@
package oidc
import (
"time"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/zitadel/internal/command"
)
type AuthRequestV2 struct {
*command.CurrentAuthRequest
}
func (a *AuthRequestV2) GetID() string {
return a.ID
}
func (a *AuthRequestV2) GetACR() string {
return "" //PLANNED: impl
}
func (a *AuthRequestV2) GetAMR() []string {
return a.AMR
}
func (a *AuthRequestV2) GetAudience() []string {
return a.Audience
}
func (a *AuthRequestV2) GetAuthTime() time.Time {
return a.AuthTime
}
func (a *AuthRequestV2) GetClientID() string {
return a.ClientID
}
func (a *AuthRequestV2) GetCodeChallenge() *oidc.CodeChallenge {
return CodeChallengeToOIDC(a.CodeChallenge)
}
func (a *AuthRequestV2) GetNonce() string {
return a.Nonce
}
func (a *AuthRequestV2) GetRedirectURI() string {
return a.RedirectURI
}
func (a *AuthRequestV2) GetResponseType() oidc.ResponseType {
return ResponseTypeToOIDC(a.ResponseType)
}
func (a *AuthRequestV2) GetResponseMode() oidc.ResponseMode {
return ""
}
func (a *AuthRequestV2) GetScopes() []string {
return a.Scope
}
func (a *AuthRequestV2) GetState() string {
return a.State
}
func (a *AuthRequestV2) GetSubject() string {
return a.UserID
}
func (a *AuthRequestV2) Done() bool {
return a.UserID != "" && a.SessionID != ""
}
type RefreshTokenRequestV2 struct {
*command.OIDCSessionWriteModel
RequestedScopes []string
}
func (r *RefreshTokenRequestV2) GetAMR() []string {
return r.AuthMethodsReferences
}
func (r *RefreshTokenRequestV2) GetAudience() []string {
return r.Audience
}
func (r *RefreshTokenRequestV2) GetAuthTime() time.Time {
return r.AuthTime
}
func (r *RefreshTokenRequestV2) GetClientID() string {
return r.ClientID
}
func (r *RefreshTokenRequestV2) GetScopes() []string {
return r.Scope
}
func (r *RefreshTokenRequestV2) GetSubject() string {
return r.UserID
}
func (r *RefreshTokenRequestV2) SetCurrentScopes(scopes []string) {
r.RequestedScopes = scopes
}

View File

@ -0,0 +1,275 @@
//go:build integration
package oidc_test
import (
"context"
"net/url"
"os"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v2/pkg/client/rp"
"github.com/zitadel/oidc/v2/pkg/oidc"
"golang.org/x/oauth2"
"github.com/zitadel/zitadel/internal/api/oidc/amr"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/integration"
oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2alpha"
user "github.com/zitadel/zitadel/pkg/grpc/user/v2alpha"
)
var (
CTX context.Context
CTXLOGIN context.Context
Tester *integration.Tester
User *user.AddHumanUserResponse
)
const (
redirectURI = "oidcIntegrationTest://callback"
redirectURIImplicit = "http://localhost:9999/callback"
)
func TestMain(m *testing.M) {
os.Exit(func() int {
ctx, errCtx, cancel := integration.Contexts(5 * time.Minute)
defer cancel()
Tester = integration.NewTester(ctx)
defer Tester.Done()
CTX, _ = Tester.WithAuthorization(ctx, integration.OrgOwner), errCtx
User = Tester.CreateHumanUser(CTX)
Tester.RegisterUserPasskey(CTX, User.GetUserId())
CTXLOGIN, _ = Tester.WithAuthorization(ctx, integration.Login), errCtx
return m.Run()
}())
}
func createClient(t testing.TB) string {
app, err := Tester.CreateOIDCNativeClient(CTX, redirectURI)
require.NoError(t, err)
return app.GetClientId()
}
func createImplicitClient(t testing.TB) string {
app, err := Tester.CreateOIDCImplicitFlowClient(CTX, redirectURIImplicit)
require.NoError(t, err)
return app.GetClientId()
}
func createAuthRequest(t testing.TB, clientID, redirectURI string, scope ...string) string {
redURL, err := Tester.CreateOIDCAuthRequest(clientID, Tester.Users[integration.FirstInstanceUsersKey][integration.Login].ID, redirectURI, scope...)
require.NoError(t, err)
return redURL
}
func createAuthRequestImplicit(t testing.TB, clientID, redirectURI string, scope ...string) string {
redURL, err := Tester.CreateOIDCAuthRequestImplicit(clientID, Tester.Users[integration.FirstInstanceUsersKey][integration.Login].ID, redirectURI, scope...)
require.NoError(t, err)
return redURL
}
func TestOPStorage_CreateAuthRequest(t *testing.T) {
clientID := createClient(t)
id := createAuthRequest(t, clientID, redirectURI)
require.Contains(t, id, command.IDPrefixV2)
}
func TestOPStorage_CreateAccessToken_code(t *testing.T) {
clientID := createClient(t)
authRequestID := createAuthRequest(t, clientID, redirectURI)
sessionID, sessionToken, startTime, changeTime := Tester.CreatePasskeySession(t, CTXLOGIN, User.GetUserId())
linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
AuthRequestId: authRequestID,
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
Session: &oidc_pb.Session{
SessionId: sessionID,
SessionToken: sessionToken,
},
},
})
require.NoError(t, err)
// test code exchange
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
tokens, err := exchangeTokens(t, clientID, code)
require.NoError(t, err)
assertTokens(t, tokens, false)
assertTokenClaims(t, tokens.IDTokenClaims, startTime, changeTime)
// callback on a succeeded request must fail
linkResp, err = Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
AuthRequestId: authRequestID,
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
Session: &oidc_pb.Session{
SessionId: sessionID,
SessionToken: sessionToken,
},
},
})
require.Error(t, err)
// exchange with a used code must fail
_, err = exchangeTokens(t, clientID, code)
require.Error(t, err)
}
func TestOPStorage_CreateAccessToken_implicit(t *testing.T) {
clientID := createImplicitClient(t)
authRequestID := createAuthRequestImplicit(t, clientID, redirectURIImplicit)
sessionID, sessionToken, startTime, changeTime := Tester.CreatePasskeySession(t, CTXLOGIN, User.GetUserId())
linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
AuthRequestId: authRequestID,
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
Session: &oidc_pb.Session{
SessionId: sessionID,
SessionToken: sessionToken,
},
},
})
require.NoError(t, err)
// test implicit callback
callback, err := url.Parse(linkResp.GetCallbackUrl())
require.NoError(t, err)
values, err := url.ParseQuery(callback.Fragment)
require.NoError(t, err)
accessToken := values.Get("access_token")
idToken := values.Get("id_token")
refreshToken := values.Get("refresh_token")
assert.NotEmpty(t, accessToken)
assert.NotEmpty(t, idToken)
assert.Empty(t, refreshToken)
assert.NotEmpty(t, values.Get("expires_in"))
assert.Equal(t, oidc.BearerToken, values.Get("token_type"))
assert.Equal(t, "state", values.Get("state"))
// check id_token / claims
provider, err := Tester.CreateRelyingParty(clientID, redirectURIImplicit)
require.NoError(t, err)
claims, err := rp.VerifyTokens[*oidc.IDTokenClaims](context.Background(), accessToken, idToken, provider.IDTokenVerifier())
require.NoError(t, err)
assertTokenClaims(t, claims, startTime, changeTime)
// callback on a succeeded request must fail
linkResp, err = Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
AuthRequestId: authRequestID,
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
Session: &oidc_pb.Session{
SessionId: sessionID,
SessionToken: sessionToken,
},
},
})
require.Error(t, err)
}
func TestOPStorage_CreateAccessAndRefreshTokens_code(t *testing.T) {
clientID := createClient(t)
authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess)
sessionID, sessionToken, startTime, changeTime := Tester.CreatePasskeySession(t, CTXLOGIN, User.GetUserId())
linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
AuthRequestId: authRequestID,
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
Session: &oidc_pb.Session{
SessionId: sessionID,
SessionToken: sessionToken,
},
},
})
require.NoError(t, err)
// test code exchange (expect refresh token to be returned)
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
tokens, err := exchangeTokens(t, clientID, code)
require.NoError(t, err)
assertTokens(t, tokens, true)
assertTokenClaims(t, tokens.IDTokenClaims, startTime, changeTime)
}
func TestOPStorage_CreateAccessAndRefreshTokens_refresh(t *testing.T) {
clientID := createClient(t)
provider, err := Tester.CreateRelyingParty(clientID, redirectURI)
require.NoError(t, err)
authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess)
sessionID, sessionToken, startTime, changeTime := Tester.CreatePasskeySession(t, CTXLOGIN, User.GetUserId())
linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
AuthRequestId: authRequestID,
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
Session: &oidc_pb.Session{
SessionId: sessionID,
SessionToken: sessionToken,
},
},
})
require.NoError(t, err)
// code exchange
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
tokens, err := exchangeTokens(t, clientID, code)
require.NoError(t, err)
assertTokens(t, tokens, true)
assertTokenClaims(t, tokens.IDTokenClaims, startTime, changeTime)
// test actual refresh grant
newTokens, err := refreshTokens(t, clientID, tokens.RefreshToken)
require.NoError(t, err)
idToken, _ := newTokens.Extra("id_token").(string)
assert.NotEmpty(t, idToken)
assert.NotEmpty(t, newTokens.AccessToken)
assert.NotEmpty(t, newTokens.RefreshToken)
claims, err := rp.VerifyTokens[*oidc.IDTokenClaims](context.Background(), newTokens.AccessToken, idToken, provider.IDTokenVerifier())
require.NoError(t, err)
// auth time must still be the initial
assertTokenClaims(t, claims, startTime, changeTime)
// refresh with an old refresh_token must fail
_, err = rp.RefreshAccessToken(provider, tokens.RefreshToken, "", "")
require.Error(t, err)
}
func exchangeTokens(t testing.TB, clientID, code string) (*oidc.Tokens[*oidc.IDTokenClaims], error) {
provider, err := Tester.CreateRelyingParty(clientID, redirectURI)
require.NoError(t, err)
codeVerifier := "codeVerifier"
return rp.CodeExchange[*oidc.IDTokenClaims](context.Background(), code, provider, rp.WithCodeVerifier(codeVerifier))
}
func refreshTokens(t testing.TB, clientID, refreshToken string) (*oauth2.Token, error) {
provider, err := Tester.CreateRelyingParty(clientID, redirectURI)
require.NoError(t, err)
return rp.RefreshAccessToken(provider, refreshToken, "", "")
}
func assertCodeResponse(t *testing.T, callback string) string {
callbackURL, err := url.Parse(callback)
require.NoError(t, err)
code := callbackURL.Query().Get("code")
require.NotEmpty(t, code)
assert.Equal(t, "state", callbackURL.Query().Get("state"))
return code
}
func assertTokens(t *testing.T, tokens *oidc.Tokens[*oidc.IDTokenClaims], requireRefreshToken bool) {
assert.NotEmpty(t, tokens.AccessToken)
assert.NotEmpty(t, tokens.IDToken)
if requireRefreshToken {
assert.NotEmpty(t, tokens.RefreshToken)
} else {
assert.Empty(t, tokens.RefreshToken)
}
}
func assertTokenClaims(t *testing.T, claims *oidc.IDTokenClaims, sessionStart, sessionChange time.Time) {
assert.Equal(t, User.GetUserId(), claims.Subject)
assert.Equal(t, []string{amr.UserPresence, amr.MFA}, claims.AuthenticationMethodsReferences)
assert.WithinRange(t, claims.AuthTime.AsTime().UTC(), sessionStart.Add(-1*time.Second), sessionChange.Add(1*time.Second))
}

View File

@ -66,7 +66,7 @@ func (o *OPStorage) GetClientByClientID(ctx context.Context, id string) (_ op.Cl
return nil, err
}
return ClientFromBusiness(client, o.defaultLoginURL, accessTokenLifetime, idTokenLifetime, allowedScopes)
return ClientFromBusiness(client, o.defaultLoginURL, o.defaultLoginURLV2, accessTokenLifetime, idTokenLifetime, allowedScopes)
}
func (o *OPStorage) GetKeyByIDAndClientID(ctx context.Context, keyID, userID string) (_ *jose.JSONWebKey, err error) {
@ -153,7 +153,7 @@ func (o *OPStorage) SetUserinfoFromScopes(ctx context.Context, userInfo *oidc.Us
return o.setUserinfo(ctx, userInfo, userID, applicationID, scopes, nil)
}
func (o *OPStorage) SetIntrospectionFromToken(ctx context.Context, introspection *oidc.IntrospectionResponse, tokenID, subject, clientID string) error {
func (o *OPStorage) SetIntrospectionFromToken(ctx context.Context, introspection *oidc.IntrospectionResponse, tokenID, subject, clientID string) (err error) {
token, err := o.repo.TokenByIDs(ctx, subject, tokenID)
if err != nil {
return errors.ThrowPermissionDenied(nil, "OIDC-Dsfb2", "token is not valid or has expired")

View File

@ -7,6 +7,7 @@ import (
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query"
@ -15,18 +16,20 @@ import (
type Client struct {
app *query.App
defaultLoginURL string
defaultLoginURLV2 string
defaultAccessTokenLifetime time.Duration
defaultIdTokenLifetime time.Duration
allowedScopes []string
}
func ClientFromBusiness(app *query.App, defaultLoginURL string, defaultAccessTokenLifetime, defaultIdTokenLifetime time.Duration, allowedScopes []string) (op.Client, error) {
func ClientFromBusiness(app *query.App, defaultLoginURL, defaultLoginURLV2 string, defaultAccessTokenLifetime, defaultIdTokenLifetime time.Duration, allowedScopes []string) (op.Client, error) {
if app.OIDCConfig == nil {
return nil, errors.ThrowInvalidArgument(nil, "OIDC-d5bhD", "client is not a proper oidc application")
}
return &Client{
app: app,
defaultLoginURL: defaultLoginURL,
defaultLoginURLV2: defaultLoginURLV2,
defaultAccessTokenLifetime: defaultAccessTokenLifetime,
defaultIdTokenLifetime: defaultIdTokenLifetime,
allowedScopes: allowedScopes},
@ -46,6 +49,9 @@ func (c *Client) GetID() string {
}
func (c *Client) LoginURL(id string) string {
if strings.HasPrefix(id, command.IDPrefixV2) {
return c.defaultLoginURLV2 + id
}
return c.defaultLoginURL + id
}

View File

@ -41,6 +41,7 @@ type Config struct {
Cache *middleware.CacheConfig
CustomEndpoints *EndpointConfig
DeviceAuth *DeviceAuthorizationConfig
DefaultLoginURLV2 string
}
type EndpointConfig struct {
@ -65,6 +66,7 @@ type OPStorage struct {
query *query.Queries
eventstore *eventstore.Eventstore
defaultLoginURL string
defaultLoginURLV2 string
defaultAccessTokenLifetime time.Duration
defaultIdTokenLifetime time.Duration
signingKeyAlgorithm string
@ -181,6 +183,7 @@ func newStorage(config Config, command *command.Commands, query *query.Queries,
query: query,
eventstore: es,
defaultLoginURL: fmt.Sprintf("%s%s?%s=", login.HandlerPrefix, login.EndpointLogin, login.QueryAuthRequestID),
defaultLoginURLV2: config.DefaultLoginURLV2,
signingKeyAlgorithm: config.SigningKeyAlgorithm,
defaultAccessTokenLifetime: config.DefaultAccessTokenLifetime,
defaultIdTokenLifetime: config.DefaultIdTokenLifetime,

View File

@ -0,0 +1,215 @@
package command
import (
"context"
"time"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/oidc/amr"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/repository/authrequest"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
type AuthRequest struct {
ID string
LoginClient string
ClientID string
RedirectURI string
State string
Nonce string
Scope []string
Audience []string
ResponseType domain.OIDCResponseType
CodeChallenge *domain.OIDCCodeChallenge
Prompt []domain.Prompt
UILocales []string
MaxAge *time.Duration
LoginHint *string
HintUserID *string
}
type CurrentAuthRequest struct {
*AuthRequest
SessionID string
UserID string
AMR []string
AuthTime time.Time
}
const IDPrefixV2 = "V2_"
func (c *Commands) AddAuthRequest(ctx context.Context, authRequest *AuthRequest) (_ *CurrentAuthRequest, err error) {
authRequestID, err := c.idGenerator.Next()
if err != nil {
return nil, err
}
authRequest.ID = IDPrefixV2 + authRequestID
writeModel, err := c.getAuthRequestWriteModel(ctx, authRequest.ID)
if err != nil {
return nil, err
}
if writeModel.AuthRequestState != domain.AuthRequestStateUnspecified {
return nil, errors.ThrowPreconditionFailed(nil, "COMMAND-Sf3gt", "Errors.AuthRequest.AlreadyExisting")
}
err = c.pushAppendAndReduce(ctx, writeModel, authrequest.NewAddedEvent(
ctx,
&authrequest.NewAggregate(authRequest.ID, authz.GetInstance(ctx).InstanceID()).Aggregate,
authRequest.LoginClient,
authRequest.ClientID,
authRequest.RedirectURI,
authRequest.State,
authRequest.Nonce,
authRequest.Scope,
authRequest.Audience,
authRequest.ResponseType,
authRequest.CodeChallenge,
authRequest.Prompt,
authRequest.UILocales,
authRequest.MaxAge,
authRequest.LoginHint,
authRequest.HintUserID,
))
if err != nil {
return nil, err
}
return authRequestWriteModelToCurrentAuthRequest(writeModel), nil
}
func (c *Commands) LinkSessionToAuthRequest(ctx context.Context, id, sessionID, sessionToken string, checkLoginClient bool) (*domain.ObjectDetails, *CurrentAuthRequest, error) {
writeModel, err := c.getAuthRequestWriteModel(ctx, id)
if err != nil {
return nil, nil, err
}
if writeModel.AuthRequestState == domain.AuthRequestStateUnspecified {
return nil, nil, errors.ThrowNotFound(nil, "COMMAND-jae5P", "Errors.AuthRequest.NotExisting")
}
if writeModel.AuthRequestState != domain.AuthRequestStateAdded {
return nil, nil, errors.ThrowPreconditionFailed(nil, "COMMAND-Sx208nt", "Errors.AuthRequest.AlreadyHandled")
}
if checkLoginClient && authz.GetCtxData(ctx).UserID != writeModel.LoginClient {
return nil, nil, errors.ThrowPermissionDenied(nil, "COMMAND-rai9Y", "Errors.AuthRequest.WrongLoginClient")
}
sessionWriteModel := NewSessionWriteModel(sessionID, authz.GetCtxData(ctx).OrgID)
err = c.eventstore.FilterToQueryReducer(ctx, sessionWriteModel)
if err != nil {
return nil, nil, err
}
if sessionWriteModel.State == domain.SessionStateUnspecified {
return nil, nil, errors.ThrowNotFound(nil, "COMMAND-x0099887", "Errors.Session.NotExisting")
}
if err := c.sessionPermission(ctx, sessionWriteModel, sessionToken, domain.PermissionSessionWrite); err != nil {
return nil, nil, err
}
if err := c.pushAppendAndReduce(ctx, writeModel, authrequest.NewSessionLinkedEvent(
ctx, &authrequest.NewAggregate(id, authz.GetInstance(ctx).InstanceID()).Aggregate,
sessionID,
sessionWriteModel.UserID,
sessionWriteModel.AuthenticationTime(),
amr.List(sessionWriteModel),
)); err != nil {
return nil, nil, err
}
return writeModelToObjectDetails(&writeModel.WriteModel), authRequestWriteModelToCurrentAuthRequest(writeModel), nil
}
func (c *Commands) FailAuthRequest(ctx context.Context, id string, reason domain.OIDCErrorReason) (*domain.ObjectDetails, *CurrentAuthRequest, error) {
writeModel, err := c.getAuthRequestWriteModel(ctx, id)
if err != nil {
return nil, nil, err
}
if writeModel.AuthRequestState != domain.AuthRequestStateAdded {
return nil, nil, errors.ThrowPreconditionFailed(nil, "COMMAND-Sx202nt", "Errors.AuthRequest.AlreadyHandled")
}
err = c.pushAppendAndReduce(ctx, writeModel, authrequest.NewFailedEvent(
ctx,
&authrequest.NewAggregate(id, authz.GetInstance(ctx).InstanceID()).Aggregate,
reason,
))
if err != nil {
return nil, nil, err
}
return writeModelToObjectDetails(&writeModel.WriteModel), authRequestWriteModelToCurrentAuthRequest(writeModel), nil
}
func (c *Commands) AddAuthRequestCode(ctx context.Context, authRequestID, code string) (err error) {
if code == "" {
return errors.ThrowPreconditionFailed(nil, "COMMAND-Ht52d", "Errors.AuthRequest.InvalidCode")
}
writeModel, err := c.getAuthRequestWriteModel(ctx, authRequestID)
if err != nil {
return err
}
if writeModel.AuthRequestState != domain.AuthRequestStateAdded || writeModel.SessionID == "" {
return errors.ThrowPreconditionFailed(nil, "COMMAND-SFwd2", "Errors.AuthRequest.AlreadyHandled")
}
return c.pushAppendAndReduce(ctx, writeModel, authrequest.NewCodeAddedEvent(ctx,
&authrequest.NewAggregate(writeModel.AggregateID, authz.GetInstance(ctx).InstanceID()).Aggregate))
}
func (c *Commands) ExchangeAuthCode(ctx context.Context, code string) (authRequest *CurrentAuthRequest, err error) {
if code == "" {
return nil, errors.ThrowPreconditionFailed(nil, "COMMAND-Sf3g2", "Errors.AuthRequest.InvalidCode")
}
writeModel, err := c.getAuthRequestWriteModel(ctx, code)
if err != nil {
return nil, err
}
if writeModel.AuthRequestState != domain.AuthRequestStateCodeAdded {
return nil, errors.ThrowPreconditionFailed(nil, "COMMAND-SFwd2", "Errors.AuthRequest.NoCode")
}
err = c.pushAppendAndReduce(ctx, writeModel, authrequest.NewCodeExchangedEvent(ctx,
&authrequest.NewAggregate(writeModel.AggregateID, authz.GetInstance(ctx).InstanceID()).Aggregate))
if err != nil {
return nil, err
}
return authRequestWriteModelToCurrentAuthRequest(writeModel), nil
}
func authRequestWriteModelToCurrentAuthRequest(writeModel *AuthRequestWriteModel) (_ *CurrentAuthRequest) {
return &CurrentAuthRequest{
AuthRequest: &AuthRequest{
ID: writeModel.AggregateID,
LoginClient: writeModel.LoginClient,
ClientID: writeModel.ClientID,
RedirectURI: writeModel.RedirectURI,
State: writeModel.State,
Nonce: writeModel.Nonce,
Scope: writeModel.Scope,
Audience: writeModel.Audience,
ResponseType: writeModel.ResponseType,
CodeChallenge: writeModel.CodeChallenge,
Prompt: writeModel.Prompt,
UILocales: writeModel.UILocales,
MaxAge: writeModel.MaxAge,
LoginHint: writeModel.LoginHint,
HintUserID: writeModel.HintUserID,
},
SessionID: writeModel.SessionID,
UserID: writeModel.UserID,
AMR: writeModel.AMR,
AuthTime: writeModel.AuthTime,
}
}
func (c *Commands) GetCurrentAuthRequest(ctx context.Context, id string) (_ *CurrentAuthRequest, err error) {
wm, err := c.getAuthRequestWriteModel(ctx, id)
if err != nil {
return nil, err
}
return authRequestWriteModelToCurrentAuthRequest(wm), nil
}
func (c *Commands) getAuthRequestWriteModel(ctx context.Context, id string) (writeModel *AuthRequestWriteModel, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
writeModel = NewAuthRequestWriteModel(ctx, id)
err = c.eventstore.FilterToQueryReducer(ctx, writeModel)
if err != nil {
return nil, err
}
return writeModel, nil
}

View File

@ -0,0 +1,110 @@
package command
import (
"context"
"time"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/domain"
caos_errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/authrequest"
)
type AuthRequestWriteModel struct {
eventstore.WriteModel
aggregate *eventstore.Aggregate
LoginClient string
ClientID string
RedirectURI string
State string
Nonce string
Scope []string
Audience []string
ResponseType domain.OIDCResponseType
CodeChallenge *domain.OIDCCodeChallenge
Prompt []domain.Prompt
UILocales []string
MaxAge *time.Duration
LoginHint *string
HintUserID *string
SessionID string
UserID string
AuthTime time.Time
AMR []string
AuthRequestState domain.AuthRequestState
}
func NewAuthRequestWriteModel(ctx context.Context, id string) *AuthRequestWriteModel {
return &AuthRequestWriteModel{
WriteModel: eventstore.WriteModel{
AggregateID: id,
},
aggregate: &authrequest.NewAggregate(id, authz.GetInstance(ctx).InstanceID()).Aggregate,
}
}
func (m *AuthRequestWriteModel) Reduce() error {
for _, event := range m.Events {
switch e := event.(type) {
case *authrequest.AddedEvent:
m.LoginClient = e.LoginClient
m.ClientID = e.ClientID
m.RedirectURI = e.RedirectURI
m.State = e.State
m.Nonce = e.Nonce
m.Scope = e.Scope
m.Audience = e.Audience
m.ResponseType = e.ResponseType
m.CodeChallenge = e.CodeChallenge
m.Prompt = e.Prompt
m.UILocales = e.UILocales
m.MaxAge = e.MaxAge
m.LoginHint = e.LoginHint
m.HintUserID = e.HintUserID
m.AuthRequestState = domain.AuthRequestStateAdded
case *authrequest.SessionLinkedEvent:
m.SessionID = e.SessionID
m.UserID = e.UserID
m.AuthTime = e.AuthTime
m.AMR = e.AMR
case *authrequest.CodeAddedEvent:
m.AuthRequestState = domain.AuthRequestStateCodeAdded
case *authrequest.FailedEvent:
m.AuthRequestState = domain.AuthRequestStateFailed
case *authrequest.CodeExchangedEvent:
m.AuthRequestState = domain.AuthRequestStateCodeExchanged
case *authrequest.SucceededEvent:
m.AuthRequestState = domain.AuthRequestStateSucceeded
}
}
return m.WriteModel.Reduce()
}
func (m *AuthRequestWriteModel) Query() *eventstore.SearchQueryBuilder {
return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
AddQuery().
AggregateTypes(authrequest.AggregateType).
AggregateIDs(m.AggregateID).
Builder()
}
// CheckAuthenticated checks that the auth request exists, a session must have been linked
// and in case of a Code Flow the code must have been exchanged
func (m *AuthRequestWriteModel) CheckAuthenticated() error {
if m.SessionID == "" {
return caos_errs.ThrowPreconditionFailed(nil, "AUTHR-SF2r2", "Errors.AuthRequest.NotAuthenticated")
}
// in case of OIDC Code Flow, the code must have been exchanged
if m.ResponseType == domain.OIDCResponseTypeCode && m.AuthRequestState == domain.AuthRequestStateCodeExchanged {
return nil
}
// in case of OIDC Implicit Flow, check that the requests exists, but has not succeeded yet
if (m.ResponseType == domain.OIDCResponseTypeIDToken || m.ResponseType == domain.OIDCResponseTypeIDTokenToken) &&
m.AuthRequestState == domain.AuthRequestStateAdded {
return nil
}
return caos_errs.ThrowPreconditionFailed(nil, "AUTHR-sajk3", "Errors.AuthRequest.NotAuthenticated")
}

View File

@ -0,0 +1,998 @@
package command
import (
"context"
"testing"
"time"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/oidc/amr"
"github.com/zitadel/zitadel/internal/domain"
caos_errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/repository"
"github.com/zitadel/zitadel/internal/id"
"github.com/zitadel/zitadel/internal/id/mock"
"github.com/zitadel/zitadel/internal/repository/authrequest"
"github.com/zitadel/zitadel/internal/repository/session"
)
func TestCommands_AddAuthRequest(t *testing.T) {
mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient")
type fields struct {
eventstore *eventstore.Eventstore
idGenerator id.Generator
}
type args struct {
ctx context.Context
request *AuthRequest
}
tests := []struct {
name string
fields fields
args args
want *CurrentAuthRequest
wantErr error
}{
{
"already exists error",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"loginClient",
"clientID",
"redirectURI",
"state",
"nonce",
[]string{"openid"},
[]string{"audience"},
domain.OIDCResponseTypeCode,
nil,
nil,
nil,
nil,
nil,
nil,
),
),
),
),
idGenerator: mock.NewIDGeneratorExpectIDs(t, "id"),
},
args{
ctx: mockCtx,
request: &AuthRequest{},
},
nil,
caos_errs.ThrowPreconditionFailed(nil, "COMMAND-Sf3gt", "Errors.AuthRequest.AlreadyExisting"),
},
{
"added",
fields{
eventstore: eventstoreExpect(t,
expectFilter(),
expectPush(
[]*repository.Event{
eventFromEventPusherWithInstanceID("instanceID",
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"loginClient",
"clientID",
"redirectURI",
"state",
"nonce",
[]string{"openid"},
[]string{"audience"},
domain.OIDCResponseTypeCode,
&domain.OIDCCodeChallenge{
Challenge: "challenge",
Method: domain.CodeChallengeMethodS256,
},
[]domain.Prompt{domain.PromptNone},
[]string{"en", "de"},
gu.Ptr(time.Duration(0)),
gu.Ptr("loginHint"),
gu.Ptr("hintUserID"),
),
),
}),
),
idGenerator: mock.NewIDGeneratorExpectIDs(t, "id"),
},
args{
ctx: mockCtx,
request: &AuthRequest{
LoginClient: "loginClient",
ClientID: "clientID",
RedirectURI: "redirectURI",
State: "state",
Nonce: "nonce",
Scope: []string{"openid"},
Audience: []string{"audience"},
ResponseType: domain.OIDCResponseTypeCode,
CodeChallenge: &domain.OIDCCodeChallenge{
Challenge: "challenge",
Method: domain.CodeChallengeMethodS256,
},
Prompt: []domain.Prompt{domain.PromptNone},
UILocales: []string{"en", "de"},
MaxAge: gu.Ptr(time.Duration(0)),
LoginHint: gu.Ptr("loginHint"),
HintUserID: gu.Ptr("hintUserID"),
},
},
&CurrentAuthRequest{
AuthRequest: &AuthRequest{
ID: "V2_id",
LoginClient: "loginClient",
ClientID: "clientID",
RedirectURI: "redirectURI",
State: "state",
Nonce: "nonce",
Scope: []string{"openid"},
Audience: []string{"audience"},
ResponseType: domain.OIDCResponseTypeCode,
CodeChallenge: &domain.OIDCCodeChallenge{
Challenge: "challenge",
Method: domain.CodeChallengeMethodS256,
},
Prompt: []domain.Prompt{domain.PromptNone},
UILocales: []string{"en", "de"},
MaxAge: gu.Ptr(time.Duration(0)),
LoginHint: gu.Ptr("loginHint"),
HintUserID: gu.Ptr("hintUserID"),
},
},
nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore,
idGenerator: tt.fields.idGenerator,
}
got, err := c.AddAuthRequest(tt.args.ctx, tt.args.request)
require.ErrorIs(t, tt.wantErr, err)
assert.Equal(t, tt.want, got)
})
}
}
func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient")
type fields struct {
eventstore *eventstore.Eventstore
tokenVerifier func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error)
checkPermission domain.PermissionCheck
}
type args struct {
ctx context.Context
id string
sessionID string
sessionToken string
checkLoginClient bool
}
type res struct {
details *domain.ObjectDetails
authReq *CurrentAuthRequest
wantErr error
}
tests := []struct {
name string
fields fields
args args
res res
}{
{
"authRequest not found",
fields{
eventstore: eventstoreExpect(t,
expectFilter(),
),
tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) {
return nil
},
checkPermission: newMockPermissionCheckNotAllowed(),
},
args{
ctx: mockCtx,
id: "id",
sessionID: "sessionID",
},
res{
wantErr: caos_errs.ThrowNotFound(nil, "COMMAND-jae5P", "Errors.AuthRequest.NotExisting"),
},
},
{
"authRequest not existing",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("id", "instanceID").Aggregate,
"loginClient",
"clientID",
"redirectURI",
"state",
"nonce",
[]string{"openid"},
[]string{"audience"},
domain.OIDCResponseTypeCode,
nil,
nil,
nil,
nil,
nil,
nil,
),
),
eventFromEventPusher(
authrequest.NewFailedEvent(mockCtx, &authrequest.NewAggregate("id", "instanceID").Aggregate,
domain.OIDCErrorReasonUnspecified),
),
),
),
tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) {
return nil
},
checkPermission: newMockPermissionCheckAllowed(),
},
args{
ctx: mockCtx,
id: "id",
sessionID: "sessionID",
},
res{
wantErr: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-Sx208nt", "Errors.AuthRequest.AlreadyHandled"),
},
},
{
"wrong login client",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("id", "instanceID").Aggregate,
"loginClient",
"clientID",
"redirectURI",
"state",
"nonce",
[]string{"openid"},
[]string{"audience"},
domain.OIDCResponseTypeCode,
nil,
nil,
nil,
nil,
nil,
nil,
),
),
),
),
tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) {
return nil
},
checkPermission: newMockPermissionCheckAllowed(),
},
args{
ctx: authz.NewMockContext("instanceID", "orgID", "wrongLoginClient"),
id: "id",
sessionID: "sessionID",
sessionToken: "token",
checkLoginClient: true,
},
res{
wantErr: caos_errs.ThrowPermissionDenied(nil, "COMMAND-rai9Y", "Errors.AuthRequest.WrongLoginClient"),
},
},
{
"session not existing",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"loginClient",
"clientID",
"redirectURI",
"state",
"nonce",
[]string{"openid"},
[]string{"audience"},
domain.OIDCResponseTypeCode,
nil,
nil,
nil,
nil,
nil,
nil,
),
),
),
expectFilter(),
),
tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) {
return nil
},
checkPermission: newMockPermissionCheckNotAllowed(),
},
args{
ctx: mockCtx,
id: "V2_id",
sessionID: "sessionID",
},
res{
wantErr: caos_errs.ThrowNotFound(nil, "COMMAND-x0099887", "Errors.Session.NotExisting"),
},
},
{
"missing permission",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"loginClient",
"clientID",
"redirectURI",
"state",
"nonce",
[]string{"openid"},
[]string{"audience"},
domain.OIDCResponseTypeCode,
nil,
nil,
nil,
nil,
nil,
nil,
),
),
),
expectFilter(
eventFromEventPusher(
session.NewAddedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld")),
),
),
tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) {
return nil
},
checkPermission: newMockPermissionCheckNotAllowed(),
},
args{
ctx: mockCtx,
id: "V2_id",
sessionID: "sessionID",
},
res{
wantErr: caos_errs.ThrowPermissionDenied(nil, "AUTHZ-HKJD33", "Errors.PermissionDenied"),
},
},
{
"invalid session token",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"loginClient",
"clientID",
"redirectURI",
"state",
"nonce",
[]string{"openid"},
[]string{"audience"},
domain.OIDCResponseTypeCode,
nil,
nil,
nil,
nil,
nil,
nil,
),
),
),
expectFilter(
eventFromEventPusher(
session.NewAddedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld")),
),
),
tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) {
return caos_errs.ThrowPermissionDenied(nil, "COMMAND-sGr42", "Errors.Session.Token.Invalid")
},
},
args{
ctx: mockCtx,
id: "V2_id",
sessionID: "sessionID",
sessionToken: "invalid",
},
res{
wantErr: caos_errs.ThrowPermissionDenied(nil, "COMMAND-sGr42", "Errors.Session.Token.Invalid"),
},
},
{
"linked",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"loginClient",
"clientID",
"redirectURI",
"state",
"nonce",
[]string{"openid"},
[]string{"audience"},
domain.OIDCResponseTypeCode,
nil,
nil,
nil,
nil,
nil,
nil,
),
),
),
expectFilter(
eventFromEventPusher(
session.NewAddedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld"),
),
eventFromEventPusher(
session.NewUserCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate,
"userID", testNow),
),
eventFromEventPusher(
session.NewPasswordCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate,
testNow),
),
),
expectPush(
[]*repository.Event{eventFromEventPusherWithInstanceID(
"instanceID",
authrequest.NewSessionLinkedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"sessionID",
"userID",
testNow,
[]string{amr.PWD},
),
)}),
),
tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) {
return nil
},
checkPermission: newMockPermissionCheckAllowed(),
},
args{
ctx: mockCtx,
id: "V2_id",
sessionID: "sessionID",
sessionToken: "token",
},
res{
details: &domain.ObjectDetails{ResourceOwner: "instanceID"},
authReq: &CurrentAuthRequest{
AuthRequest: &AuthRequest{
ID: "V2_id",
LoginClient: "loginClient",
ClientID: "clientID",
RedirectURI: "redirectURI",
State: "state",
Nonce: "nonce",
Scope: []string{"openid"},
Audience: []string{"audience"},
ResponseType: domain.OIDCResponseTypeCode,
},
SessionID: "sessionID",
UserID: "userID",
AMR: []string{amr.PWD},
},
},
},
{
"linked with login client check",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"loginClient",
"clientID",
"redirectURI",
"state",
"nonce",
[]string{"openid"},
[]string{"audience"},
domain.OIDCResponseTypeCode,
nil,
nil,
nil,
nil,
nil,
nil,
),
),
),
expectFilter(
eventFromEventPusher(
session.NewAddedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld"),
),
eventFromEventPusher(
session.NewUserCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate,
"userID", testNow),
),
eventFromEventPusher(
session.NewPasswordCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate,
testNow),
),
),
expectPush(
[]*repository.Event{eventFromEventPusherWithInstanceID(
"instanceID",
authrequest.NewSessionLinkedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"sessionID",
"userID",
testNow,
[]string{amr.PWD},
),
)}),
),
tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) {
return nil
},
checkPermission: newMockPermissionCheckAllowed(),
},
args{
ctx: authz.NewMockContext("instanceID", "orgID", "loginClient"),
id: "V2_id",
sessionID: "sessionID",
sessionToken: "token",
checkLoginClient: true,
},
res{
details: &domain.ObjectDetails{ResourceOwner: "instanceID"},
authReq: &CurrentAuthRequest{
AuthRequest: &AuthRequest{
ID: "V2_id",
LoginClient: "loginClient",
ClientID: "clientID",
RedirectURI: "redirectURI",
State: "state",
Nonce: "nonce",
Scope: []string{"openid"},
Audience: []string{"audience"},
ResponseType: domain.OIDCResponseTypeCode,
},
SessionID: "sessionID",
UserID: "userID",
AMR: []string{amr.PWD},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore,
sessionTokenVerifier: tt.fields.tokenVerifier,
checkPermission: tt.fields.checkPermission,
}
details, got, err := c.LinkSessionToAuthRequest(tt.args.ctx, tt.args.id, tt.args.sessionID, tt.args.sessionToken, tt.args.checkLoginClient)
require.ErrorIs(t, err, tt.res.wantErr)
assert.Equal(t, tt.res.details, details)
if err == nil {
assert.WithinRange(t, got.AuthTime, testNow, testNow)
got.AuthTime = time.Time{}
}
assert.Equal(t, tt.res.authReq, got)
})
}
}
func TestCommands_FailAuthRequest(t *testing.T) {
mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient")
type fields struct {
eventstore *eventstore.Eventstore
}
type args struct {
ctx context.Context
id string
reason domain.OIDCErrorReason
}
type res struct {
details *domain.ObjectDetails
authReq *CurrentAuthRequest
wantErr error
}
tests := []struct {
name string
fields fields
args args
res res
}{
{
"authRequest not existing",
fields{
eventstore: eventstoreExpect(t,
expectFilter(),
),
},
args{
ctx: mockCtx,
id: "foo",
reason: domain.OIDCErrorReasonLoginRequired,
},
res{
wantErr: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-Sx202nt", "Errors.AuthRequest.AlreadyHandled"),
},
},
{
"failed",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"loginClient",
"clientID",
"redirectURI",
"state",
"nonce",
[]string{"openid"},
[]string{"audience"},
domain.OIDCResponseTypeCode,
nil,
nil,
nil,
nil,
nil,
nil,
),
),
),
expectPush(
[]*repository.Event{eventFromEventPusherWithInstanceID(
"instanceID",
authrequest.NewFailedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
domain.OIDCErrorReasonLoginRequired),
)}),
),
},
args{
ctx: mockCtx,
id: "V2_id",
reason: domain.OIDCErrorReasonLoginRequired,
},
res{
details: &domain.ObjectDetails{ResourceOwner: "instanceID"},
authReq: &CurrentAuthRequest{
AuthRequest: &AuthRequest{
ID: "V2_id",
LoginClient: "loginClient",
ClientID: "clientID",
RedirectURI: "redirectURI",
State: "state",
Nonce: "nonce",
Scope: []string{"openid"},
Audience: []string{"audience"},
ResponseType: domain.OIDCResponseTypeCode,
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore,
}
details, got, err := c.FailAuthRequest(tt.args.ctx, tt.args.id, tt.args.reason)
require.ErrorIs(t, err, tt.res.wantErr)
assert.Equal(t, tt.res.details, details)
assert.Equal(t, tt.res.authReq, got)
})
}
}
func TestCommands_AddAuthRequestCode(t *testing.T) {
mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient")
type fields struct {
eventstore *eventstore.Eventstore
}
type args struct {
ctx context.Context
id string
code string
}
tests := []struct {
name string
fields fields
args args
wantErr error
}{
{
"empty code error",
fields{
eventstore: eventstoreExpect(t),
},
args{
ctx: mockCtx,
id: "V2_authRequestID",
code: "",
},
caos_errs.ThrowPreconditionFailed(nil, "COMMAND-Ht52d", "Errors.AuthRequest.InvalidCode"),
},
{
"no session linked error",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
"loginClient",
"clientID",
"redirectURI",
"state",
"nonce",
[]string{"openid"},
[]string{"audience"},
domain.OIDCResponseTypeCode,
&domain.OIDCCodeChallenge{
Challenge: "challenge",
Method: domain.CodeChallengeMethodS256,
},
[]domain.Prompt{domain.PromptNone},
[]string{"en", "de"},
gu.Ptr(time.Duration(0)),
gu.Ptr("loginHint"),
gu.Ptr("hintUserID"),
),
),
),
),
},
args{
ctx: mockCtx,
id: "V2_authRequestID",
code: "V2_authRequestID",
},
caos_errs.ThrowPreconditionFailed(nil, "COMMAND-SFwd2", "Errors.AuthRequest.AlreadyHandled"),
},
{
"success",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
"loginClient",
"clientID",
"redirectURI",
"state",
"nonce",
[]string{"openid"},
[]string{"audience"},
domain.OIDCResponseTypeCode,
&domain.OIDCCodeChallenge{
Challenge: "challenge",
Method: domain.CodeChallengeMethodS256,
},
[]domain.Prompt{domain.PromptNone},
[]string{"en", "de"},
gu.Ptr(time.Duration(0)),
gu.Ptr("loginHint"),
gu.Ptr("hintUserID"),
),
),
eventFromEventPusher(
authrequest.NewSessionLinkedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
"sessionID",
"userID",
testNow,
[]string{amr.PWD},
),
),
),
expectPush(
[]*repository.Event{
eventFromEventPusherWithInstanceID("instanceID",
authrequest.NewCodeAddedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
),
},
),
),
},
args{
ctx: mockCtx,
id: "V2_authRequestID",
code: "V2_authRequestID",
},
nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore,
}
err := c.AddAuthRequestCode(tt.args.ctx, tt.args.id, tt.args.code)
assert.ErrorIs(t, tt.wantErr, err)
})
}
}
func TestCommands_ExchangeAuthCode(t *testing.T) {
mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient")
type fields struct {
eventstore *eventstore.Eventstore
}
type args struct {
ctx context.Context
code string
}
type res struct {
authRequest *CurrentAuthRequest
err error
}
tests := []struct {
name string
fields fields
args args
res res
}{
{
"empty code error",
fields{
eventstore: eventstoreExpect(t),
},
args{
ctx: mockCtx,
code: "",
},
res{
err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-Sf3g2", "Errors.AuthRequest.InvalidCode"),
},
},
{
"no code added error",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
"loginClient",
"clientID",
"redirectURI",
"state",
"nonce",
[]string{"openid"},
[]string{"audience"},
domain.OIDCResponseTypeCode,
&domain.OIDCCodeChallenge{
Challenge: "challenge",
Method: domain.CodeChallengeMethodS256,
},
[]domain.Prompt{domain.PromptNone},
[]string{"en", "de"},
gu.Ptr(time.Duration(0)),
gu.Ptr("loginHint"),
gu.Ptr("hintUserID"),
),
),
),
),
},
args{
ctx: mockCtx,
code: "V2_authRequestID",
},
res{
err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-SFwd2", "Errors.AuthRequest.NoCode"),
},
},
{
"code exchanged",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
"loginClient",
"clientID",
"redirectURI",
"state",
"nonce",
[]string{"openid"},
[]string{"audience"},
domain.OIDCResponseTypeCode,
&domain.OIDCCodeChallenge{
Challenge: "challenge",
Method: domain.CodeChallengeMethodS256,
},
[]domain.Prompt{domain.PromptNone},
[]string{"en", "de"},
gu.Ptr(time.Duration(0)),
gu.Ptr("loginHint"),
gu.Ptr("hintUserID"),
),
),
eventFromEventPusher(
authrequest.NewSessionLinkedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
"sessionID",
"userID",
testNow,
[]string{amr.PWD},
),
),
eventFromEventPusher(
authrequest.NewCodeAddedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
),
),
expectPush(
[]*repository.Event{
eventFromEventPusherWithInstanceID("instanceID",
authrequest.NewCodeExchangedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
),
},
),
),
},
args{
ctx: mockCtx,
code: "V2_authRequestID",
},
res{
authRequest: &CurrentAuthRequest{
AuthRequest: &AuthRequest{
ID: "V2_authRequestID",
LoginClient: "loginClient",
ClientID: "clientID",
RedirectURI: "redirectURI",
State: "state",
Nonce: "nonce",
Scope: []string{"openid"},
Audience: []string{"audience"},
ResponseType: domain.OIDCResponseTypeCode,
CodeChallenge: &domain.OIDCCodeChallenge{
Challenge: "challenge",
Method: domain.CodeChallengeMethodS256,
},
Prompt: []domain.Prompt{domain.PromptNone},
UILocales: []string{"en", "de"},
MaxAge: gu.Ptr(time.Duration(0)),
LoginHint: gu.Ptr("loginHint"),
HintUserID: gu.Ptr("hintUserID"),
},
SessionID: "sessionID",
UserID: "userID",
AMR: []string{"pwd"},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore,
}
got, err := c.ExchangeAuthCode(tt.args.ctx, tt.args.code)
assert.ErrorIs(t, tt.res.err, err)
if err == nil {
// equal on time won't work -> test separately and clear it before comparing the rest
assert.WithinRange(t, got.AuthTime, testNow, testNow)
got.AuthTime = time.Time{}
}
assert.Equal(t, tt.res.authRequest, got)
})
}
}

View File

@ -15,10 +15,12 @@ import (
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/id"
"github.com/zitadel/zitadel/internal/repository/action"
"github.com/zitadel/zitadel/internal/repository/authrequest"
"github.com/zitadel/zitadel/internal/repository/idpintent"
instance_repo "github.com/zitadel/zitadel/internal/repository/instance"
"github.com/zitadel/zitadel/internal/repository/keypair"
"github.com/zitadel/zitadel/internal/repository/milestone"
"github.com/zitadel/zitadel/internal/repository/oidcsession"
"github.com/zitadel/zitadel/internal/repository/org"
proj_repo "github.com/zitadel/zitadel/internal/repository/project"
"github.com/zitadel/zitadel/internal/repository/quota"
@ -43,18 +45,21 @@ type Commands struct {
externalSecure bool
externalPort uint16
idpConfigEncryption crypto.EncryptionAlgorithm
smtpEncryption crypto.EncryptionAlgorithm
smsEncryption crypto.EncryptionAlgorithm
userEncryption crypto.EncryptionAlgorithm
userPasswordAlg crypto.HashAlgorithm
machineKeySize int
applicationKeySize int
domainVerificationAlg crypto.EncryptionAlgorithm
domainVerificationGenerator crypto.Generator
domainVerificationValidator func(domain, token, verifier string, checkType api_http.CheckType) error
sessionTokenCreator func(sessionID string) (id string, token string, err error)
sessionTokenVerifier func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error)
idpConfigEncryption crypto.EncryptionAlgorithm
smtpEncryption crypto.EncryptionAlgorithm
smsEncryption crypto.EncryptionAlgorithm
userEncryption crypto.EncryptionAlgorithm
userPasswordAlg crypto.HashAlgorithm
machineKeySize int
applicationKeySize int
domainVerificationAlg crypto.EncryptionAlgorithm
domainVerificationGenerator crypto.Generator
domainVerificationValidator func(domain, token, verifier string, checkType api_http.CheckType) error
sessionTokenCreator func(sessionID string) (id string, token string, err error)
sessionTokenVerifier func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error)
defaultAccessTokenLifetime time.Duration
defaultRefreshTokenLifetime time.Duration
defaultRefreshTokenIdleLifetime time.Duration
multifactors domain.MultifactorConfigs
webauthnConfig *webauthn_helper.Config
@ -80,6 +85,9 @@ func StartCommands(
httpClient *http.Client,
permissionCheck domain.PermissionCheck,
sessionTokenVerifier func(ctx context.Context, sessionToken string, sessionID string, tokenID string) (err error),
defaultAccessTokenLifetime,
defaultRefreshTokenLifetime,
defaultRefreshTokenIdleLifetime time.Duration,
) (repo *Commands, err error) {
if externalDomain == "" {
return nil, errors.ThrowInvalidArgument(nil, "COMMAND-Df21s", "no external domain specified")
@ -88,31 +96,34 @@ func StartCommands(
// reuse the oidcEncryption to be able to handle both tokens in the interceptor later on
sessionAlg := oidcEncryption
repo = &Commands{
eventstore: es,
static: staticStore,
idGenerator: idGenerator,
zitadelRoles: zitadelRoles,
externalDomain: externalDomain,
externalSecure: externalSecure,
externalPort: externalPort,
keySize: defaults.KeyConfig.Size,
certKeySize: defaults.KeyConfig.CertificateSize,
privateKeyLifetime: defaults.KeyConfig.PrivateKeyLifetime,
publicKeyLifetime: defaults.KeyConfig.PublicKeyLifetime,
certificateLifetime: defaults.KeyConfig.CertificateLifetime,
idpConfigEncryption: idpConfigEncryption,
smtpEncryption: smtpEncryption,
smsEncryption: smsEncryption,
userEncryption: userEncryption,
domainVerificationAlg: domainVerificationEncryption,
keyAlgorithm: oidcEncryption,
certificateAlgorithm: samlEncryption,
webauthnConfig: webAuthN,
httpClient: httpClient,
checkPermission: permissionCheck,
newCode: newCryptoCode,
sessionTokenCreator: sessionTokenCreator(idGenerator, sessionAlg),
sessionTokenVerifier: sessionTokenVerifier,
eventstore: es,
static: staticStore,
idGenerator: idGenerator,
zitadelRoles: zitadelRoles,
externalDomain: externalDomain,
externalSecure: externalSecure,
externalPort: externalPort,
keySize: defaults.KeyConfig.Size,
certKeySize: defaults.KeyConfig.CertificateSize,
privateKeyLifetime: defaults.KeyConfig.PrivateKeyLifetime,
publicKeyLifetime: defaults.KeyConfig.PublicKeyLifetime,
certificateLifetime: defaults.KeyConfig.CertificateLifetime,
idpConfigEncryption: idpConfigEncryption,
smtpEncryption: smtpEncryption,
smsEncryption: smsEncryption,
userEncryption: userEncryption,
domainVerificationAlg: domainVerificationEncryption,
keyAlgorithm: oidcEncryption,
certificateAlgorithm: samlEncryption,
webauthnConfig: webAuthN,
httpClient: httpClient,
checkPermission: permissionCheck,
newCode: newCryptoCode,
sessionTokenCreator: sessionTokenCreator(idGenerator, sessionAlg),
sessionTokenVerifier: sessionTokenVerifier,
defaultAccessTokenLifetime: defaultAccessTokenLifetime,
defaultRefreshTokenLifetime: defaultRefreshTokenLifetime,
defaultRefreshTokenIdleLifetime: defaultRefreshTokenIdleLifetime,
}
instance_repo.RegisterEventMappers(repo.eventstore)
@ -125,6 +136,8 @@ func StartCommands(
quota.RegisterEventMappers(repo.eventstore)
session.RegisterEventMappers(repo.eventstore)
idpintent.RegisterEventMappers(repo.eventstore)
authrequest.RegisterEventMappers(repo.eventstore)
oidcsession.RegisterEventMappers(repo.eventstore)
milestone.RegisterEventMappers(repo.eventstore)
repo.userPasswordAlg = crypto.NewBCrypt(defaults.SecretGenerators.PasswordSaltCost)

View File

@ -16,9 +16,11 @@ import (
"github.com/zitadel/zitadel/internal/eventstore/repository"
"github.com/zitadel/zitadel/internal/eventstore/repository/mock"
action_repo "github.com/zitadel/zitadel/internal/repository/action"
"github.com/zitadel/zitadel/internal/repository/authrequest"
"github.com/zitadel/zitadel/internal/repository/idpintent"
iam_repo "github.com/zitadel/zitadel/internal/repository/instance"
key_repo "github.com/zitadel/zitadel/internal/repository/keypair"
"github.com/zitadel/zitadel/internal/repository/oidcsession"
"github.com/zitadel/zitadel/internal/repository/org"
proj_repo "github.com/zitadel/zitadel/internal/repository/project"
"github.com/zitadel/zitadel/internal/repository/session"
@ -43,6 +45,8 @@ func eventstoreExpect(t *testing.T, expects ...expect) *eventstore.Eventstore {
action_repo.RegisterEventMappers(es)
session.RegisterEventMappers(es)
idpintent.RegisterEventMappers(es)
authrequest.RegisterEventMappers(es)
oidcsession.RegisterEventMappers(es)
return es
}

View File

@ -0,0 +1,281 @@
package command
import (
"context"
"encoding/base64"
"strings"
"time"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/oidc/amr"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
caos_errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/id"
"github.com/zitadel/zitadel/internal/repository/authrequest"
"github.com/zitadel/zitadel/internal/repository/oidcsession"
)
// AddOIDCSessionAccessToken creates a new OIDC Session, creates an access token and returns its id and expiration.
// If the underlying [AuthRequest] is a OIDC Auth Code Flow, it will set the code as exchanged.
func (c *Commands) AddOIDCSessionAccessToken(ctx context.Context, authRequestID string) (string, time.Time, error) {
cmd, err := c.newOIDCSessionAddEvents(ctx, authRequestID)
if err != nil {
return "", time.Time{}, err
}
cmd.AddSession(ctx)
if err = cmd.AddAccessToken(ctx, cmd.authRequestWriteModel.Scope); err != nil {
return "", time.Time{}, err
}
cmd.SetAuthRequestSuccessful(ctx)
accessTokenID, _, accessTokenExpiration, err := cmd.PushEvents(ctx)
return accessTokenID, accessTokenExpiration, err
}
// AddOIDCSessionRefreshAndAccessToken creates a new OIDC Session, creates an access token and refresh token.
// It returns the access token id, expiration and the refresh token.
// If the underlying [AuthRequest] is a OIDC Auth Code Flow, it will set the code as exchanged.
func (c *Commands) AddOIDCSessionRefreshAndAccessToken(ctx context.Context, authRequestID string) (tokenID, refreshToken string, tokenExpiration time.Time, err error) {
cmd, err := c.newOIDCSessionAddEvents(ctx, authRequestID)
if err != nil {
return "", "", time.Time{}, err
}
cmd.AddSession(ctx)
if err = cmd.AddAccessToken(ctx, cmd.authRequestWriteModel.Scope); err != nil {
return "", "", time.Time{}, err
}
if err = cmd.AddRefreshToken(ctx); err != nil {
return "", "", time.Time{}, err
}
cmd.SetAuthRequestSuccessful(ctx)
return cmd.PushEvents(ctx)
}
// ExchangeOIDCSessionRefreshAndAccessToken updates an existing OIDC Session, creates a new access and refresh token.
// It returns the access token id and expiration and the new refresh token.
func (c *Commands) ExchangeOIDCSessionRefreshAndAccessToken(ctx context.Context, oidcSessionID, refreshToken string, scope []string) (tokenID, newRefreshToken string, tokenExpiration time.Time, err error) {
cmd, err := c.newOIDCSessionUpdateEvents(ctx, oidcSessionID, refreshToken)
if err != nil {
return "", "", time.Time{}, err
}
if err = cmd.AddAccessToken(ctx, scope); err != nil {
return "", "", time.Time{}, err
}
if err = cmd.RenewRefreshToken(ctx); err != nil {
return "", "", time.Time{}, err
}
return cmd.PushEvents(ctx)
}
// OIDCSessionByRefreshToken computes the current state of an existing OIDCSession by a refresh_token (to start a Refresh Token Grant).
// If either the session is not active, the token is invalid or expired (incl. idle expiration) an invalid refresh token error will be returned.
func (c *Commands) OIDCSessionByRefreshToken(ctx context.Context, refreshToken string) (*OIDCSessionWriteModel, error) {
split := strings.Split(refreshToken, ":")
if len(split) != 2 {
return nil, caos_errs.ThrowPreconditionFailed(nil, "OIDCS-JOI23", "Errors.OIDCSession.RefreshTokenInvalid")
}
writeModel := NewOIDCSessionWriteModel(split[0], "")
err := c.eventstore.FilterToQueryReducer(ctx, writeModel)
if err != nil {
return nil, caos_errs.ThrowPreconditionFailed(err, "OIDCS-SAF31", "Errors.OIDCSession.RefreshTokenInvalid")
}
if err = writeModel.CheckRefreshToken(split[1]); err != nil {
return nil, err
}
return writeModel, nil
}
func (c *Commands) newOIDCSessionAddEvents(ctx context.Context, authRequestID string) (*OIDCSessionEvents, error) {
authRequestWriteModel, err := c.getAuthRequestWriteModel(ctx, authRequestID)
if err != nil {
return nil, err
}
if err = authRequestWriteModel.CheckAuthenticated(); err != nil {
return nil, err
}
sessionWriteModel := NewSessionWriteModel(authRequestWriteModel.SessionID, authz.GetCtxData(ctx).OrgID)
err = c.eventstore.FilterToQueryReducer(ctx, sessionWriteModel)
if err != nil {
return nil, err
}
if sessionWriteModel.State != domain.SessionStateActive {
return nil, caos_errs.ThrowPreconditionFailed(nil, "OIDCS-sjkl3", "Errors.Session.Terminated")
}
accessTokenLifetime, refreshTokenLifeTime, refreshTokenIdleLifetime, err := c.tokenTokenLifetimes(ctx)
if err != nil {
return nil, err
}
sessionID, err := c.idGenerator.Next()
if err != nil {
return nil, err
}
sessionID = IDPrefixV2 + sessionID
return &OIDCSessionEvents{
eventstore: c.eventstore,
idGenerator: c.idGenerator,
encryptionAlg: c.keyAlgorithm,
oidcSessionWriteModel: NewOIDCSessionWriteModel(sessionID, authz.GetInstance(ctx).InstanceID()),
sessionWriteModel: sessionWriteModel,
authRequestWriteModel: authRequestWriteModel,
accessTokenLifetime: accessTokenLifetime,
refreshTokenLifeTime: refreshTokenLifeTime,
refreshTokenIdleLifetime: refreshTokenIdleLifetime,
}, nil
}
func (c *Commands) decryptRefreshToken(refreshToken string) (refreshTokenID string, err error) {
decoded, err := base64.RawURLEncoding.DecodeString(refreshToken)
if err != nil {
return "", err
}
decrypted, err := c.keyAlgorithm.DecryptString(decoded, c.keyAlgorithm.EncryptionKeyID())
if err != nil {
return "", err
}
split := strings.Split(decrypted, ":")
if len(split) != 2 {
return "", caos_errs.ThrowPreconditionFailed(nil, "OIDCS-Sj3lk", "Errors.OIDCSession.RefreshTokenInvalid")
}
return split[1], nil
}
func (c *Commands) newOIDCSessionUpdateEvents(ctx context.Context, oidcSessionID, refreshToken string) (*OIDCSessionEvents, error) {
refreshTokenID, err := c.decryptRefreshToken(refreshToken)
if err != nil {
return nil, err
}
sessionWriteModel := NewOIDCSessionWriteModel(oidcSessionID, authz.GetInstance(ctx).InstanceID())
if err = c.eventstore.FilterToQueryReducer(ctx, sessionWriteModel); err != nil {
return nil, err
}
if err = sessionWriteModel.CheckRefreshToken(refreshTokenID); err != nil {
return nil, err
}
accessTokenLifetime, refreshTokenLifeTime, refreshTokenIdleLifetime, err := c.tokenTokenLifetimes(ctx)
if err != nil {
return nil, err
}
return &OIDCSessionEvents{
eventstore: c.eventstore,
idGenerator: c.idGenerator,
encryptionAlg: c.keyAlgorithm,
oidcSessionWriteModel: sessionWriteModel,
accessTokenLifetime: accessTokenLifetime,
refreshTokenLifeTime: refreshTokenLifeTime,
refreshTokenIdleLifetime: refreshTokenIdleLifetime,
}, nil
}
type OIDCSessionEvents struct {
eventstore *eventstore.Eventstore
idGenerator id.Generator
encryptionAlg crypto.EncryptionAlgorithm
events []eventstore.Command
oidcSessionWriteModel *OIDCSessionWriteModel
sessionWriteModel *SessionWriteModel
authRequestWriteModel *AuthRequestWriteModel
accessTokenLifetime time.Duration
refreshTokenLifeTime time.Duration
refreshTokenIdleLifetime time.Duration
// accessTokenID is set by the command
accessTokenID string
// refreshToken is set by the command
refreshToken string
}
func (c *OIDCSessionEvents) AddSession(ctx context.Context) {
c.events = append(c.events, oidcsession.NewAddedEvent(
ctx,
c.oidcSessionWriteModel.aggregate,
c.sessionWriteModel.UserID,
c.sessionWriteModel.AggregateID,
c.authRequestWriteModel.ClientID,
c.authRequestWriteModel.Audience,
c.authRequestWriteModel.Scope,
amr.List(c.sessionWriteModel),
c.sessionWriteModel.AuthenticationTime(),
))
}
func (c *OIDCSessionEvents) SetAuthRequestSuccessful(ctx context.Context) {
c.events = append(c.events, authrequest.NewSucceededEvent(ctx, c.authRequestWriteModel.aggregate))
}
func (c *OIDCSessionEvents) AddAccessToken(ctx context.Context, scope []string) (err error) {
c.accessTokenID, err = c.idGenerator.Next()
if err != nil {
return err
}
c.events = append(c.events, oidcsession.NewAccessTokenAddedEvent(ctx, c.oidcSessionWriteModel.aggregate, c.accessTokenID, scope, c.accessTokenLifetime))
return nil
}
func (c *OIDCSessionEvents) AddRefreshToken(ctx context.Context) (err error) {
var refreshTokenID string
refreshTokenID, c.refreshToken, err = c.generateRefreshToken()
if err != nil {
return err
}
c.events = append(c.events, oidcsession.NewRefreshTokenAddedEvent(ctx, c.oidcSessionWriteModel.aggregate, refreshTokenID, c.refreshTokenLifeTime, c.refreshTokenIdleLifetime))
return nil
}
func (c *OIDCSessionEvents) RenewRefreshToken(ctx context.Context) (err error) {
var refreshTokenID string
refreshTokenID, c.refreshToken, err = c.generateRefreshToken()
if err != nil {
return err
}
c.events = append(c.events, oidcsession.NewRefreshTokenRenewedEvent(ctx, c.oidcSessionWriteModel.aggregate, refreshTokenID, c.refreshTokenIdleLifetime))
return nil
}
func (c *OIDCSessionEvents) generateRefreshToken() (refreshTokenID, refreshToken string, err error) {
refreshTokenID, err = c.idGenerator.Next()
if err != nil {
return "", "", err
}
token, err := c.encryptionAlg.Encrypt([]byte(c.oidcSessionWriteModel.AggregateID + ":" + refreshTokenID))
if err != nil {
return "", "", err
}
return refreshTokenID, base64.RawURLEncoding.EncodeToString(token), nil
}
func (c *OIDCSessionEvents) PushEvents(ctx context.Context) (accessTokenID string, refreshToken string, accessTokenExpiration time.Time, err error) {
pushedEvents, err := c.eventstore.Push(ctx, c.events...)
if err != nil {
return "", "", time.Time{}, err
}
err = AppendAndReduce(c.oidcSessionWriteModel, pushedEvents...)
if err != nil {
return "", "", time.Time{}, err
}
// prefix the returned id with the oidcSessionID so that we can retrieve it later on
// we need to use `-` as a delimiter because the OIDC library uses `:` and will check for a length of 2 parts
return c.oidcSessionWriteModel.AggregateID + "-" + c.accessTokenID, c.refreshToken, c.oidcSessionWriteModel.AccessTokenExpiration, nil
}
func (c *Commands) tokenTokenLifetimes(ctx context.Context) (accessTokenLifetime time.Duration, refreshTokenLifetime time.Duration, refreshTokenIdleLifetime time.Duration, err error) {
oidcSettings := NewInstanceOIDCSettingsWriteModel(ctx)
err = c.eventstore.FilterToQueryReducer(ctx, oidcSettings)
if err != nil {
return 0, 0, 0, err
}
accessTokenLifetime = c.defaultAccessTokenLifetime
refreshTokenLifetime = c.defaultRefreshTokenLifetime
refreshTokenIdleLifetime = c.defaultRefreshTokenIdleLifetime
if oidcSettings.AccessTokenLifetime > 0 {
accessTokenLifetime = oidcSettings.AccessTokenLifetime
}
if oidcSettings.RefreshTokenExpiration > 0 {
refreshTokenLifetime = oidcSettings.RefreshTokenExpiration
}
if oidcSettings.RefreshTokenIdleExpiration > 0 {
refreshTokenIdleLifetime = oidcSettings.RefreshTokenIdleExpiration
}
return accessTokenLifetime, refreshTokenLifetime, refreshTokenIdleLifetime, nil
}

View File

@ -0,0 +1,114 @@
package command
import (
"time"
"github.com/zitadel/zitadel/internal/domain"
caos_errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/oidcsession"
)
type OIDCSessionWriteModel struct {
eventstore.WriteModel
UserID string
SessionID string
ClientID string
Audience []string
Scope []string
AuthMethodsReferences []string
AuthTime time.Time
State domain.OIDCSessionState
AccessTokenExpiration time.Time
RefreshTokenID string
RefreshTokenExpiration time.Time
RefreshTokenIdleExpiration time.Time
aggregate *eventstore.Aggregate
}
func NewOIDCSessionWriteModel(id string, resourceOwner string) *OIDCSessionWriteModel {
return &OIDCSessionWriteModel{
WriteModel: eventstore.WriteModel{
AggregateID: id,
ResourceOwner: resourceOwner,
},
aggregate: &oidcsession.NewAggregate(id, resourceOwner).Aggregate,
}
}
func (wm *OIDCSessionWriteModel) Reduce() error {
for _, event := range wm.Events {
switch e := event.(type) {
case *oidcsession.AddedEvent:
wm.reduceAdded(e)
case *oidcsession.AccessTokenAddedEvent:
wm.reduceAccessTokenAdded(e)
case *oidcsession.RefreshTokenAddedEvent:
wm.reduceRefreshTokenAdded(e)
case *oidcsession.RefreshTokenRenewedEvent:
wm.reduceRefreshTokenRenewed(e)
}
}
return wm.WriteModel.Reduce()
}
func (wm *OIDCSessionWriteModel) Query() *eventstore.SearchQueryBuilder {
query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
AddQuery().
AggregateTypes(oidcsession.AggregateType).
AggregateIDs(wm.AggregateID).
EventTypes(
oidcsession.AddedType,
oidcsession.AccessTokenAddedType,
oidcsession.RefreshTokenAddedType,
oidcsession.RefreshTokenRenewedType,
).
Builder()
if wm.ResourceOwner != "" {
query.ResourceOwner(wm.ResourceOwner)
}
return query
}
func (wm *OIDCSessionWriteModel) reduceAdded(e *oidcsession.AddedEvent) {
wm.UserID = e.UserID
wm.SessionID = e.SessionID
wm.ClientID = e.ClientID
wm.Audience = e.Audience
wm.Scope = e.Scope
wm.AuthMethodsReferences = e.AuthMethodsReferences
wm.AuthTime = e.AuthTime
wm.State = domain.OIDCSessionStateActive
}
func (wm *OIDCSessionWriteModel) reduceAccessTokenAdded(e *oidcsession.AccessTokenAddedEvent) {
wm.AccessTokenExpiration = e.CreationDate().Add(e.Lifetime)
}
func (wm *OIDCSessionWriteModel) reduceRefreshTokenAdded(e *oidcsession.RefreshTokenAddedEvent) {
wm.RefreshTokenID = e.ID
wm.RefreshTokenExpiration = e.CreationDate().Add(e.Lifetime)
wm.RefreshTokenIdleExpiration = e.CreationDate().Add(e.IdleLifetime)
}
func (wm *OIDCSessionWriteModel) reduceRefreshTokenRenewed(e *oidcsession.RefreshTokenRenewedEvent) {
wm.RefreshTokenID = e.ID
wm.RefreshTokenIdleExpiration = e.CreationDate().Add(e.IdleLifetime)
}
func (wm *OIDCSessionWriteModel) CheckRefreshToken(refreshTokenID string) error {
if wm.State != domain.OIDCSessionStateActive {
return caos_errs.ThrowPreconditionFailed(nil, "OIDCS-s3hjk", "Errors.OIDCSession.RefreshTokenInvalid")
}
if wm.RefreshTokenID != refreshTokenID {
return caos_errs.ThrowPreconditionFailed(nil, "OIDCS-28ubl", "Errors.OIDCSession.RefreshTokenInvalid")
}
now := time.Now()
if wm.RefreshTokenExpiration.Before(now) || wm.RefreshTokenIdleExpiration.Before(now) {
return caos_errs.ThrowPreconditionFailed(nil, "OIDCS-3jt2w", "Errors.OIDCSession.RefreshTokenInvalid")
}
return nil
}

View File

@ -0,0 +1,795 @@
package command
import (
"context"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/oidc/amr"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
caos_errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/repository"
"github.com/zitadel/zitadel/internal/id"
"github.com/zitadel/zitadel/internal/id/mock"
"github.com/zitadel/zitadel/internal/repository/authrequest"
"github.com/zitadel/zitadel/internal/repository/oidcsession"
"github.com/zitadel/zitadel/internal/repository/session"
)
var (
testNow = time.Now()
tokenCreationNow = time.Time{}
)
func TestCommands_AddOIDCSessionAccessToken(t *testing.T) {
type fields struct {
eventstore *eventstore.Eventstore
idGenerator id.Generator
defaultAccessTokenLifetime time.Duration
defaultRefreshTokenLifetime time.Duration
defaultRefreshTokenIdleLifetime time.Duration
keyAlgorithm crypto.EncryptionAlgorithm
}
type args struct {
ctx context.Context
authRequestID string
}
type res struct {
id string
expiration time.Time
err error
}
tests := []struct {
name string
fields fields
args args
res res
}{
{
"unauthenticated error",
fields{
eventstore: eventstoreExpect(t,
expectFilter(),
),
},
args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
authRequestID: "V2_authRequestID",
},
res{
err: caos_errs.ThrowPreconditionFailed(nil, "AUTHR-SF2r2", "Errors.AuthRequest.NotAuthenticated"),
},
},
{
"inactive session error",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
authrequest.NewAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
"loginClient",
"clientID",
"redirectURI",
"state",
"nonce",
[]string{"openid"},
[]string{"audience"},
domain.OIDCResponseTypeCode,
&domain.OIDCCodeChallenge{
Challenge: "challenge",
Method: domain.CodeChallengeMethodS256,
},
[]domain.Prompt{domain.PromptNone},
[]string{"en", "de"},
gu.Ptr(time.Duration(0)),
gu.Ptr("loginHint"),
gu.Ptr("hintUserID"),
),
),
eventFromEventPusher(
authrequest.NewSessionLinkedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
"sessionID",
"userID",
testNow,
[]string{amr.PWD},
),
),
eventFromEventPusher(
authrequest.NewCodeAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
),
eventFromEventPusher(
authrequest.NewCodeExchangedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
),
),
expectFilter(),
),
},
args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
authRequestID: "V2_authRequestID",
},
res{
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-sjkl3", "Errors.Session.Terminated"),
},
},
{
"add successful",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
authrequest.NewAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
"loginClient",
"clientID",
"redirectURI",
"state",
"nonce",
[]string{"openid"},
[]string{"audience"},
domain.OIDCResponseTypeCode,
&domain.OIDCCodeChallenge{
Challenge: "challenge",
Method: domain.CodeChallengeMethodS256,
},
[]domain.Prompt{domain.PromptNone},
[]string{"en", "de"},
gu.Ptr(time.Duration(0)),
gu.Ptr("loginHint"),
gu.Ptr("hintUserID"),
),
),
eventFromEventPusher(
authrequest.NewSessionLinkedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
"sessionID",
"userID",
testNow,
[]string{amr.PWD},
),
),
eventFromEventPusher(
authrequest.NewCodeAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
),
eventFromEventPusher(
authrequest.NewCodeExchangedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
),
),
expectFilter(
eventFromEventPusher(
session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate, "domain.tld"),
),
eventFromEventPusher(
session.NewUserCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate,
"userID", testNow),
),
eventFromEventPusher(
session.NewPasswordCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate,
testNow),
),
),
expectFilter(), // token lifetime
expectPush(
[]*repository.Event{
eventFromEventPusherWithInstanceID("instanceID",
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid"}, []string{amr.PWD}, testNow),
),
eventFromEventPusherWithInstanceID("instanceID",
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"accessTokenID", []string{"openid"}, time.Hour),
),
eventFromEventPusherWithInstanceID("instanceID",
authrequest.NewSucceededEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
),
},
),
),
idGenerator: mock.NewIDGeneratorExpectIDs(t, "oidcSessionID", "accessTokenID"),
defaultAccessTokenLifetime: time.Hour,
},
args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
authRequestID: "V2_authRequestID",
},
res{
id: "V2_oidcSessionID-accessTokenID",
expiration: tokenCreationNow.Add(time.Hour),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore,
idGenerator: tt.fields.idGenerator,
defaultAccessTokenLifetime: tt.fields.defaultAccessTokenLifetime,
defaultRefreshTokenLifetime: tt.fields.defaultRefreshTokenLifetime,
defaultRefreshTokenIdleLifetime: tt.fields.defaultRefreshTokenIdleLifetime,
keyAlgorithm: tt.fields.keyAlgorithm,
}
gotID, gotExpiration, err := c.AddOIDCSessionAccessToken(tt.args.ctx, tt.args.authRequestID)
assert.Equal(t, tt.res.id, gotID)
assert.Equal(t, tt.res.expiration, gotExpiration)
assert.ErrorIs(t, err, tt.res.err)
})
}
}
func TestCommands_AddOIDCSessionRefreshAndAccessToken(t *testing.T) {
type fields struct {
eventstore *eventstore.Eventstore
idGenerator id.Generator
defaultAccessTokenLifetime time.Duration
defaultRefreshTokenLifetime time.Duration
defaultRefreshTokenIdleLifetime time.Duration
keyAlgorithm crypto.EncryptionAlgorithm
}
type args struct {
ctx context.Context
authRequestID string
}
type res struct {
id string
refreshToken string
expiration time.Time
err error
}
tests := []struct {
name string
fields fields
args args
res res
}{
{
"unauthenticated error",
fields{
eventstore: eventstoreExpect(t,
expectFilter(),
),
},
args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
authRequestID: "V2_authRequestID",
},
res{
err: caos_errs.ThrowPreconditionFailed(nil, "AUTHR-SF2r2", "Errors.AuthRequest.NotAuthenticated"),
},
},
{
"inactive session error",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
authrequest.NewAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
"loginClient",
"clientID",
"redirectURI",
"state",
"nonce",
[]string{"openid", "offline_access"},
[]string{"audience"},
domain.OIDCResponseTypeCode,
&domain.OIDCCodeChallenge{
Challenge: "challenge",
Method: domain.CodeChallengeMethodS256,
},
[]domain.Prompt{domain.PromptNone},
[]string{"en", "de"},
gu.Ptr(time.Duration(0)),
gu.Ptr("loginHint"),
gu.Ptr("hintUserID"),
),
),
eventFromEventPusher(
authrequest.NewSessionLinkedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
"sessionID",
"userID",
testNow,
[]string{amr.PWD},
),
),
eventFromEventPusher(
authrequest.NewCodeAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
),
eventFromEventPusher(
authrequest.NewCodeExchangedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
),
),
expectFilter(),
),
},
args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
authRequestID: "V2_authRequestID",
},
res{
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-sjkl3", "Errors.Session.Terminated"),
},
},
{
"add successful",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
authrequest.NewAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
"loginClient",
"clientID",
"redirectURI",
"state",
"nonce",
[]string{"openid", "offline_access"},
[]string{"audience"},
domain.OIDCResponseTypeCode,
&domain.OIDCCodeChallenge{
Challenge: "challenge",
Method: domain.CodeChallengeMethodS256,
},
[]domain.Prompt{domain.PromptNone},
[]string{"en", "de"},
gu.Ptr(time.Duration(0)),
gu.Ptr("loginHint"),
gu.Ptr("hintUserID"),
),
),
eventFromEventPusher(
authrequest.NewSessionLinkedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
"sessionID",
"userID",
testNow,
[]string{amr.PWD},
),
),
eventFromEventPusher(
authrequest.NewCodeAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
),
eventFromEventPusher(
authrequest.NewCodeExchangedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
),
),
expectFilter(
eventFromEventPusher(
session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate, "domain.tld"),
),
eventFromEventPusher(
session.NewUserCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate,
"userID", testNow),
),
eventFromEventPusher(
session.NewPasswordCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate,
testNow),
),
),
expectFilter(), // token lifetime
expectPush(
[]*repository.Event{
eventFromEventPusherWithInstanceID("instanceID",
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "offline_access"}, []string{amr.PWD}, testNow),
),
eventFromEventPusherWithInstanceID("instanceID",
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"accessTokenID", []string{"openid", "offline_access"}, time.Hour),
),
eventFromEventPusherWithInstanceID("instanceID",
oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"refreshTokenID", 7*24*time.Hour, 24*time.Hour),
),
eventFromEventPusherWithInstanceID("instanceID",
authrequest.NewSucceededEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
),
},
),
),
idGenerator: mock.NewIDGeneratorExpectIDs(t, "oidcSessionID", "accessTokenID", "refreshTokenID"),
defaultAccessTokenLifetime: time.Hour,
defaultRefreshTokenLifetime: 7 * 24 * time.Hour,
defaultRefreshTokenIdleLifetime: 24 * time.Hour,
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
authRequestID: "V2_authRequestID",
},
res{
id: "V2_oidcSessionID-accessTokenID",
refreshToken: "VjJfb2lkY1Nlc3Npb25JRDpyZWZyZXNoVG9rZW5JRA", //V2_oidcSessionID:refreshTokenID
expiration: tokenCreationNow.Add(time.Hour),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore,
idGenerator: tt.fields.idGenerator,
defaultAccessTokenLifetime: tt.fields.defaultAccessTokenLifetime,
defaultRefreshTokenLifetime: tt.fields.defaultRefreshTokenLifetime,
defaultRefreshTokenIdleLifetime: tt.fields.defaultRefreshTokenIdleLifetime,
keyAlgorithm: tt.fields.keyAlgorithm,
}
gotID, gotRefreshToken, gotExpiration, err := c.AddOIDCSessionRefreshAndAccessToken(tt.args.ctx, tt.args.authRequestID)
assert.Equal(t, tt.res.id, gotID)
assert.Equal(t, tt.res.refreshToken, gotRefreshToken)
assert.Equal(t, tt.res.expiration, gotExpiration)
assert.ErrorIs(t, err, tt.res.err)
})
}
}
func TestCommands_ExchangeOIDCSessionRefreshAndAccessToken(t *testing.T) {
type fields struct {
eventstore *eventstore.Eventstore
idGenerator id.Generator
defaultAccessTokenLifetime time.Duration
defaultRefreshTokenLifetime time.Duration
defaultRefreshTokenIdleLifetime time.Duration
keyAlgorithm crypto.EncryptionAlgorithm
}
type args struct {
ctx context.Context
oidcSessionID string
refreshToken string
scope []string
}
type res struct {
id string
refreshToken string
expiration time.Time
err error
}
tests := []struct {
name string
fields fields
args args
res res
}{
{
"invalid refresh token format error",
fields{
eventstore: eventstoreExpect(t),
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
oidcSessionID: "V2_oidcSessionID",
refreshToken: "aW52YWxpZA",
},
res{
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-Sj3lk", "Errors.OIDCSession.RefreshTokenInvalid"),
},
},
{
"inactive session error",
fields{
eventstore: eventstoreExpect(t,
expectFilter(),
),
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
oidcSessionID: "V2_oidcSessionID",
refreshToken: "VjJfb2lkY1Nlc3Npb25JRDpyZWZyZXNoVG9rZW5JRA",
},
res{
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-s3hjk", "Errors.OIDCSession.RefreshTokenInvalid"),
},
},
{
"invalid refresh token error",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow),
),
eventFromEventPusher(
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour),
),
),
),
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
oidcSessionID: "V2_oidcSessionID",
refreshToken: "VjJfb2lkY1Nlc3Npb25JRDpyZWZyZXNoVG9rZW5JRA",
},
res{
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-28ubl", "Errors.OIDCSession.RefreshTokenInvalid"),
},
},
{
"expired refresh token error",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow),
),
eventFromEventPusher(
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour),
),
eventFromEventPusher(
oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"refreshTokenID", 7*24*time.Hour, 24*time.Hour),
),
),
),
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
oidcSessionID: "V2_oidcSessionID",
refreshToken: "VjJfb2lkY1Nlc3Npb25JRDpyZWZyZXNoVG9rZW5JRA",
},
res{
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-3jt2w", "Errors.OIDCSession.RefreshTokenInvalid"),
},
},
{
"refresh successful",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusherWithCreationDateNow(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow),
),
eventFromEventPusherWithCreationDateNow(
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour),
),
eventFromEventPusherWithCreationDateNow(
oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"refreshTokenID", 7*24*time.Hour, 24*time.Hour),
),
),
expectFilter(), // token lifetime
expectPush(
[]*repository.Event{
eventFromEventPusherWithInstanceID("instanceID",
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"accessTokenID", []string{"openid", "offline_access"}, time.Hour),
),
eventFromEventPusherWithInstanceID("instanceID",
oidcsession.NewRefreshTokenRenewedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"refreshTokenID2", 24*time.Hour),
),
},
),
),
idGenerator: mock.NewIDGeneratorExpectIDs(t, "accessTokenID", "refreshTokenID2"),
defaultAccessTokenLifetime: time.Hour,
defaultRefreshTokenLifetime: 7 * 24 * time.Hour,
defaultRefreshTokenIdleLifetime: 24 * time.Hour,
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
oidcSessionID: "V2_oidcSessionID",
refreshToken: "VjJfb2lkY1Nlc3Npb25JRDpyZWZyZXNoVG9rZW5JRA",
scope: []string{"openid", "offline_access"},
},
res{
id: "V2_oidcSessionID-accessTokenID",
refreshToken: "VjJfb2lkY1Nlc3Npb25JRDpyZWZyZXNoVG9rZW5JRDI",
expiration: time.Time{}.Add(time.Hour),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore,
idGenerator: tt.fields.idGenerator,
defaultAccessTokenLifetime: tt.fields.defaultAccessTokenLifetime,
defaultRefreshTokenLifetime: tt.fields.defaultRefreshTokenLifetime,
defaultRefreshTokenIdleLifetime: tt.fields.defaultRefreshTokenIdleLifetime,
keyAlgorithm: tt.fields.keyAlgorithm,
}
gotID, gotRefreshToken, gotExpiration, err := c.ExchangeOIDCSessionRefreshAndAccessToken(tt.args.ctx, tt.args.oidcSessionID, tt.args.refreshToken, tt.args.scope)
assert.Equal(t, tt.res.id, gotID)
assert.Equal(t, tt.res.refreshToken, gotRefreshToken)
assert.Equal(t, tt.res.expiration, gotExpiration)
assert.ErrorIs(t, err, tt.res.err)
})
}
}
func TestCommands_OIDCSessionByRefreshToken(t *testing.T) {
type fields struct {
eventstore *eventstore.Eventstore
idGenerator id.Generator
defaultAccessTokenLifetime time.Duration
defaultRefreshTokenLifetime time.Duration
defaultRefreshTokenIdleLifetime time.Duration
keyAlgorithm crypto.EncryptionAlgorithm
}
type args struct {
ctx context.Context
refreshToken string
}
type res struct {
model *OIDCSessionWriteModel
err error
}
tests := []struct {
name string
fields fields
args args
res res
}{
{
"invalid refresh token format error",
fields{
eventstore: eventstoreExpect(t),
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
refreshToken: "invalid",
},
res{
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-JOI23", "Errors.OIDCSession.RefreshTokenInvalid"),
},
},
{
"inactive session error",
fields{
eventstore: eventstoreExpect(t,
expectFilter(),
),
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
refreshToken: "V2_oidcSessionID:refreshTokenID",
},
res{
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-s3hjk", "Errors.OIDCSession.RefreshTokenInvalid"),
},
},
{
"invalid refresh token error",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow),
),
eventFromEventPusher(
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour),
),
),
),
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
refreshToken: "V2_oidcSessionID:refreshTokenID",
},
res{
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-28ubl", "Errors.OIDCSession.RefreshTokenInvalid"),
},
},
{
"expired refresh token error",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow),
),
eventFromEventPusher(
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour),
),
eventFromEventPusher(
oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"refreshTokenID", 7*24*time.Hour, 24*time.Hour),
),
),
),
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
refreshToken: "V2_oidcSessionID:refreshTokenID",
},
res{
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-3jt2w", "Errors.OIDCSession.RefreshTokenInvalid"),
},
},
{
"get successful",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusherWithCreationDateNow(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow),
),
eventFromEventPusherWithCreationDateNow(
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour),
),
eventFromEventPusherWithCreationDateNow(
oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"refreshTokenID", 7*24*time.Hour, 24*time.Hour),
),
),
),
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
refreshToken: "V2_oidcSessionID:refreshTokenID",
},
res{
model: &OIDCSessionWriteModel{
WriteModel: eventstore.WriteModel{
AggregateID: "V2_oidcSessionID",
ChangeDate: testNow,
},
UserID: "userID",
SessionID: "sessionID",
ClientID: "clientID",
Audience: []string{"audience"},
Scope: []string{"openid", "profile", "offline_access"},
AuthMethodsReferences: []string{amr.PWD},
AuthTime: testNow,
State: domain.OIDCSessionStateActive,
RefreshTokenID: "refreshTokenID",
RefreshTokenExpiration: testNow.Add(7 * 24 * time.Hour),
RefreshTokenIdleExpiration: testNow.Add(24 * time.Hour),
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore,
idGenerator: tt.fields.idGenerator,
defaultAccessTokenLifetime: tt.fields.defaultAccessTokenLifetime,
defaultRefreshTokenLifetime: tt.fields.defaultRefreshTokenLifetime,
defaultRefreshTokenIdleLifetime: tt.fields.defaultRefreshTokenIdleLifetime,
keyAlgorithm: tt.fields.keyAlgorithm,
}
got, err := c.OIDCSessionByRefreshToken(tt.args.ctx, tt.args.refreshToken)
require.ErrorIs(t, err, tt.res.err)
if tt.res.err == nil {
assert.WithinRange(t, got.ChangeDate, tt.res.model.ChangeDate.Add(-2*time.Second), tt.res.model.ChangeDate.Add(2*time.Second))
assert.Equal(t, tt.res.model.AggregateID, got.AggregateID)
assert.Equal(t, tt.res.model.UserID, got.UserID)
assert.Equal(t, tt.res.model.SessionID, got.SessionID)
assert.Equal(t, tt.res.model.ClientID, got.ClientID)
assert.Equal(t, tt.res.model.Audience, got.Audience)
assert.Equal(t, tt.res.model.Scope, got.Scope)
assert.Equal(t, tt.res.model.AuthMethodsReferences, got.AuthMethodsReferences)
assert.WithinRange(t, got.AuthTime, tt.res.model.AuthTime.Add(-2*time.Second), tt.res.model.AuthTime.Add(2*time.Second))
assert.Equal(t, tt.res.model.State, got.State)
assert.Equal(t, tt.res.model.RefreshTokenID, got.RefreshTokenID)
assert.WithinRange(t, got.RefreshTokenExpiration, tt.res.model.RefreshTokenExpiration.Add(-2*time.Second), tt.res.model.RefreshTokenExpiration.Add(2*time.Second))
assert.WithinRange(t, got.RefreshTokenIdleExpiration, tt.res.model.RefreshTokenIdleExpiration.Add(-2*time.Second), tt.res.model.RefreshTokenIdleExpiration.Add(2*time.Second))
}
})
}
}

View File

@ -52,6 +52,24 @@ type SessionWriteModel struct {
aggregate *eventstore.Aggregate
}
func (wm *SessionWriteModel) IsPasswordChecked() bool {
return !wm.PasswordCheckedAt.IsZero()
}
func (wm *SessionWriteModel) IsPasskeyChecked() bool {
return !wm.PasskeyCheckedAt.IsZero()
}
func (wm *SessionWriteModel) IsU2FChecked() bool {
// TODO: implement with https://github.com/zitadel/zitadel/issues/5477
return false
}
func (wm *SessionWriteModel) IsOTPChecked() bool {
// TODO: implement with https://github.com/zitadel/zitadel/issues/5477
return false
}
func NewSessionWriteModel(sessionID string, resourceOwner string) *SessionWriteModel {
return &SessionWriteModel{
WriteModel: eventstore.WriteModel{
@ -210,3 +228,19 @@ func (wm *SessionWriteModel) ChangeMetadata(ctx context.Context, metadata map[st
wm.commands = append(wm.commands, session.NewMetadataSetEvent(ctx, wm.aggregate, wm.Metadata))
}
}
// AuthenticationTime returns the time the user authenticated using the latest time of all checks
func (wm *SessionWriteModel) AuthenticationTime() time.Time {
var authTime time.Time
for _, check := range []time.Time{
wm.PasswordCheckedAt,
wm.PasskeyCheckedAt,
wm.IntentCheckedAt,
// TODO: add U2F and OTP check https://github.com/zitadel/zitadel/issues/5477
} {
if check.After(authTime) {
authTime = check
}
}
return authTime
}

View File

@ -528,7 +528,7 @@ func TestCommands_createHumanOTP(t *testing.T) {
}
func TestCommands_HumanCheckMFAOTPSetup(t *testing.T) {
ctx := authz.NewMockContext("inst1", "org1", "user1")
ctx := authz.NewMockContext("", "org1", "user1")
cryptoAlg := crypto.CreateMockEncryptionAlg(gomock.NewController(t))
key, secret, err := domain.NewOTPKey("example.com", "user1", cryptoAlg)

View File

@ -188,7 +188,7 @@ func TestCommands_AddUserTOTP(t *testing.T) {
}
func TestCommands_CheckUserTOTP(t *testing.T) {
ctx := authz.NewMockContext("inst1", "org1", "user1")
ctx := authz.NewMockContext("", "org1", "user1")
cryptoAlg := crypto.CreateMockEncryptionAlg(gomock.NewController(t))
key, secret, err := domain.NewOTPKey("example.com", "user1", cryptoAlg)

View File

@ -116,6 +116,17 @@ const (
MFALevelMultiFactorCertified
)
type AuthRequestState int
const (
AuthRequestStateUnspecified AuthRequestState = iota
AuthRequestStateAdded
AuthRequestStateCodeAdded
AuthRequestStateCodeExchanged
AuthRequestStateFailed
AuthRequestStateSucceeded
)
func NewAuthRequestFromType(requestType AuthRequestType) (*AuthRequest, error) {
switch requestType {
case AuthRequestTypeOIDC:

View File

@ -0,0 +1,23 @@
package domain
type OIDCErrorReason int32
const (
OIDCErrorReasonUnspecified OIDCErrorReason = iota
OIDCErrorReasonInvalidRequest
OIDCErrorReasonUnauthorizedClient
OIDCErrorReasonAccessDenied
OIDCErrorReasonUnsupportedResponseType
OIDCErrorReasonInvalidScope
OIDCErrorReasonServerError
OIDCErrorReasonTemporaryUnavailable
OIDCErrorReasonInteractionRequired
OIDCErrorReasonLoginRequired
OIDCErrorReasonAccountSelectionRequired
OIDCErrorReasonConsentRequired
OIDCErrorReasonInvalidRequestURI
OIDCErrorReasonInvalidRequestObject
OIDCErrorReasonRequestNotSupported
OIDCErrorReasonRequestURINotSupported
OIDCErrorReasonRegistrationNotSupported
)

View File

@ -0,0 +1,9 @@
package domain
type OIDCSessionState int32
const (
OIDCSessionStateUnspecified OIDCSessionState = iota
OIDCSessionStateActive
OIDCSessionStateTerminated
)

View File

@ -19,6 +19,7 @@ import (
"github.com/zitadel/zitadel/pkg/grpc/admin"
mgmt "github.com/zitadel/zitadel/pkg/grpc/management"
object "github.com/zitadel/zitadel/pkg/grpc/object/v2alpha"
oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2alpha"
session "github.com/zitadel/zitadel/pkg/grpc/session/v2alpha"
"github.com/zitadel/zitadel/pkg/grpc/system"
user "github.com/zitadel/zitadel/pkg/grpc/user/v2alpha"
@ -30,6 +31,7 @@ type Client struct {
Mgmt mgmt.ManagementServiceClient
UserV2 user.UserServiceClient
SessionV2 session.SessionServiceClient
OIDCv2 oidc_pb.OIDCServiceClient
System system.SystemServiceClient
}
@ -40,6 +42,7 @@ func newClient(cc *grpc.ClientConn) Client {
Mgmt: mgmt.NewManagementServiceClient(cc),
UserV2: user.NewUserServiceClient(cc),
SessionV2: session.NewSessionServiceClient(cc),
OIDCv2: oidc_pb.NewOIDCServiceClient(cc),
System: system.NewSystemServiceClient(cc),
}
}
@ -62,11 +65,9 @@ func (t *Tester) UseIsolatedInstance(iamOwnerCtx, systemCtx context.Context) (pr
}
t.createClientConn(iamOwnerCtx, grpc.WithAuthority(primaryDomain))
instanceId = instance.GetInstanceId()
t.Users[instanceId] = map[UserType]User{
IAMOwner: {
Token: instance.GetPat(),
},
}
t.Users.Set(instanceId, IAMOwner, &User{
Token: instance.GetPat(),
})
return primaryDomain, instanceId, t.WithInstanceAuthorization(iamOwnerCtx, IAMOwner, instanceId)
}
@ -187,3 +188,34 @@ func (s *Tester) CreateSuccessfulIntent(t *testing.T, idpID, userID, idpUserID s
require.NoError(t, err)
return intentID, token, writeModel.ChangeDate, writeModel.ProcessedSequence
}
func (s *Tester) CreatePasskeySession(t *testing.T, ctx context.Context, userID string) (id, token string, start, change time.Time) {
createResp, err := s.Client.SessionV2.CreateSession(ctx, &session.CreateSessionRequest{
Checks: &session.Checks{
User: &session.CheckUser{
Search: &session.CheckUser_UserId{UserId: userID},
},
},
Challenges: []session.ChallengeKind{
session.ChallengeKind_CHALLENGE_KIND_PASSKEY,
},
Domain: s.Config.ExternalDomain,
})
require.NoError(t, err)
assertion, err := s.WebAuthN.CreateAssertionResponse(createResp.GetChallenges().GetPasskey().GetPublicKeyCredentialRequestOptions())
require.NoError(t, err)
updateResp, err := s.Client.SessionV2.SetSession(ctx, &session.SetSessionRequest{
SessionId: createResp.GetSessionId(),
SessionToken: createResp.GetSessionToken(),
Checks: &session.Checks{
Passkey: &session.CheckPasskey{
CredentialAssertionData: assertion,
},
},
})
require.NoError(t, err)
return createResp.GetSessionId(), updateResp.GetSessionToken(),
createResp.GetDetails().GetChangeDate().AsTime(), updateResp.GetDetails().GetChangeDate().AsTime()
}

View File

@ -1,6 +1,8 @@
Log:
Level: debug
ExternalSecure: false
TLS:
Enabled: false

View File

@ -57,6 +57,7 @@ type UserType int
const (
Unspecified UserType = iota
OrgOwner
Login
IAMOwner
SystemUser // SystemUser is a user with access to the system service.
)
@ -71,13 +72,29 @@ type User struct {
Token string
}
type InstanceUserMap map[string]map[UserType]*User
func (m InstanceUserMap) Set(instanceID string, typ UserType, user *User) {
if m[instanceID] == nil {
m[instanceID] = make(map[UserType]*User)
}
m[instanceID][typ] = user
}
func (m InstanceUserMap) Get(instanceID string, typ UserType) *User {
if users, ok := m[instanceID]; ok {
return users[typ]
}
return nil
}
// Tester is a Zitadel server and client with all resources available for testing.
type Tester struct {
*start.Server
Instance authz.Instance
Organisation *query.Org
Users map[string]map[UserType]User
Users InstanceUserMap
Client Client
WebAuthN *webauthn.Client
@ -135,6 +152,7 @@ func (s *Tester) pollHealth(ctx context.Context) (err error) {
}
const (
LoginUser = "loginClient"
MachineUser = "integration"
)
@ -148,10 +166,9 @@ func (s *Tester) createMachineUser(ctx context.Context, instanceId string) {
s.Organisation, err = s.Queries.OrgByID(ctx, true, s.Instance.DefaultOrganisationID())
logging.OnError(err).Fatal("query organisation")
query, err := query.NewUserUsernameSearchQuery(MachineUser, query.TextEquals)
usernameQuery, err := query.NewUserUsernameSearchQuery(MachineUser, query.TextEquals)
logging.OnError(err).Fatal("user query")
user, err := s.Queries.GetUser(ctx, true, true, query)
user, err := s.Queries.GetUser(ctx, true, true, usernameQuery)
if errors.Is(err, sql.ErrNoRows) {
_, err = s.Commands.AddMachine(ctx, &command.Machine{
ObjectRoot: models.ObjectRoot{
@ -162,11 +179,10 @@ func (s *Tester) createMachineUser(ctx context.Context, instanceId string) {
Description: "who cares?",
AccessTokenType: domain.OIDCTokenTypeJWT,
})
logging.OnError(err).Fatal("add machine user")
user, err = s.Queries.GetUser(ctx, true, true, query)
logging.WithFields("username", SystemUser).OnError(err).Fatal("add machine user")
user, err = s.Queries.GetUser(ctx, true, true, usernameQuery)
}
logging.OnError(err).Fatal("get user")
logging.WithFields("username", SystemUser).OnError(err).Fatal("get user")
_, err = s.Commands.AddOrgMember(ctx, s.Organisation.ID, user.ID, "ORG_OWNER")
target := new(caos_errs.AlreadyExistsError)
@ -177,18 +193,50 @@ func (s *Tester) createMachineUser(ctx context.Context, instanceId string) {
scopes := []string{oidc.ScopeOpenID, oidc.ScopeProfile, z_oidc.ScopeUserMetaData, z_oidc.ScopeResourceOwner}
pat := command.NewPersonalAccessToken(user.ResourceOwner, user.ID, time.Now().Add(time.Hour), scopes, domain.UserTypeMachine)
_, err = s.Commands.AddPersonalAccessToken(ctx, pat)
logging.OnError(err).Fatal("add pat")
if s.Users == nil {
s.Users = make(map[string]map[UserType]User)
}
if s.Users[instanceId] == nil {
s.Users[instanceId] = make(map[UserType]User)
}
s.Users[instanceId][OrgOwner] = User{
logging.WithFields("username", SystemUser).OnError(err).Fatal("add pat")
s.Users.Set(instanceId, OrgOwner, &User{
User: user,
Token: pat.Token,
})
}
func (s *Tester) createLoginClient(ctx context.Context) {
var err error
s.Instance, err = s.Queries.InstanceByHost(ctx, s.Host())
logging.OnError(err).Fatal("query instance")
ctx = authz.WithInstance(ctx, s.Instance)
s.Organisation, err = s.Queries.OrgByID(ctx, true, s.Instance.DefaultOrganisationID())
logging.OnError(err).Fatal("query organisation")
usernameQuery, err := query.NewUserUsernameSearchQuery(LoginUser, query.TextEquals)
logging.WithFields("username", LoginUser).OnError(err).Fatal("user query")
user, err := s.Queries.GetUser(ctx, true, true, usernameQuery)
if errors.Is(err, sql.ErrNoRows) {
_, err = s.Commands.AddMachine(ctx, &command.Machine{
ObjectRoot: models.ObjectRoot{
ResourceOwner: s.Organisation.ID,
},
Username: LoginUser,
Name: LoginUser,
Description: "who cares?",
AccessTokenType: domain.OIDCTokenTypeJWT,
})
logging.WithFields("username", LoginUser).OnError(err).Fatal("add machine user")
user, err = s.Queries.GetUser(ctx, true, true, usernameQuery)
}
logging.WithFields("username", LoginUser).OnError(err).Fatal("get user")
scopes := []string{oidc.ScopeOpenID, z_oidc.ScopeUserMetaData, z_oidc.ScopeResourceOwner}
pat := command.NewPersonalAccessToken(user.ResourceOwner, user.ID, time.Now().Add(time.Hour), scopes, domain.UserTypeMachine)
_, err = s.Commands.AddPersonalAccessToken(ctx, pat)
logging.OnError(err).Fatal("add pat")
s.Users.Set(FirstInstanceUsersKey, Login, &User{
User: user,
Token: pat.Token,
})
}
func (s *Tester) WithAuthorization(ctx context.Context, u UserType) context.Context {
@ -199,15 +247,12 @@ func (s *Tester) WithInstanceAuthorization(ctx context.Context, u UserType, inst
if u == SystemUser {
s.ensureSystemUser()
}
return metadata.AppendToOutgoingContext(ctx, "Authorization", fmt.Sprintf("Bearer %s", s.Users[instanceID][u].Token))
return metadata.AppendToOutgoingContext(ctx, "Authorization", fmt.Sprintf("Bearer %s", s.Users.Get(instanceID, u).Token))
}
func (s *Tester) ensureSystemUser() {
const ISSUER = "tester"
if s.Users[FirstInstanceUsersKey] == nil {
s.Users[FirstInstanceUsersKey] = make(map[UserType]User)
}
if _, ok := s.Users[FirstInstanceUsersKey][SystemUser]; ok {
if s.Users.Get(FirstInstanceUsersKey, SystemUser) != nil {
return
}
audience := http_util.BuildOrigin(s.Host(), s.Server.Config.ExternalSecure)
@ -215,7 +260,11 @@ func (s *Tester) ensureSystemUser() {
logging.OnError(err).Fatal("system key signer")
jwt, err := client.SignedJWTProfileAssertion(ISSUER, []string{audience}, time.Hour, signer)
logging.OnError(err).Fatal("system key jwt")
s.Users[FirstInstanceUsersKey][SystemUser] = User{Token: jwt}
s.Users.Set(FirstInstanceUsersKey, SystemUser, &User{Token: jwt})
}
func (s *Tester) WithSystemAuthorizationHTTP(u UserType) map[string]string {
return map[string]string{"Authorization": fmt.Sprintf("Bearer %s", s.Users.Get(FirstInstanceUsersKey, u).Token)}
}
// Done send an interrupt signal to cleanly shutdown the server.
@ -263,9 +312,7 @@ func NewTester(ctx context.Context) *Tester {
logging.OnError(err).Fatal()
tester := Tester{
Users: map[string]map[UserType]User{
FirstInstanceUsersKey: make(map[UserType]User),
},
Users: make(InstanceUserMap),
}
tester.wg.Add(1)
go func(wg *sync.WaitGroup) {
@ -279,6 +326,8 @@ func NewTester(ctx context.Context) *Tester {
logging.OnError(ctx.Err()).Fatal("waiting for integration tester server")
}
tester.createClientConn(ctx)
tester.createLoginClient(ctx)
tester.WebAuthN = webauthn.NewClient(tester.Config.WebAuthNName, tester.Config.ExternalDomain, http_util.BuildOrigin(tester.Host(), tester.Config.ExternalSecure))
tester.createMachineUser(ctx, FirstInstanceUsersKey)
tester.WebAuthN = webauthn.NewClient(tester.Config.WebAuthNName, tester.Config.ExternalDomain, "https://"+tester.Host())

View File

@ -0,0 +1,163 @@
package integration
import (
"context"
"fmt"
"net/http"
"net/url"
"strings"
"time"
"github.com/zitadel/oidc/v2/pkg/client/rp"
"github.com/zitadel/oidc/v2/pkg/oidc"
http_util "github.com/zitadel/zitadel/internal/api/http"
oidc_internal "github.com/zitadel/zitadel/internal/api/oidc"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/pkg/grpc/app"
"github.com/zitadel/zitadel/pkg/grpc/management"
)
func (s *Tester) CreateOIDCNativeClient(ctx context.Context, redirectURI string) (*management.AddOIDCAppResponse, error) {
project, err := s.Client.Mgmt.AddProject(ctx, &management.AddProjectRequest{
Name: fmt.Sprintf("project-%d", time.Now().UnixNano()),
})
if err != nil {
return nil, err
}
return s.Client.Mgmt.AddOIDCApp(ctx, &management.AddOIDCAppRequest{
ProjectId: project.GetId(),
Name: fmt.Sprintf("app-%d", time.Now().UnixNano()),
RedirectUris: []string{redirectURI},
ResponseTypes: []app.OIDCResponseType{app.OIDCResponseType_OIDC_RESPONSE_TYPE_CODE},
GrantTypes: []app.OIDCGrantType{app.OIDCGrantType_OIDC_GRANT_TYPE_AUTHORIZATION_CODE, app.OIDCGrantType_OIDC_GRANT_TYPE_REFRESH_TOKEN},
AppType: app.OIDCAppType_OIDC_APP_TYPE_NATIVE,
AuthMethodType: app.OIDCAuthMethodType_OIDC_AUTH_METHOD_TYPE_NONE,
PostLogoutRedirectUris: nil,
Version: app.OIDCVersion_OIDC_VERSION_1_0,
DevMode: false,
AccessTokenType: app.OIDCTokenType_OIDC_TOKEN_TYPE_JWT,
AccessTokenRoleAssertion: false,
IdTokenRoleAssertion: false,
IdTokenUserinfoAssertion: false,
ClockSkew: nil,
AdditionalOrigins: nil,
SkipNativeAppSuccessPage: false,
})
}
func (s *Tester) CreateOIDCImplicitFlowClient(ctx context.Context, redirectURI string) (*management.AddOIDCAppResponse, error) {
project, err := s.Client.Mgmt.AddProject(ctx, &management.AddProjectRequest{
Name: fmt.Sprintf("project-%d", time.Now().UnixNano()),
})
if err != nil {
return nil, err
}
return s.Client.Mgmt.AddOIDCApp(ctx, &management.AddOIDCAppRequest{
ProjectId: project.GetId(),
Name: fmt.Sprintf("app-%d", time.Now().UnixNano()),
RedirectUris: []string{redirectURI},
ResponseTypes: []app.OIDCResponseType{app.OIDCResponseType_OIDC_RESPONSE_TYPE_ID_TOKEN_TOKEN},
GrantTypes: []app.OIDCGrantType{app.OIDCGrantType_OIDC_GRANT_TYPE_IMPLICIT},
AppType: app.OIDCAppType_OIDC_APP_TYPE_USER_AGENT,
AuthMethodType: app.OIDCAuthMethodType_OIDC_AUTH_METHOD_TYPE_NONE,
PostLogoutRedirectUris: nil,
Version: app.OIDCVersion_OIDC_VERSION_1_0,
DevMode: true,
AccessTokenType: app.OIDCTokenType_OIDC_TOKEN_TYPE_JWT,
AccessTokenRoleAssertion: false,
IdTokenRoleAssertion: false,
IdTokenUserinfoAssertion: false,
ClockSkew: nil,
AdditionalOrigins: nil,
SkipNativeAppSuccessPage: false,
})
}
func (s *Tester) CreateOIDCAuthRequest(clientID, loginClient, redirectURI string, scope ...string) (authRequestID string, err error) {
provider, err := s.CreateRelyingParty(clientID, redirectURI, scope...)
if err != nil {
return "", err
}
codeVerifier := "codeVerifier"
codeChallenge := oidc.NewSHACodeChallenge(codeVerifier)
authURL := rp.AuthURL("state", provider, rp.WithCodeChallenge(codeChallenge))
loc, err := CheckRedirect(authURL, map[string]string{oidc_internal.LoginClientHeader: loginClient})
if err != nil {
return "", err
}
prefixWithHost := provider.Issuer() + s.Config.OIDC.DefaultLoginURLV2
if !strings.HasPrefix(loc.String(), prefixWithHost) {
return "", fmt.Errorf("login location has not prefix %s, but is %s", prefixWithHost, loc.String())
}
return strings.TrimPrefix(loc.String(), prefixWithHost), nil
}
func (s *Tester) CreateOIDCAuthRequestImplicit(clientID, loginClient, redirectURI string, scope ...string) (authRequestID string, err error) {
provider, err := s.CreateRelyingParty(clientID, redirectURI, scope...)
if err != nil {
return "", err
}
authURL := rp.AuthURL("state", provider)
// implicit is not natively supported so let's just overwrite the response type
parsed, _ := url.Parse(authURL)
queries := parsed.Query()
queries.Set("response_type", string(oidc.ResponseTypeIDToken))
parsed.RawQuery = queries.Encode()
authURL = parsed.String()
loc, err := CheckRedirect(authURL, map[string]string{oidc_internal.LoginClientHeader: loginClient})
if err != nil {
return "", err
}
prefixWithHost := provider.Issuer() + s.Config.OIDC.DefaultLoginURLV2
if !strings.HasPrefix(loc.String(), prefixWithHost) {
return "", fmt.Errorf("login location has not prefix %s, but is %s", prefixWithHost, loc.String())
}
return strings.TrimPrefix(loc.String(), prefixWithHost), nil
}
func (s *Tester) CreateRelyingParty(clientID, redirectURI string, scope ...string) (rp.RelyingParty, error) {
issuer := http_util.BuildHTTP(s.Config.ExternalDomain, s.Config.Port, s.Config.ExternalSecure)
if len(scope) == 0 {
scope = []string{oidc.ScopeOpenID}
}
return rp.NewRelyingPartyOIDC(issuer, clientID, "", redirectURI, scope)
}
func CheckRedirect(url string, headers map[string]string) (*url.URL, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, err
}
for key, value := range headers {
req.Header.Set(key, value)
}
client := &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
return resp.Location()
}
func (s *Tester) CreateSession(ctx context.Context, userID string) (string, string, error) {
session, err := s.Commands.CreateSession(ctx, []command.SessionCommand{command.CheckUser(userID)}, "domain.tld", nil)
if err != nil {
return "", "", err
}
return session.ID, session.NewToken, nil
}

View File

@ -0,0 +1,88 @@
package query
import (
"context"
"database/sql"
_ "embed"
errs "errors"
"fmt"
"time"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query/projection"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
type AuthRequest struct {
ID string
CreationDate time.Time
LoginClient string
ClientID string
Scope []string
RedirectURI string
Prompt []domain.Prompt
UiLocales []string
LoginHint *string
MaxAge *time.Duration
HintUserID *string
}
func (a *AuthRequest) checkLoginClient(ctx context.Context) error {
if uid := authz.GetCtxData(ctx).UserID; uid != a.LoginClient {
return errors.ThrowPermissionDenied(nil, "OIDCv2-aL0ag", "Errors.AuthRequest.WrongLoginClient")
}
return nil
}
//go:embed embed/auth_request_by_id.sql
var authRequestByIDQuery string
func (q *Queries) authRequestByIDQuery(ctx context.Context) string {
return fmt.Sprintf(authRequestByIDQuery, q.client.Timetravel(call.Took(ctx)))
}
func (q *Queries) AuthRequestByID(ctx context.Context, shouldTriggerBulk bool, id string, checkLoginClient bool) (_ *AuthRequest, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
if shouldTriggerBulk {
ctx = projection.AuthRequestProjection.Trigger(ctx)
}
var (
scope database.StringArray
prompt database.EnumArray[domain.Prompt]
locales database.StringArray
)
dst := new(AuthRequest)
err = q.client.DB.QueryRowContext(
ctx, q.authRequestByIDQuery(ctx),
id, authz.GetInstance(ctx).InstanceID(),
).Scan(
&dst.ID, &dst.CreationDate, &dst.LoginClient, &dst.ClientID, &scope, &dst.RedirectURI,
&prompt, &locales, &dst.LoginHint, &dst.MaxAge, &dst.HintUserID,
)
if errs.Is(err, sql.ErrNoRows) {
return nil, errors.ThrowNotFound(err, "QUERY-Thee9", "Errors.AuthRequest.NotExisting")
}
if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-Ou8ue", "Errors.Internal")
}
dst.Scope = scope
dst.Prompt = prompt
dst.UiLocales = locales
if checkLoginClient {
if err = dst.checkLoginClient(ctx); err != nil {
return nil, err
}
}
return dst, nil
}

View File

@ -0,0 +1,180 @@
package query
import (
"database/sql"
"database/sql/driver"
_ "embed"
"fmt"
"regexp"
"testing"
"time"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query/projection"
)
func TestQueries_AuthRequestByID(t *testing.T) {
expQuery := regexp.QuoteMeta(fmt.Sprintf(
authRequestByIDQuery,
asOfSystemTime,
))
cols := []string{
projection.AuthRequestColumnID,
projection.AuthRequestColumnCreationDate,
projection.AuthRequestColumnLoginClient,
projection.AuthRequestColumnClientID,
projection.AuthRequestColumnScope,
projection.AuthRequestColumnRedirectURI,
projection.AuthRequestColumnPrompt,
projection.AuthRequestColumnUILocales,
projection.AuthRequestColumnLoginHint,
projection.AuthRequestColumnMaxAge,
projection.AuthRequestColumnHintUserID,
}
type args struct {
shouldTriggerBulk bool
id string
checkLoginClient bool
}
tests := []struct {
name string
args args
expect sqlExpectation
want *AuthRequest
wantErr error
}{
{
name: "success, all values",
args: args{
shouldTriggerBulk: false,
id: "123",
checkLoginClient: true,
},
expect: mockQuery(expQuery, cols, []driver.Value{
"id",
testNow,
"loginClient",
"clientID",
database.StringArray{"a", "b", "c"},
"example.com",
database.EnumArray[domain.Prompt]{domain.PromptLogin, domain.PromptConsent},
database.StringArray{"en", "fi"},
"me@example.com",
int64(time.Minute),
"userID",
}, "123", "instanceID"),
want: &AuthRequest{
ID: "id",
CreationDate: testNow,
LoginClient: "loginClient",
ClientID: "clientID",
Scope: []string{"a", "b", "c"},
RedirectURI: "example.com",
Prompt: []domain.Prompt{domain.PromptLogin, domain.PromptConsent},
UiLocales: []string{"en", "fi"},
LoginHint: gu.Ptr("me@example.com"),
MaxAge: gu.Ptr(time.Minute),
HintUserID: gu.Ptr("userID"),
},
},
{
name: "success, null values",
args: args{
shouldTriggerBulk: false,
id: "123",
checkLoginClient: true,
},
expect: mockQuery(expQuery, cols, []driver.Value{
"id",
testNow,
"loginClient",
"clientID",
database.StringArray{"a", "b", "c"},
"example.com",
database.EnumArray[domain.Prompt]{domain.PromptLogin, domain.PromptConsent},
database.StringArray{"en", "fi"},
sql.NullString{},
sql.NullInt64{},
sql.NullString{},
}, "123", "instanceID"),
want: &AuthRequest{
ID: "id",
CreationDate: testNow,
LoginClient: "loginClient",
ClientID: "clientID",
Scope: []string{"a", "b", "c"},
RedirectURI: "example.com",
Prompt: []domain.Prompt{domain.PromptLogin, domain.PromptConsent},
UiLocales: []string{"en", "fi"},
LoginHint: nil,
MaxAge: nil,
HintUserID: nil,
},
},
{
name: "no rows",
args: args{
shouldTriggerBulk: false,
id: "123",
},
expect: mockQuery(expQuery, cols, nil, "123", "instanceID"),
wantErr: errors.ThrowNotFound(sql.ErrNoRows, "QUERY-Thee9", "Errors.AuthRequest.NotExisting"),
},
{
name: "query error",
args: args{
shouldTriggerBulk: false,
id: "123",
},
expect: mockQueryErr(expQuery, sql.ErrConnDone, "123", "instanceID"),
wantErr: errors.ThrowInternal(sql.ErrConnDone, "QUERY-Ou8ue", "Errors.Internal"),
},
{
name: "wrong login client",
args: args{
shouldTriggerBulk: false,
id: "123",
checkLoginClient: true,
},
expect: mockQuery(expQuery, cols, []driver.Value{
"id",
testNow,
"wrongLoginClient",
"clientID",
database.StringArray{"a", "b", "c"},
"example.com",
database.EnumArray[domain.Prompt]{domain.PromptLogin, domain.PromptConsent},
database.StringArray{"en", "fi"},
sql.NullString{},
sql.NullInt64{},
sql.NullString{},
}, "123", "instanceID"),
wantErr: errors.ThrowPermissionDeniedf(nil, "OIDCv2-aL0ag", "Errors.AuthRequest.WrongLoginClient"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
execMock(t, tt.expect, func(db *sql.DB) {
q := &Queries{
client: &database.DB{
DB: db,
Database: &prepareDB{},
},
}
ctx := authz.NewMockContext("instanceID", "orgID", "loginClient")
got, err := q.AuthRequestByID(ctx, tt.args.shouldTriggerBulk, tt.args.id, tt.args.checkLoginClient)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.want, got)
})
})
}
}

View File

@ -0,0 +1,15 @@
select
id,
creation_date,
login_client,
client_id,
scope,
redirect_uri,
prompt,
ui_locales,
login_hint,
max_age,
hint_user_id
from projections.auth_requests %s
where id = $1 and instance_id = $2
limit 1;

View File

@ -14,6 +14,7 @@ import (
"github.com/DATA-DOG/go-sqlmock"
sq "github.com/Masterminds/squirrel"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var (
@ -74,9 +75,9 @@ type checkErr func(error) (err error, ok bool)
type sqlExpectation func(sqlmock.Sqlmock) sqlmock.Sqlmock
func mockQuery(stmt string, cols []string, row []driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
func mockQuery(stmt string, cols []string, row []driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
return func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
q := m.ExpectQuery(stmt)
q := m.ExpectQuery(stmt).WithArgs(args...)
result := sqlmock.NewRows(cols)
if len(row) > 0 {
result.AddRow(row...)
@ -111,6 +112,15 @@ func mockQueryErr(stmt string, err error, args ...driver.Value) func(m sqlmock.S
}
}
func execMock(t testing.TB, exp sqlExpectation, run func(db *sql.DB)) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
mock = exp(mock)
run(db)
assert.NoError(t, mock.ExpectationsWereMet())
}
var (
rowType = reflect.TypeOf(&sql.Row{})
rowsType = reflect.TypeOf(&sql.Rows{})
@ -317,7 +327,9 @@ func TestValidatePrepare(t *testing.T) {
type prepareDB struct{}
func (_ *prepareDB) Timetravel(time.Duration) string { return " AS OF SYSTEM TIME '-1 ms' " }
const asOfSystemTime = " AS OF SYSTEM TIME '-1 ms' "
func (*prepareDB) Timetravel(time.Duration) string { return asOfSystemTime }
var defaultPrepareArgs = []reflect.Value{reflect.ValueOf(context.Background()), reflect.ValueOf(new(prepareDB))}

View File

@ -0,0 +1,142 @@
package projection
import (
"context"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/handler"
"github.com/zitadel/zitadel/internal/eventstore/handler/crdb"
"github.com/zitadel/zitadel/internal/repository/authrequest"
"github.com/zitadel/zitadel/internal/repository/instance"
)
const (
AuthRequestsProjectionTable = "projections.auth_requests"
AuthRequestColumnID = "id"
AuthRequestColumnCreationDate = "creation_date"
AuthRequestColumnChangeDate = "change_date"
AuthRequestColumnSequence = "sequence"
AuthRequestColumnResourceOwner = "resource_owner"
AuthRequestColumnInstanceID = "instance_id"
AuthRequestColumnLoginClient = "login_client"
AuthRequestColumnClientID = "client_id"
AuthRequestColumnRedirectURI = "redirect_uri"
AuthRequestColumnScope = "scope"
AuthRequestColumnPrompt = "prompt"
AuthRequestColumnUILocales = "ui_locales"
AuthRequestColumnMaxAge = "max_age"
AuthRequestColumnLoginHint = "login_hint"
AuthRequestColumnHintUserID = "hint_user_id"
)
type authRequestProjection struct {
crdb.StatementHandler
}
func newAuthRequestProjection(ctx context.Context, config crdb.StatementHandlerConfig) *authRequestProjection {
p := new(authRequestProjection)
config.ProjectionName = AuthRequestsProjectionTable
config.Reducers = p.reducers()
config.InitCheck = crdb.NewMultiTableCheck(
crdb.NewTable([]*crdb.Column{
crdb.NewColumn(AuthRequestColumnID, crdb.ColumnTypeText),
crdb.NewColumn(AuthRequestColumnCreationDate, crdb.ColumnTypeTimestamp),
crdb.NewColumn(AuthRequestColumnChangeDate, crdb.ColumnTypeTimestamp),
crdb.NewColumn(AuthRequestColumnSequence, crdb.ColumnTypeInt64),
crdb.NewColumn(AuthRequestColumnResourceOwner, crdb.ColumnTypeText),
crdb.NewColumn(AuthRequestColumnInstanceID, crdb.ColumnTypeText),
crdb.NewColumn(AuthRequestColumnLoginClient, crdb.ColumnTypeText),
crdb.NewColumn(AuthRequestColumnClientID, crdb.ColumnTypeText),
crdb.NewColumn(AuthRequestColumnRedirectURI, crdb.ColumnTypeText),
crdb.NewColumn(AuthRequestColumnScope, crdb.ColumnTypeTextArray),
crdb.NewColumn(AuthRequestColumnPrompt, crdb.ColumnTypeEnumArray, crdb.Nullable()),
crdb.NewColumn(AuthRequestColumnUILocales, crdb.ColumnTypeTextArray, crdb.Nullable()),
crdb.NewColumn(AuthRequestColumnMaxAge, crdb.ColumnTypeInt64, crdb.Nullable()),
crdb.NewColumn(AuthRequestColumnLoginHint, crdb.ColumnTypeText, crdb.Nullable()),
crdb.NewColumn(AuthRequestColumnHintUserID, crdb.ColumnTypeText, crdb.Nullable()),
},
crdb.NewPrimaryKey(AuthRequestColumnInstanceID, AuthRequestColumnID),
),
)
p.StatementHandler = crdb.NewStatementHandler(ctx, config)
return p
}
func (p *authRequestProjection) reducers() []handler.AggregateReducer {
return []handler.AggregateReducer{
{
Aggregate: authrequest.AggregateType,
EventRedusers: []handler.EventReducer{
{
Event: authrequest.AddedType,
Reduce: p.reduceAuthRequestAdded,
},
{
Event: authrequest.SucceededType,
Reduce: p.reduceAuthRequestEnded,
},
{
Event: authrequest.FailedType,
Reduce: p.reduceAuthRequestEnded,
},
},
},
{
Aggregate: instance.AggregateType,
EventRedusers: []handler.EventReducer{
{
Event: instance.InstanceRemovedEventType,
Reduce: reduceInstanceRemovedHelper(AuthRequestColumnInstanceID),
},
},
},
}
}
func (p *authRequestProjection) reduceAuthRequestAdded(event eventstore.Event) (*handler.Statement, error) {
e, ok := event.(*authrequest.AddedEvent)
if !ok {
return nil, errors.ThrowInvalidArgumentf(nil, "HANDL-Sfwfa", "reduce.wrong.event.type %s", authrequest.AddedType)
}
return crdb.NewCreateStatement(
e,
[]handler.Column{
handler.NewCol(AuthRequestColumnID, e.Aggregate().ID),
handler.NewCol(AuthRequestColumnInstanceID, e.Aggregate().InstanceID),
handler.NewCol(AuthRequestColumnCreationDate, e.CreationDate()),
handler.NewCol(AuthRequestColumnChangeDate, e.CreationDate()),
handler.NewCol(AuthRequestColumnResourceOwner, e.Aggregate().ResourceOwner),
handler.NewCol(AuthRequestColumnSequence, e.Sequence()),
handler.NewCol(AuthRequestColumnLoginClient, e.LoginClient),
handler.NewCol(AuthRequestColumnClientID, e.ClientID),
handler.NewCol(AuthRequestColumnRedirectURI, e.RedirectURI),
handler.NewCol(AuthRequestColumnScope, e.Scope),
handler.NewCol(AuthRequestColumnPrompt, e.Prompt),
handler.NewCol(AuthRequestColumnUILocales, e.UILocales),
handler.NewCol(AuthRequestColumnMaxAge, e.MaxAge),
handler.NewCol(AuthRequestColumnLoginHint, e.LoginHint),
handler.NewCol(AuthRequestColumnHintUserID, e.HintUserID),
},
), nil
}
func (p *authRequestProjection) reduceAuthRequestEnded(event eventstore.Event) (*handler.Statement, error) {
switch event.(type) {
case *authrequest.SucceededEvent,
*authrequest.FailedEvent:
break
default:
return nil, errors.ThrowInvalidArgumentf(nil, "HANDL-ASF3h", "reduce.wrong.event.type %s", []eventstore.EventType{authrequest.SucceededType, authrequest.FailedType})
}
return crdb.NewDeleteStatement(
event,
[]handler.Condition{
handler.NewCond(AuthRequestColumnID, event.Aggregate().ID),
handler.NewCond(AuthRequestColumnInstanceID, event.Aggregate().InstanceID),
},
), nil
}

View File

@ -0,0 +1,134 @@
package projection
import (
"testing"
"time"
"github.com/muhlemmer/gu"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/handler"
"github.com/zitadel/zitadel/internal/repository/authrequest"
)
func TestAuthRequestProjection_reduces(t *testing.T) {
type args struct {
event func(t *testing.T) eventstore.Event
}
tests := []struct {
name string
args args
reduce func(event eventstore.Event) (*handler.Statement, error)
want wantReduce
}{
{
name: "reduceAuthRequestAdded",
args: args{
event: getEvent(testEvent(
authrequest.AddedType,
authrequest.AggregateType,
[]byte(`{"login_client": "loginClient", "client_id":"clientId","redirect_uri": "redirectURI", "scope": ["openid"], "prompt": [1], "ui_locales": ["en","de"], "max_age": 0, "login_hint": "loginHint", "hint_user_id": "hintUserID"}`),
), authrequest.AddedEventMapper),
},
reduce: (&authRequestProjection{}).reduceAuthRequestAdded,
want: wantReduce{
aggregateType: eventstore.AggregateType("auth_request"),
sequence: 15,
previousSequence: 10,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "INSERT INTO projections.auth_requests (id, instance_id, creation_date, change_date, resource_owner, sequence, login_client, client_id, redirect_uri, scope, prompt, ui_locales, max_age, login_hint, hint_user_id) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)",
expectedArgs: []interface{}{
"agg-id",
"instance-id",
anyArg{},
anyArg{},
"ro-id",
uint64(15),
"loginClient",
"clientId",
"redirectURI",
[]string{"openid"},
[]domain.Prompt{domain.PromptNone},
[]string{"en", "de"},
gu.Ptr(time.Duration(0)),
gu.Ptr("loginHint"),
gu.Ptr("hintUserID"),
},
},
},
},
},
},
{
name: "reduceAuthRequestFailed",
args: args{
event: getEvent(testEvent(
authrequest.FailedType,
authrequest.AggregateType,
[]byte(`{"reason": 0}`),
), authrequest.FailedEventMapper),
},
reduce: (&authRequestProjection{}).reduceAuthRequestEnded,
want: wantReduce{
aggregateType: eventstore.AggregateType("auth_request"),
sequence: 15,
previousSequence: 10,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "DELETE FROM projections.auth_requests WHERE (id = $1) AND (instance_id = $2)",
expectedArgs: []interface{}{
"agg-id",
"instance-id",
},
},
},
},
},
},
{
name: "reduceAuthRequestSucceeded",
args: args{
event: getEvent(testEvent(
authrequest.SucceededType,
authrequest.AggregateType,
nil,
), authrequest.SucceededEventMapper),
},
reduce: (&authRequestProjection{}).reduceAuthRequestEnded,
want: wantReduce{
aggregateType: eventstore.AggregateType("auth_request"),
sequence: 15,
previousSequence: 10,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "DELETE FROM projections.auth_requests WHERE (id = $1) AND (instance_id = $2)",
expectedArgs: []interface{}{
"agg-id",
"instance-id",
},
},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
event := baseEvent(t)
got, err := tt.reduce(event)
if !errors.IsErrorInvalidArgument(err) {
t.Errorf("no wrong event mapping: %v, got: %v", err, got)
}
event = tt.args.event(t)
got, err = tt.reduce(event)
assertReduce(t, got, err, AuthRequestsProjectionTable, tt.want)
})
}
}

View File

@ -67,6 +67,7 @@ var (
TelemetryPusherProjection interface{}
DeviceAuthProjection *deviceAuthProjection
SessionProjection *sessionProjection
AuthRequestProjection *authRequestProjection
MilestoneProjection *milestoneProjection
)
@ -145,6 +146,7 @@ func Create(ctx context.Context, sqlClient *database.DB, es *eventstore.Eventsto
NotificationPolicyProjection = newNotificationPolicyProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["notification_policies"]))
DeviceAuthProjection = newDeviceAuthProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["device_auth"]))
SessionProjection = newSessionProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["sessions"]))
AuthRequestProjection = newAuthRequestProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["auth_requests"]))
MilestoneProjection = newMilestoneProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["milestones"]))
newProjectionsList()
return nil
@ -243,6 +245,7 @@ func newProjectionsList() {
NotificationPolicyProjection,
DeviceAuthProjection,
SessionProjection,
AuthRequestProjection,
MilestoneProjection,
}
}

View File

@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net/http"
"regexp"
"sync"
"time"
@ -18,9 +19,11 @@ import (
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/query/projection"
"github.com/zitadel/zitadel/internal/repository/action"
"github.com/zitadel/zitadel/internal/repository/authrequest"
"github.com/zitadel/zitadel/internal/repository/idpintent"
iam_repo "github.com/zitadel/zitadel/internal/repository/instance"
"github.com/zitadel/zitadel/internal/repository/keypair"
"github.com/zitadel/zitadel/internal/repository/oidcsession"
"github.com/zitadel/zitadel/internal/repository/org"
"github.com/zitadel/zitadel/internal/repository/project"
"github.com/zitadel/zitadel/internal/repository/session"
@ -88,6 +91,8 @@ func StartQueries(
usergrant.RegisterEventMappers(repo.eventstore)
session.RegisterEventMappers(repo.eventstore)
idpintent.RegisterEventMappers(repo.eventstore)
authrequest.RegisterEventMappers(repo.eventstore)
oidcsession.RegisterEventMappers(repo.eventstore)
repo.idpConfigEncryption = idpConfigEncryption
repo.multifactors = domain.MultifactorConfigs{
@ -115,3 +120,19 @@ func (q *Queries) Health(ctx context.Context) error {
type prepareDatabase interface {
Timetravel(d time.Duration) string
}
// cleanStaticQueries removes whitespaces,
// such as ` `, \t, \n, from queries to improve
// readability in logs and errors.
func cleanStaticQueries(qs ...*string) {
regex := regexp.MustCompile(`\s+`)
for _, q := range qs {
*q = regex.ReplaceAllString(*q, " ")
}
}
func init() {
cleanStaticQueries(
&authRequestByIDQuery,
)
}

View File

@ -0,0 +1,17 @@
package query
import (
"testing"
"github.com/stretchr/testify/assert"
)
func Test_cleanStaticQueries(t *testing.T) {
query := `select
foo,
bar
from table;`
want := "select foo, bar from table;"
cleanStaticQueries(&query)
assert.Equal(t, want, query)
}

View File

@ -0,0 +1,26 @@
package authrequest
import (
"github.com/zitadel/zitadel/internal/eventstore"
)
const (
AggregateType = "auth_request"
AggregateVersion = "v1"
)
type Aggregate struct {
eventstore.Aggregate
}
func NewAggregate(id, instanceID string) *Aggregate {
return &Aggregate{
Aggregate: eventstore.Aggregate{
Type: AggregateType,
Version: AggregateVersion,
ID: id,
ResourceOwner: instanceID,
InstanceID: instanceID,
},
}
}

View File

@ -0,0 +1,287 @@
package authrequest
import (
"context"
"encoding/json"
"time"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/repository"
)
const (
authRequestEventPrefix = "auth_request."
AddedType = authRequestEventPrefix + "added"
FailedType = authRequestEventPrefix + "failed"
CodeAddedType = authRequestEventPrefix + "code.added"
SessionLinkedType = authRequestEventPrefix + "session.linked"
CodeExchangedType = authRequestEventPrefix + "code.exchanged"
SucceededType = authRequestEventPrefix + "succeeded"
)
type AddedEvent struct {
eventstore.BaseEvent `json:"-"`
LoginClient string `json:"login_client"`
ClientID string `json:"client_id"`
RedirectURI string `json:"redirect_uri"`
State string `json:"state,omitempty"`
Nonce string `json:"nonce,omitempty"`
Scope []string `json:"scope,omitempty"`
Audience []string `json:"audience,omitempty"`
ResponseType domain.OIDCResponseType `json:"response_type,omitempty"`
CodeChallenge *domain.OIDCCodeChallenge `json:"code_challenge,omitempty"`
Prompt []domain.Prompt `json:"prompt,omitempty"`
UILocales []string `json:"ui_locales,omitempty"`
MaxAge *time.Duration `json:"max_age,omitempty"`
LoginHint *string `json:"login_hint,omitempty"`
HintUserID *string `json:"hint_user_id,omitempty"`
}
func (e *AddedEvent) Data() interface{} {
return e
}
func (e *AddedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint {
return nil
}
func NewAddedEvent(ctx context.Context,
aggregate *eventstore.Aggregate,
loginClient,
clientID,
redirectURI,
state,
nonce string,
scope,
audience []string,
responseType domain.OIDCResponseType,
codeChallenge *domain.OIDCCodeChallenge,
prompt []domain.Prompt,
uiLocales []string,
maxAge *time.Duration,
loginHint,
hintUserID *string,
) *AddedEvent {
return &AddedEvent{
BaseEvent: *eventstore.NewBaseEventForPush(
ctx,
aggregate,
AddedType,
),
LoginClient: loginClient,
ClientID: clientID,
RedirectURI: redirectURI,
State: state,
Nonce: nonce,
Scope: scope,
Audience: audience,
ResponseType: responseType,
CodeChallenge: codeChallenge,
Prompt: prompt,
UILocales: uiLocales,
MaxAge: maxAge,
LoginHint: loginHint,
HintUserID: hintUserID,
}
}
func AddedEventMapper(event *repository.Event) (eventstore.Event, error) {
added := &AddedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(event),
}
err := json.Unmarshal(event.Data, added)
if err != nil {
return nil, errors.ThrowInternal(err, "AUTHR-DG4gn", "unable to unmarshal auth request added")
}
return added, nil
}
type SessionLinkedEvent struct {
eventstore.BaseEvent `json:"-"`
SessionID string `json:"session_id"`
UserID string `json:"user_id"`
AuthTime time.Time `json:"auth_time"`
AMR []string `json:"amr"`
}
func (e *SessionLinkedEvent) Data() interface{} {
return e
}
func (e *SessionLinkedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint {
return nil
}
func NewSessionLinkedEvent(ctx context.Context,
aggregate *eventstore.Aggregate,
sessionID,
userID string,
authTime time.Time,
amr []string,
) *SessionLinkedEvent {
return &SessionLinkedEvent{
BaseEvent: *eventstore.NewBaseEventForPush(
ctx,
aggregate,
SessionLinkedType,
),
SessionID: sessionID,
UserID: userID,
AuthTime: authTime,
AMR: amr,
}
}
func SessionLinkedEventMapper(event *repository.Event) (eventstore.Event, error) {
added := &SessionLinkedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(event),
}
err := json.Unmarshal(event.Data, added)
if err != nil {
return nil, errors.ThrowInternal(err, "AUTHR-Sfe3w", "unable to unmarshal auth request session linked")
}
return added, nil
}
type FailedEvent struct {
eventstore.BaseEvent `json:"-"`
Reason domain.OIDCErrorReason `json:"reason,omitempty"`
}
func (e *FailedEvent) Data() interface{} {
return e
}
func (e *FailedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint {
return nil
}
func NewFailedEvent(
ctx context.Context,
aggregate *eventstore.Aggregate,
reason domain.OIDCErrorReason,
) *FailedEvent {
return &FailedEvent{
BaseEvent: *eventstore.NewBaseEventForPush(
ctx,
aggregate,
FailedType,
),
Reason: reason,
}
}
func FailedEventMapper(event *repository.Event) (eventstore.Event, error) {
added := &FailedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(event),
}
err := json.Unmarshal(event.Data, added)
if err != nil {
return nil, errors.ThrowInternal(err, "AUTHR-Sfe3w", "unable to unmarshal auth request session linked")
}
return added, nil
}
type CodeAddedEvent struct {
eventstore.BaseEvent `json:"-"`
}
func (e *CodeAddedEvent) Data() interface{} {
return e
}
func (e *CodeAddedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint {
return nil
}
func NewCodeAddedEvent(ctx context.Context,
aggregate *eventstore.Aggregate,
) *CodeAddedEvent {
return &CodeAddedEvent{
BaseEvent: *eventstore.NewBaseEventForPush(
ctx,
aggregate,
CodeAddedType,
),
}
}
func CodeAddedEventMapper(event *repository.Event) (eventstore.Event, error) {
added := &CodeAddedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(event),
}
err := json.Unmarshal(event.Data, added)
if err != nil {
return nil, errors.ThrowInternal(err, "AUTHR-Sfe3w", "unable to unmarshal auth request code added")
}
return added, nil
}
type CodeExchangedEvent struct {
eventstore.BaseEvent `json:"-"`
}
func (e *CodeExchangedEvent) Data() interface{} {
return nil
}
func (e *CodeExchangedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint {
return nil
}
func NewCodeExchangedEvent(ctx context.Context,
aggregate *eventstore.Aggregate,
) *CodeExchangedEvent {
return &CodeExchangedEvent{
BaseEvent: *eventstore.NewBaseEventForPush(
ctx,
aggregate,
CodeExchangedType,
),
}
}
func CodeExchangedEventMapper(event *repository.Event) (eventstore.Event, error) {
return &CodeExchangedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(event),
}, nil
}
type SucceededEvent struct {
eventstore.BaseEvent `json:"-"`
}
func (e *SucceededEvent) Data() interface{} {
return nil
}
func (e *SucceededEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint {
return nil
}
func NewSucceededEvent(ctx context.Context,
aggregate *eventstore.Aggregate,
) *SucceededEvent {
return &SucceededEvent{
BaseEvent: *eventstore.NewBaseEventForPush(
ctx,
aggregate,
SucceededType,
),
}
}
func SucceededEventMapper(event *repository.Event) (eventstore.Event, error) {
return &SucceededEvent{
BaseEvent: *eventstore.BaseEventFromRepo(event),
}, nil
}

View File

@ -0,0 +1,12 @@
package authrequest
import "github.com/zitadel/zitadel/internal/eventstore"
func RegisterEventMappers(es *eventstore.Eventstore) {
es.RegisterFilterEventMapper(AggregateType, AddedType, AddedEventMapper).
RegisterFilterEventMapper(AggregateType, SessionLinkedType, SessionLinkedEventMapper).
RegisterFilterEventMapper(AggregateType, CodeAddedType, CodeAddedEventMapper).
RegisterFilterEventMapper(AggregateType, CodeExchangedType, CodeExchangedEventMapper).
RegisterFilterEventMapper(AggregateType, FailedType, FailedEventMapper).
RegisterFilterEventMapper(AggregateType, SucceededType, SucceededEventMapper)
}

View File

@ -0,0 +1,25 @@
package oidcsession
import (
"github.com/zitadel/zitadel/internal/eventstore"
)
const (
AggregateType = "oidc_session"
AggregateVersion = "v1"
)
type Aggregate struct {
eventstore.Aggregate
}
func NewAggregate(id, resourceOwner string) *Aggregate {
return &Aggregate{
Aggregate: eventstore.Aggregate{
Type: AggregateType,
Version: AggregateVersion,
ID: id,
ResourceOwner: resourceOwner,
},
}
}

View File

@ -0,0 +1,11 @@
package oidcsession
import "github.com/zitadel/zitadel/internal/eventstore"
func RegisterEventMappers(es *eventstore.Eventstore) {
es.RegisterFilterEventMapper(AggregateType, AddedType, AddedEventMapper).
RegisterFilterEventMapper(AggregateType, AccessTokenAddedType, AccessTokenAddedEventMapper).
RegisterFilterEventMapper(AggregateType, RefreshTokenAddedType, RefreshTokenAddedEventMapper).
RegisterFilterEventMapper(AggregateType, RefreshTokenRenewedType, RefreshTokenRenewedEventMapper)
}

View File

@ -0,0 +1,215 @@
package oidcsession
import (
"context"
"encoding/json"
"time"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/repository"
)
const (
oidcSessionEventPrefix = "oidc_session."
AddedType = oidcSessionEventPrefix + "added"
AccessTokenAddedType = oidcSessionEventPrefix + "access_token.added"
RefreshTokenAddedType = oidcSessionEventPrefix + "refresh_token.added"
RefreshTokenRenewedType = oidcSessionEventPrefix + "refresh_token.renewed"
)
type AddedEvent struct {
eventstore.BaseEvent `json:"-"`
UserID string `json:"userID"`
SessionID string `json:"sessionID"`
ClientID string `json:"clientID"`
Audience []string `json:"audience"`
Scope []string `json:"scope"`
AuthMethodsReferences []string `json:"authMethodsReferences"`
AuthTime time.Time `json:"authTime"`
}
func (e *AddedEvent) Data() interface{} {
return e
}
func (e *AddedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint {
return nil
}
func NewAddedEvent(ctx context.Context,
aggregate *eventstore.Aggregate,
userID,
sessionID,
clientID string,
audience,
scope []string,
authMethodsReferences []string,
authTime time.Time,
) *AddedEvent {
return &AddedEvent{
BaseEvent: *eventstore.NewBaseEventForPush(
ctx,
aggregate,
AddedType,
),
UserID: userID,
SessionID: sessionID,
ClientID: clientID,
Audience: audience,
Scope: scope,
AuthMethodsReferences: authMethodsReferences,
AuthTime: authTime,
}
}
func AddedEventMapper(event *repository.Event) (eventstore.Event, error) {
added := &AddedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(event),
}
err := json.Unmarshal(event.Data, added)
if err != nil {
return nil, errors.ThrowInternal(err, "OIDCS-DG4gn", "unable to unmarshal oidc session added")
}
return added, nil
}
type AccessTokenAddedEvent struct {
eventstore.BaseEvent `json:"-"`
ID string `json:"id"`
Scope []string `json:"scope"`
Lifetime time.Duration `json:"lifetime"`
}
func (e *AccessTokenAddedEvent) Data() interface{} {
return e
}
func (e *AccessTokenAddedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint {
return nil
}
func NewAccessTokenAddedEvent(
ctx context.Context,
aggregate *eventstore.Aggregate,
id string,
scope []string,
lifetime time.Duration,
) *AccessTokenAddedEvent {
return &AccessTokenAddedEvent{
BaseEvent: *eventstore.NewBaseEventForPush(
ctx,
aggregate,
AccessTokenAddedType,
),
ID: id,
Scope: scope,
Lifetime: lifetime,
}
}
func AccessTokenAddedEventMapper(event *repository.Event) (eventstore.Event, error) {
added := &AccessTokenAddedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(event),
}
err := json.Unmarshal(event.Data, added)
if err != nil {
return nil, errors.ThrowInternal(err, "OIDCS-DSGn5", "unable to unmarshal access token added")
}
return added, nil
}
type RefreshTokenAddedEvent struct {
eventstore.BaseEvent `json:"-"`
ID string `json:"id"`
Lifetime time.Duration `json:"lifetime"`
IdleLifetime time.Duration `json:"idleLifetime"`
}
func (e *RefreshTokenAddedEvent) Data() interface{} {
return e
}
func (e *RefreshTokenAddedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint {
return nil
}
func NewRefreshTokenAddedEvent(
ctx context.Context,
aggregate *eventstore.Aggregate,
id string,
lifetime,
idleLifetime time.Duration,
) *RefreshTokenAddedEvent {
return &RefreshTokenAddedEvent{
BaseEvent: *eventstore.NewBaseEventForPush(
ctx,
aggregate,
RefreshTokenAddedType,
),
ID: id,
Lifetime: lifetime,
IdleLifetime: idleLifetime,
}
}
func RefreshTokenAddedEventMapper(event *repository.Event) (eventstore.Event, error) {
added := &RefreshTokenAddedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(event),
}
err := json.Unmarshal(event.Data, added)
if err != nil {
return nil, errors.ThrowInternal(err, "OIDCS-aW3gqq", "unable to unmarshal refresh token added")
}
return added, nil
}
type RefreshTokenRenewedEvent struct {
eventstore.BaseEvent `json:"-"`
ID string `json:"id"`
IdleLifetime time.Duration `json:"idleLifetime"`
}
func (e *RefreshTokenRenewedEvent) Data() interface{} {
return e
}
func (e *RefreshTokenRenewedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint {
return nil
}
func NewRefreshTokenRenewedEvent(
ctx context.Context,
aggregate *eventstore.Aggregate,
id string,
idleLifetime time.Duration,
) *RefreshTokenRenewedEvent {
return &RefreshTokenRenewedEvent{
BaseEvent: *eventstore.NewBaseEventForPush(
ctx,
aggregate,
RefreshTokenRenewedType,
),
ID: id,
IdleLifetime: idleLifetime,
}
}
func RefreshTokenRenewedEventMapper(event *repository.Event) (eventstore.Event, error) {
added := &RefreshTokenRenewedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(event),
}
err := json.Unmarshal(event.Data, added)
if err != nil {
return nil, errors.ThrowInternal(err, "OIDCS-SF3fc", "unable to unmarshal refresh token renewed")
}
return added, nil
}

View File

@ -506,6 +506,12 @@ Errors:
TokenCreationFailed: Неуспешно създаване на токен
InvalidToken: Знакът за намерение е невалиден
OtherUser: Намерение, предназначено за друг потребител
AuthRequest:
AlreadyExists: Auth Request вече съществува
NotExisting: Auth Request не съществува
WrongLoginClient: Auth Request, създаден от друг клиент за влизане
OIDCSession:
RefreshTokenInvalid: Токенът за опресняване е невалиден
AggregateTypes:
action: Действие

View File

@ -488,7 +488,12 @@ Errors:
TokenCreationFailed: Tokenerstellung schlug fehl
InvalidToken: Intent Token ist ungültig
OtherUser: Intent ist für anderen Benutzer gedacht
AuthRequest:
AlreadyExists: Auth Request existiert bereits
NotExisting: Auth Request existiert nicht
WrongLoginClient: Auth Request wurde von einem anderen Login-Client erstellt
OIDCSession:
RefreshTokenInvalid: Refresh Token ist ungültig
AggregateTypes:
action: Action
instance: Instanz

View File

@ -488,6 +488,12 @@ Errors:
TokenCreationFailed: Token creation failed
InvalidToken: Intent Token is invalid
OtherUser: Intent meant for another user
AuthRequest:
AlreadyExists: Auth Request already exists
NotExisting: Auth Request does not exist
WrongLoginClient: Auth Request created by other login client
OIDCSession:
RefreshTokenInvalid: Refresh Token is invalid
AggregateTypes:
action: Action

View File

@ -488,6 +488,12 @@ Errors:
TokenCreationFailed: Fallo en la creación del token
InvalidToken: El token de la intención no es válido
OtherUser: Destinado a otro usuario
AuthRequest:
AlreadyExists: Auth Request ya existe
NotExisting: Auth Request no existe
WrongLoginClient: Auth Request creado por otro cliente de inicio de sesión
OIDCSession:
RefreshTokenInvalid: El token de refresco no es válido
AggregateTypes:
action: Acción

View File

@ -488,6 +488,12 @@ Errors:
TokenCreationFailed: La création du token a échoué
InvalidToken: Le jeton d'intention n'est pas valide
OtherUser: Intention destinée à un autre utilisateur
AuthRequest:
AlreadyExists: Auth Request existe déjà
NotExisting: Auth Request n'existe pas
WrongLoginClient: Auth Request créé par un autre client de connexion
OIDCSession:
RefreshTokenInvalid: Le jeton de rafraîchissement n'est pas valide
AggregateTypes:
action: Action

View File

@ -488,6 +488,12 @@ Errors:
TokenCreationFailed: creazione del token fallita
InvalidToken: Il token dell'intento non è valido
OtherUser: Intento destinato a un altro utente
AuthRequest:
AlreadyExists: Auth Request esiste già
NotExisting: Auth Request non esiste
WrongLoginClient: Auth Request creato da un altro client di accesso
OIDCSession:
RefreshTokenInvalid: Refresh Token non è valido
AggregateTypes:
action: Azione

View File

@ -477,6 +477,12 @@ Errors:
TokenCreationFailed: トークンの作成に失敗しました
InvalidToken: インテントのトークンが無効である
OtherUser: 他のユーザーを意図している
AuthRequest:
AlreadyExists: AuthRequestはすでに存在する
NotExisting: AuthRequest が存在しません
WrongLoginClient: 他のログインクライアントによって作成された AuthRequest
OIDCSession:
RefreshTokenInvalid: 無効なリフレッシュトークンです
AggregateTypes:
action: アクション

View File

@ -488,6 +488,12 @@ Errors:
TokenCreationFailed: Tworzenie tokena nie powiodło się
InvalidToken: Token intencji jest nieprawidłowy
OtherUser: Intencja przeznaczona dla innego użytkownika
AuthRequest:
AlreadyExists: Auth Request już istnieje
NotExisting: Auth Request nie istnieje
WrongLoginClient: Auth Request utworzony przez innego klienta logowania
OIDCSession:
RefreshTokenInvalid: Refresh Token jest nieprawidłowy
AggregateTypes:
action: Działanie

View File

@ -488,6 +488,12 @@ Errors:
TokenCreationFailed: 令牌创建失败
InvalidToken: 意图令牌是无效的
OtherUser: 意图是为另一个用户准备的
AuthRequest:
AlreadyExists: AuthRequest已经存在
NotExisting: AuthRequest不存在
WrongLoginClient: 其他登录客户端创建的AuthRequest
OIDCSession:
RefreshTokenInvalid: Refresh Token 无效
AggregateTypes:
action: 动作

View File

@ -0,0 +1,117 @@
syntax = "proto3";
package zitadel.oidc.v2alpha;
import "google/protobuf/duration.proto";
import "google/protobuf/timestamp.proto";
import "protoc-gen-openapiv2/options/annotations.proto";
option go_package = "github.com/zitadel/zitadel/pkg/grpc/oidc/v2alpha;oidc";
message AuthRequest{
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_schema) = {
external_docs: {
url: "https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest";
description: "Find out more about OIDC Auth Request parameters";
}
};
string id = 1 [
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
description: "ID of the authorization request";
}
];
google.protobuf.Timestamp creation_date = 2 [
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
description: "Time when the auth request was created";
}
];
string client_id = 3 [
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
description: "OIDC client ID of the application that created the auth request";
}
];
repeated string scope = 4 [
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
description: "Requested scopes by the application, which the user must consent to.";
}
];
string redirect_uri = 5 [
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
description: "Base URI that points back to the application";
}
];
repeated Prompt prompt = 6 [
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
description: "Prompts that must be displayed to the user";
}
];
repeated string ui_locales = 7 [
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
description: "End-User's preferred languages and scripts for the user interface, represented as a list of BCP47 [RFC5646] language tag values, ordered by preference. For instance, the value [fr-CA, fr, en] represents a preference for French as spoken in Canada, then French (without a region designation), followed by English (without a region designation). An error SHOULD NOT result if some or all of the requested locales are not supported.";
}
];
optional string login_hint = 8 [
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
description: "Login hint can be set by the application with a user identifier such as an email or phone number.";
}
];
optional google.protobuf.Duration max_age = 9 [
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
description: "Specifies the allowable elapsed time in seconds since the last time the End-User was actively authenticated. If the elapsed time is greater than this value, or the field is present with 0 duration, the user must be re-authenticated.";
}
];
optional string hint_user_id = 10 [
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
description: "User ID taken from a ID Token Hint if it was present and valid.";
}
];
}
enum Prompt {
PROMPT_UNSPECIFIED = 0;
PROMPT_NONE = 1;
PROMPT_LOGIN = 2;
PROMPT_CONSENT = 3;
PROMPT_SELECT_ACCOUNT = 4;
PROMPT_CREATE = 5;
}
message AuthorizationError {
ErrorReason error = 1;
optional string error_description = 2;
optional string error_uri = 3;
}
enum ErrorReason {
ERROR_REASON_UNSPECIFIED = 0;
// Error states from https://datatracker.ietf.org/doc/html/rfc6749#section-4.2.2.1
ERROR_REASON_INVALID_REQUEST = 1;
ERROR_REASON_UNAUTHORIZED_CLIENT = 2;
ERROR_REASON_ACCESS_DENIED = 3;
ERROR_REASON_UNSUPPORTED_RESPONSE_TYPE = 4;
ERROR_REASON_INVALID_SCOPE = 5;
ERROR_REASON_SERVER_ERROR = 6;
ERROR_REASON_TEMPORARY_UNAVAILABLE = 7;
// Error states from https://openid.net/specs/openid-connect-core-1_0.html#AuthError
ERROR_REASON_INTERACTION_REQUIRED = 8;
ERROR_REASON_LOGIN_REQUIRED = 9;
ERROR_REASON_ACCOUNT_SELECTION_REQUIRED = 10;
ERROR_REASON_CONSENT_REQUIRED = 11;
ERROR_REASON_INVALID_REQUEST_URI = 12;
ERROR_REASON_INVALID_REQUEST_OBJECT = 13;
ERROR_REASON_REQUEST_NOT_SUPPORTED = 14;
ERROR_REASON_REQUEST_URI_NOT_SUPPORTED = 15;
ERROR_REASON_REGISTRATION_NOT_SUPPORTED = 16;
}

View File

@ -0,0 +1,190 @@
syntax = "proto3";
package zitadel.oidc.v2alpha;
import "zitadel/object/v2alpha/object.proto";
import "zitadel/protoc_gen_zitadel/v2/options.proto";
import "zitadel/oidc/v2alpha/authorization.proto";
import "google/api/annotations.proto";
import "google/api/field_behavior.proto";
import "protoc-gen-openapiv2/options/annotations.proto";
import "validate/validate.proto";
option go_package = "github.com/zitadel/zitadel/pkg/grpc/oidc/v2alpha;oidc";
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_swagger) = {
info: {
title: "OIDC Service";
version: "2.0-alpha";
description: "Get OIDC Auth Request details and create callback URLs. This project is in alpha state. It can AND will continue breaking until the services provide the same functionality as the current login.";
contact:{
name: "ZITADEL"
url: "https://zitadel.com"
email: "hi@zitadel.com"
}
license: {
name: "Apache 2.0",
url: "https://github.com/zitadel/zitadel/blob/main/LICENSE";
};
};
schemes: HTTPS;
schemes: HTTP;
consumes: "application/json";
consumes: "application/grpc";
produces: "application/json";
produces: "application/grpc";
consumes: "application/grpc-web+proto";
produces: "application/grpc-web+proto";
host: "$ZITADEL_DOMAIN";
base_path: "/";
external_docs: {
description: "Detailed information about ZITADEL",
url: "https://zitadel.com/docs"
}
responses: {
key: "403";
value: {
description: "Returned when the user does not have permission to access the resource.";
schema: {
json_schema: {
ref: "#/definitions/rpcStatus";
}
}
}
}
responses: {
key: "404";
value: {
description: "Returned when the resource does not exist.";
schema: {
json_schema: {
ref: "#/definitions/rpcStatus";
}
}
}
}
};
service OIDCService {
rpc GetAuthRequest (GetAuthRequestRequest) returns (GetAuthRequestResponse) {
option (google.api.http) = {
get: "/v2alpha/oidc/auth_requests/{auth_request_id}"
};
option (zitadel.protoc_gen_zitadel.v2.options) = {
auth_option: {
permission: "authenticated"
}
};
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = {
summary: "Get OIDC Auth Request details";
description: "Get OIDC Auth Request details by ID, obtained from the redirect URL. Returns details that are parsed from the application's Auth Request."
responses: {
key: "200"
value: {
description: "OK";
}
};
};
}
rpc CreateCallback (CreateCallbackRequest) returns (CreateCallbackResponse) {
option (google.api.http) = {
post: "/v2alpha/oidc/auth_requests/{auth_request_id}"
body: "*"
};
option (zitadel.protoc_gen_zitadel.v2.options) = {
auth_option: {
permission: "authenticated"
}
};
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = {
summary: "Finalize an Auth Request and get the callback URL.";
description: "Finalize an Auth Request and get the callback URL for success or failure. The user must be redirected to the URL in order to inform the application about the success or failure. On success, the URL contains details for the application to obtain the tokens. This method can only be called once for an Auth request."
responses: {
key: "200"
value: {
description: "OK";
}
};
};
}
}
message GetAuthRequestRequest {
string auth_request_id = 1 [
(validate.rules).string = {min_len: 1, max_len: 200},
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
min_length: 1;
max_length: 200;
description: "ID of the Auth Request, as obtained from the redirect URL.";
example: "\"163840776835432705\"";
}
];
}
message GetAuthRequestResponse {
AuthRequest auth_request = 1;
}
message CreateCallbackRequest {
string auth_request_id = 1 [
(validate.rules).string = {min_len: 1, max_len: 200},
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
description: "Set this field when the authorization flow failed. It creates a callback URL to the application, with the error details set.";
ref: "https://openid.net/specs/openid-connect-core-1_0.html#AuthError";
}
];
oneof callback_kind {
option (validate.required) = true;
Session session = 2;
AuthorizationError error = 3 [
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
description: "Set this field when the authorization flow failed. It creates a callback URL to the application, with the error details set.";
ref: "https://openid.net/specs/openid-connect-core-1_0.html#AuthError";
}
];
}
}
message Session {
string session_id = 1 [
(validate.rules).string = {min_len: 1, max_len: 200},
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
min_length: 1;
max_length: 200;
description: "ID of the session, used to login the user. Connects the session to the Auth Request.";
example: "\"163840776835432705\"";
}
];
string session_token = 2 [
(validate.rules).string = {min_len: 1, max_len: 200},
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
min_length: 1;
max_length: 200;
description: "Token to verify the session is valid";
}
];
}
message CreateCallbackResponse {
zitadel.object.v2alpha.Details details = 1;
string callback_url = 2 [
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
description: "Callback URL where the user should be redirected, using a \"302 FOUND\" status. Contains details for the application to obtain the tokens on success, or error details on failure. Note that this field must be treated as credentials, as the contained code can be used to obtain tokens on behalve of the user.";
example: "\"https://client.example.org/cb?code=SplxlOBeZQQYbYS6WxSbIA&state=af0ifjsldkj\""
}
];
}