mirror of
https://github.com/zitadel/zitadel.git
synced 2024-12-04 23:45:07 +00:00
feat(api): add OIDC session service (#6157)
This PR starts the OIDC implementation for the API V2 including the Implicit and Code Flow. Co-authored-by: Livio Spring <livio.a@gmail.com> Co-authored-by: Tim Möhlmann <tim+github@zitadel.com> Co-authored-by: Stefan Benz <46600784+stebenz@users.noreply.github.com>
This commit is contained in:
parent
be1fe36776
commit
14b8cf4894
2
.github/workflows/integration.yml
vendored
2
.github/workflows/integration.yml
vendored
@ -43,7 +43,7 @@ jobs:
|
||||
go run main.go init --config internal/integration/config/zitadel.yaml --config internal/integration/config/${INTEGRATION_DB_FLAVOR}.yaml
|
||||
go run main.go setup --masterkeyFromEnv --config internal/integration/config/zitadel.yaml --config internal/integration/config/${INTEGRATION_DB_FLAVOR}.yaml
|
||||
- name: Run integration tests
|
||||
run: go test -tags=integration -race -p 1 -v -coverprofile=profile.cov -coverpkg=./internal/...,./cmd/... ./internal/integration ./internal/api/grpc/... ./internal/notification/handlers/...
|
||||
run: go test -tags=integration -race -p 1 -v -coverprofile=profile.cov -coverpkg=./internal/...,./cmd/... ./internal/integration ./internal/api/grpc/... ./internal/notification/handlers/... ./internal/api/oidc
|
||||
- name: Publish go coverage
|
||||
uses: codecov/codecov-action@v3.1.0
|
||||
with:
|
||||
|
@ -224,7 +224,7 @@ export INTEGRATION_DB_FLAVOR="cockroach" ZITADEL_MASTERKEY="MasterkeyNeedsToHave
|
||||
docker compose -f internal/integration/config/docker-compose.yaml up --wait ${INTEGRATION_DB_FLAVOR}
|
||||
go run main.go init --config internal/integration/config/zitadel.yaml --config internal/integration/config/${INTEGRATION_DB_FLAVOR}.yaml
|
||||
go run main.go setup --masterkeyFromEnv --config internal/integration/config/zitadel.yaml --config internal/integration/config/${INTEGRATION_DB_FLAVOR}.yaml
|
||||
go test -count 1 -tags=integration -race -p 1 ./internal/integration ./internal/api/grpc/...
|
||||
go test -count 1 -tags=integration -race -p 1 ./internal/integration ./internal/api/grpc/... ./internal/api/oidc
|
||||
docker compose -f internal/integration/config/docker-compose.yaml down
|
||||
```
|
||||
|
||||
|
@ -108,4 +108,16 @@ protoc \
|
||||
--validate_out=lang=go:${GOPATH}/src \
|
||||
${PROTO_PATH}/settings/v2alpha/settings_service.proto
|
||||
|
||||
protoc \
|
||||
-I=/proto/include \
|
||||
--grpc-gateway_out ${GOPATH}/src \
|
||||
--grpc-gateway_opt logtostderr=true \
|
||||
--grpc-gateway_opt allow_delete_body=true \
|
||||
--openapiv2_out ${OPENAPI_PATH} \
|
||||
--openapiv2_opt logtostderr=true \
|
||||
--openapiv2_opt allow_delete_body=true \
|
||||
--zitadel_out=${GOPATH}/src \
|
||||
--validate_out=lang=go:${GOPATH}/src \
|
||||
${PROTO_PATH}/oidc/v2alpha/oidc_service.proto
|
||||
|
||||
echo "done generating grpc"
|
||||
|
@ -270,6 +270,7 @@ OIDC:
|
||||
Path: /oauth/v2/keys
|
||||
DeviceAuth:
|
||||
Path: /oauth/v2/device_authorization
|
||||
DefaultLoginURLV2: "/login?authRequest="
|
||||
|
||||
SAML:
|
||||
ProviderConfig:
|
||||
|
@ -88,6 +88,9 @@ func (mig *FirstInstance) Execute(ctx context.Context) error {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -53,6 +53,9 @@ func (mig *externalConfigChange) Execute(ctx context.Context) error {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
|
@ -32,6 +32,7 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/api/grpc/admin"
|
||||
"github.com/zitadel/zitadel/internal/api/grpc/auth"
|
||||
"github.com/zitadel/zitadel/internal/api/grpc/management"
|
||||
oidc_v2 "github.com/zitadel/zitadel/internal/api/grpc/oidc/v2"
|
||||
"github.com/zitadel/zitadel/internal/api/grpc/session/v2"
|
||||
"github.com/zitadel/zitadel/internal/api/grpc/settings/v2"
|
||||
"github.com/zitadel/zitadel/internal/api/grpc/system"
|
||||
@ -192,6 +193,9 @@ func startZitadel(config *Config, masterKey string, server chan<- *Server) error
|
||||
&http.Client{},
|
||||
permissionCheck,
|
||||
sessionTokenVerifier,
|
||||
config.OIDC.DefaultAccessTokenLifetime,
|
||||
config.OIDC.DefaultRefreshTokenExpiration,
|
||||
config.OIDC.DefaultRefreshTokenIdleExpiration,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot start commands: %w", err)
|
||||
@ -344,6 +348,7 @@ func startAPIs(
|
||||
if err := apis.RegisterService(ctx, session.CreateServer(commands, queries, permissionCheck)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := apis.RegisterService(ctx, settings.CreateServer(commands, queries, config.ExternalSecure)); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -397,6 +402,11 @@ func startAPIs(
|
||||
apis.RegisterHandlerOnPrefix(login.HandlerPrefix, l.Handler())
|
||||
apis.HandleFunc(login.EndpointDeviceAuth, login.RedirectDeviceAuthToPrefix)
|
||||
|
||||
// After OIDC provider so that the callback endpoint can be used
|
||||
if err := apis.RegisterService(ctx, oidc_v2.CreateServer(commands, queries, oidcProvider, config.ExternalSecure)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// handle grpc at last to be able to handle the root, because grpc and gateway require a lot of different prefixes
|
||||
apis.RouteGRPC()
|
||||
return nil
|
||||
|
@ -274,6 +274,13 @@ module.exports = {
|
||||
groupPathsBy: "tag",
|
||||
},
|
||||
},
|
||||
oidc: {
|
||||
specPath: ".artifacts/openapi/zitadel/oidc/v2alpha/oidc_service.swagger.json",
|
||||
outputDir: "docs/apis/resources/oidc_service",
|
||||
sidebarOptions: {
|
||||
groupPathsBy: "tag",
|
||||
},
|
||||
},
|
||||
settings: {
|
||||
specPath: ".artifacts/openapi/zitadel/settings/v2alpha/settings_service.swagger.json",
|
||||
outputDir: "docs/apis/resources/settings_service",
|
||||
|
@ -484,6 +484,20 @@ module.exports = {
|
||||
},
|
||||
items: require("./docs/apis/resources/session_service/sidebar.js"),
|
||||
},
|
||||
{
|
||||
type: "category",
|
||||
label: "OIDC lifecycle (Alpha)",
|
||||
link: {
|
||||
type: "generated-index",
|
||||
title: "OIDC service API (Alpha)",
|
||||
slug: "/apis/resources/oidc_service",
|
||||
description:
|
||||
"Get OIDC Auth Request details and create callback URLs.\n"+
|
||||
"\n"+
|
||||
"This project is in alpha state. It can AND will continue breaking until the services provide the same functionality as the current login.",
|
||||
},
|
||||
items: require("./docs/apis/resources/oidc_service/sidebar.js"),
|
||||
},
|
||||
{
|
||||
type: "category",
|
||||
label: "Settings lifecycle (alpha)",
|
||||
|
@ -4,11 +4,11 @@ import "context"
|
||||
|
||||
func NewMockContext(instanceID, orgID, userID string) context.Context {
|
||||
ctx := context.WithValue(context.Background(), dataKey, CtxData{UserID: userID, OrgID: orgID})
|
||||
return context.WithValue(ctx, instanceKey, instanceID)
|
||||
return context.WithValue(ctx, instanceKey, &instance{id: instanceID})
|
||||
}
|
||||
|
||||
func NewMockContextWithPermissions(instanceID, orgID, userID string, permissions []string) context.Context {
|
||||
ctx := context.WithValue(context.Background(), dataKey, CtxData{UserID: userID, OrgID: orgID})
|
||||
ctx = context.WithValue(ctx, instanceKey, instanceID)
|
||||
ctx = context.WithValue(ctx, instanceKey, &instance{id: instanceID})
|
||||
return context.WithValue(ctx, requestPermissionsKey, permissions)
|
||||
}
|
||||
|
204
internal/api/grpc/oidc/v2/oidc.go
Normal file
204
internal/api/grpc/oidc/v2/oidc.go
Normal file
@ -0,0 +1,204 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
"github.com/zitadel/oidc/v2/pkg/op"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/api/grpc/object/v2"
|
||||
"github.com/zitadel/zitadel/internal/api/http"
|
||||
"github.com/zitadel/zitadel/internal/api/oidc"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2alpha"
|
||||
)
|
||||
|
||||
func (s *Server) GetAuthRequest(ctx context.Context, req *oidc_pb.GetAuthRequestRequest) (*oidc_pb.GetAuthRequestResponse, error) {
|
||||
authRequest, err := s.query.AuthRequestByID(ctx, true, req.GetAuthRequestId(), true)
|
||||
if err != nil {
|
||||
logging.WithError(err).Error("query authRequest by ID")
|
||||
return nil, err
|
||||
}
|
||||
return &oidc_pb.GetAuthRequestResponse{
|
||||
AuthRequest: authRequestToPb(authRequest),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func authRequestToPb(a *query.AuthRequest) *oidc_pb.AuthRequest {
|
||||
pba := &oidc_pb.AuthRequest{
|
||||
Id: a.ID,
|
||||
CreationDate: timestamppb.New(a.CreationDate),
|
||||
ClientId: a.ClientID,
|
||||
Scope: a.Scope,
|
||||
RedirectUri: a.RedirectURI,
|
||||
Prompt: promptsToPb(a.Prompt),
|
||||
UiLocales: a.UiLocales,
|
||||
LoginHint: a.LoginHint,
|
||||
HintUserId: a.HintUserID,
|
||||
}
|
||||
if a.MaxAge != nil {
|
||||
pba.MaxAge = durationpb.New(*a.MaxAge)
|
||||
}
|
||||
return pba
|
||||
}
|
||||
|
||||
func promptsToPb(promps []domain.Prompt) []oidc_pb.Prompt {
|
||||
out := make([]oidc_pb.Prompt, len(promps))
|
||||
for i, p := range promps {
|
||||
out[i] = promptToPb(p)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func promptToPb(p domain.Prompt) oidc_pb.Prompt {
|
||||
switch p {
|
||||
case domain.PromptUnspecified:
|
||||
return oidc_pb.Prompt_PROMPT_UNSPECIFIED
|
||||
case domain.PromptNone:
|
||||
return oidc_pb.Prompt_PROMPT_NONE
|
||||
case domain.PromptLogin:
|
||||
return oidc_pb.Prompt_PROMPT_LOGIN
|
||||
case domain.PromptConsent:
|
||||
return oidc_pb.Prompt_PROMPT_CONSENT
|
||||
case domain.PromptSelectAccount:
|
||||
return oidc_pb.Prompt_PROMPT_SELECT_ACCOUNT
|
||||
case domain.PromptCreate:
|
||||
return oidc_pb.Prompt_PROMPT_CREATE
|
||||
default:
|
||||
return oidc_pb.Prompt_PROMPT_UNSPECIFIED
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) CreateCallback(ctx context.Context, req *oidc_pb.CreateCallbackRequest) (*oidc_pb.CreateCallbackResponse, error) {
|
||||
switch v := req.GetCallbackKind().(type) {
|
||||
case *oidc_pb.CreateCallbackRequest_Error:
|
||||
return s.failAuthRequest(ctx, req.GetAuthRequestId(), v.Error)
|
||||
case *oidc_pb.CreateCallbackRequest_Session:
|
||||
return s.linkSessionToAuthRequest(ctx, req.GetAuthRequestId(), v.Session)
|
||||
default:
|
||||
return nil, errors.ThrowUnimplementedf(nil, "OIDCv2-zee7A", "verification oneOf %T in method CreateCallback not implemented", v)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) failAuthRequest(ctx context.Context, authRequestID string, ae *oidc_pb.AuthorizationError) (*oidc_pb.CreateCallbackResponse, error) {
|
||||
details, aar, err := s.command.FailAuthRequest(ctx, authRequestID, errorReasonToDomain(ae.GetError()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authReq := &oidc.AuthRequestV2{CurrentAuthRequest: aar}
|
||||
callback, err := oidc.CreateErrorCallbackURL(authReq, errorReasonToOIDC(ae.GetError()), ae.GetErrorDescription(), ae.GetErrorUri(), s.op)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &oidc_pb.CreateCallbackResponse{
|
||||
Details: object.DomainToDetailsPb(details),
|
||||
CallbackUrl: callback,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) linkSessionToAuthRequest(ctx context.Context, authRequestID string, session *oidc_pb.Session) (*oidc_pb.CreateCallbackResponse, error) {
|
||||
details, aar, err := s.command.LinkSessionToAuthRequest(ctx, authRequestID, session.GetSessionId(), session.GetSessionToken(), true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authReq := &oidc.AuthRequestV2{CurrentAuthRequest: aar}
|
||||
ctx = op.ContextWithIssuer(ctx, http.BuildOrigin(authz.GetInstance(ctx).RequestedHost(), s.externalSecure))
|
||||
var callback string
|
||||
if aar.ResponseType == domain.OIDCResponseTypeCode {
|
||||
callback, err = oidc.CreateCodeCallbackURL(ctx, authReq, s.op)
|
||||
} else {
|
||||
callback, err = oidc.CreateTokenCallbackURL(ctx, authReq, s.op)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &oidc_pb.CreateCallbackResponse{
|
||||
Details: object.DomainToDetailsPb(details),
|
||||
CallbackUrl: callback,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func errorReasonToDomain(errorReason oidc_pb.ErrorReason) domain.OIDCErrorReason {
|
||||
switch errorReason {
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_UNSPECIFIED:
|
||||
return domain.OIDCErrorReasonUnspecified
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST:
|
||||
return domain.OIDCErrorReasonInvalidRequest
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_UNAUTHORIZED_CLIENT:
|
||||
return domain.OIDCErrorReasonUnauthorizedClient
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_ACCESS_DENIED:
|
||||
return domain.OIDCErrorReasonAccessDenied
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_UNSUPPORTED_RESPONSE_TYPE:
|
||||
return domain.OIDCErrorReasonUnsupportedResponseType
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_SCOPE:
|
||||
return domain.OIDCErrorReasonInvalidScope
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_SERVER_ERROR:
|
||||
return domain.OIDCErrorReasonServerError
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_TEMPORARY_UNAVAILABLE:
|
||||
return domain.OIDCErrorReasonTemporaryUnavailable
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_INTERACTION_REQUIRED:
|
||||
return domain.OIDCErrorReasonInteractionRequired
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_LOGIN_REQUIRED:
|
||||
return domain.OIDCErrorReasonLoginRequired
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_ACCOUNT_SELECTION_REQUIRED:
|
||||
return domain.OIDCErrorReasonAccountSelectionRequired
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_CONSENT_REQUIRED:
|
||||
return domain.OIDCErrorReasonConsentRequired
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_URI:
|
||||
return domain.OIDCErrorReasonInvalidRequestURI
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_OBJECT:
|
||||
return domain.OIDCErrorReasonInvalidRequestObject
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_REQUEST_NOT_SUPPORTED:
|
||||
return domain.OIDCErrorReasonRequestNotSupported
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_REQUEST_URI_NOT_SUPPORTED:
|
||||
return domain.OIDCErrorReasonRequestURINotSupported
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_REGISTRATION_NOT_SUPPORTED:
|
||||
return domain.OIDCErrorReasonRegistrationNotSupported
|
||||
default:
|
||||
return domain.OIDCErrorReasonUnspecified
|
||||
}
|
||||
}
|
||||
|
||||
func errorReasonToOIDC(reason oidc_pb.ErrorReason) string {
|
||||
switch reason {
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST:
|
||||
return "invalid_request"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_UNAUTHORIZED_CLIENT:
|
||||
return "unauthorized_client"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_ACCESS_DENIED:
|
||||
return "access_denied"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_UNSUPPORTED_RESPONSE_TYPE:
|
||||
return "unsupported_response_type"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_SCOPE:
|
||||
return "invalid_scope"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_TEMPORARY_UNAVAILABLE:
|
||||
return "temporarily_unavailable"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_INTERACTION_REQUIRED:
|
||||
return "interaction_required"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_LOGIN_REQUIRED:
|
||||
return "login_required"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_ACCOUNT_SELECTION_REQUIRED:
|
||||
return "account_selection_required"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_CONSENT_REQUIRED:
|
||||
return "consent_required"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_URI:
|
||||
return "invalid_request_uri"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_OBJECT:
|
||||
return "invalid_request_object"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_REQUEST_NOT_SUPPORTED:
|
||||
return "request_not_supported"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_REQUEST_URI_NOT_SUPPORTED:
|
||||
return "request_uri_not_supported"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_REGISTRATION_NOT_SUPPORTED:
|
||||
return "registration_not_supported"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_UNSPECIFIED, oidc_pb.ErrorReason_ERROR_REASON_SERVER_ERROR:
|
||||
fallthrough
|
||||
default:
|
||||
return "server_error"
|
||||
}
|
||||
}
|
249
internal/api/grpc/oidc/v2/oidc_integration_test.go
Normal file
249
internal/api/grpc/oidc/v2/oidc_integration_test.go
Normal file
@ -0,0 +1,249 @@
|
||||
//go:build integration
|
||||
|
||||
package oidc_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
"os"
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/integration"
|
||||
object "github.com/zitadel/zitadel/pkg/grpc/object/v2alpha"
|
||||
oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2alpha"
|
||||
session "github.com/zitadel/zitadel/pkg/grpc/session/v2alpha"
|
||||
user "github.com/zitadel/zitadel/pkg/grpc/user/v2alpha"
|
||||
)
|
||||
|
||||
var (
|
||||
CTX context.Context
|
||||
Tester *integration.Tester
|
||||
Client oidc_pb.OIDCServiceClient
|
||||
User *user.AddHumanUserResponse
|
||||
)
|
||||
|
||||
const (
|
||||
redirectURI = "oidcIntegrationTest://callback"
|
||||
redirectURIImplicit = "http://localhost:9999/callback"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
os.Exit(func() int {
|
||||
ctx, _, cancel := integration.Contexts(5 * time.Minute)
|
||||
defer cancel()
|
||||
|
||||
Tester = integration.NewTester(ctx)
|
||||
defer Tester.Done()
|
||||
Client = Tester.Client.OIDCv2
|
||||
|
||||
CTX = Tester.WithAuthorization(ctx, integration.OrgOwner)
|
||||
User = Tester.CreateHumanUser(CTX)
|
||||
return m.Run()
|
||||
}())
|
||||
}
|
||||
|
||||
func TestServer_GetAuthRequest(t *testing.T) {
|
||||
client, err := Tester.CreateOIDCNativeClient(CTX, redirectURI)
|
||||
require.NoError(t, err)
|
||||
authRequestID, err := Tester.CreateOIDCAuthRequest(client.GetClientId(), Tester.Users[integration.FirstInstanceUsersKey][integration.OrgOwner].ID, redirectURI)
|
||||
require.NoError(t, err)
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
AuthRequestID string
|
||||
want *oidc_pb.GetAuthRequestResponse
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Not found",
|
||||
AuthRequestID: "123",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "success",
|
||||
AuthRequestID: authRequestID,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := Client.GetAuthRequest(CTX, &oidc_pb.GetAuthRequestRequest{
|
||||
AuthRequestId: tt.AuthRequestID,
|
||||
})
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
authRequest := got.GetAuthRequest()
|
||||
assert.NotNil(t, authRequest)
|
||||
assert.Equal(t, authRequestID, authRequest.GetId())
|
||||
assert.WithinRange(t, authRequest.GetCreationDate().AsTime(), now.Add(-time.Second), now.Add(time.Second))
|
||||
assert.Contains(t, authRequest.GetScope(), "openid")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_CreateCallback(t *testing.T) {
|
||||
client, err := Tester.CreateOIDCNativeClient(CTX, redirectURI)
|
||||
require.NoError(t, err)
|
||||
sessionResp, err := Tester.Client.SessionV2.CreateSession(CTX, &session.CreateSessionRequest{
|
||||
Checks: &session.Checks{
|
||||
User: &session.CheckUser{
|
||||
Search: &session.CheckUser_UserId{
|
||||
UserId: Tester.Users[integration.FirstInstanceUsersKey][integration.OrgOwner].ID,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
req *oidc_pb.CreateCallbackRequest
|
||||
AuthError string
|
||||
want *oidc_pb.CreateCallbackResponse
|
||||
wantURL *url.URL
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Not found",
|
||||
req: &oidc_pb.CreateCallbackRequest{
|
||||
AuthRequestId: "123",
|
||||
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
|
||||
Session: &oidc_pb.Session{
|
||||
SessionId: sessionResp.GetSessionId(),
|
||||
SessionToken: sessionResp.GetSessionToken(),
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "session not found",
|
||||
req: &oidc_pb.CreateCallbackRequest{
|
||||
AuthRequestId: func() string {
|
||||
authRequestID, err := Tester.CreateOIDCAuthRequest(client.GetClientId(), Tester.Users[integration.FirstInstanceUsersKey][integration.OrgOwner].ID, redirectURI)
|
||||
require.NoError(t, err)
|
||||
return authRequestID
|
||||
}(),
|
||||
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
|
||||
Session: &oidc_pb.Session{
|
||||
SessionId: "foo",
|
||||
SessionToken: "bar",
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "session token invalid",
|
||||
req: &oidc_pb.CreateCallbackRequest{
|
||||
AuthRequestId: func() string {
|
||||
authRequestID, err := Tester.CreateOIDCAuthRequest(client.GetClientId(), Tester.Users[integration.FirstInstanceUsersKey][integration.OrgOwner].ID, redirectURI)
|
||||
require.NoError(t, err)
|
||||
return authRequestID
|
||||
}(),
|
||||
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
|
||||
Session: &oidc_pb.Session{
|
||||
SessionId: sessionResp.GetSessionId(),
|
||||
SessionToken: "bar",
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "fail callback",
|
||||
req: &oidc_pb.CreateCallbackRequest{
|
||||
AuthRequestId: func() string {
|
||||
authRequestID, err := Tester.CreateOIDCAuthRequest(client.GetClientId(), Tester.Users[integration.FirstInstanceUsersKey][integration.OrgOwner].ID, redirectURI)
|
||||
require.NoError(t, err)
|
||||
return authRequestID
|
||||
}(),
|
||||
CallbackKind: &oidc_pb.CreateCallbackRequest_Error{
|
||||
Error: &oidc_pb.AuthorizationError{
|
||||
Error: oidc_pb.ErrorReason_ERROR_REASON_ACCESS_DENIED,
|
||||
ErrorDescription: gu.Ptr("nope"),
|
||||
ErrorUri: gu.Ptr("https://example.com/docs"),
|
||||
},
|
||||
},
|
||||
},
|
||||
want: &oidc_pb.CreateCallbackResponse{
|
||||
CallbackUrl: regexp.QuoteMeta(`oidcintegrationtest://callback?error=access_denied&error_description=nope&error_uri=https%3A%2F%2Fexample.com%2Fdocs&state=state`),
|
||||
Details: &object.Details{
|
||||
ResourceOwner: Tester.Instance.InstanceID(),
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "code callback",
|
||||
req: &oidc_pb.CreateCallbackRequest{
|
||||
AuthRequestId: func() string {
|
||||
authRequestID, err := Tester.CreateOIDCAuthRequest(client.GetClientId(), Tester.Users[integration.FirstInstanceUsersKey][integration.OrgOwner].ID, redirectURI)
|
||||
require.NoError(t, err)
|
||||
return authRequestID
|
||||
}(),
|
||||
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
|
||||
Session: &oidc_pb.Session{
|
||||
SessionId: sessionResp.GetSessionId(),
|
||||
SessionToken: sessionResp.GetSessionToken(),
|
||||
},
|
||||
},
|
||||
},
|
||||
want: &oidc_pb.CreateCallbackResponse{
|
||||
CallbackUrl: `oidcintegrationtest:\/\/callback\?code=(.*)&state=state`,
|
||||
Details: &object.Details{
|
||||
ResourceOwner: Tester.Instance.InstanceID(),
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "implicit",
|
||||
req: &oidc_pb.CreateCallbackRequest{
|
||||
AuthRequestId: func() string {
|
||||
client, err := Tester.CreateOIDCImplicitFlowClient(CTX, redirectURIImplicit)
|
||||
require.NoError(t, err)
|
||||
authRequestID, err := Tester.CreateOIDCAuthRequestImplicit(client.GetClientId(), Tester.Users[integration.FirstInstanceUsersKey][integration.OrgOwner].ID, redirectURIImplicit)
|
||||
require.NoError(t, err)
|
||||
return authRequestID
|
||||
}(),
|
||||
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
|
||||
Session: &oidc_pb.Session{
|
||||
SessionId: sessionResp.GetSessionId(),
|
||||
SessionToken: sessionResp.GetSessionToken(),
|
||||
},
|
||||
},
|
||||
},
|
||||
want: &oidc_pb.CreateCallbackResponse{
|
||||
CallbackUrl: `http:\/\/localhost:9999\/callback#access_token=(.*)&expires_in=(.*)&id_token=(.*)&state=state&token_type=Bearer`,
|
||||
Details: &object.Details{
|
||||
ResourceOwner: Tester.Instance.InstanceID(),
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := Client.CreateCallback(CTX, tt.req)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
integration.AssertDetails(t, tt.want, got)
|
||||
if tt.want != nil {
|
||||
assert.Regexp(t, regexp.MustCompile(tt.want.CallbackUrl), got.GetCallbackUrl())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
150
internal/api/grpc/oidc/v2/oidc_test.go
Normal file
150
internal/api/grpc/oidc/v2/oidc_test.go
Normal file
@ -0,0 +1,150 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2alpha"
|
||||
)
|
||||
|
||||
func Test_authRequestToPb(t *testing.T) {
|
||||
now := time.Now()
|
||||
arg := &query.AuthRequest{
|
||||
ID: "authID",
|
||||
CreationDate: now,
|
||||
ClientID: "clientID",
|
||||
Scope: []string{"a", "b", "c"},
|
||||
RedirectURI: "callbackURI",
|
||||
Prompt: []domain.Prompt{
|
||||
domain.PromptUnspecified,
|
||||
domain.PromptNone,
|
||||
domain.PromptLogin,
|
||||
domain.PromptConsent,
|
||||
domain.PromptSelectAccount,
|
||||
domain.PromptCreate,
|
||||
999,
|
||||
},
|
||||
UiLocales: []string{"en", "fi"},
|
||||
LoginHint: gu.Ptr("foo@bar.com"),
|
||||
MaxAge: gu.Ptr(time.Minute),
|
||||
HintUserID: gu.Ptr("userID"),
|
||||
}
|
||||
want := &oidc_pb.AuthRequest{
|
||||
Id: "authID",
|
||||
CreationDate: timestamppb.New(now),
|
||||
ClientId: "clientID",
|
||||
RedirectUri: "callbackURI",
|
||||
Prompt: []oidc_pb.Prompt{
|
||||
oidc_pb.Prompt_PROMPT_UNSPECIFIED,
|
||||
oidc_pb.Prompt_PROMPT_NONE,
|
||||
oidc_pb.Prompt_PROMPT_LOGIN,
|
||||
oidc_pb.Prompt_PROMPT_CONSENT,
|
||||
oidc_pb.Prompt_PROMPT_SELECT_ACCOUNT,
|
||||
oidc_pb.Prompt_PROMPT_CREATE,
|
||||
oidc_pb.Prompt_PROMPT_UNSPECIFIED,
|
||||
},
|
||||
UiLocales: []string{"en", "fi"},
|
||||
Scope: []string{"a", "b", "c"},
|
||||
LoginHint: gu.Ptr("foo@bar.com"),
|
||||
MaxAge: durationpb.New(time.Minute),
|
||||
HintUserId: gu.Ptr("userID"),
|
||||
}
|
||||
got := authRequestToPb(arg)
|
||||
if !proto.Equal(want, got) {
|
||||
t.Errorf("authRequestToPb() =\n%v\nwant\n%v\n", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_errorReasonToOIDC(t *testing.T) {
|
||||
tests := []struct {
|
||||
reason oidc_pb.ErrorReason
|
||||
want string
|
||||
}{
|
||||
{
|
||||
reason: oidc_pb.ErrorReason_ERROR_REASON_UNSPECIFIED,
|
||||
want: "server_error",
|
||||
},
|
||||
{
|
||||
reason: oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST,
|
||||
want: "invalid_request",
|
||||
},
|
||||
{
|
||||
reason: oidc_pb.ErrorReason_ERROR_REASON_UNAUTHORIZED_CLIENT,
|
||||
want: "unauthorized_client",
|
||||
},
|
||||
{
|
||||
reason: oidc_pb.ErrorReason_ERROR_REASON_ACCESS_DENIED,
|
||||
want: "access_denied",
|
||||
},
|
||||
{
|
||||
reason: oidc_pb.ErrorReason_ERROR_REASON_UNSUPPORTED_RESPONSE_TYPE,
|
||||
want: "unsupported_response_type",
|
||||
},
|
||||
{
|
||||
reason: oidc_pb.ErrorReason_ERROR_REASON_INVALID_SCOPE,
|
||||
want: "invalid_scope",
|
||||
},
|
||||
{
|
||||
reason: oidc_pb.ErrorReason_ERROR_REASON_SERVER_ERROR,
|
||||
want: "server_error",
|
||||
},
|
||||
{
|
||||
reason: oidc_pb.ErrorReason_ERROR_REASON_TEMPORARY_UNAVAILABLE,
|
||||
want: "temporarily_unavailable",
|
||||
},
|
||||
{
|
||||
reason: oidc_pb.ErrorReason_ERROR_REASON_INTERACTION_REQUIRED,
|
||||
want: "interaction_required",
|
||||
},
|
||||
{
|
||||
reason: oidc_pb.ErrorReason_ERROR_REASON_LOGIN_REQUIRED,
|
||||
want: "login_required",
|
||||
},
|
||||
{
|
||||
reason: oidc_pb.ErrorReason_ERROR_REASON_ACCOUNT_SELECTION_REQUIRED,
|
||||
want: "account_selection_required",
|
||||
},
|
||||
{
|
||||
reason: oidc_pb.ErrorReason_ERROR_REASON_CONSENT_REQUIRED,
|
||||
want: "consent_required",
|
||||
},
|
||||
{
|
||||
reason: oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_URI,
|
||||
want: "invalid_request_uri",
|
||||
},
|
||||
{
|
||||
reason: oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_OBJECT,
|
||||
want: "invalid_request_object",
|
||||
},
|
||||
{
|
||||
reason: oidc_pb.ErrorReason_ERROR_REASON_REQUEST_NOT_SUPPORTED,
|
||||
want: "request_not_supported",
|
||||
},
|
||||
{
|
||||
reason: oidc_pb.ErrorReason_ERROR_REASON_REQUEST_URI_NOT_SUPPORTED,
|
||||
want: "request_uri_not_supported",
|
||||
},
|
||||
{
|
||||
reason: oidc_pb.ErrorReason_ERROR_REASON_REGISTRATION_NOT_SUPPORTED,
|
||||
want: "registration_not_supported",
|
||||
},
|
||||
{
|
||||
reason: 99999,
|
||||
want: "server_error",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.reason.String(), func(t *testing.T) {
|
||||
got := errorReasonToOIDC(tt.reason)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
59
internal/api/grpc/oidc/v2/server.go
Normal file
59
internal/api/grpc/oidc/v2/server.go
Normal file
@ -0,0 +1,59 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"github.com/zitadel/oidc/v2/pkg/op"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/api/grpc/server"
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2alpha"
|
||||
)
|
||||
|
||||
var _ oidc_pb.OIDCServiceServer = (*Server)(nil)
|
||||
|
||||
type Server struct {
|
||||
oidc_pb.UnimplementedOIDCServiceServer
|
||||
command *command.Commands
|
||||
query *query.Queries
|
||||
|
||||
op op.OpenIDProvider
|
||||
externalSecure bool
|
||||
}
|
||||
|
||||
type Config struct{}
|
||||
|
||||
func CreateServer(
|
||||
command *command.Commands,
|
||||
query *query.Queries,
|
||||
op op.OpenIDProvider,
|
||||
externalSecure bool,
|
||||
) *Server {
|
||||
return &Server{
|
||||
command: command,
|
||||
query: query,
|
||||
op: op,
|
||||
externalSecure: externalSecure,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) RegisterServer(grpcServer *grpc.Server) {
|
||||
oidc_pb.RegisterOIDCServiceServer(grpcServer, s)
|
||||
}
|
||||
|
||||
func (s *Server) AppName() string {
|
||||
return oidc_pb.OIDCService_ServiceDesc.ServiceName
|
||||
}
|
||||
|
||||
func (s *Server) MethodPrefix() string {
|
||||
return oidc_pb.OIDCService_ServiceDesc.ServiceName
|
||||
}
|
||||
|
||||
func (s *Server) AuthMethods() authz.MethodMapping {
|
||||
return oidc_pb.OIDCService_AuthMethods
|
||||
}
|
||||
|
||||
func (s *Server) RegisterGateway() server.RegisterGatewayFunc {
|
||||
return oidc_pb.RegisterOIDCServiceHandler
|
||||
}
|
@ -18,11 +18,10 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
CTX context.Context
|
||||
Tester *integration.Tester
|
||||
Client session.SessionServiceClient
|
||||
User *user.AddHumanUserResponse
|
||||
GenericOAuthIDPID string
|
||||
CTX context.Context
|
||||
Tester *integration.Tester
|
||||
Client session.SessionServiceClient
|
||||
User *user.AddHumanUserResponse
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
|
@ -540,7 +540,7 @@ func TestServer_StartIdentityProviderFlow(t *testing.T) {
|
||||
ResourceOwner: Tester.Organisation.ID,
|
||||
},
|
||||
NextStep: &user.StartIdentityProviderFlowResponse_AuthUrl{
|
||||
AuthUrl: "https://example.com/oauth/v2/authorize?client_id=clientID&prompt=select_account&redirect_uri=https%3A%2F%2Flocalhost%3A8080%2Fidps%2Fcallback&response_type=code&scope=openid+profile+email&state=",
|
||||
AuthUrl: "https://example.com/oauth/v2/authorize?client_id=clientID&prompt=select_account&redirect_uri=http%3A%2F%2Flocalhost%3A8080%2Fidps%2Fcallback&response_type=code&scope=openid+profile+email&state=",
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
|
43
internal/api/oidc/amr/amr.go
Normal file
43
internal/api/oidc/amr/amr.go
Normal file
@ -0,0 +1,43 @@
|
||||
// Package amr maps zitadel session factors to Authentication Method Reference Values
|
||||
// as defined in [RFC 8176, section 2].
|
||||
//
|
||||
// [RFC 8176, section 2]: https://datatracker.ietf.org/doc/html/rfc8176#section-2
|
||||
package amr
|
||||
|
||||
const (
|
||||
// Password states that the users password has been verified
|
||||
// Deprecated: use `PWD` instead
|
||||
Password = "password"
|
||||
// PWD states that the users password has been verified
|
||||
PWD = "pwd"
|
||||
// MFA states that multiple factors have been verified (e.g. pwd and otp or passkey)
|
||||
MFA = "mfa"
|
||||
// OTP states that a one time password has been verified (e.g. TOTP)
|
||||
OTP = "otp"
|
||||
// UserPresence states that the end users presence has been verified (e.g. passkey and u2f)
|
||||
UserPresence = "user"
|
||||
)
|
||||
|
||||
type AuthenticationMethodReference interface {
|
||||
IsPasswordChecked() bool
|
||||
IsPasskeyChecked() bool
|
||||
IsU2FChecked() bool
|
||||
IsOTPChecked() bool
|
||||
}
|
||||
|
||||
func List(model AuthenticationMethodReference) []string {
|
||||
amr := make([]string, 0)
|
||||
if model.IsPasswordChecked() {
|
||||
amr = append(amr, PWD)
|
||||
}
|
||||
if model.IsPasskeyChecked() || model.IsU2FChecked() {
|
||||
amr = append(amr, UserPresence)
|
||||
}
|
||||
if model.IsOTPChecked() {
|
||||
amr = append(amr, OTP)
|
||||
}
|
||||
if model.IsPasskeyChecked() || len(amr) >= 2 {
|
||||
amr = append(amr, MFA)
|
||||
}
|
||||
return amr
|
||||
}
|
93
internal/api/oidc/amr/amr_test.go
Normal file
93
internal/api/oidc/amr/amr_test.go
Normal file
@ -0,0 +1,93 @@
|
||||
package amr
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAMR(t *testing.T) {
|
||||
type args struct {
|
||||
model AuthenticationMethodReference
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
"no checks, empty",
|
||||
args{
|
||||
new(test),
|
||||
},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
"pw checked",
|
||||
args{
|
||||
&test{pwChecked: true},
|
||||
},
|
||||
[]string{PWD},
|
||||
},
|
||||
{
|
||||
"passkey checked",
|
||||
args{
|
||||
&test{passkeyChecked: true},
|
||||
},
|
||||
[]string{UserPresence, MFA},
|
||||
},
|
||||
{
|
||||
"u2f checked",
|
||||
args{
|
||||
&test{u2fChecked: true},
|
||||
},
|
||||
[]string{UserPresence},
|
||||
},
|
||||
{
|
||||
"otp checked",
|
||||
args{
|
||||
&test{otpChecked: true},
|
||||
},
|
||||
[]string{OTP},
|
||||
},
|
||||
{
|
||||
"multiple checked",
|
||||
args{
|
||||
&test{
|
||||
pwChecked: true,
|
||||
u2fChecked: true,
|
||||
},
|
||||
},
|
||||
[]string{PWD, UserPresence, MFA},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := List(tt.args.model)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type test struct {
|
||||
pwChecked bool
|
||||
passkeyChecked bool
|
||||
u2fChecked bool
|
||||
otpChecked bool
|
||||
}
|
||||
|
||||
func (t test) IsPasswordChecked() bool {
|
||||
return t.pwChecked
|
||||
}
|
||||
|
||||
func (t test) IsPasskeyChecked() bool {
|
||||
return t.passkeyChecked
|
||||
}
|
||||
|
||||
func (t test) IsU2FChecked() bool {
|
||||
return t.u2fChecked
|
||||
}
|
||||
|
||||
func (t test) IsOTPChecked() bool {
|
||||
return t.otpChecked
|
||||
}
|
@ -2,6 +2,7 @@ package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@ -10,16 +11,75 @@ import (
|
||||
"github.com/zitadel/oidc/v2/pkg/op"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
http_utils "github.com/zitadel/zitadel/internal/api/http"
|
||||
"github.com/zitadel/zitadel/internal/api/http/middleware"
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
"github.com/zitadel/zitadel/internal/user/model"
|
||||
)
|
||||
|
||||
const (
|
||||
LoginClientHeader = "x-zitadel-login-client"
|
||||
)
|
||||
|
||||
func (o *OPStorage) CreateAuthRequest(ctx context.Context, req *oidc.AuthRequest, userID string) (_ op.AuthRequest, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
headers, _ := http_utils.HeadersFromCtx(ctx)
|
||||
if loginClient := headers.Get(LoginClientHeader); loginClient != "" {
|
||||
return o.createAuthRequestLoginClient(ctx, req, userID, loginClient)
|
||||
}
|
||||
|
||||
return o.createAuthRequest(ctx, req, userID)
|
||||
}
|
||||
|
||||
func (o *OPStorage) createAuthRequestLoginClient(ctx context.Context, req *oidc.AuthRequest, hintUserID, loginClient string) (op.AuthRequest, error) {
|
||||
project, err := o.query.ProjectByClientID(ctx, req.ClientID, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
scope, err := o.assertProjectRoleScopesByProject(ctx, project, req.Scopes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
audience, err := o.audienceFromProjectID(ctx, project.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
audience = domain.AddAudScopeToAudience(ctx, audience, scope)
|
||||
authRequest := &command.AuthRequest{
|
||||
LoginClient: loginClient,
|
||||
ClientID: req.ClientID,
|
||||
RedirectURI: req.RedirectURI,
|
||||
State: req.State,
|
||||
Nonce: req.Nonce,
|
||||
Scope: scope,
|
||||
Audience: audience,
|
||||
ResponseType: ResponseTypeToBusiness(req.ResponseType),
|
||||
CodeChallenge: CodeChallengeToBusiness(req.CodeChallenge, req.CodeChallengeMethod),
|
||||
Prompt: PromptToBusiness(req.Prompt),
|
||||
UILocales: UILocalesToBusiness(req.UILocales),
|
||||
MaxAge: MaxAgeToBusiness(req.MaxAge),
|
||||
}
|
||||
if req.LoginHint != "" {
|
||||
authRequest.LoginHint = &req.LoginHint
|
||||
}
|
||||
if hintUserID != "" {
|
||||
authRequest.HintUserID = &hintUserID
|
||||
}
|
||||
|
||||
aar, err := o.command.AddAuthRequest(ctx, authRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &AuthRequestV2{aar}, nil
|
||||
}
|
||||
|
||||
func (o *OPStorage) createAuthRequest(ctx context.Context, req *oidc.AuthRequest, userID string) (_ op.AuthRequest, err error) {
|
||||
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
|
||||
if !ok {
|
||||
return nil, errors.ThrowPreconditionFailed(nil, "OIDC-sd436", "no user agent id")
|
||||
@ -36,9 +96,31 @@ func (o *OPStorage) CreateAuthRequest(ctx context.Context, req *oidc.AuthRequest
|
||||
return AuthRequestFromBusiness(resp)
|
||||
}
|
||||
|
||||
func (o *OPStorage) audienceFromProjectID(ctx context.Context, projectID string) ([]string, error) {
|
||||
projectIDQuery, err := query.NewAppProjectIDSearchQuery(projectID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
appIDs, err := o.query.SearchClientIDs(ctx, &query.AppSearchQueries{Queries: []query.SearchQuery{projectIDQuery}}, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return append(appIDs, projectID), nil
|
||||
}
|
||||
|
||||
func (o *OPStorage) AuthRequestByID(ctx context.Context, id string) (_ op.AuthRequest, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
if strings.HasPrefix(id, command.IDPrefixV2) {
|
||||
req, err := o.command.GetCurrentAuthRequest(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &AuthRequestV2{req}, nil
|
||||
}
|
||||
|
||||
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
|
||||
if !ok {
|
||||
return nil, errors.ThrowPreconditionFailed(nil, "OIDC-D3g21", "no user agent id")
|
||||
@ -54,6 +136,17 @@ func (o *OPStorage) AuthRequestByCode(ctx context.Context, code string) (_ op.Au
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
plainCode, err := o.decryptGrant(code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.HasPrefix(plainCode, command.IDPrefixV2) {
|
||||
authReq, err := o.command.ExchangeAuthCode(ctx, plainCode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &AuthRequestV2{authReq}, nil
|
||||
}
|
||||
resp, err := o.repo.AuthRequestByCode(ctx, code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -61,9 +154,23 @@ func (o *OPStorage) AuthRequestByCode(ctx context.Context, code string) (_ op.Au
|
||||
return AuthRequestFromBusiness(resp)
|
||||
}
|
||||
|
||||
// decryptGrant decrypts a code or refresh_token
|
||||
func (o *OPStorage) decryptGrant(grant string) (string, error) {
|
||||
decodedGrant, err := base64.RawURLEncoding.DecodeString(grant)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return o.encAlg.DecryptString(decodedGrant, o.encAlg.EncryptionKeyID())
|
||||
}
|
||||
|
||||
func (o *OPStorage) SaveAuthCode(ctx context.Context, id, code string) (err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
if strings.HasPrefix(id, command.IDPrefixV2) {
|
||||
return o.command.AddAuthRequestCode(ctx, id, code)
|
||||
}
|
||||
|
||||
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
|
||||
if !ok {
|
||||
return errors.ThrowPreconditionFailed(nil, "OIDC-Dgus2", "no user agent id")
|
||||
@ -81,12 +188,15 @@ func (o *OPStorage) DeleteAuthRequest(ctx context.Context, id string) (err error
|
||||
func (o *OPStorage) CreateAccessToken(ctx context.Context, req op.TokenRequest) (_ string, _ time.Time, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
var userAgentID, applicationID, userOrgID string
|
||||
authReq, ok := req.(*AuthRequest)
|
||||
if ok {
|
||||
switch authReq := req.(type) {
|
||||
case *AuthRequest:
|
||||
userAgentID = authReq.AgentID
|
||||
applicationID = authReq.ApplicationID
|
||||
userOrgID = authReq.UserOrgID
|
||||
case *AuthRequestV2:
|
||||
return o.command.AddOIDCSessionAccessToken(setContextUserSystem(ctx), authReq.GetID())
|
||||
}
|
||||
|
||||
accessTokenLifetime, _, _, _, err := o.getOIDCSettings(ctx)
|
||||
@ -104,6 +214,15 @@ func (o *OPStorage) CreateAccessToken(ctx context.Context, req op.TokenRequest)
|
||||
func (o *OPStorage) CreateAccessAndRefreshTokens(ctx context.Context, req op.TokenRequest, refreshToken string) (_, _ string, _ time.Time, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
// handle V2 request directly
|
||||
switch tokenReq := req.(type) {
|
||||
case *AuthRequestV2:
|
||||
return o.command.AddOIDCSessionRefreshAndAccessToken(setContextUserSystem(ctx), tokenReq.GetID())
|
||||
case *RefreshTokenRequestV2:
|
||||
return o.command.ExchangeOIDCSessionRefreshAndAccessToken(setContextUserSystem(ctx), tokenReq.OIDCSessionWriteModel.AggregateID, refreshToken, tokenReq.RequestedScopes)
|
||||
}
|
||||
|
||||
userAgentID, applicationID, userOrgID, authTime, authMethodsReferences := getInfoFromRequest(req)
|
||||
scopes, err := o.assertProjectRoleScopes(ctx, applicationID, req.GetScopes())
|
||||
if err != nil {
|
||||
@ -142,7 +261,22 @@ func getInfoFromRequest(req op.TokenRequest) (string, string, string, time.Time,
|
||||
return "", "", "", time.Time{}, nil
|
||||
}
|
||||
|
||||
func (o *OPStorage) TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (op.RefreshTokenRequest, error) {
|
||||
func (o *OPStorage) TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (_ op.RefreshTokenRequest, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
plainCode, err := o.decryptGrant(refreshToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.HasPrefix(plainCode, command.IDPrefixV2) {
|
||||
oidcSession, err := o.command.OIDCSessionByRefreshToken(ctx, plainCode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &RefreshTokenRequestV2{OIDCSessionWriteModel: oidcSession}, nil
|
||||
}
|
||||
|
||||
tokenView, err := o.repo.RefreshTokenByToken(ctx, refreshToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -245,6 +379,29 @@ func (o *OPStorage) assertProjectRoleScopes(ctx context.Context, clientID string
|
||||
return scopes, nil
|
||||
}
|
||||
|
||||
func (o *OPStorage) assertProjectRoleScopesByProject(ctx context.Context, project *query.Project, scopes []string) ([]string, error) {
|
||||
for _, scope := range scopes {
|
||||
if strings.HasPrefix(scope, ScopeProjectRolePrefix) {
|
||||
return scopes, nil
|
||||
}
|
||||
}
|
||||
if !project.ProjectRoleAssertion {
|
||||
return scopes, nil
|
||||
}
|
||||
projectIDQuery, err := query.NewProjectRoleProjectIDSearchQuery(project.ID)
|
||||
if err != nil {
|
||||
return nil, errors.ThrowInternal(err, "OIDC-Cyc78", "Errors.Internal")
|
||||
}
|
||||
roles, err := o.query.SearchProjectRoles(ctx, true, &query.ProjectRoleSearchQueries{Queries: []query.SearchQuery{projectIDQuery}}, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, role := range roles.ProjectRoles {
|
||||
scopes = append(scopes, ScopeProjectRolePrefix+role.Key)
|
||||
}
|
||||
return scopes, nil
|
||||
}
|
||||
|
||||
func (o *OPStorage) assertClientScopesForPAT(ctx context.Context, token *model.TokenView, clientID, projectID string) error {
|
||||
token.Audience = append(token.Audience, clientID)
|
||||
projectIDQuery, err := query.NewProjectRoleProjectIDSearchQuery(projectID)
|
||||
@ -279,3 +436,58 @@ func (o *OPStorage) getOIDCSettings(ctx context.Context) (accessTokenLifetime, i
|
||||
}
|
||||
return o.defaultAccessTokenLifetime, o.defaultIdTokenLifetime, o.defaultRefreshTokenIdleExpiration, o.defaultRefreshTokenExpiration, nil
|
||||
}
|
||||
|
||||
func CreateErrorCallbackURL(authReq op.AuthRequest, reason, description, uri string, authorizer op.Authorizer) (string, error) {
|
||||
e := struct {
|
||||
Error string `schema:"error"`
|
||||
Description string `schema:"error_description,omitempty"`
|
||||
URI string `schema:"error_uri,omitempty"`
|
||||
State string `schema:"state,omitempty"`
|
||||
}{
|
||||
Error: reason,
|
||||
Description: description,
|
||||
URI: uri,
|
||||
State: authReq.GetState(),
|
||||
}
|
||||
callback, err := op.AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), e, authorizer.Encoder())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return callback, nil
|
||||
}
|
||||
|
||||
func CreateCodeCallbackURL(ctx context.Context, authReq op.AuthRequest, authorizer op.Authorizer) (string, error) {
|
||||
code, err := op.CreateAuthRequestCode(ctx, authReq, authorizer.Storage(), authorizer.Crypto())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
codeResponse := struct {
|
||||
code string
|
||||
state string
|
||||
}{
|
||||
code: code,
|
||||
state: authReq.GetState(),
|
||||
}
|
||||
callback, err := op.AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), &codeResponse, authorizer.Encoder())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return callback, err
|
||||
}
|
||||
|
||||
func CreateTokenCallbackURL(ctx context.Context, req op.AuthRequest, authorizer op.Authorizer) (string, error) {
|
||||
client, err := authorizer.Storage().GetClientByClientID(ctx, req.GetClientID())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
createAccessToken := req.GetResponseType() != oidc.ResponseTypeIDTokenOnly
|
||||
resp, err := op.CreateTokenResponse(ctx, req, client, authorizer, createAccessToken, "", "")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
callback, err := op.AuthResponseURL(req.GetRedirectURI(), req.GetResponseType(), req.GetResponseMode(), resp, authorizer.Encoder())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return callback, err
|
||||
}
|
||||
|
@ -12,20 +12,12 @@ import (
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
http_utils "github.com/zitadel/zitadel/internal/api/http"
|
||||
"github.com/zitadel/zitadel/internal/api/oidc/amr"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/user/model"
|
||||
)
|
||||
|
||||
const (
|
||||
// DEPRECATED: use `amrPWD` instead
|
||||
amrPassword = "password"
|
||||
amrPWD = "pwd"
|
||||
amrMFA = "mfa"
|
||||
amrOTP = "otp"
|
||||
amrUserPresence = "user"
|
||||
)
|
||||
|
||||
type AuthRequest struct {
|
||||
*domain.AuthRequest
|
||||
}
|
||||
@ -40,19 +32,19 @@ func (a *AuthRequest) GetACR() string {
|
||||
}
|
||||
|
||||
func (a *AuthRequest) GetAMR() []string {
|
||||
amr := make([]string, 0)
|
||||
list := make([]string, 0)
|
||||
if a.PasswordVerified {
|
||||
amr = append(amr, amrPassword, amrPWD)
|
||||
list = append(list, amr.Password, amr.PWD)
|
||||
}
|
||||
if len(a.MFAsVerified) > 0 {
|
||||
amr = append(amr, amrMFA)
|
||||
list = append(list, amr.MFA)
|
||||
for _, mfa := range a.MFAsVerified {
|
||||
if amrMFA := AMRFromMFAType(mfa); amrMFA != "" {
|
||||
amr = append(amr, amrMFA)
|
||||
list = append(list, amrMFA)
|
||||
}
|
||||
}
|
||||
}
|
||||
return amr
|
||||
return list
|
||||
}
|
||||
|
||||
func (a *AuthRequest) GetAudience() []string {
|
||||
@ -271,10 +263,10 @@ func CodeChallengeToOIDC(challenge *domain.OIDCCodeChallenge) *oidc.CodeChalleng
|
||||
func AMRFromMFAType(mfaType domain.MFAType) string {
|
||||
switch mfaType {
|
||||
case domain.MFATypeOTP:
|
||||
return amrOTP
|
||||
return amr.OTP
|
||||
case domain.MFATypeU2F,
|
||||
domain.MFATypeU2FUserVerification:
|
||||
return amrUserPresence
|
||||
return amr.UserPresence
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
|
106
internal/api/oidc/auth_request_converter_v2.go
Normal file
106
internal/api/oidc/auth_request_converter_v2.go
Normal file
@ -0,0 +1,106 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
)
|
||||
|
||||
type AuthRequestV2 struct {
|
||||
*command.CurrentAuthRequest
|
||||
}
|
||||
|
||||
func (a *AuthRequestV2) GetID() string {
|
||||
return a.ID
|
||||
}
|
||||
|
||||
func (a *AuthRequestV2) GetACR() string {
|
||||
return "" //PLANNED: impl
|
||||
}
|
||||
|
||||
func (a *AuthRequestV2) GetAMR() []string {
|
||||
return a.AMR
|
||||
}
|
||||
|
||||
func (a *AuthRequestV2) GetAudience() []string {
|
||||
return a.Audience
|
||||
}
|
||||
|
||||
func (a *AuthRequestV2) GetAuthTime() time.Time {
|
||||
return a.AuthTime
|
||||
}
|
||||
|
||||
func (a *AuthRequestV2) GetClientID() string {
|
||||
return a.ClientID
|
||||
}
|
||||
|
||||
func (a *AuthRequestV2) GetCodeChallenge() *oidc.CodeChallenge {
|
||||
return CodeChallengeToOIDC(a.CodeChallenge)
|
||||
}
|
||||
|
||||
func (a *AuthRequestV2) GetNonce() string {
|
||||
return a.Nonce
|
||||
}
|
||||
|
||||
func (a *AuthRequestV2) GetRedirectURI() string {
|
||||
return a.RedirectURI
|
||||
}
|
||||
|
||||
func (a *AuthRequestV2) GetResponseType() oidc.ResponseType {
|
||||
return ResponseTypeToOIDC(a.ResponseType)
|
||||
}
|
||||
|
||||
func (a *AuthRequestV2) GetResponseMode() oidc.ResponseMode {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (a *AuthRequestV2) GetScopes() []string {
|
||||
return a.Scope
|
||||
}
|
||||
|
||||
func (a *AuthRequestV2) GetState() string {
|
||||
return a.State
|
||||
}
|
||||
|
||||
func (a *AuthRequestV2) GetSubject() string {
|
||||
return a.UserID
|
||||
}
|
||||
|
||||
func (a *AuthRequestV2) Done() bool {
|
||||
return a.UserID != "" && a.SessionID != ""
|
||||
}
|
||||
|
||||
type RefreshTokenRequestV2 struct {
|
||||
*command.OIDCSessionWriteModel
|
||||
RequestedScopes []string
|
||||
}
|
||||
|
||||
func (r *RefreshTokenRequestV2) GetAMR() []string {
|
||||
return r.AuthMethodsReferences
|
||||
}
|
||||
|
||||
func (r *RefreshTokenRequestV2) GetAudience() []string {
|
||||
return r.Audience
|
||||
}
|
||||
|
||||
func (r *RefreshTokenRequestV2) GetAuthTime() time.Time {
|
||||
return r.AuthTime
|
||||
}
|
||||
|
||||
func (r *RefreshTokenRequestV2) GetClientID() string {
|
||||
return r.ClientID
|
||||
}
|
||||
|
||||
func (r *RefreshTokenRequestV2) GetScopes() []string {
|
||||
return r.Scope
|
||||
}
|
||||
|
||||
func (r *RefreshTokenRequestV2) GetSubject() string {
|
||||
return r.UserID
|
||||
}
|
||||
|
||||
func (r *RefreshTokenRequestV2) SetCurrentScopes(scopes []string) {
|
||||
r.RequestedScopes = scopes
|
||||
}
|
275
internal/api/oidc/auth_request_integration_test.go
Normal file
275
internal/api/oidc/auth_request_integration_test.go
Normal file
@ -0,0 +1,275 @@
|
||||
//go:build integration
|
||||
|
||||
package oidc_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/oidc/v2/pkg/client/rp"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/oidc/amr"
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
"github.com/zitadel/zitadel/internal/integration"
|
||||
oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2alpha"
|
||||
user "github.com/zitadel/zitadel/pkg/grpc/user/v2alpha"
|
||||
)
|
||||
|
||||
var (
|
||||
CTX context.Context
|
||||
CTXLOGIN context.Context
|
||||
Tester *integration.Tester
|
||||
User *user.AddHumanUserResponse
|
||||
)
|
||||
|
||||
const (
|
||||
redirectURI = "oidcIntegrationTest://callback"
|
||||
redirectURIImplicit = "http://localhost:9999/callback"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
os.Exit(func() int {
|
||||
ctx, errCtx, cancel := integration.Contexts(5 * time.Minute)
|
||||
defer cancel()
|
||||
|
||||
Tester = integration.NewTester(ctx)
|
||||
defer Tester.Done()
|
||||
|
||||
CTX, _ = Tester.WithAuthorization(ctx, integration.OrgOwner), errCtx
|
||||
User = Tester.CreateHumanUser(CTX)
|
||||
Tester.RegisterUserPasskey(CTX, User.GetUserId())
|
||||
CTXLOGIN, _ = Tester.WithAuthorization(ctx, integration.Login), errCtx
|
||||
return m.Run()
|
||||
}())
|
||||
}
|
||||
|
||||
func createClient(t testing.TB) string {
|
||||
app, err := Tester.CreateOIDCNativeClient(CTX, redirectURI)
|
||||
require.NoError(t, err)
|
||||
return app.GetClientId()
|
||||
}
|
||||
|
||||
func createImplicitClient(t testing.TB) string {
|
||||
app, err := Tester.CreateOIDCImplicitFlowClient(CTX, redirectURIImplicit)
|
||||
require.NoError(t, err)
|
||||
return app.GetClientId()
|
||||
}
|
||||
|
||||
func createAuthRequest(t testing.TB, clientID, redirectURI string, scope ...string) string {
|
||||
redURL, err := Tester.CreateOIDCAuthRequest(clientID, Tester.Users[integration.FirstInstanceUsersKey][integration.Login].ID, redirectURI, scope...)
|
||||
require.NoError(t, err)
|
||||
return redURL
|
||||
}
|
||||
|
||||
func createAuthRequestImplicit(t testing.TB, clientID, redirectURI string, scope ...string) string {
|
||||
redURL, err := Tester.CreateOIDCAuthRequestImplicit(clientID, Tester.Users[integration.FirstInstanceUsersKey][integration.Login].ID, redirectURI, scope...)
|
||||
require.NoError(t, err)
|
||||
return redURL
|
||||
}
|
||||
|
||||
func TestOPStorage_CreateAuthRequest(t *testing.T) {
|
||||
clientID := createClient(t)
|
||||
|
||||
id := createAuthRequest(t, clientID, redirectURI)
|
||||
require.Contains(t, id, command.IDPrefixV2)
|
||||
}
|
||||
|
||||
func TestOPStorage_CreateAccessToken_code(t *testing.T) {
|
||||
clientID := createClient(t)
|
||||
authRequestID := createAuthRequest(t, clientID, redirectURI)
|
||||
sessionID, sessionToken, startTime, changeTime := Tester.CreatePasskeySession(t, CTXLOGIN, User.GetUserId())
|
||||
linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
|
||||
AuthRequestId: authRequestID,
|
||||
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
|
||||
Session: &oidc_pb.Session{
|
||||
SessionId: sessionID,
|
||||
SessionToken: sessionToken,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// test code exchange
|
||||
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
|
||||
tokens, err := exchangeTokens(t, clientID, code)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, false)
|
||||
assertTokenClaims(t, tokens.IDTokenClaims, startTime, changeTime)
|
||||
|
||||
// callback on a succeeded request must fail
|
||||
linkResp, err = Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
|
||||
AuthRequestId: authRequestID,
|
||||
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
|
||||
Session: &oidc_pb.Session{
|
||||
SessionId: sessionID,
|
||||
SessionToken: sessionToken,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
|
||||
// exchange with a used code must fail
|
||||
_, err = exchangeTokens(t, clientID, code)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestOPStorage_CreateAccessToken_implicit(t *testing.T) {
|
||||
clientID := createImplicitClient(t)
|
||||
authRequestID := createAuthRequestImplicit(t, clientID, redirectURIImplicit)
|
||||
sessionID, sessionToken, startTime, changeTime := Tester.CreatePasskeySession(t, CTXLOGIN, User.GetUserId())
|
||||
linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
|
||||
AuthRequestId: authRequestID,
|
||||
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
|
||||
Session: &oidc_pb.Session{
|
||||
SessionId: sessionID,
|
||||
SessionToken: sessionToken,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// test implicit callback
|
||||
callback, err := url.Parse(linkResp.GetCallbackUrl())
|
||||
require.NoError(t, err)
|
||||
values, err := url.ParseQuery(callback.Fragment)
|
||||
require.NoError(t, err)
|
||||
accessToken := values.Get("access_token")
|
||||
idToken := values.Get("id_token")
|
||||
refreshToken := values.Get("refresh_token")
|
||||
assert.NotEmpty(t, accessToken)
|
||||
assert.NotEmpty(t, idToken)
|
||||
assert.Empty(t, refreshToken)
|
||||
assert.NotEmpty(t, values.Get("expires_in"))
|
||||
assert.Equal(t, oidc.BearerToken, values.Get("token_type"))
|
||||
assert.Equal(t, "state", values.Get("state"))
|
||||
|
||||
// check id_token / claims
|
||||
provider, err := Tester.CreateRelyingParty(clientID, redirectURIImplicit)
|
||||
require.NoError(t, err)
|
||||
claims, err := rp.VerifyTokens[*oidc.IDTokenClaims](context.Background(), accessToken, idToken, provider.IDTokenVerifier())
|
||||
require.NoError(t, err)
|
||||
assertTokenClaims(t, claims, startTime, changeTime)
|
||||
|
||||
// callback on a succeeded request must fail
|
||||
linkResp, err = Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
|
||||
AuthRequestId: authRequestID,
|
||||
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
|
||||
Session: &oidc_pb.Session{
|
||||
SessionId: sessionID,
|
||||
SessionToken: sessionToken,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestOPStorage_CreateAccessAndRefreshTokens_code(t *testing.T) {
|
||||
clientID := createClient(t)
|
||||
authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess)
|
||||
sessionID, sessionToken, startTime, changeTime := Tester.CreatePasskeySession(t, CTXLOGIN, User.GetUserId())
|
||||
linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
|
||||
AuthRequestId: authRequestID,
|
||||
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
|
||||
Session: &oidc_pb.Session{
|
||||
SessionId: sessionID,
|
||||
SessionToken: sessionToken,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// test code exchange (expect refresh token to be returned)
|
||||
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
|
||||
tokens, err := exchangeTokens(t, clientID, code)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, true)
|
||||
assertTokenClaims(t, tokens.IDTokenClaims, startTime, changeTime)
|
||||
}
|
||||
|
||||
func TestOPStorage_CreateAccessAndRefreshTokens_refresh(t *testing.T) {
|
||||
clientID := createClient(t)
|
||||
provider, err := Tester.CreateRelyingParty(clientID, redirectURI)
|
||||
require.NoError(t, err)
|
||||
authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess)
|
||||
sessionID, sessionToken, startTime, changeTime := Tester.CreatePasskeySession(t, CTXLOGIN, User.GetUserId())
|
||||
linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
|
||||
AuthRequestId: authRequestID,
|
||||
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
|
||||
Session: &oidc_pb.Session{
|
||||
SessionId: sessionID,
|
||||
SessionToken: sessionToken,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// code exchange
|
||||
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
|
||||
tokens, err := exchangeTokens(t, clientID, code)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, true)
|
||||
assertTokenClaims(t, tokens.IDTokenClaims, startTime, changeTime)
|
||||
|
||||
// test actual refresh grant
|
||||
newTokens, err := refreshTokens(t, clientID, tokens.RefreshToken)
|
||||
require.NoError(t, err)
|
||||
idToken, _ := newTokens.Extra("id_token").(string)
|
||||
assert.NotEmpty(t, idToken)
|
||||
assert.NotEmpty(t, newTokens.AccessToken)
|
||||
assert.NotEmpty(t, newTokens.RefreshToken)
|
||||
claims, err := rp.VerifyTokens[*oidc.IDTokenClaims](context.Background(), newTokens.AccessToken, idToken, provider.IDTokenVerifier())
|
||||
require.NoError(t, err)
|
||||
// auth time must still be the initial
|
||||
assertTokenClaims(t, claims, startTime, changeTime)
|
||||
|
||||
// refresh with an old refresh_token must fail
|
||||
_, err = rp.RefreshAccessToken(provider, tokens.RefreshToken, "", "")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func exchangeTokens(t testing.TB, clientID, code string) (*oidc.Tokens[*oidc.IDTokenClaims], error) {
|
||||
provider, err := Tester.CreateRelyingParty(clientID, redirectURI)
|
||||
require.NoError(t, err)
|
||||
|
||||
codeVerifier := "codeVerifier"
|
||||
return rp.CodeExchange[*oidc.IDTokenClaims](context.Background(), code, provider, rp.WithCodeVerifier(codeVerifier))
|
||||
}
|
||||
|
||||
func refreshTokens(t testing.TB, clientID, refreshToken string) (*oauth2.Token, error) {
|
||||
provider, err := Tester.CreateRelyingParty(clientID, redirectURI)
|
||||
require.NoError(t, err)
|
||||
|
||||
return rp.RefreshAccessToken(provider, refreshToken, "", "")
|
||||
}
|
||||
|
||||
func assertCodeResponse(t *testing.T, callback string) string {
|
||||
callbackURL, err := url.Parse(callback)
|
||||
require.NoError(t, err)
|
||||
code := callbackURL.Query().Get("code")
|
||||
require.NotEmpty(t, code)
|
||||
assert.Equal(t, "state", callbackURL.Query().Get("state"))
|
||||
return code
|
||||
}
|
||||
|
||||
func assertTokens(t *testing.T, tokens *oidc.Tokens[*oidc.IDTokenClaims], requireRefreshToken bool) {
|
||||
assert.NotEmpty(t, tokens.AccessToken)
|
||||
assert.NotEmpty(t, tokens.IDToken)
|
||||
if requireRefreshToken {
|
||||
assert.NotEmpty(t, tokens.RefreshToken)
|
||||
} else {
|
||||
assert.Empty(t, tokens.RefreshToken)
|
||||
}
|
||||
}
|
||||
|
||||
func assertTokenClaims(t *testing.T, claims *oidc.IDTokenClaims, sessionStart, sessionChange time.Time) {
|
||||
assert.Equal(t, User.GetUserId(), claims.Subject)
|
||||
assert.Equal(t, []string{amr.UserPresence, amr.MFA}, claims.AuthenticationMethodsReferences)
|
||||
assert.WithinRange(t, claims.AuthTime.AsTime().UTC(), sessionStart.Add(-1*time.Second), sessionChange.Add(1*time.Second))
|
||||
}
|
@ -66,7 +66,7 @@ func (o *OPStorage) GetClientByClientID(ctx context.Context, id string) (_ op.Cl
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ClientFromBusiness(client, o.defaultLoginURL, accessTokenLifetime, idTokenLifetime, allowedScopes)
|
||||
return ClientFromBusiness(client, o.defaultLoginURL, o.defaultLoginURLV2, accessTokenLifetime, idTokenLifetime, allowedScopes)
|
||||
}
|
||||
|
||||
func (o *OPStorage) GetKeyByIDAndClientID(ctx context.Context, keyID, userID string) (_ *jose.JSONWebKey, err error) {
|
||||
@ -153,7 +153,7 @@ func (o *OPStorage) SetUserinfoFromScopes(ctx context.Context, userInfo *oidc.Us
|
||||
return o.setUserinfo(ctx, userInfo, userID, applicationID, scopes, nil)
|
||||
}
|
||||
|
||||
func (o *OPStorage) SetIntrospectionFromToken(ctx context.Context, introspection *oidc.IntrospectionResponse, tokenID, subject, clientID string) error {
|
||||
func (o *OPStorage) SetIntrospectionFromToken(ctx context.Context, introspection *oidc.IntrospectionResponse, tokenID, subject, clientID string) (err error) {
|
||||
token, err := o.repo.TokenByIDs(ctx, subject, tokenID)
|
||||
if err != nil {
|
||||
return errors.ThrowPermissionDenied(nil, "OIDC-Dsfb2", "token is not valid or has expired")
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/op"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
@ -15,18 +16,20 @@ import (
|
||||
type Client struct {
|
||||
app *query.App
|
||||
defaultLoginURL string
|
||||
defaultLoginURLV2 string
|
||||
defaultAccessTokenLifetime time.Duration
|
||||
defaultIdTokenLifetime time.Duration
|
||||
allowedScopes []string
|
||||
}
|
||||
|
||||
func ClientFromBusiness(app *query.App, defaultLoginURL string, defaultAccessTokenLifetime, defaultIdTokenLifetime time.Duration, allowedScopes []string) (op.Client, error) {
|
||||
func ClientFromBusiness(app *query.App, defaultLoginURL, defaultLoginURLV2 string, defaultAccessTokenLifetime, defaultIdTokenLifetime time.Duration, allowedScopes []string) (op.Client, error) {
|
||||
if app.OIDCConfig == nil {
|
||||
return nil, errors.ThrowInvalidArgument(nil, "OIDC-d5bhD", "client is not a proper oidc application")
|
||||
}
|
||||
return &Client{
|
||||
app: app,
|
||||
defaultLoginURL: defaultLoginURL,
|
||||
defaultLoginURLV2: defaultLoginURLV2,
|
||||
defaultAccessTokenLifetime: defaultAccessTokenLifetime,
|
||||
defaultIdTokenLifetime: defaultIdTokenLifetime,
|
||||
allowedScopes: allowedScopes},
|
||||
@ -46,6 +49,9 @@ func (c *Client) GetID() string {
|
||||
}
|
||||
|
||||
func (c *Client) LoginURL(id string) string {
|
||||
if strings.HasPrefix(id, command.IDPrefixV2) {
|
||||
return c.defaultLoginURLV2 + id
|
||||
}
|
||||
return c.defaultLoginURL + id
|
||||
}
|
||||
|
||||
|
@ -41,6 +41,7 @@ type Config struct {
|
||||
Cache *middleware.CacheConfig
|
||||
CustomEndpoints *EndpointConfig
|
||||
DeviceAuth *DeviceAuthorizationConfig
|
||||
DefaultLoginURLV2 string
|
||||
}
|
||||
|
||||
type EndpointConfig struct {
|
||||
@ -65,6 +66,7 @@ type OPStorage struct {
|
||||
query *query.Queries
|
||||
eventstore *eventstore.Eventstore
|
||||
defaultLoginURL string
|
||||
defaultLoginURLV2 string
|
||||
defaultAccessTokenLifetime time.Duration
|
||||
defaultIdTokenLifetime time.Duration
|
||||
signingKeyAlgorithm string
|
||||
@ -181,6 +183,7 @@ func newStorage(config Config, command *command.Commands, query *query.Queries,
|
||||
query: query,
|
||||
eventstore: es,
|
||||
defaultLoginURL: fmt.Sprintf("%s%s?%s=", login.HandlerPrefix, login.EndpointLogin, login.QueryAuthRequestID),
|
||||
defaultLoginURLV2: config.DefaultLoginURLV2,
|
||||
signingKeyAlgorithm: config.SigningKeyAlgorithm,
|
||||
defaultAccessTokenLifetime: config.DefaultAccessTokenLifetime,
|
||||
defaultIdTokenLifetime: config.DefaultIdTokenLifetime,
|
||||
|
215
internal/command/auth_request.go
Normal file
215
internal/command/auth_request.go
Normal file
@ -0,0 +1,215 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/api/oidc/amr"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/repository/authrequest"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
|
||||
type AuthRequest struct {
|
||||
ID string
|
||||
LoginClient string
|
||||
ClientID string
|
||||
RedirectURI string
|
||||
State string
|
||||
Nonce string
|
||||
Scope []string
|
||||
Audience []string
|
||||
ResponseType domain.OIDCResponseType
|
||||
CodeChallenge *domain.OIDCCodeChallenge
|
||||
Prompt []domain.Prompt
|
||||
UILocales []string
|
||||
MaxAge *time.Duration
|
||||
LoginHint *string
|
||||
HintUserID *string
|
||||
}
|
||||
|
||||
type CurrentAuthRequest struct {
|
||||
*AuthRequest
|
||||
SessionID string
|
||||
UserID string
|
||||
AMR []string
|
||||
AuthTime time.Time
|
||||
}
|
||||
|
||||
const IDPrefixV2 = "V2_"
|
||||
|
||||
func (c *Commands) AddAuthRequest(ctx context.Context, authRequest *AuthRequest) (_ *CurrentAuthRequest, err error) {
|
||||
authRequestID, err := c.idGenerator.Next()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authRequest.ID = IDPrefixV2 + authRequestID
|
||||
writeModel, err := c.getAuthRequestWriteModel(ctx, authRequest.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if writeModel.AuthRequestState != domain.AuthRequestStateUnspecified {
|
||||
return nil, errors.ThrowPreconditionFailed(nil, "COMMAND-Sf3gt", "Errors.AuthRequest.AlreadyExisting")
|
||||
}
|
||||
err = c.pushAppendAndReduce(ctx, writeModel, authrequest.NewAddedEvent(
|
||||
ctx,
|
||||
&authrequest.NewAggregate(authRequest.ID, authz.GetInstance(ctx).InstanceID()).Aggregate,
|
||||
authRequest.LoginClient,
|
||||
authRequest.ClientID,
|
||||
authRequest.RedirectURI,
|
||||
authRequest.State,
|
||||
authRequest.Nonce,
|
||||
authRequest.Scope,
|
||||
authRequest.Audience,
|
||||
authRequest.ResponseType,
|
||||
authRequest.CodeChallenge,
|
||||
authRequest.Prompt,
|
||||
authRequest.UILocales,
|
||||
authRequest.MaxAge,
|
||||
authRequest.LoginHint,
|
||||
authRequest.HintUserID,
|
||||
))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return authRequestWriteModelToCurrentAuthRequest(writeModel), nil
|
||||
}
|
||||
|
||||
func (c *Commands) LinkSessionToAuthRequest(ctx context.Context, id, sessionID, sessionToken string, checkLoginClient bool) (*domain.ObjectDetails, *CurrentAuthRequest, error) {
|
||||
writeModel, err := c.getAuthRequestWriteModel(ctx, id)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if writeModel.AuthRequestState == domain.AuthRequestStateUnspecified {
|
||||
return nil, nil, errors.ThrowNotFound(nil, "COMMAND-jae5P", "Errors.AuthRequest.NotExisting")
|
||||
}
|
||||
if writeModel.AuthRequestState != domain.AuthRequestStateAdded {
|
||||
return nil, nil, errors.ThrowPreconditionFailed(nil, "COMMAND-Sx208nt", "Errors.AuthRequest.AlreadyHandled")
|
||||
}
|
||||
if checkLoginClient && authz.GetCtxData(ctx).UserID != writeModel.LoginClient {
|
||||
return nil, nil, errors.ThrowPermissionDenied(nil, "COMMAND-rai9Y", "Errors.AuthRequest.WrongLoginClient")
|
||||
}
|
||||
sessionWriteModel := NewSessionWriteModel(sessionID, authz.GetCtxData(ctx).OrgID)
|
||||
err = c.eventstore.FilterToQueryReducer(ctx, sessionWriteModel)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if sessionWriteModel.State == domain.SessionStateUnspecified {
|
||||
return nil, nil, errors.ThrowNotFound(nil, "COMMAND-x0099887", "Errors.Session.NotExisting")
|
||||
}
|
||||
if err := c.sessionPermission(ctx, sessionWriteModel, sessionToken, domain.PermissionSessionWrite); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := c.pushAppendAndReduce(ctx, writeModel, authrequest.NewSessionLinkedEvent(
|
||||
ctx, &authrequest.NewAggregate(id, authz.GetInstance(ctx).InstanceID()).Aggregate,
|
||||
sessionID,
|
||||
sessionWriteModel.UserID,
|
||||
sessionWriteModel.AuthenticationTime(),
|
||||
amr.List(sessionWriteModel),
|
||||
)); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return writeModelToObjectDetails(&writeModel.WriteModel), authRequestWriteModelToCurrentAuthRequest(writeModel), nil
|
||||
}
|
||||
|
||||
func (c *Commands) FailAuthRequest(ctx context.Context, id string, reason domain.OIDCErrorReason) (*domain.ObjectDetails, *CurrentAuthRequest, error) {
|
||||
writeModel, err := c.getAuthRequestWriteModel(ctx, id)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if writeModel.AuthRequestState != domain.AuthRequestStateAdded {
|
||||
return nil, nil, errors.ThrowPreconditionFailed(nil, "COMMAND-Sx202nt", "Errors.AuthRequest.AlreadyHandled")
|
||||
}
|
||||
err = c.pushAppendAndReduce(ctx, writeModel, authrequest.NewFailedEvent(
|
||||
ctx,
|
||||
&authrequest.NewAggregate(id, authz.GetInstance(ctx).InstanceID()).Aggregate,
|
||||
reason,
|
||||
))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return writeModelToObjectDetails(&writeModel.WriteModel), authRequestWriteModelToCurrentAuthRequest(writeModel), nil
|
||||
}
|
||||
|
||||
func (c *Commands) AddAuthRequestCode(ctx context.Context, authRequestID, code string) (err error) {
|
||||
if code == "" {
|
||||
return errors.ThrowPreconditionFailed(nil, "COMMAND-Ht52d", "Errors.AuthRequest.InvalidCode")
|
||||
}
|
||||
writeModel, err := c.getAuthRequestWriteModel(ctx, authRequestID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if writeModel.AuthRequestState != domain.AuthRequestStateAdded || writeModel.SessionID == "" {
|
||||
return errors.ThrowPreconditionFailed(nil, "COMMAND-SFwd2", "Errors.AuthRequest.AlreadyHandled")
|
||||
}
|
||||
return c.pushAppendAndReduce(ctx, writeModel, authrequest.NewCodeAddedEvent(ctx,
|
||||
&authrequest.NewAggregate(writeModel.AggregateID, authz.GetInstance(ctx).InstanceID()).Aggregate))
|
||||
}
|
||||
|
||||
func (c *Commands) ExchangeAuthCode(ctx context.Context, code string) (authRequest *CurrentAuthRequest, err error) {
|
||||
if code == "" {
|
||||
return nil, errors.ThrowPreconditionFailed(nil, "COMMAND-Sf3g2", "Errors.AuthRequest.InvalidCode")
|
||||
}
|
||||
writeModel, err := c.getAuthRequestWriteModel(ctx, code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if writeModel.AuthRequestState != domain.AuthRequestStateCodeAdded {
|
||||
return nil, errors.ThrowPreconditionFailed(nil, "COMMAND-SFwd2", "Errors.AuthRequest.NoCode")
|
||||
}
|
||||
err = c.pushAppendAndReduce(ctx, writeModel, authrequest.NewCodeExchangedEvent(ctx,
|
||||
&authrequest.NewAggregate(writeModel.AggregateID, authz.GetInstance(ctx).InstanceID()).Aggregate))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return authRequestWriteModelToCurrentAuthRequest(writeModel), nil
|
||||
}
|
||||
|
||||
func authRequestWriteModelToCurrentAuthRequest(writeModel *AuthRequestWriteModel) (_ *CurrentAuthRequest) {
|
||||
return &CurrentAuthRequest{
|
||||
AuthRequest: &AuthRequest{
|
||||
ID: writeModel.AggregateID,
|
||||
LoginClient: writeModel.LoginClient,
|
||||
ClientID: writeModel.ClientID,
|
||||
RedirectURI: writeModel.RedirectURI,
|
||||
State: writeModel.State,
|
||||
Nonce: writeModel.Nonce,
|
||||
Scope: writeModel.Scope,
|
||||
Audience: writeModel.Audience,
|
||||
ResponseType: writeModel.ResponseType,
|
||||
CodeChallenge: writeModel.CodeChallenge,
|
||||
Prompt: writeModel.Prompt,
|
||||
UILocales: writeModel.UILocales,
|
||||
MaxAge: writeModel.MaxAge,
|
||||
LoginHint: writeModel.LoginHint,
|
||||
HintUserID: writeModel.HintUserID,
|
||||
},
|
||||
SessionID: writeModel.SessionID,
|
||||
UserID: writeModel.UserID,
|
||||
AMR: writeModel.AMR,
|
||||
AuthTime: writeModel.AuthTime,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Commands) GetCurrentAuthRequest(ctx context.Context, id string) (_ *CurrentAuthRequest, err error) {
|
||||
wm, err := c.getAuthRequestWriteModel(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return authRequestWriteModelToCurrentAuthRequest(wm), nil
|
||||
}
|
||||
|
||||
func (c *Commands) getAuthRequestWriteModel(ctx context.Context, id string) (writeModel *AuthRequestWriteModel, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
writeModel = NewAuthRequestWriteModel(ctx, id)
|
||||
err = c.eventstore.FilterToQueryReducer(ctx, writeModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return writeModel, nil
|
||||
}
|
110
internal/command/auth_request_model.go
Normal file
110
internal/command/auth_request_model.go
Normal file
@ -0,0 +1,110 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
caos_errs "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/repository/authrequest"
|
||||
)
|
||||
|
||||
type AuthRequestWriteModel struct {
|
||||
eventstore.WriteModel
|
||||
aggregate *eventstore.Aggregate
|
||||
|
||||
LoginClient string
|
||||
ClientID string
|
||||
RedirectURI string
|
||||
State string
|
||||
Nonce string
|
||||
Scope []string
|
||||
Audience []string
|
||||
ResponseType domain.OIDCResponseType
|
||||
CodeChallenge *domain.OIDCCodeChallenge
|
||||
Prompt []domain.Prompt
|
||||
UILocales []string
|
||||
MaxAge *time.Duration
|
||||
LoginHint *string
|
||||
HintUserID *string
|
||||
SessionID string
|
||||
UserID string
|
||||
AuthTime time.Time
|
||||
AMR []string
|
||||
AuthRequestState domain.AuthRequestState
|
||||
}
|
||||
|
||||
func NewAuthRequestWriteModel(ctx context.Context, id string) *AuthRequestWriteModel {
|
||||
return &AuthRequestWriteModel{
|
||||
WriteModel: eventstore.WriteModel{
|
||||
AggregateID: id,
|
||||
},
|
||||
aggregate: &authrequest.NewAggregate(id, authz.GetInstance(ctx).InstanceID()).Aggregate,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *AuthRequestWriteModel) Reduce() error {
|
||||
for _, event := range m.Events {
|
||||
switch e := event.(type) {
|
||||
case *authrequest.AddedEvent:
|
||||
m.LoginClient = e.LoginClient
|
||||
m.ClientID = e.ClientID
|
||||
m.RedirectURI = e.RedirectURI
|
||||
m.State = e.State
|
||||
m.Nonce = e.Nonce
|
||||
m.Scope = e.Scope
|
||||
m.Audience = e.Audience
|
||||
m.ResponseType = e.ResponseType
|
||||
m.CodeChallenge = e.CodeChallenge
|
||||
m.Prompt = e.Prompt
|
||||
m.UILocales = e.UILocales
|
||||
m.MaxAge = e.MaxAge
|
||||
m.LoginHint = e.LoginHint
|
||||
m.HintUserID = e.HintUserID
|
||||
m.AuthRequestState = domain.AuthRequestStateAdded
|
||||
case *authrequest.SessionLinkedEvent:
|
||||
m.SessionID = e.SessionID
|
||||
m.UserID = e.UserID
|
||||
m.AuthTime = e.AuthTime
|
||||
m.AMR = e.AMR
|
||||
case *authrequest.CodeAddedEvent:
|
||||
m.AuthRequestState = domain.AuthRequestStateCodeAdded
|
||||
case *authrequest.FailedEvent:
|
||||
m.AuthRequestState = domain.AuthRequestStateFailed
|
||||
case *authrequest.CodeExchangedEvent:
|
||||
m.AuthRequestState = domain.AuthRequestStateCodeExchanged
|
||||
case *authrequest.SucceededEvent:
|
||||
m.AuthRequestState = domain.AuthRequestStateSucceeded
|
||||
}
|
||||
}
|
||||
|
||||
return m.WriteModel.Reduce()
|
||||
}
|
||||
|
||||
func (m *AuthRequestWriteModel) Query() *eventstore.SearchQueryBuilder {
|
||||
return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
|
||||
AddQuery().
|
||||
AggregateTypes(authrequest.AggregateType).
|
||||
AggregateIDs(m.AggregateID).
|
||||
Builder()
|
||||
}
|
||||
|
||||
// CheckAuthenticated checks that the auth request exists, a session must have been linked
|
||||
// and in case of a Code Flow the code must have been exchanged
|
||||
func (m *AuthRequestWriteModel) CheckAuthenticated() error {
|
||||
if m.SessionID == "" {
|
||||
return caos_errs.ThrowPreconditionFailed(nil, "AUTHR-SF2r2", "Errors.AuthRequest.NotAuthenticated")
|
||||
}
|
||||
// in case of OIDC Code Flow, the code must have been exchanged
|
||||
if m.ResponseType == domain.OIDCResponseTypeCode && m.AuthRequestState == domain.AuthRequestStateCodeExchanged {
|
||||
return nil
|
||||
}
|
||||
// in case of OIDC Implicit Flow, check that the requests exists, but has not succeeded yet
|
||||
if (m.ResponseType == domain.OIDCResponseTypeIDToken || m.ResponseType == domain.OIDCResponseTypeIDTokenToken) &&
|
||||
m.AuthRequestState == domain.AuthRequestStateAdded {
|
||||
return nil
|
||||
}
|
||||
return caos_errs.ThrowPreconditionFailed(nil, "AUTHR-sajk3", "Errors.AuthRequest.NotAuthenticated")
|
||||
}
|
998
internal/command/auth_request_test.go
Normal file
998
internal/command/auth_request_test.go
Normal file
@ -0,0 +1,998 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/api/oidc/amr"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
caos_errs "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/repository"
|
||||
"github.com/zitadel/zitadel/internal/id"
|
||||
"github.com/zitadel/zitadel/internal/id/mock"
|
||||
"github.com/zitadel/zitadel/internal/repository/authrequest"
|
||||
"github.com/zitadel/zitadel/internal/repository/session"
|
||||
)
|
||||
|
||||
func TestCommands_AddAuthRequest(t *testing.T) {
|
||||
mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient")
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
idGenerator id.Generator
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
request *AuthRequest
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want *CurrentAuthRequest
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
"already exists error",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"clientID",
|
||||
"redirectURI",
|
||||
"state",
|
||||
"nonce",
|
||||
[]string{"openid"},
|
||||
[]string{"audience"},
|
||||
domain.OIDCResponseTypeCode,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
idGenerator: mock.NewIDGeneratorExpectIDs(t, "id"),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
request: &AuthRequest{},
|
||||
},
|
||||
nil,
|
||||
caos_errs.ThrowPreconditionFailed(nil, "COMMAND-Sf3gt", "Errors.AuthRequest.AlreadyExisting"),
|
||||
},
|
||||
{
|
||||
"added",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(),
|
||||
expectPush(
|
||||
[]*repository.Event{
|
||||
eventFromEventPusherWithInstanceID("instanceID",
|
||||
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"clientID",
|
||||
"redirectURI",
|
||||
"state",
|
||||
"nonce",
|
||||
[]string{"openid"},
|
||||
[]string{"audience"},
|
||||
domain.OIDCResponseTypeCode,
|
||||
&domain.OIDCCodeChallenge{
|
||||
Challenge: "challenge",
|
||||
Method: domain.CodeChallengeMethodS256,
|
||||
},
|
||||
[]domain.Prompt{domain.PromptNone},
|
||||
[]string{"en", "de"},
|
||||
gu.Ptr(time.Duration(0)),
|
||||
gu.Ptr("loginHint"),
|
||||
gu.Ptr("hintUserID"),
|
||||
),
|
||||
),
|
||||
}),
|
||||
),
|
||||
idGenerator: mock.NewIDGeneratorExpectIDs(t, "id"),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
request: &AuthRequest{
|
||||
LoginClient: "loginClient",
|
||||
ClientID: "clientID",
|
||||
RedirectURI: "redirectURI",
|
||||
State: "state",
|
||||
Nonce: "nonce",
|
||||
Scope: []string{"openid"},
|
||||
Audience: []string{"audience"},
|
||||
ResponseType: domain.OIDCResponseTypeCode,
|
||||
CodeChallenge: &domain.OIDCCodeChallenge{
|
||||
Challenge: "challenge",
|
||||
Method: domain.CodeChallengeMethodS256,
|
||||
},
|
||||
Prompt: []domain.Prompt{domain.PromptNone},
|
||||
UILocales: []string{"en", "de"},
|
||||
MaxAge: gu.Ptr(time.Duration(0)),
|
||||
LoginHint: gu.Ptr("loginHint"),
|
||||
HintUserID: gu.Ptr("hintUserID"),
|
||||
},
|
||||
},
|
||||
&CurrentAuthRequest{
|
||||
AuthRequest: &AuthRequest{
|
||||
ID: "V2_id",
|
||||
LoginClient: "loginClient",
|
||||
ClientID: "clientID",
|
||||
RedirectURI: "redirectURI",
|
||||
State: "state",
|
||||
Nonce: "nonce",
|
||||
Scope: []string{"openid"},
|
||||
Audience: []string{"audience"},
|
||||
ResponseType: domain.OIDCResponseTypeCode,
|
||||
CodeChallenge: &domain.OIDCCodeChallenge{
|
||||
Challenge: "challenge",
|
||||
Method: domain.CodeChallengeMethodS256,
|
||||
},
|
||||
Prompt: []domain.Prompt{domain.PromptNone},
|
||||
UILocales: []string{"en", "de"},
|
||||
MaxAge: gu.Ptr(time.Duration(0)),
|
||||
LoginHint: gu.Ptr("loginHint"),
|
||||
HintUserID: gu.Ptr("hintUserID"),
|
||||
},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore,
|
||||
idGenerator: tt.fields.idGenerator,
|
||||
}
|
||||
got, err := c.AddAuthRequest(tt.args.ctx, tt.args.request)
|
||||
require.ErrorIs(t, tt.wantErr, err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
|
||||
mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient")
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
tokenVerifier func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error)
|
||||
checkPermission domain.PermissionCheck
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
id string
|
||||
sessionID string
|
||||
sessionToken string
|
||||
checkLoginClient bool
|
||||
}
|
||||
type res struct {
|
||||
details *domain.ObjectDetails
|
||||
authReq *CurrentAuthRequest
|
||||
wantErr error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"authRequest not found",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(),
|
||||
),
|
||||
tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) {
|
||||
return nil
|
||||
},
|
||||
checkPermission: newMockPermissionCheckNotAllowed(),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
id: "id",
|
||||
sessionID: "sessionID",
|
||||
},
|
||||
res{
|
||||
wantErr: caos_errs.ThrowNotFound(nil, "COMMAND-jae5P", "Errors.AuthRequest.NotExisting"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"authRequest not existing",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("id", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"clientID",
|
||||
"redirectURI",
|
||||
"state",
|
||||
"nonce",
|
||||
[]string{"openid"},
|
||||
[]string{"audience"},
|
||||
domain.OIDCResponseTypeCode,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
authrequest.NewFailedEvent(mockCtx, &authrequest.NewAggregate("id", "instanceID").Aggregate,
|
||||
domain.OIDCErrorReasonUnspecified),
|
||||
),
|
||||
),
|
||||
),
|
||||
tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) {
|
||||
return nil
|
||||
},
|
||||
checkPermission: newMockPermissionCheckAllowed(),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
id: "id",
|
||||
sessionID: "sessionID",
|
||||
},
|
||||
res{
|
||||
wantErr: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-Sx208nt", "Errors.AuthRequest.AlreadyHandled"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"wrong login client",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("id", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"clientID",
|
||||
"redirectURI",
|
||||
"state",
|
||||
"nonce",
|
||||
[]string{"openid"},
|
||||
[]string{"audience"},
|
||||
domain.OIDCResponseTypeCode,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) {
|
||||
return nil
|
||||
},
|
||||
checkPermission: newMockPermissionCheckAllowed(),
|
||||
},
|
||||
args{
|
||||
ctx: authz.NewMockContext("instanceID", "orgID", "wrongLoginClient"),
|
||||
id: "id",
|
||||
sessionID: "sessionID",
|
||||
sessionToken: "token",
|
||||
checkLoginClient: true,
|
||||
},
|
||||
res{
|
||||
wantErr: caos_errs.ThrowPermissionDenied(nil, "COMMAND-rai9Y", "Errors.AuthRequest.WrongLoginClient"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"session not existing",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"clientID",
|
||||
"redirectURI",
|
||||
"state",
|
||||
"nonce",
|
||||
[]string{"openid"},
|
||||
[]string{"audience"},
|
||||
domain.OIDCResponseTypeCode,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
),
|
||||
),
|
||||
),
|
||||
expectFilter(),
|
||||
),
|
||||
tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) {
|
||||
return nil
|
||||
},
|
||||
checkPermission: newMockPermissionCheckNotAllowed(),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
id: "V2_id",
|
||||
sessionID: "sessionID",
|
||||
},
|
||||
res{
|
||||
wantErr: caos_errs.ThrowNotFound(nil, "COMMAND-x0099887", "Errors.Session.NotExisting"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"missing permission",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"clientID",
|
||||
"redirectURI",
|
||||
"state",
|
||||
"nonce",
|
||||
[]string{"openid"},
|
||||
[]string{"audience"},
|
||||
domain.OIDCResponseTypeCode,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
),
|
||||
),
|
||||
),
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
session.NewAddedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld")),
|
||||
),
|
||||
),
|
||||
tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) {
|
||||
return nil
|
||||
},
|
||||
checkPermission: newMockPermissionCheckNotAllowed(),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
id: "V2_id",
|
||||
sessionID: "sessionID",
|
||||
},
|
||||
res{
|
||||
wantErr: caos_errs.ThrowPermissionDenied(nil, "AUTHZ-HKJD33", "Errors.PermissionDenied"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"invalid session token",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"clientID",
|
||||
"redirectURI",
|
||||
"state",
|
||||
"nonce",
|
||||
[]string{"openid"},
|
||||
[]string{"audience"},
|
||||
domain.OIDCResponseTypeCode,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
),
|
||||
),
|
||||
),
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
session.NewAddedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld")),
|
||||
),
|
||||
),
|
||||
tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) {
|
||||
return caos_errs.ThrowPermissionDenied(nil, "COMMAND-sGr42", "Errors.Session.Token.Invalid")
|
||||
},
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
id: "V2_id",
|
||||
sessionID: "sessionID",
|
||||
sessionToken: "invalid",
|
||||
},
|
||||
res{
|
||||
wantErr: caos_errs.ThrowPermissionDenied(nil, "COMMAND-sGr42", "Errors.Session.Token.Invalid"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"linked",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"clientID",
|
||||
"redirectURI",
|
||||
"state",
|
||||
"nonce",
|
||||
[]string{"openid"},
|
||||
[]string{"audience"},
|
||||
domain.OIDCResponseTypeCode,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
),
|
||||
),
|
||||
),
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
session.NewAddedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld"),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
session.NewUserCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate,
|
||||
"userID", testNow),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
session.NewPasswordCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate,
|
||||
testNow),
|
||||
),
|
||||
),
|
||||
expectPush(
|
||||
[]*repository.Event{eventFromEventPusherWithInstanceID(
|
||||
"instanceID",
|
||||
authrequest.NewSessionLinkedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||
"sessionID",
|
||||
"userID",
|
||||
testNow,
|
||||
[]string{amr.PWD},
|
||||
),
|
||||
)}),
|
||||
),
|
||||
tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) {
|
||||
return nil
|
||||
},
|
||||
checkPermission: newMockPermissionCheckAllowed(),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
id: "V2_id",
|
||||
sessionID: "sessionID",
|
||||
sessionToken: "token",
|
||||
},
|
||||
res{
|
||||
details: &domain.ObjectDetails{ResourceOwner: "instanceID"},
|
||||
authReq: &CurrentAuthRequest{
|
||||
AuthRequest: &AuthRequest{
|
||||
ID: "V2_id",
|
||||
LoginClient: "loginClient",
|
||||
ClientID: "clientID",
|
||||
RedirectURI: "redirectURI",
|
||||
State: "state",
|
||||
Nonce: "nonce",
|
||||
Scope: []string{"openid"},
|
||||
Audience: []string{"audience"},
|
||||
ResponseType: domain.OIDCResponseTypeCode,
|
||||
},
|
||||
SessionID: "sessionID",
|
||||
UserID: "userID",
|
||||
AMR: []string{amr.PWD},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"linked with login client check",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"clientID",
|
||||
"redirectURI",
|
||||
"state",
|
||||
"nonce",
|
||||
[]string{"openid"},
|
||||
[]string{"audience"},
|
||||
domain.OIDCResponseTypeCode,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
),
|
||||
),
|
||||
),
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
session.NewAddedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld"),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
session.NewUserCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate,
|
||||
"userID", testNow),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
session.NewPasswordCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate,
|
||||
testNow),
|
||||
),
|
||||
),
|
||||
expectPush(
|
||||
[]*repository.Event{eventFromEventPusherWithInstanceID(
|
||||
"instanceID",
|
||||
authrequest.NewSessionLinkedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||
"sessionID",
|
||||
"userID",
|
||||
testNow,
|
||||
[]string{amr.PWD},
|
||||
),
|
||||
)}),
|
||||
),
|
||||
tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) {
|
||||
return nil
|
||||
},
|
||||
checkPermission: newMockPermissionCheckAllowed(),
|
||||
},
|
||||
args{
|
||||
ctx: authz.NewMockContext("instanceID", "orgID", "loginClient"),
|
||||
id: "V2_id",
|
||||
sessionID: "sessionID",
|
||||
sessionToken: "token",
|
||||
checkLoginClient: true,
|
||||
},
|
||||
res{
|
||||
details: &domain.ObjectDetails{ResourceOwner: "instanceID"},
|
||||
authReq: &CurrentAuthRequest{
|
||||
AuthRequest: &AuthRequest{
|
||||
ID: "V2_id",
|
||||
LoginClient: "loginClient",
|
||||
ClientID: "clientID",
|
||||
RedirectURI: "redirectURI",
|
||||
State: "state",
|
||||
Nonce: "nonce",
|
||||
Scope: []string{"openid"},
|
||||
Audience: []string{"audience"},
|
||||
ResponseType: domain.OIDCResponseTypeCode,
|
||||
},
|
||||
SessionID: "sessionID",
|
||||
UserID: "userID",
|
||||
AMR: []string{amr.PWD},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore,
|
||||
sessionTokenVerifier: tt.fields.tokenVerifier,
|
||||
checkPermission: tt.fields.checkPermission,
|
||||
}
|
||||
details, got, err := c.LinkSessionToAuthRequest(tt.args.ctx, tt.args.id, tt.args.sessionID, tt.args.sessionToken, tt.args.checkLoginClient)
|
||||
require.ErrorIs(t, err, tt.res.wantErr)
|
||||
assert.Equal(t, tt.res.details, details)
|
||||
if err == nil {
|
||||
assert.WithinRange(t, got.AuthTime, testNow, testNow)
|
||||
got.AuthTime = time.Time{}
|
||||
}
|
||||
assert.Equal(t, tt.res.authReq, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommands_FailAuthRequest(t *testing.T) {
|
||||
mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient")
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
id string
|
||||
reason domain.OIDCErrorReason
|
||||
}
|
||||
type res struct {
|
||||
details *domain.ObjectDetails
|
||||
authReq *CurrentAuthRequest
|
||||
wantErr error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"authRequest not existing",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(),
|
||||
),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
id: "foo",
|
||||
reason: domain.OIDCErrorReasonLoginRequired,
|
||||
},
|
||||
res{
|
||||
wantErr: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-Sx202nt", "Errors.AuthRequest.AlreadyHandled"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"failed",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"clientID",
|
||||
"redirectURI",
|
||||
"state",
|
||||
"nonce",
|
||||
[]string{"openid"},
|
||||
[]string{"audience"},
|
||||
domain.OIDCResponseTypeCode,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
),
|
||||
),
|
||||
),
|
||||
expectPush(
|
||||
[]*repository.Event{eventFromEventPusherWithInstanceID(
|
||||
"instanceID",
|
||||
authrequest.NewFailedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||
domain.OIDCErrorReasonLoginRequired),
|
||||
)}),
|
||||
),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
id: "V2_id",
|
||||
reason: domain.OIDCErrorReasonLoginRequired,
|
||||
},
|
||||
res{
|
||||
details: &domain.ObjectDetails{ResourceOwner: "instanceID"},
|
||||
authReq: &CurrentAuthRequest{
|
||||
AuthRequest: &AuthRequest{
|
||||
ID: "V2_id",
|
||||
LoginClient: "loginClient",
|
||||
ClientID: "clientID",
|
||||
RedirectURI: "redirectURI",
|
||||
State: "state",
|
||||
Nonce: "nonce",
|
||||
Scope: []string{"openid"},
|
||||
Audience: []string{"audience"},
|
||||
ResponseType: domain.OIDCResponseTypeCode,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore,
|
||||
}
|
||||
details, got, err := c.FailAuthRequest(tt.args.ctx, tt.args.id, tt.args.reason)
|
||||
require.ErrorIs(t, err, tt.res.wantErr)
|
||||
assert.Equal(t, tt.res.details, details)
|
||||
assert.Equal(t, tt.res.authReq, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommands_AddAuthRequestCode(t *testing.T) {
|
||||
mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient")
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
id string
|
||||
code string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
"empty code error",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
id: "V2_authRequestID",
|
||||
code: "",
|
||||
},
|
||||
caos_errs.ThrowPreconditionFailed(nil, "COMMAND-Ht52d", "Errors.AuthRequest.InvalidCode"),
|
||||
},
|
||||
{
|
||||
"no session linked error",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"clientID",
|
||||
"redirectURI",
|
||||
"state",
|
||||
"nonce",
|
||||
[]string{"openid"},
|
||||
[]string{"audience"},
|
||||
domain.OIDCResponseTypeCode,
|
||||
&domain.OIDCCodeChallenge{
|
||||
Challenge: "challenge",
|
||||
Method: domain.CodeChallengeMethodS256,
|
||||
},
|
||||
[]domain.Prompt{domain.PromptNone},
|
||||
[]string{"en", "de"},
|
||||
gu.Ptr(time.Duration(0)),
|
||||
gu.Ptr("loginHint"),
|
||||
gu.Ptr("hintUserID"),
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
id: "V2_authRequestID",
|
||||
code: "V2_authRequestID",
|
||||
},
|
||||
caos_errs.ThrowPreconditionFailed(nil, "COMMAND-SFwd2", "Errors.AuthRequest.AlreadyHandled"),
|
||||
},
|
||||
{
|
||||
"success",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"clientID",
|
||||
"redirectURI",
|
||||
"state",
|
||||
"nonce",
|
||||
[]string{"openid"},
|
||||
[]string{"audience"},
|
||||
domain.OIDCResponseTypeCode,
|
||||
&domain.OIDCCodeChallenge{
|
||||
Challenge: "challenge",
|
||||
Method: domain.CodeChallengeMethodS256,
|
||||
},
|
||||
[]domain.Prompt{domain.PromptNone},
|
||||
[]string{"en", "de"},
|
||||
gu.Ptr(time.Duration(0)),
|
||||
gu.Ptr("loginHint"),
|
||||
gu.Ptr("hintUserID"),
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
authrequest.NewSessionLinkedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
|
||||
"sessionID",
|
||||
"userID",
|
||||
testNow,
|
||||
[]string{amr.PWD},
|
||||
),
|
||||
),
|
||||
),
|
||||
expectPush(
|
||||
[]*repository.Event{
|
||||
eventFromEventPusherWithInstanceID("instanceID",
|
||||
authrequest.NewCodeAddedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
|
||||
),
|
||||
},
|
||||
),
|
||||
),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
id: "V2_authRequestID",
|
||||
code: "V2_authRequestID",
|
||||
},
|
||||
nil,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore,
|
||||
}
|
||||
err := c.AddAuthRequestCode(tt.args.ctx, tt.args.id, tt.args.code)
|
||||
assert.ErrorIs(t, tt.wantErr, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommands_ExchangeAuthCode(t *testing.T) {
|
||||
mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient")
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
code string
|
||||
}
|
||||
type res struct {
|
||||
authRequest *CurrentAuthRequest
|
||||
err error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"empty code error",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
code: "",
|
||||
},
|
||||
res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-Sf3g2", "Errors.AuthRequest.InvalidCode"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"no code added error",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"clientID",
|
||||
"redirectURI",
|
||||
"state",
|
||||
"nonce",
|
||||
[]string{"openid"},
|
||||
[]string{"audience"},
|
||||
domain.OIDCResponseTypeCode,
|
||||
&domain.OIDCCodeChallenge{
|
||||
Challenge: "challenge",
|
||||
Method: domain.CodeChallengeMethodS256,
|
||||
},
|
||||
[]domain.Prompt{domain.PromptNone},
|
||||
[]string{"en", "de"},
|
||||
gu.Ptr(time.Duration(0)),
|
||||
gu.Ptr("loginHint"),
|
||||
gu.Ptr("hintUserID"),
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
code: "V2_authRequestID",
|
||||
},
|
||||
res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-SFwd2", "Errors.AuthRequest.NoCode"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"code exchanged",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"clientID",
|
||||
"redirectURI",
|
||||
"state",
|
||||
"nonce",
|
||||
[]string{"openid"},
|
||||
[]string{"audience"},
|
||||
domain.OIDCResponseTypeCode,
|
||||
&domain.OIDCCodeChallenge{
|
||||
Challenge: "challenge",
|
||||
Method: domain.CodeChallengeMethodS256,
|
||||
},
|
||||
[]domain.Prompt{domain.PromptNone},
|
||||
[]string{"en", "de"},
|
||||
gu.Ptr(time.Duration(0)),
|
||||
gu.Ptr("loginHint"),
|
||||
gu.Ptr("hintUserID"),
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
authrequest.NewSessionLinkedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
|
||||
"sessionID",
|
||||
"userID",
|
||||
testNow,
|
||||
[]string{amr.PWD},
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
authrequest.NewCodeAddedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
|
||||
),
|
||||
),
|
||||
expectPush(
|
||||
[]*repository.Event{
|
||||
eventFromEventPusherWithInstanceID("instanceID",
|
||||
authrequest.NewCodeExchangedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
|
||||
),
|
||||
},
|
||||
),
|
||||
),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
code: "V2_authRequestID",
|
||||
},
|
||||
res{
|
||||
authRequest: &CurrentAuthRequest{
|
||||
AuthRequest: &AuthRequest{
|
||||
ID: "V2_authRequestID",
|
||||
LoginClient: "loginClient",
|
||||
ClientID: "clientID",
|
||||
RedirectURI: "redirectURI",
|
||||
State: "state",
|
||||
Nonce: "nonce",
|
||||
Scope: []string{"openid"},
|
||||
Audience: []string{"audience"},
|
||||
ResponseType: domain.OIDCResponseTypeCode,
|
||||
CodeChallenge: &domain.OIDCCodeChallenge{
|
||||
Challenge: "challenge",
|
||||
Method: domain.CodeChallengeMethodS256,
|
||||
},
|
||||
Prompt: []domain.Prompt{domain.PromptNone},
|
||||
UILocales: []string{"en", "de"},
|
||||
MaxAge: gu.Ptr(time.Duration(0)),
|
||||
LoginHint: gu.Ptr("loginHint"),
|
||||
HintUserID: gu.Ptr("hintUserID"),
|
||||
},
|
||||
SessionID: "sessionID",
|
||||
UserID: "userID",
|
||||
AMR: []string{"pwd"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore,
|
||||
}
|
||||
got, err := c.ExchangeAuthCode(tt.args.ctx, tt.args.code)
|
||||
assert.ErrorIs(t, tt.res.err, err)
|
||||
|
||||
if err == nil {
|
||||
// equal on time won't work -> test separately and clear it before comparing the rest
|
||||
assert.WithinRange(t, got.AuthTime, testNow, testNow)
|
||||
got.AuthTime = time.Time{}
|
||||
}
|
||||
assert.Equal(t, tt.res.authRequest, got)
|
||||
})
|
||||
}
|
||||
}
|
@ -15,10 +15,12 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/id"
|
||||
"github.com/zitadel/zitadel/internal/repository/action"
|
||||
"github.com/zitadel/zitadel/internal/repository/authrequest"
|
||||
"github.com/zitadel/zitadel/internal/repository/idpintent"
|
||||
instance_repo "github.com/zitadel/zitadel/internal/repository/instance"
|
||||
"github.com/zitadel/zitadel/internal/repository/keypair"
|
||||
"github.com/zitadel/zitadel/internal/repository/milestone"
|
||||
"github.com/zitadel/zitadel/internal/repository/oidcsession"
|
||||
"github.com/zitadel/zitadel/internal/repository/org"
|
||||
proj_repo "github.com/zitadel/zitadel/internal/repository/project"
|
||||
"github.com/zitadel/zitadel/internal/repository/quota"
|
||||
@ -43,18 +45,21 @@ type Commands struct {
|
||||
externalSecure bool
|
||||
externalPort uint16
|
||||
|
||||
idpConfigEncryption crypto.EncryptionAlgorithm
|
||||
smtpEncryption crypto.EncryptionAlgorithm
|
||||
smsEncryption crypto.EncryptionAlgorithm
|
||||
userEncryption crypto.EncryptionAlgorithm
|
||||
userPasswordAlg crypto.HashAlgorithm
|
||||
machineKeySize int
|
||||
applicationKeySize int
|
||||
domainVerificationAlg crypto.EncryptionAlgorithm
|
||||
domainVerificationGenerator crypto.Generator
|
||||
domainVerificationValidator func(domain, token, verifier string, checkType api_http.CheckType) error
|
||||
sessionTokenCreator func(sessionID string) (id string, token string, err error)
|
||||
sessionTokenVerifier func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error)
|
||||
idpConfigEncryption crypto.EncryptionAlgorithm
|
||||
smtpEncryption crypto.EncryptionAlgorithm
|
||||
smsEncryption crypto.EncryptionAlgorithm
|
||||
userEncryption crypto.EncryptionAlgorithm
|
||||
userPasswordAlg crypto.HashAlgorithm
|
||||
machineKeySize int
|
||||
applicationKeySize int
|
||||
domainVerificationAlg crypto.EncryptionAlgorithm
|
||||
domainVerificationGenerator crypto.Generator
|
||||
domainVerificationValidator func(domain, token, verifier string, checkType api_http.CheckType) error
|
||||
sessionTokenCreator func(sessionID string) (id string, token string, err error)
|
||||
sessionTokenVerifier func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error)
|
||||
defaultAccessTokenLifetime time.Duration
|
||||
defaultRefreshTokenLifetime time.Duration
|
||||
defaultRefreshTokenIdleLifetime time.Duration
|
||||
|
||||
multifactors domain.MultifactorConfigs
|
||||
webauthnConfig *webauthn_helper.Config
|
||||
@ -80,6 +85,9 @@ func StartCommands(
|
||||
httpClient *http.Client,
|
||||
permissionCheck domain.PermissionCheck,
|
||||
sessionTokenVerifier func(ctx context.Context, sessionToken string, sessionID string, tokenID string) (err error),
|
||||
defaultAccessTokenLifetime,
|
||||
defaultRefreshTokenLifetime,
|
||||
defaultRefreshTokenIdleLifetime time.Duration,
|
||||
) (repo *Commands, err error) {
|
||||
if externalDomain == "" {
|
||||
return nil, errors.ThrowInvalidArgument(nil, "COMMAND-Df21s", "no external domain specified")
|
||||
@ -88,31 +96,34 @@ func StartCommands(
|
||||
// reuse the oidcEncryption to be able to handle both tokens in the interceptor later on
|
||||
sessionAlg := oidcEncryption
|
||||
repo = &Commands{
|
||||
eventstore: es,
|
||||
static: staticStore,
|
||||
idGenerator: idGenerator,
|
||||
zitadelRoles: zitadelRoles,
|
||||
externalDomain: externalDomain,
|
||||
externalSecure: externalSecure,
|
||||
externalPort: externalPort,
|
||||
keySize: defaults.KeyConfig.Size,
|
||||
certKeySize: defaults.KeyConfig.CertificateSize,
|
||||
privateKeyLifetime: defaults.KeyConfig.PrivateKeyLifetime,
|
||||
publicKeyLifetime: defaults.KeyConfig.PublicKeyLifetime,
|
||||
certificateLifetime: defaults.KeyConfig.CertificateLifetime,
|
||||
idpConfigEncryption: idpConfigEncryption,
|
||||
smtpEncryption: smtpEncryption,
|
||||
smsEncryption: smsEncryption,
|
||||
userEncryption: userEncryption,
|
||||
domainVerificationAlg: domainVerificationEncryption,
|
||||
keyAlgorithm: oidcEncryption,
|
||||
certificateAlgorithm: samlEncryption,
|
||||
webauthnConfig: webAuthN,
|
||||
httpClient: httpClient,
|
||||
checkPermission: permissionCheck,
|
||||
newCode: newCryptoCode,
|
||||
sessionTokenCreator: sessionTokenCreator(idGenerator, sessionAlg),
|
||||
sessionTokenVerifier: sessionTokenVerifier,
|
||||
eventstore: es,
|
||||
static: staticStore,
|
||||
idGenerator: idGenerator,
|
||||
zitadelRoles: zitadelRoles,
|
||||
externalDomain: externalDomain,
|
||||
externalSecure: externalSecure,
|
||||
externalPort: externalPort,
|
||||
keySize: defaults.KeyConfig.Size,
|
||||
certKeySize: defaults.KeyConfig.CertificateSize,
|
||||
privateKeyLifetime: defaults.KeyConfig.PrivateKeyLifetime,
|
||||
publicKeyLifetime: defaults.KeyConfig.PublicKeyLifetime,
|
||||
certificateLifetime: defaults.KeyConfig.CertificateLifetime,
|
||||
idpConfigEncryption: idpConfigEncryption,
|
||||
smtpEncryption: smtpEncryption,
|
||||
smsEncryption: smsEncryption,
|
||||
userEncryption: userEncryption,
|
||||
domainVerificationAlg: domainVerificationEncryption,
|
||||
keyAlgorithm: oidcEncryption,
|
||||
certificateAlgorithm: samlEncryption,
|
||||
webauthnConfig: webAuthN,
|
||||
httpClient: httpClient,
|
||||
checkPermission: permissionCheck,
|
||||
newCode: newCryptoCode,
|
||||
sessionTokenCreator: sessionTokenCreator(idGenerator, sessionAlg),
|
||||
sessionTokenVerifier: sessionTokenVerifier,
|
||||
defaultAccessTokenLifetime: defaultAccessTokenLifetime,
|
||||
defaultRefreshTokenLifetime: defaultRefreshTokenLifetime,
|
||||
defaultRefreshTokenIdleLifetime: defaultRefreshTokenIdleLifetime,
|
||||
}
|
||||
|
||||
instance_repo.RegisterEventMappers(repo.eventstore)
|
||||
@ -125,6 +136,8 @@ func StartCommands(
|
||||
quota.RegisterEventMappers(repo.eventstore)
|
||||
session.RegisterEventMappers(repo.eventstore)
|
||||
idpintent.RegisterEventMappers(repo.eventstore)
|
||||
authrequest.RegisterEventMappers(repo.eventstore)
|
||||
oidcsession.RegisterEventMappers(repo.eventstore)
|
||||
milestone.RegisterEventMappers(repo.eventstore)
|
||||
|
||||
repo.userPasswordAlg = crypto.NewBCrypt(defaults.SecretGenerators.PasswordSaltCost)
|
||||
|
@ -16,9 +16,11 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/eventstore/repository"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/repository/mock"
|
||||
action_repo "github.com/zitadel/zitadel/internal/repository/action"
|
||||
"github.com/zitadel/zitadel/internal/repository/authrequest"
|
||||
"github.com/zitadel/zitadel/internal/repository/idpintent"
|
||||
iam_repo "github.com/zitadel/zitadel/internal/repository/instance"
|
||||
key_repo "github.com/zitadel/zitadel/internal/repository/keypair"
|
||||
"github.com/zitadel/zitadel/internal/repository/oidcsession"
|
||||
"github.com/zitadel/zitadel/internal/repository/org"
|
||||
proj_repo "github.com/zitadel/zitadel/internal/repository/project"
|
||||
"github.com/zitadel/zitadel/internal/repository/session"
|
||||
@ -43,6 +45,8 @@ func eventstoreExpect(t *testing.T, expects ...expect) *eventstore.Eventstore {
|
||||
action_repo.RegisterEventMappers(es)
|
||||
session.RegisterEventMappers(es)
|
||||
idpintent.RegisterEventMappers(es)
|
||||
authrequest.RegisterEventMappers(es)
|
||||
oidcsession.RegisterEventMappers(es)
|
||||
return es
|
||||
}
|
||||
|
||||
|
281
internal/command/oidc_session.go
Normal file
281
internal/command/oidc_session.go
Normal file
@ -0,0 +1,281 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/api/oidc/amr"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
caos_errs "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/id"
|
||||
"github.com/zitadel/zitadel/internal/repository/authrequest"
|
||||
"github.com/zitadel/zitadel/internal/repository/oidcsession"
|
||||
)
|
||||
|
||||
// AddOIDCSessionAccessToken creates a new OIDC Session, creates an access token and returns its id and expiration.
|
||||
// If the underlying [AuthRequest] is a OIDC Auth Code Flow, it will set the code as exchanged.
|
||||
func (c *Commands) AddOIDCSessionAccessToken(ctx context.Context, authRequestID string) (string, time.Time, error) {
|
||||
cmd, err := c.newOIDCSessionAddEvents(ctx, authRequestID)
|
||||
if err != nil {
|
||||
return "", time.Time{}, err
|
||||
}
|
||||
cmd.AddSession(ctx)
|
||||
if err = cmd.AddAccessToken(ctx, cmd.authRequestWriteModel.Scope); err != nil {
|
||||
return "", time.Time{}, err
|
||||
}
|
||||
cmd.SetAuthRequestSuccessful(ctx)
|
||||
accessTokenID, _, accessTokenExpiration, err := cmd.PushEvents(ctx)
|
||||
return accessTokenID, accessTokenExpiration, err
|
||||
}
|
||||
|
||||
// AddOIDCSessionRefreshAndAccessToken creates a new OIDC Session, creates an access token and refresh token.
|
||||
// It returns the access token id, expiration and the refresh token.
|
||||
// If the underlying [AuthRequest] is a OIDC Auth Code Flow, it will set the code as exchanged.
|
||||
func (c *Commands) AddOIDCSessionRefreshAndAccessToken(ctx context.Context, authRequestID string) (tokenID, refreshToken string, tokenExpiration time.Time, err error) {
|
||||
cmd, err := c.newOIDCSessionAddEvents(ctx, authRequestID)
|
||||
if err != nil {
|
||||
return "", "", time.Time{}, err
|
||||
}
|
||||
cmd.AddSession(ctx)
|
||||
if err = cmd.AddAccessToken(ctx, cmd.authRequestWriteModel.Scope); err != nil {
|
||||
return "", "", time.Time{}, err
|
||||
}
|
||||
if err = cmd.AddRefreshToken(ctx); err != nil {
|
||||
return "", "", time.Time{}, err
|
||||
}
|
||||
cmd.SetAuthRequestSuccessful(ctx)
|
||||
return cmd.PushEvents(ctx)
|
||||
}
|
||||
|
||||
// ExchangeOIDCSessionRefreshAndAccessToken updates an existing OIDC Session, creates a new access and refresh token.
|
||||
// It returns the access token id and expiration and the new refresh token.
|
||||
func (c *Commands) ExchangeOIDCSessionRefreshAndAccessToken(ctx context.Context, oidcSessionID, refreshToken string, scope []string) (tokenID, newRefreshToken string, tokenExpiration time.Time, err error) {
|
||||
cmd, err := c.newOIDCSessionUpdateEvents(ctx, oidcSessionID, refreshToken)
|
||||
if err != nil {
|
||||
return "", "", time.Time{}, err
|
||||
}
|
||||
if err = cmd.AddAccessToken(ctx, scope); err != nil {
|
||||
return "", "", time.Time{}, err
|
||||
}
|
||||
if err = cmd.RenewRefreshToken(ctx); err != nil {
|
||||
return "", "", time.Time{}, err
|
||||
}
|
||||
return cmd.PushEvents(ctx)
|
||||
}
|
||||
|
||||
// OIDCSessionByRefreshToken computes the current state of an existing OIDCSession by a refresh_token (to start a Refresh Token Grant).
|
||||
// If either the session is not active, the token is invalid or expired (incl. idle expiration) an invalid refresh token error will be returned.
|
||||
func (c *Commands) OIDCSessionByRefreshToken(ctx context.Context, refreshToken string) (*OIDCSessionWriteModel, error) {
|
||||
split := strings.Split(refreshToken, ":")
|
||||
if len(split) != 2 {
|
||||
return nil, caos_errs.ThrowPreconditionFailed(nil, "OIDCS-JOI23", "Errors.OIDCSession.RefreshTokenInvalid")
|
||||
}
|
||||
writeModel := NewOIDCSessionWriteModel(split[0], "")
|
||||
err := c.eventstore.FilterToQueryReducer(ctx, writeModel)
|
||||
if err != nil {
|
||||
return nil, caos_errs.ThrowPreconditionFailed(err, "OIDCS-SAF31", "Errors.OIDCSession.RefreshTokenInvalid")
|
||||
}
|
||||
if err = writeModel.CheckRefreshToken(split[1]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return writeModel, nil
|
||||
}
|
||||
|
||||
func (c *Commands) newOIDCSessionAddEvents(ctx context.Context, authRequestID string) (*OIDCSessionEvents, error) {
|
||||
authRequestWriteModel, err := c.getAuthRequestWriteModel(ctx, authRequestID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = authRequestWriteModel.CheckAuthenticated(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sessionWriteModel := NewSessionWriteModel(authRequestWriteModel.SessionID, authz.GetCtxData(ctx).OrgID)
|
||||
err = c.eventstore.FilterToQueryReducer(ctx, sessionWriteModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if sessionWriteModel.State != domain.SessionStateActive {
|
||||
return nil, caos_errs.ThrowPreconditionFailed(nil, "OIDCS-sjkl3", "Errors.Session.Terminated")
|
||||
}
|
||||
accessTokenLifetime, refreshTokenLifeTime, refreshTokenIdleLifetime, err := c.tokenTokenLifetimes(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sessionID, err := c.idGenerator.Next()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sessionID = IDPrefixV2 + sessionID
|
||||
return &OIDCSessionEvents{
|
||||
eventstore: c.eventstore,
|
||||
idGenerator: c.idGenerator,
|
||||
encryptionAlg: c.keyAlgorithm,
|
||||
oidcSessionWriteModel: NewOIDCSessionWriteModel(sessionID, authz.GetInstance(ctx).InstanceID()),
|
||||
sessionWriteModel: sessionWriteModel,
|
||||
authRequestWriteModel: authRequestWriteModel,
|
||||
accessTokenLifetime: accessTokenLifetime,
|
||||
refreshTokenLifeTime: refreshTokenLifeTime,
|
||||
refreshTokenIdleLifetime: refreshTokenIdleLifetime,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Commands) decryptRefreshToken(refreshToken string) (refreshTokenID string, err error) {
|
||||
decoded, err := base64.RawURLEncoding.DecodeString(refreshToken)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
decrypted, err := c.keyAlgorithm.DecryptString(decoded, c.keyAlgorithm.EncryptionKeyID())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
split := strings.Split(decrypted, ":")
|
||||
if len(split) != 2 {
|
||||
return "", caos_errs.ThrowPreconditionFailed(nil, "OIDCS-Sj3lk", "Errors.OIDCSession.RefreshTokenInvalid")
|
||||
}
|
||||
return split[1], nil
|
||||
}
|
||||
|
||||
func (c *Commands) newOIDCSessionUpdateEvents(ctx context.Context, oidcSessionID, refreshToken string) (*OIDCSessionEvents, error) {
|
||||
refreshTokenID, err := c.decryptRefreshToken(refreshToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sessionWriteModel := NewOIDCSessionWriteModel(oidcSessionID, authz.GetInstance(ctx).InstanceID())
|
||||
if err = c.eventstore.FilterToQueryReducer(ctx, sessionWriteModel); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = sessionWriteModel.CheckRefreshToken(refreshTokenID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
accessTokenLifetime, refreshTokenLifeTime, refreshTokenIdleLifetime, err := c.tokenTokenLifetimes(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &OIDCSessionEvents{
|
||||
eventstore: c.eventstore,
|
||||
idGenerator: c.idGenerator,
|
||||
encryptionAlg: c.keyAlgorithm,
|
||||
oidcSessionWriteModel: sessionWriteModel,
|
||||
accessTokenLifetime: accessTokenLifetime,
|
||||
refreshTokenLifeTime: refreshTokenLifeTime,
|
||||
refreshTokenIdleLifetime: refreshTokenIdleLifetime,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type OIDCSessionEvents struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
idGenerator id.Generator
|
||||
encryptionAlg crypto.EncryptionAlgorithm
|
||||
events []eventstore.Command
|
||||
oidcSessionWriteModel *OIDCSessionWriteModel
|
||||
sessionWriteModel *SessionWriteModel
|
||||
authRequestWriteModel *AuthRequestWriteModel
|
||||
accessTokenLifetime time.Duration
|
||||
refreshTokenLifeTime time.Duration
|
||||
refreshTokenIdleLifetime time.Duration
|
||||
|
||||
// accessTokenID is set by the command
|
||||
accessTokenID string
|
||||
|
||||
// refreshToken is set by the command
|
||||
refreshToken string
|
||||
}
|
||||
|
||||
func (c *OIDCSessionEvents) AddSession(ctx context.Context) {
|
||||
c.events = append(c.events, oidcsession.NewAddedEvent(
|
||||
ctx,
|
||||
c.oidcSessionWriteModel.aggregate,
|
||||
c.sessionWriteModel.UserID,
|
||||
c.sessionWriteModel.AggregateID,
|
||||
c.authRequestWriteModel.ClientID,
|
||||
c.authRequestWriteModel.Audience,
|
||||
c.authRequestWriteModel.Scope,
|
||||
amr.List(c.sessionWriteModel),
|
||||
c.sessionWriteModel.AuthenticationTime(),
|
||||
))
|
||||
}
|
||||
|
||||
func (c *OIDCSessionEvents) SetAuthRequestSuccessful(ctx context.Context) {
|
||||
c.events = append(c.events, authrequest.NewSucceededEvent(ctx, c.authRequestWriteModel.aggregate))
|
||||
}
|
||||
|
||||
func (c *OIDCSessionEvents) AddAccessToken(ctx context.Context, scope []string) (err error) {
|
||||
c.accessTokenID, err = c.idGenerator.Next()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.events = append(c.events, oidcsession.NewAccessTokenAddedEvent(ctx, c.oidcSessionWriteModel.aggregate, c.accessTokenID, scope, c.accessTokenLifetime))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OIDCSessionEvents) AddRefreshToken(ctx context.Context) (err error) {
|
||||
var refreshTokenID string
|
||||
refreshTokenID, c.refreshToken, err = c.generateRefreshToken()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.events = append(c.events, oidcsession.NewRefreshTokenAddedEvent(ctx, c.oidcSessionWriteModel.aggregate, refreshTokenID, c.refreshTokenLifeTime, c.refreshTokenIdleLifetime))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OIDCSessionEvents) RenewRefreshToken(ctx context.Context) (err error) {
|
||||
var refreshTokenID string
|
||||
refreshTokenID, c.refreshToken, err = c.generateRefreshToken()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.events = append(c.events, oidcsession.NewRefreshTokenRenewedEvent(ctx, c.oidcSessionWriteModel.aggregate, refreshTokenID, c.refreshTokenIdleLifetime))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OIDCSessionEvents) generateRefreshToken() (refreshTokenID, refreshToken string, err error) {
|
||||
refreshTokenID, err = c.idGenerator.Next()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
token, err := c.encryptionAlg.Encrypt([]byte(c.oidcSessionWriteModel.AggregateID + ":" + refreshTokenID))
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return refreshTokenID, base64.RawURLEncoding.EncodeToString(token), nil
|
||||
}
|
||||
|
||||
func (c *OIDCSessionEvents) PushEvents(ctx context.Context) (accessTokenID string, refreshToken string, accessTokenExpiration time.Time, err error) {
|
||||
pushedEvents, err := c.eventstore.Push(ctx, c.events...)
|
||||
if err != nil {
|
||||
return "", "", time.Time{}, err
|
||||
}
|
||||
err = AppendAndReduce(c.oidcSessionWriteModel, pushedEvents...)
|
||||
if err != nil {
|
||||
return "", "", time.Time{}, err
|
||||
}
|
||||
// prefix the returned id with the oidcSessionID so that we can retrieve it later on
|
||||
// we need to use `-` as a delimiter because the OIDC library uses `:` and will check for a length of 2 parts
|
||||
return c.oidcSessionWriteModel.AggregateID + "-" + c.accessTokenID, c.refreshToken, c.oidcSessionWriteModel.AccessTokenExpiration, nil
|
||||
}
|
||||
|
||||
func (c *Commands) tokenTokenLifetimes(ctx context.Context) (accessTokenLifetime time.Duration, refreshTokenLifetime time.Duration, refreshTokenIdleLifetime time.Duration, err error) {
|
||||
oidcSettings := NewInstanceOIDCSettingsWriteModel(ctx)
|
||||
err = c.eventstore.FilterToQueryReducer(ctx, oidcSettings)
|
||||
if err != nil {
|
||||
return 0, 0, 0, err
|
||||
}
|
||||
accessTokenLifetime = c.defaultAccessTokenLifetime
|
||||
refreshTokenLifetime = c.defaultRefreshTokenLifetime
|
||||
refreshTokenIdleLifetime = c.defaultRefreshTokenIdleLifetime
|
||||
if oidcSettings.AccessTokenLifetime > 0 {
|
||||
accessTokenLifetime = oidcSettings.AccessTokenLifetime
|
||||
}
|
||||
if oidcSettings.RefreshTokenExpiration > 0 {
|
||||
refreshTokenLifetime = oidcSettings.RefreshTokenExpiration
|
||||
}
|
||||
if oidcSettings.RefreshTokenIdleExpiration > 0 {
|
||||
refreshTokenIdleLifetime = oidcSettings.RefreshTokenIdleExpiration
|
||||
}
|
||||
return accessTokenLifetime, refreshTokenLifetime, refreshTokenIdleLifetime, nil
|
||||
}
|
114
internal/command/oidc_session_model.go
Normal file
114
internal/command/oidc_session_model.go
Normal file
@ -0,0 +1,114 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
caos_errs "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/repository/oidcsession"
|
||||
)
|
||||
|
||||
type OIDCSessionWriteModel struct {
|
||||
eventstore.WriteModel
|
||||
|
||||
UserID string
|
||||
SessionID string
|
||||
ClientID string
|
||||
Audience []string
|
||||
Scope []string
|
||||
AuthMethodsReferences []string
|
||||
AuthTime time.Time
|
||||
State domain.OIDCSessionState
|
||||
AccessTokenExpiration time.Time
|
||||
RefreshTokenID string
|
||||
RefreshTokenExpiration time.Time
|
||||
RefreshTokenIdleExpiration time.Time
|
||||
|
||||
aggregate *eventstore.Aggregate
|
||||
}
|
||||
|
||||
func NewOIDCSessionWriteModel(id string, resourceOwner string) *OIDCSessionWriteModel {
|
||||
return &OIDCSessionWriteModel{
|
||||
WriteModel: eventstore.WriteModel{
|
||||
AggregateID: id,
|
||||
ResourceOwner: resourceOwner,
|
||||
},
|
||||
aggregate: &oidcsession.NewAggregate(id, resourceOwner).Aggregate,
|
||||
}
|
||||
}
|
||||
|
||||
func (wm *OIDCSessionWriteModel) Reduce() error {
|
||||
for _, event := range wm.Events {
|
||||
switch e := event.(type) {
|
||||
case *oidcsession.AddedEvent:
|
||||
wm.reduceAdded(e)
|
||||
case *oidcsession.AccessTokenAddedEvent:
|
||||
wm.reduceAccessTokenAdded(e)
|
||||
case *oidcsession.RefreshTokenAddedEvent:
|
||||
wm.reduceRefreshTokenAdded(e)
|
||||
case *oidcsession.RefreshTokenRenewedEvent:
|
||||
wm.reduceRefreshTokenRenewed(e)
|
||||
}
|
||||
}
|
||||
return wm.WriteModel.Reduce()
|
||||
}
|
||||
|
||||
func (wm *OIDCSessionWriteModel) Query() *eventstore.SearchQueryBuilder {
|
||||
query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
|
||||
AddQuery().
|
||||
AggregateTypes(oidcsession.AggregateType).
|
||||
AggregateIDs(wm.AggregateID).
|
||||
EventTypes(
|
||||
oidcsession.AddedType,
|
||||
oidcsession.AccessTokenAddedType,
|
||||
oidcsession.RefreshTokenAddedType,
|
||||
oidcsession.RefreshTokenRenewedType,
|
||||
).
|
||||
Builder()
|
||||
|
||||
if wm.ResourceOwner != "" {
|
||||
query.ResourceOwner(wm.ResourceOwner)
|
||||
}
|
||||
return query
|
||||
}
|
||||
|
||||
func (wm *OIDCSessionWriteModel) reduceAdded(e *oidcsession.AddedEvent) {
|
||||
wm.UserID = e.UserID
|
||||
wm.SessionID = e.SessionID
|
||||
wm.ClientID = e.ClientID
|
||||
wm.Audience = e.Audience
|
||||
wm.Scope = e.Scope
|
||||
wm.AuthMethodsReferences = e.AuthMethodsReferences
|
||||
wm.AuthTime = e.AuthTime
|
||||
wm.State = domain.OIDCSessionStateActive
|
||||
}
|
||||
|
||||
func (wm *OIDCSessionWriteModel) reduceAccessTokenAdded(e *oidcsession.AccessTokenAddedEvent) {
|
||||
wm.AccessTokenExpiration = e.CreationDate().Add(e.Lifetime)
|
||||
}
|
||||
|
||||
func (wm *OIDCSessionWriteModel) reduceRefreshTokenAdded(e *oidcsession.RefreshTokenAddedEvent) {
|
||||
wm.RefreshTokenID = e.ID
|
||||
wm.RefreshTokenExpiration = e.CreationDate().Add(e.Lifetime)
|
||||
wm.RefreshTokenIdleExpiration = e.CreationDate().Add(e.IdleLifetime)
|
||||
}
|
||||
|
||||
func (wm *OIDCSessionWriteModel) reduceRefreshTokenRenewed(e *oidcsession.RefreshTokenRenewedEvent) {
|
||||
wm.RefreshTokenID = e.ID
|
||||
wm.RefreshTokenIdleExpiration = e.CreationDate().Add(e.IdleLifetime)
|
||||
}
|
||||
|
||||
func (wm *OIDCSessionWriteModel) CheckRefreshToken(refreshTokenID string) error {
|
||||
if wm.State != domain.OIDCSessionStateActive {
|
||||
return caos_errs.ThrowPreconditionFailed(nil, "OIDCS-s3hjk", "Errors.OIDCSession.RefreshTokenInvalid")
|
||||
}
|
||||
if wm.RefreshTokenID != refreshTokenID {
|
||||
return caos_errs.ThrowPreconditionFailed(nil, "OIDCS-28ubl", "Errors.OIDCSession.RefreshTokenInvalid")
|
||||
}
|
||||
now := time.Now()
|
||||
if wm.RefreshTokenExpiration.Before(now) || wm.RefreshTokenIdleExpiration.Before(now) {
|
||||
return caos_errs.ThrowPreconditionFailed(nil, "OIDCS-3jt2w", "Errors.OIDCSession.RefreshTokenInvalid")
|
||||
}
|
||||
return nil
|
||||
}
|
795
internal/command/oidc_session_test.go
Normal file
795
internal/command/oidc_session_test.go
Normal file
@ -0,0 +1,795 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/api/oidc/amr"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
caos_errs "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/repository"
|
||||
"github.com/zitadel/zitadel/internal/id"
|
||||
"github.com/zitadel/zitadel/internal/id/mock"
|
||||
"github.com/zitadel/zitadel/internal/repository/authrequest"
|
||||
"github.com/zitadel/zitadel/internal/repository/oidcsession"
|
||||
"github.com/zitadel/zitadel/internal/repository/session"
|
||||
)
|
||||
|
||||
var (
|
||||
testNow = time.Now()
|
||||
tokenCreationNow = time.Time{}
|
||||
)
|
||||
|
||||
func TestCommands_AddOIDCSessionAccessToken(t *testing.T) {
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
idGenerator id.Generator
|
||||
defaultAccessTokenLifetime time.Duration
|
||||
defaultRefreshTokenLifetime time.Duration
|
||||
defaultRefreshTokenIdleLifetime time.Duration
|
||||
keyAlgorithm crypto.EncryptionAlgorithm
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
authRequestID string
|
||||
}
|
||||
type res struct {
|
||||
id string
|
||||
expiration time.Time
|
||||
err error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"unauthenticated error",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(),
|
||||
),
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
authRequestID: "V2_authRequestID",
|
||||
},
|
||||
res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "AUTHR-SF2r2", "Errors.AuthRequest.NotAuthenticated"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"inactive session error",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
authrequest.NewAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"clientID",
|
||||
"redirectURI",
|
||||
"state",
|
||||
"nonce",
|
||||
[]string{"openid"},
|
||||
[]string{"audience"},
|
||||
domain.OIDCResponseTypeCode,
|
||||
&domain.OIDCCodeChallenge{
|
||||
Challenge: "challenge",
|
||||
Method: domain.CodeChallengeMethodS256,
|
||||
},
|
||||
[]domain.Prompt{domain.PromptNone},
|
||||
[]string{"en", "de"},
|
||||
gu.Ptr(time.Duration(0)),
|
||||
gu.Ptr("loginHint"),
|
||||
gu.Ptr("hintUserID"),
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
authrequest.NewSessionLinkedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
|
||||
"sessionID",
|
||||
"userID",
|
||||
testNow,
|
||||
[]string{amr.PWD},
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
authrequest.NewCodeAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
authrequest.NewCodeExchangedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
|
||||
),
|
||||
),
|
||||
expectFilter(),
|
||||
),
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
authRequestID: "V2_authRequestID",
|
||||
},
|
||||
res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-sjkl3", "Errors.Session.Terminated"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"add successful",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
authrequest.NewAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"clientID",
|
||||
"redirectURI",
|
||||
"state",
|
||||
"nonce",
|
||||
[]string{"openid"},
|
||||
[]string{"audience"},
|
||||
domain.OIDCResponseTypeCode,
|
||||
&domain.OIDCCodeChallenge{
|
||||
Challenge: "challenge",
|
||||
Method: domain.CodeChallengeMethodS256,
|
||||
},
|
||||
[]domain.Prompt{domain.PromptNone},
|
||||
[]string{"en", "de"},
|
||||
gu.Ptr(time.Duration(0)),
|
||||
gu.Ptr("loginHint"),
|
||||
gu.Ptr("hintUserID"),
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
authrequest.NewSessionLinkedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
|
||||
"sessionID",
|
||||
"userID",
|
||||
testNow,
|
||||
[]string{amr.PWD},
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
authrequest.NewCodeAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
authrequest.NewCodeExchangedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
|
||||
),
|
||||
),
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate, "domain.tld"),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
session.NewUserCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate,
|
||||
"userID", testNow),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
session.NewPasswordCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate,
|
||||
testNow),
|
||||
),
|
||||
),
|
||||
expectFilter(), // token lifetime
|
||||
expectPush(
|
||||
[]*repository.Event{
|
||||
eventFromEventPusherWithInstanceID("instanceID",
|
||||
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid"}, []string{amr.PWD}, testNow),
|
||||
),
|
||||
eventFromEventPusherWithInstanceID("instanceID",
|
||||
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"accessTokenID", []string{"openid"}, time.Hour),
|
||||
),
|
||||
eventFromEventPusherWithInstanceID("instanceID",
|
||||
authrequest.NewSucceededEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
|
||||
),
|
||||
},
|
||||
),
|
||||
),
|
||||
idGenerator: mock.NewIDGeneratorExpectIDs(t, "oidcSessionID", "accessTokenID"),
|
||||
defaultAccessTokenLifetime: time.Hour,
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
authRequestID: "V2_authRequestID",
|
||||
},
|
||||
res{
|
||||
id: "V2_oidcSessionID-accessTokenID",
|
||||
expiration: tokenCreationNow.Add(time.Hour),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore,
|
||||
idGenerator: tt.fields.idGenerator,
|
||||
defaultAccessTokenLifetime: tt.fields.defaultAccessTokenLifetime,
|
||||
defaultRefreshTokenLifetime: tt.fields.defaultRefreshTokenLifetime,
|
||||
defaultRefreshTokenIdleLifetime: tt.fields.defaultRefreshTokenIdleLifetime,
|
||||
keyAlgorithm: tt.fields.keyAlgorithm,
|
||||
}
|
||||
gotID, gotExpiration, err := c.AddOIDCSessionAccessToken(tt.args.ctx, tt.args.authRequestID)
|
||||
assert.Equal(t, tt.res.id, gotID)
|
||||
assert.Equal(t, tt.res.expiration, gotExpiration)
|
||||
assert.ErrorIs(t, err, tt.res.err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommands_AddOIDCSessionRefreshAndAccessToken(t *testing.T) {
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
idGenerator id.Generator
|
||||
defaultAccessTokenLifetime time.Duration
|
||||
defaultRefreshTokenLifetime time.Duration
|
||||
defaultRefreshTokenIdleLifetime time.Duration
|
||||
keyAlgorithm crypto.EncryptionAlgorithm
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
authRequestID string
|
||||
}
|
||||
type res struct {
|
||||
id string
|
||||
refreshToken string
|
||||
expiration time.Time
|
||||
err error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"unauthenticated error",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(),
|
||||
),
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
authRequestID: "V2_authRequestID",
|
||||
},
|
||||
res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "AUTHR-SF2r2", "Errors.AuthRequest.NotAuthenticated"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"inactive session error",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
authrequest.NewAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"clientID",
|
||||
"redirectURI",
|
||||
"state",
|
||||
"nonce",
|
||||
[]string{"openid", "offline_access"},
|
||||
[]string{"audience"},
|
||||
domain.OIDCResponseTypeCode,
|
||||
&domain.OIDCCodeChallenge{
|
||||
Challenge: "challenge",
|
||||
Method: domain.CodeChallengeMethodS256,
|
||||
},
|
||||
[]domain.Prompt{domain.PromptNone},
|
||||
[]string{"en", "de"},
|
||||
gu.Ptr(time.Duration(0)),
|
||||
gu.Ptr("loginHint"),
|
||||
gu.Ptr("hintUserID"),
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
authrequest.NewSessionLinkedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
|
||||
"sessionID",
|
||||
"userID",
|
||||
testNow,
|
||||
[]string{amr.PWD},
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
authrequest.NewCodeAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
authrequest.NewCodeExchangedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
|
||||
),
|
||||
),
|
||||
expectFilter(),
|
||||
),
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
authRequestID: "V2_authRequestID",
|
||||
},
|
||||
res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-sjkl3", "Errors.Session.Terminated"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"add successful",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
authrequest.NewAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"clientID",
|
||||
"redirectURI",
|
||||
"state",
|
||||
"nonce",
|
||||
[]string{"openid", "offline_access"},
|
||||
[]string{"audience"},
|
||||
domain.OIDCResponseTypeCode,
|
||||
&domain.OIDCCodeChallenge{
|
||||
Challenge: "challenge",
|
||||
Method: domain.CodeChallengeMethodS256,
|
||||
},
|
||||
[]domain.Prompt{domain.PromptNone},
|
||||
[]string{"en", "de"},
|
||||
gu.Ptr(time.Duration(0)),
|
||||
gu.Ptr("loginHint"),
|
||||
gu.Ptr("hintUserID"),
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
authrequest.NewSessionLinkedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
|
||||
"sessionID",
|
||||
"userID",
|
||||
testNow,
|
||||
[]string{amr.PWD},
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
authrequest.NewCodeAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
authrequest.NewCodeExchangedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
|
||||
),
|
||||
),
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate, "domain.tld"),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
session.NewUserCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate,
|
||||
"userID", testNow),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
session.NewPasswordCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate,
|
||||
testNow),
|
||||
),
|
||||
),
|
||||
expectFilter(), // token lifetime
|
||||
expectPush(
|
||||
[]*repository.Event{
|
||||
eventFromEventPusherWithInstanceID("instanceID",
|
||||
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "offline_access"}, []string{amr.PWD}, testNow),
|
||||
),
|
||||
eventFromEventPusherWithInstanceID("instanceID",
|
||||
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"accessTokenID", []string{"openid", "offline_access"}, time.Hour),
|
||||
),
|
||||
eventFromEventPusherWithInstanceID("instanceID",
|
||||
oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"refreshTokenID", 7*24*time.Hour, 24*time.Hour),
|
||||
),
|
||||
eventFromEventPusherWithInstanceID("instanceID",
|
||||
authrequest.NewSucceededEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
|
||||
),
|
||||
},
|
||||
),
|
||||
),
|
||||
idGenerator: mock.NewIDGeneratorExpectIDs(t, "oidcSessionID", "accessTokenID", "refreshTokenID"),
|
||||
defaultAccessTokenLifetime: time.Hour,
|
||||
defaultRefreshTokenLifetime: 7 * 24 * time.Hour,
|
||||
defaultRefreshTokenIdleLifetime: 24 * time.Hour,
|
||||
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
authRequestID: "V2_authRequestID",
|
||||
},
|
||||
res{
|
||||
id: "V2_oidcSessionID-accessTokenID",
|
||||
refreshToken: "VjJfb2lkY1Nlc3Npb25JRDpyZWZyZXNoVG9rZW5JRA", //V2_oidcSessionID:refreshTokenID
|
||||
expiration: tokenCreationNow.Add(time.Hour),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore,
|
||||
idGenerator: tt.fields.idGenerator,
|
||||
defaultAccessTokenLifetime: tt.fields.defaultAccessTokenLifetime,
|
||||
defaultRefreshTokenLifetime: tt.fields.defaultRefreshTokenLifetime,
|
||||
defaultRefreshTokenIdleLifetime: tt.fields.defaultRefreshTokenIdleLifetime,
|
||||
keyAlgorithm: tt.fields.keyAlgorithm,
|
||||
}
|
||||
gotID, gotRefreshToken, gotExpiration, err := c.AddOIDCSessionRefreshAndAccessToken(tt.args.ctx, tt.args.authRequestID)
|
||||
assert.Equal(t, tt.res.id, gotID)
|
||||
assert.Equal(t, tt.res.refreshToken, gotRefreshToken)
|
||||
assert.Equal(t, tt.res.expiration, gotExpiration)
|
||||
assert.ErrorIs(t, err, tt.res.err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommands_ExchangeOIDCSessionRefreshAndAccessToken(t *testing.T) {
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
idGenerator id.Generator
|
||||
defaultAccessTokenLifetime time.Duration
|
||||
defaultRefreshTokenLifetime time.Duration
|
||||
defaultRefreshTokenIdleLifetime time.Duration
|
||||
keyAlgorithm crypto.EncryptionAlgorithm
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
oidcSessionID string
|
||||
refreshToken string
|
||||
scope []string
|
||||
}
|
||||
type res struct {
|
||||
id string
|
||||
refreshToken string
|
||||
expiration time.Time
|
||||
err error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"invalid refresh token format error",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t),
|
||||
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
oidcSessionID: "V2_oidcSessionID",
|
||||
refreshToken: "aW52YWxpZA",
|
||||
},
|
||||
res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-Sj3lk", "Errors.OIDCSession.RefreshTokenInvalid"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"inactive session error",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(),
|
||||
),
|
||||
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
oidcSessionID: "V2_oidcSessionID",
|
||||
refreshToken: "VjJfb2lkY1Nlc3Npb25JRDpyZWZyZXNoVG9rZW5JRA",
|
||||
},
|
||||
res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-s3hjk", "Errors.OIDCSession.RefreshTokenInvalid"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"invalid refresh token error",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour),
|
||||
),
|
||||
),
|
||||
),
|
||||
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
oidcSessionID: "V2_oidcSessionID",
|
||||
refreshToken: "VjJfb2lkY1Nlc3Npb25JRDpyZWZyZXNoVG9rZW5JRA",
|
||||
},
|
||||
res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-28ubl", "Errors.OIDCSession.RefreshTokenInvalid"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"expired refresh token error",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"refreshTokenID", 7*24*time.Hour, 24*time.Hour),
|
||||
),
|
||||
),
|
||||
),
|
||||
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
oidcSessionID: "V2_oidcSessionID",
|
||||
refreshToken: "VjJfb2lkY1Nlc3Npb25JRDpyZWZyZXNoVG9rZW5JRA",
|
||||
},
|
||||
res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-3jt2w", "Errors.OIDCSession.RefreshTokenInvalid"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"refresh successful",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusherWithCreationDateNow(
|
||||
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow),
|
||||
),
|
||||
eventFromEventPusherWithCreationDateNow(
|
||||
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour),
|
||||
),
|
||||
eventFromEventPusherWithCreationDateNow(
|
||||
oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"refreshTokenID", 7*24*time.Hour, 24*time.Hour),
|
||||
),
|
||||
),
|
||||
expectFilter(), // token lifetime
|
||||
expectPush(
|
||||
[]*repository.Event{
|
||||
eventFromEventPusherWithInstanceID("instanceID",
|
||||
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"accessTokenID", []string{"openid", "offline_access"}, time.Hour),
|
||||
),
|
||||
eventFromEventPusherWithInstanceID("instanceID",
|
||||
oidcsession.NewRefreshTokenRenewedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"refreshTokenID2", 24*time.Hour),
|
||||
),
|
||||
},
|
||||
),
|
||||
),
|
||||
idGenerator: mock.NewIDGeneratorExpectIDs(t, "accessTokenID", "refreshTokenID2"),
|
||||
defaultAccessTokenLifetime: time.Hour,
|
||||
defaultRefreshTokenLifetime: 7 * 24 * time.Hour,
|
||||
defaultRefreshTokenIdleLifetime: 24 * time.Hour,
|
||||
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
oidcSessionID: "V2_oidcSessionID",
|
||||
refreshToken: "VjJfb2lkY1Nlc3Npb25JRDpyZWZyZXNoVG9rZW5JRA",
|
||||
scope: []string{"openid", "offline_access"},
|
||||
},
|
||||
res{
|
||||
id: "V2_oidcSessionID-accessTokenID",
|
||||
refreshToken: "VjJfb2lkY1Nlc3Npb25JRDpyZWZyZXNoVG9rZW5JRDI",
|
||||
expiration: time.Time{}.Add(time.Hour),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore,
|
||||
idGenerator: tt.fields.idGenerator,
|
||||
defaultAccessTokenLifetime: tt.fields.defaultAccessTokenLifetime,
|
||||
defaultRefreshTokenLifetime: tt.fields.defaultRefreshTokenLifetime,
|
||||
defaultRefreshTokenIdleLifetime: tt.fields.defaultRefreshTokenIdleLifetime,
|
||||
keyAlgorithm: tt.fields.keyAlgorithm,
|
||||
}
|
||||
gotID, gotRefreshToken, gotExpiration, err := c.ExchangeOIDCSessionRefreshAndAccessToken(tt.args.ctx, tt.args.oidcSessionID, tt.args.refreshToken, tt.args.scope)
|
||||
assert.Equal(t, tt.res.id, gotID)
|
||||
assert.Equal(t, tt.res.refreshToken, gotRefreshToken)
|
||||
assert.Equal(t, tt.res.expiration, gotExpiration)
|
||||
assert.ErrorIs(t, err, tt.res.err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommands_OIDCSessionByRefreshToken(t *testing.T) {
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
idGenerator id.Generator
|
||||
defaultAccessTokenLifetime time.Duration
|
||||
defaultRefreshTokenLifetime time.Duration
|
||||
defaultRefreshTokenIdleLifetime time.Duration
|
||||
keyAlgorithm crypto.EncryptionAlgorithm
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
refreshToken string
|
||||
}
|
||||
type res struct {
|
||||
model *OIDCSessionWriteModel
|
||||
err error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"invalid refresh token format error",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t),
|
||||
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
refreshToken: "invalid",
|
||||
},
|
||||
res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-JOI23", "Errors.OIDCSession.RefreshTokenInvalid"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"inactive session error",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(),
|
||||
),
|
||||
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
refreshToken: "V2_oidcSessionID:refreshTokenID",
|
||||
},
|
||||
res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-s3hjk", "Errors.OIDCSession.RefreshTokenInvalid"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"invalid refresh token error",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour),
|
||||
),
|
||||
),
|
||||
),
|
||||
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
refreshToken: "V2_oidcSessionID:refreshTokenID",
|
||||
},
|
||||
res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-28ubl", "Errors.OIDCSession.RefreshTokenInvalid"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"expired refresh token error",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"refreshTokenID", 7*24*time.Hour, 24*time.Hour),
|
||||
),
|
||||
),
|
||||
),
|
||||
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
refreshToken: "V2_oidcSessionID:refreshTokenID",
|
||||
},
|
||||
res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-3jt2w", "Errors.OIDCSession.RefreshTokenInvalid"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"get successful",
|
||||
fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
expectFilter(
|
||||
eventFromEventPusherWithCreationDateNow(
|
||||
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow),
|
||||
),
|
||||
eventFromEventPusherWithCreationDateNow(
|
||||
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour),
|
||||
),
|
||||
eventFromEventPusherWithCreationDateNow(
|
||||
oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
|
||||
"refreshTokenID", 7*24*time.Hour, 24*time.Hour),
|
||||
),
|
||||
),
|
||||
),
|
||||
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
refreshToken: "V2_oidcSessionID:refreshTokenID",
|
||||
},
|
||||
res{
|
||||
model: &OIDCSessionWriteModel{
|
||||
WriteModel: eventstore.WriteModel{
|
||||
AggregateID: "V2_oidcSessionID",
|
||||
ChangeDate: testNow,
|
||||
},
|
||||
UserID: "userID",
|
||||
SessionID: "sessionID",
|
||||
ClientID: "clientID",
|
||||
Audience: []string{"audience"},
|
||||
Scope: []string{"openid", "profile", "offline_access"},
|
||||
AuthMethodsReferences: []string{amr.PWD},
|
||||
AuthTime: testNow,
|
||||
State: domain.OIDCSessionStateActive,
|
||||
RefreshTokenID: "refreshTokenID",
|
||||
RefreshTokenExpiration: testNow.Add(7 * 24 * time.Hour),
|
||||
RefreshTokenIdleExpiration: testNow.Add(24 * time.Hour),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore,
|
||||
idGenerator: tt.fields.idGenerator,
|
||||
defaultAccessTokenLifetime: tt.fields.defaultAccessTokenLifetime,
|
||||
defaultRefreshTokenLifetime: tt.fields.defaultRefreshTokenLifetime,
|
||||
defaultRefreshTokenIdleLifetime: tt.fields.defaultRefreshTokenIdleLifetime,
|
||||
keyAlgorithm: tt.fields.keyAlgorithm,
|
||||
}
|
||||
got, err := c.OIDCSessionByRefreshToken(tt.args.ctx, tt.args.refreshToken)
|
||||
require.ErrorIs(t, err, tt.res.err)
|
||||
if tt.res.err == nil {
|
||||
assert.WithinRange(t, got.ChangeDate, tt.res.model.ChangeDate.Add(-2*time.Second), tt.res.model.ChangeDate.Add(2*time.Second))
|
||||
assert.Equal(t, tt.res.model.AggregateID, got.AggregateID)
|
||||
assert.Equal(t, tt.res.model.UserID, got.UserID)
|
||||
assert.Equal(t, tt.res.model.SessionID, got.SessionID)
|
||||
assert.Equal(t, tt.res.model.ClientID, got.ClientID)
|
||||
assert.Equal(t, tt.res.model.Audience, got.Audience)
|
||||
assert.Equal(t, tt.res.model.Scope, got.Scope)
|
||||
assert.Equal(t, tt.res.model.AuthMethodsReferences, got.AuthMethodsReferences)
|
||||
assert.WithinRange(t, got.AuthTime, tt.res.model.AuthTime.Add(-2*time.Second), tt.res.model.AuthTime.Add(2*time.Second))
|
||||
assert.Equal(t, tt.res.model.State, got.State)
|
||||
assert.Equal(t, tt.res.model.RefreshTokenID, got.RefreshTokenID)
|
||||
assert.WithinRange(t, got.RefreshTokenExpiration, tt.res.model.RefreshTokenExpiration.Add(-2*time.Second), tt.res.model.RefreshTokenExpiration.Add(2*time.Second))
|
||||
assert.WithinRange(t, got.RefreshTokenIdleExpiration, tt.res.model.RefreshTokenIdleExpiration.Add(-2*time.Second), tt.res.model.RefreshTokenIdleExpiration.Add(2*time.Second))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -52,6 +52,24 @@ type SessionWriteModel struct {
|
||||
aggregate *eventstore.Aggregate
|
||||
}
|
||||
|
||||
func (wm *SessionWriteModel) IsPasswordChecked() bool {
|
||||
return !wm.PasswordCheckedAt.IsZero()
|
||||
}
|
||||
|
||||
func (wm *SessionWriteModel) IsPasskeyChecked() bool {
|
||||
return !wm.PasskeyCheckedAt.IsZero()
|
||||
}
|
||||
|
||||
func (wm *SessionWriteModel) IsU2FChecked() bool {
|
||||
// TODO: implement with https://github.com/zitadel/zitadel/issues/5477
|
||||
return false
|
||||
}
|
||||
|
||||
func (wm *SessionWriteModel) IsOTPChecked() bool {
|
||||
// TODO: implement with https://github.com/zitadel/zitadel/issues/5477
|
||||
return false
|
||||
}
|
||||
|
||||
func NewSessionWriteModel(sessionID string, resourceOwner string) *SessionWriteModel {
|
||||
return &SessionWriteModel{
|
||||
WriteModel: eventstore.WriteModel{
|
||||
@ -210,3 +228,19 @@ func (wm *SessionWriteModel) ChangeMetadata(ctx context.Context, metadata map[st
|
||||
wm.commands = append(wm.commands, session.NewMetadataSetEvent(ctx, wm.aggregate, wm.Metadata))
|
||||
}
|
||||
}
|
||||
|
||||
// AuthenticationTime returns the time the user authenticated using the latest time of all checks
|
||||
func (wm *SessionWriteModel) AuthenticationTime() time.Time {
|
||||
var authTime time.Time
|
||||
for _, check := range []time.Time{
|
||||
wm.PasswordCheckedAt,
|
||||
wm.PasskeyCheckedAt,
|
||||
wm.IntentCheckedAt,
|
||||
// TODO: add U2F and OTP check https://github.com/zitadel/zitadel/issues/5477
|
||||
} {
|
||||
if check.After(authTime) {
|
||||
authTime = check
|
||||
}
|
||||
}
|
||||
return authTime
|
||||
}
|
||||
|
@ -528,7 +528,7 @@ func TestCommands_createHumanOTP(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCommands_HumanCheckMFAOTPSetup(t *testing.T) {
|
||||
ctx := authz.NewMockContext("inst1", "org1", "user1")
|
||||
ctx := authz.NewMockContext("", "org1", "user1")
|
||||
|
||||
cryptoAlg := crypto.CreateMockEncryptionAlg(gomock.NewController(t))
|
||||
key, secret, err := domain.NewOTPKey("example.com", "user1", cryptoAlg)
|
||||
|
@ -188,7 +188,7 @@ func TestCommands_AddUserTOTP(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCommands_CheckUserTOTP(t *testing.T) {
|
||||
ctx := authz.NewMockContext("inst1", "org1", "user1")
|
||||
ctx := authz.NewMockContext("", "org1", "user1")
|
||||
|
||||
cryptoAlg := crypto.CreateMockEncryptionAlg(gomock.NewController(t))
|
||||
key, secret, err := domain.NewOTPKey("example.com", "user1", cryptoAlg)
|
||||
|
@ -116,6 +116,17 @@ const (
|
||||
MFALevelMultiFactorCertified
|
||||
)
|
||||
|
||||
type AuthRequestState int
|
||||
|
||||
const (
|
||||
AuthRequestStateUnspecified AuthRequestState = iota
|
||||
AuthRequestStateAdded
|
||||
AuthRequestStateCodeAdded
|
||||
AuthRequestStateCodeExchanged
|
||||
AuthRequestStateFailed
|
||||
AuthRequestStateSucceeded
|
||||
)
|
||||
|
||||
func NewAuthRequestFromType(requestType AuthRequestType) (*AuthRequest, error) {
|
||||
switch requestType {
|
||||
case AuthRequestTypeOIDC:
|
||||
|
23
internal/domain/oidc_error_reason.go
Normal file
23
internal/domain/oidc_error_reason.go
Normal file
@ -0,0 +1,23 @@
|
||||
package domain
|
||||
|
||||
type OIDCErrorReason int32
|
||||
|
||||
const (
|
||||
OIDCErrorReasonUnspecified OIDCErrorReason = iota
|
||||
OIDCErrorReasonInvalidRequest
|
||||
OIDCErrorReasonUnauthorizedClient
|
||||
OIDCErrorReasonAccessDenied
|
||||
OIDCErrorReasonUnsupportedResponseType
|
||||
OIDCErrorReasonInvalidScope
|
||||
OIDCErrorReasonServerError
|
||||
OIDCErrorReasonTemporaryUnavailable
|
||||
OIDCErrorReasonInteractionRequired
|
||||
OIDCErrorReasonLoginRequired
|
||||
OIDCErrorReasonAccountSelectionRequired
|
||||
OIDCErrorReasonConsentRequired
|
||||
OIDCErrorReasonInvalidRequestURI
|
||||
OIDCErrorReasonInvalidRequestObject
|
||||
OIDCErrorReasonRequestNotSupported
|
||||
OIDCErrorReasonRequestURINotSupported
|
||||
OIDCErrorReasonRegistrationNotSupported
|
||||
)
|
9
internal/domain/oidc_session.go
Normal file
9
internal/domain/oidc_session.go
Normal file
@ -0,0 +1,9 @@
|
||||
package domain
|
||||
|
||||
type OIDCSessionState int32
|
||||
|
||||
const (
|
||||
OIDCSessionStateUnspecified OIDCSessionState = iota
|
||||
OIDCSessionStateActive
|
||||
OIDCSessionStateTerminated
|
||||
)
|
@ -19,6 +19,7 @@ import (
|
||||
"github.com/zitadel/zitadel/pkg/grpc/admin"
|
||||
mgmt "github.com/zitadel/zitadel/pkg/grpc/management"
|
||||
object "github.com/zitadel/zitadel/pkg/grpc/object/v2alpha"
|
||||
oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2alpha"
|
||||
session "github.com/zitadel/zitadel/pkg/grpc/session/v2alpha"
|
||||
"github.com/zitadel/zitadel/pkg/grpc/system"
|
||||
user "github.com/zitadel/zitadel/pkg/grpc/user/v2alpha"
|
||||
@ -30,6 +31,7 @@ type Client struct {
|
||||
Mgmt mgmt.ManagementServiceClient
|
||||
UserV2 user.UserServiceClient
|
||||
SessionV2 session.SessionServiceClient
|
||||
OIDCv2 oidc_pb.OIDCServiceClient
|
||||
System system.SystemServiceClient
|
||||
}
|
||||
|
||||
@ -40,6 +42,7 @@ func newClient(cc *grpc.ClientConn) Client {
|
||||
Mgmt: mgmt.NewManagementServiceClient(cc),
|
||||
UserV2: user.NewUserServiceClient(cc),
|
||||
SessionV2: session.NewSessionServiceClient(cc),
|
||||
OIDCv2: oidc_pb.NewOIDCServiceClient(cc),
|
||||
System: system.NewSystemServiceClient(cc),
|
||||
}
|
||||
}
|
||||
@ -62,11 +65,9 @@ func (t *Tester) UseIsolatedInstance(iamOwnerCtx, systemCtx context.Context) (pr
|
||||
}
|
||||
t.createClientConn(iamOwnerCtx, grpc.WithAuthority(primaryDomain))
|
||||
instanceId = instance.GetInstanceId()
|
||||
t.Users[instanceId] = map[UserType]User{
|
||||
IAMOwner: {
|
||||
Token: instance.GetPat(),
|
||||
},
|
||||
}
|
||||
t.Users.Set(instanceId, IAMOwner, &User{
|
||||
Token: instance.GetPat(),
|
||||
})
|
||||
return primaryDomain, instanceId, t.WithInstanceAuthorization(iamOwnerCtx, IAMOwner, instanceId)
|
||||
}
|
||||
|
||||
@ -187,3 +188,34 @@ func (s *Tester) CreateSuccessfulIntent(t *testing.T, idpID, userID, idpUserID s
|
||||
require.NoError(t, err)
|
||||
return intentID, token, writeModel.ChangeDate, writeModel.ProcessedSequence
|
||||
}
|
||||
|
||||
func (s *Tester) CreatePasskeySession(t *testing.T, ctx context.Context, userID string) (id, token string, start, change time.Time) {
|
||||
createResp, err := s.Client.SessionV2.CreateSession(ctx, &session.CreateSessionRequest{
|
||||
Checks: &session.Checks{
|
||||
User: &session.CheckUser{
|
||||
Search: &session.CheckUser_UserId{UserId: userID},
|
||||
},
|
||||
},
|
||||
Challenges: []session.ChallengeKind{
|
||||
session.ChallengeKind_CHALLENGE_KIND_PASSKEY,
|
||||
},
|
||||
Domain: s.Config.ExternalDomain,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
assertion, err := s.WebAuthN.CreateAssertionResponse(createResp.GetChallenges().GetPasskey().GetPublicKeyCredentialRequestOptions())
|
||||
require.NoError(t, err)
|
||||
|
||||
updateResp, err := s.Client.SessionV2.SetSession(ctx, &session.SetSessionRequest{
|
||||
SessionId: createResp.GetSessionId(),
|
||||
SessionToken: createResp.GetSessionToken(),
|
||||
Checks: &session.Checks{
|
||||
Passkey: &session.CheckPasskey{
|
||||
CredentialAssertionData: assertion,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return createResp.GetSessionId(), updateResp.GetSessionToken(),
|
||||
createResp.GetDetails().GetChangeDate().AsTime(), updateResp.GetDetails().GetChangeDate().AsTime()
|
||||
}
|
||||
|
@ -1,6 +1,8 @@
|
||||
Log:
|
||||
Level: debug
|
||||
|
||||
ExternalSecure: false
|
||||
|
||||
TLS:
|
||||
Enabled: false
|
||||
|
||||
|
@ -57,6 +57,7 @@ type UserType int
|
||||
const (
|
||||
Unspecified UserType = iota
|
||||
OrgOwner
|
||||
Login
|
||||
IAMOwner
|
||||
SystemUser // SystemUser is a user with access to the system service.
|
||||
)
|
||||
@ -71,13 +72,29 @@ type User struct {
|
||||
Token string
|
||||
}
|
||||
|
||||
type InstanceUserMap map[string]map[UserType]*User
|
||||
|
||||
func (m InstanceUserMap) Set(instanceID string, typ UserType, user *User) {
|
||||
if m[instanceID] == nil {
|
||||
m[instanceID] = make(map[UserType]*User)
|
||||
}
|
||||
m[instanceID][typ] = user
|
||||
}
|
||||
|
||||
func (m InstanceUserMap) Get(instanceID string, typ UserType) *User {
|
||||
if users, ok := m[instanceID]; ok {
|
||||
return users[typ]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Tester is a Zitadel server and client with all resources available for testing.
|
||||
type Tester struct {
|
||||
*start.Server
|
||||
|
||||
Instance authz.Instance
|
||||
Organisation *query.Org
|
||||
Users map[string]map[UserType]User
|
||||
Users InstanceUserMap
|
||||
|
||||
Client Client
|
||||
WebAuthN *webauthn.Client
|
||||
@ -135,6 +152,7 @@ func (s *Tester) pollHealth(ctx context.Context) (err error) {
|
||||
}
|
||||
|
||||
const (
|
||||
LoginUser = "loginClient"
|
||||
MachineUser = "integration"
|
||||
)
|
||||
|
||||
@ -148,10 +166,9 @@ func (s *Tester) createMachineUser(ctx context.Context, instanceId string) {
|
||||
s.Organisation, err = s.Queries.OrgByID(ctx, true, s.Instance.DefaultOrganisationID())
|
||||
logging.OnError(err).Fatal("query organisation")
|
||||
|
||||
query, err := query.NewUserUsernameSearchQuery(MachineUser, query.TextEquals)
|
||||
usernameQuery, err := query.NewUserUsernameSearchQuery(MachineUser, query.TextEquals)
|
||||
logging.OnError(err).Fatal("user query")
|
||||
user, err := s.Queries.GetUser(ctx, true, true, query)
|
||||
|
||||
user, err := s.Queries.GetUser(ctx, true, true, usernameQuery)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
_, err = s.Commands.AddMachine(ctx, &command.Machine{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
@ -162,11 +179,10 @@ func (s *Tester) createMachineUser(ctx context.Context, instanceId string) {
|
||||
Description: "who cares?",
|
||||
AccessTokenType: domain.OIDCTokenTypeJWT,
|
||||
})
|
||||
logging.OnError(err).Fatal("add machine user")
|
||||
user, err = s.Queries.GetUser(ctx, true, true, query)
|
||||
|
||||
logging.WithFields("username", SystemUser).OnError(err).Fatal("add machine user")
|
||||
user, err = s.Queries.GetUser(ctx, true, true, usernameQuery)
|
||||
}
|
||||
logging.OnError(err).Fatal("get user")
|
||||
logging.WithFields("username", SystemUser).OnError(err).Fatal("get user")
|
||||
|
||||
_, err = s.Commands.AddOrgMember(ctx, s.Organisation.ID, user.ID, "ORG_OWNER")
|
||||
target := new(caos_errs.AlreadyExistsError)
|
||||
@ -177,18 +193,50 @@ func (s *Tester) createMachineUser(ctx context.Context, instanceId string) {
|
||||
scopes := []string{oidc.ScopeOpenID, oidc.ScopeProfile, z_oidc.ScopeUserMetaData, z_oidc.ScopeResourceOwner}
|
||||
pat := command.NewPersonalAccessToken(user.ResourceOwner, user.ID, time.Now().Add(time.Hour), scopes, domain.UserTypeMachine)
|
||||
_, err = s.Commands.AddPersonalAccessToken(ctx, pat)
|
||||
logging.OnError(err).Fatal("add pat")
|
||||
|
||||
if s.Users == nil {
|
||||
s.Users = make(map[string]map[UserType]User)
|
||||
}
|
||||
if s.Users[instanceId] == nil {
|
||||
s.Users[instanceId] = make(map[UserType]User)
|
||||
}
|
||||
s.Users[instanceId][OrgOwner] = User{
|
||||
logging.WithFields("username", SystemUser).OnError(err).Fatal("add pat")
|
||||
s.Users.Set(instanceId, OrgOwner, &User{
|
||||
User: user,
|
||||
Token: pat.Token,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Tester) createLoginClient(ctx context.Context) {
|
||||
var err error
|
||||
|
||||
s.Instance, err = s.Queries.InstanceByHost(ctx, s.Host())
|
||||
logging.OnError(err).Fatal("query instance")
|
||||
ctx = authz.WithInstance(ctx, s.Instance)
|
||||
|
||||
s.Organisation, err = s.Queries.OrgByID(ctx, true, s.Instance.DefaultOrganisationID())
|
||||
logging.OnError(err).Fatal("query organisation")
|
||||
|
||||
usernameQuery, err := query.NewUserUsernameSearchQuery(LoginUser, query.TextEquals)
|
||||
logging.WithFields("username", LoginUser).OnError(err).Fatal("user query")
|
||||
user, err := s.Queries.GetUser(ctx, true, true, usernameQuery)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
_, err = s.Commands.AddMachine(ctx, &command.Machine{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
ResourceOwner: s.Organisation.ID,
|
||||
},
|
||||
Username: LoginUser,
|
||||
Name: LoginUser,
|
||||
Description: "who cares?",
|
||||
AccessTokenType: domain.OIDCTokenTypeJWT,
|
||||
})
|
||||
logging.WithFields("username", LoginUser).OnError(err).Fatal("add machine user")
|
||||
user, err = s.Queries.GetUser(ctx, true, true, usernameQuery)
|
||||
}
|
||||
logging.WithFields("username", LoginUser).OnError(err).Fatal("get user")
|
||||
|
||||
scopes := []string{oidc.ScopeOpenID, z_oidc.ScopeUserMetaData, z_oidc.ScopeResourceOwner}
|
||||
pat := command.NewPersonalAccessToken(user.ResourceOwner, user.ID, time.Now().Add(time.Hour), scopes, domain.UserTypeMachine)
|
||||
_, err = s.Commands.AddPersonalAccessToken(ctx, pat)
|
||||
logging.OnError(err).Fatal("add pat")
|
||||
|
||||
s.Users.Set(FirstInstanceUsersKey, Login, &User{
|
||||
User: user,
|
||||
Token: pat.Token,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Tester) WithAuthorization(ctx context.Context, u UserType) context.Context {
|
||||
@ -199,15 +247,12 @@ func (s *Tester) WithInstanceAuthorization(ctx context.Context, u UserType, inst
|
||||
if u == SystemUser {
|
||||
s.ensureSystemUser()
|
||||
}
|
||||
return metadata.AppendToOutgoingContext(ctx, "Authorization", fmt.Sprintf("Bearer %s", s.Users[instanceID][u].Token))
|
||||
return metadata.AppendToOutgoingContext(ctx, "Authorization", fmt.Sprintf("Bearer %s", s.Users.Get(instanceID, u).Token))
|
||||
}
|
||||
|
||||
func (s *Tester) ensureSystemUser() {
|
||||
const ISSUER = "tester"
|
||||
if s.Users[FirstInstanceUsersKey] == nil {
|
||||
s.Users[FirstInstanceUsersKey] = make(map[UserType]User)
|
||||
}
|
||||
if _, ok := s.Users[FirstInstanceUsersKey][SystemUser]; ok {
|
||||
if s.Users.Get(FirstInstanceUsersKey, SystemUser) != nil {
|
||||
return
|
||||
}
|
||||
audience := http_util.BuildOrigin(s.Host(), s.Server.Config.ExternalSecure)
|
||||
@ -215,7 +260,11 @@ func (s *Tester) ensureSystemUser() {
|
||||
logging.OnError(err).Fatal("system key signer")
|
||||
jwt, err := client.SignedJWTProfileAssertion(ISSUER, []string{audience}, time.Hour, signer)
|
||||
logging.OnError(err).Fatal("system key jwt")
|
||||
s.Users[FirstInstanceUsersKey][SystemUser] = User{Token: jwt}
|
||||
s.Users.Set(FirstInstanceUsersKey, SystemUser, &User{Token: jwt})
|
||||
}
|
||||
|
||||
func (s *Tester) WithSystemAuthorizationHTTP(u UserType) map[string]string {
|
||||
return map[string]string{"Authorization": fmt.Sprintf("Bearer %s", s.Users.Get(FirstInstanceUsersKey, u).Token)}
|
||||
}
|
||||
|
||||
// Done send an interrupt signal to cleanly shutdown the server.
|
||||
@ -263,9 +312,7 @@ func NewTester(ctx context.Context) *Tester {
|
||||
logging.OnError(err).Fatal()
|
||||
|
||||
tester := Tester{
|
||||
Users: map[string]map[UserType]User{
|
||||
FirstInstanceUsersKey: make(map[UserType]User),
|
||||
},
|
||||
Users: make(InstanceUserMap),
|
||||
}
|
||||
tester.wg.Add(1)
|
||||
go func(wg *sync.WaitGroup) {
|
||||
@ -279,6 +326,8 @@ func NewTester(ctx context.Context) *Tester {
|
||||
logging.OnError(ctx.Err()).Fatal("waiting for integration tester server")
|
||||
}
|
||||
tester.createClientConn(ctx)
|
||||
tester.createLoginClient(ctx)
|
||||
tester.WebAuthN = webauthn.NewClient(tester.Config.WebAuthNName, tester.Config.ExternalDomain, http_util.BuildOrigin(tester.Host(), tester.Config.ExternalSecure))
|
||||
tester.createMachineUser(ctx, FirstInstanceUsersKey)
|
||||
tester.WebAuthN = webauthn.NewClient(tester.Config.WebAuthNName, tester.Config.ExternalDomain, "https://"+tester.Host())
|
||||
|
||||
|
163
internal/integration/oidc.go
Normal file
163
internal/integration/oidc.go
Normal file
@ -0,0 +1,163 @@
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/oidc/v2/pkg/client/rp"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
|
||||
http_util "github.com/zitadel/zitadel/internal/api/http"
|
||||
oidc_internal "github.com/zitadel/zitadel/internal/api/oidc"
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
"github.com/zitadel/zitadel/pkg/grpc/app"
|
||||
"github.com/zitadel/zitadel/pkg/grpc/management"
|
||||
)
|
||||
|
||||
func (s *Tester) CreateOIDCNativeClient(ctx context.Context, redirectURI string) (*management.AddOIDCAppResponse, error) {
|
||||
project, err := s.Client.Mgmt.AddProject(ctx, &management.AddProjectRequest{
|
||||
Name: fmt.Sprintf("project-%d", time.Now().UnixNano()),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.Client.Mgmt.AddOIDCApp(ctx, &management.AddOIDCAppRequest{
|
||||
ProjectId: project.GetId(),
|
||||
Name: fmt.Sprintf("app-%d", time.Now().UnixNano()),
|
||||
RedirectUris: []string{redirectURI},
|
||||
ResponseTypes: []app.OIDCResponseType{app.OIDCResponseType_OIDC_RESPONSE_TYPE_CODE},
|
||||
GrantTypes: []app.OIDCGrantType{app.OIDCGrantType_OIDC_GRANT_TYPE_AUTHORIZATION_CODE, app.OIDCGrantType_OIDC_GRANT_TYPE_REFRESH_TOKEN},
|
||||
AppType: app.OIDCAppType_OIDC_APP_TYPE_NATIVE,
|
||||
AuthMethodType: app.OIDCAuthMethodType_OIDC_AUTH_METHOD_TYPE_NONE,
|
||||
PostLogoutRedirectUris: nil,
|
||||
Version: app.OIDCVersion_OIDC_VERSION_1_0,
|
||||
DevMode: false,
|
||||
AccessTokenType: app.OIDCTokenType_OIDC_TOKEN_TYPE_JWT,
|
||||
AccessTokenRoleAssertion: false,
|
||||
IdTokenRoleAssertion: false,
|
||||
IdTokenUserinfoAssertion: false,
|
||||
ClockSkew: nil,
|
||||
AdditionalOrigins: nil,
|
||||
SkipNativeAppSuccessPage: false,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Tester) CreateOIDCImplicitFlowClient(ctx context.Context, redirectURI string) (*management.AddOIDCAppResponse, error) {
|
||||
project, err := s.Client.Mgmt.AddProject(ctx, &management.AddProjectRequest{
|
||||
Name: fmt.Sprintf("project-%d", time.Now().UnixNano()),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.Client.Mgmt.AddOIDCApp(ctx, &management.AddOIDCAppRequest{
|
||||
ProjectId: project.GetId(),
|
||||
Name: fmt.Sprintf("app-%d", time.Now().UnixNano()),
|
||||
RedirectUris: []string{redirectURI},
|
||||
ResponseTypes: []app.OIDCResponseType{app.OIDCResponseType_OIDC_RESPONSE_TYPE_ID_TOKEN_TOKEN},
|
||||
GrantTypes: []app.OIDCGrantType{app.OIDCGrantType_OIDC_GRANT_TYPE_IMPLICIT},
|
||||
AppType: app.OIDCAppType_OIDC_APP_TYPE_USER_AGENT,
|
||||
AuthMethodType: app.OIDCAuthMethodType_OIDC_AUTH_METHOD_TYPE_NONE,
|
||||
PostLogoutRedirectUris: nil,
|
||||
Version: app.OIDCVersion_OIDC_VERSION_1_0,
|
||||
DevMode: true,
|
||||
AccessTokenType: app.OIDCTokenType_OIDC_TOKEN_TYPE_JWT,
|
||||
AccessTokenRoleAssertion: false,
|
||||
IdTokenRoleAssertion: false,
|
||||
IdTokenUserinfoAssertion: false,
|
||||
ClockSkew: nil,
|
||||
AdditionalOrigins: nil,
|
||||
SkipNativeAppSuccessPage: false,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Tester) CreateOIDCAuthRequest(clientID, loginClient, redirectURI string, scope ...string) (authRequestID string, err error) {
|
||||
provider, err := s.CreateRelyingParty(clientID, redirectURI, scope...)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
codeVerifier := "codeVerifier"
|
||||
codeChallenge := oidc.NewSHACodeChallenge(codeVerifier)
|
||||
authURL := rp.AuthURL("state", provider, rp.WithCodeChallenge(codeChallenge))
|
||||
|
||||
loc, err := CheckRedirect(authURL, map[string]string{oidc_internal.LoginClientHeader: loginClient})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
prefixWithHost := provider.Issuer() + s.Config.OIDC.DefaultLoginURLV2
|
||||
if !strings.HasPrefix(loc.String(), prefixWithHost) {
|
||||
return "", fmt.Errorf("login location has not prefix %s, but is %s", prefixWithHost, loc.String())
|
||||
}
|
||||
return strings.TrimPrefix(loc.String(), prefixWithHost), nil
|
||||
}
|
||||
|
||||
func (s *Tester) CreateOIDCAuthRequestImplicit(clientID, loginClient, redirectURI string, scope ...string) (authRequestID string, err error) {
|
||||
provider, err := s.CreateRelyingParty(clientID, redirectURI, scope...)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
authURL := rp.AuthURL("state", provider)
|
||||
|
||||
// implicit is not natively supported so let's just overwrite the response type
|
||||
parsed, _ := url.Parse(authURL)
|
||||
queries := parsed.Query()
|
||||
queries.Set("response_type", string(oidc.ResponseTypeIDToken))
|
||||
parsed.RawQuery = queries.Encode()
|
||||
authURL = parsed.String()
|
||||
|
||||
loc, err := CheckRedirect(authURL, map[string]string{oidc_internal.LoginClientHeader: loginClient})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
prefixWithHost := provider.Issuer() + s.Config.OIDC.DefaultLoginURLV2
|
||||
if !strings.HasPrefix(loc.String(), prefixWithHost) {
|
||||
return "", fmt.Errorf("login location has not prefix %s, but is %s", prefixWithHost, loc.String())
|
||||
}
|
||||
return strings.TrimPrefix(loc.String(), prefixWithHost), nil
|
||||
}
|
||||
|
||||
func (s *Tester) CreateRelyingParty(clientID, redirectURI string, scope ...string) (rp.RelyingParty, error) {
|
||||
issuer := http_util.BuildHTTP(s.Config.ExternalDomain, s.Config.Port, s.Config.ExternalSecure)
|
||||
if len(scope) == 0 {
|
||||
scope = []string{oidc.ScopeOpenID}
|
||||
}
|
||||
return rp.NewRelyingPartyOIDC(issuer, clientID, "", redirectURI, scope)
|
||||
}
|
||||
|
||||
func CheckRedirect(url string, headers map[string]string) (*url.URL, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for key, value := range headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
return resp.Location()
|
||||
}
|
||||
|
||||
func (s *Tester) CreateSession(ctx context.Context, userID string) (string, string, error) {
|
||||
session, err := s.Commands.CreateSession(ctx, []command.SessionCommand{command.CheckUser(userID)}, "domain.tld", nil)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return session.ID, session.NewToken, nil
|
||||
}
|
88
internal/query/auth_request.go
Normal file
88
internal/query/auth_request.go
Normal file
@ -0,0 +1,88 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
_ "embed"
|
||||
errs "errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/api/call"
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/query/projection"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
|
||||
type AuthRequest struct {
|
||||
ID string
|
||||
CreationDate time.Time
|
||||
LoginClient string
|
||||
ClientID string
|
||||
Scope []string
|
||||
RedirectURI string
|
||||
Prompt []domain.Prompt
|
||||
UiLocales []string
|
||||
LoginHint *string
|
||||
MaxAge *time.Duration
|
||||
HintUserID *string
|
||||
}
|
||||
|
||||
func (a *AuthRequest) checkLoginClient(ctx context.Context) error {
|
||||
if uid := authz.GetCtxData(ctx).UserID; uid != a.LoginClient {
|
||||
return errors.ThrowPermissionDenied(nil, "OIDCv2-aL0ag", "Errors.AuthRequest.WrongLoginClient")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//go:embed embed/auth_request_by_id.sql
|
||||
var authRequestByIDQuery string
|
||||
|
||||
func (q *Queries) authRequestByIDQuery(ctx context.Context) string {
|
||||
return fmt.Sprintf(authRequestByIDQuery, q.client.Timetravel(call.Took(ctx)))
|
||||
}
|
||||
|
||||
func (q *Queries) AuthRequestByID(ctx context.Context, shouldTriggerBulk bool, id string, checkLoginClient bool) (_ *AuthRequest, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
if shouldTriggerBulk {
|
||||
ctx = projection.AuthRequestProjection.Trigger(ctx)
|
||||
}
|
||||
|
||||
var (
|
||||
scope database.StringArray
|
||||
prompt database.EnumArray[domain.Prompt]
|
||||
locales database.StringArray
|
||||
)
|
||||
|
||||
dst := new(AuthRequest)
|
||||
err = q.client.DB.QueryRowContext(
|
||||
ctx, q.authRequestByIDQuery(ctx),
|
||||
id, authz.GetInstance(ctx).InstanceID(),
|
||||
).Scan(
|
||||
&dst.ID, &dst.CreationDate, &dst.LoginClient, &dst.ClientID, &scope, &dst.RedirectURI,
|
||||
&prompt, &locales, &dst.LoginHint, &dst.MaxAge, &dst.HintUserID,
|
||||
)
|
||||
if errs.Is(err, sql.ErrNoRows) {
|
||||
return nil, errors.ThrowNotFound(err, "QUERY-Thee9", "Errors.AuthRequest.NotExisting")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, errors.ThrowInternal(err, "QUERY-Ou8ue", "Errors.Internal")
|
||||
}
|
||||
|
||||
dst.Scope = scope
|
||||
dst.Prompt = prompt
|
||||
dst.UiLocales = locales
|
||||
|
||||
if checkLoginClient {
|
||||
if err = dst.checkLoginClient(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return dst, nil
|
||||
}
|
180
internal/query/auth_request_test.go
Normal file
180
internal/query/auth_request_test.go
Normal file
@ -0,0 +1,180 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/query/projection"
|
||||
)
|
||||
|
||||
func TestQueries_AuthRequestByID(t *testing.T) {
|
||||
expQuery := regexp.QuoteMeta(fmt.Sprintf(
|
||||
authRequestByIDQuery,
|
||||
asOfSystemTime,
|
||||
))
|
||||
|
||||
cols := []string{
|
||||
projection.AuthRequestColumnID,
|
||||
projection.AuthRequestColumnCreationDate,
|
||||
projection.AuthRequestColumnLoginClient,
|
||||
projection.AuthRequestColumnClientID,
|
||||
projection.AuthRequestColumnScope,
|
||||
projection.AuthRequestColumnRedirectURI,
|
||||
projection.AuthRequestColumnPrompt,
|
||||
projection.AuthRequestColumnUILocales,
|
||||
projection.AuthRequestColumnLoginHint,
|
||||
projection.AuthRequestColumnMaxAge,
|
||||
projection.AuthRequestColumnHintUserID,
|
||||
}
|
||||
type args struct {
|
||||
shouldTriggerBulk bool
|
||||
id string
|
||||
checkLoginClient bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
expect sqlExpectation
|
||||
want *AuthRequest
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "success, all values",
|
||||
args: args{
|
||||
shouldTriggerBulk: false,
|
||||
id: "123",
|
||||
checkLoginClient: true,
|
||||
},
|
||||
expect: mockQuery(expQuery, cols, []driver.Value{
|
||||
"id",
|
||||
testNow,
|
||||
"loginClient",
|
||||
"clientID",
|
||||
database.StringArray{"a", "b", "c"},
|
||||
"example.com",
|
||||
database.EnumArray[domain.Prompt]{domain.PromptLogin, domain.PromptConsent},
|
||||
database.StringArray{"en", "fi"},
|
||||
"me@example.com",
|
||||
int64(time.Minute),
|
||||
"userID",
|
||||
}, "123", "instanceID"),
|
||||
want: &AuthRequest{
|
||||
ID: "id",
|
||||
CreationDate: testNow,
|
||||
LoginClient: "loginClient",
|
||||
ClientID: "clientID",
|
||||
Scope: []string{"a", "b", "c"},
|
||||
RedirectURI: "example.com",
|
||||
Prompt: []domain.Prompt{domain.PromptLogin, domain.PromptConsent},
|
||||
UiLocales: []string{"en", "fi"},
|
||||
LoginHint: gu.Ptr("me@example.com"),
|
||||
MaxAge: gu.Ptr(time.Minute),
|
||||
HintUserID: gu.Ptr("userID"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success, null values",
|
||||
args: args{
|
||||
shouldTriggerBulk: false,
|
||||
id: "123",
|
||||
checkLoginClient: true,
|
||||
},
|
||||
expect: mockQuery(expQuery, cols, []driver.Value{
|
||||
"id",
|
||||
testNow,
|
||||
"loginClient",
|
||||
"clientID",
|
||||
database.StringArray{"a", "b", "c"},
|
||||
"example.com",
|
||||
database.EnumArray[domain.Prompt]{domain.PromptLogin, domain.PromptConsent},
|
||||
database.StringArray{"en", "fi"},
|
||||
sql.NullString{},
|
||||
sql.NullInt64{},
|
||||
sql.NullString{},
|
||||
}, "123", "instanceID"),
|
||||
want: &AuthRequest{
|
||||
ID: "id",
|
||||
CreationDate: testNow,
|
||||
LoginClient: "loginClient",
|
||||
ClientID: "clientID",
|
||||
Scope: []string{"a", "b", "c"},
|
||||
RedirectURI: "example.com",
|
||||
Prompt: []domain.Prompt{domain.PromptLogin, domain.PromptConsent},
|
||||
UiLocales: []string{"en", "fi"},
|
||||
LoginHint: nil,
|
||||
MaxAge: nil,
|
||||
HintUserID: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no rows",
|
||||
args: args{
|
||||
shouldTriggerBulk: false,
|
||||
id: "123",
|
||||
},
|
||||
expect: mockQuery(expQuery, cols, nil, "123", "instanceID"),
|
||||
wantErr: errors.ThrowNotFound(sql.ErrNoRows, "QUERY-Thee9", "Errors.AuthRequest.NotExisting"),
|
||||
},
|
||||
{
|
||||
name: "query error",
|
||||
args: args{
|
||||
shouldTriggerBulk: false,
|
||||
id: "123",
|
||||
},
|
||||
expect: mockQueryErr(expQuery, sql.ErrConnDone, "123", "instanceID"),
|
||||
wantErr: errors.ThrowInternal(sql.ErrConnDone, "QUERY-Ou8ue", "Errors.Internal"),
|
||||
},
|
||||
{
|
||||
name: "wrong login client",
|
||||
args: args{
|
||||
shouldTriggerBulk: false,
|
||||
id: "123",
|
||||
checkLoginClient: true,
|
||||
},
|
||||
expect: mockQuery(expQuery, cols, []driver.Value{
|
||||
"id",
|
||||
testNow,
|
||||
"wrongLoginClient",
|
||||
"clientID",
|
||||
database.StringArray{"a", "b", "c"},
|
||||
"example.com",
|
||||
database.EnumArray[domain.Prompt]{domain.PromptLogin, domain.PromptConsent},
|
||||
database.StringArray{"en", "fi"},
|
||||
sql.NullString{},
|
||||
sql.NullInt64{},
|
||||
sql.NullString{},
|
||||
}, "123", "instanceID"),
|
||||
wantErr: errors.ThrowPermissionDeniedf(nil, "OIDCv2-aL0ag", "Errors.AuthRequest.WrongLoginClient"),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
execMock(t, tt.expect, func(db *sql.DB) {
|
||||
q := &Queries{
|
||||
client: &database.DB{
|
||||
DB: db,
|
||||
Database: &prepareDB{},
|
||||
},
|
||||
}
|
||||
ctx := authz.NewMockContext("instanceID", "orgID", "loginClient")
|
||||
|
||||
got, err := q.AuthRequestByID(ctx, tt.args.shouldTriggerBulk, tt.args.id, tt.args.checkLoginClient)
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
15
internal/query/embed/auth_request_by_id.sql
Normal file
15
internal/query/embed/auth_request_by_id.sql
Normal file
@ -0,0 +1,15 @@
|
||||
select
|
||||
id,
|
||||
creation_date,
|
||||
login_client,
|
||||
client_id,
|
||||
scope,
|
||||
redirect_uri,
|
||||
prompt,
|
||||
ui_locales,
|
||||
login_hint,
|
||||
max_age,
|
||||
hint_user_id
|
||||
from projections.auth_requests %s
|
||||
where id = $1 and instance_id = $2
|
||||
limit 1;
|
@ -14,6 +14,7 @@ import (
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
sq "github.com/Masterminds/squirrel"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -74,9 +75,9 @@ type checkErr func(error) (err error, ok bool)
|
||||
|
||||
type sqlExpectation func(sqlmock.Sqlmock) sqlmock.Sqlmock
|
||||
|
||||
func mockQuery(stmt string, cols []string, row []driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
|
||||
func mockQuery(stmt string, cols []string, row []driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
|
||||
return func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
|
||||
q := m.ExpectQuery(stmt)
|
||||
q := m.ExpectQuery(stmt).WithArgs(args...)
|
||||
result := sqlmock.NewRows(cols)
|
||||
if len(row) > 0 {
|
||||
result.AddRow(row...)
|
||||
@ -111,6 +112,15 @@ func mockQueryErr(stmt string, err error, args ...driver.Value) func(m sqlmock.S
|
||||
}
|
||||
}
|
||||
|
||||
func execMock(t testing.TB, exp sqlExpectation, run func(db *sql.DB)) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
mock = exp(mock)
|
||||
run(db)
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
var (
|
||||
rowType = reflect.TypeOf(&sql.Row{})
|
||||
rowsType = reflect.TypeOf(&sql.Rows{})
|
||||
@ -317,7 +327,9 @@ func TestValidatePrepare(t *testing.T) {
|
||||
|
||||
type prepareDB struct{}
|
||||
|
||||
func (_ *prepareDB) Timetravel(time.Duration) string { return " AS OF SYSTEM TIME '-1 ms' " }
|
||||
const asOfSystemTime = " AS OF SYSTEM TIME '-1 ms' "
|
||||
|
||||
func (*prepareDB) Timetravel(time.Duration) string { return asOfSystemTime }
|
||||
|
||||
var defaultPrepareArgs = []reflect.Value{reflect.ValueOf(context.Background()), reflect.ValueOf(new(prepareDB))}
|
||||
|
||||
|
142
internal/query/projection/auth_request.go
Normal file
142
internal/query/projection/auth_request.go
Normal file
@ -0,0 +1,142 @@
|
||||
package projection
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/handler"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/handler/crdb"
|
||||
"github.com/zitadel/zitadel/internal/repository/authrequest"
|
||||
"github.com/zitadel/zitadel/internal/repository/instance"
|
||||
)
|
||||
|
||||
const (
|
||||
AuthRequestsProjectionTable = "projections.auth_requests"
|
||||
|
||||
AuthRequestColumnID = "id"
|
||||
AuthRequestColumnCreationDate = "creation_date"
|
||||
AuthRequestColumnChangeDate = "change_date"
|
||||
AuthRequestColumnSequence = "sequence"
|
||||
AuthRequestColumnResourceOwner = "resource_owner"
|
||||
AuthRequestColumnInstanceID = "instance_id"
|
||||
AuthRequestColumnLoginClient = "login_client"
|
||||
AuthRequestColumnClientID = "client_id"
|
||||
AuthRequestColumnRedirectURI = "redirect_uri"
|
||||
AuthRequestColumnScope = "scope"
|
||||
AuthRequestColumnPrompt = "prompt"
|
||||
AuthRequestColumnUILocales = "ui_locales"
|
||||
AuthRequestColumnMaxAge = "max_age"
|
||||
AuthRequestColumnLoginHint = "login_hint"
|
||||
AuthRequestColumnHintUserID = "hint_user_id"
|
||||
)
|
||||
|
||||
type authRequestProjection struct {
|
||||
crdb.StatementHandler
|
||||
}
|
||||
|
||||
func newAuthRequestProjection(ctx context.Context, config crdb.StatementHandlerConfig) *authRequestProjection {
|
||||
p := new(authRequestProjection)
|
||||
config.ProjectionName = AuthRequestsProjectionTable
|
||||
config.Reducers = p.reducers()
|
||||
config.InitCheck = crdb.NewMultiTableCheck(
|
||||
crdb.NewTable([]*crdb.Column{
|
||||
crdb.NewColumn(AuthRequestColumnID, crdb.ColumnTypeText),
|
||||
crdb.NewColumn(AuthRequestColumnCreationDate, crdb.ColumnTypeTimestamp),
|
||||
crdb.NewColumn(AuthRequestColumnChangeDate, crdb.ColumnTypeTimestamp),
|
||||
crdb.NewColumn(AuthRequestColumnSequence, crdb.ColumnTypeInt64),
|
||||
crdb.NewColumn(AuthRequestColumnResourceOwner, crdb.ColumnTypeText),
|
||||
crdb.NewColumn(AuthRequestColumnInstanceID, crdb.ColumnTypeText),
|
||||
crdb.NewColumn(AuthRequestColumnLoginClient, crdb.ColumnTypeText),
|
||||
crdb.NewColumn(AuthRequestColumnClientID, crdb.ColumnTypeText),
|
||||
crdb.NewColumn(AuthRequestColumnRedirectURI, crdb.ColumnTypeText),
|
||||
crdb.NewColumn(AuthRequestColumnScope, crdb.ColumnTypeTextArray),
|
||||
crdb.NewColumn(AuthRequestColumnPrompt, crdb.ColumnTypeEnumArray, crdb.Nullable()),
|
||||
crdb.NewColumn(AuthRequestColumnUILocales, crdb.ColumnTypeTextArray, crdb.Nullable()),
|
||||
crdb.NewColumn(AuthRequestColumnMaxAge, crdb.ColumnTypeInt64, crdb.Nullable()),
|
||||
crdb.NewColumn(AuthRequestColumnLoginHint, crdb.ColumnTypeText, crdb.Nullable()),
|
||||
crdb.NewColumn(AuthRequestColumnHintUserID, crdb.ColumnTypeText, crdb.Nullable()),
|
||||
},
|
||||
crdb.NewPrimaryKey(AuthRequestColumnInstanceID, AuthRequestColumnID),
|
||||
),
|
||||
)
|
||||
p.StatementHandler = crdb.NewStatementHandler(ctx, config)
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *authRequestProjection) reducers() []handler.AggregateReducer {
|
||||
return []handler.AggregateReducer{
|
||||
{
|
||||
Aggregate: authrequest.AggregateType,
|
||||
EventRedusers: []handler.EventReducer{
|
||||
{
|
||||
Event: authrequest.AddedType,
|
||||
Reduce: p.reduceAuthRequestAdded,
|
||||
},
|
||||
{
|
||||
Event: authrequest.SucceededType,
|
||||
Reduce: p.reduceAuthRequestEnded,
|
||||
},
|
||||
{
|
||||
Event: authrequest.FailedType,
|
||||
Reduce: p.reduceAuthRequestEnded,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Aggregate: instance.AggregateType,
|
||||
EventRedusers: []handler.EventReducer{
|
||||
{
|
||||
Event: instance.InstanceRemovedEventType,
|
||||
Reduce: reduceInstanceRemovedHelper(AuthRequestColumnInstanceID),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *authRequestProjection) reduceAuthRequestAdded(event eventstore.Event) (*handler.Statement, error) {
|
||||
e, ok := event.(*authrequest.AddedEvent)
|
||||
if !ok {
|
||||
return nil, errors.ThrowInvalidArgumentf(nil, "HANDL-Sfwfa", "reduce.wrong.event.type %s", authrequest.AddedType)
|
||||
}
|
||||
|
||||
return crdb.NewCreateStatement(
|
||||
e,
|
||||
[]handler.Column{
|
||||
handler.NewCol(AuthRequestColumnID, e.Aggregate().ID),
|
||||
handler.NewCol(AuthRequestColumnInstanceID, e.Aggregate().InstanceID),
|
||||
handler.NewCol(AuthRequestColumnCreationDate, e.CreationDate()),
|
||||
handler.NewCol(AuthRequestColumnChangeDate, e.CreationDate()),
|
||||
handler.NewCol(AuthRequestColumnResourceOwner, e.Aggregate().ResourceOwner),
|
||||
handler.NewCol(AuthRequestColumnSequence, e.Sequence()),
|
||||
handler.NewCol(AuthRequestColumnLoginClient, e.LoginClient),
|
||||
handler.NewCol(AuthRequestColumnClientID, e.ClientID),
|
||||
handler.NewCol(AuthRequestColumnRedirectURI, e.RedirectURI),
|
||||
handler.NewCol(AuthRequestColumnScope, e.Scope),
|
||||
handler.NewCol(AuthRequestColumnPrompt, e.Prompt),
|
||||
handler.NewCol(AuthRequestColumnUILocales, e.UILocales),
|
||||
handler.NewCol(AuthRequestColumnMaxAge, e.MaxAge),
|
||||
handler.NewCol(AuthRequestColumnLoginHint, e.LoginHint),
|
||||
handler.NewCol(AuthRequestColumnHintUserID, e.HintUserID),
|
||||
},
|
||||
), nil
|
||||
}
|
||||
|
||||
func (p *authRequestProjection) reduceAuthRequestEnded(event eventstore.Event) (*handler.Statement, error) {
|
||||
switch event.(type) {
|
||||
case *authrequest.SucceededEvent,
|
||||
*authrequest.FailedEvent:
|
||||
break
|
||||
default:
|
||||
return nil, errors.ThrowInvalidArgumentf(nil, "HANDL-ASF3h", "reduce.wrong.event.type %s", []eventstore.EventType{authrequest.SucceededType, authrequest.FailedType})
|
||||
}
|
||||
|
||||
return crdb.NewDeleteStatement(
|
||||
event,
|
||||
[]handler.Condition{
|
||||
handler.NewCond(AuthRequestColumnID, event.Aggregate().ID),
|
||||
handler.NewCond(AuthRequestColumnInstanceID, event.Aggregate().InstanceID),
|
||||
},
|
||||
), nil
|
||||
}
|
134
internal/query/projection/auth_request_test.go
Normal file
134
internal/query/projection/auth_request_test.go
Normal file
@ -0,0 +1,134 @@
|
||||
package projection
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/muhlemmer/gu"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/handler"
|
||||
"github.com/zitadel/zitadel/internal/repository/authrequest"
|
||||
)
|
||||
|
||||
func TestAuthRequestProjection_reduces(t *testing.T) {
|
||||
type args struct {
|
||||
event func(t *testing.T) eventstore.Event
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
reduce func(event eventstore.Event) (*handler.Statement, error)
|
||||
want wantReduce
|
||||
}{
|
||||
{
|
||||
name: "reduceAuthRequestAdded",
|
||||
args: args{
|
||||
event: getEvent(testEvent(
|
||||
authrequest.AddedType,
|
||||
authrequest.AggregateType,
|
||||
[]byte(`{"login_client": "loginClient", "client_id":"clientId","redirect_uri": "redirectURI", "scope": ["openid"], "prompt": [1], "ui_locales": ["en","de"], "max_age": 0, "login_hint": "loginHint", "hint_user_id": "hintUserID"}`),
|
||||
), authrequest.AddedEventMapper),
|
||||
},
|
||||
reduce: (&authRequestProjection{}).reduceAuthRequestAdded,
|
||||
want: wantReduce{
|
||||
aggregateType: eventstore.AggregateType("auth_request"),
|
||||
sequence: 15,
|
||||
previousSequence: 10,
|
||||
executer: &testExecuter{
|
||||
executions: []execution{
|
||||
{
|
||||
expectedStmt: "INSERT INTO projections.auth_requests (id, instance_id, creation_date, change_date, resource_owner, sequence, login_client, client_id, redirect_uri, scope, prompt, ui_locales, max_age, login_hint, hint_user_id) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)",
|
||||
expectedArgs: []interface{}{
|
||||
"agg-id",
|
||||
"instance-id",
|
||||
anyArg{},
|
||||
anyArg{},
|
||||
"ro-id",
|
||||
uint64(15),
|
||||
"loginClient",
|
||||
"clientId",
|
||||
"redirectURI",
|
||||
[]string{"openid"},
|
||||
[]domain.Prompt{domain.PromptNone},
|
||||
[]string{"en", "de"},
|
||||
gu.Ptr(time.Duration(0)),
|
||||
gu.Ptr("loginHint"),
|
||||
gu.Ptr("hintUserID"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "reduceAuthRequestFailed",
|
||||
args: args{
|
||||
event: getEvent(testEvent(
|
||||
authrequest.FailedType,
|
||||
authrequest.AggregateType,
|
||||
[]byte(`{"reason": 0}`),
|
||||
), authrequest.FailedEventMapper),
|
||||
},
|
||||
reduce: (&authRequestProjection{}).reduceAuthRequestEnded,
|
||||
want: wantReduce{
|
||||
aggregateType: eventstore.AggregateType("auth_request"),
|
||||
sequence: 15,
|
||||
previousSequence: 10,
|
||||
executer: &testExecuter{
|
||||
executions: []execution{
|
||||
{
|
||||
expectedStmt: "DELETE FROM projections.auth_requests WHERE (id = $1) AND (instance_id = $2)",
|
||||
expectedArgs: []interface{}{
|
||||
"agg-id",
|
||||
"instance-id",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "reduceAuthRequestSucceeded",
|
||||
args: args{
|
||||
event: getEvent(testEvent(
|
||||
authrequest.SucceededType,
|
||||
authrequest.AggregateType,
|
||||
nil,
|
||||
), authrequest.SucceededEventMapper),
|
||||
},
|
||||
reduce: (&authRequestProjection{}).reduceAuthRequestEnded,
|
||||
want: wantReduce{
|
||||
aggregateType: eventstore.AggregateType("auth_request"),
|
||||
sequence: 15,
|
||||
previousSequence: 10,
|
||||
executer: &testExecuter{
|
||||
executions: []execution{
|
||||
{
|
||||
expectedStmt: "DELETE FROM projections.auth_requests WHERE (id = $1) AND (instance_id = $2)",
|
||||
expectedArgs: []interface{}{
|
||||
"agg-id",
|
||||
"instance-id",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
event := baseEvent(t)
|
||||
got, err := tt.reduce(event)
|
||||
if !errors.IsErrorInvalidArgument(err) {
|
||||
t.Errorf("no wrong event mapping: %v, got: %v", err, got)
|
||||
}
|
||||
|
||||
event = tt.args.event(t)
|
||||
got, err = tt.reduce(event)
|
||||
assertReduce(t, got, err, AuthRequestsProjectionTable, tt.want)
|
||||
})
|
||||
}
|
||||
}
|
@ -67,6 +67,7 @@ var (
|
||||
TelemetryPusherProjection interface{}
|
||||
DeviceAuthProjection *deviceAuthProjection
|
||||
SessionProjection *sessionProjection
|
||||
AuthRequestProjection *authRequestProjection
|
||||
MilestoneProjection *milestoneProjection
|
||||
)
|
||||
|
||||
@ -145,6 +146,7 @@ func Create(ctx context.Context, sqlClient *database.DB, es *eventstore.Eventsto
|
||||
NotificationPolicyProjection = newNotificationPolicyProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["notification_policies"]))
|
||||
DeviceAuthProjection = newDeviceAuthProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["device_auth"]))
|
||||
SessionProjection = newSessionProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["sessions"]))
|
||||
AuthRequestProjection = newAuthRequestProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["auth_requests"]))
|
||||
MilestoneProjection = newMilestoneProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["milestones"]))
|
||||
newProjectionsList()
|
||||
return nil
|
||||
@ -243,6 +245,7 @@ func newProjectionsList() {
|
||||
NotificationPolicyProjection,
|
||||
DeviceAuthProjection,
|
||||
SessionProjection,
|
||||
AuthRequestProjection,
|
||||
MilestoneProjection,
|
||||
}
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@ -18,9 +19,11 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/query/projection"
|
||||
"github.com/zitadel/zitadel/internal/repository/action"
|
||||
"github.com/zitadel/zitadel/internal/repository/authrequest"
|
||||
"github.com/zitadel/zitadel/internal/repository/idpintent"
|
||||
iam_repo "github.com/zitadel/zitadel/internal/repository/instance"
|
||||
"github.com/zitadel/zitadel/internal/repository/keypair"
|
||||
"github.com/zitadel/zitadel/internal/repository/oidcsession"
|
||||
"github.com/zitadel/zitadel/internal/repository/org"
|
||||
"github.com/zitadel/zitadel/internal/repository/project"
|
||||
"github.com/zitadel/zitadel/internal/repository/session"
|
||||
@ -88,6 +91,8 @@ func StartQueries(
|
||||
usergrant.RegisterEventMappers(repo.eventstore)
|
||||
session.RegisterEventMappers(repo.eventstore)
|
||||
idpintent.RegisterEventMappers(repo.eventstore)
|
||||
authrequest.RegisterEventMappers(repo.eventstore)
|
||||
oidcsession.RegisterEventMappers(repo.eventstore)
|
||||
|
||||
repo.idpConfigEncryption = idpConfigEncryption
|
||||
repo.multifactors = domain.MultifactorConfigs{
|
||||
@ -115,3 +120,19 @@ func (q *Queries) Health(ctx context.Context) error {
|
||||
type prepareDatabase interface {
|
||||
Timetravel(d time.Duration) string
|
||||
}
|
||||
|
||||
// cleanStaticQueries removes whitespaces,
|
||||
// such as ` `, \t, \n, from queries to improve
|
||||
// readability in logs and errors.
|
||||
func cleanStaticQueries(qs ...*string) {
|
||||
regex := regexp.MustCompile(`\s+`)
|
||||
for _, q := range qs {
|
||||
*q = regex.ReplaceAllString(*q, " ")
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
cleanStaticQueries(
|
||||
&authRequestByIDQuery,
|
||||
)
|
||||
}
|
||||
|
17
internal/query/query_test.go
Normal file
17
internal/query/query_test.go
Normal file
@ -0,0 +1,17 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_cleanStaticQueries(t *testing.T) {
|
||||
query := `select
|
||||
foo,
|
||||
bar
|
||||
from table;`
|
||||
want := "select foo, bar from table;"
|
||||
cleanStaticQueries(&query)
|
||||
assert.Equal(t, want, query)
|
||||
}
|
26
internal/repository/authrequest/aggregate.go
Normal file
26
internal/repository/authrequest/aggregate.go
Normal file
@ -0,0 +1,26 @@
|
||||
package authrequest
|
||||
|
||||
import (
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
const (
|
||||
AggregateType = "auth_request"
|
||||
AggregateVersion = "v1"
|
||||
)
|
||||
|
||||
type Aggregate struct {
|
||||
eventstore.Aggregate
|
||||
}
|
||||
|
||||
func NewAggregate(id, instanceID string) *Aggregate {
|
||||
return &Aggregate{
|
||||
Aggregate: eventstore.Aggregate{
|
||||
Type: AggregateType,
|
||||
Version: AggregateVersion,
|
||||
ID: id,
|
||||
ResourceOwner: instanceID,
|
||||
InstanceID: instanceID,
|
||||
},
|
||||
}
|
||||
}
|
287
internal/repository/authrequest/auth_request.go
Normal file
287
internal/repository/authrequest/auth_request.go
Normal file
@ -0,0 +1,287 @@
|
||||
package authrequest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/repository"
|
||||
)
|
||||
|
||||
const (
|
||||
authRequestEventPrefix = "auth_request."
|
||||
AddedType = authRequestEventPrefix + "added"
|
||||
FailedType = authRequestEventPrefix + "failed"
|
||||
CodeAddedType = authRequestEventPrefix + "code.added"
|
||||
SessionLinkedType = authRequestEventPrefix + "session.linked"
|
||||
CodeExchangedType = authRequestEventPrefix + "code.exchanged"
|
||||
SucceededType = authRequestEventPrefix + "succeeded"
|
||||
)
|
||||
|
||||
type AddedEvent struct {
|
||||
eventstore.BaseEvent `json:"-"`
|
||||
|
||||
LoginClient string `json:"login_client"`
|
||||
ClientID string `json:"client_id"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
State string `json:"state,omitempty"`
|
||||
Nonce string `json:"nonce,omitempty"`
|
||||
Scope []string `json:"scope,omitempty"`
|
||||
Audience []string `json:"audience,omitempty"`
|
||||
ResponseType domain.OIDCResponseType `json:"response_type,omitempty"`
|
||||
CodeChallenge *domain.OIDCCodeChallenge `json:"code_challenge,omitempty"`
|
||||
Prompt []domain.Prompt `json:"prompt,omitempty"`
|
||||
UILocales []string `json:"ui_locales,omitempty"`
|
||||
MaxAge *time.Duration `json:"max_age,omitempty"`
|
||||
LoginHint *string `json:"login_hint,omitempty"`
|
||||
HintUserID *string `json:"hint_user_id,omitempty"`
|
||||
}
|
||||
|
||||
func (e *AddedEvent) Data() interface{} {
|
||||
return e
|
||||
}
|
||||
|
||||
func (e *AddedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint {
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewAddedEvent(ctx context.Context,
|
||||
aggregate *eventstore.Aggregate,
|
||||
loginClient,
|
||||
clientID,
|
||||
redirectURI,
|
||||
state,
|
||||
nonce string,
|
||||
scope,
|
||||
audience []string,
|
||||
responseType domain.OIDCResponseType,
|
||||
codeChallenge *domain.OIDCCodeChallenge,
|
||||
prompt []domain.Prompt,
|
||||
uiLocales []string,
|
||||
maxAge *time.Duration,
|
||||
loginHint,
|
||||
hintUserID *string,
|
||||
) *AddedEvent {
|
||||
return &AddedEvent{
|
||||
BaseEvent: *eventstore.NewBaseEventForPush(
|
||||
ctx,
|
||||
aggregate,
|
||||
AddedType,
|
||||
),
|
||||
LoginClient: loginClient,
|
||||
ClientID: clientID,
|
||||
RedirectURI: redirectURI,
|
||||
State: state,
|
||||
Nonce: nonce,
|
||||
Scope: scope,
|
||||
Audience: audience,
|
||||
ResponseType: responseType,
|
||||
CodeChallenge: codeChallenge,
|
||||
Prompt: prompt,
|
||||
UILocales: uiLocales,
|
||||
MaxAge: maxAge,
|
||||
LoginHint: loginHint,
|
||||
HintUserID: hintUserID,
|
||||
}
|
||||
}
|
||||
|
||||
func AddedEventMapper(event *repository.Event) (eventstore.Event, error) {
|
||||
added := &AddedEvent{
|
||||
BaseEvent: *eventstore.BaseEventFromRepo(event),
|
||||
}
|
||||
err := json.Unmarshal(event.Data, added)
|
||||
if err != nil {
|
||||
return nil, errors.ThrowInternal(err, "AUTHR-DG4gn", "unable to unmarshal auth request added")
|
||||
}
|
||||
|
||||
return added, nil
|
||||
}
|
||||
|
||||
type SessionLinkedEvent struct {
|
||||
eventstore.BaseEvent `json:"-"`
|
||||
|
||||
SessionID string `json:"session_id"`
|
||||
UserID string `json:"user_id"`
|
||||
AuthTime time.Time `json:"auth_time"`
|
||||
AMR []string `json:"amr"`
|
||||
}
|
||||
|
||||
func (e *SessionLinkedEvent) Data() interface{} {
|
||||
return e
|
||||
}
|
||||
|
||||
func (e *SessionLinkedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint {
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewSessionLinkedEvent(ctx context.Context,
|
||||
aggregate *eventstore.Aggregate,
|
||||
sessionID,
|
||||
userID string,
|
||||
authTime time.Time,
|
||||
amr []string,
|
||||
) *SessionLinkedEvent {
|
||||
return &SessionLinkedEvent{
|
||||
BaseEvent: *eventstore.NewBaseEventForPush(
|
||||
ctx,
|
||||
aggregate,
|
||||
SessionLinkedType,
|
||||
),
|
||||
SessionID: sessionID,
|
||||
UserID: userID,
|
||||
AuthTime: authTime,
|
||||
AMR: amr,
|
||||
}
|
||||
}
|
||||
|
||||
func SessionLinkedEventMapper(event *repository.Event) (eventstore.Event, error) {
|
||||
added := &SessionLinkedEvent{
|
||||
BaseEvent: *eventstore.BaseEventFromRepo(event),
|
||||
}
|
||||
err := json.Unmarshal(event.Data, added)
|
||||
if err != nil {
|
||||
return nil, errors.ThrowInternal(err, "AUTHR-Sfe3w", "unable to unmarshal auth request session linked")
|
||||
}
|
||||
|
||||
return added, nil
|
||||
}
|
||||
|
||||
type FailedEvent struct {
|
||||
eventstore.BaseEvent `json:"-"`
|
||||
|
||||
Reason domain.OIDCErrorReason `json:"reason,omitempty"`
|
||||
}
|
||||
|
||||
func (e *FailedEvent) Data() interface{} {
|
||||
return e
|
||||
}
|
||||
|
||||
func (e *FailedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint {
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewFailedEvent(
|
||||
ctx context.Context,
|
||||
aggregate *eventstore.Aggregate,
|
||||
reason domain.OIDCErrorReason,
|
||||
) *FailedEvent {
|
||||
return &FailedEvent{
|
||||
BaseEvent: *eventstore.NewBaseEventForPush(
|
||||
ctx,
|
||||
aggregate,
|
||||
FailedType,
|
||||
),
|
||||
Reason: reason,
|
||||
}
|
||||
}
|
||||
|
||||
func FailedEventMapper(event *repository.Event) (eventstore.Event, error) {
|
||||
added := &FailedEvent{
|
||||
BaseEvent: *eventstore.BaseEventFromRepo(event),
|
||||
}
|
||||
err := json.Unmarshal(event.Data, added)
|
||||
if err != nil {
|
||||
return nil, errors.ThrowInternal(err, "AUTHR-Sfe3w", "unable to unmarshal auth request session linked")
|
||||
}
|
||||
|
||||
return added, nil
|
||||
}
|
||||
|
||||
type CodeAddedEvent struct {
|
||||
eventstore.BaseEvent `json:"-"`
|
||||
}
|
||||
|
||||
func (e *CodeAddedEvent) Data() interface{} {
|
||||
return e
|
||||
}
|
||||
|
||||
func (e *CodeAddedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint {
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewCodeAddedEvent(ctx context.Context,
|
||||
aggregate *eventstore.Aggregate,
|
||||
) *CodeAddedEvent {
|
||||
return &CodeAddedEvent{
|
||||
BaseEvent: *eventstore.NewBaseEventForPush(
|
||||
ctx,
|
||||
aggregate,
|
||||
CodeAddedType,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
func CodeAddedEventMapper(event *repository.Event) (eventstore.Event, error) {
|
||||
added := &CodeAddedEvent{
|
||||
BaseEvent: *eventstore.BaseEventFromRepo(event),
|
||||
}
|
||||
err := json.Unmarshal(event.Data, added)
|
||||
if err != nil {
|
||||
return nil, errors.ThrowInternal(err, "AUTHR-Sfe3w", "unable to unmarshal auth request code added")
|
||||
}
|
||||
|
||||
return added, nil
|
||||
}
|
||||
|
||||
type CodeExchangedEvent struct {
|
||||
eventstore.BaseEvent `json:"-"`
|
||||
}
|
||||
|
||||
func (e *CodeExchangedEvent) Data() interface{} {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *CodeExchangedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint {
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewCodeExchangedEvent(ctx context.Context,
|
||||
aggregate *eventstore.Aggregate,
|
||||
) *CodeExchangedEvent {
|
||||
return &CodeExchangedEvent{
|
||||
BaseEvent: *eventstore.NewBaseEventForPush(
|
||||
ctx,
|
||||
aggregate,
|
||||
CodeExchangedType,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
func CodeExchangedEventMapper(event *repository.Event) (eventstore.Event, error) {
|
||||
return &CodeExchangedEvent{
|
||||
BaseEvent: *eventstore.BaseEventFromRepo(event),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type SucceededEvent struct {
|
||||
eventstore.BaseEvent `json:"-"`
|
||||
}
|
||||
|
||||
func (e *SucceededEvent) Data() interface{} {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *SucceededEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint {
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewSucceededEvent(ctx context.Context,
|
||||
aggregate *eventstore.Aggregate,
|
||||
) *SucceededEvent {
|
||||
return &SucceededEvent{
|
||||
BaseEvent: *eventstore.NewBaseEventForPush(
|
||||
ctx,
|
||||
aggregate,
|
||||
SucceededType,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
func SucceededEventMapper(event *repository.Event) (eventstore.Event, error) {
|
||||
return &SucceededEvent{
|
||||
BaseEvent: *eventstore.BaseEventFromRepo(event),
|
||||
}, nil
|
||||
}
|
12
internal/repository/authrequest/eventstore.go
Normal file
12
internal/repository/authrequest/eventstore.go
Normal file
@ -0,0 +1,12 @@
|
||||
package authrequest
|
||||
|
||||
import "github.com/zitadel/zitadel/internal/eventstore"
|
||||
|
||||
func RegisterEventMappers(es *eventstore.Eventstore) {
|
||||
es.RegisterFilterEventMapper(AggregateType, AddedType, AddedEventMapper).
|
||||
RegisterFilterEventMapper(AggregateType, SessionLinkedType, SessionLinkedEventMapper).
|
||||
RegisterFilterEventMapper(AggregateType, CodeAddedType, CodeAddedEventMapper).
|
||||
RegisterFilterEventMapper(AggregateType, CodeExchangedType, CodeExchangedEventMapper).
|
||||
RegisterFilterEventMapper(AggregateType, FailedType, FailedEventMapper).
|
||||
RegisterFilterEventMapper(AggregateType, SucceededType, SucceededEventMapper)
|
||||
}
|
25
internal/repository/oidcsession/aggregate.go
Normal file
25
internal/repository/oidcsession/aggregate.go
Normal file
@ -0,0 +1,25 @@
|
||||
package oidcsession
|
||||
|
||||
import (
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
const (
|
||||
AggregateType = "oidc_session"
|
||||
AggregateVersion = "v1"
|
||||
)
|
||||
|
||||
type Aggregate struct {
|
||||
eventstore.Aggregate
|
||||
}
|
||||
|
||||
func NewAggregate(id, resourceOwner string) *Aggregate {
|
||||
return &Aggregate{
|
||||
Aggregate: eventstore.Aggregate{
|
||||
Type: AggregateType,
|
||||
Version: AggregateVersion,
|
||||
ID: id,
|
||||
ResourceOwner: resourceOwner,
|
||||
},
|
||||
}
|
||||
}
|
11
internal/repository/oidcsession/eventstore.go
Normal file
11
internal/repository/oidcsession/eventstore.go
Normal file
@ -0,0 +1,11 @@
|
||||
package oidcsession
|
||||
|
||||
import "github.com/zitadel/zitadel/internal/eventstore"
|
||||
|
||||
func RegisterEventMappers(es *eventstore.Eventstore) {
|
||||
es.RegisterFilterEventMapper(AggregateType, AddedType, AddedEventMapper).
|
||||
RegisterFilterEventMapper(AggregateType, AccessTokenAddedType, AccessTokenAddedEventMapper).
|
||||
RegisterFilterEventMapper(AggregateType, RefreshTokenAddedType, RefreshTokenAddedEventMapper).
|
||||
RegisterFilterEventMapper(AggregateType, RefreshTokenRenewedType, RefreshTokenRenewedEventMapper)
|
||||
|
||||
}
|
215
internal/repository/oidcsession/oidc_session.go
Normal file
215
internal/repository/oidcsession/oidc_session.go
Normal file
@ -0,0 +1,215 @@
|
||||
package oidcsession
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/repository"
|
||||
)
|
||||
|
||||
const (
|
||||
oidcSessionEventPrefix = "oidc_session."
|
||||
AddedType = oidcSessionEventPrefix + "added"
|
||||
AccessTokenAddedType = oidcSessionEventPrefix + "access_token.added"
|
||||
RefreshTokenAddedType = oidcSessionEventPrefix + "refresh_token.added"
|
||||
RefreshTokenRenewedType = oidcSessionEventPrefix + "refresh_token.renewed"
|
||||
)
|
||||
|
||||
type AddedEvent struct {
|
||||
eventstore.BaseEvent `json:"-"`
|
||||
|
||||
UserID string `json:"userID"`
|
||||
SessionID string `json:"sessionID"`
|
||||
ClientID string `json:"clientID"`
|
||||
Audience []string `json:"audience"`
|
||||
Scope []string `json:"scope"`
|
||||
AuthMethodsReferences []string `json:"authMethodsReferences"`
|
||||
AuthTime time.Time `json:"authTime"`
|
||||
}
|
||||
|
||||
func (e *AddedEvent) Data() interface{} {
|
||||
return e
|
||||
}
|
||||
|
||||
func (e *AddedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint {
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewAddedEvent(ctx context.Context,
|
||||
aggregate *eventstore.Aggregate,
|
||||
userID,
|
||||
sessionID,
|
||||
clientID string,
|
||||
audience,
|
||||
scope []string,
|
||||
authMethodsReferences []string,
|
||||
authTime time.Time,
|
||||
) *AddedEvent {
|
||||
return &AddedEvent{
|
||||
BaseEvent: *eventstore.NewBaseEventForPush(
|
||||
ctx,
|
||||
aggregate,
|
||||
AddedType,
|
||||
),
|
||||
UserID: userID,
|
||||
SessionID: sessionID,
|
||||
ClientID: clientID,
|
||||
Audience: audience,
|
||||
Scope: scope,
|
||||
AuthMethodsReferences: authMethodsReferences,
|
||||
AuthTime: authTime,
|
||||
}
|
||||
}
|
||||
|
||||
func AddedEventMapper(event *repository.Event) (eventstore.Event, error) {
|
||||
added := &AddedEvent{
|
||||
BaseEvent: *eventstore.BaseEventFromRepo(event),
|
||||
}
|
||||
err := json.Unmarshal(event.Data, added)
|
||||
if err != nil {
|
||||
return nil, errors.ThrowInternal(err, "OIDCS-DG4gn", "unable to unmarshal oidc session added")
|
||||
}
|
||||
|
||||
return added, nil
|
||||
}
|
||||
|
||||
type AccessTokenAddedEvent struct {
|
||||
eventstore.BaseEvent `json:"-"`
|
||||
|
||||
ID string `json:"id"`
|
||||
Scope []string `json:"scope"`
|
||||
Lifetime time.Duration `json:"lifetime"`
|
||||
}
|
||||
|
||||
func (e *AccessTokenAddedEvent) Data() interface{} {
|
||||
return e
|
||||
}
|
||||
|
||||
func (e *AccessTokenAddedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint {
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewAccessTokenAddedEvent(
|
||||
ctx context.Context,
|
||||
aggregate *eventstore.Aggregate,
|
||||
id string,
|
||||
scope []string,
|
||||
lifetime time.Duration,
|
||||
) *AccessTokenAddedEvent {
|
||||
return &AccessTokenAddedEvent{
|
||||
BaseEvent: *eventstore.NewBaseEventForPush(
|
||||
ctx,
|
||||
aggregate,
|
||||
AccessTokenAddedType,
|
||||
),
|
||||
ID: id,
|
||||
Scope: scope,
|
||||
Lifetime: lifetime,
|
||||
}
|
||||
}
|
||||
|
||||
func AccessTokenAddedEventMapper(event *repository.Event) (eventstore.Event, error) {
|
||||
added := &AccessTokenAddedEvent{
|
||||
BaseEvent: *eventstore.BaseEventFromRepo(event),
|
||||
}
|
||||
err := json.Unmarshal(event.Data, added)
|
||||
if err != nil {
|
||||
return nil, errors.ThrowInternal(err, "OIDCS-DSGn5", "unable to unmarshal access token added")
|
||||
}
|
||||
|
||||
return added, nil
|
||||
}
|
||||
|
||||
type RefreshTokenAddedEvent struct {
|
||||
eventstore.BaseEvent `json:"-"`
|
||||
|
||||
ID string `json:"id"`
|
||||
Lifetime time.Duration `json:"lifetime"`
|
||||
IdleLifetime time.Duration `json:"idleLifetime"`
|
||||
}
|
||||
|
||||
func (e *RefreshTokenAddedEvent) Data() interface{} {
|
||||
return e
|
||||
}
|
||||
|
||||
func (e *RefreshTokenAddedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint {
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewRefreshTokenAddedEvent(
|
||||
ctx context.Context,
|
||||
aggregate *eventstore.Aggregate,
|
||||
id string,
|
||||
lifetime,
|
||||
idleLifetime time.Duration,
|
||||
) *RefreshTokenAddedEvent {
|
||||
return &RefreshTokenAddedEvent{
|
||||
BaseEvent: *eventstore.NewBaseEventForPush(
|
||||
ctx,
|
||||
aggregate,
|
||||
RefreshTokenAddedType,
|
||||
),
|
||||
ID: id,
|
||||
Lifetime: lifetime,
|
||||
IdleLifetime: idleLifetime,
|
||||
}
|
||||
}
|
||||
|
||||
func RefreshTokenAddedEventMapper(event *repository.Event) (eventstore.Event, error) {
|
||||
added := &RefreshTokenAddedEvent{
|
||||
BaseEvent: *eventstore.BaseEventFromRepo(event),
|
||||
}
|
||||
err := json.Unmarshal(event.Data, added)
|
||||
if err != nil {
|
||||
return nil, errors.ThrowInternal(err, "OIDCS-aW3gqq", "unable to unmarshal refresh token added")
|
||||
}
|
||||
|
||||
return added, nil
|
||||
}
|
||||
|
||||
type RefreshTokenRenewedEvent struct {
|
||||
eventstore.BaseEvent `json:"-"`
|
||||
|
||||
ID string `json:"id"`
|
||||
IdleLifetime time.Duration `json:"idleLifetime"`
|
||||
}
|
||||
|
||||
func (e *RefreshTokenRenewedEvent) Data() interface{} {
|
||||
return e
|
||||
}
|
||||
|
||||
func (e *RefreshTokenRenewedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint {
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewRefreshTokenRenewedEvent(
|
||||
ctx context.Context,
|
||||
aggregate *eventstore.Aggregate,
|
||||
id string,
|
||||
idleLifetime time.Duration,
|
||||
) *RefreshTokenRenewedEvent {
|
||||
return &RefreshTokenRenewedEvent{
|
||||
BaseEvent: *eventstore.NewBaseEventForPush(
|
||||
ctx,
|
||||
aggregate,
|
||||
RefreshTokenRenewedType,
|
||||
),
|
||||
ID: id,
|
||||
IdleLifetime: idleLifetime,
|
||||
}
|
||||
}
|
||||
|
||||
func RefreshTokenRenewedEventMapper(event *repository.Event) (eventstore.Event, error) {
|
||||
added := &RefreshTokenRenewedEvent{
|
||||
BaseEvent: *eventstore.BaseEventFromRepo(event),
|
||||
}
|
||||
err := json.Unmarshal(event.Data, added)
|
||||
if err != nil {
|
||||
return nil, errors.ThrowInternal(err, "OIDCS-SF3fc", "unable to unmarshal refresh token renewed")
|
||||
}
|
||||
|
||||
return added, nil
|
||||
}
|
@ -506,6 +506,12 @@ Errors:
|
||||
TokenCreationFailed: Неуспешно създаване на токен
|
||||
InvalidToken: Знакът за намерение е невалиден
|
||||
OtherUser: Намерение, предназначено за друг потребител
|
||||
AuthRequest:
|
||||
AlreadyExists: Auth Request вече съществува
|
||||
NotExisting: Auth Request не съществува
|
||||
WrongLoginClient: Auth Request, създаден от друг клиент за влизане
|
||||
OIDCSession:
|
||||
RefreshTokenInvalid: Токенът за опресняване е невалиден
|
||||
|
||||
AggregateTypes:
|
||||
action: Действие
|
||||
|
@ -488,7 +488,12 @@ Errors:
|
||||
TokenCreationFailed: Tokenerstellung schlug fehl
|
||||
InvalidToken: Intent Token ist ungültig
|
||||
OtherUser: Intent ist für anderen Benutzer gedacht
|
||||
|
||||
AuthRequest:
|
||||
AlreadyExists: Auth Request existiert bereits
|
||||
NotExisting: Auth Request existiert nicht
|
||||
WrongLoginClient: Auth Request wurde von einem anderen Login-Client erstellt
|
||||
OIDCSession:
|
||||
RefreshTokenInvalid: Refresh Token ist ungültig
|
||||
AggregateTypes:
|
||||
action: Action
|
||||
instance: Instanz
|
||||
|
@ -488,6 +488,12 @@ Errors:
|
||||
TokenCreationFailed: Token creation failed
|
||||
InvalidToken: Intent Token is invalid
|
||||
OtherUser: Intent meant for another user
|
||||
AuthRequest:
|
||||
AlreadyExists: Auth Request already exists
|
||||
NotExisting: Auth Request does not exist
|
||||
WrongLoginClient: Auth Request created by other login client
|
||||
OIDCSession:
|
||||
RefreshTokenInvalid: Refresh Token is invalid
|
||||
|
||||
AggregateTypes:
|
||||
action: Action
|
||||
|
@ -488,6 +488,12 @@ Errors:
|
||||
TokenCreationFailed: Fallo en la creación del token
|
||||
InvalidToken: El token de la intención no es válido
|
||||
OtherUser: Destinado a otro usuario
|
||||
AuthRequest:
|
||||
AlreadyExists: Auth Request ya existe
|
||||
NotExisting: Auth Request no existe
|
||||
WrongLoginClient: Auth Request creado por otro cliente de inicio de sesión
|
||||
OIDCSession:
|
||||
RefreshTokenInvalid: El token de refresco no es válido
|
||||
|
||||
AggregateTypes:
|
||||
action: Acción
|
||||
|
@ -488,6 +488,12 @@ Errors:
|
||||
TokenCreationFailed: La création du token a échoué
|
||||
InvalidToken: Le jeton d'intention n'est pas valide
|
||||
OtherUser: Intention destinée à un autre utilisateur
|
||||
AuthRequest:
|
||||
AlreadyExists: Auth Request existe déjà
|
||||
NotExisting: Auth Request n'existe pas
|
||||
WrongLoginClient: Auth Request créé par un autre client de connexion
|
||||
OIDCSession:
|
||||
RefreshTokenInvalid: Le jeton de rafraîchissement n'est pas valide
|
||||
|
||||
AggregateTypes:
|
||||
action: Action
|
||||
|
@ -488,6 +488,12 @@ Errors:
|
||||
TokenCreationFailed: creazione del token fallita
|
||||
InvalidToken: Il token dell'intento non è valido
|
||||
OtherUser: Intento destinato a un altro utente
|
||||
AuthRequest:
|
||||
AlreadyExists: Auth Request esiste già
|
||||
NotExisting: Auth Request non esiste
|
||||
WrongLoginClient: Auth Request creato da un altro client di accesso
|
||||
OIDCSession:
|
||||
RefreshTokenInvalid: Refresh Token non è valido
|
||||
|
||||
AggregateTypes:
|
||||
action: Azione
|
||||
|
@ -477,6 +477,12 @@ Errors:
|
||||
TokenCreationFailed: トークンの作成に失敗しました
|
||||
InvalidToken: インテントのトークンが無効である
|
||||
OtherUser: 他のユーザーを意図している
|
||||
AuthRequest:
|
||||
AlreadyExists: AuthRequestはすでに存在する
|
||||
NotExisting: AuthRequest が存在しません
|
||||
WrongLoginClient: 他のログインクライアントによって作成された AuthRequest
|
||||
OIDCSession:
|
||||
RefreshTokenInvalid: 無効なリフレッシュトークンです
|
||||
|
||||
AggregateTypes:
|
||||
action: アクション
|
||||
|
@ -488,6 +488,12 @@ Errors:
|
||||
TokenCreationFailed: Tworzenie tokena nie powiodło się
|
||||
InvalidToken: Token intencji jest nieprawidłowy
|
||||
OtherUser: Intencja przeznaczona dla innego użytkownika
|
||||
AuthRequest:
|
||||
AlreadyExists: Auth Request już istnieje
|
||||
NotExisting: Auth Request nie istnieje
|
||||
WrongLoginClient: Auth Request utworzony przez innego klienta logowania
|
||||
OIDCSession:
|
||||
RefreshTokenInvalid: Refresh Token jest nieprawidłowy
|
||||
|
||||
AggregateTypes:
|
||||
action: Działanie
|
||||
|
@ -488,6 +488,12 @@ Errors:
|
||||
TokenCreationFailed: 令牌创建失败
|
||||
InvalidToken: 意图令牌是无效的
|
||||
OtherUser: 意图是为另一个用户准备的
|
||||
AuthRequest:
|
||||
AlreadyExists: AuthRequest已经存在
|
||||
NotExisting: AuthRequest不存在
|
||||
WrongLoginClient: 其他登录客户端创建的AuthRequest
|
||||
OIDCSession:
|
||||
RefreshTokenInvalid: Refresh Token 无效
|
||||
|
||||
AggregateTypes:
|
||||
action: 动作
|
||||
|
117
proto/zitadel/oidc/v2alpha/authorization.proto
Normal file
117
proto/zitadel/oidc/v2alpha/authorization.proto
Normal file
@ -0,0 +1,117 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package zitadel.oidc.v2alpha;
|
||||
|
||||
import "google/protobuf/duration.proto";
|
||||
import "google/protobuf/timestamp.proto";
|
||||
import "protoc-gen-openapiv2/options/annotations.proto";
|
||||
|
||||
option go_package = "github.com/zitadel/zitadel/pkg/grpc/oidc/v2alpha;oidc";
|
||||
|
||||
message AuthRequest{
|
||||
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_schema) = {
|
||||
external_docs: {
|
||||
url: "https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest";
|
||||
description: "Find out more about OIDC Auth Request parameters";
|
||||
}
|
||||
};
|
||||
|
||||
string id = 1 [
|
||||
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
|
||||
description: "ID of the authorization request";
|
||||
}
|
||||
];
|
||||
|
||||
google.protobuf.Timestamp creation_date = 2 [
|
||||
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
|
||||
description: "Time when the auth request was created";
|
||||
}
|
||||
];
|
||||
|
||||
string client_id = 3 [
|
||||
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
|
||||
description: "OIDC client ID of the application that created the auth request";
|
||||
}
|
||||
];
|
||||
|
||||
repeated string scope = 4 [
|
||||
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
|
||||
description: "Requested scopes by the application, which the user must consent to.";
|
||||
}
|
||||
];
|
||||
|
||||
string redirect_uri = 5 [
|
||||
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
|
||||
description: "Base URI that points back to the application";
|
||||
}
|
||||
];
|
||||
|
||||
repeated Prompt prompt = 6 [
|
||||
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
|
||||
description: "Prompts that must be displayed to the user";
|
||||
}
|
||||
];
|
||||
|
||||
repeated string ui_locales = 7 [
|
||||
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
|
||||
description: "End-User's preferred languages and scripts for the user interface, represented as a list of BCP47 [RFC5646] language tag values, ordered by preference. For instance, the value [fr-CA, fr, en] represents a preference for French as spoken in Canada, then French (without a region designation), followed by English (without a region designation). An error SHOULD NOT result if some or all of the requested locales are not supported.";
|
||||
}
|
||||
];
|
||||
|
||||
optional string login_hint = 8 [
|
||||
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
|
||||
description: "Login hint can be set by the application with a user identifier such as an email or phone number.";
|
||||
}
|
||||
];
|
||||
|
||||
optional google.protobuf.Duration max_age = 9 [
|
||||
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
|
||||
description: "Specifies the allowable elapsed time in seconds since the last time the End-User was actively authenticated. If the elapsed time is greater than this value, or the field is present with 0 duration, the user must be re-authenticated.";
|
||||
}
|
||||
];
|
||||
|
||||
optional string hint_user_id = 10 [
|
||||
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
|
||||
description: "User ID taken from a ID Token Hint if it was present and valid.";
|
||||
}
|
||||
];
|
||||
}
|
||||
|
||||
enum Prompt {
|
||||
PROMPT_UNSPECIFIED = 0;
|
||||
PROMPT_NONE = 1;
|
||||
PROMPT_LOGIN = 2;
|
||||
PROMPT_CONSENT = 3;
|
||||
PROMPT_SELECT_ACCOUNT = 4;
|
||||
PROMPT_CREATE = 5;
|
||||
}
|
||||
|
||||
message AuthorizationError {
|
||||
ErrorReason error = 1;
|
||||
optional string error_description = 2;
|
||||
optional string error_uri = 3;
|
||||
}
|
||||
|
||||
enum ErrorReason {
|
||||
ERROR_REASON_UNSPECIFIED = 0;
|
||||
|
||||
// Error states from https://datatracker.ietf.org/doc/html/rfc6749#section-4.2.2.1
|
||||
ERROR_REASON_INVALID_REQUEST = 1;
|
||||
ERROR_REASON_UNAUTHORIZED_CLIENT = 2;
|
||||
ERROR_REASON_ACCESS_DENIED = 3;
|
||||
ERROR_REASON_UNSUPPORTED_RESPONSE_TYPE = 4;
|
||||
ERROR_REASON_INVALID_SCOPE = 5;
|
||||
ERROR_REASON_SERVER_ERROR = 6;
|
||||
ERROR_REASON_TEMPORARY_UNAVAILABLE = 7;
|
||||
|
||||
// Error states from https://openid.net/specs/openid-connect-core-1_0.html#AuthError
|
||||
ERROR_REASON_INTERACTION_REQUIRED = 8;
|
||||
ERROR_REASON_LOGIN_REQUIRED = 9;
|
||||
ERROR_REASON_ACCOUNT_SELECTION_REQUIRED = 10;
|
||||
ERROR_REASON_CONSENT_REQUIRED = 11;
|
||||
ERROR_REASON_INVALID_REQUEST_URI = 12;
|
||||
ERROR_REASON_INVALID_REQUEST_OBJECT = 13;
|
||||
ERROR_REASON_REQUEST_NOT_SUPPORTED = 14;
|
||||
ERROR_REASON_REQUEST_URI_NOT_SUPPORTED = 15;
|
||||
ERROR_REASON_REGISTRATION_NOT_SUPPORTED = 16;
|
||||
}
|
190
proto/zitadel/oidc/v2alpha/oidc_service.proto
Normal file
190
proto/zitadel/oidc/v2alpha/oidc_service.proto
Normal file
@ -0,0 +1,190 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package zitadel.oidc.v2alpha;
|
||||
|
||||
import "zitadel/object/v2alpha/object.proto";
|
||||
import "zitadel/protoc_gen_zitadel/v2/options.proto";
|
||||
import "zitadel/oidc/v2alpha/authorization.proto";
|
||||
import "google/api/annotations.proto";
|
||||
import "google/api/field_behavior.proto";
|
||||
import "protoc-gen-openapiv2/options/annotations.proto";
|
||||
import "validate/validate.proto";
|
||||
|
||||
option go_package = "github.com/zitadel/zitadel/pkg/grpc/oidc/v2alpha;oidc";
|
||||
|
||||
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_swagger) = {
|
||||
info: {
|
||||
title: "OIDC Service";
|
||||
version: "2.0-alpha";
|
||||
description: "Get OIDC Auth Request details and create callback URLs. This project is in alpha state. It can AND will continue breaking until the services provide the same functionality as the current login.";
|
||||
contact:{
|
||||
name: "ZITADEL"
|
||||
url: "https://zitadel.com"
|
||||
email: "hi@zitadel.com"
|
||||
}
|
||||
license: {
|
||||
name: "Apache 2.0",
|
||||
url: "https://github.com/zitadel/zitadel/blob/main/LICENSE";
|
||||
};
|
||||
};
|
||||
schemes: HTTPS;
|
||||
schemes: HTTP;
|
||||
|
||||
consumes: "application/json";
|
||||
consumes: "application/grpc";
|
||||
|
||||
produces: "application/json";
|
||||
produces: "application/grpc";
|
||||
|
||||
consumes: "application/grpc-web+proto";
|
||||
produces: "application/grpc-web+proto";
|
||||
|
||||
host: "$ZITADEL_DOMAIN";
|
||||
base_path: "/";
|
||||
|
||||
external_docs: {
|
||||
description: "Detailed information about ZITADEL",
|
||||
url: "https://zitadel.com/docs"
|
||||
}
|
||||
|
||||
responses: {
|
||||
key: "403";
|
||||
value: {
|
||||
description: "Returned when the user does not have permission to access the resource.";
|
||||
schema: {
|
||||
json_schema: {
|
||||
ref: "#/definitions/rpcStatus";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
responses: {
|
||||
key: "404";
|
||||
value: {
|
||||
description: "Returned when the resource does not exist.";
|
||||
schema: {
|
||||
json_schema: {
|
||||
ref: "#/definitions/rpcStatus";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
service OIDCService {
|
||||
rpc GetAuthRequest (GetAuthRequestRequest) returns (GetAuthRequestResponse) {
|
||||
option (google.api.http) = {
|
||||
get: "/v2alpha/oidc/auth_requests/{auth_request_id}"
|
||||
};
|
||||
|
||||
option (zitadel.protoc_gen_zitadel.v2.options) = {
|
||||
auth_option: {
|
||||
permission: "authenticated"
|
||||
}
|
||||
};
|
||||
|
||||
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = {
|
||||
summary: "Get OIDC Auth Request details";
|
||||
description: "Get OIDC Auth Request details by ID, obtained from the redirect URL. Returns details that are parsed from the application's Auth Request."
|
||||
responses: {
|
||||
key: "200"
|
||||
value: {
|
||||
description: "OK";
|
||||
}
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
rpc CreateCallback (CreateCallbackRequest) returns (CreateCallbackResponse) {
|
||||
option (google.api.http) = {
|
||||
post: "/v2alpha/oidc/auth_requests/{auth_request_id}"
|
||||
body: "*"
|
||||
};
|
||||
|
||||
option (zitadel.protoc_gen_zitadel.v2.options) = {
|
||||
auth_option: {
|
||||
permission: "authenticated"
|
||||
}
|
||||
};
|
||||
|
||||
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = {
|
||||
summary: "Finalize an Auth Request and get the callback URL.";
|
||||
description: "Finalize an Auth Request and get the callback URL for success or failure. The user must be redirected to the URL in order to inform the application about the success or failure. On success, the URL contains details for the application to obtain the tokens. This method can only be called once for an Auth request."
|
||||
responses: {
|
||||
key: "200"
|
||||
value: {
|
||||
description: "OK";
|
||||
}
|
||||
};
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
message GetAuthRequestRequest {
|
||||
string auth_request_id = 1 [
|
||||
(validate.rules).string = {min_len: 1, max_len: 200},
|
||||
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
|
||||
min_length: 1;
|
||||
max_length: 200;
|
||||
description: "ID of the Auth Request, as obtained from the redirect URL.";
|
||||
example: "\"163840776835432705\"";
|
||||
}
|
||||
];
|
||||
}
|
||||
|
||||
message GetAuthRequestResponse {
|
||||
AuthRequest auth_request = 1;
|
||||
}
|
||||
|
||||
message CreateCallbackRequest {
|
||||
string auth_request_id = 1 [
|
||||
(validate.rules).string = {min_len: 1, max_len: 200},
|
||||
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
|
||||
description: "Set this field when the authorization flow failed. It creates a callback URL to the application, with the error details set.";
|
||||
ref: "https://openid.net/specs/openid-connect-core-1_0.html#AuthError";
|
||||
}
|
||||
];
|
||||
|
||||
oneof callback_kind {
|
||||
option (validate.required) = true;
|
||||
Session session = 2;
|
||||
AuthorizationError error = 3 [
|
||||
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
|
||||
description: "Set this field when the authorization flow failed. It creates a callback URL to the application, with the error details set.";
|
||||
ref: "https://openid.net/specs/openid-connect-core-1_0.html#AuthError";
|
||||
}
|
||||
];
|
||||
}
|
||||
}
|
||||
|
||||
message Session {
|
||||
string session_id = 1 [
|
||||
(validate.rules).string = {min_len: 1, max_len: 200},
|
||||
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
|
||||
min_length: 1;
|
||||
max_length: 200;
|
||||
description: "ID of the session, used to login the user. Connects the session to the Auth Request.";
|
||||
example: "\"163840776835432705\"";
|
||||
}
|
||||
];
|
||||
|
||||
string session_token = 2 [
|
||||
(validate.rules).string = {min_len: 1, max_len: 200},
|
||||
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
|
||||
min_length: 1;
|
||||
max_length: 200;
|
||||
description: "Token to verify the session is valid";
|
||||
}
|
||||
];
|
||||
}
|
||||
|
||||
message CreateCallbackResponse {
|
||||
zitadel.object.v2alpha.Details details = 1;
|
||||
string callback_url = 2 [
|
||||
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
|
||||
description: "Callback URL where the user should be redirected, using a \"302 FOUND\" status. Contains details for the application to obtain the tokens on success, or error details on failure. Note that this field must be treated as credentials, as the contained code can be used to obtain tokens on behalve of the user.";
|
||||
example: "\"https://client.example.org/cb?code=SplxlOBeZQQYbYS6WxSbIA&state=af0ifjsldkj\""
|
||||
}
|
||||
];
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user